-
Notifications
You must be signed in to change notification settings - Fork 645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update to support Gymnasium #277
Closed
Closed
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
4c39915
update gym version
arjun-kg a7c42fb
update done, reset
arjun-kg 0c0a647
bump ale-py version, dont handle timeout in rb
arjun-kg 35d01ee
pre-commit
arjun-kg 0ddcac7
cache changes
vwxyzjn 2495c5f
Merge branch 'throwaway' into gym_v0.26
vwxyzjn 1d64b5b
pre-commit
vwxyzjn d4dcc60
fix indent
vwxyzjn 4e8f8b8
remove pybullet
vwxyzjn 59e727c
Fix next observation
vwxyzjn 2ae0be5
fix dqn script
vwxyzjn ba8983f
update some docs
vwxyzjn ed68e76
Test API
vwxyzjn 7e8f2db
Merge branch 'master' of https://github.com/vwxyzjn/cleanrl into gym_…
arjun-kg 4a05385
Merge branch 'master' of https://github.com/vwxyzjn/cleanrl into gym_…
arjun-kg f8271fe
Support DM control and make backward compatibility
vwxyzjn ecffa00
Merge branch 'master' of https://github.com/arjun-kg/cleanrl into gym…
arjun-kg 28fd178
Merge branch 'master' of https://github.com/vwxyzjn/cleanrl into gym_…
arjun-kg 813192d
Merge branch 'master' of https://github.com/arjun-kg/cleanrl into gym…
arjun-kg File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,254 @@ | ||
from typing import Dict, Tuple | ||
|
||
import gym | ||
import numpy as np | ||
from gym import spaces | ||
|
||
try: | ||
import cv2 # pytype:disable=import-error | ||
|
||
cv2.ocl.setUseOpenCL(False) | ||
except ImportError: | ||
cv2 = None | ||
|
||
from stable_baselines3.common.type_aliases import Gym26ResetReturn, Gym26StepReturn | ||
|
||
|
||
class NoopResetEnv(gym.Wrapper): | ||
""" | ||
Sample initial states by taking random number of no-ops on reset. | ||
No-op is assumed to be action 0. | ||
|
||
:param env: the environment to wrap | ||
:param noop_max: the maximum value of no-ops to run | ||
""" | ||
|
||
def __init__(self, env: gym.Env, noop_max: int = 30): | ||
gym.Wrapper.__init__(self, env) | ||
self.noop_max = noop_max | ||
self.override_num_noops = None | ||
self.noop_action = 0 | ||
assert env.unwrapped.get_action_meanings()[0] == "NOOP" | ||
|
||
def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: | ||
self.env.reset(**kwargs) | ||
if self.override_num_noops is not None: | ||
noops = self.override_num_noops | ||
else: | ||
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) | ||
assert noops > 0 | ||
obs = np.zeros(0) | ||
info = {} | ||
for _ in range(noops): | ||
obs, _, done, truncated, info = self.env.step(self.noop_action) | ||
if done or truncated: | ||
obs, info = self.env.reset(**kwargs) | ||
return obs, info | ||
|
||
|
||
class FireResetEnv(gym.Wrapper): | ||
""" | ||
Take action on reset for environments that are fixed until firing. | ||
|
||
:param env: the environment to wrap | ||
""" | ||
|
||
def __init__(self, env: gym.Env): | ||
gym.Wrapper.__init__(self, env) | ||
assert env.unwrapped.get_action_meanings()[1] == "FIRE" | ||
assert len(env.unwrapped.get_action_meanings()) >= 3 | ||
|
||
def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: | ||
self.env.reset(**kwargs) | ||
obs, _, done, truncated, _ = self.env.step(1) | ||
if done or truncated: | ||
self.env.reset(**kwargs) | ||
obs, _, done, truncated, _ = self.env.step(2) | ||
if done or truncated: | ||
self.env.reset(**kwargs) | ||
return obs, {} | ||
|
||
|
||
class EpisodicLifeEnv(gym.Wrapper): | ||
""" | ||
Make end-of-life == end-of-episode, but only reset on true game over. | ||
Done by DeepMind for the DQN and co. since it helps value estimation. | ||
|
||
:param env: the environment to wrap | ||
""" | ||
|
||
def __init__(self, env: gym.Env): | ||
gym.Wrapper.__init__(self, env) | ||
self.lives = 0 | ||
self.was_real_done = True | ||
|
||
def step(self, action: int) -> Gym26StepReturn: | ||
obs, reward, done, truncated, info = self.env.step(action) | ||
self.was_real_done = done | ||
# check current lives, make loss of life terminal, | ||
# then update lives to handle bonus lives | ||
lives = self.env.unwrapped.ale.lives() | ||
if 0 < lives < self.lives: | ||
# for Qbert sometimes we stay in lives == 0 condition for a few frames | ||
# so its important to keep lives > 0, so that we only reset once | ||
# the environment advertises done. | ||
done = True | ||
self.lives = lives | ||
return obs, reward, done, truncated, info | ||
|
||
def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: | ||
""" | ||
Calls the Gym environment reset, only when lives are exhausted. | ||
This way all states are still reachable even though lives are episodic, | ||
and the learner need not know about any of this behind-the-scenes. | ||
|
||
:param kwargs: Extra keywords passed to env.reset() call | ||
:return: the first observation of the environment | ||
""" | ||
if self.was_real_done: | ||
obs, info = self.env.reset(**kwargs) | ||
else: | ||
# no-op step to advance from terminal/lost life state | ||
obs, _, _, _, info = self.env.step(0) | ||
self.lives = self.env.unwrapped.ale.lives() | ||
return obs, info | ||
|
||
|
||
class MaxAndSkipEnv(gym.Wrapper): | ||
""" | ||
Return only every ``skip``-th frame (frameskipping) | ||
|
||
:param env: the environment | ||
:param skip: number of ``skip``-th frame | ||
""" | ||
|
||
def __init__(self, env: gym.Env, skip: int = 4): | ||
gym.Wrapper.__init__(self, env) | ||
# most recent raw observations (for max pooling across time steps) | ||
self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=env.observation_space.dtype) | ||
self._skip = skip | ||
|
||
def step(self, action: int) -> Gym26StepReturn: | ||
""" | ||
Step the environment with the given action | ||
Repeat action, sum reward, and max over last observations. | ||
|
||
:param action: the action | ||
:return: observation, reward, done, information | ||
""" | ||
total_reward = 0.0 | ||
terminated = truncated = False | ||
for i in range(self._skip): | ||
obs, reward, terminated, truncated, info = self.env.step(action) | ||
done = terminated or truncated | ||
if i == self._skip - 2: | ||
self._obs_buffer[0] = obs | ||
if i == self._skip - 1: | ||
self._obs_buffer[1] = obs | ||
total_reward += reward | ||
if done: | ||
break | ||
# Note that the observation on the done=True frame | ||
# doesn't matter | ||
max_frame = self._obs_buffer.max(axis=0) | ||
|
||
return max_frame, total_reward, terminated, truncated, info | ||
|
||
def reset(self, **kwargs) -> Gym26ResetReturn: | ||
return self.env.reset(**kwargs) | ||
|
||
|
||
class ClipRewardEnv(gym.RewardWrapper): | ||
""" | ||
Clips the reward to {+1, 0, -1} by its sign. | ||
|
||
:param env: the environment | ||
""" | ||
|
||
def __init__(self, env: gym.Env): | ||
gym.RewardWrapper.__init__(self, env) | ||
|
||
def reward(self, reward: float) -> float: | ||
""" | ||
Bin reward to {+1, 0, -1} by its sign. | ||
|
||
:param reward: | ||
:return: | ||
""" | ||
return np.sign(reward) | ||
|
||
|
||
class WarpFrame(gym.ObservationWrapper): | ||
""" | ||
Convert to grayscale and warp frames to 84x84 (default) | ||
as done in the Nature paper and later work. | ||
|
||
:param env: the environment | ||
:param width: | ||
:param height: | ||
""" | ||
|
||
def __init__(self, env: gym.Env, width: int = 84, height: int = 84): | ||
gym.ObservationWrapper.__init__(self, env) | ||
self.width = width | ||
self.height = height | ||
self.observation_space = spaces.Box( | ||
low=0, high=255, shape=(self.height, self.width, 1), dtype=env.observation_space.dtype | ||
) | ||
|
||
def observation(self, frame: np.ndarray) -> np.ndarray: | ||
""" | ||
returns the current observation from a frame | ||
|
||
:param frame: environment frame | ||
:return: the observation | ||
""" | ||
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) | ||
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA) | ||
return frame[:, :, None] | ||
|
||
|
||
class AtariWrapper(gym.Wrapper): | ||
""" | ||
Atari 2600 preprocessings | ||
|
||
Specifically: | ||
|
||
* NoopReset: obtain initial state by taking random number of no-ops on reset. | ||
* Frame skipping: 4 by default | ||
* Max-pooling: most recent two observations | ||
* Termination signal when a life is lost. | ||
* Resize to a square image: 84x84 by default | ||
* Grayscale observation | ||
* Clip reward to {-1, 0, 1} | ||
|
||
:param env: gym environment | ||
:param noop_max: max number of no-ops | ||
:param frame_skip: the frequency at which the agent experiences the game. | ||
:param screen_size: resize Atari frame | ||
:param terminal_on_life_loss: if True, then step() returns done=True whenever a life is lost. | ||
:param clip_reward: If True (default), the reward is clip to {-1, 0, 1} depending on its sign. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
env: gym.Env, | ||
noop_max: int = 30, | ||
frame_skip: int = 4, | ||
screen_size: int = 84, | ||
terminal_on_life_loss: bool = True, | ||
clip_reward: bool = True, | ||
): | ||
if noop_max > 0: | ||
env = NoopResetEnv(env, noop_max=noop_max) | ||
if frame_skip > 0: | ||
env = MaxAndSkipEnv(env, skip=frame_skip) | ||
if terminal_on_life_loss: | ||
env = EpisodicLifeEnv(env) | ||
if "FIRE" in env.unwrapped.get_action_meanings(): | ||
env = FireResetEnv(env) | ||
env = WarpFrame(env, width=screen_size, height=screen_size) | ||
if clip_reward: | ||
env = ClipRewardEnv(env) | ||
|
||
super().__init__(env) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect!