-
Notifications
You must be signed in to change notification settings - Fork 704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gymnasium support for DDPG continuous (+Jax) #371
Merged
vwxyzjn
merged 16 commits into
vwxyzjn:master
from
arjun-kg:ddpg_continuous_action_gymnasium
May 3, 2023
Merged
Changes from 3 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
54439c8
ddpg continuous + jax
arjun-kg 6f4f072
fix video recording
arjun-kg 91aae1d
Merge branch 'master' of https://github.com/arjun-kg/cleanrl into ddp…
arjun-kg 4171609
remove pybullet
arjun-kg f2608a3
move to usage docs
arjun-kg 6e6a5b5
isort
arjun-kg d8dd801
update lock files
vwxyzjn 06f41ce
Merge branch 'master' into ddpg_continuous_action_gymnasium
vwxyzjn d9825b4
try trigger CI
vwxyzjn 91f770f
Merge branch 'master' into ddpg_continuous_action_gymnasium
vwxyzjn a05e618
Merge branch 'master' into ddpg_continuous_action_gymnasium
vwxyzjn 03b3c7e
update ddpg default v4 environments
vwxyzjn b6d8598
trigger CI
vwxyzjn 8f0029d
install jax dependency
vwxyzjn d630603
fix CI
vwxyzjn 65000a8
remove windows CI
vwxyzjn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,7 +30,15 @@ You can read more about CleanRL in our [JMLR paper](https://www.jmlr.org/papers/ | |
|
||
CleanRL only contains implementations of **online** deep reinforcement learning algorithms. If you are looking for **offline** algorithms, please check out [tinkoff-ai/CORL](https://github.com/tinkoff-ai/CORL), which shares a similar design philosophy as CleanRL. | ||
|
||
> ℹ️ **Support for Gymnasium**: [Farama-Foundation/Gymnasium](https://github.com/Farama-Foundation/Gymnasium) is the next generation of [`openai/gym`](https://github.com/openai/gym) that will continue to be maintained and introduce new features. Please see their [announcement](https://farama.org/Announcing-The-Farama-Foundation) for further detail. We are migrating to `gymnasium` and the progress can be tracked in [vwxyzjn/cleanrl#277](https://github.com/vwxyzjn/cleanrl/pull/277). | ||
> ℹ️ **Support for Gymnasium**: [Farama-Foundation/Gymnasium](https://github.com/Farama-Foundation/Gymnasium) is the next generation of [`openai/gym`](https://github.com/openai/gym) that will continue to be maintained and introduce new features. Please see their [announcement](https://farama.org/Announcing-The-Farama-Foundation) for further detail. We are migrating to `gymnasium` and the progress can be tracked in [vwxyzjn/cleanrl#277](https://github.com/vwxyzjn/cleanrl/pull/277). | ||
|
||
Currently, `ppo_continuous_action_isaacgym.py`, `ddpg_continuous_action_jax.py`, `ddpg_continuous_action.py` have been ported to gymnasium. | ||
|
||
Please note that, `stable-baselines3` version `1.2` does not support `gymnasium`. To use these scripts, please install the `alpha1` version like, | ||
|
||
``` | ||
poetry run pip install sb3==2.0.0a1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we move this to the usage docs? |
||
``` | ||
|
||
|
||
|
||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,9 +5,10 @@ | |
import time | ||
from distutils.util import strtobool | ||
|
||
import gym | ||
import gymnasium as gym | ||
import numpy as np | ||
import pybullet_envs # noqa | ||
|
||
# import pybullet_envs # noqa | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of commenting, just remove it :) |
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
@@ -37,7 +38,7 @@ def parse_args(): | |
help="whether to capture videos of the agent performances (check out `videos` folder)") | ||
|
||
# Algorithm specific arguments | ||
parser.add_argument("--env-id", type=str, default="HopperBulletEnv-v0", | ||
parser.add_argument("--env-id", type=str, default="HalfCheetah-v2", | ||
help="the id of the environment") | ||
parser.add_argument("--total-timesteps", type=int, default=1000000, | ||
help="total timesteps of the experiments") | ||
|
@@ -66,12 +67,14 @@ def parse_args(): | |
|
||
def make_env(env_id, seed, idx, capture_video, run_name): | ||
def thunk(): | ||
env = gym.make(env_id) | ||
if capture_video: | ||
env = gym.make(env_id, render_mode="rgb_array") | ||
else: | ||
env = gym.make(env_id) | ||
env = gym.wrappers.RecordEpisodeStatistics(env) | ||
if capture_video: | ||
if idx == 0: | ||
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") | ||
env.seed(seed) | ||
env.action_space.seed(seed) | ||
env.observation_space.seed(seed) | ||
return env | ||
|
@@ -128,7 +131,7 @@ def forward(self, x): | |
sync_tensorboard=True, | ||
config=vars(args), | ||
name=run_name, | ||
monitor_gym=True, | ||
# monitor_gym=True, # no longer works for gymnasium | ||
save_code=True, | ||
) | ||
writer = SummaryWriter(f"runs/{run_name}") | ||
|
@@ -164,12 +167,14 @@ def forward(self, x): | |
envs.single_observation_space, | ||
envs.single_action_space, | ||
device, | ||
handle_timeout_termination=True, | ||
handle_timeout_termination=False, | ||
) | ||
start_time = time.time() | ||
|
||
# TRY NOT TO MODIFY: start the game | ||
obs = envs.reset() | ||
obs, _ = envs.reset(seed=args.seed) | ||
video_filenames = set() | ||
|
||
for global_step in range(args.total_timesteps): | ||
# ALGO LOGIC: put action logic here | ||
if global_step < args.learning_starts: | ||
|
@@ -181,22 +186,23 @@ def forward(self, x): | |
actions = actions.cpu().numpy().clip(envs.single_action_space.low, envs.single_action_space.high) | ||
|
||
# TRY NOT TO MODIFY: execute the game and log data. | ||
next_obs, rewards, dones, infos = envs.step(actions) | ||
next_obs, rewards, terminateds, truncateds, infos = envs.step(actions) | ||
|
||
# TRY NOT TO MODIFY: record rewards for plotting purposes | ||
for info in infos: | ||
if "episode" in info.keys(): | ||
|
||
if "final_info" in infos: | ||
for info in infos["final_info"]: | ||
print(f"global_step={global_step}, episodic_return={info['episode']['r']}") | ||
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) | ||
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) | ||
break | ||
|
||
# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation` | ||
real_next_obs = next_obs.copy() | ||
for idx, d in enumerate(dones): | ||
for idx, d in enumerate(truncateds): | ||
if d: | ||
real_next_obs[idx] = infos[idx]["terminal_observation"] | ||
rb.add(obs, real_next_obs, actions, rewards, dones, infos) | ||
real_next_obs[idx] = infos["final_observation"][idx] | ||
rb.add(obs, real_next_obs, actions, rewards, terminateds, infos) | ||
|
||
# TRY NOT TO MODIFY: CRUCIAL step easy to overlook | ||
obs = next_obs | ||
|
@@ -237,4 +243,10 @@ def forward(self, x): | |
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) | ||
|
||
envs.close() | ||
|
||
if args.track and args.capture_video: | ||
for filename in os.listdir(f"videos/{run_name}"): | ||
if filename not in video_filenames and filename.endswith(".mp4"): | ||
wandb.log({f"videos": wandb.Video(f"videos/{run_name}/{filename}")}) | ||
video_filenames.add(filename) | ||
writer.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ppo_continuous_action_isaacgym.py
should not be included, right? It should beppo_continuous_action.py