Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gymnasium support for DDPG continuous (+Jax) #371

Merged
merged 16 commits into from
May 3, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ You can read more about CleanRL in our [JMLR paper](https://www.jmlr.org/papers/

CleanRL only contains implementations of **online** deep reinforcement learning algorithms. If you are looking for **offline** algorithms, please check out [tinkoff-ai/CORL](https://github.com/tinkoff-ai/CORL), which shares a similar design philosophy as CleanRL.

> ℹ️ **Support for Gymnasium**: [Farama-Foundation/Gymnasium](https://github.com/Farama-Foundation/Gymnasium) is the next generation of [`openai/gym`](https://github.com/openai/gym) that will continue to be maintained and introduce new features. Please see their [announcement](https://farama.org/Announcing-The-Farama-Foundation) for further detail. We are migrating to `gymnasium` and the progress can be tracked in [vwxyzjn/cleanrl#277](https://github.com/vwxyzjn/cleanrl/pull/277).

> ℹ️ **Support for Gymnasium**: [Farama-Foundation/Gymnasium](https://github.com/Farama-Foundation/Gymnasium) is the next generation of [`openai/gym`](https://github.com/openai/gym) that will continue to be maintained and introduce new features. Please see their [announcement](https://farama.org/Announcing-The-Farama-Foundation) for further detail. We are migrating to `gymnasium` and the progress can be tracked in [vwxyzjn/cleanrl#277](https://github.com/vwxyzjn/cleanrl/pull/277).


> ⚠️ **NOTE**: CleanRL is *not* a modular library and therefore it is not meant to be imported. At the cost of duplicate code, we make all implementation details of a DRL algorithm variant easy to understand, so CleanRL comes with its own pros and cons. You should consider using CleanRL if you want to 1) understand all implementation details of an algorithm's varaint or 2) prototype advanced features that other modular DRL libraries do not support (CleanRL has minimal lines of code so it gives you great debugging experience and you don't have do a lot of subclassing like sometimes in modular DRL libraries).
Expand Down
Empty file modified benchmark/ddpg.sh
100644 → 100755
Empty file.
42 changes: 28 additions & 14 deletions cleanrl/ddpg_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
import time
from distutils.util import strtobool

import gym
import gymnasium as gym
import numpy as np
import pybullet_envs # noqa
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -37,7 +36,7 @@ def parse_args():
help="whether to capture videos of the agent performances (check out `videos` folder)")

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="HopperBulletEnv-v0",
parser.add_argument("--env-id", type=str, default="HalfCheetah-v2",
help="the id of the environment")
parser.add_argument("--total-timesteps", type=int, default=1000000,
help="total timesteps of the experiments")
Expand Down Expand Up @@ -66,12 +65,14 @@ def parse_args():

def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
env = gym.make(env_id)
if capture_video:
env = gym.make(env_id, render_mode="rgb_array")
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
return env
Expand Down Expand Up @@ -118,6 +119,10 @@ def forward(self, x):

if __name__ == "__main__":
args = parse_args()
import stable_baselines3 as sb3
if sb3.__version__ < "2.0":
raise ValueError("Ongoing migration: run `poetry run pip install sb3==2.0.0a1`")
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
import wandb
Expand All @@ -128,7 +133,7 @@ def forward(self, x):
sync_tensorboard=True,
config=vars(args),
name=run_name,
monitor_gym=True,
# monitor_gym=True, # no longer works for gymnasium
save_code=True,
)
writer = SummaryWriter(f"runs/{run_name}")
Expand Down Expand Up @@ -164,12 +169,14 @@ def forward(self, x):
envs.single_observation_space,
envs.single_action_space,
device,
handle_timeout_termination=True,
handle_timeout_termination=False,
)
start_time = time.time()

# TRY NOT TO MODIFY: start the game
obs = envs.reset()
obs, _ = envs.reset(seed=args.seed)
video_filenames = set()

for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
if global_step < args.learning_starts:
Expand All @@ -181,22 +188,23 @@ def forward(self, x):
actions = actions.cpu().numpy().clip(envs.single_action_space.low, envs.single_action_space.high)

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)
next_obs, rewards, terminateds, truncateds, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
for info in infos:
if "episode" in info.keys():

