-
Notifications
You must be signed in to change notification settings - Fork 148
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Create a Gymnasium and GymnasiumAtari environments - Adapt some mushroom_rl Gym examples to Gymnasium - Add headless argument to ImageViewer
- Loading branch information
1 parent
e92d3d5
commit cc0dfe3
Showing
12 changed files
with
361 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
from copy import deepcopy | ||
from collections import deque | ||
|
||
import gymnasium as gym | ||
|
||
from mushroom_rl.core import Environment, MDPInfo | ||
from mushroom_rl.rl_utils.spaces import * | ||
from mushroom_rl.utils.frames import LazyFrames, preprocess_frame | ||
from mushroom_rl.utils.viewer import ImageViewer | ||
|
||
class MaxAndSkip(gym.Wrapper): | ||
def __init__(self, env, skip, max_pooling=True): | ||
gym.Wrapper.__init__(self, env) | ||
self._obs_buffer = np.zeros((2,) + env.observation_space.shape, | ||
dtype=np.uint8) | ||
self._skip = skip | ||
self._max_pooling = max_pooling | ||
|
||
def reset(self): | ||
return self.env.reset() | ||
|
||
def step(self, action): | ||
total_reward = 0. | ||
for i in range(self._skip): | ||
obs, reward, absorbing, _, info = self.env.step(action) | ||
if i == self._skip - 2: | ||
self._obs_buffer[0] = obs | ||
if i == self._skip - 1: | ||
self._obs_buffer[1] = obs | ||
total_reward += reward | ||
if absorbing: | ||
break | ||
if self._max_pooling: | ||
frame = self._obs_buffer.max(axis=0) | ||
else: | ||
frame = self._obs_buffer.mean(axis=0) | ||
|
||
return frame, total_reward, absorbing, info | ||
|
||
def reset(self, **kwargs): | ||
return self.env.reset(**kwargs) | ||
|
||
|
||
class GymnasiumAtari(Environment): | ||
""" | ||
The Atari environment as presented in: | ||
"Human-level control through deep reinforcement learning". Mnih et. al.. | ||
2015. | ||
""" | ||
def __init__(self, name, width=84, height=84, ends_at_life=False, | ||
max_pooling=True, history_length=4, max_no_op_actions=30, headless = False): | ||
""" | ||
Constructor. | ||
Args: | ||
name (str): id name of the Atari game in Gym; | ||
width (int, 84): width of the screen; | ||
height (int, 84): height of the screen; | ||
ends_at_life (bool, False): whether the episode ends when a life is | ||
lost or not; | ||
max_pooling (bool, True): whether to do max-pooling or | ||
average-pooling of the last two frames when using NoFrameskip; | ||
history_length (int, 4): number of frames to form a state; | ||
max_no_op_actions (int, 30): maximum number of no-op action to | ||
execute at the beginning of an episode. | ||
headless (bool, False): If True, the rendering is forced to be headless. | ||
""" | ||
# MPD creation | ||
if 'NoFrameskip' in name: | ||
self.env = MaxAndSkip(gym.make(name, render_mode='rgb_array'), history_length, max_pooling) | ||
else: | ||
self.env = gym.make(name, render_mode='rgb_array') | ||
|
||
# MDP parameters | ||
self._headless = headless | ||
self._img_size = (width, height) | ||
self._episode_ends_at_life = ends_at_life | ||
self._max_lives = self.env.unwrapped.ale.lives() | ||
self._lives = self._max_lives | ||
self._force_fire = None | ||
self._real_reset = True | ||
self._max_no_op_actions = max_no_op_actions | ||
self._history_length = history_length | ||
self._current_no_op = None | ||
|
||
assert self.env.unwrapped.get_action_meanings()[0] == 'NOOP' | ||
|
||
# MDP properties | ||
action_space = Discrete(self.env.action_space.n) | ||
observation_space = Box( | ||
low=0., high=255., shape=(history_length, self._img_size[1], self._img_size[0])) | ||
horizon = 1e4 # instead of np.inf | ||
gamma = .99 | ||
dt = 1/60 | ||
mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, dt) | ||
|
||
# Viewer | ||
self._viewer = ImageViewer((self._img_size[1], self._img_size[0]), dt, headless=self._headless) | ||
|
||
super().__init__(mdp_info) | ||
|
||
def reset(self, state=None): | ||
if self._real_reset: | ||
state, info = self.env.reset() | ||
self._state = preprocess_frame(state, self._img_size) | ||
self._state = deque([deepcopy( | ||
self._state) for _ in range(self._history_length)], | ||
maxlen=self._history_length | ||
) | ||
self._lives = self._max_lives | ||
|
||
self._force_fire = self.env.unwrapped.get_action_meanings()[1] == 'FIRE' | ||
|
||
self._current_no_op = np.random.randint(self._max_no_op_actions + 1) | ||
|
||
return LazyFrames(list(self._state), self._history_length), info | ||
|
||
def step(self, action): | ||
action = action[0] | ||
|
||
# Force FIRE action to start episodes in games with lives | ||
if self._force_fire: | ||
obs, _, _, _, _ = self.env.env.step(1) | ||
self._force_fire = False | ||
while self._current_no_op > 0: | ||
obs, _, _, _, _ = self.env.env.step(0) | ||
self._current_no_op -= 1 | ||
|
||
obs, reward, absorbing, _, info = self.env.step(action) | ||
self._real_reset = absorbing | ||
|
||
if info['lives'] != self._lives: | ||
if self._episode_ends_at_life: | ||
absorbing = True | ||
self._lives = info['lives'] | ||
self._force_fire = self.env.unwrapped.get_action_meanings()[1] == 'FIRE' | ||
|
||
self._state.append(preprocess_frame(obs, self._img_size)) | ||
|
||
return LazyFrames(list(self._state), self._history_length), reward, absorbing, info | ||
|
||
def render(self, record=False): | ||
img = self.env.render() | ||
|
||
self._viewer.display(img) | ||
|
||
if record: | ||
return img | ||
else: | ||
return None | ||
|
||
def stop(self): | ||
self.env.close() | ||
self._viewer.close() | ||
self._real_reset = True | ||
|
||
def set_episode_end(self, ends_at_life): | ||
""" | ||
Setter. | ||
Args: | ||
ends_at_life (bool): whether the episode ends when a life is | ||
lost or not. | ||
""" | ||
self._episode_ends_at_life = ends_at_life |
Oops, something went wrong.