Skip to content

Commit

Permalink
Gymnasium support for DDPG continuous (+Jax) (#371)
Browse files Browse the repository at this point in the history
* ddpg continuous + jax

* fix video recording

* remove pybullet

* move to usage docs

* isort

* update lock files

* try trigger CI

* update ddpg default v4 environments

* trigger CI

* install jax dependency

* fix CI

* remove windows CI

---------

Co-authored-by: Costa Huang <[email protected]>
  • Loading branch information
arjun-kg and vwxyzjn authored May 3, 2023
1 parent 39670fc commit 9f8b64b
Show file tree
Hide file tree
Showing 12 changed files with 206 additions and 49 deletions.
98 changes: 96 additions & 2 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,13 @@ jobs:
continue-on-error: true # MUJOCO_GL=osmesa results in `free(): invalid pointer`
run: poetry run pytest tests/test_mujoco.py

test-mujoco-envs-windows-mac:
test-mujoco-gymnasium-envs:
strategy:
fail-fast: false
matrix:
python-version: [3.8]
poetry-version: [1.3]
os: [macos-latest, windows-latest]
os: [ubuntu-22.04]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
Expand All @@ -197,9 +197,68 @@ jobs:
run: poetry install -E "pytest mujoco dm_control"
- name: Downgrade setuptools
run: poetry run pip install setuptools==59.5.0
- name: Run gymnasium migration dependencies
run: poetry run pip install "stable_baselines3==2.0.0a1"
- name: install mujoco dependencies
run: |
sudo apt-get update && sudo apt-get -y install libgl1-mesa-glx libosmesa6 libglfw3
- name: Run mujoco tests
continue-on-error: true # MUJOCO_GL=osmesa results in `free(): invalid pointer`
run: poetry run pytest tests/test_mujoco_gymnasium.py

test-mujoco-envs-mac:
strategy:
fail-fast: false
matrix:
python-version: [3.8]
poetry-version: [1.3]
os: [macos-latest]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Run image
uses: abatilo/[email protected]
with:
poetry-version: ${{ matrix.poetry-version }}

# mujoco tests
- name: Install dependencies
run: poetry install -E "pytest mujoco dm_control jax"
- name: Downgrade setuptools
run: poetry run pip install setuptools==59.5.0
- name: Run mujoco tests
run: poetry run pytest tests/test_mujoco.py

test-mujoco-gymnasium-mac:
strategy:
fail-fast: false
matrix:
python-version: [3.8]
poetry-version: [1.3]
os: [macos-latest]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Run image
uses: abatilo/[email protected]
with:
poetry-version: ${{ matrix.poetry-version }}

# mujoco tests
- name: Install dependencies
run: poetry install -E "pytest mujoco dm_control jax"
- name: Downgrade setuptools
run: poetry run pip install setuptools==59.5.0
- name: Run gymnasium migration dependencies
run: poetry run pip install "stable_baselines3==2.0.0a1"
- name: Run mujoco tests
run: poetry run pytest tests/test_mujoco_gymnasium.py

test-mujoco_py-envs:
strategy:
Expand Down Expand Up @@ -234,6 +293,41 @@ jobs:
- name: Run mujoco_py tests
run: poetry run pytest tests/test_mujoco_py.py

test-mujoco_py-envs-gymnasium:
strategy:
fail-fast: false
matrix:
python-version: [3.8]
poetry-version: [1.3]
os: [ubuntu-22.04]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Run image
uses: abatilo/[email protected]
with:
poetry-version: ${{ matrix.poetry-version }}

# mujoco_py tests
- name: Install dependencies
run: poetry install -E "pytest pybullet mujoco_py mujoco jax"
- name: Run gymnasium migration dependencies
run: poetry run pip install "stable_baselines3==2.0.0a1"
- name: Downgrade setuptools
run: poetry run pip install setuptools==59.5.0
- name: install mujoco_py dependencies
run: |
sudo apt-get update && sudo apt-get -y install wget unzip software-properties-common \
libgl1-mesa-dev \
libgl1-mesa-glx \
libglew-dev \
libosmesa6-dev patchelf
- name: Run mujoco_py tests
run: poetry run pytest tests/test_mujoco_py_gymnasium.py

test-envpool-envs:
strategy:
fail-fast: false
Expand Down
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ You can read more about CleanRL in our [JMLR paper](https://www.jmlr.org/papers/

CleanRL only contains implementations of **online** deep reinforcement learning algorithms. If you are looking for **offline** algorithms, please check out [tinkoff-ai/CORL](https://github.com/tinkoff-ai/CORL), which shares a similar design philosophy as CleanRL.

> ℹ️ **Support for Gymnasium**: [Farama-Foundation/Gymnasium](https://github.com/Farama-Foundation/Gymnasium) is the next generation of [`openai/gym`](https://github.com/openai/gym) that will continue to be maintained and introduce new features. Please see their [announcement](https://farama.org/Announcing-The-Farama-Foundation) for further detail. We are migrating to `gymnasium` and the progress can be tracked in [vwxyzjn/cleanrl#277](https://github.com/vwxyzjn/cleanrl/pull/277).
> ℹ️ **Support for Gymnasium**: [Farama-Foundation/Gymnasium](https://github.com/Farama-Foundation/Gymnasium) is the next generation of [`openai/gym`](https://github.com/openai/gym) that will continue to be maintained and introduce new features. Please see their [announcement](https://farama.org/Announcing-The-Farama-Foundation) for further detail. We are migrating to `gymnasium` and the progress can be tracked in [vwxyzjn/cleanrl#277](https://github.com/vwxyzjn/cleanrl/pull/277).

> ⚠️ **NOTE**: CleanRL is *not* a modular library and therefore it is not meant to be imported. At the cost of duplicate code, we make all implementation details of a DRL algorithm variant easy to understand, so CleanRL comes with its own pros and cons. You should consider using CleanRL if you want to 1) understand all implementation details of an algorithm's varaint or 2) prototype advanced features that other modular DRL libraries do not support (CleanRL has minimal lines of code so it gives you great debugging experience and you don't have do a lot of subclassing like sometimes in modular DRL libraries).
Expand Down
Empty file modified benchmark/ddpg.sh
100644 → 100755
Empty file.
47 changes: 33 additions & 14 deletions cleanrl/ddpg_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
import time
from distutils.util import strtobool

import gym
import gymnasium as gym
import numpy as np
import pybullet_envs # noqa
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -37,7 +36,7 @@ def parse_args():
help="whether to capture videos of the agent performances (check out `videos` folder)")

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="HopperBulletEnv-v0",
parser.add_argument("--env-id", type=str, default="HalfCheetah-v4",
help="the id of the environment")
parser.add_argument("--total-timesteps", type=int, default=1000000,
help="total timesteps of the experiments")
Expand Down Expand Up @@ -66,12 +65,14 @@ def parse_args():

def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
env = gym.make(env_id)
if capture_video:
env = gym.make(env_id, render_mode="rgb_array")
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
return env
Expand Down Expand Up @@ -117,6 +118,15 @@ def forward(self, x):


if __name__ == "__main__":
import stable_baselines3 as sb3

if sb3.__version__ < "2.0":
raise ValueError(
"""Ongoing migration: run the following command to install the new dependencies:
poetry run pip install "stable_baselines3==2.0.0a1"
"""
)
args = parse_args()
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
Expand All @@ -128,7 +138,7 @@ def forward(self, x):
sync_tensorboard=True,
config=vars(args),
name=run_name,
monitor_gym=True,
# monitor_gym=True, # no longer works for gymnasium
save_code=True,
)
writer = SummaryWriter(f"runs/{run_name}")
Expand Down Expand Up @@ -164,12 +174,14 @@ def forward(self, x):
envs.single_observation_space,
envs.single_action_space,
device,
handle_timeout_termination=True,
handle_timeout_termination=False,
)
start_time = time.time()