if "final_info" in infos:
for info in infos["final_info"]:
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(dones):
for idx, d in enumerate(truncateds):
if d:
real_next_obs[idx] = infos[idx]["terminal_observation"]
rb.add(obs, real_next_obs, actions, rewards, dones, infos)
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminateds, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down Expand Up @@ -237,4 +245,10 @@ def forward(self, x):
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

envs.close()

if args.track and args.capture_video:
for filename in os.listdir(f"videos/{run_name}"):
if filename not in video_filenames and filename.endswith(".mp4"):
wandb.log({f"videos": wandb.Video(f"videos/{run_name}/{filename}")})
video_filenames.add(filename)
writer.close()
34 changes: 21 additions & 13 deletions cleanrl/ddpg_continuous_action_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@

import flax
import flax.linen as nn
import gym
import gymnasium as gym
import jax
import jax.numpy as jnp
import numpy as np
import optax
import pybullet_envs # noqa
from flax.training.train_state import TrainState
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter
Expand Down Expand Up @@ -65,12 +64,14 @@ def parse_args():

def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
env = gym.make(env_id)
if capture_video:
env = gym.make(env_id, render_mode="rgb_array")
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
return env
Expand Down Expand Up @@ -124,7 +125,7 @@ class TrainState(TrainState):
sync_tensorboard=True,
config=vars(args),
name=run_name,
monitor_gym=True,
monitor_gym=True, # does not work on gymnasium
save_code=True,
)
writer = SummaryWriter(f"runs/{run_name}")
Expand All @@ -150,11 +151,12 @@ class TrainState(TrainState):
envs.single_observation_space,
envs.single_action_space,
device="cpu",
handle_timeout_termination=True,
handle_timeout_termination=False,
)

# TRY NOT TO MODIFY: start the game
obs = envs.reset()
obs, _ = envs.reset()
video_filenames = set()
action_scale = np.array((envs.action_space.high - envs.action_space.low) / 2.0)
action_bias = np.array((envs.action_space.high + envs.action_space.low) / 2.0)
actor = Actor(
Expand Down Expand Up @@ -235,22 +237,22 @@ def actor_loss(params):
)

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)
next_obs, rewards, terminateds, truncateds, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
for info in infos:
if "episode" in info.keys():
if "final_info" in infos:
for info in infos["final_info"]:
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(dones):
for idx, d in enumerate(truncateds):
if d:
real_next_obs[idx] = infos[idx]["terminal_observation"]
rb.add(obs, real_next_obs, actions, rewards, dones, infos)
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminateds, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down Expand Up @@ -282,4 +284,10 @@ def actor_loss(params):
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

envs.close()
if args.track and args.capture_video:
for filename in os.listdir(f"videos/{run_name}"):
if filename not in video_filenames and filename.endswith(".mp4"):
wandb.log({f"videos": wandb.Video(f"videos/{run_name}/{filename}")})
video_filenames.add(filename)

writer.close()
9 changes: 9 additions & 0 deletions docs/get-started/basic-usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ the CleanRL script under the poetry virtual environments.
**We will assume to run other commands (e.g. `tensorboard`) in the documentation within the poetry's shell.**


!!! note
Currently, `ddpg_continuous_action_jax.py`, `ddpg_continuous_action.py` have been ported to gymnasium.

Please note that, `stable-baselines3` version `1.2` does not support `gymnasium`. To use these scripts, please install the `alpha1` version like,

```
poetry run pip install sb3==2.0.0a1
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be poetry run pip install stable_baselines3==2.0.0a1

```

!!! warning

If you are using NVIDIA ampere GPUs (e.g., 3060 TI), you might meet the following error
Expand Down
80 changes: 51 additions & 29 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ wandb = "^0.13.6"
gym = "0.23.1"
torch = ">=1.12.1"
stable-baselines3 = "1.2.0"
gymnasium = "^0.26.3"
gymnasium = "^0.28.1"
moviepy = "^1.0.3"
pygame = "2.1.0"
huggingface-hub = "^0.11.1"
Expand Down