Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to support Gymnasium #277

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ jobs:
run: poetry install --with pytest,procgen
- name: Downgrade setuptools
run: poetry run pip install setuptools==59.5.0
- name: Run pybullet tests
- name: Run procgen tests
run: poetry run pytest tests/test_procgen.py

test-mujoco-envs:
Expand Down
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Good luck and have fun!
```bash
poetry install
poetry install --with atari
poetry install --with pybullet
poetry install --with mujoco
```

Then you can run the scripts under the poetry environment in two ways: `poetry run` or `poetry shell`.
Expand Down
20 changes: 15 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@
[<img src="https://img.shields.io/badge/%F0%9F%A4%97%20Models-Huggingface-F8D521">](https://huggingface.co/cleanrl)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vwxyzjn/cleanrl/blob/master/docs/get-started/CleanRL_Huggingface_Integration_Demo.ipynb)

# ⚠️ NOTE: Gym 0.26.1 Migration

This branch is an ongoing effort to integrate the latest gym into CleanRL. Checkout [vwxyzjn/cleanrl#277](https://github.com/vwxyzjn/cleanrl/pull/277) for the current progress.

Things that work:
* `dqn.py`
* `dqn_jax.py`
* `ppo.py`

----------

CleanRL is a Deep Reinforcement Learning library that provides high-quality single-file implementation with research-friendly features. The implementation is clean and simple, yet we can scale it to run thousands of experiments using AWS Batch. The highlight features of CleanRL are:

Expand Down Expand Up @@ -93,11 +103,11 @@ python cleanrl/ppo_atari_envpool.py --env-id BreakoutNoFrameskip-v4
# Side effects such as lower sample efficiency might occur
poetry run python ppo_atari_envpool.py --clip-coef=0.2 --num-envs=16 --num-minibatches=8 --num-steps=128 --update-epochs=3

# pybullet
poetry install --with pybullet
python cleanrl/td3_continuous_action.py --env-id MinitaurBulletDuckEnv-v0
python cleanrl/ddpg_continuous_action.py --env-id MinitaurBulletDuckEnv-v0
python cleanrl/sac_continuous_action.py --env-id MinitaurBulletDuckEnv-v0
# mujoco
poetry install --with mujoco
python cleanrl/td3_continuous_action.py --env-id HalfCheetah-v4
python cleanrl/ddpg_continuous_action.py --env-id HalfCheetah-v4
python cleanrl/sac_continuous_action.py --env-id HalfCheetah-v4

# procgen
poetry install --with procgen
Expand Down
254 changes: 254 additions & 0 deletions cleanrl/atari_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
from typing import Dict, Tuple

import gym
import numpy as np
from gym import spaces

try:
import cv2 # pytype:disable=import-error

cv2.ocl.setUseOpenCL(False)
except ImportError:
cv2 = None

from stable_baselines3.common.type_aliases import Gym26ResetReturn, Gym26StepReturn


class NoopResetEnv(gym.Wrapper):
"""
Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0.

:param env: the environment to wrap
:param noop_max: the maximum value of no-ops to run
"""

def __init__(self, env: gym.Env, noop_max: int = 30):
gym.Wrapper.__init__(self, env)
self.noop_max = noop_max
self.override_num_noops = None
self.noop_action = 0
assert env.unwrapped.get_action_meanings()[0] == "NOOP"

def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]:
self.env.reset(**kwargs)
if self.override_num_noops is not None:
noops = self.override_num_noops
else:
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
assert noops > 0
obs = np.zeros(0)
info = {}
for _ in range(noops):
obs, _, done, truncated, info = self.env.step(self.noop_action)
if done or truncated:
obs, info = self.env.reset(**kwargs)
return obs, info


class FireResetEnv(gym.Wrapper):
"""
Take action on reset for environments that are fixed until firing.

:param env: the environment to wrap
"""

def __init__(self, env: gym.Env):
gym.Wrapper.__init__(self, env)
assert env.unwrapped.get_action_meanings()[1] == "FIRE"
assert len(env.unwrapped.get_action_meanings()) >= 3

def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]:
self.env.reset(**kwargs)
obs, _, done, truncated, _ = self.env.step(1)
if done or truncated:
self.env.reset(**kwargs)
obs, _, done, truncated, _ = self.env.step(2)
if done or truncated:
self.env.reset(**kwargs)
return obs, {}


class EpisodicLifeEnv(gym.Wrapper):
"""
Make end-of-life == end-of-episode, but only reset on true game over.
Done by DeepMind for the DQN and co. since it helps value estimation.

:param env: the environment to wrap
"""

def __init__(self, env: gym.Env):
gym.Wrapper.__init__(self, env)
self.lives = 0
self.was_real_done = True

def step(self, action: int) -> Gym26StepReturn:
obs, reward, done, truncated, info = self.env.step(action)
self.was_real_done = done
# check current lives, make loss of life terminal,
# then update lives to handle bonus lives
lives = self.env.unwrapped.ale.lives()
if 0 < lives < self.lives:
# for Qbert sometimes we stay in lives == 0 condition for a few frames
# so its important to keep lives > 0, so that we only reset once
# the environment advertises done.
done = True
self.lives = lives
return obs, reward, done, truncated, info

