Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add RLPD Baseline #437

Merged
merged 6 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/source/user_guide/workflows/learning_from_demos/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ As part of these baselines we establish a few standard learning from demonstrati

**Online Learning from Demonstrations Baselines**

| Baseline | Code | Results | Paper |
| --------------------------------------------------- | ----------------------------------------------------------------------------------- | ------- | ---------------------------------------- |
| Reverse Forward Curriculum Learning (RFCL)* | [Link](https://github.com/haosulab/ManiSkill/blob/main/examples/baselines/rfcl) | WIP | [Link](https://arxiv.org/abs/2405.03379) |
| Reinforcement Learning from Prior Data (RLPD) | WIP | WIP | [Link](https://arxiv.org/abs/2302.02948) |
| SAC + Demos (SAC+Demos) | WIP | N/A | |
| Baseline | Code | Results | Paper |
| --------------------------------------------- | ------------------------------------------------------------------------------- | ------- | ---------------------------------------- |
| Reverse Forward Curriculum Learning (RFCL)* | [Link](https://github.com/haosulab/ManiSkill/blob/main/examples/baselines/rfcl) | WIP | [Link](https://arxiv.org/abs/2405.03379) |
| Reinforcement Learning from Prior Data (RLPD) | [Link](https://github.com/haosulab/ManiSkill/blob/main/examples/baselines/rlpd) | WIP | [Link](https://arxiv.org/abs/2302.02948) |
| SAC + Demos (SAC+Demos) | WIP | N/A | |


\* - This indicates the baseline uses environment state reset which is typically a simulation only feature
19 changes: 11 additions & 8 deletions examples/baselines/rfcl/README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Reverse Forward Curriculum Learning

Fast offline/online imitation learning from sparse rewards in simulation based on ["Reverse Forward Curriculum Learning for Extreme Sample and Demo Efficiency in Reinforcement Learning (ICLR 2024)"](https://arxiv.org/abs/2405.03379). Code adapted from https://github.com/StoneT2000/rfcl/
Fast offline/online imitation learning from sparse rewards from very few demonstrations in simulation based on ["Reverse Forward Curriculum Learning for Extreme Sample and Demo Efficiency in Reinforcement Learning (ICLR 2024)"](https://arxiv.org/abs/2405.03379). Code adapted from https://github.com/StoneT2000/rfcl/

This code can be useful for solving tasks, verifying tasks are solvable via neural nets, and generating infinite demonstrations via trained neural nets, all without using dense rewards (provided the task is not too long horizon)
This code can be useful for solving tasks, verifying tasks are solvable via neural nets, and generating infinite demonstrations via trained neural nets, all without using dense rewards (provided the task is not too long horizon).

## Installation
To get started run `git clone https://github.com/StoneT2000/rfcl.git rfcl_jax --branch ms3-gpu` which contains the code for RFCL written in jax. While ManiSkill3 does run on torch, the jax implementation is much more optimized and trains faster.
Expand Down Expand Up @@ -31,14 +31,15 @@ python -m mani_skill.utils.download_demo "PickCube-v1"
```

<!-- TODO (stao): note how this part can be optional if user wants to do action free learning -->
Process the demonstrations in preparation for the learning workflow. We will use the teleoperated trajectories to train. Other provided demonstration sources (like motion planning and RL generated) can work as well but may require modifying a few hyperparameters.
Process the demonstrations in preparation for the learning workflow. We will use the teleoperated trajectories to train. Other provided demonstration sources (like motion planning and RL generated) can work as well but may require modifying a few hyperparameters. RFCL is extremely demonstration efficient and so we only need to process and save 5 demonstrations for training here.

```bash
env_id="PickCube-v1"
python -m mani_skill.trajectory.replay_trajectory \
--traj-path ~/.maniskill/demos/${env_id}/motionplanning/trajectory.h5 \
--use-first-env-state \
-c pd_joint_delta_pos -o state \
--save-traj
--save-traj --count 5
```

## Train
Expand Down Expand Up @@ -81,12 +82,14 @@ To generate 1000 demonstrations you can run
XLA_PYTHON_CLIENT_PREALLOCATE=false python rfcl_jax/scripts/collect_demos.py exps/path/to/model.jx \
num_envs=8 num_episodes=1000
```
This saves the demos

which uses CPU vectorization to generate demonstrations in parallel. Note that while the demos are generated on the CPU, you can always convert them to demonstrations on the GPU via the [replay trajectory tool](https://maniskill.readthedocs.io/en/latest/user_guide/datasets/replay.html) as so
This saves the demos which uses CPU vectorization to generate demonstrations in parallel. Note that while the demos are generated on the CPU, you can always convert them to demonstrations on the GPU via the [replay trajectory tool](https://maniskill.readthedocs.io/en/latest/user_guide/datasets/replay.html) as so

```bash
python -m mani_skill.trajectory.replay_trajectory \
--traj-path exps/ms3/PickCube-v1/_5_demos_s42/eval_videos/trajectory.h5 \
-b gpu --use-first-env-state
```
```

The replay_trajectory tool can also be used to generate videos

See the rlpd_jax/scripts/collect_demos.py code for details on how to load the saved policies and modify it to your needs.
4 changes: 4 additions & 0 deletions examples/baselines/rlpd/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
/rlpd_jax
/wandb
/exps
/videos
100 changes: 100 additions & 0 deletions examples/baselines/rlpd/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Reinforcement Learning from Prior Data (RLPD)

Sample-efficient offline/online imitation learning from sparse rewards leveraging prior data based on ["Efficient Online Reinforcement Learning with Offline Data
(ICML 2023)"](https://arxiv.org/abs/2302.02948). Code adapted from https://github.com/ikostrikov/rlpd

RLPD leverages prior collected trajectory data (expert and non-expert work) and trains on the prior data while collecting online data to sample efficiently learn a policy to solve a task with just sparse rewards.

## Installation

To get started run `git clone https://github.com/StoneT2000/rfcl.git rlpd_jax --branch ms3-gpu` which contains the code for RLPD written in jax (a partial fork of the original RLPD and JaxRL repos that has been optimized to run faster and support vectorized environments).

We recommend using conda/mamba and you can install the dependencies as so:

```bash
conda create -n "rlpd_ms3" "python==3.9"
conda activate rlpd_ms3
pip install --upgrade "jax[cuda12_pip]==0.4.28" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install -e rlpd_jax
```

Then you can install ManiSkill and its dependencies

```bash
pip install mani_skill torch==2.3.1
```
Note that since jax and torch are used, we recommend installing the specific versions detailed in the commands above as those are tested to work together.

## Download and Process Dataset

Download demonstrations for a desired task e.g. PickCube-v1
```bash
python -m mani_skill.utils.download_demo "PickCube-v1"
```

<!-- TODO (stao): note how this part can be optional if user wants to do action free learning -->
Process the demonstrations in preparation for the learning workflow. RLPD works well for harder tasks if sufficient data is provided and the data itself is not too multi-modal. Hence we will use the RL generated trajectories (lot of data and not multi-modal so much easier to learn) for the example below:


The preprocessing step here simply replays all trajectories by environment state (so the exact same trajectory is returned) and save the state observations to train on. Moreover failed demos are also saved as RLPD can learn from sub-optimal data as well.

```bash
env_id="PickCube-v1"
python -m mani_skill.trajectory.replay_trajectory \
--traj-path ~/.maniskill/demos/${env_id}/rl/trajectory.h5 \
--use-env-state --allow-failure \
-c pd_joint_delta_pos -o state \
--save-traj --num-procs 4
```

## Train

To train with environment vectorization run

```bash
env_id=PickCube-v1
demos=1000 # number of demos to train on.
seed=42
XLA_PYTHON_CLIENT_PREALLOCATE=false python train_ms3.py configs/base_rlpd_ms3.yml \
logger.exp_name="rlpd-${env_id}-state-${demos}_rl_demos-${seed}-walltime_efficient" logger.wandb=True \
seed=${seed} train.num_demos=${demos} train.steps=200_000 \
env.env_id=${env_id} \
train.dataset_path="~/.maniskill/demos/${env_id}/rl/trajectory.state.pd_joint_delta_pos.h5"
```

This should solve the PickCube-v1 task in a few minutes, but won't get good sample efficiency.

For sample-efficient settings you can use the sample-efficient configurations stored in configs/base_rlpd_ms3_sample_efficient.yml (no env parallelization, more critics, higher update-to-data ratio). This will take less environment samples (around 50K to solve) but runs slower.

```bash
env_id=PickCube-v1
demos=1000 # number of demos to train on.
seed=42
XLA_PYTHON_CLIENT_PREALLOCATE=false python train_ms3.py configs/base_rlpd_ms3_sample_efficient.yml \
logger.exp_name="rlpd-${env_id}-state-${demos}_rl_demos-${seed}-sample_efficient" logger.wandb=True \
seed=${seed} train.num_demos=${demos} train.steps=100_000 \
env.env_id=${env_id} \
train.dataset_path="~/.maniskill/demos/${env_id}/rl/trajectory.state.pd_joint_delta_pos.h5"
```

evaluation videos are saved to `exps/<exp_name>/videos`.

## Generating Demonstrations / Evaluating policies

To generate 1000 demonstrations you can run

```bash
XLA_PYTHON_CLIENT_PREALLOCATE=false python rlpd_jax/scripts/collect_demos.py exps/path/to/model.jx \
num_envs=8 num_episodes=1000
```
This saves the demos which uses CPU vectorization to generate demonstrations in parallel. Note that while the demos are generated on the CPU, you can always convert them to demonstrations on the GPU via the [replay trajectory tool](https://maniskill.readthedocs.io/en/latest/user_guide/datasets/replay.html) as so

```bash
python -m mani_skill.trajectory.replay_trajectory \
--traj-path exps/<exp_name>/eval_videos/trajectory.h5 \
-b gpu --use-first-env-state
```

The replay_trajectory tool can also be used to generate videos

See the rlpd_jax/scripts/collect_demos.py code for details on how to load the saved policies and modify it to your needs.
75 changes: 75 additions & 0 deletions examples/baselines/rlpd/configs/base_rlpd_ms3.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
jax_env: False

seed: 0
algo: sac
verbose: 1
# Environment configuration
env:
env_id: None
max_episode_steps: 50
num_envs: 8
env_type: "gym:cpu"
env_kwargs:
control_mode: "pd_joint_delta_pos"
render_mode: "rgb_array"
reward_mode: "sparse"
eval_env:
num_envs: 2
max_episode_steps: 50

sac:
num_seed_steps: 5_000
seed_with_policy: False
replay_buffer_capacity: 1_000_000
batch_size: 256
steps_per_env: 4
grad_updates_per_step: 16
actor_update_freq: 1

num_qs: 2
num_min_qs: 2

discount: 0.9
tau: 0.005
backup_entropy: False

eval_freq: 50_000
eval_steps: 250

log_freq: 1000
save_freq: 50_000

learnable_temp: True
initial_temperature: 1.0

network:
actor:
type: "mlp"
arch_cfg:
features: [256, 256, 256]
output_activation: "relu"
critic:
type: "mlp"
arch_cfg:
features: [256, 256, 256]
output_activation: "relu"
use_layer_norm: True

train:
actor_lr: 3e-4
critic_lr: 3e-4
steps: 100_000_000
dataset_path: None
shuffle_demos: True
num_demos: 1000

data_action_scale: null

logger:
tensorboard: True
wandb: False

workspace: "exps"
project_name: "ManiSkill"
wandb_cfg:
group: "RLPD"
75 changes: 75 additions & 0 deletions examples/baselines/rlpd/configs/base_rlpd_ms3_sample_efficient.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
jax_env: False

seed: 0
algo: sac
verbose: 1
# Environment configuration
env:
env_id: None
max_episode_steps: 50
num_envs: 1
env_type: "gym:cpu"
env_kwargs:
control_mode: "pd_joint_delta_pos"
render_mode: "rgb_array"
reward_mode: "sparse"
eval_env:
num_envs: 2
max_episode_steps: 50

sac:
num_seed_steps: 5_000
seed_with_policy: False
replay_buffer_capacity: 1_000_000
batch_size: 256
steps_per_env: 1
grad_updates_per_step: 16
actor_update_freq: 1

num_qs: 10
num_min_qs: 2

discount: 0.9
tau: 0.005
backup_entropy: False

eval_freq: 5_000
eval_steps: 250

log_freq: 1000
save_freq: 5_000

learnable_temp: True
initial_temperature: 1.0

network:
actor:
type: "mlp"
arch_cfg:
features: [256, 256, 256]
output_activation: "relu"
critic:
type: "mlp"
arch_cfg:
features: [256, 256, 256]
output_activation: "relu"
use_layer_norm: True

train:
actor_lr: 3e-4
critic_lr: 3e-4
steps: 100_000_000
dataset_path: None
shuffle_demos: True
num_demos: 1000

data_action_scale: null

logger:
tensorboard: True
wandb: False

workspace: "exps"
project_name: "ManiSkill"
wandb_cfg:
group: "RLPD"
Loading