Skip to content

Commit

Permalink
Integrate Gymnasium
Browse files Browse the repository at this point in the history
- Create a Gymnasium and GymnasiumAtari environments
- Adapt some mushroom_rl Gym examples to Gymnasium
- Add headless argument to ImageViewer
  • Loading branch information
AhmedMagdyHendawy committed Feb 5, 2024
1 parent e92d3d5 commit cc0dfe3
Show file tree
Hide file tree
Showing 12 changed files with 361 additions and 17 deletions.
4 changes: 2 additions & 2 deletions examples/acrobot_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from mushroom_rl.algorithms.actor_critic import A2C
from mushroom_rl.core import Core, Logger
from mushroom_rl.environments import Gym
from mushroom_rl.environments import Gymnasium
from mushroom_rl.policy import BoltzmannTorchPolicy
from mushroom_rl.approximators.parametric.torch_approximator import *
from mushroom_rl.rl_utils.parameters import Parameter
Expand Down Expand Up @@ -47,7 +47,7 @@ def experiment(n_epochs, n_steps, n_steps_per_fit, n_step_test):
# MDP
horizon = 1000
gamma = 0.99
mdp = Gym('Acrobot-v1', horizon, gamma)
mdp = Gymnasium('Acrobot-v1', horizon, gamma, headless=False)

# Policy
policy_params = dict(
Expand Down
2 changes: 1 addition & 1 deletion examples/acrobot_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def experiment(n_epochs, n_steps, n_steps_test):
# MDP
horizon = 1000
gamma = 0.99
mdp = Gym('Acrobot-v1', horizon, gamma)
mdp = Gymnasium('Acrobot-v1', horizon, gamma, headless=False)

# Policy
epsilon = LinearParameter(value=1., threshold_value=.01, n=5000)
Expand Down
6 changes: 3 additions & 3 deletions examples/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,9 @@ def experiment():
max_steps = args.max_steps

# MDP
mdp = Atari(args.name, args.screen_width, args.screen_height,
mdp = GymnasiumAtari(args.name, args.screen_width, args.screen_height,
ends_at_life=True, history_length=args.history_length,
max_no_op_actions=args.max_no_op_actions)
max_no_op_actions=args.max_no_op_actions, headless=False)

if args.load_path:
logger = Logger(DQN.__name__, results_dir=None)
Expand Down Expand Up @@ -408,7 +408,7 @@ def experiment():
pi.set_epsilon(epsilon_test)
mdp.set_episode_end(False)
dataset = core.evaluate(n_steps=test_samples, render=args.render,
quiet=args.quiet)
quiet=args.quiet, record=True)
scores.append(get_stats(dataset, logger))

np.save(folder_name + '/scores.npy', scores)
Expand Down
4 changes: 2 additions & 2 deletions examples/mountain_car_sarsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from mushroom_rl.algorithms.value import TrueOnlineSARSALambda
from mushroom_rl.core import Core, Logger
from mushroom_rl.environments import Gym
from mushroom_rl.environments import Gymnasium
from mushroom_rl.features import Features
from mushroom_rl.features.tiles import Tiles
from mushroom_rl.policy import EpsGreedy
Expand All @@ -21,7 +21,7 @@ def experiment(alpha):
np.random.seed()

# MDP
mdp = Gym(name='MountainCar-v0', horizon=np.inf, gamma=1.)
mdp = Gymnasium(name='MountainCar-v0', horizon=int(1e4), gamma=1., headless=False)

# Policy
epsilon = Parameter(value=0.)
Expand Down
4 changes: 2 additions & 2 deletions examples/pendulum_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tqdm import trange

from mushroom_rl.core import Core, Logger
from mushroom_rl.environments import Gym
from mushroom_rl.environments import Gymnasium
from mushroom_rl.algorithms.actor_critic import A2C