def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]:
"""
Calls the Gym environment reset, only when lives are exhausted.
This way all states are still reachable even though lives are episodic,
and the learner need not know about any of this behind-the-scenes.

:param kwargs: Extra keywords passed to env.reset() call
:return: the first observation of the environment
"""
if self.was_real_done:
obs, info = self.env.reset(**kwargs)
else:
# no-op step to advance from terminal/lost life state
obs, _, _, _, info = self.env.step(0)
self.lives = self.env.unwrapped.ale.lives()
return obs, info


class MaxAndSkipEnv(gym.Wrapper):
"""
Return only every ``skip``-th frame (frameskipping)

:param env: the environment
:param skip: number of ``skip``-th frame
"""

def __init__(self, env: gym.Env, skip: int = 4):
gym.Wrapper.__init__(self, env)
# most recent raw observations (for max pooling across time steps)
self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=env.observation_space.dtype)
self._skip = skip

def step(self, action: int) -> Gym26StepReturn:
"""
Step the environment with the given action
Repeat action, sum reward, and max over last observations.

:param action: the action
:return: observation, reward, done, information
"""
total_reward = 0.0
terminated = truncated = False
for i in range(self._skip):
obs, reward, terminated, truncated, info = self.env.step(action)
done = terminated or truncated
if i == self._skip - 2:
self._obs_buffer[0] = obs
if i == self._skip - 1:
self._obs_buffer[1] = obs
total_reward += reward
if done:
break
# Note that the observation on the done=True frame
# doesn't matter
max_frame = self._obs_buffer.max(axis=0)

return max_frame, total_reward, terminated, truncated, info

def reset(self, **kwargs) -> Gym26ResetReturn:
return self.env.reset(**kwargs)


class ClipRewardEnv(gym.RewardWrapper):
"""
Clips the reward to {+1, 0, -1} by its sign.

:param env: the environment
"""

def __init__(self, env: gym.Env):
gym.RewardWrapper.__init__(self, env)

def reward(self, reward: float) -> float:
"""
Bin reward to {+1, 0, -1} by its sign.

:param reward:
:return:
"""
return np.sign(reward)


class WarpFrame(gym.ObservationWrapper):
"""
Convert to grayscale and warp frames to 84x84 (default)
as done in the Nature paper and later work.

:param env: the environment
:param width:
:param height:
"""

def __init__(self, env: gym.Env, width: int = 84, height: int = 84):
gym.ObservationWrapper.__init__(self, env)
self.width = width
self.height = height
self.observation_space = spaces.Box(
low=0, high=255, shape=(self.height, self.width, 1), dtype=env.observation_space.dtype
)

def observation(self, frame: np.ndarray) -> np.ndarray:
"""
returns the current observation from a frame

:param frame: environment frame
:return: the observation
"""
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
return frame[:, :, None]


class AtariWrapper(gym.Wrapper):
"""
Atari 2600 preprocessings

Specifically:

* NoopReset: obtain initial state by taking random number of no-ops on reset.
* Frame skipping: 4 by default
* Max-pooling: most recent two observations
* Termination signal when a life is lost.
* Resize to a square image: 84x84 by default
* Grayscale observation
* Clip reward to {-1, 0, 1}

:param env: gym environment
:param noop_max: max number of no-ops
:param frame_skip: the frequency at which the agent experiences the game.
:param screen_size: resize Atari frame
:param terminal_on_life_loss: if True, then step() returns done=True whenever a life is lost.
:param clip_reward: If True (default), the reward is clip to {-1, 0, 1} depending on its sign.
"""

def __init__(
self,
env: gym.Env,
noop_max: int = 30,
frame_skip: int = 4,
screen_size: int = 84,
terminal_on_life_loss: bool = True,
clip_reward: bool = True,
):
if noop_max > 0:
env = NoopResetEnv(env, noop_max=noop_max)
if frame_skip > 0:
env = MaxAndSkipEnv(env, skip=frame_skip)
if terminal_on_life_loss:
env = EpisodicLifeEnv(env)
if "FIRE" in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = WarpFrame(env, width=screen_size, height=screen_size)
if clip_reward:
env = ClipRewardEnv(env)

super().__init__(env)
38 changes: 20 additions & 18 deletions cleanrl/c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def thunk():
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
env.seed(seed)

env.action_space.seed(seed)
env.observation_space.seed(seed)
return env
Expand Down Expand Up @@ -165,12 +165,12 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
envs.single_observation_space,
envs.single_action_space,
device,
handle_timeout_termination=True,
handle_timeout_termination=False,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect!

)
start_time = time.time()

# TRY NOT TO MODIFY: start the game
obs = envs.reset()
obs, _ = envs.reset(seed=args.seed)
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
Expand All @@ -181,23 +181,25 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
actions = actions.cpu().numpy()

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)
next_obs, rewards, terminateds, _, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
for info in infos:
if "episode" in info.keys():
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(dones):
if d:
real_next_obs[idx] = infos[idx]["terminal_observation"]
rb.add(obs, real_next_obs, actions, rewards, dones, infos)
if "episode" in infos:
first_idx = infos["_episode"].nonzero()[0][0]
r = infos["episode"]["r"][first_idx]
l = infos["episode"]["l"][first_idx]
print(f"global_step={global_step}, episodic_return={r}")
writer.add_scalar("charts/episodic_return", r, global_step)
writer.add_scalar("charts/episodic_length", l, global_step)

# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs
if "final_observation" in infos:
real_next_obs = next_obs.copy()
for idx, d in enumerate(infos["_final_observation"]):
if d:
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminateds, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down
Loading