Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gymnasium support for DQN #370

Merged
merged 19 commits into from
May 3, 2023
8 changes: 8 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ jobs:
- name: Install jax
if: runner.os == 'Linux' || runner.os == 'macOS'
run: poetry install -E "pytest jax"
- name: Run gymnasium migration dependencies
run: poetry run pip install "stable_baselines3==2.0.0a1"
- name: Run gymnasium tests
run: poetry run pytest tests/test_classic_control_gymnasium.py
- name: Run core tests with jax
if: runner.os == 'Linux' || runner.os == 'macOS'
run: poetry run pytest tests/test_classic_control_jax.py
Expand Down Expand Up @@ -78,6 +82,10 @@ jobs:
- name: Install jax
if: runner.os == 'Linux' || runner.os == 'macOS'
run: poetry install -E "pytest atari jax"
- name: Run gymnasium migration dependencies
run: poetry run pip install "stable_baselines3==2.0.0a1" "gymnasium[atari,accept-rom-license]==0.28.1" "ale-py==0.8.1"
- name: Run gymnasium tests
run: poetry run pytest tests/test_atari_gymnasium.py
- name: Run core tests with jax
if: runner.os == 'Linux' || runner.os == 'macOS'
run: poetry run pytest tests/test_atari_jax.py
Expand Down
53 changes: 35 additions & 18 deletions cleanrl/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
from distutils.util import strtobool

import gym
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
Expand Down Expand Up @@ -48,6 +48,8 @@ def parse_args():
help="total timesteps of the experiments")
parser.add_argument("--learning-rate", type=float, default=2.5e-4,
help="the learning rate of the optimizer")
parser.add_argument("--num-envs", type=int, default=1,
help="the number of parallel game environments")
parser.add_argument("--buffer-size", type=int, default=10000,
help="the replay memory buffer size")
parser.add_argument("--gamma", type=float, default=0.99,
Expand All @@ -70,19 +72,21 @@ def parse_args():
help="the frequency of training")
args = parser.parse_args()
# fmt: on
assert args.num_envs == 1, "vectorized envs are not supported at the moment"

return args


def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
env = gym.make(env_id)
if capture_video and idx == 0:
env = gym.make(env_id, render_mode="rgb_array")
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)

return env

return thunk
Expand Down Expand Up @@ -110,6 +114,15 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):


if __name__ == "__main__":
import stable_baselines3 as sb3

if sb3.__version__ < "2.0":
raise ValueError(
"""Ongoing migration: run the following command to install the new dependencies:

poetry run pip install "stable_baselines3==2.0.0a1"
"""
)
args = parse_args()
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
Expand Down Expand Up @@ -139,7 +152,9 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

# env setup
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
envs = gym.vector.SyncVectorEnv(
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
)
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

q_network = QNetwork(envs).to(device)
Expand All @@ -152,12 +167,12 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
envs.single_observation_space,
envs.single_action_space,
device,
handle_timeout_termination=True,
handle_timeout_termination=False,
)
start_time = time.time()

# TRY NOT TO MODIFY: start the game
obs = envs.reset()
obs, _ = envs.reset(seed=args.seed)
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
Expand All @@ -168,23 +183,25 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
actions = torch.argmax(q_values, dim=1).cpu().numpy()

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)
next_obs, rewards, terminated, truncated, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
for info in infos:
if "episode" in info.keys():
if "final_info" in infos:
for info in infos["final_info"]:
# Skip the envs that are not done
if "episode" not in info:
continue
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(dones):
for idx, d in enumerate(truncated):
if d:
real_next_obs[idx] = infos[idx]["terminal_observation"]
rb.add(obs, real_next_obs, actions, rewards, dones, infos)
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminated, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down
53 changes: 35 additions & 18 deletions cleanrl/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
from distutils.util import strtobool

import gym
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
Expand Down Expand Up @@ -55,6 +55,8 @@ def parse_args():
help="total timesteps of the experiments")
parser.add_argument("--learning-rate", type=float, default=1e-4,
help="the learning rate of the optimizer")
parser.add_argument("--num-envs", type=int, default=1,
help="the number of parallel game environments")
parser.add_argument("--buffer-size", type=int, default=1000000,
help="the replay memory buffer size")
parser.add_argument("--gamma", type=float, default=0.99,
Expand All @@ -77,16 +79,19 @@ def parse_args():
help="the frequency of training")
args = parser.parse_args()
# fmt: on
assert args.num_envs == 1, "vectorized envs are not supported at the moment"

return args


def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
env = gym.make(env_id)
if capture_video and idx == 0:
env = gym.make(env_id, render_mode="rgb_array")
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
env = EpisodicLifeEnv(env)
Expand All @@ -96,9 +101,8 @@ def thunk():
env = gym.wrappers.ResizeObservation(env, (84, 84))
env = gym.wrappers.GrayScaleObservation(env)
env = gym.wrappers.FrameStack(env, 4)
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)

return env

return thunk
Expand Down Expand Up @@ -131,6 +135,15 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):


if __name__ == "__main__":
import stable_baselines3 as sb3

if sb3.__version__ < "2.0":
raise ValueError(
"""Ongoing migration: run the following command to install the new dependencies:

poetry run pip install "stable_baselines3==2.0.0a1" "gymnasium[atari,accept-rom-license]==0.28.1" "ale-py==0.8.1"
"""
)
args = parse_args()
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
Expand Down Expand Up @@ -160,7 +173,9 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

# env setup
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
envs = gym.vector.SyncVectorEnv(
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
)
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

q_network = QNetwork(envs).to(device)
Expand All @@ -174,12 +189,12 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
envs.single_action_space,
device,
optimize_memory_usage=True,
handle_timeout_termination=True,
handle_timeout_termination=False,
)
start_time = time.time()

# TRY NOT TO MODIFY: start the game
obs = envs.reset()
obs, _ = envs.reset(seed=args.seed)
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
Expand All @@ -190,23 +205,25 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
actions = torch.argmax(q_values, dim=1).cpu().numpy()

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)
next_obs, rewards, terminated, truncated, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
for info in infos:
if "episode" in info.keys():
if "final_info" in infos:
for info in infos["final_info"]:
# Skip the envs that are not done
if "episode" not in info:
continue
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(dones):
for idx, d in enumerate(truncated):
if d:
real_next_obs[idx] = infos[idx]["terminal_observation"]
rb.add(obs, real_next_obs, actions, rewards, dones, infos)
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminated, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down
Loading