diff --git a/.gitignore b/.gitignore index 97b6af2f8d..4ab886933e 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,7 @@ outputs # VS Code .vscode +.devcontainer # HPC nautilus/*.yaml diff --git a/README.md b/README.md index 8462634f45..e98f35663a 100644 --- a/README.md +++ b/README.md @@ -418,6 +418,19 @@ Additionally, if you are using any of the particular policy architecture, pretra year={2024} } ``` + + +- [HIL-SERL](https://hil-serl.github.io/) +```bibtex +@Article{luo2024hilserl, +title={Precise and Dexterous Robotic Manipulation via Human-in-the-Loop Reinforcement Learning}, +author={Jianlan Luo and Charles Xu and Jeffrey Wu and Sergey Levine}, +year={2024}, +eprint={2410.21845}, +archivePrefix={arXiv}, +primaryClass={cs.RO} +} +``` ## Star History [![Star History Chart](https://api.star-history.com/svg?repos=huggingface/lerobot&type=Timeline)](https://star-history.com/#huggingface/lerobot&Timeline) diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 5e628dec37..ea80e82577 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -5,11 +5,23 @@ title: Installation title: Get started - sections: - - local: getting_started_real_world_robot - title: Getting Started with Real-World Robots + - local: il_robots + title: Imitation Learning for Robots + - local: il_sim + title: Imitation Learning in Sim - local: cameras title: Cameras + - local: integrate_hardware + title: Bring Your Own Hardware + - local: hilserl + title: Train a Robot with RL + - local: hilserl_sim + title: Train RL in Simulation title: "Tutorials" +- sections: + - local: smolvla + title: Finetune SmolVLA + title: "Policies" - sections: - local: so101 title: SO-101 @@ -20,6 +32,10 @@ - local: lekiwi title: LeKiwi title: "Robots" +- sections: + - local: notebooks + title: Notebooks + title: "Resources" - sections: - local: contributing title: Contribute to LeRobot diff --git a/docs/source/cameras.mdx b/docs/source/cameras.mdx index 5556660e9d..d8a49c1ee7 100644 --- a/docs/source/cameras.mdx +++ b/docs/source/cameras.mdx @@ -75,13 +75,13 @@ finally: ```python -from lerobot.common.cameras.intel.configuration_realsense import RealSenseCameraConfig -from lerobot.common.cameras.intel.camera_realsense import RealSenseCamera +from lerobot.common.cameras.realsense.configuration_realsense import RealSenseCameraConfig +from lerobot.common.cameras.realsense.camera_realsense import RealSenseCamera from lerobot.common.cameras.configs import ColorMode, Cv2Rotation # Create a `RealSenseCameraConfig` specifying your camera’s serial number and enabling depth. config = RealSenseCameraConfig( - serial_number="233522074606", + serial_number_or_name="233522074606", fps=15, width=640, height=480, diff --git a/docs/source/hilserl.mdx b/docs/source/hilserl.mdx new file mode 100644 index 0000000000..149b25c687 --- /dev/null +++ b/docs/source/hilserl.mdx @@ -0,0 +1,547 @@ +# HIL-SERL Real Robot Training Workflow Guide + +In this tutorial you will go through the full Human-in-the-Loop Sample-Efficient Reinforcement Learning (HIL-SERL) workflow using LeRobot. You will master training a policy with RL on a real robot in just a few hours. + +HIL-SERL is a sample-efficient reinforcement learning algorithm that combines human demonstrations with online learning and human interventions. The approach starts from a small set of human demonstrations, uses them to train a reward classifier, and then employs an actor-learner architecture where humans can intervene during policy execution to guide exploration and correct unsafe behaviors. In this tutorial, you'll use a gamepad to provide interventions and control the robot during the learning process. + +It combines three key ingredients: + 1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point. + 2. **On-robot actor / learner loop with human interventions:** a distributed Soft Actor Critic (SAC) learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour. + 3. **Safety & efficiency tools:** joint/end-effector (EE) bounds, crop region of interest (ROI) preprocessing and WandB monitoring keep the data useful and the hardware safe. + +Together these elements let HIL-SERL reach near-perfect task success and faster cycle times than imitation-only baselines. + +

+ HIL-SERL workflow +

+ +

HIL-SERL workflow, Luo et al. 2024

+ +This guide provides step-by-step instructions for training a robot policy using LeRobot's HilSerl implementation to train on a real robot. + +## What do I need? + +- A gamepad (recommended) or keyboard to control the robot +- A Nvidia GPU +- A real robot with a follower and leader arm (optional if you use the keyboard or the gamepad) + +## What kind of tasks can I train? + +One can use HIL-SERL to train on a variety of manipulation tasks. Some recommendations: +- Start with a simple task to understand how the system works. + - Push cube to a goal region + - Pick and lift cube with the gripper +- Avoid extremely long horizon tasks. Focus on tasks that can be completed in 5-10 seconds. +- Once you have a good idea of how the system works, you can try more complex tasks and longer horizons. + - Pick and place cube + - Bimanual tasks to pick objects with two arms + - Hand-over tasks to transfer objects from one arm to another + - Go crazy! + +## Install LeRobot with HIL-SERL + +To install LeRobot with HIL-SERL, you need to install the `hilserl` extra. + +```bash +pip install -e ".[hilserl]" +``` + +## Real Robot Training Workflow + +### Understanding Configuration + +The training process begins with proper configuration for the HILSerl environment. The configuration class of interest is `HILSerlRobotEnvConfig` in `lerobot/common/envs/configs.py`. Which is defined as: + +```python +class HILSerlRobotEnvConfig(EnvConfig): + robot: RobotConfig | None = None # Main robot agent (defined in `lerobot/common/robots`) + teleop: TeleoperatorConfig | None = None # Teleoperator agent, e.g., gamepad or leader arm, (defined in `lerobot/common/teleoperators`) + wrapper: EnvTransformConfig | None = None # Environment wrapper settings; check `lerobot/scripts/server/gym_manipulator.py` + fps: int = 10 # Control frequency + name: str = "real_robot" # Environment name + mode: str = None # "record", "replay", or None (for training) + repo_id: str | None = None # LeRobot dataset repository ID + dataset_root: str | None = None # Local dataset root (optional) + task: str = "" # Task identifier + num_episodes: int = 10 # Number of episodes for recording + episode: int = 0 # episode index for replay + device: str = "cuda" # Compute device + push_to_hub: bool = True # Whether to push the recorded datasets to Hub + pretrained_policy_name_or_path: str | None = None # For policy loading + reward_classifier_pretrained_path: str | None = None # For reward model + number_of_steps_after_success: int = 0 # For reward classifier, collect more positive examples after a success to train a classifier +``` + + +### Finding Robot Workspace Bounds + +Before collecting demonstrations, you need to determine the appropriate operational bounds for your robot. + +This helps simplify the problem of learning on the real robot in two ways: 1) by limiting the robot's operational space to a specific region that solves the task and avoids unnecessary or unsafe exploration, and 2) by allowing training in end-effector space rather than joint space. Empirically, learning in joint space for reinforcement learning in manipulation is often a harder problem - some tasks are nearly impossible to learn in joint space but become learnable when the action space is transformed to end-effector coordinates. + +**Using find_joint_limits.py** + +This script helps you find the safe operational bounds for your robot's end-effector. Given that you have a follower and leader arm, you can use the script to find the bounds for the follower arm that will be applied during training. +Bounding the action space will reduce the redundant exploration of the agent and guarantees safety. + +```bash +python -m lerobot.scripts.find_joint_limits \ + --robot.type=so100_follower \ + --robot.port=/dev/tty.usbmodem58760431541 \ + --robot.id=black \ + --teleop.type=so100_leader \ + --teleop.port=/dev/tty.usbmodem58760431551 \ + --teleop.id=blue +``` + +**Workflow** + +1. Run the script and move the robot through the space that solves the task +2. The script will record the minimum and maximum end-effector positions and the joint angles and prints them to the console, for example: + ``` + Max ee position [0.2417 0.2012 0.1027] + Min ee position [0.1663 -0.0823 0.0336] + Max joint positions [-20.0, -20.0, -20.0, -20.0, -20.0, -20.0] + Min joint positions [50.0, 50.0, 50.0, 50.0, 50.0, 50.0] + ``` +3. Use these values in the configuration of your teleoperation device (TeleoperatorConfig) under the `end_effector_bounds` field + +**Example Configuration** + +```json +"end_effector_bounds": { + "max": [0.24, 0.20, 0.10], + "min": [0.16, -0.08, 0.03] +} +``` + +### Collecting Demonstrations + +With the bounds defined, you can safely collect demonstrations for training. Training RL with off-policy algorithm allows us to use offline datasets collected in order to improve the efficiency of the learning process. + +**Setting Up Record Mode** + +Create a configuration file for recording demonstrations (or edit an existing one like [env_config_so100.json](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_so100.json)): + +1. Set `mode` to `"record"` +2. Specify a unique `repo_id` for your dataset (e.g., "username/task_name") +3. Set `num_episodes` to the number of demonstrations you want to collect +4. Set `crop_params_dict` to `null` initially (we'll determine crops later) +5. Configure `robot`, `cameras`, and other hardware settings + +Example configuration section: +```json +"mode": "record", +"repo_id": "username/pick_lift_cube", +"dataset_root": null, +"task": "pick_and_lift", +"num_episodes": 15, +"episode": 0, +"push_to_hub": true +``` + +### Using a Teleoperation Device + +Along with your robot, you will need a teleoperation device to control it in order to collect datasets of your task and perform interventions during the online training. +We support using a gamepad or a keyboard or the leader arm of the robot. + +HIL-Serl learns actions in the end-effector space of the robot. Therefore, the teleoperation will control the end-effector's x,y,z displacements. + +For that we need to define a version of the robot that takes actions in the end-effector space. Check the robot class `SO100FollowerEndEffector` and its configuration `SO100FollowerEndEffectorConfig` for the default parameters related to the end-effector space. + +```python +class SO100FollowerEndEffectorConfig(SO100FollowerConfig): + """Configuration for the SO100FollowerEndEffector robot.""" + + # Default bounds for the end-effector position (in meters) + end_effector_bounds: dict[str, list[float]] = field( # bounds for the end-effector in x,y,z direction + default_factory=lambda: { + "min": [-1.0, -1.0, -1.0], # min x, y, z + "max": [1.0, 1.0, 1.0], # max x, y, z + } + ) + + max_gripper_pos: float = 50 # maximum gripper position that the gripper will be open at + + end_effector_step_sizes: dict[str, float] = field( # maximum step size for the end-effector in x,y,z direction + default_factory=lambda: { + "x": 0.02, + "y": 0.02, + "z": 0.02, + } + ) +``` + +The `Teleoperator` defines the teleoperation device. You can check the list of available teleoperators in `lerobot/common/teleoperators`. + +**Setting up the Gamepad** + +The gamepad provides a very convenient way to control the robot and the episode state. + +To setup the gamepad, you need to set the `control_mode` to `"gamepad"` and define the `teleop` section in the configuration file. + +```json + "teleop": { + "type": "gamepad", + "use_gripper": true + }, +``` + +

+ Figure shows the control mappings on a Logitech gamepad. +

+

Gamepad button mapping for robot control and episode management

+ +**Setting up the SO101 leader** + +The SO101 leader arm has reduced gears that allows it to move and track the follower arm during exploration. Therefore, taking over is much smoother than the gearless SO100. + +To setup the SO101 leader, you need to set the `control_mode` to `"leader"` and define the `teleop` section in the configuration file. + +```json + "teleop": { + "type": "so101_leader", + "port": "/dev/tty.usbmodem585A0077921", # check your port number + "use_degrees": true + }, +``` + +In order to annotate the success/failure of the episode, **you will need** to use a keyboard to press `s` for success, `esc` for failure. +During the online training, press `space` to take over the policy and `space` again to give the control back to the policy. + +
+Video: SO101 leader teleoperation + +
+ +
+ +

SO101 leader teleoperation example, the leader tracks the follower, press `space` to intervene

+
+ +**Recording Demonstrations** + +Start the recording process, an example of the config file can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_so100.json): + +```bash +python lerobot/scripts/rl/gym_manipulator.py --config_path lerobot/configs/env_config_so100.json +``` + +During recording: +1. The robot will reset to the initial position defined in the configuration file `fixed_reset_joint_positions` +2. Complete the task successfully +3. The episode ends with a reward of 1 when you press the "success" button +4. If the time limit is reached, or the fail button is pressed, the episode ends with a reward of 0 +5. You can rerecord an episode by pressing the "rerecord" button +6. The process automatically continues to the next episode +7. After recording all episodes, the dataset is pushed to the Hugging Face Hub (optional) and saved locally + + +### Processing the Dataset + +After collecting demonstrations, process them to determine optimal camera crops. +Reinforcement learning is sensitive to background distractions, so it is important to crop the images to the relevant workspace area. + +Visual RL algorithms learn directly from pixel inputs, making them vulnerable to irrelevant visual information. Background elements like changing lighting, shadows, people moving, or objects outside the workspace can confuse the learning process. Good ROI selection should: +- Include only the essential workspace where the task happens +- Capture the robot's end-effector and all objects involved in the task +- Exclude unnecessary background elements and distractions + +Note: If you already know the crop parameters, you can skip this step and just set the `crop_params_dict` in the configuration file during recording. + +**Determining Crop Parameters** + +Use the `crop_dataset_roi.py` script to interactively select regions of interest in your camera images: + +```bash +python lerobot/scripts/rl/crop_dataset_roi.py --repo-id username/pick_lift_cube +``` + +1. For each camera view, the script will display the first frame +2. Draw a rectangle around the relevant workspace area +3. Press 'c' to confirm the selection +4. Repeat for all camera views +5. The script outputs cropping parameters and creates a new cropped dataset + +Example output: +``` +Selected Rectangular Regions of Interest (top, left, height, width): +observation.images.side: [180, 207, 180, 200] +observation.images.front: [180, 250, 120, 150] +``` + +

+ +

+ +

Interactive cropping tool for selecting regions of interest

+ + +**Updating Configuration** + +Add these crop parameters to your training configuration: + +```json +"crop_params_dict": { + "observation.images.side": [180, 207, 180, 200], + "observation.images.front": [180, 250, 120, 150] +}, +"resize_size": [128, 128] +``` + +**Recommended image resolution** + +Most vision-based policies have been validated on square inputs of either **128×128** (default) or **64×64** pixels. We therefore advise setting the resize_size parameter to [128, 128] – or [64, 64] if you need to save GPU memory and bandwidth. Other resolutions are possible but have not been extensively tested. + + +### Training a Reward Classifier + +The reward classifier plays an important role in the HIL-SERL workflow by automating reward assignment and automatically detecting episode success. Instead of manually defining reward functions or relying on human feedback for every timestep, the reward classifier learns to predict success/failure from visual observations. This enables the RL algorithm to learn efficiently by providing consistent and automated reward signals based on the robot's camera inputs. + +This guide explains how to train a reward classifier for human-in-the-loop reinforcement learning implementation of LeRobot. Reward classifiers learn to predict the reward value given a state which can be used in an RL setup to train a policy. + +**Note**: Training a reward classifier is optional. You can start the first round of RL experiments by annotating the success manually with your gamepad or keyboard device. + +The reward classifier implementation in `modeling_classifier.py` uses a pretrained vision model to process the images. It can output either a single value for binary rewards to predict success/fail cases or multiple values for multi-class settings. + +**Collecting a Dataset for the reward classifier** + +Before training, you need to collect a dataset with labeled examples. The `record_dataset` function in `gym_manipulator.py` enables the process of collecting a dataset of observations, actions, and rewards. + +To collect a dataset, you need to modify some parameters in the environment configuration based on HILSerlRobotEnvConfig. + +```bash +python lerobot/scripts/rl/gym_manipulator.py --config_path lerobot/configs/reward_classifier_train_config.json +``` + +**Key Parameters for Data Collection** + +- **mode**: set it to `"record"` to collect a dataset +- **repo_id**: `"hf_username/dataset_name"`, name of the dataset and repo on the hub +- **num_episodes**: Number of episodes to record +- **number_of_steps_after_success**: Number of additional frames to record after a success (reward=1) is detected +- **fps**: Number of frames per second to record +- **push_to_hub**: Whether to push the dataset to the hub + +The `number_of_steps_after_success` parameter is crucial as it allows you to collect more positive examples. When a success is detected, the system will continue recording for the specified number of steps while maintaining the reward=1 label. Otherwise, there won't be enough states in the dataset labeled to 1 to train a good classifier. + +Example configuration section for data collection: + +```json +{ + "mode": "record", + "repo_id": "hf_username/dataset_name", + "dataset_root": "data/your_dataset", + "num_episodes": 20, + "push_to_hub": true, + "fps": 10, + "number_of_steps_after_success": 15 +} +``` + +**Reward Classifier Configuration** + +The reward classifier is configured using `configuration_classifier.py`. Here are the key parameters: + +- **model_name**: Base model architecture (e.g., we mainly use `"helper2424/resnet10"`) +- **model_type**: `"cnn"` or `"transformer"` +- **num_cameras**: Number of camera inputs +- **num_classes**: Number of output classes (typically 2 for binary success/failure) +- **hidden_dim**: Size of hidden representation +- **dropout_rate**: Regularization parameter +- **learning_rate**: Learning rate for optimizer + +Example configuration for training the [reward classifier](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/reward_classifier_train_config.json): + +```json +{ + "policy": { + "type": "reward_classifier", + "model_name": "helper2424/resnet10", + "model_type": "cnn", + "num_cameras": 2, + "num_classes": 2, + "hidden_dim": 256, + "dropout_rate": 0.1, + "learning_rate": 1e-4, + "device": "cuda", + "use_amp": true, + "input_features": { + "observation.images.front": { + "type": "VISUAL", + "shape": [3, 128, 128] + }, + "observation.images.side": { + "type": "VISUAL", + "shape": [3, 128, 128] + } + } + } +} +``` + +**Training the Classifier** + +To train the classifier, use the `train.py` script with your configuration: + +```bash +python lerobot/scripts/train.py --config_path path/to/reward_classifier_train_config.json +``` + +**Deploying and Testing the Model** + +To use your trained reward classifier, configure the `HILSerlRobotEnvConfig` to use your model: + +```python +env_config = HILSerlRobotEnvConfig( + reward_classifier_pretrained_path="path_to_your_pretrained_trained_model", + # Other environment parameters +) +``` +or set the argument in the json config file. + +```json +{ + "reward_classifier_pretrained_path": "path_to_your_pretrained_model" +} +``` + +Run `gym_manipulator.py` to test the model. +```bash +python lerobot/scripts/rl/gym_manipulator.py --config_path path/to/env_config.json +``` + +The reward classifier will automatically provide rewards based on the visual input from the robot's cameras. + +**Example Workflow for training the reward classifier** + +1. **Create the configuration files**: + Create the necessary json configuration files for the reward classifier and the environment. Check the examples [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/tree/main). + +2. **Collect a dataset**: + ```bash + python lerobot/scripts/rl/gym_manipulator.py --config_path lerobot/configs/env_config.json + ``` + +3. **Train the classifier**: + ```bash + python lerobot/scripts/train.py --config_path lerobot/configs/reward_classifier_train_config.json + ``` + +4. **Test the classifier**: + ```bash + python lerobot/scripts/rl/gym_manipulator.py --config_path lerobot/configs/env_config.json + ``` + +### Training with Actor-Learner + +The LeRobot system uses a distributed actor-learner architecture for training. This architecture decouples robot interactions from the learning process, allowing them to run concurrently without blocking each other. The actor server handles robot observations and actions, sending interaction data to the learner server. The learner server performs gradient descent and periodically updates the actor's policy weights. You will need to start two processes: a learner and an actor. + +**Configuration Setup** + +Create a training configuration file (example available [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/train_config_hilserl_so100.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/configs/train.py`. + +1. Configure the policy settings (`type="sac"`, `device`, etc.) +2. Set `dataset` to your cropped dataset +3. Configure environment settings with crop parameters +4. Check the other parameters related to SAC in [configuration_sac.py](https://github.com/huggingface/lerobot/blob/19bb621a7d0a31c20cd3cc08b1dbab68d3031454/lerobot/common/policies/sac/configuration_sac.py#L79). +5. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task. + +**Starting the Learner** + +First, start the learner server process: + +```bash +python lerobot/scripts/rl/learner.py --config_path lerobot/configs/train_config_hilserl_so100.json +``` + +The learner: +- Initializes the policy network +- Prepares replay buffers +- Opens a `gRPC` server to communicate with actors +- Processes transitions and updates the policy + +**Starting the Actor** + +In a separate terminal, start the actor process with the same configuration: + +```bash +python lerobot/scripts/rl/actor.py --config_path lerobot/configs/train_config_hilserl_so100.json +``` + +The actor: +- Connects to the learner via `gRPC` +- Initializes the environment +- Execute rollouts of the policy to collect experience +- Sends transitions to the learner +- Receives updated policy parameters + +**Training Flow** + +The training proceeds automatically: + +1. The actor executes the policy in the environment +2. Transitions are collected and sent to the learner +3. The learner updates the policy based on these transitions +4. Updated policy parameters are sent back to the actor +5. The process continues until the specified step limit is reached + +**Human in the Loop** + +- The key to learning efficiently is to have human interventions to provide corrective feedback and completing the task to aide the policy learning and exploration. +- To perform human interventions, you can press the upper right trigger button on the gamepad (or the `space` key on the keyboard). This will pause the policy actions and allow you to take over. +- A successful experiment is one where the human has to intervene at the start but then reduces the amount of interventions as the policy improves. You can monitor the intervention rate in the `wandb` dashboard. + +

+ Figure shows the control mappings on a Logitech gamepad. +

+ +

Example showing how human interventions help guide policy learning over time

+ +- The figure shows the plot of the episodic reward over interaction step. The figure shows the effect of human interventions on the policy learning. +- The orange curve is an experiment without any human interventions. While the pink and blue curves are experiments with human interventions. +- We can observe that the number of steps where the policy starts achieving the maximum reward is cut by a quarter when human interventions are present. + +**Monitoring and Debugging** + +If you have `wandb.enable` set to `true` in your configuration, you can monitor training progress in real-time through the [Weights & Biases](https://wandb.ai/site/) dashboard. + +### Guide to Human Interventions +The learning process is very sensitive to the intervention strategy. It will takes a few runs to understand how to intervene effectively. Some tips and hints: +- Allow the policy to explore for a few episodes at the start of training. +- Avoid intervening for long periods of time. Try to intervene in situation to correct the robot's behaviour when it goes off track. +- Once the policy starts achieving the task, even if its not perfect, you can limit your interventions to simple quick actions like a simple grasping commands. + +The ideal behaviour is that your intervention rate should drop gradually during training as shown in the figure below. + +

+ Intervention rate +

+ +

Plot of the intervention rate during a training run on a pick and lift cube task

+ +### Key hyperparameters to tune + +Some configuration values have a disproportionate impact on training stability and speed: + +- **`temperature_init`** (`policy.temperature_init`) – initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning. +- **`policy_parameters_push_frequency`** (`policy.actor_learner_config.policy_parameters_push_frequency`) – interval in *seconds* between two weight pushes from the learner to the actor. The default is `4 s`. Decrease to **1-2 s** to provide fresher weights (at the cost of more network traffic); increase only if your connection is slow, as this will reduce sample efficiency. +- **`storage_device`** (`policy.storage_device`) – device on which the learner keeps the policy parameters. If you have spare GPU memory, set this to `"cuda"` (instead of the default `"cpu"`). Keeping the weights on-GPU removes CPU→GPU transfer overhead and can significantly increase the number of learner updates per second. + + +Congrats 🎉, you have finished this tutorial! + +> [!TIP] +> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). + +Paper citation: +``` +@article{luo2024precise, + title={Precise and Dexterous Robotic Manipulation via Human-in-the-Loop Reinforcement Learning}, + author={Luo, Jianlan and Xu, Charles and Wu, Jeffrey and Levine, Sergey}, + journal={arXiv preprint arXiv:2410.21845}, + year={2024} +} +``` diff --git a/docs/source/hilserl_sim.mdx b/docs/source/hilserl_sim.mdx new file mode 100644 index 0000000000..3239ba91ac --- /dev/null +++ b/docs/source/hilserl_sim.mdx @@ -0,0 +1,120 @@ +# Train RL in Simulation + +This guide explains how to use the `gym_hil` simulation environments as an alternative to real robots when working with the LeRobot framework for Human-In-the-Loop (HIL) reinforcement learning. + +`gym_hil` is a package that provides Gymnasium-compatible simulation environments specifically designed for Human-In-the-Loop reinforcement learning. These environments allow you to: + +- Train policies in simulation to test the RL stack before training on real robots + +- Collect demonstrations in sim using external devices like gamepads or keyboards +- Perform human interventions during policy learning + +Currently, the main environment is a Franka Panda robot simulation based on MuJoCo, with tasks like picking up a cube. + + +## Installation + +First, install the `gym_hil` package within the LeRobot environment: + +```bash +pip install -e ".[hilserl]" +``` + +## What do I need? + +- A gamepad or keyboard to control the robot +- A Nvidia GPU + + + +## Configuration + +To use `gym_hil` with LeRobot, you need to create a configuration file. An example is provided [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/gym_hil_env.json). Key configuration sections include: + +### Environment Type and Task + +```json +{ + "type": "hil", + "name": "franka_sim", + "task": "PandaPickCubeGamepad-v0", + "device": "cuda" +} +``` + +Available tasks: +- `PandaPickCubeBase-v0`: Basic environment +- `PandaPickCubeGamepad-v0`: With gamepad control +- `PandaPickCubeKeyboard-v0`: With keyboard control + +### Gym Wrappers Configuration + +```json +"wrapper": { + "gripper_penalty": -0.02, + "control_time_s": 15.0, + "use_gripper": true, + "fixed_reset_joint_positions": [0.0, 0.195, 0.0, -2.43, 0.0, 2.62, 0.785], + "end_effector_step_sizes": { + "x": 0.025, + "y": 0.025, + "z": 0.025 + }, + "control_mode": "gamepad" + } +``` + +Important parameters: +- `gripper_penalty`: Penalty for excessive gripper movement +- `use_gripper`: Whether to enable gripper control +- `end_effector_step_sizes`: Size of the steps in the x,y,z axes of the end-effector +- `control_mode`: Set to `"gamepad"` to use a gamepad controller + +## Running with HIL RL of LeRobot + +### Basic Usage + +To run the environment, set mode to null: + +```python +python lerobot/scripts/rl/gym_manipulator.py --config_path path/to/gym_hil_env.json +``` + +### Recording a Dataset + +To collect a dataset, set the mode to `record` whilst defining the repo_id and number of episodes to record: + +```python +python lerobot/scripts/rl/gym_manipulator.py --config_path path/to/gym_hil_env.json +``` + +### Training a Policy + +To train a policy, checkout the configuration example available [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/train_gym_hil_env.json) and run the actor and learner servers: + +```python +python lerobot/scripts/rl/actor.py --config_path path/to/train_gym_hil_env.json +``` + +In a different terminal, run the learner server: + +```python +python lerobot/scripts/rl/learner.py --config_path path/to/train_gym_hil_env.json +``` + +The simulation environment provides a safe and repeatable way to develop and test your Human-In-the-Loop reinforcement learning components before deploying to real robots. + +Congrats 🎉, you have finished this tutorial! + +> [!TIP] +> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). + +Paper citation: +``` +@article{luo2024precise, + title={Precise and Dexterous Robotic Manipulation via Human-in-the-Loop Reinforcement Learning}, + author={Luo, Jianlan and Xu, Charles and Wu, Jeffrey and Levine, Sergey}, + journal={arXiv preprint arXiv:2410.21845}, + year={2024} +} +``` diff --git a/docs/source/getting_started_real_world_robot.mdx b/docs/source/il_robots.mdx similarity index 98% rename from docs/source/getting_started_real_world_robot.mdx rename to docs/source/il_robots.mdx index 85f2311dbd..d13e431c85 100644 --- a/docs/source/getting_started_real_world_robot.mdx +++ b/docs/source/il_robots.mdx @@ -1,4 +1,4 @@ -# Getting Started with Real-World Robots +# Imitation Learning on Real-World Robots This tutorial will explain how to train a neural network to control a real robot autonomously. @@ -273,6 +273,9 @@ python lerobot/scripts/train.py \ --resume=true ``` +#### Train using Collab +If your local computer doesn't have a powerful GPU you could utilize Google Collab to train your model by following the [ACT training notebook](./notebooks#training-act). + #### Upload policy checkpoints Once training is done, upload the latest checkpoint with: @@ -297,9 +300,6 @@ python -m lerobot.record \ --robot.port=/dev/ttyACM1 \ --robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \ --robot.id=my_awesome_follower_arm \ - --teleop.type=so100_leader \ - --teleop.port=/dev/ttyACM0 \ - --teleop.id=my_awesome_leader_arm \ --display_data=false \ --dataset.repo_id=$HF_USER/eval_so100 \ --dataset.single_task="Put lego brick into the transparent box" \ diff --git a/docs/source/il_sim.mdx b/docs/source/il_sim.mdx new file mode 100644 index 0000000000..625b2fc00d --- /dev/null +++ b/docs/source/il_sim.mdx @@ -0,0 +1,152 @@ +# Imitation Learning in Sim + +This tutorial will explain how to train a neural network to control a robot in simulation with imitation learning. + +**You'll learn:** +1. How to record a dataset in simulation with [gym-hil](https://github.com/huggingface/gym-hil) and visualize the dataset. +2. How to train a policy using your data. +3. How to evaluate your policy in simulation and visualize the results. + +For the simulation environment we use the same [repo](https://github.com/huggingface/gym-hil) that is also being used by the Human-In-the-Loop (HIL) reinforcement learning algorithm. +This environment is based on [MuJoCo](https://mujoco.org) and allows you to record datasets in LeRobotDataset format. +Teleoperation is easiest with a controller like the Logitech F710, but you can also use your keyboard if you are up for the challenge. + +## Installation + +First, install the `gym_hil` package within the LeRobot environment, go to your LeRobot folder and run this command: + +```bash +pip install -e ".[hilserl]" +``` + +## Teleoperate and Record a Dataset + +To use `gym_hil` with LeRobot, you need to use a configuration file. An example config file can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_gym_hil_il.json). + +To teleoperate and collect a dataset, we need to modify this config file and you should add your `repo_id` here: `"repo_id": "il_gym",` and `"num_episodes": 30,` and make sure you set `mode` to `record`, "mode": "record". + +If you do not have a Nvidia GPU also change `"device": "cuda"` parameter in the config file (for example to `mps` for MacOS). + +By default the config file assumes you use a controller. To use your keyboard please change the envoirment specified at `"task"` in the config file and set it to `"PandaPickCubeKeyboard-v0"`. + +Then we can run this command to start: + + + + +```bash +python lerobot/scripts/rl/gym_manipulator.py --config_path path/to/env_config_gym_hil_il.json +``` + + + + +```bash +mjpython lerobot/scripts/rl/gym_manipulator.py --config_path path/to/env_config_gym_hil_il.json +``` + + + + +Once rendered you can teleoperate the robot with the gamepad or keyboard, below you can find the gamepad/keyboard controls. + +Note that to teleoperate the robot you have to hold the "Human Take Over Pause Policy" Button `RB` to enable control! + +**Gamepad Controls** + +

+ Figure shows the control mappings on a Logitech gamepad. +

+

Gamepad button mapping for robot control and episode management

+ +**Keyboard controls** + +For keyboard controls use the `spacebar` to enable control and the following keys to move the robot: +```bash + Arrow keys: Move in X-Y plane + Shift and Shift_R: Move in Z axis + Right Ctrl and Left Ctrl: Open and close gripper + ESC: Exit +``` + +## Visualize a dataset + +If you uploaded your dataset to the hub you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id. + +

+ Figure shows the dataset visualizer +

+

Dataset visualizer

+ + +## Train a policy + +To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: +```bash +python lerobot/scripts/train.py \ + --dataset.repo_id=${HF_USER}/il_gym \ + --policy.type=act \ + --output_dir=outputs/train/il_sim_test \ + --job_name=il_sim_test \ + --policy.device=cuda \ + --wandb.enable=true +``` + +Let's explain the command: +1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/il_gym`. +2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset. +4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon. +5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`. + +Training should take several hours, 100k steps (which is the default) will take about 1h on Nvidia A100. You will find checkpoints in `outputs/train/il_sim_test/checkpoints`. + +#### Train using Collab +If your local computer doesn't have a powerful GPU you could utilize Google Collab to train your model by following the [ACT training notebook](./notebooks#training-act). + +#### Upload policy checkpoints + +Once training is done, upload the latest checkpoint with: +```bash +huggingface-cli upload ${HF_USER}/il_sim_test \ + outputs/train/il_sim_test/checkpoints/last/pretrained_model +``` + +You can also upload intermediate checkpoints with: +```bash +CKPT=010000 +huggingface-cli upload ${HF_USER}/il_sim_test${CKPT} \ + outputs/train/il_sim_test/checkpoints/${CKPT}/pretrained_model +``` + +## Evaluate your policy in Sim + +To evaluate your policy we have to use the config file that can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/eval_config_gym_hil.json). + +Make sure to replace the `repo_id` with the dataset you trained on, for example `pepijn223/il_sim_dataset` and replace the `pretrained_policy_name_or_path` with your model id, for example `pepijn223/il_sim_model` + +Then you can run this command to visualize your trained policy + + + + +```bash +python lerobot/scripts/rl/eval_policy.py --config_path=path/to/eval_config_gym_hil.json +``` + + + + +```bash +mjpython lerobot/scripts/rl/eval_policy.py --config_path=path/to/eval_config_gym_hil.json +``` + + + + +> [!WARNING] +> While the main workflow of training ACT in simulation is straightforward, there is significant room for exploring how to set up the task, define the initial state of the environment, and determine the type of data required during collection to learn the most effective policy. If your trained policy doesn't perform well, investigate the quality of the dataset it was trained on using our visualizers, as well as the action values and various hyperparameters related to ACT and the simulation. + +Congrats 🎉, you have finished this tutorial. If you want to continue with using LeRobot in simulation follow this [Tutorial on reinforcement learning in sim with HIL-SERL](https://huggingface.co/docs/lerobot/hilserl_sim) + +> [!TIP] +> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index acb2a7a59d..51474d8f7a 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -68,3 +68,5 @@ To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tra ```bash wandb login ``` + +You can now assemble your robot if it's not ready yet, look for your robot type on the left. Then follow the link below to use Lerobot with your robot. diff --git a/docs/source/integrate_hardware.mdx b/docs/source/integrate_hardware.mdx new file mode 100644 index 0000000000..f7de1cece8 --- /dev/null +++ b/docs/source/integrate_hardware.mdx @@ -0,0 +1,318 @@ +# Bring Your Own Hardware + +This tutorial will explain how to integrate your own robot design into the LeRobot ecosystem and have it access all of our tools (data collection, control pipelines, policy training and inference). + +To that end, we provide the [`Robot`](https://github.com/huggingface/lerobot/blob/main/lerobot/common/robots/robot.py) base class in the LeRobot which specifies a standard interface for physical robot integration. Let's see how to implement it. + +## Prerequisites + +- Your own robot which exposes a communication interface (e.g. serial, CAN, TCP) +- A way to read sensor data and send motor commands programmatically, e.g. manufacturer's SDK or API, or your own protocol implementation. +- LeRobot installed in your environment. Follow our [Installation Guide](./installation). + +## Choose your motors + +If you're using Feetech or Dynamixel motors, LeRobot provides built-in bus interfaces: + +- [`FeetechMotorsBus`](https://github.com/huggingface/lerobot/blob/main/lerobot/common/motors/feetech/feetech.py) – for controlling Feetech servos +- [`DynamixelMotorsBus`](https://github.com/huggingface/lerobot/blob/main/lerobot/common/motors/dynamixel/dynamixel.py) – for controlling Dynamixel servos + +Please refer to the [`MotorsBus`](https://github.com/huggingface/lerobot/blob/main/lerobot/common/motors/motors_bus.py) abstract class to learn about its API. +For a good example of how it can be used, you can have a look at our own [SO101 follower implementation](https://github.com/huggingface/lerobot/blob/main/lerobot/common/robots/so101_follower/so101_follower.py) + +Use these if compatible. Otherwise, you'll need to find or write a Python interface (not covered in this tutorial): +- Find an existing SDK in Python (or use bindings to C/C++) +- Or implement a basic communication wrapper (e.g., via pyserial, socket, or CANopen) + +You're not alone—many community contributions use custom boards or firmware! + +For Feetech and Dynamixel, we currently support these servos: + - Feetech: + - STS & SMS series (protocol 0): `sts3215`, `sts3250`, `sm8512bl` + - SCS series (protocol 1): `scs0009` + - Dynamixel (protocol 2.0 only): `xl330-m077`, `xl330-m288`, `xl430-w250`, `xm430-w350`, `xm540-w270`, `xc430-w150` + +If you are using Feetech or Dynamixel servos that are not in this list, you can add those in the [Feetech table](https://github.com/huggingface/lerobot/blob/main/lerobot/common/motors/feetech/tables.py) or [Dynamixel table](https://github.com/huggingface/lerobot/blob/main/lerobot/common/motors/dynamixel/tables.py). Depending on the model, this will require you to add model-specific information. In most cases though, there shouldn't be a lot of additions to do. + +In the next sections, we'll use a `FeetechMotorsBus` as the motors interface for the examples. Replace it and adapt to your motors if necessary. + +## Step 1: Subclass the `Robot` Interface + +You’ll first need to specify the config class and a string identifier (`name`) for your robot. If your robot has special needs that you'd like to be able to change easily, it should go here (e.g. port/address, baudrate). + +Here, we'll add the port name and one camera by default for our robot: +```python +from dataclasses import dataclass, field + +from lerobot.common.cameras import CameraConfig +from lerobot.common.cameras.opencv import OpenCVCameraConfig +from lerobot.common.robots import RobotConfig + + +@RobotConfig.register_subclass("my_cool_robot") +@dataclass +class MyCoolRobotConfig(RobotConfig): + port: str + cameras: dict[str, CameraConfig] = field( + default_factory={ + "cam_1": OpenCVCameraConfig( + index_or_path=2, + fps=30, + width=480, + height=640, + ), + } + ) +``` + +Have a look at our [Cameras tutorial](./cameras) to understand how to detect and add your camera. + +Next, we'll create our actual robot class which inherits from `Robot`. This abstract class defines a contract you must follow for your robot to be usable with the rest of the LeRobot tools. + +Here we'll create a simple 5-DoF robot with one camera. It could be a simple arm but notice that the `Robot` abstract class does not assume anything on your robot's form factor. You can let you imagination run wild when designing new robots! + +```python +from lerobot.common.cameras import make_cameras_from_configs +from lerobot.common.motors import Motor, MotorNormMode +from lerobot.common.motors.feetech import FeetechMotorsBus +from lerobot.common.robots import Robot + +class MyCoolRobot(Robot): + config_class = MyCoolRobotConfig + name = "my_cool_robot" + + def __init__(self, config: MyCoolRobotConfig): + super().__init__(config) + self.bus = FeetechMotorsBus( + port=self.config.port, + motors={ + "joint_1": Motor(1, "sts3250", MotorNormMode.RANGE_M100_100), + "joint_2": Motor(2, "sts3215", MotorNormMode.RANGE_M100_100), + "joint_3": Motor(3, "sts3215", MotorNormMode.RANGE_M100_100), + "joint_4": Motor(4, "sts3215", MotorNormMode.RANGE_M100_100), + "joint_5": Motor(5, "sts3215", MotorNormMode.RANGE_M100_100), + }, + calibration=self.calibration, + ) + self.cameras = make_cameras_from_configs(config.cameras) +``` + +## Step 2: Define Observation and Action Features + +These two properties define the *interface contract* between your robot and tools that consume it (such as data collection or learning pipelines). + +> [!WARNING] +> Note that these properties must be callable even if the robot is not yet connected, so avoid relying on runtime hardware state to define them. + +### `observation_features` + +This property should return a dictionary describing the structure of sensor outputs from your robot. The keys match what `get_observation()` returns, and the values describe either the shape (for arrays/images) or the type (for simple values). + +Example for our 5-DoF arm with one camera: +```python +@property +def _motors_ft(self) -> dict[str, type]: + return { + "joint_1.pos": float, + "joint_2.pos": float, + "joint_3.pos": float, + "joint_4.pos": float, + "joint_5.pos": float, + } + +@property +def _cameras_ft(self) -> dict[str, tuple]: + return { + cam: (self.cameras[cam].height, self.cameras[cam].width, 3) for cam in self.cameras + } + +@property +def observation_features(self) -> dict: + return {**self._motors_ft, **self._cameras_ft} +``` +In this case, observations consist of a simple dict storing each motor's position and a camera image. + +### `action_features` + +This property describes the commands your robot expects via `send_action()`. Again, keys must match the expected input format, and values define the shape/type of each command. + +Here, we simply use the same joints proprioceptive features (`self._motors_ft`) as with `observation_features`: the action sent will simply the goal position for each motor. +```python +def action_features(self) -> dict: + return self._motors_ft +``` + +## Step 3: Handle Connection and Disconnection + +These methods should handle opening and closing communication with your hardware (e.g. serial ports, CAN interfaces, USB devices, cameras). + +### `is_connected` + +This property should simply reflect that communication with the robot's hardware is established. When this property is `True`, it should be possible to read and write to the hardware using `get_observation()` and `send_action()`. + +```python +@property +def is_connected(self) -> bool: + return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) +``` + +### `connect()` + +This method should establish communication with the hardware. Moreover, if your robot needs calibration and is not calibrated, it should start a calibration procedure by default. If your robot needs some specific configuration, this should also be called here. + +```python +def connect(self, calibrate: bool = True) -> None: + self.bus.connect() + if not self.is_calibrated and calibrate: + self.calibrate() + + for cam in self.cameras.values(): + cam.connect() + + self.configure() +``` + +### `disconnect()` + +This method should gracefully terminate communication with the hardware: free any related resources (threads or processes), close ports, etc. + +Here, we already handle this in our `MotorsBus` and `Camera` classes so we just need to call their own `disconnect()` methods: +```python +def disconnect(self) -> None: + self.bus.disconnect() + for cam in self.cameras.values(): + cam.disconnect() +``` + +## Step 4: Support Calibration and Configuration + +LeRobot supports saving and loading calibration data automatically. This is useful for joint offsets, zero positions, or sensor alignment. + +> Note that depending on your hardware, this may not apply. If that's the case, you can simply leave these methods as no-ops: +> ```python +> @property +> def is_calibrated(self) -> bool: +> return True +> +> def calibrate(self) -> None: +> pass +> ``` + +### `is_calibrated` + +This should reflect whether your robot has the required calibration loaded. + +```python +@property +def is_calibrated(self) -> bool: + return self.bus.is_calibrated +``` + +### `calibrate()` + +The goal of the calibration is twofold: + - Know the physical range of motion of each motors in order to only send commands within this range. + - Normalize raw motors positions to sensible continuous values (e.g. percentages, degrees) instead of arbitrary discrete value dependant on the specific motor used that will not replicate elsewhere. + +It should implement the logic for calibration (if relevant) and update the `self.calibration` dictionary. If you are using Feetech or Dynamixel motors, our bus interfaces already include methods to help with this. + +```python +def calibrate(self) -> None: + self.bus.disable_torque() + for motor in self.bus.motors: + self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value) + + input(f"Move {self} to the middle of its range of motion and press ENTER....") + homing_offsets = self.bus.set_half_turn_homings() + + print( + "Move all joints sequentially through their entire ranges " + "of motion.\nRecording positions. Press ENTER to stop..." + ) + range_mins, range_maxes = self.bus.record_ranges_of_motion() + + self.calibration = {} + for motor, m in self.bus.motors.items(): + self.calibration[motor] = MotorCalibration( + id=m.id, + drive_mode=0, + homing_offset=homing_offsets[motor], + range_min=range_mins[motor], + range_max=range_maxes[motor], + ) + + self.bus.write_calibration(self.calibration) + self._save_calibration() + print("Calibration saved to", self.calibration_fpath) +``` + +### `configure()` + +Use this to set up any configuration for your hardware (servos control modes, controller gains, etc.). This should usually be run at connection time and be idempotent. + +```python +def configure(self) -> None: + with self.bus.torque_disabled(): + self.bus.configure_motors() + for motor in self.bus.motors: + self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value) + self.bus.write("P_Coefficient", motor, 16) + self.bus.write("I_Coefficient", motor, 0) + self.bus.write("D_Coefficient", motor, 32) +``` + +## Step 5: Implement Sensors Reading and Action Sending + +These are the most important runtime functions: the core I/O loop. + +### `get_observation()` + +Returns a dictionary of sensor values from the robot. These typically include motor states, camera frames, various sensors, etc. In the LeRobot framework, these observations are what will be fed to a policy in order to predict the actions to take. The dictionary keys and structure must match `observation_features`. + +```python +def get_observation(self) -> dict[str, Any]: + if not self.is_connected: + raise ConnectionError(f"{self} is not connected.") + + # Read arm position + obs_dict = self.bus.sync_read("Present_Position") + obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()} + + # Capture images from cameras + for cam_key, cam in self.cameras.items(): + obs_dict[cam_key] = cam.async_read() + + return obs_dict +``` + +### `send_action()` + +Takes a dictionary that matches `action_features`, and sends it to your hardware. You can add safety limits (clipping, smoothing) and return what was actually sent. + +For simplicity, we won't be adding any modification of the actions in our example here. + +```python +def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + goal_pos = {key.removesuffix(".pos"): val for key, val in action.items()} + + # Send goal position to the arm + self.bus.sync_write("Goal_Position", goal_pos) + + return action +``` + +## Adding a Teleoperator + +For implementing teleoperation devices, we also provide a [`Teleoperator`](https://github.com/huggingface/lerobot/blob/main/lerobot/common/teleoperators/teleoperator.py) base class. This class is very similar to the `Robot` base class and also doesn't assume anything on form factor. + +The main differences are in the I/O functions: a teleoperator allows you to produce action via `get_action` and can receive feedback actions via `send_feedback`. Feedback could be anything controllable on the teleoperation device that could help the person controlling it understand the consequences of the actions sent. Think motion/force feedback on a leader arm, vibrations on a gamepad controller for example. To implement a teleoperator, you can follow this same tutorial and adapt it for these two methods. + +## Wrapping Up + +Once your robot class is complete, you can leverage the LeRobot ecosystem: + +- Control your robot with available teleoperators or integrate directly your teleoperating device +- Record training data and visualize it +- Integrate it into RL or imitation learning pipelines + +Don't hesitate to reach out to the community for help on our [Discord](https://discord.gg/s3KuuzsPFb) 🤗 diff --git a/docs/source/notebooks.mdx b/docs/source/notebooks.mdx new file mode 100644 index 0000000000..729b31a99d --- /dev/null +++ b/docs/source/notebooks.mdx @@ -0,0 +1,29 @@ +# 🤗 LeRobot Notebooks + +This repository contains example notebooks for using LeRobot. These notebooks demonstrate how to train policies on real or simulation datasets using standardized policies. + +--- + +### Training ACT + +[ACT](https://huggingface.co/papers/2304.13705) (Action Chunking Transformer) is a transformer-based policy architecture for imitation learning that processes robot states and camera inputs to generate smooth, chunked action sequences. + +We provide a ready-to-run Google Colab notebook to help you train ACT policies using datasets from the Hugging Face Hub, with optional logging to Weights & Biases. + +| Notebook | Colab | +|:---------|:------| +| [Train ACT with LeRobot](https://github.com/huggingface/notebooks/blob/main/lerobot/training-act.ipynb) | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/lerobot/training-act.ipynb) | + +Expected training time for 100k steps: ~1.5 hours on an NVIDIA A100 GPU with batch size of `64`. + +### Training SmolVLA + +[SmolVLA](https://huggingface.co/papers/2506.01844) is a small but efficient Vision-Language-Action model. It is compact in size with 450 M-parameter and is developed by Hugging Face. + +We provide a ready-to-run Google Colab notebook to help you train SmolVLA policies using datasets from the Hugging Face Hub, with optional logging to Weights & Biases. + +| Notebook | Colab | +| :-------------------------------------------------------------------------------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| [Train SmolVLA with LeRobot](https://github.com/huggingface/notebooks/blob/main/lerobot/training-smolvla.ipynb) | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/lerobot/training-smolvla.ipynb) | + +Expected training time for 20k steps: ~5 hours on an NVIDIA A100 GPU with batch size of `64`. diff --git a/docs/source/smolvla.mdx b/docs/source/smolvla.mdx new file mode 100644 index 0000000000..58340baa0d --- /dev/null +++ b/docs/source/smolvla.mdx @@ -0,0 +1,93 @@ +# Finetune SmolVLA + +SmolVLA is Hugging Face’s lightweight foundation model for robotics. Designed for easy fine-tuning on LeRobot datasets, it helps accelerate your development! + +

+ SmolVLA architecture. +
+ Figure 1. SmolVLA takes as input (i) multiple cameras views, (ii) the robot’s current sensorimotor state, and (iii) a natural language instruction, encoded into contextual features used to condition the action expert when generating an action chunk. +

+ +## Set Up Your Environment + +1. Install LeRobot by following our [Installation Guide](./installation). +2. Install SmolVLA dependencies by running: + + ```bash + pip install -e ".[smolvla]" + ``` + +## Collect a dataset + +SmolVLA is a base model, so fine-tuning on your own data is required for optimal performance in your setup. +We recommend recording ~50 episodes of your task as a starting point. Follow our guide to get started: [Recording a Dataset](https://huggingface.co/docs/lerobot/getting_started_real_world_robot#record-a-dataset) + + + +In your dataset, make sure to have enough demonstrations per each variation (e.g. the cube position on the table if it is cube pick-place task) you are introducing. + +We recommend checking out the dataset linked below for reference that was used in the [SmolVLA paper](https://huggingface.co/papers/2506.01844): + +🔗 [SVLA SO100 PickPlace](https://huggingface.co/spaces/lerobot/visualize_dataset?path=%2Flerobot%2Fsvla_so100_pickplace%2Fepisode_0) + +In this dataset, we recorded 50 episodes across 5 distinct cube positions. For each position, we collected 10 episodes of pick-and-place interactions. This structure, repeating each variation several times, helped the model generalize better. We tried similar dataset with 25 episodes, and it was not enough leading to a bad performance. So, the data quality and quantity is definitely a key. +After you have your dataset available on the Hub, you are good to go to use our finetuning script to adapt SmolVLA to your application. + + +## Finetune SmolVLA on your data + +Use [`smolvla_base`](https://hf.co/lerobot/smolvla_base), our pretrained 450M model, and fine-tune it on your data. +Training the model for 20k steps will roughly take ~4 hrs on a single A100 GPU. You should tune the number of steps based on performance and your use-case. + +If you don't have a gpu device, you can train using our notebook on [![Google Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/lerobot/training-smolvla.ipynb) + +Pass your dataset to the training script using `--dataset.repo_id`. If you want to test your installation, run the following command where we use one of the datasets we collected for the [SmolVLA Paper](https://huggingface.co/papers/2506.01844). + +```bash +cd lerobot && python lerobot/scripts/train.py \ + --policy.path=lerobot/smolvla_base \ + --dataset.repo_id=${HF_USER}/mydataset \ + --batch_size=64 \ + --steps=20000 \ + --output_dir=outputs/train/my_smolvla \ + --job_name=my_smolvla_training \ + --policy.device=cuda \ + --wandb.enable=true +``` + + +You can start with a small batch size and increase it incrementally, if the GPU allows it, as long as loading times remain short. + + +Fine-tuning is an art. For a complete overview of the options for finetuning, run + +```bash +python lerobot/scripts/train.py --help +``` + +

+ Comparison of SmolVLA across task variations. +
+ Figure 2: Comparison of SmolVLA across task variations. From left to right: (1) pick-place cube counting, (2) pick-place cube counting, (3) pick-place cube counting under perturbations, and (4) generalization on pick-and-place of the lego block with real-world SO101. +

+ + +## Evaluate the finetuned model and run it in real-time + +Similarly for when recording an episode, it is recommended that you are logged in to the HuggingFace Hub. You can follow the corresponding steps: [Record a dataset](./getting_started_real_world_robot#record-a-dataset). +Once you are logged in, you can run inference in your setup by doing: + +```bash +python -m lerobot.record \ + --robot.type=so101_follower \ + --robot.port=/dev/ttyACM0 \ # <- Use your port + --robot.id=my_blue_follower_arm \ # <- Use your robot id + --robot.cameras="{ front: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \ # <- Use your cameras + --dataset.single_task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording + --dataset.repo_id=${HF_USER}/eval_DATASET_NAME_test \ # <- This will be the dataset name on HF Hub + --dataset.episode_time_s=50 \ + --dataset.num_episodes=10 \ + --policy.path=HF_USER/FINETUNE_MODEL_NAME # <- Use your fine-tuned model +``` + +Depending on your evaluation setup, you can configure the duration and the number of episodes to record for your evaluation suite. diff --git a/examples/lekiwi/record.py b/examples/lekiwi/record.py index 4f56213d7b..405a41bd38 100644 --- a/examples/lekiwi/record.py +++ b/examples/lekiwi/record.py @@ -23,7 +23,7 @@ dataset_features = {**action_features, **obs_features} dataset = LeRobotDataset.create( - repo_id="user/lekiwi" + str(int(time.time())), + repo_id="pepijn223/lekiwi" + str(int(time.time())), fps=10, features=dataset_features, robot_type=robot.name, @@ -36,7 +36,7 @@ if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected: exit() -print("Starting LeKiwi teleoperation") +print("Starting LeKiwi recording") i = 0 while i < NB_CYCLES_CLIENT_CONNECTION: arm_action = leader_arm.get_action() diff --git a/lerobot/common/constants.py b/lerobot/common/constants.py index e78e748baf..990f2aa1eb 100644 --- a/lerobot/common/constants.py +++ b/lerobot/common/constants.py @@ -22,6 +22,7 @@ OBS_IMAGE = "observation.image" OBS_IMAGES = "observation.images" ACTION = "action" +REWARD = "next.reward" ROBOTS = "robots" TELEOPERATORS = "teleoperators" diff --git a/lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py b/lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py index 41dd33b62b..9b21cf7ca4 100644 --- a/lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py +++ b/lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py @@ -36,7 +36,7 @@ "robot_config": AlohaRobotConfig(), "license": "mit", "url": "https://mobile-aloha.github.io/", - "paper": "https://arxiv.org/abs/2401.02117", + "paper": "https://huggingface.co/papers/2401.02117", "citation_bibtex": dedent(r""" @inproceedings{fu2024mobile, author = {Fu, Zipeng and Zhao, Tony Z. and Finn, Chelsea}, @@ -49,7 +49,7 @@ "robot_config": AlohaRobotConfig(), "license": "mit", "url": "https://tonyzhaozh.github.io/aloha/", - "paper": "https://arxiv.org/abs/2304.13705", + "paper": "https://huggingface.co/papers/2304.13705", "citation_bibtex": dedent(r""" @article{Zhao2023LearningFB, title={Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware}, @@ -57,13 +57,13 @@ journal={RSS}, year={2023}, volume={abs/2304.13705}, - url={https://arxiv.org/abs/2304.13705} + url={https://huggingface.co/papers/2304.13705} }""").lstrip(), } PUSHT_INFO = { "license": "mit", "url": "https://diffusion-policy.cs.columbia.edu/", - "paper": "https://arxiv.org/abs/2303.04137v5", + "paper": "https://huggingface.co/papers/2303.04137", "citation_bibtex": dedent(r""" @article{chi2024diffusionpolicy, author = {Cheng Chi and Zhenjia Xu and Siyuan Feng and Eric Cousineau and Yilun Du and Benjamin Burchfiel and Russ Tedrake and Shuran Song}, @@ -75,7 +75,7 @@ XARM_INFO = { "license": "mit", "url": "https://www.nicklashansen.com/td-mpc/", - "paper": "https://arxiv.org/abs/2203.04955", + "paper": "https://huggingface.co/papers/2203.04955", "citation_bibtex": dedent(r""" @inproceedings{Hansen2022tdmpc, title={Temporal Difference Learning for Model Predictive Control}, @@ -244,7 +244,7 @@ "tasks_col": "language_instruction", "license": "mit", "url": "https://ut-austin-rpl.github.io/BUDS-website/", - "paper": "https://arxiv.org/abs/2109.13841", + "paper": "https://huggingface.co/papers/2109.13841", "citation_bibtex": dedent(r""" @article{zhu2022bottom, title={Bottom-Up Skill Discovery From Unsegmented Demonstrations for Long-Horizon Robot Manipulation}, @@ -261,7 +261,7 @@ "tasks_col": "language_instruction", "license": "mit", "url": "https://ut-austin-rpl.github.io/sailor/", - "paper": "https://arxiv.org/abs/2210.11435", + "paper": "https://huggingface.co/papers/2210.11435", "citation_bibtex": dedent(r""" @inproceedings{nasiriany2022sailor, title={Learning and Retrieval from Prior Data for Skill-based Imitation Learning}, @@ -274,7 +274,7 @@ "tasks_col": "language_instruction", "license": "mit", "url": "https://ut-austin-rpl.github.io/sirius/", - "paper": "https://arxiv.org/abs/2211.08416", + "paper": "https://huggingface.co/papers/2211.08416", "citation_bibtex": dedent(r""" @inproceedings{liu2022robot, title = {Robot Learning on the Job: Human-in-the-Loop Autonomy and Learning During Deployment}, @@ -298,14 +298,14 @@ "tasks_col": "language_instruction", "license": "cc-by-4.0", "url": "https://sites.google.com/view/cablerouting/home", - "paper": "https://arxiv.org/abs/2307.08927", + "paper": "https://huggingface.co/papers/2307.08927", "citation_bibtex": dedent(r""" @article{luo2023multistage, author = {Jianlan Luo and Charles Xu and Xinyang Geng and Gilbert Feng and Kuan Fang and Liam Tan and Stefan Schaal and Sergey Levine}, title = {Multi-Stage Cable Routing through Hierarchical Imitation Learning}, journal = {arXiv pre-print}, year = {2023}, - url = {https://arxiv.org/abs/2307.08927}, + url = {https://huggingface.co/papers/2307.08927}, }""").lstrip(), }, "berkeley_fanuc_manipulation": { @@ -322,7 +322,7 @@ "berkeley_gnm_cory_hall": { "tasks_col": "language_instruction", "license": "mit", - "paper": "https://arxiv.org/abs/1709.10489", + "paper": "https://huggingface.co/papers/1709.10489", "citation_bibtex": dedent(r""" @inproceedings{kahn2018self, title={Self-supervised deep reinforcement learning with generalized computation graphs for robot navigation}, @@ -337,7 +337,7 @@ "tasks_col": "language_instruction", "license": "mit", "url": "https://sites.google.com/view/recon-robot", - "paper": "https://arxiv.org/abs/2104.05859", + "paper": "https://huggingface.co/papers/2104.05859", "citation_bibtex": dedent(r""" @inproceedings{shah2021rapid, title={Rapid Exploration for Open-World Navigation with Latent Goal Models}, @@ -351,7 +351,7 @@ "tasks_col": "language_instruction", "license": "mit", "url": "https://sites.google.com/view/SACSoN-review", - "paper": "https://arxiv.org/abs/2306.01874", + "paper": "https://huggingface.co/papers/2306.01874", "citation_bibtex": dedent(r""" @article{hirose2023sacson, title={SACSoN: Scalable Autonomous Data Collection for Social Navigation}, @@ -363,7 +363,7 @@ "berkeley_mvp": { "tasks_col": "language_instruction", "license": "mit", - "paper": "https://arxiv.org/abs/2203.06173", + "paper": "https://huggingface.co/papers/2203.06173", "citation_bibtex": dedent(r""" @InProceedings{Radosavovic2022, title = {Real-World Robot Learning with Masked Visual Pre-training}, @@ -375,7 +375,7 @@ "berkeley_rpt": { "tasks_col": "language_instruction", "license": "mit", - "paper": "https://arxiv.org/abs/2306.10007", + "paper": "https://huggingface.co/papers/2306.10007", "citation_bibtex": dedent(r""" @article{Radosavovic2023, title={Robot Learning with Sensorimotor Pre-training}, @@ -388,7 +388,7 @@ "tasks_col": "language_instruction", "license": "mit", "url": "https://human-world-model.github.io/", - "paper": "https://arxiv.org/abs/2308.10901", + "paper": "https://huggingface.co/papers/2308.10901", "citation_bibtex": dedent(r""" @inproceedings{mendonca2023structured, title={Structured World Models from Human Videos}, @@ -401,7 +401,7 @@ "tasks_col": "language_instruction", "license": "mit", "url": "https://play-fusion.github.io/", - "paper": "https://arxiv.org/abs/2312.04549", + "paper": "https://huggingface.co/papers/2312.04549", "citation_bibtex": dedent(r""" @inproceedings{chen2023playfusion, title={PlayFusion: Skill Acquisition via Diffusion from Language-Annotated Play}, @@ -414,7 +414,7 @@ "tasks_col": "language_instruction", "license": "mit", "url": "https://robo-affordances.github.io/", - "paper": "https://arxiv.org/abs/2304.08488", + "paper": "https://huggingface.co/papers/2304.08488", "citation_bibtex": dedent(r""" @inproceedings{bahl2023affordances, title={Affordances from Human Videos as a Versatile Representation for Robotics}, @@ -433,7 +433,7 @@ "tasks_col": "language_instruction", "license": "mit", "url": "https://diffusion-policy.cs.columbia.edu/", - "paper": "https://arxiv.org/abs/2303.04137v5", + "paper": "https://huggingface.co/papers/2303.04137", "citation_bibtex": dedent(r""" @inproceedings{chi2023diffusionpolicy, title={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion}, @@ -505,7 +505,7 @@ "tasks_col": "language_instruction", "license": "mit", "url": "https://droid-dataset.github.io/", - "paper": "https://arxiv.org/abs/2403.12945", + "paper": "https://huggingface.co/papers/2403.12945", "citation_bibtex": dedent(r""" @article{khazatsky2024droid, title = {DROID: A Large-Scale In-The-Wild Robot Manipulation Dataset}, @@ -517,7 +517,7 @@ "tasks_col": "language_instruction", "license": "cc-by-4.0", "url": "https://functional-manipulation-benchmark.github.io/", - "paper": "https://arxiv.org/abs/2401.08553", + "paper": "https://huggingface.co/papers/2401.08553", "citation_bibtex": dedent(r""" @article{luo2024fmb, title={FMB: a Functional Manipulation Benchmark for Generalizable Robotic Learning}, @@ -530,7 +530,7 @@ "tasks_col": "language_instruction", "license": "mit", "url": "https://openreview.net/forum?id=WuBv9-IGDUA", - "paper": "https://arxiv.org/abs/2401.14502", + "paper": "https://huggingface.co/papers/2401.14502", "citation_bibtex": dedent(r""" @inproceedings{saxena2023multiresolution, title={Multi-Resolution Sensing for Real-Time Control with Vision-Language Models}, @@ -575,7 +575,7 @@ "tasks_col": "language_instruction", "license": "mit", "url": "https://jyopari.github.io/VINN/", - "paper": "https://arxiv.org/abs/2112.01511", + "paper": "https://huggingface.co/papers/2112.01511", "citation_bibtex": dedent(r""" @misc{pari2021surprising, title={The Surprising Effectiveness of Representation Learning for Visual Imitation}, @@ -590,7 +590,7 @@ "tasks_col": "language_instruction", "license": "mit", "url": "https://play-to-policy.github.io/", - "paper": "https://arxiv.org/abs/2210.10047", + "paper": "https://huggingface.co/papers/2210.10047", "citation_bibtex": dedent(r""" @article{cui2022play, title = {From Play to Policy: Conditional Behavior Generation from Uncurated Robot Data}, @@ -603,7 +603,7 @@ "tasks_col": "language_instruction", "license": "mit", "url": "https://rot-robot.github.io/", - "paper": "https://arxiv.org/abs/2206.15469", + "paper": "https://huggingface.co/papers/2206.15469", "citation_bibtex": dedent(r""" @inproceedings{haldar2023watch, title={Watch and match: Supercharging imitation with regularized optimal transport}, @@ -633,7 +633,7 @@ "tasks_col": "language_instruction", "license": "mit", "url": "https://sites.google.com/view/hydra-il-2023", - "paper": "https://arxiv.org/abs/2306.17237", + "paper": "https://huggingface.co/papers/2306.17237", "citation_bibtex": dedent(r""" @article{belkhale2023hydra, title={HYDRA: Hybrid Robot Actions for Imitation Learning}, @@ -646,21 +646,21 @@ "tasks_col": "language_instruction", "license": "mit", "url": "https://sites.google.com/view/visionandtouch", - "paper": "https://arxiv.org/abs/1810.10191", + "paper": "https://huggingface.co/papers/1810.10191", "citation_bibtex": dedent(r""" @inproceedings{lee2019icra, title={Making sense of vision and touch: Self-supervised learning of multimodal representations for contact-rich tasks}, author={Lee, Michelle A and Zhu, Yuke and Srinivasan, Krishnan and Shah, Parth and Savarese, Silvio and Fei-Fei, Li and Garg, Animesh and Bohg, Jeannette}, booktitle={2019 IEEE International Conference on Robotics and Automation (ICRA)}, year={2019}, - url={https://arxiv.org/abs/1810.10191} + url={https://huggingface.co/papers/1810.10191} }""").lstrip(), }, "stanford_robocook": { "tasks_col": "language_instruction", "license": "mit", "url": "https://hshi74.github.io/robocook/", - "paper": "https://arxiv.org/abs/2306.14447", + "paper": "https://huggingface.co/papers/2306.14447", "citation_bibtex": dedent(r""" @article{shi2023robocook, title={RoboCook: Long-Horizon Elasto-Plastic Object Manipulation with Diverse Tools}, @@ -673,7 +673,7 @@ "tasks_col": "language_instruction", "license": "cc-by-4.0", "url": "https://www.kaggle.com/datasets/oiermees/taco-robot", - "paper": "https://arxiv.org/abs/2209.08959, https://arxiv.org/abs/2210.01911", + "paper": "https://huggingface.co/papers/2209.08959, https://huggingface.co/papers/2210.01911", "citation_bibtex": dedent(r""" @inproceedings{rosete2022tacorl, author = {Erick Rosete-Beas and Oier Mees and Gabriel Kalweit and Joschka Boedecker and Wolfram Burgard}, @@ -693,7 +693,7 @@ "tasks_col": "language_instruction", "license": "mit", "url": "URL", - "paper": "https://arxiv.org/abs/2107.05842", + "paper": "https://huggingface.co/papers/2107.05842", "citation_bibtex": dedent(r""" @Article{Osa22, author = {Takayuki Osa}, @@ -709,7 +709,7 @@ "tasks_col": "language_instruction", "license": "mit", "url": "https://toto-benchmark.org/", - "paper": "https://arxiv.org/abs/2306.00942", + "paper": "https://huggingface.co/papers/2306.00942", "citation_bibtex": dedent(r""" @inproceedings{zhou2023train, author={Zhou, Gaoyue and Dean, Victoria and Srirama, Mohan Kumar and Rajeswaran, Aravind and Pari, Jyothish and Hatch, Kyle and Jain, Aryan and Yu, Tianhe and Abbeel, Pieter and Pinto, Lerrel and Finn, Chelsea and Gupta, Abhinav}, @@ -733,7 +733,7 @@ "tasks_col": "language_instruction", "license": "mit", "url": "https://owmcorl.github.io/#", - "paper": "https://arxiv.org/abs/2310.16029", + "paper": "https://huggingface.co/papers/2310.16029", "citation_bibtex": dedent(r""" @preprint{Feng2023Finetuning, title={Finetuning Offline World Models in the Real World}, @@ -745,7 +745,7 @@ "tasks_col": "language_instruction", "license": "mit", "url": "https://robopil.github.io/d3fields/", - "paper": "https://arxiv.org/abs/2309.16118", + "paper": "https://huggingface.co/papers/2309.16118", "citation_bibtex": dedent(r""" @article{wang2023d3field, title={D^3Field: Dynamic 3D Descriptor Fields for Generalizable Robotic Manipulation}, @@ -758,7 +758,7 @@ "tasks_col": "language_instruction", "license": "mit", "url": "https://uscresl.github.io/dmfd/", - "paper": "https://arxiv.org/abs/2207.10148", + "paper": "https://huggingface.co/papers/2207.10148", "citation_bibtex": dedent(r""" @article{salhotra2022dmfd, author={Salhotra, Gautam and Liu, I-Chun Arthur and Dominguez-Kuhne, Marcus and Sukhatme, Gaurav S.}, @@ -775,7 +775,7 @@ "tasks_col": "language_instruction", "license": "mit", "url": "https://ut-austin-rpl.github.io/MUTEX/", - "paper": "https://arxiv.org/abs/2309.14320", + "paper": "https://huggingface.co/papers/2309.14320", "citation_bibtex": dedent(r""" @inproceedings{shah2023mutex, title={{MUTEX}: Learning Unified Policies from Multimodal Task Specifications}, @@ -811,7 +811,7 @@ "tasks_col": "language_instruction", "license": "mit", "url": "https://saytap.github.io/", - "paper": "https://arxiv.org/abs/2306.07580", + "paper": "https://huggingface.co/papers/2306.07580", "citation_bibtex": dedent(r""" @article{saytap2023, author = {Yujin Tang and Wenhao Yu and Jie Tan and Heiga Zen and Aleksandra Faust and @@ -847,7 +847,7 @@ "tasks_col": "language_instruction", "license": "mit", "url": "https://ut-austin-rpl.github.io/VIOLA/", - "paper": "https://arxiv.org/abs/2210.11339", + "paper": "https://huggingface.co/papers/2210.11339", "citation_bibtex": dedent(r""" @article{zhu2022viola, title={VIOLA: Imitation Learning for Vision-Based Manipulation with Object Proposal Priors}, diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py index c99fba811d..ea081e9fbf 100644 --- a/lerobot/common/envs/configs.py +++ b/lerobot/common/envs/configs.py @@ -14,10 +14,13 @@ import abc from dataclasses import dataclass, field +from typing import Any, Optional import draccus from lerobot.common.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE +from lerobot.common.robots import RobotConfig +from lerobot.common.teleoperators.config import TeleoperatorConfig from lerobot.configs.types import FeatureType, PolicyFeature @@ -155,3 +158,116 @@ def gym_kwargs(self) -> dict: "visualization_height": self.visualization_height, "max_episode_steps": self.episode_length, } + + +@dataclass +class VideoRecordConfig: + """Configuration for video recording in ManiSkill environments.""" + + enabled: bool = False + record_dir: str = "videos" + trajectory_name: str = "trajectory" + + +@dataclass +class EnvTransformConfig: + """Configuration for environment wrappers.""" + + # ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig) + control_mode: str = "gamepad" + display_cameras: bool = False + add_joint_velocity_to_observation: bool = False + add_current_to_observation: bool = False + add_ee_pose_to_observation: bool = False + crop_params_dict: Optional[dict[str, tuple[int, int, int, int]]] = None + resize_size: Optional[tuple[int, int]] = None + control_time_s: float = 20.0 + fixed_reset_joint_positions: Optional[Any] = None + reset_time_s: float = 5.0 + use_gripper: bool = True + gripper_quantization_threshold: float | None = 0.8 + gripper_penalty: float = 0.0 + gripper_penalty_in_reward: bool = False + + +@EnvConfig.register_subclass(name="gym_manipulator") +@dataclass +class HILSerlRobotEnvConfig(EnvConfig): + """Configuration for the HILSerlRobotEnv environment.""" + + robot: Optional[RobotConfig] = None + teleop: Optional[TeleoperatorConfig] = None + wrapper: Optional[EnvTransformConfig] = None + fps: int = 10 + name: str = "real_robot" + mode: str = None # Either "record", "replay", None + repo_id: Optional[str] = None + dataset_root: Optional[str] = None + task: str = "" + num_episodes: int = 10 # only for record mode + episode: int = 0 + device: str = "cuda" + push_to_hub: bool = True + pretrained_policy_name_or_path: Optional[str] = None + reward_classifier_pretrained_path: Optional[str] = None + # For the reward classifier, to record more positive examples after a success + number_of_steps_after_success: int = 0 + + def gym_kwargs(self) -> dict: + return {} + + +@EnvConfig.register_subclass("hil") +@dataclass +class HILEnvConfig(EnvConfig): + """Configuration for the HIL environment.""" + + type: str = "hil" + name: str = "PandaPickCube" + task: str = "PandaPickCubeKeyboard-v0" + use_viewer: bool = True + gripper_penalty: float = 0.0 + use_gamepad: bool = True + state_dim: int = 18 + action_dim: int = 4 + fps: int = 100 + episode_length: int = 100 + video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig) + features: dict[str, PolicyFeature] = field( + default_factory=lambda: { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)), + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(18,)), + } + ) + features_map: dict[str, str] = field( + default_factory=lambda: { + "action": ACTION, + "observation.image": OBS_IMAGE, + "observation.state": OBS_STATE, + } + ) + ################# args from hilserlrobotenv + reward_classifier_pretrained_path: Optional[str] = None + robot_config: Optional[RobotConfig] = None + teleop_config: Optional[TeleoperatorConfig] = None + wrapper: Optional[EnvTransformConfig] = None + mode: str = None # Either "record", "replay", None + repo_id: Optional[str] = None + dataset_root: Optional[str] = None + num_episodes: int = 10 # only for record mode + episode: int = 0 + device: str = "cuda" + push_to_hub: bool = True + pretrained_policy_name_or_path: Optional[str] = None + # For the reward classifier, to record more positive examples after a success + number_of_steps_after_success: int = 0 + ############################ + + @property + def gym_kwargs(self) -> dict: + return { + "use_viewer": self.use_viewer, + "use_gamepad": self.use_gamepad, + "gripper_penalty": self.gripper_penalty, + } diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 8450f84b95..4f5d59c698 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -17,7 +17,7 @@ import gymnasium as gym -from lerobot.common.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv +from lerobot.common.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, PushtEnv, XarmEnv def make_env_config(env_type: str, **kwargs) -> EnvConfig: @@ -27,6 +27,8 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig: return PushtEnv(**kwargs) elif env_type == "xarm": return XarmEnv(**kwargs) + elif env_type == "hil": + return HILEnvConfig(**kwargs) else: raise ValueError(f"Policy type '{env_type}' is not available.") diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py index 83334f876d..66d6e5f93f 100644 --- a/lerobot/common/envs/utils.py +++ b/lerobot/common/envs/utils.py @@ -47,6 +47,10 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten # TODO(aliberts, rcadene): use transforms.ToTensor()? img = torch.from_numpy(img) + # When preprocessing observations in a non-vectorized environment, we need to add a batch dimension. + # This is the case for human-in-the-loop RL where there is only one environment. + if img.ndim == 3: + img = img.unsqueeze(0) # sanity check that images are channel last _, h, w, c = img.shape assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}" @@ -62,13 +66,18 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten return_observations[imgkey] = img if "environment_state" in observations: - return_observations["observation.environment_state"] = torch.from_numpy( - observations["environment_state"] - ).float() + env_state = torch.from_numpy(observations["environment_state"]).float() + if env_state.dim() == 1: + env_state = env_state.unsqueeze(0) + + return_observations["observation.environment_state"] = env_state # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing - # requirement for "agent_pos" - return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float() + agent_pos = torch.from_numpy(observations["agent_pos"]).float() + if agent_pos.dim() == 1: + agent_pos = agent_pos.unsqueeze(0) + return_observations["observation.state"] = agent_pos + return return_observations diff --git a/lerobot/common/model/kinematics.py b/lerobot/common/model/kinematics.py new file mode 100644 index 0000000000..367b609e19 --- /dev/null +++ b/lerobot/common/model/kinematics.py @@ -0,0 +1,483 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +from numpy.typing import NDArray +from scipy.spatial.transform import Rotation + + +def skew_symmetric(w: NDArray[np.float32]) -> NDArray[np.float32]: + """Creates the skew-symmetric matrix from a 3D vector.""" + return np.array([[0, -w[2], w[1]], [w[2], 0, -w[0]], [-w[1], w[0], 0]]) + + +def rodrigues_rotation(w: NDArray[np.float32], theta: float) -> NDArray[np.float32]: + """Computes the rotation matrix using Rodrigues' formula.""" + w_hat = skew_symmetric(w) + return np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat + + +def screw_axis_to_transform(s: NDArray[np.float32], theta: float) -> NDArray[np.float32]: + """Converts a screw axis to a 4x4 transformation matrix.""" + screw_axis_rot = s[:3] + screw_axis_trans = s[3:] + + # Pure translation + if np.allclose(screw_axis_rot, 0) and np.linalg.norm(screw_axis_trans) == 1: + transform = np.eye(4) + transform[:3, 3] = screw_axis_trans * theta + + # Rotation (and potentially translation) + elif np.linalg.norm(screw_axis_rot) == 1: + w_hat = skew_symmetric(screw_axis_rot) + rot_mat = np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat + t = ( + np.eye(3) * theta + (1 - np.cos(theta)) * w_hat + (theta - np.sin(theta)) * w_hat @ w_hat + ) @ screw_axis_trans + transform = np.eye(4) + transform[:3, :3] = rot_mat + transform[:3, 3] = t + else: + raise ValueError("Invalid screw axis parameters") + return transform + + +def pose_difference_se3(pose1: NDArray[np.float32], pose2: NDArray[np.float32]) -> NDArray[np.float32]: + """ + Calculates the SE(3) difference between two 4x4 homogeneous transformation matrices. + SE(3) (Special Euclidean Group) represents rigid body transformations in 3D space, + combining rotation (SO(3)) and translation. + + Each 4x4 matrix has the following structure: + [R11 R12 R13 tx] + [R21 R22 R23 ty] + [R31 R32 R33 tz] + [ 0 0 0 1] + + where R is the 3x3 rotation matrix and [tx,ty,tz] is the translation vector. + + Args: + pose1: A 4x4 numpy array representing the first pose. + pose2: A 4x4 numpy array representing the second pose. + + Returns: + A 6D numpy array concatenating translation and rotation differences. + First 3 elements are the translational difference (position). + Last 3 elements are the rotational difference in axis-angle representation. + """ + rot1 = pose1[:3, :3] + rot2 = pose2[:3, :3] + + translation_diff = pose1[:3, 3] - pose2[:3, 3] + + # Calculate rotational difference using scipy's Rotation library + rot_diff = Rotation.from_matrix(rot1 @ rot2.T) + rotation_diff = rot_diff.as_rotvec() # Axis-angle representation + + return np.concatenate([translation_diff, rotation_diff]) + + +def se3_error(target_pose: NDArray[np.float32], current_pose: NDArray[np.float32]) -> NDArray[np.float32]: + pos_error = target_pose[:3, 3] - current_pose[:3, 3] + + rot_target = target_pose[:3, :3] + rot_current = current_pose[:3, :3] + rot_error_mat = rot_target @ rot_current.T + rot_error = Rotation.from_matrix(rot_error_mat).as_rotvec() + + return np.concatenate([pos_error, rot_error]) + + +class RobotKinematics: + """Robot kinematics class supporting multiple robot models.""" + + # Robot measurements dictionary + ROBOT_MEASUREMENTS = { + "koch": { + "gripper": [0.239, -0.001, 0.024], + "wrist": [0.209, 0, 0.024], + "forearm": [0.108, 0, 0.02], + "humerus": [0, 0, 0.036], + "shoulder": [0, 0, 0], + "base": [0, 0, 0.02], + }, + "moss": { + "gripper": [0.246, 0.013, 0.111], + "wrist": [0.245, 0.002, 0.064], + "forearm": [0.122, 0, 0.064], + "humerus": [0.001, 0.001, 0.063], + "shoulder": [0, 0, 0], + "base": [0, 0, 0.02], + }, + "so_old_calibration": { + "gripper": [0.320, 0, 0.050], + "wrist": [0.278, 0, 0.050], + "forearm": [0.143, 0, 0.044], + "humerus": [0.031, 0, 0.072], + "shoulder": [0, 0, 0], + "base": [0, 0, 0.02], + }, + "so_new_calibration": { + "gripper": [0.33, 0.0, 0.285], + "wrist": [0.30, 0.0, 0.267], + "forearm": [0.25, 0.0, 0.266], + "humerus": [0.06, 0.0, 0.264], + "shoulder": [0.0, 0.0, 0.238], + "base": [0.0, 0.0, 0.12], + }, + } + + def __init__(self, robot_type: str = "so100"): + """Initialize kinematics for the specified robot type. + + Args: + robot_type: String specifying the robot model ("koch", "so100", or "moss") + """ + if robot_type not in self.ROBOT_MEASUREMENTS: + raise ValueError( + f"Unknown robot type: {robot_type}. Available types: {list(self.ROBOT_MEASUREMENTS.keys())}" + ) + + self.robot_type = robot_type + self.measurements = self.ROBOT_MEASUREMENTS[robot_type] + + # Initialize all transformation matrices and screw axes + self._setup_transforms() + + def _create_translation_matrix( + self, x: float = 0.0, y: float = 0.0, z: float = 0.0 + ) -> NDArray[np.float32]: + """Create a 4x4 translation matrix.""" + return np.array([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z], [0, 0, 0, 1]]) + + def _setup_transforms(self): + """Setup all transformation matrices and screw axes for the robot.""" + # Set up rotation matrices (constant across robot types) + + # Gripper orientation + self.gripper_X0 = np.array( + [ + [1, 0, 0, 0], + [0, 0, 1, 0], + [0, -1, 0, 0], + [0, 0, 0, 1], + ], + dtype=np.float32, + ) + + # Wrist orientation + self.wrist_X0 = np.array( + [ + [0, -1, 0, 0], + [1, 0, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1], + ], + dtype=np.float32, + ) + + # Base orientation + self.base_X0 = np.array( + [ + [0, 0, 1, 0], + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 0, 1], + ], + dtype=np.float32, + ) + + # Gripper + # Screw axis of gripper frame wrt base frame + self.S_BG = np.array( + [ + 1, + 0, + 0, + 0, + self.measurements["gripper"][2], + -self.measurements["gripper"][1], + ], + dtype=np.float32, + ) + + # Gripper origin to centroid transform + self.X_GoGc = self._create_translation_matrix(x=0.07) + + # Gripper origin to tip transform + self.X_GoGt = self._create_translation_matrix(x=0.12) + + # 0-position gripper frame pose wrt base + self.X_BoGo = self._create_translation_matrix( + x=self.measurements["gripper"][0], + y=self.measurements["gripper"][1], + z=self.measurements["gripper"][2], + ) + + # Wrist + # Screw axis of wrist frame wrt base frame + self.S_BR = np.array( + [0, 1, 0, -self.measurements["wrist"][2], 0, self.measurements["wrist"][0]], dtype=np.float32 + ) + + # 0-position origin to centroid transform + self.X_RoRc = self._create_translation_matrix(x=0.0035, y=-0.002) + + # 0-position wrist frame pose wrt base + self.X_BR = self._create_translation_matrix( + x=self.measurements["wrist"][0], + y=self.measurements["wrist"][1], + z=self.measurements["wrist"][2], + ) + + # Forearm + # Screw axis of forearm frame wrt base frame + self.S_BF = np.array( + [ + 0, + 1, + 0, + -self.measurements["forearm"][2], + 0, + self.measurements["forearm"][0], + ], + dtype=np.float32, + ) + + # Forearm origin + centroid transform + self.X_ForearmFc = self._create_translation_matrix(x=0.036) + + # 0-position forearm frame pose wrt base + self.X_BF = self._create_translation_matrix( + x=self.measurements["forearm"][0], + y=self.measurements["forearm"][1], + z=self.measurements["forearm"][2], + ) + + # Humerus + # Screw axis of humerus frame wrt base frame + self.S_BH = np.array( + [ + 0, + -1, + 0, + self.measurements["humerus"][2], + 0, + -self.measurements["humerus"][0], + ], + dtype=np.float32, + ) + + # Humerus origin to centroid transform + self.X_HoHc = self._create_translation_matrix(x=0.0475) + + # 0-position humerus frame pose wrt base + self.X_BH = self._create_translation_matrix( + x=self.measurements["humerus"][0], + y=self.measurements["humerus"][1], + z=self.measurements["humerus"][2], + ) + + # Shoulder + # Screw axis of shoulder frame wrt Base frame + self.S_BS = np.array([0, 0, -1, 0, 0, 0], dtype=np.float32) + + # Shoulder origin to centroid transform + self.X_SoSc = self._create_translation_matrix(x=-0.017, z=0.0235) + + # 0-position shoulder frame pose wrt base + self.X_BS = self._create_translation_matrix( + x=self.measurements["shoulder"][0], + y=self.measurements["shoulder"][1], + z=self.measurements["shoulder"][2], + ) + + # Base + # Base origin to centroid transform + self.X_BoBc = self._create_translation_matrix(y=0.015) + + # World to base transform + self.X_WoBo = self._create_translation_matrix( + x=self.measurements["base"][0], + y=self.measurements["base"][1], + z=self.measurements["base"][2], + ) + + # Pre-compute gripper post-multiplication matrix + self._fk_gripper_post = self.X_GoGc @ self.X_BoGo @ self.gripper_X0 + + def forward_kinematics( + self, + robot_pos_deg: NDArray[np.float32], + frame: str = "gripper_tip", + ) -> NDArray[np.float32]: + """Generic forward kinematics. + + Args: + robot_pos_deg: Joint positions in degrees. Can be ``None`` when + computing the *base* frame as it does not depend on joint + angles. + frame: Target frame. One of + ``{"base", "shoulder", "humerus", "forearm", "wrist", "gripper", "gripper_tip"}``. + + Returns + ------- + NDArray[np.float32] + 4×4 homogeneous transformation matrix of the requested frame + expressed in the world coordinate system. + """ + frame = frame.lower() + if frame not in { + "base", + "shoulder", + "humerus", + "forearm", + "wrist", + "gripper", + "gripper_tip", + }: + raise ValueError( + f"Unknown frame '{frame}'. Valid options are base, shoulder, humerus, forearm, wrist, gripper, gripper_tip." + ) + + # Base frame does not rely on joint angles. + if frame == "base": + return self.X_WoBo @ self.X_BoBc @ self.base_X0 + + robot_pos_rad = robot_pos_deg / 180 * np.pi + + # Extract joint angles (note the sign convention for shoulder lift). + theta_shoulder_pan = robot_pos_rad[0] + theta_shoulder_lift = -robot_pos_rad[1] + theta_elbow_flex = robot_pos_rad[2] + theta_wrist_flex = robot_pos_rad[3] + theta_wrist_roll = robot_pos_rad[4] + + # Start with the world-to-base transform; incrementally add successive links. + transformation_matrix = self.X_WoBo @ screw_axis_to_transform(self.S_BS, theta_shoulder_pan) + if frame == "shoulder": + return transformation_matrix @ self.X_SoSc @ self.X_BS + + transformation_matrix = transformation_matrix @ screw_axis_to_transform( + self.S_BH, theta_shoulder_lift + ) + if frame == "humerus": + return transformation_matrix @ self.X_HoHc @ self.X_BH + + transformation_matrix = transformation_matrix @ screw_axis_to_transform(self.S_BF, theta_elbow_flex) + if frame == "forearm": + return transformation_matrix @ self.X_ForearmFc @ self.X_BF + + transformation_matrix = transformation_matrix @ screw_axis_to_transform(self.S_BR, theta_wrist_flex) + if frame == "wrist": + return transformation_matrix @ self.X_RoRc @ self.X_BR @ self.wrist_X0 + + transformation_matrix = transformation_matrix @ screw_axis_to_transform(self.S_BG, theta_wrist_roll) + if frame == "gripper": + return transformation_matrix @ self._fk_gripper_post + else: # frame == "gripper_tip" + return transformation_matrix @ self.X_GoGt @ self.X_BoGo @ self.gripper_X0 + + def compute_jacobian( + self, robot_pos_deg: NDArray[np.float32], frame: str = "gripper_tip" + ) -> NDArray[np.float32]: + """Finite differences to compute the Jacobian. + J(i, j) represents how the ith component of the end-effector's velocity changes wrt a small change + in the jth joint's velocity. + + Args: + robot_pos_deg: Current joint positions in degrees + fk_func: Forward kinematics function to use (defaults to fk_gripper) + """ + + eps = 1e-8 + jac = np.zeros(shape=(6, 5)) + delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64) + for el_ix in range(len(robot_pos_deg[:-1])): + delta *= 0 + delta[el_ix] = eps / 2 + sdot = ( + pose_difference_se3( + self.forward_kinematics(robot_pos_deg[:-1] + delta, frame), + self.forward_kinematics(robot_pos_deg[:-1] - delta, frame), + ) + / eps + ) + jac[:, el_ix] = sdot + return jac + + def compute_positional_jacobian( + self, robot_pos_deg: NDArray[np.float32], frame: str = "gripper_tip" + ) -> NDArray[np.float32]: + """Finite differences to compute the positional Jacobian. + J(i, j) represents how the ith component of the end-effector's position changes wrt a small change + in the jth joint's velocity. + + Args: + robot_pos_deg: Current joint positions in degrees + fk_func: Forward kinematics function to use (defaults to fk_gripper) + """ + eps = 1e-8 + jac = np.zeros(shape=(3, 5)) + delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64) + for el_ix in range(len(robot_pos_deg[:-1])): + delta *= 0 + delta[el_ix] = eps / 2 + sdot = ( + self.forward_kinematics(robot_pos_deg[:-1] + delta, frame)[:3, 3] + - self.forward_kinematics(robot_pos_deg[:-1] - delta, frame)[:3, 3] + ) / eps + jac[:, el_ix] = sdot + return jac + + def ik( + self, + current_joint_pos: NDArray[np.float32], + desired_ee_pose: NDArray[np.float32], + position_only: bool = True, + frame: str = "gripper_tip", + max_iterations: int = 5, + learning_rate: float = 1, + ) -> NDArray[np.float32]: + """Inverse kinematics using gradient descent. + + Args: + current_joint_state: Initial joint positions in degrees + desired_ee_pose: Target end-effector pose as a 4x4 transformation matrix + position_only: If True, only match end-effector position, not orientation + frame: Target frame. One of + ``{"base", "shoulder", "humerus", "forearm", "wrist", "gripper", "gripper_tip"}``. + max_iterations: Maximum number of iterations to run + learning_rate: Learning rate for gradient descent + + Returns: + Joint positions in degrees that achieve the desired end-effector pose + """ + # Do gradient descent. + current_joint_state = current_joint_pos.copy() + for _ in range(max_iterations): + current_ee_pose = self.forward_kinematics(current_joint_state, frame) + if not position_only: + error = se3_error(desired_ee_pose, current_ee_pose) + jac = self.compute_jacobian(current_joint_state, frame) + else: + error = desired_ee_pose[:3, 3] - current_ee_pose[:3, 3] + jac = self.compute_positional_jacobian(current_joint_state, frame) + delta_angles = np.linalg.pinv(jac) @ error + current_joint_state[:-1] += learning_rate * delta_angles + + if np.linalg.norm(error) < 5e-3: + return current_joint_state + return current_joint_state diff --git a/lerobot/common/optim/optimizers.py b/lerobot/common/optim/optimizers.py index 0cf4124ce6..903434f593 100644 --- a/lerobot/common/optim/optimizers.py +++ b/lerobot/common/optim/optimizers.py @@ -14,8 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field from pathlib import Path +from typing import Any import draccus import torch @@ -44,7 +45,16 @@ def default_choice_name(cls) -> str | None: return "adam" @abc.abstractmethod - def build(self) -> torch.optim.Optimizer: + def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]: + """ + Build the optimizer. It can be a single optimizer or a dictionary of optimizers. + NOTE: Multiple optimizers are useful when you have different models to optimize. + For example, you can have one optimizer for the policy and another one for the value function + in reinforcement learning settings. + + Returns: + The optimizer or a dictionary of optimizers. + """ raise NotImplementedError @@ -94,7 +104,76 @@ def build(self, params: dict) -> torch.optim.Optimizer: return torch.optim.SGD(params, **kwargs) -def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None: +@OptimizerConfig.register_subclass("multi_adam") +@dataclass +class MultiAdamConfig(OptimizerConfig): + """Configuration for multiple Adam optimizers with different parameter groups. + + This creates a dictionary of Adam optimizers, each with its own hyperparameters. + + Args: + lr: Default learning rate (used if not specified for a group) + weight_decay: Default weight decay (used if not specified for a group) + optimizer_groups: Dictionary mapping parameter group names to their hyperparameters + grad_clip_norm: Gradient clipping norm + """ + + lr: float = 1e-3 + weight_decay: float = 0.0 + grad_clip_norm: float = 10.0 + optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict) + + def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]: + """Build multiple Adam optimizers. + + Args: + params_dict: Dictionary mapping parameter group names to lists of parameters + The keys should match the keys in optimizer_groups + + Returns: + Dictionary mapping parameter group names to their optimizers + """ + optimizers = {} + + for name, params in params_dict.items(): + # Get group-specific hyperparameters or use defaults + group_config = self.optimizer_groups.get(name, {}) + + # Create optimizer with merged parameters (defaults + group-specific) + optimizer_kwargs = { + "lr": group_config.get("lr", self.lr), + "betas": group_config.get("betas", (0.9, 0.999)), + "eps": group_config.get("eps", 1e-5), + "weight_decay": group_config.get("weight_decay", self.weight_decay), + } + + optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs) + + return optimizers + + +def save_optimizer_state( + optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path +) -> None: + """Save optimizer state to disk. + + Args: + optimizer: Either a single optimizer or a dictionary of optimizers. + save_dir: Directory to save the optimizer state. + """ + if isinstance(optimizer, dict): + # Handle dictionary of optimizers + for name, opt in optimizer.items(): + optimizer_dir = save_dir / name + optimizer_dir.mkdir(exist_ok=True, parents=True) + _save_single_optimizer_state(opt, optimizer_dir) + else: + # Handle single optimizer + _save_single_optimizer_state(optimizer, save_dir) + + +def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None: + """Save a single optimizer's state to disk.""" state = optimizer.state_dict() param_groups = state.pop("param_groups") flat_state = flatten_dict(state) @@ -102,11 +181,44 @@ def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> No write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS) -def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer: +def load_optimizer_state( + optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path +) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]: + """Load optimizer state from disk. + + Args: + optimizer: Either a single optimizer or a dictionary of optimizers. + save_dir: Directory to load the optimizer state from. + + Returns: + The updated optimizer(s) with loaded state. + """ + if isinstance(optimizer, dict): + # Handle dictionary of optimizers + loaded_optimizers = {} + for name, opt in optimizer.items(): + optimizer_dir = save_dir / name + if optimizer_dir.exists(): + loaded_optimizers[name] = _load_single_optimizer_state(opt, optimizer_dir) + else: + loaded_optimizers[name] = opt + return loaded_optimizers + else: + # Handle single optimizer + return _load_single_optimizer_state(optimizer, save_dir) + + +def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer: + """Load a single optimizer's state from disk.""" current_state_dict = optimizer.state_dict() flat_state = load_file(save_dir / OPTIMIZER_STATE) state = unflatten_dict(flat_state) - loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}} + + # Handle case where 'state' key might not exist (for newly created optimizers) + if "state" in state: + loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}} + else: + loaded_state_dict = {"state": {}} if "param_groups" in current_state_dict: param_groups = deserialize_json_into_object( diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 72d4df03a2..e7e74bf380 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -15,7 +15,7 @@ # limitations under the License. """Action Chunking Transformer Policy -As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705). +As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://huggingface.co/papers/2304.13705). The majority of changes here involve removing unused code, unifying naming, and adding helpful comments. """ @@ -41,7 +41,7 @@ class ACTPolicy(PreTrainedPolicy): """ Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost - Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act) + Hardware (paper: https://huggingface.co/papers/2304.13705, code: https://github.com/tonyzhaozh/act) """ config_class = ACTConfig @@ -161,7 +161,7 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for # each dimension independently, we sum over the latent dimension to get the total # KL-divergence per batch element, then take the mean over the batch. - # (See App. B of https://arxiv.org/abs/1312.6114 for more details). + # (See App. B of https://huggingface.co/papers/1312.6114 for more details). mean_kld = ( (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() ) @@ -175,7 +175,7 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: class ACTTemporalEnsembler: def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None: - """Temporal ensembling as described in Algorithm 2 of https://arxiv.org/abs/2304.13705. + """Temporal ensembling as described in Algorithm 2 of https://huggingface.co/papers/2304.13705. The weights are calculated as wᵢ = exp(-temporal_ensemble_coeff * i) where w₀ is the oldest action. They are then normalized to sum to 1 by dividing by Σwᵢ. Here's some intuition around how the diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index e73c65fe9a..c8841f06b9 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -81,7 +81,7 @@ class DiffusionConfig(PreTrainedConfig): n_groups: Number of groups used in the group norm of the Unet's convolutional blocks. diffusion_step_embed_dim: The Unet is conditioned on the diffusion timestep via a small non-linear network. This is the output dimension of that network, i.e., the embedding dimension. - use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning. + use_film_scale_modulation: FiLM (https://huggingface.co/papers/1709.07871) is used for the Unet conditioning. Bias modulation is used be default, while this parameter indicates whether to also use scale modulation. noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"]. diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 3edaf852bc..446e2cb6ef 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -48,7 +48,7 @@ class DiffusionPolicy(PreTrainedPolicy): """ Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" - (paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy). + (paper: https://huggingface.co/papers/2303.04137, code: https://github.com/real-stanford/diffusion_policy). """ config_class = DiffusionConfig @@ -370,7 +370,7 @@ def compute_loss(self, batch: dict[str, Tensor]) -> Tensor: class SpatialSoftmax(nn.Module): """ Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al. - (https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation. + (https://huggingface.co/papers/1509.06113). A minimal port of the robomimic implementation. At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass" of activations of each channel, i.e., keypoints in the image space for the policy to focus on. @@ -728,7 +728,7 @@ def __init__( self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups) - # FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale. + # FiLM modulation (https://huggingface.co/papers/1709.07871) outputs per-channel bias and (maybe) scale. cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels)) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 3aade06656..682bb8cee9 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -27,6 +27,8 @@ from lerobot.common.policies.pi0.configuration_pi0 import PI0Config from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.policies.sac.configuration_sac import SACConfig +from lerobot.common.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig from lerobot.common.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig @@ -60,6 +62,14 @@ def get_policy_class(name: str) -> PreTrainedPolicy: from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy return PI0FASTPolicy + elif name == "sac": + from lerobot.common.policies.sac.modeling_sac import SACPolicy + + return SACPolicy + elif name == "reward_classifier": + from lerobot.common.policies.sac.reward_model.modeling_classifier import Classifier + + return Classifier elif name == "smolvla": from lerobot.common.policies.smolvla.modeling_smolvla import SmolVLAPolicy @@ -81,8 +91,12 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return PI0Config(**kwargs) elif policy_type == "pi0fast": return PI0FASTConfig(**kwargs) + elif policy_type == "sac": + return SACConfig(**kwargs) elif policy_type == "smolvla": return SmolVLAConfig(**kwargs) + elif policy_type == "reward_classifier": + return RewardClassifierConfig(**kwargs) else: raise ValueError(f"Policy type '{policy_type}' is not available.") diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index b3255ec106..9cc94b9298 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -151,6 +151,7 @@ def __init__( # TODO(rcadene): should we remove torch.no_grad? @torch.no_grad def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + # TODO: Remove this shallow copy batch = dict(batch) # shallow copy avoids mutating the input batch for key, ft in self.features.items(): if key not in batch: @@ -252,3 +253,168 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: else: raise ValueError(norm_mode) return batch + + +# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization +# and remove the `Normalize` and `Unnormalize` classes. +def _initialize_stats_buffers( + module: nn.Module, + features: dict[str, PolicyFeature], + norm_map: dict[str, NormalizationMode], + stats: dict[str, dict[str, Tensor]] | None = None, +) -> None: + """Register statistics buffers (mean/std or min/max) on the given *module*. + + The logic matches the previous constructors of `NormalizeBuffer` and `UnnormalizeBuffer`, + but is factored out so it can be reused by both classes and stay in sync. + """ + for key, ft in features.items(): + norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY) + if norm_mode is NormalizationMode.IDENTITY: + continue + + shape: tuple[int, ...] = tuple(ft.shape) + if ft.type is FeatureType.VISUAL: + # reduce spatial dimensions, keep channel dimension only + c, *_ = shape + shape = (c, 1, 1) + + prefix = key.replace(".", "_") + + if norm_mode is NormalizationMode.MEAN_STD: + mean = torch.full(shape, torch.inf, dtype=torch.float32) + std = torch.full(shape, torch.inf, dtype=torch.float32) + + if stats and key in stats and "mean" in stats[key] and "std" in stats[key]: + mean_data = stats[key]["mean"] + std_data = stats[key]["std"] + if isinstance(mean_data, torch.Tensor): + # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated + # tensors anywhere (for example, when we use the same stats for normalization and + # unnormalization). See the logic here + # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. + mean = mean_data.clone().to(dtype=torch.float32) + std = std_data.clone().to(dtype=torch.float32) + else: + raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).") + + module.register_buffer(f"{prefix}_mean", mean) + module.register_buffer(f"{prefix}_std", std) + continue + + if norm_mode is NormalizationMode.MIN_MAX: + min_val = torch.full(shape, torch.inf, dtype=torch.float32) + max_val = torch.full(shape, torch.inf, dtype=torch.float32) + + if stats and key in stats and "min" in stats[key] and "max" in stats[key]: + min_data = stats[key]["min"] + max_data = stats[key]["max"] + if isinstance(min_data, torch.Tensor): + min_val = min_data.clone().to(dtype=torch.float32) + max_val = max_data.clone().to(dtype=torch.float32) + else: + raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).") + + module.register_buffer(f"{prefix}_min", min_val) + module.register_buffer(f"{prefix}_max", max_val) + continue + + raise ValueError(norm_mode) + + +class NormalizeBuffer(nn.Module): + """Same as `Normalize` but statistics are stored as registered buffers rather than parameters.""" + + def __init__( + self, + features: dict[str, PolicyFeature], + norm_map: dict[str, NormalizationMode], + stats: dict[str, dict[str, Tensor]] | None = None, + ): + super().__init__() + self.features = features + self.norm_map = norm_map + + _initialize_stats_buffers(self, features, norm_map, stats) + + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + batch = dict(batch) + for key, ft in self.features.items(): + if key not in batch: + continue + + norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) + if norm_mode is NormalizationMode.IDENTITY: + continue + + prefix = key.replace(".", "_") + + if norm_mode is NormalizationMode.MEAN_STD: + mean = getattr(self, f"{prefix}_mean") + std = getattr(self, f"{prefix}_std") + assert not torch.isinf(mean).any(), _no_stats_error_str("mean") + assert not torch.isinf(std).any(), _no_stats_error_str("std") + batch[key] = (batch[key] - mean) / (std + 1e-8) + continue + + if norm_mode is NormalizationMode.MIN_MAX: + min_val = getattr(self, f"{prefix}_min") + max_val = getattr(self, f"{prefix}_max") + assert not torch.isinf(min_val).any(), _no_stats_error_str("min") + assert not torch.isinf(max_val).any(), _no_stats_error_str("max") + batch[key] = (batch[key] - min_val) / (max_val - min_val + 1e-8) + batch[key] = batch[key] * 2 - 1 + continue + + raise ValueError(norm_mode) + + return batch + + +class UnnormalizeBuffer(nn.Module): + """Inverse operation of `NormalizeBuffer`. Uses registered buffers for statistics.""" + + def __init__( + self, + features: dict[str, PolicyFeature], + norm_map: dict[str, NormalizationMode], + stats: dict[str, dict[str, Tensor]] | None = None, + ): + super().__init__() + self.features = features + self.norm_map = norm_map + + _initialize_stats_buffers(self, features, norm_map, stats) + + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + # batch = dict(batch) + for key, ft in self.features.items(): + if key not in batch: + continue + + norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) + if norm_mode is NormalizationMode.IDENTITY: + continue + + prefix = key.replace(".", "_") + + if norm_mode is NormalizationMode.MEAN_STD: + mean = getattr(self, f"{prefix}_mean") + std = getattr(self, f"{prefix}_std") + assert not torch.isinf(mean).any(), _no_stats_error_str("mean") + assert not torch.isinf(std).any(), _no_stats_error_str("std") + batch[key] = batch[key] * std + mean + continue + + if norm_mode is NormalizationMode.MIN_MAX: + min_val = getattr(self, f"{prefix}_min") + max_val = getattr(self, f"{prefix}_max") + assert not torch.isinf(min_val).any(), _no_stats_error_str("min") + assert not torch.isinf(max_val).any(), _no_stats_error_str("max") + batch[key] = (batch[key] + 1) / 2 + batch[key] = batch[key] * (max_val - min_val) + min_val + continue + + raise ValueError(norm_mode) + + return batch diff --git a/lerobot/common/policies/pi0/paligemma_with_expert.py b/lerobot/common/policies/pi0/paligemma_with_expert.py index 76e2ce6005..49c844c7bf 100644 --- a/lerobot/common/policies/pi0/paligemma_with_expert.py +++ b/lerobot/common/policies/pi0/paligemma_with_expert.py @@ -216,7 +216,11 @@ def to_bfloat16_like_physical_intelligence(self): param.data = param.data.to(dtype=torch.bfloat16) def embed_image(self, image: torch.Tensor): - return self.paligemma.get_image_features(image) + # Handle different transformers versions + if hasattr(self.paligemma, "get_image_features"): + return self.paligemma.get_image_features(image) + else: + return self.paligemma.model.get_image_features(image) def embed_language_tokens(self, tokens: torch.Tensor): return self.paligemma.language_model.model.embed_tokens(tokens) diff --git a/lerobot/common/policies/pi0fast/modeling_pi0fast.py b/lerobot/common/policies/pi0fast/modeling_pi0fast.py index 4996b1a083..7102bdded5 100644 --- a/lerobot/common/policies/pi0fast/modeling_pi0fast.py +++ b/lerobot/common/policies/pi0fast/modeling_pi0fast.py @@ -17,7 +17,7 @@ """ π0+FAST: Efficient Action Tokenization for Vision-Language-Action Models -[Paper](https://arxiv.org/abs/2501.09747) +[Paper](https://huggingface.co/papers/2501.09747) [Jax code](https://github.com/Physical-Intelligence/openpi) Designed by Physical Intelligence. Ported from Jax by Hugging Face. @@ -878,7 +878,11 @@ def generate_actions(self, batch: dict[str, Tensor]): return actions def embed_image(self, image: torch.Tensor): - return self.pi0_paligemma.get_image_features(image) + # Handle different transformers versions + if hasattr(self.pi0_paligemma, "get_image_features"): + return self.pi0_paligemma.get_image_features(image) + else: + return self.pi0_paligemma.model.get_image_features(image) def embed_inputs( self, diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py new file mode 100644 index 0000000000..db58beb2f0 --- /dev/null +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -0,0 +1,245 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_STATE +from lerobot.common.optim.optimizers import MultiAdamConfig +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import NormalizationMode + + +def is_image_feature(key: str) -> bool: + """Check if a feature key represents an image feature. + + Args: + key: The feature key to check + + Returns: + True if the key represents an image feature, False otherwise + """ + return key.startswith(OBS_IMAGE) + + +@dataclass +class ConcurrencyConfig: + """Configuration for the concurrency of the actor and learner. + Possible values are: + - "threads": Use threads for the actor and learner. + - "processes": Use processes for the actor and learner. + """ + + actor: str = "threads" + learner: str = "threads" + + +@dataclass +class ActorLearnerConfig: + learner_host: str = "127.0.0.1" + learner_port: int = 50051 + policy_parameters_push_frequency: int = 4 + queue_get_timeout: float = 2 + + +@dataclass +class CriticNetworkConfig: + hidden_dims: list[int] = field(default_factory=lambda: [256, 256]) + activate_final: bool = True + final_activation: str | None = None + + +@dataclass +class ActorNetworkConfig: + hidden_dims: list[int] = field(default_factory=lambda: [256, 256]) + activate_final: bool = True + + +@dataclass +class PolicyConfig: + use_tanh_squash: bool = True + std_min: float = 1e-5 + std_max: float = 10.0 + init_final: float = 0.05 + + +@PreTrainedConfig.register_subclass("sac") +@dataclass +class SACConfig(PreTrainedConfig): + """Soft Actor-Critic (SAC) configuration. + + SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy + reinforcement learning framework. It learns a policy and a Q-function simultaneously + using experience collected from the environment. + + This configuration class contains all the parameters needed to define a SAC agent, + including network architectures, optimization settings, and algorithm-specific + hyperparameters. + """ + + # Mapping of feature types to normalization modes + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.MEAN_STD, + "STATE": NormalizationMode.MIN_MAX, + "ENV": NormalizationMode.MIN_MAX, + "ACTION": NormalizationMode.MIN_MAX, + } + ) + + # Statistics for normalizing different types of inputs + dataset_stats: dict[str, dict[str, list[float]]] | None = field( + default_factory=lambda: { + OBS_IMAGE: { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + }, + OBS_STATE: { + "min": [0.0, 0.0], + "max": [1.0, 1.0], + }, + ACTION: { + "min": [0.0, 0.0, 0.0], + "max": [1.0, 1.0, 1.0], + }, + } + ) + + # Architecture specifics + # Device to run the model on (e.g., "cuda", "cpu") + device: str = "cpu" + # Device to store the model on + storage_device: str = "cpu" + # Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10) + vision_encoder_name: str | None = None + # Whether to freeze the vision encoder during training + freeze_vision_encoder: bool = True + # Hidden dimension size for the image encoder + image_encoder_hidden_dim: int = 32 + # Whether to use a shared encoder for actor and critic + shared_encoder: bool = True + # Number of discrete actions, eg for gripper actions + num_discrete_actions: int | None = None + # Dimension of the image embedding pooling + image_embedding_pooling_dim: int = 8 + + # Training parameter + # Number of steps for online training + online_steps: int = 1000000 + # Seed for the online environment + online_env_seed: int = 10000 + # Capacity of the online replay buffer + online_buffer_capacity: int = 100000 + # Capacity of the offline replay buffer + offline_buffer_capacity: int = 100000 + # Whether to use asynchronous prefetching for the buffers + async_prefetch: bool = False + # Number of steps before learning starts + online_step_before_learning: int = 100 + # Frequency of policy updates + policy_update_freq: int = 1 + + # SAC algorithm parameters + # Discount factor for the SAC algorithm + discount: float = 0.99 + # Initial temperature value + temperature_init: float = 1.0 + # Number of critics in the ensemble + num_critics: int = 2 + # Number of subsampled critics for training + num_subsample_critics: int | None = None + # Learning rate for the critic network + critic_lr: float = 3e-4 + # Learning rate for the actor network + actor_lr: float = 3e-4 + # Learning rate for the temperature parameter + temperature_lr: float = 3e-4 + # Weight for the critic target update + critic_target_update_weight: float = 0.005 + # Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1) + utd_ratio: int = 1 + # Hidden dimension size for the state encoder + state_encoder_hidden_dim: int = 256 + # Dimension of the latent space + latent_dim: int = 256 + # Target entropy for the SAC algorithm + target_entropy: float | None = None + # Whether to use backup entropy for the SAC algorithm + use_backup_entropy: bool = True + # Gradient clipping norm for the SAC algorithm + grad_clip_norm: float = 40.0 + + # Network configuration + # Configuration for the critic network architecture + critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) + # Configuration for the actor network architecture + actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig) + # Configuration for the policy parameters + policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig) + # Configuration for the discrete critic network + discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) + # Configuration for actor-learner architecture + actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig) + # Configuration for concurrency settings (you can use threads or processes for the actor and learner) + concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig) + + # Optimizations + use_torch_compile: bool = True + + def __post_init__(self): + super().__post_init__() + # Any validation specific to SAC configuration + + def get_optimizer_preset(self) -> MultiAdamConfig: + return MultiAdamConfig( + weight_decay=0.0, + optimizer_groups={ + "actor": {"lr": self.actor_lr}, + "critic": {"lr": self.critic_lr}, + "temperature": {"lr": self.temperature_lr}, + }, + ) + + def get_scheduler_preset(self) -> None: + return None + + def validate_features(self) -> None: + has_image = any(is_image_feature(key) for key in self.input_features) + has_state = OBS_STATE in self.input_features + + if not (has_state or has_image): + raise ValueError( + "You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features" + ) + + if "action" not in self.output_features: + raise ValueError("You must provide 'action' in the output features") + + @property + def image_features(self) -> list[str]: + return [key for key in self.input_features if is_image_feature(key)] + + @property + def observation_delta_indices(self) -> list: + return None + + @property + def action_delta_indices(self) -> list: + return None # SAC typically predicts one action at a time + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py new file mode 100644 index 0000000000..b588115ea0 --- /dev/null +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -0,0 +1,1111 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import asdict +from typing import Callable, Literal + +import einops +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F # noqa: N812 +from torch import Tensor +from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution + +from lerobot.common.policies.normalize import NormalizeBuffer +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.policies.sac.configuration_sac import SACConfig, is_image_feature +from lerobot.common.policies.utils import get_device_from_parameters + +DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension + + +class SACPolicy( + PreTrainedPolicy, +): + config_class = SACConfig + name = "sac" + + def __init__( + self, + config: SACConfig | None = None, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + ): + super().__init__(config) + config.validate_features() + self.config = config + + # Determine action dimension and initialize all components + continuous_action_dim = config.output_features["action"].shape[0] + self._init_normalization(dataset_stats) + self._init_encoders() + self._init_critics(continuous_action_dim) + self._init_actor(continuous_action_dim) + self._init_temperature() + + def get_optim_params(self) -> dict: + optim_params = { + "actor": [ + p + for n, p in self.actor.named_parameters() + if not n.startswith("encoder") or not self.shared_encoder + ], + "critic": self.critic_ensemble.parameters(), + "temperature": self.log_alpha, + } + if self.config.num_discrete_actions is not None: + optim_params["discrete_critic"] = self.discrete_critic.parameters() + return optim_params + + def reset(self): + """Reset the policy""" + pass + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select action for inference/evaluation""" + + observations_features = None + if self.shared_encoder and self.actor.encoder.has_images: + # Cache and normalize image features + observations_features = self.actor.encoder.get_cached_image_features(batch, normalize=True) + + actions, _, _ = self.actor(batch, observations_features) + + if self.config.num_discrete_actions is not None: + discrete_action_value = self.discrete_critic(batch, observations_features) + discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True) + actions = torch.cat([actions, discrete_action], dim=-1) + + return actions + + def critic_forward( + self, + observations: dict[str, Tensor], + actions: Tensor, + use_target: bool = False, + observation_features: Tensor | None = None, + ) -> Tensor: + """Forward pass through a critic network ensemble + + Args: + observations: Dictionary of observations + actions: Action tensor + use_target: If True, use target critics, otherwise use ensemble critics + + Returns: + Tensor of Q-values from all critics + """ + + critics = self.critic_target if use_target else self.critic_ensemble + q_values = critics(observations, actions, observation_features) + return q_values + + def discrete_critic_forward( + self, observations, use_target=False, observation_features=None + ) -> torch.Tensor: + """Forward pass through a discrete critic network + + Args: + observations: Dictionary of observations + use_target: If True, use target critics, otherwise use ensemble critics + observation_features: Optional pre-computed observation features to avoid recomputing encoder output + + Returns: + Tensor of Q-values from the discrete critic network + """ + discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic + q_values = discrete_critic(observations, observation_features) + return q_values + + def forward( + self, + batch: dict[str, Tensor | dict[str, Tensor]], + model: Literal["actor", "critic", "temperature", "discrete_critic"] = "critic", + ) -> dict[str, Tensor]: + """Compute the loss for the given model + + Args: + batch: Dictionary containing: + - action: Action tensor + - reward: Reward tensor + - state: Observations tensor dict + - next_state: Next observations tensor dict + - done: Done mask tensor + - observation_feature: Optional pre-computed observation features + - next_observation_feature: Optional pre-computed next observation features + model: Which model to compute the loss for ("actor", "critic", "discrete_critic", or "temperature") + + Returns: + The computed loss tensor + """ + # Extract common components from batch + actions: Tensor = batch["action"] + observations: dict[str, Tensor] = batch["state"] + observation_features: Tensor = batch.get("observation_feature") + + if model == "critic": + # Extract critic-specific components + rewards: Tensor = batch["reward"] + next_observations: dict[str, Tensor] = batch["next_state"] + done: Tensor = batch["done"] + next_observation_features: Tensor = batch.get("next_observation_feature") + + loss_critic = self.compute_loss_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + observation_features=observation_features, + next_observation_features=next_observation_features, + ) + + return {"loss_critic": loss_critic} + + if model == "discrete_critic" and self.config.num_discrete_actions is not None: + # Extract critic-specific components + rewards: Tensor = batch["reward"] + next_observations: dict[str, Tensor] = batch["next_state"] + done: Tensor = batch["done"] + next_observation_features: Tensor = batch.get("next_observation_feature") + complementary_info = batch.get("complementary_info") + loss_discrete_critic = self.compute_loss_discrete_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + observation_features=observation_features, + next_observation_features=next_observation_features, + complementary_info=complementary_info, + ) + return {"loss_discrete_critic": loss_discrete_critic} + if model == "actor": + return { + "loss_actor": self.compute_loss_actor( + observations=observations, + observation_features=observation_features, + ) + } + + if model == "temperature": + return { + "loss_temperature": self.compute_loss_temperature( + observations=observations, + observation_features=observation_features, + ) + } + + raise ValueError(f"Unknown model type: {model}") + + def update_target_networks(self): + """Update target networks with exponential moving average""" + for target_param, param in zip( + self.critic_target.parameters(), + self.critic_ensemble.parameters(), + strict=True, + ): + target_param.data.copy_( + param.data * self.config.critic_target_update_weight + + target_param.data * (1.0 - self.config.critic_target_update_weight) + ) + if self.config.num_discrete_actions is not None: + for target_param, param in zip( + self.discrete_critic_target.parameters(), + self.discrete_critic.parameters(), + strict=True, + ): + target_param.data.copy_( + param.data * self.config.critic_target_update_weight + + target_param.data * (1.0 - self.config.critic_target_update_weight) + ) + + def update_temperature(self): + self.temperature = self.log_alpha.exp().item() + + def compute_loss_critic( + self, + observations, + actions, + rewards, + next_observations, + done, + observation_features: Tensor | None = None, + next_observation_features: Tensor | None = None, + ) -> Tensor: + with torch.no_grad(): + next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features) + + # 2- compute q targets + q_targets = self.critic_forward( + observations=next_observations, + actions=next_action_preds, + use_target=True, + observation_features=next_observation_features, + ) + + # subsample critics to prevent overfitting if use high UTD (update to date) + # TODO: Get indices before forward pass to avoid unnecessary computation + if self.config.num_subsample_critics is not None: + indices = torch.randperm(self.config.num_critics) + indices = indices[: self.config.num_subsample_critics] + q_targets = q_targets[indices] + + # critics subsample size + min_q, _ = q_targets.min(dim=0) # Get values from min operation + if self.config.use_backup_entropy: + min_q = min_q - (self.temperature * next_log_probs) + + td_target = rewards + (1 - done) * self.config.discount * min_q + + # 3- compute predicted qs + if self.config.num_discrete_actions is not None: + # NOTE: We only want to keep the continuous action part + # In the buffer we have the full action space (continuous + discrete) + # We need to split them before concatenating them in the critic forward + actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX] + q_preds = self.critic_forward( + observations=observations, + actions=actions, + use_target=False, + observation_features=observation_features, + ) + + # 4- Calculate loss + # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. + td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]) + # You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up + critics_loss = ( + F.mse_loss( + input=q_preds, + target=td_target_duplicate, + reduction="none", + ).mean(dim=1) + ).sum() + return critics_loss + + def compute_loss_discrete_critic( + self, + observations, + actions, + rewards, + next_observations, + done, + observation_features=None, + next_observation_features=None, + complementary_info=None, + ): + # NOTE: We only want to keep the discrete action part + # In the buffer we have the full action space (continuous + discrete) + # We need to split them before concatenating them in the critic forward + actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone() + actions_discrete = torch.round(actions_discrete) + actions_discrete = actions_discrete.long() + + discrete_penalties: Tensor | None = None + if complementary_info is not None: + discrete_penalties: Tensor | None = complementary_info.get("discrete_penalty") + + with torch.no_grad(): + # For DQN, select actions using online network, evaluate with target network + next_discrete_qs = self.discrete_critic_forward( + next_observations, use_target=False, observation_features=next_observation_features + ) + best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True) + + # Get target Q-values from target network + target_next_discrete_qs = self.discrete_critic_forward( + observations=next_observations, + use_target=True, + observation_features=next_observation_features, + ) + + # Use gather to select Q-values for best actions + target_next_discrete_q = torch.gather( + target_next_discrete_qs, dim=1, index=best_next_discrete_action + ).squeeze(-1) + + # Compute target Q-value with Bellman equation + rewards_discrete = rewards + if discrete_penalties is not None: + rewards_discrete = rewards + discrete_penalties + target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q + + # Get predicted Q-values for current observations + predicted_discrete_qs = self.discrete_critic_forward( + observations=observations, use_target=False, observation_features=observation_features + ) + + # Use gather to select Q-values for taken actions + predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1) + + # Compute MSE loss between predicted and target Q-values + discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q) + return discrete_critic_loss + + def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor: + """Compute the temperature loss""" + # calculate temperature loss + with torch.no_grad(): + _, log_probs, _ = self.actor(observations, observation_features) + temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean() + return temperature_loss + + def compute_loss_actor( + self, + observations, + observation_features: Tensor | None = None, + ) -> Tensor: + actions_pi, log_probs, _ = self.actor(observations, observation_features) + + q_preds = self.critic_forward( + observations=observations, + actions=actions_pi, + use_target=False, + observation_features=observation_features, + ) + min_q_preds = q_preds.min(dim=0)[0] + + actor_loss = ((self.temperature * log_probs) - min_q_preds).mean() + return actor_loss + + def _init_normalization(self, dataset_stats): + """Initialize input/output normalization modules.""" + self.normalize_inputs = nn.Identity() + self.normalize_targets = nn.Identity() + if self.config.dataset_stats is not None: + params = _convert_normalization_params_to_tensor(self.config.dataset_stats) + self.normalize_inputs = NormalizeBuffer( + self.config.input_features, self.config.normalization_mapping, params + ) + stats = dataset_stats or params + self.normalize_targets = NormalizeBuffer( + self.config.output_features, self.config.normalization_mapping, stats + ) + + def _init_encoders(self): + """Initialize shared or separate encoders for actor and critic.""" + self.shared_encoder = self.config.shared_encoder + self.encoder_critic = SACObservationEncoder(self.config, self.normalize_inputs) + self.encoder_actor = ( + self.encoder_critic + if self.shared_encoder + else SACObservationEncoder(self.config, self.normalize_inputs) + ) + + def _init_critics(self, continuous_action_dim): + """Build critic ensemble, targets, and optional discrete critic.""" + heads = [ + CriticHead( + input_dim=self.encoder_critic.output_dim + continuous_action_dim, + **asdict(self.config.critic_network_kwargs), + ) + for _ in range(self.config.num_critics) + ] + self.critic_ensemble = CriticEnsemble( + encoder=self.encoder_critic, ensemble=heads, output_normalization=self.normalize_targets + ) + target_heads = [ + CriticHead( + input_dim=self.encoder_critic.output_dim + continuous_action_dim, + **asdict(self.config.critic_network_kwargs), + ) + for _ in range(self.config.num_critics) + ] + self.critic_target = CriticEnsemble( + encoder=self.encoder_critic, ensemble=target_heads, output_normalization=self.normalize_targets + ) + self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) + + if self.config.use_torch_compile: + self.critic_ensemble = torch.compile(self.critic_ensemble) + self.critic_target = torch.compile(self.critic_target) + + if self.config.num_discrete_actions is not None: + self._init_discrete_critics() + + def _init_discrete_critics(self): + """Build discrete discrete critic ensemble and target networks.""" + self.discrete_critic = DiscreteCritic( + encoder=self.encoder_critic, + input_dim=self.encoder_critic.output_dim, + output_dim=self.config.num_discrete_actions, + **asdict(self.config.discrete_critic_network_kwargs), + ) + self.discrete_critic_target = DiscreteCritic( + encoder=self.encoder_critic, + input_dim=self.encoder_critic.output_dim, + output_dim=self.config.num_discrete_actions, + **asdict(self.config.discrete_critic_network_kwargs), + ) + + # TODO: (maractingi, azouitine) Compile the discrete critic + self.discrete_critic_target.load_state_dict(self.discrete_critic.state_dict()) + + def _init_actor(self, continuous_action_dim): + """Initialize policy actor network and default target entropy.""" + # NOTE: The actor select only the continuous action part + self.actor = Policy( + encoder=self.encoder_actor, + network=MLP(input_dim=self.encoder_actor.output_dim, **asdict(self.config.actor_network_kwargs)), + action_dim=continuous_action_dim, + encoder_is_shared=self.shared_encoder, + **asdict(self.config.policy_kwargs), + ) + + self.target_entropy = self.config.target_entropy + if self.target_entropy is None: + dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0) + self.target_entropy = -np.prod(dim) / 2 + + def _init_temperature(self): + """Set up temperature parameter and initial log_alpha.""" + temp_init = self.config.temperature_init + self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)])) + self.temperature = self.log_alpha.exp().item() + + +class SACObservationEncoder(nn.Module): + """Encode image and/or state vector observations.""" + + def __init__(self, config: SACConfig, input_normalizer: nn.Module) -> None: + super().__init__() + self.config = config + self.input_normalization = input_normalizer + self._init_image_layers() + self._init_state_layers() + self._compute_output_dim() + + def _init_image_layers(self) -> None: + self.image_keys = [k for k in self.config.input_features if is_image_feature(k)] + self.has_images = bool(self.image_keys) + if not self.has_images: + return + + if self.config.vision_encoder_name is not None: + self.image_encoder = PretrainedImageEncoder(self.config) + else: + self.image_encoder = DefaultImageEncoder(self.config) + + if self.config.freeze_vision_encoder: + freeze_image_encoder(self.image_encoder) + + dummy = torch.zeros(1, *self.config.input_features[self.image_keys[0]].shape) + with torch.no_grad(): + _, channels, height, width = self.image_encoder(dummy).shape + + self.spatial_embeddings = nn.ModuleDict() + self.post_encoders = nn.ModuleDict() + + for key in self.image_keys: + name = key.replace(".", "_") + self.spatial_embeddings[name] = SpatialLearnedEmbeddings( + height=height, + width=width, + channel=channels, + num_features=self.config.image_embedding_pooling_dim, + ) + self.post_encoders[name] = nn.Sequential( + nn.Dropout(0.1), + nn.Linear( + in_features=channels * self.config.image_embedding_pooling_dim, + out_features=self.config.latent_dim, + ), + nn.LayerNorm(normalized_shape=self.config.latent_dim), + nn.Tanh(), + ) + + def _init_state_layers(self) -> None: + self.has_env = "observation.environment_state" in self.config.input_features + self.has_state = "observation.state" in self.config.input_features + if self.has_env: + dim = self.config.input_features["observation.environment_state"].shape[0] + self.env_encoder = nn.Sequential( + nn.Linear(dim, self.config.latent_dim), + nn.LayerNorm(self.config.latent_dim), + nn.Tanh(), + ) + if self.has_state: + dim = self.config.input_features["observation.state"].shape[0] + self.state_encoder = nn.Sequential( + nn.Linear(dim, self.config.latent_dim), + nn.LayerNorm(self.config.latent_dim), + nn.Tanh(), + ) + + def _compute_output_dim(self) -> None: + out = 0 + if self.has_images: + out += len(self.image_keys) * self.config.latent_dim + if self.has_env: + out += self.config.latent_dim + if self.has_state: + out += self.config.latent_dim + self._out_dim = out + + def forward( + self, obs: dict[str, Tensor], cache: dict[str, Tensor] | None = None, detach: bool = False + ) -> Tensor: + obs = self.input_normalization(obs) + parts = [] + if self.has_images: + if cache is None: + cache = self.get_cached_image_features(obs, normalize=False) + parts.append(self._encode_images(cache, detach)) + if self.has_env: + parts.append(self.env_encoder(obs["observation.environment_state"])) + if self.has_state: + parts.append(self.state_encoder(obs["observation.state"])) + if parts: + return torch.cat(parts, dim=-1) + + raise ValueError( + "No parts to concatenate, you should have at least one image or environment state or state" + ) + + def get_cached_image_features(self, obs: dict[str, Tensor], normalize: bool = False) -> dict[str, Tensor]: + """Extract and optionally cache image features from observations. + + This function processes image observations through the vision encoder once and returns + the resulting features. + When the image encoder is shared between actor and critics AND frozen, these features can be safely cached and + reused across policy components (actor, critic, discrete_critic), avoiding redundant forward passes. + + Performance impact: + - The vision encoder forward pass is typically the main computational bottleneck during training and inference + - Caching these features can provide 2-4x speedup in training and inference + + Normalization behavior: + - When called from inside forward(): set normalize=False since inputs are already normalized + - When called from outside forward(): set normalize=True to ensure proper input normalization + + Usage patterns: + - Called in select_action() with normalize=True + - Called in learner.py's get_observation_features() to pre-compute features for all policy components + - Called internally by forward() with normalize=False + + Args: + obs: Dictionary of observation tensors containing image keys + normalize: Whether to normalize observations before encoding + Set to True when calling directly from outside the encoder's forward method + Set to False when calling from within forward() where inputs are already normalized + + Returns: + Dictionary mapping image keys to their corresponding encoded features + """ + if normalize: + obs = self.input_normalization(obs) + batched = torch.cat([obs[k] for k in self.image_keys], dim=0) + out = self.image_encoder(batched) + chunks = torch.chunk(out, len(self.image_keys), dim=0) + return dict(zip(self.image_keys, chunks, strict=False)) + + def _encode_images(self, cache: dict[str, Tensor], detach: bool) -> Tensor: + """Encode image features from cached observations. + + This function takes pre-encoded image features from the cache and applies spatial embeddings and post-encoders. + It also supports detaching the encoded features if specified. + + Args: + cache (dict[str, Tensor]): The cached image features. + detach (bool): Usually when the encoder is shared between actor and critics, + we want to detach the encoded features on the policy side to avoid backprop through the encoder. + More detail here `https://cdn.aaai.org/ojs/17276/17276-13-20770-1-2-20210518.pdf` + + Returns: + Tensor: The encoded image features. + """ + feats = [] + for k, feat in cache.items(): + safe_key = k.replace(".", "_") + x = self.spatial_embeddings[safe_key](feat) + x = self.post_encoders[safe_key](x) + if detach: + x = x.detach() + feats.append(x) + return torch.cat(feats, dim=-1) + + @property + def output_dim(self) -> int: + return self._out_dim + + +class MLP(nn.Module): + """Multi-layer perceptron builder. + + Dynamically constructs a sequence of layers based on `hidden_dims`: + 1) Linear (in_dim -> out_dim) + 2) Optional Dropout if `dropout_rate` > 0 and (not final layer or `activate_final`) + 3) LayerNorm on the output features + 4) Activation (standard for intermediate layers, `final_activation` for last layer if `activate_final`) + + Arguments: + input_dim (int): Size of input feature dimension. + hidden_dims (list[int]): Sizes for each hidden layer. + activations (Callable or str): Activation to apply between layers. + activate_final (bool): Whether to apply activation at the final layer. + dropout_rate (Optional[float]): Dropout probability applied before normalization and activation. + final_activation (Optional[Callable or str]): Activation for the final layer when `activate_final` is True. + + For each layer, `in_dim` is updated to the previous `out_dim`. All constructed modules are + stored in `self.net` as an `nn.Sequential` container. + """ + + def __init__( + self, + input_dim: int, + hidden_dims: list[int], + activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(), + activate_final: bool = False, + dropout_rate: float | None = None, + final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None, + ): + super().__init__() + layers: list[nn.Module] = [] + in_dim = input_dim + total = len(hidden_dims) + + for idx, out_dim in enumerate(hidden_dims): + # 1) linear transform + layers.append(nn.Linear(in_dim, out_dim)) + + is_last = idx == total - 1 + # 2-4) optionally add dropout, normalization, and activation + if not is_last or activate_final: + if dropout_rate and dropout_rate > 0: + layers.append(nn.Dropout(p=dropout_rate)) + layers.append(nn.LayerNorm(out_dim)) + act_cls = final_activation if is_last and final_activation else activations + act = act_cls if isinstance(act_cls, nn.Module) else getattr(nn, act_cls)() + layers.append(act) + + in_dim = out_dim + + self.net = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +class CriticHead(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dims: list[int], + activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(), + activate_final: bool = False, + dropout_rate: float | None = None, + init_final: float | None = None, + final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None, + ): + super().__init__() + self.net = MLP( + input_dim=input_dim, + hidden_dims=hidden_dims, + activations=activations, + activate_final=activate_final, + dropout_rate=dropout_rate, + final_activation=final_activation, + ) + self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1) + if init_final is not None: + nn.init.uniform_(self.output_layer.weight, -init_final, init_final) + nn.init.uniform_(self.output_layer.bias, -init_final, init_final) + else: + orthogonal_init()(self.output_layer.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.output_layer(self.net(x)) + + +class CriticEnsemble(nn.Module): + """ + CriticEnsemble wraps multiple CriticHead modules into an ensemble. + + Args: + encoder (SACObservationEncoder): encoder for observations. + ensemble (List[CriticHead]): list of critic heads. + output_normalization (nn.Module): normalization layer for actions. + init_final (float | None): optional initializer scale for final layers. + + Forward returns a tensor of shape (num_critics, batch_size) containing Q-values. + """ + + def __init__( + self, + encoder: SACObservationEncoder, + ensemble: list[CriticHead], + output_normalization: nn.Module, + init_final: float | None = None, + ): + super().__init__() + self.encoder = encoder + self.init_final = init_final + self.output_normalization = output_normalization + self.critics = nn.ModuleList(ensemble) + + def forward( + self, + observations: dict[str, torch.Tensor], + actions: torch.Tensor, + observation_features: torch.Tensor | None = None, + ) -> torch.Tensor: + device = get_device_from_parameters(self) + # Move each tensor in observations to device + observations = {k: v.to(device) for k, v in observations.items()} + # NOTE: We normalize actions it helps for sample efficiency + actions: dict[str, torch.tensor] = {"action": actions} + # NOTE: Normalization layer took dict in input and outputs a dict that why + actions = self.output_normalization(actions)["action"] + actions = actions.to(device) + + obs_enc = self.encoder(observations, cache=observation_features) + + inputs = torch.cat([obs_enc, actions], dim=-1) + + # Loop through critics and collect outputs + q_values = [] + for critic in self.critics: + q_values.append(critic(inputs)) + + # Stack outputs to match expected shape [num_critics, batch_size] + q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0) + return q_values + + +class DiscreteCritic(nn.Module): + def __init__( + self, + encoder: nn.Module, + input_dim: int, + hidden_dims: list[int], + output_dim: int = 3, + activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(), + activate_final: bool = False, + dropout_rate: float | None = None, + init_final: float | None = None, + final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None, + ): + super().__init__() + self.encoder = encoder + self.output_dim = output_dim + + self.net = MLP( + input_dim=input_dim, + hidden_dims=hidden_dims, + activations=activations, + activate_final=activate_final, + dropout_rate=dropout_rate, + final_activation=final_activation, + ) + + self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=self.output_dim) + if init_final is not None: + nn.init.uniform_(self.output_layer.weight, -init_final, init_final) + nn.init.uniform_(self.output_layer.bias, -init_final, init_final) + else: + orthogonal_init()(self.output_layer.weight) + + def forward( + self, observations: torch.Tensor, observation_features: torch.Tensor | None = None + ) -> torch.Tensor: + device = get_device_from_parameters(self) + observations = {k: v.to(device) for k, v in observations.items()} + obs_enc = self.encoder(observations, cache=observation_features) + return self.output_layer(self.net(obs_enc)) + + +class Policy(nn.Module): + def __init__( + self, + encoder: SACObservationEncoder, + network: nn.Module, + action_dim: int, + std_min: float = -5, + std_max: float = 2, + fixed_std: torch.Tensor | None = None, + init_final: float | None = None, + use_tanh_squash: bool = False, + encoder_is_shared: bool = False, + ): + super().__init__() + self.encoder: SACObservationEncoder = encoder + self.network = network + self.action_dim = action_dim + self.std_min = std_min + self.std_max = std_max + self.fixed_std = fixed_std + self.use_tanh_squash = use_tanh_squash + self.encoder_is_shared = encoder_is_shared + + # Find the last Linear layer's output dimension + for layer in reversed(network.net): + if isinstance(layer, nn.Linear): + out_features = layer.out_features + break + # Mean layer + self.mean_layer = nn.Linear(out_features, action_dim) + if init_final is not None: + nn.init.uniform_(self.mean_layer.weight, -init_final, init_final) + nn.init.uniform_(self.mean_layer.bias, -init_final, init_final) + else: + orthogonal_init()(self.mean_layer.weight) + + # Standard deviation layer or parameter + if fixed_std is None: + self.std_layer = nn.Linear(out_features, action_dim) + if init_final is not None: + nn.init.uniform_(self.std_layer.weight, -init_final, init_final) + nn.init.uniform_(self.std_layer.bias, -init_final, init_final) + else: + orthogonal_init()(self.std_layer.weight) + + def forward( + self, + observations: torch.Tensor, + observation_features: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # We detach the encoder if it is shared to avoid backprop through it + # This is important to avoid the encoder to be updated through the policy + obs_enc = self.encoder(observations, cache=observation_features, detach=self.encoder_is_shared) + + # Get network outputs + outputs = self.network(obs_enc) + means = self.mean_layer(outputs) + + # Compute standard deviations + if self.fixed_std is None: + log_std = self.std_layer(outputs) + std = torch.exp(log_std) # Match JAX "exp" + std = torch.clamp(std, self.std_min, self.std_max) # Match JAX default clip + else: + std = self.fixed_std.expand_as(means) + + # Build transformed distribution + dist = TanhMultivariateNormalDiag(loc=means, scale_diag=std) + + # Sample actions (reparameterized) + actions = dist.rsample() + + # Compute log_probs + log_probs = dist.log_prob(actions) + + return actions, log_probs, means + + def get_features(self, observations: torch.Tensor) -> torch.Tensor: + """Get encoded features from observations""" + device = get_device_from_parameters(self) + observations = observations.to(device) + if self.encoder is not None: + with torch.inference_mode(): + return self.encoder(observations) + return observations + + +class DefaultImageEncoder(nn.Module): + def __init__(self, config: SACConfig): + super().__init__() + image_key = next(key for key in config.input_features if is_image_feature(key)) + self.image_enc_layers = nn.Sequential( + nn.Conv2d( + in_channels=config.input_features[image_key].shape[0], + out_channels=config.image_encoder_hidden_dim, + kernel_size=7, + stride=2, + ), + nn.ReLU(), + nn.Conv2d( + in_channels=config.image_encoder_hidden_dim, + out_channels=config.image_encoder_hidden_dim, + kernel_size=5, + stride=2, + ), + nn.ReLU(), + nn.Conv2d( + in_channels=config.image_encoder_hidden_dim, + out_channels=config.image_encoder_hidden_dim, + kernel_size=3, + stride=2, + ), + nn.ReLU(), + nn.Conv2d( + in_channels=config.image_encoder_hidden_dim, + out_channels=config.image_encoder_hidden_dim, + kernel_size=3, + stride=2, + ), + nn.ReLU(), + ) + + def forward(self, x): + x = self.image_enc_layers(x) + return x + + +def freeze_image_encoder(image_encoder: nn.Module): + """Freeze all parameters in the encoder""" + for param in image_encoder.parameters(): + param.requires_grad = False + + +class PretrainedImageEncoder(nn.Module): + def __init__(self, config: SACConfig): + super().__init__() + + self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config) + + def _load_pretrained_vision_encoder(self, config: SACConfig): + """Set up CNN encoder""" + from transformers import AutoModel + + self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True) + + if hasattr(self.image_enc_layers.config, "hidden_sizes"): + self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension + elif hasattr(self.image_enc_layers, "fc"): + self.image_enc_out_shape = self.image_enc_layers.fc.in_features + else: + raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN") + return self.image_enc_layers, self.image_enc_out_shape + + def forward(self, x): + enc_feat = self.image_enc_layers(x).last_hidden_state + return enc_feat + + +def orthogonal_init(): + return lambda x: torch.nn.init.orthogonal_(x, gain=1.0) + + +class SpatialLearnedEmbeddings(nn.Module): + def __init__(self, height, width, channel, num_features=8): + """ + PyTorch implementation of learned spatial embeddings + + Args: + height: Spatial height of input features + width: Spatial width of input features + channel: Number of input channels + num_features: Number of output embedding dimensions + """ + super().__init__() + self.height = height + self.width = width + self.channel = channel + self.num_features = num_features + + self.kernel = nn.Parameter(torch.empty(channel, height, width, num_features)) + + nn.init.kaiming_normal_(self.kernel, mode="fan_in", nonlinearity="linear") + + def forward(self, features): + """ + Forward pass for spatial embedding + + Args: + features: Input tensor of shape [B, C, H, W] where B is batch size, + C is number of channels, H is height, and W is width + Returns: + Output tensor of shape [B, C*F] where F is the number of features + """ + + features_expanded = features.unsqueeze(-1) # [B, C, H, W, 1] + kernel_expanded = self.kernel.unsqueeze(0) # [1, C, H, W, F] + + # Element-wise multiplication and spatial reduction + output = (features_expanded * kernel_expanded).sum(dim=(2, 3)) # Sum over H,W dimensions + + # Reshape to combine channel and feature dimensions + output = output.view(output.size(0), -1) # [B, C*F] + + return output + + +class RescaleFromTanh(Transform): + def __init__(self, low: float = -1, high: float = 1): + super().__init__() + + self.low = low + + self.high = high + + def _call(self, x): + # Rescale from (-1, 1) to (low, high) + + return 0.5 * (x + 1.0) * (self.high - self.low) + self.low + + def _inverse(self, y): + # Rescale from (low, high) back to (-1, 1) + + return 2.0 * (y - self.low) / (self.high - self.low) - 1.0 + + def log_abs_det_jacobian(self, x, y): + # log|d(rescale)/dx| = sum(log(0.5 * (high - low))) + + scale = 0.5 * (self.high - self.low) + + return torch.sum(torch.log(scale), dim=-1) + + +class TanhMultivariateNormalDiag(TransformedDistribution): + def __init__(self, loc, scale_diag, low=None, high=None): + base_dist = MultivariateNormal(loc, torch.diag_embed(scale_diag)) + + transforms = [TanhTransform(cache_size=1)] + + if low is not None and high is not None: + low = torch.as_tensor(low) + + high = torch.as_tensor(high) + + transforms.insert(0, RescaleFromTanh(low, high)) + + super().__init__(base_dist, transforms) + + def mode(self): + # Mode is mean of base distribution, passed through transforms + + x = self.base_dist.mean + + for transform in self.transforms: + x = transform(x) + + return x + + def stddev(self): + std = self.base_dist.stddev + + x = std + + for transform in self.transforms: + x = transform(x) + + return x + + +def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict: + converted_params = {} + for outer_key, inner_dict in normalization_params.items(): + converted_params[outer_key] = {} + for key, value in inner_dict.items(): + converted_params[outer_key][key] = torch.tensor(value) + if "image" in outer_key: + converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1) + + return converted_params diff --git a/lerobot/common/policies/sac/reward_model/configuration_classifier.py b/lerobot/common/policies/sac/reward_model/configuration_classifier.py new file mode 100644 index 0000000000..6e2a551d4d --- /dev/null +++ b/lerobot/common/policies/sac/reward_model/configuration_classifier.py @@ -0,0 +1,76 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field + +from lerobot.common.optim.optimizers import AdamWConfig, OptimizerConfig +from lerobot.common.optim.schedulers import LRSchedulerConfig +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import NormalizationMode + + +@PreTrainedConfig.register_subclass(name="reward_classifier") +@dataclass +class RewardClassifierConfig(PreTrainedConfig): + """Configuration for the Reward Classifier model.""" + + name: str = "reward_classifier" + num_classes: int = 2 + hidden_dim: int = 256 + latent_dim: int = 256 + image_embedding_pooling_dim: int = 8 + dropout_rate: float = 0.1 + model_name: str = "helper2424/resnet10" + device: str = "cpu" + model_type: str = "cnn" # "transformer" or "cnn" + num_cameras: int = 2 + learning_rate: float = 1e-4 + weight_decay: float = 0.01 + grad_clip_norm: float = 1.0 + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.MEAN_STD, + } + ) + + @property + def observation_delta_indices(self) -> list | None: + return None + + @property + def action_delta_indices(self) -> list | None: + return None + + @property + def reward_delta_indices(self) -> list | None: + return None + + def get_optimizer_preset(self) -> OptimizerConfig: + return AdamWConfig( + lr=self.learning_rate, + weight_decay=self.weight_decay, + grad_clip_norm=self.grad_clip_norm, + ) + + def get_scheduler_preset(self) -> LRSchedulerConfig | None: + return None + + def validate_features(self) -> None: + """Validate feature configurations.""" + has_image = any(key.startswith("observation.image") for key in self.input_features) + if not has_image: + raise ValueError( + "You must provide an image observation (key starting with 'observation.image') in the input features" + ) diff --git a/lerobot/common/policies/sac/reward_model/modeling_classifier.py b/lerobot/common/policies/sac/reward_model/modeling_classifier.py new file mode 100644 index 0000000000..f537e3aefd --- /dev/null +++ b/lerobot/common/policies/sac/reward_model/modeling_classifier.py @@ -0,0 +1,316 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import torch +from torch import Tensor, nn + +from lerobot.common.constants import OBS_IMAGE, REWARD +from lerobot.common.policies.normalize import Normalize, Unnormalize +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig + + +class ClassifierOutput: + """Wrapper for classifier outputs with additional metadata.""" + + def __init__( + self, + logits: Tensor, + probabilities: Tensor | None = None, + hidden_states: Tensor | None = None, + ): + self.logits = logits + self.probabilities = probabilities + self.hidden_states = hidden_states + + def __repr__(self): + return ( + f"ClassifierOutput(logits={self.logits}, " + f"probabilities={self.probabilities}, " + f"hidden_states={self.hidden_states})" + ) + + +class SpatialLearnedEmbeddings(nn.Module): + def __init__(self, height, width, channel, num_features=8): + """ + PyTorch implementation of learned spatial embeddings + + Args: + height: Spatial height of input features + width: Spatial width of input features + channel: Number of input channels + num_features: Number of output embedding dimensions + """ + super().__init__() + self.height = height + self.width = width + self.channel = channel + self.num_features = num_features + + self.kernel = nn.Parameter(torch.empty(channel, height, width, num_features)) + + nn.init.kaiming_normal_(self.kernel, mode="fan_in", nonlinearity="linear") + + def forward(self, features): + """ + Forward pass for spatial embedding + + Args: + features: Input tensor of shape [B, H, W, C] or [H, W, C] if no batch + Returns: + Output tensor of shape [B, C*F] or [C*F] if no batch + """ + + features = features.last_hidden_state + + original_shape = features.shape + if features.dim() == 3: + features = features.unsqueeze(0) # Add batch dim + + features_expanded = features.unsqueeze(-1) # [B, H, W, C, 1] + kernel_expanded = self.kernel.unsqueeze(0) # [1, H, W, C, F] + + # Element-wise multiplication and spatial reduction + output = (features_expanded * kernel_expanded).sum(dim=(2, 3)) # Sum H,W + + # Reshape to combine channel and feature dimensions + output = output.view(output.size(0), -1) # [B, C*F] + + # Remove batch dim + if len(original_shape) == 3: + output = output.squeeze(0) + + return output + + +class Classifier(PreTrainedPolicy): + """Image classifier built on top of a pre-trained encoder.""" + + name = "reward_classifier" + config_class = RewardClassifierConfig + + def __init__( + self, + config: RewardClassifierConfig, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + ): + from transformers import AutoModel + + super().__init__(config) + self.config = config + + # Initialize normalization (standardized with the policy framework) + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + + # Set up encoder + encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True) + # Extract vision model if we're given a multimodal model + if hasattr(encoder, "vision_model"): + logging.info("Multimodal model detected - using vision encoder only") + self.encoder = encoder.vision_model + self.vision_config = encoder.config.vision_config + else: + self.encoder = encoder + self.vision_config = getattr(encoder, "config", None) + + # Model type from config + self.is_cnn = self.config.model_type == "cnn" + + # For CNNs, initialize backbone + if self.is_cnn: + self._setup_cnn_backbone() + + self._freeze_encoder() + + # Extract image keys from input_features + self.image_keys = [ + key.replace(".", "_") for key in config.input_features if key.startswith(OBS_IMAGE) + ] + + if self.is_cnn: + self.encoders = nn.ModuleDict() + for image_key in self.image_keys: + encoder = self._create_single_encoder() + self.encoders[image_key] = encoder + + self._build_classifier_head() + + def _setup_cnn_backbone(self): + """Set up CNN encoder""" + if hasattr(self.encoder, "fc"): + self.feature_dim = self.encoder.fc.in_features + self.encoder = nn.Sequential(*list(self.encoder.children())[:-1]) + elif hasattr(self.encoder.config, "hidden_sizes"): + self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension + else: + raise ValueError("Unsupported CNN architecture") + + def _freeze_encoder(self) -> None: + """Freeze the encoder parameters.""" + for param in self.encoder.parameters(): + param.requires_grad = False + + def _create_single_encoder(self): + encoder = nn.Sequential( + self.encoder, + SpatialLearnedEmbeddings( + height=4, + width=4, + channel=self.feature_dim, + num_features=self.config.image_embedding_pooling_dim, + ), + nn.Dropout(self.config.dropout_rate), + nn.Linear(self.feature_dim * self.config.image_embedding_pooling_dim, self.config.latent_dim), + nn.LayerNorm(self.config.latent_dim), + nn.Tanh(), + ) + + return encoder + + def _build_classifier_head(self) -> None: + """Initialize the classifier head architecture.""" + # Get input dimension based on model type + if self.is_cnn: + input_dim = self.config.latent_dim + else: # Transformer models + if hasattr(self.encoder.config, "hidden_size"): + input_dim = self.encoder.config.hidden_size + else: + raise ValueError("Unsupported transformer architecture since hidden_size is not found") + + self.classifier_head = nn.Sequential( + nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim), + nn.Dropout(self.config.dropout_rate), + nn.LayerNorm(self.config.hidden_dim), + nn.ReLU(), + nn.Linear( + self.config.hidden_dim, + 1 if self.config.num_classes == 2 else self.config.num_classes, + ), + ) + + def _get_encoder_output(self, x: torch.Tensor, image_key: str) -> torch.Tensor: + """Extract the appropriate output from the encoder.""" + with torch.no_grad(): + if self.is_cnn: + # The HF ResNet applies pooling internally + outputs = self.encoders[image_key](x) + return outputs + else: # Transformer models + outputs = self.encoder(x) + return outputs.last_hidden_state[:, 0, :] + + def extract_images_and_labels(self, batch: dict[str, Tensor]) -> tuple[list, Tensor]: + """Extract image tensors and label tensors from batch.""" + # Check for both OBS_IMAGE and OBS_IMAGES prefixes + images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)] + labels = batch[REWARD] + + return images, labels + + def predict(self, xs: list) -> ClassifierOutput: + """Forward pass of the classifier for inference.""" + encoder_outputs = torch.hstack( + [self._get_encoder_output(x, img_key) for x, img_key in zip(xs, self.image_keys, strict=True)] + ) + logits = self.classifier_head(encoder_outputs) + + if self.config.num_classes == 2: + logits = logits.squeeze(-1) + probabilities = torch.sigmoid(logits) + else: + probabilities = torch.softmax(logits, dim=-1) + + return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs) + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]: + """Standard forward pass for training compatible with train.py.""" + # Normalize inputs if needed + batch = self.normalize_inputs(batch) + batch = self.normalize_targets(batch) + + # Extract images and labels + images, labels = self.extract_images_and_labels(batch) + + # Get predictions + outputs = self.predict(images) + + # Calculate loss + if self.config.num_classes == 2: + # Binary classification + loss = nn.functional.binary_cross_entropy_with_logits(outputs.logits, labels) + predictions = (torch.sigmoid(outputs.logits) > 0.5).float() + else: + # Multi-class classification + loss = nn.functional.cross_entropy(outputs.logits, labels.long()) + predictions = torch.argmax(outputs.logits, dim=1) + + # Calculate accuracy for logging + correct = (predictions == labels).sum().item() + total = labels.size(0) + accuracy = 100 * correct / total + + # Return loss and metrics for logging + output_dict = { + "accuracy": accuracy, + "correct": correct, + "total": total, + } + + return loss, output_dict + + def predict_reward(self, batch, threshold=0.5): + """Eval method. Returns predicted reward with the decision threshold as argument.""" + # Check for both OBS_IMAGE and OBS_IMAGES prefixes + batch = self.normalize_inputs(batch) + batch = self.normalize_targets(batch) + + # Extract images from batch dict + images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)] + + if self.config.num_classes == 2: + probs = self.predict(images).probabilities + logging.debug(f"Predicted reward images: {probs}") + return (probs > threshold).float() + else: + return torch.argmax(self.predict(images).probabilities, dim=1) + + def get_optim_params(self): + """Return optimizer parameters for the policy.""" + return self.parameters() + + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """ + This method is required by PreTrainedPolicy but not used for reward classifiers. + The reward classifier is not an actor and does not select actions. + """ + raise NotImplementedError("Reward classifiers do not select actions") + + def reset(self): + """ + This method is required by PreTrainedPolicy but not used for reward classifiers. + The reward classifier is not an actor and does not select actions. + """ + pass diff --git a/lerobot/common/policies/smolvla/modeling_smolvla.py b/lerobot/common/policies/smolvla/modeling_smolvla.py index 6ac2d3e7ee..5e0a9622e0 100644 --- a/lerobot/common/policies/smolvla/modeling_smolvla.py +++ b/lerobot/common/policies/smolvla/modeling_smolvla.py @@ -53,8 +53,11 @@ """ import math +import os +import re from collections import deque +import safetensors import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn @@ -73,6 +76,102 @@ ) from lerobot.common.utils.utils import get_safe_dtype +# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker +_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_") + + +def canonicalise(k: str) -> str: + """ + Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a + normalisation-buffer key. + """ + return _VARIANT_RE.sub(".buffer_", k) + + +def standardise_state_dict( + checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True +) -> tuple[dict[str, torch.Tensor], list[str]]: + """ + • Re-keys `checkpoint ` so that every entry matches the *reference* key set. + • If several variant keys collapse to the same canonical name we keep the + first one and log the collision. + • Returns the new dict + a list of entries that could not be matched. + """ + out, collisions, unmatched = {}, {}, [] + + for k, v in checkpoint.items(): + canon = canonicalise(k) + if canon in ref_keys: + if canon in out: # duplicate after collapsing + collisions.setdefault(canon, []).append(k) + else: + out[canon] = v + else: + unmatched.append(k) + + if verbose: + for canon, variants in collisions.items(): + print(f"[standardise_state_dict] '{canon}' ← {variants}") + if unmatched: + print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys") + + out.update({k: checkpoint[k] for k in unmatched}) + return out, unmatched + + +def rename_checkpoint_keys(checkpoint: dict, rename_str: str): + """ + Renames keys in a checkpoint dictionary based on the given rename string. + + Args: + checkpoint (dict): The checkpoint dictionary. + rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2". + + Returns: + dict: The modified checkpoint with renamed keys. + """ + + rename_dict = dict(pair.split("//") for pair in rename_str.split(",")) + + new_checkpoint = {} + for k, v in checkpoint.items(): + for old_key, new_key in rename_dict.items(): + if old_key in k: + k = k.replace(old_key, new_key) + new_checkpoint[k] = v + return new_checkpoint + + +def load_smolvla( + model: torch.nn.Module, + filename: str | os.PathLike, + *, + device: str = "cpu", + checkpoint_keys_mapping: str = "", +) -> torch.nn.Module: + state_dict = safetensors.torch.load_file(filename, device=device) + + # Optional user-supplied renames (e.g. "model._orig_mod.//model.") + if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping: + state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping) + + state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys())) + + # HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset + norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs") + state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)} + + missing, unexpected = model.load_state_dict(state_dict, strict=False) + + if not all(key.startswith(norm_keys) for key in missing) or unexpected: + raise RuntimeError( + "SmolVLA %d missing / %d unexpected keys", + len(missing), + len(unexpected), + ) + + return model + def create_sinusoidal_pos_embedding( time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" @@ -264,6 +363,23 @@ def reset(self): ACTION: deque(maxlen=self.config.n_action_steps), } + # HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues + @classmethod + def _load_as_safetensor( + cls, + model: "SmolVLAPolicy", + model_file: str, + map_location: str, + strict: bool, + ): + safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) + return load_smolvla( + model, + model_file, + device=map_location, + checkpoint_keys_mapping="model._orig_mod.//model.", + ) + def get_optim_params(self) -> dict: return self.parameters() @@ -387,10 +503,14 @@ def prepare_language(self, batch) -> tuple[Tensor, Tensor]: """Tokenize the text input""" device = batch[OBS_STATE].device tasks = batch["task"] + if isinstance(tasks, str): + tasks = [tasks] + if len(tasks) == 1: tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])] tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks] + tokenized_prompt = self.language_tokenizer.__call__( tasks, padding=self.config.pad_language_to, diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index 31220aa935..476e6decd2 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -17,8 +17,8 @@ """Implementation of Finetuning Offline World Models in the Real World. The comments in this code may sometimes refer to these references: - TD-MPC paper: Temporal Difference Learning for Model Predictive Control (https://arxiv.org/abs/2203.04955) - FOWM paper: Finetuning Offline World Models in the Real World (https://arxiv.org/abs/2310.16029) + TD-MPC paper: Temporal Difference Learning for Model Predictive Control (https://huggingface.co/papers/2203.04955) + FOWM paper: Finetuning Offline World Models in the Real World (https://huggingface.co/papers/2310.16029) """ # ruff: noqa: N806 diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index 97a08e2f4f..44006a5b21 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -162,7 +162,7 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) batch = self.normalize_targets(batch) - # VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181) + # VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://huggingface.co/papers/2403.03181) if not self.vqbet.action_head.vqvae_model.discretized.item(): # loss: total loss of training RVQ # n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`. @@ -185,7 +185,7 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: class SpatialSoftmax(nn.Module): """ Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al. - (https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation. + (https://huggingface.co/papers/1509.06113). A minimal port of the robomimic implementation. At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass" of activations of each channel, i.e., keypoints in the image space for the policy to focus on. @@ -387,7 +387,7 @@ def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]: # only extract the output tokens at the position of action query: # Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models, - # mapping sequential observation to sequential action (please refer to section 2.2 in BeT paper https://arxiv.org/pdf/2206.11251). + # mapping sequential observation to sequential action (please refer to section 2.2 in BeT paper https://huggingface.co/papers/2206.11251). # Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional). if len_additional_action_token > 0: features = torch.cat( @@ -824,8 +824,8 @@ def get_action_from_latent(self, latent): return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0]) def get_code(self, state): - # in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181) - # this function outputs the `GT code` of given action using frozen encoder and quantization layers. (please refer to Figure 2. in the paper https://arxiv.org/pdf/2403.03181) + # in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://huggingface.co/papers/2403.03181) + # this function outputs the `GT code` of given action using frozen encoder and quantization layers. (please refer to Figure 2. in the paper https://huggingface.co/papers/2403.03181) state = einops.rearrange(state, "N T A -> N (T A)") with torch.no_grad(): state_rep = self.encoder(state) @@ -838,7 +838,7 @@ def get_code(self, state): return state_vq, vq_code def vqvae_forward(self, state): - # This function passes the given data through Residual VQ with Encoder and Decoder. Please refer to section 3.2 in the paper https://arxiv.org/pdf/2403.03181). + # This function passes the given data through Residual VQ with Encoder and Decoder. Please refer to section 3.2 in the paper https://huggingface.co/papers/2403.03181). state = einops.rearrange(state, "N T A -> N (T A)") # We start with passing action (or action chunk) at:t+n through the encoder ϕ. state_rep = self.encoder(state) diff --git a/lerobot/common/policies/vqbet/vqbet_utils.py b/lerobot/common/policies/vqbet/vqbet_utils.py index 139d119edc..09a86c07ba 100644 --- a/lerobot/common/policies/vqbet/vqbet_utils.py +++ b/lerobot/common/policies/vqbet/vqbet_utils.py @@ -336,7 +336,7 @@ class ResidualVQ(nn.Module): """ Residual VQ is composed of multiple VectorQuantize layers. - Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + Follows Algorithm 1. in https://huggingface.co/papers/2107.03312 "Residual Vector Quantizer (a.k.a. multi-stage vector quantizer [36]) cascades Nq layers of VQ as follows. The unquantized input vector is passed through a first VQ and quantization residuals are computed. The residuals are then iteratively quantized by a sequence of additional Nq -1 vector quantizers, as described in Algorithm 1." @@ -1006,7 +1006,7 @@ def gumbel_sample( if not straight_through or temperature <= 0.0 or not training: return ind, one_hot - # use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612 + # use reinmax for better second-order accuracy - https://huggingface.co/papers/2304.08612 # algorithm 2 if reinmax: @@ -1156,7 +1156,7 @@ def batched_embedding(indices, embeds): def orthogonal_loss_fn(t): - # eq (2) from https://arxiv.org/abs/2112.00384 + # eq (2) from https://huggingface.co/papers/2112.00384 h, n = t.shape[:2] normed_codes = F.normalize(t, p=2, dim=-1) cosine_sim = einsum("h i d, h j d -> h i j", normed_codes, normed_codes) diff --git a/lerobot/common/robots/lekiwi/config_lekiwi.py b/lerobot/common/robots/lekiwi/config_lekiwi.py index 9876ada210..022d09cdd6 100644 --- a/lerobot/common/robots/lekiwi/config_lekiwi.py +++ b/lerobot/common/robots/lekiwi/config_lekiwi.py @@ -20,10 +20,21 @@ from ..config import RobotConfig +def lekiwi_cameras_config() -> dict[str, CameraConfig]: + return { + "front": OpenCVCameraConfig( + index_or_path="/dev/video0", fps=30, width=640, height=480, rotation=Cv2Rotation.ROTATE_180 + ), + "wrist": OpenCVCameraConfig( + index_or_path="/dev/video2", fps=30, width=480, height=640, rotation=Cv2Rotation.ROTATE_90 + ), + } + + @RobotConfig.register_subclass("lekiwi") @dataclass class LeKiwiConfig(RobotConfig): - port = "/dev/ttyACM0" # port to connect to the bus + port: str = "/dev/ttyACM0" # port to connect to the bus disable_torque_on_disconnect: bool = True @@ -32,14 +43,7 @@ class LeKiwiConfig(RobotConfig): # the number of motors in your follower arms. max_relative_target: int | None = None - cameras: dict[str, CameraConfig] = field( - default_factory=lambda: { - "front": OpenCVCameraConfig(index_or_path="/dev/video0", fps=30, width=640, height=480), - "wrist": OpenCVCameraConfig( - index_or_path="/dev/video2", fps=30, width=640, height=480, rotation=Cv2Rotation.ROTATE_180 - ), - } - ) + cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config) # Set to `True` for backward compatibility with previous policies/dataset use_degrees: bool = False @@ -86,5 +90,7 @@ class LeKiwiClientConfig(RobotConfig): } ) + cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config) + polling_timeout_ms: int = 15 connect_timeout_s: int = 5 diff --git a/lerobot/common/robots/lekiwi/lekiwi.mdx b/lerobot/common/robots/lekiwi/lekiwi.mdx index 68082d8a22..dd39a90399 100644 --- a/lerobot/common/robots/lekiwi/lekiwi.mdx +++ b/lerobot/common/robots/lekiwi/lekiwi.mdx @@ -43,9 +43,69 @@ First, we will assemble the two SO100/SO101 arms. One to attach to the mobile ba - [Assemble SO101](./so101#step-by-step-assembly-instructions) - [Assemble LeKiwi](https://github.com/SIGRobotics-UIUC/LeKiwi/blob/main/Assembly.md) +### Find the USB ports associated with motor board + +To find the port for each bus servo adapter, run this script: +```bash +python lerobot/find_port.py +``` + + + + +Example output: + +``` +Finding all available ports for the MotorBus. +['/dev/tty.usbmodem575E0032081'] +Remove the USB cable from your MotorsBus and press Enter when done. + +[...Disconnect corresponding leader or follower arm and press Enter...] + +The port of this MotorsBus is /dev/tty.usbmodem575E0032081 +Reconnect the USB cable. +``` + +Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your board. + + + + +On Linux, you might need to give access to the USB ports by running: +```bash +sudo chmod 666 /dev/ttyACM0 +sudo chmod 666 /dev/ttyACM1 +``` + +Example output: + +``` +Finding all available ports for the MotorBus. +['/dev/ttyACM0'] +Remove the usb cable from your MotorsBus and press Enter when done. + +[...Disconnect corresponding leader or follower arm and press Enter...] + +The port of this MotorsBus is /dev/ttyACM0 +Reconnect the USB cable. +``` + +Where the found port is: `/dev/ttyACM0` corresponding to your board. + + + + ### Configure motors The instructions for configuring the motors can be found in the SO101 [docs](./so101#configure-the-motors). Besides the ids for the arm motors, we also need to set the motor ids for the mobile base. These need to be in a specific order to work. Below an image of the motor ids and motor mounting positions for the mobile base. Note that we only use one Motor Control board on LeKiwi. This means the motor ids for the wheels are 7, 8 and 9. +You can run this command to setup motors for LeKiwi. It will first setup the motors for arm (id 6..1) and then setup motors for wheels (9,8,7) + +```bash +python -m lerobot.setup_motors \ + --robot.type=lekiwi \ + --robot.port=/dev/tty.usbmodem58760431551 # <- paste here the port found at previous step +``` + Motor ID's for mobile robot ### Troubleshoot communication diff --git a/lerobot/common/robots/lekiwi/lekiwi.py b/lerobot/common/robots/lekiwi/lekiwi.py index a1c2ffa14b..f6a9b8bf13 100644 --- a/lerobot/common/robots/lekiwi/lekiwi.py +++ b/lerobot/common/robots/lekiwi/lekiwi.py @@ -23,7 +23,6 @@ import numpy as np from lerobot.common.cameras.utils import make_cameras_from_configs -from lerobot.common.constants import OBS_IMAGES, OBS_STATE from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode from lerobot.common.motors.feetech import ( @@ -65,8 +64,8 @@ def __init__(self, config: LeKiwiConfig): "arm_gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100), # base "base_left_wheel": Motor(7, "sts3215", MotorNormMode.RANGE_M100_100), - "base_right_wheel": Motor(8, "sts3215", MotorNormMode.RANGE_M100_100), - "base_back_wheel": Motor(9, "sts3215", MotorNormMode.RANGE_M100_100), + "base_back_wheel": Motor(8, "sts3215", MotorNormMode.RANGE_M100_100), + "base_right_wheel": Motor(9, "sts3215", MotorNormMode.RANGE_M100_100), }, calibration=self.calibration, ) @@ -249,7 +248,7 @@ def _body_to_wheel_raw( velocity_vector = np.array([x, y, theta_rad]) # Define the wheel mounting angles with a -90° offset. - angles = np.radians(np.array([240, 120, 0]) - 90) + angles = np.radians(np.array([240, 0, 120]) - 90) # Build the kinematic matrix: each row maps body velocities to a wheel’s linear speed. # The third column (base_radius) accounts for the effect of rotation. m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles]) @@ -295,10 +294,7 @@ def _wheel_raw_to_body( base_radius : Distance from the robot center to each wheel (meters). Returns: - A dict (x_cmd, y_cmd, theta_cmd) where: - OBS_STATE.x_cmd : Linear velocity in x (m/s). - OBS_STATE.y_cmd : Linear velocity in y (m/s). - OBS_STATE.theta_cmd : Rotational velocity in deg/s. + A dict (x.vel, y.vel, theta.vel) all in m/s """ # Convert each raw command back to an angular speed in deg/s. @@ -316,7 +312,7 @@ def _wheel_raw_to_body( wheel_linear_speeds = wheel_radps * wheel_radius # Define the wheel mounting angles with a -90° offset. - angles = np.radians(np.array([240, 120, 0]) - 90) + angles = np.radians(np.array([240, 0, 120]) - 90) m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles]) # Solve the inverse kinematics: body_velocity = M⁻¹ · wheel_linear_speeds. @@ -347,16 +343,15 @@ def get_observation(self) -> dict[str, Any]: arm_state = {f"{k}.pos": v for k, v in arm_pos.items()} - flat_states = {**arm_state, **base_vel} + obs_dict = {**arm_state, **base_vel} - obs_dict = {f"{OBS_STATE}": flat_states} dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read state: {dt_ms:.1f}ms") # Capture images from cameras for cam_key, cam in self.cameras.items(): start = time.perf_counter() - obs_dict[f"{OBS_IMAGES}.{cam_key}"] = cam.async_read() + obs_dict[cam_key] = cam.async_read() dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") diff --git a/lerobot/common/robots/lekiwi/lekiwi_client.py b/lerobot/common/robots/lekiwi/lekiwi_client.py index 927ed49f53..f79b7f81a0 100644 --- a/lerobot/common/robots/lekiwi/lekiwi_client.py +++ b/lerobot/common/robots/lekiwi/lekiwi_client.py @@ -25,7 +25,6 @@ import torch import zmq -from lerobot.common.constants import OBS_IMAGES, OBS_STATE from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..robot import Robot @@ -92,11 +91,8 @@ def _state_order(self) -> tuple[str, ...]: return tuple(self._state_ft.keys()) @cached_property - def _cameras_ft(self) -> dict[str, tuple]: - return { - "front": (480, 640, 3), - "wrist": (640, 480, 3), - } + def _cameras_ft(self) -> dict[str, tuple[int, int, int]]: + return {name: (cfg.height, cfg.width, 3) for name, cfg in self.config.cameras.items()} @cached_property def observation_features(self) -> dict[str, type | tuple]: @@ -199,7 +195,7 @@ def _remote_state_from_obs( self, observation: Dict[str, Any] ) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]: """Extracts frames, and state from the parsed observation.""" - flat_state = observation[OBS_STATE] + flat_state = {key: value for key, value in observation.items() if key in self._state_ft} state_vec = np.array( [flat_state.get(k, 0.0) for k in self._state_order], @@ -207,7 +203,11 @@ def _remote_state_from_obs( ) # Decode images - image_observation = {k: v for k, v in observation.items() if k.startswith(OBS_IMAGES)} + image_observation = { + f"observation.images.{key}": value + for key, value in observation.items() + if key in self._cameras_ft + } current_frames: Dict[str, np.ndarray] = {} for cam_name, image_b64 in image_observation.items(): frame = self._decode_image_from_b64(image_b64) diff --git a/lerobot/common/robots/lekiwi/lekiwi_host.py b/lerobot/common/robots/lekiwi/lekiwi_host.py index 014c965b7f..1155cf71c2 100644 --- a/lerobot/common/robots/lekiwi/lekiwi_host.py +++ b/lerobot/common/robots/lekiwi/lekiwi_host.py @@ -22,8 +22,6 @@ import cv2 import zmq -from lerobot.common.constants import OBS_IMAGES - from .config_lekiwi import LeKiwiConfig, LeKiwiHostConfig from .lekiwi import LeKiwi @@ -95,12 +93,12 @@ def main(): # Encode ndarrays to base64 strings for cam_key, _ in robot.cameras.items(): ret, buffer = cv2.imencode( - ".jpg", last_observation[f"{OBS_IMAGES}.{cam_key}"], [int(cv2.IMWRITE_JPEG_QUALITY), 90] + ".jpg", last_observation[cam_key], [int(cv2.IMWRITE_JPEG_QUALITY), 90] ) if ret: - last_observation[f"{OBS_IMAGES}.{cam_key}"] = base64.b64encode(buffer).decode("utf-8") + last_observation[cam_key] = base64.b64encode(buffer).decode("utf-8") else: - last_observation[f"{OBS_IMAGES}.{cam_key}"] = "" + last_observation[cam_key] = "" # Send the observation to the remote agent try: diff --git a/lerobot/common/robots/robot.py b/lerobot/common/robots/robot.py index e5af9e79f0..ec2b155f35 100644 --- a/lerobot/common/robots/robot.py +++ b/lerobot/common/robots/robot.py @@ -27,7 +27,16 @@ # TODO(aliberts): action/obs typing such as Generic[ObsType, ActType] similar to gym.Env ? # https://github.com/Farama-Foundation/Gymnasium/blob/3287c869f9a48d99454306b0d4b4ec537f0f35e3/gymnasium/core.py#L23 class Robot(abc.ABC): - """The main LeRobot class for implementing robots.""" + """ + The base abstract class for all LeRobot-compatible robots. + + This class provides a standardized interface for interacting with physical robots. + Subclasses must implement all abstract methods and properties to be usable. + + Attributes: + config_class (RobotConfig): The expected configuration class for this robot. + name (str): The unique robot name used to identify this robot type. + """ # Set these in ALL subclasses config_class: RobotConfig @@ -52,58 +61,124 @@ def __str__(self) -> str: @property @abc.abstractmethod def observation_features(self) -> dict: + """ + A dictionary describing the structure and types of the observations produced by the robot. + Its structure (keys) should match the structure of what is returned by :pymeth:`get_observation`. + Values for the dict should either be: + - The type of the value if it's a simple value, e.g. `float` for single proprioceptive value (a joint's position/velocity) + - A tuple representing the shape if it's an array-type value, e.g. `(height, width, channel)` for images + + Note: this property should be able to be called regardless of whether the robot is connected or not. + """ pass @property @abc.abstractmethod def action_features(self) -> dict: + """ + A dictionary describing the structure and types of the actions expected by the robot. Its structure + (keys) should match the structure of what is passed to :pymeth:`send_action`. Values for the dict + should be the type of the value if it's a simple value, e.g. `float` for single proprioceptive value + (a joint's goal position/velocity) + + Note: this property should be able to be called regardless of whether the robot is connected or not. + """ pass @property @abc.abstractmethod def is_connected(self) -> bool: + """ + Whether the robot is currently connected or not. If `False`, calling :pymeth:`get_observation` or + :pymeth:`send_action` should raise an error. + """ pass @abc.abstractmethod def connect(self, calibrate: bool = True) -> None: - """Connects to the robot.""" + """ + Establish communication with the robot. + + Args: + calibrate (bool): If True, automatically calibrate the robot after connecting if it's not + calibrated or needs calibration (this is hardware-dependant). + """ pass @property @abc.abstractmethod def is_calibrated(self) -> bool: + """Whether the robot is currently calibrated or not. Should be always `True` if not applicable""" pass @abc.abstractmethod def calibrate(self) -> None: - """Calibrates the robot.""" + """ + Calibrate the robot if applicable. If not, this should be a no-op. + + This method should collect any necessary data (e.g., motor offsets) and update the + :pyattr:`calibration` dictionary accordingly. + """ pass def _load_calibration(self, fpath: Path | None = None) -> None: + """ + Helper to load calibration data from the specified file. + + Args: + fpath (Path | None): Optional path to the calibration file. Defaults to `self.calibration_fpath`. + """ fpath = self.calibration_fpath if fpath is None else fpath with open(fpath) as f, draccus.config_type("json"): self.calibration = draccus.load(dict[str, MotorCalibration], f) def _save_calibration(self, fpath: Path | None = None) -> None: + """ + Helper to save calibration data to the specified file. + + Args: + fpath (Path | None): Optional path to save the calibration file. Defaults to `self.calibration_fpath`. + """ fpath = self.calibration_fpath if fpath is None else fpath with open(fpath, "w") as f, draccus.config_type("json"): draccus.dump(self.calibration, f, indent=4) @abc.abstractmethod def configure(self) -> None: + """ + Apply any one-time or runtime configuration to the robot. + This may include setting motor parameters, control modes, or initial state. + """ pass @abc.abstractmethod def get_observation(self) -> dict[str, Any]: - """Gets observation from the robot.""" + """ + Retrieve the current observation from the robot. + + Returns: + dict[str, Any]: A flat dictionary representing the robot's current sensory state. Its structure + should match :pymeth:`observation_features`. + """ + pass @abc.abstractmethod def send_action(self, action: dict[str, Any]) -> dict[str, Any]: - """Sends actions to the robot.""" + """ + Send an action command to the robot. + + Args: + action (dict[str, Any]): Dictionary representing the desired action. Its structure should match + :pymeth:`action_features`. + + Returns: + dict[str, Any]: The action actually sent to the motors potentially clipped or modified, e.g. by + safety limits on velocity. + """ pass @abc.abstractmethod def disconnect(self) -> None: - """Disconnects from the robot.""" + """Disconnect from the robot and perform any necessary cleanup.""" pass diff --git a/lerobot/common/robots/so100_follower/__init__.py b/lerobot/common/robots/so100_follower/__init__.py index 087fd64562..63c3e1c17a 100644 --- a/lerobot/common/robots/so100_follower/__init__.py +++ b/lerobot/common/robots/so100_follower/__init__.py @@ -1,2 +1,3 @@ -from .config_so100_follower import SO100FollowerConfig +from .config_so100_follower import SO100FollowerConfig, SO100FollowerEndEffectorConfig from .so100_follower import SO100Follower +from .so100_follower_end_effector import SO100FollowerEndEffector diff --git a/lerobot/common/robots/so100_follower/config_so100_follower.py b/lerobot/common/robots/so100_follower/config_so100_follower.py index 2a5a966ee2..b76675d26a 100644 --- a/lerobot/common/robots/so100_follower/config_so100_follower.py +++ b/lerobot/common/robots/so100_follower/config_so100_follower.py @@ -37,3 +37,27 @@ class SO100FollowerConfig(RobotConfig): # Set to `True` for backward compatibility with previous policies/dataset use_degrees: bool = False + + +@RobotConfig.register_subclass("so100_follower_end_effector") +@dataclass +class SO100FollowerEndEffectorConfig(SO100FollowerConfig): + """Configuration for the SO100FollowerEndEffector robot.""" + + # Default bounds for the end-effector position (in meters) + end_effector_bounds: dict[str, list[float]] = field( + default_factory=lambda: { + "min": [-1.0, -1.0, -1.0], # min x, y, z + "max": [1.0, 1.0, 1.0], # max x, y, z + } + ) + + max_gripper_pos: float = 50 + + end_effector_step_sizes: dict[str, float] = field( + default_factory=lambda: { + "x": 0.02, + "y": 0.02, + "z": 0.02, + } + ) diff --git a/lerobot/common/robots/so100_follower/so100_follower_end_effector.py b/lerobot/common/robots/so100_follower/so100_follower_end_effector.py new file mode 100644 index 0000000000..82e89305b3 --- /dev/null +++ b/lerobot/common/robots/so100_follower/so100_follower_end_effector.py @@ -0,0 +1,193 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from typing import Any + +import numpy as np + +from lerobot.common.cameras import make_cameras_from_configs +from lerobot.common.errors import DeviceNotConnectedError +from lerobot.common.model.kinematics import RobotKinematics +from lerobot.common.motors import Motor, MotorNormMode +from lerobot.common.motors.feetech import FeetechMotorsBus + +from . import SO100Follower +from .config_so100_follower import SO100FollowerEndEffectorConfig + +logger = logging.getLogger(__name__) +EE_FRAME = "gripper_tip" + + +class SO100FollowerEndEffector(SO100Follower): + """ + SO100Follower robot with end-effector space control. + + This robot inherits from SO100Follower but transforms actions from + end-effector space to joint space before sending them to the motors. + """ + + config_class = SO100FollowerEndEffectorConfig + name = "so100_follower_end_effector" + + def __init__(self, config: SO100FollowerEndEffectorConfig): + super().__init__(config) + self.bus = FeetechMotorsBus( + port=self.config.port, + motors={ + "shoulder_pan": Motor(1, "sts3215", MotorNormMode.DEGREES), + "shoulder_lift": Motor(2, "sts3215", MotorNormMode.DEGREES), + "elbow_flex": Motor(3, "sts3215", MotorNormMode.DEGREES), + "wrist_flex": Motor(4, "sts3215", MotorNormMode.DEGREES), + "wrist_roll": Motor(5, "sts3215", MotorNormMode.DEGREES), + "gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100), + }, + calibration=self.calibration, + ) + + self.cameras = make_cameras_from_configs(config.cameras) + + self.config = config + + # Initialize the kinematics module for the so100 robot + self.kinematics = RobotKinematics(robot_type="so_new_calibration") + + # Store the bounds for end-effector position + self.end_effector_bounds = self.config.end_effector_bounds + + self.current_ee_pos = None + self.current_joint_pos = None + + @property + def action_features(self) -> dict[str, Any]: + """ + Define action features for end-effector control. + Returns dictionary with dtype, shape, and names. + """ + return { + "dtype": "float32", + "shape": (4,), + "names": {"delta_x": 0, "delta_y": 1, "delta_z": 2, "gripper": 3}, + } + + def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + """ + Transform action from end-effector space to joint space and send to motors. + + Args: + action: Dictionary with keys 'delta_x', 'delta_y', 'delta_z' for end-effector control + or a numpy array with [delta_x, delta_y, delta_z] + + Returns: + The joint-space action that was sent to the motors + """ + + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + # Convert action to numpy array if not already + if isinstance(action, dict): + if all(k in action for k in ["delta_x", "delta_y", "delta_z"]): + delta_ee = np.array( + [ + action["delta_x"] * self.config.end_effector_step_sizes["x"], + action["delta_y"] * self.config.end_effector_step_sizes["y"], + action["delta_z"] * self.config.end_effector_step_sizes["z"], + ], + dtype=np.float32, + ) + if "gripper" not in action: + action["gripper"] = [1.0] + action = np.append(delta_ee, action["gripper"]) + else: + logger.warning( + f"Expected action keys 'delta_x', 'delta_y', 'delta_z', got {list(action.keys())}" + ) + action = np.zeros(4, dtype=np.float32) + + if self.current_joint_pos is None: + # Read current joint positions + current_joint_pos = self.bus.sync_read("Present_Position") + self.current_joint_pos = np.array([current_joint_pos[name] for name in self.bus.motors]) + + # Calculate current end-effector position using forward kinematics + if self.current_ee_pos is None: + self.current_ee_pos = self.kinematics.forward_kinematics(self.current_joint_pos, frame=EE_FRAME) + + # Set desired end-effector position by adding delta + desired_ee_pos = np.eye(4) + desired_ee_pos[:3, :3] = self.current_ee_pos[:3, :3] # Keep orientation + + # Add delta to position and clip to bounds + desired_ee_pos[:3, 3] = self.current_ee_pos[:3, 3] + action[:3] + if self.end_effector_bounds is not None: + desired_ee_pos[:3, 3] = np.clip( + desired_ee_pos[:3, 3], + self.end_effector_bounds["min"], + self.end_effector_bounds["max"], + ) + + # Compute inverse kinematics to get joint positions + target_joint_values_in_degrees = self.kinematics.ik( + self.current_joint_pos, desired_ee_pos, position_only=True, frame=EE_FRAME + ) + + target_joint_values_in_degrees = np.clip(target_joint_values_in_degrees, -180.0, 180.0) + # Create joint space action dictionary + joint_action = { + f"{key}.pos": target_joint_values_in_degrees[i] for i, key in enumerate(self.bus.motors.keys()) + } + + # Handle gripper separately if included in action + # Gripper delta action is in the range 0 - 2, + # We need to shift the action to the range -1, 1 so that we can expand it to -Max_gripper_pos, Max_gripper_pos + joint_action["gripper.pos"] = np.clip( + self.current_joint_pos[-1] + (action[-1] - 1) * self.config.max_gripper_pos, + 5, + self.config.max_gripper_pos, + ) + + self.current_ee_pos = desired_ee_pos.copy() + self.current_joint_pos = target_joint_values_in_degrees.copy() + self.current_joint_pos[-1] = joint_action["gripper.pos"] + + # Send joint space action to parent class + return super().send_action(joint_action) + + def get_observation(self) -> dict[str, Any]: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + # Read arm position + start = time.perf_counter() + obs_dict = self.bus.sync_read("Present_Position") + obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()} + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read state: {dt_ms:.1f}ms") + + # Capture images from cameras + for cam_key, cam in self.cameras.items(): + start = time.perf_counter() + obs_dict[cam_key] = cam.async_read() + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") + + return obs_dict + + def reset(self): + self.current_ee_pos = None + self.current_joint_pos = None diff --git a/lerobot/common/robots/utils.py b/lerobot/common/robots/utils.py index d100c8366c..ccc1c58e86 100644 --- a/lerobot/common/robots/utils.py +++ b/lerobot/common/robots/utils.py @@ -29,6 +29,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot: from .so100_follower import SO100Follower return SO100Follower(config) + elif config.type == "so100_follower_end_effector": + from .so100_follower import SO100FollowerEndEffector + + return SO100FollowerEndEffector(config) elif config.type == "so101_follower": from .so101_follower import SO101Follower diff --git a/lerobot/common/teleoperators/gamepad/__init__.py b/lerobot/common/teleoperators/gamepad/__init__.py new file mode 100644 index 0000000000..6f9f7fbd91 --- /dev/null +++ b/lerobot/common/teleoperators/gamepad/__init__.py @@ -0,0 +1,18 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_gamepad import GamepadTeleopConfig +from .teleop_gamepad import GamepadTeleop diff --git a/lerobot/common/teleoperators/gamepad/configuration_gamepad.py b/lerobot/common/teleoperators/gamepad/configuration_gamepad.py new file mode 100644 index 0000000000..b3a565c072 --- /dev/null +++ b/lerobot/common/teleoperators/gamepad/configuration_gamepad.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from ..config import TeleoperatorConfig + + +@TeleoperatorConfig.register_subclass("gamepad") +@dataclass +class GamepadTeleopConfig(TeleoperatorConfig): + use_gripper: bool = True diff --git a/lerobot/common/teleoperators/gamepad/gamepad_utils.py b/lerobot/common/teleoperators/gamepad/gamepad_utils.py new file mode 100644 index 0000000000..21a293c771 --- /dev/null +++ b/lerobot/common/teleoperators/gamepad/gamepad_utils.py @@ -0,0 +1,480 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + + +class InputController: + """Base class for input controllers that generate motion deltas.""" + + def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0): + """ + Initialize the controller. + + Args: + x_step_size: Base movement step size in meters + y_step_size: Base movement step size in meters + z_step_size: Base movement step size in meters + """ + self.x_step_size = x_step_size + self.y_step_size = y_step_size + self.z_step_size = z_step_size + self.running = True + self.episode_end_status = None # None, "success", or "failure" + self.intervention_flag = False + self.open_gripper_command = False + self.close_gripper_command = False + + def start(self): + """Start the controller and initialize resources.""" + pass + + def stop(self): + """Stop the controller and release resources.""" + pass + + def get_deltas(self): + """Get the current movement deltas (dx, dy, dz) in meters.""" + return 0.0, 0.0, 0.0 + + def should_quit(self): + """Return True if the user has requested to quit.""" + return not self.running + + def update(self): + """Update controller state - call this once per frame.""" + pass + + def __enter__(self): + """Support for use in 'with' statements.""" + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Ensure resources are released when exiting 'with' block.""" + self.stop() + + def get_episode_end_status(self): + """ + Get the current episode end status. + + Returns: + None if episode should continue, "success" or "failure" otherwise + """ + status = self.episode_end_status + self.episode_end_status = None # Reset after reading + return status + + def should_intervene(self): + """Return True if intervention flag was set.""" + return self.intervention_flag + + def gripper_command(self): + """Return the current gripper command.""" + if self.open_gripper_command == self.close_gripper_command: + return "stay" + elif self.open_gripper_command: + return "open" + elif self.close_gripper_command: + return "close" + + +class KeyboardController(InputController): + """Generate motion deltas from keyboard input.""" + + def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0): + super().__init__(x_step_size, y_step_size, z_step_size) + self.key_states = { + "forward_x": False, + "backward_x": False, + "forward_y": False, + "backward_y": False, + "forward_z": False, + "backward_z": False, + "quit": False, + "success": False, + "failure": False, + } + self.listener = None + + def start(self): + """Start the keyboard listener.""" + from pynput import keyboard + + def on_press(key): + try: + if key == keyboard.Key.up: + self.key_states["forward_x"] = True + elif key == keyboard.Key.down: + self.key_states["backward_x"] = True + elif key == keyboard.Key.left: + self.key_states["forward_y"] = True + elif key == keyboard.Key.right: + self.key_states["backward_y"] = True + elif key == keyboard.Key.shift: + self.key_states["backward_z"] = True + elif key == keyboard.Key.shift_r: + self.key_states["forward_z"] = True + elif key == keyboard.Key.esc: + self.key_states["quit"] = True + self.running = False + return False + elif key == keyboard.Key.enter: + self.key_states["success"] = True + self.episode_end_status = "success" + elif key == keyboard.Key.backspace: + self.key_states["failure"] = True + self.episode_end_status = "failure" + except AttributeError: + pass + + def on_release(key): + try: + if key == keyboard.Key.up: + self.key_states["forward_x"] = False + elif key == keyboard.Key.down: + self.key_states["backward_x"] = False + elif key == keyboard.Key.left: + self.key_states["forward_y"] = False + elif key == keyboard.Key.right: + self.key_states["backward_y"] = False + elif key == keyboard.Key.shift: + self.key_states["backward_z"] = False + elif key == keyboard.Key.shift_r: + self.key_states["forward_z"] = False + elif key == keyboard.Key.enter: + self.key_states["success"] = False + elif key == keyboard.Key.backspace: + self.key_states["failure"] = False + except AttributeError: + pass + + self.listener = keyboard.Listener(on_press=on_press, on_release=on_release) + self.listener.start() + + print("Keyboard controls:") + print(" Arrow keys: Move in X-Y plane") + print(" Shift and Shift_R: Move in Z axis") + print(" Enter: End episode with SUCCESS") + print(" Backspace: End episode with FAILURE") + print(" ESC: Exit") + + def stop(self): + """Stop the keyboard listener.""" + if self.listener and self.listener.is_alive(): + self.listener.stop() + + def get_deltas(self): + """Get the current movement deltas from keyboard state.""" + delta_x = delta_y = delta_z = 0.0 + + if self.key_states["forward_x"]: + delta_x += self.x_step_size + if self.key_states["backward_x"]: + delta_x -= self.x_step_size + if self.key_states["forward_y"]: + delta_y += self.y_step_size + if self.key_states["backward_y"]: + delta_y -= self.y_step_size + if self.key_states["forward_z"]: + delta_z += self.z_step_size + if self.key_states["backward_z"]: + delta_z -= self.z_step_size + + return delta_x, delta_y, delta_z + + def should_quit(self): + """Return True if ESC was pressed.""" + return self.key_states["quit"] + + def should_save(self): + """Return True if Enter was pressed (save episode).""" + return self.key_states["success"] or self.key_states["failure"] + + +class GamepadController(InputController): + """Generate motion deltas from gamepad input.""" + + def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0, deadzone=0.1): + super().__init__(x_step_size, y_step_size, z_step_size) + self.deadzone = deadzone + self.joystick = None + self.intervention_flag = False + + def start(self): + """Initialize pygame and the gamepad.""" + import pygame + + pygame.init() + pygame.joystick.init() + + if pygame.joystick.get_count() == 0: + logging.error("No gamepad detected. Please connect a gamepad and try again.") + self.running = False + return + + self.joystick = pygame.joystick.Joystick(0) + self.joystick.init() + logging.info(f"Initialized gamepad: {self.joystick.get_name()}") + + print("Gamepad controls:") + print(" Left analog stick: Move in X-Y plane") + print(" Right analog stick (vertical): Move in Z axis") + print(" B/Circle button: Exit") + print(" Y/Triangle button: End episode with SUCCESS") + print(" A/Cross button: End episode with FAILURE") + print(" X/Square button: Rerecord episode") + + def stop(self): + """Clean up pygame resources.""" + import pygame + + if pygame.joystick.get_init(): + if self.joystick: + self.joystick.quit() + pygame.joystick.quit() + pygame.quit() + + def update(self): + """Process pygame events to get fresh gamepad readings.""" + import pygame + + for event in pygame.event.get(): + if event.type == pygame.JOYBUTTONDOWN: + if event.button == 3: + self.episode_end_status = "success" + # A button (1) for failure + elif event.button == 1: + self.episode_end_status = "failure" + # X button (0) for rerecord + elif event.button == 0: + self.episode_end_status = "rerecord_episode" + + # RB button (6) for closing gripper + elif event.button == 6: + self.close_gripper_command = True + + # LT button (7) for opening gripper + elif event.button == 7: + self.open_gripper_command = True + + # Reset episode status on button release + elif event.type == pygame.JOYBUTTONUP: + if event.button in [0, 2, 3]: + self.episode_end_status = None + + elif event.button == 6: + self.close_gripper_command = False + + elif event.button == 7: + self.open_gripper_command = False + + # Check for RB button (typically button 5) for intervention flag + if self.joystick.get_button(5): + self.intervention_flag = True + else: + self.intervention_flag = False + + def get_deltas(self): + """Get the current movement deltas from gamepad state.""" + import pygame + + try: + # Read joystick axes + # Left stick X and Y (typically axes 0 and 1) + x_input = self.joystick.get_axis(0) # Left/Right + y_input = self.joystick.get_axis(1) # Up/Down (often inverted) + + # Right stick Y (typically axis 3 or 4) + z_input = self.joystick.get_axis(3) # Up/Down for Z + + # Apply deadzone to avoid drift + x_input = 0 if abs(x_input) < self.deadzone else x_input + y_input = 0 if abs(y_input) < self.deadzone else y_input + z_input = 0 if abs(z_input) < self.deadzone else z_input + + # Calculate deltas (note: may need to invert axes depending on controller) + delta_x = -y_input * self.y_step_size # Forward/backward + delta_y = -x_input * self.x_step_size # Left/right + delta_z = -z_input * self.z_step_size # Up/down + + return delta_x, delta_y, delta_z + + except pygame.error: + logging.error("Error reading gamepad. Is it still connected?") + return 0.0, 0.0, 0.0 + + +class GamepadControllerHID(InputController): + """Generate motion deltas from gamepad input using HIDAPI.""" + + def __init__( + self, + x_step_size=1.0, + y_step_size=1.0, + z_step_size=1.0, + deadzone=0.1, + ): + """ + Initialize the HID gamepad controller. + + Args: + step_size: Base movement step size in meters + z_scale: Scaling factor for Z-axis movement + deadzone: Joystick deadzone to prevent drift + """ + super().__init__(x_step_size, y_step_size, z_step_size) + self.deadzone = deadzone + self.device = None + self.device_info = None + + # Movement values (normalized from -1.0 to 1.0) + self.left_x = 0.0 + self.left_y = 0.0 + self.right_x = 0.0 + self.right_y = 0.0 + + # Button states + self.buttons = {} + self.quit_requested = False + self.save_requested = False + + def find_device(self): + """Look for the gamepad device by vendor and product ID.""" + import hid + + devices = hid.enumerate() + for device in devices: + device_name = device["product_string"] + if any(controller in device_name for controller in ["Logitech", "Xbox", "PS4", "PS5"]): + return device + + logging.error( + "No gamepad found, check the connection and the product string in HID to add your gamepad" + ) + return None + + def start(self): + """Connect to the gamepad using HIDAPI.""" + import hid + + self.device_info = self.find_device() + if not self.device_info: + self.running = False + return + + try: + logging.info(f"Connecting to gamepad at path: {self.device_info['path']}") + self.device = hid.device() + self.device.open_path(self.device_info["path"]) + self.device.set_nonblocking(1) + + manufacturer = self.device.get_manufacturer_string() + product = self.device.get_product_string() + logging.info(f"Connected to {manufacturer} {product}") + + logging.info("Gamepad controls (HID mode):") + logging.info(" Left analog stick: Move in X-Y plane") + logging.info(" Right analog stick: Move in Z axis (vertical)") + logging.info(" Button 1/B/Circle: Exit") + logging.info(" Button 2/A/Cross: End episode with SUCCESS") + logging.info(" Button 3/X/Square: End episode with FAILURE") + + except OSError as e: + logging.error(f"Error opening gamepad: {e}") + logging.error("You might need to run this with sudo/admin privileges on some systems") + self.running = False + + def stop(self): + """Close the HID device connection.""" + if self.device: + self.device.close() + self.device = None + + def update(self): + """ + Read and process the latest gamepad data. + Due to an issue with the HIDAPI, we need to read the read the device several times in order to get a stable reading + """ + for _ in range(10): + self._update() + + def _update(self): + """Read and process the latest gamepad data.""" + if not self.device or not self.running: + return + + try: + # Read data from the gamepad + data = self.device.read(64) + # Interpret gamepad data - this will vary by controller model + # These offsets are for the Logitech RumblePad 2 + if data and len(data) >= 8: + # Normalize joystick values from 0-255 to -1.0-1.0 + self.left_x = (data[1] - 128) / 128.0 + self.left_y = (data[2] - 128) / 128.0 + self.right_x = (data[3] - 128) / 128.0 + self.right_y = (data[4] - 128) / 128.0 + + # Apply deadzone + self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x + self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y + self.right_x = 0 if abs(self.right_x) < self.deadzone else self.right_x + self.right_y = 0 if abs(self.right_y) < self.deadzone else self.right_y + + # Parse button states (byte 5 in the Logitech RumblePad 2) + buttons = data[5] + + # Check if RB is pressed then the intervention flag should be set + self.intervention_flag = data[6] in [2, 6, 10, 14] + + # Check if RT is pressed + self.open_gripper_command = data[6] in [8, 10, 12] + + # Check if LT is pressed + self.close_gripper_command = data[6] in [4, 6, 12] + + # Check if Y/Triangle button (bit 7) is pressed for saving + # Check if X/Square button (bit 5) is pressed for failure + # Check if A/Cross button (bit 4) is pressed for rerecording + if buttons & 1 << 7: + self.episode_end_status = "success" + elif buttons & 1 << 5: + self.episode_end_status = "failure" + elif buttons & 1 << 4: + self.episode_end_status = "rerecord_episode" + else: + self.episode_end_status = None + + except OSError as e: + logging.error(f"Error reading from gamepad: {e}") + + def get_deltas(self): + """Get the current movement deltas from gamepad state.""" + # Calculate deltas - invert as needed based on controller orientation + delta_x = -self.left_y * self.x_step_size # Forward/backward + delta_y = -self.left_x * self.y_step_size # Left/right + delta_z = -self.right_y * self.z_step_size # Up/down + + return delta_x, delta_y, delta_z + + def should_quit(self): + """Return True if quit button was pressed.""" + return self.quit_requested + + def should_save(self): + """Return True if save button was pressed.""" + return self.save_requested diff --git a/lerobot/common/teleoperators/gamepad/teleop_gamepad.py b/lerobot/common/teleoperators/gamepad/teleop_gamepad.py new file mode 100644 index 0000000000..98a0647e21 --- /dev/null +++ b/lerobot/common/teleoperators/gamepad/teleop_gamepad.py @@ -0,0 +1,138 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from enum import IntEnum +from typing import Any + +import numpy as np + +from ..teleoperator import Teleoperator +from .configuration_gamepad import GamepadTeleopConfig + + +class GripperAction(IntEnum): + CLOSE = 0 + STAY = 1 + OPEN = 2 + + +gripper_action_map = { + "close": GripperAction.CLOSE.value, + "open": GripperAction.OPEN.value, + "stay": GripperAction.STAY.value, +} + + +class GamepadTeleop(Teleoperator): + """ + Teleop class to use gamepad inputs for control. + """ + + config_class = GamepadTeleopConfig + name = "gamepad" + + def __init__(self, config: GamepadTeleopConfig): + super().__init__(config) + self.config = config + self.robot_type = config.type + + self.gamepad = None + + @property + def action_features(self) -> dict: + if self.config.use_gripper: + return { + "dtype": "float32", + "shape": (4,), + "names": {"delta_x": 0, "delta_y": 1, "delta_z": 2, "gripper": 3}, + } + else: + return { + "dtype": "float32", + "shape": (3,), + "names": {"delta_x": 0, "delta_y": 1, "delta_z": 2}, + } + + @property + def feedback_features(self) -> dict: + return {} + + def connect(self) -> None: + # use HidApi for macos + if sys.platform == "darwin": + # NOTE: On macOS, pygame doesn’t reliably detect input from some controllers so we fall back to hidapi + from .gamepad_utils import GamepadControllerHID as Gamepad + else: + from .gamepad_utils import GamepadController as Gamepad + + self.gamepad = Gamepad() + self.gamepad.start() + + def get_action(self) -> dict[str, Any]: + # Update the controller to get fresh inputs + self.gamepad.update() + + # Get movement deltas from the controller + delta_x, delta_y, delta_z = self.gamepad.get_deltas() + + # Create action from gamepad input + gamepad_action = np.array([delta_x, delta_y, delta_z], dtype=np.float32) + + action_dict = { + "delta_x": gamepad_action[0], + "delta_y": gamepad_action[1], + "delta_z": gamepad_action[2], + } + + # Default gripper action is to stay + gripper_action = GripperAction.STAY.value + if self.config.use_gripper: + gripper_command = self.gamepad.gripper_command() + gripper_action = gripper_action_map[gripper_command] + action_dict["gripper"] = gripper_action + + return action_dict + + def disconnect(self) -> None: + """Disconnect from the gamepad.""" + if self.gamepad is not None: + self.gamepad.stop() + self.gamepad = None + + def is_connected(self) -> bool: + """Check if gamepad is connected.""" + return self.gamepad is not None + + def calibrate(self) -> None: + """Calibrate the gamepad.""" + # No calibration needed for gamepad + pass + + def is_calibrated(self) -> bool: + """Check if gamepad is calibrated.""" + # Gamepad doesn't require calibration + return True + + def configure(self) -> None: + """Configure the gamepad.""" + # No additional configuration needed + pass + + def send_feedback(self, feedback: dict) -> None: + """Send feedback to the gamepad.""" + # Gamepad doesn't support feedback + pass diff --git a/lerobot/common/teleoperators/so100_leader/so100_leader.py b/lerobot/common/teleoperators/so100_leader/so100_leader.py index 900346ad55..59b083e3fd 100644 --- a/lerobot/common/teleoperators/so100_leader/so100_leader.py +++ b/lerobot/common/teleoperators/so100_leader/so100_leader.py @@ -112,7 +112,7 @@ def calibrate(self) -> None: self.bus.write_calibration(self.calibration) self._save_calibration() - logger.info(f"Calibration saved to {self.calibration_fpath}") + print(f"Calibration saved to {self.calibration_fpath}") def configure(self) -> None: self.bus.disable_torque() diff --git a/lerobot/common/teleoperators/so101_leader/config_so101_leader.py b/lerobot/common/teleoperators/so101_leader/config_so101_leader.py index 5f2e110da1..8d91c32dfe 100644 --- a/lerobot/common/teleoperators/so101_leader/config_so101_leader.py +++ b/lerobot/common/teleoperators/so101_leader/config_so101_leader.py @@ -24,3 +24,5 @@ class SO101LeaderConfig(TeleoperatorConfig): # Port to connect to the arm port: str + + use_degrees: bool = False diff --git a/lerobot/common/teleoperators/so101_leader/so101_leader.py b/lerobot/common/teleoperators/so101_leader/so101_leader.py index 34ad31dafe..80ddfbb1d6 100644 --- a/lerobot/common/teleoperators/so101_leader/so101_leader.py +++ b/lerobot/common/teleoperators/so101_leader/so101_leader.py @@ -41,14 +41,15 @@ class SO101Leader(Teleoperator): def __init__(self, config: SO101LeaderConfig): super().__init__(config) self.config = config + norm_mode_body = MotorNormMode.DEGREES if config.use_degrees else MotorNormMode.RANGE_M100_100 self.bus = FeetechMotorsBus( port=self.config.port, motors={ - "shoulder_pan": Motor(1, "sts3215", MotorNormMode.RANGE_M100_100), - "shoulder_lift": Motor(2, "sts3215", MotorNormMode.RANGE_M100_100), - "elbow_flex": Motor(3, "sts3215", MotorNormMode.RANGE_M100_100), - "wrist_flex": Motor(4, "sts3215", MotorNormMode.RANGE_M100_100), - "wrist_roll": Motor(5, "sts3215", MotorNormMode.RANGE_M100_100), + "shoulder_pan": Motor(1, "sts3215", norm_mode_body), + "shoulder_lift": Motor(2, "sts3215", norm_mode_body), + "elbow_flex": Motor(3, "sts3215", norm_mode_body), + "wrist_flex": Motor(4, "sts3215", norm_mode_body), + "wrist_roll": Motor(5, "sts3215", norm_mode_body), "gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100), }, calibration=self.calibration, @@ -108,7 +109,7 @@ def calibrate(self) -> None: self.bus.write_calibration(self.calibration) self._save_calibration() - logger.info(f"Calibration saved to {self.calibration_fpath}") + print(f"Calibration saved to {self.calibration_fpath}") def configure(self) -> None: self.bus.disable_torque() diff --git a/lerobot/common/teleoperators/teleoperator.py b/lerobot/common/teleoperators/teleoperator.py index d8715a5524..a385173120 100644 --- a/lerobot/common/teleoperators/teleoperator.py +++ b/lerobot/common/teleoperators/teleoperator.py @@ -25,7 +25,16 @@ class Teleoperator(abc.ABC): - """The main LeRobot class for implementing teleoperation devices.""" + """ + The base abstract class for all LeRobot-compatible teleoperation devices. + + This class provides a standardized interface for interacting with physical teleoperators. + Subclasses must implement all abstract methods and properties to be usable. + + Attributes: + config_class (RobotConfig): The expected configuration class for this teleoperator. + name (str): The unique name used to identify this teleoperator type. + """ # Set these in ALL subclasses config_class: TeleoperatorConfig @@ -50,58 +59,122 @@ def __str__(self) -> str: @property @abc.abstractmethod def action_features(self) -> dict: + """ + A dictionary describing the structure and types of the actions produced by the teleoperator. Its + structure (keys) should match the structure of what is returned by :pymeth:`get_action`. Values for + the dict should be the type of the value if it's a simple value, e.g. `float` for single + proprioceptive value (a joint's goal position/velocity) + + Note: this property should be able to be called regardless of whether the robot is connected or not. + """ pass @property @abc.abstractmethod def feedback_features(self) -> dict: + """ + A dictionary describing the structure and types of the feedback actions expected by the robot. Its + structure (keys) should match the structure of what is passed to :pymeth:`send_feedback`. Values for + the dict should be the type of the value if it's a simple value, e.g. `float` for single + proprioceptive value (a joint's goal position/velocity) + + Note: this property should be able to be called regardless of whether the robot is connected or not. + """ pass @property @abc.abstractmethod def is_connected(self) -> bool: + """ + Whether the teleoperator is currently connected or not. If `False`, calling :pymeth:`get_action` + or :pymeth:`send_feedback` should raise an error. + """ pass @abc.abstractmethod def connect(self, calibrate: bool = True) -> None: - """Connects to the teleoperator.""" + """ + Establish communication with the teleoperator. + + Args: + calibrate (bool): If True, automatically calibrate the teleoperator after connecting if it's not + calibrated or needs calibration (this is hardware-dependant). + """ pass @property @abc.abstractmethod def is_calibrated(self) -> bool: + """Whether the teleoperator is currently calibrated or not. Should be always `True` if not applicable""" pass @abc.abstractmethod def calibrate(self) -> None: - """Calibrates the teleoperator.""" + """ + Calibrate the teleoperator if applicable. If not, this should be a no-op. + + This method should collect any necessary data (e.g., motor offsets) and update the + :pyattr:`calibration` dictionary accordingly. + """ pass def _load_calibration(self, fpath: Path | None = None) -> None: + """ + Helper to load calibration data from the specified file. + + Args: + fpath (Path | None): Optional path to the calibration file. Defaults to `self.calibration_fpath`. + """ fpath = self.calibration_fpath if fpath is None else fpath with open(fpath) as f, draccus.config_type("json"): self.calibration = draccus.load(dict[str, MotorCalibration], f) def _save_calibration(self, fpath: Path | None = None) -> None: + """ + Helper to save calibration data to the specified file. + + Args: + fpath (Path | None): Optional path to save the calibration file. Defaults to `self.calibration_fpath`. + """ fpath = self.calibration_fpath if fpath is None else fpath with open(fpath, "w") as f, draccus.config_type("json"): draccus.dump(self.calibration, f, indent=4) @abc.abstractmethod def configure(self) -> None: + """ + Apply any one-time or runtime configuration to the teleoperator. + This may include setting motor parameters, control modes, or initial state. + """ pass @abc.abstractmethod def get_action(self) -> dict[str, Any]: - """Gets the action to send to a teleoperator.""" + """ + Retrieve the current action from the teleoperator. + + Returns: + dict[str, Any]: A flat dictionary representing the teleoperator's current actions. Its + structure should match :pymeth:`observation_features`. + """ pass @abc.abstractmethod def send_feedback(self, feedback: dict[str, Any]) -> None: - """Sends feedback captured from a robot to the teleoperator.""" + """ + Send a feedback action command to the teleoperator. + + Args: + feedback (dict[str, Any]): Dictionary representing the desired feedback. Its structure should match + :pymeth:`feedback_features`. + + Returns: + dict[str, Any]: The action actually sent to the motors potentially clipped or modified, e.g. by + safety limits on velocity. + """ pass @abc.abstractmethod def disconnect(self) -> None: - """Disconnects from the teleoperator.""" + """Disconnect from the teleoperator and perform any necessary cleanup.""" pass diff --git a/lerobot/common/teleoperators/utils.py b/lerobot/common/teleoperators/utils.py index 4942084ac7..d7b7bcf0e6 100644 --- a/lerobot/common/teleoperators/utils.py +++ b/lerobot/common/teleoperators/utils.py @@ -45,5 +45,9 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator: from tests.mocks.mock_teleop import MockTeleop return MockTeleop(config) + elif config.type == "gamepad": + from .gamepad.teleop_gamepad import GamepadTeleop + + return GamepadTeleop(config) else: raise ValueError(config.type) diff --git a/lerobot/common/transport/services.proto b/lerobot/common/transport/services.proto new file mode 100644 index 0000000000..29d00005a6 --- /dev/null +++ b/lerobot/common/transport/services.proto @@ -0,0 +1,59 @@ +// Copyright 2024 The HuggingFace Inc. team. +// All rights reserved. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// To generate a classes for transport part (services_pb2.py and services_pb2_grpc.py) use the following command: +// +// python -m grpc_tools.protoc -I . --python_out=. --grpc_python_out=. lerobot/common/transport/services.proto +// +// The command should be launched from the root of the project. + +syntax = "proto3"; + +package transport; + +// LearnerService: the Actor calls this to push transitions. +// The Learner implements this service. +service LearnerService { + // Actor -> Learner to store transitions + rpc StreamParameters(Empty) returns (stream Parameters); + rpc SendTransitions(stream Transition) returns (Empty); + rpc SendInteractions(stream InteractionMessage) returns (Empty); + rpc Ready(Empty) returns (Empty); +} + +enum TransferState { + TRANSFER_UNKNOWN = 0; + TRANSFER_BEGIN = 1; + TRANSFER_MIDDLE = 2; + TRANSFER_END = 3; +} + +// Messages +message Transition { + TransferState transfer_state = 1; + bytes data = 2; +} + +message Parameters { + TransferState transfer_state = 1; + bytes data = 2; +} + +message InteractionMessage { + TransferState transfer_state = 1; + bytes data = 2; +} + +message Empty {} diff --git a/lerobot/common/transport/services_pb2.py b/lerobot/common/transport/services_pb2.py new file mode 100644 index 0000000000..727beb60de --- /dev/null +++ b/lerobot/common/transport/services_pb2.py @@ -0,0 +1,45 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: lerobot/common/transport/services.proto +# Protobuf Python Version: 5.29.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 5, + 29, + 0, + '', + 'lerobot/common/transport/services.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\'lerobot/common/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'lerobot.common.transport.services_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_TRANSFERSTATE']._serialized_start=305 + _globals['_TRANSFERSTATE']._serialized_end=401 + _globals['_TRANSITION']._serialized_start=54 + _globals['_TRANSITION']._serialized_end=130 + _globals['_PARAMETERS']._serialized_start=132 + _globals['_PARAMETERS']._serialized_end=208 + _globals['_INTERACTIONMESSAGE']._serialized_start=210 + _globals['_INTERACTIONMESSAGE']._serialized_end=294 + _globals['_EMPTY']._serialized_start=296 + _globals['_EMPTY']._serialized_end=303 + _globals['_LEARNERSERVICE']._serialized_start=404 + _globals['_LEARNERSERVICE']._serialized_end=661 +# @@protoc_insertion_point(module_scope) diff --git a/lerobot/common/transport/services_pb2_grpc.py b/lerobot/common/transport/services_pb2_grpc.py new file mode 100644 index 0000000000..5a7a924fd2 --- /dev/null +++ b/lerobot/common/transport/services_pb2_grpc.py @@ -0,0 +1,233 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +import warnings + +from lerobot.common.transport import services_pb2 as lerobot_dot_common_dot_transport_dot_services__pb2 + +GRPC_GENERATED_VERSION = '1.71.0' +GRPC_VERSION = grpc.__version__ +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + raise RuntimeError( + f'The grpc package installed is at version {GRPC_VERSION},' + + f' but the generated code in lerobot/common/transport/services_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + ) + + +class LearnerServiceStub: + """LearnerService: the Actor calls this to push transitions. + The Learner implements this service. + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.StreamParameters = channel.unary_stream( + '/transport.LearnerService/StreamParameters', + request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, + response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Parameters.FromString, + _registered_method=True) + self.SendTransitions = channel.stream_unary( + '/transport.LearnerService/SendTransitions', + request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Transition.SerializeToString, + response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, + _registered_method=True) + self.SendInteractions = channel.stream_unary( + '/transport.LearnerService/SendInteractions', + request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString, + response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, + _registered_method=True) + self.Ready = channel.unary_unary( + '/transport.LearnerService/Ready', + request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, + response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, + _registered_method=True) + + +class LearnerServiceServicer: + """LearnerService: the Actor calls this to push transitions. + The Learner implements this service. + """ + + def StreamParameters(self, request, context): + """Actor -> Learner to store transitions + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendTransitions(self, request_iterator, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendInteractions(self, request_iterator, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Ready(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_LearnerServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'StreamParameters': grpc.unary_stream_rpc_method_handler( + servicer.StreamParameters, + request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, + response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Parameters.SerializeToString, + ), + 'SendTransitions': grpc.stream_unary_rpc_method_handler( + servicer.SendTransitions, + request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Transition.FromString, + response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, + ), + 'SendInteractions': grpc.stream_unary_rpc_method_handler( + servicer.SendInteractions, + request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.FromString, + response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, + ), + 'Ready': grpc.unary_unary_rpc_method_handler( + servicer.Ready, + request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, + response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'transport.LearnerService', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('transport.LearnerService', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class LearnerService: + """LearnerService: the Actor calls this to push transitions. + The Learner implements this service. + """ + + @staticmethod + def StreamParameters(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream( + request, + target, + '/transport.LearnerService/StreamParameters', + lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, + lerobot_dot_common_dot_transport_dot_services__pb2.Parameters.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def SendTransitions(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_unary( + request_iterator, + target, + '/transport.LearnerService/SendTransitions', + lerobot_dot_common_dot_transport_dot_services__pb2.Transition.SerializeToString, + lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def SendInteractions(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_unary( + request_iterator, + target, + '/transport.LearnerService/SendInteractions', + lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString, + lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def Ready(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/transport.LearnerService/Ready', + lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, + lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/lerobot/common/transport/utils.py b/lerobot/common/transport/utils.py new file mode 100644 index 0000000000..774721fc6d --- /dev/null +++ b/lerobot/common/transport/utils.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import logging +import pickle # nosec B403: Safe usage for internal serialization only +from multiprocessing import Event, Queue +from typing import Any + +import torch + +from lerobot.common.transport import services_pb2 +from lerobot.common.utils.transition import Transition + +CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB + + +def bytes_buffer_size(buffer: io.BytesIO) -> int: + buffer.seek(0, io.SEEK_END) + result = buffer.tell() + buffer.seek(0) + return result + + +def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True): + buffer = io.BytesIO(buffer) + size_in_bytes = bytes_buffer_size(buffer) + + sent_bytes = 0 + + logging_method = logging.info if not silent else logging.debug + + logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with") + + while sent_bytes < size_in_bytes: + transfer_state = services_pb2.TransferState.TRANSFER_MIDDLE + + if sent_bytes + CHUNK_SIZE >= size_in_bytes: + transfer_state = services_pb2.TransferState.TRANSFER_END + elif sent_bytes == 0: + transfer_state = services_pb2.TransferState.TRANSFER_BEGIN + + size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes) + chunk = buffer.read(size_to_read) + + yield message_class(transfer_state=transfer_state, data=chunk) + sent_bytes += size_to_read + logging_method(f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}") + + logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB") + + +def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""): # type: ignore + bytes_buffer = io.BytesIO() + step = 0 + + logging.info(f"{log_prefix} Starting receiver") + for item in iterator: + logging.debug(f"{log_prefix} Received item") + if shutdown_event.is_set(): + logging.info(f"{log_prefix} Shutting down receiver") + return + + if item.transfer_state == services_pb2.TransferState.TRANSFER_BEGIN: + bytes_buffer.seek(0) + bytes_buffer.truncate(0) + bytes_buffer.write(item.data) + logging.debug(f"{log_prefix} Received data at step 0") + step = 0 + elif item.transfer_state == services_pb2.TransferState.TRANSFER_MIDDLE: + bytes_buffer.write(item.data) + step += 1 + logging.debug(f"{log_prefix} Received data at step {step}") + elif item.transfer_state == services_pb2.TransferState.TRANSFER_END: + bytes_buffer.write(item.data) + logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}") + + queue.put(bytes_buffer.getvalue()) + + bytes_buffer.seek(0) + bytes_buffer.truncate(0) + step = 0 + + logging.debug(f"{log_prefix} Queue updated") + else: + logging.warning(f"{log_prefix} Received unknown transfer state {item.transfer_state}") + raise ValueError(f"Received unknown transfer state {item.transfer_state}") + + +def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes: + """Convert model state dict to flat array for transmission""" + buffer = io.BytesIO() + + torch.save(state_dict, buffer) + + return buffer.getvalue() + + +def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]: + buffer = io.BytesIO(buffer) + buffer.seek(0) + return torch.load(buffer, weights_only=True) + + +def python_object_to_bytes(python_object: Any) -> bytes: + return pickle.dumps(python_object) + + +def bytes_to_python_object(buffer: bytes) -> Any: + buffer = io.BytesIO(buffer) + buffer.seek(0) + obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load + # Add validation checks here + return obj + + +def bytes_to_transitions(buffer: bytes) -> list[Transition]: + buffer = io.BytesIO(buffer) + buffer.seek(0) + transitions = torch.load(buffer, weights_only=True) + return transitions + + +def transitions_to_bytes(transitions: list[Transition]) -> bytes: + buffer = io.BytesIO() + torch.save(transitions, buffer) + return buffer.getvalue() diff --git a/lerobot/common/utils/buffer.py b/lerobot/common/utils/buffer.py new file mode 100644 index 0000000000..9ae231ad92 --- /dev/null +++ b/lerobot/common/utils/buffer.py @@ -0,0 +1,841 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from contextlib import suppress +from typing import Callable, Sequence, TypedDict + +import torch +import torch.nn.functional as F # noqa: N812 +from tqdm import tqdm + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.utils.transition import Transition + + +class BatchTransition(TypedDict): + state: dict[str, torch.Tensor] + action: torch.Tensor + reward: torch.Tensor + next_state: dict[str, torch.Tensor] + done: torch.Tensor + truncated: torch.Tensor + complementary_info: dict[str, torch.Tensor | float | int] | None = None + + +def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor: + """ + Perform a per-image random crop over a batch of images in a vectorized way. + (Same as shown previously.) + """ + B, C, H, W = images.shape # noqa: N806 + crop_h, crop_w = output_size + + if crop_h > H or crop_w > W: + raise ValueError( + f"Requested crop size ({crop_h}, {crop_w}) is bigger than the image size ({H}, {W})." + ) + + tops = torch.randint(0, H - crop_h + 1, (B,), device=images.device) + lefts = torch.randint(0, W - crop_w + 1, (B,), device=images.device) + + rows = torch.arange(crop_h, device=images.device).unsqueeze(0) + tops.unsqueeze(1) + cols = torch.arange(crop_w, device=images.device).unsqueeze(0) + lefts.unsqueeze(1) + + rows = rows.unsqueeze(2).expand(-1, -1, crop_w) # (B, crop_h, crop_w) + cols = cols.unsqueeze(1).expand(-1, crop_h, -1) # (B, crop_h, crop_w) + + images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C) + + # Gather pixels + cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :] + # cropped_hwcn => (B, crop_h, crop_w, C) + + cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w) + return cropped + + +def random_shift(images: torch.Tensor, pad: int = 4): + """Vectorized random shift, imgs: (B,C,H,W), pad: #pixels""" + _, _, h, w = images.shape + images = F.pad(input=images, pad=(pad, pad, pad, pad), mode="replicate") + return random_crop_vectorized(images=images, output_size=(h, w)) + + +class ReplayBuffer: + def __init__( + self, + capacity: int, + device: str = "cuda:0", + state_keys: Sequence[str] | None = None, + image_augmentation_function: Callable | None = None, + use_drq: bool = True, + storage_device: str = "cpu", + optimize_memory: bool = False, + ): + """ + Replay buffer for storing transitions. + It will allocate tensors on the specified device, when the first transition is added. + NOTE: If you encounter memory issues, you can try to use the `optimize_memory` flag to save memory or + and use the `storage_device` flag to store the buffer on a different device. + Args: + capacity (int): Maximum number of transitions to store in the buffer. + device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu"). + state_keys (List[str]): The list of keys that appear in `state` and `next_state`. + image_augmentation_function (Optional[Callable]): A function that takes a batch of images + and returns a batch of augmented images. If None, a default augmentation function is used. + use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer. + storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored. + Using "cpu" can help save GPU memory. + optimize_memory (bool): If True, optimizes memory by not storing duplicate next_states when + they can be derived from states. This is useful for large datasets where next_state[i] = state[i+1]. + """ + if capacity <= 0: + raise ValueError("Capacity must be greater than 0.") + + self.capacity = capacity + self.device = device + self.storage_device = storage_device + self.position = 0 + self.size = 0 + self.initialized = False + self.optimize_memory = optimize_memory + + # Track episode boundaries for memory optimization + self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device) + + # If no state_keys provided, default to an empty list + self.state_keys = state_keys if state_keys is not None else [] + + self.image_augmentation_function = image_augmentation_function + + if image_augmentation_function is None: + base_function = functools.partial(random_shift, pad=4) + self.image_augmentation_function = torch.compile(base_function) + self.use_drq = use_drq + + def _initialize_storage( + self, + state: dict[str, torch.Tensor], + action: torch.Tensor, + complementary_info: dict[str, torch.Tensor] | None = None, + ): + """Initialize the storage tensors based on the first transition.""" + # Determine shapes from the first transition + state_shapes = {key: val.squeeze(0).shape for key, val in state.items()} + action_shape = action.squeeze(0).shape + + # Pre-allocate tensors for storage + self.states = { + key: torch.empty((self.capacity, *shape), device=self.storage_device) + for key, shape in state_shapes.items() + } + self.actions = torch.empty((self.capacity, *action_shape), device=self.storage_device) + self.rewards = torch.empty((self.capacity,), device=self.storage_device) + + if not self.optimize_memory: + # Standard approach: store states and next_states separately + self.next_states = { + key: torch.empty((self.capacity, *shape), device=self.storage_device) + for key, shape in state_shapes.items() + } + else: + # Memory-optimized approach: don't allocate next_states buffer + # Just create a reference to states for consistent API + self.next_states = self.states # Just a reference for API consistency + + self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) + self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) + + # Initialize storage for complementary_info + self.has_complementary_info = complementary_info is not None + self.complementary_info_keys = [] + self.complementary_info = {} + + if self.has_complementary_info: + self.complementary_info_keys = list(complementary_info.keys()) + # Pre-allocate tensors for each key in complementary_info + for key, value in complementary_info.items(): + if isinstance(value, torch.Tensor): + value_shape = value.squeeze(0).shape + self.complementary_info[key] = torch.empty( + (self.capacity, *value_shape), device=self.storage_device + ) + elif isinstance(value, (int, float)): + # Handle scalar values similar to reward + self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device) + else: + raise ValueError(f"Unsupported type {type(value)} for complementary_info[{key}]") + + self.initialized = True + + def __len__(self): + return self.size + + def add( + self, + state: dict[str, torch.Tensor], + action: torch.Tensor, + reward: float, + next_state: dict[str, torch.Tensor], + done: bool, + truncated: bool, + complementary_info: dict[str, torch.Tensor] | None = None, + ): + """Saves a transition, ensuring tensors are stored on the designated storage device.""" + # Initialize storage if this is the first transition + if not self.initialized: + self._initialize_storage(state=state, action=action, complementary_info=complementary_info) + + # Store the transition in pre-allocated tensors + for key in self.states: + self.states[key][self.position].copy_(state[key].squeeze(dim=0)) + + if not self.optimize_memory: + # Only store next_states if not optimizing memory + self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0)) + + self.actions[self.position].copy_(action.squeeze(dim=0)) + self.rewards[self.position] = reward + self.dones[self.position] = done + self.truncateds[self.position] = truncated + + # Handle complementary_info if provided and storage is initialized + if complementary_info is not None and self.has_complementary_info: + # Store the complementary_info + for key in self.complementary_info_keys: + if key in complementary_info: + value = complementary_info[key] + if isinstance(value, torch.Tensor): + self.complementary_info[key][self.position].copy_(value.squeeze(dim=0)) + elif isinstance(value, (int, float)): + self.complementary_info[key][self.position] = value + + self.position = (self.position + 1) % self.capacity + self.size = min(self.size + 1, self.capacity) + + def sample(self, batch_size: int) -> BatchTransition: + """Sample a random batch of transitions and collate them into batched tensors.""" + if not self.initialized: + raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.") + + batch_size = min(batch_size, self.size) + high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size + + # Random indices for sampling - create on the same device as storage + idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device) + + # Identify image keys that need augmentation + image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else [] + + # Create batched state and next_state + batch_state = {} + batch_next_state = {} + + # First pass: load all state tensors to target device + for key in self.states: + batch_state[key] = self.states[key][idx].to(self.device) + + if not self.optimize_memory: + # Standard approach - load next_states directly + batch_next_state[key] = self.next_states[key][idx].to(self.device) + else: + # Memory-optimized approach - get next_state from the next index + next_idx = (idx + 1) % self.capacity + batch_next_state[key] = self.states[key][next_idx].to(self.device) + + # Apply image augmentation in a batched way if needed + if self.use_drq and image_keys: + # Concatenate all images from state and next_state + all_images = [] + for key in image_keys: + all_images.append(batch_state[key]) + all_images.append(batch_next_state[key]) + + # Optimization: Batch all images and apply augmentation once + all_images_tensor = torch.cat(all_images, dim=0) + augmented_images = self.image_augmentation_function(all_images_tensor) + + # Split the augmented images back to their sources + for i, key in enumerate(image_keys): + # Calculate offsets for the current image key: + # For each key, we have 2*batch_size images (batch_size for states, batch_size for next_states) + # States start at index i*2*batch_size and take up batch_size slots + batch_state[key] = augmented_images[i * 2 * batch_size : (i * 2 + 1) * batch_size] + # Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots + batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size] + + # Sample other tensors + batch_actions = self.actions[idx].to(self.device) + batch_rewards = self.rewards[idx].to(self.device) + batch_dones = self.dones[idx].to(self.device).float() + batch_truncateds = self.truncateds[idx].to(self.device).float() + + # Sample complementary_info if available + batch_complementary_info = None + if self.has_complementary_info: + batch_complementary_info = {} + for key in self.complementary_info_keys: + batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device) + + return BatchTransition( + state=batch_state, + action=batch_actions, + reward=batch_rewards, + next_state=batch_next_state, + done=batch_dones, + truncated=batch_truncateds, + complementary_info=batch_complementary_info, + ) + + def get_iterator( + self, + batch_size: int, + async_prefetch: bool = True, + queue_size: int = 2, + ): + """ + Creates an infinite iterator that yields batches of transitions. + Will automatically restart when internal iterator is exhausted. + + Args: + batch_size (int): Size of batches to sample + async_prefetch (bool): Whether to use asynchronous prefetching with threads (default: True) + queue_size (int): Number of batches to prefetch (default: 2) + + Yields: + BatchTransition: Batched transitions + """ + while True: # Create an infinite loop + if async_prefetch: + # Get the standard iterator + iterator = self._get_async_iterator(queue_size=queue_size, batch_size=batch_size) + else: + iterator = self._get_naive_iterator(batch_size=batch_size, queue_size=queue_size) + + # Yield all items from the iterator + with suppress(StopIteration): + yield from iterator + + def _get_async_iterator(self, batch_size: int, queue_size: int = 2): + """ + Create an iterator that continuously yields prefetched batches in a + background thread. The design is intentionally simple and avoids busy + waiting / complex state management. + + Args: + batch_size (int): Size of batches to sample. + queue_size (int): Maximum number of prefetched batches to keep in + memory. + + Yields: + BatchTransition: A batch sampled from the replay buffer. + """ + import queue + import threading + + data_queue: queue.Queue = queue.Queue(maxsize=queue_size) + shutdown_event = threading.Event() + + def producer() -> None: + """Continuously put sampled batches into the queue until shutdown.""" + while not shutdown_event.is_set(): + try: + batch = self.sample(batch_size) + # The timeout ensures the thread unblocks if the queue is full + # and the shutdown event gets set meanwhile. + data_queue.put(batch, block=True, timeout=0.5) + except queue.Full: + # Queue is full – loop again (will re-check shutdown_event) + continue + except Exception: + # Surface any unexpected error and terminate the producer. + shutdown_event.set() + + producer_thread = threading.Thread(target=producer, daemon=True) + producer_thread.start() + + try: + while not shutdown_event.is_set(): + try: + yield data_queue.get(block=True) + except Exception: + # If the producer already set the shutdown flag we exit. + if shutdown_event.is_set(): + break + finally: + shutdown_event.set() + # Drain the queue quickly to help the thread exit if it's blocked on `put`. + while not data_queue.empty(): + _ = data_queue.get_nowait() + # Give the producer thread a bit of time to finish. + producer_thread.join(timeout=1.0) + + def _get_naive_iterator(self, batch_size: int, queue_size: int = 2): + """ + Creates a simple non-threaded iterator that yields batches. + + Args: + batch_size (int): Size of batches to sample + queue_size (int): Number of initial batches to prefetch + + Yields: + BatchTransition: Batch transitions + """ + import collections + + queue = collections.deque() + + def enqueue(n): + for _ in range(n): + data = self.sample(batch_size) + queue.append(data) + + enqueue(queue_size) + while queue: + yield queue.popleft() + enqueue(1) + + @classmethod + def from_lerobot_dataset( + cls, + lerobot_dataset: LeRobotDataset, + device: str = "cuda:0", + state_keys: Sequence[str] | None = None, + capacity: int | None = None, + image_augmentation_function: Callable | None = None, + use_drq: bool = True, + storage_device: str = "cpu", + optimize_memory: bool = False, + ) -> "ReplayBuffer": + """ + Convert a LeRobotDataset into a ReplayBuffer. + + Args: + lerobot_dataset (LeRobotDataset): The dataset to convert. + device (str): The device for sampling tensors. Defaults to "cuda:0". + state_keys (Sequence[str] | None): The list of keys that appear in `state` and `next_state`. + capacity (int | None): Buffer capacity. If None, uses dataset length. + action_mask (Sequence[int] | None): Indices of action dimensions to keep. + image_augmentation_function (Callable | None): Function for image augmentation. + If None, uses default random shift with pad=4. + use_drq (bool): Whether to use DrQ image augmentation when sampling. + storage_device (str): Device for storing tensor data. Using "cpu" saves GPU memory. + optimize_memory (bool): If True, reduces memory usage by not duplicating state data. + + Returns: + ReplayBuffer: The replay buffer with dataset transitions. + """ + if capacity is None: + capacity = len(lerobot_dataset) + + if capacity < len(lerobot_dataset): + raise ValueError( + "The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset." + ) + + # Create replay buffer with image augmentation and DrQ settings + replay_buffer = cls( + capacity=capacity, + device=device, + state_keys=state_keys, + image_augmentation_function=image_augmentation_function, + use_drq=use_drq, + storage_device=storage_device, + optimize_memory=optimize_memory, + ) + + # Convert dataset to transitions + list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys) + + # Initialize the buffer with the first transition to set up storage tensors + if list_transition: + first_transition = list_transition[0] + first_state = {k: v.to(device) for k, v in first_transition["state"].items()} + first_action = first_transition["action"].to(device) + + # Get complementary info if available + first_complementary_info = None + if ( + "complementary_info" in first_transition + and first_transition["complementary_info"] is not None + ): + first_complementary_info = { + k: v.to(device) for k, v in first_transition["complementary_info"].items() + } + + replay_buffer._initialize_storage( + state=first_state, action=first_action, complementary_info=first_complementary_info + ) + + # Fill the buffer with all transitions + for data in list_transition: + for k, v in data.items(): + if isinstance(v, dict): + for key, tensor in v.items(): + v[key] = tensor.to(storage_device) + elif isinstance(v, torch.Tensor): + data[k] = v.to(storage_device) + + action = data["action"] + + replay_buffer.add( + state=data["state"], + action=action, + reward=data["reward"], + next_state=data["next_state"], + done=data["done"], + truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset + complementary_info=data.get("complementary_info", None), + ) + + return replay_buffer + + def to_lerobot_dataset( + self, + repo_id: str, + fps=1, + root=None, + task_name="from_replay_buffer", + ) -> LeRobotDataset: + """ + Converts all transitions in this ReplayBuffer into a single LeRobotDataset object. + """ + if self.size == 0: + raise ValueError("The replay buffer is empty. Cannot convert to a dataset.") + + # Create features dictionary for the dataset + features = { + "index": {"dtype": "int64", "shape": [1]}, # global index across episodes + "episode_index": {"dtype": "int64", "shape": [1]}, # which episode + "frame_index": {"dtype": "int64", "shape": [1]}, # index inside an episode + "timestamp": {"dtype": "float32", "shape": [1]}, # for now we store dummy + "task_index": {"dtype": "int64", "shape": [1]}, + } + + # Add "action" + sample_action = self.actions[0] + act_info = guess_feature_info(t=sample_action, name="action") + features["action"] = act_info + + # Add "reward" and "done" + features["next.reward"] = {"dtype": "float32", "shape": (1,)} + features["next.done"] = {"dtype": "bool", "shape": (1,)} + + # Add state keys + for key in self.states: + sample_val = self.states[key][0] + f_info = guess_feature_info(t=sample_val, name=key) + features[key] = f_info + + # Add complementary_info keys if available + if self.has_complementary_info: + for key in self.complementary_info_keys: + sample_val = self.complementary_info[key][0] + if isinstance(sample_val, torch.Tensor) and sample_val.ndim == 0: + sample_val = sample_val.unsqueeze(0) + f_info = guess_feature_info(t=sample_val, name=f"complementary_info.{key}") + features[f"complementary_info.{key}"] = f_info + + # Create an empty LeRobotDataset + lerobot_dataset = LeRobotDataset.create( + repo_id=repo_id, + fps=fps, + root=root, + robot_type=None, + features=features, + use_videos=True, + ) + + # Start writing images if needed + lerobot_dataset.start_image_writer(num_processes=0, num_threads=3) + + # Convert transitions into episodes and frames + episode_index = 0 + lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index=episode_index) + + frame_idx_in_episode = 0 + for idx in range(self.size): + actual_idx = (self.position - self.size + idx) % self.capacity + + frame_dict = {} + + # Fill the data for state keys + for key in self.states: + frame_dict[key] = self.states[key][actual_idx].cpu() + + # Fill action, reward, done + frame_dict["action"] = self.actions[actual_idx].cpu() + frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu() + frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu() + + # Add complementary_info if available + if self.has_complementary_info: + for key in self.complementary_info_keys: + val = self.complementary_info[key][actual_idx] + # Convert tensors to CPU + if isinstance(val, torch.Tensor): + if val.ndim == 0: + val = val.unsqueeze(0) + frame_dict[f"complementary_info.{key}"] = val.cpu() + # Non-tensor values can be used directly + else: + frame_dict[f"complementary_info.{key}"] = val + + # Add to the dataset's buffer + lerobot_dataset.add_frame(frame_dict, task=task_name) + + # Move to next frame + frame_idx_in_episode += 1 + + # If we reached an episode boundary, call save_episode, reset counters + if self.dones[actual_idx] or self.truncateds[actual_idx]: + lerobot_dataset.save_episode() + episode_index += 1 + frame_idx_in_episode = 0 + lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer( + episode_index=episode_index + ) + + # Save any remaining frames in the buffer + if lerobot_dataset.episode_buffer["size"] > 0: + lerobot_dataset.save_episode() + + lerobot_dataset.stop_image_writer() + + return lerobot_dataset + + @staticmethod + def _lerobotdataset_to_transitions( + dataset: LeRobotDataset, + state_keys: Sequence[str] | None = None, + ) -> list[Transition]: + """ + Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions. + + Args: + dataset (LeRobotDataset): + The dataset to convert. Each item in the dataset is expected to have + at least the following keys: + { + "action": ... + "next.reward": ... + "next.done": ... + "episode_index": ... + } + plus whatever your 'state_keys' specify. + + state_keys (Sequence[str] | None): + The dataset keys to include in 'state' and 'next_state'. Their names + will be kept as-is in the output transitions. E.g. + ["observation.state", "observation.environment_state"]. + If None, you must handle or define default keys. + + Returns: + transitions (List[Transition]): + A list of Transition dictionaries with the same length as `dataset`. + """ + if state_keys is None: + raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.") + + transitions = [] + num_frames = len(dataset) + + # Check if the dataset has "next.done" key + sample = dataset[0] + has_done_key = "next.done" in sample + + # Check for complementary_info keys + complementary_info_keys = [key for key in sample if key.startswith("complementary_info.")] + has_complementary_info = len(complementary_info_keys) > 0 + + # If not, we need to infer it from episode boundaries + if not has_done_key: + print("'next.done' key not found in dataset. Inferring from episode boundaries...") + + for i in tqdm(range(num_frames)): + current_sample = dataset[i] + + # ----- 1) Current state ----- + current_state: dict[str, torch.Tensor] = {} + for key in state_keys: + val = current_sample[key] + current_state[key] = val.unsqueeze(0) # Add batch dimension + + # ----- 2) Action ----- + action = current_sample["action"].unsqueeze(0) # Add batch dimension + + # ----- 3) Reward and done ----- + reward = float(current_sample["next.reward"].item()) # ensure float + + # Determine done flag - use next.done if available, otherwise infer from episode boundaries + if has_done_key: + done = bool(current_sample["next.done"].item()) # ensure bool + else: + # If this is the last frame or if next frame is in a different episode, mark as done + done = False + if i == num_frames - 1: + done = True + elif i < num_frames - 1: + next_sample = dataset[i + 1] + if next_sample["episode_index"] != current_sample["episode_index"]: + done = True + + # TODO: (azouitine) Handle truncation (using the same value as done for now) + truncated = done + + # ----- 4) Next state ----- + # If not done and the next sample is in the same episode, we pull the next sample's state. + # Otherwise (done=True or next sample crosses to a new episode), next_state = current_state. + next_state = current_state # default + if not done and (i < num_frames - 1): + next_sample = dataset[i + 1] + if next_sample["episode_index"] == current_sample["episode_index"]: + # Build next_state from the same keys + next_state_data: dict[str, torch.Tensor] = {} + for key in state_keys: + val = next_sample[key] + next_state_data[key] = val.unsqueeze(0) # Add batch dimension + next_state = next_state_data + + # ----- 5) Complementary info (if available) ----- + complementary_info = None + if has_complementary_info: + complementary_info = {} + for key in complementary_info_keys: + # Strip the "complementary_info." prefix to get the actual key + clean_key = key[len("complementary_info.") :] + val = current_sample[key] + # Handle tensor and non-tensor values differently + if isinstance(val, torch.Tensor): + complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension + else: + # TODO: (azouitine) Check if it's necessary to convert to tensor + # For non-tensor values, use directly + complementary_info[clean_key] = val + + # ----- Construct the Transition ----- + transition = Transition( + state=current_state, + action=action, + reward=reward, + next_state=next_state, + done=done, + truncated=truncated, + complementary_info=complementary_info, + ) + transitions.append(transition) + + return transitions + + +# Utility function to guess shapes/dtypes from a tensor +def guess_feature_info(t, name: str): + """ + Return a dictionary with the 'dtype' and 'shape' for a given tensor or scalar value. + If it looks like a 3D (C,H,W) shape, we might consider it an 'image'. + Otherwise default to appropriate dtype for numeric. + """ + + shape = tuple(t.shape) + # Basic guess: if we have exactly 3 dims and shape[0] in {1, 3}, guess 'image' + if len(shape) == 3 and shape[0] in [1, 3]: + return { + "dtype": "image", + "shape": shape, + } + else: + # Otherwise treat as numeric + return { + "dtype": "float32", + "shape": shape, + } + + +def concatenate_batch_transitions( + left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition +) -> BatchTransition: + """ + Concatenates two BatchTransition objects into one. + + This function merges the right BatchTransition into the left one by concatenating + all corresponding tensors along dimension 0. The operation modifies the left_batch_transitions + in place and also returns it. + + Args: + left_batch_transitions (BatchTransition): The first batch to concatenate and the one + that will be modified in place. + right_batch_transition (BatchTransition): The second batch to append to the first one. + + Returns: + BatchTransition: The concatenated batch (same object as left_batch_transitions). + + Warning: + This function modifies the left_batch_transitions object in place. + """ + # Concatenate state fields + left_batch_transitions["state"] = { + key: torch.cat( + [left_batch_transitions["state"][key], right_batch_transition["state"][key]], + dim=0, + ) + for key in left_batch_transitions["state"] + } + + # Concatenate basic fields + left_batch_transitions["action"] = torch.cat( + [left_batch_transitions["action"], right_batch_transition["action"]], dim=0 + ) + left_batch_transitions["reward"] = torch.cat( + [left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0 + ) + + # Concatenate next_state fields + left_batch_transitions["next_state"] = { + key: torch.cat( + [left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], + dim=0, + ) + for key in left_batch_transitions["next_state"] + } + + # Concatenate done and truncated fields + left_batch_transitions["done"] = torch.cat( + [left_batch_transitions["done"], right_batch_transition["done"]], dim=0 + ) + left_batch_transitions["truncated"] = torch.cat( + [left_batch_transitions["truncated"], right_batch_transition["truncated"]], + dim=0, + ) + + # Handle complementary_info + left_info = left_batch_transitions.get("complementary_info") + right_info = right_batch_transition.get("complementary_info") + + # Only process if right_info exists + if right_info is not None: + # Initialize left complementary_info if needed + if left_info is None: + left_batch_transitions["complementary_info"] = right_info + else: + # Concatenate each field + for key in right_info: + if key in left_info: + left_info[key] = torch.cat([left_info[key], right_info[key]], dim=0) + else: + left_info[key] = right_info[key] + + return left_batch_transitions diff --git a/lerobot/common/utils/import_utils.py b/lerobot/common/utils/import_utils.py index cd5f824502..5c29b5a847 100644 --- a/lerobot/common/utils/import_utils.py +++ b/lerobot/common/utils/import_utils.py @@ -28,6 +28,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b try: # Primary method to get the package version package_version = importlib.metadata.version(pkg_name) + except importlib.metadata.PackageNotFoundError: # Fallback method: Only for "torch" and versions containing "dev" if pkg_name == "torch": @@ -43,6 +44,9 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b except ImportError: # If the package can't be imported, it's not available package_exists = False + elif pkg_name == "grpc": + package = importlib.import_module(pkg_name) + package_version = getattr(package, "__version__", "N/A") else: # For packages other than "torch", don't attempt the fallback and set as not available package_exists = False diff --git a/lerobot/common/utils/process.py b/lerobot/common/utils/process.py new file mode 100644 index 0000000000..72438b6f98 --- /dev/null +++ b/lerobot/common/utils/process.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import signal +import sys + + +class ProcessSignalHandler: + """Utility class to attach graceful shutdown signal handlers. + + The class exposes a shutdown_event attribute that is set when a shutdown + signal is received. A counter tracks how many shutdown signals have been + caught. On the second signal the process exits with status 1. + """ + + _SUPPORTED_SIGNALS = ("SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT") + + def __init__(self, use_threads: bool, display_pid: bool = False): + # TODO: Check if we can use Event from threading since Event from + # multiprocessing is the a clone of threading.Event. + # https://docs.python.org/3/library/multiprocessing.html#multiprocessing.Event + if use_threads: + from threading import Event + else: + from multiprocessing import Event + + self.shutdown_event = Event() + self._counter: int = 0 + self._display_pid = display_pid + + self._register_handlers() + + @property + def counter(self) -> int: # pragma: no cover – simple accessor + """Number of shutdown signals that have been intercepted.""" + return self._counter + + def _register_handlers(self): + """Attach the internal _signal_handler to a subset of POSIX signals.""" + + def _signal_handler(signum, frame): + pid_str = "" + if self._display_pid: + pid_str = f"[PID: {os.getpid()}]" + logging.info(f"{pid_str} Shutdown signal {signum} received. Cleaning up…") + self.shutdown_event.set() + self._counter += 1 + + # On a second Ctrl-C (or any supported signal) force the exit to + # mimic the previous behaviour while giving the caller one chance to + # shutdown gracefully. + # TODO: Investigate if we need it later + if self._counter > 1: + logging.info("Force shutdown") + sys.exit(1) + + for sig_name in self._SUPPORTED_SIGNALS: + sig = getattr(signal, sig_name, None) + if sig is None: + # The signal is not available on this platform (Windows for + # instance does not provide SIGHUP, SIGQUIT…). Skip it. + continue + try: + signal.signal(sig, _signal_handler) + except (ValueError, OSError): # pragma: no cover – unlikely but safe + # Signal not supported or we are in a non-main thread. + continue diff --git a/lerobot/common/utils/queue.py b/lerobot/common/utils/queue.py new file mode 100644 index 0000000000..ceb30e2bff --- /dev/null +++ b/lerobot/common/utils/queue.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from queue import Empty +from typing import Any + +from torch.multiprocessing import Queue + + +def get_last_item_from_queue(queue: Queue, block=True, timeout: float = 0.1) -> Any: + if block: + try: + item = queue.get(timeout=timeout) + except Empty: + return None + else: + item = None + + # Drain queue and keep only the most recent parameters + try: + while True: + item = queue.get_nowait() + except Empty: + pass + + return item diff --git a/lerobot/common/utils/transition.py b/lerobot/common/utils/transition.py new file mode 100644 index 0000000000..db413c388f --- /dev/null +++ b/lerobot/common/utils/transition.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TypedDict + +import torch + + +class Transition(TypedDict): + state: dict[str, torch.Tensor] + action: torch.Tensor + reward: float + next_state: dict[str, torch.Tensor] + done: bool + truncated: bool + complementary_info: dict[str, torch.Tensor | float | int] | None = None + + +def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition: + device = torch.device(device) + non_blocking = device.type == "cuda" + + # Move state tensors to device + transition["state"] = { + key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items() + } + + # Move action to device + transition["action"] = transition["action"].to(device, non_blocking=non_blocking) + + # Move reward and done if they are tensors + if isinstance(transition["reward"], torch.Tensor): + transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking) + + if isinstance(transition["done"], torch.Tensor): + transition["done"] = transition["done"].to(device, non_blocking=non_blocking) + + if isinstance(transition["truncated"], torch.Tensor): + transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking) + + # Move next_state tensors to device + transition["next_state"] = { + key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items() + } + + # Move complementary_info tensors if present + if transition.get("complementary_info") is not None: + for key, val in transition["complementary_info"].items(): + if isinstance(val, torch.Tensor): + transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking) + elif isinstance(val, (int, float, bool)): + transition["complementary_info"][key] = torch.tensor(val, device=device) + else: + raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]") + return transition + + +def move_state_dict_to_device(state_dict, device="cpu"): + """ + Recursively move all tensors in a (potentially) nested + dict/list/tuple structure to the CPU. + """ + if isinstance(state_dict, torch.Tensor): + return state_dict.to(device) + elif isinstance(state_dict, dict): + return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()} + elif isinstance(state_dict, list): + return [move_state_dict_to_device(v, device=device) for v in state_dict] + elif isinstance(state_dict, tuple): + return tuple(move_state_dict_to_device(v, device=device) for v in state_dict) + else: + return state_dict diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index 08e9a3c06b..cba65ba456 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -20,9 +20,11 @@ import select import subprocess import sys -from copy import copy +import time +from copy import copy, deepcopy from datetime import datetime, timezone from pathlib import Path +from statistics import mean import numpy as np import torch @@ -109,11 +111,17 @@ def is_amp_available(device: str): raise ValueError(f"Unknown device '{device}.") -def init_logging(): +def init_logging(log_file: Path | None = None, display_pid: bool = False): def custom_format(record): dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") fnameline = f"{record.pathname}:{record.lineno}" - message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}" + + # NOTE: Display PID is useful for multi-process logging. + if display_pid: + pid_str = f"[PID: {os.getpid()}]" + message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.msg}" + else: + message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}" return message logging.basicConfig(level=logging.INFO) @@ -127,6 +135,12 @@ def custom_format(record): console_handler.setFormatter(formatter) logging.getLogger().addHandler(console_handler) + if log_file is not None: + # Additionally write logs to file + file_handler = logging.FileHandler(log_file) + file_handler.setFormatter(formatter) + logging.getLogger().addHandler(file_handler) + def format_big_number(num, precision=0): suffixes = ["", "K", "M", "B", "T", "Q"] @@ -247,3 +261,114 @@ def enter_pressed() -> bool: def move_cursor_up(lines): """Move the cursor up by a specified number of lines.""" print(f"\033[{lines}A", end="") + + +class TimerManager: + """ + Lightweight utility to measure elapsed time. + + Examples + -------- + ```python + # Example 1: Using context manager + timer = TimerManager("Policy", log=False) + for _ in range(3): + with timer: + time.sleep(0.01) + print(timer.last, timer.fps_avg, timer.percentile(90)) # Prints: 0.01 100.0 0.01 + ``` + + ```python + # Example 2: Using start/stop methods + timer = TimerManager("Policy", log=False) + timer.start() + time.sleep(0.01) + timer.stop() + print(timer.last, timer.fps_avg, timer.percentile(90)) # Prints: 0.01 100.0 0.01 + ``` + """ + + def __init__( + self, + label: str = "Elapsed-time", + log: bool = True, + logger: logging.Logger | None = None, + ): + self.label = label + self.log = log + self.logger = logger + self._start: float | None = None + self._history: list[float] = [] + + def __enter__(self): + return self.start() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() + + def start(self): + self._start = time.perf_counter() + return self + + def stop(self) -> float: + if self._start is None: + raise RuntimeError("Timer was never started.") + elapsed = time.perf_counter() - self._start + self._history.append(elapsed) + self._start = None + if self.log: + if self.logger is not None: + self.logger.info(f"{self.label}: {elapsed:.6f} s") + else: + logging.info(f"{self.label}: {elapsed:.6f} s") + return elapsed + + def reset(self): + self._history.clear() + + @property + def last(self) -> float: + return self._history[-1] if self._history else 0.0 + + @property + def avg(self) -> float: + return mean(self._history) if self._history else 0.0 + + @property + def total(self) -> float: + return sum(self._history) + + @property + def count(self) -> int: + return len(self._history) + + @property + def history(self) -> list[float]: + return deepcopy(self._history) + + @property + def fps_history(self) -> list[float]: + return [1.0 / t for t in self._history] + + @property + def fps_last(self) -> float: + return 0.0 if self.last == 0 else 1.0 / self.last + + @property + def fps_avg(self) -> float: + return 0.0 if self.avg == 0 else 1.0 / self.avg + + def percentile(self, p: float) -> float: + """ + Return the p-th percentile of recorded times. + """ + if not self._history: + return 0.0 + return float(np.percentile(self._history, p)) + + def fps_percentile(self, p: float) -> float: + """ + FPS corresponding to the p-th percentile time. + """ + val = self.percentile(p) + return 0.0 if val == 0 else 1.0 / val diff --git a/lerobot/common/utils/wandb_utils.py b/lerobot/common/utils/wandb_utils.py index 9e938e1917..ac4d223433 100644 --- a/lerobot/common/utils/wandb_utils.py +++ b/lerobot/common/utils/wandb_utils.py @@ -30,9 +30,10 @@ def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[st """Return a group name for logging. Optionally returns group name as list.""" lst = [ f"policy:{cfg.policy.type}", - f"dataset:{cfg.dataset.repo_id}", f"seed:{cfg.seed}", ] + if cfg.dataset is not None: + lst.append(f"dataset:{cfg.dataset.repo_id}") if cfg.env is not None: lst.append(f"env:{cfg.env.type}") return lst if return_list else "-".join(lst) @@ -92,6 +93,12 @@ def __init__(self, cfg: TrainPipelineConfig): resume="must" if cfg.resume else None, mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online", ) + run_id = wandb.run.id + # NOTE: We will override the cfg.wandb.run_id with the wandb run id. + # This is because we want to be able to resume the run from the wandb run id. + cfg.wandb.run_id = run_id + # Handle custom step key for rl asynchronous training. + self._wandb_custom_step_key: set[str] | None = None print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") self._wandb = wandb @@ -108,9 +115,26 @@ def log_policy(self, checkpoint_dir: Path): artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE) self._wandb.log_artifact(artifact) - def log_dict(self, d: dict, step: int, mode: str = "train"): + def log_dict( + self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None + ): if mode not in {"train", "eval"}: raise ValueError(mode) + if step is None and custom_step_key is None: + raise ValueError("Either step or custom_step_key must be provided.") + + # NOTE: This is not simple. Wandb step must always monotonically increase and it + # increases with each wandb.log call, but in the case of asynchronous RL for example, + # multiple time steps is possible. For example, the interaction step with the environment, + # the training step, the evaluation step, etc. So we need to define a custom step key + # to log the correct step for each metric. + if custom_step_key is not None: + if self._wandb_custom_step_key is None: + self._wandb_custom_step_key = set() + new_custom_key = f"{mode}/{custom_step_key}" + if new_custom_key not in self._wandb_custom_step_key: + self._wandb_custom_step_key.add(new_custom_key) + self._wandb.define_metric(new_custom_key, hidden=True) for k, v in d.items(): if not isinstance(v, (int, float, str)): @@ -118,7 +142,18 @@ def log_dict(self, d: dict, step: int, mode: str = "train"): f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.' ) continue - self._wandb.log({f"{mode}/{k}": v}, step=step) + + # Do not log the custom step key itself. + if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key: + continue + + if custom_step_key is not None: + value_custom_step = d[custom_step_key] + data = {f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step} + self._wandb.log(data) + continue + + self._wandb.log(data={f"{mode}/{k}": v}, step=step) def log_video(self, video_path: str, step: int, mode: str = "train"): if mode not in {"train", "eval"}: diff --git a/lerobot/configs/control.py b/lerobot/configs/control.py deleted file mode 100644 index 07b8d13523..0000000000 --- a/lerobot/configs/control.py +++ /dev/null @@ -1,134 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from pathlib import Path - -import draccus - -from lerobot.common.robots import RobotConfig -from lerobot.configs import parser -from lerobot.configs.policies import PreTrainedConfig - - -@dataclass -class ControlConfig(draccus.ChoiceRegistry): - pass - - -@ControlConfig.register_subclass("calibrate") -@dataclass -class CalibrateControlConfig(ControlConfig): - # List of arms to calibrate (e.g. `--arms='["left_follower","right_follower"]' left_leader`) - arms: list[str] | None = None - - -@ControlConfig.register_subclass("teleoperate") -@dataclass -class TeleoperateControlConfig(ControlConfig): - # Limit the maximum frames per second. By default, no limit. - fps: int | None = None - teleop_time_s: float | None = None - # Display all cameras on screen - display_data: bool = False - - -@ControlConfig.register_subclass("record") -@dataclass -class RecordControlConfig(ControlConfig): - # Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`). - repo_id: str - # A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.") - single_task: str - # Root directory where the dataset will be stored (e.g. 'dataset/path'). - root: str | Path | None = None - policy: PreTrainedConfig | None = None - # Limit the frames per second. By default, uses the policy fps. - fps: int | None = None - # Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize. - warmup_time_s: int | float = 10 - # Number of seconds for data recording for each episode. - episode_time_s: int | float = 60 - # Number of seconds for resetting the environment after each episode. - reset_time_s: int | float = 60 - # Number of episodes to record. - num_episodes: int = 50 - # Encode frames in the dataset into video - video: bool = True - # Upload dataset to Hugging Face hub. - push_to_hub: bool = True - # Upload on private repository on the Hugging Face hub. - private: bool = False - # Add tags to your dataset on the hub. - tags: list[str] | None = None - # Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only; - # set to ≥1 to use subprocesses, each using threads to write images. The best number of processes - # and threads depends on your system. We recommend 4 threads per camera with 0 processes. - # If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses. - num_image_writer_processes: int = 0 - # Number of threads writing the frames as png images on disk, per camera. - # Too many threads might cause unstable teleoperation fps due to main thread being blocked. - # Not enough threads might cause low camera fps. - num_image_writer_threads_per_camera: int = 4 - # Display all cameras on screen - display_data: bool = False - # Use vocal synthesis to read events. - play_sounds: bool = True - # Resume recording on an existing dataset. - resume: bool = False - - def __post_init__(self): - # HACK: We parse again the cli args here to get the pretrained path if there was one. - policy_path = parser.get_path_arg("control.policy") - if policy_path: - cli_overrides = parser.get_cli_overrides("control.policy") - self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) - self.policy.pretrained_path = policy_path - - -@ControlConfig.register_subclass("replay") -@dataclass -class ReplayControlConfig(ControlConfig): - # Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`). - repo_id: str - # Index of the episode to replay. - episode: int - # Root directory where the dataset will be stored (e.g. 'dataset/path'). - root: str | Path | None = None - # Limit the frames per second. By default, uses the dataset fps. - fps: int | None = None - # Use vocal synthesis to read events. - play_sounds: bool = True - - -@ControlConfig.register_subclass("remote_robot") -@dataclass -class RemoteRobotConfig(ControlConfig): - log_interval: int = 100 - # Display all cameras on screen - display_data: bool = False - # Rerun configuration for remote robot (https://ref.rerun.io/docs/python/0.22.1/common/initialization_functions/#rerun.connect_tcp) - viewer_ip: str | None = None - viewer_port: str | None = None - - -@dataclass -class ControlPipelineConfig: - robot: RobotConfig - control: ControlConfig - - @classmethod - def __get_path_fields__(cls) -> list[str]: - """This enables the parser to load config from the policy using `--policy.path=local/dir`""" - return ["control.policy"] diff --git a/lerobot/configs/train.py b/lerobot/configs/train.py index 98826294ea..96a460bdf1 100644 --- a/lerobot/configs/train.py +++ b/lerobot/configs/train.py @@ -172,3 +172,8 @@ def from_pretrained( cli_args = kwargs.pop("cli_args", []) with draccus.config_type("json"): return draccus.parse(cls, config_file, args=cli_args) + + +@dataclass(kw_only=True) +class TrainRLServerPipelineConfig(TrainPipelineConfig): + dataset: DatasetConfig | None = None # NOTE: In RL, we don't need an offline dataset diff --git a/lerobot/configs/types.py b/lerobot/configs/types.py index 6b3d92e80d..6040ff70ba 100644 --- a/lerobot/configs/types.py +++ b/lerobot/configs/types.py @@ -23,6 +23,7 @@ class FeatureType(str, Enum): VISUAL = "VISUAL" ENV = "ENV" ACTION = "ACTION" + REWARD = "REWARD" class NormalizationMode(str, Enum): diff --git a/lerobot/find_cameras.py b/lerobot/find_cameras.py index 3b5c4af3c0..34f4865b1d 100644 --- a/lerobot/find_cameras.py +++ b/lerobot/find_cameras.py @@ -170,7 +170,7 @@ def create_camera_instance(cam_meta: Dict[str, Any]) -> Dict[str, Any] | None: instance = OpenCVCamera(cv_config) elif cam_type == "RealSense": rs_config = RealSenseCameraConfig( - serial_number_or_name=int(cam_id), + serial_number_or_name=cam_id, color_mode=ColorMode.RGB, ) instance = RealSenseCamera(rs_config) @@ -283,7 +283,7 @@ def save_images_from_all_cameras( print("\nFinalizing image saving...") executor.shutdown(wait=True) cleanup_cameras(cameras_to_use) - logger.info(f"Image capture finished. Images saved to {output_dir}") + print(f"Image capture finished. Images saved to {output_dir}") if __name__ == "__main__": diff --git a/lerobot/record.py b/lerobot/record.py index 531846f297..884a3fcd6d 100644 --- a/lerobot/record.py +++ b/lerobot/record.py @@ -139,7 +139,7 @@ class RecordConfig: resume: bool = False def __post_init__(self): - if bool(self.teleop) == bool(self.policy): + if self.teleop is not None and self.policy is not None: raise ValueError("Choose either a policy or a teleoperator to control the robot") # HACK: We parse again the cli args here to get the pretrained path if there was one. diff --git a/lerobot/scripts/find_joint_limits.py b/lerobot/scripts/find_joint_limits.py new file mode 100644 index 0000000000..95676dd359 --- /dev/null +++ b/lerobot/scripts/find_joint_limits.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Simple script to control a robot from teleoperation. + +Example: + +```shell +python -m lerobot.scripts.server.find_joint_limits \ + --robot.type=so100_follower \ + --robot.port=/dev/tty.usbmodem58760431541 \ + --robot.id=black \ + --teleop.type=so100_leader \ + --teleop.port=/dev/tty.usbmodem58760431551 \ + --teleop.id=blue +``` +""" + +import time +from dataclasses import dataclass + +import draccus +import numpy as np + +from lerobot.common.model.kinematics import RobotKinematics +from lerobot.common.robots import ( # noqa: F401 + RobotConfig, + koch_follower, + make_robot_from_config, + so100_follower, +) +from lerobot.common.teleoperators import ( # noqa: F401 + TeleoperatorConfig, + gamepad, + koch_leader, + make_teleoperator_from_config, + so100_leader, +) + + +@dataclass +class FindJointLimitsConfig: + teleop: TeleoperatorConfig + robot: RobotConfig + # Limit the maximum frames per second. By default, no limit. + teleop_time_s: float = 30 + # Display all cameras on screen + display_data: bool = False + + +@draccus.wrap() +def find_joint_and_ee_bounds(cfg: FindJointLimitsConfig): + teleop = make_teleoperator_from_config(cfg.teleop) + robot = make_robot_from_config(cfg.robot) + + teleop.connect() + robot.connect() + + start_episode_t = time.perf_counter() + robot_type = getattr(robot.config, "robot_type", "so101") + if "so100" in robot_type or "so101" in robot_type: + # Note to be compatible with the rest of the codebase, + # we are using the new calibration method for so101 and so100 + robot_type = "so_new_calibration" + kinematics = RobotKinematics(robot_type=robot_type) + + # Initialize min/max values + observation = robot.get_observation() + joint_positions = np.array([observation[f"{key}.pos"] for key in robot.bus.motors]) + ee_pos = kinematics.forward_kinematics(joint_positions, frame="gripper_tip")[:3, 3] + + max_pos = joint_positions.copy() + min_pos = joint_positions.copy() + max_ee = ee_pos.copy() + min_ee = ee_pos.copy() + + while True: + action = teleop.get_action() + robot.send_action(action) + + observation = robot.get_observation() + joint_positions = np.array([observation[f"{key}.pos"] for key in robot.bus.motors]) + ee_pos = kinematics.forward_kinematics(joint_positions, frame="gripper_tip")[:3, 3] + + # Skip initial warmup period + if (time.perf_counter() - start_episode_t) < 5: + continue + + # Update min/max values + max_ee = np.maximum(max_ee, ee_pos) + min_ee = np.minimum(min_ee, ee_pos) + max_pos = np.maximum(max_pos, joint_positions) + min_pos = np.minimum(min_pos, joint_positions) + + if time.perf_counter() - start_episode_t > cfg.teleop_time_s: + print(f"Max ee position {np.round(max_ee, 4).tolist()}") + print(f"Min ee position {np.round(min_ee, 4).tolist()}") + print(f"Max joint pos position {np.round(max_pos, 4).tolist()}") + print(f"Min joint pos position {np.round(min_pos, 4).tolist()}") + break + + +if __name__ == "__main__": + find_joint_and_ee_bounds() diff --git a/lerobot/scripts/rl/actor.py b/lerobot/scripts/rl/actor.py new file mode 100644 index 0000000000..da24d0dc58 --- /dev/null +++ b/lerobot/scripts/rl/actor.py @@ -0,0 +1,709 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Actor server runner for distributed HILSerl robot policy training. + +This script implements the actor component of the distributed HILSerl architecture. +It executes the policy in the robot environment, collects experience, +and sends transitions to the learner server for policy updates. + +Examples of usage: + +- Start an actor server for real robot training with human-in-the-loop intervention: +```bash +python lerobot/scripts/rl/actor.py --config_path lerobot/configs/train_config_hilserl_so100.json +``` + +**NOTE**: The actor server requires a running learner server to connect to. Ensure the learner +server is started before launching the actor. + +**NOTE**: Human intervention is key to HILSerl training. Press the upper right trigger button on the +gamepad to take control of the robot during training. Initially intervene frequently, then gradually +reduce interventions as the policy improves. + +**WORKFLOW**: +1. Determine robot workspace bounds using `find_joint_limits.py` +2. Record demonstrations with `gym_manipulator.py` in record mode +3. Process the dataset and determine camera crops with `crop_dataset_roi.py` +4. Start the learner server with the training configuration +5. Start this actor server with the same configuration +6. Use human interventions to guide policy learning + +For more details on the complete HILSerl training workflow, see: +https://github.com/michel-aractingi/lerobot-hilserl-guide +""" + +import logging +import os +import time +from functools import lru_cache +from queue import Empty + +import grpc +import torch +from torch import nn +from torch.multiprocessing import Event, Queue + +from lerobot.common.cameras import opencv # noqa: F401 +from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.sac.modeling_sac import SACPolicy +from lerobot.common.robots import so100_follower # noqa: F401 +from lerobot.common.teleoperators import gamepad, so101_leader # noqa: F401 +from lerobot.common.transport import services_pb2, services_pb2_grpc +from lerobot.common.transport.utils import ( + bytes_to_state_dict, + python_object_to_bytes, + receive_bytes_in_chunks, + send_bytes_in_chunks, + transitions_to_bytes, +) +from lerobot.common.utils.process import ProcessSignalHandler +from lerobot.common.utils.queue import get_last_item_from_queue +from lerobot.common.utils.random_utils import set_seed +from lerobot.common.utils.robot_utils import busy_wait +from lerobot.common.utils.transition import ( + Transition, + move_state_dict_to_device, + move_transition_to_device, +) +from lerobot.common.utils.utils import ( + TimerManager, + get_safe_torch_device, + init_logging, +) +from lerobot.configs import parser +from lerobot.configs.train import TrainRLServerPipelineConfig +from lerobot.scripts.rl import learner_service +from lerobot.scripts.rl.gym_manipulator import make_robot_env + +ACTOR_SHUTDOWN_TIMEOUT = 30 + + +################################################# +# Main entry point # +################################################# + + +@parser.wrap() +def actor_cli(cfg: TrainRLServerPipelineConfig): + cfg.validate() + display_pid = False + if not use_threads(cfg): + import torch.multiprocessing as mp + + mp.set_start_method("spawn") + display_pid = True + + # Create logs directory to ensure it exists + log_dir = os.path.join(cfg.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"actor_{cfg.job_name}.log") + + # Initialize logging with explicit log file + init_logging(log_file=log_file, display_pid=display_pid) + logging.info(f"Actor logging initialized, writing to {log_file}") + + is_threaded = use_threads(cfg) + shutdown_event = ProcessSignalHandler(is_threaded, display_pid=display_pid).shutdown_event + + learner_client, grpc_channel = learner_service_client( + host=cfg.policy.actor_learner_config.learner_host, + port=cfg.policy.actor_learner_config.learner_port, + ) + + logging.info("[ACTOR] Establishing connection with Learner") + if not establish_learner_connection(learner_client, shutdown_event): + logging.error("[ACTOR] Failed to establish connection with Learner") + return + + if not use_threads(cfg): + # If we use multithreading, we can reuse the channel + grpc_channel.close() + grpc_channel = None + + logging.info("[ACTOR] Connection with Learner established") + + parameters_queue = Queue() + transitions_queue = Queue() + interactions_queue = Queue() + + concurrency_entity = None + if use_threads(cfg): + from threading import Thread + + concurrency_entity = Thread + else: + from multiprocessing import Process + + concurrency_entity = Process + + receive_policy_process = concurrency_entity( + target=receive_policy, + args=(cfg, parameters_queue, shutdown_event, grpc_channel), + daemon=True, + ) + + transitions_process = concurrency_entity( + target=send_transitions, + args=(cfg, transitions_queue, shutdown_event, grpc_channel), + daemon=True, + ) + + interactions_process = concurrency_entity( + target=send_interactions, + args=(cfg, interactions_queue, shutdown_event, grpc_channel), + daemon=True, + ) + + transitions_process.start() + interactions_process.start() + receive_policy_process.start() + + act_with_policy( + cfg=cfg, + shutdown_event=shutdown_event, + parameters_queue=parameters_queue, + transitions_queue=transitions_queue, + interactions_queue=interactions_queue, + ) + logging.info("[ACTOR] Policy process joined") + + logging.info("[ACTOR] Closing queues") + transitions_queue.close() + interactions_queue.close() + parameters_queue.close() + + transitions_process.join() + logging.info("[ACTOR] Transitions process joined") + interactions_process.join() + logging.info("[ACTOR] Interactions process joined") + receive_policy_process.join() + logging.info("[ACTOR] Receive policy process joined") + + logging.info("[ACTOR] join queues") + transitions_queue.cancel_join_thread() + interactions_queue.cancel_join_thread() + parameters_queue.cancel_join_thread() + + logging.info("[ACTOR] queues closed") + + +################################################# +# Core algorithm functions # +################################################# + + +def act_with_policy( + cfg: TrainRLServerPipelineConfig, + shutdown_event: any, # Event, + parameters_queue: Queue, + transitions_queue: Queue, + interactions_queue: Queue, +): + """ + Executes policy interaction within the environment. + + This function rolls out the policy in the environment, collecting interaction data and pushing it to a queue for streaming to the learner. + Once an episode is completed, updated network parameters received from the learner are retrieved from a queue and loaded into the network. + + Args: + cfg: Configuration settings for the interaction process. + shutdown_event: Event to check if the process should shutdown. + parameters_queue: Queue to receive updated network parameters from the learner. + transitions_queue: Queue to send transitions to the learner. + interactions_queue: Queue to send interactions to the learner. + """ + # Initialize logging for multiprocessing + if not use_threads(cfg): + log_dir = os.path.join(cfg.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"actor_policy_{os.getpid()}.log") + init_logging(log_file=log_file, display_pid=True) + logging.info("Actor policy process logging initialized") + + logging.info("make_env online") + + online_env = make_robot_env(cfg=cfg.env) + + set_seed(cfg.seed) + device = get_safe_torch_device(cfg.policy.device, log=True) + + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + + logging.info("make_policy") + + ### Instantiate the policy in both the actor and learner processes + ### To avoid sending a SACPolicy object through the port, we create a policy instance + ### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters + policy: SACPolicy = make_policy( + cfg=cfg.policy, + env_cfg=cfg.env, + ) + policy = policy.eval() + assert isinstance(policy, nn.Module) + + obs, info = online_env.reset() + + # NOTE: For the moment we will solely handle the case of a single environment + sum_reward_episode = 0 + list_transition_to_send_to_learner = [] + episode_intervention = False + # Add counters for intervention rate calculation + episode_intervention_steps = 0 + episode_total_steps = 0 + + policy_timer = TimerManager("Policy inference", log=False) + + for interaction_step in range(cfg.policy.online_steps): + start_time = time.perf_counter() + if shutdown_event.is_set(): + logging.info("[ACTOR] Shutting down act_with_policy") + return + + if interaction_step >= cfg.policy.online_step_before_learning: + # Time policy inference and check if it meets FPS requirement + with policy_timer: + action = policy.select_action(batch=obs) + policy_fps = policy_timer.fps_last + + log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step) + + else: + action = online_env.action_space.sample() + + next_obs, reward, done, truncated, info = online_env.step(action) + + sum_reward_episode += float(reward) + # Increment total steps counter for intervention rate + episode_total_steps += 1 + + # NOTE: We override the action if the intervention is True, because the action applied is the intervention action + if "is_intervention" in info and info["is_intervention"]: + # NOTE: The action space for demonstration before hand is with the full action space + # but sometimes for example we want to deactivate the gripper + action = info["action_intervention"] + episode_intervention = True + # Increment intervention steps counter + episode_intervention_steps += 1 + + list_transition_to_send_to_learner.append( + Transition( + state=obs, + action=action, + reward=reward, + next_state=next_obs, + done=done, + truncated=truncated, # TODO: (azouitine) Handle truncation properly + complementary_info=info, + ) + ) + # assign obs to the next obs and continue the rollout + obs = next_obs + + if done or truncated: + logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}") + + update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device) + + if len(list_transition_to_send_to_learner) > 0: + push_transitions_to_transport_queue( + transitions=list_transition_to_send_to_learner, + transitions_queue=transitions_queue, + ) + list_transition_to_send_to_learner = [] + + stats = get_frequency_stats(policy_timer) + policy_timer.reset() + + # Calculate intervention rate + intervention_rate = 0.0 + if episode_total_steps > 0: + intervention_rate = episode_intervention_steps / episode_total_steps + + # Send episodic reward to the learner + interactions_queue.put( + python_object_to_bytes( + { + "Episodic reward": sum_reward_episode, + "Interaction step": interaction_step, + "Episode intervention": int(episode_intervention), + "Intervention rate": intervention_rate, + **stats, + } + ) + ) + + # Reset intervention counters + sum_reward_episode = 0.0 + episode_intervention = False + episode_intervention_steps = 0 + episode_total_steps = 0 + obs, info = online_env.reset() + + if cfg.env.fps is not None: + dt_time = time.perf_counter() - start_time + busy_wait(1 / cfg.env.fps - dt_time) + + +################################################# +# Communication Functions - Group all gRPC/messaging functions # +################################################# + + +def establish_learner_connection( + stub: services_pb2_grpc.LearnerServiceStub, + shutdown_event: Event, # type: ignore + attempts: int = 30, +): + """Establish a connection with the learner. + + Args: + stub (services_pb2_grpc.LearnerServiceStub): The stub to use for the connection. + shutdown_event (Event): The event to check if the connection should be established. + attempts (int): The number of attempts to establish the connection. + Returns: + bool: True if the connection is established, False otherwise. + """ + for _ in range(attempts): + if shutdown_event.is_set(): + logging.info("[ACTOR] Shutting down establish_learner_connection") + return False + + # Force a connection attempt and check state + try: + logging.info("[ACTOR] Send ready message to Learner") + if stub.Ready(services_pb2.Empty()) == services_pb2.Empty(): + return True + except grpc.RpcError as e: + logging.error(f"[ACTOR] Waiting for Learner to be ready... {e}") + time.sleep(2) + return False + + +@lru_cache(maxsize=1) +def learner_service_client( + host: str = "127.0.0.1", + port: int = 50051, +) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]: + import json + + """ + Returns a client for the learner service. + + GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection. + So we need to create only one client and reuse it. + """ + + service_config = { + "methodConfig": [ + { + "name": [{}], # Applies to ALL methods in ALL services + "retryPolicy": { + "maxAttempts": 5, # Max retries (total attempts = 5) + "initialBackoff": "0.1s", # First retry after 0.1s + "maxBackoff": "2s", # Max wait time between retries + "backoffMultiplier": 2, # Exponential backoff factor + "retryableStatusCodes": [ + "UNAVAILABLE", + "DEADLINE_EXCEEDED", + ], # Retries on network failures + }, + } + ] + } + + service_config_json = json.dumps(service_config) + + channel = grpc.insecure_channel( + f"{host}:{port}", + options=[ + ("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE), + ("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE), + ("grpc.enable_retries", 1), + ("grpc.service_config", service_config_json), + ], + ) + stub = services_pb2_grpc.LearnerServiceStub(channel) + logging.info("[ACTOR] Learner service client created") + return stub, channel + + +def receive_policy( + cfg: TrainRLServerPipelineConfig, + parameters_queue: Queue, + shutdown_event: Event, # type: ignore + learner_client: services_pb2_grpc.LearnerServiceStub | None = None, + grpc_channel: grpc.Channel | None = None, +): + """Receive parameters from the learner. + + Args: + cfg (TrainRLServerPipelineConfig): The configuration for the actor. + parameters_queue (Queue): The queue to receive the parameters. + shutdown_event (Event): The event to check if the process should shutdown. + """ + logging.info("[ACTOR] Start receiving parameters from the Learner") + if not use_threads(cfg): + # Create a process-specific log file + log_dir = os.path.join(cfg.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"actor_receive_policy_{os.getpid()}.log") + + # Initialize logging with explicit log file + init_logging(log_file=log_file, display_pid=True) + logging.info("Actor receive policy process logging initialized") + + # Setup process handlers to handle shutdown signal + # But use shutdown event from the main process + _ = ProcessSignalHandler(use_threads=False, display_pid=True) + + if grpc_channel is None or learner_client is None: + learner_client, grpc_channel = learner_service_client( + host=cfg.policy.actor_learner_config.learner_host, + port=cfg.policy.actor_learner_config.learner_port, + ) + + try: + iterator = learner_client.StreamParameters(services_pb2.Empty()) + receive_bytes_in_chunks( + iterator, + parameters_queue, + shutdown_event, + log_prefix="[ACTOR] parameters", + ) + + except grpc.RpcError as e: + logging.error(f"[ACTOR] gRPC error: {e}") + + if not use_threads(cfg): + grpc_channel.close() + logging.info("[ACTOR] Received policy loop stopped") + + +def send_transitions( + cfg: TrainRLServerPipelineConfig, + transitions_queue: Queue, + shutdown_event: any, # Event, + learner_client: services_pb2_grpc.LearnerServiceStub | None = None, + grpc_channel: grpc.Channel | None = None, +) -> services_pb2.Empty: + """ + Sends transitions to the learner. + + This function continuously retrieves messages from the queue and processes: + + - Transition Data: + - A batch of transitions (observation, action, reward, next observation) is collected. + - Transitions are moved to the CPU and serialized using PyTorch. + - The serialized data is wrapped in a `services_pb2.Transition` message and sent to the learner. + """ + + if not use_threads(cfg): + # Create a process-specific log file + log_dir = os.path.join(cfg.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"actor_transitions_{os.getpid()}.log") + + # Initialize logging with explicit log file + init_logging(log_file=log_file, display_pid=True) + logging.info("Actor transitions process logging initialized") + + if grpc_channel is None or learner_client is None: + learner_client, grpc_channel = learner_service_client( + host=cfg.policy.actor_learner_config.learner_host, + port=cfg.policy.actor_learner_config.learner_port, + ) + + try: + learner_client.SendTransitions( + transitions_stream( + shutdown_event, transitions_queue, cfg.policy.actor_learner_config.queue_get_timeout + ) + ) + except grpc.RpcError as e: + logging.error(f"[ACTOR] gRPC error: {e}") + + logging.info("[ACTOR] Finished streaming transitions") + + if not use_threads(cfg): + grpc_channel.close() + logging.info("[ACTOR] Transitions process stopped") + + +def send_interactions( + cfg: TrainRLServerPipelineConfig, + interactions_queue: Queue, + shutdown_event: Event, # type: ignore + learner_client: services_pb2_grpc.LearnerServiceStub | None = None, + grpc_channel: grpc.Channel | None = None, +) -> services_pb2.Empty: + """ + Sends interactions to the learner. + + This function continuously retrieves messages from the queue and processes: + + - Interaction Messages: + - Contains useful statistics about episodic rewards and policy timings. + - The message is serialized using `pickle` and sent to the learner. + """ + + if not use_threads(cfg): + # Create a process-specific log file + log_dir = os.path.join(cfg.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"actor_interactions_{os.getpid()}.log") + + # Initialize logging with explicit log file + init_logging(log_file=log_file, display_pid=True) + logging.info("Actor interactions process logging initialized") + + # Setup process handlers to handle shutdown signal + # But use shutdown event from the main process + _ = ProcessSignalHandler(use_threads=False, display_pid=True) + + if grpc_channel is None or learner_client is None: + learner_client, grpc_channel = learner_service_client( + host=cfg.policy.actor_learner_config.learner_host, + port=cfg.policy.actor_learner_config.learner_port, + ) + + try: + learner_client.SendInteractions( + interactions_stream( + shutdown_event, interactions_queue, cfg.policy.actor_learner_config.queue_get_timeout + ) + ) + except grpc.RpcError as e: + logging.error(f"[ACTOR] gRPC error: {e}") + + logging.info("[ACTOR] Finished streaming interactions") + + if not use_threads(cfg): + grpc_channel.close() + logging.info("[ACTOR] Interactions process stopped") + + +def transitions_stream(shutdown_event: Event, transitions_queue: Queue, timeout: float) -> services_pb2.Empty: # type: ignore + while not shutdown_event.is_set(): + try: + message = transitions_queue.get(block=True, timeout=timeout) + except Empty: + logging.debug("[ACTOR] Transition queue is empty") + continue + + yield from send_bytes_in_chunks( + message, services_pb2.Transition, log_prefix="[ACTOR] Send transitions" + ) + + return services_pb2.Empty() + + +def interactions_stream( + shutdown_event: Event, + interactions_queue: Queue, + timeout: float, # type: ignore +) -> services_pb2.Empty: + while not shutdown_event.is_set(): + try: + message = interactions_queue.get(block=True, timeout=timeout) + except Empty: + logging.debug("[ACTOR] Interaction queue is empty") + continue + + yield from send_bytes_in_chunks( + message, + services_pb2.InteractionMessage, + log_prefix="[ACTOR] Send interactions", + ) + + return services_pb2.Empty() + + +################################################# +# Policy functions # +################################################# + + +def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device): + bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False) + if bytes_state_dict is not None: + logging.info("[ACTOR] Load new parameters from Learner.") + state_dict = bytes_to_state_dict(bytes_state_dict) + state_dict = move_state_dict_to_device(state_dict, device=device) + policy.load_state_dict(state_dict) + + +################################################# +# Utilities functions # +################################################# + + +def push_transitions_to_transport_queue(transitions: list, transitions_queue): + """Send transitions to learner in smaller chunks to avoid network issues. + + Args: + transitions: List of transitions to send + message_queue: Queue to send messages to learner + chunk_size: Size of each chunk to send + """ + transition_to_send_to_learner = [] + for transition in transitions: + tr = move_transition_to_device(transition=transition, device="cpu") + for key, value in tr["state"].items(): + if torch.isnan(value).any(): + logging.warning(f"Found NaN values in transition {key}") + + transition_to_send_to_learner.append(tr) + + transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner)) + + +def get_frequency_stats(timer: TimerManager) -> dict[str, float]: + """Get the frequency statistics of the policy. + + Args: + timer (TimerManager): The timer with collected metrics. + + Returns: + dict[str, float]: The frequency statistics of the policy. + """ + stats = {} + if timer.count > 1: + avg_fps = timer.fps_avg + p90_fps = timer.fps_percentile(90) + logging.debug(f"[ACTOR] Average policy frame rate: {avg_fps}") + logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {p90_fps}") + stats = { + "Policy frequency [Hz]": avg_fps, + "Policy frequency 90th-p [Hz]": p90_fps, + } + return stats + + +def log_policy_frequency_issue(policy_fps: float, cfg: TrainRLServerPipelineConfig, interaction_step: int): + if policy_fps < cfg.env.fps: + logging.warning( + f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.env.fps} at step {interaction_step}" + ) + + +def use_threads(cfg: TrainRLServerPipelineConfig) -> bool: + return cfg.policy.concurrency.actor == "threads" + + +if __name__ == "__main__": + actor_cli() diff --git a/lerobot/scripts/rl/crop_dataset_roi.py b/lerobot/scripts/rl/crop_dataset_roi.py new file mode 100644 index 0000000000..5b7038de30 --- /dev/null +++ b/lerobot/scripts/rl/crop_dataset_roi.py @@ -0,0 +1,314 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +from copy import deepcopy +from pathlib import Path +from typing import Dict, Tuple + +import cv2 + +# import torch.nn.functional as F # noqa: N812 +import torchvision.transforms.functional as F # type: ignore # noqa: N812 +from tqdm import tqdm # type: ignore + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + + +def select_rect_roi(img): + """ + Allows the user to draw a rectangular ROI on the image. + + The user must click and drag to draw the rectangle. + - While dragging, the rectangle is dynamically drawn. + - On mouse button release, the rectangle is fixed. + - Press 'c' to confirm the selection. + - Press 'r' to reset the selection. + - Press ESC to cancel. + + Returns: + A tuple (top, left, height, width) representing the rectangular ROI, + or None if no valid ROI is selected. + """ + # Create a working copy of the image + clone = img.copy() + working_img = clone.copy() + + roi = None # Will store the final ROI as (top, left, height, width) + drawing = False + index_x, index_y = -1, -1 # Initial click coordinates + + def mouse_callback(event, x, y, flags, param): + nonlocal index_x, index_y, drawing, roi, working_img + + if event == cv2.EVENT_LBUTTONDOWN: + # Start drawing: record starting coordinates + drawing = True + index_x, index_y = x, y + + elif event == cv2.EVENT_MOUSEMOVE: + if drawing: + # Compute the top-left and bottom-right corners regardless of drag direction + top = min(index_y, y) + left = min(index_x, x) + bottom = max(index_y, y) + right = max(index_x, x) + # Show a temporary image with the current rectangle drawn + temp = working_img.copy() + cv2.rectangle(temp, (left, top), (right, bottom), (0, 255, 0), 2) + cv2.imshow("Select ROI", temp) + + elif event == cv2.EVENT_LBUTTONUP: + # Finish drawing + drawing = False + top = min(index_y, y) + left = min(index_x, x) + bottom = max(index_y, y) + right = max(index_x, x) + height = bottom - top + width = right - left + roi = (top, left, height, width) # (top, left, height, width) + # Draw the final rectangle on the working image and display it + working_img = clone.copy() + cv2.rectangle(working_img, (left, top), (right, bottom), (0, 255, 0), 2) + cv2.imshow("Select ROI", working_img) + + # Create the window and set the callback + cv2.namedWindow("Select ROI") + cv2.setMouseCallback("Select ROI", mouse_callback) + cv2.imshow("Select ROI", working_img) + + print("Instructions for ROI selection:") + print(" - Click and drag to draw a rectangular ROI.") + print(" - Press 'c' to confirm the selection.") + print(" - Press 'r' to reset and draw again.") + print(" - Press ESC to cancel the selection.") + + # Wait until the user confirms with 'c', resets with 'r', or cancels with ESC + while True: + key = cv2.waitKey(1) & 0xFF + # Confirm ROI if one has been drawn + if key == ord("c") and roi is not None: + break + # Reset: clear the ROI and restore the original image + elif key == ord("r"): + working_img = clone.copy() + roi = None + cv2.imshow("Select ROI", working_img) + # Cancel selection for this image + elif key == 27: # ESC key + roi = None + break + + cv2.destroyWindow("Select ROI") + return roi + + +def select_square_roi_for_images(images: dict) -> dict: + """ + For each image in the provided dictionary, open a window to allow the user + to select a rectangular ROI. Returns a dictionary mapping each key to a tuple + (top, left, height, width) representing the ROI. + + Parameters: + images (dict): Dictionary where keys are identifiers and values are OpenCV images. + + Returns: + dict: Mapping of image keys to the selected rectangular ROI. + """ + selected_rois = {} + + for key, img in images.items(): + if img is None: + print(f"Image for key '{key}' is None, skipping.") + continue + + print(f"\nSelect rectangular ROI for image with key: '{key}'") + roi = select_rect_roi(img) + + if roi is None: + print(f"No valid ROI selected for '{key}'.") + else: + selected_rois[key] = roi + print(f"ROI for '{key}': {roi}") + + return selected_rois + + +def get_image_from_lerobot_dataset(dataset: LeRobotDataset): + """ + Find the first row in the dataset and extract the image in order to be used for the crop. + """ + row = dataset[0] + image_dict = {} + for k in row: + if "image" in k: + image_dict[k] = deepcopy(row[k]) + return image_dict + + +def convert_lerobot_dataset_to_cropper_lerobot_dataset( + original_dataset: LeRobotDataset, + crop_params_dict: Dict[str, Tuple[int, int, int, int]], + new_repo_id: str, + new_dataset_root: str, + resize_size: Tuple[int, int] = (128, 128), + push_to_hub: bool = False, + task: str = "", +) -> LeRobotDataset: + """ + Converts an existing LeRobotDataset by iterating over its episodes and frames, + applying cropping and resizing to image observations, and saving a new dataset + with the transformed data. + + Args: + original_dataset (LeRobotDataset): The source dataset. + crop_params_dict (Dict[str, Tuple[int, int, int, int]]): + A dictionary mapping observation keys to crop parameters (top, left, height, width). + new_repo_id (str): Repository id for the new dataset. + new_dataset_root (str): The root directory where the new dataset will be written. + resize_size (Tuple[int, int], optional): The target size (height, width) after cropping. + Defaults to (128, 128). + + Returns: + LeRobotDataset: A new LeRobotDataset where the specified image observations have been cropped + and resized. + """ + # 1. Create a new (empty) LeRobotDataset for writing. + new_dataset = LeRobotDataset.create( + repo_id=new_repo_id, + fps=original_dataset.fps, + root=new_dataset_root, + robot_type=original_dataset.meta.robot_type, + features=original_dataset.meta.info["features"], + use_videos=len(original_dataset.meta.video_keys) > 0, + ) + + # Update the metadata for every image key that will be cropped: + # (Here we simply set the shape to be the final resize_size.) + for key in crop_params_dict: + if key in new_dataset.meta.info["features"]: + new_dataset.meta.info["features"][key]["shape"] = [3] + list(resize_size) + + # TODO: Directly modify the mp4 video + meta info features, instead of recreating a dataset + prev_episode_index = 0 + for frame_idx in tqdm(range(len(original_dataset))): + frame = original_dataset[frame_idx] + + # Create a copy of the frame to add to the new dataset + new_frame = {} + for key, value in frame.items(): + if key in ("task_index", "timestamp", "episode_index", "frame_index", "index", "task"): + continue + if key in ("next.done", "next.reward"): + # if not isinstance(value, str) and len(value.shape) == 0: + value = value.unsqueeze(0) + + if key in crop_params_dict: + top, left, height, width = crop_params_dict[key] + # Apply crop then resize. + cropped = F.crop(value, top, left, height, width) + value = F.resize(cropped, resize_size) + value = value.clamp(0, 1) + + new_frame[key] = value + + new_dataset.add_frame(new_frame, task=task) + + if frame["episode_index"].item() != prev_episode_index: + # Save the episode + new_dataset.save_episode() + prev_episode_index = frame["episode_index"].item() + + # Save the last episode + new_dataset.save_episode() + + if push_to_hub: + new_dataset.push_to_hub() + + return new_dataset + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.") + parser.add_argument( + "--repo-id", + type=str, + default="lerobot", + help="The repository id of the LeRobot dataset to process.", + ) + parser.add_argument( + "--root", + type=str, + default=None, + help="The root directory of the LeRobot dataset.", + ) + parser.add_argument( + "--crop-params-path", + type=str, + default=None, + help="The path to the JSON file containing the ROIs.", + ) + parser.add_argument( + "--push-to-hub", + type=bool, + default=False, + help="Whether to push the new dataset to the hub.", + ) + parser.add_argument( + "--task", + type=str, + default="", + help="The natural language task to describe the dataset.", + ) + args = parser.parse_args() + + dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root) + + images = get_image_from_lerobot_dataset(dataset) + images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()} + images = {k: (v * 255).astype("uint8") for k, v in images.items()} + + if args.crop_params_path is None: + rois = select_square_roi_for_images(images) + else: + with open(args.crop_params_path) as f: + rois = json.load(f) + + # Print the selected rectangular ROIs + print("\nSelected Rectangular Regions of Interest (top, left, height, width):") + for key, roi in rois.items(): + print(f"{key}: {roi}") + + new_repo_id = args.repo_id + "_cropped_resized" + new_dataset_root = Path(str(dataset.root) + "_cropped_resized") + + cropped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset( + original_dataset=dataset, + crop_params_dict=rois, + new_repo_id=new_repo_id, + new_dataset_root=new_dataset_root, + resize_size=(128, 128), + push_to_hub=args.push_to_hub, + task=args.task, + ) + + meta_dir = new_dataset_root / "meta" + meta_dir.mkdir(exist_ok=True) + + with open(meta_dir / "crop_params.json", "w") as f: + json.dump(rois, f, indent=4) diff --git a/lerobot/scripts/rl/eval_policy.py b/lerobot/scripts/rl/eval_policy.py new file mode 100644 index 0000000000..3762719bf4 --- /dev/null +++ b/lerobot/scripts/rl/eval_policy.py @@ -0,0 +1,74 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from lerobot.common.cameras import opencv # noqa: F401 +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.policies.factory import make_policy +from lerobot.common.robots import ( # noqa: F401 + RobotConfig, + make_robot_from_config, + so100_follower, +) +from lerobot.common.teleoperators import ( + gamepad, # noqa: F401 + so101_leader, # noqa: F401 +) +from lerobot.configs import parser +from lerobot.configs.train import TrainRLServerPipelineConfig +from lerobot.scripts.rl.gym_manipulator import make_robot_env + +logging.basicConfig(level=logging.INFO) + + +def eval_policy(env, policy, n_episodes): + sum_reward_episode = [] + for _ in range(n_episodes): + obs, _ = env.reset() + episode_reward = 0.0 + while True: + action = policy.select_action(obs) + obs, reward, terminated, truncated, _ = env.step(action) + episode_reward += reward + if terminated or truncated: + break + sum_reward_episode.append(episode_reward) + + logging.info(f"Success after 20 steps {sum_reward_episode}") + logging.info(f"success rate {sum(sum_reward_episode) / len(sum_reward_episode)}") + + +@parser.wrap() +def main(cfg: TrainRLServerPipelineConfig): + env_cfg = cfg.env + env = make_robot_env(env_cfg) + dataset_cfg = cfg.dataset + dataset = LeRobotDataset(repo_id=dataset_cfg.repo_id) + dataset_meta = dataset.meta + + policy = make_policy( + cfg=cfg.policy, + # env_cfg=cfg.env, + ds_meta=dataset_meta, + ) + policy.from_pretrained(env_cfg.pretrained_policy_name_or_path) + policy.eval() + + eval_policy(env, policy=policy, n_episodes=10) + + +if __name__ == "__main__": + main() diff --git a/lerobot/scripts/rl/gym_manipulator.py b/lerobot/scripts/rl/gym_manipulator.py new file mode 100644 index 0000000000..98445e6668 --- /dev/null +++ b/lerobot/scripts/rl/gym_manipulator.py @@ -0,0 +1,2171 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Robot Environment for LeRobot Manipulation Tasks + +This module provides a comprehensive gym-compatible environment for robot manipulation +with support for: +- Multiple robot types (SO100, SO101, Koch and Moss) +- Human intervention via leader-follower control or gamepad + +- End-effector and joint space control +- Image processing (cropping and resizing) + +The environment is built using a composable wrapper pattern where each wrapper +adds specific functionality to the base RobotEnv. + +Example: + env = make_robot_env(cfg) + obs, info = env.reset() + action = policy.select_action(obs) + obs, reward, terminated, truncated, info = env.step(action) +""" + +import logging +import time +from collections import deque +from threading import Lock +from typing import Annotated, Any, Sequence + +import gymnasium as gym +import numpy as np +import torch +import torchvision.transforms.functional as F # noqa: N812 + +from lerobot.common.cameras import opencv # noqa: F401 +from lerobot.common.envs.configs import EnvConfig +from lerobot.common.envs.utils import preprocess_observation +from lerobot.common.model.kinematics import RobotKinematics +from lerobot.common.robots import ( # noqa: F401 + RobotConfig, + make_robot_from_config, + so100_follower, +) +from lerobot.common.teleoperators import ( + gamepad, # noqa: F401 + make_teleoperator_from_config, + so101_leader, # noqa: F401 +) +from lerobot.common.teleoperators.gamepad.teleop_gamepad import GamepadTeleop +from lerobot.common.utils.robot_utils import busy_wait +from lerobot.common.utils.utils import log_say +from lerobot.configs import parser + +logging.basicConfig(level=logging.INFO) + + +def reset_follower_position(robot_arm, target_position): + current_position_dict = robot_arm.bus.sync_read("Present_Position") + current_position = np.array( + [current_position_dict[name] for name in current_position_dict], dtype=np.float32 + ) + trajectory = torch.from_numpy( + np.linspace(current_position, target_position, 50) + ) # NOTE: 30 is just an arbitrary number + for pose in trajectory: + action_dict = dict(zip(current_position_dict, pose, strict=False)) + robot_arm.bus.sync_write("Goal_Position", action_dict) + busy_wait(0.015) + + +class TorchBox(gym.spaces.Box): + """ + A version of gym.spaces.Box that handles PyTorch tensors. + + This class extends gym.spaces.Box to work with PyTorch tensors, + providing compatibility between NumPy arrays and PyTorch tensors. + """ + + def __init__( + self, + low: float | Sequence[float] | np.ndarray, + high: float | Sequence[float] | np.ndarray, + shape: Sequence[int] | None = None, + np_dtype: np.dtype | type = np.float32, + torch_dtype: torch.dtype = torch.float32, + device: str = "cpu", + seed: int | np.random.Generator | None = None, + ) -> None: + """ + Initialize the PyTorch-compatible Box space. + + Args: + low: Lower bounds of the space. + high: Upper bounds of the space. + shape: Shape of the space. If None, inferred from low and high. + np_dtype: NumPy data type for internal storage. + torch_dtype: PyTorch data type for tensor conversion. + device: PyTorch device for returned tensors. + seed: Random seed for sampling. + """ + super().__init__(low, high, shape=shape, dtype=np_dtype, seed=seed) + self.torch_dtype = torch_dtype + self.device = device + + def sample(self) -> torch.Tensor: + """ + Sample a random point from the space. + + Returns: + A PyTorch tensor within the space bounds. + """ + arr = super().sample() + return torch.as_tensor(arr, dtype=self.torch_dtype, device=self.device) + + def contains(self, x: torch.Tensor) -> bool: + """ + Check if a tensor is within the space bounds. + + Args: + x: The PyTorch tensor to check. + + Returns: + Boolean indicating whether the tensor is within bounds. + """ + # Move to CPU/numpy and cast to the internal dtype + arr = x.detach().cpu().numpy().astype(self.dtype, copy=False) + return super().contains(arr) + + def seed(self, seed: int | np.random.Generator | None = None): + """ + Set the random seed for sampling. + + Args: + seed: The random seed to use. + + Returns: + List containing the seed. + """ + super().seed(seed) + return [seed] + + def __repr__(self) -> str: + """ + Return a string representation of the space. + + Returns: + Formatted string with space details. + """ + return ( + f"TorchBox({self.low_repr}, {self.high_repr}, {self.shape}, " + f"np={self.dtype.name}, torch={self.torch_dtype}, device={self.device})" + ) + + +class TorchActionWrapper(gym.Wrapper): + """ + Wrapper that changes the action space to use PyTorch tensors. + + This wrapper modifies the action space to return PyTorch tensors when sampled + and handles converting PyTorch actions to NumPy when stepping the environment. + """ + + def __init__(self, env: gym.Env, device: str): + """ + Initialize the PyTorch action space wrapper. + + Args: + env: The environment to wrap. + device: The PyTorch device to use for tensor operations. + """ + super().__init__(env) + self.action_space = TorchBox( + low=env.action_space.low, + high=env.action_space.high, + shape=env.action_space.shape, + torch_dtype=torch.float32, + device=torch.device("cpu"), + ) + + def step(self, action: torch.Tensor): + """ + Step the environment with a PyTorch tensor action. + + This method handles conversion from PyTorch tensors to NumPy arrays + for compatibility with the underlying environment. + + Args: + action: PyTorch tensor action to take. + + Returns: + Tuple of (observation, reward, terminated, truncated, info). + """ + if action.dim() == 2: + action = action.squeeze(0) + action = action.detach().cpu().numpy() + return self.env.step(action) + + +class RobotEnv(gym.Env): + """ + Gym-compatible environment for evaluating robotic control policies with integrated human intervention. + + This environment wraps a robot interface to provide a consistent API for policy evaluation. It supports both relative (delta) + and absolute joint position commands and automatically configures its observation and action spaces based on the robot's + sensors and configuration. + """ + + def __init__( + self, + robot, + use_gripper: bool = False, + display_cameras: bool = False, + ): + """ + Initialize the RobotEnv environment. + + The environment is set up with a robot interface, which is used to capture observations and send joint commands. The setup + supports both relative (delta) adjustments and absolute joint positions for controlling the robot. + + Args: + robot: The robot interface object used to connect and interact with the physical robot. + display_cameras: If True, the robot's camera feeds will be displayed during execution. + """ + super().__init__() + + self.robot = robot + self.display_cameras = display_cameras + + # Connect to the robot if not already connected. + if not self.robot.is_connected: + self.robot.connect() + + # Episode tracking. + self.current_step = 0 + self.episode_data = None + + self._joint_names = [f"{key}.pos" for key in self.robot.bus.motors] + self._image_keys = self.robot.cameras.keys() + + # Read initial joint positions using the bus + self.current_joint_positions = self._get_observation()["agent_pos"] + + self.use_gripper = use_gripper + + self._setup_spaces() + + def _get_observation(self) -> np.ndarray: + """Helper to convert a dictionary from bus.sync_read to an ordered numpy array.""" + obs_dict = self.robot.get_observation() + joint_positions = np.array([obs_dict[name] for name in self._joint_names], dtype=np.float32) + + images = {key: obs_dict[key] for key in self._image_keys} + return {"agent_pos": joint_positions, "pixels": images} + + def _setup_spaces(self): + """ + Dynamically configure the observation and action spaces based on the robot's capabilities. + + Observation Space: + - For keys with "image": A Box space with pixel values ranging from 0 to 255. + - For non-image keys: A nested Dict space is created under 'observation.state' with a suitable range. + + Action Space: + - The action space is defined as a Box space representing joint position commands. It is defined as relative (delta) + or absolute, based on the configuration. + """ + example_obs = self._get_observation() + + observation_spaces = {} + + # Define observation spaces for images and other states. + if "pixels" in example_obs: + prefix = "observation.images" if len(example_obs["pixels"]) > 1 else "observation.image" + observation_spaces = { + f"{prefix}.{key}": gym.spaces.Box( + low=0, high=255, shape=example_obs["pixels"][key].shape, dtype=np.uint8 + ) + for key in example_obs["pixels"] + } + + observation_spaces["observation.state"] = gym.spaces.Box( + low=0, + high=10, + shape=example_obs["agent_pos"].shape, + dtype=np.float32, + ) + + self.observation_space = gym.spaces.Dict(observation_spaces) + + # Define the action space for joint positions along with setting an intervention flag. + action_dim = 3 + bounds = {} + bounds["min"] = -np.ones(action_dim) + bounds["max"] = np.ones(action_dim) + + if self.use_gripper: + action_dim += 1 + bounds["min"] = np.concatenate([bounds["min"], [0]]) + bounds["max"] = np.concatenate([bounds["max"], [2]]) + + self.action_space = gym.spaces.Box( + low=bounds["min"], + high=bounds["max"], + shape=(action_dim,), + dtype=np.float32, + ) + + def reset(self, seed=None, options=None) -> tuple[dict[str, np.ndarray], dict[str, Any]]: + """ + Reset the environment to its initial state. + This method resets the step counter and clears any episodic data. + + Args: + seed: A seed for random number generation to ensure reproducibility. + options: Additional options to influence the reset behavior. + + Returns: + A tuple containing: + - observation (dict): The initial sensor observation. + - info (dict): A dictionary with supplementary information, including the key "is_intervention". + """ + super().reset(seed=seed, options=options) + + self.robot.reset() + + # Capture the initial observation. + observation = self._get_observation() + + # Reset episode tracking variables. + self.current_step = 0 + self.episode_data = None + + return observation, {"is_intervention": False} + + def step(self, action) -> tuple[dict[str, np.ndarray], float, bool, bool, dict[str, Any]]: + """ + Execute a single step within the environment using the specified action. + + The provided action is processed and sent to the robot as joint position commands + that may be either absolute values or deltas based on the environment configuration. + + Args: + action: The commanded joint positions as a numpy array or torch tensor. + + Returns: + A tuple containing: + - observation (dict): The new sensor observation after taking the step. + - reward (float): The step reward (default is 0.0 within this wrapper). + - terminated (bool): True if the episode has reached a terminal state. + - truncated (bool): True if the episode was truncated (e.g., time constraints). + - info (dict): Additional debugging information including intervention status. + """ + self.current_joint_positions = self._get_observation()["agent_pos"] + + action_dict = {"delta_x": action[0], "delta_y": action[1], "delta_z": action[2]} + + # 1.0 action corresponds to no-op action + action_dict["gripper"] = action[3] if self.use_gripper else 1.0 + + self.robot.send_action(action_dict) + + if self.display_cameras: + self.render() + + self.current_step += 1 + + reward = 0.0 + terminated = False + truncated = False + + return ( + self._get_observation(), + reward, + terminated, + truncated, + {"is_intervention": False}, + ) + + def render(self): + """ + Render the current state of the environment by displaying the robot's camera feeds. + """ + import cv2 + + observation = self._get_observation() + image_keys = [key for key in observation if "image" in key] + + for key in image_keys: + cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) + cv2.waitKey(1) + + def close(self): + """ + Close the environment and clean up resources by disconnecting the robot. + + If the robot is currently connected, this method properly terminates the connection to ensure that all + associated resources are released. + """ + if self.robot.is_connected: + self.robot.disconnect() + + +class AddJointVelocityToObservation(gym.ObservationWrapper): + """ + Wrapper that adds joint velocity information to the observation. + + This wrapper computes joint velocities by tracking changes in joint positions over time, + and extends the observation space to include these velocities. + """ + + def __init__(self, env, joint_velocity_limits=100.0, fps=30, num_dof=6): + """ + Initialize the joint velocity wrapper. + + Args: + env: The environment to wrap. + joint_velocity_limits: Maximum expected joint velocity for space bounds. + fps: Frames per second used to calculate velocity (position delta / time). + num_dof: Number of degrees of freedom (joints) in the robot. + """ + super().__init__(env) + + # Extend observation space to include joint velocities + old_low = self.observation_space["observation.state"].low + old_high = self.observation_space["observation.state"].high + old_shape = self.observation_space["observation.state"].shape + + self.last_joint_positions = np.zeros(num_dof) + + new_low = np.concatenate([old_low, np.ones(num_dof) * -joint_velocity_limits]) + new_high = np.concatenate([old_high, np.ones(num_dof) * joint_velocity_limits]) + + new_shape = (old_shape[0] + num_dof,) + + self.observation_space["observation.state"] = gym.spaces.Box( + low=new_low, + high=new_high, + shape=new_shape, + dtype=np.float32, + ) + + self.dt = 1.0 / fps + + def observation(self, observation): + """ + Add joint velocity information to the observation. + + Args: + observation: The original observation from the environment. + + Returns: + The modified observation with joint velocities. + """ + joint_velocities = (observation["agent_pos"] - self.last_joint_positions) / self.dt + self.last_joint_positions = observation["agent_pos"] + observation["agent_pos"] = np.concatenate([observation["agent_pos"], joint_velocities], axis=-1) + return observation + + +class AddCurrentToObservation(gym.ObservationWrapper): + """ + Wrapper that adds motor current information to the observation. + + This wrapper extends the observation space to include the current values + from each motor, providing information about the forces being applied. + """ + + def __init__(self, env, max_current=500, num_dof=6): + """ + Initialize the current observation wrapper. + + Args: + env: The environment to wrap. + max_current: Maximum expected current for space bounds. + num_dof: Number of degrees of freedom (joints) in the robot. + """ + super().__init__(env) + + # Extend observation space to include joint velocities + old_low = self.observation_space["observation.state"].low + old_high = self.observation_space["observation.state"].high + old_shape = self.observation_space["observation.state"].shape + + new_low = np.concatenate([old_low, np.zeros(num_dof)]) + new_high = np.concatenate([old_high, np.ones(num_dof) * max_current]) + + new_shape = (old_shape[0] + num_dof,) + + self.observation_space["observation.state"] = gym.spaces.Box( + low=new_low, + high=new_high, + shape=new_shape, + dtype=np.float32, + ) + + def observation(self, observation): + """ + Add current information to the observation. + + Args: + observation: The original observation from the environment. + + Returns: + The modified observation with current values. + """ + present_current_observation = self.unwrapped._get_observation()["agent_pos"] + observation["agent_pos"] = np.concatenate( + [observation["agent_pos"], present_current_observation], axis=-1 + ) + return observation + + +class RewardWrapper(gym.Wrapper): + def __init__(self, env, reward_classifier, device="cuda"): + """ + Wrapper to add reward prediction to the environment using a trained classifier. + + Args: + env: The environment to wrap. + reward_classifier: The reward classifier model. + device: The device to run the model on. + """ + self.env = env + + self.device = device + + self.reward_classifier = torch.compile(reward_classifier) + self.reward_classifier.to(self.device) + + def step(self, action): + """ + Execute a step and compute the reward using the classifier. + + Args: + action: The action to take in the environment. + + Returns: + Tuple of (observation, reward, terminated, truncated, info). + """ + observation, _, terminated, truncated, info = self.env.step(action) + + images = {} + for key in observation: + if "image" in key: + images[key] = observation[key].to(self.device, non_blocking=(self.device == "cuda")) + if images[key].dim() == 3: + images[key] = images[key].unsqueeze(0) + + start_time = time.perf_counter() + with torch.inference_mode(): + success = ( + self.reward_classifier.predict_reward(images, threshold=0.7) + if self.reward_classifier is not None + else 0.0 + ) + info["Reward classifier frequency"] = 1 / (time.perf_counter() - start_time) + + reward = 0.0 + if success == 1.0: + terminated = True + reward = 1.0 + + return observation, reward, terminated, truncated, info + + def reset(self, seed=None, options=None): + """ + Reset the environment. + + Args: + seed: Random seed for reproducibility. + options: Additional reset options. + + Returns: + The initial observation and info from the wrapped environment. + """ + return self.env.reset(seed=seed, options=options) + + +class TimeLimitWrapper(gym.Wrapper): + """ + Wrapper that adds a time limit to episodes and tracks execution time. + + This wrapper terminates episodes after a specified time has elapsed, providing + better control over episode length. + """ + + def __init__(self, env, control_time_s, fps): + """ + Initialize the time limit wrapper. + + Args: + env: The environment to wrap. + control_time_s: Maximum episode duration in seconds. + fps: Frames per second for calculating the maximum number of steps. + """ + self.env = env + self.control_time_s = control_time_s + self.fps = fps + + self.last_timestamp = 0.0 + self.episode_time_in_s = 0.0 + + self.max_episode_steps = int(self.control_time_s * self.fps) + + self.current_step = 0 + + def step(self, action): + """ + Step the environment and track time elapsed. + + Args: + action: The action to take in the environment. + + Returns: + Tuple of (observation, reward, terminated, truncated, info). + """ + obs, reward, terminated, truncated, info = self.env.step(action) + time_since_last_step = time.perf_counter() - self.last_timestamp + self.episode_time_in_s += time_since_last_step + self.last_timestamp = time.perf_counter() + self.current_step += 1 + # check if last timestep took more time than the expected fps + if 1.0 / time_since_last_step < self.fps: + logging.debug(f"Current timestep exceeded expected fps {self.fps}") + + if self.current_step >= self.max_episode_steps: + terminated = True + return obs, reward, terminated, truncated, info + + def reset(self, seed=None, options=None): + """ + Reset the environment and time tracking. + + Args: + seed: Random seed for reproducibility. + options: Additional reset options. + + Returns: + The initial observation and info from the wrapped environment. + """ + self.episode_time_in_s = 0.0 + self.last_timestamp = time.perf_counter() + self.current_step = 0 + return self.env.reset(seed=seed, options=options) + + +class ImageCropResizeWrapper(gym.Wrapper): + """ + Wrapper that crops and resizes image observations. + + This wrapper processes image observations to focus on relevant regions by + cropping and then resizing to a standard size. + """ + + def __init__( + self, + env, + crop_params_dict: dict[str, Annotated[tuple[int], 4]], + resize_size=None, + ): + """ + Initialize the image crop and resize wrapper. + + Args: + env: The environment to wrap. + crop_params_dict: Dictionary mapping image observation keys to crop parameters + (top, left, height, width). + resize_size: Target size for resized images (height, width). Defaults to (128, 128). + """ + super().__init__(env) + self.env = env + self.crop_params_dict = crop_params_dict + print(f"obs_keys , {self.env.observation_space}") + print(f"crop params dict {crop_params_dict.keys()}") + for key_crop in crop_params_dict: + if key_crop not in self.env.observation_space.keys(): # noqa: SIM118 + raise ValueError(f"Key {key_crop} not in observation space") + for key in crop_params_dict: + new_shape = (3, resize_size[0], resize_size[1]) + self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape) + + self.resize_size = resize_size + if self.resize_size is None: + self.resize_size = (128, 128) + + def step(self, action): + """ + Step the environment and process image observations. + + Args: + action: The action to take in the environment. + + Returns: + Tuple of (observation, reward, terminated, truncated, info) with processed images. + """ + obs, reward, terminated, truncated, info = self.env.step(action) + for k in self.crop_params_dict: + device = obs[k].device + if obs[k].dim() >= 3: + # Reshape to combine height and width dimensions for easier calculation + batch_size = obs[k].size(0) + channels = obs[k].size(1) + flattened_spatial_dims = obs[k].view(batch_size, channels, -1) + + # Calculate standard deviation across spatial dimensions (H, W) + # If any channel has std=0, all pixels in that channel have the same value + # This is helpful if one camera mistakenly covered or the image is black + std_per_channel = torch.std(flattened_spatial_dims, dim=2) + if (std_per_channel <= 0.02).any(): + logging.warning( + f"Potential hardware issue detected: All pixels have the same value in observation {k}" + ) + + if device == torch.device("mps:0"): + obs[k] = obs[k].cpu() + + obs[k] = F.crop(obs[k], *self.crop_params_dict[k]) + obs[k] = F.resize(obs[k], self.resize_size) + # TODO (michel-aractingi): Bug in resize, it returns values outside [0, 1] + obs[k] = obs[k].clamp(0.0, 1.0) + obs[k] = obs[k].to(device) + + return obs, reward, terminated, truncated, info + + def reset(self, seed=None, options=None): + """ + Reset the environment and process image observations. + + Args: + seed: Random seed for reproducibility. + options: Additional reset options. + + Returns: + Tuple of (observation, info) with processed images. + """ + obs, info = self.env.reset(seed=seed, options=options) + for k in self.crop_params_dict: + device = obs[k].device + if device == torch.device("mps:0"): + obs[k] = obs[k].cpu() + obs[k] = F.crop(obs[k], *self.crop_params_dict[k]) + obs[k] = F.resize(obs[k], self.resize_size) + obs[k] = obs[k].clamp(0.0, 1.0) + obs[k] = obs[k].to(device) + return obs, info + + +class ConvertToLeRobotObservation(gym.ObservationWrapper): + """ + Wrapper that converts standard observations to LeRobot format. + + This wrapper processes observations to match the expected format for LeRobot, + including normalizing image values and moving tensors to the specified device. + """ + + def __init__(self, env, device: str = "cpu"): + """ + Initialize the LeRobot observation converter. + + Args: + env: The environment to wrap. + device: Target device for the observation tensors. + """ + super().__init__(env) + + self.device = torch.device(device) + + def observation(self, observation): + """ + Convert observations to LeRobot format. + + Args: + observation: The original observation from the environment. + + Returns: + The processed observation with normalized images and proper tensor formats. + """ + observation = preprocess_observation(observation) + observation = { + key: observation[key].to(self.device, non_blocking=self.device.type == "cuda") + for key in observation + } + return observation + + +class ResetWrapper(gym.Wrapper): + """ + Wrapper that handles environment reset procedures. + + This wrapper provides additional functionality during environment reset, + including the option to reset to a fixed pose or allow manual reset. + """ + + def __init__( + self, + env: RobotEnv, + reset_pose: np.ndarray | None = None, + reset_time_s: float = 5, + ): + """ + Initialize the reset wrapper. + + Args: + env: The environment to wrap. + reset_pose: Fixed joint positions to reset to. If None, manual reset is used. + reset_time_s: Time in seconds to wait after reset or allowed for manual reset. + """ + super().__init__(env) + self.reset_time_s = reset_time_s + self.reset_pose = reset_pose + self.robot = self.unwrapped.robot + + def reset(self, *, seed=None, options=None): + """ + Reset the environment with either fixed or manual reset procedure. + + If reset_pose is provided, the robot will move to that position. + Otherwise, manual teleoperation control is allowed for reset_time_s seconds. + + Args: + seed: Random seed for reproducibility. + options: Additional reset options. + + Returns: + The initial observation and info from the wrapped environment. + """ + start_time = time.perf_counter() + if self.reset_pose is not None: + log_say("Reset the environment.", play_sounds=True) + reset_follower_position(self.unwrapped.robot, self.reset_pose) + log_say("Reset the environment done.", play_sounds=True) + + if hasattr(self.env, "robot_leader"): + self.env.robot_leader.bus.sync_write("Torque_Enable", 1) + log_say("Reset the leader robot.", play_sounds=True) + reset_follower_position(self.env.robot_leader, self.reset_pose) + log_say("Reset the leader robot done.", play_sounds=True) + else: + log_say( + f"Manually reset the environment for {self.reset_time_s} seconds.", + play_sounds=True, + ) + start_time = time.perf_counter() + while time.perf_counter() - start_time < self.reset_time_s: + action = self.env.robot_leader.get_action() + self.unwrapped.robot.send_action(action) + + log_say("Manual reset of the environment done.", play_sounds=True) + + busy_wait(self.reset_time_s - (time.perf_counter() - start_time)) + + return super().reset(seed=seed, options=options) + + +class BatchCompatibleWrapper(gym.ObservationWrapper): + """ + Wrapper that ensures observations are compatible with batch processing. + + This wrapper adds a batch dimension to observations that don't already have one, + making them compatible with models that expect batched inputs. + """ + + def __init__(self, env): + """ + Initialize the batch compatibility wrapper. + + Args: + env: The environment to wrap. + """ + super().__init__(env) + + def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Add batch dimensions to observations if needed. + + Args: + observation: Dictionary of observation tensors. + + Returns: + Dictionary of observation tensors with batch dimensions. + """ + for key in observation: + if "image" in key and observation[key].dim() == 3: + observation[key] = observation[key].unsqueeze(0) + if "state" in key and observation[key].dim() == 1: + observation[key] = observation[key].unsqueeze(0) + if "velocity" in key and observation[key].dim() == 1: + observation[key] = observation[key].unsqueeze(0) + return observation + + +class GripperPenaltyWrapper(gym.RewardWrapper): + """ + Wrapper that adds penalties for inefficient gripper commands. + + This wrapper modifies rewards to discourage excessive gripper movement + or commands that attempt to move the gripper beyond its physical limits. + """ + + def __init__(self, env, penalty: float = -0.1): + """ + Initialize the gripper penalty wrapper. + + Args: + env: The environment to wrap. + penalty: Negative reward value to apply for inefficient gripper actions. + """ + super().__init__(env) + self.penalty = penalty + self.last_gripper_state = None + + def reward(self, reward, action): + """ + Apply penalties to reward based on gripper actions. + + Args: + reward: The original reward from the environment. + action: The action that was taken. + + Returns: + Modified reward with penalty applied if necessary. + """ + gripper_state_normalized = self.last_gripper_state / self.unwrapped.robot.config.max_gripper_pos + + action_normalized = action - 1.0 # action / MAX_GRIPPER_COMMAND + + gripper_penalty_bool = (gripper_state_normalized < 0.5 and action_normalized > 0.5) or ( + gripper_state_normalized > 0.75 and action_normalized < -0.5 + ) + + return reward + self.penalty * int(gripper_penalty_bool) + + def step(self, action): + """ + Step the environment and apply gripper penalties. + + Args: + action: The action to take in the environment. + + Returns: + Tuple of (observation, reward, terminated, truncated, info) with penalty applied. + """ + self.last_gripper_state = self.unwrapped.robot.bus.sync_read("Present_Position")["gripper"] + + gripper_action = action[-1] + obs, reward, terminated, truncated, info = self.env.step(action) + gripper_penalty = self.reward(reward, gripper_action) + + info["discrete_penalty"] = gripper_penalty + + return obs, reward, terminated, truncated, info + + def reset(self, **kwargs): + """ + Reset the environment and penalty tracking. + + Args: + **kwargs: Keyword arguments passed to the wrapped environment's reset. + + Returns: + The initial observation and info with gripper penalty initialized. + """ + self.last_gripper_state = None + obs, info = super().reset(**kwargs) + info["gripper_penalty"] = 0.0 + return obs, info + + +class GripperActionWrapper(gym.ActionWrapper): + """ + Wrapper that processes gripper control commands. + + This wrapper quantizes and processes gripper commands, adding a sleep time between + consecutive gripper actions to prevent rapid toggling. + """ + + def __init__(self, env, quantization_threshold: float = 0.2, gripper_sleep: float = 0.0): + """ + Initialize the gripper action wrapper. + + Args: + env: The environment to wrap. + quantization_threshold: Threshold below which gripper commands are quantized to zero. + gripper_sleep: Minimum time in seconds between consecutive gripper commands. + """ + super().__init__(env) + self.quantization_threshold = quantization_threshold + self.gripper_sleep = gripper_sleep + self.last_gripper_action_time = 0.0 + self.last_gripper_action = None + + def action(self, action): + """ + Process gripper commands in the action. + + Args: + action: The original action from the agent. + + Returns: + Modified action with processed gripper command. + """ + if self.gripper_sleep > 0.0: + if ( + self.last_gripper_action is not None + and time.perf_counter() - self.last_gripper_action_time < self.gripper_sleep + ): + action[-1] = self.last_gripper_action + else: + self.last_gripper_action_time = time.perf_counter() + self.last_gripper_action = action[-1] + + gripper_command = action[-1] + # Gripper actions are between 0, 2 + # we want to quantize them to -1, 0 or 1 + gripper_command = gripper_command - 1.0 + + if self.quantization_threshold is not None: + # Quantize gripper command to -1, 0 or 1 + gripper_command = ( + np.sign(gripper_command) if abs(gripper_command) > self.quantization_threshold else 0.0 + ) + gripper_command = gripper_command * self.unwrapped.robot.config.max_gripper_pos + + gripper_state = self.unwrapped.robot.bus.sync_read("Present_Position")["gripper"] + + gripper_action_value = np.clip( + gripper_state + gripper_command, 0, self.unwrapped.robot.config.max_gripper_pos + ) + action[-1] = gripper_action_value.item() + return action + + def reset(self, **kwargs): + """ + Reset the gripper action tracking. + + Args: + **kwargs: Keyword arguments passed to the wrapped environment's reset. + + Returns: + The initial observation and info. + """ + obs, info = super().reset(**kwargs) + self.last_gripper_action_time = 0.0 + self.last_gripper_action = None + return obs, info + + +class EEObservationWrapper(gym.ObservationWrapper): + """ + Wrapper that adds end-effector pose information to observations. + + This wrapper computes the end-effector pose using forward kinematics + and adds it to the observation space. + """ + + def __init__(self, env, ee_pose_limits): + """ + Initialize the end-effector observation wrapper. + + Args: + env: The environment to wrap. + ee_pose_limits: Dictionary with 'min' and 'max' keys containing limits for EE pose. + """ + super().__init__(env) + + # Extend observation space to include end effector pose + prev_space = self.observation_space["observation.state"] + + self.observation_space["observation.state"] = gym.spaces.Box( + low=np.concatenate([prev_space.low, ee_pose_limits["min"]]), + high=np.concatenate([prev_space.high, ee_pose_limits["max"]]), + shape=(prev_space.shape[0] + 3,), + dtype=np.float32, + ) + + # Initialize kinematics instance for the appropriate robot type + robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so101") + if "so100" in robot_type or "so101" in robot_type: + # Note to be compatible with the rest of the codebase, + # we are using the new calibration method for so101 and so100 + robot_type = "so_new_calibration" + self.kinematics = RobotKinematics(robot_type) + + def observation(self, observation): + """ + Add end-effector pose to the observation. + + Args: + observation: Original observation from the environment. + + Returns: + Enhanced observation with end-effector pose information. + """ + current_joint_pos = self.unwrapped._get_observation()["agent_pos"] + + current_ee_pos = self.kinematics.forward_kinematics(current_joint_pos, frame="gripper_tip")[:3, 3] + observation["agent_pos"] = np.concatenate([observation["agent_pos"], current_ee_pos], -1) + return observation + + +########################################################### +# Wrappers related to human intervention and input devices +########################################################### + + +class BaseLeaderControlWrapper(gym.Wrapper): + """ + Base class for leader-follower robot control wrappers. + + This wrapper enables human intervention through a leader-follower robot setup, + where the human can control a leader robot to guide the follower robot's movements. + """ + + def __init__( + self, + env, + teleop_device, + end_effector_step_sizes, + use_geared_leader_arm: bool = False, + use_gripper=False, + ): + """ + Initialize the base leader control wrapper. + + Args: + env: The environment to wrap. + teleop_device: The teleoperation device. + use_geared_leader_arm: Whether to use a geared leader arm setup. + use_gripper: Whether to include gripper control. + """ + super().__init__(env) + self.robot_leader = teleop_device + self.robot_follower = env.unwrapped.robot + self.use_geared_leader_arm = use_geared_leader_arm + self.use_gripper: bool = use_gripper + self.end_effector_step_sizes = np.array(list(end_effector_step_sizes.values())) + + # Set up keyboard event tracking + self._init_keyboard_events() + self.event_lock = Lock() # Thread-safe access to events + + # Initialize robot control + robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so101") + if "so100" in robot_type or "so101" in robot_type: + # Note to be compatible with the rest of the codebase, + # we are using the new calibration method for so101 and so100 + robot_type = "so_new_calibration" + self.kinematics = RobotKinematics(robot_type) + self.leader_torque_enabled = True + self.prev_leader_gripper = None + + # Configure leader arm + # NOTE: Lower the gains of leader arm for automatic take-over + # With lower gains we can manually move the leader arm without risk of injury to ourselves or the robot + # With higher gains, it would be dangerous and difficult to modify the leader's pose while torque is enabled + # Default value for P_coeff is 32 + self.robot_leader.bus.sync_write("Torque_Enable", 1) + for motor in self.robot_leader.bus.motors: + self.robot_leader.bus.write("P_Coefficient", motor, 16) + self.robot_leader.bus.write("I_Coefficient", motor, 0) + self.robot_leader.bus.write("D_Coefficient", motor, 16) + + self.leader_tracking_error_queue = deque(maxlen=4) + self._init_keyboard_listener() + + def _init_keyboard_events(self): + """ + Initialize the keyboard events dictionary. + + This method sets up tracking for keyboard events used for intervention control. + It should be overridden in subclasses to add additional events. + """ + self.keyboard_events = { + "episode_success": False, + "episode_end": False, + "rerecord_episode": False, + } + + def _handle_key_press(self, key, keyboard): + """ + Handle key press events. + + Args: + key: The key that was pressed. + keyboard: The keyboard module with key definitions. + + This method should be overridden in subclasses for additional key handling. + """ + try: + if key == keyboard.Key.esc: + self.keyboard_events["episode_end"] = True + return + if key == keyboard.Key.left: + self.keyboard_events["rerecord_episode"] = True + return + if hasattr(key, "char") and key.char == "s": + logging.info("Key 's' pressed. Episode success triggered.") + self.keyboard_events["episode_success"] = True + return + except Exception as e: + logging.error(f"Error handling key press: {e}") + + def _init_keyboard_listener(self): + """ + Initialize the keyboard listener for intervention control. + + This method sets up keyboard event handling if not in headless mode. + """ + from pynput import keyboard + + def on_press(key): + with self.event_lock: + self._handle_key_press(key, keyboard) + + self.listener = keyboard.Listener(on_press=on_press) + self.listener.start() + + def _check_intervention(self): + """ + Check if human intervention is needed. + + Returns: + Boolean indicating whether intervention is needed. + + This method should be overridden in subclasses with specific intervention logic. + """ + return False + + def _handle_intervention(self, action): + """ + Process actions during intervention mode. + + Args: + action: The original action from the agent. + + Returns: + Tuple of (modified_action, intervention_action). + """ + if self.leader_torque_enabled: + self.robot_leader.bus.sync_write("Torque_Enable", 0) + self.leader_torque_enabled = False + + leader_pos_dict = self.robot_leader.bus.sync_read("Present_Position") + follower_pos_dict = self.robot_follower.bus.sync_read("Present_Position") + + leader_pos = np.array([leader_pos_dict[name] for name in leader_pos_dict], dtype=np.float32) + follower_pos = np.array([follower_pos_dict[name] for name in follower_pos_dict], dtype=np.float32) + + self.leader_tracking_error_queue.append(np.linalg.norm(follower_pos[:-1] - leader_pos[:-1])) + + # [:3, 3] Last column of the transformation matrix corresponds to the xyz translation + leader_ee = self.kinematics.forward_kinematics(leader_pos, frame="gripper_tip")[:3, 3] + follower_ee = self.kinematics.forward_kinematics(follower_pos, frame="gripper_tip")[:3, 3] + + action = np.clip(leader_ee - follower_ee, -self.end_effector_step_sizes, self.end_effector_step_sizes) + # Normalize the action to the range [-1, 1] + action = action / self.end_effector_step_sizes + + if self.use_gripper: + if self.prev_leader_gripper is None: + self.prev_leader_gripper = np.clip( + leader_pos[-1], 0, self.robot_follower.config.max_gripper_pos + ) + + # Get gripper action delta based on leader pose + leader_gripper = leader_pos[-1] + gripper_delta = leader_gripper - self.prev_leader_gripper + + # Normalize by max angle and quantize to {0,1,2} + normalized_delta = gripper_delta / self.robot_follower.config.max_gripper_pos + if normalized_delta >= 0.3: + gripper_action = 2 + elif normalized_delta <= 0.1: + gripper_action = 0 + else: + gripper_action = 1 + + action = np.append(action, gripper_action) + + return action + + def _handle_leader_teleoperation(self): + """ + Handle leader teleoperation in non-intervention mode. + + This method synchronizes the leader robot position with the follower. + """ + + prev_leader_pos_dict = self.robot_leader.bus.sync_read("Present_Position") + prev_leader_pos = np.array( + [prev_leader_pos_dict[name] for name in prev_leader_pos_dict], dtype=np.float32 + ) + + if not self.leader_torque_enabled: + self.robot_leader.bus.sync_write("Torque_Enable", 1) + self.leader_torque_enabled = True + + follower_pos_dict = self.robot_follower.bus.sync_read("Present_Position") + follower_pos = np.array([follower_pos_dict[name] for name in follower_pos_dict], dtype=np.float32) + + goal_pos = {f"{motor}": follower_pos[i] for i, motor in enumerate(self.robot_leader.bus.motors)} + self.robot_leader.bus.sync_write("Goal_Position", goal_pos) + + self.leader_tracking_error_queue.append(np.linalg.norm(follower_pos[:-1] - prev_leader_pos[:-1])) + + def step(self, action): + """ + Execute a step with possible human intervention. + + Args: + action: The action to take in the environment. + + Returns: + Tuple of (observation, reward, terminated, truncated, info). + """ + is_intervention = self._check_intervention() + + # NOTE: + if is_intervention: + action = self._handle_intervention(action) + else: + self._handle_leader_teleoperation() + + # NOTE: + obs, reward, terminated, truncated, info = self.env.step(action) + + # Add intervention info + info["is_intervention"] = is_intervention + info["action_intervention"] = action if is_intervention else None + + self.prev_leader_gripper = np.clip( + self.robot_leader.bus.sync_read("Present_Position")["gripper"], + 0, + self.robot_follower.config.max_gripper_pos, + ) + + # Check for success or manual termination + success = self.keyboard_events["episode_success"] + terminated = terminated or self.keyboard_events["episode_end"] or success + + if success: + reward = 1.0 + logging.info("Episode ended successfully with reward 1.0") + + return obs, reward, terminated, truncated, info + + def reset(self, **kwargs): + """ + Reset the environment and intervention state. + + Args: + **kwargs: Keyword arguments passed to the wrapped environment's reset. + + Returns: + The initial observation and info. + """ + self.keyboard_events = dict.fromkeys(self.keyboard_events, False) + self.leader_tracking_error_queue.clear() + return super().reset(**kwargs) + + def close(self): + """ + Clean up resources, including stopping keyboard listener. + + Returns: + Result of closing the wrapped environment. + """ + if hasattr(self, "listener") and self.listener is not None: + self.listener.stop() + return self.env.close() + + +class GearedLeaderControlWrapper(BaseLeaderControlWrapper): + """ + Wrapper that enables manual intervention via keyboard. + + This wrapper extends the BaseLeaderControlWrapper to allow explicit toggling + of human intervention mode with keyboard controls. + """ + + def _init_keyboard_events(self): + """ + Initialize keyboard events including human intervention flag. + + Extends the base class dictionary with an additional flag for tracking + intervention state toggled by keyboard. + """ + super()._init_keyboard_events() + self.keyboard_events["human_intervention_step"] = False + + def _handle_key_press(self, key, keyboard): + """ + Handle key presses including space for intervention toggle. + + Args: + key: The key that was pressed. + keyboard: The keyboard module with key definitions. + + Extends the base handler to respond to space key for toggling intervention. + """ + super()._handle_key_press(key, keyboard) + if key == keyboard.Key.space: + if not self.keyboard_events["human_intervention_step"]: + logging.info( + "Space key pressed. Human intervention required.\n" + "Place the leader in similar pose to the follower and press space again." + ) + self.keyboard_events["human_intervention_step"] = True + log_say("Human intervention step.", play_sounds=True) + else: + self.keyboard_events["human_intervention_step"] = False + logging.info("Space key pressed for a second time.\nContinuing with policy actions.") + log_say("Continuing with policy actions.", play_sounds=True) + + def _check_intervention(self): + """ + Check if human intervention is active based on keyboard toggle. + + Returns: + Boolean indicating whether intervention mode is active. + """ + return self.keyboard_events["human_intervention_step"] + + +class GearedLeaderAutomaticControlWrapper(BaseLeaderControlWrapper): + """ + Wrapper with automatic intervention based on error thresholds. + + This wrapper monitors the error between leader and follower positions + and automatically triggers intervention when error exceeds thresholds. + """ + + def __init__( + self, + env, + teleop_device, + end_effector_step_sizes, + use_gripper=False, + intervention_threshold=10.0, + release_threshold=1e-2, + ): + """ + Initialize the automatic intervention wrapper. + + Args: + env: The environment to wrap. + teleop_device: The teleoperation device. + use_gripper: Whether to include gripper control. + intervention_threshold: Error threshold to trigger intervention. + release_threshold: Error threshold to release intervention. + queue_size: Number of error measurements to track for smoothing. + """ + super().__init__(env, teleop_device, end_effector_step_sizes, use_gripper=use_gripper) + + # Error tracking parameters + self.intervention_threshold = intervention_threshold # Threshold to trigger intervention + self.release_threshold = release_threshold # Threshold to release intervention + self.is_intervention_active = False + self.start_time = time.perf_counter() + + def _check_intervention(self): + """ + Determine if intervention should occur based on the rate of change of leader-follower error in end_effector space. + + This method monitors the rate of change of leader-follower error in end_effector space + and automatically triggers intervention when the rate of change exceeds + the intervention threshold, releasing when it falls below the release threshold. + + Returns: + Boolean indicating whether intervention should be active. + """ + + # Condition for starting the intervention + # If the error in teleoperation is too high, that means the a user has grasped the leader robot and he wants to take over + if ( + not self.is_intervention_active + and len(self.leader_tracking_error_queue) == self.leader_tracking_error_queue.maxlen + and np.var(list(self.leader_tracking_error_queue)[-2:]) > self.intervention_threshold + ): + self.is_intervention_active = True + self.leader_tracking_error_queue.clear() + log_say("Intervention started", play_sounds=True) + return True + + # Track the error over time in leader_tracking_error_queue + # If the variance of the tracking error is too low, that means the user has let go of the leader robot and the intervention is over + if ( + self.is_intervention_active + and len(self.leader_tracking_error_queue) == self.leader_tracking_error_queue.maxlen + and np.var(self.leader_tracking_error_queue) < self.release_threshold + ): + self.is_intervention_active = False + self.leader_tracking_error_queue.clear() + log_say("Intervention ended", play_sounds=True) + return False + + # If not change has happened that merits a change in the intervention state, return the current state + return self.is_intervention_active + + def reset(self, **kwargs): + """ + Reset error tracking on environment reset. + + Args: + **kwargs: Keyword arguments passed to the wrapped environment's reset. + + Returns: + The initial observation and info. + """ + self.is_intervention_active = False + return super().reset(**kwargs) + + +class GamepadControlWrapper(gym.Wrapper): + """ + Wrapper that allows controlling a gym environment with a gamepad. + + This wrapper intercepts the step method and allows human input via gamepad + to override the agent's actions when desired. + """ + + def __init__( + self, + env, + teleop_device, # Accepts an instantiated teleoperator + use_gripper=False, # This should align with teleop_device's config + auto_reset=False, + ): + """ + Initialize the gamepad controller wrapper. + + Args: + env: The environment to wrap. + teleop_device: The instantiated teleoperation device (e.g., GamepadTeleop). + use_gripper: Whether to include gripper control (should match teleop_device.config.use_gripper). + auto_reset: Whether to auto reset the environment when episode ends. + """ + super().__init__(env) + + self.teleop_device = teleop_device + # Ensure the teleop_device is connected if it has a connect method + if hasattr(self.teleop_device, "connect") and not self.teleop_device.is_connected: + self.teleop_device.connect() + + # self.controller attribute is removed + + self.auto_reset = auto_reset + # use_gripper from args should ideally match teleop_device.config.use_gripper + # For now, we use the one passed, but it can lead to inconsistency if not set correctly from config + self.use_gripper = use_gripper + + logging.info("Gamepad control wrapper initialized with provided teleop_device.") + print( + "Gamepad controls (managed by the provided teleop_device - specific button mappings might vary):" + ) + print(" Left analog stick: Move in X-Y plane") + print(" Right analog stick: Move in Z axis (up/down)") + print(" X/Square button: End episode (FAILURE)") + print(" Y/Triangle button: End episode (SUCCESS)") + print(" B/Circle button: Exit program") + + def get_gamepad_action( + self, + ) -> tuple[bool, np.ndarray, bool, bool, bool]: + """ + Get the current action from the gamepad if any input is active. + + Returns: + Tuple containing: + - is_active: Whether gamepad input is active (from teleop_device.gamepad.should_intervene()) + - action: The action derived from gamepad input (from teleop_device.get_action()) + - terminate_episode: Whether episode termination was requested + - success: Whether episode success was signaled + - rerecord_episode: Whether episode rerecording was requested + """ + if not hasattr(self.teleop_device, "gamepad") or self.teleop_device.gamepad is None: + raise AttributeError( + "teleop_device does not have a 'gamepad' attribute or it is None. Expected for GamepadControlWrapper." + ) + + # Get status flags from the underlying gamepad controller within the teleop_device + self.teleop_device.gamepad.update() # Ensure gamepad state is fresh + intervention_is_active = self.teleop_device.gamepad.should_intervene() + episode_end_status = self.teleop_device.gamepad.get_episode_end_status() + + terminate_episode = episode_end_status is not None + success = episode_end_status == "success" + rerecord_episode = episode_end_status == "rerecord_episode" + + # Get the action dictionary from the teleop_device + action_dict = self.teleop_device.get_action() + + # Convert action_dict to numpy array based on expected structure + # Order: delta_x, delta_y, delta_z, gripper (if use_gripper) + action_list = [action_dict["delta_x"], action_dict["delta_y"], action_dict["delta_z"]] + if self.use_gripper: + # GamepadTeleop returns gripper action as 0 (close), 1 (stay), 2 (open) + # This needs to be consistent with what EEActionWrapper expects if it's used downstream + # EEActionWrapper for gripper typically expects 0.0 (closed) to 2.0 (open) + # For now, we pass the direct value from GamepadTeleop, ensure downstream compatibility. + gripper_val = action_dict.get("gripper", 1.0) # Default to 1.0 (stay) if not present + action_list.append(float(gripper_val)) + + gamepad_action_np = np.array(action_list, dtype=np.float32) + + return ( + intervention_is_active, + gamepad_action_np, + terminate_episode, + success, + rerecord_episode, + ) + + def step(self, action): + """ + Step the environment, using gamepad input to override actions when active. + + Args: + action: Original action from agent. + + Returns: + Tuple of (observation, reward, terminated, truncated, info). + """ + # Get gamepad state and action + ( + is_intervention, + gamepad_action, + terminate_episode, + success, + rerecord_episode, + ) = self.get_gamepad_action() + + # Update episode ending state if requested + if terminate_episode: + logging.info(f"Episode manually ended: {'SUCCESS' if success else 'FAILURE'}") + + # Only override the action if gamepad is active + action = gamepad_action if is_intervention else action + + # Step the environment + obs, reward, terminated, truncated, info = self.env.step(action) + + # Add episode ending if requested via gamepad + terminated = terminated or truncated or terminate_episode + + if success: + reward = 1.0 + logging.info("Episode ended successfully with reward 1.0") + + if isinstance(action, np.ndarray): + action = torch.from_numpy(action) + + info["is_intervention"] = is_intervention + # The original `BaseLeaderControlWrapper` puts `action_intervention` in info. + # For Gamepad, if intervention, `gamepad_action` is the intervention. + # If not intervention, policy's action is `action`. + # For consistency, let's store the *human's* action if intervention occurred. + info["action_intervention"] = action + + info["rerecord_episode"] = rerecord_episode + + # If episode ended, reset the state + if terminated or truncated: + # Add success/failure information to info dict + info["next.success"] = success + + # Auto reset if configured + if self.auto_reset: + obs, reset_info = self.reset() + info.update(reset_info) + + return obs, reward, terminated, truncated, info + + def close(self): + """ + Clean up resources when environment closes. + + Returns: + Result of closing the wrapped environment. + """ + if hasattr(self.teleop_device, "disconnect"): + self.teleop_device.disconnect() + + # Call the parent close method + return self.env.close() + + +class GymHilDeviceWrapper(gym.Wrapper): + def __init__(self, env, device="cpu"): + super().__init__(env) + self.device = device + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + for k in obs: + obs[k] = obs[k].to(self.device) + if "action_intervention" in info: + # NOTE: This is a hack to ensure the action intervention is a float32 tensor and supported on MPS device + info["action_intervention"] = info["action_intervention"].astype(np.float32) + info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device) + return obs, reward, terminated, truncated, info + + def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None): + obs, info = self.env.reset(seed=seed, options=options) + for k in obs: + obs[k] = obs[k].to(self.device) + if "action_intervention" in info: + # NOTE: This is a hack to ensure the action intervention is a float32 tensor and supported on MPS device + info["action_intervention"] = info["action_intervention"].astype(np.float32) + info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device) + return obs, info + + +class GymHilObservationProcessorWrapper(gym.ObservationWrapper): + def __init__(self, env: gym.Env): + super().__init__(env) + prev_space = self.observation_space + new_space = {} + + for key in prev_space: + if "pixels" in key: + for k in prev_space["pixels"]: + new_space[f"observation.images.{k}"] = gym.spaces.Box( + 0.0, 255.0, shape=(3, 128, 128), dtype=np.uint8 + ) + + if key == "agent_pos": + new_space["observation.state"] = prev_space["agent_pos"] + + self.observation_space = gym.spaces.Dict(new_space) + + def observation(self, observation: dict[str, Any]) -> dict[str, Any]: + return preprocess_observation(observation) + + +########################################################### +# Factory functions +########################################################### + + +def make_robot_env(cfg: EnvConfig) -> gym.Env: + """ + Factory function to create a robot environment. + + This function builds a robot environment with all necessary wrappers + based on the provided configuration. + + Args: + cfg: Configuration object containing environment parameters. + + Returns: + A gym environment with all necessary wrappers applied. + """ + if cfg.type == "hil": + import gym_hil # noqa: F401 + + # TODO (azouitine) + env = gym.make( + f"gym_hil/{cfg.task}", + image_obs=True, + render_mode="human", + use_gripper=cfg.wrapper.use_gripper, + gripper_penalty=cfg.wrapper.gripper_penalty, + ) + env = GymHilObservationProcessorWrapper(env=env) + env = GymHilDeviceWrapper(env=env, device=cfg.device) + env = BatchCompatibleWrapper(env=env) + env = TorchActionWrapper(env=env, device=cfg.device) + return env + + if not hasattr(cfg, "robot") or not hasattr(cfg, "teleop"): + raise ValueError( + "Configuration for 'gym_manipulator' must be HILSerlRobotEnvConfig with robot and teleop." + ) + + if cfg.robot is None: + raise ValueError("RobotConfig (cfg.robot) must be provided for gym_manipulator environment.") + robot = make_robot_from_config(cfg.robot) + + teleop_device = make_teleoperator_from_config(cfg.teleop) + teleop_device.connect() + + # Create base environment + env = RobotEnv( + robot=robot, + use_gripper=cfg.wrapper.use_gripper, + display_cameras=cfg.wrapper.display_cameras if cfg.wrapper else False, + ) + + # Add observation and image processing + if cfg.wrapper: + if cfg.wrapper.add_joint_velocity_to_observation: + env = AddJointVelocityToObservation(env=env, fps=cfg.fps) + if cfg.wrapper.add_current_to_observation: + env = AddCurrentToObservation(env=env) + if cfg.wrapper.add_ee_pose_to_observation: + env = EEObservationWrapper(env=env, ee_pose_limits=robot.end_effector_bounds) + + env = ConvertToLeRobotObservation(env=env, device=cfg.device) + + if cfg.wrapper and cfg.wrapper.crop_params_dict is not None: + env = ImageCropResizeWrapper( + env=env, + crop_params_dict=cfg.wrapper.crop_params_dict, + resize_size=cfg.wrapper.resize_size, + ) + + # Add reward computation and control wrappers + reward_classifier = init_reward_classifier(cfg) + if reward_classifier is not None: + env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) + + env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps) + if cfg.wrapper.use_gripper and cfg.wrapper.gripper_penalty is not None: + env = GripperPenaltyWrapper( + env=env, + penalty=cfg.wrapper.gripper_penalty, + ) + + # Control mode specific wrappers + control_mode = cfg.wrapper.control_mode + if control_mode == "gamepad": + assert isinstance(teleop_device, GamepadTeleop), ( + "teleop_device must be an instance of GamepadTeleop for gamepad control mode" + ) + env = GamepadControlWrapper( + env=env, + teleop_device=teleop_device, + use_gripper=cfg.wrapper.use_gripper, + ) + elif control_mode == "leader": + env = GearedLeaderControlWrapper( + env=env, + teleop_device=teleop_device, + end_effector_step_sizes=cfg.robot.end_effector_step_sizes, + use_gripper=cfg.wrapper.use_gripper, + ) + elif control_mode == "leader_automatic": + env = GearedLeaderAutomaticControlWrapper( + env=env, + teleop_device=teleop_device, + end_effector_step_sizes=cfg.robot.end_effector_step_sizes, + use_gripper=cfg.wrapper.use_gripper, + ) + else: + raise ValueError(f"Invalid control mode: {control_mode}") + + env = ResetWrapper( + env=env, + reset_pose=cfg.wrapper.fixed_reset_joint_positions, + reset_time_s=cfg.wrapper.reset_time_s, + ) + + env = BatchCompatibleWrapper(env=env) + env = TorchActionWrapper(env=env, device=cfg.device) + + return env + + +def init_reward_classifier(cfg): + """ + Load a reward classifier policy from a pretrained path if configured. + + Args: + cfg: The environment configuration containing classifier paths. + + Returns: + The loaded classifier model or None if not configured. + """ + if cfg.reward_classifier_pretrained_path is None: + return None + + from lerobot.common.policies.sac.reward_model.modeling_classifier import Classifier + + # Get device from config or default to CUDA + device = getattr(cfg, "device", "cpu") + + # Load the classifier directly using from_pretrained + classifier = Classifier.from_pretrained( + pretrained_name_or_path=cfg.reward_classifier_pretrained_path, + ) + + # Ensure model is on the correct device + classifier.to(device) + classifier.eval() # Set to evaluation mode + + return classifier + + +########################################################### +# Record and replay functions +########################################################### + + +def record_dataset(env, policy, cfg): + """ + Record a dataset of robot interactions using either a policy or teleop. + + This function runs episodes in the environment and records the observations, + actions, and results for dataset creation. + + Args: + env: The environment to record from. + policy: Optional policy to generate actions (if None, uses teleop). + cfg: Configuration object containing recording parameters like: + - repo_id: Repository ID for dataset storage + - dataset_root: Local root directory for dataset + - num_episodes: Number of episodes to record + - fps: Frames per second for recording + - push_to_hub: Whether to push dataset to Hugging Face Hub + - task: Name/description of the task being recorded + - number_of_steps_after_success: Number of additional steps to continue recording after + a success (reward=1) is detected. This helps collect + more positive examples for reward classifier training. + """ + from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + + # Setup initial action (zero action if using teleop) + action = env.action_space.sample() * 0.0 + + action_names = ["delta_x_ee", "delta_y_ee", "delta_z_ee"] + if cfg.wrapper.use_gripper: + action_names.append("gripper_delta") + + # Configure dataset features based on environment spaces + features = { + "observation.state": { + "dtype": "float32", + "shape": env.observation_space["observation.state"].shape, + "names": None, + }, + "action": { + "dtype": "float32", + "shape": (len(action_names),), + "names": action_names, + }, + "next.reward": {"dtype": "float32", "shape": (1,), "names": None}, + "next.done": {"dtype": "bool", "shape": (1,), "names": None}, + "complementary_info.discrete_penalty": { + "dtype": "float32", + "shape": (1,), + "names": ["discrete_penalty"], + }, + } + + # Add image features + for key in env.observation_space: + if "image" in key: + features[key] = { + "dtype": "video", + "shape": env.observation_space[key].shape, + "names": ["channels", "height", "width"], + } + + # Create dataset + dataset = LeRobotDataset.create( + cfg.repo_id, + cfg.fps, + root=cfg.dataset_root, + use_videos=True, + image_writer_threads=4, + image_writer_processes=0, + features=features, + ) + + # Record episodes + episode_index = 0 + recorded_action = None + while episode_index < cfg.num_episodes: + obs, _ = env.reset() + start_episode_t = time.perf_counter() + log_say(f"Recording episode {episode_index}", play_sounds=True) + + # Track success state collection + success_detected = False + success_steps_collected = 0 + + # Run episode steps + while time.perf_counter() - start_episode_t < cfg.wrapper.control_time_s: + start_loop_t = time.perf_counter() + + # Get action from policy if available + if cfg.pretrained_policy_name_or_path is not None: + action = policy.select_action(obs) + + # Step environment + obs, reward, terminated, truncated, info = env.step(action) + + # Check if episode needs to be rerecorded + if info.get("rerecord_episode", False): + break + + # For teleop, get action from intervention + recorded_action = { + "action": info["action_intervention"].cpu().squeeze(0).float() if policy is None else action + } + + # Process observation for dataset + obs_processed = {k: v.cpu().squeeze(0).float() for k, v in obs.items()} + + # Check if we've just detected success + if reward == 1.0 and not success_detected: + success_detected = True + logging.info("Success detected! Collecting additional success states.") + + # Add frame to dataset - continue marking as success even during extra collection steps + frame = {**obs_processed, **recorded_action} + + # If we're in the success collection phase, keep marking rewards as 1.0 + if success_detected: + frame["next.reward"] = np.array([1.0], dtype=np.float32) + else: + frame["next.reward"] = np.array([reward], dtype=np.float32) + + # Only mark as done if we're truly done (reached end or collected enough success states) + really_done = terminated or truncated + if success_detected: + success_steps_collected += 1 + really_done = success_steps_collected >= cfg.number_of_steps_after_success + + frame["next.done"] = np.array([really_done], dtype=bool) + frame["complementary_info.discrete_penalty"] = torch.tensor( + [info.get("discrete_penalty", 0.0)], dtype=torch.float32 + ) + dataset.add_frame(frame, task=cfg.task) + + # Maintain consistent timing + if cfg.fps: + dt_s = time.perf_counter() - start_loop_t + busy_wait(1 / cfg.fps - dt_s) + + # Check if we should end the episode + if (terminated or truncated) and not success_detected: + # Regular termination without success + break + elif success_detected and success_steps_collected >= cfg.number_of_steps_after_success: + # We've collected enough success states + logging.info(f"Collected {success_steps_collected} additional success states") + break + + # Handle episode recording + if info.get("rerecord_episode", False): + dataset.clear_episode_buffer() + logging.info(f"Re-recording episode {episode_index}") + continue + + dataset.save_episode() + episode_index += 1 + + # Finalize dataset + # dataset.consolidate(run_compute_stats=True) + if cfg.push_to_hub: + dataset.push_to_hub() + + +def replay_episode(env, cfg): + """ + Replay a recorded episode in the environment. + + This function loads actions from a previously recorded episode + and executes them in the environment. + + Args: + env: The environment to replay in. + cfg: Configuration object containing replay parameters: + - repo_id: Repository ID for dataset + - dataset_root: Local root directory for dataset + - episode: Episode ID to replay + """ + from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.dataset_root, episodes=[cfg.episode]) + env.reset() + + actions = dataset.hf_dataset.select_columns("action") + + for idx in range(dataset.num_frames): + start_episode_t = time.perf_counter() + + action = actions[idx]["action"] + env.step(action) + + dt_s = time.perf_counter() - start_episode_t + busy_wait(1 / 10 - dt_s) + + +@parser.wrap() +def main(cfg: EnvConfig): + """Main entry point for the robot environment script. + + This function runs the robot environment in one of several modes + based on the provided configuration. + + Args: + cfg: Configuration object defining the run parameters, + including mode (record, replay, random) and other settings. + """ + env = make_robot_env(cfg) + + if cfg.mode == "record": + policy = None + if cfg.pretrained_policy_name_or_path is not None: + from lerobot.common.policies.sac.modeling_sac import SACPolicy + + policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path) + policy.to(cfg.device) + policy.eval() + + record_dataset( + env, + policy=policy, + cfg=cfg, + ) + exit() + + if cfg.mode == "replay": + replay_episode( + env, + cfg=cfg, + ) + exit() + + env.reset() + + # Initialize the smoothed action as a random sample. + smoothed_action = env.action_space.sample() * 0.0 + + # Smoothing coefficient (alpha) defines how much of the new random sample to mix in. + # A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth. + alpha = 1.0 + + num_episode = 0 + successes = [] + while num_episode < 10: + start_loop_s = time.perf_counter() + # Sample a new random action from the robot's action space. + new_random_action = env.action_space.sample() + # Update the smoothed action using an exponential moving average. + smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action + + # Execute the step: wrap the NumPy action in a torch tensor. + obs, reward, terminated, truncated, info = env.step(smoothed_action) + if terminated or truncated: + successes.append(reward) + env.reset() + num_episode += 1 + + dt_s = time.perf_counter() - start_loop_s + busy_wait(1 / cfg.fps - dt_s) + + logging.info(f"Success after 20 steps {successes}") + logging.info(f"success rate {sum(successes) / len(successes)}") + + +if __name__ == "__main__": + main() diff --git a/lerobot/scripts/rl/learner.py b/lerobot/scripts/rl/learner.py new file mode 100644 index 0000000000..2d2c3755a9 --- /dev/null +++ b/lerobot/scripts/rl/learner.py @@ -0,0 +1,1206 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Learner server runner for distributed HILSerl robot policy training. + +This script implements the learner component of the distributed HILSerl architecture. +It initializes the policy network, maintains replay buffers, and updates +the policy based on transitions received from the actor server. + +Examples of usage: + +- Start a learner server for training: +```bash +python lerobot/scripts/rl/learner.py --config_path lerobot/configs/train_config_hilserl_so100.json +``` + +**NOTE**: Start the learner server before launching the actor server. The learner opens a gRPC server +to communicate with actors. + +**NOTE**: Training progress can be monitored through Weights & Biases if wandb.enable is set to true +in your configuration. + +**WORKFLOW**: +1. Create training configuration with proper policy, dataset, and environment settings +2. Start this learner server with the configuration +3. Start an actor server with the same configuration +4. Monitor training progress through wandb dashboard + +For more details on the complete HILSerl training workflow, see: +https://github.com/michel-aractingi/lerobot-hilserl-guide +""" + +import logging +import os +import shutil +import time +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from pprint import pformat + +import grpc +import torch +from termcolor import colored +from torch import nn +from torch.multiprocessing import Queue +from torch.optim.optimizer import Optimizer + +from lerobot.common.cameras import opencv # noqa: F401 +from lerobot.common.constants import ( + CHECKPOINTS_DIR, + LAST_CHECKPOINT_LINK, + PRETRAINED_MODEL_DIR, + TRAINING_STATE_DIR, +) +from lerobot.common.datasets.factory import make_dataset +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.sac.modeling_sac import SACPolicy +from lerobot.common.robots import so100_follower # noqa: F401 +from lerobot.common.teleoperators import gamepad, so100_leader # noqa: F401 +from lerobot.common.transport import services_pb2_grpc +from lerobot.common.transport.utils import ( + bytes_to_python_object, + bytes_to_transitions, + state_to_bytes, +) +from lerobot.common.utils.buffer import ReplayBuffer, concatenate_batch_transitions +from lerobot.common.utils.process import ProcessSignalHandler +from lerobot.common.utils.random_utils import set_seed +from lerobot.common.utils.train_utils import ( + get_step_checkpoint_dir, + save_checkpoint, + update_last_checkpoint, +) +from lerobot.common.utils.train_utils import ( + load_training_state as utils_load_training_state, +) +from lerobot.common.utils.transition import move_state_dict_to_device, move_transition_to_device +from lerobot.common.utils.utils import ( + format_big_number, + get_safe_torch_device, + init_logging, +) +from lerobot.common.utils.wandb_utils import WandBLogger +from lerobot.configs import parser +from lerobot.configs.train import TrainRLServerPipelineConfig +from lerobot.scripts.rl import learner_service + +LOG_PREFIX = "[LEARNER]" + + +################################################# +# MAIN ENTRY POINTS AND CORE ALGORITHM FUNCTIONS # +################################################# + + +@parser.wrap() +def train_cli(cfg: TrainRLServerPipelineConfig): + if not use_threads(cfg): + import torch.multiprocessing as mp + + mp.set_start_method("spawn") + + # Use the job_name from the config + train( + cfg, + job_name=cfg.job_name, + ) + + logging.info("[LEARNER] train_cli finished") + + +def train(cfg: TrainRLServerPipelineConfig, job_name: str | None = None): + """ + Main training function that initializes and runs the training process. + + Args: + cfg (TrainRLServerPipelineConfig): The training configuration + job_name (str | None, optional): Job name for logging. Defaults to None. + """ + + cfg.validate() + + if job_name is None: + job_name = cfg.job_name + + if job_name is None: + raise ValueError("Job name must be specified either in config or as a parameter") + + display_pid = False + if not use_threads(cfg): + display_pid = True + + # Create logs directory to ensure it exists + log_dir = os.path.join(cfg.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"learner_{job_name}.log") + + # Initialize logging with explicit log file + init_logging(log_file=log_file, display_pid=display_pid) + logging.info(f"Learner logging initialized, writing to {log_file}") + logging.info(pformat(cfg.to_dict())) + + # Setup WandB logging if enabled + if cfg.wandb.enable and cfg.wandb.project: + from lerobot.common.utils.wandb_utils import WandBLogger + + wandb_logger = WandBLogger(cfg) + else: + wandb_logger = None + logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) + + # Handle resume logic + cfg = handle_resume_logic(cfg) + + set_seed(seed=cfg.seed) + + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + + is_threaded = use_threads(cfg) + shutdown_event = ProcessSignalHandler(is_threaded, display_pid=display_pid).shutdown_event + + start_learner_threads( + cfg=cfg, + wandb_logger=wandb_logger, + shutdown_event=shutdown_event, + ) + + +def start_learner_threads( + cfg: TrainRLServerPipelineConfig, + wandb_logger: WandBLogger | None, + shutdown_event: any, # Event, +) -> None: + """ + Start the learner threads for training. + + Args: + cfg (TrainRLServerPipelineConfig): Training configuration + wandb_logger (WandBLogger | None): Logger for metrics + shutdown_event: Event to signal shutdown + """ + # Create multiprocessing queues + transition_queue = Queue() + interaction_message_queue = Queue() + parameters_queue = Queue() + + concurrency_entity = None + + if use_threads(cfg): + from threading import Thread + + concurrency_entity = Thread + else: + from torch.multiprocessing import Process + + concurrency_entity = Process + + communication_process = concurrency_entity( + target=start_learner, + args=( + parameters_queue, + transition_queue, + interaction_message_queue, + shutdown_event, + cfg, + ), + daemon=True, + ) + communication_process.start() + + add_actor_information_and_train( + cfg=cfg, + wandb_logger=wandb_logger, + shutdown_event=shutdown_event, + transition_queue=transition_queue, + interaction_message_queue=interaction_message_queue, + parameters_queue=parameters_queue, + ) + logging.info("[LEARNER] Training process stopped") + + logging.info("[LEARNER] Closing queues") + transition_queue.close() + interaction_message_queue.close() + parameters_queue.close() + + communication_process.join() + logging.info("[LEARNER] Communication process joined") + + logging.info("[LEARNER] join queues") + transition_queue.cancel_join_thread() + interaction_message_queue.cancel_join_thread() + parameters_queue.cancel_join_thread() + + logging.info("[LEARNER] queues closed") + + +################################################# +# Core algorithm functions # +################################################# + + +def add_actor_information_and_train( + cfg: TrainRLServerPipelineConfig, + wandb_logger: WandBLogger | None, + shutdown_event: any, # Event, + transition_queue: Queue, + interaction_message_queue: Queue, + parameters_queue: Queue, +): + """ + Handles data transfer from the actor to the learner, manages training updates, + and logs training progress in an online reinforcement learning setup. + + This function continuously: + - Transfers transitions from the actor to the replay buffer. + - Logs received interaction messages. + - Ensures training begins only when the replay buffer has a sufficient number of transitions. + - Samples batches from the replay buffer and performs multiple critic updates. + - Periodically updates the actor, critic, and temperature optimizers. + - Logs training statistics, including loss values and optimization frequency. + + NOTE: This function doesn't have a single responsibility, it should be split into multiple functions + in the future. The reason why we did that is the GIL in Python. It's super slow the performance + are divided by 200. So we need to have a single thread that does all the work. + + Args: + cfg (TrainRLServerPipelineConfig): Configuration object containing hyperparameters. + wandb_logger (WandBLogger | None): Logger for tracking training progress. + shutdown_event (Event): Event to signal shutdown. + transition_queue (Queue): Queue for receiving transitions from the actor. + interaction_message_queue (Queue): Queue for receiving interaction messages from the actor. + parameters_queue (Queue): Queue for sending policy parameters to the actor. + """ + # Extract all configuration variables at the beginning, it improve the speed performance + # of 7% + device = get_safe_torch_device(try_device=cfg.policy.device, log=True) + storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device) + clip_grad_norm_value = cfg.policy.grad_clip_norm + online_step_before_learning = cfg.policy.online_step_before_learning + utd_ratio = cfg.policy.utd_ratio + fps = cfg.env.fps + log_freq = cfg.log_freq + save_freq = cfg.save_freq + policy_update_freq = cfg.policy.policy_update_freq + policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency + saving_checkpoint = cfg.save_checkpoint + online_steps = cfg.policy.online_steps + async_prefetch = cfg.policy.async_prefetch + + # Initialize logging for multiprocessing + if not use_threads(cfg): + log_dir = os.path.join(cfg.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"learner_train_process_{os.getpid()}.log") + init_logging(log_file=log_file, display_pid=True) + logging.info("Initialized logging for actor information and training process") + + logging.info("Initializing policy") + + policy: SACPolicy = make_policy( + cfg=cfg.policy, + env_cfg=cfg.env, + ) + + assert isinstance(policy, nn.Module) + + policy.train() + + push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) + + last_time_policy_pushed = time.time() + + optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy) + + # If we are resuming, we need to load the training state + resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers) + + log_training_info(cfg=cfg, policy=policy) + + replay_buffer = initialize_replay_buffer(cfg, device, storage_device) + batch_size = cfg.batch_size + offline_replay_buffer = None + + if cfg.dataset is not None: + offline_replay_buffer = initialize_offline_replay_buffer( + cfg=cfg, + device=device, + storage_device=storage_device, + ) + batch_size: int = batch_size // 2 # We will sample from both replay buffer + + logging.info("Starting learner thread") + interaction_message = None + optimization_step = resume_optimization_step if resume_optimization_step is not None else 0 + interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0 + + dataset_repo_id = None + if cfg.dataset is not None: + dataset_repo_id = cfg.dataset.repo_id + + # Initialize iterators + online_iterator = None + offline_iterator = None + + # NOTE: THIS IS THE MAIN LOOP OF THE LEARNER + while True: + # Exit the training loop if shutdown is requested + if shutdown_event is not None and shutdown_event.is_set(): + logging.info("[LEARNER] Shutdown signal received. Exiting...") + break + + # Process all available transitions to the replay buffer, send by the actor server + process_transitions( + transition_queue=transition_queue, + replay_buffer=replay_buffer, + offline_replay_buffer=offline_replay_buffer, + device=device, + dataset_repo_id=dataset_repo_id, + shutdown_event=shutdown_event, + ) + + # Process all available interaction messages sent by the actor server + interaction_message = process_interaction_messages( + interaction_message_queue=interaction_message_queue, + interaction_step_shift=interaction_step_shift, + wandb_logger=wandb_logger, + shutdown_event=shutdown_event, + ) + + # Wait until the replay buffer has enough samples to start training + if len(replay_buffer) < online_step_before_learning: + continue + + if online_iterator is None: + online_iterator = replay_buffer.get_iterator( + batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2 + ) + + if offline_replay_buffer is not None and offline_iterator is None: + offline_iterator = offline_replay_buffer.get_iterator( + batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2 + ) + + time_for_one_optimization_step = time.time() + for _ in range(utd_ratio - 1): + # Sample from the iterators + batch = next(online_iterator) + + if dataset_repo_id is not None: + batch_offline = next(offline_iterator) + batch = concatenate_batch_transitions( + left_batch_transitions=batch, right_batch_transition=batch_offline + ) + + actions = batch["action"] + rewards = batch["reward"] + observations = batch["state"] + next_observations = batch["next_state"] + done = batch["done"] + check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) + + observation_features, next_observation_features = get_observation_features( + policy=policy, observations=observations, next_observations=next_observations + ) + + # Create a batch dictionary with all required elements for the forward method + forward_batch = { + "action": actions, + "reward": rewards, + "state": observations, + "next_state": next_observations, + "done": done, + "observation_feature": observation_features, + "next_observation_feature": next_observation_features, + "complementary_info": batch["complementary_info"], + } + + # Use the forward method for critic loss + critic_output = policy.forward(forward_batch, model="critic") + + # Main critic optimization + loss_critic = critic_output["loss_critic"] + optimizers["critic"].zero_grad() + loss_critic.backward() + critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value + ) + optimizers["critic"].step() + + # Discrete critic optimization (if available) + if policy.config.num_discrete_actions is not None: + discrete_critic_output = policy.forward(forward_batch, model="discrete_critic") + loss_discrete_critic = discrete_critic_output["loss_discrete_critic"] + optimizers["discrete_critic"].zero_grad() + loss_discrete_critic.backward() + discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value + ) + optimizers["discrete_critic"].step() + + # Update target networks (main and discrete) + policy.update_target_networks() + + # Sample for the last update in the UTD ratio + batch = next(online_iterator) + + if dataset_repo_id is not None: + batch_offline = next(offline_iterator) + batch = concatenate_batch_transitions( + left_batch_transitions=batch, right_batch_transition=batch_offline + ) + + actions = batch["action"] + rewards = batch["reward"] + observations = batch["state"] + next_observations = batch["next_state"] + done = batch["done"] + + check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) + + observation_features, next_observation_features = get_observation_features( + policy=policy, observations=observations, next_observations=next_observations + ) + + # Create a batch dictionary with all required elements for the forward method + forward_batch = { + "action": actions, + "reward": rewards, + "state": observations, + "next_state": next_observations, + "done": done, + "observation_feature": observation_features, + "next_observation_feature": next_observation_features, + } + + critic_output = policy.forward(forward_batch, model="critic") + + loss_critic = critic_output["loss_critic"] + optimizers["critic"].zero_grad() + loss_critic.backward() + critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value + ).item() + optimizers["critic"].step() + + # Initialize training info dictionary + training_infos = { + "loss_critic": loss_critic.item(), + "critic_grad_norm": critic_grad_norm, + } + + # Discrete critic optimization (if available) + if policy.config.num_discrete_actions is not None: + discrete_critic_output = policy.forward(forward_batch, model="discrete_critic") + loss_discrete_critic = discrete_critic_output["loss_discrete_critic"] + optimizers["discrete_critic"].zero_grad() + loss_discrete_critic.backward() + discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value + ).item() + optimizers["discrete_critic"].step() + + # Add discrete critic info to training info + training_infos["loss_discrete_critic"] = loss_discrete_critic.item() + training_infos["discrete_critic_grad_norm"] = discrete_critic_grad_norm + + # Actor and temperature optimization (at specified frequency) + if optimization_step % policy_update_freq == 0: + for _ in range(policy_update_freq): + # Actor optimization + actor_output = policy.forward(forward_batch, model="actor") + loss_actor = actor_output["loss_actor"] + optimizers["actor"].zero_grad() + loss_actor.backward() + actor_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value + ).item() + optimizers["actor"].step() + + # Add actor info to training info + training_infos["loss_actor"] = loss_actor.item() + training_infos["actor_grad_norm"] = actor_grad_norm + + # Temperature optimization + temperature_output = policy.forward(forward_batch, model="temperature") + loss_temperature = temperature_output["loss_temperature"] + optimizers["temperature"].zero_grad() + loss_temperature.backward() + temp_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=[policy.log_alpha], max_norm=clip_grad_norm_value + ).item() + optimizers["temperature"].step() + + # Add temperature info to training info + training_infos["loss_temperature"] = loss_temperature.item() + training_infos["temperature_grad_norm"] = temp_grad_norm + training_infos["temperature"] = policy.temperature + + # Update temperature + policy.update_temperature() + + # Push policy to actors if needed + if time.time() - last_time_policy_pushed > policy_parameters_push_frequency: + push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) + last_time_policy_pushed = time.time() + + # Update target networks (main and discrete) + policy.update_target_networks() + + # Log training metrics at specified intervals + if optimization_step % log_freq == 0: + training_infos["replay_buffer_size"] = len(replay_buffer) + if offline_replay_buffer is not None: + training_infos["offline_replay_buffer_size"] = len(offline_replay_buffer) + training_infos["Optimization step"] = optimization_step + + # Log training metrics + if wandb_logger: + wandb_logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step") + + # Calculate and log optimization frequency + time_for_one_optimization_step = time.time() - time_for_one_optimization_step + frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9) + + logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}") + + # Log optimization frequency + if wandb_logger: + wandb_logger.log_dict( + { + "Optimization frequency loop [Hz]": frequency_for_one_optimization_step, + "Optimization step": optimization_step, + }, + mode="train", + custom_step_key="Optimization step", + ) + + optimization_step += 1 + if optimization_step % log_freq == 0: + logging.info(f"[LEARNER] Number of optimization step: {optimization_step}") + + # Save checkpoint at specified intervals + if saving_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps): + save_training_checkpoint( + cfg=cfg, + optimization_step=optimization_step, + online_steps=online_steps, + interaction_message=interaction_message, + policy=policy, + optimizers=optimizers, + replay_buffer=replay_buffer, + offline_replay_buffer=offline_replay_buffer, + dataset_repo_id=dataset_repo_id, + fps=fps, + ) + + +def start_learner( + parameters_queue: Queue, + transition_queue: Queue, + interaction_message_queue: Queue, + shutdown_event: any, # Event, + cfg: TrainRLServerPipelineConfig, +): + """ + Start the learner server for training. + It will receive transitions and interaction messages from the actor server, + and send policy parameters to the actor server. + + Args: + parameters_queue: Queue for sending policy parameters to the actor + transition_queue: Queue for receiving transitions from the actor + interaction_message_queue: Queue for receiving interaction messages from the actor + shutdown_event: Event to signal shutdown + cfg: Training configuration + """ + if not use_threads(cfg): + # Create a process-specific log file + log_dir = os.path.join(cfg.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"learner_process_{os.getpid()}.log") + + # Initialize logging with explicit log file + init_logging(log_file=log_file, display_pid=True) + logging.info("Learner server process logging initialized") + + # Setup process handlers to handle shutdown signal + # But use shutdown event from the main process + # Return back for MP + # TODO: Check if its useful + _ = ProcessSignalHandler(False, display_pid=True) + + service = learner_service.LearnerService( + shutdown_event=shutdown_event, + parameters_queue=parameters_queue, + seconds_between_pushes=cfg.policy.actor_learner_config.policy_parameters_push_frequency, + transition_queue=transition_queue, + interaction_message_queue=interaction_message_queue, + queue_get_timeout=cfg.policy.actor_learner_config.queue_get_timeout, + ) + + server = grpc.server( + ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS), + options=[ + ("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE), + ("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE), + ], + ) + + services_pb2_grpc.add_LearnerServiceServicer_to_server( + service, + server, + ) + + host = cfg.policy.actor_learner_config.learner_host + port = cfg.policy.actor_learner_config.learner_port + + server.add_insecure_port(f"{host}:{port}") + server.start() + logging.info("[LEARNER] gRPC server started") + + shutdown_event.wait() + logging.info("[LEARNER] Stopping gRPC server...") + server.stop(learner_service.SHUTDOWN_TIMEOUT) + logging.info("[LEARNER] gRPC server stopped") + + +def save_training_checkpoint( + cfg: TrainRLServerPipelineConfig, + optimization_step: int, + online_steps: int, + interaction_message: dict | None, + policy: nn.Module, + optimizers: dict[str, Optimizer], + replay_buffer: ReplayBuffer, + offline_replay_buffer: ReplayBuffer | None = None, + dataset_repo_id: str | None = None, + fps: int = 30, +) -> None: + """ + Save training checkpoint and associated data. + + This function performs the following steps: + 1. Creates a checkpoint directory with the current optimization step + 2. Saves the policy model, configuration, and optimizer states + 3. Saves the current interaction step for resuming training + 4. Updates the "last" checkpoint symlink to point to this checkpoint + 5. Saves the replay buffer as a dataset for later use + 6. If an offline replay buffer exists, saves it as a separate dataset + + Args: + cfg: Training configuration + optimization_step: Current optimization step + online_steps: Total number of online steps + interaction_message: Dictionary containing interaction information + policy: Policy model to save + optimizers: Dictionary of optimizers + replay_buffer: Replay buffer to save as dataset + offline_replay_buffer: Optional offline replay buffer to save + dataset_repo_id: Repository ID for dataset + fps: Frames per second for dataset + """ + logging.info(f"Checkpoint policy after step {optimization_step}") + _num_digits = max(6, len(str(online_steps))) + interaction_step = interaction_message["Interaction step"] if interaction_message is not None else 0 + + # Create checkpoint directory + checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step) + + # Save checkpoint + save_checkpoint( + checkpoint_dir=checkpoint_dir, + step=optimization_step, + cfg=cfg, + policy=policy, + optimizer=optimizers, + scheduler=None, + ) + + # Save interaction step manually + training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR) + os.makedirs(training_state_dir, exist_ok=True) + training_state = {"step": optimization_step, "interaction_step": interaction_step} + torch.save(training_state, os.path.join(training_state_dir, "training_state.pt")) + + # Update the "last" symlink + update_last_checkpoint(checkpoint_dir) + + # TODO : temporary save replay buffer here, remove later when on the robot + # We want to control this with the keyboard inputs + dataset_dir = os.path.join(cfg.output_dir, "dataset") + if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir): + shutil.rmtree(dataset_dir) + + # Save dataset + # NOTE: Handle the case where the dataset repo id is not specified in the config + # eg. RL training without demonstrations data + repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id + replay_buffer.to_lerobot_dataset(repo_id=repo_id_buffer_save, fps=fps, root=dataset_dir) + + if offline_replay_buffer is not None: + dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline") + if os.path.exists(dataset_offline_dir) and os.path.isdir(dataset_offline_dir): + shutil.rmtree(dataset_offline_dir) + + offline_replay_buffer.to_lerobot_dataset( + cfg.dataset.repo_id, + fps=fps, + root=dataset_offline_dir, + ) + + logging.info("Resume training") + + +def make_optimizers_and_scheduler(cfg: TrainRLServerPipelineConfig, policy: nn.Module): + """ + Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy. + + This function sets up Adam optimizers for: + - The **actor network**, ensuring that only relevant parameters are optimized. + - The **critic ensemble**, which evaluates the value function. + - The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods. + + It also initializes a learning rate scheduler, though currently, it is set to `None`. + + NOTE: + - If the encoder is shared, its parameters are excluded from the actor's optimization process. + - The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor. + + Args: + cfg: Configuration object containing hyperparameters. + policy (nn.Module): The policy model containing the actor, critic, and temperature components. + + Returns: + Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]: + A tuple containing: + - `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers. + - `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling. + + """ + optimizer_actor = torch.optim.Adam( + params=[ + p + for n, p in policy.actor.named_parameters() + if not policy.config.shared_encoder or not n.startswith("encoder") + ], + lr=cfg.policy.actor_lr, + ) + optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) + + if cfg.policy.num_discrete_actions is not None: + optimizer_discrete_critic = torch.optim.Adam( + params=policy.discrete_critic.parameters(), lr=cfg.policy.critic_lr + ) + optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) + lr_scheduler = None + optimizers = { + "actor": optimizer_actor, + "critic": optimizer_critic, + "temperature": optimizer_temperature, + } + if cfg.policy.num_discrete_actions is not None: + optimizers["discrete_critic"] = optimizer_discrete_critic + return optimizers, lr_scheduler + + +################################################# +# Training setup functions # +################################################# + + +def handle_resume_logic(cfg: TrainRLServerPipelineConfig) -> TrainRLServerPipelineConfig: + """ + Handle the resume logic for training. + + If resume is True: + - Verifies that a checkpoint exists + - Loads the checkpoint configuration + - Logs resumption details + - Returns the checkpoint configuration + + If resume is False: + - Checks if an output directory exists (to prevent accidental overwriting) + - Returns the original configuration + + Args: + cfg (TrainRLServerPipelineConfig): The training configuration + + Returns: + TrainRLServerPipelineConfig: The updated configuration + + Raises: + RuntimeError: If resume is True but no checkpoint found, or if resume is False but directory exists + """ + out_dir = cfg.output_dir + + # Case 1: Not resuming, but need to check if directory exists to prevent overwrites + if not cfg.resume: + checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) + if os.path.exists(checkpoint_dir): + raise RuntimeError( + f"Output directory {checkpoint_dir} already exists. Use `resume=true` to resume training." + ) + return cfg + + # Case 2: Resuming training + checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) + if not os.path.exists(checkpoint_dir): + raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True") + + # Log that we found a valid checkpoint and are resuming + logging.info( + colored( + "Valid checkpoint found: resume=True detected, resuming previous run", + color="yellow", + attrs=["bold"], + ) + ) + + # Load config using Draccus + checkpoint_cfg_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR, "train_config.json") + checkpoint_cfg = TrainRLServerPipelineConfig.from_pretrained(checkpoint_cfg_path) + + # Ensure resume flag is set in returned config + checkpoint_cfg.resume = True + return checkpoint_cfg + + +def load_training_state( + cfg: TrainRLServerPipelineConfig, + optimizers: Optimizer | dict[str, Optimizer], +): + """ + Loads the training state (optimizers, step count, etc.) from a checkpoint. + + Args: + cfg (TrainRLServerPipelineConfig): Training configuration + optimizers (Optimizer | dict): Optimizers to load state into + + Returns: + tuple: (optimization_step, interaction_step) or (None, None) if not resuming + """ + if not cfg.resume: + return None, None + + # Construct path to the last checkpoint directory + checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) + + logging.info(f"Loading training state from {checkpoint_dir}") + + try: + # Use the utility function from train_utils which loads the optimizer state + step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None) + + # Load interaction step separately from training_state.pt + training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt") + interaction_step = 0 + if os.path.exists(training_state_path): + training_state = torch.load(training_state_path, weights_only=False) # nosec B614: Safe usage of torch.load + interaction_step = training_state.get("interaction_step", 0) + + logging.info(f"Resuming from step {step}, interaction step {interaction_step}") + return step, interaction_step + + except Exception as e: + logging.error(f"Failed to load training state: {e}") + return None, None + + +def log_training_info(cfg: TrainRLServerPipelineConfig, policy: nn.Module) -> None: + """ + Log information about the training process. + + Args: + cfg (TrainRLServerPipelineConfig): Training configuration + policy (nn.Module): Policy model + """ + num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) + num_total_params = sum(p.numel() for p in policy.parameters()) + + logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") + logging.info(f"{cfg.env.task=}") + logging.info(f"{cfg.policy.online_steps=}") + logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") + logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") + + +def initialize_replay_buffer( + cfg: TrainRLServerPipelineConfig, device: str, storage_device: str +) -> ReplayBuffer: + """ + Initialize a replay buffer, either empty or from a dataset if resuming. + + Args: + cfg (TrainRLServerPipelineConfig): Training configuration + device (str): Device to store tensors on + storage_device (str): Device for storage optimization + + Returns: + ReplayBuffer: Initialized replay buffer + """ + if not cfg.resume: + return ReplayBuffer( + capacity=cfg.policy.online_buffer_capacity, + device=device, + state_keys=cfg.policy.input_features.keys(), + storage_device=storage_device, + optimize_memory=True, + ) + + logging.info("Resume training load the online dataset") + dataset_path = os.path.join(cfg.output_dir, "dataset") + + # NOTE: In RL is possible to not have a dataset. + repo_id = None + if cfg.dataset is not None: + repo_id = cfg.dataset.repo_id + dataset = LeRobotDataset( + repo_id=repo_id, + root=dataset_path, + ) + return ReplayBuffer.from_lerobot_dataset( + lerobot_dataset=dataset, + capacity=cfg.policy.online_buffer_capacity, + device=device, + state_keys=cfg.policy.input_features.keys(), + optimize_memory=True, + ) + + +def initialize_offline_replay_buffer( + cfg: TrainRLServerPipelineConfig, + device: str, + storage_device: str, +) -> ReplayBuffer: + """ + Initialize an offline replay buffer from a dataset. + + Args: + cfg (TrainRLServerPipelineConfig): Training configuration + device (str): Device to store tensors on + storage_device (str): Device for storage optimization + + Returns: + ReplayBuffer: Initialized offline replay buffer + """ + if not cfg.resume: + logging.info("make_dataset offline buffer") + offline_dataset = make_dataset(cfg) + else: + logging.info("load offline dataset") + dataset_offline_path = os.path.join(cfg.output_dir, "dataset_offline") + offline_dataset = LeRobotDataset( + repo_id=cfg.dataset.repo_id, + root=dataset_offline_path, + ) + + logging.info("Convert to a offline replay buffer") + offline_replay_buffer = ReplayBuffer.from_lerobot_dataset( + offline_dataset, + device=device, + state_keys=cfg.policy.input_features.keys(), + storage_device=storage_device, + optimize_memory=True, + capacity=cfg.policy.offline_buffer_capacity, + ) + return offline_replay_buffer + + +################################################# +# Utilities/Helpers functions # +################################################# + + +def get_observation_features( + policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor +) -> tuple[torch.Tensor | None, torch.Tensor | None]: + """ + Get observation features from the policy encoder. It act as cache for the observation features. + when the encoder is frozen, the observation features are not updated. + We can save compute by caching the observation features. + + Args: + policy: The policy model + observations: The current observations + next_observations: The next observations + + Returns: + tuple: observation_features, next_observation_features + """ + + if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder: + return None, None + + with torch.no_grad(): + observation_features = policy.actor.encoder.get_cached_image_features(observations, normalize=True) + next_observation_features = policy.actor.encoder.get_cached_image_features( + next_observations, normalize=True + ) + + return observation_features, next_observation_features + + +def use_threads(cfg: TrainRLServerPipelineConfig) -> bool: + return cfg.policy.concurrency.learner == "threads" + + +def check_nan_in_transition( + observations: torch.Tensor, + actions: torch.Tensor, + next_state: torch.Tensor, + raise_error: bool = False, +) -> bool: + """ + Check for NaN values in transition data. + + Args: + observations: Dictionary of observation tensors + actions: Action tensor + next_state: Dictionary of next state tensors + raise_error: If True, raises ValueError when NaN is detected + + Returns: + bool: True if NaN values were detected, False otherwise + """ + nan_detected = False + + # Check observations + for key, tensor in observations.items(): + if torch.isnan(tensor).any(): + logging.error(f"observations[{key}] contains NaN values") + nan_detected = True + if raise_error: + raise ValueError(f"NaN detected in observations[{key}]") + + # Check next state + for key, tensor in next_state.items(): + if torch.isnan(tensor).any(): + logging.error(f"next_state[{key}] contains NaN values") + nan_detected = True + if raise_error: + raise ValueError(f"NaN detected in next_state[{key}]") + + # Check actions + if torch.isnan(actions).any(): + logging.error("actions contains NaN values") + nan_detected = True + if raise_error: + raise ValueError("NaN detected in actions") + + return nan_detected + + +def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module): + logging.debug("[LEARNER] Pushing actor policy to the queue") + state_dict = move_state_dict_to_device(policy.actor.state_dict(), device="cpu") + state_bytes = state_to_bytes(state_dict) + parameters_queue.put(state_bytes) + + +def process_interaction_message( + message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None +): + """Process a single interaction message with consistent handling.""" + message = bytes_to_python_object(message) + # Shift interaction step for consistency with checkpointed state + message["Interaction step"] += interaction_step_shift + + # Log if logger available + if wandb_logger: + wandb_logger.log_dict(d=message, mode="train", custom_step_key="Interaction step") + + return message + + +def process_transitions( + transition_queue: Queue, + replay_buffer: ReplayBuffer, + offline_replay_buffer: ReplayBuffer, + device: str, + dataset_repo_id: str | None, + shutdown_event: any, +): + """Process all available transitions from the queue. + + Args: + transition_queue: Queue for receiving transitions from the actor + replay_buffer: Replay buffer to add transitions to + offline_replay_buffer: Offline replay buffer to add transitions to + device: Device to move transitions to + dataset_repo_id: Repository ID for dataset + shutdown_event: Event to signal shutdown + """ + while not transition_queue.empty() and not shutdown_event.is_set(): + transition_list = transition_queue.get() + transition_list = bytes_to_transitions(buffer=transition_list) + + for transition in transition_list: + transition = move_transition_to_device(transition=transition, device=device) + + # Skip transitions with NaN values + if check_nan_in_transition( + observations=transition["state"], + actions=transition["action"], + next_state=transition["next_state"], + ): + logging.warning("[LEARNER] NaN detected in transition, skipping") + continue + + replay_buffer.add(**transition) + + # Add to offline buffer if it's an intervention + if dataset_repo_id is not None and transition.get("complementary_info", {}).get( + "is_intervention" + ): + offline_replay_buffer.add(**transition) + + +def process_interaction_messages( + interaction_message_queue: Queue, + interaction_step_shift: int, + wandb_logger: WandBLogger | None, + shutdown_event: any, +) -> dict | None: + """Process all available interaction messages from the queue. + + Args: + interaction_message_queue: Queue for receiving interaction messages + interaction_step_shift: Amount to shift interaction step by + wandb_logger: Logger for tracking progress + shutdown_event: Event to signal shutdown + + Returns: + dict | None: The last interaction message processed, or None if none were processed + """ + last_message = None + while not interaction_message_queue.empty() and not shutdown_event.is_set(): + message = interaction_message_queue.get() + last_message = process_interaction_message( + message=message, + interaction_step_shift=interaction_step_shift, + wandb_logger=wandb_logger, + ) + + return last_message + + +if __name__ == "__main__": + train_cli() + logging.info("[LEARNER] main finished") diff --git a/lerobot/scripts/rl/learner_service.py b/lerobot/scripts/rl/learner_service.py new file mode 100644 index 0000000000..f967d812cf --- /dev/null +++ b/lerobot/scripts/rl/learner_service.py @@ -0,0 +1,118 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from multiprocessing import Event, Queue + +from lerobot.common.transport import services_pb2, services_pb2_grpc +from lerobot.common.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks +from lerobot.common.utils.queue import get_last_item_from_queue + +MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB +MAX_WORKERS = 3 # Stream parameters, send transitions and interactions +SHUTDOWN_TIMEOUT = 10 + + +class LearnerService(services_pb2_grpc.LearnerServiceServicer): + """ + Implementation of the LearnerService gRPC service + This service is used to send parameters to the Actor and receive transitions and interactions from the Actor + check transport.proto for the gRPC service definition + """ + + def __init__( + self, + shutdown_event: Event, # type: ignore + parameters_queue: Queue, + seconds_between_pushes: float, + transition_queue: Queue, + interaction_message_queue: Queue, + queue_get_timeout: float = 0.001, + ): + self.shutdown_event = shutdown_event + self.parameters_queue = parameters_queue + self.seconds_between_pushes = seconds_between_pushes + self.transition_queue = transition_queue + self.interaction_message_queue = interaction_message_queue + self.queue_get_timeout = queue_get_timeout + + def StreamParameters(self, request, context): # noqa: N802 + # TODO: authorize the request + logging.info("[LEARNER] Received request to stream parameters from the Actor") + + last_push_time = 0 + + while not self.shutdown_event.is_set(): + time_since_last_push = time.time() - last_push_time + if time_since_last_push < self.seconds_between_pushes: + self.shutdown_event.wait(self.seconds_between_pushes - time_since_last_push) + # Continue, because we could receive a shutdown event, + # and it's checked in the while loop + continue + + logging.info("[LEARNER] Push parameters to the Actor") + buffer = get_last_item_from_queue( + self.parameters_queue, block=True, timeout=self.queue_get_timeout + ) + + if buffer is None: + continue + + yield from send_bytes_in_chunks( + buffer, + services_pb2.Parameters, + log_prefix="[LEARNER] Sending parameters", + silent=True, + ) + + last_push_time = time.time() + logging.info("[LEARNER] Parameters sent") + + logging.info("[LEARNER] Stream parameters finished") + return services_pb2.Empty() + + def SendTransitions(self, request_iterator, _context): # noqa: N802 + # TODO: authorize the request + logging.info("[LEARNER] Received request to receive transitions from the Actor") + + receive_bytes_in_chunks( + request_iterator, + self.transition_queue, + self.shutdown_event, + log_prefix="[LEARNER] transitions", + ) + + logging.debug("[LEARNER] Finished receiving transitions") + return services_pb2.Empty() + + def SendInteractions(self, request_iterator, _context): # noqa: N802 + # TODO: authorize the request + logging.info("[LEARNER] Received request to receive interactions from the Actor") + + receive_bytes_in_chunks( + request_iterator, + self.interaction_message_queue, + self.shutdown_event, + log_prefix="[LEARNER] interactions", + ) + + logging.debug("[LEARNER] Finished receiving interactions") + return services_pb2.Empty() + + def Ready(self, request, context): # noqa: N802 + return services_pb2.Empty() diff --git a/lerobot/teleoperate.py b/lerobot/teleoperate.py index 97e6104301..6080dfb403 100644 --- a/lerobot/teleoperate.py +++ b/lerobot/teleoperate.py @@ -58,7 +58,7 @@ from lerobot.common.utils.utils import init_logging, move_cursor_up from lerobot.common.utils.visualization_utils import _init_rerun -from .common.teleoperators import koch_leader, so100_leader, so101_leader # noqa: F401 +from .common.teleoperators import gamepad, koch_leader, so100_leader, so101_leader # noqa: F401 @dataclass diff --git a/pyproject.toml b/pyproject.toml index 2ce5d049be..31276a18b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,7 @@ dependencies = [ "pyserial>=3.5", "pyzmq>=26.2.1", "rerun-sdk>=0.21.0", + "scipy>=1.14.0", "termcolor>=2.4.0", "torch>=2.2.1", "torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", @@ -85,19 +86,21 @@ dora = [ ] dynamixel = ["dynamixel-sdk>=3.7.31"] feetech = ["feetech-servo-sdk>=1.0.0"] +gamepad = ["pygame>=2.5.1", "hidapi>=0.14.0"] intelrealsense = [ "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'", "pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'", ] pi0 = ["transformers>=4.48.0"] -smolvla = ["transformers>=4.50.3", "num2words>=0.5.14", "accelerate>=1.7.0"] +smolvla = ["transformers>=4.50.3", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"] pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"] stretch = [ "hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'", "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'", "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'" ] -test = ["pytest>=8.1.0", "pytest-cov>=5.0.0", "mock-serial>=0.0.1 ; sys_platform != 'win32'"] +test = ["pytest>=8.1.0", "pytest-timeout>=2.4.0", "pytest-cov>=5.0.0", "pyserial>=3.5", "mock-serial>=0.0.1 ; sys_platform != 'win32'"] +hilserl = ["transformers>=4.48", "gym-hil>=0.1.8", "protobuf>=5.29.3", "grpcio==1.71.0"] umi = ["imagecodecs>=2024.1.1"] video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"] xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"] @@ -108,7 +111,7 @@ requires-poetry = ">=2.1" [tool.ruff] line-length = 110 target-version = "py310" -exclude = ["tests/artifacts/**/*.safetensors"] +exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"] [tool.ruff.lint] select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"] diff --git a/tests/optim/test_optimizers.py b/tests/optim/test_optimizers.py index 997e14fe94..630353fcaf 100644 --- a/tests/optim/test_optimizers.py +++ b/tests/optim/test_optimizers.py @@ -21,6 +21,7 @@ from lerobot.common.optim.optimizers import ( AdamConfig, AdamWConfig, + MultiAdamConfig, SGDConfig, load_optimizer_state, save_optimizer_state, @@ -33,13 +34,21 @@ (AdamConfig, torch.optim.Adam), (AdamWConfig, torch.optim.AdamW), (SGDConfig, torch.optim.SGD), + (MultiAdamConfig, dict), ], ) def test_optimizer_build(config_cls, expected_class, model_params): config = config_cls() - optimizer = config.build(model_params) - assert isinstance(optimizer, expected_class) - assert optimizer.defaults["lr"] == config.lr + if config_cls == MultiAdamConfig: + params_dict = {"default": model_params} + optimizer = config.build(params_dict) + assert isinstance(optimizer, expected_class) + assert isinstance(optimizer["default"], torch.optim.Adam) + assert optimizer["default"].defaults["lr"] == config.lr + else: + optimizer = config.build(model_params) + assert isinstance(optimizer, expected_class) + assert optimizer.defaults["lr"] == config.lr def test_save_optimizer_state(optimizer, tmp_path): @@ -54,3 +63,180 @@ def test_save_and_load_optimizer_state(model_params, optimizer, tmp_path): loaded_optimizer = load_optimizer_state(loaded_optimizer, tmp_path) torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict()) + + +@pytest.fixture +def base_params_dict(): + return { + "actor": [torch.nn.Parameter(torch.randn(10, 10))], + "critic": [torch.nn.Parameter(torch.randn(5, 5))], + "temperature": [torch.nn.Parameter(torch.randn(3, 3))], + } + + +@pytest.mark.parametrize( + "config_params, expected_values", + [ + # Test 1: Basic configuration with different learning rates + ( + { + "lr": 1e-3, + "weight_decay": 1e-4, + "optimizer_groups": { + "actor": {"lr": 1e-4}, + "critic": {"lr": 5e-4}, + "temperature": {"lr": 2e-3}, + }, + }, + { + "actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)}, + "critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)}, + "temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999)}, + }, + ), + # Test 2: Different weight decays and beta values + ( + { + "lr": 1e-3, + "weight_decay": 1e-4, + "optimizer_groups": { + "actor": {"lr": 1e-4, "weight_decay": 1e-5}, + "critic": {"lr": 5e-4, "weight_decay": 1e-6}, + "temperature": {"lr": 2e-3, "betas": (0.95, 0.999)}, + }, + }, + { + "actor": {"lr": 1e-4, "weight_decay": 1e-5, "betas": (0.9, 0.999)}, + "critic": {"lr": 5e-4, "weight_decay": 1e-6, "betas": (0.9, 0.999)}, + "temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.95, 0.999)}, + }, + ), + # Test 3: Epsilon parameter customization + ( + { + "lr": 1e-3, + "weight_decay": 1e-4, + "optimizer_groups": { + "actor": {"lr": 1e-4, "eps": 1e-6}, + "critic": {"lr": 5e-4, "eps": 1e-7}, + "temperature": {"lr": 2e-3, "eps": 1e-8}, + }, + }, + { + "actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-6}, + "critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-7}, + "temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-8}, + }, + ), + ], +) +def test_multi_adam_configuration(base_params_dict, config_params, expected_values): + # Create config with the given parameters + config = MultiAdamConfig(**config_params) + optimizers = config.build(base_params_dict) + + # Verify optimizer count and keys + assert len(optimizers) == len(expected_values) + assert set(optimizers.keys()) == set(expected_values.keys()) + + # Check that all optimizers are Adam instances + for opt in optimizers.values(): + assert isinstance(opt, torch.optim.Adam) + + # Verify hyperparameters for each optimizer + for name, expected in expected_values.items(): + optimizer = optimizers[name] + for param, value in expected.items(): + assert optimizer.defaults[param] == value + + +@pytest.fixture +def multi_optimizers(base_params_dict): + config = MultiAdamConfig( + lr=1e-3, + optimizer_groups={ + "actor": {"lr": 1e-4}, + "critic": {"lr": 5e-4}, + "temperature": {"lr": 2e-3}, + }, + ) + return config.build(base_params_dict) + + +def test_save_multi_optimizer_state(multi_optimizers, tmp_path): + # Save optimizer states + save_optimizer_state(multi_optimizers, tmp_path) + + # Verify that directories were created for each optimizer + for name in multi_optimizers: + assert (tmp_path / name).is_dir() + assert (tmp_path / name / OPTIMIZER_STATE).is_file() + assert (tmp_path / name / OPTIMIZER_PARAM_GROUPS).is_file() + + +def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers, tmp_path): + # Option 1: Add a minimal backward pass to populate optimizer states + for name, params in base_params_dict.items(): + if name in multi_optimizers: + # Create a dummy loss and do backward + dummy_loss = params[0].sum() + dummy_loss.backward() + # Perform an optimization step + multi_optimizers[name].step() + # Zero gradients for next steps + multi_optimizers[name].zero_grad() + + # Save optimizer states + save_optimizer_state(multi_optimizers, tmp_path) + + # Create new optimizers with the same config + config = MultiAdamConfig( + lr=1e-3, + optimizer_groups={ + "actor": {"lr": 1e-4}, + "critic": {"lr": 5e-4}, + "temperature": {"lr": 2e-3}, + }, + ) + new_optimizers = config.build(base_params_dict) + + # Load optimizer states + loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path) + + # Verify state dictionaries match + for name in multi_optimizers: + torch.testing.assert_close(multi_optimizers[name].state_dict(), loaded_optimizers[name].state_dict()) + + +def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path): + """Test saving and loading optimizer states even when the state is empty (no backward pass).""" + # Create config and build optimizers + config = MultiAdamConfig( + lr=1e-3, + optimizer_groups={ + "actor": {"lr": 1e-4}, + "critic": {"lr": 5e-4}, + "temperature": {"lr": 2e-3}, + }, + ) + optimizers = config.build(base_params_dict) + + # Save optimizer states without any backward pass (empty state) + save_optimizer_state(optimizers, tmp_path) + + # Create new optimizers with the same config + new_optimizers = config.build(base_params_dict) + + # Load optimizer states + loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path) + + # Verify hyperparameters match even with empty state + for name, optimizer in optimizers.items(): + assert optimizer.defaults["lr"] == loaded_optimizers[name].defaults["lr"] + assert optimizer.defaults["weight_decay"] == loaded_optimizers[name].defaults["weight_decay"] + assert optimizer.defaults["betas"] == loaded_optimizers[name].defaults["betas"] + + # Verify state dictionaries match (they will be empty) + torch.testing.assert_close( + optimizer.state_dict()["param_groups"], loaded_optimizers[name].state_dict()["param_groups"] + ) diff --git a/tests/policies/hilserl/test_modeling_classifier.py b/tests/policies/hilserl/test_modeling_classifier.py new file mode 100644 index 0000000000..526e1f17dd --- /dev/null +++ b/tests/policies/hilserl/test_modeling_classifier.py @@ -0,0 +1,139 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from lerobot.common.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig +from lerobot.common.policies.sac.reward_model.modeling_classifier import ClassifierOutput +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from tests.utils import require_package + + +def test_classifier_output(): + output = ClassifierOutput( + logits=torch.tensor([1, 2, 3]), + probabilities=torch.tensor([0.1, 0.2, 0.3]), + hidden_states=None, + ) + + assert ( + f"{output}" + == "ClassifierOutput(logits=tensor([1, 2, 3]), probabilities=tensor([0.1000, 0.2000, 0.3000]), hidden_states=None)" + ) + + +@require_package("transformers") +def test_binary_classifier_with_default_params(): + from lerobot.common.policies.sac.reward_model.modeling_classifier import Classifier + + config = RewardClassifierConfig() + config.input_features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,)), + } + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "REWARD": NormalizationMode.IDENTITY, + } + config.num_cameras = 1 + classifier = Classifier(config) + + batch_size = 10 + + input = { + "observation.image": torch.rand((batch_size, 3, 128, 128)), + "next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(), + } + + images, labels = classifier.extract_images_and_labels(input) + assert len(images) == 1 + assert images[0].shape == torch.Size([batch_size, 3, 128, 128]) + assert labels.shape == torch.Size([batch_size]) + + output = classifier.predict(images) + + assert output is not None + assert output.logits.size() == torch.Size([batch_size]) + assert not torch.isnan(output.logits).any(), "Tensor contains NaN values" + assert output.probabilities.shape == torch.Size([batch_size]) + assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values" + assert output.hidden_states.shape == torch.Size([batch_size, 256]) + assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values" + + +@require_package("transformers") +def test_multiclass_classifier(): + from lerobot.common.policies.sac.reward_model.modeling_classifier import Classifier + + num_classes = 5 + config = RewardClassifierConfig() + config.input_features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)), + } + config.num_cameras = 1 + config.num_classes = num_classes + classifier = Classifier(config) + + batch_size = 10 + + input = { + "observation.image": torch.rand((batch_size, 3, 128, 128)), + "next.reward": torch.rand((batch_size, num_classes)), + } + + images, labels = classifier.extract_images_and_labels(input) + assert len(images) == 1 + assert images[0].shape == torch.Size([batch_size, 3, 128, 128]) + assert labels.shape == torch.Size([batch_size, num_classes]) + + output = classifier.predict(images) + + assert output is not None + assert output.logits.shape == torch.Size([batch_size, num_classes]) + assert not torch.isnan(output.logits).any(), "Tensor contains NaN values" + assert output.probabilities.shape == torch.Size([batch_size, num_classes]) + assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values" + assert output.hidden_states.shape == torch.Size([batch_size, 256]) + assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values" + + +@require_package("transformers") +def test_default_device(): + from lerobot.common.policies.sac.reward_model.modeling_classifier import Classifier + + config = RewardClassifierConfig() + assert config.device == "cpu" + + classifier = Classifier(config) + for p in classifier.parameters(): + assert p.device == torch.device("cpu") + + +@require_package("transformers") +def test_explicit_device_setup(): + from lerobot.common.policies.sac.reward_model.modeling_classifier import Classifier + + config = RewardClassifierConfig(device="cpu") + assert config.device == "cpu" + + classifier = Classifier(config) + for p in classifier.parameters(): + assert p.device == torch.device("cpu") diff --git a/tests/policies/test_sac_config.py b/tests/policies/test_sac_config.py new file mode 100644 index 0000000000..d94ee41e04 --- /dev/null +++ b/tests/policies/test_sac_config.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from lerobot.common.policies.sac.configuration_sac import ( + ActorLearnerConfig, + ActorNetworkConfig, + ConcurrencyConfig, + CriticNetworkConfig, + PolicyConfig, + SACConfig, +) +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature + + +def test_sac_config_default_initialization(): + config = SACConfig() + + assert config.normalization_mapping == { + "VISUAL": NormalizationMode.MEAN_STD, + "STATE": NormalizationMode.MIN_MAX, + "ENV": NormalizationMode.MIN_MAX, + "ACTION": NormalizationMode.MIN_MAX, + } + assert config.dataset_stats == { + "observation.image": { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + }, + "observation.state": { + "min": [0.0, 0.0], + "max": [1.0, 1.0], + }, + "action": { + "min": [0.0, 0.0, 0.0], + "max": [1.0, 1.0, 1.0], + }, + } + + # Basic parameters + assert config.device == "cpu" + assert config.storage_device == "cpu" + assert config.discount == 0.99 + assert config.temperature_init == 1.0 + assert config.num_critics == 2 + + # Architecture specifics + assert config.vision_encoder_name is None + assert config.freeze_vision_encoder is True + assert config.image_encoder_hidden_dim == 32 + assert config.shared_encoder is True + assert config.num_discrete_actions is None + assert config.image_embedding_pooling_dim == 8 + + # Training parameters + assert config.online_steps == 1000000 + assert config.online_env_seed == 10000 + assert config.online_buffer_capacity == 100000 + assert config.offline_buffer_capacity == 100000 + assert config.async_prefetch is False + assert config.online_step_before_learning == 100 + assert config.policy_update_freq == 1 + + # SAC algorithm parameters + assert config.num_subsample_critics is None + assert config.critic_lr == 3e-4 + assert config.actor_lr == 3e-4 + assert config.temperature_lr == 3e-4 + assert config.critic_target_update_weight == 0.005 + assert config.utd_ratio == 1 + assert config.state_encoder_hidden_dim == 256 + assert config.latent_dim == 256 + assert config.target_entropy is None + assert config.use_backup_entropy is True + assert config.grad_clip_norm == 40.0 + + # Dataset stats defaults + expected_dataset_stats = { + "observation.image": { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + }, + "observation.state": { + "min": [0.0, 0.0], + "max": [1.0, 1.0], + }, + "action": { + "min": [0.0, 0.0, 0.0], + "max": [1.0, 1.0, 1.0], + }, + } + assert config.dataset_stats == expected_dataset_stats + + # Critic network configuration + assert config.critic_network_kwargs.hidden_dims == [256, 256] + assert config.critic_network_kwargs.activate_final is True + assert config.critic_network_kwargs.final_activation is None + + # Actor network configuration + assert config.actor_network_kwargs.hidden_dims == [256, 256] + assert config.actor_network_kwargs.activate_final is True + + # Policy configuration + assert config.policy_kwargs.use_tanh_squash is True + assert config.policy_kwargs.std_min == 1e-5 + assert config.policy_kwargs.std_max == 10.0 + assert config.policy_kwargs.init_final == 0.05 + + # Discrete critic network configuration + assert config.discrete_critic_network_kwargs.hidden_dims == [256, 256] + assert config.discrete_critic_network_kwargs.activate_final is True + assert config.discrete_critic_network_kwargs.final_activation is None + + # Actor learner configuration + assert config.actor_learner_config.learner_host == "127.0.0.1" + assert config.actor_learner_config.learner_port == 50051 + assert config.actor_learner_config.policy_parameters_push_frequency == 4 + + # Concurrency configuration + assert config.concurrency.actor == "threads" + assert config.concurrency.learner == "threads" + + assert isinstance(config.actor_network_kwargs, ActorNetworkConfig) + assert isinstance(config.critic_network_kwargs, CriticNetworkConfig) + assert isinstance(config.policy_kwargs, PolicyConfig) + assert isinstance(config.actor_learner_config, ActorLearnerConfig) + assert isinstance(config.concurrency, ConcurrencyConfig) + + +def test_critic_network_kwargs(): + config = CriticNetworkConfig() + assert config.hidden_dims == [256, 256] + assert config.activate_final is True + assert config.final_activation is None + + +def test_actor_network_kwargs(): + config = ActorNetworkConfig() + assert config.hidden_dims == [256, 256] + assert config.activate_final is True + + +def test_policy_kwargs(): + config = PolicyConfig() + assert config.use_tanh_squash is True + assert config.std_min == 1e-5 + assert config.std_max == 10.0 + assert config.init_final == 0.05 + + +def test_actor_learner_config(): + config = ActorLearnerConfig() + assert config.learner_host == "127.0.0.1" + assert config.learner_port == 50051 + assert config.policy_parameters_push_frequency == 4 + + +def test_concurrency_config(): + config = ConcurrencyConfig() + assert config.actor == "threads" + assert config.learner == "threads" + + +def test_sac_config_custom_initialization(): + config = SACConfig( + device="cpu", + discount=0.95, + temperature_init=0.5, + num_critics=3, + ) + + assert config.device == "cpu" + assert config.discount == 0.95 + assert config.temperature_init == 0.5 + assert config.num_critics == 3 + + +def test_validate_features(): + config = SACConfig( + input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))}, + output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, + ) + config.validate_features() + + +def test_validate_features_missing_observation(): + config = SACConfig( + input_features={"wrong_key": PolicyFeature(type=FeatureType.STATE, shape=(10,))}, + output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, + ) + with pytest.raises( + ValueError, match="You must provide either 'observation.state' or an image observation" + ): + config.validate_features() + + +def test_validate_features_missing_action(): + config = SACConfig( + input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))}, + output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, + ) + with pytest.raises(ValueError, match="You must provide 'action' in the output features"): + config.validate_features() diff --git a/tests/policies/test_sac_policy.py b/tests/policies/test_sac_policy.py new file mode 100644 index 0000000000..e4e2dd8a99 --- /dev/null +++ b/tests/policies/test_sac_policy.py @@ -0,0 +1,541 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import pytest +import torch +from torch import Tensor, nn + +from lerobot.common.policies.sac.configuration_sac import SACConfig +from lerobot.common.policies.sac.modeling_sac import MLP, SACPolicy +from lerobot.common.utils.random_utils import seeded_context, set_seed +from lerobot.configs.types import FeatureType, PolicyFeature + +try: + import transformers # noqa: F401 + + TRANSFORMERS_AVAILABLE = True +except ImportError: + TRANSFORMERS_AVAILABLE = False + + +@pytest.fixture(autouse=True) +def set_random_seed(): + seed = 42 + set_seed(seed) + + +def test_mlp_with_default_args(): + mlp = MLP(input_dim=10, hidden_dims=[256, 256]) + + x = torch.randn(10) + y = mlp(x) + assert y.shape == (256,) + + +def test_mlp_with_batch_dim(): + mlp = MLP(input_dim=10, hidden_dims=[256, 256]) + x = torch.randn(2, 10) + y = mlp(x) + assert y.shape == (2, 256) + + +def test_forward_with_empty_hidden_dims(): + mlp = MLP(input_dim=10, hidden_dims=[]) + x = torch.randn(1, 10) + assert mlp(x).shape == (1, 10) + + +def test_mlp_with_dropout(): + mlp = MLP(input_dim=10, hidden_dims=[256, 256, 11], dropout_rate=0.1) + x = torch.randn(1, 10) + y = mlp(x) + assert y.shape == (1, 11) + + drop_out_layers_count = sum(isinstance(layer, nn.Dropout) for layer in mlp.net) + assert drop_out_layers_count == 2 + + +def test_mlp_with_custom_final_activation(): + mlp = MLP(input_dim=10, hidden_dims=[256, 256], final_activation=torch.nn.Tanh()) + x = torch.randn(1, 10) + y = mlp(x) + assert y.shape == (1, 256) + assert (y >= -1).all() and (y <= 1).all() + + +def test_sac_policy_with_default_args(): + with pytest.raises(ValueError, match="should be an instance of class `PreTrainedConfig`"): + SACPolicy() + + +def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor: + return { + "observation.state": torch.randn(batch_size, state_dim), + } + + +def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor: + return { + "observation.image": torch.randn(batch_size, 3, 84, 84), + "observation.state": torch.randn(batch_size, state_dim), + } + + +def create_dummy_action(batch_size: int, action_dim: int = 10) -> Tensor: + return torch.randn(batch_size, action_dim) + + +def create_default_train_batch( + batch_size: int = 8, state_dim: int = 10, action_dim: int = 10 +) -> dict[str, Tensor]: + return { + "action": create_dummy_action(batch_size, action_dim), + "reward": torch.randn(batch_size), + "state": create_dummy_state(batch_size, state_dim), + "next_state": create_dummy_state(batch_size, state_dim), + "done": torch.randn(batch_size), + } + + +def create_train_batch_with_visual_input( + batch_size: int = 8, state_dim: int = 10, action_dim: int = 10 +) -> dict[str, Tensor]: + return { + "action": create_dummy_action(batch_size, action_dim), + "reward": torch.randn(batch_size), + "state": create_dummy_with_visual_input(batch_size, state_dim), + "next_state": create_dummy_with_visual_input(batch_size, state_dim), + "done": torch.randn(batch_size), + } + + +def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]: + return { + "observation.state": torch.randn(batch_size, state_dim), + } + + +def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]: + return { + "observation.state": torch.randn(batch_size, state_dim), + "observation.image": torch.randn(batch_size, 3, 84, 84), + } + + +def make_optimizers(policy: SACPolicy, has_discrete_action: bool = False) -> dict[str, torch.optim.Optimizer]: + """Create optimizers for the SAC policy.""" + optimizer_actor = torch.optim.Adam( + # Handle the case of shared encoder where the encoder weights are not optimized with the actor gradient + params=[ + p + for n, p in policy.actor.named_parameters() + if not policy.config.shared_encoder or not n.startswith("encoder") + ], + lr=policy.config.actor_lr, + ) + optimizer_critic = torch.optim.Adam( + params=policy.critic_ensemble.parameters(), + lr=policy.config.critic_lr, + ) + optimizer_temperature = torch.optim.Adam( + params=[policy.log_alpha], + lr=policy.config.critic_lr, + ) + + optimizers = { + "actor": optimizer_actor, + "critic": optimizer_critic, + "temperature": optimizer_temperature, + } + + if has_discrete_action: + optimizers["discrete_critic"] = torch.optim.Adam( + params=policy.discrete_critic.parameters(), + lr=policy.config.critic_lr, + ) + + return optimizers + + +def create_default_config( + state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False +) -> SACConfig: + action_dim = continuous_action_dim + if has_discrete_action: + action_dim += 1 + + config = SACConfig( + input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, + output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))}, + dataset_stats={ + "observation.state": { + "min": [0.0] * state_dim, + "max": [1.0] * state_dim, + }, + "action": { + "min": [0.0] * continuous_action_dim, + "max": [1.0] * continuous_action_dim, + }, + }, + ) + config.validate_features() + return config + + +def create_config_with_visual_input( + state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False +) -> SACConfig: + config = create_default_config( + state_dim=state_dim, + continuous_action_dim=continuous_action_dim, + has_discrete_action=has_discrete_action, + ) + config.input_features["observation.image"] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84)) + config.dataset_stats["observation.image"] = { + "mean": torch.randn(3, 1, 1), + "std": torch.randn(3, 1, 1), + } + + # Let make tests a little bit faster + config.state_encoder_hidden_dim = 32 + config.latent_dim = 32 + + config.validate_features() + return config + + +@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)]) +def test_sac_policy_with_default_config(batch_size: int, state_dim: int, action_dim: int): + batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim) + config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim) + + policy = SACPolicy(config=config) + policy.train() + + optimizers = make_optimizers(policy) + + cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] + assert cirtic_loss.item() is not None + assert cirtic_loss.shape == () + cirtic_loss.backward() + optimizers["critic"].step() + + actor_loss = policy.forward(batch, model="actor")["loss_actor"] + assert actor_loss.item() is not None + assert actor_loss.shape == () + + actor_loss.backward() + optimizers["actor"].step() + + temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"] + assert temperature_loss.item() is not None + assert temperature_loss.shape == () + + temperature_loss.backward() + optimizers["temperature"].step() + + policy.eval() + with torch.no_grad(): + observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + selected_action = policy.select_action(observation_batch) + assert selected_action.shape == (batch_size, action_dim) + + +@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)]) +def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_dim: int): + config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) + policy = SACPolicy(config=config) + + batch = create_train_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim, action_dim=action_dim + ) + + policy.train() + + optimizers = make_optimizers(policy) + + cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] + assert cirtic_loss.item() is not None + assert cirtic_loss.shape == () + cirtic_loss.backward() + optimizers["critic"].step() + + actor_loss = policy.forward(batch, model="actor")["loss_actor"] + assert actor_loss.item() is not None + assert actor_loss.shape == () + + actor_loss.backward() + optimizers["actor"].step() + + temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"] + assert temperature_loss.item() is not None + assert temperature_loss.shape == () + + temperature_loss.backward() + optimizers["temperature"].step() + + policy.eval() + with torch.no_grad(): + observation_batch = create_observation_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim + ) + selected_action = policy.select_action(observation_batch) + assert selected_action.shape == (batch_size, action_dim) + + +# Let's check best candidates for pretrained encoders +@pytest.mark.parametrize( + "batch_size,state_dim,action_dim,vision_encoder_name", + [(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")], +) +@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed") +def test_sac_policy_with_pretrained_encoder( + batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str +): + config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) + config.vision_encoder_name = vision_encoder_name + policy = SACPolicy(config=config) + policy.train() + + batch = create_train_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim, action_dim=action_dim + ) + + optimizers = make_optimizers(policy) + + cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] + assert cirtic_loss.item() is not None + assert cirtic_loss.shape == () + cirtic_loss.backward() + optimizers["critic"].step() + + actor_loss = policy.forward(batch, model="actor")["loss_actor"] + assert actor_loss.item() is not None + assert actor_loss.shape == () + + +def test_sac_policy_with_shared_encoder(): + batch_size = 2 + action_dim = 10 + state_dim = 10 + config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) + config.shared_encoder = True + + policy = SACPolicy(config=config) + policy.train() + + batch = create_train_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim, action_dim=action_dim + ) + + policy.train() + + optimizers = make_optimizers(policy) + + cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] + assert cirtic_loss.item() is not None + assert cirtic_loss.shape == () + cirtic_loss.backward() + optimizers["critic"].step() + + actor_loss = policy.forward(batch, model="actor")["loss_actor"] + assert actor_loss.item() is not None + assert actor_loss.shape == () + + actor_loss.backward() + optimizers["actor"].step() + + +def test_sac_policy_with_discrete_critic(): + batch_size = 2 + continuous_action_dim = 9 + full_action_dim = continuous_action_dim + 1 # the last action is discrete + state_dim = 10 + config = create_config_with_visual_input( + state_dim=state_dim, continuous_action_dim=continuous_action_dim, has_discrete_action=True + ) + + num_discrete_actions = 5 + config.num_discrete_actions = num_discrete_actions + + policy = SACPolicy(config=config) + policy.train() + + batch = create_train_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim, action_dim=full_action_dim + ) + + policy.train() + + optimizers = make_optimizers(policy, has_discrete_action=True) + + cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] + assert cirtic_loss.item() is not None + assert cirtic_loss.shape == () + cirtic_loss.backward() + optimizers["critic"].step() + + discrete_critic_loss = policy.forward(batch, model="discrete_critic")["loss_discrete_critic"] + assert discrete_critic_loss.item() is not None + assert discrete_critic_loss.shape == () + discrete_critic_loss.backward() + optimizers["discrete_critic"].step() + + actor_loss = policy.forward(batch, model="actor")["loss_actor"] + assert actor_loss.item() is not None + assert actor_loss.shape == () + + actor_loss.backward() + optimizers["actor"].step() + + policy.eval() + with torch.no_grad(): + observation_batch = create_observation_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim + ) + selected_action = policy.select_action(observation_batch) + assert selected_action.shape == (batch_size, full_action_dim) + + discrete_actions = selected_action[:, -1].long() + discrete_action_values = set(discrete_actions.tolist()) + + assert all(action in range(num_discrete_actions) for action in discrete_action_values), ( + f"Discrete action {discrete_action_values} is not in range({num_discrete_actions})" + ) + + +def test_sac_policy_with_default_entropy(): + config = create_default_config(continuous_action_dim=10, state_dim=10) + policy = SACPolicy(config=config) + assert policy.target_entropy == -5.0 + + +def test_sac_policy_default_target_entropy_with_discrete_action(): + config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True) + policy = SACPolicy(config=config) + assert policy.target_entropy == -3.0 + + +def test_sac_policy_with_predefined_entropy(): + config = create_default_config(state_dim=10, continuous_action_dim=6) + config.target_entropy = -3.5 + + policy = SACPolicy(config=config) + assert policy.target_entropy == pytest.approx(-3.5) + + +def test_sac_policy_update_temperature(): + config = create_default_config(continuous_action_dim=10, state_dim=10) + policy = SACPolicy(config=config) + + assert policy.temperature == pytest.approx(1.0) + policy.log_alpha.data = torch.tensor([math.log(0.1)]) + policy.update_temperature() + assert policy.temperature == pytest.approx(0.1) + + +def test_sac_policy_update_target_network(): + config = create_default_config(state_dim=10, continuous_action_dim=6) + config.critic_target_update_weight = 1.0 + + policy = SACPolicy(config=config) + policy.train() + + for p in policy.critic_ensemble.parameters(): + p.data = torch.ones_like(p.data) + + policy.update_target_networks() + for p in policy.critic_target.parameters(): + assert torch.allclose(p.data, torch.ones_like(p.data)), ( + f"Target network {p.data} is not equal to {torch.ones_like(p.data)}" + ) + + +@pytest.mark.parametrize("num_critics", [1, 3]) +def test_sac_policy_with_critics_number_of_heads(num_critics: int): + batch_size = 2 + action_dim = 10 + state_dim = 10 + config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) + config.num_critics = num_critics + + policy = SACPolicy(config=config) + policy.train() + + assert len(policy.critic_ensemble.critics) == num_critics + + batch = create_train_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim, action_dim=action_dim + ) + + policy.train() + + optimizers = make_optimizers(policy) + + cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] + assert cirtic_loss.item() is not None + assert cirtic_loss.shape == () + cirtic_loss.backward() + optimizers["critic"].step() + + +def test_sac_policy_save_and_load(tmp_path): + root = tmp_path / "test_sac_save_and_load" + + state_dim = 10 + action_dim = 10 + batch_size = 2 + + config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim) + policy = SACPolicy(config=config) + policy.eval() + policy.save_pretrained(root) + loaded_policy = SACPolicy.from_pretrained(root, config=config) + loaded_policy.eval() + + batch = create_default_train_batch(batch_size=1, state_dim=10, action_dim=10) + + with torch.no_grad(): + with seeded_context(12): + # Collect policy values before saving + cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] + actor_loss = policy.forward(batch, model="actor")["loss_actor"] + temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"] + + observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + actions = policy.select_action(observation_batch) + + with seeded_context(12): + # Collect policy values after loading + loaded_cirtic_loss = loaded_policy.forward(batch, model="critic")["loss_critic"] + loaded_actor_loss = loaded_policy.forward(batch, model="actor")["loss_actor"] + loaded_temperature_loss = loaded_policy.forward(batch, model="temperature")["loss_temperature"] + + loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + loaded_actions = loaded_policy.select_action(loaded_observation_batch) + + assert policy.state_dict().keys() == loaded_policy.state_dict().keys() + for k in policy.state_dict(): + assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6) + + # Compare values before and after saving and loading + # They should be the same + assert torch.allclose(cirtic_loss, loaded_cirtic_loss) + assert torch.allclose(actor_loss, loaded_actor_loss) + assert torch.allclose(temperature_loss, loaded_temperature_loss) + assert torch.allclose(actions, loaded_actions) diff --git a/tests/rl/test_actor.py b/tests/rl/test_actor.py new file mode 100644 index 0000000000..0cf6a8f644 --- /dev/null +++ b/tests/rl/test_actor.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from concurrent import futures +from unittest.mock import patch + +import pytest +import torch +from torch.multiprocessing import Event, Queue + +from lerobot.common.utils.transition import Transition +from tests.utils import require_package + + +def create_learner_service_stub(): + import grpc + + from lerobot.common.transport import services_pb2, services_pb2_grpc + + class MockLearnerService(services_pb2_grpc.LearnerServiceServicer): + def __init__(self): + self.ready_call_count = 0 + self.should_fail = False + + def Ready(self, request, context): # noqa: N802 + self.ready_call_count += 1 + if self.should_fail: + context.set_code(grpc.StatusCode.UNAVAILABLE) + context.set_details("Service unavailable") + raise grpc.RpcError("Service unavailable") + return services_pb2.Empty() + + """Fixture to start a LearnerService gRPC server and provide a connected stub.""" + + servicer = MockLearnerService() + + # Create a gRPC server and add our servicer to it. + server = grpc.server(futures.ThreadPoolExecutor(max_workers=4)) + services_pb2_grpc.add_LearnerServiceServicer_to_server(servicer, server) + port = server.add_insecure_port("[::]:0") # bind to a free port chosen by OS + server.start() # start the server (non-blocking call):contentReference[oaicite:1]{index=1} + + # Create a client channel and stub connected to the server's port. + channel = grpc.insecure_channel(f"localhost:{port}") + return services_pb2_grpc.LearnerServiceStub(channel), servicer, channel, server + + +def close_service_stub(channel, server): + channel.close() + server.stop(None) + + +@require_package("grpc") +def test_establish_learner_connection_success(): + from lerobot.scripts.rl.actor import establish_learner_connection + + """Test successful connection establishment.""" + stub, _servicer, channel, server = create_learner_service_stub() + + shutdown_event = Event() + + # Test successful connection + result = establish_learner_connection(stub, shutdown_event, attempts=5) + + assert result is True + + close_service_stub(channel, server) + + +@require_package("grpc") +def test_establish_learner_connection_failure(): + from lerobot.scripts.rl.actor import establish_learner_connection + + """Test connection failure.""" + stub, servicer, channel, server = create_learner_service_stub() + servicer.should_fail = True + + shutdown_event = Event() + + # Test failed connection + with patch("time.sleep"): # Speed up the test + result = establish_learner_connection(stub, shutdown_event, attempts=2) + + assert result is False + + close_service_stub(channel, server) + + +@require_package("grpc") +def test_push_transitions_to_transport_queue(): + from lerobot.common.transport.utils import bytes_to_transitions + from lerobot.scripts.rl.actor import push_transitions_to_transport_queue + from tests.transport.test_transport_utils import assert_transitions_equal + + """Test pushing transitions to transport queue.""" + # Create mock transitions + transitions = [] + for i in range(3): + transition = Transition( + state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, + action=torch.randn(5), + reward=torch.tensor(1.0 + i), + done=torch.tensor(False), + truncated=torch.tensor(False), + next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, + complementary_info={"step": torch.tensor(i)}, + ) + transitions.append(transition) + + transitions_queue = Queue() + + # Test pushing transitions + push_transitions_to_transport_queue(transitions, transitions_queue) + + # Verify the data can be retrieved + serialized_data = transitions_queue.get() + assert isinstance(serialized_data, bytes) + deserialized_transitions = bytes_to_transitions(serialized_data) + assert len(deserialized_transitions) == len(transitions) + for i, deserialized_transition in enumerate(deserialized_transitions): + assert_transitions_equal(deserialized_transition, transitions[i]) + + +@require_package("grpc") +@pytest.mark.timeout(3) # force cross-platform watchdog +def test_transitions_stream(): + from lerobot.scripts.rl.actor import transitions_stream + + """Test transitions stream functionality.""" + shutdown_event = Event() + transitions_queue = Queue() + + # Add test data to queue + test_data = [b"transition_data_1", b"transition_data_2", b"transition_data_3"] + for data in test_data: + transitions_queue.put(data) + + # Collect streamed data + streamed_data = [] + stream_generator = transitions_stream(shutdown_event, transitions_queue, 0.1) + + # Process a few items + for i, message in enumerate(stream_generator): + streamed_data.append(message) + if i >= len(test_data) - 1: + shutdown_event.set() + break + + # Verify we got messages + assert len(streamed_data) == len(test_data) + assert streamed_data[0].data == b"transition_data_1" + assert streamed_data[1].data == b"transition_data_2" + assert streamed_data[2].data == b"transition_data_3" + + +@require_package("grpc") +@pytest.mark.timeout(3) # force cross-platform watchdog +def test_interactions_stream(): + from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes + from lerobot.scripts.rl.actor import interactions_stream + + """Test interactions stream functionality.""" + shutdown_event = Event() + interactions_queue = Queue() + + # Create test interaction data (similar structure to what would be sent) + test_interactions = [ + {"episode_reward": 10.5, "step": 1, "policy_fps": 30.2}, + {"episode_reward": 15.2, "step": 2, "policy_fps": 28.7}, + {"episode_reward": 8.7, "step": 3, "policy_fps": 29.1}, + ] + + # Serialize the interaction data as it would be in practice + test_data = [ + interactions_queue.put(python_object_to_bytes(interaction)) for interaction in test_interactions + ] + + # Collect streamed data + streamed_data = [] + stream_generator = interactions_stream(shutdown_event, interactions_queue, 0.1) + + # Process the items + for i, message in enumerate(stream_generator): + streamed_data.append(message) + if i >= len(test_data) - 1: + shutdown_event.set() + break + + # Verify we got messages + assert len(streamed_data) == len(test_data) + + # Verify the messages can be deserialized back to original data + for i, message in enumerate(streamed_data): + deserialized_interaction = bytes_to_python_object(message.data) + assert deserialized_interaction == test_interactions[i] diff --git a/tests/rl/test_actor_learner.py b/tests/rl/test_actor_learner.py new file mode 100644 index 0000000000..cb72da7e40 --- /dev/null +++ b/tests/rl/test_actor_learner.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import socket +import threading +import time + +import pytest +import torch +from torch.multiprocessing import Event, Queue + +from lerobot.common.policies.sac.configuration_sac import SACConfig +from lerobot.common.utils.transition import Transition +from lerobot.configs.train import TrainRLServerPipelineConfig +from tests.utils import require_package + + +def create_test_transitions(count: int = 3) -> list[Transition]: + """Create test transitions for integration testing.""" + transitions = [] + for i in range(count): + transition = Transition( + state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, + action=torch.randn(5), + reward=torch.tensor(1.0 + i), + done=torch.tensor(i == count - 1), # Last transition is done + truncated=torch.tensor(False), + next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, + complementary_info={"step": torch.tensor(i), "episode_id": i // 2}, + ) + transitions.append(transition) + return transitions + + +def create_test_interactions(count: int = 3) -> list[dict]: + """Create test interactions for integration testing.""" + interactions = [] + for i in range(count): + interaction = { + "episode_reward": 10.0 + i * 5, + "step": i * 100, + "policy_fps": 30.0 + i, + "intervention_rate": 0.1 * i, + "episode_length": 200 + i * 50, + } + interactions.append(interaction) + return interactions + + +def find_free_port(): + """Finds a free port on the local machine.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) # Bind to port 0 to let the OS choose a free port + s.listen(1) + port = s.getsockname()[1] + return port + + +@pytest.fixture +def cfg(): + cfg = TrainRLServerPipelineConfig() + + port = find_free_port() + + policy_cfg = SACConfig() + policy_cfg.actor_learner_config.learner_host = "127.0.0.1" + policy_cfg.actor_learner_config.learner_port = port + policy_cfg.concurrency.actor = "threads" + policy_cfg.concurrency.learner = "threads" + policy_cfg.actor_learner_config.queue_get_timeout = 0.1 + + cfg.policy = policy_cfg + + return cfg + + +@require_package("grpc") +@pytest.mark.timeout(10) # force cross-platform watchdog +def test_end_to_end_transitions_flow(cfg): + from lerobot.common.transport.utils import bytes_to_transitions + from lerobot.scripts.rl.actor import ( + establish_learner_connection, + learner_service_client, + push_transitions_to_transport_queue, + send_transitions, + ) + from lerobot.scripts.rl.learner import start_learner + from tests.transport.test_transport_utils import assert_transitions_equal + + """Test complete transitions flow from actor to learner.""" + transitions_actor_queue = Queue() + transitions_learner_queue = Queue() + + interactions_queue = Queue() + parameters_queue = Queue() + shutdown_event = Event() + + learner_thread = threading.Thread( + target=start_learner, + args=(parameters_queue, transitions_learner_queue, interactions_queue, shutdown_event, cfg), + ) + learner_thread.start() + + policy_cfg = cfg.policy + learner_client, channel = learner_service_client( + host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port + ) + + assert establish_learner_connection(learner_client, shutdown_event, attempts=5) + + send_transitions_thread = threading.Thread( + target=send_transitions, args=(cfg, transitions_actor_queue, shutdown_event, learner_client, channel) + ) + send_transitions_thread.start() + + input_transitions = create_test_transitions(count=5) + + push_transitions_to_transport_queue(input_transitions, transitions_actor_queue) + + # Wait for learner to start + time.sleep(0.1) + + shutdown_event.set() + + # Wait for learner to receive transitions + learner_thread.join() + send_transitions_thread.join() + channel.close() + + received_transitions = [] + while not transitions_learner_queue.empty(): + received_transitions.extend(bytes_to_transitions(transitions_learner_queue.get())) + + assert len(received_transitions) == len(input_transitions) + for i, transition in enumerate(received_transitions): + assert_transitions_equal(transition, input_transitions[i]) + + +@require_package("grpc") +@pytest.mark.timeout(10) +def test_end_to_end_interactions_flow(cfg): + from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes + from lerobot.scripts.rl.actor import ( + establish_learner_connection, + learner_service_client, + send_interactions, + ) + from lerobot.scripts.rl.learner import start_learner + + """Test complete interactions flow from actor to learner.""" + # Queues for actor-learner communication + interactions_actor_queue = Queue() + interactions_learner_queue = Queue() + + # Other queues required by the learner + parameters_queue = Queue() + transitions_learner_queue = Queue() + + shutdown_event = Event() + + # Start the learner in a separate thread + learner_thread = threading.Thread( + target=start_learner, + args=(parameters_queue, transitions_learner_queue, interactions_learner_queue, shutdown_event, cfg), + ) + learner_thread.start() + + # Establish connection from actor to learner + policy_cfg = cfg.policy + learner_client, channel = learner_service_client( + host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port + ) + + assert establish_learner_connection(learner_client, shutdown_event, attempts=5) + + # Start the actor's interaction sending process in a separate thread + send_interactions_thread = threading.Thread( + target=send_interactions, + args=(cfg, interactions_actor_queue, shutdown_event, learner_client, channel), + ) + send_interactions_thread.start() + + # Create and push test interactions to the actor's queue + input_interactions = create_test_interactions(count=5) + for interaction in input_interactions: + interactions_actor_queue.put(python_object_to_bytes(interaction)) + + # Wait for the communication to happen + time.sleep(0.1) + + # Signal shutdown and wait for threads to complete + shutdown_event.set() + learner_thread.join() + send_interactions_thread.join() + channel.close() + + # Verify that the learner received the interactions + received_interactions = [] + while not interactions_learner_queue.empty(): + received_interactions.append(bytes_to_python_object(interactions_learner_queue.get())) + + assert len(received_interactions) == len(input_interactions) + + # Sort by a unique key to handle potential reordering in queues + received_interactions.sort(key=lambda x: x["step"]) + input_interactions.sort(key=lambda x: x["step"]) + + for received, expected in zip(received_interactions, input_interactions, strict=False): + assert received == expected + + +@require_package("grpc") +@pytest.mark.parametrize("data_size", ["small", "large"]) +@pytest.mark.timeout(10) +def test_end_to_end_parameters_flow(cfg, data_size): + from lerobot.common.transport.utils import bytes_to_state_dict, state_to_bytes + from lerobot.scripts.rl.actor import establish_learner_connection, learner_service_client, receive_policy + from lerobot.scripts.rl.learner import start_learner + + """Test complete parameter flow from learner to actor, with small and large data.""" + # Actor's local queue to receive params + parameters_actor_queue = Queue() + # Learner's queue to send params from + parameters_learner_queue = Queue() + + # Other queues required by the learner + transitions_learner_queue = Queue() + interactions_learner_queue = Queue() + + shutdown_event = Event() + + # Start the learner in a separate thread + learner_thread = threading.Thread( + target=start_learner, + args=( + parameters_learner_queue, + transitions_learner_queue, + interactions_learner_queue, + shutdown_event, + cfg, + ), + ) + learner_thread.start() + + # Establish connection from actor to learner + policy_cfg = cfg.policy + learner_client, channel = learner_service_client( + host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port + ) + + assert establish_learner_connection(learner_client, shutdown_event, attempts=5) + + # Start the actor's parameter receiving process in a separate thread + receive_params_thread = threading.Thread( + target=receive_policy, + args=(cfg, parameters_actor_queue, shutdown_event, learner_client, channel), + ) + receive_params_thread.start() + + # Create test parameters based on parametrization + if data_size == "small": + input_params = {"layer.weight": torch.randn(128, 64)} + else: # "large" + # CHUNK_SIZE is 2MB, so this tensor (4MB) will force chunking + input_params = {"large_layer.weight": torch.randn(1024, 1024)} + + # Simulate learner having new parameters to send + parameters_learner_queue.put(state_to_bytes(input_params)) + + # Wait for the actor to receive the parameters + time.sleep(0.1) + + # Signal shutdown and wait for threads to complete + shutdown_event.set() + learner_thread.join() + receive_params_thread.join() + channel.close() + + # Verify that the actor received the parameters correctly + received_params = bytes_to_state_dict(parameters_actor_queue.get()) + + assert received_params.keys() == input_params.keys() + for key in input_params: + assert torch.allclose(received_params[key], input_params[key]) diff --git a/tests/rl/test_learner_service.py b/tests/rl/test_learner_service.py new file mode 100644 index 0000000000..ee9d06e914 --- /dev/null +++ b/tests/rl/test_learner_service.py @@ -0,0 +1,374 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import threading +import time +from concurrent import futures +from multiprocessing import Event, Queue + +import pytest + +from tests.utils import require_package # our gRPC servicer class + + +@pytest.fixture(scope="function") +def learner_service_stub(): + shutdown_event = Event() + parameters_queue = Queue() + transitions_queue = Queue() + interactions_queue = Queue() + seconds_between_pushes = 1 + client, channel, server = create_learner_service_stub( + shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes + ) + + yield client # provide the stub to the test function + + close_learner_service_stub(channel, server) + + +@require_package("grpc") +def create_learner_service_stub( + shutdown_event: Event, + parameters_queue: Queue, + transitions_queue: Queue, + interactions_queue: Queue, + seconds_between_pushes: int, + queue_get_timeout: float = 0.1, +): + import grpc + + from lerobot.common.transport import services_pb2_grpc # generated from .proto + from lerobot.scripts.rl.learner_service import LearnerService + + """Fixture to start a LearnerService gRPC server and provide a connected stub.""" + + servicer = LearnerService( + shutdown_event=shutdown_event, + parameters_queue=parameters_queue, + seconds_between_pushes=seconds_between_pushes, + transition_queue=transitions_queue, + interaction_message_queue=interactions_queue, + queue_get_timeout=queue_get_timeout, + ) + + # Create a gRPC server and add our servicer to it. + server = grpc.server(futures.ThreadPoolExecutor(max_workers=4)) + services_pb2_grpc.add_LearnerServiceServicer_to_server(servicer, server) + port = server.add_insecure_port("[::]:0") # bind to a free port chosen by OS + server.start() # start the server (non-blocking call):contentReference[oaicite:1]{index=1} + + # Create a client channel and stub connected to the server's port. + channel = grpc.insecure_channel(f"localhost:{port}") + return services_pb2_grpc.LearnerServiceStub(channel), channel, server + + +@require_package("grpc") +def close_learner_service_stub(channel, server): + channel.close() + server.stop(None) + + +@pytest.mark.timeout(3) # force cross-platform watchdog +def test_ready_method(learner_service_stub): + from lerobot.common.transport import services_pb2 + + """Test the ready method of the UserService.""" + request = services_pb2.Empty() + response = learner_service_stub.Ready(request) + assert response == services_pb2.Empty() + + +@require_package("grpc") +@pytest.mark.timeout(3) # force cross-platform watchdog +def test_send_interactions(): + from lerobot.common.transport import services_pb2 + + shutdown_event = Event() + + parameters_queue = Queue() + transitions_queue = Queue() + interactions_queue = Queue() + seconds_between_pushes = 1 + client, channel, server = create_learner_service_stub( + shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes + ) + + list_of_interaction_messages = [ + services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"1"), + services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE, data=b"2"), + services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"3"), + services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"4"), + services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"5"), + services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"6"), + services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE, data=b"7"), + services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"8"), + ] + + def mock_intercations_stream(): + yield from list_of_interaction_messages + + return services_pb2.Empty() + + response = client.SendInteractions(mock_intercations_stream()) + assert response == services_pb2.Empty() + + close_learner_service_stub(channel, server) + + # Extract the data from the interactions queue + interactions = [] + while not interactions_queue.empty(): + interactions.append(interactions_queue.get()) + + assert interactions == [b"123", b"4", b"5", b"678"] + + +@require_package("grpc") +@pytest.mark.timeout(3) # force cross-platform watchdog +def test_send_transitions(): + from lerobot.common.transport import services_pb2 + + """Test the SendTransitions method with various transition data.""" + shutdown_event = Event() + parameters_queue = Queue() + transitions_queue = Queue() + interactions_queue = Queue() + seconds_between_pushes = 1 + + client, channel, server = create_learner_service_stub( + shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes + ) + + # Create test transition messages + list_of_transition_messages = [ + services_pb2.Transition( + transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"transition_1" + ), + services_pb2.Transition( + transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE, data=b"transition_2" + ), + services_pb2.Transition(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"transition_3"), + services_pb2.Transition(transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"batch_1"), + services_pb2.Transition(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"batch_2"), + ] + + def mock_transitions_stream(): + yield from list_of_transition_messages + + response = client.SendTransitions(mock_transitions_stream()) + assert response == services_pb2.Empty() + + close_learner_service_stub(channel, server) + + # Extract the data from the transitions queue + transitions = [] + while not transitions_queue.empty(): + transitions.append(transitions_queue.get()) + + # Should have assembled the chunked data + assert transitions == [b"transition_1transition_2transition_3", b"batch_1batch_2"] + + +@require_package("grpc") +@pytest.mark.timeout(3) # force cross-platform watchdog +def test_send_transitions_empty_stream(): + from lerobot.common.transport import services_pb2 + + """Test SendTransitions with empty stream.""" + shutdown_event = Event() + parameters_queue = Queue() + transitions_queue = Queue() + interactions_queue = Queue() + seconds_between_pushes = 1 + + client, channel, server = create_learner_service_stub( + shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes + ) + + def empty_stream(): + return iter([]) + + response = client.SendTransitions(empty_stream()) + assert response == services_pb2.Empty() + + close_learner_service_stub(channel, server) + + # Queue should remain empty + assert transitions_queue.empty() + + +@require_package("grpc") +@pytest.mark.timeout(10) # force cross-platform watchdog +def test_stream_parameters(): + import time + + from lerobot.common.transport import services_pb2 + + """Test the StreamParameters method.""" + shutdown_event = Event() + parameters_queue = Queue() + transitions_queue = Queue() + interactions_queue = Queue() + seconds_between_pushes = 0.2 # Short delay for testing + + client, channel, server = create_learner_service_stub( + shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes + ) + + # Add test parameters to the queue + test_params = [b"param_batch_1", b"param_batch_2"] + for param in test_params: + parameters_queue.put(param) + + # Start streaming parameters + request = services_pb2.Empty() + stream = client.StreamParameters(request) + + # Collect streamed parameters and timestamps + received_params = [] + timestamps = [] + + for response in stream: + received_params.append(response.data) + timestamps.append(time.time()) + + # We should receive one last item + break + + parameters_queue.put(b"param_batch_3") + + for response in stream: + received_params.append(response.data) + timestamps.append(time.time()) + + # We should receive only one item + break + + shutdown_event.set() + close_learner_service_stub(channel, server) + + assert received_params == [b"param_batch_2", b"param_batch_3"] + + # Check the time difference between the two sends + time_diff = timestamps[1] - timestamps[0] + # Check if the time difference is close to the expected push frequency + assert time_diff == pytest.approx(seconds_between_pushes, abs=0.1) + + +@require_package("grpc") +@pytest.mark.timeout(3) # force cross-platform watchdog +def test_stream_parameters_with_shutdown(): + from lerobot.common.transport import services_pb2 + + """Test StreamParameters handles shutdown gracefully.""" + shutdown_event = Event() + parameters_queue = Queue() + transitions_queue = Queue() + interactions_queue = Queue() + seconds_between_pushes = 0.1 + queue_get_timeout = 0.001 + + client, channel, server = create_learner_service_stub( + shutdown_event, + parameters_queue, + transitions_queue, + interactions_queue, + seconds_between_pushes, + queue_get_timeout=queue_get_timeout, + ) + + test_params = [b"param_batch_1", b"stop", b"param_batch_3", b"param_batch_4"] + + # create a thread that will put the parameters in the queue + def producer(): + for param in test_params: + parameters_queue.put(param) + time.sleep(0.1) + + producer_thread = threading.Thread(target=producer) + producer_thread.start() + + # Start streaming + request = services_pb2.Empty() + stream = client.StreamParameters(request) + + # Collect streamed parameters + received_params = [] + + for response in stream: + received_params.append(response.data) + + if response.data == b"stop": + shutdown_event.set() + + producer_thread.join() + close_learner_service_stub(channel, server) + + assert received_params == [b"param_batch_1", b"stop"] + + +@require_package("grpc") +@pytest.mark.timeout(3) # force cross-platform watchdog +def test_stream_parameters_waits_and_retries_on_empty_queue(): + import threading + import time + + from lerobot.common.transport import services_pb2 + + """Test that StreamParameters waits and retries when the queue is empty.""" + shutdown_event = Event() + parameters_queue = Queue() + transitions_queue = Queue() + interactions_queue = Queue() + seconds_between_pushes = 0.05 + queue_get_timeout = 0.01 + + client, channel, server = create_learner_service_stub( + shutdown_event, + parameters_queue, + transitions_queue, + interactions_queue, + seconds_between_pushes, + queue_get_timeout=queue_get_timeout, + ) + + request = services_pb2.Empty() + stream = client.StreamParameters(request) + + received_params = [] + + def producer(): + # Let the consumer start and find an empty queue. + # It will wait `seconds_between_pushes` (0.05s), then `get` will timeout after `queue_get_timeout` (0.01s). + # Total time for the first empty loop is > 0.06s. We wait a bit longer to be safe. + time.sleep(0.06) + parameters_queue.put(b"param_after_wait") + time.sleep(0.05) + parameters_queue.put(b"param_after_wait_2") + + producer_thread = threading.Thread(target=producer) + producer_thread.start() + + # The consumer will block here until the producer sends an item. + for response in stream: + received_params.append(response.data) + if response.data == b"param_after_wait_2": + break # We only need one item for this test. + + shutdown_event.set() + producer_thread.join() + close_learner_service_stub(channel, server) + + assert received_params == [b"param_after_wait", b"param_after_wait_2"] diff --git a/tests/transport/test_transport_utils.py b/tests/transport/test_transport_utils.py new file mode 100644 index 0000000000..cf33f52c04 --- /dev/null +++ b/tests/transport/test_transport_utils.py @@ -0,0 +1,571 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +from multiprocessing import Event, Queue +from pickle import UnpicklingError + +import pytest +import torch + +from lerobot.common.utils.transition import Transition +from tests.utils import require_cuda, require_package + + +@require_package("grpc") +def test_bytes_buffer_size_empty_buffer(): + from lerobot.common.transport.utils import bytes_buffer_size + + """Test with an empty buffer.""" + buffer = io.BytesIO() + assert bytes_buffer_size(buffer) == 0 + # Ensure position is reset to beginning + assert buffer.tell() == 0 + + +@require_package("grpc") +def test_bytes_buffer_size_small_buffer(): + from lerobot.common.transport.utils import bytes_buffer_size + + """Test with a small buffer.""" + buffer = io.BytesIO(b"Hello, World!") + assert bytes_buffer_size(buffer) == 13 + assert buffer.tell() == 0 + + +@require_package("grpc") +def test_bytes_buffer_size_large_buffer(): + from lerobot.common.transport.utils import CHUNK_SIZE, bytes_buffer_size + + """Test with a large buffer.""" + data = b"x" * (CHUNK_SIZE * 2 + 1000) + buffer = io.BytesIO(data) + assert bytes_buffer_size(buffer) == len(data) + assert buffer.tell() == 0 + + +@require_package("grpc") +def test_send_bytes_in_chunks_empty_data(): + from lerobot.common.transport.utils import send_bytes_in_chunks, services_pb2 + + """Test sending empty data.""" + message_class = services_pb2.InteractionMessage + chunks = list(send_bytes_in_chunks(b"", message_class)) + assert len(chunks) == 0 + + +@require_package("grpc") +def test_single_chunk_small_data(): + from lerobot.common.transport.utils import send_bytes_in_chunks, services_pb2 + + """Test data that fits in a single chunk.""" + data = b"Some data" + message_class = services_pb2.InteractionMessage + chunks = list(send_bytes_in_chunks(data, message_class)) + + assert len(chunks) == 1 + assert chunks[0].data == b"Some data" + assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_END + + +@require_package("grpc") +def test_not_silent_mode(): + from lerobot.common.transport.utils import send_bytes_in_chunks, services_pb2 + + """Test not silent mode.""" + data = b"Some data" + message_class = services_pb2.InteractionMessage + chunks = list(send_bytes_in_chunks(data, message_class, silent=False)) + assert len(chunks) == 1 + assert chunks[0].data == b"Some data" + + +@require_package("grpc") +def test_send_bytes_in_chunks_large_data(): + from lerobot.common.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2 + + """Test sending large data.""" + data = b"x" * (CHUNK_SIZE * 2 + 1000) + message_class = services_pb2.InteractionMessage + chunks = list(send_bytes_in_chunks(data, message_class)) + assert len(chunks) == 3 + assert chunks[0].data == b"x" * CHUNK_SIZE + assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_BEGIN + assert chunks[1].data == b"x" * CHUNK_SIZE + assert chunks[1].transfer_state == services_pb2.TransferState.TRANSFER_MIDDLE + assert chunks[2].data == b"x" * 1000 + assert chunks[2].transfer_state == services_pb2.TransferState.TRANSFER_END + + +@require_package("grpc") +def test_send_bytes_in_chunks_large_data_with_exact_chunk_size(): + from lerobot.common.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2 + + """Test sending large data with exact chunk size.""" + data = b"x" * CHUNK_SIZE + message_class = services_pb2.InteractionMessage + chunks = list(send_bytes_in_chunks(data, message_class)) + assert len(chunks) == 1 + assert chunks[0].data == data + assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_END + + +@require_package("grpc") +def test_receive_bytes_in_chunks_empty_data(): + from lerobot.common.transport.utils import receive_bytes_in_chunks + + """Test receiving empty data.""" + queue = Queue() + shutdown_event = Event() + + # Empty iterator + receive_bytes_in_chunks(iter([]), queue, shutdown_event) + + assert queue.empty() + + +@require_package("grpc") +def test_receive_bytes_in_chunks_single_chunk(): + from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2 + + """Test receiving a single chunk message.""" + queue = Queue() + shutdown_event = Event() + + data = b"Single chunk data" + chunks = [ + services_pb2.InteractionMessage(data=data, transfer_state=services_pb2.TransferState.TRANSFER_END) + ] + + receive_bytes_in_chunks(iter(chunks), queue, shutdown_event) + + assert queue.get(timeout=0.01) == data + assert queue.empty() + + +@require_package("grpc") +def test_receive_bytes_in_chunks_single_not_end_chunk(): + from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2 + + """Test receiving a single chunk message.""" + queue = Queue() + shutdown_event = Event() + + data = b"Single chunk data" + chunks = [ + services_pb2.InteractionMessage(data=data, transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE) + ] + + receive_bytes_in_chunks(iter(chunks), queue, shutdown_event) + + assert queue.empty() + + +@require_package("grpc") +def test_receive_bytes_in_chunks_multiple_chunks(): + from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2 + + """Test receiving a multi-chunk message.""" + queue = Queue() + shutdown_event = Event() + + chunks = [ + services_pb2.InteractionMessage( + data=b"First ", transfer_state=services_pb2.TransferState.TRANSFER_BEGIN + ), + services_pb2.InteractionMessage( + data=b"Middle ", transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE + ), + services_pb2.InteractionMessage(data=b"Last", transfer_state=services_pb2.TransferState.TRANSFER_END), + ] + + receive_bytes_in_chunks(iter(chunks), queue, shutdown_event) + + assert queue.get(timeout=0.01) == b"First Middle Last" + assert queue.empty() + + +@require_package("grpc") +def test_receive_bytes_in_chunks_multiple_messages(): + from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2 + + """Test receiving multiple complete messages in sequence.""" + queue = Queue() + shutdown_event = Event() + + chunks = [ + # First message - single chunk + services_pb2.InteractionMessage( + data=b"Message1", transfer_state=services_pb2.TransferState.TRANSFER_END + ), + # Second message - multi chunk + services_pb2.InteractionMessage( + data=b"Start2 ", transfer_state=services_pb2.TransferState.TRANSFER_BEGIN + ), + services_pb2.InteractionMessage( + data=b"Middle2 ", transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE + ), + services_pb2.InteractionMessage(data=b"End2", transfer_state=services_pb2.TransferState.TRANSFER_END), + # Third message - single chunk + services_pb2.InteractionMessage( + data=b"Message3", transfer_state=services_pb2.TransferState.TRANSFER_END + ), + ] + + receive_bytes_in_chunks(iter(chunks), queue, shutdown_event) + + # Should have three messages in queue + assert queue.get(timeout=0.01) == b"Message1" + assert queue.get(timeout=0.01) == b"Start2 Middle2 End2" + assert queue.get(timeout=0.01) == b"Message3" + assert queue.empty() + + +@require_package("grpc") +def test_receive_bytes_in_chunks_shutdown_during_receive(): + from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2 + + """Test that shutdown event stops receiving mid-stream.""" + queue = Queue() + shutdown_event = Event() + shutdown_event.set() + + chunks = [ + services_pb2.InteractionMessage( + data=b"First ", transfer_state=services_pb2.TransferState.TRANSFER_BEGIN + ), + services_pb2.InteractionMessage( + data=b"Middle ", transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE + ), + services_pb2.InteractionMessage(data=b"Last", transfer_state=services_pb2.TransferState.TRANSFER_END), + ] + + receive_bytes_in_chunks(iter(chunks), queue, shutdown_event) + + assert queue.empty() + + +@require_package("grpc") +def test_receive_bytes_in_chunks_only_begin_chunk(): + from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2 + + """Test receiving only a BEGIN chunk without END.""" + queue = Queue() + shutdown_event = Event() + + chunks = [ + services_pb2.InteractionMessage( + data=b"Start", transfer_state=services_pb2.TransferState.TRANSFER_BEGIN + ), + # No END chunk + ] + + receive_bytes_in_chunks(iter(chunks), queue, shutdown_event) + + assert queue.empty() + + +@require_package("grpc") +def test_receive_bytes_in_chunks_missing_begin(): + from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2 + + """Test receiving chunks starting with MIDDLE instead of BEGIN.""" + queue = Queue() + shutdown_event = Event() + + chunks = [ + # Missing BEGIN + services_pb2.InteractionMessage( + data=b"Middle", transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE + ), + services_pb2.InteractionMessage(data=b"End", transfer_state=services_pb2.TransferState.TRANSFER_END), + ] + + receive_bytes_in_chunks(iter(chunks), queue, shutdown_event) + + # The implementation continues from where it is, so we should get partial data + assert queue.get(timeout=0.01) == b"MiddleEnd" + assert queue.empty() + + +# Tests for state_to_bytes and bytes_to_state_dict +@require_package("grpc") +def test_state_to_bytes_empty_dict(): + from lerobot.common.transport.utils import bytes_to_state_dict, state_to_bytes + + """Test converting empty state dict to bytes.""" + state_dict = {} + data = state_to_bytes(state_dict) + reconstructed = bytes_to_state_dict(data) + assert reconstructed == state_dict + + +@require_package("grpc") +def test_bytes_to_state_dict_empty_data(): + from lerobot.common.transport.utils import bytes_to_state_dict + + """Test converting empty data to state dict.""" + with pytest.raises(EOFError): + bytes_to_state_dict(b"") + + +@require_package("grpc") +def test_state_to_bytes_simple_dict(): + from lerobot.common.transport.utils import bytes_to_state_dict, state_to_bytes + + """Test converting simple state dict to bytes.""" + state_dict = { + "layer1.weight": torch.randn(10, 5), + "layer1.bias": torch.randn(10), + "layer2.weight": torch.randn(1, 10), + "layer2.bias": torch.randn(1), + } + + data = state_to_bytes(state_dict) + assert isinstance(data, bytes) + assert len(data) > 0 + + reconstructed = bytes_to_state_dict(data) + + assert len(reconstructed) == len(state_dict) + for key in state_dict: + assert key in reconstructed + assert torch.allclose(state_dict[key], reconstructed[key]) + + +@require_package("grpc") +def test_state_to_bytes_various_dtypes(): + from lerobot.common.transport.utils import bytes_to_state_dict, state_to_bytes + + """Test converting state dict with various tensor dtypes.""" + state_dict = { + "float32": torch.randn(5, 5), + "float64": torch.randn(3, 3).double(), + "int32": torch.randint(0, 100, (4, 4), dtype=torch.int32), + "int64": torch.randint(0, 100, (2, 2), dtype=torch.int64), + "bool": torch.tensor([True, False, True]), + "uint8": torch.randint(0, 255, (3, 3), dtype=torch.uint8), + } + + data = state_to_bytes(state_dict) + reconstructed = bytes_to_state_dict(data) + + for key in state_dict: + assert reconstructed[key].dtype == state_dict[key].dtype + if state_dict[key].dtype == torch.bool: + assert torch.equal(state_dict[key], reconstructed[key]) + else: + assert torch.allclose(state_dict[key], reconstructed[key]) + + +@require_package("grpc") +def test_bytes_to_state_dict_invalid_data(): + from lerobot.common.transport.utils import bytes_to_state_dict + + """Test bytes_to_state_dict with invalid data.""" + with pytest.raises(UnpicklingError): + bytes_to_state_dict(b"This is not a valid torch save file") + + +@require_cuda +@require_package("grpc") +def test_state_to_bytes_various_dtypes_cuda(): + from lerobot.common.transport.utils import bytes_to_state_dict, state_to_bytes + + """Test converting state dict with various tensor dtypes.""" + state_dict = { + "float32": torch.randn(5, 5).cuda(), + "float64": torch.randn(3, 3).double().cuda(), + "int32": torch.randint(0, 100, (4, 4), dtype=torch.int32).cuda(), + "int64": torch.randint(0, 100, (2, 2), dtype=torch.int64).cuda(), + "bool": torch.tensor([True, False, True]), + "uint8": torch.randint(0, 255, (3, 3), dtype=torch.uint8), + } + + data = state_to_bytes(state_dict) + reconstructed = bytes_to_state_dict(data) + + for key in state_dict: + assert reconstructed[key].dtype == state_dict[key].dtype + if state_dict[key].dtype == torch.bool: + assert torch.equal(state_dict[key], reconstructed[key]) + else: + assert torch.allclose(state_dict[key], reconstructed[key]) + + +@require_package("grpc") +def test_python_object_to_bytes_none(): + from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes + + """Test converting None to bytes.""" + obj = None + data = python_object_to_bytes(obj) + reconstructed = bytes_to_python_object(data) + assert reconstructed is None + + +@pytest.mark.parametrize( + "obj", + [ + 42, + -123, + 3.14159, + -2.71828, + "Hello, World!", + "Unicode: 你好世界 🌍", + True, + False, + b"byte string", + [], + [1, 2, 3], + [1, "two", 3.0, True, None], + {}, + {"key": "value", "number": 123, "nested": {"a": 1}}, + (), + (1, 2, 3), + ], +) +@require_package("grpc") +def test_python_object_to_bytes_simple_types(obj): + from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes + + """Test converting simple Python types.""" + data = python_object_to_bytes(obj) + reconstructed = bytes_to_python_object(data) + assert reconstructed == obj + assert type(reconstructed) is type(obj) + + +@require_package("grpc") +def test_python_object_to_bytes_with_tensors(): + from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes + + """Test converting objects containing PyTorch tensors.""" + obj = { + "tensor": torch.randn(5, 5), + "list_with_tensor": [1, 2, torch.randn(3, 3), "string"], + "nested": { + "tensor1": torch.randn(2, 2), + "tensor2": torch.tensor([1, 2, 3]), + }, + } + + data = python_object_to_bytes(obj) + reconstructed = bytes_to_python_object(data) + + assert torch.allclose(obj["tensor"], reconstructed["tensor"]) + assert reconstructed["list_with_tensor"][0] == 1 + assert reconstructed["list_with_tensor"][3] == "string" + assert torch.allclose(obj["list_with_tensor"][2], reconstructed["list_with_tensor"][2]) + assert torch.allclose(obj["nested"]["tensor1"], reconstructed["nested"]["tensor1"]) + assert torch.equal(obj["nested"]["tensor2"], reconstructed["nested"]["tensor2"]) + + +@require_package("grpc") +def test_transitions_to_bytes_empty_list(): + from lerobot.common.transport.utils import bytes_to_transitions, transitions_to_bytes + + """Test converting empty transitions list.""" + transitions = [] + data = transitions_to_bytes(transitions) + reconstructed = bytes_to_transitions(data) + assert reconstructed == transitions + assert isinstance(reconstructed, list) + + +@require_package("grpc") +def test_transitions_to_bytes_single_transition(): + from lerobot.common.transport.utils import bytes_to_transitions, transitions_to_bytes + + """Test converting a single transition.""" + transition = Transition( + state={"image": torch.randn(3, 64, 64), "state": torch.randn(10)}, + action=torch.randn(5), + reward=torch.tensor(1.5), + done=torch.tensor(False), + next_state={"image": torch.randn(3, 64, 64), "state": torch.randn(10)}, + ) + + transitions = [transition] + data = transitions_to_bytes(transitions) + reconstructed = bytes_to_transitions(data) + + assert len(reconstructed) == 1 + + assert_transitions_equal(transitions[0], reconstructed[0]) + + +@require_package("grpc") +def assert_transitions_equal(t1: Transition, t2: Transition): + """Helper to assert two transitions are equal.""" + assert_observation_equal(t1["state"], t2["state"]) + assert torch.allclose(t1["action"], t2["action"]) + assert torch.allclose(t1["reward"], t2["reward"]) + assert torch.equal(t1["done"], t2["done"]) + assert_observation_equal(t1["next_state"], t2["next_state"]) + + +@require_package("grpc") +def assert_observation_equal(o1: dict, o2: dict): + """Helper to assert two observations are equal.""" + assert set(o1.keys()) == set(o2.keys()) + for key in o1: + assert torch.allclose(o1[key], o2[key]) + + +@require_package("grpc") +def test_transitions_to_bytes_multiple_transitions(): + from lerobot.common.transport.utils import bytes_to_transitions, transitions_to_bytes + + """Test converting multiple transitions.""" + transitions = [] + for i in range(5): + transition = Transition( + state={"data": torch.randn(10)}, + action=torch.randn(3), + reward=torch.tensor(float(i)), + done=torch.tensor(i == 4), + next_state={"data": torch.randn(10)}, + ) + transitions.append(transition) + + data = transitions_to_bytes(transitions) + reconstructed = bytes_to_transitions(data) + + assert len(reconstructed) == len(transitions) + for original, reconstructed_item in zip(transitions, reconstructed, strict=False): + assert_transitions_equal(original, reconstructed_item) + + +@require_package("grpc") +def test_receive_bytes_in_chunks_unknown_state(): + from lerobot.common.transport.utils import receive_bytes_in_chunks + + """Test receive_bytes_in_chunks with an unknown transfer state.""" + + # Mock the gRPC message object, which has `transfer_state` and `data` attributes. + class MockMessage: + def __init__(self, transfer_state, data): + self.transfer_state = transfer_state + self.data = data + + # 10 is not a valid TransferState enum value + bad_iterator = [MockMessage(transfer_state=10, data=b"bad_data")] + output_queue = Queue() + shutdown_event = Event() + + with pytest.raises(ValueError, match="Received unknown transfer state"): + receive_bytes_in_chunks(bad_iterator, output_queue, shutdown_event) diff --git a/tests/utils/test_process.py b/tests/utils/test_process.py new file mode 100644 index 0000000000..054a8593a5 --- /dev/null +++ b/tests/utils/test_process.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing +import os +import signal +import threading +from unittest.mock import patch + +import pytest + +from lerobot.common.utils.process import ProcessSignalHandler + + +# Fixture to reset shutdown_event_counter and original signal handlers before and after each test +@pytest.fixture(autouse=True) +def reset_globals_and_handlers(): + # Store original signal handlers + original_handlers = { + sig: signal.getsignal(sig) + for sig in [signal.SIGINT, signal.SIGTERM, signal.SIGHUP, signal.SIGQUIT] + if hasattr(signal, sig.name) + } + + yield + + # Restore original signal handlers + for sig, handler in original_handlers.items(): + signal.signal(sig, handler) + + +def test_setup_process_handlers_event_with_threads(): + """Test that setup_process_handlers returns the correct event type.""" + handler = ProcessSignalHandler(use_threads=True) + shutdown_event = handler.shutdown_event + assert isinstance(shutdown_event, threading.Event), "Should be a threading.Event" + assert not shutdown_event.is_set(), "Event should initially be unset" + + +def test_setup_process_handlers_event_with_processes(): + """Test that setup_process_handlers returns the correct event type.""" + handler = ProcessSignalHandler(use_threads=False) + shutdown_event = handler.shutdown_event + assert isinstance(shutdown_event, type(multiprocessing.Event())), "Should be a multiprocessing.Event" + assert not shutdown_event.is_set(), "Event should initially be unset" + + +@pytest.mark.parametrize("use_threads", [True, False]) +@pytest.mark.parametrize( + "sig", + [ + signal.SIGINT, + signal.SIGTERM, + # SIGHUP and SIGQUIT are not reliably available on all platforms (e.g. Windows) + pytest.param( + signal.SIGHUP, + marks=pytest.mark.skipif(not hasattr(signal, "SIGHUP"), reason="SIGHUP not available"), + ), + pytest.param( + signal.SIGQUIT, + marks=pytest.mark.skipif(not hasattr(signal, "SIGQUIT"), reason="SIGQUIT not available"), + ), + ], +) +def test_signal_handler_sets_event(use_threads, sig): + """Test that the signal handler sets the event on receiving a signal.""" + handler = ProcessSignalHandler(use_threads=use_threads) + shutdown_event = handler.shutdown_event + + assert handler.counter == 0 + + os.kill(os.getpid(), sig) + + # In some environments, the signal might take a moment to be handled. + shutdown_event.wait(timeout=1.0) + + assert shutdown_event.is_set(), f"Event should be set after receiving signal {sig}" + + # Ensure the internal counter was incremented + assert handler.counter == 1 + + +@pytest.mark.parametrize("use_threads", [True, False]) +@patch("sys.exit") +def test_force_shutdown_on_second_signal(mock_sys_exit, use_threads): + """Test that a second signal triggers a force shutdown.""" + handler = ProcessSignalHandler(use_threads=use_threads) + + os.kill(os.getpid(), signal.SIGINT) + # Give a moment for the first signal to be processed + import time + + time.sleep(0.1) + os.kill(os.getpid(), signal.SIGINT) + + time.sleep(0.1) + + assert handler.counter == 2 + mock_sys_exit.assert_called_once_with(1) diff --git a/tests/utils/test_queue.py b/tests/utils/test_queue.py new file mode 100644 index 0000000000..863231e82b --- /dev/null +++ b/tests/utils/test_queue.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +import time +from queue import Queue + +from lerobot.common.utils.queue import get_last_item_from_queue + + +def test_get_last_item_single_item(): + """Test getting the last item when queue has only one item.""" + queue = Queue() + queue.put("single_item") + + result = get_last_item_from_queue(queue) + + assert result == "single_item" + assert queue.empty() + + +def test_get_last_item_multiple_items(): + """Test getting the last item when queue has multiple items.""" + queue = Queue() + items = ["first", "second", "third", "fourth", "last"] + + for item in items: + queue.put(item) + + result = get_last_item_from_queue(queue) + + assert result == "last" + assert queue.empty() + + +def test_get_last_item_different_types(): + """Test with different data types in the queue.""" + queue = Queue() + items = [1, 2.5, "string", {"key": "value"}, [1, 2, 3], ("tuple", "data")] + + for item in items: + queue.put(item) + + result = get_last_item_from_queue(queue) + + assert result == ("tuple", "data") + assert queue.empty() + + +def test_get_last_item_maxsize_queue(): + """Test with a queue that has a maximum size.""" + queue = Queue(maxsize=5) + + # Fill the queue + for i in range(5): + queue.put(i) + + # Give the queue time to fill + time.sleep(0.1) + + result = get_last_item_from_queue(queue) + + assert result == 4 + assert queue.empty() + + +def test_get_last_item_with_none_values(): + """Test with None values in the queue.""" + queue = Queue() + items = [1, None, 2, None, 3] + + for item in items: + queue.put(item) + + # Give the queue time to fill + time.sleep(0.1) + + result = get_last_item_from_queue(queue) + + assert result == 3 + assert queue.empty() + + +def test_get_last_item_blocking_timeout(): + """Test get_last_item_from_queue returns None on timeout.""" + queue = Queue() + result = get_last_item_from_queue(queue, block=True, timeout=0.1) + assert result is None + + +def test_get_last_item_non_blocking_empty(): + """Test get_last_item_from_queue with block=False on an empty queue returns None.""" + queue = Queue() + result = get_last_item_from_queue(queue, block=False) + assert result is None + + +def test_get_last_item_non_blocking_success(): + """Test get_last_item_from_queue with block=False on a non-empty queue.""" + queue = Queue() + items = ["first", "second", "last"] + for item in items: + queue.put(item) + + # Give the queue time to fill + time.sleep(0.1) + + result = get_last_item_from_queue(queue, block=False) + assert result == "last" + assert queue.empty() + + +def test_get_last_item_blocking_waits_for_item(): + """Test that get_last_item_from_queue waits for an item if block=True.""" + queue = Queue() + result = [] + + def producer(): + queue.put("item1") + queue.put("item2") + + def consumer(): + # This will block until the producer puts the first item + item = get_last_item_from_queue(queue, block=True, timeout=0.2) + result.append(item) + + producer_thread = threading.Thread(target=producer) + consumer_thread = threading.Thread(target=consumer) + + producer_thread.start() + consumer_thread.start() + + producer_thread.join() + consumer_thread.join() + + assert result == ["item2"] + assert queue.empty() diff --git a/tests/utils/test_replay_buffer.py b/tests/utils/test_replay_buffer.py new file mode 100644 index 0000000000..f7a055b20f --- /dev/null +++ b/tests/utils/test_replay_buffer.py @@ -0,0 +1,682 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from typing import Callable + +import pytest +import torch + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.utils.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized +from tests.fixtures.constants import DUMMY_REPO_ID + + +def state_dims() -> list[str]: + return ["observation.image", "observation.state"] + + +@pytest.fixture +def replay_buffer() -> ReplayBuffer: + return create_empty_replay_buffer() + + +def clone_state(state: dict) -> dict: + return {k: v.clone() for k, v in state.items()} + + +def create_empty_replay_buffer( + optimize_memory: bool = False, + use_drq: bool = False, + image_augmentation_function: Callable | None = None, +) -> ReplayBuffer: + buffer_capacity = 10 + device = "cpu" + return ReplayBuffer( + buffer_capacity, + device, + state_dims(), + optimize_memory=optimize_memory, + use_drq=use_drq, + image_augmentation_function=image_augmentation_function, + ) + + +def create_random_image() -> torch.Tensor: + return torch.rand(3, 84, 84) + + +def create_dummy_transition() -> dict: + return { + "observation.image": create_random_image(), + "action": torch.randn(4), + "reward": torch.tensor(1.0), + "observation.state": torch.randn( + 10, + ), + "done": torch.tensor(False), + "truncated": torch.tensor(False), + "complementary_info": {}, + } + + +def create_dataset_from_replay_buffer(tmp_path) -> tuple[LeRobotDataset, ReplayBuffer]: + dummy_state_1 = create_dummy_state() + dummy_action_1 = create_dummy_action() + + dummy_state_2 = create_dummy_state() + dummy_action_2 = create_dummy_action() + + dummy_state_3 = create_dummy_state() + dummy_action_3 = create_dummy_action() + + dummy_state_4 = create_dummy_state() + dummy_action_4 = create_dummy_action() + + replay_buffer = create_empty_replay_buffer() + replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False) + replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False) + replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True) + replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True) + + root = tmp_path / "test" + return (replay_buffer.to_lerobot_dataset(DUMMY_REPO_ID, root=root), replay_buffer) + + +def create_dummy_state() -> dict: + return { + "observation.image": create_random_image(), + "observation.state": torch.randn( + 10, + ), + } + + +def get_tensor_memory_consumption(tensor): + return tensor.nelement() * tensor.element_size() + + +def get_tensors_memory_consumption(obj, visited_addresses): + total_size = 0 + + address = id(obj) + if address in visited_addresses: + return 0 + + visited_addresses.add(address) + + if isinstance(obj, torch.Tensor): + return get_tensor_memory_consumption(obj) + elif isinstance(obj, (list, tuple)): + for item in obj: + total_size += get_tensors_memory_consumption(item, visited_addresses) + elif isinstance(obj, dict): + for value in obj.values(): + total_size += get_tensors_memory_consumption(value, visited_addresses) + elif hasattr(obj, "__dict__"): + # It's an object, we need to get the size of the attributes + for _, attr in vars(obj).items(): + total_size += get_tensors_memory_consumption(attr, visited_addresses) + + return total_size + + +def get_object_memory(obj): + # Track visited addresses to avoid infinite loops + # and cases when two properties point to the same object + visited_addresses = set() + + # Get the size of the object in bytes + total_size = sys.getsizeof(obj) + + # Get the size of the tensor attributes + total_size += get_tensors_memory_consumption(obj, visited_addresses) + + return total_size + + +def create_dummy_action() -> torch.Tensor: + return torch.randn(4) + + +def dict_properties() -> list: + return ["state", "next_state"] + + +@pytest.fixture +def dummy_state() -> dict: + return create_dummy_state() + + +@pytest.fixture +def next_dummy_state() -> dict: + return create_dummy_state() + + +@pytest.fixture +def dummy_action() -> torch.Tensor: + return torch.randn(4) + + +def test_empty_buffer_sample_raises_error(replay_buffer): + assert len(replay_buffer) == 0, "Replay buffer should be empty." + assert replay_buffer.capacity == 10, "Replay buffer capacity should be 10." + with pytest.raises(RuntimeError, match="Cannot sample from an empty buffer"): + replay_buffer.sample(1) + + +def test_zero_capacity_buffer_raises_error(): + with pytest.raises(ValueError, match="Capacity must be greater than 0."): + ReplayBuffer(0, "cpu", ["observation", "next_observation"]) + + +def test_add_transition(replay_buffer, dummy_state, dummy_action): + replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False) + assert len(replay_buffer) == 1, "Replay buffer should have one transition after adding." + assert torch.equal(replay_buffer.actions[0], dummy_action), ( + "Action should be equal to the first transition." + ) + assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the first transition." + assert not replay_buffer.dones[0], "Done should be False for the first transition." + assert not replay_buffer.truncateds[0], "Truncated should be False for the first transition." + + for dim in state_dims(): + assert torch.equal(replay_buffer.states[dim][0], dummy_state[dim]), ( + "Observation should be equal to the first transition." + ) + assert torch.equal(replay_buffer.next_states[dim][0], dummy_state[dim]), ( + "Next observation should be equal to the first transition." + ) + + +def test_add_over_capacity(): + replay_buffer = ReplayBuffer(2, "cpu", ["observation", "next_observation"]) + dummy_state_1 = create_dummy_state() + dummy_action_1 = create_dummy_action() + + dummy_state_2 = create_dummy_state() + dummy_action_2 = create_dummy_action() + + dummy_state_3 = create_dummy_state() + dummy_action_3 = create_dummy_action() + + replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False) + replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False) + replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True) + + assert len(replay_buffer) == 2, "Replay buffer should have 2 transitions after adding 3." + + for dim in state_dims(): + assert torch.equal(replay_buffer.states[dim][0], dummy_state_3[dim]), ( + "Observation should be equal to the first transition." + ) + assert torch.equal(replay_buffer.next_states[dim][0], dummy_state_3[dim]), ( + "Next observation should be equal to the first transition." + ) + + assert torch.equal(replay_buffer.actions[0], dummy_action_3), ( + "Action should be equal to the last transition." + ) + assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the last transition." + assert replay_buffer.dones[0], "Done should be True for the first transition." + assert replay_buffer.truncateds[0], "Truncated should be True for the first transition." + + +def test_sample_from_empty_buffer(replay_buffer): + with pytest.raises(RuntimeError, match="Cannot sample from an empty buffer"): + replay_buffer.sample(1) + + +def test_sample_with_1_transition(replay_buffer, dummy_state, next_dummy_state, dummy_action): + replay_buffer.add(dummy_state, dummy_action, 1.0, next_dummy_state, False, False) + got_batch_transition = replay_buffer.sample(1) + + expected_batch_transition = BatchTransition( + state=clone_state(dummy_state), + action=dummy_action.clone(), + reward=1.0, + next_state=clone_state(next_dummy_state), + done=False, + truncated=False, + ) + + for buffer_property in dict_properties(): + for k, v in expected_batch_transition[buffer_property].items(): + got_state = got_batch_transition[buffer_property][k] + + assert got_state.shape[0] == 1, f"{k} should have 1 transition." + assert got_state.device.type == "cpu", f"{k} should be on cpu." + + assert torch.equal(got_state[0], v), f"{k} should be equal to the expected batch transition." + + for key, _value in expected_batch_transition.items(): + if key in dict_properties(): + continue + + got_value = got_batch_transition[key] + + v_tensor = expected_batch_transition[key] + if not isinstance(v_tensor, torch.Tensor): + v_tensor = torch.tensor(v_tensor) + + assert got_value.shape[0] == 1, f"{key} should have 1 transition." + assert got_value.device.type == "cpu", f"{key} should be on cpu." + assert torch.equal(got_value[0], v_tensor), f"{key} should be equal to the expected batch transition." + + +def test_sample_with_batch_bigger_than_buffer_size( + replay_buffer, dummy_state, next_dummy_state, dummy_action +): + replay_buffer.add(dummy_state, dummy_action, 1.0, next_dummy_state, False, False) + got_batch_transition = replay_buffer.sample(10) + + expected_batch_transition = BatchTransition( + state=dummy_state, + action=dummy_action, + reward=1.0, + next_state=next_dummy_state, + done=False, + truncated=False, + ) + + for buffer_property in dict_properties(): + for k in expected_batch_transition[buffer_property]: + got_state = got_batch_transition[buffer_property][k] + + assert got_state.shape[0] == 1, f"{k} should have 1 transition." + + for key in expected_batch_transition: + if key in dict_properties(): + continue + + got_value = got_batch_transition[key] + assert got_value.shape[0] == 1, f"{key} should have 1 transition." + + +def test_sample_batch(replay_buffer): + dummy_state_1 = create_dummy_state() + dummy_action_1 = create_dummy_action() + + dummy_state_2 = create_dummy_state() + dummy_action_2 = create_dummy_action() + + dummy_state_3 = create_dummy_state() + dummy_action_3 = create_dummy_action() + + dummy_state_4 = create_dummy_state() + dummy_action_4 = create_dummy_action() + + replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False) + replay_buffer.add(dummy_state_2, dummy_action_2, 2.0, dummy_state_2, False, False) + replay_buffer.add(dummy_state_3, dummy_action_3, 3.0, dummy_state_3, True, True) + replay_buffer.add(dummy_state_4, dummy_action_4, 4.0, dummy_state_4, True, True) + + dummy_states = [dummy_state_1, dummy_state_2, dummy_state_3, dummy_state_4] + dummy_actions = [dummy_action_1, dummy_action_2, dummy_action_3, dummy_action_4] + + got_batch_transition = replay_buffer.sample(3) + + for buffer_property in dict_properties(): + for k in got_batch_transition[buffer_property]: + got_state = got_batch_transition[buffer_property][k] + + assert got_state.shape[0] == 3, f"{k} should have 3 transition." + + for got_state_item in got_state: + assert any(torch.equal(got_state_item, dummy_state[k]) for dummy_state in dummy_states), ( + f"{k} should be equal to one of the dummy states." + ) + + for got_action_item in got_batch_transition["action"]: + assert any(torch.equal(got_action_item, dummy_action) for dummy_action in dummy_actions), ( + "Actions should be equal to the dummy actions." + ) + + for k in got_batch_transition: + if k in dict_properties() or k == "complementary_info": + continue + + got_value = got_batch_transition[k] + assert got_value.shape[0] == 3, f"{k} should have 3 transition." + + +def test_to_lerobot_dataset_with_empty_buffer(replay_buffer): + with pytest.raises(ValueError, match="The replay buffer is empty. Cannot convert to a dataset."): + replay_buffer.to_lerobot_dataset("dummy_repo") + + +def test_to_lerobot_dataset(tmp_path): + ds, buffer = create_dataset_from_replay_buffer(tmp_path) + + assert len(ds) == len(buffer), "Dataset should have the same size as the Replay Buffer" + assert ds.fps == 1, "FPS should be 1" + assert ds.repo_id == "dummy/repo", "The dataset should have `dummy/repo` repo id" + + for dim in state_dims(): + assert dim in ds.features + assert ds.features[dim]["shape"] == buffer.states[dim][0].shape + + assert ds.num_episodes == 2 + assert ds.num_frames == 4 + + for j, value in enumerate(ds): + print(torch.equal(value["observation.image"], buffer.next_states["observation.image"][j])) + + for i in range(len(ds)): + for feature, value in ds[i].items(): + if feature == "action": + assert torch.equal(value, buffer.actions[i]) + elif feature == "next.reward": + assert torch.equal(value, buffer.rewards[i]) + elif feature == "next.done": + assert torch.equal(value, buffer.dones[i]) + elif feature == "observation.image": + # Tenssor -> numpy is not precise, so we have some diff there + # TODO: Check and fix it + torch.testing.assert_close(value, buffer.states["observation.image"][i], rtol=0.3, atol=0.003) + elif feature == "observation.state": + assert torch.equal(value, buffer.states["observation.state"][i]) + + +def test_from_lerobot_dataset(tmp_path): + dummy_state_1 = create_dummy_state() + dummy_action_1 = create_dummy_action() + + dummy_state_2 = create_dummy_state() + dummy_action_2 = create_dummy_action() + + dummy_state_3 = create_dummy_state() + dummy_action_3 = create_dummy_action() + + dummy_state_4 = create_dummy_state() + dummy_action_4 = create_dummy_action() + + replay_buffer = create_empty_replay_buffer() + replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False) + replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False) + replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True) + replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True) + + root = tmp_path / "test" + ds = replay_buffer.to_lerobot_dataset(DUMMY_REPO_ID, root=root) + + reconverted_buffer = ReplayBuffer.from_lerobot_dataset( + ds, state_keys=list(state_dims()), device="cpu", capacity=replay_buffer.capacity, use_drq=False + ) + + # Check only the part of the buffer that's actually filled with data + assert torch.equal( + reconverted_buffer.actions[: len(replay_buffer)], + replay_buffer.actions[: len(replay_buffer)], + ), "Actions from converted buffer should be equal to the original replay buffer." + assert torch.equal( + reconverted_buffer.rewards[: len(replay_buffer)], replay_buffer.rewards[: len(replay_buffer)] + ), "Rewards from converted buffer should be equal to the original replay buffer." + assert torch.equal( + reconverted_buffer.dones[: len(replay_buffer)], replay_buffer.dones[: len(replay_buffer)] + ), "Dones from converted buffer should be equal to the original replay buffer." + + # Lerobot DS haven't supported truncateds yet + expected_truncateds = torch.zeros(len(replay_buffer)).bool() + assert torch.equal(reconverted_buffer.truncateds[: len(replay_buffer)], expected_truncateds), ( + "Truncateds from converted buffer should be equal False" + ) + + assert torch.equal( + replay_buffer.states["observation.state"][: len(replay_buffer)], + reconverted_buffer.states["observation.state"][: len(replay_buffer)], + ), "State should be the same after converting to dataset and return back" + + for i in range(4): + torch.testing.assert_close( + replay_buffer.states["observation.image"][i], + reconverted_buffer.states["observation.image"][i], + rtol=0.4, + atol=0.004, + ) + + # The 2, 3 frames have done flag, so their values will be equal to the current state + for i in range(2): + # In the current implementation we take the next state from the `states` and ignore `next_states` + next_index = (i + 1) % 4 + + torch.testing.assert_close( + replay_buffer.states["observation.image"][next_index], + reconverted_buffer.next_states["observation.image"][i], + rtol=0.4, + atol=0.004, + ) + + for i in range(2, 4): + assert torch.equal( + replay_buffer.states["observation.state"][i], + reconverted_buffer.next_states["observation.state"][i], + ) + + +def test_buffer_sample_alignment(): + # Initialize buffer + buffer = ReplayBuffer(capacity=100, device="cpu", state_keys=["state_value"], storage_device="cpu") + + # Fill buffer with patterned data + for i in range(100): + signature = float(i) / 100.0 + state = {"state_value": torch.tensor([[signature]]).float()} + action = torch.tensor([[2.0 * signature]]).float() + reward = 3.0 * signature + + is_end = (i + 1) % 10 == 0 + if is_end: + next_state = {"state_value": torch.tensor([[signature]]).float()} + done = True + else: + next_signature = float(i + 1) / 100.0 + next_state = {"state_value": torch.tensor([[next_signature]]).float()} + done = False + + buffer.add(state, action, reward, next_state, done, False) + + # Sample and verify + batch = buffer.sample(50) + + for i in range(50): + state_sig = batch["state"]["state_value"][i].item() + action_val = batch["action"][i].item() + reward_val = batch["reward"][i].item() + next_state_sig = batch["next_state"]["state_value"][i].item() + is_done = batch["done"][i].item() > 0.5 + + # Verify relationships + assert abs(action_val - 2.0 * state_sig) < 1e-4, ( + f"Action {action_val} should be 2x state signature {state_sig}" + ) + + assert abs(reward_val - 3.0 * state_sig) < 1e-4, ( + f"Reward {reward_val} should be 3x state signature {state_sig}" + ) + + if is_done: + assert abs(next_state_sig - state_sig) < 1e-4, ( + f"For done states, next_state {next_state_sig} should equal state {state_sig}" + ) + else: + # Either it's the next sequential state (+0.01) or same state (for episode boundaries) + valid_next = ( + abs(next_state_sig - state_sig - 0.01) < 1e-4 or abs(next_state_sig - state_sig) < 1e-4 + ) + assert valid_next, ( + f"Next state {next_state_sig} should be either state+0.01 or same as state {state_sig}" + ) + + +def test_memory_optimization(): + dummy_state_1 = create_dummy_state() + dummy_action_1 = create_dummy_action() + + dummy_state_2 = create_dummy_state() + dummy_action_2 = create_dummy_action() + + dummy_state_3 = create_dummy_state() + dummy_action_3 = create_dummy_action() + + dummy_state_4 = create_dummy_state() + dummy_action_4 = create_dummy_action() + + replay_buffer = create_empty_replay_buffer() + replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_2, False, False) + replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_3, False, False) + replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_4, False, False) + replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True) + + optimized_replay_buffer = create_empty_replay_buffer(True) + optimized_replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_2, False, False) + optimized_replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_3, False, False) + optimized_replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_4, False, False) + optimized_replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, None, True, True) + + assert get_object_memory(optimized_replay_buffer) < get_object_memory(replay_buffer), ( + "Optimized replay buffer should be smaller than the original replay buffer" + ) + + +def test_check_image_augmentations_with_drq_and_dummy_image_augmentation_function(dummy_state, dummy_action): + def dummy_image_augmentation_function(x): + return torch.ones_like(x) * 10 + + replay_buffer = create_empty_replay_buffer( + use_drq=True, image_augmentation_function=dummy_image_augmentation_function + ) + + replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False) + + sampled_transitions = replay_buffer.sample(1) + assert torch.all(sampled_transitions["state"]["observation.image"] == 10), ( + "Image augmentations should be applied" + ) + assert torch.all(sampled_transitions["next_state"]["observation.image"] == 10), ( + "Image augmentations should be applied" + ) + + +def test_check_image_augmentations_with_drq_and_default_image_augmentation_function( + dummy_state, dummy_action +): + replay_buffer = create_empty_replay_buffer(use_drq=True) + + replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False) + + # Let's check that it doesn't fail and shapes are correct + sampled_transitions = replay_buffer.sample(1) + assert sampled_transitions["state"]["observation.image"].shape == (1, 3, 84, 84) + assert sampled_transitions["next_state"]["observation.image"].shape == (1, 3, 84, 84) + + +def test_random_crop_vectorized_basic(): + # Create a batch of 2 images with known patterns + batch_size, channels, height, width = 2, 3, 10, 8 + images = torch.zeros((batch_size, channels, height, width)) + + # Fill with unique values for testing + for b in range(batch_size): + images[b] = b + 1 + + crop_size = (6, 4) # Smaller than original + cropped = random_crop_vectorized(images, crop_size) + + # Check output shape + assert cropped.shape == (batch_size, channels, *crop_size) + + # Check that values are preserved (should be either 1s or 2s for respective batches) + assert torch.all(cropped[0] == 1) + assert torch.all(cropped[1] == 2) + + +def test_random_crop_vectorized_invalid_size(): + images = torch.zeros((2, 3, 10, 8)) + + # Test crop size larger than image + with pytest.raises(ValueError, match="Requested crop size .* is bigger than the image size"): + random_crop_vectorized(images, (12, 8)) + + with pytest.raises(ValueError, match="Requested crop size .* is bigger than the image size"): + random_crop_vectorized(images, (10, 10)) + + +def _populate_buffer_for_async_test(capacity: int = 10) -> ReplayBuffer: + """Create a small buffer with deterministic 3×128×128 images and 11-D state.""" + buffer = ReplayBuffer( + capacity=capacity, + device="cpu", + state_keys=["observation.image", "observation.state"], + storage_device="cpu", + ) + + for i in range(capacity): + img = torch.ones(3, 128, 128) * i + state_vec = torch.arange(11).float() + i + state = { + "observation.image": img, + "observation.state": state_vec, + } + buffer.add( + state=state, + action=torch.tensor([0.0]), + reward=0.0, + next_state=state, + done=False, + truncated=False, + ) + return buffer + + +def test_async_iterator_shapes_basic(): + buffer = _populate_buffer_for_async_test() + batch_size = 2 + iterator = buffer.get_iterator(batch_size=batch_size, async_prefetch=True, queue_size=1) + batch = next(iterator) + + images = batch["state"]["observation.image"] + states = batch["state"]["observation.state"] + + assert images.shape == (batch_size, 3, 128, 128) + assert states.shape == (batch_size, 11) + + next_images = batch["next_state"]["observation.image"] + next_states = batch["next_state"]["observation.state"] + + assert next_images.shape == (batch_size, 3, 128, 128) + assert next_states.shape == (batch_size, 11) + + +def test_async_iterator_multiple_iterations(): + buffer = _populate_buffer_for_async_test() + batch_size = 2 + iterator = buffer.get_iterator(batch_size=batch_size, async_prefetch=True, queue_size=2) + + for _ in range(5): + batch = next(iterator) + images = batch["state"]["observation.image"] + states = batch["state"]["observation.state"] + assert images.shape == (batch_size, 3, 128, 128) + assert states.shape == (batch_size, 11) + + next_images = batch["next_state"]["observation.image"] + next_states = batch["next_state"]["observation.state"] + assert next_images.shape == (batch_size, 3, 128, 128) + assert next_states.shape == (batch_size, 11) + + # Ensure iterator can be disposed without blocking + del iterator