Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added gymnasium and gym 0.26 support #267

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion notebooks/train_and_export_onnx_example_continuous.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"import yaml\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"import gym\n",
"import gymnasium as gym\n",
"from IPython import display\n",
"import numpy as np\n",
"import onnx\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/train_and_export_onnx_example_discrete.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"import yaml\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"import gym\n",
"import gymnasium as gym\n",
"from IPython import display\n",
"import numpy as np\n",
"import onnx\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
"import yaml\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"import gym\n",
"import gymnasium as gym\n",
"from IPython import display\n",
"import numpy as np\n",
"import onnx\n",
Expand Down
2 changes: 1 addition & 1 deletion rl_games/algos_torch/players.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from rl_games.algos_torch import torch_ext
from rl_games.algos_torch.running_mean_std import RunningMeanStd
from rl_games.common.tr_helpers import unsqueeze_obs
import gym
import gymnasium as gym
import torch
from torch import nn
import numpy as np
Expand Down
4 changes: 2 additions & 2 deletions rl_games/algos_torch/sac_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from rl_games.common import schedulers
from rl_games.common import experience
from rl_games.common.a2c_common import print_statistics

from rl_games.common.env_configurations import patch_env_info
from rl_games.interfaces.base_algorithm import BaseAlgorithm
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
Expand Down Expand Up @@ -103,7 +103,7 @@ def base_init(self, base_name, config):
self.env_info = config.get('env_info')
if self.env_info is None:
self.vec_env = vecenv.create_vec_env(self.env_name, self.num_actors, **self.env_config)
self.env_info = self.vec_env.get_env_info()
self.env_info = patch_env_info(self.vec_env.get_env_info())

self._device = config.get('device', 'cuda:0')

Expand Down
5 changes: 3 additions & 2 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
from rl_games.common.diagnostics import DefaultDiagnostics, PpoDiagnostics
from rl_games.algos_torch import model_builder
from rl_games.interfaces.base_algorithm import BaseAlgorithm
from rl_games.common.env_configurations import patch_env_info
import numpy as np
import time
import gym
import gymnasium as gym

from datetime import datetime
from tensorboardX import SummaryWriter
Expand Down Expand Up @@ -127,7 +128,7 @@ def __init__(self, base_name, params):
self.env_info = config.get('env_info')
if self.env_info is None:
self.vec_env = vecenv.create_vec_env(self.env_name, self.num_actors, **self.env_config)
self.env_info = self.vec_env.get_env_info()
self.env_info = patch_env_info(self.vec_env.get_env_info())
else:
self.vec_env = config.get('vec_env', None)

Expand Down
26 changes: 17 additions & 9 deletions rl_games/common/env_configurations.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import rl_games.envs.test
from rl_games.common import wrappers
from rl_games.common import tr_helpers
from rl_games.common import gymnasium_utils
from rl_games.envs.brax import create_brax_env
from rl_games.envs.envpool import create_envpool
from rl_games.envs.cule import create_cule
import gym
import gymnasium as gym
from gym.wrappers import FlattenObservation, FilterObservation
import numpy as np
import math
Expand Down Expand Up @@ -108,10 +109,10 @@ def create_dm_control_env(**kwargs):
return env

def create_super_mario_env(name='SuperMarioBros-v1'):
import gym
import gymnasium as gym
from nes_py.wrappers import JoypadSpace
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, COMPLEX_MOVEMENT
import gym_super_mario_bros
import gymnasium as gym_super_mario_bros
env = gym_super_mario_bros.make(name)
env = JoypadSpace(env, SIMPLE_MOVEMENT)

Expand All @@ -120,11 +121,11 @@ def create_super_mario_env(name='SuperMarioBros-v1'):
return env

def create_super_mario_env_stage1(name='SuperMarioBrosRandomStage1-v1'):
import gym
import gymnasium as gym
from nes_py.wrappers import JoypadSpace
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, COMPLEX_MOVEMENT

