Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gymnasium support for DDPG continuous (+Jax) #371

Merged
merged 16 commits into from
May 3, 2023
Merged
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,15 @@ You can read more about CleanRL in our [JMLR paper](https://www.jmlr.org/papers/

CleanRL only contains implementations of **online** deep reinforcement learning algorithms. If you are looking for **offline** algorithms, please check out [tinkoff-ai/CORL](https://github.com/tinkoff-ai/CORL), which shares a similar design philosophy as CleanRL.

> ℹ️ **Support for Gymnasium**: [Farama-Foundation/Gymnasium](https://github.com/Farama-Foundation/Gymnasium) is the next generation of [`openai/gym`](https://github.com/openai/gym) that will continue to be maintained and introduce new features. Please see their [announcement](https://farama.org/Announcing-The-Farama-Foundation) for further detail. We are migrating to `gymnasium` and the progress can be tracked in [vwxyzjn/cleanrl#277](https://github.com/vwxyzjn/cleanrl/pull/277).
> ℹ️ **Support for Gymnasium**: [Farama-Foundation/Gymnasium](https://github.com/Farama-Foundation/Gymnasium) is the next generation of [`openai/gym`](https://github.com/openai/gym) that will continue to be maintained and introduce new features. Please see their [announcement](https://farama.org/Announcing-The-Farama-Foundation) for further detail. We are migrating to `gymnasium` and the progress can be tracked in [vwxyzjn/cleanrl#277](https://github.com/vwxyzjn/cleanrl/pull/277).

Currently, `ppo_continuous_action_isaacgym.py`, `ddpg_continuous_action_jax.py`, `ddpg_continuous_action.py` have been ported to gymnasium.
Copy link
Owner

@vwxyzjn vwxyzjn Apr 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ppo_continuous_action_isaacgym.py should not be included, right? It should be ppo_continuous_action.py


Please note that, `stable-baselines3` version `1.2` does not support `gymnasium`. To use these scripts, please install the `alpha1` version like,

```
poetry run pip install sb3==2.0.0a1
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we move this to the usage docs?

```



Expand Down
Empty file modified benchmark/ddpg.sh
100644 → 100755
Empty file.
40 changes: 26 additions & 14 deletions cleanrl/ddpg_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import time
from distutils.util import strtobool

import gym
import gymnasium as gym
import numpy as np
import pybullet_envs # noqa

# import pybullet_envs # noqa
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of commenting, just remove it :)

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -37,7 +38,7 @@ def parse_args():
help="whether to capture videos of the agent performances (check out `videos` folder)")

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="HopperBulletEnv-v0",
parser.add_argument("--env-id", type=str, default="HalfCheetah-v2",
help="the id of the environment")
parser.add_argument("--total-timesteps", type=int, default=1000000,
help="total timesteps of the experiments")
Expand Down Expand Up @@ -66,12 +67,14 @@ def parse_args():

def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
env = gym.make(env_id)
if capture_video:
env = gym.make(env_id, render_mode="rgb_array")
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
return env
Expand Down Expand Up @@ -128,7 +131,7 @@ def forward(self, x):
sync_tensorboard=True,
config=vars(args),
name=run_name,
monitor_gym=True,
# monitor_gym=True, # no longer works for gymnasium
save_code=True,
)
writer = SummaryWriter(f"runs/{run_name}")
Expand Down Expand Up @@ -164,12 +167,14 @@ def forward(self, x):
envs.single_observation_space,
envs.single_action_space,
device,
handle_timeout_termination=True,
handle_timeout_termination=False,
)
start_time = time.time()

# TRY NOT TO MODIFY: start the game
obs = envs.reset()
obs, _ = envs.reset(seed=args.seed)
video_filenames = set()

for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
if global_step < args.learning_starts:
Expand All @@ -181,22 +186,23 @@ def forward(self, x):
actions = actions.cpu().numpy().clip(envs.single_action_space.low, envs.single_action_space.high)

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)
next_obs, rewards, terminateds, truncateds, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
for info in infos:
if "episode" in info.keys():

if "final_info" in infos:
for info in infos["final_info"]:
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(dones):
for idx, d in enumerate(truncateds):
if d:
real_next_obs[idx] = infos[idx]["terminal_observation"]
rb.add(obs, real_next_obs, actions, rewards, dones, infos)
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminateds, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down Expand Up @@ -237,4 +243,10 @@ def forward(self, x):
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

envs.close()

if args.track and args.capture_video:
for filename in os.listdir(f"videos/{run_name}"):
if filename not in video_filenames and filename.endswith(".mp4"):
wandb.log({f"videos": wandb.Video(f"videos/{run_name}/{filename}")})
video_filenames.add(filename)
writer.close()
34 changes: 21 additions & 13 deletions cleanrl/ddpg_continuous_action_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@

import flax
import flax.linen as nn
import gym
import gymnasium as gym
import jax
import jax.numpy as jnp
import numpy as np
import optax
import pybullet_envs # noqa
from flax.training.train_state import TrainState
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter
Expand Down Expand Up @@ -65,12 +64,14 @@ def parse_args():

def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
env = gym.make(env_id)
if capture_video:
env = gym.make(env_id, render_mode="rgb_array")
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
return env
Expand Down Expand Up @@ -124,7 +125,7 @@ class TrainState(TrainState):
sync_tensorboard=True,
config=vars(args),
name=run_name,
monitor_gym=True,
monitor_gym=True, # does not work on gymnasium
save_code=True,
)
writer = SummaryWriter(f"runs/{run_name}")
Expand All @@ -150,11 +151,12 @@ class TrainState(TrainState):
envs.single_observation_space,
envs.single_action_space,
device="cpu",
handle_timeout_termination=True,
handle_timeout_termination=False,
)

# TRY NOT TO MODIFY: start the game
obs = envs.reset()
obs, _ = envs.reset()
video_filenames = set()
action_scale = np.array((envs.action_space.high - envs.action_space.low) / 2.0)
action_bias = np.array((envs.action_space.high + envs.action_space.low) / 2.0)
actor = Actor(
Expand Down Expand Up @@ -235,22 +237,22 @@ def actor_loss(params):
)

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)
next_obs, rewards, terminateds, truncateds, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
for info in infos:
if "episode" in info.keys():
if "final_info" in infos:
for info in infos["final_info"]:
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(dones):
for idx, d in enumerate(truncateds):
if d:
real_next_obs[idx] = infos[idx]["terminal_observation"]
rb.add(obs, real_next_obs, actions, rewards, dones, infos)
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminateds, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down Expand Up @@ -282,4 +284,10 @@ def actor_loss(params):
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

envs.close()
if args.track and args.capture_video:
for filename in os.listdir(f"videos/{run_name}"):
if filename not in video_filenames and filename.endswith(".mp4"):
wandb.log({f"videos": wandb.Video(f"videos/{run_name}/{filename}")})
video_filenames.add(filename)

writer.close()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ wandb = "^0.13.6"
gym = "0.23.1"
torch = ">=1.12.1"
stable-baselines3 = "1.2.0"
gymnasium = "^0.26.3"
gymnasium = "^0.28.1"
moviepy = "^1.0.3"
pygame = "2.1.0"
huggingface-hub = "^0.11.1"
Expand Down