# TRY NOT TO MODIFY: start the game
obs = envs.reset()
obs, _ = envs.reset(seed=args.seed)
video_filenames = set()

for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
if global_step < args.learning_starts:
Expand All @@ -181,22 +193,23 @@ def forward(self, x):
actions = actions.cpu().numpy().clip(envs.single_action_space.low, envs.single_action_space.high)

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)
next_obs, rewards, terminateds, truncateds, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
for info in infos:
if "episode" in info.keys():

if "final_info" in infos:
for info in infos["final_info"]:
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(dones):
for idx, d in enumerate(truncateds):
if d:
real_next_obs[idx] = infos[idx]["terminal_observation"]
rb.add(obs, real_next_obs, actions, rewards, dones, infos)
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminateds, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down Expand Up @@ -237,4 +250,10 @@ def forward(self, x):
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

envs.close()

if args.track and args.capture_video:
for filename in os.listdir(f"videos/{run_name}"):
if filename not in video_filenames and filename.endswith(".mp4"):
wandb.log({f"videos": wandb.Video(f"videos/{run_name}/{filename}")})
video_filenames.add(filename)
writer.close()
45 changes: 31 additions & 14 deletions cleanrl/ddpg_continuous_action_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@

import flax
import flax.linen as nn
import gym
import gymnasium as gym
import jax
import jax.numpy as jnp
import numpy as np
import optax
import pybullet_envs # noqa
from flax.training.train_state import TrainState
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter
Expand All @@ -36,7 +35,7 @@ def parse_args():
help="whether to capture videos of the agent performances (check out `videos` folder)")

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="HalfCheetah-v2",
parser.add_argument("--env-id", type=str, default="HalfCheetah-v4",
help="the id of the environment")
parser.add_argument("--total-timesteps", type=int, default=1000000,
help="total timesteps of the experiments")
Expand Down Expand Up @@ -65,12 +64,14 @@ def parse_args():