import gym_super_mario_bros
import gymnasium as gym_super_mario_bros
stage_names = [
'SuperMarioBros-1-1-v1',
'SuperMarioBros-1-2-v1',
Expand All @@ -142,13 +143,13 @@ def create_super_mario_env_stage1(name='SuperMarioBrosRandomStage1-v1'):
return env

def create_quadrupped_env():
import gym
import gymnasium as gym
import roboschool
import quadruppedEnv
return wrappers.FrameStack(wrappers.MaxAndSkipEnv(gym.make('QuadruppedWalk-v1'), 4, False), 2, True)

def create_roboschool_env(name):
import gym
import gymnasium as gym
import roboschool
return gym.make(name)

Expand Down Expand Up @@ -441,15 +442,22 @@ def get_env_info(env):
if hasattr(env, "value_size"):
result_shapes['value_size'] = env.value_size
print(result_shapes)
return result_shapes
return patch_env_info(result_shapes)

def get_obs_and_action_spaces_from_config(config):
env_config = config.get('env_config', {})
env = configurations[config['env_name']]['env_creator'](**env_config)
result_shapes = get_env_info(env)
env.close()
return result_shapes
return patch_env_info(result_shapes)


def patch_env_info(env_info):
env_info['observation_space'] = gymnasium_utils.convert_space(env_info['observation_space'] )
env_info['action_space'] = gymnasium_utils.convert_space(env_info['action_space'] )
if 'state_space' in env_info:
env_info['state_space'] = gymnasium_utils.convert_space(env_info['state_space'] )
return env_info

def register(name, config):
configurations[name] = config
2 changes: 1 addition & 1 deletion rl_games/common/experience.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import random
import gym
import gymnasium as gym
import torch
from rl_games.common.segment_tree import SumSegmentTree, MinSegmentTree
import torch
Expand Down
131 changes: 131 additions & 0 deletions rl_games/common/gymnasium_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""
Compatibility layer for Gym -> Gymnasium transition.
Adapted from Stable Baselines3, Tianshou, and Shimmy https://github.com/Farama-Foundation/Shimmy
Thanks to @alex-petrenko
"""

import warnings
from inspect import signature
from typing import Union

import gymnasium

try:
import gym # pytype: disable=import-error

gym_installed = True
except ImportError:
gym_installed = False


def patch_non_gymnasium_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env:
env = _patch_env(env)

try:
# patching spaces
if not isinstance(env.observation_space, gymnasium.Space):
env.observation_space = convert_space(env.observation_space)
if not isinstance(env.action_space, gymnasium.Space):
env.action_space = convert_space(env.action_space)
except AttributeError:
# gym.Env does not have observation_space and action_space or they're defined as properties
# in this case... God bless us all
log.warning("Could not patch spaces for the environment. Consider switching to Gymnasium API.")

return env


def _patch_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env:
"""
Adapted from https://github.com/thu-ml/tianshou.

Takes an environment and patches it to return Gymnasium env.
This function takes the environment object and returns a patched
env, using shimmy wrapper to convert it to Gymnasium,
if necessary.

:param env: A gym/gymnasium env
:return: Patched env (gymnasium env)
"""

# Gymnasium env, no patching to be done
if isinstance(env, gymnasium.Env):
return env

if not gym_installed or not isinstance(env, gym.Env):
raise ValueError(
f"The environment is of type {type(env)}, not a Gymnasium "
f"environment. In this case, we expect OpenAI Gym to be "
f"installed and the environment to be an OpenAI Gym environment."
)

try:
import shimmy
except ImportError as e:
raise ImportError(
"Missing shimmy installation. You are using an OpenAI Gym environment. "
"Sample Factory has transitioned to using Gymnasium internally. "
"In order to use OpenAI Gym environments with SF, you need to "
"install shimmy (`pip install 'shimmy>=0.2.1'`)."
) from e

warnings.warn(
"You provided an OpenAI Gym environment. "
"We strongly recommend transitioning to Gymnasium environments. "
"Sample Factory is automatically wrapping your environments in a compatibility "
"layer, which could potentially cause issues."
)

if "seed" in signature(env.unwrapped.reset).parameters:
# Gym 0.26+ env
gymnasium_env = shimmy.GymV26CompatibilityV0(env=env)
else:
# Gym 0.21 env
gymnasium_env = shimmy.GymV21CompatibilityV0(env=env)

# preserving potential multi-agent env attributes
if hasattr(env, "num_agents"):
gymnasium_env.num_agents = env.num_agents
if hasattr(env, "is_multiagent"):
gymnasium_env.is_multiagent = env.is_multiagent

return gymnasium_env


def convert_space(space: Union["gym.Space", gymnasium.Space]) -> gymnasium.Space: # pragma: no cover
"""
Takes a space and patches it to return Gymnasium Space.
This function takes the space object and returns a patched
space, using shimmy wrapper to convert it to Gymnasium,
if necessary.

:param space: A gym/gymnasium Space
:return: Patched space (gymnasium Space)
"""
if space is None:
return None
# Gymnasium space, no convertion to be done
if isinstance(space, gymnasium.Space):
return space

if not gym_installed or not isinstance(space, gym.Space):
raise ValueError(
f"The space is of type {type(space)}, not a Gymnasium "
f"space. In this case, we expect OpenAI Gym to be "
f"installed and the space to be an OpenAI Gym space."
)

try:
import shimmy # pytype: disable=import-error
except ImportError as e:
raise ImportError(
"Missing shimmy installation. You provided an OpenAI Gym space. "
"Sample Factory has transitioned to using Gymnasium internally. "
"In order to use OpenAI Gym space with Sample Factory, you need to "
"install shimmy (`pip install 'shimmy>=0.2.1'`)."
) from e

return shimmy.openai_gym_compatibility._convert_space(space)



8 changes: 4 additions & 4 deletions rl_games/common/player.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import shutil
import threading
import time
import gym
import gymnasium as gym
import numpy as np
import torch
import copy
Expand All @@ -11,7 +11,7 @@
from rl_games.common import vecenv
from rl_games.common import env_configurations
from rl_games.algos_torch import model_builder

from rl_games.common.env_configurations import patch_env_info

class BasePlayer(object):

Expand All @@ -32,11 +32,11 @@ def __init__(self, params):
print('[BasePlayer] Creating vecenv: ', self.env_name)
self.env = vecenv.create_vec_env(
self.env_name, self.config['num_actors'], **self.env_config)
self.env_info = self.env.get_env_info()
self.env_info = patch_env_info(self.env.get_env_info())
else:
print('[BasePlayer] Creating regular env: ', self.env_name)
self.env = self.create_env()
self.env_info = env_configurations.get_env_info(self.env)
self.env_info = patch_env_info(env_configurations.get_env_info(self.env))
else:
self.env = config.get('vec_env')

Expand Down
Loading