from mushroom_rl.policy import GaussianTorchPolicy
Expand Down Expand Up @@ -45,7 +45,7 @@ def experiment(alg, env_id, horizon, gamma, n_epochs, n_steps, n_steps_per_fit,
logger.strong_line()
logger.info('Experiment Algorithm: ' + A2C.__name__)

mdp = Gym(env_id, horizon, gamma)
mdp = Gymnasium(env_id, horizon, gamma, headless=False)

critic_params = dict(network=Network,
optimizer={'class': optim.RMSprop,
Expand Down
4 changes: 2 additions & 2 deletions examples/pendulum_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from mushroom_rl.algorithms.actor_critic import DDPG, TD3
from mushroom_rl.core import Core, Logger
from mushroom_rl.environments.gym_env import Gym
from mushroom_rl.environments import Gymnasium
from mushroom_rl.policy import OrnsteinUhlenbeckPolicy

from tqdm import trange
Expand Down Expand Up @@ -76,7 +76,7 @@ def experiment(alg, n_epochs, n_steps, n_steps_test):
# MDP
horizon = 200
gamma = 0.99
mdp = Gym('Pendulum-v1', horizon, gamma)
mdp = Gymnasium('Pendulum-v1', horizon, gamma, headless=False)

# Policy
policy_class = OrnsteinUhlenbeckPolicy
Expand Down
4 changes: 2 additions & 2 deletions examples/pendulum_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from mushroom_rl.algorithms.actor_critic import SAC
from mushroom_rl.core import Core, Logger
from mushroom_rl.environments.gym_env import Gym
from mushroom_rl.environments import Gymnasium
from mushroom_rl.utils import TorchUtils

from tqdm import trange
Expand Down Expand Up @@ -76,7 +76,7 @@ def experiment(alg, n_epochs, n_steps, n_steps_test, save, load):
# MDP
horizon = 200
gamma = 0.99
mdp = Gym('Pendulum-v1', horizon, gamma)
mdp = Gymnasium('Pendulum-v1', horizon, gamma, headless=False)

# Settings
initial_replay_size = 64
Expand Down
4 changes: 2 additions & 2 deletions examples/pendulum_trust_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tqdm import trange

from mushroom_rl.core import Core, Logger
from mushroom_rl.environments import Gym
from mushroom_rl.environments import Gymnasium
from mushroom_rl.algorithms.actor_critic import PPO, TRPO

from mushroom_rl.policy import GaussianTorchPolicy
Expand Down Expand Up @@ -43,7 +43,7 @@ def experiment(alg, env_id, horizon, gamma, n_epochs, n_steps, n_steps_per_fit,
logger.strong_line()
logger.info('Experiment Algorithm: ' + alg.__name__)

mdp = Gym(env_id, horizon, gamma)
mdp = Gymnasium(env_id, horizon, gamma, headless=False)

critic_params = dict(network=Network,
optimizer={'class': optim.Adam,
Expand Down
14 changes: 14 additions & 0 deletions mushroom_rl/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,20 @@
except ImportError:
pass

try:
GymnasiumAtari = None
from .gymnasium_atari import GymnasiumAtari
GymnasiumAtari.register()
except ImportError:
pass

try:
Gymnasium = None
from .gymnasium_env import Gymnasium
Gymnasium.register()
except ImportError:
pass

try:
DMControl = None
from .dm_control_env import DMControl
Expand Down
168 changes: 168 additions & 0 deletions mushroom_rl/environments/gymnasium_atari.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from copy import deepcopy
from collections import deque

import gymnasium as gym

from mushroom_rl.core import Environment, MDPInfo
from mushroom_rl.rl_utils.spaces import *
from mushroom_rl.utils.frames import LazyFrames, preprocess_frame
from mushroom_rl.utils.viewer import ImageViewer

class MaxAndSkip(gym.Wrapper):
def __init__(self, env, skip, max_pooling=True):
gym.Wrapper.__init__(self, env)
self._obs_buffer = np.zeros((2,) + env.observation_space.shape,
dtype=np.uint8)
self._skip = skip
self._max_pooling = max_pooling

def reset(self):
return self.env.reset()

def step(self, action):
total_reward = 0.
for i in range(self._skip):
obs, reward, absorbing, _, info = self.env.step(action)
if i == self._skip - 2:
self._obs_buffer[0] = obs
if i == self._skip - 1:
self._obs_buffer[1] = obs
total_reward += reward
if absorbing:
break
if self._max_pooling:
frame = self._obs_buffer.max(axis=0)
else:
frame = self._obs_buffer.mean(axis=0)

return frame, total_reward, absorbing, info

def reset(self, **kwargs):
return self.env.reset(**kwargs)


class GymnasiumAtari(Environment):
"""
The Atari environment as presented in:
"Human-level control through deep reinforcement learning". Mnih et. al..
2015.
"""
def __init__(self, name, width=84, height=84, ends_at_life=False,
max_pooling=True, history_length=4, max_no_op_actions=30, headless = False):
"""
Constructor.
Args:
name (str): id name of the Atari game in Gym;
width (int, 84): width of the screen;
height (int, 84): height of the screen;
ends_at_life (bool, False): whether the episode ends when a life is
lost or not;
max_pooling (bool, True): whether to do max-pooling or
average-pooling of the last two frames when using NoFrameskip;
history_length (int, 4): number of frames to form a state;
max_no_op_actions (int, 30): maximum number of no-op action to
execute at the beginning of an episode.
headless (bool, False): If True, the rendering is forced to be headless.
"""
# MPD creation
if 'NoFrameskip' in name:
self.env = MaxAndSkip(gym.make(name, render_mode='rgb_array'), history_length, max_pooling)
else:
self.env = gym.make(name, render_mode='rgb_array')

# MDP parameters
self._headless = headless
self._img_size = (width, height)
self._episode_ends_at_life = ends_at_life
self._max_lives = self.env.unwrapped.ale.lives()
self._lives = self._max_lives
self._force_fire = None
self._real_reset = True
self._max_no_op_actions = max_no_op_actions
self._history_length = history_length
self._current_no_op = None

assert self.env.unwrapped.get_action_meanings()[0] == 'NOOP'

# MDP properties
action_space = Discrete(self.env.action_space.n)
observation_space = Box(
low=0., high=255., shape=(history_length, self._img_size[1], self._img_size[0]))
horizon = 1e4 # instead of np.inf
gamma = .99
dt = 1/60
mdp_info = MDPInfo(observation_space, action_space, gamma, horizon, dt)

# Viewer
self._viewer = ImageViewer((self._img_size[1], self._img_size[0]), dt, headless=self._headless)

super().__init__(mdp_info)

def reset(self, state=None):
if self._real_reset:
state, info = self.env.reset()
self._state = preprocess_frame(state, self._img_size)
self._state = deque([deepcopy(
self._state) for _ in range(self._history_length)],
maxlen=self._history_length
)
self._lives = self._max_lives

self._force_fire = self.env.unwrapped.get_action_meanings()[1] == 'FIRE'

self._current_no_op = np.random.randint(self._max_no_op_actions + 1)

return LazyFrames(list(self._state), self._history_length), info

def step(self, action):
action = action[0]

# Force FIRE action to start episodes in games with lives
if self._force_fire:
obs, _, _, _, _ = self.env.env.step(1)
self._force_fire = False
while self._current_no_op > 0:
obs, _, _, _, _ = self.env.env.step(0)
self._current_no_op -= 1

obs, reward, absorbing, _, info = self.env.step(action)
self._real_reset = absorbing

if info['lives'] != self._lives:
if self._episode_ends_at_life:
absorbing = True
self._lives = info['lives']
self._force_fire = self.env.unwrapped.get_action_meanings()[1] == 'FIRE'

self._state.append(preprocess_frame(obs, self._img_size))

return LazyFrames(list(self._state), self._history_length), reward, absorbing, info

def render(self, record=False):
img = self.env.render()

self._viewer.display(img)

if record:
return img
else:
return None

def stop(self):
self.env.close()
self._viewer.close()
self._real_reset = True

def set_episode_end(self, ends_at_life):
"""
Setter.
Args:
ends_at_life (bool): whether the episode ends when a life is
lost or not.
"""
self._episode_ends_at_life = ends_at_life
Loading

0 comments on commit cc0dfe3

Please sign in to comment.