def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
env = gym.make(env_id)
if capture_video:
env = gym.make(env_id, render_mode="rgb_array")
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
return env
Expand Down Expand Up @@ -113,6 +114,15 @@ class TrainState(TrainState):


if __name__ == "__main__":
import stable_baselines3 as sb3

if sb3.__version__ < "2.0":
raise ValueError(
"""Ongoing migration: run the following command to install the new dependencies:
poetry run pip install "stable_baselines3==2.0.0a1"
"""
)
args = parse_args()
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
Expand All @@ -124,7 +134,7 @@ class TrainState(TrainState):
sync_tensorboard=True,
config=vars(args),
name=run_name,
monitor_gym=True,
monitor_gym=True, # does not work on gymnasium
save_code=True,
)
writer = SummaryWriter(f"runs/{run_name}")
Expand All @@ -150,11 +160,12 @@ class TrainState(TrainState):
envs.single_observation_space,
envs.single_action_space,
device="cpu",
handle_timeout_termination=True,
handle_timeout_termination=False,
)

# TRY NOT TO MODIFY: start the game
obs = envs.reset()
obs, _ = envs.reset()
video_filenames = set()
action_scale = np.array((envs.action_space.high - envs.action_space.low) / 2.0)
action_bias = np.array((envs.action_space.high + envs.action_space.low) / 2.0)
actor = Actor(
Expand Down Expand Up @@ -235,22 +246,22 @@ def actor_loss(params):
)

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)
next_obs, rewards, terminateds, truncateds, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
for info in infos:
if "episode" in info.keys():
if "final_info" in infos:
for info in infos["final_info"]:
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(dones):
for idx, d in enumerate(truncateds):
if d:
real_next_obs[idx] = infos[idx]["terminal_observation"]
rb.add(obs, real_next_obs, actions, rewards, dones, infos)
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminateds, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down Expand Up @@ -282,4 +293,10 @@ def actor_loss(params):
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

envs.close()
if args.track and args.capture_video:
for filename in os.listdir(f"videos/{run_name}"):
if filename not in video_filenames and filename.endswith(".mp4"):
wandb.log({f"videos": wandb.Video(f"videos/{run_name}/{filename}")})
video_filenames.add(filename)

writer.close()
9 changes: 9 additions & 0 deletions docs/get-started/basic-usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ the CleanRL script under the poetry virtual environments.
**We will assume to run other commands (e.g. `tensorboard`) in the documentation within the poetry's shell.**
!!! note
Currently, `ddpg_continuous_action_jax.py`, `ddpg_continuous_action.py` have been ported to gymnasium.
Please note that, `stable-baselines3` version `1.2` does not support `gymnasium`. To use these scripts, please install the `alpha1` version like,
```
poetry run pip install sb3==2.0.0a1
```
!!! warning
If you are using NVIDIA ampere GPUs (e.g., 3060 TI), you might meet the following error
Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

1 comment on commit 9f8b64b

@vercel
Copy link

@vercel vercel bot commented on 9f8b64b May 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Successfully deployed to the following URLs:

cleanrl – ./

cleanrl-vwxyzjn.vercel.app
cleanrl-git-master-vwxyzjn.vercel.app
docs.cleanrl.dev
cleanrl.vercel.app

Please sign in to comment.