diff --git a/.gitignore b/.gitignore index 97b6af2f8d..4ab886933e 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,7 @@ outputs # VS Code .vscode +.devcontainer # HPC nautilus/*.yaml diff --git a/README.md b/README.md index 8462634f45..e98f35663a 100644 --- a/README.md +++ b/README.md @@ -418,6 +418,19 @@ Additionally, if you are using any of the particular policy architecture, pretra year={2024} } ``` + + +- [HIL-SERL](https://hil-serl.github.io/) +```bibtex +@Article{luo2024hilserl, +title={Precise and Dexterous Robotic Manipulation via Human-in-the-Loop Reinforcement Learning}, +author={Jianlan Luo and Charles Xu and Jeffrey Wu and Sergey Levine}, +year={2024}, +eprint={2410.21845}, +archivePrefix={arXiv}, +primaryClass={cs.RO} +} +``` ## Star History [![Star History Chart](https://api.star-history.com/svg?repos=huggingface/lerobot&type=Timeline)](https://star-history.com/#huggingface/lerobot&Timeline) diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 5e628dec37..37938358ff 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -9,6 +9,10 @@ title: Getting Started with Real-World Robots - local: cameras title: Cameras + - local: hilserl + title: Train a Robot with RL + - local: hilserl_sim + title: Train RL in Simulation title: "Tutorials" - sections: - local: so101 diff --git a/docs/source/hilserl.mdx b/docs/source/hilserl.mdx new file mode 100644 index 0000000000..149b25c687 --- /dev/null +++ b/docs/source/hilserl.mdx @@ -0,0 +1,547 @@ +# HIL-SERL Real Robot Training Workflow Guide + +In this tutorial you will go through the full Human-in-the-Loop Sample-Efficient Reinforcement Learning (HIL-SERL) workflow using LeRobot. You will master training a policy with RL on a real robot in just a few hours. + +HIL-SERL is a sample-efficient reinforcement learning algorithm that combines human demonstrations with online learning and human interventions. The approach starts from a small set of human demonstrations, uses them to train a reward classifier, and then employs an actor-learner architecture where humans can intervene during policy execution to guide exploration and correct unsafe behaviors. In this tutorial, you'll use a gamepad to provide interventions and control the robot during the learning process. + +It combines three key ingredients: + 1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point. + 2. **On-robot actor / learner loop with human interventions:** a distributed Soft Actor Critic (SAC) learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour. + 3. **Safety & efficiency tools:** joint/end-effector (EE) bounds, crop region of interest (ROI) preprocessing and WandB monitoring keep the data useful and the hardware safe. + +Together these elements let HIL-SERL reach near-perfect task success and faster cycle times than imitation-only baselines. + +

+ HIL-SERL workflow +

+ +

HIL-SERL workflow, Luo et al. 2024

+ +This guide provides step-by-step instructions for training a robot policy using LeRobot's HilSerl implementation to train on a real robot. + +## What do I need? + +- A gamepad (recommended) or keyboard to control the robot +- A Nvidia GPU +- A real robot with a follower and leader arm (optional if you use the keyboard or the gamepad) + +## What kind of tasks can I train? + +One can use HIL-SERL to train on a variety of manipulation tasks. Some recommendations: +- Start with a simple task to understand how the system works. + - Push cube to a goal region + - Pick and lift cube with the gripper +- Avoid extremely long horizon tasks. Focus on tasks that can be completed in 5-10 seconds. +- Once you have a good idea of how the system works, you can try more complex tasks and longer horizons. + - Pick and place cube + - Bimanual tasks to pick objects with two arms + - Hand-over tasks to transfer objects from one arm to another + - Go crazy! + +## Install LeRobot with HIL-SERL + +To install LeRobot with HIL-SERL, you need to install the `hilserl` extra. + +```bash +pip install -e ".[hilserl]" +``` + +## Real Robot Training Workflow + +### Understanding Configuration + +The training process begins with proper configuration for the HILSerl environment. The configuration class of interest is `HILSerlRobotEnvConfig` in `lerobot/common/envs/configs.py`. Which is defined as: + +```python +class HILSerlRobotEnvConfig(EnvConfig): + robot: RobotConfig | None = None # Main robot agent (defined in `lerobot/common/robots`) + teleop: TeleoperatorConfig | None = None # Teleoperator agent, e.g., gamepad or leader arm, (defined in `lerobot/common/teleoperators`) + wrapper: EnvTransformConfig | None = None # Environment wrapper settings; check `lerobot/scripts/server/gym_manipulator.py` + fps: int = 10 # Control frequency + name: str = "real_robot" # Environment name + mode: str = None # "record", "replay", or None (for training) + repo_id: str | None = None # LeRobot dataset repository ID + dataset_root: str | None = None # Local dataset root (optional) + task: str = "" # Task identifier + num_episodes: int = 10 # Number of episodes for recording + episode: int = 0 # episode index for replay + device: str = "cuda" # Compute device + push_to_hub: bool = True # Whether to push the recorded datasets to Hub + pretrained_policy_name_or_path: str | None = None # For policy loading + reward_classifier_pretrained_path: str | None = None # For reward model + number_of_steps_after_success: int = 0 # For reward classifier, collect more positive examples after a success to train a classifier +``` + + +### Finding Robot Workspace Bounds + +Before collecting demonstrations, you need to determine the appropriate operational bounds for your robot. + +This helps simplify the problem of learning on the real robot in two ways: 1) by limiting the robot's operational space to a specific region that solves the task and avoids unnecessary or unsafe exploration, and 2) by allowing training in end-effector space rather than joint space. Empirically, learning in joint space for reinforcement learning in manipulation is often a harder problem - some tasks are nearly impossible to learn in joint space but become learnable when the action space is transformed to end-effector coordinates. + +**Using find_joint_limits.py** + +This script helps you find the safe operational bounds for your robot's end-effector. Given that you have a follower and leader arm, you can use the script to find the bounds for the follower arm that will be applied during training. +Bounding the action space will reduce the redundant exploration of the agent and guarantees safety. + +```bash +python -m lerobot.scripts.find_joint_limits \ + --robot.type=so100_follower \ + --robot.port=/dev/tty.usbmodem58760431541 \ + --robot.id=black \ + --teleop.type=so100_leader \ + --teleop.port=/dev/tty.usbmodem58760431551 \ + --teleop.id=blue +``` + +**Workflow** + +1. Run the script and move the robot through the space that solves the task +2. The script will record the minimum and maximum end-effector positions and the joint angles and prints them to the console, for example: + ``` + Max ee position [0.2417 0.2012 0.1027] + Min ee position [0.1663 -0.0823 0.0336] + Max joint positions [-20.0, -20.0, -20.0, -20.0, -20.0, -20.0] + Min joint positions [50.0, 50.0, 50.0, 50.0, 50.0, 50.0] + ``` +3. Use these values in the configuration of your teleoperation device (TeleoperatorConfig) under the `end_effector_bounds` field + +**Example Configuration** + +```json +"end_effector_bounds": { + "max": [0.24, 0.20, 0.10], + "min": [0.16, -0.08, 0.03] +} +``` + +### Collecting Demonstrations + +With the bounds defined, you can safely collect demonstrations for training. Training RL with off-policy algorithm allows us to use offline datasets collected in order to improve the efficiency of the learning process. + +**Setting Up Record Mode** + +Create a configuration file for recording demonstrations (or edit an existing one like [env_config_so100.json](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_so100.json)): + +1. Set `mode` to `"record"` +2. Specify a unique `repo_id` for your dataset (e.g., "username/task_name") +3. Set `num_episodes` to the number of demonstrations you want to collect +4. Set `crop_params_dict` to `null` initially (we'll determine crops later) +5. Configure `robot`, `cameras`, and other hardware settings + +Example configuration section: +```json +"mode": "record", +"repo_id": "username/pick_lift_cube", +"dataset_root": null, +"task": "pick_and_lift", +"num_episodes": 15, +"episode": 0, +"push_to_hub": true +``` + +### Using a Teleoperation Device + +Along with your robot, you will need a teleoperation device to control it in order to collect datasets of your task and perform interventions during the online training. +We support using a gamepad or a keyboard or the leader arm of the robot. + +HIL-Serl learns actions in the end-effector space of the robot. Therefore, the teleoperation will control the end-effector's x,y,z displacements. + +For that we need to define a version of the robot that takes actions in the end-effector space. Check the robot class `SO100FollowerEndEffector` and its configuration `SO100FollowerEndEffectorConfig` for the default parameters related to the end-effector space. + +```python +class SO100FollowerEndEffectorConfig(SO100FollowerConfig): + """Configuration for the SO100FollowerEndEffector robot.""" + + # Default bounds for the end-effector position (in meters) + end_effector_bounds: dict[str, list[float]] = field( # bounds for the end-effector in x,y,z direction + default_factory=lambda: { + "min": [-1.0, -1.0, -1.0], # min x, y, z + "max": [1.0, 1.0, 1.0], # max x, y, z + } + ) + + max_gripper_pos: float = 50 # maximum gripper position that the gripper will be open at + + end_effector_step_sizes: dict[str, float] = field( # maximum step size for the end-effector in x,y,z direction + default_factory=lambda: { + "x": 0.02, + "y": 0.02, + "z": 0.02, + } + ) +``` + +The `Teleoperator` defines the teleoperation device. You can check the list of available teleoperators in `lerobot/common/teleoperators`. + +**Setting up the Gamepad** + +The gamepad provides a very convenient way to control the robot and the episode state. + +To setup the gamepad, you need to set the `control_mode` to `"gamepad"` and define the `teleop` section in the configuration file. + +```json + "teleop": { + "type": "gamepad", + "use_gripper": true + }, +``` + +

+ Figure shows the control mappings on a Logitech gamepad. +

+

Gamepad button mapping for robot control and episode management

+ +**Setting up the SO101 leader** + +The SO101 leader arm has reduced gears that allows it to move and track the follower arm during exploration. Therefore, taking over is much smoother than the gearless SO100. + +To setup the SO101 leader, you need to set the `control_mode` to `"leader"` and define the `teleop` section in the configuration file. + +```json + "teleop": { + "type": "so101_leader", + "port": "/dev/tty.usbmodem585A0077921", # check your port number + "use_degrees": true + }, +``` + +In order to annotate the success/failure of the episode, **you will need** to use a keyboard to press `s` for success, `esc` for failure. +During the online training, press `space` to take over the policy and `space` again to give the control back to the policy. + +
+Video: SO101 leader teleoperation + +
+ +
+ +

SO101 leader teleoperation example, the leader tracks the follower, press `space` to intervene

+
+ +**Recording Demonstrations** + +Start the recording process, an example of the config file can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_so100.json): + +```bash +python lerobot/scripts/rl/gym_manipulator.py --config_path lerobot/configs/env_config_so100.json +``` + +During recording: +1. The robot will reset to the initial position defined in the configuration file `fixed_reset_joint_positions` +2. Complete the task successfully +3. The episode ends with a reward of 1 when you press the "success" button +4. If the time limit is reached, or the fail button is pressed, the episode ends with a reward of 0 +5. You can rerecord an episode by pressing the "rerecord" button +6. The process automatically continues to the next episode +7. After recording all episodes, the dataset is pushed to the Hugging Face Hub (optional) and saved locally + + +### Processing the Dataset + +After collecting demonstrations, process them to determine optimal camera crops. +Reinforcement learning is sensitive to background distractions, so it is important to crop the images to the relevant workspace area. + +Visual RL algorithms learn directly from pixel inputs, making them vulnerable to irrelevant visual information. Background elements like changing lighting, shadows, people moving, or objects outside the workspace can confuse the learning process. Good ROI selection should: +- Include only the essential workspace where the task happens +- Capture the robot's end-effector and all objects involved in the task +- Exclude unnecessary background elements and distractions + +Note: If you already know the crop parameters, you can skip this step and just set the `crop_params_dict` in the configuration file during recording. + +**Determining Crop Parameters** + +Use the `crop_dataset_roi.py` script to interactively select regions of interest in your camera images: + +```bash +python lerobot/scripts/rl/crop_dataset_roi.py --repo-id username/pick_lift_cube +``` + +1. For each camera view, the script will display the first frame +2. Draw a rectangle around the relevant workspace area +3. Press 'c' to confirm the selection +4. Repeat for all camera views +5. The script outputs cropping parameters and creates a new cropped dataset + +Example output: +``` +Selected Rectangular Regions of Interest (top, left, height, width): +observation.images.side: [180, 207, 180, 200] +observation.images.front: [180, 250, 120, 150] +``` + +

+ +

+ +

Interactive cropping tool for selecting regions of interest

+ + +**Updating Configuration** + +Add these crop parameters to your training configuration: + +```json +"crop_params_dict": { + "observation.images.side": [180, 207, 180, 200], + "observation.images.front": [180, 250, 120, 150] +}, +"resize_size": [128, 128] +``` + +**Recommended image resolution** + +Most vision-based policies have been validated on square inputs of either **128×128** (default) or **64×64** pixels. We therefore advise setting the resize_size parameter to [128, 128] – or [64, 64] if you need to save GPU memory and bandwidth. Other resolutions are possible but have not been extensively tested. + + +### Training a Reward Classifier + +The reward classifier plays an important role in the HIL-SERL workflow by automating reward assignment and automatically detecting episode success. Instead of manually defining reward functions or relying on human feedback for every timestep, the reward classifier learns to predict success/failure from visual observations. This enables the RL algorithm to learn efficiently by providing consistent and automated reward signals based on the robot's camera inputs. + +This guide explains how to train a reward classifier for human-in-the-loop reinforcement learning implementation of LeRobot. Reward classifiers learn to predict the reward value given a state which can be used in an RL setup to train a policy. + +**Note**: Training a reward classifier is optional. You can start the first round of RL experiments by annotating the success manually with your gamepad or keyboard device. + +The reward classifier implementation in `modeling_classifier.py` uses a pretrained vision model to process the images. It can output either a single value for binary rewards to predict success/fail cases or multiple values for multi-class settings. + +**Collecting a Dataset for the reward classifier** + +Before training, you need to collect a dataset with labeled examples. The `record_dataset` function in `gym_manipulator.py` enables the process of collecting a dataset of observations, actions, and rewards. + +To collect a dataset, you need to modify some parameters in the environment configuration based on HILSerlRobotEnvConfig. + +```bash +python lerobot/scripts/rl/gym_manipulator.py --config_path lerobot/configs/reward_classifier_train_config.json +``` + +**Key Parameters for Data Collection** + +- **mode**: set it to `"record"` to collect a dataset +- **repo_id**: `"hf_username/dataset_name"`, name of the dataset and repo on the hub +- **num_episodes**: Number of episodes to record +- **number_of_steps_after_success**: Number of additional frames to record after a success (reward=1) is detected +- **fps**: Number of frames per second to record +- **push_to_hub**: Whether to push the dataset to the hub + +The `number_of_steps_after_success` parameter is crucial as it allows you to collect more positive examples. When a success is detected, the system will continue recording for the specified number of steps while maintaining the reward=1 label. Otherwise, there won't be enough states in the dataset labeled to 1 to train a good classifier. + +Example configuration section for data collection: + +```json +{ + "mode": "record", + "repo_id": "hf_username/dataset_name", + "dataset_root": "data/your_dataset", + "num_episodes": 20, + "push_to_hub": true, + "fps": 10, + "number_of_steps_after_success": 15 +} +``` + +**Reward Classifier Configuration** + +The reward classifier is configured using `configuration_classifier.py`. Here are the key parameters: + +- **model_name**: Base model architecture (e.g., we mainly use `"helper2424/resnet10"`) +- **model_type**: `"cnn"` or `"transformer"` +- **num_cameras**: Number of camera inputs +- **num_classes**: Number of output classes (typically 2 for binary success/failure) +- **hidden_dim**: Size of hidden representation +- **dropout_rate**: Regularization parameter +- **learning_rate**: Learning rate for optimizer + +Example configuration for training the [reward classifier](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/reward_classifier_train_config.json): + +```json +{ + "policy": { + "type": "reward_classifier", + "model_name": "helper2424/resnet10", + "model_type": "cnn", + "num_cameras": 2, + "num_classes": 2, + "hidden_dim": 256, + "dropout_rate": 0.1, + "learning_rate": 1e-4, + "device": "cuda", + "use_amp": true, + "input_features": { + "observation.images.front": { + "type": "VISUAL", + "shape": [3, 128, 128] + }, + "observation.images.side": { + "type": "VISUAL", + "shape": [3, 128, 128] + } + } + } +} +``` + +**Training the Classifier** + +To train the classifier, use the `train.py` script with your configuration: + +```bash +python lerobot/scripts/train.py --config_path path/to/reward_classifier_train_config.json +``` + +**Deploying and Testing the Model** + +To use your trained reward classifier, configure the `HILSerlRobotEnvConfig` to use your model: + +```python +env_config = HILSerlRobotEnvConfig( + reward_classifier_pretrained_path="path_to_your_pretrained_trained_model", + # Other environment parameters +) +``` +or set the argument in the json config file. + +```json +{ + "reward_classifier_pretrained_path": "path_to_your_pretrained_model" +} +``` + +Run `gym_manipulator.py` to test the model. +```bash +python lerobot/scripts/rl/gym_manipulator.py --config_path path/to/env_config.json +``` + +The reward classifier will automatically provide rewards based on the visual input from the robot's cameras. + +**Example Workflow for training the reward classifier** + +1. **Create the configuration files**: + Create the necessary json configuration files for the reward classifier and the environment. Check the examples [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/tree/main). + +2. **Collect a dataset**: + ```bash + python lerobot/scripts/rl/gym_manipulator.py --config_path lerobot/configs/env_config.json + ``` + +3. **Train the classifier**: + ```bash + python lerobot/scripts/train.py --config_path lerobot/configs/reward_classifier_train_config.json + ``` + +4. **Test the classifier**: + ```bash + python lerobot/scripts/rl/gym_manipulator.py --config_path lerobot/configs/env_config.json + ``` + +### Training with Actor-Learner + +The LeRobot system uses a distributed actor-learner architecture for training. This architecture decouples robot interactions from the learning process, allowing them to run concurrently without blocking each other. The actor server handles robot observations and actions, sending interaction data to the learner server. The learner server performs gradient descent and periodically updates the actor's policy weights. You will need to start two processes: a learner and an actor. + +**Configuration Setup** + +Create a training configuration file (example available [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/train_config_hilserl_so100.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/configs/train.py`. + +1. Configure the policy settings (`type="sac"`, `device`, etc.) +2. Set `dataset` to your cropped dataset +3. Configure environment settings with crop parameters +4. Check the other parameters related to SAC in [configuration_sac.py](https://github.com/huggingface/lerobot/blob/19bb621a7d0a31c20cd3cc08b1dbab68d3031454/lerobot/common/policies/sac/configuration_sac.py#L79). +5. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task. + +**Starting the Learner** + +First, start the learner server process: + +```bash +python lerobot/scripts/rl/learner.py --config_path lerobot/configs/train_config_hilserl_so100.json +``` + +The learner: +- Initializes the policy network +- Prepares replay buffers +- Opens a `gRPC` server to communicate with actors +- Processes transitions and updates the policy + +**Starting the Actor** + +In a separate terminal, start the actor process with the same configuration: + +```bash +python lerobot/scripts/rl/actor.py --config_path lerobot/configs/train_config_hilserl_so100.json +``` + +The actor: +- Connects to the learner via `gRPC` +- Initializes the environment +- Execute rollouts of the policy to collect experience +- Sends transitions to the learner +- Receives updated policy parameters + +**Training Flow** + +The training proceeds automatically: + +1. The actor executes the policy in the environment +2. Transitions are collected and sent to the learner +3. The learner updates the policy based on these transitions +4. Updated policy parameters are sent back to the actor +5. The process continues until the specified step limit is reached + +**Human in the Loop** + +- The key to learning efficiently is to have human interventions to provide corrective feedback and completing the task to aide the policy learning and exploration. +- To perform human interventions, you can press the upper right trigger button on the gamepad (or the `space` key on the keyboard). This will pause the policy actions and allow you to take over. +- A successful experiment is one where the human has to intervene at the start but then reduces the amount of interventions as the policy improves. You can monitor the intervention rate in the `wandb` dashboard. + +

+ Figure shows the control mappings on a Logitech gamepad. +

+ +

Example showing how human interventions help guide policy learning over time

+ +- The figure shows the plot of the episodic reward over interaction step. The figure shows the effect of human interventions on the policy learning. +- The orange curve is an experiment without any human interventions. While the pink and blue curves are experiments with human interventions. +- We can observe that the number of steps where the policy starts achieving the maximum reward is cut by a quarter when human interventions are present. + +**Monitoring and Debugging** + +If you have `wandb.enable` set to `true` in your configuration, you can monitor training progress in real-time through the [Weights & Biases](https://wandb.ai/site/) dashboard. + +### Guide to Human Interventions +The learning process is very sensitive to the intervention strategy. It will takes a few runs to understand how to intervene effectively. Some tips and hints: +- Allow the policy to explore for a few episodes at the start of training. +- Avoid intervening for long periods of time. Try to intervene in situation to correct the robot's behaviour when it goes off track. +- Once the policy starts achieving the task, even if its not perfect, you can limit your interventions to simple quick actions like a simple grasping commands. + +The ideal behaviour is that your intervention rate should drop gradually during training as shown in the figure below. + +

+ Intervention rate +

+ +

Plot of the intervention rate during a training run on a pick and lift cube task

+ +### Key hyperparameters to tune + +Some configuration values have a disproportionate impact on training stability and speed: + +- **`temperature_init`** (`policy.temperature_init`) – initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning. +- **`policy_parameters_push_frequency`** (`policy.actor_learner_config.policy_parameters_push_frequency`) – interval in *seconds* between two weight pushes from the learner to the actor. The default is `4 s`. Decrease to **1-2 s** to provide fresher weights (at the cost of more network traffic); increase only if your connection is slow, as this will reduce sample efficiency. +- **`storage_device`** (`policy.storage_device`) – device on which the learner keeps the policy parameters. If you have spare GPU memory, set this to `"cuda"` (instead of the default `"cpu"`). Keeping the weights on-GPU removes CPU→GPU transfer overhead and can significantly increase the number of learner updates per second. + + +Congrats 🎉, you have finished this tutorial! + +> [!TIP] +> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). + +Paper citation: +``` +@article{luo2024precise, + title={Precise and Dexterous Robotic Manipulation via Human-in-the-Loop Reinforcement Learning}, + author={Luo, Jianlan and Xu, Charles and Wu, Jeffrey and Levine, Sergey}, + journal={arXiv preprint arXiv:2410.21845}, + year={2024} +} +``` diff --git a/docs/source/hilserl_sim.mdx b/docs/source/hilserl_sim.mdx new file mode 100644 index 0000000000..3239ba91ac --- /dev/null +++ b/docs/source/hilserl_sim.mdx @@ -0,0 +1,120 @@ +# Train RL in Simulation + +This guide explains how to use the `gym_hil` simulation environments as an alternative to real robots when working with the LeRobot framework for Human-In-the-Loop (HIL) reinforcement learning. + +`gym_hil` is a package that provides Gymnasium-compatible simulation environments specifically designed for Human-In-the-Loop reinforcement learning. These environments allow you to: + +- Train policies in simulation to test the RL stack before training on real robots + +- Collect demonstrations in sim using external devices like gamepads or keyboards +- Perform human interventions during policy learning + +Currently, the main environment is a Franka Panda robot simulation based on MuJoCo, with tasks like picking up a cube. + + +## Installation + +First, install the `gym_hil` package within the LeRobot environment: + +```bash +pip install -e ".[hilserl]" +``` + +## What do I need? + +- A gamepad or keyboard to control the robot +- A Nvidia GPU + + + +## Configuration + +To use `gym_hil` with LeRobot, you need to create a configuration file. An example is provided [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/gym_hil_env.json). Key configuration sections include: + +### Environment Type and Task + +```json +{ + "type": "hil", + "name": "franka_sim", + "task": "PandaPickCubeGamepad-v0", + "device": "cuda" +} +``` + +Available tasks: +- `PandaPickCubeBase-v0`: Basic environment +- `PandaPickCubeGamepad-v0`: With gamepad control +- `PandaPickCubeKeyboard-v0`: With keyboard control + +### Gym Wrappers Configuration + +```json +"wrapper": { + "gripper_penalty": -0.02, + "control_time_s": 15.0, + "use_gripper": true, + "fixed_reset_joint_positions": [0.0, 0.195, 0.0, -2.43, 0.0, 2.62, 0.785], + "end_effector_step_sizes": { + "x": 0.025, + "y": 0.025, + "z": 0.025 + }, + "control_mode": "gamepad" + } +``` + +Important parameters: +- `gripper_penalty`: Penalty for excessive gripper movement +- `use_gripper`: Whether to enable gripper control +- `end_effector_step_sizes`: Size of the steps in the x,y,z axes of the end-effector +- `control_mode`: Set to `"gamepad"` to use a gamepad controller + +## Running with HIL RL of LeRobot + +### Basic Usage + +To run the environment, set mode to null: + +```python +python lerobot/scripts/rl/gym_manipulator.py --config_path path/to/gym_hil_env.json +``` + +### Recording a Dataset + +To collect a dataset, set the mode to `record` whilst defining the repo_id and number of episodes to record: + +```python +python lerobot/scripts/rl/gym_manipulator.py --config_path path/to/gym_hil_env.json +``` + +### Training a Policy + +To train a policy, checkout the configuration example available [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/train_gym_hil_env.json) and run the actor and learner servers: + +```python +python lerobot/scripts/rl/actor.py --config_path path/to/train_gym_hil_env.json +``` + +In a different terminal, run the learner server: + +```python +python lerobot/scripts/rl/learner.py --config_path path/to/train_gym_hil_env.json +``` + +The simulation environment provides a safe and repeatable way to develop and test your Human-In-the-Loop reinforcement learning components before deploying to real robots. + +Congrats 🎉, you have finished this tutorial! + +> [!TIP] +> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). + +Paper citation: +``` +@article{luo2024precise, + title={Precise and Dexterous Robotic Manipulation via Human-in-the-Loop Reinforcement Learning}, + author={Luo, Jianlan and Xu, Charles and Wu, Jeffrey and Levine, Sergey}, + journal={arXiv preprint arXiv:2410.21845}, + year={2024} +} +``` diff --git a/lerobot/common/constants.py b/lerobot/common/constants.py index e78e748baf..990f2aa1eb 100644 --- a/lerobot/common/constants.py +++ b/lerobot/common/constants.py @@ -22,6 +22,7 @@ OBS_IMAGE = "observation.image" OBS_IMAGES = "observation.images" ACTION = "action" +REWARD = "next.reward" ROBOTS = "robots" TELEOPERATORS = "teleoperators" diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py index c99fba811d..ea081e9fbf 100644 --- a/lerobot/common/envs/configs.py +++ b/lerobot/common/envs/configs.py @@ -14,10 +14,13 @@ import abc from dataclasses import dataclass, field +from typing import Any, Optional import draccus from lerobot.common.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE +from lerobot.common.robots import RobotConfig +from lerobot.common.teleoperators.config import TeleoperatorConfig from lerobot.configs.types import FeatureType, PolicyFeature @@ -155,3 +158,116 @@ def gym_kwargs(self) -> dict: "visualization_height": self.visualization_height, "max_episode_steps": self.episode_length, } + + +@dataclass +class VideoRecordConfig: + """Configuration for video recording in ManiSkill environments.""" + + enabled: bool = False + record_dir: str = "videos" + trajectory_name: str = "trajectory" + + +@dataclass +class EnvTransformConfig: + """Configuration for environment wrappers.""" + + # ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig) + control_mode: str = "gamepad" + display_cameras: bool = False + add_joint_velocity_to_observation: bool = False + add_current_to_observation: bool = False + add_ee_pose_to_observation: bool = False + crop_params_dict: Optional[dict[str, tuple[int, int, int, int]]] = None + resize_size: Optional[tuple[int, int]] = None + control_time_s: float = 20.0 + fixed_reset_joint_positions: Optional[Any] = None + reset_time_s: float = 5.0 + use_gripper: bool = True + gripper_quantization_threshold: float | None = 0.8 + gripper_penalty: float = 0.0 + gripper_penalty_in_reward: bool = False + + +@EnvConfig.register_subclass(name="gym_manipulator") +@dataclass +class HILSerlRobotEnvConfig(EnvConfig): + """Configuration for the HILSerlRobotEnv environment.""" + + robot: Optional[RobotConfig] = None + teleop: Optional[TeleoperatorConfig] = None + wrapper: Optional[EnvTransformConfig] = None + fps: int = 10 + name: str = "real_robot" + mode: str = None # Either "record", "replay", None + repo_id: Optional[str] = None + dataset_root: Optional[str] = None + task: str = "" + num_episodes: int = 10 # only for record mode + episode: int = 0 + device: str = "cuda" + push_to_hub: bool = True + pretrained_policy_name_or_path: Optional[str] = None + reward_classifier_pretrained_path: Optional[str] = None + # For the reward classifier, to record more positive examples after a success + number_of_steps_after_success: int = 0 + + def gym_kwargs(self) -> dict: + return {} + + +@EnvConfig.register_subclass("hil") +@dataclass +class HILEnvConfig(EnvConfig): + """Configuration for the HIL environment.""" + + type: str = "hil" + name: str = "PandaPickCube" + task: str = "PandaPickCubeKeyboard-v0" + use_viewer: bool = True + gripper_penalty: float = 0.0 + use_gamepad: bool = True + state_dim: int = 18 + action_dim: int = 4 + fps: int = 100 + episode_length: int = 100 + video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig) + features: dict[str, PolicyFeature] = field( + default_factory=lambda: { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)), + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(18,)), + } + ) + features_map: dict[str, str] = field( + default_factory=lambda: { + "action": ACTION, + "observation.image": OBS_IMAGE, + "observation.state": OBS_STATE, + } + ) + ################# args from hilserlrobotenv + reward_classifier_pretrained_path: Optional[str] = None + robot_config: Optional[RobotConfig] = None + teleop_config: Optional[TeleoperatorConfig] = None + wrapper: Optional[EnvTransformConfig] = None + mode: str = None # Either "record", "replay", None + repo_id: Optional[str] = None + dataset_root: Optional[str] = None + num_episodes: int = 10 # only for record mode + episode: int = 0 + device: str = "cuda" + push_to_hub: bool = True + pretrained_policy_name_or_path: Optional[str] = None + # For the reward classifier, to record more positive examples after a success + number_of_steps_after_success: int = 0 + ############################ + + @property + def gym_kwargs(self) -> dict: + return { + "use_viewer": self.use_viewer, + "use_gamepad": self.use_gamepad, + "gripper_penalty": self.gripper_penalty, + } diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 8450f84b95..4f5d59c698 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -17,7 +17,7 @@ import gymnasium as gym -from lerobot.common.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv +from lerobot.common.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, PushtEnv, XarmEnv def make_env_config(env_type: str, **kwargs) -> EnvConfig: @@ -27,6 +27,8 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig: return PushtEnv(**kwargs) elif env_type == "xarm": return XarmEnv(**kwargs) + elif env_type == "hil": + return HILEnvConfig(**kwargs) else: raise ValueError(f"Policy type '{env_type}' is not available.") diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py index 83334f876d..66d6e5f93f 100644 --- a/lerobot/common/envs/utils.py +++ b/lerobot/common/envs/utils.py @@ -47,6 +47,10 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten # TODO(aliberts, rcadene): use transforms.ToTensor()? img = torch.from_numpy(img) + # When preprocessing observations in a non-vectorized environment, we need to add a batch dimension. + # This is the case for human-in-the-loop RL where there is only one environment. + if img.ndim == 3: + img = img.unsqueeze(0) # sanity check that images are channel last _, h, w, c = img.shape assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}" @@ -62,13 +66,18 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten return_observations[imgkey] = img if "environment_state" in observations: - return_observations["observation.environment_state"] = torch.from_numpy( - observations["environment_state"] - ).float() + env_state = torch.from_numpy(observations["environment_state"]).float() + if env_state.dim() == 1: + env_state = env_state.unsqueeze(0) + + return_observations["observation.environment_state"] = env_state # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing - # requirement for "agent_pos" - return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float() + agent_pos = torch.from_numpy(observations["agent_pos"]).float() + if agent_pos.dim() == 1: + agent_pos = agent_pos.unsqueeze(0) + return_observations["observation.state"] = agent_pos + return return_observations diff --git a/lerobot/common/model/kinematics.py b/lerobot/common/model/kinematics.py new file mode 100644 index 0000000000..367b609e19 --- /dev/null +++ b/lerobot/common/model/kinematics.py @@ -0,0 +1,483 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +from numpy.typing import NDArray +from scipy.spatial.transform import Rotation + + +def skew_symmetric(w: NDArray[np.float32]) -> NDArray[np.float32]: + """Creates the skew-symmetric matrix from a 3D vector.""" + return np.array([[0, -w[2], w[1]], [w[2], 0, -w[0]], [-w[1], w[0], 0]]) + + +def rodrigues_rotation(w: NDArray[np.float32], theta: float) -> NDArray[np.float32]: + """Computes the rotation matrix using Rodrigues' formula.""" + w_hat = skew_symmetric(w) + return np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat + + +def screw_axis_to_transform(s: NDArray[np.float32], theta: float) -> NDArray[np.float32]: + """Converts a screw axis to a 4x4 transformation matrix.""" + screw_axis_rot = s[:3] + screw_axis_trans = s[3:] + + # Pure translation + if np.allclose(screw_axis_rot, 0) and np.linalg.norm(screw_axis_trans) == 1: + transform = np.eye(4) + transform[:3, 3] = screw_axis_trans * theta + + # Rotation (and potentially translation) + elif np.linalg.norm(screw_axis_rot) == 1: + w_hat = skew_symmetric(screw_axis_rot) + rot_mat = np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat + t = ( + np.eye(3) * theta + (1 - np.cos(theta)) * w_hat + (theta - np.sin(theta)) * w_hat @ w_hat + ) @ screw_axis_trans + transform = np.eye(4) + transform[:3, :3] = rot_mat + transform[:3, 3] = t + else: + raise ValueError("Invalid screw axis parameters") + return transform + + +def pose_difference_se3(pose1: NDArray[np.float32], pose2: NDArray[np.float32]) -> NDArray[np.float32]: + """ + Calculates the SE(3) difference between two 4x4 homogeneous transformation matrices. + SE(3) (Special Euclidean Group) represents rigid body transformations in 3D space, + combining rotation (SO(3)) and translation. + + Each 4x4 matrix has the following structure: + [R11 R12 R13 tx] + [R21 R22 R23 ty] + [R31 R32 R33 tz] + [ 0 0 0 1] + + where R is the 3x3 rotation matrix and [tx,ty,tz] is the translation vector. + + Args: + pose1: A 4x4 numpy array representing the first pose. + pose2: A 4x4 numpy array representing the second pose. + + Returns: + A 6D numpy array concatenating translation and rotation differences. + First 3 elements are the translational difference (position). + Last 3 elements are the rotational difference in axis-angle representation. + """ + rot1 = pose1[:3, :3] + rot2 = pose2[:3, :3] + + translation_diff = pose1[:3, 3] - pose2[:3, 3] + + # Calculate rotational difference using scipy's Rotation library + rot_diff = Rotation.from_matrix(rot1 @ rot2.T) + rotation_diff = rot_diff.as_rotvec() # Axis-angle representation + + return np.concatenate([translation_diff, rotation_diff]) + + +def se3_error(target_pose: NDArray[np.float32], current_pose: NDArray[np.float32]) -> NDArray[np.float32]: + pos_error = target_pose[:3, 3] - current_pose[:3, 3] + + rot_target = target_pose[:3, :3] + rot_current = current_pose[:3, :3] + rot_error_mat = rot_target @ rot_current.T + rot_error = Rotation.from_matrix(rot_error_mat).as_rotvec() + + return np.concatenate([pos_error, rot_error]) + + +class RobotKinematics: + """Robot kinematics class supporting multiple robot models.""" + + # Robot measurements dictionary + ROBOT_MEASUREMENTS = { + "koch": { + "gripper": [0.239, -0.001, 0.024], + "wrist": [0.209, 0, 0.024], + "forearm": [0.108, 0, 0.02], + "humerus": [0, 0, 0.036], + "shoulder": [0, 0, 0], + "base": [0, 0, 0.02], + }, + "moss": { + "gripper": [0.246, 0.013, 0.111], + "wrist": [0.245, 0.002, 0.064], + "forearm": [0.122, 0, 0.064], + "humerus": [0.001, 0.001, 0.063], + "shoulder": [0, 0, 0], + "base": [0, 0, 0.02], + }, + "so_old_calibration": { + "gripper": [0.320, 0, 0.050], + "wrist": [0.278, 0, 0.050], + "forearm": [0.143, 0, 0.044], + "humerus": [0.031, 0, 0.072], + "shoulder": [0, 0, 0], + "base": [0, 0, 0.02], + }, + "so_new_calibration": { + "gripper": [0.33, 0.0, 0.285], + "wrist": [0.30, 0.0, 0.267], + "forearm": [0.25, 0.0, 0.266], + "humerus": [0.06, 0.0, 0.264], + "shoulder": [0.0, 0.0, 0.238], + "base": [0.0, 0.0, 0.12], + }, + } + + def __init__(self, robot_type: str = "so100"): + """Initialize kinematics for the specified robot type. + + Args: + robot_type: String specifying the robot model ("koch", "so100", or "moss") + """ + if robot_type not in self.ROBOT_MEASUREMENTS: + raise ValueError( + f"Unknown robot type: {robot_type}. Available types: {list(self.ROBOT_MEASUREMENTS.keys())}" + ) + + self.robot_type = robot_type + self.measurements = self.ROBOT_MEASUREMENTS[robot_type] + + # Initialize all transformation matrices and screw axes + self._setup_transforms() + + def _create_translation_matrix( + self, x: float = 0.0, y: float = 0.0, z: float = 0.0 + ) -> NDArray[np.float32]: + """Create a 4x4 translation matrix.""" + return np.array([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z], [0, 0, 0, 1]]) + + def _setup_transforms(self): + """Setup all transformation matrices and screw axes for the robot.""" + # Set up rotation matrices (constant across robot types) + + # Gripper orientation + self.gripper_X0 = np.array( + [ + [1, 0, 0, 0], + [0, 0, 1, 0], + [0, -1, 0, 0], + [0, 0, 0, 1], + ], + dtype=np.float32, + ) + + # Wrist orientation + self.wrist_X0 = np.array( + [ + [0, -1, 0, 0], + [1, 0, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1], + ], + dtype=np.float32, + ) + + # Base orientation + self.base_X0 = np.array( + [ + [0, 0, 1, 0], + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 0, 1], + ], + dtype=np.float32, + ) + + # Gripper + # Screw axis of gripper frame wrt base frame + self.S_BG = np.array( + [ + 1, + 0, + 0, + 0, + self.measurements["gripper"][2], + -self.measurements["gripper"][1], + ], + dtype=np.float32, + ) + + # Gripper origin to centroid transform + self.X_GoGc = self._create_translation_matrix(x=0.07) + + # Gripper origin to tip transform + self.X_GoGt = self._create_translation_matrix(x=0.12) + + # 0-position gripper frame pose wrt base + self.X_BoGo = self._create_translation_matrix( + x=self.measurements["gripper"][0], + y=self.measurements["gripper"][1], + z=self.measurements["gripper"][2], + ) + + # Wrist + # Screw axis of wrist frame wrt base frame + self.S_BR = np.array( + [0, 1, 0, -self.measurements["wrist"][2], 0, self.measurements["wrist"][0]], dtype=np.float32 + ) + + # 0-position origin to centroid transform + self.X_RoRc = self._create_translation_matrix(x=0.0035, y=-0.002) + + # 0-position wrist frame pose wrt base + self.X_BR = self._create_translation_matrix( + x=self.measurements["wrist"][0], + y=self.measurements["wrist"][1], + z=self.measurements["wrist"][2], + ) + + # Forearm + # Screw axis of forearm frame wrt base frame + self.S_BF = np.array( + [ + 0, + 1, + 0, + -self.measurements["forearm"][2], + 0, + self.measurements["forearm"][0], + ], + dtype=np.float32, + ) + + # Forearm origin + centroid transform + self.X_ForearmFc = self._create_translation_matrix(x=0.036) + + # 0-position forearm frame pose wrt base + self.X_BF = self._create_translation_matrix( + x=self.measurements["forearm"][0], + y=self.measurements["forearm"][1], + z=self.measurements["forearm"][2], + ) + + # Humerus + # Screw axis of humerus frame wrt base frame + self.S_BH = np.array( + [ + 0, + -1, + 0, + self.measurements["humerus"][2], + 0, + -self.measurements["humerus"][0], + ], + dtype=np.float32, + ) + + # Humerus origin to centroid transform + self.X_HoHc = self._create_translation_matrix(x=0.0475) + + # 0-position humerus frame pose wrt base + self.X_BH = self._create_translation_matrix( + x=self.measurements["humerus"][0], + y=self.measurements["humerus"][1], + z=self.measurements["humerus"][2], + ) + + # Shoulder + # Screw axis of shoulder frame wrt Base frame + self.S_BS = np.array([0, 0, -1, 0, 0, 0], dtype=np.float32) + + # Shoulder origin to centroid transform + self.X_SoSc = self._create_translation_matrix(x=-0.017, z=0.0235) + + # 0-position shoulder frame pose wrt base + self.X_BS = self._create_translation_matrix( + x=self.measurements["shoulder"][0], + y=self.measurements["shoulder"][1], + z=self.measurements["shoulder"][2], + ) + + # Base + # Base origin to centroid transform + self.X_BoBc = self._create_translation_matrix(y=0.015) + + # World to base transform + self.X_WoBo = self._create_translation_matrix( + x=self.measurements["base"][0], + y=self.measurements["base"][1], + z=self.measurements["base"][2], + ) + + # Pre-compute gripper post-multiplication matrix + self._fk_gripper_post = self.X_GoGc @ self.X_BoGo @ self.gripper_X0 + + def forward_kinematics( + self, + robot_pos_deg: NDArray[np.float32], + frame: str = "gripper_tip", + ) -> NDArray[np.float32]: + """Generic forward kinematics. + + Args: + robot_pos_deg: Joint positions in degrees. Can be ``None`` when + computing the *base* frame as it does not depend on joint + angles. + frame: Target frame. One of + ``{"base", "shoulder", "humerus", "forearm", "wrist", "gripper", "gripper_tip"}``. + + Returns + ------- + NDArray[np.float32] + 4×4 homogeneous transformation matrix of the requested frame + expressed in the world coordinate system. + """ + frame = frame.lower() + if frame not in { + "base", + "shoulder", + "humerus", + "forearm", + "wrist", + "gripper", + "gripper_tip", + }: + raise ValueError( + f"Unknown frame '{frame}'. Valid options are base, shoulder, humerus, forearm, wrist, gripper, gripper_tip." + ) + + # Base frame does not rely on joint angles. + if frame == "base": + return self.X_WoBo @ self.X_BoBc @ self.base_X0 + + robot_pos_rad = robot_pos_deg / 180 * np.pi + + # Extract joint angles (note the sign convention for shoulder lift). + theta_shoulder_pan = robot_pos_rad[0] + theta_shoulder_lift = -robot_pos_rad[1] + theta_elbow_flex = robot_pos_rad[2] + theta_wrist_flex = robot_pos_rad[3] + theta_wrist_roll = robot_pos_rad[4] + + # Start with the world-to-base transform; incrementally add successive links. + transformation_matrix = self.X_WoBo @ screw_axis_to_transform(self.S_BS, theta_shoulder_pan) + if frame == "shoulder": + return transformation_matrix @ self.X_SoSc @ self.X_BS + + transformation_matrix = transformation_matrix @ screw_axis_to_transform( + self.S_BH, theta_shoulder_lift + ) + if frame == "humerus": + return transformation_matrix @ self.X_HoHc @ self.X_BH + + transformation_matrix = transformation_matrix @ screw_axis_to_transform(self.S_BF, theta_elbow_flex) + if frame == "forearm": + return transformation_matrix @ self.X_ForearmFc @ self.X_BF + + transformation_matrix = transformation_matrix @ screw_axis_to_transform(self.S_BR, theta_wrist_flex) + if frame == "wrist": + return transformation_matrix @ self.X_RoRc @ self.X_BR @ self.wrist_X0 + + transformation_matrix = transformation_matrix @ screw_axis_to_transform(self.S_BG, theta_wrist_roll) + if frame == "gripper": + return transformation_matrix @ self._fk_gripper_post + else: # frame == "gripper_tip" + return transformation_matrix @ self.X_GoGt @ self.X_BoGo @ self.gripper_X0 + + def compute_jacobian( + self, robot_pos_deg: NDArray[np.float32], frame: str = "gripper_tip" + ) -> NDArray[np.float32]: + """Finite differences to compute the Jacobian. + J(i, j) represents how the ith component of the end-effector's velocity changes wrt a small change + in the jth joint's velocity. + + Args: + robot_pos_deg: Current joint positions in degrees + fk_func: Forward kinematics function to use (defaults to fk_gripper) + """ + + eps = 1e-8 + jac = np.zeros(shape=(6, 5)) + delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64) + for el_ix in range(len(robot_pos_deg[:-1])): + delta *= 0 + delta[el_ix] = eps / 2 + sdot = ( + pose_difference_se3( + self.forward_kinematics(robot_pos_deg[:-1] + delta, frame), + self.forward_kinematics(robot_pos_deg[:-1] - delta, frame), + ) + / eps + ) + jac[:, el_ix] = sdot + return jac + + def compute_positional_jacobian( + self, robot_pos_deg: NDArray[np.float32], frame: str = "gripper_tip" + ) -> NDArray[np.float32]: + """Finite differences to compute the positional Jacobian. + J(i, j) represents how the ith component of the end-effector's position changes wrt a small change + in the jth joint's velocity. + + Args: + robot_pos_deg: Current joint positions in degrees + fk_func: Forward kinematics function to use (defaults to fk_gripper) + """ + eps = 1e-8 + jac = np.zeros(shape=(3, 5)) + delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64) + for el_ix in range(len(robot_pos_deg[:-1])): + delta *= 0 + delta[el_ix] = eps / 2 + sdot = ( + self.forward_kinematics(robot_pos_deg[:-1] + delta, frame)[:3, 3] + - self.forward_kinematics(robot_pos_deg[:-1] - delta, frame)[:3, 3] + ) / eps + jac[:, el_ix] = sdot + return jac + + def ik( + self, + current_joint_pos: NDArray[np.float32], + desired_ee_pose: NDArray[np.float32], + position_only: bool = True, + frame: str = "gripper_tip", + max_iterations: int = 5, + learning_rate: float = 1, + ) -> NDArray[np.float32]: + """Inverse kinematics using gradient descent. + + Args: + current_joint_state: Initial joint positions in degrees + desired_ee_pose: Target end-effector pose as a 4x4 transformation matrix + position_only: If True, only match end-effector position, not orientation + frame: Target frame. One of + ``{"base", "shoulder", "humerus", "forearm", "wrist", "gripper", "gripper_tip"}``. + max_iterations: Maximum number of iterations to run + learning_rate: Learning rate for gradient descent + + Returns: + Joint positions in degrees that achieve the desired end-effector pose + """ + # Do gradient descent. + current_joint_state = current_joint_pos.copy() + for _ in range(max_iterations): + current_ee_pose = self.forward_kinematics(current_joint_state, frame) + if not position_only: + error = se3_error(desired_ee_pose, current_ee_pose) + jac = self.compute_jacobian(current_joint_state, frame) + else: + error = desired_ee_pose[:3, 3] - current_ee_pose[:3, 3] + jac = self.compute_positional_jacobian(current_joint_state, frame) + delta_angles = np.linalg.pinv(jac) @ error + current_joint_state[:-1] += learning_rate * delta_angles + + if np.linalg.norm(error) < 5e-3: + return current_joint_state + return current_joint_state diff --git a/lerobot/common/optim/optimizers.py b/lerobot/common/optim/optimizers.py index 0cf4124ce6..903434f593 100644 --- a/lerobot/common/optim/optimizers.py +++ b/lerobot/common/optim/optimizers.py @@ -14,8 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field from pathlib import Path +from typing import Any import draccus import torch @@ -44,7 +45,16 @@ def default_choice_name(cls) -> str | None: return "adam" @abc.abstractmethod - def build(self) -> torch.optim.Optimizer: + def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]: + """ + Build the optimizer. It can be a single optimizer or a dictionary of optimizers. + NOTE: Multiple optimizers are useful when you have different models to optimize. + For example, you can have one optimizer for the policy and another one for the value function + in reinforcement learning settings. + + Returns: + The optimizer or a dictionary of optimizers. + """ raise NotImplementedError @@ -94,7 +104,76 @@ def build(self, params: dict) -> torch.optim.Optimizer: return torch.optim.SGD(params, **kwargs) -def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None: +@OptimizerConfig.register_subclass("multi_adam") +@dataclass +class MultiAdamConfig(OptimizerConfig): + """Configuration for multiple Adam optimizers with different parameter groups. + + This creates a dictionary of Adam optimizers, each with its own hyperparameters. + + Args: + lr: Default learning rate (used if not specified for a group) + weight_decay: Default weight decay (used if not specified for a group) + optimizer_groups: Dictionary mapping parameter group names to their hyperparameters + grad_clip_norm: Gradient clipping norm + """ + + lr: float = 1e-3 + weight_decay: float = 0.0 + grad_clip_norm: float = 10.0 + optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict) + + def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]: + """Build multiple Adam optimizers. + + Args: + params_dict: Dictionary mapping parameter group names to lists of parameters + The keys should match the keys in optimizer_groups + + Returns: + Dictionary mapping parameter group names to their optimizers + """ + optimizers = {} + + for name, params in params_dict.items(): + # Get group-specific hyperparameters or use defaults + group_config = self.optimizer_groups.get(name, {}) + + # Create optimizer with merged parameters (defaults + group-specific) + optimizer_kwargs = { + "lr": group_config.get("lr", self.lr), + "betas": group_config.get("betas", (0.9, 0.999)), + "eps": group_config.get("eps", 1e-5), + "weight_decay": group_config.get("weight_decay", self.weight_decay), + } + + optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs) + + return optimizers + + +def save_optimizer_state( + optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path +) -> None: + """Save optimizer state to disk. + + Args: + optimizer: Either a single optimizer or a dictionary of optimizers. + save_dir: Directory to save the optimizer state. + """ + if isinstance(optimizer, dict): + # Handle dictionary of optimizers + for name, opt in optimizer.items(): + optimizer_dir = save_dir / name + optimizer_dir.mkdir(exist_ok=True, parents=True) + _save_single_optimizer_state(opt, optimizer_dir) + else: + # Handle single optimizer + _save_single_optimizer_state(optimizer, save_dir) + + +def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None: + """Save a single optimizer's state to disk.""" state = optimizer.state_dict() param_groups = state.pop("param_groups") flat_state = flatten_dict(state) @@ -102,11 +181,44 @@ def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> No write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS) -def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer: +def load_optimizer_state( + optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path +) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]: + """Load optimizer state from disk. + + Args: + optimizer: Either a single optimizer or a dictionary of optimizers. + save_dir: Directory to load the optimizer state from. + + Returns: + The updated optimizer(s) with loaded state. + """ + if isinstance(optimizer, dict): + # Handle dictionary of optimizers + loaded_optimizers = {} + for name, opt in optimizer.items(): + optimizer_dir = save_dir / name + if optimizer_dir.exists(): + loaded_optimizers[name] = _load_single_optimizer_state(opt, optimizer_dir) + else: + loaded_optimizers[name] = opt + return loaded_optimizers + else: + # Handle single optimizer + return _load_single_optimizer_state(optimizer, save_dir) + + +def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer: + """Load a single optimizer's state from disk.""" current_state_dict = optimizer.state_dict() flat_state = load_file(save_dir / OPTIMIZER_STATE) state = unflatten_dict(flat_state) - loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}} + + # Handle case where 'state' key might not exist (for newly created optimizers) + if "state" in state: + loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}} + else: + loaded_state_dict = {"state": {}} if "param_groups" in current_state_dict: param_groups = deserialize_json_into_object( diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 3aade06656..682bb8cee9 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -27,6 +27,8 @@ from lerobot.common.policies.pi0.configuration_pi0 import PI0Config from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.policies.sac.configuration_sac import SACConfig +from lerobot.common.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig from lerobot.common.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig @@ -60,6 +62,14 @@ def get_policy_class(name: str) -> PreTrainedPolicy: from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy return PI0FASTPolicy + elif name == "sac": + from lerobot.common.policies.sac.modeling_sac import SACPolicy + + return SACPolicy + elif name == "reward_classifier": + from lerobot.common.policies.sac.reward_model.modeling_classifier import Classifier + + return Classifier elif name == "smolvla": from lerobot.common.policies.smolvla.modeling_smolvla import SmolVLAPolicy @@ -81,8 +91,12 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return PI0Config(**kwargs) elif policy_type == "pi0fast": return PI0FASTConfig(**kwargs) + elif policy_type == "sac": + return SACConfig(**kwargs) elif policy_type == "smolvla": return SmolVLAConfig(**kwargs) + elif policy_type == "reward_classifier": + return RewardClassifierConfig(**kwargs) else: raise ValueError(f"Policy type '{policy_type}' is not available.") diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index b3255ec106..9cc94b9298 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -151,6 +151,7 @@ def __init__( # TODO(rcadene): should we remove torch.no_grad? @torch.no_grad def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + # TODO: Remove this shallow copy batch = dict(batch) # shallow copy avoids mutating the input batch for key, ft in self.features.items(): if key not in batch: @@ -252,3 +253,168 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: else: raise ValueError(norm_mode) return batch + + +# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization +# and remove the `Normalize` and `Unnormalize` classes. +def _initialize_stats_buffers( + module: nn.Module, + features: dict[str, PolicyFeature], + norm_map: dict[str, NormalizationMode], + stats: dict[str, dict[str, Tensor]] | None = None, +) -> None: + """Register statistics buffers (mean/std or min/max) on the given *module*. + + The logic matches the previous constructors of `NormalizeBuffer` and `UnnormalizeBuffer`, + but is factored out so it can be reused by both classes and stay in sync. + """ + for key, ft in features.items(): + norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY) + if norm_mode is NormalizationMode.IDENTITY: + continue + + shape: tuple[int, ...] = tuple(ft.shape) + if ft.type is FeatureType.VISUAL: + # reduce spatial dimensions, keep channel dimension only + c, *_ = shape + shape = (c, 1, 1) + + prefix = key.replace(".", "_") + + if norm_mode is NormalizationMode.MEAN_STD: + mean = torch.full(shape, torch.inf, dtype=torch.float32) + std = torch.full(shape, torch.inf, dtype=torch.float32) + + if stats and key in stats and "mean" in stats[key] and "std" in stats[key]: + mean_data = stats[key]["mean"] + std_data = stats[key]["std"] + if isinstance(mean_data, torch.Tensor): + # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated + # tensors anywhere (for example, when we use the same stats for normalization and + # unnormalization). See the logic here + # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. + mean = mean_data.clone().to(dtype=torch.float32) + std = std_data.clone().to(dtype=torch.float32) + else: + raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).") + + module.register_buffer(f"{prefix}_mean", mean) + module.register_buffer(f"{prefix}_std", std) + continue + + if norm_mode is NormalizationMode.MIN_MAX: + min_val = torch.full(shape, torch.inf, dtype=torch.float32) + max_val = torch.full(shape, torch.inf, dtype=torch.float32) + + if stats and key in stats and "min" in stats[key] and "max" in stats[key]: + min_data = stats[key]["min"] + max_data = stats[key]["max"] + if isinstance(min_data, torch.Tensor): + min_val = min_data.clone().to(dtype=torch.float32) + max_val = max_data.clone().to(dtype=torch.float32) + else: + raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).") + + module.register_buffer(f"{prefix}_min", min_val) + module.register_buffer(f"{prefix}_max", max_val) + continue + + raise ValueError(norm_mode) + + +class NormalizeBuffer(nn.Module): + """Same as `Normalize` but statistics are stored as registered buffers rather than parameters.""" + + def __init__( + self, + features: dict[str, PolicyFeature], + norm_map: dict[str, NormalizationMode], + stats: dict[str, dict[str, Tensor]] | None = None, + ): + super().__init__() + self.features = features + self.norm_map = norm_map + + _initialize_stats_buffers(self, features, norm_map, stats) + + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + batch = dict(batch) + for key, ft in self.features.items(): + if key not in batch: + continue + + norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) + if norm_mode is NormalizationMode.IDENTITY: + continue + + prefix = key.replace(".", "_") + + if norm_mode is NormalizationMode.MEAN_STD: + mean = getattr(self, f"{prefix}_mean") + std = getattr(self, f"{prefix}_std") + assert not torch.isinf(mean).any(), _no_stats_error_str("mean") + assert not torch.isinf(std).any(), _no_stats_error_str("std") + batch[key] = (batch[key] - mean) / (std + 1e-8) + continue + + if norm_mode is NormalizationMode.MIN_MAX: + min_val = getattr(self, f"{prefix}_min") + max_val = getattr(self, f"{prefix}_max") + assert not torch.isinf(min_val).any(), _no_stats_error_str("min") + assert not torch.isinf(max_val).any(), _no_stats_error_str("max") + batch[key] = (batch[key] - min_val) / (max_val - min_val + 1e-8) + batch[key] = batch[key] * 2 - 1 + continue + + raise ValueError(norm_mode) + + return batch + + +class UnnormalizeBuffer(nn.Module): + """Inverse operation of `NormalizeBuffer`. Uses registered buffers for statistics.""" + + def __init__( + self, + features: dict[str, PolicyFeature], + norm_map: dict[str, NormalizationMode], + stats: dict[str, dict[str, Tensor]] | None = None, + ): + super().__init__() + self.features = features + self.norm_map = norm_map + + _initialize_stats_buffers(self, features, norm_map, stats) + + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + # batch = dict(batch) + for key, ft in self.features.items(): + if key not in batch: + continue + + norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) + if norm_mode is NormalizationMode.IDENTITY: + continue + + prefix = key.replace(".", "_") + + if norm_mode is NormalizationMode.MEAN_STD: + mean = getattr(self, f"{prefix}_mean") + std = getattr(self, f"{prefix}_std") + assert not torch.isinf(mean).any(), _no_stats_error_str("mean") + assert not torch.isinf(std).any(), _no_stats_error_str("std") + batch[key] = batch[key] * std + mean + continue + + if norm_mode is NormalizationMode.MIN_MAX: + min_val = getattr(self, f"{prefix}_min") + max_val = getattr(self, f"{prefix}_max") + assert not torch.isinf(min_val).any(), _no_stats_error_str("min") + assert not torch.isinf(max_val).any(), _no_stats_error_str("max") + batch[key] = (batch[key] + 1) / 2 + batch[key] = batch[key] * (max_val - min_val) + min_val + continue + + raise ValueError(norm_mode) + + return batch diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py new file mode 100644 index 0000000000..db58beb2f0 --- /dev/null +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -0,0 +1,245 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_STATE +from lerobot.common.optim.optimizers import MultiAdamConfig +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import NormalizationMode + + +def is_image_feature(key: str) -> bool: + """Check if a feature key represents an image feature. + + Args: + key: The feature key to check + + Returns: + True if the key represents an image feature, False otherwise + """ + return key.startswith(OBS_IMAGE) + + +@dataclass +class ConcurrencyConfig: + """Configuration for the concurrency of the actor and learner. + Possible values are: + - "threads": Use threads for the actor and learner. + - "processes": Use processes for the actor and learner. + """ + + actor: str = "threads" + learner: str = "threads" + + +@dataclass +class ActorLearnerConfig: + learner_host: str = "127.0.0.1" + learner_port: int = 50051 + policy_parameters_push_frequency: int = 4 + queue_get_timeout: float = 2 + + +@dataclass +class CriticNetworkConfig: + hidden_dims: list[int] = field(default_factory=lambda: [256, 256]) + activate_final: bool = True + final_activation: str | None = None + + +@dataclass +class ActorNetworkConfig: + hidden_dims: list[int] = field(default_factory=lambda: [256, 256]) + activate_final: bool = True + + +@dataclass +class PolicyConfig: + use_tanh_squash: bool = True + std_min: float = 1e-5 + std_max: float = 10.0 + init_final: float = 0.05 + + +@PreTrainedConfig.register_subclass("sac") +@dataclass +class SACConfig(PreTrainedConfig): + """Soft Actor-Critic (SAC) configuration. + + SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy + reinforcement learning framework. It learns a policy and a Q-function simultaneously + using experience collected from the environment. + + This configuration class contains all the parameters needed to define a SAC agent, + including network architectures, optimization settings, and algorithm-specific + hyperparameters. + """ + + # Mapping of feature types to normalization modes + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.MEAN_STD, + "STATE": NormalizationMode.MIN_MAX, + "ENV": NormalizationMode.MIN_MAX, + "ACTION": NormalizationMode.MIN_MAX, + } + ) + + # Statistics for normalizing different types of inputs + dataset_stats: dict[str, dict[str, list[float]]] | None = field( + default_factory=lambda: { + OBS_IMAGE: { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + }, + OBS_STATE: { + "min": [0.0, 0.0], + "max": [1.0, 1.0], + }, + ACTION: { + "min": [0.0, 0.0, 0.0], + "max": [1.0, 1.0, 1.0], + }, + } + ) + + # Architecture specifics + # Device to run the model on (e.g., "cuda", "cpu") + device: str = "cpu" + # Device to store the model on + storage_device: str = "cpu" + # Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10) + vision_encoder_name: str | None = None + # Whether to freeze the vision encoder during training + freeze_vision_encoder: bool = True + # Hidden dimension size for the image encoder + image_encoder_hidden_dim: int = 32 + # Whether to use a shared encoder for actor and critic + shared_encoder: bool = True + # Number of discrete actions, eg for gripper actions + num_discrete_actions: int | None = None + # Dimension of the image embedding pooling + image_embedding_pooling_dim: int = 8 + + # Training parameter + # Number of steps for online training + online_steps: int = 1000000 + # Seed for the online environment + online_env_seed: int = 10000 + # Capacity of the online replay buffer + online_buffer_capacity: int = 100000 + # Capacity of the offline replay buffer + offline_buffer_capacity: int = 100000 + # Whether to use asynchronous prefetching for the buffers + async_prefetch: bool = False + # Number of steps before learning starts + online_step_before_learning: int = 100 + # Frequency of policy updates + policy_update_freq: int = 1 + + # SAC algorithm parameters + # Discount factor for the SAC algorithm + discount: float = 0.99 + # Initial temperature value + temperature_init: float = 1.0 + # Number of critics in the ensemble + num_critics: int = 2 + # Number of subsampled critics for training + num_subsample_critics: int | None = None + # Learning rate for the critic network + critic_lr: float = 3e-4 + # Learning rate for the actor network + actor_lr: float = 3e-4 + # Learning rate for the temperature parameter + temperature_lr: float = 3e-4 + # Weight for the critic target update + critic_target_update_weight: float = 0.005 + # Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1) + utd_ratio: int = 1 + # Hidden dimension size for the state encoder + state_encoder_hidden_dim: int = 256 + # Dimension of the latent space + latent_dim: int = 256 + # Target entropy for the SAC algorithm + target_entropy: float | None = None + # Whether to use backup entropy for the SAC algorithm + use_backup_entropy: bool = True + # Gradient clipping norm for the SAC algorithm + grad_clip_norm: float = 40.0 + + # Network configuration + # Configuration for the critic network architecture + critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) + # Configuration for the actor network architecture + actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig) + # Configuration for the policy parameters + policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig) + # Configuration for the discrete critic network + discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) + # Configuration for actor-learner architecture + actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig) + # Configuration for concurrency settings (you can use threads or processes for the actor and learner) + concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig) + + # Optimizations + use_torch_compile: bool = True + + def __post_init__(self): + super().__post_init__() + # Any validation specific to SAC configuration + + def get_optimizer_preset(self) -> MultiAdamConfig: + return MultiAdamConfig( + weight_decay=0.0, + optimizer_groups={ + "actor": {"lr": self.actor_lr}, + "critic": {"lr": self.critic_lr}, + "temperature": {"lr": self.temperature_lr}, + }, + ) + + def get_scheduler_preset(self) -> None: + return None + + def validate_features(self) -> None: + has_image = any(is_image_feature(key) for key in self.input_features) + has_state = OBS_STATE in self.input_features + + if not (has_state or has_image): + raise ValueError( + "You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features" + ) + + if "action" not in self.output_features: + raise ValueError("You must provide 'action' in the output features") + + @property + def image_features(self) -> list[str]: + return [key for key in self.input_features if is_image_feature(key)] + + @property + def observation_delta_indices(self) -> list: + return None + + @property + def action_delta_indices(self) -> list: + return None # SAC typically predicts one action at a time + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py new file mode 100644 index 0000000000..b588115ea0 --- /dev/null +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -0,0 +1,1111 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import asdict +from typing import Callable, Literal + +import einops +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F # noqa: N812 +from torch import Tensor +from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution + +from lerobot.common.policies.normalize import NormalizeBuffer +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.policies.sac.configuration_sac import SACConfig, is_image_feature +from lerobot.common.policies.utils import get_device_from_parameters + +DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension + + +class SACPolicy( + PreTrainedPolicy, +): + config_class = SACConfig + name = "sac" + + def __init__( + self, + config: SACConfig | None = None, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + ): + super().__init__(config) + config.validate_features() + self.config = config + + # Determine action dimension and initialize all components + continuous_action_dim = config.output_features["action"].shape[0] + self._init_normalization(dataset_stats) + self._init_encoders() + self._init_critics(continuous_action_dim) + self._init_actor(continuous_action_dim) + self._init_temperature() + + def get_optim_params(self) -> dict: + optim_params = { + "actor": [ + p + for n, p in self.actor.named_parameters() + if not n.startswith("encoder") or not self.shared_encoder + ], + "critic": self.critic_ensemble.parameters(), + "temperature": self.log_alpha, + } + if self.config.num_discrete_actions is not None: + optim_params["discrete_critic"] = self.discrete_critic.parameters() + return optim_params + + def reset(self): + """Reset the policy""" + pass + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select action for inference/evaluation""" + + observations_features = None + if self.shared_encoder and self.actor.encoder.has_images: + # Cache and normalize image features + observations_features = self.actor.encoder.get_cached_image_features(batch, normalize=True) + + actions, _, _ = self.actor(batch, observations_features) + + if self.config.num_discrete_actions is not None: + discrete_action_value = self.discrete_critic(batch, observations_features) + discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True) + actions = torch.cat([actions, discrete_action], dim=-1) + + return actions + + def critic_forward( + self, + observations: dict[str, Tensor], + actions: Tensor, + use_target: bool = False, + observation_features: Tensor | None = None, + ) -> Tensor: + """Forward pass through a critic network ensemble + + Args: + observations: Dictionary of observations + actions: Action tensor + use_target: If True, use target critics, otherwise use ensemble critics + + Returns: + Tensor of Q-values from all critics + """ + + critics = self.critic_target if use_target else self.critic_ensemble + q_values = critics(observations, actions, observation_features) + return q_values + + def discrete_critic_forward( + self, observations, use_target=False, observation_features=None + ) -> torch.Tensor: + """Forward pass through a discrete critic network + + Args: + observations: Dictionary of observations + use_target: If True, use target critics, otherwise use ensemble critics + observation_features: Optional pre-computed observation features to avoid recomputing encoder output + + Returns: + Tensor of Q-values from the discrete critic network + """ + discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic + q_values = discrete_critic(observations, observation_features) + return q_values + + def forward( + self, + batch: dict[str, Tensor | dict[str, Tensor]], + model: Literal["actor", "critic", "temperature", "discrete_critic"] = "critic", + ) -> dict[str, Tensor]: + """Compute the loss for the given model + + Args: + batch: Dictionary containing: + - action: Action tensor + - reward: Reward tensor + - state: Observations tensor dict + - next_state: Next observations tensor dict + - done: Done mask tensor + - observation_feature: Optional pre-computed observation features + - next_observation_feature: Optional pre-computed next observation features + model: Which model to compute the loss for ("actor", "critic", "discrete_critic", or "temperature") + + Returns: + The computed loss tensor + """ + # Extract common components from batch + actions: Tensor = batch["action"] + observations: dict[str, Tensor] = batch["state"] + observation_features: Tensor = batch.get("observation_feature") + + if model == "critic": + # Extract critic-specific components + rewards: Tensor = batch["reward"] + next_observations: dict[str, Tensor] = batch["next_state"] + done: Tensor = batch["done"] + next_observation_features: Tensor = batch.get("next_observation_feature") + + loss_critic = self.compute_loss_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + observation_features=observation_features, + next_observation_features=next_observation_features, + ) + + return {"loss_critic": loss_critic} + + if model == "discrete_critic" and self.config.num_discrete_actions is not None: + # Extract critic-specific components + rewards: Tensor = batch["reward"] + next_observations: dict[str, Tensor] = batch["next_state"] + done: Tensor = batch["done"] + next_observation_features: Tensor = batch.get("next_observation_feature") + complementary_info = batch.get("complementary_info") + loss_discrete_critic = self.compute_loss_discrete_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + observation_features=observation_features, + next_observation_features=next_observation_features, + complementary_info=complementary_info, + ) + return {"loss_discrete_critic": loss_discrete_critic} + if model == "actor": + return { + "loss_actor": self.compute_loss_actor( + observations=observations, + observation_features=observation_features, + ) + } + + if model == "temperature": + return { + "loss_temperature": self.compute_loss_temperature( + observations=observations, + observation_features=observation_features, + ) + } + + raise ValueError(f"Unknown model type: {model}") + + def update_target_networks(self): + """Update target networks with exponential moving average""" + for target_param, param in zip( + self.critic_target.parameters(), + self.critic_ensemble.parameters(), + strict=True, + ): + target_param.data.copy_( + param.data * self.config.critic_target_update_weight + + target_param.data * (1.0 - self.config.critic_target_update_weight) + ) + if self.config.num_discrete_actions is not None: + for target_param, param in zip( + self.discrete_critic_target.parameters(), + self.discrete_critic.parameters(), + strict=True, + ): + target_param.data.copy_( + param.data * self.config.critic_target_update_weight + + target_param.data * (1.0 - self.config.critic_target_update_weight) + ) + + def update_temperature(self): + self.temperature = self.log_alpha.exp().item() + + def compute_loss_critic( + self, + observations, + actions, + rewards, + next_observations, + done, + observation_features: Tensor | None = None, + next_observation_features: Tensor | None = None, + ) -> Tensor: + with torch.no_grad(): + next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features) + + # 2- compute q targets + q_targets = self.critic_forward( + observations=next_observations, + actions=next_action_preds, + use_target=True, + observation_features=next_observation_features, + ) + + # subsample critics to prevent overfitting if use high UTD (update to date) + # TODO: Get indices before forward pass to avoid unnecessary computation + if self.config.num_subsample_critics is not None: + indices = torch.randperm(self.config.num_critics) + indices = indices[: self.config.num_subsample_critics] + q_targets = q_targets[indices] + + # critics subsample size + min_q, _ = q_targets.min(dim=0) # Get values from min operation + if self.config.use_backup_entropy: + min_q = min_q - (self.temperature * next_log_probs) + + td_target = rewards + (1 - done) * self.config.discount * min_q + + # 3- compute predicted qs + if self.config.num_discrete_actions is not None: + # NOTE: We only want to keep the continuous action part + # In the buffer we have the full action space (continuous + discrete) + # We need to split them before concatenating them in the critic forward + actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX] + q_preds = self.critic_forward( + observations=observations, + actions=actions, + use_target=False, + observation_features=observation_features, + ) + + # 4- Calculate loss + # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. + td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]) + # You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up + critics_loss = ( + F.mse_loss( + input=q_preds, + target=td_target_duplicate, + reduction="none", + ).mean(dim=1) + ).sum() + return critics_loss + + def compute_loss_discrete_critic( + self, + observations, + actions, + rewards, + next_observations, + done, + observation_features=None, + next_observation_features=None, + complementary_info=None, + ): + # NOTE: We only want to keep the discrete action part + # In the buffer we have the full action space (continuous + discrete) + # We need to split them before concatenating them in the critic forward + actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone() + actions_discrete = torch.round(actions_discrete) + actions_discrete = actions_discrete.long() + + discrete_penalties: Tensor | None = None + if complementary_info is not None: + discrete_penalties: Tensor | None = complementary_info.get("discrete_penalty") + + with torch.no_grad(): + # For DQN, select actions using online network, evaluate with target network + next_discrete_qs = self.discrete_critic_forward( + next_observations, use_target=False, observation_features=next_observation_features + ) + best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True) + + # Get target Q-values from target network + target_next_discrete_qs = self.discrete_critic_forward( + observations=next_observations, + use_target=True, + observation_features=next_observation_features, + ) + + # Use gather to select Q-values for best actions + target_next_discrete_q = torch.gather( + target_next_discrete_qs, dim=1, index=best_next_discrete_action + ).squeeze(-1) + + # Compute target Q-value with Bellman equation + rewards_discrete = rewards + if discrete_penalties is not None: + rewards_discrete = rewards + discrete_penalties + target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q + + # Get predicted Q-values for current observations + predicted_discrete_qs = self.discrete_critic_forward( + observations=observations, use_target=False, observation_features=observation_features + ) + + # Use gather to select Q-values for taken actions + predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1) + + # Compute MSE loss between predicted and target Q-values + discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q) + return discrete_critic_loss + + def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor: + """Compute the temperature loss""" + # calculate temperature loss + with torch.no_grad(): + _, log_probs, _ = self.actor(observations, observation_features) + temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean() + return temperature_loss + + def compute_loss_actor( + self, + observations, + observation_features: Tensor | None = None, + ) -> Tensor: + actions_pi, log_probs, _ = self.actor(observations, observation_features) + + q_preds = self.critic_forward( + observations=observations, + actions=actions_pi, + use_target=False, + observation_features=observation_features, + ) + min_q_preds = q_preds.min(dim=0)[0] + + actor_loss = ((self.temperature * log_probs) - min_q_preds).mean() + return actor_loss + + def _init_normalization(self, dataset_stats): + """Initialize input/output normalization modules.""" + self.normalize_inputs = nn.Identity() + self.normalize_targets = nn.Identity() + if self.config.dataset_stats is not None: + params = _convert_normalization_params_to_tensor(self.config.dataset_stats) + self.normalize_inputs = NormalizeBuffer( + self.config.input_features, self.config.normalization_mapping, params + ) + stats = dataset_stats or params + self.normalize_targets = NormalizeBuffer( + self.config.output_features, self.config.normalization_mapping, stats + ) + + def _init_encoders(self): + """Initialize shared or separate encoders for actor and critic.""" + self.shared_encoder = self.config.shared_encoder + self.encoder_critic = SACObservationEncoder(self.config, self.normalize_inputs) + self.encoder_actor = ( + self.encoder_critic + if self.shared_encoder + else SACObservationEncoder(self.config, self.normalize_inputs) + ) + + def _init_critics(self, continuous_action_dim): + """Build critic ensemble, targets, and optional discrete critic.""" + heads = [ + CriticHead( + input_dim=self.encoder_critic.output_dim + continuous_action_dim, + **asdict(self.config.critic_network_kwargs), + ) + for _ in range(self.config.num_critics) + ] + self.critic_ensemble = CriticEnsemble( + encoder=self.encoder_critic, ensemble=heads, output_normalization=self.normalize_targets + ) + target_heads = [ + CriticHead( + input_dim=self.encoder_critic.output_dim + continuous_action_dim, + **asdict(self.config.critic_network_kwargs), + ) + for _ in range(self.config.num_critics) + ] + self.critic_target = CriticEnsemble( + encoder=self.encoder_critic, ensemble=target_heads, output_normalization=self.normalize_targets + ) + self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) + + if self.config.use_torch_compile: + self.critic_ensemble = torch.compile(self.critic_ensemble) + self.critic_target = torch.compile(self.critic_target) + + if self.config.num_discrete_actions is not None: + self._init_discrete_critics() + + def _init_discrete_critics(self): + """Build discrete discrete critic ensemble and target networks.""" + self.discrete_critic = DiscreteCritic( + encoder=self.encoder_critic, + input_dim=self.encoder_critic.output_dim, + output_dim=self.config.num_discrete_actions, + **asdict(self.config.discrete_critic_network_kwargs), + ) + self.discrete_critic_target = DiscreteCritic( + encoder=self.encoder_critic, + input_dim=self.encoder_critic.output_dim, + output_dim=self.config.num_discrete_actions, + **asdict(self.config.discrete_critic_network_kwargs), + ) + + # TODO: (maractingi, azouitine) Compile the discrete critic + self.discrete_critic_target.load_state_dict(self.discrete_critic.state_dict()) + + def _init_actor(self, continuous_action_dim): + """Initialize policy actor network and default target entropy.""" + # NOTE: The actor select only the continuous action part + self.actor = Policy( + encoder=self.encoder_actor, + network=MLP(input_dim=self.encoder_actor.output_dim, **asdict(self.config.actor_network_kwargs)), + action_dim=continuous_action_dim, + encoder_is_shared=self.shared_encoder, + **asdict(self.config.policy_kwargs), + ) + + self.target_entropy = self.config.target_entropy + if self.target_entropy is None: + dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0) + self.target_entropy = -np.prod(dim) / 2 + + def _init_temperature(self): + """Set up temperature parameter and initial log_alpha.""" + temp_init = self.config.temperature_init + self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)])) + self.temperature = self.log_alpha.exp().item() + + +class SACObservationEncoder(nn.Module): + """Encode image and/or state vector observations.""" + + def __init__(self, config: SACConfig, input_normalizer: nn.Module) -> None: + super().__init__() + self.config = config + self.input_normalization = input_normalizer + self._init_image_layers() + self._init_state_layers() + self._compute_output_dim() + + def _init_image_layers(self) -> None: + self.image_keys = [k for k in self.config.input_features if is_image_feature(k)] + self.has_images = bool(self.image_keys) + if not self.has_images: + return + + if self.config.vision_encoder_name is not None: + self.image_encoder = PretrainedImageEncoder(self.config) + else: + self.image_encoder = DefaultImageEncoder(self.config) + + if self.config.freeze_vision_encoder: + freeze_image_encoder(self.image_encoder) + + dummy = torch.zeros(1, *self.config.input_features[self.image_keys[0]].shape) + with torch.no_grad(): + _, channels, height, width = self.image_encoder(dummy).shape + + self.spatial_embeddings = nn.ModuleDict() + self.post_encoders = nn.ModuleDict() + + for key in self.image_keys: + name = key.replace(".", "_") + self.spatial_embeddings[name] = SpatialLearnedEmbeddings( + height=height, + width=width, + channel=channels, + num_features=self.config.image_embedding_pooling_dim, + ) + self.post_encoders[name] = nn.Sequential( + nn.Dropout(0.1), + nn.Linear( + in_features=channels * self.config.image_embedding_pooling_dim, + out_features=self.config.latent_dim, + ), + nn.LayerNorm(normalized_shape=self.config.latent_dim), + nn.Tanh(), + ) + + def _init_state_layers(self) -> None: + self.has_env = "observation.environment_state" in self.config.input_features + self.has_state = "observation.state" in self.config.input_features + if self.has_env: + dim = self.config.input_features["observation.environment_state"].shape[0] + self.env_encoder = nn.Sequential( + nn.Linear(dim, self.config.latent_dim), + nn.LayerNorm(self.config.latent_dim), + nn.Tanh(), + ) + if self.has_state: + dim = self.config.input_features["observation.state"].shape[0] + self.state_encoder = nn.Sequential( + nn.Linear(dim, self.config.latent_dim), + nn.LayerNorm(self.config.latent_dim), + nn.Tanh(), + ) + + def _compute_output_dim(self) -> None: + out = 0 + if self.has_images: + out += len(self.image_keys) * self.config.latent_dim + if self.has_env: + out += self.config.latent_dim + if self.has_state: + out += self.config.latent_dim + self._out_dim = out + + def forward( + self, obs: dict[str, Tensor], cache: dict[str, Tensor] | None = None, detach: bool = False + ) -> Tensor: + obs = self.input_normalization(obs) + parts = [] + if self.has_images: + if cache is None: + cache = self.get_cached_image_features(obs, normalize=False) + parts.append(self._encode_images(cache, detach)) + if self.has_env: + parts.append(self.env_encoder(obs["observation.environment_state"])) + if self.has_state: + parts.append(self.state_encoder(obs["observation.state"])) + if parts: + return torch.cat(parts, dim=-1) + + raise ValueError( + "No parts to concatenate, you should have at least one image or environment state or state" + ) + + def get_cached_image_features(self, obs: dict[str, Tensor], normalize: bool = False) -> dict[str, Tensor]: + """Extract and optionally cache image features from observations. + + This function processes image observations through the vision encoder once and returns + the resulting features. + When the image encoder is shared between actor and critics AND frozen, these features can be safely cached and + reused across policy components (actor, critic, discrete_critic), avoiding redundant forward passes. + + Performance impact: + - The vision encoder forward pass is typically the main computational bottleneck during training and inference + - Caching these features can provide 2-4x speedup in training and inference + + Normalization behavior: + - When called from inside forward(): set normalize=False since inputs are already normalized + - When called from outside forward(): set normalize=True to ensure proper input normalization + + Usage patterns: + - Called in select_action() with normalize=True + - Called in learner.py's get_observation_features() to pre-compute features for all policy components + - Called internally by forward() with normalize=False + + Args: + obs: Dictionary of observation tensors containing image keys + normalize: Whether to normalize observations before encoding + Set to True when calling directly from outside the encoder's forward method + Set to False when calling from within forward() where inputs are already normalized + + Returns: + Dictionary mapping image keys to their corresponding encoded features + """ + if normalize: + obs = self.input_normalization(obs) + batched = torch.cat([obs[k] for k in self.image_keys], dim=0) + out = self.image_encoder(batched) + chunks = torch.chunk(out, len(self.image_keys), dim=0) + return dict(zip(self.image_keys, chunks, strict=False)) + + def _encode_images(self, cache: dict[str, Tensor], detach: bool) -> Tensor: + """Encode image features from cached observations. + + This function takes pre-encoded image features from the cache and applies spatial embeddings and post-encoders. + It also supports detaching the encoded features if specified. + + Args: + cache (dict[str, Tensor]): The cached image features. + detach (bool): Usually when the encoder is shared between actor and critics, + we want to detach the encoded features on the policy side to avoid backprop through the encoder. + More detail here `https://cdn.aaai.org/ojs/17276/17276-13-20770-1-2-20210518.pdf` + + Returns: + Tensor: The encoded image features. + """ + feats = [] + for k, feat in cache.items(): + safe_key = k.replace(".", "_") + x = self.spatial_embeddings[safe_key](feat) + x = self.post_encoders[safe_key](x) + if detach: + x = x.detach() + feats.append(x) + return torch.cat(feats, dim=-1) + + @property + def output_dim(self) -> int: + return self._out_dim + + +class MLP(nn.Module): + """Multi-layer perceptron builder. + + Dynamically constructs a sequence of layers based on `hidden_dims`: + 1) Linear (in_dim -> out_dim) + 2) Optional Dropout if `dropout_rate` > 0 and (not final layer or `activate_final`) + 3) LayerNorm on the output features + 4) Activation (standard for intermediate layers, `final_activation` for last layer if `activate_final`) + + Arguments: + input_dim (int): Size of input feature dimension. + hidden_dims (list[int]): Sizes for each hidden layer. + activations (Callable or str): Activation to apply between layers. + activate_final (bool): Whether to apply activation at the final layer. + dropout_rate (Optional[float]): Dropout probability applied before normalization and activation. + final_activation (Optional[Callable or str]): Activation for the final layer when `activate_final` is True. + + For each layer, `in_dim` is updated to the previous `out_dim`. All constructed modules are + stored in `self.net` as an `nn.Sequential` container. + """ + + def __init__( + self, + input_dim: int, + hidden_dims: list[int], + activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(), + activate_final: bool = False, + dropout_rate: float | None = None, + final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None, + ): + super().__init__() + layers: list[nn.Module] = [] + in_dim = input_dim + total = len(hidden_dims) + + for idx, out_dim in enumerate(hidden_dims): + # 1) linear transform + layers.append(nn.Linear(in_dim, out_dim)) + + is_last = idx == total - 1 + # 2-4) optionally add dropout, normalization, and activation + if not is_last or activate_final: + if dropout_rate and dropout_rate > 0: + layers.append(nn.Dropout(p=dropout_rate)) + layers.append(nn.LayerNorm(out_dim)) + act_cls = final_activation if is_last and final_activation else activations + act = act_cls if isinstance(act_cls, nn.Module) else getattr(nn, act_cls)() + layers.append(act) + + in_dim = out_dim + + self.net = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +class CriticHead(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dims: list[int], + activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(), + activate_final: bool = False, + dropout_rate: float | None = None, + init_final: float | None = None, + final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None, + ): + super().__init__() + self.net = MLP( + input_dim=input_dim, + hidden_dims=hidden_dims, + activations=activations, + activate_final=activate_final, + dropout_rate=dropout_rate, + final_activation=final_activation, + ) + self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1) + if init_final is not None: + nn.init.uniform_(self.output_layer.weight, -init_final, init_final) + nn.init.uniform_(self.output_layer.bias, -init_final, init_final) + else: + orthogonal_init()(self.output_layer.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.output_layer(self.net(x)) + + +class CriticEnsemble(nn.Module): + """ + CriticEnsemble wraps multiple CriticHead modules into an ensemble. + + Args: + encoder (SACObservationEncoder): encoder for observations. + ensemble (List[CriticHead]): list of critic heads. + output_normalization (nn.Module): normalization layer for actions. + init_final (float | None): optional initializer scale for final layers. + + Forward returns a tensor of shape (num_critics, batch_size) containing Q-values. + """ + + def __init__( + self, + encoder: SACObservationEncoder, + ensemble: list[CriticHead], + output_normalization: nn.Module, + init_final: float | None = None, + ): + super().__init__() + self.encoder = encoder + self.init_final = init_final + self.output_normalization = output_normalization + self.critics = nn.ModuleList(ensemble) + + def forward( + self, + observations: dict[str, torch.Tensor], + actions: torch.Tensor, + observation_features: torch.Tensor | None = None, + ) -> torch.Tensor: + device = get_device_from_parameters(self) + # Move each tensor in observations to device + observations = {k: v.to(device) for k, v in observations.items()} + # NOTE: We normalize actions it helps for sample efficiency + actions: dict[str, torch.tensor] = {"action": actions} + # NOTE: Normalization layer took dict in input and outputs a dict that why + actions = self.output_normalization(actions)["action"] + actions = actions.to(device) + + obs_enc = self.encoder(observations, cache=observation_features) + + inputs = torch.cat([obs_enc, actions], dim=-1) + + # Loop through critics and collect outputs + q_values = [] + for critic in self.critics: + q_values.append(critic(inputs)) + + # Stack outputs to match expected shape [num_critics, batch_size] + q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0) + return q_values + + +class DiscreteCritic(nn.Module): + def __init__( + self, + encoder: nn.Module, + input_dim: int, + hidden_dims: list[int], + output_dim: int = 3, + activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(), + activate_final: bool = False, + dropout_rate: float | None = None, + init_final: float | None = None, + final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None, + ): + super().__init__() + self.encoder = encoder + self.output_dim = output_dim + + self.net = MLP( + input_dim=input_dim, + hidden_dims=hidden_dims, + activations=activations, + activate_final=activate_final, + dropout_rate=dropout_rate, + final_activation=final_activation, + ) + + self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=self.output_dim) + if init_final is not None: + nn.init.uniform_(self.output_layer.weight, -init_final, init_final) + nn.init.uniform_(self.output_layer.bias, -init_final, init_final) + else: + orthogonal_init()(self.output_layer.weight) + + def forward( + self, observations: torch.Tensor, observation_features: torch.Tensor | None = None + ) -> torch.Tensor: + device = get_device_from_parameters(self) + observations = {k: v.to(device) for k, v in observations.items()} + obs_enc = self.encoder(observations, cache=observation_features) + return self.output_layer(self.net(obs_enc)) + + +class Policy(nn.Module): + def __init__( + self, + encoder: SACObservationEncoder, + network: nn.Module, + action_dim: int, + std_min: float = -5, + std_max: float = 2, + fixed_std: torch.Tensor | None = None, + init_final: float | None = None, + use_tanh_squash: bool = False, + encoder_is_shared: bool = False, + ): + super().__init__() + self.encoder: SACObservationEncoder = encoder + self.network = network + self.action_dim = action_dim + self.std_min = std_min + self.std_max = std_max + self.fixed_std = fixed_std + self.use_tanh_squash = use_tanh_squash + self.encoder_is_shared = encoder_is_shared + + # Find the last Linear layer's output dimension + for layer in reversed(network.net): + if isinstance(layer, nn.Linear): + out_features = layer.out_features + break + # Mean layer + self.mean_layer = nn.Linear(out_features, action_dim) + if init_final is not None: + nn.init.uniform_(self.mean_layer.weight, -init_final, init_final) + nn.init.uniform_(self.mean_layer.bias, -init_final, init_final) + else: + orthogonal_init()(self.mean_layer.weight) + + # Standard deviation layer or parameter + if fixed_std is None: + self.std_layer = nn.Linear(out_features, action_dim) + if init_final is not None: + nn.init.uniform_(self.std_layer.weight, -init_final, init_final) + nn.init.uniform_(self.std_layer.bias, -init_final, init_final) + else: + orthogonal_init()(self.std_layer.weight) + + def forward( + self, + observations: torch.Tensor, + observation_features: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # We detach the encoder if it is shared to avoid backprop through it + # This is important to avoid the encoder to be updated through the policy + obs_enc = self.encoder(observations, cache=observation_features, detach=self.encoder_is_shared) + + # Get network outputs + outputs = self.network(obs_enc) + means = self.mean_layer(outputs) + + # Compute standard deviations + if self.fixed_std is None: + log_std = self.std_layer(outputs) + std = torch.exp(log_std) # Match JAX "exp" + std = torch.clamp(std, self.std_min, self.std_max) # Match JAX default clip + else: + std = self.fixed_std.expand_as(means) + + # Build transformed distribution + dist = TanhMultivariateNormalDiag(loc=means, scale_diag=std) + + # Sample actions (reparameterized) + actions = dist.rsample() + + # Compute log_probs + log_probs = dist.log_prob(actions) + + return actions, log_probs, means + + def get_features(self, observations: torch.Tensor) -> torch.Tensor: + """Get encoded features from observations""" + device = get_device_from_parameters(self) + observations = observations.to(device) + if self.encoder is not None: + with torch.inference_mode(): + return self.encoder(observations) + return observations + + +class DefaultImageEncoder(nn.Module): + def __init__(self, config: SACConfig): + super().__init__() + image_key = next(key for key in config.input_features if is_image_feature(key)) + self.image_enc_layers = nn.Sequential( + nn.Conv2d( + in_channels=config.input_features[image_key].shape[0], + out_channels=config.image_encoder_hidden_dim, + kernel_size=7, + stride=2, + ), + nn.ReLU(), + nn.Conv2d( + in_channels=config.image_encoder_hidden_dim, + out_channels=config.image_encoder_hidden_dim, + kernel_size=5, + stride=2, + ), + nn.ReLU(), + nn.Conv2d( + in_channels=config.image_encoder_hidden_dim, + out_channels=config.image_encoder_hidden_dim, + kernel_size=3, + stride=2, + ), + nn.ReLU(), + nn.Conv2d( + in_channels=config.image_encoder_hidden_dim, + out_channels=config.image_encoder_hidden_dim, + kernel_size=3, + stride=2, + ), + nn.ReLU(), + ) + + def forward(self, x): + x = self.image_enc_layers(x) + return x + + +def freeze_image_encoder(image_encoder: nn.Module): + """Freeze all parameters in the encoder""" + for param in image_encoder.parameters(): + param.requires_grad = False + + +class PretrainedImageEncoder(nn.Module): + def __init__(self, config: SACConfig): + super().__init__() + + self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config) + + def _load_pretrained_vision_encoder(self, config: SACConfig): + """Set up CNN encoder""" + from transformers import AutoModel + + self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True) + + if hasattr(self.image_enc_layers.config, "hidden_sizes"): + self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension + elif hasattr(self.image_enc_layers, "fc"): + self.image_enc_out_shape = self.image_enc_layers.fc.in_features + else: + raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN") + return self.image_enc_layers, self.image_enc_out_shape + + def forward(self, x): + enc_feat = self.image_enc_layers(x).last_hidden_state + return enc_feat + + +def orthogonal_init(): + return lambda x: torch.nn.init.orthogonal_(x, gain=1.0) + + +class SpatialLearnedEmbeddings(nn.Module): + def __init__(self, height, width, channel, num_features=8): + """ + PyTorch implementation of learned spatial embeddings + + Args: + height: Spatial height of input features + width: Spatial width of input features + channel: Number of input channels + num_features: Number of output embedding dimensions + """ + super().__init__() + self.height = height + self.width = width + self.channel = channel + self.num_features = num_features + + self.kernel = nn.Parameter(torch.empty(channel, height, width, num_features)) + + nn.init.kaiming_normal_(self.kernel, mode="fan_in", nonlinearity="linear") + + def forward(self, features): + """ + Forward pass for spatial embedding + + Args: + features: Input tensor of shape [B, C, H, W] where B is batch size, + C is number of channels, H is height, and W is width + Returns: + Output tensor of shape [B, C*F] where F is the number of features + """ + + features_expanded = features.unsqueeze(-1) # [B, C, H, W, 1] + kernel_expanded = self.kernel.unsqueeze(0) # [1, C, H, W, F] + + # Element-wise multiplication and spatial reduction + output = (features_expanded * kernel_expanded).sum(dim=(2, 3)) # Sum over H,W dimensions + + # Reshape to combine channel and feature dimensions + output = output.view(output.size(0), -1) # [B, C*F] + + return output + + +class RescaleFromTanh(Transform): + def __init__(self, low: float = -1, high: float = 1): + super().__init__() + + self.low = low + + self.high = high + + def _call(self, x): + # Rescale from (-1, 1) to (low, high) + + return 0.5 * (x + 1.0) * (self.high - self.low) + self.low + + def _inverse(self, y): + # Rescale from (low, high) back to (-1, 1) + + return 2.0 * (y - self.low) / (self.high - self.low) - 1.0 + + def log_abs_det_jacobian(self, x, y): + # log|d(rescale)/dx| = sum(log(0.5 * (high - low))) + + scale = 0.5 * (self.high - self.low) + + return torch.sum(torch.log(scale), dim=-1) + + +class TanhMultivariateNormalDiag(TransformedDistribution): + def __init__(self, loc, scale_diag, low=None, high=None): + base_dist = MultivariateNormal(loc, torch.diag_embed(scale_diag)) + + transforms = [TanhTransform(cache_size=1)] + + if low is not None and high is not None: + low = torch.as_tensor(low) + + high = torch.as_tensor(high) + + transforms.insert(0, RescaleFromTanh(low, high)) + + super().__init__(base_dist, transforms) + + def mode(self): + # Mode is mean of base distribution, passed through transforms + + x = self.base_dist.mean + + for transform in self.transforms: + x = transform(x) + + return x + + def stddev(self): + std = self.base_dist.stddev + + x = std + + for transform in self.transforms: + x = transform(x) + + return x + + +def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict: + converted_params = {} + for outer_key, inner_dict in normalization_params.items(): + converted_params[outer_key] = {} + for key, value in inner_dict.items(): + converted_params[outer_key][key] = torch.tensor(value) + if "image" in outer_key: + converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1) + + return converted_params diff --git a/lerobot/common/policies/sac/reward_model/configuration_classifier.py b/lerobot/common/policies/sac/reward_model/configuration_classifier.py new file mode 100644 index 0000000000..6e2a551d4d --- /dev/null +++ b/lerobot/common/policies/sac/reward_model/configuration_classifier.py @@ -0,0 +1,76 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field + +from lerobot.common.optim.optimizers import AdamWConfig, OptimizerConfig +from lerobot.common.optim.schedulers import LRSchedulerConfig +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import NormalizationMode + + +@PreTrainedConfig.register_subclass(name="reward_classifier") +@dataclass +class RewardClassifierConfig(PreTrainedConfig): + """Configuration for the Reward Classifier model.""" + + name: str = "reward_classifier" + num_classes: int = 2 + hidden_dim: int = 256 + latent_dim: int = 256 + image_embedding_pooling_dim: int = 8 + dropout_rate: float = 0.1 + model_name: str = "helper2424/resnet10" + device: str = "cpu" + model_type: str = "cnn" # "transformer" or "cnn" + num_cameras: int = 2 + learning_rate: float = 1e-4 + weight_decay: float = 0.01 + grad_clip_norm: float = 1.0 + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.MEAN_STD, + } + ) + + @property + def observation_delta_indices(self) -> list | None: + return None + + @property + def action_delta_indices(self) -> list | None: + return None + + @property + def reward_delta_indices(self) -> list | None: + return None + + def get_optimizer_preset(self) -> OptimizerConfig: + return AdamWConfig( + lr=self.learning_rate, + weight_decay=self.weight_decay, + grad_clip_norm=self.grad_clip_norm, + ) + + def get_scheduler_preset(self) -> LRSchedulerConfig | None: + return None + + def validate_features(self) -> None: + """Validate feature configurations.""" + has_image = any(key.startswith("observation.image") for key in self.input_features) + if not has_image: + raise ValueError( + "You must provide an image observation (key starting with 'observation.image') in the input features" + ) diff --git a/lerobot/common/policies/sac/reward_model/modeling_classifier.py b/lerobot/common/policies/sac/reward_model/modeling_classifier.py new file mode 100644 index 0000000000..f537e3aefd --- /dev/null +++ b/lerobot/common/policies/sac/reward_model/modeling_classifier.py @@ -0,0 +1,316 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import torch +from torch import Tensor, nn + +from lerobot.common.constants import OBS_IMAGE, REWARD +from lerobot.common.policies.normalize import Normalize, Unnormalize +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig + + +class ClassifierOutput: + """Wrapper for classifier outputs with additional metadata.""" + + def __init__( + self, + logits: Tensor, + probabilities: Tensor | None = None, + hidden_states: Tensor | None = None, + ): + self.logits = logits + self.probabilities = probabilities + self.hidden_states = hidden_states + + def __repr__(self): + return ( + f"ClassifierOutput(logits={self.logits}, " + f"probabilities={self.probabilities}, " + f"hidden_states={self.hidden_states})" + ) + + +class SpatialLearnedEmbeddings(nn.Module): + def __init__(self, height, width, channel, num_features=8): + """ + PyTorch implementation of learned spatial embeddings + + Args: + height: Spatial height of input features + width: Spatial width of input features + channel: Number of input channels + num_features: Number of output embedding dimensions + """ + super().__init__() + self.height = height + self.width = width + self.channel = channel + self.num_features = num_features + + self.kernel = nn.Parameter(torch.empty(channel, height, width, num_features)) + + nn.init.kaiming_normal_(self.kernel, mode="fan_in", nonlinearity="linear") + + def forward(self, features): + """ + Forward pass for spatial embedding + + Args: + features: Input tensor of shape [B, H, W, C] or [H, W, C] if no batch + Returns: + Output tensor of shape [B, C*F] or [C*F] if no batch + """ + + features = features.last_hidden_state + + original_shape = features.shape + if features.dim() == 3: + features = features.unsqueeze(0) # Add batch dim + + features_expanded = features.unsqueeze(-1) # [B, H, W, C, 1] + kernel_expanded = self.kernel.unsqueeze(0) # [1, H, W, C, F] + + # Element-wise multiplication and spatial reduction + output = (features_expanded * kernel_expanded).sum(dim=(2, 3)) # Sum H,W + + # Reshape to combine channel and feature dimensions + output = output.view(output.size(0), -1) # [B, C*F] + + # Remove batch dim + if len(original_shape) == 3: + output = output.squeeze(0) + + return output + + +class Classifier(PreTrainedPolicy): + """Image classifier built on top of a pre-trained encoder.""" + + name = "reward_classifier" + config_class = RewardClassifierConfig + + def __init__( + self, + config: RewardClassifierConfig, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + ): + from transformers import AutoModel + + super().__init__(config) + self.config = config + + # Initialize normalization (standardized with the policy framework) + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + + # Set up encoder + encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True) + # Extract vision model if we're given a multimodal model + if hasattr(encoder, "vision_model"): + logging.info("Multimodal model detected - using vision encoder only") + self.encoder = encoder.vision_model + self.vision_config = encoder.config.vision_config + else: + self.encoder = encoder + self.vision_config = getattr(encoder, "config", None) + + # Model type from config + self.is_cnn = self.config.model_type == "cnn" + + # For CNNs, initialize backbone + if self.is_cnn: + self._setup_cnn_backbone() + + self._freeze_encoder() + + # Extract image keys from input_features + self.image_keys = [ + key.replace(".", "_") for key in config.input_features if key.startswith(OBS_IMAGE) + ] + + if self.is_cnn: + self.encoders = nn.ModuleDict() + for image_key in self.image_keys: + encoder = self._create_single_encoder() + self.encoders[image_key] = encoder + + self._build_classifier_head() + + def _setup_cnn_backbone(self): + """Set up CNN encoder""" + if hasattr(self.encoder, "fc"): + self.feature_dim = self.encoder.fc.in_features + self.encoder = nn.Sequential(*list(self.encoder.children())[:-1]) + elif hasattr(self.encoder.config, "hidden_sizes"): + self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension + else: + raise ValueError("Unsupported CNN architecture") + + def _freeze_encoder(self) -> None: + """Freeze the encoder parameters.""" + for param in self.encoder.parameters(): + param.requires_grad = False + + def _create_single_encoder(self): + encoder = nn.Sequential( + self.encoder, + SpatialLearnedEmbeddings( + height=4, + width=4, + channel=self.feature_dim, + num_features=self.config.image_embedding_pooling_dim, + ), + nn.Dropout(self.config.dropout_rate), + nn.Linear(self.feature_dim * self.config.image_embedding_pooling_dim, self.config.latent_dim), + nn.LayerNorm(self.config.latent_dim), + nn.Tanh(), + ) + + return encoder + + def _build_classifier_head(self) -> None: + """Initialize the classifier head architecture.""" + # Get input dimension based on model type + if self.is_cnn: + input_dim = self.config.latent_dim + else: # Transformer models + if hasattr(self.encoder.config, "hidden_size"): + input_dim = self.encoder.config.hidden_size + else: + raise ValueError("Unsupported transformer architecture since hidden_size is not found") + + self.classifier_head = nn.Sequential( + nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim), + nn.Dropout(self.config.dropout_rate), + nn.LayerNorm(self.config.hidden_dim), + nn.ReLU(), + nn.Linear( + self.config.hidden_dim, + 1 if self.config.num_classes == 2 else self.config.num_classes, + ), + ) + + def _get_encoder_output(self, x: torch.Tensor, image_key: str) -> torch.Tensor: + """Extract the appropriate output from the encoder.""" + with torch.no_grad(): + if self.is_cnn: + # The HF ResNet applies pooling internally + outputs = self.encoders[image_key](x) + return outputs + else: # Transformer models + outputs = self.encoder(x) + return outputs.last_hidden_state[:, 0, :] + + def extract_images_and_labels(self, batch: dict[str, Tensor]) -> tuple[list, Tensor]: + """Extract image tensors and label tensors from batch.""" + # Check for both OBS_IMAGE and OBS_IMAGES prefixes + images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)] + labels = batch[REWARD] + + return images, labels + + def predict(self, xs: list) -> ClassifierOutput: + """Forward pass of the classifier for inference.""" + encoder_outputs = torch.hstack( + [self._get_encoder_output(x, img_key) for x, img_key in zip(xs, self.image_keys, strict=True)] + ) + logits = self.classifier_head(encoder_outputs) + + if self.config.num_classes == 2: + logits = logits.squeeze(-1) + probabilities = torch.sigmoid(logits) + else: + probabilities = torch.softmax(logits, dim=-1) + + return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs) + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]: + """Standard forward pass for training compatible with train.py.""" + # Normalize inputs if needed + batch = self.normalize_inputs(batch) + batch = self.normalize_targets(batch) + + # Extract images and labels + images, labels = self.extract_images_and_labels(batch) + + # Get predictions + outputs = self.predict(images) + + # Calculate loss + if self.config.num_classes == 2: + # Binary classification + loss = nn.functional.binary_cross_entropy_with_logits(outputs.logits, labels) + predictions = (torch.sigmoid(outputs.logits) > 0.5).float() + else: + # Multi-class classification + loss = nn.functional.cross_entropy(outputs.logits, labels.long()) + predictions = torch.argmax(outputs.logits, dim=1) + + # Calculate accuracy for logging + correct = (predictions == labels).sum().item() + total = labels.size(0) + accuracy = 100 * correct / total + + # Return loss and metrics for logging + output_dict = { + "accuracy": accuracy, + "correct": correct, + "total": total, + } + + return loss, output_dict + + def predict_reward(self, batch, threshold=0.5): + """Eval method. Returns predicted reward with the decision threshold as argument.""" + # Check for both OBS_IMAGE and OBS_IMAGES prefixes + batch = self.normalize_inputs(batch) + batch = self.normalize_targets(batch) + + # Extract images from batch dict + images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)] + + if self.config.num_classes == 2: + probs = self.predict(images).probabilities + logging.debug(f"Predicted reward images: {probs}") + return (probs > threshold).float() + else: + return torch.argmax(self.predict(images).probabilities, dim=1) + + def get_optim_params(self): + """Return optimizer parameters for the policy.""" + return self.parameters() + + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """ + This method is required by PreTrainedPolicy but not used for reward classifiers. + The reward classifier is not an actor and does not select actions. + """ + raise NotImplementedError("Reward classifiers do not select actions") + + def reset(self): + """ + This method is required by PreTrainedPolicy but not used for reward classifiers. + The reward classifier is not an actor and does not select actions. + """ + pass diff --git a/lerobot/common/robots/so100_follower/__init__.py b/lerobot/common/robots/so100_follower/__init__.py index 087fd64562..63c3e1c17a 100644 --- a/lerobot/common/robots/so100_follower/__init__.py +++ b/lerobot/common/robots/so100_follower/__init__.py @@ -1,2 +1,3 @@ -from .config_so100_follower import SO100FollowerConfig +from .config_so100_follower import SO100FollowerConfig, SO100FollowerEndEffectorConfig from .so100_follower import SO100Follower +from .so100_follower_end_effector import SO100FollowerEndEffector diff --git a/lerobot/common/robots/so100_follower/config_so100_follower.py b/lerobot/common/robots/so100_follower/config_so100_follower.py index 2a5a966ee2..b76675d26a 100644 --- a/lerobot/common/robots/so100_follower/config_so100_follower.py +++ b/lerobot/common/robots/so100_follower/config_so100_follower.py @@ -37,3 +37,27 @@ class SO100FollowerConfig(RobotConfig): # Set to `True` for backward compatibility with previous policies/dataset use_degrees: bool = False + + +@RobotConfig.register_subclass("so100_follower_end_effector") +@dataclass +class SO100FollowerEndEffectorConfig(SO100FollowerConfig): + """Configuration for the SO100FollowerEndEffector robot.""" + + # Default bounds for the end-effector position (in meters) + end_effector_bounds: dict[str, list[float]] = field( + default_factory=lambda: { + "min": [-1.0, -1.0, -1.0], # min x, y, z + "max": [1.0, 1.0, 1.0], # max x, y, z + } + ) + + max_gripper_pos: float = 50 + + end_effector_step_sizes: dict[str, float] = field( + default_factory=lambda: { + "x": 0.02, + "y": 0.02, + "z": 0.02, + } + ) diff --git a/lerobot/common/robots/so100_follower/so100_follower_end_effector.py b/lerobot/common/robots/so100_follower/so100_follower_end_effector.py new file mode 100644 index 0000000000..82e89305b3 --- /dev/null +++ b/lerobot/common/robots/so100_follower/so100_follower_end_effector.py @@ -0,0 +1,193 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from typing import Any + +import numpy as np + +from lerobot.common.cameras import make_cameras_from_configs +from lerobot.common.errors import DeviceNotConnectedError +from lerobot.common.model.kinematics import RobotKinematics +from lerobot.common.motors import Motor, MotorNormMode +from lerobot.common.motors.feetech import FeetechMotorsBus + +from . import SO100Follower +from .config_so100_follower import SO100FollowerEndEffectorConfig + +logger = logging.getLogger(__name__) +EE_FRAME = "gripper_tip" + + +class SO100FollowerEndEffector(SO100Follower): + """ + SO100Follower robot with end-effector space control. + + This robot inherits from SO100Follower but transforms actions from + end-effector space to joint space before sending them to the motors. + """ + + config_class = SO100FollowerEndEffectorConfig + name = "so100_follower_end_effector" + + def __init__(self, config: SO100FollowerEndEffectorConfig): + super().__init__(config) + self.bus = FeetechMotorsBus( + port=self.config.port, + motors={ + "shoulder_pan": Motor(1, "sts3215", MotorNormMode.DEGREES), + "shoulder_lift": Motor(2, "sts3215", MotorNormMode.DEGREES), + "elbow_flex": Motor(3, "sts3215", MotorNormMode.DEGREES), + "wrist_flex": Motor(4, "sts3215", MotorNormMode.DEGREES), + "wrist_roll": Motor(5, "sts3215", MotorNormMode.DEGREES), + "gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100), + }, + calibration=self.calibration, + ) + + self.cameras = make_cameras_from_configs(config.cameras) + + self.config = config + + # Initialize the kinematics module for the so100 robot + self.kinematics = RobotKinematics(robot_type="so_new_calibration") + + # Store the bounds for end-effector position + self.end_effector_bounds = self.config.end_effector_bounds + + self.current_ee_pos = None + self.current_joint_pos = None + + @property + def action_features(self) -> dict[str, Any]: + """ + Define action features for end-effector control. + Returns dictionary with dtype, shape, and names. + """ + return { + "dtype": "float32", + "shape": (4,), + "names": {"delta_x": 0, "delta_y": 1, "delta_z": 2, "gripper": 3}, + } + + def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + """ + Transform action from end-effector space to joint space and send to motors. + + Args: + action: Dictionary with keys 'delta_x', 'delta_y', 'delta_z' for end-effector control + or a numpy array with [delta_x, delta_y, delta_z] + + Returns: + The joint-space action that was sent to the motors + """ + + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + # Convert action to numpy array if not already + if isinstance(action, dict): + if all(k in action for k in ["delta_x", "delta_y", "delta_z"]): + delta_ee = np.array( + [ + action["delta_x"] * self.config.end_effector_step_sizes["x"], + action["delta_y"] * self.config.end_effector_step_sizes["y"], + action["delta_z"] * self.config.end_effector_step_sizes["z"], + ], + dtype=np.float32, + ) + if "gripper" not in action: + action["gripper"] = [1.0] + action = np.append(delta_ee, action["gripper"]) + else: + logger.warning( + f"Expected action keys 'delta_x', 'delta_y', 'delta_z', got {list(action.keys())}" + ) + action = np.zeros(4, dtype=np.float32) + + if self.current_joint_pos is None: + # Read current joint positions + current_joint_pos = self.bus.sync_read("Present_Position") + self.current_joint_pos = np.array([current_joint_pos[name] for name in self.bus.motors]) + + # Calculate current end-effector position using forward kinematics + if self.current_ee_pos is None: + self.current_ee_pos = self.kinematics.forward_kinematics(self.current_joint_pos, frame=EE_FRAME) + + # Set desired end-effector position by adding delta + desired_ee_pos = np.eye(4) + desired_ee_pos[:3, :3] = self.current_ee_pos[:3, :3] # Keep orientation + + # Add delta to position and clip to bounds + desired_ee_pos[:3, 3] = self.current_ee_pos[:3, 3] + action[:3] + if self.end_effector_bounds is not None: + desired_ee_pos[:3, 3] = np.clip( + desired_ee_pos[:3, 3], + self.end_effector_bounds["min"], + self.end_effector_bounds["max"], + ) + + # Compute inverse kinematics to get joint positions + target_joint_values_in_degrees = self.kinematics.ik( + self.current_joint_pos, desired_ee_pos, position_only=True, frame=EE_FRAME + ) + + target_joint_values_in_degrees = np.clip(target_joint_values_in_degrees, -180.0, 180.0) + # Create joint space action dictionary + joint_action = { + f"{key}.pos": target_joint_values_in_degrees[i] for i, key in enumerate(self.bus.motors.keys()) + } + + # Handle gripper separately if included in action + # Gripper delta action is in the range 0 - 2, + # We need to shift the action to the range -1, 1 so that we can expand it to -Max_gripper_pos, Max_gripper_pos + joint_action["gripper.pos"] = np.clip( + self.current_joint_pos[-1] + (action[-1] - 1) * self.config.max_gripper_pos, + 5, + self.config.max_gripper_pos, + ) + + self.current_ee_pos = desired_ee_pos.copy() + self.current_joint_pos = target_joint_values_in_degrees.copy() + self.current_joint_pos[-1] = joint_action["gripper.pos"] + + # Send joint space action to parent class + return super().send_action(joint_action) + + def get_observation(self) -> dict[str, Any]: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + # Read arm position + start = time.perf_counter() + obs_dict = self.bus.sync_read("Present_Position") + obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()} + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read state: {dt_ms:.1f}ms") + + # Capture images from cameras + for cam_key, cam in self.cameras.items(): + start = time.perf_counter() + obs_dict[cam_key] = cam.async_read() + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") + + return obs_dict + + def reset(self): + self.current_ee_pos = None + self.current_joint_pos = None diff --git a/lerobot/common/robots/utils.py b/lerobot/common/robots/utils.py index d100c8366c..ccc1c58e86 100644 --- a/lerobot/common/robots/utils.py +++ b/lerobot/common/robots/utils.py @@ -29,6 +29,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot: from .so100_follower import SO100Follower return SO100Follower(config) + elif config.type == "so100_follower_end_effector": + from .so100_follower import SO100FollowerEndEffector + + return SO100FollowerEndEffector(config) elif config.type == "so101_follower": from .so101_follower import SO101Follower diff --git a/lerobot/common/teleoperators/gamepad/__init__.py b/lerobot/common/teleoperators/gamepad/__init__.py new file mode 100644 index 0000000000..6f9f7fbd91 --- /dev/null +++ b/lerobot/common/teleoperators/gamepad/__init__.py @@ -0,0 +1,18 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_gamepad import GamepadTeleopConfig +from .teleop_gamepad import GamepadTeleop diff --git a/lerobot/common/teleoperators/gamepad/configuration_gamepad.py b/lerobot/common/teleoperators/gamepad/configuration_gamepad.py new file mode 100644 index 0000000000..b3a565c072 --- /dev/null +++ b/lerobot/common/teleoperators/gamepad/configuration_gamepad.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from ..config import TeleoperatorConfig + + +@TeleoperatorConfig.register_subclass("gamepad") +@dataclass +class GamepadTeleopConfig(TeleoperatorConfig): + use_gripper: bool = True diff --git a/lerobot/common/teleoperators/gamepad/gamepad_utils.py b/lerobot/common/teleoperators/gamepad/gamepad_utils.py new file mode 100644 index 0000000000..21a293c771 --- /dev/null +++ b/lerobot/common/teleoperators/gamepad/gamepad_utils.py @@ -0,0 +1,480 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + + +class InputController: + """Base class for input controllers that generate motion deltas.""" + + def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0): + """ + Initialize the controller. + + Args: + x_step_size: Base movement step size in meters + y_step_size: Base movement step size in meters + z_step_size: Base movement step size in meters + """ + self.x_step_size = x_step_size + self.y_step_size = y_step_size + self.z_step_size = z_step_size + self.running = True + self.episode_end_status = None # None, "success", or "failure" + self.intervention_flag = False + self.open_gripper_command = False + self.close_gripper_command = False + + def start(self): + """Start the controller and initialize resources.""" + pass + + def stop(self): + """Stop the controller and release resources.""" + pass + + def get_deltas(self): + """Get the current movement deltas (dx, dy, dz) in meters.""" + return 0.0, 0.0, 0.0 + + def should_quit(self): + """Return True if the user has requested to quit.""" + return not self.running + + def update(self): + """Update controller state - call this once per frame.""" + pass + + def __enter__(self): + """Support for use in 'with' statements.""" + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Ensure resources are released when exiting 'with' block.""" + self.stop() + + def get_episode_end_status(self): + """ + Get the current episode end status. + + Returns: + None if episode should continue, "success" or "failure" otherwise + """ + status = self.episode_end_status + self.episode_end_status = None # Reset after reading + return status + + def should_intervene(self): + """Return True if intervention flag was set.""" + return self.intervention_flag + + def gripper_command(self): + """Return the current gripper command.""" + if self.open_gripper_command == self.close_gripper_command: + return "stay" + elif self.open_gripper_command: + return "open" + elif self.close_gripper_command: + return "close" + + +class KeyboardController(InputController): + """Generate motion deltas from keyboard input.""" + + def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0): + super().__init__(x_step_size, y_step_size, z_step_size) + self.key_states = { + "forward_x": False, + "backward_x": False, + "forward_y": False, + "backward_y": False, + "forward_z": False, + "backward_z": False, + "quit": False, + "success": False, + "failure": False, + } + self.listener = None + + def start(self): + """Start the keyboard listener.""" + from pynput import keyboard + + def on_press(key): + try: + if key == keyboard.Key.up: + self.key_states["forward_x"] = True + elif key == keyboard.Key.down: + self.key_states["backward_x"] = True + elif key == keyboard.Key.left: + self.key_states["forward_y"] = True + elif key == keyboard.Key.right: + self.key_states["backward_y"] = True + elif key == keyboard.Key.shift: + self.key_states["backward_z"] = True + elif key == keyboard.Key.shift_r: + self.key_states["forward_z"] = True + elif key == keyboard.Key.esc: + self.key_states["quit"] = True + self.running = False + return False + elif key == keyboard.Key.enter: + self.key_states["success"] = True + self.episode_end_status = "success" + elif key == keyboard.Key.backspace: + self.key_states["failure"] = True + self.episode_end_status = "failure" + except AttributeError: + pass + + def on_release(key): + try: + if key == keyboard.Key.up: + self.key_states["forward_x"] = False + elif key == keyboard.Key.down: + self.key_states["backward_x"] = False + elif key == keyboard.Key.left: + self.key_states["forward_y"] = False + elif key == keyboard.Key.right: + self.key_states["backward_y"] = False + elif key == keyboard.Key.shift: + self.key_states["backward_z"] = False + elif key == keyboard.Key.shift_r: + self.key_states["forward_z"] = False + elif key == keyboard.Key.enter: + self.key_states["success"] = False + elif key == keyboard.Key.backspace: + self.key_states["failure"] = False + except AttributeError: + pass + + self.listener = keyboard.Listener(on_press=on_press, on_release=on_release) + self.listener.start() + + print("Keyboard controls:") + print(" Arrow keys: Move in X-Y plane") + print(" Shift and Shift_R: Move in Z axis") + print(" Enter: End episode with SUCCESS") + print(" Backspace: End episode with FAILURE") + print(" ESC: Exit") + + def stop(self): + """Stop the keyboard listener.""" + if self.listener and self.listener.is_alive(): + self.listener.stop() + + def get_deltas(self): + """Get the current movement deltas from keyboard state.""" + delta_x = delta_y = delta_z = 0.0 + + if self.key_states["forward_x"]: + delta_x += self.x_step_size + if self.key_states["backward_x"]: + delta_x -= self.x_step_size + if self.key_states["forward_y"]: + delta_y += self.y_step_size + if self.key_states["backward_y"]: + delta_y -= self.y_step_size + if self.key_states["forward_z"]: + delta_z += self.z_step_size + if self.key_states["backward_z"]: + delta_z -= self.z_step_size + + return delta_x, delta_y, delta_z + + def should_quit(self): + """Return True if ESC was pressed.""" + return self.key_states["quit"] + + def should_save(self): + """Return True if Enter was pressed (save episode).""" + return self.key_states["success"] or self.key_states["failure"] + + +class GamepadController(InputController): + """Generate motion deltas from gamepad input.""" + + def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0, deadzone=0.1): + super().__init__(x_step_size, y_step_size, z_step_size) + self.deadzone = deadzone + self.joystick = None + self.intervention_flag = False + + def start(self): + """Initialize pygame and the gamepad.""" + import pygame + + pygame.init() + pygame.joystick.init() + + if pygame.joystick.get_count() == 0: + logging.error("No gamepad detected. Please connect a gamepad and try again.") + self.running = False + return + + self.joystick = pygame.joystick.Joystick(0) + self.joystick.init() + logging.info(f"Initialized gamepad: {self.joystick.get_name()}") + + print("Gamepad controls:") + print(" Left analog stick: Move in X-Y plane") + print(" Right analog stick (vertical): Move in Z axis") + print(" B/Circle button: Exit") + print(" Y/Triangle button: End episode with SUCCESS") + print(" A/Cross button: End episode with FAILURE") + print(" X/Square button: Rerecord episode") + + def stop(self): + """Clean up pygame resources.""" + import pygame + + if pygame.joystick.get_init(): + if self.joystick: + self.joystick.quit() + pygame.joystick.quit() + pygame.quit() + + def update(self): + """Process pygame events to get fresh gamepad readings.""" + import pygame + + for event in pygame.event.get(): + if event.type == pygame.JOYBUTTONDOWN: + if event.button == 3: + self.episode_end_status = "success" + # A button (1) for failure + elif event.button == 1: + self.episode_end_status = "failure" + # X button (0) for rerecord + elif event.button == 0: + self.episode_end_status = "rerecord_episode" + + # RB button (6) for closing gripper + elif event.button == 6: + self.close_gripper_command = True + + # LT button (7) for opening gripper + elif event.button == 7: + self.open_gripper_command = True + + # Reset episode status on button release + elif event.type == pygame.JOYBUTTONUP: + if event.button in [0, 2, 3]: + self.episode_end_status = None + + elif event.button == 6: + self.close_gripper_command = False + + elif event.button == 7: + self.open_gripper_command = False + + # Check for RB button (typically button 5) for intervention flag + if self.joystick.get_button(5): + self.intervention_flag = True + else: + self.intervention_flag = False + + def get_deltas(self): + """Get the current movement deltas from gamepad state.""" + import pygame + + try: + # Read joystick axes + # Left stick X and Y (typically axes 0 and 1) + x_input = self.joystick.get_axis(0) # Left/Right + y_input = self.joystick.get_axis(1) # Up/Down (often inverted) + + # Right stick Y (typically axis 3 or 4) + z_input = self.joystick.get_axis(3) # Up/Down for Z + + # Apply deadzone to avoid drift + x_input = 0 if abs(x_input) < self.deadzone else x_input + y_input = 0 if abs(y_input) < self.deadzone else y_input + z_input = 0 if abs(z_input) < self.deadzone else z_input + + # Calculate deltas (note: may need to invert axes depending on controller) + delta_x = -y_input * self.y_step_size # Forward/backward + delta_y = -x_input * self.x_step_size # Left/right + delta_z = -z_input * self.z_step_size # Up/down + + return delta_x, delta_y, delta_z + + except pygame.error: + logging.error("Error reading gamepad. Is it still connected?") + return 0.0, 0.0, 0.0 + + +class GamepadControllerHID(InputController): + """Generate motion deltas from gamepad input using HIDAPI.""" + + def __init__( + self, + x_step_size=1.0, + y_step_size=1.0, + z_step_size=1.0, + deadzone=0.1, + ): + """ + Initialize the HID gamepad controller. + + Args: + step_size: Base movement step size in meters + z_scale: Scaling factor for Z-axis movement + deadzone: Joystick deadzone to prevent drift + """ + super().__init__(x_step_size, y_step_size, z_step_size) + self.deadzone = deadzone + self.device = None + self.device_info = None + + # Movement values (normalized from -1.0 to 1.0) + self.left_x = 0.0 + self.left_y = 0.0 + self.right_x = 0.0 + self.right_y = 0.0 + + # Button states + self.buttons = {} + self.quit_requested = False + self.save_requested = False + + def find_device(self): + """Look for the gamepad device by vendor and product ID.""" + import hid + + devices = hid.enumerate() + for device in devices: + device_name = device["product_string"] + if any(controller in device_name for controller in ["Logitech", "Xbox", "PS4", "PS5"]): + return device + + logging.error( + "No gamepad found, check the connection and the product string in HID to add your gamepad" + ) + return None + + def start(self): + """Connect to the gamepad using HIDAPI.""" + import hid + + self.device_info = self.find_device() + if not self.device_info: + self.running = False + return + + try: + logging.info(f"Connecting to gamepad at path: {self.device_info['path']}") + self.device = hid.device() + self.device.open_path(self.device_info["path"]) + self.device.set_nonblocking(1) + + manufacturer = self.device.get_manufacturer_string() + product = self.device.get_product_string() + logging.info(f"Connected to {manufacturer} {product}") + + logging.info("Gamepad controls (HID mode):") + logging.info(" Left analog stick: Move in X-Y plane") + logging.info(" Right analog stick: Move in Z axis (vertical)") + logging.info(" Button 1/B/Circle: Exit") + logging.info(" Button 2/A/Cross: End episode with SUCCESS") + logging.info(" Button 3/X/Square: End episode with FAILURE") + + except OSError as e: + logging.error(f"Error opening gamepad: {e}") + logging.error("You might need to run this with sudo/admin privileges on some systems") + self.running = False + + def stop(self): + """Close the HID device connection.""" + if self.device: + self.device.close() + self.device = None + + def update(self): + """ + Read and process the latest gamepad data. + Due to an issue with the HIDAPI, we need to read the read the device several times in order to get a stable reading + """ + for _ in range(10): + self._update() + + def _update(self): + """Read and process the latest gamepad data.""" + if not self.device or not self.running: + return + + try: + # Read data from the gamepad + data = self.device.read(64) + # Interpret gamepad data - this will vary by controller model + # These offsets are for the Logitech RumblePad 2 + if data and len(data) >= 8: + # Normalize joystick values from 0-255 to -1.0-1.0 + self.left_x = (data[1] - 128) / 128.0 + self.left_y = (data[2] - 128) / 128.0 + self.right_x = (data[3] - 128) / 128.0 + self.right_y = (data[4] - 128) / 128.0 + + # Apply deadzone + self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x + self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y + self.right_x = 0 if abs(self.right_x) < self.deadzone else self.right_x + self.right_y = 0 if abs(self.right_y) < self.deadzone else self.right_y + + # Parse button states (byte 5 in the Logitech RumblePad 2) + buttons = data[5] + + # Check if RB is pressed then the intervention flag should be set + self.intervention_flag = data[6] in [2, 6, 10, 14] + + # Check if RT is pressed + self.open_gripper_command = data[6] in [8, 10, 12] + + # Check if LT is pressed + self.close_gripper_command = data[6] in [4, 6, 12] + + # Check if Y/Triangle button (bit 7) is pressed for saving + # Check if X/Square button (bit 5) is pressed for failure + # Check if A/Cross button (bit 4) is pressed for rerecording + if buttons & 1 << 7: + self.episode_end_status = "success" + elif buttons & 1 << 5: + self.episode_end_status = "failure" + elif buttons & 1 << 4: + self.episode_end_status = "rerecord_episode" + else: + self.episode_end_status = None + + except OSError as e: + logging.error(f"Error reading from gamepad: {e}") + + def get_deltas(self): + """Get the current movement deltas from gamepad state.""" + # Calculate deltas - invert as needed based on controller orientation + delta_x = -self.left_y * self.x_step_size # Forward/backward + delta_y = -self.left_x * self.y_step_size # Left/right + delta_z = -self.right_y * self.z_step_size # Up/down + + return delta_x, delta_y, delta_z + + def should_quit(self): + """Return True if quit button was pressed.""" + return self.quit_requested + + def should_save(self): + """Return True if save button was pressed.""" + return self.save_requested diff --git a/lerobot/common/teleoperators/gamepad/teleop_gamepad.py b/lerobot/common/teleoperators/gamepad/teleop_gamepad.py new file mode 100644 index 0000000000..98a0647e21 --- /dev/null +++ b/lerobot/common/teleoperators/gamepad/teleop_gamepad.py @@ -0,0 +1,138 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from enum import IntEnum +from typing import Any + +import numpy as np + +from ..teleoperator import Teleoperator +from .configuration_gamepad import GamepadTeleopConfig + + +class GripperAction(IntEnum): + CLOSE = 0 + STAY = 1 + OPEN = 2 + + +gripper_action_map = { + "close": GripperAction.CLOSE.value, + "open": GripperAction.OPEN.value, + "stay": GripperAction.STAY.value, +} + + +class GamepadTeleop(Teleoperator): + """ + Teleop class to use gamepad inputs for control. + """ + + config_class = GamepadTeleopConfig + name = "gamepad" + + def __init__(self, config: GamepadTeleopConfig): + super().__init__(config) + self.config = config + self.robot_type = config.type + + self.gamepad = None + + @property + def action_features(self) -> dict: + if self.config.use_gripper: + return { + "dtype": "float32", + "shape": (4,), + "names": {"delta_x": 0, "delta_y": 1, "delta_z": 2, "gripper": 3}, + } + else: + return { + "dtype": "float32", + "shape": (3,), + "names": {"delta_x": 0, "delta_y": 1, "delta_z": 2}, + } + + @property + def feedback_features(self) -> dict: + return {} + + def connect(self) -> None: + # use HidApi for macos + if sys.platform == "darwin": + # NOTE: On macOS, pygame doesn’t reliably detect input from some controllers so we fall back to hidapi + from .gamepad_utils import GamepadControllerHID as Gamepad + else: + from .gamepad_utils import GamepadController as Gamepad + + self.gamepad = Gamepad() + self.gamepad.start() + + def get_action(self) -> dict[str, Any]: + # Update the controller to get fresh inputs + self.gamepad.update() + + # Get movement deltas from the controller + delta_x, delta_y, delta_z = self.gamepad.get_deltas() + + # Create action from gamepad input + gamepad_action = np.array([delta_x, delta_y, delta_z], dtype=np.float32) + + action_dict = { + "delta_x": gamepad_action[0], + "delta_y": gamepad_action[1], + "delta_z": gamepad_action[2], + } + + # Default gripper action is to stay + gripper_action = GripperAction.STAY.value + if self.config.use_gripper: + gripper_command = self.gamepad.gripper_command() + gripper_action = gripper_action_map[gripper_command] + action_dict["gripper"] = gripper_action + + return action_dict + + def disconnect(self) -> None: + """Disconnect from the gamepad.""" + if self.gamepad is not None: + self.gamepad.stop() + self.gamepad = None + + def is_connected(self) -> bool: + """Check if gamepad is connected.""" + return self.gamepad is not None + + def calibrate(self) -> None: + """Calibrate the gamepad.""" + # No calibration needed for gamepad + pass + + def is_calibrated(self) -> bool: + """Check if gamepad is calibrated.""" + # Gamepad doesn't require calibration + return True + + def configure(self) -> None: + """Configure the gamepad.""" + # No additional configuration needed + pass + + def send_feedback(self, feedback: dict) -> None: + """Send feedback to the gamepad.""" + # Gamepad doesn't support feedback + pass diff --git a/lerobot/common/teleoperators/so101_leader/config_so101_leader.py b/lerobot/common/teleoperators/so101_leader/config_so101_leader.py index 5f2e110da1..8d91c32dfe 100644 --- a/lerobot/common/teleoperators/so101_leader/config_so101_leader.py +++ b/lerobot/common/teleoperators/so101_leader/config_so101_leader.py @@ -24,3 +24,5 @@ class SO101LeaderConfig(TeleoperatorConfig): # Port to connect to the arm port: str + + use_degrees: bool = False diff --git a/lerobot/common/teleoperators/so101_leader/so101_leader.py b/lerobot/common/teleoperators/so101_leader/so101_leader.py index 34ad31dafe..d324e2a888 100644 --- a/lerobot/common/teleoperators/so101_leader/so101_leader.py +++ b/lerobot/common/teleoperators/so101_leader/so101_leader.py @@ -41,14 +41,15 @@ class SO101Leader(Teleoperator): def __init__(self, config: SO101LeaderConfig): super().__init__(config) self.config = config + norm_mode_body = MotorNormMode.DEGREES if config.use_degrees else MotorNormMode.RANGE_M100_100 self.bus = FeetechMotorsBus( port=self.config.port, motors={ - "shoulder_pan": Motor(1, "sts3215", MotorNormMode.RANGE_M100_100), - "shoulder_lift": Motor(2, "sts3215", MotorNormMode.RANGE_M100_100), - "elbow_flex": Motor(3, "sts3215", MotorNormMode.RANGE_M100_100), - "wrist_flex": Motor(4, "sts3215", MotorNormMode.RANGE_M100_100), - "wrist_roll": Motor(5, "sts3215", MotorNormMode.RANGE_M100_100), + "shoulder_pan": Motor(1, "sts3215", norm_mode_body), + "shoulder_lift": Motor(2, "sts3215", norm_mode_body), + "elbow_flex": Motor(3, "sts3215", norm_mode_body), + "wrist_flex": Motor(4, "sts3215", norm_mode_body), + "wrist_roll": Motor(5, "sts3215", norm_mode_body), "gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100), }, calibration=self.calibration, diff --git a/lerobot/common/teleoperators/utils.py b/lerobot/common/teleoperators/utils.py index 4942084ac7..d7b7bcf0e6 100644 --- a/lerobot/common/teleoperators/utils.py +++ b/lerobot/common/teleoperators/utils.py @@ -45,5 +45,9 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator: from tests.mocks.mock_teleop import MockTeleop return MockTeleop(config) + elif config.type == "gamepad": + from .gamepad.teleop_gamepad import GamepadTeleop + + return GamepadTeleop(config) else: raise ValueError(config.type) diff --git a/lerobot/common/transport/services.proto b/lerobot/common/transport/services.proto new file mode 100644 index 0000000000..29d00005a6 --- /dev/null +++ b/lerobot/common/transport/services.proto @@ -0,0 +1,59 @@ +// Copyright 2024 The HuggingFace Inc. team. +// All rights reserved. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// To generate a classes for transport part (services_pb2.py and services_pb2_grpc.py) use the following command: +// +// python -m grpc_tools.protoc -I . --python_out=. --grpc_python_out=. lerobot/common/transport/services.proto +// +// The command should be launched from the root of the project. + +syntax = "proto3"; + +package transport; + +// LearnerService: the Actor calls this to push transitions. +// The Learner implements this service. +service LearnerService { + // Actor -> Learner to store transitions + rpc StreamParameters(Empty) returns (stream Parameters); + rpc SendTransitions(stream Transition) returns (Empty); + rpc SendInteractions(stream InteractionMessage) returns (Empty); + rpc Ready(Empty) returns (Empty); +} + +enum TransferState { + TRANSFER_UNKNOWN = 0; + TRANSFER_BEGIN = 1; + TRANSFER_MIDDLE = 2; + TRANSFER_END = 3; +} + +// Messages +message Transition { + TransferState transfer_state = 1; + bytes data = 2; +} + +message Parameters { + TransferState transfer_state = 1; + bytes data = 2; +} + +message InteractionMessage { + TransferState transfer_state = 1; + bytes data = 2; +} + +message Empty {} diff --git a/lerobot/common/transport/services_pb2.py b/lerobot/common/transport/services_pb2.py new file mode 100644 index 0000000000..727beb60de --- /dev/null +++ b/lerobot/common/transport/services_pb2.py @@ -0,0 +1,45 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: lerobot/common/transport/services.proto +# Protobuf Python Version: 5.29.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 5, + 29, + 0, + '', + 'lerobot/common/transport/services.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\'lerobot/common/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'lerobot.common.transport.services_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_TRANSFERSTATE']._serialized_start=305 + _globals['_TRANSFERSTATE']._serialized_end=401 + _globals['_TRANSITION']._serialized_start=54 + _globals['_TRANSITION']._serialized_end=130 + _globals['_PARAMETERS']._serialized_start=132 + _globals['_PARAMETERS']._serialized_end=208 + _globals['_INTERACTIONMESSAGE']._serialized_start=210 + _globals['_INTERACTIONMESSAGE']._serialized_end=294 + _globals['_EMPTY']._serialized_start=296 + _globals['_EMPTY']._serialized_end=303 + _globals['_LEARNERSERVICE']._serialized_start=404 + _globals['_LEARNERSERVICE']._serialized_end=661 +# @@protoc_insertion_point(module_scope) diff --git a/lerobot/common/transport/services_pb2_grpc.py b/lerobot/common/transport/services_pb2_grpc.py new file mode 100644 index 0000000000..5a7a924fd2 --- /dev/null +++ b/lerobot/common/transport/services_pb2_grpc.py @@ -0,0 +1,233 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +import warnings + +from lerobot.common.transport import services_pb2 as lerobot_dot_common_dot_transport_dot_services__pb2 + +GRPC_GENERATED_VERSION = '1.71.0' +GRPC_VERSION = grpc.__version__ +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + raise RuntimeError( + f'The grpc package installed is at version {GRPC_VERSION},' + + f' but the generated code in lerobot/common/transport/services_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + ) + + +class LearnerServiceStub: + """LearnerService: the Actor calls this to push transitions. + The Learner implements this service. + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.StreamParameters = channel.unary_stream( + '/transport.LearnerService/StreamParameters', + request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, + response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Parameters.FromString, + _registered_method=True) + self.SendTransitions = channel.stream_unary( + '/transport.LearnerService/SendTransitions', + request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Transition.SerializeToString, + response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, + _registered_method=True) + self.SendInteractions = channel.stream_unary( + '/transport.LearnerService/SendInteractions', + request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString, + response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, + _registered_method=True) + self.Ready = channel.unary_unary( + '/transport.LearnerService/Ready', + request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, + response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, + _registered_method=True) + + +class LearnerServiceServicer: + """LearnerService: the Actor calls this to push transitions. + The Learner implements this service. + """ + + def StreamParameters(self, request, context): + """Actor -> Learner to store transitions + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendTransitions(self, request_iterator, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendInteractions(self, request_iterator, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Ready(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_LearnerServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'StreamParameters': grpc.unary_stream_rpc_method_handler( + servicer.StreamParameters, + request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, + response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Parameters.SerializeToString, + ), + 'SendTransitions': grpc.stream_unary_rpc_method_handler( + servicer.SendTransitions, + request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Transition.FromString, + response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, + ), + 'SendInteractions': grpc.stream_unary_rpc_method_handler( + servicer.SendInteractions, + request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.FromString, + response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, + ), + 'Ready': grpc.unary_unary_rpc_method_handler( + servicer.Ready, + request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, + response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'transport.LearnerService', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('transport.LearnerService', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class LearnerService: + """LearnerService: the Actor calls this to push transitions. + The Learner implements this service. + """ + + @staticmethod + def StreamParameters(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream( + request, + target, + '/transport.LearnerService/StreamParameters', + lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, + lerobot_dot_common_dot_transport_dot_services__pb2.Parameters.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def SendTransitions(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_unary( + request_iterator, + target, + '/transport.LearnerService/SendTransitions', + lerobot_dot_common_dot_transport_dot_services__pb2.Transition.SerializeToString, + lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def SendInteractions(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_unary( + request_iterator, + target, + '/transport.LearnerService/SendInteractions', + lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString, + lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def Ready(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/transport.LearnerService/Ready', + lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, + lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/lerobot/common/transport/utils.py b/lerobot/common/transport/utils.py new file mode 100644 index 0000000000..774721fc6d --- /dev/null +++ b/lerobot/common/transport/utils.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import logging +import pickle # nosec B403: Safe usage for internal serialization only +from multiprocessing import Event, Queue +from typing import Any + +import torch + +from lerobot.common.transport import services_pb2 +from lerobot.common.utils.transition import Transition + +CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB + + +def bytes_buffer_size(buffer: io.BytesIO) -> int: + buffer.seek(0, io.SEEK_END) + result = buffer.tell() + buffer.seek(0) + return result + + +def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True): + buffer = io.BytesIO(buffer) + size_in_bytes = bytes_buffer_size(buffer) + + sent_bytes = 0 + + logging_method = logging.info if not silent else logging.debug + + logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with") + + while sent_bytes < size_in_bytes: + transfer_state = services_pb2.TransferState.TRANSFER_MIDDLE + + if sent_bytes + CHUNK_SIZE >= size_in_bytes: + transfer_state = services_pb2.TransferState.TRANSFER_END + elif sent_bytes == 0: + transfer_state = services_pb2.TransferState.TRANSFER_BEGIN + + size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes) + chunk = buffer.read(size_to_read) + + yield message_class(transfer_state=transfer_state, data=chunk) + sent_bytes += size_to_read + logging_method(f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}") + + logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB") + + +def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""): # type: ignore + bytes_buffer = io.BytesIO() + step = 0 + + logging.info(f"{log_prefix} Starting receiver") + for item in iterator: + logging.debug(f"{log_prefix} Received item") + if shutdown_event.is_set(): + logging.info(f"{log_prefix} Shutting down receiver") + return + + if item.transfer_state == services_pb2.TransferState.TRANSFER_BEGIN: + bytes_buffer.seek(0) + bytes_buffer.truncate(0) + bytes_buffer.write(item.data) + logging.debug(f"{log_prefix} Received data at step 0") + step = 0 + elif item.transfer_state == services_pb2.TransferState.TRANSFER_MIDDLE: + bytes_buffer.write(item.data) + step += 1 + logging.debug(f"{log_prefix} Received data at step {step}") + elif item.transfer_state == services_pb2.TransferState.TRANSFER_END: + bytes_buffer.write(item.data) + logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}") + + queue.put(bytes_buffer.getvalue()) + + bytes_buffer.seek(0) + bytes_buffer.truncate(0) + step = 0 + + logging.debug(f"{log_prefix} Queue updated") + else: + logging.warning(f"{log_prefix} Received unknown transfer state {item.transfer_state}") + raise ValueError(f"Received unknown transfer state {item.transfer_state}") + + +def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes: + """Convert model state dict to flat array for transmission""" + buffer = io.BytesIO() + + torch.save(state_dict, buffer) + + return buffer.getvalue() + + +def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]: + buffer = io.BytesIO(buffer) + buffer.seek(0) + return torch.load(buffer, weights_only=True) + + +def python_object_to_bytes(python_object: Any) -> bytes: + return pickle.dumps(python_object) + + +def bytes_to_python_object(buffer: bytes) -> Any: + buffer = io.BytesIO(buffer) + buffer.seek(0) + obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load + # Add validation checks here + return obj + + +def bytes_to_transitions(buffer: bytes) -> list[Transition]: + buffer = io.BytesIO(buffer) + buffer.seek(0) + transitions = torch.load(buffer, weights_only=True) + return transitions + + +def transitions_to_bytes(transitions: list[Transition]) -> bytes: + buffer = io.BytesIO() + torch.save(transitions, buffer) + return buffer.getvalue() diff --git a/lerobot/common/utils/buffer.py b/lerobot/common/utils/buffer.py new file mode 100644 index 0000000000..9ae231ad92 --- /dev/null +++ b/lerobot/common/utils/buffer.py @@ -0,0 +1,841 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from contextlib import suppress +from typing import Callable, Sequence, TypedDict + +import torch +import torch.nn.functional as F # noqa: N812 +from tqdm import tqdm + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.utils.transition import Transition + + +class BatchTransition(TypedDict): + state: dict[str, torch.Tensor] + action: torch.Tensor + reward: torch.Tensor + next_state: dict[str, torch.Tensor] + done: torch.Tensor + truncated: torch.Tensor + complementary_info: dict[str, torch.Tensor | float | int] | None = None + + +def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor: + """ + Perform a per-image random crop over a batch of images in a vectorized way. + (Same as shown previously.) + """ + B, C, H, W = images.shape # noqa: N806 + crop_h, crop_w = output_size + + if crop_h > H or crop_w > W: + raise ValueError( + f"Requested crop size ({crop_h}, {crop_w}) is bigger than the image size ({H}, {W})." + ) + + tops = torch.randint(0, H - crop_h + 1, (B,), device=images.device) + lefts = torch.randint(0, W - crop_w + 1, (B,), device=images.device) + + rows = torch.arange(crop_h, device=images.device).unsqueeze(0) + tops.unsqueeze(1) + cols = torch.arange(crop_w, device=images.device).unsqueeze(0) + lefts.unsqueeze(1) + + rows = rows.unsqueeze(2).expand(-1, -1, crop_w) # (B, crop_h, crop_w) + cols = cols.unsqueeze(1).expand(-1, crop_h, -1) # (B, crop_h, crop_w) + + images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C) + + # Gather pixels + cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :] + # cropped_hwcn => (B, crop_h, crop_w, C) + + cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w) + return cropped + + +def random_shift(images: torch.Tensor, pad: int = 4): + """Vectorized random shift, imgs: (B,C,H,W), pad: #pixels""" + _, _, h, w = images.shape + images = F.pad(input=images, pad=(pad, pad, pad, pad), mode="replicate") + return random_crop_vectorized(images=images, output_size=(h, w)) + + +class ReplayBuffer: + def __init__( + self, + capacity: int, + device: str = "cuda:0", + state_keys: Sequence[str] | None = None, + image_augmentation_function: Callable | None = None, + use_drq: bool = True, + storage_device: str = "cpu", + optimize_memory: bool = False, + ): + """ + Replay buffer for storing transitions. + It will allocate tensors on the specified device, when the first transition is added. + NOTE: If you encounter memory issues, you can try to use the `optimize_memory` flag to save memory or + and use the `storage_device` flag to store the buffer on a different device. + Args: + capacity (int): Maximum number of transitions to store in the buffer. + device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu"). + state_keys (List[str]): The list of keys that appear in `state` and `next_state`. + image_augmentation_function (Optional[Callable]): A function that takes a batch of images + and returns a batch of augmented images. If None, a default augmentation function is used. + use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer. + storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored. + Using "cpu" can help save GPU memory. + optimize_memory (bool): If True, optimizes memory by not storing duplicate next_states when + they can be derived from states. This is useful for large datasets where next_state[i] = state[i+1]. + """ + if capacity <= 0: + raise ValueError("Capacity must be greater than 0.") + + self.capacity = capacity + self.device = device + self.storage_device = storage_device + self.position = 0 + self.size = 0 + self.initialized = False + self.optimize_memory = optimize_memory + + # Track episode boundaries for memory optimization + self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device) + + # If no state_keys provided, default to an empty list + self.state_keys = state_keys if state_keys is not None else [] + + self.image_augmentation_function = image_augmentation_function + + if image_augmentation_function is None: + base_function = functools.partial(random_shift, pad=4) + self.image_augmentation_function = torch.compile(base_function) + self.use_drq = use_drq + + def _initialize_storage( + self, + state: dict[str, torch.Tensor], + action: torch.Tensor, + complementary_info: dict[str, torch.Tensor] | None = None, + ): + """Initialize the storage tensors based on the first transition.""" + # Determine shapes from the first transition + state_shapes = {key: val.squeeze(0).shape for key, val in state.items()} + action_shape = action.squeeze(0).shape + + # Pre-allocate tensors for storage + self.states = { + key: torch.empty((self.capacity, *shape), device=self.storage_device) + for key, shape in state_shapes.items() + } + self.actions = torch.empty((self.capacity, *action_shape), device=self.storage_device) + self.rewards = torch.empty((self.capacity,), device=self.storage_device) + + if not self.optimize_memory: + # Standard approach: store states and next_states separately + self.next_states = { + key: torch.empty((self.capacity, *shape), device=self.storage_device) + for key, shape in state_shapes.items() + } + else: + # Memory-optimized approach: don't allocate next_states buffer + # Just create a reference to states for consistent API + self.next_states = self.states # Just a reference for API consistency + + self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) + self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) + + # Initialize storage for complementary_info + self.has_complementary_info = complementary_info is not None + self.complementary_info_keys = [] + self.complementary_info = {} + + if self.has_complementary_info: + self.complementary_info_keys = list(complementary_info.keys()) + # Pre-allocate tensors for each key in complementary_info + for key, value in complementary_info.items(): + if isinstance(value, torch.Tensor): + value_shape = value.squeeze(0).shape + self.complementary_info[key] = torch.empty( + (self.capacity, *value_shape), device=self.storage_device + ) + elif isinstance(value, (int, float)): + # Handle scalar values similar to reward + self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device) + else: + raise ValueError(f"Unsupported type {type(value)} for complementary_info[{key}]") + + self.initialized = True + + def __len__(self): + return self.size + + def add( + self, + state: dict[str, torch.Tensor], + action: torch.Tensor, + reward: float, + next_state: dict[str, torch.Tensor], + done: bool, + truncated: bool, + complementary_info: dict[str, torch.Tensor] | None = None, + ): + """Saves a transition, ensuring tensors are stored on the designated storage device.""" + # Initialize storage if this is the first transition + if not self.initialized: + self._initialize_storage(state=state, action=action, complementary_info=complementary_info) + + # Store the transition in pre-allocated tensors + for key in self.states: + self.states[key][self.position].copy_(state[key].squeeze(dim=0)) + + if not self.optimize_memory: + # Only store next_states if not optimizing memory + self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0)) + + self.actions[self.position].copy_(action.squeeze(dim=0)) + self.rewards[self.position] = reward + self.dones[self.position] = done + self.truncateds[self.position] = truncated + + # Handle complementary_info if provided and storage is initialized + if complementary_info is not None and self.has_complementary_info: + # Store the complementary_info + for key in self.complementary_info_keys: + if key in complementary_info: + value = complementary_info[key] + if isinstance(value, torch.Tensor): + self.complementary_info[key][self.position].copy_(value.squeeze(dim=0)) + elif isinstance(value, (int, float)): + self.complementary_info[key][self.position] = value + + self.position = (self.position + 1) % self.capacity + self.size = min(self.size + 1, self.capacity) + + def sample(self, batch_size: int) -> BatchTransition: + """Sample a random batch of transitions and collate them into batched tensors.""" + if not self.initialized: + raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.") + + batch_size = min(batch_size, self.size) + high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size + + # Random indices for sampling - create on the same device as storage + idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device) + + # Identify image keys that need augmentation + image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else [] + + # Create batched state and next_state + batch_state = {} + batch_next_state = {} + + # First pass: load all state tensors to target device + for key in self.states: + batch_state[key] = self.states[key][idx].to(self.device) + + if not self.optimize_memory: + # Standard approach - load next_states directly + batch_next_state[key] = self.next_states[key][idx].to(self.device) + else: + # Memory-optimized approach - get next_state from the next index + next_idx = (idx + 1) % self.capacity + batch_next_state[key] = self.states[key][next_idx].to(self.device) + + # Apply image augmentation in a batched way if needed + if self.use_drq and image_keys: + # Concatenate all images from state and next_state + all_images = [] + for key in image_keys: + all_images.append(batch_state[key]) + all_images.append(batch_next_state[key]) + + # Optimization: Batch all images and apply augmentation once + all_images_tensor = torch.cat(all_images, dim=0) + augmented_images = self.image_augmentation_function(all_images_tensor) + + # Split the augmented images back to their sources + for i, key in enumerate(image_keys): + # Calculate offsets for the current image key: + # For each key, we have 2*batch_size images (batch_size for states, batch_size for next_states) + # States start at index i*2*batch_size and take up batch_size slots + batch_state[key] = augmented_images[i * 2 * batch_size : (i * 2 + 1) * batch_size] + # Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots + batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size] + + # Sample other tensors + batch_actions = self.actions[idx].to(self.device) + batch_rewards = self.rewards[idx].to(self.device) + batch_dones = self.dones[idx].to(self.device).float() + batch_truncateds = self.truncateds[idx].to(self.device).float() + + # Sample complementary_info if available + batch_complementary_info = None + if self.has_complementary_info: + batch_complementary_info = {} + for key in self.complementary_info_keys: + batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device) + + return BatchTransition( + state=batch_state, + action=batch_actions, + reward=batch_rewards, + next_state=batch_next_state, + done=batch_dones, + truncated=batch_truncateds, + complementary_info=batch_complementary_info, + ) + + def get_iterator( + self, + batch_size: int, + async_prefetch: bool = True, + queue_size: int = 2, + ): + """ + Creates an infinite iterator that yields batches of transitions. + Will automatically restart when internal iterator is exhausted. + + Args: + batch_size (int): Size of batches to sample + async_prefetch (bool): Whether to use asynchronous prefetching with threads (default: True) + queue_size (int): Number of batches to prefetch (default: 2) + + Yields: + BatchTransition: Batched transitions + """ + while True: # Create an infinite loop + if async_prefetch: + # Get the standard iterator + iterator = self._get_async_iterator(queue_size=queue_size, batch_size=batch_size) + else: + iterator = self._get_naive_iterator(batch_size=batch_size, queue_size=queue_size) + + # Yield all items from the iterator + with suppress(StopIteration): + yield from iterator + + def _get_async_iterator(self, batch_size: int, queue_size: int = 2): + """ + Create an iterator that continuously yields prefetched batches in a + background thread. The design is intentionally simple and avoids busy + waiting / complex state management. + + Args: + batch_size (int): Size of batches to sample. + queue_size (int): Maximum number of prefetched batches to keep in + memory. + + Yields: + BatchTransition: A batch sampled from the replay buffer. + """ + import queue + import threading + + data_queue: queue.Queue = queue.Queue(maxsize=queue_size) + shutdown_event = threading.Event() + + def producer() -> None: + """Continuously put sampled batches into the queue until shutdown.""" + while not shutdown_event.is_set(): + try: + batch = self.sample(batch_size) + # The timeout ensures the thread unblocks if the queue is full + # and the shutdown event gets set meanwhile. + data_queue.put(batch, block=True, timeout=0.5) + except queue.Full: + # Queue is full – loop again (will re-check shutdown_event) + continue + except Exception: + # Surface any unexpected error and terminate the producer. + shutdown_event.set() + + producer_thread = threading.Thread(target=producer, daemon=True) + producer_thread.start() + + try: + while not shutdown_event.is_set(): + try: + yield data_queue.get(block=True) + except Exception: + # If the producer already set the shutdown flag we exit. + if shutdown_event.is_set(): + break + finally: + shutdown_event.set() + # Drain the queue quickly to help the thread exit if it's blocked on `put`. + while not data_queue.empty(): + _ = data_queue.get_nowait() + # Give the producer thread a bit of time to finish. + producer_thread.join(timeout=1.0) + + def _get_naive_iterator(self, batch_size: int, queue_size: int = 2): + """ + Creates a simple non-threaded iterator that yields batches. + + Args: + batch_size (int): Size of batches to sample + queue_size (int): Number of initial batches to prefetch + + Yields: + BatchTransition: Batch transitions + """ + import collections + + queue = collections.deque() + + def enqueue(n): + for _ in range(n): + data = self.sample(batch_size) + queue.append(data) + + enqueue(queue_size) + while queue: + yield queue.popleft() + enqueue(1) + + @classmethod + def from_lerobot_dataset( + cls, + lerobot_dataset: LeRobotDataset, + device: str = "cuda:0", + state_keys: Sequence[str] | None = None, + capacity: int | None = None, + image_augmentation_function: Callable | None = None, + use_drq: bool = True, + storage_device: str = "cpu", + optimize_memory: bool = False, + ) -> "ReplayBuffer": + """ + Convert a LeRobotDataset into a ReplayBuffer. + + Args: + lerobot_dataset (LeRobotDataset): The dataset to convert. + device (str): The device for sampling tensors. Defaults to "cuda:0". + state_keys (Sequence[str] | None): The list of keys that appear in `state` and `next_state`. + capacity (int | None): Buffer capacity. If None, uses dataset length. + action_mask (Sequence[int] | None): Indices of action dimensions to keep. + image_augmentation_function (Callable | None): Function for image augmentation. + If None, uses default random shift with pad=4. + use_drq (bool): Whether to use DrQ image augmentation when sampling. + storage_device (str): Device for storing tensor data. Using "cpu" saves GPU memory. + optimize_memory (bool): If True, reduces memory usage by not duplicating state data. + + Returns: + ReplayBuffer: The replay buffer with dataset transitions. + """ + if capacity is None: + capacity = len(lerobot_dataset) + + if capacity < len(lerobot_dataset): + raise ValueError( + "The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset." + ) + + # Create replay buffer with image augmentation and DrQ settings + replay_buffer = cls( + capacity=capacity, + device=device, + state_keys=state_keys, + image_augmentation_function=image_augmentation_function, + use_drq=use_drq, + storage_device=storage_device, + optimize_memory=optimize_memory, + ) + + # Convert dataset to transitions + list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys) + + # Initialize the buffer with the first transition to set up storage tensors + if list_transition: + first_transition = list_transition[0] + first_state = {k: v.to(device) for k, v in first_transition["state"].items()} + first_action = first_transition["action"].to(device) + + # Get complementary info if available + first_complementary_info = None + if ( + "complementary_info" in first_transition + and first_transition["complementary_info"] is not None + ): + first_complementary_info = { + k: v.to(device) for k, v in first_transition["complementary_info"].items() + } + + replay_buffer._initialize_storage( + state=first_state, action=first_action, complementary_info=first_complementary_info + ) + + # Fill the buffer with all transitions + for data in list_transition: + for k, v in data.items(): + if isinstance(v, dict): + for key, tensor in v.items(): + v[key] = tensor.to(storage_device) + elif isinstance(v, torch.Tensor): + data[k] = v.to(storage_device) + + action = data["action"] + + replay_buffer.add( + state=data["state"], + action=action, + reward=data["reward"], + next_state=data["next_state"], + done=data["done"], + truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset + complementary_info=data.get("complementary_info", None), + ) + + return replay_buffer + + def to_lerobot_dataset( + self, + repo_id: str, + fps=1, + root=None, + task_name="from_replay_buffer", + ) -> LeRobotDataset: + """ + Converts all transitions in this ReplayBuffer into a single LeRobotDataset object. + """ + if self.size == 0: + raise ValueError("The replay buffer is empty. Cannot convert to a dataset.") + + # Create features dictionary for the dataset + features = { + "index": {"dtype": "int64", "shape": [1]}, # global index across episodes + "episode_index": {"dtype": "int64", "shape": [1]}, # which episode + "frame_index": {"dtype": "int64", "shape": [1]}, # index inside an episode + "timestamp": {"dtype": "float32", "shape": [1]}, # for now we store dummy + "task_index": {"dtype": "int64", "shape": [1]}, + } + + # Add "action" + sample_action = self.actions[0] + act_info = guess_feature_info(t=sample_action, name="action") + features["action"] = act_info + + # Add "reward" and "done" + features["next.reward"] = {"dtype": "float32", "shape": (1,)} + features["next.done"] = {"dtype": "bool", "shape": (1,)} + + # Add state keys + for key in self.states: + sample_val = self.states[key][0] + f_info = guess_feature_info(t=sample_val, name=key) + features[key] = f_info + + # Add complementary_info keys if available + if self.has_complementary_info: + for key in self.complementary_info_keys: + sample_val = self.complementary_info[key][0] + if isinstance(sample_val, torch.Tensor) and sample_val.ndim == 0: + sample_val = sample_val.unsqueeze(0) + f_info = guess_feature_info(t=sample_val, name=f"complementary_info.{key}") + features[f"complementary_info.{key}"] = f_info + + # Create an empty LeRobotDataset + lerobot_dataset = LeRobotDataset.create( + repo_id=repo_id, + fps=fps, + root=root, + robot_type=None, + features=features, + use_videos=True, + ) + + # Start writing images if needed + lerobot_dataset.start_image_writer(num_processes=0, num_threads=3) + + # Convert transitions into episodes and frames + episode_index = 0 + lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index=episode_index) + + frame_idx_in_episode = 0 + for idx in range(self.size): + actual_idx = (self.position - self.size + idx) % self.capacity + + frame_dict = {} + + # Fill the data for state keys + for key in self.states: + frame_dict[key] = self.states[key][actual_idx].cpu() + + # Fill action, reward, done + frame_dict["action"] = self.actions[actual_idx].cpu() + frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu() + frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu() + + # Add complementary_info if available + if self.has_complementary_info: + for key in self.complementary_info_keys: + val = self.complementary_info[key][actual_idx] + # Convert tensors to CPU + if isinstance(val, torch.Tensor): + if val.ndim == 0: + val = val.unsqueeze(0) + frame_dict[f"complementary_info.{key}"] = val.cpu() + # Non-tensor values can be used directly + else: + frame_dict[f"complementary_info.{key}"] = val + + # Add to the dataset's buffer + lerobot_dataset.add_frame(frame_dict, task=task_name) + + # Move to next frame + frame_idx_in_episode += 1 + + # If we reached an episode boundary, call save_episode, reset counters + if self.dones[actual_idx] or self.truncateds[actual_idx]: + lerobot_dataset.save_episode() + episode_index += 1 + frame_idx_in_episode = 0 + lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer( + episode_index=episode_index + ) + + # Save any remaining frames in the buffer + if lerobot_dataset.episode_buffer["size"] > 0: + lerobot_dataset.save_episode() + + lerobot_dataset.stop_image_writer() + + return lerobot_dataset + + @staticmethod + def _lerobotdataset_to_transitions( + dataset: LeRobotDataset, + state_keys: Sequence[str] | None = None, + ) -> list[Transition]: + """ + Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions. + + Args: + dataset (LeRobotDataset): + The dataset to convert. Each item in the dataset is expected to have + at least the following keys: + { + "action": ... + "next.reward": ... + "next.done": ... + "episode_index": ... + } + plus whatever your 'state_keys' specify. + + state_keys (Sequence[str] | None): + The dataset keys to include in 'state' and 'next_state'. Their names + will be kept as-is in the output transitions. E.g. + ["observation.state", "observation.environment_state"]. + If None, you must handle or define default keys. + + Returns: + transitions (List[Transition]): + A list of Transition dictionaries with the same length as `dataset`. + """ + if state_keys is None: + raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.") + + transitions = [] + num_frames = len(dataset) + + # Check if the dataset has "next.done" key + sample = dataset[0] + has_done_key = "next.done" in sample + + # Check for complementary_info keys + complementary_info_keys = [key for key in sample if key.startswith("complementary_info.")] + has_complementary_info = len(complementary_info_keys) > 0 + + # If not, we need to infer it from episode boundaries + if not has_done_key: + print("'next.done' key not found in dataset. Inferring from episode boundaries...") + + for i in tqdm(range(num_frames)): + current_sample = dataset[i] + + # ----- 1) Current state ----- + current_state: dict[str, torch.Tensor] = {} + for key in state_keys: + val = current_sample[key] + current_state[key] = val.unsqueeze(0) # Add batch dimension + + # ----- 2) Action ----- + action = current_sample["action"].unsqueeze(0) # Add batch dimension + + # ----- 3) Reward and done ----- + reward = float(current_sample["next.reward"].item()) # ensure float + + # Determine done flag - use next.done if available, otherwise infer from episode boundaries + if has_done_key: + done = bool(current_sample["next.done"].item()) # ensure bool + else: + # If this is the last frame or if next frame is in a different episode, mark as done + done = False + if i == num_frames - 1: + done = True + elif i < num_frames - 1: + next_sample = dataset[i + 1] + if next_sample["episode_index"] != current_sample["episode_index"]: + done = True + + # TODO: (azouitine) Handle truncation (using the same value as done for now) + truncated = done + + # ----- 4) Next state ----- + # If not done and the next sample is in the same episode, we pull the next sample's state. + # Otherwise (done=True or next sample crosses to a new episode), next_state = current_state. + next_state = current_state # default + if not done and (i < num_frames - 1): + next_sample = dataset[i + 1] + if next_sample["episode_index"] == current_sample["episode_index"]: + # Build next_state from the same keys + next_state_data: dict[str, torch.Tensor] = {} + for key in state_keys: + val = next_sample[key] + next_state_data[key] = val.unsqueeze(0) # Add batch dimension + next_state = next_state_data + + # ----- 5) Complementary info (if available) ----- + complementary_info = None + if has_complementary_info: + complementary_info = {} + for key in complementary_info_keys: + # Strip the "complementary_info." prefix to get the actual key + clean_key = key[len("complementary_info.") :] + val = current_sample[key] + # Handle tensor and non-tensor values differently + if isinstance(val, torch.Tensor): + complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension + else: + # TODO: (azouitine) Check if it's necessary to convert to tensor + # For non-tensor values, use directly + complementary_info[clean_key] = val + + # ----- Construct the Transition ----- + transition = Transition( + state=current_state, + action=action, + reward=reward, + next_state=next_state, + done=done, + truncated=truncated, + complementary_info=complementary_info, + ) + transitions.append(transition) + + return transitions + + +# Utility function to guess shapes/dtypes from a tensor +def guess_feature_info(t, name: str): + """ + Return a dictionary with the 'dtype' and 'shape' for a given tensor or scalar value. + If it looks like a 3D (C,H,W) shape, we might consider it an 'image'. + Otherwise default to appropriate dtype for numeric. + """ + + shape = tuple(t.shape) + # Basic guess: if we have exactly 3 dims and shape[0] in {1, 3}, guess 'image' + if len(shape) == 3 and shape[0] in [1, 3]: + return { + "dtype": "image", + "shape": shape, + } + else: + # Otherwise treat as numeric + return { + "dtype": "float32", + "shape": shape, + } + + +def concatenate_batch_transitions( + left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition +) -> BatchTransition: + """ + Concatenates two BatchTransition objects into one. + + This function merges the right BatchTransition into the left one by concatenating + all corresponding tensors along dimension 0. The operation modifies the left_batch_transitions + in place and also returns it. + + Args: + left_batch_transitions (BatchTransition): The first batch to concatenate and the one + that will be modified in place. + right_batch_transition (BatchTransition): The second batch to append to the first one. + + Returns: + BatchTransition: The concatenated batch (same object as left_batch_transitions). + + Warning: + This function modifies the left_batch_transitions object in place. + """ + # Concatenate state fields + left_batch_transitions["state"] = { + key: torch.cat( + [left_batch_transitions["state"][key], right_batch_transition["state"][key]], + dim=0, + ) + for key in left_batch_transitions["state"] + } + + # Concatenate basic fields + left_batch_transitions["action"] = torch.cat( + [left_batch_transitions["action"], right_batch_transition["action"]], dim=0 + ) + left_batch_transitions["reward"] = torch.cat( + [left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0 + ) + + # Concatenate next_state fields + left_batch_transitions["next_state"] = { + key: torch.cat( + [left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], + dim=0, + ) + for key in left_batch_transitions["next_state"] + } + + # Concatenate done and truncated fields + left_batch_transitions["done"] = torch.cat( + [left_batch_transitions["done"], right_batch_transition["done"]], dim=0 + ) + left_batch_transitions["truncated"] = torch.cat( + [left_batch_transitions["truncated"], right_batch_transition["truncated"]], + dim=0, + ) + + # Handle complementary_info + left_info = left_batch_transitions.get("complementary_info") + right_info = right_batch_transition.get("complementary_info") + + # Only process if right_info exists + if right_info is not None: + # Initialize left complementary_info if needed + if left_info is None: + left_batch_transitions["complementary_info"] = right_info + else: + # Concatenate each field + for key in right_info: + if key in left_info: + left_info[key] = torch.cat([left_info[key], right_info[key]], dim=0) + else: + left_info[key] = right_info[key] + + return left_batch_transitions diff --git a/lerobot/common/utils/import_utils.py b/lerobot/common/utils/import_utils.py index cd5f824502..5c29b5a847 100644 --- a/lerobot/common/utils/import_utils.py +++ b/lerobot/common/utils/import_utils.py @@ -28,6 +28,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b try: # Primary method to get the package version package_version = importlib.metadata.version(pkg_name) + except importlib.metadata.PackageNotFoundError: # Fallback method: Only for "torch" and versions containing "dev" if pkg_name == "torch": @@ -43,6 +44,9 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b except ImportError: # If the package can't be imported, it's not available package_exists = False + elif pkg_name == "grpc": + package = importlib.import_module(pkg_name) + package_version = getattr(package, "__version__", "N/A") else: # For packages other than "torch", don't attempt the fallback and set as not available package_exists = False diff --git a/lerobot/common/utils/process.py b/lerobot/common/utils/process.py new file mode 100644 index 0000000000..72438b6f98 --- /dev/null +++ b/lerobot/common/utils/process.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import signal +import sys + + +class ProcessSignalHandler: + """Utility class to attach graceful shutdown signal handlers. + + The class exposes a shutdown_event attribute that is set when a shutdown + signal is received. A counter tracks how many shutdown signals have been + caught. On the second signal the process exits with status 1. + """ + + _SUPPORTED_SIGNALS = ("SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT") + + def __init__(self, use_threads: bool, display_pid: bool = False): + # TODO: Check if we can use Event from threading since Event from + # multiprocessing is the a clone of threading.Event. + # https://docs.python.org/3/library/multiprocessing.html#multiprocessing.Event + if use_threads: + from threading import Event + else: + from multiprocessing import Event + + self.shutdown_event = Event() + self._counter: int = 0 + self._display_pid = display_pid + + self._register_handlers() + + @property + def counter(self) -> int: # pragma: no cover – simple accessor + """Number of shutdown signals that have been intercepted.""" + return self._counter + + def _register_handlers(self): + """Attach the internal _signal_handler to a subset of POSIX signals.""" + + def _signal_handler(signum, frame): + pid_str = "" + if self._display_pid: + pid_str = f"[PID: {os.getpid()}]" + logging.info(f"{pid_str} Shutdown signal {signum} received. Cleaning up…") + self.shutdown_event.set() + self._counter += 1 + + # On a second Ctrl-C (or any supported signal) force the exit to + # mimic the previous behaviour while giving the caller one chance to + # shutdown gracefully. + # TODO: Investigate if we need it later + if self._counter > 1: + logging.info("Force shutdown") + sys.exit(1) + + for sig_name in self._SUPPORTED_SIGNALS: + sig = getattr(signal, sig_name, None) + if sig is None: + # The signal is not available on this platform (Windows for + # instance does not provide SIGHUP, SIGQUIT…). Skip it. + continue + try: + signal.signal(sig, _signal_handler) + except (ValueError, OSError): # pragma: no cover – unlikely but safe + # Signal not supported or we are in a non-main thread. + continue diff --git a/lerobot/common/utils/queue.py b/lerobot/common/utils/queue.py new file mode 100644 index 0000000000..ceb30e2bff --- /dev/null +++ b/lerobot/common/utils/queue.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from queue import Empty +from typing import Any + +from torch.multiprocessing import Queue + + +def get_last_item_from_queue(queue: Queue, block=True, timeout: float = 0.1) -> Any: + if block: + try: + item = queue.get(timeout=timeout) + except Empty: + return None + else: + item = None + + # Drain queue and keep only the most recent parameters + try: + while True: + item = queue.get_nowait() + except Empty: + pass + + return item diff --git a/lerobot/common/utils/transition.py b/lerobot/common/utils/transition.py new file mode 100644 index 0000000000..db413c388f --- /dev/null +++ b/lerobot/common/utils/transition.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TypedDict + +import torch + + +class Transition(TypedDict): + state: dict[str, torch.Tensor] + action: torch.Tensor + reward: float + next_state: dict[str, torch.Tensor] + done: bool + truncated: bool + complementary_info: dict[str, torch.Tensor | float | int] | None = None + + +def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition: + device = torch.device(device) + non_blocking = device.type == "cuda" + + # Move state tensors to device + transition["state"] = { + key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items() + } + + # Move action to device + transition["action"] = transition["action"].to(device, non_blocking=non_blocking) + + # Move reward and done if they are tensors + if isinstance(transition["reward"], torch.Tensor): + transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking) + + if isinstance(transition["done"], torch.Tensor): + transition["done"] = transition["done"].to(device, non_blocking=non_blocking) + + if isinstance(transition["truncated"], torch.Tensor): + transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking) + + # Move next_state tensors to device + transition["next_state"] = { + key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items() + } + + # Move complementary_info tensors if present + if transition.get("complementary_info") is not None: + for key, val in transition["complementary_info"].items(): + if isinstance(val, torch.Tensor): + transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking) + elif isinstance(val, (int, float, bool)): + transition["complementary_info"][key] = torch.tensor(val, device=device) + else: + raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]") + return transition + + +def move_state_dict_to_device(state_dict, device="cpu"): + """ + Recursively move all tensors in a (potentially) nested + dict/list/tuple structure to the CPU. + """ + if isinstance(state_dict, torch.Tensor): + return state_dict.to(device) + elif isinstance(state_dict, dict): + return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()} + elif isinstance(state_dict, list): + return [move_state_dict_to_device(v, device=device) for v in state_dict] + elif isinstance(state_dict, tuple): + return tuple(move_state_dict_to_device(v, device=device) for v in state_dict) + else: + return state_dict diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index 08e9a3c06b..cba65ba456 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -20,9 +20,11 @@ import select import subprocess import sys -from copy import copy +import time +from copy import copy, deepcopy from datetime import datetime, timezone from pathlib import Path +from statistics import mean import numpy as np import torch @@ -109,11 +111,17 @@ def is_amp_available(device: str): raise ValueError(f"Unknown device '{device}.") -def init_logging(): +def init_logging(log_file: Path | None = None, display_pid: bool = False): def custom_format(record): dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") fnameline = f"{record.pathname}:{record.lineno}" - message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}" + + # NOTE: Display PID is useful for multi-process logging. + if display_pid: + pid_str = f"[PID: {os.getpid()}]" + message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.msg}" + else: + message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}" return message logging.basicConfig(level=logging.INFO) @@ -127,6 +135,12 @@ def custom_format(record): console_handler.setFormatter(formatter) logging.getLogger().addHandler(console_handler) + if log_file is not None: + # Additionally write logs to file + file_handler = logging.FileHandler(log_file) + file_handler.setFormatter(formatter) + logging.getLogger().addHandler(file_handler) + def format_big_number(num, precision=0): suffixes = ["", "K", "M", "B", "T", "Q"] @@ -247,3 +261,114 @@ def enter_pressed() -> bool: def move_cursor_up(lines): """Move the cursor up by a specified number of lines.""" print(f"\033[{lines}A", end="") + + +class TimerManager: + """ + Lightweight utility to measure elapsed time. + + Examples + -------- + ```python + # Example 1: Using context manager + timer = TimerManager("Policy", log=False) + for _ in range(3): + with timer: + time.sleep(0.01) + print(timer.last, timer.fps_avg, timer.percentile(90)) # Prints: 0.01 100.0 0.01 + ``` + + ```python + # Example 2: Using start/stop methods + timer = TimerManager("Policy", log=False) + timer.start() + time.sleep(0.01) + timer.stop() + print(timer.last, timer.fps_avg, timer.percentile(90)) # Prints: 0.01 100.0 0.01 + ``` + """ + + def __init__( + self, + label: str = "Elapsed-time", + log: bool = True, + logger: logging.Logger | None = None, + ): + self.label = label + self.log = log + self.logger = logger + self._start: float | None = None + self._history: list[float] = [] + + def __enter__(self): + return self.start() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() + + def start(self): + self._start = time.perf_counter() + return self + + def stop(self) -> float: + if self._start is None: + raise RuntimeError("Timer was never started.") + elapsed = time.perf_counter() - self._start + self._history.append(elapsed) + self._start = None + if self.log: + if self.logger is not None: + self.logger.info(f"{self.label}: {elapsed:.6f} s") + else: + logging.info(f"{self.label}: {elapsed:.6f} s") + return elapsed + + def reset(self): + self._history.clear() + + @property + def last(self) -> float: + return self._history[-1] if self._history else 0.0 + + @property + def avg(self) -> float: + return mean(self._history) if self._history else 0.0 + + @property + def total(self) -> float: + return sum(self._history) + + @property + def count(self) -> int: + return len(self._history) + + @property + def history(self) -> list[float]: + return deepcopy(self._history) + + @property + def fps_history(self) -> list[float]: + return [1.0 / t for t in self._history] + + @property + def fps_last(self) -> float: + return 0.0 if self.last == 0 else 1.0 / self.last + + @property + def fps_avg(self) -> float: + return 0.0 if self.avg == 0 else 1.0 / self.avg + + def percentile(self, p: float) -> float: + """ + Return the p-th percentile of recorded times. + """ + if not self._history: + return 0.0 + return float(np.percentile(self._history, p)) + + def fps_percentile(self, p: float) -> float: + """ + FPS corresponding to the p-th percentile time. + """ + val = self.percentile(p) + return 0.0 if val == 0 else 1.0 / val diff --git a/lerobot/common/utils/wandb_utils.py b/lerobot/common/utils/wandb_utils.py index 9e938e1917..ac4d223433 100644 --- a/lerobot/common/utils/wandb_utils.py +++ b/lerobot/common/utils/wandb_utils.py @@ -30,9 +30,10 @@ def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[st """Return a group name for logging. Optionally returns group name as list.""" lst = [ f"policy:{cfg.policy.type}", - f"dataset:{cfg.dataset.repo_id}", f"seed:{cfg.seed}", ] + if cfg.dataset is not None: + lst.append(f"dataset:{cfg.dataset.repo_id}") if cfg.env is not None: lst.append(f"env:{cfg.env.type}") return lst if return_list else "-".join(lst) @@ -92,6 +93,12 @@ def __init__(self, cfg: TrainPipelineConfig): resume="must" if cfg.resume else None, mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online", ) + run_id = wandb.run.id + # NOTE: We will override the cfg.wandb.run_id with the wandb run id. + # This is because we want to be able to resume the run from the wandb run id. + cfg.wandb.run_id = run_id + # Handle custom step key for rl asynchronous training. + self._wandb_custom_step_key: set[str] | None = None print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") self._wandb = wandb @@ -108,9 +115,26 @@ def log_policy(self, checkpoint_dir: Path): artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE) self._wandb.log_artifact(artifact) - def log_dict(self, d: dict, step: int, mode: str = "train"): + def log_dict( + self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None + ): if mode not in {"train", "eval"}: raise ValueError(mode) + if step is None and custom_step_key is None: + raise ValueError("Either step or custom_step_key must be provided.") + + # NOTE: This is not simple. Wandb step must always monotonically increase and it + # increases with each wandb.log call, but in the case of asynchronous RL for example, + # multiple time steps is possible. For example, the interaction step with the environment, + # the training step, the evaluation step, etc. So we need to define a custom step key + # to log the correct step for each metric. + if custom_step_key is not None: + if self._wandb_custom_step_key is None: + self._wandb_custom_step_key = set() + new_custom_key = f"{mode}/{custom_step_key}" + if new_custom_key not in self._wandb_custom_step_key: + self._wandb_custom_step_key.add(new_custom_key) + self._wandb.define_metric(new_custom_key, hidden=True) for k, v in d.items(): if not isinstance(v, (int, float, str)): @@ -118,7 +142,18 @@ def log_dict(self, d: dict, step: int, mode: str = "train"): f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.' ) continue - self._wandb.log({f"{mode}/{k}": v}, step=step) + + # Do not log the custom step key itself. + if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key: + continue + + if custom_step_key is not None: + value_custom_step = d[custom_step_key] + data = {f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step} + self._wandb.log(data) + continue + + self._wandb.log(data={f"{mode}/{k}": v}, step=step) def log_video(self, video_path: str, step: int, mode: str = "train"): if mode not in {"train", "eval"}: diff --git a/lerobot/configs/control.py b/lerobot/configs/control.py deleted file mode 100644 index 07b8d13523..0000000000 --- a/lerobot/configs/control.py +++ /dev/null @@ -1,134 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from pathlib import Path - -import draccus - -from lerobot.common.robots import RobotConfig -from lerobot.configs import parser -from lerobot.configs.policies import PreTrainedConfig - - -@dataclass -class ControlConfig(draccus.ChoiceRegistry): - pass - - -@ControlConfig.register_subclass("calibrate") -@dataclass -class CalibrateControlConfig(ControlConfig): - # List of arms to calibrate (e.g. `--arms='["left_follower","right_follower"]' left_leader`) - arms: list[str] | None = None - - -@ControlConfig.register_subclass("teleoperate") -@dataclass -class TeleoperateControlConfig(ControlConfig): - # Limit the maximum frames per second. By default, no limit. - fps: int | None = None - teleop_time_s: float | None = None - # Display all cameras on screen - display_data: bool = False - - -@ControlConfig.register_subclass("record") -@dataclass -class RecordControlConfig(ControlConfig): - # Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`). - repo_id: str - # A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.") - single_task: str - # Root directory where the dataset will be stored (e.g. 'dataset/path'). - root: str | Path | None = None - policy: PreTrainedConfig | None = None - # Limit the frames per second. By default, uses the policy fps. - fps: int | None = None - # Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize. - warmup_time_s: int | float = 10 - # Number of seconds for data recording for each episode. - episode_time_s: int | float = 60 - # Number of seconds for resetting the environment after each episode. - reset_time_s: int | float = 60 - # Number of episodes to record. - num_episodes: int = 50 - # Encode frames in the dataset into video - video: bool = True - # Upload dataset to Hugging Face hub. - push_to_hub: bool = True - # Upload on private repository on the Hugging Face hub. - private: bool = False - # Add tags to your dataset on the hub. - tags: list[str] | None = None - # Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only; - # set to ≥1 to use subprocesses, each using threads to write images. The best number of processes - # and threads depends on your system. We recommend 4 threads per camera with 0 processes. - # If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses. - num_image_writer_processes: int = 0 - # Number of threads writing the frames as png images on disk, per camera. - # Too many threads might cause unstable teleoperation fps due to main thread being blocked. - # Not enough threads might cause low camera fps. - num_image_writer_threads_per_camera: int = 4 - # Display all cameras on screen - display_data: bool = False - # Use vocal synthesis to read events. - play_sounds: bool = True - # Resume recording on an existing dataset. - resume: bool = False - - def __post_init__(self): - # HACK: We parse again the cli args here to get the pretrained path if there was one. - policy_path = parser.get_path_arg("control.policy") - if policy_path: - cli_overrides = parser.get_cli_overrides("control.policy") - self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) - self.policy.pretrained_path = policy_path - - -@ControlConfig.register_subclass("replay") -@dataclass -class ReplayControlConfig(ControlConfig): - # Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`). - repo_id: str - # Index of the episode to replay. - episode: int - # Root directory where the dataset will be stored (e.g. 'dataset/path'). - root: str | Path | None = None - # Limit the frames per second. By default, uses the dataset fps. - fps: int | None = None - # Use vocal synthesis to read events. - play_sounds: bool = True - - -@ControlConfig.register_subclass("remote_robot") -@dataclass -class RemoteRobotConfig(ControlConfig): - log_interval: int = 100 - # Display all cameras on screen - display_data: bool = False - # Rerun configuration for remote robot (https://ref.rerun.io/docs/python/0.22.1/common/initialization_functions/#rerun.connect_tcp) - viewer_ip: str | None = None - viewer_port: str | None = None - - -@dataclass -class ControlPipelineConfig: - robot: RobotConfig - control: ControlConfig - - @classmethod - def __get_path_fields__(cls) -> list[str]: - """This enables the parser to load config from the policy using `--policy.path=local/dir`""" - return ["control.policy"] diff --git a/lerobot/configs/train.py b/lerobot/configs/train.py index 98826294ea..96a460bdf1 100644 --- a/lerobot/configs/train.py +++ b/lerobot/configs/train.py @@ -172,3 +172,8 @@ def from_pretrained( cli_args = kwargs.pop("cli_args", []) with draccus.config_type("json"): return draccus.parse(cls, config_file, args=cli_args) + + +@dataclass(kw_only=True) +class TrainRLServerPipelineConfig(TrainPipelineConfig): + dataset: DatasetConfig | None = None # NOTE: In RL, we don't need an offline dataset diff --git a/lerobot/configs/types.py b/lerobot/configs/types.py index 6b3d92e80d..6040ff70ba 100644 --- a/lerobot/configs/types.py +++ b/lerobot/configs/types.py @@ -23,6 +23,7 @@ class FeatureType(str, Enum): VISUAL = "VISUAL" ENV = "ENV" ACTION = "ACTION" + REWARD = "REWARD" class NormalizationMode(str, Enum): diff --git a/lerobot/scripts/find_joint_limits.py b/lerobot/scripts/find_joint_limits.py new file mode 100644 index 0000000000..95676dd359 --- /dev/null +++ b/lerobot/scripts/find_joint_limits.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Simple script to control a robot from teleoperation. + +Example: + +```shell +python -m lerobot.scripts.server.find_joint_limits \ + --robot.type=so100_follower \ + --robot.port=/dev/tty.usbmodem58760431541 \ + --robot.id=black \ + --teleop.type=so100_leader \ + --teleop.port=/dev/tty.usbmodem58760431551 \ + --teleop.id=blue +``` +""" + +import time +from dataclasses import dataclass + +import draccus +import numpy as np + +from lerobot.common.model.kinematics import RobotKinematics +from lerobot.common.robots import ( # noqa: F401 + RobotConfig, + koch_follower, + make_robot_from_config, + so100_follower, +) +from lerobot.common.teleoperators import ( # noqa: F401 + TeleoperatorConfig, + gamepad, + koch_leader, + make_teleoperator_from_config, + so100_leader, +) + + +@dataclass +class FindJointLimitsConfig: + teleop: TeleoperatorConfig + robot: RobotConfig + # Limit the maximum frames per second. By default, no limit. + teleop_time_s: float = 30 + # Display all cameras on screen + display_data: bool = False + + +@draccus.wrap() +def find_joint_and_ee_bounds(cfg: FindJointLimitsConfig): + teleop = make_teleoperator_from_config(cfg.teleop) + robot = make_robot_from_config(cfg.robot) + + teleop.connect() + robot.connect() + + start_episode_t = time.perf_counter() + robot_type = getattr(robot.config, "robot_type", "so101") + if "so100" in robot_type or "so101" in robot_type: + # Note to be compatible with the rest of the codebase, + # we are using the new calibration method for so101 and so100 + robot_type = "so_new_calibration" + kinematics = RobotKinematics(robot_type=robot_type) + + # Initialize min/max values + observation = robot.get_observation() + joint_positions = np.array([observation[f"{key}.pos"] for key in robot.bus.motors]) + ee_pos = kinematics.forward_kinematics(joint_positions, frame="gripper_tip")[:3, 3] + + max_pos = joint_positions.copy() + min_pos = joint_positions.copy() + max_ee = ee_pos.copy() + min_ee = ee_pos.copy() + + while True: + action = teleop.get_action() + robot.send_action(action) + + observation = robot.get_observation() + joint_positions = np.array([observation[f"{key}.pos"] for key in robot.bus.motors]) + ee_pos = kinematics.forward_kinematics(joint_positions, frame="gripper_tip")[:3, 3] + + # Skip initial warmup period + if (time.perf_counter() - start_episode_t) < 5: + continue + + # Update min/max values + max_ee = np.maximum(max_ee, ee_pos) + min_ee = np.minimum(min_ee, ee_pos) + max_pos = np.maximum(max_pos, joint_positions) + min_pos = np.minimum(min_pos, joint_positions) + + if time.perf_counter() - start_episode_t > cfg.teleop_time_s: + print(f"Max ee position {np.round(max_ee, 4).tolist()}") + print(f"Min ee position {np.round(min_ee, 4).tolist()}") + print(f"Max joint pos position {np.round(max_pos, 4).tolist()}") + print(f"Min joint pos position {np.round(min_pos, 4).tolist()}") + break + + +if __name__ == "__main__": + find_joint_and_ee_bounds() diff --git a/lerobot/scripts/rl/actor.py b/lerobot/scripts/rl/actor.py new file mode 100644 index 0000000000..da24d0dc58 --- /dev/null +++ b/lerobot/scripts/rl/actor.py @@ -0,0 +1,709 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Actor server runner for distributed HILSerl robot policy training. + +This script implements the actor component of the distributed HILSerl architecture. +It executes the policy in the robot environment, collects experience, +and sends transitions to the learner server for policy updates. + +Examples of usage: + +- Start an actor server for real robot training with human-in-the-loop intervention: +```bash +python lerobot/scripts/rl/actor.py --config_path lerobot/configs/train_config_hilserl_so100.json +``` + +**NOTE**: The actor server requires a running learner server to connect to. Ensure the learner +server is started before launching the actor. + +**NOTE**: Human intervention is key to HILSerl training. Press the upper right trigger button on the +gamepad to take control of the robot during training. Initially intervene frequently, then gradually +reduce interventions as the policy improves. + +**WORKFLOW**: +1. Determine robot workspace bounds using `find_joint_limits.py` +2. Record demonstrations with `gym_manipulator.py` in record mode +3. Process the dataset and determine camera crops with `crop_dataset_roi.py` +4. Start the learner server with the training configuration +5. Start this actor server with the same configuration +6. Use human interventions to guide policy learning + +For more details on the complete HILSerl training workflow, see: +https://github.com/michel-aractingi/lerobot-hilserl-guide +""" + +import logging +import os +import time +from functools import lru_cache +from queue import Empty + +import grpc +import torch +from torch import nn +from torch.multiprocessing import Event, Queue + +from lerobot.common.cameras import opencv # noqa: F401 +from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.sac.modeling_sac import SACPolicy +from lerobot.common.robots import so100_follower # noqa: F401 +from lerobot.common.teleoperators import gamepad, so101_leader # noqa: F401 +from lerobot.common.transport import services_pb2, services_pb2_grpc +from lerobot.common.transport.utils import ( + bytes_to_state_dict, + python_object_to_bytes, + receive_bytes_in_chunks, + send_bytes_in_chunks, + transitions_to_bytes, +) +from lerobot.common.utils.process import ProcessSignalHandler +from lerobot.common.utils.queue import get_last_item_from_queue +from lerobot.common.utils.random_utils import set_seed +from lerobot.common.utils.robot_utils import busy_wait +from lerobot.common.utils.transition import ( + Transition, + move_state_dict_to_device, + move_transition_to_device, +) +from lerobot.common.utils.utils import ( + TimerManager, + get_safe_torch_device, + init_logging, +) +from lerobot.configs import parser +from lerobot.configs.train import TrainRLServerPipelineConfig +from lerobot.scripts.rl import learner_service +from lerobot.scripts.rl.gym_manipulator import make_robot_env + +ACTOR_SHUTDOWN_TIMEOUT = 30 + + +################################################# +# Main entry point # +################################################# + + +@parser.wrap() +def actor_cli(cfg: TrainRLServerPipelineConfig): + cfg.validate() + display_pid = False + if not use_threads(cfg): + import torch.multiprocessing as mp + + mp.set_start_method("spawn") + display_pid = True + + # Create logs directory to ensure it exists + log_dir = os.path.join(cfg.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"actor_{cfg.job_name}.log") + + # Initialize logging with explicit log file + init_logging(log_file=log_file, display_pid=display_pid) + logging.info(f"Actor logging initialized, writing to {log_file}") + + is_threaded = use_threads(cfg) + shutdown_event = ProcessSignalHandler(is_threaded, display_pid=display_pid).shutdown_event + + learner_client, grpc_channel = learner_service_client( + host=cfg.policy.actor_learner_config.learner_host, + port=cfg.policy.actor_learner_config.learner_port, + ) + + logging.info("[ACTOR] Establishing connection with Learner") + if not establish_learner_connection(learner_client, shutdown_event): + logging.error("[ACTOR] Failed to establish connection with Learner") + return + + if not use_threads(cfg): + # If we use multithreading, we can reuse the channel + grpc_channel.close() + grpc_channel = None + + logging.info("[ACTOR] Connection with Learner established") + + parameters_queue = Queue() + transitions_queue = Queue() + interactions_queue = Queue() + + concurrency_entity = None + if use_threads(cfg): + from threading import Thread + + concurrency_entity = Thread + else: + from multiprocessing import Process + + concurrency_entity = Process + + receive_policy_process = concurrency_entity( + target=receive_policy, + args=(cfg, parameters_queue, shutdown_event, grpc_channel), + daemon=True, + ) + + transitions_process = concurrency_entity( + target=send_transitions, + args=(cfg, transitions_queue, shutdown_event, grpc_channel), + daemon=True, + ) + + interactions_process = concurrency_entity( + target=send_interactions, + args=(cfg, interactions_queue, shutdown_event, grpc_channel), + daemon=True, + ) + + transitions_process.start() + interactions_process.start() + receive_policy_process.start() + + act_with_policy( + cfg=cfg, + shutdown_event=shutdown_event, + parameters_queue=parameters_queue, + transitions_queue=transitions_queue, + interactions_queue=interactions_queue, + ) + logging.info("[ACTOR] Policy process joined") + + logging.info("[ACTOR] Closing queues") + transitions_queue.close() + interactions_queue.close() + parameters_queue.close() + + transitions_process.join() + logging.info("[ACTOR] Transitions process joined") + interactions_process.join() + logging.info("[ACTOR] Interactions process joined") + receive_policy_process.join() + logging.info("[ACTOR] Receive policy process joined") + + logging.info("[ACTOR] join queues") + transitions_queue.cancel_join_thread() + interactions_queue.cancel_join_thread() + parameters_queue.cancel_join_thread() + + logging.info("[ACTOR] queues closed") + + +################################################# +# Core algorithm functions # +################################################# + + +def act_with_policy( + cfg: TrainRLServerPipelineConfig, + shutdown_event: any, # Event, + parameters_queue: Queue, + transitions_queue: Queue, + interactions_queue: Queue, +): + """ + Executes policy interaction within the environment. + + This function rolls out the policy in the environment, collecting interaction data and pushing it to a queue for streaming to the learner. + Once an episode is completed, updated network parameters received from the learner are retrieved from a queue and loaded into the network. + + Args: + cfg: Configuration settings for the interaction process. + shutdown_event: Event to check if the process should shutdown. + parameters_queue: Queue to receive updated network parameters from the learner. + transitions_queue: Queue to send transitions to the learner. + interactions_queue: Queue to send interactions to the learner. + """ + # Initialize logging for multiprocessing + if not use_threads(cfg): + log_dir = os.path.join(cfg.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"actor_policy_{os.getpid()}.log") + init_logging(log_file=log_file, display_pid=True) + logging.info("Actor policy process logging initialized") + + logging.info("make_env online") + + online_env = make_robot_env(cfg=cfg.env) + + set_seed(cfg.seed) + device = get_safe_torch_device(cfg.policy.device, log=True) + + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + + logging.info("make_policy") + + ### Instantiate the policy in both the actor and learner processes + ### To avoid sending a SACPolicy object through the port, we create a policy instance + ### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters + policy: SACPolicy = make_policy( + cfg=cfg.policy, + env_cfg=cfg.env, + ) + policy = policy.eval() + assert isinstance(policy, nn.Module) + + obs, info = online_env.reset() + + # NOTE: For the moment we will solely handle the case of a single environment + sum_reward_episode = 0 + list_transition_to_send_to_learner = [] + episode_intervention = False + # Add counters for intervention rate calculation + episode_intervention_steps = 0 + episode_total_steps = 0 + + policy_timer = TimerManager("Policy inference", log=False) + + for interaction_step in range(cfg.policy.online_steps): + start_time = time.perf_counter() + if shutdown_event.is_set(): + logging.info("[ACTOR] Shutting down act_with_policy") + return + + if interaction_step >= cfg.policy.online_step_before_learning: + # Time policy inference and check if it meets FPS requirement + with policy_timer: + action = policy.select_action(batch=obs) + policy_fps = policy_timer.fps_last + + log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step) + + else: + action = online_env.action_space.sample() + + next_obs, reward, done, truncated, info = online_env.step(action) + + sum_reward_episode += float(reward) + # Increment total steps counter for intervention rate + episode_total_steps += 1 + + # NOTE: We override the action if the intervention is True, because the action applied is the intervention action + if "is_intervention" in info and info["is_intervention"]: + # NOTE: The action space for demonstration before hand is with the full action space + # but sometimes for example we want to deactivate the gripper + action = info["action_intervention"] + episode_intervention = True + # Increment intervention steps counter + episode_intervention_steps += 1 + + list_transition_to_send_to_learner.append( + Transition( + state=obs, + action=action, + reward=reward, + next_state=next_obs, + done=done, + truncated=truncated, # TODO: (azouitine) Handle truncation properly + complementary_info=info, + ) + ) + # assign obs to the next obs and continue the rollout + obs = next_obs + + if done or truncated: + logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}") + + update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device) + + if len(list_transition_to_send_to_learner) > 0: + push_transitions_to_transport_queue( + transitions=list_transition_to_send_to_learner, + transitions_queue=transitions_queue, + ) + list_transition_to_send_to_learner = [] + + stats = get_frequency_stats(policy_timer) + policy_timer.reset() + + # Calculate intervention rate + intervention_rate = 0.0 + if episode_total_steps > 0: + intervention_rate = episode_intervention_steps / episode_total_steps + + # Send episodic reward to the learner + interactions_queue.put( + python_object_to_bytes( + { + "Episodic reward": sum_reward_episode, + "Interaction step": interaction_step, + "Episode intervention": int(episode_intervention), + "Intervention rate": intervention_rate, + **stats, + } + ) + ) + + # Reset intervention counters + sum_reward_episode = 0.0 + episode_intervention = False + episode_intervention_steps = 0 + episode_total_steps = 0 + obs, info = online_env.reset() + + if cfg.env.fps is not None: + dt_time = time.perf_counter() - start_time + busy_wait(1 / cfg.env.fps - dt_time) + + +################################################# +# Communication Functions - Group all gRPC/messaging functions # +################################################# + + +def establish_learner_connection( + stub: services_pb2_grpc.LearnerServiceStub, + shutdown_event: Event, # type: ignore + attempts: int = 30, +): + """Establish a connection with the learner. + + Args: + stub (services_pb2_grpc.LearnerServiceStub): The stub to use for the connection. + shutdown_event (Event): The event to check if the connection should be established. + attempts (int): The number of attempts to establish the connection. + Returns: + bool: True if the connection is established, False otherwise. + """ + for _ in range(attempts): + if shutdown_event.is_set(): + logging.info("[ACTOR] Shutting down establish_learner_connection") + return False + + # Force a connection attempt and check state + try: + logging.info("[ACTOR] Send ready message to Learner") + if stub.Ready(services_pb2.Empty()) == services_pb2.Empty(): + return True + except grpc.RpcError as e: + logging.error(f"[ACTOR] Waiting for Learner to be ready... {e}") + time.sleep(2) + return False + + +@lru_cache(maxsize=1) +def learner_service_client( + host: str = "127.0.0.1", + port: int = 50051, +) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]: + import json + + """ + Returns a client for the learner service. + + GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection. + So we need to create only one client and reuse it. + """ + + service_config = { + "methodConfig": [ + { + "name": [{}], # Applies to ALL methods in ALL services + "retryPolicy": { + "maxAttempts": 5, # Max retries (total attempts = 5) + "initialBackoff": "0.1s", # First retry after 0.1s + "maxBackoff": "2s", # Max wait time between retries + "backoffMultiplier": 2, # Exponential backoff factor + "retryableStatusCodes": [ + "UNAVAILABLE", + "DEADLINE_EXCEEDED", + ], # Retries on network failures + }, + } + ] + } + + service_config_json = json.dumps(service_config) + + channel = grpc.insecure_channel( + f"{host}:{port}", + options=[ + ("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE), + ("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE), + ("grpc.enable_retries", 1), + ("grpc.service_config", service_config_json), + ], + ) + stub = services_pb2_grpc.LearnerServiceStub(channel) + logging.info("[ACTOR] Learner service client created") + return stub, channel + + +def receive_policy( + cfg: TrainRLServerPipelineConfig, + parameters_queue: Queue, + shutdown_event: Event, # type: ignore + learner_client: services_pb2_grpc.LearnerServiceStub | None = None, + grpc_channel: grpc.Channel | None = None, +): + """Receive parameters from the learner. + + Args: + cfg (TrainRLServerPipelineConfig): The configuration for the actor. + parameters_queue (Queue): The queue to receive the parameters. + shutdown_event (Event): The event to check if the process should shutdown. + """ + logging.info("[ACTOR] Start receiving parameters from the Learner") + if not use_threads(cfg): + # Create a process-specific log file + log_dir = os.path.join(cfg.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"actor_receive_policy_{os.getpid()}.log") + + # Initialize logging with explicit log file + init_logging(log_file=log_file, display_pid=True) + logging.info("Actor receive policy process logging initialized") + + # Setup process handlers to handle shutdown signal + # But use shutdown event from the main process + _ = ProcessSignalHandler(use_threads=False, display_pid=True) + + if grpc_channel is None or learner_client is None: + learner_client, grpc_channel = learner_service_client( + host=cfg.policy.actor_learner_config.learner_host, + port=cfg.policy.actor_learner_config.learner_port, + ) + + try: + iterator = learner_client.StreamParameters(services_pb2.Empty()) + receive_bytes_in_chunks( + iterator, + parameters_queue, + shutdown_event, + log_prefix="[ACTOR] parameters", + ) + + except grpc.RpcError as e: + logging.error(f"[ACTOR] gRPC error: {e}") + + if not use_threads(cfg): + grpc_channel.close() + logging.info("[ACTOR] Received policy loop stopped") + + +def send_transitions( + cfg: TrainRLServerPipelineConfig, + transitions_queue: Queue, + shutdown_event: any, # Event, + learner_client: services_pb2_grpc.LearnerServiceStub | None = None, + grpc_channel: grpc.Channel | None = None, +) -> services_pb2.Empty: + """ + Sends transitions to the learner. + + This function continuously retrieves messages from the queue and processes: + + - Transition Data: + - A batch of transitions (observation, action, reward, next observation) is collected. + - Transitions are moved to the CPU and serialized using PyTorch. + - The serialized data is wrapped in a `services_pb2.Transition` message and sent to the learner. + """ + + if not use_threads(cfg): + # Create a process-specific log file + log_dir = os.path.join(cfg.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"actor_transitions_{os.getpid()}.log") + + # Initialize logging with explicit log file + init_logging(log_file=log_file, display_pid=True) + logging.info("Actor transitions process logging initialized") + + if grpc_channel is None or learner_client is None: + learner_client, grpc_channel = learner_service_client( + host=cfg.policy.actor_learner_config.learner_host, + port=cfg.policy.actor_learner_config.learner_port, + ) + + try: + learner_client.SendTransitions( + transitions_stream( + shutdown_event, transitions_queue, cfg.policy.actor_learner_config.queue_get_timeout + ) + ) + except grpc.RpcError as e: + logging.error(f"[ACTOR] gRPC error: {e}") + + logging.info("[ACTOR] Finished streaming transitions") + + if not use_threads(cfg): + grpc_channel.close() + logging.info("[ACTOR] Transitions process stopped") + + +def send_interactions( + cfg: TrainRLServerPipelineConfig, + interactions_queue: Queue, + shutdown_event: Event, # type: ignore + learner_client: services_pb2_grpc.LearnerServiceStub | None = None, + grpc_channel: grpc.Channel | None = None, +) -> services_pb2.Empty: + """ + Sends interactions to the learner. + + This function continuously retrieves messages from the queue and processes: + + - Interaction Messages: + - Contains useful statistics about episodic rewards and policy timings. + - The message is serialized using `pickle` and sent to the learner. + """ + + if not use_threads(cfg): + # Create a process-specific log file + log_dir = os.path.join(cfg.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"actor_interactions_{os.getpid()}.log") + + # Initialize logging with explicit log file + init_logging(log_file=log_file, display_pid=True) + logging.info("Actor interactions process logging initialized") + + # Setup process handlers to handle shutdown signal + # But use shutdown event from the main process + _ = ProcessSignalHandler(use_threads=False, display_pid=True) + + if grpc_channel is None or learner_client is None: + learner_client, grpc_channel = learner_service_client( + host=cfg.policy.actor_learner_config.learner_host, + port=cfg.policy.actor_learner_config.learner_port, + ) + + try: + learner_client.SendInteractions( + interactions_stream( + shutdown_event, interactions_queue, cfg.policy.actor_learner_config.queue_get_timeout + ) + ) + except grpc.RpcError as e: + logging.error(f"[ACTOR] gRPC error: {e}") + + logging.info("[ACTOR] Finished streaming interactions") + + if not use_threads(cfg): + grpc_channel.close() + logging.info("[ACTOR] Interactions process stopped") + + +def transitions_stream(shutdown_event: Event, transitions_queue: Queue, timeout: float) -> services_pb2.Empty: # type: ignore + while not shutdown_event.is_set(): + try: + message = transitions_queue.get(block=True, timeout=timeout) + except Empty: + logging.debug("[ACTOR] Transition queue is empty") + continue + + yield from send_bytes_in_chunks( + message, services_pb2.Transition, log_prefix="[ACTOR] Send transitions" + ) + + return services_pb2.Empty() + + +def interactions_stream( + shutdown_event: Event, + interactions_queue: Queue, + timeout: float, # type: ignore +) -> services_pb2.Empty: + while not shutdown_event.is_set(): + try: + message = interactions_queue.get(block=True, timeout=timeout) + except Empty: + logging.debug("[ACTOR] Interaction queue is empty") + continue + + yield from send_bytes_in_chunks( + message, + services_pb2.InteractionMessage, + log_prefix="[ACTOR] Send interactions", + ) + + return services_pb2.Empty() + + +################################################# +# Policy functions # +################################################# + + +def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device): + bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False) + if bytes_state_dict is not None: + logging.info("[ACTOR] Load new parameters from Learner.") + state_dict = bytes_to_state_dict(bytes_state_dict) + state_dict = move_state_dict_to_device(state_dict, device=device) + policy.load_state_dict(state_dict) + + +################################################# +# Utilities functions # +################################################# + + +def push_transitions_to_transport_queue(transitions: list, transitions_queue): + """Send transitions to learner in smaller chunks to avoid network issues. + + Args: + transitions: List of transitions to send + message_queue: Queue to send messages to learner + chunk_size: Size of each chunk to send + """ + transition_to_send_to_learner = [] + for transition in transitions: + tr = move_transition_to_device(transition=transition, device="cpu") + for key, value in tr["state"].items(): + if torch.isnan(value).any(): + logging.warning(f"Found NaN values in transition {key}") + + transition_to_send_to_learner.append(tr) + + transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner)) + + +def get_frequency_stats(timer: TimerManager) -> dict[str, float]: + """Get the frequency statistics of the policy. + + Args: + timer (TimerManager): The timer with collected metrics. + + Returns: + dict[str, float]: The frequency statistics of the policy. + """ + stats = {} + if timer.count > 1: + avg_fps = timer.fps_avg + p90_fps = timer.fps_percentile(90) + logging.debug(f"[ACTOR] Average policy frame rate: {avg_fps}") + logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {p90_fps}") + stats = { + "Policy frequency [Hz]": avg_fps, + "Policy frequency 90th-p [Hz]": p90_fps, + } + return stats + + +def log_policy_frequency_issue(policy_fps: float, cfg: TrainRLServerPipelineConfig, interaction_step: int): + if policy_fps < cfg.env.fps: + logging.warning( + f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.env.fps} at step {interaction_step}" + ) + + +def use_threads(cfg: TrainRLServerPipelineConfig) -> bool: + return cfg.policy.concurrency.actor == "threads" + + +if __name__ == "__main__": + actor_cli() diff --git a/lerobot/scripts/rl/crop_dataset_roi.py b/lerobot/scripts/rl/crop_dataset_roi.py new file mode 100644 index 0000000000..5b7038de30 --- /dev/null +++ b/lerobot/scripts/rl/crop_dataset_roi.py @@ -0,0 +1,314 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +from copy import deepcopy +from pathlib import Path +from typing import Dict, Tuple + +import cv2 + +# import torch.nn.functional as F # noqa: N812 +import torchvision.transforms.functional as F # type: ignore # noqa: N812 +from tqdm import tqdm # type: ignore + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + + +def select_rect_roi(img): + """ + Allows the user to draw a rectangular ROI on the image. + + The user must click and drag to draw the rectangle. + - While dragging, the rectangle is dynamically drawn. + - On mouse button release, the rectangle is fixed. + - Press 'c' to confirm the selection. + - Press 'r' to reset the selection. + - Press ESC to cancel. + + Returns: + A tuple (top, left, height, width) representing the rectangular ROI, + or None if no valid ROI is selected. + """ + # Create a working copy of the image + clone = img.copy() + working_img = clone.copy() + + roi = None # Will store the final ROI as (top, left, height, width) + drawing = False + index_x, index_y = -1, -1 # Initial click coordinates + + def mouse_callback(event, x, y, flags, param): + nonlocal index_x, index_y, drawing, roi, working_img + + if event == cv2.EVENT_LBUTTONDOWN: + # Start drawing: record starting coordinates + drawing = True + index_x, index_y = x, y + + elif event == cv2.EVENT_MOUSEMOVE: + if drawing: + # Compute the top-left and bottom-right corners regardless of drag direction + top = min(index_y, y) + left = min(index_x, x) + bottom = max(index_y, y) + right = max(index_x, x) + # Show a temporary image with the current rectangle drawn + temp = working_img.copy() + cv2.rectangle(temp, (left, top), (right, bottom), (0, 255, 0), 2) + cv2.imshow("Select ROI", temp) + + elif event == cv2.EVENT_LBUTTONUP: + # Finish drawing + drawing = False + top = min(index_y, y) + left = min(index_x, x) + bottom = max(index_y, y) + right = max(index_x, x) + height = bottom - top + width = right - left + roi = (top, left, height, width) # (top, left, height, width) + # Draw the final rectangle on the working image and display it + working_img = clone.copy() + cv2.rectangle(working_img, (left, top), (right, bottom), (0, 255, 0), 2) + cv2.imshow("Select ROI", working_img) + + # Create the window and set the callback + cv2.namedWindow("Select ROI") + cv2.setMouseCallback("Select ROI", mouse_callback) + cv2.imshow("Select ROI", working_img) + + print("Instructions for ROI selection:") + print(" - Click and drag to draw a rectangular ROI.") + print(" - Press 'c' to confirm the selection.") + print(" - Press 'r' to reset and draw again.") + print(" - Press ESC to cancel the selection.") + + # Wait until the user confirms with 'c', resets with 'r', or cancels with ESC + while True: + key = cv2.waitKey(1) & 0xFF + # Confirm ROI if one has been drawn + if key == ord("c") and roi is not None: + break + # Reset: clear the ROI and restore the original image + elif key == ord("r"): + working_img = clone.copy() + roi = None + cv2.imshow("Select ROI", working_img) + # Cancel selection for this image + elif key == 27: # ESC key + roi = None + break + + cv2.destroyWindow("Select ROI") + return roi + + +def select_square_roi_for_images(images: dict) -> dict: + """ + For each image in the provided dictionary, open a window to allow the user + to select a rectangular ROI. Returns a dictionary mapping each key to a tuple + (top, left, height, width) representing the ROI. + + Parameters: + images (dict): Dictionary where keys are identifiers and values are OpenCV images. + + Returns: + dict: Mapping of image keys to the selected rectangular ROI. + """ + selected_rois = {} + + for key, img in images.items(): + if img is None: + print(f"Image for key '{key}' is None, skipping.") + continue + + print(f"\nSelect rectangular ROI for image with key: '{key}'") + roi = select_rect_roi(img) + + if roi is None: + print(f"No valid ROI selected for '{key}'.") + else: + selected_rois[key] = roi + print(f"ROI for '{key}': {roi}") + + return selected_rois + + +def get_image_from_lerobot_dataset(dataset: LeRobotDataset): + """ + Find the first row in the dataset and extract the image in order to be used for the crop. + """ + row = dataset[0] + image_dict = {} + for k in row: + if "image" in k: + image_dict[k] = deepcopy(row[k]) + return image_dict + + +def convert_lerobot_dataset_to_cropper_lerobot_dataset( + original_dataset: LeRobotDataset, + crop_params_dict: Dict[str, Tuple[int, int, int, int]], + new_repo_id: str, + new_dataset_root: str, + resize_size: Tuple[int, int] = (128, 128), + push_to_hub: bool = False, + task: str = "", +) -> LeRobotDataset: + """ + Converts an existing LeRobotDataset by iterating over its episodes and frames, + applying cropping and resizing to image observations, and saving a new dataset + with the transformed data. + + Args: + original_dataset (LeRobotDataset): The source dataset. + crop_params_dict (Dict[str, Tuple[int, int, int, int]]): + A dictionary mapping observation keys to crop parameters (top, left, height, width). + new_repo_id (str): Repository id for the new dataset. + new_dataset_root (str): The root directory where the new dataset will be written. + resize_size (Tuple[int, int], optional): The target size (height, width) after cropping. + Defaults to (128, 128). + + Returns: + LeRobotDataset: A new LeRobotDataset where the specified image observations have been cropped + and resized. + """ + # 1. Create a new (empty) LeRobotDataset for writing. + new_dataset = LeRobotDataset.create( + repo_id=new_repo_id, + fps=original_dataset.fps, + root=new_dataset_root, + robot_type=original_dataset.meta.robot_type, + features=original_dataset.meta.info["features"], + use_videos=len(original_dataset.meta.video_keys) > 0, + ) + + # Update the metadata for every image key that will be cropped: + # (Here we simply set the shape to be the final resize_size.) + for key in crop_params_dict: + if key in new_dataset.meta.info["features"]: + new_dataset.meta.info["features"][key]["shape"] = [3] + list(resize_size) + + # TODO: Directly modify the mp4 video + meta info features, instead of recreating a dataset + prev_episode_index = 0 + for frame_idx in tqdm(range(len(original_dataset))): + frame = original_dataset[frame_idx] + + # Create a copy of the frame to add to the new dataset + new_frame = {} + for key, value in frame.items(): + if key in ("task_index", "timestamp", "episode_index", "frame_index", "index", "task"): + continue + if key in ("next.done", "next.reward"): + # if not isinstance(value, str) and len(value.shape) == 0: + value = value.unsqueeze(0) + + if key in crop_params_dict: + top, left, height, width = crop_params_dict[key] + # Apply crop then resize. + cropped = F.crop(value, top, left, height, width) + value = F.resize(cropped, resize_size) + value = value.clamp(0, 1) + + new_frame[key] = value + + new_dataset.add_frame(new_frame, task=task) + + if frame["episode_index"].item() != prev_episode_index: + # Save the episode + new_dataset.save_episode() + prev_episode_index = frame["episode_index"].item() + + # Save the last episode + new_dataset.save_episode() + + if push_to_hub: + new_dataset.push_to_hub() + + return new_dataset + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.") + parser.add_argument( + "--repo-id", + type=str, + default="lerobot", + help="The repository id of the LeRobot dataset to process.", + ) + parser.add_argument( + "--root", + type=str, + default=None, + help="The root directory of the LeRobot dataset.", + ) + parser.add_argument( + "--crop-params-path", + type=str, + default=None, + help="The path to the JSON file containing the ROIs.", + ) + parser.add_argument( + "--push-to-hub", + type=bool, + default=False, + help="Whether to push the new dataset to the hub.", + ) + parser.add_argument( + "--task", + type=str, + default="", + help="The natural language task to describe the dataset.", + ) + args = parser.parse_args() + + dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root) + + images = get_image_from_lerobot_dataset(dataset) + images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()} + images = {k: (v * 255).astype("uint8") for k, v in images.items()} + + if args.crop_params_path is None: + rois = select_square_roi_for_images(images) + else: + with open(args.crop_params_path) as f: + rois = json.load(f) + + # Print the selected rectangular ROIs + print("\nSelected Rectangular Regions of Interest (top, left, height, width):") + for key, roi in rois.items(): + print(f"{key}: {roi}") + + new_repo_id = args.repo_id + "_cropped_resized" + new_dataset_root = Path(str(dataset.root) + "_cropped_resized") + + cropped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset( + original_dataset=dataset, + crop_params_dict=rois, + new_repo_id=new_repo_id, + new_dataset_root=new_dataset_root, + resize_size=(128, 128), + push_to_hub=args.push_to_hub, + task=args.task, + ) + + meta_dir = new_dataset_root / "meta" + meta_dir.mkdir(exist_ok=True) + + with open(meta_dir / "crop_params.json", "w") as f: + json.dump(rois, f, indent=4) diff --git a/lerobot/scripts/rl/gym_manipulator.py b/lerobot/scripts/rl/gym_manipulator.py new file mode 100644 index 0000000000..98445e6668 --- /dev/null +++ b/lerobot/scripts/rl/gym_manipulator.py @@ -0,0 +1,2171 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Robot Environment for LeRobot Manipulation Tasks + +This module provides a comprehensive gym-compatible environment for robot manipulation +with support for: +- Multiple robot types (SO100, SO101, Koch and Moss) +- Human intervention via leader-follower control or gamepad + +- End-effector and joint space control +- Image processing (cropping and resizing) + +The environment is built using a composable wrapper pattern where each wrapper +adds specific functionality to the base RobotEnv. + +Example: + env = make_robot_env(cfg) + obs, info = env.reset() + action = policy.select_action(obs) + obs, reward, terminated, truncated, info = env.step(action) +""" + +import logging +import time +from collections import deque +from threading import Lock +from typing import Annotated, Any, Sequence + +import gymnasium as gym +import numpy as np +import torch +import torchvision.transforms.functional as F # noqa: N812 + +from lerobot.common.cameras import opencv # noqa: F401 +from lerobot.common.envs.configs import EnvConfig +from lerobot.common.envs.utils import preprocess_observation +from lerobot.common.model.kinematics import RobotKinematics +from lerobot.common.robots import ( # noqa: F401 + RobotConfig, + make_robot_from_config, + so100_follower, +) +from lerobot.common.teleoperators import ( + gamepad, # noqa: F401 + make_teleoperator_from_config, + so101_leader, # noqa: F401 +) +from lerobot.common.teleoperators.gamepad.teleop_gamepad import GamepadTeleop +from lerobot.common.utils.robot_utils import busy_wait +from lerobot.common.utils.utils import log_say +from lerobot.configs import parser + +logging.basicConfig(level=logging.INFO) + + +def reset_follower_position(robot_arm, target_position): + current_position_dict = robot_arm.bus.sync_read("Present_Position") + current_position = np.array( + [current_position_dict[name] for name in current_position_dict], dtype=np.float32 + ) + trajectory = torch.from_numpy( + np.linspace(current_position, target_position, 50) + ) # NOTE: 30 is just an arbitrary number + for pose in trajectory: + action_dict = dict(zip(current_position_dict, pose, strict=False)) + robot_arm.bus.sync_write("Goal_Position", action_dict) + busy_wait(0.015) + + +class TorchBox(gym.spaces.Box): + """ + A version of gym.spaces.Box that handles PyTorch tensors. + + This class extends gym.spaces.Box to work with PyTorch tensors, + providing compatibility between NumPy arrays and PyTorch tensors. + """ + + def __init__( + self, + low: float | Sequence[float] | np.ndarray, + high: float | Sequence[float] | np.ndarray, + shape: Sequence[int] | None = None, + np_dtype: np.dtype | type = np.float32, + torch_dtype: torch.dtype = torch.float32, + device: str = "cpu", + seed: int | np.random.Generator | None = None, + ) -> None: + """ + Initialize the PyTorch-compatible Box space. + + Args: + low: Lower bounds of the space. + high: Upper bounds of the space. + shape: Shape of the space. If None, inferred from low and high. + np_dtype: NumPy data type for internal storage. + torch_dtype: PyTorch data type for tensor conversion. + device: PyTorch device for returned tensors. + seed: Random seed for sampling. + """ + super().__init__(low, high, shape=shape, dtype=np_dtype, seed=seed) + self.torch_dtype = torch_dtype + self.device = device + + def sample(self) -> torch.Tensor: + """ + Sample a random point from the space. + + Returns: + A PyTorch tensor within the space bounds. + """ + arr = super().sample() + return torch.as_tensor(arr, dtype=self.torch_dtype, device=self.device) + + def contains(self, x: torch.Tensor) -> bool: + """ + Check if a tensor is within the space bounds. + + Args: + x: The PyTorch tensor to check. + + Returns: + Boolean indicating whether the tensor is within bounds. + """ + # Move to CPU/numpy and cast to the internal dtype + arr = x.detach().cpu().numpy().astype(self.dtype, copy=False) + return super().contains(arr) + + def seed(self, seed: int | np.random.Generator | None = None): + """ + Set the random seed for sampling. + + Args: + seed: The random seed to use. + + Returns: + List containing the seed. + """ + super().seed(seed) + return [seed] + + def __repr__(self) -> str: + """ + Return a string representation of the space. + + Returns: + Formatted string with space details. + """ + return ( + f"TorchBox({self.low_repr}, {self.high_repr}, {self.shape}, " + f"np={self.dtype.name}, torch={self.torch_dtype}, device={self.device})" + ) + + +class TorchActionWrapper(gym.Wrapper): + """ + Wrapper that changes the action space to use PyTorch tensors. + + This wrapper modifies the action space to return PyTorch tensors when sampled + and handles converting PyTorch actions to NumPy when stepping the environment. + """ + + def __init__(self, env: gym.Env, device: str): + """ + Initialize the PyTorch action space wrapper. + + Args: + env: The environment to wrap. + device: The PyTorch device to use for tensor operations. + """ + super().__init__(env) + self.action_space = TorchBox( + low=env.action_space.low, + high=env.action_space.high, + shape=env.action_space.shape, + torch_dtype=torch.float32, + device=torch.device("cpu"), + ) + + def step(self, action: torch.Tensor): + """ + Step the environment with a PyTorch tensor action. + + This method handles conversion from PyTorch tensors to NumPy arrays + for compatibility with the underlying environment. + + Args: + action: PyTorch tensor action to take. + + Returns: + Tuple of (observation, reward, terminated, truncated, info). + """ + if action.dim() == 2: + action = action.squeeze(0) + action = action.detach().cpu().numpy() + return self.env.step(action) + + +class RobotEnv(gym.Env): + """ + Gym-compatible environment for evaluating robotic control policies with integrated human intervention. + + This environment wraps a robot interface to provide a consistent API for policy evaluation. It supports both relative (delta) + and absolute joint position commands and automatically configures its observation and action spaces based on the robot's + sensors and configuration. + """ + + def __init__( + self, + robot, + use_gripper: bool = False, + display_cameras: bool = False, + ): + """ + Initialize the RobotEnv environment. + + The environment is set up with a robot interface, which is used to capture observations and send joint commands. The setup + supports both relative (delta) adjustments and absolute joint positions for controlling the robot. + + Args: + robot: The robot interface object used to connect and interact with the physical robot. + display_cameras: If True, the robot's camera feeds will be displayed during execution. + """ + super().__init__() + + self.robot = robot + self.display_cameras = display_cameras + + # Connect to the robot if not already connected. + if not self.robot.is_connected: + self.robot.connect() + + # Episode tracking. + self.current_step = 0 + self.episode_data = None + + self._joint_names = [f"{key}.pos" for key in self.robot.bus.motors] + self._image_keys = self.robot.cameras.keys() + + # Read initial joint positions using the bus + self.current_joint_positions = self._get_observation()["agent_pos"] + + self.use_gripper = use_gripper + + self._setup_spaces() + + def _get_observation(self) -> np.ndarray: + """Helper to convert a dictionary from bus.sync_read to an ordered numpy array.""" + obs_dict = self.robot.get_observation() + joint_positions = np.array([obs_dict[name] for name in self._joint_names], dtype=np.float32) + + images = {key: obs_dict[key] for key in self._image_keys} + return {"agent_pos": joint_positions, "pixels": images} + + def _setup_spaces(self): + """ + Dynamically configure the observation and action spaces based on the robot's capabilities. + + Observation Space: + - For keys with "image": A Box space with pixel values ranging from 0 to 255. + - For non-image keys: A nested Dict space is created under 'observation.state' with a suitable range. + + Action Space: + - The action space is defined as a Box space representing joint position commands. It is defined as relative (delta) + or absolute, based on the configuration. + """ + example_obs = self._get_observation() + + observation_spaces = {} + + # Define observation spaces for images and other states. + if "pixels" in example_obs: + prefix = "observation.images" if len(example_obs["pixels"]) > 1 else "observation.image" + observation_spaces = { + f"{prefix}.{key}": gym.spaces.Box( + low=0, high=255, shape=example_obs["pixels"][key].shape, dtype=np.uint8 + ) + for key in example_obs["pixels"] + } + + observation_spaces["observation.state"] = gym.spaces.Box( + low=0, + high=10, + shape=example_obs["agent_pos"].shape, + dtype=np.float32, + ) + + self.observation_space = gym.spaces.Dict(observation_spaces) + + # Define the action space for joint positions along with setting an intervention flag. + action_dim = 3 + bounds = {} + bounds["min"] = -np.ones(action_dim) + bounds["max"] = np.ones(action_dim) + + if self.use_gripper: + action_dim += 1 + bounds["min"] = np.concatenate([bounds["min"], [0]]) + bounds["max"] = np.concatenate([bounds["max"], [2]]) + + self.action_space = gym.spaces.Box( + low=bounds["min"], + high=bounds["max"], + shape=(action_dim,), + dtype=np.float32, + ) + + def reset(self, seed=None, options=None) -> tuple[dict[str, np.ndarray], dict[str, Any]]: + """ + Reset the environment to its initial state. + This method resets the step counter and clears any episodic data. + + Args: + seed: A seed for random number generation to ensure reproducibility. + options: Additional options to influence the reset behavior. + + Returns: + A tuple containing: + - observation (dict): The initial sensor observation. + - info (dict): A dictionary with supplementary information, including the key "is_intervention". + """ + super().reset(seed=seed, options=options) + + self.robot.reset() + + # Capture the initial observation. + observation = self._get_observation() + + # Reset episode tracking variables. + self.current_step = 0 + self.episode_data = None + + return observation, {"is_intervention": False} + + def step(self, action) -> tuple[dict[str, np.ndarray], float, bool, bool, dict[str, Any]]: + """ + Execute a single step within the environment using the specified action. + + The provided action is processed and sent to the robot as joint position commands + that may be either absolute values or deltas based on the environment configuration. + + Args: + action: The commanded joint positions as a numpy array or torch tensor. + + Returns: + A tuple containing: + - observation (dict): The new sensor observation after taking the step. + - reward (float): The step reward (default is 0.0 within this wrapper). + - terminated (bool): True if the episode has reached a terminal state. + - truncated (bool): True if the episode was truncated (e.g., time constraints). + - info (dict): Additional debugging information including intervention status. + """ + self.current_joint_positions = self._get_observation()["agent_pos"] + + action_dict = {"delta_x": action[0], "delta_y": action[1], "delta_z": action[2]} + + # 1.0 action corresponds to no-op action + action_dict["gripper"] = action[3] if self.use_gripper else 1.0 + + self.robot.send_action(action_dict) + + if self.display_cameras: + self.render() + + self.current_step += 1 + + reward = 0.0 + terminated = False + truncated = False + + return ( + self._get_observation(), + reward, + terminated, + truncated, + {"is_intervention": False}, + ) + + def render(self): + """ + Render the current state of the environment by displaying the robot's camera feeds. + """ + import cv2 + + observation = self._get_observation() + image_keys = [key for key in observation if "image" in key] + + for key in image_keys: + cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) + cv2.waitKey(1) + + def close(self): + """ + Close the environment and clean up resources by disconnecting the robot. + + If the robot is currently connected, this method properly terminates the connection to ensure that all + associated resources are released. + """ + if self.robot.is_connected: + self.robot.disconnect() + + +class AddJointVelocityToObservation(gym.ObservationWrapper): + """ + Wrapper that adds joint velocity information to the observation. + + This wrapper computes joint velocities by tracking changes in joint positions over time, + and extends the observation space to include these velocities. + """ + + def __init__(self, env, joint_velocity_limits=100.0, fps=30, num_dof=6): + """ + Initialize the joint velocity wrapper. + + Args: + env: The environment to wrap. + joint_velocity_limits: Maximum expected joint velocity for space bounds. + fps: Frames per second used to calculate velocity (position delta / time). + num_dof: Number of degrees of freedom (joints) in the robot. + """ + super().__init__(env) + + # Extend observation space to include joint velocities + old_low = self.observation_space["observation.state"].low + old_high = self.observation_space["observation.state"].high + old_shape = self.observation_space["observation.state"].shape + + self.last_joint_positions = np.zeros(num_dof) + + new_low = np.concatenate([old_low, np.ones(num_dof) * -joint_velocity_limits]) + new_high = np.concatenate([old_high, np.ones(num_dof) * joint_velocity_limits]) + + new_shape = (old_shape[0] + num_dof,) + + self.observation_space["observation.state"] = gym.spaces.Box( + low=new_low, + high=new_high, + shape=new_shape, + dtype=np.float32, + ) + + self.dt = 1.0 / fps + + def observation(self, observation): + """ + Add joint velocity information to the observation. + + Args: + observation: The original observation from the environment. + + Returns: + The modified observation with joint velocities. + """ + joint_velocities = (observation["agent_pos"] - self.last_joint_positions) / self.dt + self.last_joint_positions = observation["agent_pos"] + observation["agent_pos"] = np.concatenate([observation["agent_pos"], joint_velocities], axis=-1) + return observation + + +class AddCurrentToObservation(gym.ObservationWrapper): + """ + Wrapper that adds motor current information to the observation. + + This wrapper extends the observation space to include the current values + from each motor, providing information about the forces being applied. + """ + + def __init__(self, env, max_current=500, num_dof=6): + """ + Initialize the current observation wrapper. + + Args: + env: The environment to wrap. + max_current: Maximum expected current for space bounds. + num_dof: Number of degrees of freedom (joints) in the robot. + """ + super().__init__(env) + + # Extend observation space to include joint velocities + old_low = self.observation_space["observation.state"].low + old_high = self.observation_space["observation.state"].high + old_shape = self.observation_space["observation.state"].shape + + new_low = np.concatenate([old_low, np.zeros(num_dof)]) + new_high = np.concatenate([old_high, np.ones(num_dof) * max_current]) + + new_shape = (old_shape[0] + num_dof,) + + self.observation_space["observation.state"] = gym.spaces.Box( + low=new_low, + high=new_high, + shape=new_shape, + dtype=np.float32, + ) + + def observation(self, observation): + """ + Add current information to the observation. + + Args: + observation: The original observation from the environment. + + Returns: + The modified observation with current values. + """ + present_current_observation = self.unwrapped._get_observation()["agent_pos"] + observation["agent_pos"] = np.concatenate( + [observation["agent_pos"], present_current_observation], axis=-1 + ) + return observation + + +class RewardWrapper(gym.Wrapper): + def __init__(self, env, reward_classifier, device="cuda"): + """ + Wrapper to add reward prediction to the environment using a trained classifier. + + Args: + env: The environment to wrap. + reward_classifier: The reward classifier model. + device: The device to run the model on. + """ + self.env = env + + self.device = device + + self.reward_classifier = torch.compile(reward_classifier) + self.reward_classifier.to(self.device) + + def step(self, action): + """ + Execute a step and compute the reward using the classifier. + + Args: + action: The action to take in the environment. + + Returns: + Tuple of (observation, reward, terminated, truncated, info). + """ + observation, _, terminated, truncated, info = self.env.step(action) + + images = {} + for key in observation: + if "image" in key: + images[key] = observation[key].to(self.device, non_blocking=(self.device == "cuda")) + if images[key].dim() == 3: + images[key] = images[key].unsqueeze(0) + + start_time = time.perf_counter() + with torch.inference_mode(): + success = ( + self.reward_classifier.predict_reward(images, threshold=0.7) + if self.reward_classifier is not None + else 0.0 + ) + info["Reward classifier frequency"] = 1 / (time.perf_counter() - start_time) + + reward = 0.0 + if success == 1.0: + terminated = True + reward = 1.0 + + return observation, reward, terminated, truncated, info + + def reset(self, seed=None, options=None): + """ + Reset the environment. + + Args: + seed: Random seed for reproducibility. + options: Additional reset options. + + Returns: + The initial observation and info from the wrapped environment. + """ + return self.env.reset(seed=seed, options=options) + + +class TimeLimitWrapper(gym.Wrapper): + """ + Wrapper that adds a time limit to episodes and tracks execution time. + + This wrapper terminates episodes after a specified time has elapsed, providing + better control over episode length. + """ + + def __init__(self, env, control_time_s, fps): + """ + Initialize the time limit wrapper. + + Args: + env: The environment to wrap. + control_time_s: Maximum episode duration in seconds. + fps: Frames per second for calculating the maximum number of steps. + """ + self.env = env + self.control_time_s = control_time_s + self.fps = fps + + self.last_timestamp = 0.0 + self.episode_time_in_s = 0.0 + + self.max_episode_steps = int(self.control_time_s * self.fps) + + self.current_step = 0 + + def step(self, action): + """ + Step the environment and track time elapsed. + + Args: + action: The action to take in the environment. + + Returns: + Tuple of (observation, reward, terminated, truncated, info). + """ + obs, reward, terminated, truncated, info = self.env.step(action) + time_since_last_step = time.perf_counter() - self.last_timestamp + self.episode_time_in_s += time_since_last_step + self.last_timestamp = time.perf_counter() + self.current_step += 1 + # check if last timestep took more time than the expected fps + if 1.0 / time_since_last_step < self.fps: + logging.debug(f"Current timestep exceeded expected fps {self.fps}") + + if self.current_step >= self.max_episode_steps: + terminated = True + return obs, reward, terminated, truncated, info + + def reset(self, seed=None, options=None): + """ + Reset the environment and time tracking. + + Args: + seed: Random seed for reproducibility. + options: Additional reset options. + + Returns: + The initial observation and info from the wrapped environment. + """ + self.episode_time_in_s = 0.0 + self.last_timestamp = time.perf_counter() + self.current_step = 0 + return self.env.reset(seed=seed, options=options) + + +class ImageCropResizeWrapper(gym.Wrapper): + """ + Wrapper that crops and resizes image observations. + + This wrapper processes image observations to focus on relevant regions by + cropping and then resizing to a standard size. + """ + + def __init__( + self, + env, + crop_params_dict: dict[str, Annotated[tuple[int], 4]], + resize_size=None, + ): + """ + Initialize the image crop and resize wrapper. + + Args: + env: The environment to wrap. + crop_params_dict: Dictionary mapping image observation keys to crop parameters + (top, left, height, width). + resize_size: Target size for resized images (height, width). Defaults to (128, 128). + """ + super().__init__(env) + self.env = env + self.crop_params_dict = crop_params_dict + print(f"obs_keys , {self.env.observation_space}") + print(f"crop params dict {crop_params_dict.keys()}") + for key_crop in crop_params_dict: + if key_crop not in self.env.observation_space.keys(): # noqa: SIM118 + raise ValueError(f"Key {key_crop} not in observation space") + for key in crop_params_dict: + new_shape = (3, resize_size[0], resize_size[1]) + self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape) + + self.resize_size = resize_size + if self.resize_size is None: + self.resize_size = (128, 128) + + def step(self, action): + """ + Step the environment and process image observations. + + Args: + action: The action to take in the environment. + + Returns: + Tuple of (observation, reward, terminated, truncated, info) with processed images. + """ + obs, reward, terminated, truncated, info = self.env.step(action) + for k in self.crop_params_dict: + device = obs[k].device + if obs[k].dim() >= 3: + # Reshape to combine height and width dimensions for easier calculation + batch_size = obs[k].size(0) + channels = obs[k].size(1) + flattened_spatial_dims = obs[k].view(batch_size, channels, -1) + + # Calculate standard deviation across spatial dimensions (H, W) + # If any channel has std=0, all pixels in that channel have the same value + # This is helpful if one camera mistakenly covered or the image is black + std_per_channel = torch.std(flattened_spatial_dims, dim=2) + if (std_per_channel <= 0.02).any(): + logging.warning( + f"Potential hardware issue detected: All pixels have the same value in observation {k}" + ) + + if device == torch.device("mps:0"): + obs[k] = obs[k].cpu() + + obs[k] = F.crop(obs[k], *self.crop_params_dict[k]) + obs[k] = F.resize(obs[k], self.resize_size) + # TODO (michel-aractingi): Bug in resize, it returns values outside [0, 1] + obs[k] = obs[k].clamp(0.0, 1.0) + obs[k] = obs[k].to(device) + + return obs, reward, terminated, truncated, info + + def reset(self, seed=None, options=None): + """ + Reset the environment and process image observations. + + Args: + seed: Random seed for reproducibility. + options: Additional reset options. + + Returns: + Tuple of (observation, info) with processed images. + """ + obs, info = self.env.reset(seed=seed, options=options) + for k in self.crop_params_dict: + device = obs[k].device + if device == torch.device("mps:0"): + obs[k] = obs[k].cpu() + obs[k] = F.crop(obs[k], *self.crop_params_dict[k]) + obs[k] = F.resize(obs[k], self.resize_size) + obs[k] = obs[k].clamp(0.0, 1.0) + obs[k] = obs[k].to(device) + return obs, info + + +class ConvertToLeRobotObservation(gym.ObservationWrapper): + """ + Wrapper that converts standard observations to LeRobot format. + + This wrapper processes observations to match the expected format for LeRobot, + including normalizing image values and moving tensors to the specified device. + """ + + def __init__(self, env, device: str = "cpu"): + """ + Initialize the LeRobot observation converter. + + Args: + env: The environment to wrap. + device: Target device for the observation tensors. + """ + super().__init__(env) + + self.device = torch.device(device) + + def observation(self, observation): + """ + Convert observations to LeRobot format. + + Args: + observation: The original observation from the environment. + + Returns: + The processed observation with normalized images and proper tensor formats. + """ + observation = preprocess_observation(observation) + observation = { + key: observation[key].to(self.device, non_blocking=self.device.type == "cuda") + for key in observation + } + return observation + + +class ResetWrapper(gym.Wrapper): + """ + Wrapper that handles environment reset procedures. + + This wrapper provides additional functionality during environment reset, + including the option to reset to a fixed pose or allow manual reset. + """ + + def __init__( + self, + env: RobotEnv, + reset_pose: np.ndarray | None = None, + reset_time_s: float = 5, + ): + """ + Initialize the reset wrapper. + + Args: + env: The environment to wrap. + reset_pose: Fixed joint positions to reset to. If None, manual reset is used. + reset_time_s: Time in seconds to wait after reset or allowed for manual reset. + """ + super().__init__(env) + self.reset_time_s = reset_time_s + self.reset_pose = reset_pose + self.robot = self.unwrapped.robot + + def reset(self, *, seed=None, options=None): + """ + Reset the environment with either fixed or manual reset procedure. + + If reset_pose is provided, the robot will move to that position. + Otherwise, manual teleoperation control is allowed for reset_time_s seconds. + + Args: + seed: Random seed for reproducibility. + options: Additional reset options. + + Returns: + The initial observation and info from the wrapped environment. + """ + start_time = time.perf_counter() + if self.reset_pose is not None: + log_say("Reset the environment.", play_sounds=True) + reset_follower_position(self.unwrapped.robot, self.reset_pose) + log_say("Reset the environment done.", play_sounds=True) + + if hasattr(self.env, "robot_leader"): + self.env.robot_leader.bus.sync_write("Torque_Enable", 1) + log_say("Reset the leader robot.", play_sounds=True) + reset_follower_position(self.env.robot_leader, self.reset_pose) + log_say("Reset the leader robot done.", play_sounds=True) + else: + log_say( + f"Manually reset the environment for {self.reset_time_s} seconds.", + play_sounds=True, + ) + start_time = time.perf_counter() + while time.perf_counter() - start_time < self.reset_time_s: + action = self.env.robot_leader.get_action() + self.unwrapped.robot.send_action(action) + + log_say("Manual reset of the environment done.", play_sounds=True) + + busy_wait(self.reset_time_s - (time.perf_counter() - start_time)) + + return super().reset(seed=seed, options=options) + + +class BatchCompatibleWrapper(gym.ObservationWrapper): + """ + Wrapper that ensures observations are compatible with batch processing. + + This wrapper adds a batch dimension to observations that don't already have one, + making them compatible with models that expect batched inputs. + """ + + def __init__(self, env): + """ + Initialize the batch compatibility wrapper. + + Args: + env: The environment to wrap. + """ + super().__init__(env) + + def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Add batch dimensions to observations if needed. + + Args: + observation: Dictionary of observation tensors. + + Returns: + Dictionary of observation tensors with batch dimensions. + """ + for key in observation: + if "image" in key and observation[key].dim() == 3: + observation[key] = observation[key].unsqueeze(0) + if "state" in key and observation[key].dim() == 1: + observation[key] = observation[key].unsqueeze(0) + if "velocity" in key and observation[key].dim() == 1: + observation[key] = observation[key].unsqueeze(0) + return observation + + +class GripperPenaltyWrapper(gym.RewardWrapper): + """ + Wrapper that adds penalties for inefficient gripper commands. + + This wrapper modifies rewards to discourage excessive gripper movement + or commands that attempt to move the gripper beyond its physical limits. + """ + + def __init__(self, env, penalty: float = -0.1): + """ + Initialize the gripper penalty wrapper. + + Args: + env: The environment to wrap. + penalty: Negative reward value to apply for inefficient gripper actions. + """ + super().__init__(env) + self.penalty = penalty + self.last_gripper_state = None + + def reward(self, reward, action): + """ + Apply penalties to reward based on gripper actions. + + Args: + reward: The original reward from the environment. + action: The action that was taken. + + Returns: + Modified reward with penalty applied if necessary. + """ + gripper_state_normalized = self.last_gripper_state / self.unwrapped.robot.config.max_gripper_pos + + action_normalized = action - 1.0 # action / MAX_GRIPPER_COMMAND + + gripper_penalty_bool = (gripper_state_normalized < 0.5 and action_normalized > 0.5) or ( + gripper_state_normalized > 0.75 and action_normalized < -0.5 + ) + + return reward + self.penalty * int(gripper_penalty_bool) + + def step(self, action): + """ + Step the environment and apply gripper penalties. + + Args: + action: The action to take in the environment. + + Returns: + Tuple of (observation, reward, terminated, truncated, info) with penalty applied. + """ + self.last_gripper_state = self.unwrapped.robot.bus.sync_read("Present_Position")["gripper"] + + gripper_action = action[-1] + obs, reward, terminated, truncated, info = self.env.step(action) + gripper_penalty = self.reward(reward, gripper_action) + + info["discrete_penalty"] = gripper_penalty + + return obs, reward, terminated, truncated, info + + def reset(self, **kwargs): + """ + Reset the environment and penalty tracking. + + Args: + **kwargs: Keyword arguments passed to the wrapped environment's reset. + + Returns: + The initial observation and info with gripper penalty initialized. + """ + self.last_gripper_state = None + obs, info = super().reset(**kwargs) + info["gripper_penalty"] = 0.0 + return obs, info + + +class GripperActionWrapper(gym.ActionWrapper): + """ + Wrapper that processes gripper control commands. + + This wrapper quantizes and processes gripper commands, adding a sleep time between + consecutive gripper actions to prevent rapid toggling. + """ + + def __init__(self, env, quantization_threshold: float = 0.2, gripper_sleep: float = 0.0): + """ + Initialize the gripper action wrapper. + + Args: + env: The environment to wrap. + quantization_threshold: Threshold below which gripper commands are quantized to zero. + gripper_sleep: Minimum time in seconds between consecutive gripper commands. + """ + super().__init__(env) + self.quantization_threshold = quantization_threshold + self.gripper_sleep = gripper_sleep + self.last_gripper_action_time = 0.0 + self.last_gripper_action = None + + def action(self, action): + """ + Process gripper commands in the action. + + Args: + action: The original action from the agent. + + Returns: + Modified action with processed gripper command. + """ + if self.gripper_sleep > 0.0: + if ( + self.last_gripper_action is not None + and time.perf_counter() - self.last_gripper_action_time < self.gripper_sleep + ): + action[-1] = self.last_gripper_action + else: + self.last_gripper_action_time = time.perf_counter() + self.last_gripper_action = action[-1] + + gripper_command = action[-1] + # Gripper actions are between 0, 2 + # we want to quantize them to -1, 0 or 1 + gripper_command = gripper_command - 1.0 + + if self.quantization_threshold is not None: + # Quantize gripper command to -1, 0 or 1 + gripper_command = ( + np.sign(gripper_command) if abs(gripper_command) > self.quantization_threshold else 0.0 + ) + gripper_command = gripper_command * self.unwrapped.robot.config.max_gripper_pos + + gripper_state = self.unwrapped.robot.bus.sync_read("Present_Position")["gripper"] + + gripper_action_value = np.clip( + gripper_state + gripper_command, 0, self.unwrapped.robot.config.max_gripper_pos + ) + action[-1] = gripper_action_value.item() + return action + + def reset(self, **kwargs): + """ + Reset the gripper action tracking. + + Args: + **kwargs: Keyword arguments passed to the wrapped environment's reset. + + Returns: + The initial observation and info. + """ + obs, info = super().reset(**kwargs) + self.last_gripper_action_time = 0.0 + self.last_gripper_action = None + return obs, info + + +class EEObservationWrapper(gym.ObservationWrapper): + """ + Wrapper that adds end-effector pose information to observations. + + This wrapper computes the end-effector pose using forward kinematics + and adds it to the observation space. + """ + + def __init__(self, env, ee_pose_limits): + """ + Initialize the end-effector observation wrapper. + + Args: + env: The environment to wrap. + ee_pose_limits: Dictionary with 'min' and 'max' keys containing limits for EE pose. + """ + super().__init__(env) + + # Extend observation space to include end effector pose + prev_space = self.observation_space["observation.state"] + + self.observation_space["observation.state"] = gym.spaces.Box( + low=np.concatenate([prev_space.low, ee_pose_limits["min"]]), + high=np.concatenate([prev_space.high, ee_pose_limits["max"]]), + shape=(prev_space.shape[0] + 3,), + dtype=np.float32, + ) + + # Initialize kinematics instance for the appropriate robot type + robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so101") + if "so100" in robot_type or "so101" in robot_type: + # Note to be compatible with the rest of the codebase, + # we are using the new calibration method for so101 and so100 + robot_type = "so_new_calibration" + self.kinematics = RobotKinematics(robot_type) + + def observation(self, observation): + """ + Add end-effector pose to the observation. + + Args: + observation: Original observation from the environment. + + Returns: + Enhanced observation with end-effector pose information. + """ + current_joint_pos = self.unwrapped._get_observation()["agent_pos"] + + current_ee_pos = self.kinematics.forward_kinematics(current_joint_pos, frame="gripper_tip")[:3, 3] + observation["agent_pos"] = np.concatenate([observation["agent_pos"], current_ee_pos], -1) + return observation + + +########################################################### +# Wrappers related to human intervention and input devices +########################################################### + + +class BaseLeaderControlWrapper(gym.Wrapper): + """ + Base class for leader-follower robot control wrappers. + + This wrapper enables human intervention through a leader-follower robot setup, + where the human can control a leader robot to guide the follower robot's movements. + """ + + def __init__( + self, + env, + teleop_device, + end_effector_step_sizes, + use_geared_leader_arm: bool = False, + use_gripper=False, + ): + """ + Initialize the base leader control wrapper. + + Args: + env: The environment to wrap. + teleop_device: The teleoperation device. + use_geared_leader_arm: Whether to use a geared leader arm setup. + use_gripper: Whether to include gripper control. + """ + super().__init__(env) + self.robot_leader = teleop_device + self.robot_follower = env.unwrapped.robot + self.use_geared_leader_arm = use_geared_leader_arm + self.use_gripper: bool = use_gripper + self.end_effector_step_sizes = np.array(list(end_effector_step_sizes.values())) + + # Set up keyboard event tracking + self._init_keyboard_events() + self.event_lock = Lock() # Thread-safe access to events + + # Initialize robot control + robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so101") + if "so100" in robot_type or "so101" in robot_type: + # Note to be compatible with the rest of the codebase, + # we are using the new calibration method for so101 and so100 + robot_type = "so_new_calibration" + self.kinematics = RobotKinematics(robot_type) + self.leader_torque_enabled = True + self.prev_leader_gripper = None + + # Configure leader arm + # NOTE: Lower the gains of leader arm for automatic take-over + # With lower gains we can manually move the leader arm without risk of injury to ourselves or the robot + # With higher gains, it would be dangerous and difficult to modify the leader's pose while torque is enabled + # Default value for P_coeff is 32 + self.robot_leader.bus.sync_write("Torque_Enable", 1) + for motor in self.robot_leader.bus.motors: + self.robot_leader.bus.write("P_Coefficient", motor, 16) + self.robot_leader.bus.write("I_Coefficient", motor, 0) + self.robot_leader.bus.write("D_Coefficient", motor, 16) + + self.leader_tracking_error_queue = deque(maxlen=4) + self._init_keyboard_listener() + + def _init_keyboard_events(self): + """ + Initialize the keyboard events dictionary. + + This method sets up tracking for keyboard events used for intervention control. + It should be overridden in subclasses to add additional events. + """ + self.keyboard_events = { + "episode_success": False, + "episode_end": False, + "rerecord_episode": False, + } + + def _handle_key_press(self, key, keyboard): + """ + Handle key press events. + + Args: + key: The key that was pressed. + keyboard: The keyboard module with key definitions. + + This method should be overridden in subclasses for additional key handling. + """ + try: + if key == keyboard.Key.esc: + self.keyboard_events["episode_end"] = True + return + if key == keyboard.Key.left: + self.keyboard_events["rerecord_episode"] = True + return + if hasattr(key, "char") and key.char == "s": + logging.info("Key 's' pressed. Episode success triggered.") + self.keyboard_events["episode_success"] = True + return + except Exception as e: + logging.error(f"Error handling key press: {e}") + + def _init_keyboard_listener(self): + """ + Initialize the keyboard listener for intervention control. + + This method sets up keyboard event handling if not in headless mode. + """ + from pynput import keyboard + + def on_press(key): + with self.event_lock: + self._handle_key_press(key, keyboard) + + self.listener = keyboard.Listener(on_press=on_press) + self.listener.start() + + def _check_intervention(self): + """ + Check if human intervention is needed. + + Returns: + Boolean indicating whether intervention is needed. + + This method should be overridden in subclasses with specific intervention logic. + """ + return False + + def _handle_intervention(self, action): + """ + Process actions during intervention mode. + + Args: + action: The original action from the agent. + + Returns: + Tuple of (modified_action, intervention_action). + """ + if self.leader_torque_enabled: + self.robot_leader.bus.sync_write("Torque_Enable", 0) + self.leader_torque_enabled = False + + leader_pos_dict = self.robot_leader.bus.sync_read("Present_Position") + follower_pos_dict = self.robot_follower.bus.sync_read("Present_Position") + + leader_pos = np.array([leader_pos_dict[name] for name in leader_pos_dict], dtype=np.float32) + follower_pos = np.array([follower_pos_dict[name] for name in follower_pos_dict], dtype=np.float32) + + self.leader_tracking_error_queue.append(np.linalg.norm(follower_pos[:-1] - leader_pos[:-1])) + + # [:3, 3] Last column of the transformation matrix corresponds to the xyz translation + leader_ee = self.kinematics.forward_kinematics(leader_pos, frame="gripper_tip")[:3, 3] + follower_ee = self.kinematics.forward_kinematics(follower_pos, frame="gripper_tip")[:3, 3] + + action = np.clip(leader_ee - follower_ee, -self.end_effector_step_sizes, self.end_effector_step_sizes) + # Normalize the action to the range [-1, 1] + action = action / self.end_effector_step_sizes + + if self.use_gripper: + if self.prev_leader_gripper is None: + self.prev_leader_gripper = np.clip( + leader_pos[-1], 0, self.robot_follower.config.max_gripper_pos + ) + + # Get gripper action delta based on leader pose + leader_gripper = leader_pos[-1] + gripper_delta = leader_gripper - self.prev_leader_gripper + + # Normalize by max angle and quantize to {0,1,2} + normalized_delta = gripper_delta / self.robot_follower.config.max_gripper_pos + if normalized_delta >= 0.3: + gripper_action = 2 + elif normalized_delta <= 0.1: + gripper_action = 0 + else: + gripper_action = 1 + + action = np.append(action, gripper_action) + + return action + + def _handle_leader_teleoperation(self): + """ + Handle leader teleoperation in non-intervention mode. + + This method synchronizes the leader robot position with the follower. + """ + + prev_leader_pos_dict = self.robot_leader.bus.sync_read("Present_Position") + prev_leader_pos = np.array( + [prev_leader_pos_dict[name] for name in prev_leader_pos_dict], dtype=np.float32 + ) + + if not self.leader_torque_enabled: + self.robot_leader.bus.sync_write("Torque_Enable", 1) + self.leader_torque_enabled = True + + follower_pos_dict = self.robot_follower.bus.sync_read("Present_Position") + follower_pos = np.array([follower_pos_dict[name] for name in follower_pos_dict], dtype=np.float32) + + goal_pos = {f"{motor}": follower_pos[i] for i, motor in enumerate(self.robot_leader.bus.motors)} + self.robot_leader.bus.sync_write("Goal_Position", goal_pos) + + self.leader_tracking_error_queue.append(np.linalg.norm(follower_pos[:-1] - prev_leader_pos[:-1])) + + def step(self, action): + """ + Execute a step with possible human intervention. + + Args: + action: The action to take in the environment. + + Returns: + Tuple of (observation, reward, terminated, truncated, info). + """ + is_intervention = self._check_intervention() + + # NOTE: + if is_intervention: + action = self._handle_intervention(action) + else: + self._handle_leader_teleoperation() + + # NOTE: + obs, reward, terminated, truncated, info = self.env.step(action) + + # Add intervention info + info["is_intervention"] = is_intervention + info["action_intervention"] = action if is_intervention else None + + self.prev_leader_gripper = np.clip( + self.robot_leader.bus.sync_read("Present_Position")["gripper"], + 0, + self.robot_follower.config.max_gripper_pos, + ) + + # Check for success or manual termination + success = self.keyboard_events["episode_success"] + terminated = terminated or self.keyboard_events["episode_end"] or success + + if success: + reward = 1.0 + logging.info("Episode ended successfully with reward 1.0") + + return obs, reward, terminated, truncated, info + + def reset(self, **kwargs): + """ + Reset the environment and intervention state. + + Args: + **kwargs: Keyword arguments passed to the wrapped environment's reset. + + Returns: + The initial observation and info. + """ + self.keyboard_events = dict.fromkeys(self.keyboard_events, False) + self.leader_tracking_error_queue.clear() + return super().reset(**kwargs) + + def close(self): + """ + Clean up resources, including stopping keyboard listener. + + Returns: + Result of closing the wrapped environment. + """ + if hasattr(self, "listener") and self.listener is not None: + self.listener.stop() + return self.env.close() + + +class GearedLeaderControlWrapper(BaseLeaderControlWrapper): + """ + Wrapper that enables manual intervention via keyboard. + + This wrapper extends the BaseLeaderControlWrapper to allow explicit toggling + of human intervention mode with keyboard controls. + """ + + def _init_keyboard_events(self): + """ + Initialize keyboard events including human intervention flag. + + Extends the base class dictionary with an additional flag for tracking + intervention state toggled by keyboard. + """ + super()._init_keyboard_events() + self.keyboard_events["human_intervention_step"] = False + + def _handle_key_press(self, key, keyboard): + """ + Handle key presses including space for intervention toggle. + + Args: + key: The key that was pressed. + keyboard: The keyboard module with key definitions. + + Extends the base handler to respond to space key for toggling intervention. + """ + super()._handle_key_press(key, keyboard) + if key == keyboard.Key.space: + if not self.keyboard_events["human_intervention_step"]: + logging.info( + "Space key pressed. Human intervention required.\n" + "Place the leader in similar pose to the follower and press space again." + ) + self.keyboard_events["human_intervention_step"] = True + log_say("Human intervention step.", play_sounds=True) + else: + self.keyboard_events["human_intervention_step"] = False + logging.info("Space key pressed for a second time.\nContinuing with policy actions.") + log_say("Continuing with policy actions.", play_sounds=True) + + def _check_intervention(self): + """ + Check if human intervention is active based on keyboard toggle. + + Returns: + Boolean indicating whether intervention mode is active. + """ + return self.keyboard_events["human_intervention_step"] + + +class GearedLeaderAutomaticControlWrapper(BaseLeaderControlWrapper): + """ + Wrapper with automatic intervention based on error thresholds. + + This wrapper monitors the error between leader and follower positions + and automatically triggers intervention when error exceeds thresholds. + """ + + def __init__( + self, + env, + teleop_device, + end_effector_step_sizes, + use_gripper=False, + intervention_threshold=10.0, + release_threshold=1e-2, + ): + """ + Initialize the automatic intervention wrapper. + + Args: + env: The environment to wrap. + teleop_device: The teleoperation device. + use_gripper: Whether to include gripper control. + intervention_threshold: Error threshold to trigger intervention. + release_threshold: Error threshold to release intervention. + queue_size: Number of error measurements to track for smoothing. + """ + super().__init__(env, teleop_device, end_effector_step_sizes, use_gripper=use_gripper) + + # Error tracking parameters + self.intervention_threshold = intervention_threshold # Threshold to trigger intervention + self.release_threshold = release_threshold # Threshold to release intervention + self.is_intervention_active = False + self.start_time = time.perf_counter() + + def _check_intervention(self): + """ + Determine if intervention should occur based on the rate of change of leader-follower error in end_effector space. + + This method monitors the rate of change of leader-follower error in end_effector space + and automatically triggers intervention when the rate of change exceeds + the intervention threshold, releasing when it falls below the release threshold. + + Returns: + Boolean indicating whether intervention should be active. + """ + + # Condition for starting the intervention + # If the error in teleoperation is too high, that means the a user has grasped the leader robot and he wants to take over + if ( + not self.is_intervention_active + and len(self.leader_tracking_error_queue) == self.leader_tracking_error_queue.maxlen + and np.var(list(self.leader_tracking_error_queue)[-2:]) > self.intervention_threshold + ): + self.is_intervention_active = True + self.leader_tracking_error_queue.clear() + log_say("Intervention started", play_sounds=True) + return True + + # Track the error over time in leader_tracking_error_queue + # If the variance of the tracking error is too low, that means the user has let go of the leader robot and the intervention is over + if ( + self.is_intervention_active + and len(self.leader_tracking_error_queue) == self.leader_tracking_error_queue.maxlen + and np.var(self.leader_tracking_error_queue) < self.release_threshold + ): + self.is_intervention_active = False + self.leader_tracking_error_queue.clear() + log_say("Intervention ended", play_sounds=True) + return False + + # If not change has happened that merits a change in the intervention state, return the current state + return self.is_intervention_active + + def reset(self, **kwargs): + """ + Reset error tracking on environment reset. + + Args: + **kwargs: Keyword arguments passed to the wrapped environment's reset. + + Returns: + The initial observation and info. + """ + self.is_intervention_active = False + return super().reset(**kwargs) + + +class GamepadControlWrapper(gym.Wrapper): + """ + Wrapper that allows controlling a gym environment with a gamepad. + + This wrapper intercepts the step method and allows human input via gamepad + to override the agent's actions when desired. + """ + + def __init__( + self, + env, + teleop_device, # Accepts an instantiated teleoperator + use_gripper=False, # This should align with teleop_device's config + auto_reset=False, + ): + """ + Initialize the gamepad controller wrapper. + + Args: + env: The environment to wrap. + teleop_device: The instantiated teleoperation device (e.g., GamepadTeleop). + use_gripper: Whether to include gripper control (should match teleop_device.config.use_gripper). + auto_reset: Whether to auto reset the environment when episode ends. + """ + super().__init__(env) + + self.teleop_device = teleop_device + # Ensure the teleop_device is connected if it has a connect method + if hasattr(self.teleop_device, "connect") and not self.teleop_device.is_connected: + self.teleop_device.connect() + + # self.controller attribute is removed + + self.auto_reset = auto_reset + # use_gripper from args should ideally match teleop_device.config.use_gripper + # For now, we use the one passed, but it can lead to inconsistency if not set correctly from config + self.use_gripper = use_gripper + + logging.info("Gamepad control wrapper initialized with provided teleop_device.") + print( + "Gamepad controls (managed by the provided teleop_device - specific button mappings might vary):" + ) + print(" Left analog stick: Move in X-Y plane") + print(" Right analog stick: Move in Z axis (up/down)") + print(" X/Square button: End episode (FAILURE)") + print(" Y/Triangle button: End episode (SUCCESS)") + print(" B/Circle button: Exit program") + + def get_gamepad_action( + self, + ) -> tuple[bool, np.ndarray, bool, bool, bool]: + """ + Get the current action from the gamepad if any input is active. + + Returns: + Tuple containing: + - is_active: Whether gamepad input is active (from teleop_device.gamepad.should_intervene()) + - action: The action derived from gamepad input (from teleop_device.get_action()) + - terminate_episode: Whether episode termination was requested + - success: Whether episode success was signaled + - rerecord_episode: Whether episode rerecording was requested + """ + if not hasattr(self.teleop_device, "gamepad") or self.teleop_device.gamepad is None: + raise AttributeError( + "teleop_device does not have a 'gamepad' attribute or it is None. Expected for GamepadControlWrapper." + ) + + # Get status flags from the underlying gamepad controller within the teleop_device + self.teleop_device.gamepad.update() # Ensure gamepad state is fresh + intervention_is_active = self.teleop_device.gamepad.should_intervene() + episode_end_status = self.teleop_device.gamepad.get_episode_end_status() + + terminate_episode = episode_end_status is not None + success = episode_end_status == "success" + rerecord_episode = episode_end_status == "rerecord_episode" + + # Get the action dictionary from the teleop_device + action_dict = self.teleop_device.get_action() + + # Convert action_dict to numpy array based on expected structure + # Order: delta_x, delta_y, delta_z, gripper (if use_gripper) + action_list = [action_dict["delta_x"], action_dict["delta_y"], action_dict["delta_z"]] + if self.use_gripper: + # GamepadTeleop returns gripper action as 0 (close), 1 (stay), 2 (open) + # This needs to be consistent with what EEActionWrapper expects if it's used downstream + # EEActionWrapper for gripper typically expects 0.0 (closed) to 2.0 (open) + # For now, we pass the direct value from GamepadTeleop, ensure downstream compatibility. + gripper_val = action_dict.get("gripper", 1.0) # Default to 1.0 (stay) if not present + action_list.append(float(gripper_val)) + + gamepad_action_np = np.array(action_list, dtype=np.float32) + + return ( + intervention_is_active, + gamepad_action_np, + terminate_episode, + success, + rerecord_episode, + ) + + def step(self, action): + """ + Step the environment, using gamepad input to override actions when active. + + Args: + action: Original action from agent. + + Returns: + Tuple of (observation, reward, terminated, truncated, info). + """ + # Get gamepad state and action + ( + is_intervention, + gamepad_action, + terminate_episode, + success, + rerecord_episode, + ) = self.get_gamepad_action() + + # Update episode ending state if requested + if terminate_episode: + logging.info(f"Episode manually ended: {'SUCCESS' if success else 'FAILURE'}") + + # Only override the action if gamepad is active + action = gamepad_action if is_intervention else action + + # Step the environment + obs, reward, terminated, truncated, info = self.env.step(action) + + # Add episode ending if requested via gamepad + terminated = terminated or truncated or terminate_episode + + if success: + reward = 1.0 + logging.info("Episode ended successfully with reward 1.0") + + if isinstance(action, np.ndarray): + action = torch.from_numpy(action) + + info["is_intervention"] = is_intervention + # The original `BaseLeaderControlWrapper` puts `action_intervention` in info. + # For Gamepad, if intervention, `gamepad_action` is the intervention. + # If not intervention, policy's action is `action`. + # For consistency, let's store the *human's* action if intervention occurred. + info["action_intervention"] = action + + info["rerecord_episode"] = rerecord_episode + + # If episode ended, reset the state + if terminated or truncated: + # Add success/failure information to info dict + info["next.success"] = success + + # Auto reset if configured + if self.auto_reset: + obs, reset_info = self.reset() + info.update(reset_info) + + return obs, reward, terminated, truncated, info + + def close(self): + """ + Clean up resources when environment closes. + + Returns: + Result of closing the wrapped environment. + """ + if hasattr(self.teleop_device, "disconnect"): + self.teleop_device.disconnect() + + # Call the parent close method + return self.env.close() + + +class GymHilDeviceWrapper(gym.Wrapper): + def __init__(self, env, device="cpu"): + super().__init__(env) + self.device = device + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + for k in obs: + obs[k] = obs[k].to(self.device) + if "action_intervention" in info: + # NOTE: This is a hack to ensure the action intervention is a float32 tensor and supported on MPS device + info["action_intervention"] = info["action_intervention"].astype(np.float32) + info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device) + return obs, reward, terminated, truncated, info + + def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None): + obs, info = self.env.reset(seed=seed, options=options) + for k in obs: + obs[k] = obs[k].to(self.device) + if "action_intervention" in info: + # NOTE: This is a hack to ensure the action intervention is a float32 tensor and supported on MPS device + info["action_intervention"] = info["action_intervention"].astype(np.float32) + info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device) + return obs, info + + +class GymHilObservationProcessorWrapper(gym.ObservationWrapper): + def __init__(self, env: gym.Env): + super().__init__(env) + prev_space = self.observation_space + new_space = {} + + for key in prev_space: + if "pixels" in key: + for k in prev_space["pixels"]: + new_space[f"observation.images.{k}"] = gym.spaces.Box( + 0.0, 255.0, shape=(3, 128, 128), dtype=np.uint8 + ) + + if key == "agent_pos": + new_space["observation.state"] = prev_space["agent_pos"] + + self.observation_space = gym.spaces.Dict(new_space) + + def observation(self, observation: dict[str, Any]) -> dict[str, Any]: + return preprocess_observation(observation) + + +########################################################### +# Factory functions +########################################################### + + +def make_robot_env(cfg: EnvConfig) -> gym.Env: + """ + Factory function to create a robot environment. + + This function builds a robot environment with all necessary wrappers + based on the provided configuration. + + Args: + cfg: Configuration object containing environment parameters. + + Returns: + A gym environment with all necessary wrappers applied. + """ + if cfg.type == "hil": + import gym_hil # noqa: F401 + + # TODO (azouitine) + env = gym.make( + f"gym_hil/{cfg.task}", + image_obs=True, + render_mode="human", + use_gripper=cfg.wrapper.use_gripper, + gripper_penalty=cfg.wrapper.gripper_penalty, + ) + env = GymHilObservationProcessorWrapper(env=env) + env = GymHilDeviceWrapper(env=env, device=cfg.device) + env = BatchCompatibleWrapper(env=env) + env = TorchActionWrapper(env=env, device=cfg.device) + return env + + if not hasattr(cfg, "robot") or not hasattr(cfg, "teleop"): + raise ValueError( + "Configuration for 'gym_manipulator' must be HILSerlRobotEnvConfig with robot and teleop." + ) + + if cfg.robot is None: + raise ValueError("RobotConfig (cfg.robot) must be provided for gym_manipulator environment.") + robot = make_robot_from_config(cfg.robot) + + teleop_device = make_teleoperator_from_config(cfg.teleop) + teleop_device.connect() + + # Create base environment + env = RobotEnv( + robot=robot, + use_gripper=cfg.wrapper.use_gripper, + display_cameras=cfg.wrapper.display_cameras if cfg.wrapper else False, + ) + + # Add observation and image processing + if cfg.wrapper: + if cfg.wrapper.add_joint_velocity_to_observation: + env = AddJointVelocityToObservation(env=env, fps=cfg.fps) + if cfg.wrapper.add_current_to_observation: + env = AddCurrentToObservation(env=env) + if cfg.wrapper.add_ee_pose_to_observation: + env = EEObservationWrapper(env=env, ee_pose_limits=robot.end_effector_bounds) + + env = ConvertToLeRobotObservation(env=env, device=cfg.device) + + if cfg.wrapper and cfg.wrapper.crop_params_dict is not None: + env = ImageCropResizeWrapper( + env=env, + crop_params_dict=cfg.wrapper.crop_params_dict, + resize_size=cfg.wrapper.resize_size, + ) + + # Add reward computation and control wrappers + reward_classifier = init_reward_classifier(cfg) + if reward_classifier is not None: + env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) + + env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps) + if cfg.wrapper.use_gripper and cfg.wrapper.gripper_penalty is not None: + env = GripperPenaltyWrapper( + env=env, + penalty=cfg.wrapper.gripper_penalty, + ) + + # Control mode specific wrappers + control_mode = cfg.wrapper.control_mode + if control_mode == "gamepad": + assert isinstance(teleop_device, GamepadTeleop), ( + "teleop_device must be an instance of GamepadTeleop for gamepad control mode" + ) + env = GamepadControlWrapper( + env=env, + teleop_device=teleop_device, + use_gripper=cfg.wrapper.use_gripper, + ) + elif control_mode == "leader": + env = GearedLeaderControlWrapper( + env=env, + teleop_device=teleop_device, + end_effector_step_sizes=cfg.robot.end_effector_step_sizes, + use_gripper=cfg.wrapper.use_gripper, + ) + elif control_mode == "leader_automatic": + env = GearedLeaderAutomaticControlWrapper( + env=env, + teleop_device=teleop_device, + end_effector_step_sizes=cfg.robot.end_effector_step_sizes, + use_gripper=cfg.wrapper.use_gripper, + ) + else: + raise ValueError(f"Invalid control mode: {control_mode}") + + env = ResetWrapper( + env=env, + reset_pose=cfg.wrapper.fixed_reset_joint_positions, + reset_time_s=cfg.wrapper.reset_time_s, + ) + + env = BatchCompatibleWrapper(env=env) + env = TorchActionWrapper(env=env, device=cfg.device) + + return env + + +def init_reward_classifier(cfg): + """ + Load a reward classifier policy from a pretrained path if configured. + + Args: + cfg: The environment configuration containing classifier paths. + + Returns: + The loaded classifier model or None if not configured. + """ + if cfg.reward_classifier_pretrained_path is None: + return None + + from lerobot.common.policies.sac.reward_model.modeling_classifier import Classifier + + # Get device from config or default to CUDA + device = getattr(cfg, "device", "cpu") + + # Load the classifier directly using from_pretrained + classifier = Classifier.from_pretrained( + pretrained_name_or_path=cfg.reward_classifier_pretrained_path, + ) + + # Ensure model is on the correct device + classifier.to(device) + classifier.eval() # Set to evaluation mode + + return classifier + + +########################################################### +# Record and replay functions +########################################################### + + +def record_dataset(env, policy, cfg): + """ + Record a dataset of robot interactions using either a policy or teleop. + + This function runs episodes in the environment and records the observations, + actions, and results for dataset creation. + + Args: + env: The environment to record from. + policy: Optional policy to generate actions (if None, uses teleop). + cfg: Configuration object containing recording parameters like: + - repo_id: Repository ID for dataset storage + - dataset_root: Local root directory for dataset + - num_episodes: Number of episodes to record + - fps: Frames per second for recording + - push_to_hub: Whether to push dataset to Hugging Face Hub + - task: Name/description of the task being recorded + - number_of_steps_after_success: Number of additional steps to continue recording after + a success (reward=1) is detected. This helps collect + more positive examples for reward classifier training. + """ + from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + + # Setup initial action (zero action if using teleop) + action = env.action_space.sample() * 0.0 + + action_names = ["delta_x_ee", "delta_y_ee", "delta_z_ee"] + if cfg.wrapper.use_gripper: + action_names.append("gripper_delta") + + # Configure dataset features based on environment spaces + features = { + "observation.state": { + "dtype": "float32", + "shape": env.observation_space["observation.state"].shape, + "names": None, + }, + "action": { + "dtype": "float32", + "shape": (len(action_names),), + "names": action_names, + }, + "next.reward": {"dtype": "float32", "shape": (1,), "names": None}, + "next.done": {"dtype": "bool", "shape": (1,), "names": None}, + "complementary_info.discrete_penalty": { + "dtype": "float32", + "shape": (1,), + "names": ["discrete_penalty"], + }, + } + + # Add image features + for key in env.observation_space: + if "image" in key: + features[key] = { + "dtype": "video", + "shape": env.observation_space[key].shape, + "names": ["channels", "height", "width"], + } + + # Create dataset + dataset = LeRobotDataset.create( + cfg.repo_id, + cfg.fps, + root=cfg.dataset_root, + use_videos=True, + image_writer_threads=4, + image_writer_processes=0, + features=features, + ) + + # Record episodes + episode_index = 0 + recorded_action = None + while episode_index < cfg.num_episodes: + obs, _ = env.reset() + start_episode_t = time.perf_counter() + log_say(f"Recording episode {episode_index}", play_sounds=True) + + # Track success state collection + success_detected = False + success_steps_collected = 0 + + # Run episode steps + while time.perf_counter() - start_episode_t < cfg.wrapper.control_time_s: + start_loop_t = time.perf_counter() + + # Get action from policy if available + if cfg.pretrained_policy_name_or_path is not None: + action = policy.select_action(obs) + + # Step environment + obs, reward, terminated, truncated, info = env.step(action) + + # Check if episode needs to be rerecorded + if info.get("rerecord_episode", False): + break + + # For teleop, get action from intervention + recorded_action = { + "action": info["action_intervention"].cpu().squeeze(0).float() if policy is None else action + } + + # Process observation for dataset + obs_processed = {k: v.cpu().squeeze(0).float() for k, v in obs.items()} + + # Check if we've just detected success + if reward == 1.0 and not success_detected: + success_detected = True + logging.info("Success detected! Collecting additional success states.") + + # Add frame to dataset - continue marking as success even during extra collection steps + frame = {**obs_processed, **recorded_action} + + # If we're in the success collection phase, keep marking rewards as 1.0 + if success_detected: + frame["next.reward"] = np.array([1.0], dtype=np.float32) + else: + frame["next.reward"] = np.array([reward], dtype=np.float32) + + # Only mark as done if we're truly done (reached end or collected enough success states) + really_done = terminated or truncated + if success_detected: + success_steps_collected += 1 + really_done = success_steps_collected >= cfg.number_of_steps_after_success + + frame["next.done"] = np.array([really_done], dtype=bool) + frame["complementary_info.discrete_penalty"] = torch.tensor( + [info.get("discrete_penalty", 0.0)], dtype=torch.float32 + ) + dataset.add_frame(frame, task=cfg.task) + + # Maintain consistent timing + if cfg.fps: + dt_s = time.perf_counter() - start_loop_t + busy_wait(1 / cfg.fps - dt_s) + + # Check if we should end the episode + if (terminated or truncated) and not success_detected: + # Regular termination without success + break + elif success_detected and success_steps_collected >= cfg.number_of_steps_after_success: + # We've collected enough success states + logging.info(f"Collected {success_steps_collected} additional success states") + break + + # Handle episode recording + if info.get("rerecord_episode", False): + dataset.clear_episode_buffer() + logging.info(f"Re-recording episode {episode_index}") + continue + + dataset.save_episode() + episode_index += 1 + + # Finalize dataset + # dataset.consolidate(run_compute_stats=True) + if cfg.push_to_hub: + dataset.push_to_hub() + + +def replay_episode(env, cfg): + """ + Replay a recorded episode in the environment. + + This function loads actions from a previously recorded episode + and executes them in the environment. + + Args: + env: The environment to replay in. + cfg: Configuration object containing replay parameters: + - repo_id: Repository ID for dataset + - dataset_root: Local root directory for dataset + - episode: Episode ID to replay + """ + from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.dataset_root, episodes=[cfg.episode]) + env.reset() + + actions = dataset.hf_dataset.select_columns("action") + + for idx in range(dataset.num_frames): + start_episode_t = time.perf_counter() + + action = actions[idx]["action"] + env.step(action) + + dt_s = time.perf_counter() - start_episode_t + busy_wait(1 / 10 - dt_s) + + +@parser.wrap() +def main(cfg: EnvConfig): + """Main entry point for the robot environment script. + + This function runs the robot environment in one of several modes + based on the provided configuration. + + Args: + cfg: Configuration object defining the run parameters, + including mode (record, replay, random) and other settings. + """ + env = make_robot_env(cfg) + + if cfg.mode == "record": + policy = None + if cfg.pretrained_policy_name_or_path is not None: + from lerobot.common.policies.sac.modeling_sac import SACPolicy + + policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path) + policy.to(cfg.device) + policy.eval() + + record_dataset( + env, + policy=policy, + cfg=cfg, + ) + exit() + + if cfg.mode == "replay": + replay_episode( + env, + cfg=cfg, + ) + exit() + + env.reset() + + # Initialize the smoothed action as a random sample. + smoothed_action = env.action_space.sample() * 0.0 + + # Smoothing coefficient (alpha) defines how much of the new random sample to mix in. + # A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth. + alpha = 1.0 + + num_episode = 0 + successes = [] + while num_episode < 10: + start_loop_s = time.perf_counter() + # Sample a new random action from the robot's action space. + new_random_action = env.action_space.sample() + # Update the smoothed action using an exponential moving average. + smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action + + # Execute the step: wrap the NumPy action in a torch tensor. + obs, reward, terminated, truncated, info = env.step(smoothed_action) + if terminated or truncated: + successes.append(reward) + env.reset() + num_episode += 1 + + dt_s = time.perf_counter() - start_loop_s + busy_wait(1 / cfg.fps - dt_s) + + logging.info(f"Success after 20 steps {successes}") + logging.info(f"success rate {sum(successes) / len(successes)}") + + +if __name__ == "__main__": + main() diff --git a/lerobot/scripts/rl/learner.py b/lerobot/scripts/rl/learner.py new file mode 100644 index 0000000000..2d2c3755a9 --- /dev/null +++ b/lerobot/scripts/rl/learner.py @@ -0,0 +1,1206 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Learner server runner for distributed HILSerl robot policy training. + +This script implements the learner component of the distributed HILSerl architecture. +It initializes the policy network, maintains replay buffers, and updates +the policy based on transitions received from the actor server. + +Examples of usage: + +- Start a learner server for training: +```bash +python lerobot/scripts/rl/learner.py --config_path lerobot/configs/train_config_hilserl_so100.json +``` + +**NOTE**: Start the learner server before launching the actor server. The learner opens a gRPC server +to communicate with actors. + +**NOTE**: Training progress can be monitored through Weights & Biases if wandb.enable is set to true +in your configuration. + +**WORKFLOW**: +1. Create training configuration with proper policy, dataset, and environment settings +2. Start this learner server with the configuration +3. Start an actor server with the same configuration +4. Monitor training progress through wandb dashboard + +For more details on the complete HILSerl training workflow, see: +https://github.com/michel-aractingi/lerobot-hilserl-guide +""" + +import logging +import os +import shutil +import time +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from pprint import pformat + +import grpc +import torch +from termcolor import colored +from torch import nn +from torch.multiprocessing import Queue +from torch.optim.optimizer import Optimizer + +from lerobot.common.cameras import opencv # noqa: F401 +from lerobot.common.constants import ( + CHECKPOINTS_DIR, + LAST_CHECKPOINT_LINK, + PRETRAINED_MODEL_DIR, + TRAINING_STATE_DIR, +) +from lerobot.common.datasets.factory import make_dataset +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.sac.modeling_sac import SACPolicy +from lerobot.common.robots import so100_follower # noqa: F401 +from lerobot.common.teleoperators import gamepad, so100_leader # noqa: F401 +from lerobot.common.transport import services_pb2_grpc +from lerobot.common.transport.utils import ( + bytes_to_python_object, + bytes_to_transitions, + state_to_bytes, +) +from lerobot.common.utils.buffer import ReplayBuffer, concatenate_batch_transitions +from lerobot.common.utils.process import ProcessSignalHandler +from lerobot.common.utils.random_utils import set_seed +from lerobot.common.utils.train_utils import ( + get_step_checkpoint_dir, + save_checkpoint, + update_last_checkpoint, +) +from lerobot.common.utils.train_utils import ( + load_training_state as utils_load_training_state, +) +from lerobot.common.utils.transition import move_state_dict_to_device, move_transition_to_device +from lerobot.common.utils.utils import ( + format_big_number, + get_safe_torch_device, + init_logging, +) +from lerobot.common.utils.wandb_utils import WandBLogger +from lerobot.configs import parser +from lerobot.configs.train import TrainRLServerPipelineConfig +from lerobot.scripts.rl import learner_service + +LOG_PREFIX = "[LEARNER]" + + +################################################# +# MAIN ENTRY POINTS AND CORE ALGORITHM FUNCTIONS # +################################################# + + +@parser.wrap() +def train_cli(cfg: TrainRLServerPipelineConfig): + if not use_threads(cfg): + import torch.multiprocessing as mp + + mp.set_start_method("spawn") + + # Use the job_name from the config + train( + cfg, + job_name=cfg.job_name, + ) + + logging.info("[LEARNER] train_cli finished") + + +def train(cfg: TrainRLServerPipelineConfig, job_name: str | None = None): + """ + Main training function that initializes and runs the training process. + + Args: + cfg (TrainRLServerPipelineConfig): The training configuration + job_name (str | None, optional): Job name for logging. Defaults to None. + """ + + cfg.validate() + + if job_name is None: + job_name = cfg.job_name + + if job_name is None: + raise ValueError("Job name must be specified either in config or as a parameter") + + display_pid = False + if not use_threads(cfg): + display_pid = True + + # Create logs directory to ensure it exists + log_dir = os.path.join(cfg.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"learner_{job_name}.log") + + # Initialize logging with explicit log file + init_logging(log_file=log_file, display_pid=display_pid) + logging.info(f"Learner logging initialized, writing to {log_file}") + logging.info(pformat(cfg.to_dict())) + + # Setup WandB logging if enabled + if cfg.wandb.enable and cfg.wandb.project: + from lerobot.common.utils.wandb_utils import WandBLogger + + wandb_logger = WandBLogger(cfg) + else: + wandb_logger = None + logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) + + # Handle resume logic + cfg = handle_resume_logic(cfg) + + set_seed(seed=cfg.seed) + + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + + is_threaded = use_threads(cfg) + shutdown_event = ProcessSignalHandler(is_threaded, display_pid=display_pid).shutdown_event + + start_learner_threads( + cfg=cfg, + wandb_logger=wandb_logger, + shutdown_event=shutdown_event, + ) + + +def start_learner_threads( + cfg: TrainRLServerPipelineConfig, + wandb_logger: WandBLogger | None, + shutdown_event: any, # Event, +) -> None: + """ + Start the learner threads for training. + + Args: + cfg (TrainRLServerPipelineConfig): Training configuration + wandb_logger (WandBLogger | None): Logger for metrics + shutdown_event: Event to signal shutdown + """ + # Create multiprocessing queues + transition_queue = Queue() + interaction_message_queue = Queue() + parameters_queue = Queue() + + concurrency_entity = None + + if use_threads(cfg): + from threading import Thread + + concurrency_entity = Thread + else: + from torch.multiprocessing import Process + + concurrency_entity = Process + + communication_process = concurrency_entity( + target=start_learner, + args=( + parameters_queue, + transition_queue, + interaction_message_queue, + shutdown_event, + cfg, + ), + daemon=True, + ) + communication_process.start() + + add_actor_information_and_train( + cfg=cfg, + wandb_logger=wandb_logger, + shutdown_event=shutdown_event, + transition_queue=transition_queue, + interaction_message_queue=interaction_message_queue, + parameters_queue=parameters_queue, + ) + logging.info("[LEARNER] Training process stopped") + + logging.info("[LEARNER] Closing queues") + transition_queue.close() + interaction_message_queue.close() + parameters_queue.close() + + communication_process.join() + logging.info("[LEARNER] Communication process joined") + + logging.info("[LEARNER] join queues") + transition_queue.cancel_join_thread() + interaction_message_queue.cancel_join_thread() + parameters_queue.cancel_join_thread() + + logging.info("[LEARNER] queues closed") + + +################################################# +# Core algorithm functions # +################################################# + + +def add_actor_information_and_train( + cfg: TrainRLServerPipelineConfig, + wandb_logger: WandBLogger | None, + shutdown_event: any, # Event, + transition_queue: Queue, + interaction_message_queue: Queue, + parameters_queue: Queue, +): + """ + Handles data transfer from the actor to the learner, manages training updates, + and logs training progress in an online reinforcement learning setup. + + This function continuously: + - Transfers transitions from the actor to the replay buffer. + - Logs received interaction messages. + - Ensures training begins only when the replay buffer has a sufficient number of transitions. + - Samples batches from the replay buffer and performs multiple critic updates. + - Periodically updates the actor, critic, and temperature optimizers. + - Logs training statistics, including loss values and optimization frequency. + + NOTE: This function doesn't have a single responsibility, it should be split into multiple functions + in the future. The reason why we did that is the GIL in Python. It's super slow the performance + are divided by 200. So we need to have a single thread that does all the work. + + Args: + cfg (TrainRLServerPipelineConfig): Configuration object containing hyperparameters. + wandb_logger (WandBLogger | None): Logger for tracking training progress. + shutdown_event (Event): Event to signal shutdown. + transition_queue (Queue): Queue for receiving transitions from the actor. + interaction_message_queue (Queue): Queue for receiving interaction messages from the actor. + parameters_queue (Queue): Queue for sending policy parameters to the actor. + """ + # Extract all configuration variables at the beginning, it improve the speed performance + # of 7% + device = get_safe_torch_device(try_device=cfg.policy.device, log=True) + storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device) + clip_grad_norm_value = cfg.policy.grad_clip_norm + online_step_before_learning = cfg.policy.online_step_before_learning + utd_ratio = cfg.policy.utd_ratio + fps = cfg.env.fps + log_freq = cfg.log_freq + save_freq = cfg.save_freq + policy_update_freq = cfg.policy.policy_update_freq + policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency + saving_checkpoint = cfg.save_checkpoint + online_steps = cfg.policy.online_steps + async_prefetch = cfg.policy.async_prefetch + + # Initialize logging for multiprocessing + if not use_threads(cfg): + log_dir = os.path.join(cfg.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"learner_train_process_{os.getpid()}.log") + init_logging(log_file=log_file, display_pid=True) + logging.info("Initialized logging for actor information and training process") + + logging.info("Initializing policy") + + policy: SACPolicy = make_policy( + cfg=cfg.policy, + env_cfg=cfg.env, + ) + + assert isinstance(policy, nn.Module) + + policy.train() + + push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) + + last_time_policy_pushed = time.time() + + optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy) + + # If we are resuming, we need to load the training state + resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers) + + log_training_info(cfg=cfg, policy=policy) + + replay_buffer = initialize_replay_buffer(cfg, device, storage_device) + batch_size = cfg.batch_size + offline_replay_buffer = None + + if cfg.dataset is not None: + offline_replay_buffer = initialize_offline_replay_buffer( + cfg=cfg, + device=device, + storage_device=storage_device, + ) + batch_size: int = batch_size // 2 # We will sample from both replay buffer + + logging.info("Starting learner thread") + interaction_message = None + optimization_step = resume_optimization_step if resume_optimization_step is not None else 0 + interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0 + + dataset_repo_id = None + if cfg.dataset is not None: + dataset_repo_id = cfg.dataset.repo_id + + # Initialize iterators + online_iterator = None + offline_iterator = None + + # NOTE: THIS IS THE MAIN LOOP OF THE LEARNER + while True: + # Exit the training loop if shutdown is requested + if shutdown_event is not None and shutdown_event.is_set(): + logging.info("[LEARNER] Shutdown signal received. Exiting...") + break + + # Process all available transitions to the replay buffer, send by the actor server + process_transitions( + transition_queue=transition_queue, + replay_buffer=replay_buffer, + offline_replay_buffer=offline_replay_buffer, + device=device, + dataset_repo_id=dataset_repo_id, + shutdown_event=shutdown_event, + ) + + # Process all available interaction messages sent by the actor server + interaction_message = process_interaction_messages( + interaction_message_queue=interaction_message_queue, + interaction_step_shift=interaction_step_shift, + wandb_logger=wandb_logger, + shutdown_event=shutdown_event, + ) + + # Wait until the replay buffer has enough samples to start training + if len(replay_buffer) < online_step_before_learning: + continue + + if online_iterator is None: + online_iterator = replay_buffer.get_iterator( + batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2 + ) + + if offline_replay_buffer is not None and offline_iterator is None: + offline_iterator = offline_replay_buffer.get_iterator( + batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2 + ) + + time_for_one_optimization_step = time.time() + for _ in range(utd_ratio - 1): + # Sample from the iterators + batch = next(online_iterator) + + if dataset_repo_id is not None: + batch_offline = next(offline_iterator) + batch = concatenate_batch_transitions( + left_batch_transitions=batch, right_batch_transition=batch_offline + ) + + actions = batch["action"] + rewards = batch["reward"] + observations = batch["state"] + next_observations = batch["next_state"] + done = batch["done"] + check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) + + observation_features, next_observation_features = get_observation_features( + policy=policy, observations=observations, next_observations=next_observations + ) + + # Create a batch dictionary with all required elements for the forward method + forward_batch = { + "action": actions, + "reward": rewards, + "state": observations, + "next_state": next_observations, + "done": done, + "observation_feature": observation_features, + "next_observation_feature": next_observation_features, + "complementary_info": batch["complementary_info"], + } + + # Use the forward method for critic loss + critic_output = policy.forward(forward_batch, model="critic") + + # Main critic optimization + loss_critic = critic_output["loss_critic"] + optimizers["critic"].zero_grad() + loss_critic.backward() + critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value + ) + optimizers["critic"].step() + + # Discrete critic optimization (if available) + if policy.config.num_discrete_actions is not None: + discrete_critic_output = policy.forward(forward_batch, model="discrete_critic") + loss_discrete_critic = discrete_critic_output["loss_discrete_critic"] + optimizers["discrete_critic"].zero_grad() + loss_discrete_critic.backward() + discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value + ) + optimizers["discrete_critic"].step() + + # Update target networks (main and discrete) + policy.update_target_networks() + + # Sample for the last update in the UTD ratio + batch = next(online_iterator) + + if dataset_repo_id is not None: + batch_offline = next(offline_iterator) + batch = concatenate_batch_transitions( + left_batch_transitions=batch, right_batch_transition=batch_offline + ) + + actions = batch["action"] + rewards = batch["reward"] + observations = batch["state"] + next_observations = batch["next_state"] + done = batch["done"] + + check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) + + observation_features, next_observation_features = get_observation_features( + policy=policy, observations=observations, next_observations=next_observations + ) + + # Create a batch dictionary with all required elements for the forward method + forward_batch = { + "action": actions, + "reward": rewards, + "state": observations, + "next_state": next_observations, + "done": done, + "observation_feature": observation_features, + "next_observation_feature": next_observation_features, + } + + critic_output = policy.forward(forward_batch, model="critic") + + loss_critic = critic_output["loss_critic"] + optimizers["critic"].zero_grad() + loss_critic.backward() + critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value + ).item() + optimizers["critic"].step() + + # Initialize training info dictionary + training_infos = { + "loss_critic": loss_critic.item(), + "critic_grad_norm": critic_grad_norm, + } + + # Discrete critic optimization (if available) + if policy.config.num_discrete_actions is not None: + discrete_critic_output = policy.forward(forward_batch, model="discrete_critic") + loss_discrete_critic = discrete_critic_output["loss_discrete_critic"] + optimizers["discrete_critic"].zero_grad() + loss_discrete_critic.backward() + discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value + ).item() + optimizers["discrete_critic"].step() + + # Add discrete critic info to training info + training_infos["loss_discrete_critic"] = loss_discrete_critic.item() + training_infos["discrete_critic_grad_norm"] = discrete_critic_grad_norm + + # Actor and temperature optimization (at specified frequency) + if optimization_step % policy_update_freq == 0: + for _ in range(policy_update_freq): + # Actor optimization + actor_output = policy.forward(forward_batch, model="actor") + loss_actor = actor_output["loss_actor"] + optimizers["actor"].zero_grad() + loss_actor.backward() + actor_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value + ).item() + optimizers["actor"].step() + + # Add actor info to training info + training_infos["loss_actor"] = loss_actor.item() + training_infos["actor_grad_norm"] = actor_grad_norm + + # Temperature optimization + temperature_output = policy.forward(forward_batch, model="temperature") + loss_temperature = temperature_output["loss_temperature"] + optimizers["temperature"].zero_grad() + loss_temperature.backward() + temp_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=[policy.log_alpha], max_norm=clip_grad_norm_value + ).item() + optimizers["temperature"].step() + + # Add temperature info to training info + training_infos["loss_temperature"] = loss_temperature.item() + training_infos["temperature_grad_norm"] = temp_grad_norm + training_infos["temperature"] = policy.temperature + + # Update temperature + policy.update_temperature() + + # Push policy to actors if needed + if time.time() - last_time_policy_pushed > policy_parameters_push_frequency: + push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) + last_time_policy_pushed = time.time() + + # Update target networks (main and discrete) + policy.update_target_networks() + + # Log training metrics at specified intervals + if optimization_step % log_freq == 0: + training_infos["replay_buffer_size"] = len(replay_buffer) + if offline_replay_buffer is not None: + training_infos["offline_replay_buffer_size"] = len(offline_replay_buffer) + training_infos["Optimization step"] = optimization_step + + # Log training metrics + if wandb_logger: + wandb_logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step") + + # Calculate and log optimization frequency + time_for_one_optimization_step = time.time() - time_for_one_optimization_step + frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9) + + logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}") + + # Log optimization frequency + if wandb_logger: + wandb_logger.log_dict( + { + "Optimization frequency loop [Hz]": frequency_for_one_optimization_step, + "Optimization step": optimization_step, + }, + mode="train", + custom_step_key="Optimization step", + ) + + optimization_step += 1 + if optimization_step % log_freq == 0: + logging.info(f"[LEARNER] Number of optimization step: {optimization_step}") + + # Save checkpoint at specified intervals + if saving_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps): + save_training_checkpoint( + cfg=cfg, + optimization_step=optimization_step, + online_steps=online_steps, + interaction_message=interaction_message, + policy=policy, + optimizers=optimizers, + replay_buffer=replay_buffer, + offline_replay_buffer=offline_replay_buffer, + dataset_repo_id=dataset_repo_id, + fps=fps, + ) + + +def start_learner( + parameters_queue: Queue, + transition_queue: Queue, + interaction_message_queue: Queue, + shutdown_event: any, # Event, + cfg: TrainRLServerPipelineConfig, +): + """ + Start the learner server for training. + It will receive transitions and interaction messages from the actor server, + and send policy parameters to the actor server. + + Args: + parameters_queue: Queue for sending policy parameters to the actor + transition_queue: Queue for receiving transitions from the actor + interaction_message_queue: Queue for receiving interaction messages from the actor + shutdown_event: Event to signal shutdown + cfg: Training configuration + """ + if not use_threads(cfg): + # Create a process-specific log file + log_dir = os.path.join(cfg.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"learner_process_{os.getpid()}.log") + + # Initialize logging with explicit log file + init_logging(log_file=log_file, display_pid=True) + logging.info("Learner server process logging initialized") + + # Setup process handlers to handle shutdown signal + # But use shutdown event from the main process + # Return back for MP + # TODO: Check if its useful + _ = ProcessSignalHandler(False, display_pid=True) + + service = learner_service.LearnerService( + shutdown_event=shutdown_event, + parameters_queue=parameters_queue, + seconds_between_pushes=cfg.policy.actor_learner_config.policy_parameters_push_frequency, + transition_queue=transition_queue, + interaction_message_queue=interaction_message_queue, + queue_get_timeout=cfg.policy.actor_learner_config.queue_get_timeout, + ) + + server = grpc.server( + ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS), + options=[ + ("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE), + ("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE), + ], + ) + + services_pb2_grpc.add_LearnerServiceServicer_to_server( + service, + server, + ) + + host = cfg.policy.actor_learner_config.learner_host + port = cfg.policy.actor_learner_config.learner_port + + server.add_insecure_port(f"{host}:{port}") + server.start() + logging.info("[LEARNER] gRPC server started") + + shutdown_event.wait() + logging.info("[LEARNER] Stopping gRPC server...") + server.stop(learner_service.SHUTDOWN_TIMEOUT) + logging.info("[LEARNER] gRPC server stopped") + + +def save_training_checkpoint( + cfg: TrainRLServerPipelineConfig, + optimization_step: int, + online_steps: int, + interaction_message: dict | None, + policy: nn.Module, + optimizers: dict[str, Optimizer], + replay_buffer: ReplayBuffer, + offline_replay_buffer: ReplayBuffer | None = None, + dataset_repo_id: str | None = None, + fps: int = 30, +) -> None: + """ + Save training checkpoint and associated data. + + This function performs the following steps: + 1. Creates a checkpoint directory with the current optimization step + 2. Saves the policy model, configuration, and optimizer states + 3. Saves the current interaction step for resuming training + 4. Updates the "last" checkpoint symlink to point to this checkpoint + 5. Saves the replay buffer as a dataset for later use + 6. If an offline replay buffer exists, saves it as a separate dataset + + Args: + cfg: Training configuration + optimization_step: Current optimization step + online_steps: Total number of online steps + interaction_message: Dictionary containing interaction information + policy: Policy model to save + optimizers: Dictionary of optimizers + replay_buffer: Replay buffer to save as dataset + offline_replay_buffer: Optional offline replay buffer to save + dataset_repo_id: Repository ID for dataset + fps: Frames per second for dataset + """ + logging.info(f"Checkpoint policy after step {optimization_step}") + _num_digits = max(6, len(str(online_steps))) + interaction_step = interaction_message["Interaction step"] if interaction_message is not None else 0 + + # Create checkpoint directory + checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step) + + # Save checkpoint + save_checkpoint( + checkpoint_dir=checkpoint_dir, + step=optimization_step, + cfg=cfg, + policy=policy, + optimizer=optimizers, + scheduler=None, + ) + + # Save interaction step manually + training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR) + os.makedirs(training_state_dir, exist_ok=True) + training_state = {"step": optimization_step, "interaction_step": interaction_step} + torch.save(training_state, os.path.join(training_state_dir, "training_state.pt")) + + # Update the "last" symlink + update_last_checkpoint(checkpoint_dir) + + # TODO : temporary save replay buffer here, remove later when on the robot + # We want to control this with the keyboard inputs + dataset_dir = os.path.join(cfg.output_dir, "dataset") + if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir): + shutil.rmtree(dataset_dir) + + # Save dataset + # NOTE: Handle the case where the dataset repo id is not specified in the config + # eg. RL training without demonstrations data + repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id + replay_buffer.to_lerobot_dataset(repo_id=repo_id_buffer_save, fps=fps, root=dataset_dir) + + if offline_replay_buffer is not None: + dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline") + if os.path.exists(dataset_offline_dir) and os.path.isdir(dataset_offline_dir): + shutil.rmtree(dataset_offline_dir) + + offline_replay_buffer.to_lerobot_dataset( + cfg.dataset.repo_id, + fps=fps, + root=dataset_offline_dir, + ) + + logging.info("Resume training") + + +def make_optimizers_and_scheduler(cfg: TrainRLServerPipelineConfig, policy: nn.Module): + """ + Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy. + + This function sets up Adam optimizers for: + - The **actor network**, ensuring that only relevant parameters are optimized. + - The **critic ensemble**, which evaluates the value function. + - The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods. + + It also initializes a learning rate scheduler, though currently, it is set to `None`. + + NOTE: + - If the encoder is shared, its parameters are excluded from the actor's optimization process. + - The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor. + + Args: + cfg: Configuration object containing hyperparameters. + policy (nn.Module): The policy model containing the actor, critic, and temperature components. + + Returns: + Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]: + A tuple containing: + - `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers. + - `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling. + + """ + optimizer_actor = torch.optim.Adam( + params=[ + p + for n, p in policy.actor.named_parameters() + if not policy.config.shared_encoder or not n.startswith("encoder") + ], + lr=cfg.policy.actor_lr, + ) + optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) + + if cfg.policy.num_discrete_actions is not None: + optimizer_discrete_critic = torch.optim.Adam( + params=policy.discrete_critic.parameters(), lr=cfg.policy.critic_lr + ) + optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) + lr_scheduler = None + optimizers = { + "actor": optimizer_actor, + "critic": optimizer_critic, + "temperature": optimizer_temperature, + } + if cfg.policy.num_discrete_actions is not None: + optimizers["discrete_critic"] = optimizer_discrete_critic + return optimizers, lr_scheduler + + +################################################# +# Training setup functions # +################################################# + + +def handle_resume_logic(cfg: TrainRLServerPipelineConfig) -> TrainRLServerPipelineConfig: + """ + Handle the resume logic for training. + + If resume is True: + - Verifies that a checkpoint exists + - Loads the checkpoint configuration + - Logs resumption details + - Returns the checkpoint configuration + + If resume is False: + - Checks if an output directory exists (to prevent accidental overwriting) + - Returns the original configuration + + Args: + cfg (TrainRLServerPipelineConfig): The training configuration + + Returns: + TrainRLServerPipelineConfig: The updated configuration + + Raises: + RuntimeError: If resume is True but no checkpoint found, or if resume is False but directory exists + """ + out_dir = cfg.output_dir + + # Case 1: Not resuming, but need to check if directory exists to prevent overwrites + if not cfg.resume: + checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) + if os.path.exists(checkpoint_dir): + raise RuntimeError( + f"Output directory {checkpoint_dir} already exists. Use `resume=true` to resume training." + ) + return cfg + + # Case 2: Resuming training + checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) + if not os.path.exists(checkpoint_dir): + raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True") + + # Log that we found a valid checkpoint and are resuming + logging.info( + colored( + "Valid checkpoint found: resume=True detected, resuming previous run", + color="yellow", + attrs=["bold"], + ) + ) + + # Load config using Draccus + checkpoint_cfg_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR, "train_config.json") + checkpoint_cfg = TrainRLServerPipelineConfig.from_pretrained(checkpoint_cfg_path) + + # Ensure resume flag is set in returned config + checkpoint_cfg.resume = True + return checkpoint_cfg + + +def load_training_state( + cfg: TrainRLServerPipelineConfig, + optimizers: Optimizer | dict[str, Optimizer], +): + """ + Loads the training state (optimizers, step count, etc.) from a checkpoint. + + Args: + cfg (TrainRLServerPipelineConfig): Training configuration + optimizers (Optimizer | dict): Optimizers to load state into + + Returns: + tuple: (optimization_step, interaction_step) or (None, None) if not resuming + """ + if not cfg.resume: + return None, None + + # Construct path to the last checkpoint directory + checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) + + logging.info(f"Loading training state from {checkpoint_dir}") + + try: + # Use the utility function from train_utils which loads the optimizer state + step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None) + + # Load interaction step separately from training_state.pt + training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt") + interaction_step = 0 + if os.path.exists(training_state_path): + training_state = torch.load(training_state_path, weights_only=False) # nosec B614: Safe usage of torch.load + interaction_step = training_state.get("interaction_step", 0) + + logging.info(f"Resuming from step {step}, interaction step {interaction_step}") + return step, interaction_step + + except Exception as e: + logging.error(f"Failed to load training state: {e}") + return None, None + + +def log_training_info(cfg: TrainRLServerPipelineConfig, policy: nn.Module) -> None: + """ + Log information about the training process. + + Args: + cfg (TrainRLServerPipelineConfig): Training configuration + policy (nn.Module): Policy model + """ + num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) + num_total_params = sum(p.numel() for p in policy.parameters()) + + logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") + logging.info(f"{cfg.env.task=}") + logging.info(f"{cfg.policy.online_steps=}") + logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") + logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") + + +def initialize_replay_buffer( + cfg: TrainRLServerPipelineConfig, device: str, storage_device: str +) -> ReplayBuffer: + """ + Initialize a replay buffer, either empty or from a dataset if resuming. + + Args: + cfg (TrainRLServerPipelineConfig): Training configuration + device (str): Device to store tensors on + storage_device (str): Device for storage optimization + + Returns: + ReplayBuffer: Initialized replay buffer + """ + if not cfg.resume: + return ReplayBuffer( + capacity=cfg.policy.online_buffer_capacity, + device=device, + state_keys=cfg.policy.input_features.keys(), + storage_device=storage_device, + optimize_memory=True, + ) + + logging.info("Resume training load the online dataset") + dataset_path = os.path.join(cfg.output_dir, "dataset") + + # NOTE: In RL is possible to not have a dataset. + repo_id = None + if cfg.dataset is not None: + repo_id = cfg.dataset.repo_id + dataset = LeRobotDataset( + repo_id=repo_id, + root=dataset_path, + ) + return ReplayBuffer.from_lerobot_dataset( + lerobot_dataset=dataset, + capacity=cfg.policy.online_buffer_capacity, + device=device, + state_keys=cfg.policy.input_features.keys(), + optimize_memory=True, + ) + + +def initialize_offline_replay_buffer( + cfg: TrainRLServerPipelineConfig, + device: str, + storage_device: str, +) -> ReplayBuffer: + """ + Initialize an offline replay buffer from a dataset. + + Args: + cfg (TrainRLServerPipelineConfig): Training configuration + device (str): Device to store tensors on + storage_device (str): Device for storage optimization + + Returns: + ReplayBuffer: Initialized offline replay buffer + """ + if not cfg.resume: + logging.info("make_dataset offline buffer") + offline_dataset = make_dataset(cfg) + else: + logging.info("load offline dataset") + dataset_offline_path = os.path.join(cfg.output_dir, "dataset_offline") + offline_dataset = LeRobotDataset( + repo_id=cfg.dataset.repo_id, + root=dataset_offline_path, + ) + + logging.info("Convert to a offline replay buffer") + offline_replay_buffer = ReplayBuffer.from_lerobot_dataset( + offline_dataset, + device=device, + state_keys=cfg.policy.input_features.keys(), + storage_device=storage_device, + optimize_memory=True, + capacity=cfg.policy.offline_buffer_capacity, + ) + return offline_replay_buffer + + +################################################# +# Utilities/Helpers functions # +################################################# + + +def get_observation_features( + policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor +) -> tuple[torch.Tensor | None, torch.Tensor | None]: + """ + Get observation features from the policy encoder. It act as cache for the observation features. + when the encoder is frozen, the observation features are not updated. + We can save compute by caching the observation features. + + Args: + policy: The policy model + observations: The current observations + next_observations: The next observations + + Returns: + tuple: observation_features, next_observation_features + """ + + if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder: + return None, None + + with torch.no_grad(): + observation_features = policy.actor.encoder.get_cached_image_features(observations, normalize=True) + next_observation_features = policy.actor.encoder.get_cached_image_features( + next_observations, normalize=True + ) + + return observation_features, next_observation_features + + +def use_threads(cfg: TrainRLServerPipelineConfig) -> bool: + return cfg.policy.concurrency.learner == "threads" + + +def check_nan_in_transition( + observations: torch.Tensor, + actions: torch.Tensor, + next_state: torch.Tensor, + raise_error: bool = False, +) -> bool: + """ + Check for NaN values in transition data. + + Args: + observations: Dictionary of observation tensors + actions: Action tensor + next_state: Dictionary of next state tensors + raise_error: If True, raises ValueError when NaN is detected + + Returns: + bool: True if NaN values were detected, False otherwise + """ + nan_detected = False + + # Check observations + for key, tensor in observations.items(): + if torch.isnan(tensor).any(): + logging.error(f"observations[{key}] contains NaN values") + nan_detected = True + if raise_error: + raise ValueError(f"NaN detected in observations[{key}]") + + # Check next state + for key, tensor in next_state.items(): + if torch.isnan(tensor).any(): + logging.error(f"next_state[{key}] contains NaN values") + nan_detected = True + if raise_error: + raise ValueError(f"NaN detected in next_state[{key}]") + + # Check actions + if torch.isnan(actions).any(): + logging.error("actions contains NaN values") + nan_detected = True + if raise_error: + raise ValueError("NaN detected in actions") + + return nan_detected + + +def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module): + logging.debug("[LEARNER] Pushing actor policy to the queue") + state_dict = move_state_dict_to_device(policy.actor.state_dict(), device="cpu") + state_bytes = state_to_bytes(state_dict) + parameters_queue.put(state_bytes) + + +def process_interaction_message( + message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None +): + """Process a single interaction message with consistent handling.""" + message = bytes_to_python_object(message) + # Shift interaction step for consistency with checkpointed state + message["Interaction step"] += interaction_step_shift + + # Log if logger available + if wandb_logger: + wandb_logger.log_dict(d=message, mode="train", custom_step_key="Interaction step") + + return message + + +def process_transitions( + transition_queue: Queue, + replay_buffer: ReplayBuffer, + offline_replay_buffer: ReplayBuffer, + device: str, + dataset_repo_id: str | None, + shutdown_event: any, +): + """Process all available transitions from the queue. + + Args: + transition_queue: Queue for receiving transitions from the actor + replay_buffer: Replay buffer to add transitions to + offline_replay_buffer: Offline replay buffer to add transitions to + device: Device to move transitions to + dataset_repo_id: Repository ID for dataset + shutdown_event: Event to signal shutdown + """ + while not transition_queue.empty() and not shutdown_event.is_set(): + transition_list = transition_queue.get() + transition_list = bytes_to_transitions(buffer=transition_list) + + for transition in transition_list: + transition = move_transition_to_device(transition=transition, device=device) + + # Skip transitions with NaN values + if check_nan_in_transition( + observations=transition["state"], + actions=transition["action"], + next_state=transition["next_state"], + ): + logging.warning("[LEARNER] NaN detected in transition, skipping") + continue + + replay_buffer.add(**transition) + + # Add to offline buffer if it's an intervention + if dataset_repo_id is not None and transition.get("complementary_info", {}).get( + "is_intervention" + ): + offline_replay_buffer.add(**transition) + + +def process_interaction_messages( + interaction_message_queue: Queue, + interaction_step_shift: int, + wandb_logger: WandBLogger | None, + shutdown_event: any, +) -> dict | None: + """Process all available interaction messages from the queue. + + Args: + interaction_message_queue: Queue for receiving interaction messages + interaction_step_shift: Amount to shift interaction step by + wandb_logger: Logger for tracking progress + shutdown_event: Event to signal shutdown + + Returns: + dict | None: The last interaction message processed, or None if none were processed + """ + last_message = None + while not interaction_message_queue.empty() and not shutdown_event.is_set(): + message = interaction_message_queue.get() + last_message = process_interaction_message( + message=message, + interaction_step_shift=interaction_step_shift, + wandb_logger=wandb_logger, + ) + + return last_message + + +if __name__ == "__main__": + train_cli() + logging.info("[LEARNER] main finished") diff --git a/lerobot/scripts/rl/learner_service.py b/lerobot/scripts/rl/learner_service.py new file mode 100644 index 0000000000..f967d812cf --- /dev/null +++ b/lerobot/scripts/rl/learner_service.py @@ -0,0 +1,118 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from multiprocessing import Event, Queue + +from lerobot.common.transport import services_pb2, services_pb2_grpc +from lerobot.common.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks +from lerobot.common.utils.queue import get_last_item_from_queue + +MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB +MAX_WORKERS = 3 # Stream parameters, send transitions and interactions +SHUTDOWN_TIMEOUT = 10 + + +class LearnerService(services_pb2_grpc.LearnerServiceServicer): + """ + Implementation of the LearnerService gRPC service + This service is used to send parameters to the Actor and receive transitions and interactions from the Actor + check transport.proto for the gRPC service definition + """ + + def __init__( + self, + shutdown_event: Event, # type: ignore + parameters_queue: Queue, + seconds_between_pushes: float, + transition_queue: Queue, + interaction_message_queue: Queue, + queue_get_timeout: float = 0.001, + ): + self.shutdown_event = shutdown_event + self.parameters_queue = parameters_queue + self.seconds_between_pushes = seconds_between_pushes + self.transition_queue = transition_queue + self.interaction_message_queue = interaction_message_queue + self.queue_get_timeout = queue_get_timeout + + def StreamParameters(self, request, context): # noqa: N802 + # TODO: authorize the request + logging.info("[LEARNER] Received request to stream parameters from the Actor") + + last_push_time = 0 + + while not self.shutdown_event.is_set(): + time_since_last_push = time.time() - last_push_time + if time_since_last_push < self.seconds_between_pushes: + self.shutdown_event.wait(self.seconds_between_pushes - time_since_last_push) + # Continue, because we could receive a shutdown event, + # and it's checked in the while loop + continue + + logging.info("[LEARNER] Push parameters to the Actor") + buffer = get_last_item_from_queue( + self.parameters_queue, block=True, timeout=self.queue_get_timeout + ) + + if buffer is None: + continue + + yield from send_bytes_in_chunks( + buffer, + services_pb2.Parameters, + log_prefix="[LEARNER] Sending parameters", + silent=True, + ) + + last_push_time = time.time() + logging.info("[LEARNER] Parameters sent") + + logging.info("[LEARNER] Stream parameters finished") + return services_pb2.Empty() + + def SendTransitions(self, request_iterator, _context): # noqa: N802 + # TODO: authorize the request + logging.info("[LEARNER] Received request to receive transitions from the Actor") + + receive_bytes_in_chunks( + request_iterator, + self.transition_queue, + self.shutdown_event, + log_prefix="[LEARNER] transitions", + ) + + logging.debug("[LEARNER] Finished receiving transitions") + return services_pb2.Empty() + + def SendInteractions(self, request_iterator, _context): # noqa: N802 + # TODO: authorize the request + logging.info("[LEARNER] Received request to receive interactions from the Actor") + + receive_bytes_in_chunks( + request_iterator, + self.interaction_message_queue, + self.shutdown_event, + log_prefix="[LEARNER] interactions", + ) + + logging.debug("[LEARNER] Finished receiving interactions") + return services_pb2.Empty() + + def Ready(self, request, context): # noqa: N802 + return services_pb2.Empty() diff --git a/lerobot/teleoperate.py b/lerobot/teleoperate.py index 97e6104301..6080dfb403 100644 --- a/lerobot/teleoperate.py +++ b/lerobot/teleoperate.py @@ -58,7 +58,7 @@ from lerobot.common.utils.utils import init_logging, move_cursor_up from lerobot.common.utils.visualization_utils import _init_rerun -from .common.teleoperators import koch_leader, so100_leader, so101_leader # noqa: F401 +from .common.teleoperators import gamepad, koch_leader, so100_leader, so101_leader # noqa: F401 @dataclass diff --git a/pyproject.toml b/pyproject.toml index a99b1b16c4..1ebef75bff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,7 @@ dependencies = [ "pyserial>=3.5", "pyzmq>=26.2.1", "rerun-sdk>=0.21.0", + "scipy>=1.14.0", "termcolor>=2.4.0", "torch>=2.2.1", "torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", @@ -97,7 +98,8 @@ stretch = [ "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'", "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'" ] -test = ["pytest>=8.1.0", "pytest-cov>=5.0.0", "mock-serial>=0.0.1 ; sys_platform != 'win32'"] +test = ["pytest>=8.1.0", "pytest-timeout>=2.4.0", "pytest-cov>=5.0.0", "pyserial>=3.5", "mock-serial>=0.0.1 ; sys_platform != 'win32'"] +hilserl = ["transformers>=4.48", "gym-hil>=0.1.8", "protobuf>=5.29.3", "grpcio==1.71.0"] umi = ["imagecodecs>=2024.1.1"] video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"] xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"] @@ -108,7 +110,7 @@ requires-poetry = ">=2.1" [tool.ruff] line-length = 110 target-version = "py310" -exclude = ["tests/artifacts/**/*.safetensors"] +exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"] [tool.ruff.lint] select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"] diff --git a/tests/optim/test_optimizers.py b/tests/optim/test_optimizers.py index 997e14fe94..630353fcaf 100644 --- a/tests/optim/test_optimizers.py +++ b/tests/optim/test_optimizers.py @@ -21,6 +21,7 @@ from lerobot.common.optim.optimizers import ( AdamConfig, AdamWConfig, + MultiAdamConfig, SGDConfig, load_optimizer_state, save_optimizer_state, @@ -33,13 +34,21 @@ (AdamConfig, torch.optim.Adam), (AdamWConfig, torch.optim.AdamW), (SGDConfig, torch.optim.SGD), + (MultiAdamConfig, dict), ], ) def test_optimizer_build(config_cls, expected_class, model_params): config = config_cls() - optimizer = config.build(model_params) - assert isinstance(optimizer, expected_class) - assert optimizer.defaults["lr"] == config.lr + if config_cls == MultiAdamConfig: + params_dict = {"default": model_params} + optimizer = config.build(params_dict) + assert isinstance(optimizer, expected_class) + assert isinstance(optimizer["default"], torch.optim.Adam) + assert optimizer["default"].defaults["lr"] == config.lr + else: + optimizer = config.build(model_params) + assert isinstance(optimizer, expected_class) + assert optimizer.defaults["lr"] == config.lr def test_save_optimizer_state(optimizer, tmp_path): @@ -54,3 +63,180 @@ def test_save_and_load_optimizer_state(model_params, optimizer, tmp_path): loaded_optimizer = load_optimizer_state(loaded_optimizer, tmp_path) torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict()) + + +@pytest.fixture +def base_params_dict(): + return { + "actor": [torch.nn.Parameter(torch.randn(10, 10))], + "critic": [torch.nn.Parameter(torch.randn(5, 5))], + "temperature": [torch.nn.Parameter(torch.randn(3, 3))], + } + + +@pytest.mark.parametrize( + "config_params, expected_values", + [ + # Test 1: Basic configuration with different learning rates + ( + { + "lr": 1e-3, + "weight_decay": 1e-4, + "optimizer_groups": { + "actor": {"lr": 1e-4}, + "critic": {"lr": 5e-4}, + "temperature": {"lr": 2e-3}, + }, + }, + { + "actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)}, + "critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)}, + "temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999)}, + }, + ), + # Test 2: Different weight decays and beta values + ( + { + "lr": 1e-3, + "weight_decay": 1e-4, + "optimizer_groups": { + "actor": {"lr": 1e-4, "weight_decay": 1e-5}, + "critic": {"lr": 5e-4, "weight_decay": 1e-6}, + "temperature": {"lr": 2e-3, "betas": (0.95, 0.999)}, + }, + }, + { + "actor": {"lr": 1e-4, "weight_decay": 1e-5, "betas": (0.9, 0.999)}, + "critic": {"lr": 5e-4, "weight_decay": 1e-6, "betas": (0.9, 0.999)}, + "temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.95, 0.999)}, + }, + ), + # Test 3: Epsilon parameter customization + ( + { + "lr": 1e-3, + "weight_decay": 1e-4, + "optimizer_groups": { + "actor": {"lr": 1e-4, "eps": 1e-6}, + "critic": {"lr": 5e-4, "eps": 1e-7}, + "temperature": {"lr": 2e-3, "eps": 1e-8}, + }, + }, + { + "actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-6}, + "critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-7}, + "temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-8}, + }, + ), + ], +) +def test_multi_adam_configuration(base_params_dict, config_params, expected_values): + # Create config with the given parameters + config = MultiAdamConfig(**config_params) + optimizers = config.build(base_params_dict) + + # Verify optimizer count and keys + assert len(optimizers) == len(expected_values) + assert set(optimizers.keys()) == set(expected_values.keys()) + + # Check that all optimizers are Adam instances + for opt in optimizers.values(): + assert isinstance(opt, torch.optim.Adam) + + # Verify hyperparameters for each optimizer + for name, expected in expected_values.items(): + optimizer = optimizers[name] + for param, value in expected.items(): + assert optimizer.defaults[param] == value + + +@pytest.fixture +def multi_optimizers(base_params_dict): + config = MultiAdamConfig( + lr=1e-3, + optimizer_groups={ + "actor": {"lr": 1e-4}, + "critic": {"lr": 5e-4}, + "temperature": {"lr": 2e-3}, + }, + ) + return config.build(base_params_dict) + + +def test_save_multi_optimizer_state(multi_optimizers, tmp_path): + # Save optimizer states + save_optimizer_state(multi_optimizers, tmp_path) + + # Verify that directories were created for each optimizer + for name in multi_optimizers: + assert (tmp_path / name).is_dir() + assert (tmp_path / name / OPTIMIZER_STATE).is_file() + assert (tmp_path / name / OPTIMIZER_PARAM_GROUPS).is_file() + + +def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers, tmp_path): + # Option 1: Add a minimal backward pass to populate optimizer states + for name, params in base_params_dict.items(): + if name in multi_optimizers: + # Create a dummy loss and do backward + dummy_loss = params[0].sum() + dummy_loss.backward() + # Perform an optimization step + multi_optimizers[name].step() + # Zero gradients for next steps + multi_optimizers[name].zero_grad() + + # Save optimizer states + save_optimizer_state(multi_optimizers, tmp_path) + + # Create new optimizers with the same config + config = MultiAdamConfig( + lr=1e-3, + optimizer_groups={ + "actor": {"lr": 1e-4}, + "critic": {"lr": 5e-4}, + "temperature": {"lr": 2e-3}, + }, + ) + new_optimizers = config.build(base_params_dict) + + # Load optimizer states + loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path) + + # Verify state dictionaries match + for name in multi_optimizers: + torch.testing.assert_close(multi_optimizers[name].state_dict(), loaded_optimizers[name].state_dict()) + + +def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path): + """Test saving and loading optimizer states even when the state is empty (no backward pass).""" + # Create config and build optimizers + config = MultiAdamConfig( + lr=1e-3, + optimizer_groups={ + "actor": {"lr": 1e-4}, + "critic": {"lr": 5e-4}, + "temperature": {"lr": 2e-3}, + }, + ) + optimizers = config.build(base_params_dict) + + # Save optimizer states without any backward pass (empty state) + save_optimizer_state(optimizers, tmp_path) + + # Create new optimizers with the same config + new_optimizers = config.build(base_params_dict) + + # Load optimizer states + loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path) + + # Verify hyperparameters match even with empty state + for name, optimizer in optimizers.items(): + assert optimizer.defaults["lr"] == loaded_optimizers[name].defaults["lr"] + assert optimizer.defaults["weight_decay"] == loaded_optimizers[name].defaults["weight_decay"] + assert optimizer.defaults["betas"] == loaded_optimizers[name].defaults["betas"] + + # Verify state dictionaries match (they will be empty) + torch.testing.assert_close( + optimizer.state_dict()["param_groups"], loaded_optimizers[name].state_dict()["param_groups"] + ) diff --git a/tests/policies/hilserl/test_modeling_classifier.py b/tests/policies/hilserl/test_modeling_classifier.py new file mode 100644 index 0000000000..526e1f17dd --- /dev/null +++ b/tests/policies/hilserl/test_modeling_classifier.py @@ -0,0 +1,139 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from lerobot.common.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig +from lerobot.common.policies.sac.reward_model.modeling_classifier import ClassifierOutput +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from tests.utils import require_package + + +def test_classifier_output(): + output = ClassifierOutput( + logits=torch.tensor([1, 2, 3]), + probabilities=torch.tensor([0.1, 0.2, 0.3]), + hidden_states=None, + ) + + assert ( + f"{output}" + == "ClassifierOutput(logits=tensor([1, 2, 3]), probabilities=tensor([0.1000, 0.2000, 0.3000]), hidden_states=None)" + ) + + +@require_package("transformers") +def test_binary_classifier_with_default_params(): + from lerobot.common.policies.sac.reward_model.modeling_classifier import Classifier + + config = RewardClassifierConfig() + config.input_features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,)), + } + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "REWARD": NormalizationMode.IDENTITY, + } + config.num_cameras = 1 + classifier = Classifier(config) + + batch_size = 10 + + input = { + "observation.image": torch.rand((batch_size, 3, 128, 128)), + "next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(), + } + + images, labels = classifier.extract_images_and_labels(input) + assert len(images) == 1 + assert images[0].shape == torch.Size([batch_size, 3, 128, 128]) + assert labels.shape == torch.Size([batch_size]) + + output = classifier.predict(images) + + assert output is not None + assert output.logits.size() == torch.Size([batch_size]) + assert not torch.isnan(output.logits).any(), "Tensor contains NaN values" + assert output.probabilities.shape == torch.Size([batch_size]) + assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values" + assert output.hidden_states.shape == torch.Size([batch_size, 256]) + assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values" + + +@require_package("transformers") +def test_multiclass_classifier(): + from lerobot.common.policies.sac.reward_model.modeling_classifier import Classifier + + num_classes = 5 + config = RewardClassifierConfig() + config.input_features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)), + } + config.num_cameras = 1 + config.num_classes = num_classes + classifier = Classifier(config) + + batch_size = 10 + + input = { + "observation.image": torch.rand((batch_size, 3, 128, 128)), + "next.reward": torch.rand((batch_size, num_classes)), + } + + images, labels = classifier.extract_images_and_labels(input) + assert len(images) == 1 + assert images[0].shape == torch.Size([batch_size, 3, 128, 128]) + assert labels.shape == torch.Size([batch_size, num_classes]) + + output = classifier.predict(images) + + assert output is not None + assert output.logits.shape == torch.Size([batch_size, num_classes]) + assert not torch.isnan(output.logits).any(), "Tensor contains NaN values" + assert output.probabilities.shape == torch.Size([batch_size, num_classes]) + assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values" + assert output.hidden_states.shape == torch.Size([batch_size, 256]) + assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values" + + +@require_package("transformers") +def test_default_device(): + from lerobot.common.policies.sac.reward_model.modeling_classifier import Classifier + + config = RewardClassifierConfig() + assert config.device == "cpu" + + classifier = Classifier(config) + for p in classifier.parameters(): + assert p.device == torch.device("cpu") + + +@require_package("transformers") +def test_explicit_device_setup(): + from lerobot.common.policies.sac.reward_model.modeling_classifier import Classifier + + config = RewardClassifierConfig(device="cpu") + assert config.device == "cpu" + + classifier = Classifier(config) + for p in classifier.parameters(): + assert p.device == torch.device("cpu") diff --git a/tests/policies/test_sac_config.py b/tests/policies/test_sac_config.py new file mode 100644 index 0000000000..d94ee41e04 --- /dev/null +++ b/tests/policies/test_sac_config.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from lerobot.common.policies.sac.configuration_sac import ( + ActorLearnerConfig, + ActorNetworkConfig, + ConcurrencyConfig, + CriticNetworkConfig, + PolicyConfig, + SACConfig, +) +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature + + +def test_sac_config_default_initialization(): + config = SACConfig() + + assert config.normalization_mapping == { + "VISUAL": NormalizationMode.MEAN_STD, + "STATE": NormalizationMode.MIN_MAX, + "ENV": NormalizationMode.MIN_MAX, + "ACTION": NormalizationMode.MIN_MAX, + } + assert config.dataset_stats == { + "observation.image": { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + }, + "observation.state": { + "min": [0.0, 0.0], + "max": [1.0, 1.0], + }, + "action": { + "min": [0.0, 0.0, 0.0], + "max": [1.0, 1.0, 1.0], + }, + } + + # Basic parameters + assert config.device == "cpu" + assert config.storage_device == "cpu" + assert config.discount == 0.99 + assert config.temperature_init == 1.0 + assert config.num_critics == 2 + + # Architecture specifics + assert config.vision_encoder_name is None + assert config.freeze_vision_encoder is True + assert config.image_encoder_hidden_dim == 32 + assert config.shared_encoder is True + assert config.num_discrete_actions is None + assert config.image_embedding_pooling_dim == 8 + + # Training parameters + assert config.online_steps == 1000000 + assert config.online_env_seed == 10000 + assert config.online_buffer_capacity == 100000 + assert config.offline_buffer_capacity == 100000 + assert config.async_prefetch is False + assert config.online_step_before_learning == 100 + assert config.policy_update_freq == 1 + + # SAC algorithm parameters + assert config.num_subsample_critics is None + assert config.critic_lr == 3e-4 + assert config.actor_lr == 3e-4 + assert config.temperature_lr == 3e-4 + assert config.critic_target_update_weight == 0.005 + assert config.utd_ratio == 1 + assert config.state_encoder_hidden_dim == 256 + assert config.latent_dim == 256 + assert config.target_entropy is None + assert config.use_backup_entropy is True + assert config.grad_clip_norm == 40.0 + + # Dataset stats defaults + expected_dataset_stats = { + "observation.image": { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + }, + "observation.state": { + "min": [0.0, 0.0], + "max": [1.0, 1.0], + }, + "action": { + "min": [0.0, 0.0, 0.0], + "max": [1.0, 1.0, 1.0], + }, + } + assert config.dataset_stats == expected_dataset_stats + + # Critic network configuration + assert config.critic_network_kwargs.hidden_dims == [256, 256] + assert config.critic_network_kwargs.activate_final is True + assert config.critic_network_kwargs.final_activation is None + + # Actor network configuration + assert config.actor_network_kwargs.hidden_dims == [256, 256] + assert config.actor_network_kwargs.activate_final is True + + # Policy configuration + assert config.policy_kwargs.use_tanh_squash is True + assert config.policy_kwargs.std_min == 1e-5 + assert config.policy_kwargs.std_max == 10.0 + assert config.policy_kwargs.init_final == 0.05 + + # Discrete critic network configuration + assert config.discrete_critic_network_kwargs.hidden_dims == [256, 256] + assert config.discrete_critic_network_kwargs.activate_final is True + assert config.discrete_critic_network_kwargs.final_activation is None + + # Actor learner configuration + assert config.actor_learner_config.learner_host == "127.0.0.1" + assert config.actor_learner_config.learner_port == 50051 + assert config.actor_learner_config.policy_parameters_push_frequency == 4 + + # Concurrency configuration + assert config.concurrency.actor == "threads" + assert config.concurrency.learner == "threads" + + assert isinstance(config.actor_network_kwargs, ActorNetworkConfig) + assert isinstance(config.critic_network_kwargs, CriticNetworkConfig) + assert isinstance(config.policy_kwargs, PolicyConfig) + assert isinstance(config.actor_learner_config, ActorLearnerConfig) + assert isinstance(config.concurrency, ConcurrencyConfig) + + +def test_critic_network_kwargs(): + config = CriticNetworkConfig() + assert config.hidden_dims == [256, 256] + assert config.activate_final is True + assert config.final_activation is None + + +def test_actor_network_kwargs(): + config = ActorNetworkConfig() + assert config.hidden_dims == [256, 256] + assert config.activate_final is True + + +def test_policy_kwargs(): + config = PolicyConfig() + assert config.use_tanh_squash is True + assert config.std_min == 1e-5 + assert config.std_max == 10.0 + assert config.init_final == 0.05 + + +def test_actor_learner_config(): + config = ActorLearnerConfig() + assert config.learner_host == "127.0.0.1" + assert config.learner_port == 50051 + assert config.policy_parameters_push_frequency == 4 + + +def test_concurrency_config(): + config = ConcurrencyConfig() + assert config.actor == "threads" + assert config.learner == "threads" + + +def test_sac_config_custom_initialization(): + config = SACConfig( + device="cpu", + discount=0.95, + temperature_init=0.5, + num_critics=3, + ) + + assert config.device == "cpu" + assert config.discount == 0.95 + assert config.temperature_init == 0.5 + assert config.num_critics == 3 + + +def test_validate_features(): + config = SACConfig( + input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))}, + output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, + ) + config.validate_features() + + +def test_validate_features_missing_observation(): + config = SACConfig( + input_features={"wrong_key": PolicyFeature(type=FeatureType.STATE, shape=(10,))}, + output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, + ) + with pytest.raises( + ValueError, match="You must provide either 'observation.state' or an image observation" + ): + config.validate_features() + + +def test_validate_features_missing_action(): + config = SACConfig( + input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))}, + output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, + ) + with pytest.raises(ValueError, match="You must provide 'action' in the output features"): + config.validate_features() diff --git a/tests/policies/test_sac_policy.py b/tests/policies/test_sac_policy.py new file mode 100644 index 0000000000..e4e2dd8a99 --- /dev/null +++ b/tests/policies/test_sac_policy.py @@ -0,0 +1,541 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import pytest +import torch +from torch import Tensor, nn + +from lerobot.common.policies.sac.configuration_sac import SACConfig +from lerobot.common.policies.sac.modeling_sac import MLP, SACPolicy +from lerobot.common.utils.random_utils import seeded_context, set_seed +from lerobot.configs.types import FeatureType, PolicyFeature + +try: + import transformers # noqa: F401 + + TRANSFORMERS_AVAILABLE = True +except ImportError: + TRANSFORMERS_AVAILABLE = False + + +@pytest.fixture(autouse=True) +def set_random_seed(): + seed = 42 + set_seed(seed) + + +def test_mlp_with_default_args(): + mlp = MLP(input_dim=10, hidden_dims=[256, 256]) + + x = torch.randn(10) + y = mlp(x) + assert y.shape == (256,) + + +def test_mlp_with_batch_dim(): + mlp = MLP(input_dim=10, hidden_dims=[256, 256]) + x = torch.randn(2, 10) + y = mlp(x) + assert y.shape == (2, 256) + + +def test_forward_with_empty_hidden_dims(): + mlp = MLP(input_dim=10, hidden_dims=[]) + x = torch.randn(1, 10) + assert mlp(x).shape == (1, 10) + + +def test_mlp_with_dropout(): + mlp = MLP(input_dim=10, hidden_dims=[256, 256, 11], dropout_rate=0.1) + x = torch.randn(1, 10) + y = mlp(x) + assert y.shape == (1, 11) + + drop_out_layers_count = sum(isinstance(layer, nn.Dropout) for layer in mlp.net) + assert drop_out_layers_count == 2 + + +def test_mlp_with_custom_final_activation(): + mlp = MLP(input_dim=10, hidden_dims=[256, 256], final_activation=torch.nn.Tanh()) + x = torch.randn(1, 10) + y = mlp(x) + assert y.shape == (1, 256) + assert (y >= -1).all() and (y <= 1).all() + + +def test_sac_policy_with_default_args(): + with pytest.raises(ValueError, match="should be an instance of class `PreTrainedConfig`"): + SACPolicy() + + +def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor: + return { + "observation.state": torch.randn(batch_size, state_dim), + } + + +def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor: + return { + "observation.image": torch.randn(batch_size, 3, 84, 84), + "observation.state": torch.randn(batch_size, state_dim), + } + + +def create_dummy_action(batch_size: int, action_dim: int = 10) -> Tensor: + return torch.randn(batch_size, action_dim) + + +def create_default_train_batch( + batch_size: int = 8, state_dim: int = 10, action_dim: int = 10 +) -> dict[str, Tensor]: + return { + "action": create_dummy_action(batch_size, action_dim), + "reward": torch.randn(batch_size), + "state": create_dummy_state(batch_size, state_dim), + "next_state": create_dummy_state(batch_size, state_dim), + "done": torch.randn(batch_size), + } + + +def create_train_batch_with_visual_input( + batch_size: int = 8, state_dim: int = 10, action_dim: int = 10 +) -> dict[str, Tensor]: + return { + "action": create_dummy_action(batch_size, action_dim), + "reward": torch.randn(batch_size), + "state": create_dummy_with_visual_input(batch_size, state_dim), + "next_state": create_dummy_with_visual_input(batch_size, state_dim), + "done": torch.randn(batch_size), + } + + +def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]: + return { + "observation.state": torch.randn(batch_size, state_dim), + } + + +def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]: + return { + "observation.state": torch.randn(batch_size, state_dim), + "observation.image": torch.randn(batch_size, 3, 84, 84), + } + + +def make_optimizers(policy: SACPolicy, has_discrete_action: bool = False) -> dict[str, torch.optim.Optimizer]: + """Create optimizers for the SAC policy.""" + optimizer_actor = torch.optim.Adam( + # Handle the case of shared encoder where the encoder weights are not optimized with the actor gradient + params=[ + p + for n, p in policy.actor.named_parameters() + if not policy.config.shared_encoder or not n.startswith("encoder") + ], + lr=policy.config.actor_lr, + ) + optimizer_critic = torch.optim.Adam( + params=policy.critic_ensemble.parameters(), + lr=policy.config.critic_lr, + ) + optimizer_temperature = torch.optim.Adam( + params=[policy.log_alpha], + lr=policy.config.critic_lr, + ) + + optimizers = { + "actor": optimizer_actor, + "critic": optimizer_critic, + "temperature": optimizer_temperature, + } + + if has_discrete_action: + optimizers["discrete_critic"] = torch.optim.Adam( + params=policy.discrete_critic.parameters(), + lr=policy.config.critic_lr, + ) + + return optimizers + + +def create_default_config( + state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False +) -> SACConfig: + action_dim = continuous_action_dim + if has_discrete_action: + action_dim += 1 + + config = SACConfig( + input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, + output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))}, + dataset_stats={ + "observation.state": { + "min": [0.0] * state_dim, + "max": [1.0] * state_dim, + }, + "action": { + "min": [0.0] * continuous_action_dim, + "max": [1.0] * continuous_action_dim, + }, + }, + ) + config.validate_features() + return config + + +def create_config_with_visual_input( + state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False +) -> SACConfig: + config = create_default_config( + state_dim=state_dim, + continuous_action_dim=continuous_action_dim, + has_discrete_action=has_discrete_action, + ) + config.input_features["observation.image"] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84)) + config.dataset_stats["observation.image"] = { + "mean": torch.randn(3, 1, 1), + "std": torch.randn(3, 1, 1), + } + + # Let make tests a little bit faster + config.state_encoder_hidden_dim = 32 + config.latent_dim = 32 + + config.validate_features() + return config + + +@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)]) +def test_sac_policy_with_default_config(batch_size: int, state_dim: int, action_dim: int): + batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim) + config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim) + + policy = SACPolicy(config=config) + policy.train() + + optimizers = make_optimizers(policy) + + cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] + assert cirtic_loss.item() is not None + assert cirtic_loss.shape == () + cirtic_loss.backward() + optimizers["critic"].step() + + actor_loss = policy.forward(batch, model="actor")["loss_actor"] + assert actor_loss.item() is not None + assert actor_loss.shape == () + + actor_loss.backward() + optimizers["actor"].step() + + temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"] + assert temperature_loss.item() is not None + assert temperature_loss.shape == () + + temperature_loss.backward() + optimizers["temperature"].step() + + policy.eval() + with torch.no_grad(): + observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + selected_action = policy.select_action(observation_batch) + assert selected_action.shape == (batch_size, action_dim) + + +@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)]) +def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_dim: int): + config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) + policy = SACPolicy(config=config) + + batch = create_train_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim, action_dim=action_dim + ) + + policy.train() + + optimizers = make_optimizers(policy) + + cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] + assert cirtic_loss.item() is not None + assert cirtic_loss.shape == () + cirtic_loss.backward() + optimizers["critic"].step() + + actor_loss = policy.forward(batch, model="actor")["loss_actor"] + assert actor_loss.item() is not None + assert actor_loss.shape == () + + actor_loss.backward() + optimizers["actor"].step() + + temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"] + assert temperature_loss.item() is not None + assert temperature_loss.shape == () + + temperature_loss.backward() + optimizers["temperature"].step() + + policy.eval() + with torch.no_grad(): + observation_batch = create_observation_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim + ) + selected_action = policy.select_action(observation_batch) + assert selected_action.shape == (batch_size, action_dim) + + +# Let's check best candidates for pretrained encoders +@pytest.mark.parametrize( + "batch_size,state_dim,action_dim,vision_encoder_name", + [(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")], +) +@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed") +def test_sac_policy_with_pretrained_encoder( + batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str +): + config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) + config.vision_encoder_name = vision_encoder_name + policy = SACPolicy(config=config) + policy.train() + + batch = create_train_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim, action_dim=action_dim + ) + + optimizers = make_optimizers(policy) + + cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] + assert cirtic_loss.item() is not None + assert cirtic_loss.shape == () + cirtic_loss.backward() + optimizers["critic"].step() + + actor_loss = policy.forward(batch, model="actor")["loss_actor"] + assert actor_loss.item() is not None + assert actor_loss.shape == () + + +def test_sac_policy_with_shared_encoder(): + batch_size = 2 + action_dim = 10 + state_dim = 10 + config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) + config.shared_encoder = True + + policy = SACPolicy(config=config) + policy.train() + + batch = create_train_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim, action_dim=action_dim + ) + + policy.train() + + optimizers = make_optimizers(policy) + + cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] + assert cirtic_loss.item() is not None + assert cirtic_loss.shape == () + cirtic_loss.backward() + optimizers["critic"].step() + + actor_loss = policy.forward(batch, model="actor")["loss_actor"] + assert actor_loss.item() is not None + assert actor_loss.shape == () + + actor_loss.backward() + optimizers["actor"].step() + + +def test_sac_policy_with_discrete_critic(): + batch_size = 2 + continuous_action_dim = 9 + full_action_dim = continuous_action_dim + 1 # the last action is discrete + state_dim = 10 + config = create_config_with_visual_input( + state_dim=state_dim, continuous_action_dim=continuous_action_dim, has_discrete_action=True + ) + + num_discrete_actions = 5 + config.num_discrete_actions = num_discrete_actions + + policy = SACPolicy(config=config) + policy.train() + + batch = create_train_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim, action_dim=full_action_dim + ) + + policy.train() + + optimizers = make_optimizers(policy, has_discrete_action=True) + + cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] + assert cirtic_loss.item() is not None + assert cirtic_loss.shape == () + cirtic_loss.backward() + optimizers["critic"].step() + + discrete_critic_loss = policy.forward(batch, model="discrete_critic")["loss_discrete_critic"] + assert discrete_critic_loss.item() is not None + assert discrete_critic_loss.shape == () + discrete_critic_loss.backward() + optimizers["discrete_critic"].step() + + actor_loss = policy.forward(batch, model="actor")["loss_actor"] + assert actor_loss.item() is not None + assert actor_loss.shape == () + + actor_loss.backward() + optimizers["actor"].step() + + policy.eval() + with torch.no_grad(): + observation_batch = create_observation_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim + ) + selected_action = policy.select_action(observation_batch) + assert selected_action.shape == (batch_size, full_action_dim) + + discrete_actions = selected_action[:, -1].long() + discrete_action_values = set(discrete_actions.tolist()) + + assert all(action in range(num_discrete_actions) for action in discrete_action_values), ( + f"Discrete action {discrete_action_values} is not in range({num_discrete_actions})" + ) + + +def test_sac_policy_with_default_entropy(): + config = create_default_config(continuous_action_dim=10, state_dim=10) + policy = SACPolicy(config=config) + assert policy.target_entropy == -5.0 + + +def test_sac_policy_default_target_entropy_with_discrete_action(): + config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True) + policy = SACPolicy(config=config) + assert policy.target_entropy == -3.0 + + +def test_sac_policy_with_predefined_entropy(): + config = create_default_config(state_dim=10, continuous_action_dim=6) + config.target_entropy = -3.5 + + policy = SACPolicy(config=config) + assert policy.target_entropy == pytest.approx(-3.5) + + +def test_sac_policy_update_temperature(): + config = create_default_config(continuous_action_dim=10, state_dim=10) + policy = SACPolicy(config=config) + + assert policy.temperature == pytest.approx(1.0) + policy.log_alpha.data = torch.tensor([math.log(0.1)]) + policy.update_temperature() + assert policy.temperature == pytest.approx(0.1) + + +def test_sac_policy_update_target_network(): + config = create_default_config(state_dim=10, continuous_action_dim=6) + config.critic_target_update_weight = 1.0 + + policy = SACPolicy(config=config) + policy.train() + + for p in policy.critic_ensemble.parameters(): + p.data = torch.ones_like(p.data) + + policy.update_target_networks() + for p in policy.critic_target.parameters(): + assert torch.allclose(p.data, torch.ones_like(p.data)), ( + f"Target network {p.data} is not equal to {torch.ones_like(p.data)}" + ) + + +@pytest.mark.parametrize("num_critics", [1, 3]) +def test_sac_policy_with_critics_number_of_heads(num_critics: int): + batch_size = 2 + action_dim = 10 + state_dim = 10 + config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) + config.num_critics = num_critics + + policy = SACPolicy(config=config) + policy.train() + + assert len(policy.critic_ensemble.critics) == num_critics + + batch = create_train_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim, action_dim=action_dim + ) + + policy.train() + + optimizers = make_optimizers(policy) + + cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] + assert cirtic_loss.item() is not None + assert cirtic_loss.shape == () + cirtic_loss.backward() + optimizers["critic"].step() + + +def test_sac_policy_save_and_load(tmp_path): + root = tmp_path / "test_sac_save_and_load" + + state_dim = 10 + action_dim = 10 + batch_size = 2 + + config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim) + policy = SACPolicy(config=config) + policy.eval() + policy.save_pretrained(root) + loaded_policy = SACPolicy.from_pretrained(root, config=config) + loaded_policy.eval() + + batch = create_default_train_batch(batch_size=1, state_dim=10, action_dim=10) + + with torch.no_grad(): + with seeded_context(12): + # Collect policy values before saving + cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] + actor_loss = policy.forward(batch, model="actor")["loss_actor"] + temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"] + + observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + actions = policy.select_action(observation_batch) + + with seeded_context(12): + # Collect policy values after loading + loaded_cirtic_loss = loaded_policy.forward(batch, model="critic")["loss_critic"] + loaded_actor_loss = loaded_policy.forward(batch, model="actor")["loss_actor"] + loaded_temperature_loss = loaded_policy.forward(batch, model="temperature")["loss_temperature"] + + loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + loaded_actions = loaded_policy.select_action(loaded_observation_batch) + + assert policy.state_dict().keys() == loaded_policy.state_dict().keys() + for k in policy.state_dict(): + assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6) + + # Compare values before and after saving and loading + # They should be the same + assert torch.allclose(cirtic_loss, loaded_cirtic_loss) + assert torch.allclose(actor_loss, loaded_actor_loss) + assert torch.allclose(temperature_loss, loaded_temperature_loss) + assert torch.allclose(actions, loaded_actions) diff --git a/tests/rl/test_actor.py b/tests/rl/test_actor.py new file mode 100644 index 0000000000..0cf6a8f644 --- /dev/null +++ b/tests/rl/test_actor.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from concurrent import futures +from unittest.mock import patch + +import pytest +import torch +from torch.multiprocessing import Event, Queue + +from lerobot.common.utils.transition import Transition +from tests.utils import require_package + + +def create_learner_service_stub(): + import grpc + + from lerobot.common.transport import services_pb2, services_pb2_grpc + + class MockLearnerService(services_pb2_grpc.LearnerServiceServicer): + def __init__(self): + self.ready_call_count = 0 + self.should_fail = False + + def Ready(self, request, context): # noqa: N802 + self.ready_call_count += 1 + if self.should_fail: + context.set_code(grpc.StatusCode.UNAVAILABLE) + context.set_details("Service unavailable") + raise grpc.RpcError("Service unavailable") + return services_pb2.Empty() + + """Fixture to start a LearnerService gRPC server and provide a connected stub.""" + + servicer = MockLearnerService() + + # Create a gRPC server and add our servicer to it. + server = grpc.server(futures.ThreadPoolExecutor(max_workers=4)) + services_pb2_grpc.add_LearnerServiceServicer_to_server(servicer, server) + port = server.add_insecure_port("[::]:0") # bind to a free port chosen by OS + server.start() # start the server (non-blocking call):contentReference[oaicite:1]{index=1} + + # Create a client channel and stub connected to the server's port. + channel = grpc.insecure_channel(f"localhost:{port}") + return services_pb2_grpc.LearnerServiceStub(channel), servicer, channel, server + + +def close_service_stub(channel, server): + channel.close() + server.stop(None) + + +@require_package("grpc") +def test_establish_learner_connection_success(): + from lerobot.scripts.rl.actor import establish_learner_connection + + """Test successful connection establishment.""" + stub, _servicer, channel, server = create_learner_service_stub() + + shutdown_event = Event() + + # Test successful connection + result = establish_learner_connection(stub, shutdown_event, attempts=5) + + assert result is True + + close_service_stub(channel, server) + + +@require_package("grpc") +def test_establish_learner_connection_failure(): + from lerobot.scripts.rl.actor import establish_learner_connection + + """Test connection failure.""" + stub, servicer, channel, server = create_learner_service_stub() + servicer.should_fail = True + + shutdown_event = Event() + + # Test failed connection + with patch("time.sleep"): # Speed up the test + result = establish_learner_connection(stub, shutdown_event, attempts=2) + + assert result is False + + close_service_stub(channel, server) + + +@require_package("grpc") +def test_push_transitions_to_transport_queue(): + from lerobot.common.transport.utils import bytes_to_transitions + from lerobot.scripts.rl.actor import push_transitions_to_transport_queue + from tests.transport.test_transport_utils import assert_transitions_equal + + """Test pushing transitions to transport queue.""" + # Create mock transitions + transitions = [] + for i in range(3): + transition = Transition( + state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, + action=torch.randn(5), + reward=torch.tensor(1.0 + i), + done=torch.tensor(False), + truncated=torch.tensor(False), + next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, + complementary_info={"step": torch.tensor(i)}, + ) + transitions.append(transition) + + transitions_queue = Queue() + + # Test pushing transitions + push_transitions_to_transport_queue(transitions, transitions_queue) + + # Verify the data can be retrieved + serialized_data = transitions_queue.get() + assert isinstance(serialized_data, bytes) + deserialized_transitions = bytes_to_transitions(serialized_data) + assert len(deserialized_transitions) == len(transitions) + for i, deserialized_transition in enumerate(deserialized_transitions): + assert_transitions_equal(deserialized_transition, transitions[i]) + + +@require_package("grpc") +@pytest.mark.timeout(3) # force cross-platform watchdog +def test_transitions_stream(): + from lerobot.scripts.rl.actor import transitions_stream + + """Test transitions stream functionality.""" + shutdown_event = Event() + transitions_queue = Queue() + + # Add test data to queue + test_data = [b"transition_data_1", b"transition_data_2", b"transition_data_3"] + for data in test_data: + transitions_queue.put(data) + + # Collect streamed data + streamed_data = [] + stream_generator = transitions_stream(shutdown_event, transitions_queue, 0.1) + + # Process a few items + for i, message in enumerate(stream_generator): + streamed_data.append(message) + if i >= len(test_data) - 1: + shutdown_event.set() + break + + # Verify we got messages + assert len(streamed_data) == len(test_data) + assert streamed_data[0].data == b"transition_data_1" + assert streamed_data[1].data == b"transition_data_2" + assert streamed_data[2].data == b"transition_data_3" + + +@require_package("grpc") +@pytest.mark.timeout(3) # force cross-platform watchdog +def test_interactions_stream(): + from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes + from lerobot.scripts.rl.actor import interactions_stream + + """Test interactions stream functionality.""" + shutdown_event = Event() + interactions_queue = Queue() + + # Create test interaction data (similar structure to what would be sent) + test_interactions = [ + {"episode_reward": 10.5, "step": 1, "policy_fps": 30.2}, + {"episode_reward": 15.2, "step": 2, "policy_fps": 28.7}, + {"episode_reward": 8.7, "step": 3, "policy_fps": 29.1}, + ] + + # Serialize the interaction data as it would be in practice + test_data = [ + interactions_queue.put(python_object_to_bytes(interaction)) for interaction in test_interactions + ] + + # Collect streamed data + streamed_data = [] + stream_generator = interactions_stream(shutdown_event, interactions_queue, 0.1) + + # Process the items + for i, message in enumerate(stream_generator): + streamed_data.append(message) + if i >= len(test_data) - 1: + shutdown_event.set() + break + + # Verify we got messages + assert len(streamed_data) == len(test_data) + + # Verify the messages can be deserialized back to original data + for i, message in enumerate(streamed_data): + deserialized_interaction = bytes_to_python_object(message.data) + assert deserialized_interaction == test_interactions[i] diff --git a/tests/rl/test_actor_learner.py b/tests/rl/test_actor_learner.py new file mode 100644 index 0000000000..cb72da7e40 --- /dev/null +++ b/tests/rl/test_actor_learner.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import socket +import threading +import time + +import pytest +import torch +from torch.multiprocessing import Event, Queue + +from lerobot.common.policies.sac.configuration_sac import SACConfig +from lerobot.common.utils.transition import Transition +from lerobot.configs.train import TrainRLServerPipelineConfig +from tests.utils import require_package + + +def create_test_transitions(count: int = 3) -> list[Transition]: + """Create test transitions for integration testing.""" + transitions = [] + for i in range(count): + transition = Transition( + state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, + action=torch.randn(5), + reward=torch.tensor(1.0 + i), + done=torch.tensor(i == count - 1), # Last transition is done + truncated=torch.tensor(False), + next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, + complementary_info={"step": torch.tensor(i), "episode_id": i // 2}, + ) + transitions.append(transition) + return transitions + + +def create_test_interactions(count: int = 3) -> list[dict]: + """Create test interactions for integration testing.""" + interactions = [] + for i in range(count): + interaction = { + "episode_reward": 10.0 + i * 5, + "step": i * 100, + "policy_fps": 30.0 + i, + "intervention_rate": 0.1 * i, + "episode_length": 200 + i * 50, + } + interactions.append(interaction) + return interactions + + +def find_free_port(): + """Finds a free port on the local machine.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) # Bind to port 0 to let the OS choose a free port + s.listen(1) + port = s.getsockname()[1] + return port + + +@pytest.fixture +def cfg(): + cfg = TrainRLServerPipelineConfig() + + port = find_free_port() + + policy_cfg = SACConfig() + policy_cfg.actor_learner_config.learner_host = "127.0.0.1" + policy_cfg.actor_learner_config.learner_port = port + policy_cfg.concurrency.actor = "threads" + policy_cfg.concurrency.learner = "threads" + policy_cfg.actor_learner_config.queue_get_timeout = 0.1 + + cfg.policy = policy_cfg + + return cfg + + +@require_package("grpc") +@pytest.mark.timeout(10) # force cross-platform watchdog +def test_end_to_end_transitions_flow(cfg): + from lerobot.common.transport.utils import bytes_to_transitions + from lerobot.scripts.rl.actor import ( + establish_learner_connection, + learner_service_client, + push_transitions_to_transport_queue, + send_transitions, + ) + from lerobot.scripts.rl.learner import start_learner + from tests.transport.test_transport_utils import assert_transitions_equal + + """Test complete transitions flow from actor to learner.""" + transitions_actor_queue = Queue() + transitions_learner_queue = Queue() + + interactions_queue = Queue() + parameters_queue = Queue() + shutdown_event = Event() + + learner_thread = threading.Thread( + target=start_learner, + args=(parameters_queue, transitions_learner_queue, interactions_queue, shutdown_event, cfg), + ) + learner_thread.start() + + policy_cfg = cfg.policy + learner_client, channel = learner_service_client( + host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port + ) + + assert establish_learner_connection(learner_client, shutdown_event, attempts=5) + + send_transitions_thread = threading.Thread( + target=send_transitions, args=(cfg, transitions_actor_queue, shutdown_event, learner_client, channel) + ) + send_transitions_thread.start() + + input_transitions = create_test_transitions(count=5) + + push_transitions_to_transport_queue(input_transitions, transitions_actor_queue) + + # Wait for learner to start + time.sleep(0.1) + + shutdown_event.set() + + # Wait for learner to receive transitions + learner_thread.join() + send_transitions_thread.join() + channel.close() + + received_transitions = [] + while not transitions_learner_queue.empty(): + received_transitions.extend(bytes_to_transitions(transitions_learner_queue.get())) + + assert len(received_transitions) == len(input_transitions) + for i, transition in enumerate(received_transitions): + assert_transitions_equal(transition, input_transitions[i]) + + +@require_package("grpc") +@pytest.mark.timeout(10) +def test_end_to_end_interactions_flow(cfg): + from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes + from lerobot.scripts.rl.actor import ( + establish_learner_connection, + learner_service_client, + send_interactions, + ) + from lerobot.scripts.rl.learner import start_learner + + """Test complete interactions flow from actor to learner.""" + # Queues for actor-learner communication + interactions_actor_queue = Queue() + interactions_learner_queue = Queue() + + # Other queues required by the learner + parameters_queue = Queue() + transitions_learner_queue = Queue() + + shutdown_event = Event() + + # Start the learner in a separate thread + learner_thread = threading.Thread( + target=start_learner, + args=(parameters_queue, transitions_learner_queue, interactions_learner_queue, shutdown_event, cfg), + ) + learner_thread.start() + + # Establish connection from actor to learner + policy_cfg = cfg.policy + learner_client, channel = learner_service_client( + host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port + ) + + assert establish_learner_connection(learner_client, shutdown_event, attempts=5) + + # Start the actor's interaction sending process in a separate thread + send_interactions_thread = threading.Thread( + target=send_interactions, + args=(cfg, interactions_actor_queue, shutdown_event, learner_client, channel), + ) + send_interactions_thread.start() + + # Create and push test interactions to the actor's queue + input_interactions = create_test_interactions(count=5) + for interaction in input_interactions: + interactions_actor_queue.put(python_object_to_bytes(interaction)) + + # Wait for the communication to happen + time.sleep(0.1) + + # Signal shutdown and wait for threads to complete + shutdown_event.set() + learner_thread.join() + send_interactions_thread.join() + channel.close() + + # Verify that the learner received the interactions + received_interactions = [] + while not interactions_learner_queue.empty(): + received_interactions.append(bytes_to_python_object(interactions_learner_queue.get())) + + assert len(received_interactions) == len(input_interactions) + + # Sort by a unique key to handle potential reordering in queues + received_interactions.sort(key=lambda x: x["step"]) + input_interactions.sort(key=lambda x: x["step"]) + + for received, expected in zip(received_interactions, input_interactions, strict=False): + assert received == expected + + +@require_package("grpc") +@pytest.mark.parametrize("data_size", ["small", "large"]) +@pytest.mark.timeout(10) +def test_end_to_end_parameters_flow(cfg, data_size): + from lerobot.common.transport.utils import bytes_to_state_dict, state_to_bytes + from lerobot.scripts.rl.actor import establish_learner_connection, learner_service_client, receive_policy + from lerobot.scripts.rl.learner import start_learner + + """Test complete parameter flow from learner to actor, with small and large data.""" + # Actor's local queue to receive params + parameters_actor_queue = Queue() + # Learner's queue to send params from + parameters_learner_queue = Queue() + + # Other queues required by the learner + transitions_learner_queue = Queue() + interactions_learner_queue = Queue() + + shutdown_event = Event() + + # Start the learner in a separate thread + learner_thread = threading.Thread( + target=start_learner, + args=( + parameters_learner_queue, + transitions_learner_queue, + interactions_learner_queue, + shutdown_event, + cfg, + ), + ) + learner_thread.start() + + # Establish connection from actor to learner + policy_cfg = cfg.policy + learner_client, channel = learner_service_client( + host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port + ) + + assert establish_learner_connection(learner_client, shutdown_event, attempts=5) + + # Start the actor's parameter receiving process in a separate thread + receive_params_thread = threading.Thread( + target=receive_policy, + args=(cfg, parameters_actor_queue, shutdown_event, learner_client, channel), + ) + receive_params_thread.start() + + # Create test parameters based on parametrization + if data_size == "small": + input_params = {"layer.weight": torch.randn(128, 64)} + else: # "large" + # CHUNK_SIZE is 2MB, so this tensor (4MB) will force chunking + input_params = {"large_layer.weight": torch.randn(1024, 1024)} + + # Simulate learner having new parameters to send + parameters_learner_queue.put(state_to_bytes(input_params)) + + # Wait for the actor to receive the parameters + time.sleep(0.1) + + # Signal shutdown and wait for threads to complete + shutdown_event.set() + learner_thread.join() + receive_params_thread.join() + channel.close() + + # Verify that the actor received the parameters correctly + received_params = bytes_to_state_dict(parameters_actor_queue.get()) + + assert received_params.keys() == input_params.keys() + for key in input_params: + assert torch.allclose(received_params[key], input_params[key]) diff --git a/tests/rl/test_learner_service.py b/tests/rl/test_learner_service.py new file mode 100644 index 0000000000..ee9d06e914 --- /dev/null +++ b/tests/rl/test_learner_service.py @@ -0,0 +1,374 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import threading +import time +from concurrent import futures +from multiprocessing import Event, Queue + +import pytest + +from tests.utils import require_package # our gRPC servicer class + + +@pytest.fixture(scope="function") +def learner_service_stub(): + shutdown_event = Event() + parameters_queue = Queue() + transitions_queue = Queue() + interactions_queue = Queue() + seconds_between_pushes = 1 + client, channel, server = create_learner_service_stub( + shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes + ) + + yield client # provide the stub to the test function + + close_learner_service_stub(channel, server) + + +@require_package("grpc") +def create_learner_service_stub( + shutdown_event: Event, + parameters_queue: Queue, + transitions_queue: Queue, + interactions_queue: Queue, + seconds_between_pushes: int, + queue_get_timeout: float = 0.1, +): + import grpc + + from lerobot.common.transport import services_pb2_grpc # generated from .proto + from lerobot.scripts.rl.learner_service import LearnerService + + """Fixture to start a LearnerService gRPC server and provide a connected stub.""" + + servicer = LearnerService( + shutdown_event=shutdown_event, + parameters_queue=parameters_queue, + seconds_between_pushes=seconds_between_pushes, + transition_queue=transitions_queue, + interaction_message_queue=interactions_queue, + queue_get_timeout=queue_get_timeout, + ) + + # Create a gRPC server and add our servicer to it. + server = grpc.server(futures.ThreadPoolExecutor(max_workers=4)) + services_pb2_grpc.add_LearnerServiceServicer_to_server(servicer, server) + port = server.add_insecure_port("[::]:0") # bind to a free port chosen by OS + server.start() # start the server (non-blocking call):contentReference[oaicite:1]{index=1} + + # Create a client channel and stub connected to the server's port. + channel = grpc.insecure_channel(f"localhost:{port}") + return services_pb2_grpc.LearnerServiceStub(channel), channel, server + + +@require_package("grpc") +def close_learner_service_stub(channel, server): + channel.close() + server.stop(None) + + +@pytest.mark.timeout(3) # force cross-platform watchdog +def test_ready_method(learner_service_stub): + from lerobot.common.transport import services_pb2 + + """Test the ready method of the UserService.""" + request = services_pb2.Empty() + response = learner_service_stub.Ready(request) + assert response == services_pb2.Empty() + + +@require_package("grpc") +@pytest.mark.timeout(3) # force cross-platform watchdog +def test_send_interactions(): + from lerobot.common.transport import services_pb2 + + shutdown_event = Event() + + parameters_queue = Queue() + transitions_queue = Queue() + interactions_queue = Queue() + seconds_between_pushes = 1 + client, channel, server = create_learner_service_stub( + shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes + ) + + list_of_interaction_messages = [ + services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"1"), + services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE, data=b"2"), + services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"3"), + services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"4"), + services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"5"), + services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"6"), + services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE, data=b"7"), + services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"8"), + ] + + def mock_intercations_stream(): + yield from list_of_interaction_messages + + return services_pb2.Empty() + + response = client.SendInteractions(mock_intercations_stream()) + assert response == services_pb2.Empty() + + close_learner_service_stub(channel, server) + + # Extract the data from the interactions queue + interactions = [] + while not interactions_queue.empty(): + interactions.append(interactions_queue.get()) + + assert interactions == [b"123", b"4", b"5", b"678"] + + +@require_package("grpc") +@pytest.mark.timeout(3) # force cross-platform watchdog +def test_send_transitions(): + from lerobot.common.transport import services_pb2 + + """Test the SendTransitions method with various transition data.""" + shutdown_event = Event() + parameters_queue = Queue() + transitions_queue = Queue() + interactions_queue = Queue() + seconds_between_pushes = 1 + + client, channel, server = create_learner_service_stub( + shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes + ) + + # Create test transition messages + list_of_transition_messages = [ + services_pb2.Transition( + transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"transition_1" + ), + services_pb2.Transition( + transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE, data=b"transition_2" + ), + services_pb2.Transition(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"transition_3"), + services_pb2.Transition(transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"batch_1"), + services_pb2.Transition(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"batch_2"), + ] + + def mock_transitions_stream(): + yield from list_of_transition_messages + + response = client.SendTransitions(mock_transitions_stream()) + assert response == services_pb2.Empty() + + close_learner_service_stub(channel, server) + + # Extract the data from the transitions queue + transitions = [] + while not transitions_queue.empty(): + transitions.append(transitions_queue.get()) + + # Should have assembled the chunked data + assert transitions == [b"transition_1transition_2transition_3", b"batch_1batch_2"] + + +@require_package("grpc") +@pytest.mark.timeout(3) # force cross-platform watchdog +def test_send_transitions_empty_stream(): + from lerobot.common.transport import services_pb2 + + """Test SendTransitions with empty stream.""" + shutdown_event = Event() + parameters_queue = Queue() + transitions_queue = Queue() + interactions_queue = Queue() + seconds_between_pushes = 1 + + client, channel, server = create_learner_service_stub( + shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes + ) + + def empty_stream(): + return iter([]) + + response = client.SendTransitions(empty_stream()) + assert response == services_pb2.Empty() + + close_learner_service_stub(channel, server) + + # Queue should remain empty + assert transitions_queue.empty() + + +@require_package("grpc") +@pytest.mark.timeout(10) # force cross-platform watchdog +def test_stream_parameters(): + import time + + from lerobot.common.transport import services_pb2 + + """Test the StreamParameters method.""" + shutdown_event = Event() + parameters_queue = Queue() + transitions_queue = Queue() + interactions_queue = Queue() + seconds_between_pushes = 0.2 # Short delay for testing + + client, channel, server = create_learner_service_stub( + shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes + ) + + # Add test parameters to the queue + test_params = [b"param_batch_1", b"param_batch_2"] + for param in test_params: + parameters_queue.put(param) + + # Start streaming parameters + request = services_pb2.Empty() + stream = client.StreamParameters(request) + + # Collect streamed parameters and timestamps + received_params = [] + timestamps = [] + + for response in stream: + received_params.append(response.data) + timestamps.append(time.time()) + + # We should receive one last item + break + + parameters_queue.put(b"param_batch_3") + + for response in stream: + received_params.append(response.data) + timestamps.append(time.time()) + + # We should receive only one item + break + + shutdown_event.set() + close_learner_service_stub(channel, server) + + assert received_params == [b"param_batch_2", b"param_batch_3"] + + # Check the time difference between the two sends + time_diff = timestamps[1] - timestamps[0] + # Check if the time difference is close to the expected push frequency + assert time_diff == pytest.approx(seconds_between_pushes, abs=0.1) + + +@require_package("grpc") +@pytest.mark.timeout(3) # force cross-platform watchdog +def test_stream_parameters_with_shutdown(): + from lerobot.common.transport import services_pb2 + + """Test StreamParameters handles shutdown gracefully.""" + shutdown_event = Event() + parameters_queue = Queue() + transitions_queue = Queue() + interactions_queue = Queue() + seconds_between_pushes = 0.1 + queue_get_timeout = 0.001 + + client, channel, server = create_learner_service_stub( + shutdown_event, + parameters_queue, + transitions_queue, + interactions_queue, + seconds_between_pushes, + queue_get_timeout=queue_get_timeout, + ) + + test_params = [b"param_batch_1", b"stop", b"param_batch_3", b"param_batch_4"] + + # create a thread that will put the parameters in the queue + def producer(): + for param in test_params: + parameters_queue.put(param) + time.sleep(0.1) + + producer_thread = threading.Thread(target=producer) + producer_thread.start() + + # Start streaming + request = services_pb2.Empty() + stream = client.StreamParameters(request) + + # Collect streamed parameters + received_params = [] + + for response in stream: + received_params.append(response.data) + + if response.data == b"stop": + shutdown_event.set() + + producer_thread.join() + close_learner_service_stub(channel, server) + + assert received_params == [b"param_batch_1", b"stop"] + + +@require_package("grpc") +@pytest.mark.timeout(3) # force cross-platform watchdog +def test_stream_parameters_waits_and_retries_on_empty_queue(): + import threading + import time + + from lerobot.common.transport import services_pb2 + + """Test that StreamParameters waits and retries when the queue is empty.""" + shutdown_event = Event() + parameters_queue = Queue() + transitions_queue = Queue() + interactions_queue = Queue() + seconds_between_pushes = 0.05 + queue_get_timeout = 0.01 + + client, channel, server = create_learner_service_stub( + shutdown_event, + parameters_queue, + transitions_queue, + interactions_queue, + seconds_between_pushes, + queue_get_timeout=queue_get_timeout, + ) + + request = services_pb2.Empty() + stream = client.StreamParameters(request) + + received_params = [] + + def producer(): + # Let the consumer start and find an empty queue. + # It will wait `seconds_between_pushes` (0.05s), then `get` will timeout after `queue_get_timeout` (0.01s). + # Total time for the first empty loop is > 0.06s. We wait a bit longer to be safe. + time.sleep(0.06) + parameters_queue.put(b"param_after_wait") + time.sleep(0.05) + parameters_queue.put(b"param_after_wait_2") + + producer_thread = threading.Thread(target=producer) + producer_thread.start() + + # The consumer will block here until the producer sends an item. + for response in stream: + received_params.append(response.data) + if response.data == b"param_after_wait_2": + break # We only need one item for this test. + + shutdown_event.set() + producer_thread.join() + close_learner_service_stub(channel, server) + + assert received_params == [b"param_after_wait", b"param_after_wait_2"] diff --git a/tests/transport/test_transport_utils.py b/tests/transport/test_transport_utils.py new file mode 100644 index 0000000000..cf33f52c04 --- /dev/null +++ b/tests/transport/test_transport_utils.py @@ -0,0 +1,571 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +from multiprocessing import Event, Queue +from pickle import UnpicklingError + +import pytest +import torch + +from lerobot.common.utils.transition import Transition +from tests.utils import require_cuda, require_package + + +@require_package("grpc") +def test_bytes_buffer_size_empty_buffer(): + from lerobot.common.transport.utils import bytes_buffer_size + + """Test with an empty buffer.""" + buffer = io.BytesIO() + assert bytes_buffer_size(buffer) == 0 + # Ensure position is reset to beginning + assert buffer.tell() == 0 + + +@require_package("grpc") +def test_bytes_buffer_size_small_buffer(): + from lerobot.common.transport.utils import bytes_buffer_size + + """Test with a small buffer.""" + buffer = io.BytesIO(b"Hello, World!") + assert bytes_buffer_size(buffer) == 13 + assert buffer.tell() == 0 + + +@require_package("grpc") +def test_bytes_buffer_size_large_buffer(): + from lerobot.common.transport.utils import CHUNK_SIZE, bytes_buffer_size + + """Test with a large buffer.""" + data = b"x" * (CHUNK_SIZE * 2 + 1000) + buffer = io.BytesIO(data) + assert bytes_buffer_size(buffer) == len(data) + assert buffer.tell() == 0 + + +@require_package("grpc") +def test_send_bytes_in_chunks_empty_data(): + from lerobot.common.transport.utils import send_bytes_in_chunks, services_pb2 + + """Test sending empty data.""" + message_class = services_pb2.InteractionMessage + chunks = list(send_bytes_in_chunks(b"", message_class)) + assert len(chunks) == 0 + + +@require_package("grpc") +def test_single_chunk_small_data(): + from lerobot.common.transport.utils import send_bytes_in_chunks, services_pb2 + + """Test data that fits in a single chunk.""" + data = b"Some data" + message_class = services_pb2.InteractionMessage + chunks = list(send_bytes_in_chunks(data, message_class)) + + assert len(chunks) == 1 + assert chunks[0].data == b"Some data" + assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_END + + +@require_package("grpc") +def test_not_silent_mode(): + from lerobot.common.transport.utils import send_bytes_in_chunks, services_pb2 + + """Test not silent mode.""" + data = b"Some data" + message_class = services_pb2.InteractionMessage + chunks = list(send_bytes_in_chunks(data, message_class, silent=False)) + assert len(chunks) == 1 + assert chunks[0].data == b"Some data" + + +@require_package("grpc") +def test_send_bytes_in_chunks_large_data(): + from lerobot.common.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2 + + """Test sending large data.""" + data = b"x" * (CHUNK_SIZE * 2 + 1000) + message_class = services_pb2.InteractionMessage + chunks = list(send_bytes_in_chunks(data, message_class)) + assert len(chunks) == 3 + assert chunks[0].data == b"x" * CHUNK_SIZE + assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_BEGIN + assert chunks[1].data == b"x" * CHUNK_SIZE + assert chunks[1].transfer_state == services_pb2.TransferState.TRANSFER_MIDDLE + assert chunks[2].data == b"x" * 1000 + assert chunks[2].transfer_state == services_pb2.TransferState.TRANSFER_END + + +@require_package("grpc") +def test_send_bytes_in_chunks_large_data_with_exact_chunk_size(): + from lerobot.common.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2 + + """Test sending large data with exact chunk size.""" + data = b"x" * CHUNK_SIZE + message_class = services_pb2.InteractionMessage + chunks = list(send_bytes_in_chunks(data, message_class)) + assert len(chunks) == 1 + assert chunks[0].data == data + assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_END + + +@require_package("grpc") +def test_receive_bytes_in_chunks_empty_data(): + from lerobot.common.transport.utils import receive_bytes_in_chunks + + """Test receiving empty data.""" + queue = Queue() + shutdown_event = Event() + + # Empty iterator + receive_bytes_in_chunks(iter([]), queue, shutdown_event) + + assert queue.empty() + + +@require_package("grpc") +def test_receive_bytes_in_chunks_single_chunk(): + from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2 + + """Test receiving a single chunk message.""" + queue = Queue() + shutdown_event = Event() + + data = b"Single chunk data" + chunks = [ + services_pb2.InteractionMessage(data=data, transfer_state=services_pb2.TransferState.TRANSFER_END) + ] + + receive_bytes_in_chunks(iter(chunks), queue, shutdown_event) + + assert queue.get(timeout=0.01) == data + assert queue.empty() + + +@require_package("grpc") +def test_receive_bytes_in_chunks_single_not_end_chunk(): + from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2 + + """Test receiving a single chunk message.""" + queue = Queue() + shutdown_event = Event() + + data = b"Single chunk data" + chunks = [ + services_pb2.InteractionMessage(data=data, transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE) + ] + + receive_bytes_in_chunks(iter(chunks), queue, shutdown_event) + + assert queue.empty() + + +@require_package("grpc") +def test_receive_bytes_in_chunks_multiple_chunks(): + from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2 + + """Test receiving a multi-chunk message.""" + queue = Queue() + shutdown_event = Event() + + chunks = [ + services_pb2.InteractionMessage( + data=b"First ", transfer_state=services_pb2.TransferState.TRANSFER_BEGIN + ), + services_pb2.InteractionMessage( + data=b"Middle ", transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE + ), + services_pb2.InteractionMessage(data=b"Last", transfer_state=services_pb2.TransferState.TRANSFER_END), + ] + + receive_bytes_in_chunks(iter(chunks), queue, shutdown_event) + + assert queue.get(timeout=0.01) == b"First Middle Last" + assert queue.empty() + + +@require_package("grpc") +def test_receive_bytes_in_chunks_multiple_messages(): + from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2 + + """Test receiving multiple complete messages in sequence.""" + queue = Queue() + shutdown_event = Event() + + chunks = [ + # First message - single chunk + services_pb2.InteractionMessage( + data=b"Message1", transfer_state=services_pb2.TransferState.TRANSFER_END + ), + # Second message - multi chunk + services_pb2.InteractionMessage( + data=b"Start2 ", transfer_state=services_pb2.TransferState.TRANSFER_BEGIN + ), + services_pb2.InteractionMessage( + data=b"Middle2 ", transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE + ), + services_pb2.InteractionMessage(data=b"End2", transfer_state=services_pb2.TransferState.TRANSFER_END), + # Third message - single chunk + services_pb2.InteractionMessage( + data=b"Message3", transfer_state=services_pb2.TransferState.TRANSFER_END + ), + ] + + receive_bytes_in_chunks(iter(chunks), queue, shutdown_event) + + # Should have three messages in queue + assert queue.get(timeout=0.01) == b"Message1" + assert queue.get(timeout=0.01) == b"Start2 Middle2 End2" + assert queue.get(timeout=0.01) == b"Message3" + assert queue.empty() + + +@require_package("grpc") +def test_receive_bytes_in_chunks_shutdown_during_receive(): + from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2 + + """Test that shutdown event stops receiving mid-stream.""" + queue = Queue() + shutdown_event = Event() + shutdown_event.set() + + chunks = [ + services_pb2.InteractionMessage( + data=b"First ", transfer_state=services_pb2.TransferState.TRANSFER_BEGIN + ), + services_pb2.InteractionMessage( + data=b"Middle ", transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE + ), + services_pb2.InteractionMessage(data=b"Last", transfer_state=services_pb2.TransferState.TRANSFER_END), + ] + + receive_bytes_in_chunks(iter(chunks), queue, shutdown_event) + + assert queue.empty() + + +@require_package("grpc") +def test_receive_bytes_in_chunks_only_begin_chunk(): + from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2 + + """Test receiving only a BEGIN chunk without END.""" + queue = Queue() + shutdown_event = Event() + + chunks = [ + services_pb2.InteractionMessage( + data=b"Start", transfer_state=services_pb2.TransferState.TRANSFER_BEGIN + ), + # No END chunk + ] + + receive_bytes_in_chunks(iter(chunks), queue, shutdown_event) + + assert queue.empty() + + +@require_package("grpc") +def test_receive_bytes_in_chunks_missing_begin(): + from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2 + + """Test receiving chunks starting with MIDDLE instead of BEGIN.""" + queue = Queue() + shutdown_event = Event() + + chunks = [ + # Missing BEGIN + services_pb2.InteractionMessage( + data=b"Middle", transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE + ), + services_pb2.InteractionMessage(data=b"End", transfer_state=services_pb2.TransferState.TRANSFER_END), + ] + + receive_bytes_in_chunks(iter(chunks), queue, shutdown_event) + + # The implementation continues from where it is, so we should get partial data + assert queue.get(timeout=0.01) == b"MiddleEnd" + assert queue.empty() + + +# Tests for state_to_bytes and bytes_to_state_dict +@require_package("grpc") +def test_state_to_bytes_empty_dict(): + from lerobot.common.transport.utils import bytes_to_state_dict, state_to_bytes + + """Test converting empty state dict to bytes.""" + state_dict = {} + data = state_to_bytes(state_dict) + reconstructed = bytes_to_state_dict(data) + assert reconstructed == state_dict + + +@require_package("grpc") +def test_bytes_to_state_dict_empty_data(): + from lerobot.common.transport.utils import bytes_to_state_dict + + """Test converting empty data to state dict.""" + with pytest.raises(EOFError): + bytes_to_state_dict(b"") + + +@require_package("grpc") +def test_state_to_bytes_simple_dict(): + from lerobot.common.transport.utils import bytes_to_state_dict, state_to_bytes + + """Test converting simple state dict to bytes.""" + state_dict = { + "layer1.weight": torch.randn(10, 5), + "layer1.bias": torch.randn(10), + "layer2.weight": torch.randn(1, 10), + "layer2.bias": torch.randn(1), + } + + data = state_to_bytes(state_dict) + assert isinstance(data, bytes) + assert len(data) > 0 + + reconstructed = bytes_to_state_dict(data) + + assert len(reconstructed) == len(state_dict) + for key in state_dict: + assert key in reconstructed + assert torch.allclose(state_dict[key], reconstructed[key]) + + +@require_package("grpc") +def test_state_to_bytes_various_dtypes(): + from lerobot.common.transport.utils import bytes_to_state_dict, state_to_bytes + + """Test converting state dict with various tensor dtypes.""" + state_dict = { + "float32": torch.randn(5, 5), + "float64": torch.randn(3, 3).double(), + "int32": torch.randint(0, 100, (4, 4), dtype=torch.int32), + "int64": torch.randint(0, 100, (2, 2), dtype=torch.int64), + "bool": torch.tensor([True, False, True]), + "uint8": torch.randint(0, 255, (3, 3), dtype=torch.uint8), + } + + data = state_to_bytes(state_dict) + reconstructed = bytes_to_state_dict(data) + + for key in state_dict: + assert reconstructed[key].dtype == state_dict[key].dtype + if state_dict[key].dtype == torch.bool: + assert torch.equal(state_dict[key], reconstructed[key]) + else: + assert torch.allclose(state_dict[key], reconstructed[key]) + + +@require_package("grpc") +def test_bytes_to_state_dict_invalid_data(): + from lerobot.common.transport.utils import bytes_to_state_dict + + """Test bytes_to_state_dict with invalid data.""" + with pytest.raises(UnpicklingError): + bytes_to_state_dict(b"This is not a valid torch save file") + + +@require_cuda +@require_package("grpc") +def test_state_to_bytes_various_dtypes_cuda(): + from lerobot.common.transport.utils import bytes_to_state_dict, state_to_bytes + + """Test converting state dict with various tensor dtypes.""" + state_dict = { + "float32": torch.randn(5, 5).cuda(), + "float64": torch.randn(3, 3).double().cuda(), + "int32": torch.randint(0, 100, (4, 4), dtype=torch.int32).cuda(), + "int64": torch.randint(0, 100, (2, 2), dtype=torch.int64).cuda(), + "bool": torch.tensor([True, False, True]), + "uint8": torch.randint(0, 255, (3, 3), dtype=torch.uint8), + } + + data = state_to_bytes(state_dict) + reconstructed = bytes_to_state_dict(data) + + for key in state_dict: + assert reconstructed[key].dtype == state_dict[key].dtype + if state_dict[key].dtype == torch.bool: + assert torch.equal(state_dict[key], reconstructed[key]) + else: + assert torch.allclose(state_dict[key], reconstructed[key]) + + +@require_package("grpc") +def test_python_object_to_bytes_none(): + from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes + + """Test converting None to bytes.""" + obj = None + data = python_object_to_bytes(obj) + reconstructed = bytes_to_python_object(data) + assert reconstructed is None + + +@pytest.mark.parametrize( + "obj", + [ + 42, + -123, + 3.14159, + -2.71828, + "Hello, World!", + "Unicode: 你好世界 🌍", + True, + False, + b"byte string", + [], + [1, 2, 3], + [1, "two", 3.0, True, None], + {}, + {"key": "value", "number": 123, "nested": {"a": 1}}, + (), + (1, 2, 3), + ], +) +@require_package("grpc") +def test_python_object_to_bytes_simple_types(obj): + from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes + + """Test converting simple Python types.""" + data = python_object_to_bytes(obj) + reconstructed = bytes_to_python_object(data) + assert reconstructed == obj + assert type(reconstructed) is type(obj) + + +@require_package("grpc") +def test_python_object_to_bytes_with_tensors(): + from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes + + """Test converting objects containing PyTorch tensors.""" + obj = { + "tensor": torch.randn(5, 5), + "list_with_tensor": [1, 2, torch.randn(3, 3), "string"], + "nested": { + "tensor1": torch.randn(2, 2), + "tensor2": torch.tensor([1, 2, 3]), + }, + } + + data = python_object_to_bytes(obj) + reconstructed = bytes_to_python_object(data) + + assert torch.allclose(obj["tensor"], reconstructed["tensor"]) + assert reconstructed["list_with_tensor"][0] == 1 + assert reconstructed["list_with_tensor"][3] == "string" + assert torch.allclose(obj["list_with_tensor"][2], reconstructed["list_with_tensor"][2]) + assert torch.allclose(obj["nested"]["tensor1"], reconstructed["nested"]["tensor1"]) + assert torch.equal(obj["nested"]["tensor2"], reconstructed["nested"]["tensor2"]) + + +@require_package("grpc") +def test_transitions_to_bytes_empty_list(): + from lerobot.common.transport.utils import bytes_to_transitions, transitions_to_bytes + + """Test converting empty transitions list.""" + transitions = [] + data = transitions_to_bytes(transitions) + reconstructed = bytes_to_transitions(data) + assert reconstructed == transitions + assert isinstance(reconstructed, list) + + +@require_package("grpc") +def test_transitions_to_bytes_single_transition(): + from lerobot.common.transport.utils import bytes_to_transitions, transitions_to_bytes + + """Test converting a single transition.""" + transition = Transition( + state={"image": torch.randn(3, 64, 64), "state": torch.randn(10)}, + action=torch.randn(5), + reward=torch.tensor(1.5), + done=torch.tensor(False), + next_state={"image": torch.randn(3, 64, 64), "state": torch.randn(10)}, + ) + + transitions = [transition] + data = transitions_to_bytes(transitions) + reconstructed = bytes_to_transitions(data) + + assert len(reconstructed) == 1 + + assert_transitions_equal(transitions[0], reconstructed[0]) + + +@require_package("grpc") +def assert_transitions_equal(t1: Transition, t2: Transition): + """Helper to assert two transitions are equal.""" + assert_observation_equal(t1["state"], t2["state"]) + assert torch.allclose(t1["action"], t2["action"]) + assert torch.allclose(t1["reward"], t2["reward"]) + assert torch.equal(t1["done"], t2["done"]) + assert_observation_equal(t1["next_state"], t2["next_state"]) + + +@require_package("grpc") +def assert_observation_equal(o1: dict, o2: dict): + """Helper to assert two observations are equal.""" + assert set(o1.keys()) == set(o2.keys()) + for key in o1: + assert torch.allclose(o1[key], o2[key]) + + +@require_package("grpc") +def test_transitions_to_bytes_multiple_transitions(): + from lerobot.common.transport.utils import bytes_to_transitions, transitions_to_bytes + + """Test converting multiple transitions.""" + transitions = [] + for i in range(5): + transition = Transition( + state={"data": torch.randn(10)}, + action=torch.randn(3), + reward=torch.tensor(float(i)), + done=torch.tensor(i == 4), + next_state={"data": torch.randn(10)}, + ) + transitions.append(transition) + + data = transitions_to_bytes(transitions) + reconstructed = bytes_to_transitions(data) + + assert len(reconstructed) == len(transitions) + for original, reconstructed_item in zip(transitions, reconstructed, strict=False): + assert_transitions_equal(original, reconstructed_item) + + +@require_package("grpc") +def test_receive_bytes_in_chunks_unknown_state(): + from lerobot.common.transport.utils import receive_bytes_in_chunks + + """Test receive_bytes_in_chunks with an unknown transfer state.""" + + # Mock the gRPC message object, which has `transfer_state` and `data` attributes. + class MockMessage: + def __init__(self, transfer_state, data): + self.transfer_state = transfer_state + self.data = data + + # 10 is not a valid TransferState enum value + bad_iterator = [MockMessage(transfer_state=10, data=b"bad_data")] + output_queue = Queue() + shutdown_event = Event() + + with pytest.raises(ValueError, match="Received unknown transfer state"): + receive_bytes_in_chunks(bad_iterator, output_queue, shutdown_event) diff --git a/tests/utils/test_process.py b/tests/utils/test_process.py new file mode 100644 index 0000000000..054a8593a5 --- /dev/null +++ b/tests/utils/test_process.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing +import os +import signal +import threading +from unittest.mock import patch + +import pytest + +from lerobot.common.utils.process import ProcessSignalHandler + + +# Fixture to reset shutdown_event_counter and original signal handlers before and after each test +@pytest.fixture(autouse=True) +def reset_globals_and_handlers(): + # Store original signal handlers + original_handlers = { + sig: signal.getsignal(sig) + for sig in [signal.SIGINT, signal.SIGTERM, signal.SIGHUP, signal.SIGQUIT] + if hasattr(signal, sig.name) + } + + yield + + # Restore original signal handlers + for sig, handler in original_handlers.items(): + signal.signal(sig, handler) + + +def test_setup_process_handlers_event_with_threads(): + """Test that setup_process_handlers returns the correct event type.""" + handler = ProcessSignalHandler(use_threads=True) + shutdown_event = handler.shutdown_event + assert isinstance(shutdown_event, threading.Event), "Should be a threading.Event" + assert not shutdown_event.is_set(), "Event should initially be unset" + + +def test_setup_process_handlers_event_with_processes(): + """Test that setup_process_handlers returns the correct event type.""" + handler = ProcessSignalHandler(use_threads=False) + shutdown_event = handler.shutdown_event + assert isinstance(shutdown_event, type(multiprocessing.Event())), "Should be a multiprocessing.Event" + assert not shutdown_event.is_set(), "Event should initially be unset" + + +@pytest.mark.parametrize("use_threads", [True, False]) +@pytest.mark.parametrize( + "sig", + [ + signal.SIGINT, + signal.SIGTERM, + # SIGHUP and SIGQUIT are not reliably available on all platforms (e.g. Windows) + pytest.param( + signal.SIGHUP, + marks=pytest.mark.skipif(not hasattr(signal, "SIGHUP"), reason="SIGHUP not available"), + ), + pytest.param( + signal.SIGQUIT, + marks=pytest.mark.skipif(not hasattr(signal, "SIGQUIT"), reason="SIGQUIT not available"), + ), + ], +) +def test_signal_handler_sets_event(use_threads, sig): + """Test that the signal handler sets the event on receiving a signal.""" + handler = ProcessSignalHandler(use_threads=use_threads) + shutdown_event = handler.shutdown_event + + assert handler.counter == 0 + + os.kill(os.getpid(), sig) + + # In some environments, the signal might take a moment to be handled. + shutdown_event.wait(timeout=1.0) + + assert shutdown_event.is_set(), f"Event should be set after receiving signal {sig}" + + # Ensure the internal counter was incremented + assert handler.counter == 1 + + +@pytest.mark.parametrize("use_threads", [True, False]) +@patch("sys.exit") +def test_force_shutdown_on_second_signal(mock_sys_exit, use_threads): + """Test that a second signal triggers a force shutdown.""" + handler = ProcessSignalHandler(use_threads=use_threads) + + os.kill(os.getpid(), signal.SIGINT) + # Give a moment for the first signal to be processed + import time + + time.sleep(0.1) + os.kill(os.getpid(), signal.SIGINT) + + time.sleep(0.1) + + assert handler.counter == 2 + mock_sys_exit.assert_called_once_with(1) diff --git a/tests/utils/test_queue.py b/tests/utils/test_queue.py new file mode 100644 index 0000000000..863231e82b --- /dev/null +++ b/tests/utils/test_queue.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +import time +from queue import Queue + +from lerobot.common.utils.queue import get_last_item_from_queue + + +def test_get_last_item_single_item(): + """Test getting the last item when queue has only one item.""" + queue = Queue() + queue.put("single_item") + + result = get_last_item_from_queue(queue) + + assert result == "single_item" + assert queue.empty() + + +def test_get_last_item_multiple_items(): + """Test getting the last item when queue has multiple items.""" + queue = Queue() + items = ["first", "second", "third", "fourth", "last"] + + for item in items: + queue.put(item) + + result = get_last_item_from_queue(queue) + + assert result == "last" + assert queue.empty() + + +def test_get_last_item_different_types(): + """Test with different data types in the queue.""" + queue = Queue() + items = [1, 2.5, "string", {"key": "value"}, [1, 2, 3], ("tuple", "data")] + + for item in items: + queue.put(item) + + result = get_last_item_from_queue(queue) + + assert result == ("tuple", "data") + assert queue.empty() + + +def test_get_last_item_maxsize_queue(): + """Test with a queue that has a maximum size.""" + queue = Queue(maxsize=5) + + # Fill the queue + for i in range(5): + queue.put(i) + + # Give the queue time to fill + time.sleep(0.1) + + result = get_last_item_from_queue(queue) + + assert result == 4 + assert queue.empty() + + +def test_get_last_item_with_none_values(): + """Test with None values in the queue.""" + queue = Queue() + items = [1, None, 2, None, 3] + + for item in items: + queue.put(item) + + # Give the queue time to fill + time.sleep(0.1) + + result = get_last_item_from_queue(queue) + + assert result == 3 + assert queue.empty() + + +def test_get_last_item_blocking_timeout(): + """Test get_last_item_from_queue returns None on timeout.""" + queue = Queue() + result = get_last_item_from_queue(queue, block=True, timeout=0.1) + assert result is None + + +def test_get_last_item_non_blocking_empty(): + """Test get_last_item_from_queue with block=False on an empty queue returns None.""" + queue = Queue() + result = get_last_item_from_queue(queue, block=False) + assert result is None + + +def test_get_last_item_non_blocking_success(): + """Test get_last_item_from_queue with block=False on a non-empty queue.""" + queue = Queue() + items = ["first", "second", "last"] + for item in items: + queue.put(item) + + # Give the queue time to fill + time.sleep(0.1) + + result = get_last_item_from_queue(queue, block=False) + assert result == "last" + assert queue.empty() + + +def test_get_last_item_blocking_waits_for_item(): + """Test that get_last_item_from_queue waits for an item if block=True.""" + queue = Queue() + result = [] + + def producer(): + queue.put("item1") + queue.put("item2") + + def consumer(): + # This will block until the producer puts the first item + item = get_last_item_from_queue(queue, block=True, timeout=0.2) + result.append(item) + + producer_thread = threading.Thread(target=producer) + consumer_thread = threading.Thread(target=consumer) + + producer_thread.start() + consumer_thread.start() + + producer_thread.join() + consumer_thread.join() + + assert result == ["item2"] + assert queue.empty() diff --git a/tests/utils/test_replay_buffer.py b/tests/utils/test_replay_buffer.py new file mode 100644 index 0000000000..f7a055b20f --- /dev/null +++ b/tests/utils/test_replay_buffer.py @@ -0,0 +1,682 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from typing import Callable + +import pytest +import torch + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.utils.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized +from tests.fixtures.constants import DUMMY_REPO_ID + + +def state_dims() -> list[str]: + return ["observation.image", "observation.state"] + + +@pytest.fixture +def replay_buffer() -> ReplayBuffer: + return create_empty_replay_buffer() + + +def clone_state(state: dict) -> dict: + return {k: v.clone() for k, v in state.items()} + + +def create_empty_replay_buffer( + optimize_memory: bool = False, + use_drq: bool = False, + image_augmentation_function: Callable | None = None, +) -> ReplayBuffer: + buffer_capacity = 10 + device = "cpu" + return ReplayBuffer( + buffer_capacity, + device, + state_dims(), + optimize_memory=optimize_memory, + use_drq=use_drq, + image_augmentation_function=image_augmentation_function, + ) + + +def create_random_image() -> torch.Tensor: + return torch.rand(3, 84, 84) + + +def create_dummy_transition() -> dict: + return { + "observation.image": create_random_image(), + "action": torch.randn(4), + "reward": torch.tensor(1.0), + "observation.state": torch.randn( + 10, + ), + "done": torch.tensor(False), + "truncated": torch.tensor(False), + "complementary_info": {}, + } + + +def create_dataset_from_replay_buffer(tmp_path) -> tuple[LeRobotDataset, ReplayBuffer]: + dummy_state_1 = create_dummy_state() + dummy_action_1 = create_dummy_action() + + dummy_state_2 = create_dummy_state() + dummy_action_2 = create_dummy_action() + + dummy_state_3 = create_dummy_state() + dummy_action_3 = create_dummy_action() + + dummy_state_4 = create_dummy_state() + dummy_action_4 = create_dummy_action() + + replay_buffer = create_empty_replay_buffer() + replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False) + replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False) + replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True) + replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True) + + root = tmp_path / "test" + return (replay_buffer.to_lerobot_dataset(DUMMY_REPO_ID, root=root), replay_buffer) + + +def create_dummy_state() -> dict: + return { + "observation.image": create_random_image(), + "observation.state": torch.randn( + 10, + ), + } + + +def get_tensor_memory_consumption(tensor): + return tensor.nelement() * tensor.element_size() + + +def get_tensors_memory_consumption(obj, visited_addresses): + total_size = 0 + + address = id(obj) + if address in visited_addresses: + return 0 + + visited_addresses.add(address) + + if isinstance(obj, torch.Tensor): + return get_tensor_memory_consumption(obj) + elif isinstance(obj, (list, tuple)): + for item in obj: + total_size += get_tensors_memory_consumption(item, visited_addresses) + elif isinstance(obj, dict): + for value in obj.values(): + total_size += get_tensors_memory_consumption(value, visited_addresses) + elif hasattr(obj, "__dict__"): + # It's an object, we need to get the size of the attributes + for _, attr in vars(obj).items(): + total_size += get_tensors_memory_consumption(attr, visited_addresses) + + return total_size + + +def get_object_memory(obj): + # Track visited addresses to avoid infinite loops + # and cases when two properties point to the same object + visited_addresses = set() + + # Get the size of the object in bytes + total_size = sys.getsizeof(obj) + + # Get the size of the tensor attributes + total_size += get_tensors_memory_consumption(obj, visited_addresses) + + return total_size + + +def create_dummy_action() -> torch.Tensor: + return torch.randn(4) + + +def dict_properties() -> list: + return ["state", "next_state"] + + +@pytest.fixture +def dummy_state() -> dict: + return create_dummy_state() + + +@pytest.fixture +def next_dummy_state() -> dict: + return create_dummy_state() + + +@pytest.fixture +def dummy_action() -> torch.Tensor: + return torch.randn(4) + + +def test_empty_buffer_sample_raises_error(replay_buffer): + assert len(replay_buffer) == 0, "Replay buffer should be empty." + assert replay_buffer.capacity == 10, "Replay buffer capacity should be 10." + with pytest.raises(RuntimeError, match="Cannot sample from an empty buffer"): + replay_buffer.sample(1) + + +def test_zero_capacity_buffer_raises_error(): + with pytest.raises(ValueError, match="Capacity must be greater than 0."): + ReplayBuffer(0, "cpu", ["observation", "next_observation"]) + + +def test_add_transition(replay_buffer, dummy_state, dummy_action): + replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False) + assert len(replay_buffer) == 1, "Replay buffer should have one transition after adding." + assert torch.equal(replay_buffer.actions[0], dummy_action), ( + "Action should be equal to the first transition." + ) + assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the first transition." + assert not replay_buffer.dones[0], "Done should be False for the first transition." + assert not replay_buffer.truncateds[0], "Truncated should be False for the first transition." + + for dim in state_dims(): + assert torch.equal(replay_buffer.states[dim][0], dummy_state[dim]), ( + "Observation should be equal to the first transition." + ) + assert torch.equal(replay_buffer.next_states[dim][0], dummy_state[dim]), ( + "Next observation should be equal to the first transition." + ) + + +def test_add_over_capacity(): + replay_buffer = ReplayBuffer(2, "cpu", ["observation", "next_observation"]) + dummy_state_1 = create_dummy_state() + dummy_action_1 = create_dummy_action() + + dummy_state_2 = create_dummy_state() + dummy_action_2 = create_dummy_action() + + dummy_state_3 = create_dummy_state() + dummy_action_3 = create_dummy_action() + + replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False) + replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False) + replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True) + + assert len(replay_buffer) == 2, "Replay buffer should have 2 transitions after adding 3." + + for dim in state_dims(): + assert torch.equal(replay_buffer.states[dim][0], dummy_state_3[dim]), ( + "Observation should be equal to the first transition." + ) + assert torch.equal(replay_buffer.next_states[dim][0], dummy_state_3[dim]), ( + "Next observation should be equal to the first transition." + ) + + assert torch.equal(replay_buffer.actions[0], dummy_action_3), ( + "Action should be equal to the last transition." + ) + assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the last transition." + assert replay_buffer.dones[0], "Done should be True for the first transition." + assert replay_buffer.truncateds[0], "Truncated should be True for the first transition." + + +def test_sample_from_empty_buffer(replay_buffer): + with pytest.raises(RuntimeError, match="Cannot sample from an empty buffer"): + replay_buffer.sample(1) + + +def test_sample_with_1_transition(replay_buffer, dummy_state, next_dummy_state, dummy_action): + replay_buffer.add(dummy_state, dummy_action, 1.0, next_dummy_state, False, False) + got_batch_transition = replay_buffer.sample(1) + + expected_batch_transition = BatchTransition( + state=clone_state(dummy_state), + action=dummy_action.clone(), + reward=1.0, + next_state=clone_state(next_dummy_state), + done=False, + truncated=False, + ) + + for buffer_property in dict_properties(): + for k, v in expected_batch_transition[buffer_property].items(): + got_state = got_batch_transition[buffer_property][k] + + assert got_state.shape[0] == 1, f"{k} should have 1 transition." + assert got_state.device.type == "cpu", f"{k} should be on cpu." + + assert torch.equal(got_state[0], v), f"{k} should be equal to the expected batch transition." + + for key, _value in expected_batch_transition.items(): + if key in dict_properties(): + continue + + got_value = got_batch_transition[key] + + v_tensor = expected_batch_transition[key] + if not isinstance(v_tensor, torch.Tensor): + v_tensor = torch.tensor(v_tensor) + + assert got_value.shape[0] == 1, f"{key} should have 1 transition." + assert got_value.device.type == "cpu", f"{key} should be on cpu." + assert torch.equal(got_value[0], v_tensor), f"{key} should be equal to the expected batch transition." + + +def test_sample_with_batch_bigger_than_buffer_size( + replay_buffer, dummy_state, next_dummy_state, dummy_action +): + replay_buffer.add(dummy_state, dummy_action, 1.0, next_dummy_state, False, False) + got_batch_transition = replay_buffer.sample(10) + + expected_batch_transition = BatchTransition( + state=dummy_state, + action=dummy_action, + reward=1.0, + next_state=next_dummy_state, + done=False, + truncated=False, + ) + + for buffer_property in dict_properties(): + for k in expected_batch_transition[buffer_property]: + got_state = got_batch_transition[buffer_property][k] + + assert got_state.shape[0] == 1, f"{k} should have 1 transition." + + for key in expected_batch_transition: + if key in dict_properties(): + continue + + got_value = got_batch_transition[key] + assert got_value.shape[0] == 1, f"{key} should have 1 transition." + + +def test_sample_batch(replay_buffer): + dummy_state_1 = create_dummy_state() + dummy_action_1 = create_dummy_action() + + dummy_state_2 = create_dummy_state() + dummy_action_2 = create_dummy_action() + + dummy_state_3 = create_dummy_state() + dummy_action_3 = create_dummy_action() + + dummy_state_4 = create_dummy_state() + dummy_action_4 = create_dummy_action() + + replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False) + replay_buffer.add(dummy_state_2, dummy_action_2, 2.0, dummy_state_2, False, False) + replay_buffer.add(dummy_state_3, dummy_action_3, 3.0, dummy_state_3, True, True) + replay_buffer.add(dummy_state_4, dummy_action_4, 4.0, dummy_state_4, True, True) + + dummy_states = [dummy_state_1, dummy_state_2, dummy_state_3, dummy_state_4] + dummy_actions = [dummy_action_1, dummy_action_2, dummy_action_3, dummy_action_4] + + got_batch_transition = replay_buffer.sample(3) + + for buffer_property in dict_properties(): + for k in got_batch_transition[buffer_property]: + got_state = got_batch_transition[buffer_property][k] + + assert got_state.shape[0] == 3, f"{k} should have 3 transition." + + for got_state_item in got_state: + assert any(torch.equal(got_state_item, dummy_state[k]) for dummy_state in dummy_states), ( + f"{k} should be equal to one of the dummy states." + ) + + for got_action_item in got_batch_transition["action"]: + assert any(torch.equal(got_action_item, dummy_action) for dummy_action in dummy_actions), ( + "Actions should be equal to the dummy actions." + ) + + for k in got_batch_transition: + if k in dict_properties() or k == "complementary_info": + continue + + got_value = got_batch_transition[k] + assert got_value.shape[0] == 3, f"{k} should have 3 transition." + + +def test_to_lerobot_dataset_with_empty_buffer(replay_buffer): + with pytest.raises(ValueError, match="The replay buffer is empty. Cannot convert to a dataset."): + replay_buffer.to_lerobot_dataset("dummy_repo") + + +def test_to_lerobot_dataset(tmp_path): + ds, buffer = create_dataset_from_replay_buffer(tmp_path) + + assert len(ds) == len(buffer), "Dataset should have the same size as the Replay Buffer" + assert ds.fps == 1, "FPS should be 1" + assert ds.repo_id == "dummy/repo", "The dataset should have `dummy/repo` repo id" + + for dim in state_dims(): + assert dim in ds.features + assert ds.features[dim]["shape"] == buffer.states[dim][0].shape + + assert ds.num_episodes == 2 + assert ds.num_frames == 4 + + for j, value in enumerate(ds): + print(torch.equal(value["observation.image"], buffer.next_states["observation.image"][j])) + + for i in range(len(ds)): + for feature, value in ds[i].items(): + if feature == "action": + assert torch.equal(value, buffer.actions[i]) + elif feature == "next.reward": + assert torch.equal(value, buffer.rewards[i]) + elif feature == "next.done": + assert torch.equal(value, buffer.dones[i]) + elif feature == "observation.image": + # Tenssor -> numpy is not precise, so we have some diff there + # TODO: Check and fix it + torch.testing.assert_close(value, buffer.states["observation.image"][i], rtol=0.3, atol=0.003) + elif feature == "observation.state": + assert torch.equal(value, buffer.states["observation.state"][i]) + + +def test_from_lerobot_dataset(tmp_path): + dummy_state_1 = create_dummy_state() + dummy_action_1 = create_dummy_action() + + dummy_state_2 = create_dummy_state() + dummy_action_2 = create_dummy_action() + + dummy_state_3 = create_dummy_state() + dummy_action_3 = create_dummy_action() + + dummy_state_4 = create_dummy_state() + dummy_action_4 = create_dummy_action() + + replay_buffer = create_empty_replay_buffer() + replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False) + replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False) + replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True) + replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True) + + root = tmp_path / "test" + ds = replay_buffer.to_lerobot_dataset(DUMMY_REPO_ID, root=root) + + reconverted_buffer = ReplayBuffer.from_lerobot_dataset( + ds, state_keys=list(state_dims()), device="cpu", capacity=replay_buffer.capacity, use_drq=False + ) + + # Check only the part of the buffer that's actually filled with data + assert torch.equal( + reconverted_buffer.actions[: len(replay_buffer)], + replay_buffer.actions[: len(replay_buffer)], + ), "Actions from converted buffer should be equal to the original replay buffer." + assert torch.equal( + reconverted_buffer.rewards[: len(replay_buffer)], replay_buffer.rewards[: len(replay_buffer)] + ), "Rewards from converted buffer should be equal to the original replay buffer." + assert torch.equal( + reconverted_buffer.dones[: len(replay_buffer)], replay_buffer.dones[: len(replay_buffer)] + ), "Dones from converted buffer should be equal to the original replay buffer." + + # Lerobot DS haven't supported truncateds yet + expected_truncateds = torch.zeros(len(replay_buffer)).bool() + assert torch.equal(reconverted_buffer.truncateds[: len(replay_buffer)], expected_truncateds), ( + "Truncateds from converted buffer should be equal False" + ) + + assert torch.equal( + replay_buffer.states["observation.state"][: len(replay_buffer)], + reconverted_buffer.states["observation.state"][: len(replay_buffer)], + ), "State should be the same after converting to dataset and return back" + + for i in range(4): + torch.testing.assert_close( + replay_buffer.states["observation.image"][i], + reconverted_buffer.states["observation.image"][i], + rtol=0.4, + atol=0.004, + ) + + # The 2, 3 frames have done flag, so their values will be equal to the current state + for i in range(2): + # In the current implementation we take the next state from the `states` and ignore `next_states` + next_index = (i + 1) % 4 + + torch.testing.assert_close( + replay_buffer.states["observation.image"][next_index], + reconverted_buffer.next_states["observation.image"][i], + rtol=0.4, + atol=0.004, + ) + + for i in range(2, 4): + assert torch.equal( + replay_buffer.states["observation.state"][i], + reconverted_buffer.next_states["observation.state"][i], + ) + + +def test_buffer_sample_alignment(): + # Initialize buffer + buffer = ReplayBuffer(capacity=100, device="cpu", state_keys=["state_value"], storage_device="cpu") + + # Fill buffer with patterned data + for i in range(100): + signature = float(i) / 100.0 + state = {"state_value": torch.tensor([[signature]]).float()} + action = torch.tensor([[2.0 * signature]]).float() + reward = 3.0 * signature + + is_end = (i + 1) % 10 == 0 + if is_end: + next_state = {"state_value": torch.tensor([[signature]]).float()} + done = True + else: + next_signature = float(i + 1) / 100.0 + next_state = {"state_value": torch.tensor([[next_signature]]).float()} + done = False + + buffer.add(state, action, reward, next_state, done, False) + + # Sample and verify + batch = buffer.sample(50) + + for i in range(50): + state_sig = batch["state"]["state_value"][i].item() + action_val = batch["action"][i].item() + reward_val = batch["reward"][i].item() + next_state_sig = batch["next_state"]["state_value"][i].item() + is_done = batch["done"][i].item() > 0.5 + + # Verify relationships + assert abs(action_val - 2.0 * state_sig) < 1e-4, ( + f"Action {action_val} should be 2x state signature {state_sig}" + ) + + assert abs(reward_val - 3.0 * state_sig) < 1e-4, ( + f"Reward {reward_val} should be 3x state signature {state_sig}" + ) + + if is_done: + assert abs(next_state_sig - state_sig) < 1e-4, ( + f"For done states, next_state {next_state_sig} should equal state {state_sig}" + ) + else: + # Either it's the next sequential state (+0.01) or same state (for episode boundaries) + valid_next = ( + abs(next_state_sig - state_sig - 0.01) < 1e-4 or abs(next_state_sig - state_sig) < 1e-4 + ) + assert valid_next, ( + f"Next state {next_state_sig} should be either state+0.01 or same as state {state_sig}" + ) + + +def test_memory_optimization(): + dummy_state_1 = create_dummy_state() + dummy_action_1 = create_dummy_action() + + dummy_state_2 = create_dummy_state() + dummy_action_2 = create_dummy_action() + + dummy_state_3 = create_dummy_state() + dummy_action_3 = create_dummy_action() + + dummy_state_4 = create_dummy_state() + dummy_action_4 = create_dummy_action() + + replay_buffer = create_empty_replay_buffer() + replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_2, False, False) + replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_3, False, False) + replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_4, False, False) + replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True) + + optimized_replay_buffer = create_empty_replay_buffer(True) + optimized_replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_2, False, False) + optimized_replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_3, False, False) + optimized_replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_4, False, False) + optimized_replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, None, True, True) + + assert get_object_memory(optimized_replay_buffer) < get_object_memory(replay_buffer), ( + "Optimized replay buffer should be smaller than the original replay buffer" + ) + + +def test_check_image_augmentations_with_drq_and_dummy_image_augmentation_function(dummy_state, dummy_action): + def dummy_image_augmentation_function(x): + return torch.ones_like(x) * 10 + + replay_buffer = create_empty_replay_buffer( + use_drq=True, image_augmentation_function=dummy_image_augmentation_function + ) + + replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False) + + sampled_transitions = replay_buffer.sample(1) + assert torch.all(sampled_transitions["state"]["observation.image"] == 10), ( + "Image augmentations should be applied" + ) + assert torch.all(sampled_transitions["next_state"]["observation.image"] == 10), ( + "Image augmentations should be applied" + ) + + +def test_check_image_augmentations_with_drq_and_default_image_augmentation_function( + dummy_state, dummy_action +): + replay_buffer = create_empty_replay_buffer(use_drq=True) + + replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False) + + # Let's check that it doesn't fail and shapes are correct + sampled_transitions = replay_buffer.sample(1) + assert sampled_transitions["state"]["observation.image"].shape == (1, 3, 84, 84) + assert sampled_transitions["next_state"]["observation.image"].shape == (1, 3, 84, 84) + + +def test_random_crop_vectorized_basic(): + # Create a batch of 2 images with known patterns + batch_size, channels, height, width = 2, 3, 10, 8 + images = torch.zeros((batch_size, channels, height, width)) + + # Fill with unique values for testing + for b in range(batch_size): + images[b] = b + 1 + + crop_size = (6, 4) # Smaller than original + cropped = random_crop_vectorized(images, crop_size) + + # Check output shape + assert cropped.shape == (batch_size, channels, *crop_size) + + # Check that values are preserved (should be either 1s or 2s for respective batches) + assert torch.all(cropped[0] == 1) + assert torch.all(cropped[1] == 2) + + +def test_random_crop_vectorized_invalid_size(): + images = torch.zeros((2, 3, 10, 8)) + + # Test crop size larger than image + with pytest.raises(ValueError, match="Requested crop size .* is bigger than the image size"): + random_crop_vectorized(images, (12, 8)) + + with pytest.raises(ValueError, match="Requested crop size .* is bigger than the image size"): + random_crop_vectorized(images, (10, 10)) + + +def _populate_buffer_for_async_test(capacity: int = 10) -> ReplayBuffer: + """Create a small buffer with deterministic 3×128×128 images and 11-D state.""" + buffer = ReplayBuffer( + capacity=capacity, + device="cpu", + state_keys=["observation.image", "observation.state"], + storage_device="cpu", + ) + + for i in range(capacity): + img = torch.ones(3, 128, 128) * i + state_vec = torch.arange(11).float() + i + state = { + "observation.image": img, + "observation.state": state_vec, + } + buffer.add( + state=state, + action=torch.tensor([0.0]), + reward=0.0, + next_state=state, + done=False, + truncated=False, + ) + return buffer + + +def test_async_iterator_shapes_basic(): + buffer = _populate_buffer_for_async_test() + batch_size = 2 + iterator = buffer.get_iterator(batch_size=batch_size, async_prefetch=True, queue_size=1) + batch = next(iterator) + + images = batch["state"]["observation.image"] + states = batch["state"]["observation.state"] + + assert images.shape == (batch_size, 3, 128, 128) + assert states.shape == (batch_size, 11) + + next_images = batch["next_state"]["observation.image"] + next_states = batch["next_state"]["observation.state"] + + assert next_images.shape == (batch_size, 3, 128, 128) + assert next_states.shape == (batch_size, 11) + + +def test_async_iterator_multiple_iterations(): + buffer = _populate_buffer_for_async_test() + batch_size = 2 + iterator = buffer.get_iterator(batch_size=batch_size, async_prefetch=True, queue_size=2) + + for _ in range(5): + batch = next(iterator) + images = batch["state"]["observation.image"] + states = batch["state"]["observation.state"] + assert images.shape == (batch_size, 3, 128, 128) + assert states.shape == (batch_size, 11) + + next_images = batch["next_state"]["observation.image"] + next_states = batch["next_state"]["observation.state"] + assert next_images.shape == (batch_size, 3, 128, 128) + assert next_states.shape == (batch_size, 11) + + # Ensure iterator can be disposed without blocking + del iterator