Skip to content

Commit

Permalink
Add gymnasium support for DQN (#370)
Browse files Browse the repository at this point in the history
* Add gymnasium dqn.py

* Add gymnasium support for dqn_jax.py

* moved assert to parse func

* add gymnasium support for dqn atari

* black formatting

* fix make_env for rendering

* moved np to jnp

* add warning mesage and update dependencies

* update test cases

* update test cases

* pre-commit

* bump shimmy version

* update shimmy

* update shimmy

* trigger CI

* trigger CI

* trigger CI

* fix poetry

---------

Co-authored-by: Costa Huang <[email protected]>
  • Loading branch information
vcharraut and vwxyzjn authored May 3, 2023
1 parent 790c917 commit 39670fc
Show file tree
Hide file tree
Showing 24 changed files with 379 additions and 246 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ jobs:
- name: Install jax
if: runner.os == 'Linux' || runner.os == 'macOS'
run: poetry install -E "pytest jax"
- name: Run gymnasium migration dependencies
run: poetry run pip install "stable_baselines3==2.0.0a1"
- name: Run gymnasium tests
run: poetry run pytest tests/test_classic_control_gymnasium.py
- name: Run core tests with jax
if: runner.os == 'Linux' || runner.os == 'macOS'
run: poetry run pytest tests/test_classic_control_jax.py
Expand Down Expand Up @@ -78,6 +82,10 @@ jobs:
- name: Install jax
if: runner.os == 'Linux' || runner.os == 'macOS'
run: poetry install -E "pytest atari jax"
- name: Run gymnasium migration dependencies
run: poetry run pip install "stable_baselines3==2.0.0a1" "gymnasium[atari,accept-rom-license]==0.28.1" "ale-py==0.8.1"
- name: Run gymnasium tests
run: poetry run pytest tests/test_atari_gymnasium.py
- name: Run core tests with jax
if: runner.os == 'Linux' || runner.os == 'macOS'
run: poetry run pytest tests/test_atari_jax.py
Expand Down
53 changes: 35 additions & 18 deletions cleanrl/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
from distutils.util import strtobool

import gym
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
Expand Down Expand Up @@ -48,6 +48,8 @@ def parse_args():
help="total timesteps of the experiments")
parser.add_argument("--learning-rate", type=float, default=2.5e-4,
help="the learning rate of the optimizer")
parser.add_argument("--num-envs", type=int, default=1,
help="the number of parallel game environments")
parser.add_argument("--buffer-size", type=int, default=10000,
help="the replay memory buffer size")
parser.add_argument("--gamma", type=float, default=0.99,
Expand All @@ -70,19 +72,21 @@ def parse_args():
help="the frequency of training")
args = parser.parse_args()
# fmt: on
assert args.num_envs == 1, "vectorized envs are not supported at the moment"

return args


def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
env = gym.make(env_id)
if capture_video and idx == 0:
env = gym.make(env_id, render_mode="rgb_array")
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)

return env

return thunk
Expand Down Expand Up @@ -110,6 +114,15 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):


if __name__ == "__main__":
import stable_baselines3 as sb3

if sb3.__version__ < "2.0":
raise ValueError(
"""Ongoing migration: run the following command to install the new dependencies:
poetry run pip install "stable_baselines3==2.0.0a1"
"""
)
args = parse_args()
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
Expand Down Expand Up @@ -139,7 +152,9 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

# env setup
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
envs = gym.vector.SyncVectorEnv(
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
)
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

q_network = QNetwork(envs).to(device)
Expand All @@ -152,12 +167,12 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
envs.single_observation_space,
envs.single_action_space,
device,
handle_timeout_termination=True,
handle_timeout_termination=False,
)
start_time = time.time()

# TRY NOT TO MODIFY: start the game
obs = envs.reset()
obs, _ = envs.reset(seed=args.seed)
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
Expand All @@ -168,23 +183,25 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
actions = torch.argmax(q_values, dim=1).cpu().numpy()

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)
next_obs, rewards, terminated, truncated, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
for info in infos:
if "episode" in info.keys():
if "final_info" in infos:
for info in infos["final_info"]:
# Skip the envs that are not done
if "episode" not in info:
continue
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(dones):
for idx, d in enumerate(truncated):
if d:
real_next_obs[idx] = infos[idx]["terminal_observation"]
rb.add(obs, real_next_obs, actions, rewards, dones, infos)
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminated, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down
53 changes: 35 additions & 18 deletions cleanrl/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
from distutils.util import strtobool

import gym
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
Expand Down Expand Up @@ -55,6 +55,8 @@ def parse_args():
help="total timesteps of the experiments")
parser.add_argument("--learning-rate", type=float, default=1e-4,
help="the learning rate of the optimizer")
parser.add_argument("--num-envs", type=int, default=1,
help="the number of parallel game environments")
parser.add_argument("--buffer-size", type=int, default=1000000,
help="the replay memory buffer size")
parser.add_argument("--gamma", type=float, default=0.99,
Expand All @@ -77,16 +79,19 @@ def parse_args():
help="the frequency of training")
args = parser.parse_args()
# fmt: on
assert args.num_envs == 1, "vectorized envs are not supported at the moment"

return args


def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
env = gym.make(env_id)
if capture_video and idx == 0:
env = gym.make(env_id, render_mode="rgb_array")
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
env = EpisodicLifeEnv(env)
Expand All @@ -96,9 +101,8 @@ def thunk():
env = gym.wrappers.ResizeObservation(env, (84, 84))
env = gym.wrappers.GrayScaleObservation(env)
env = gym.wrappers.FrameStack(env, 4)
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)

return env

return thunk
Expand Down Expand Up @@ -131,6 +135,15 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):


if __name__ == "__main__":
import stable_baselines3 as sb3

if sb3.__version__ < "2.0":
raise ValueError(
"""Ongoing migration: run the following command to install the new dependencies:
poetry run pip install "stable_baselines3==2.0.0a1" "gymnasium[atari,accept-rom-license]==0.28.1" "ale-py==0.8.1"
"""
)
args = parse_args()
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
Expand Down Expand Up @@ -160,7 +173,9 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

# env setup
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
envs = gym.vector.SyncVectorEnv(
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
)
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

q_network = QNetwork(envs).to(device)
Expand All @@ -174,12 +189,12 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
envs.single_action_space,
device,
optimize_memory_usage=True,
handle_timeout_termination=True,
handle_timeout_termination=False,
)
start_time = time.time()

# TRY NOT TO MODIFY: start the game
obs = envs.reset()
obs, _ = envs.reset(seed=args.seed)
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
Expand All @@ -190,23 +205,25 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
actions = torch.argmax(q_values, dim=1).cpu().numpy()

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)
next_obs, rewards, terminated, truncated, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
for info in infos:
if "episode" in info.keys():
if "final_info" in infos:
for info in infos["final_info"]:
# Skip the envs that are not done
if "episode" not in info:
continue
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(dones):
for idx, d in enumerate(truncated):
if d:
real_next_obs[idx] = infos[idx]["terminal_observation"]
rb.add(obs, real_next_obs, actions, rewards, dones, infos)
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminated, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down
Loading

1 comment on commit 39670fc

@vercel
Copy link

@vercel vercel bot commented on 39670fc May 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Successfully deployed to the following URLs:

cleanrl – ./

cleanrl-git-master-vwxyzjn.vercel.app
cleanrl.vercel.app
cleanrl-vwxyzjn.vercel.app
docs.cleanrl.dev

Please sign in to comment.