Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gym-Gymnasium compatibility converter #61

Merged
merged 16 commits into from
Oct 20, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions gymnasium/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,11 @@
entry_point="gymnasium.envs.mujoco.humanoidstandup_v4:HumanoidStandupEnv",
max_episode_steps=1000,
)


# Gym conversion
# ----------------------------------------
register(
id="GymV26Environment-v0",
entry_point="gymnasium.envs.external.gym_env:GymEnvironment",
)
Empty file.
212 changes: 212 additions & 0 deletions gymnasium/envs/external/gym_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
from typing import Optional, Tuple

import gymnasium
from gymnasium import error
from gymnasium.core import ActType, ObsType

try:
import gym
except ImportError as e:
# gym = None
GYM_IMPORT_ERROR = e
else:
GYM_IMPORT_ERROR = None


class GymEnvironment(gymnasium.Env):
"""
Converts a gym environment to a gymnasium environment.
"""

def __init__(
self,
env_id: Optional[str] = None,
make_kwargs: Optional[dict] = None,
env: Optional["gym.Env"] = None,
):
if GYM_IMPORT_ERROR is not None:
raise error.DependencyNotInstalled(
f"{GYM_IMPORT_ERROR} (Hint: You need to install gym with `pip install gym` to use gym environments"
)

if make_kwargs is None:
make_kwargs = {}

if env is not None:
self.gym_env = env
elif env_id is not None:
self.gym_env = gym.make(env_id, **make_kwargs)
else:
raise gymnasium.error.MissingArgument(
"Either env_id or env must be provided to create a legacy gym environment."
)
self.gym_env = _strip_default_wrappers(gym.make(env_id, **make_kwargs))

self.observation_space = _convert_space(self.gym_env.observation_space)
self.action_space = _convert_space(self.gym_env.action_space)

self.metadata = getattr(self.gym_env, "metadata", {"render_modes": []})
self.render_mode = self.gym_env.render_mode
self.reward_range = getattr(self.gym_env, "reward_range", None)
self.spec = getattr(self.gym_env, "spec", None)

def reset(
self, seed: Optional[int] = None, options: Optional[dict] = None
) -> Tuple[ObsType, dict]:
"""Resets the environment.

Args:
seed: the seed to reset the environment with
options: the options to reset the environment with

Returns:
(observation, info)
"""
super().reset(seed=None) # We don't need the seed inside gymnasium
# Options are ignored
return self.gym_env.reset()
RedTachyon marked this conversation as resolved.
Show resolved Hide resolved

def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
"""Steps through the environment.

Args:
action: action to step through the environment with

Returns:
(observation, reward, terminated, truncated, info)
"""
return self.gym_env.step(action)

def render(self):
"""Renders the environment.

Returns:
The rendering of the environment, depending on the render mode
"""
return self.gym_env.render()

def close(self):
"""Closes the environment."""
self.gym_env.close()

def __str__(self):
return f"GymEnvironment({self.gym_env})"

def __repr__(self):
return f"GymEnvironment({self.gym_env})"


def _strip_default_wrappers(env: "gym.Env") -> "gym.Env":
"""Strips builtin wrappers from the environment.

Args:
env: the environment to strip builtin wrappers from

Returns:
The environment without builtin wrappers
"""
import gym.wrappers.compatibility # Cheat because gym doesn't expose it in __init__
pseudo-rnd-thoughts marked this conversation as resolved.
Show resolved Hide resolved
import gym.wrappers.env_checker # Cheat because gym doesn't expose it in __init__

default_wrappers = (
gym.wrappers.RenderCollection,
gym.wrappers.HumanRendering,
)
while isinstance(env, default_wrappers):
env = env.env
return env


def _convert_space(space: "gym.Space") -> gymnasium.Space:
pseudo-rnd-thoughts marked this conversation as resolved.
Show resolved Hide resolved
"""Converts a gym space to a gymnasium space.

Args:
space: the space to convert

Returns:
The converted space
"""
if isinstance(space, gym.spaces.Discrete):
return gymnasium.spaces.Discrete(n=space.n)
elif isinstance(space, gym.spaces.Box):
return gymnasium.spaces.Box(
low=space.low, high=space.high, shape=space.shape, dtype=space.dtype
)
elif isinstance(space, gym.spaces.MultiDiscrete):
return gymnasium.spaces.MultiDiscrete(nvec=space.nvec)
elif isinstance(space, gym.spaces.MultiBinary):
return gymnasium.spaces.MultiBinary(n=space.n)
elif isinstance(space, gym.spaces.Tuple):
return gymnasium.spaces.Tuple(spaces=tuple(map(_convert_space, space.spaces)))
elif isinstance(space, gym.spaces.Dict):
return gymnasium.spaces.Dict(
spaces={k: _convert_space(v) for k, v in space.spaces.items()}
)
elif isinstance(space, gym.spaces.Sequence):
return gymnasium.spaces.Sequence(space=_convert_space(space.feature_space))
elif isinstance(space, gym.spaces.Graph):
return gymnasium.spaces.Graph(
node_space=_convert_space(space.node_space), # type: ignore
edge_space=_convert_space(space.edge_space), # type: ignore
)
elif isinstance(space, gym.spaces.Text):
return gymnasium.spaces.Text(
max_length=space.max_length,
min_length=space.min_length,
charset=space._char_str,
)
else:
raise NotImplementedError(
f"Cannot convert space of type {space}. Please upgrade your code to gymnasium."
)


# @singledispatch
# def _convert_space(space: "gym.Space") -> gymnasium.Space:
# """Blah"""
# raise NotImplementedError(
# f"Cannot convert space of type {type(space)}. Please upgrade your code to gymnasium."
# )
#
#
# @_convert_space.register
# def _(space: "gym.spaces.Discrete") -> gymnasium.spaces.Discrete:
# return gymnasium.spaces.Discrete(space.n)
#
#
# @_convert_space.register
# def _(space: "gym.spaces.Box") -> gymnasium.spaces.Box:
# return gymnasium.spaces.Box(space.low, space.high, space.shape)
#
#
# @_convert_space.register
# def _(space: "gym.spaces.Tuple") -> gymnasium.spaces.Tuple:
# return gymnasium.spaces.Tuple(_convert_space(s) for s in space.spaces)
#
#
# @_convert_space.register
# def _(space: "gym.spaces.Dict") -> gymnasium.spaces.Dict:
# return gymnasium.spaces.Dict(
# {k: _convert_space(s) for k, s in space.spaces.items()}
# )
#
#
# @_convert_space.register
# def _(space: "gym.spaces.MultiDiscrete") -> gymnasium.spaces.MultiDiscrete:
# return gymnasium.spaces.MultiDiscrete(space.nvec)
#
#
# @_convert_space.register
# def _(space: "gym.spaces.MultiBinary") -> gymnasium.spaces.MultiBinary:
# return gymnasium.spaces.MultiBinary(space.n)
#
#
# @_convert_space.register
# def _(space: "gym.spaces.Sequence") -> gymnasium.spaces.Sequence:
# return gymnasium.spaces.Sequence(_convert_space(space.feature_space))
#
#
# @_convert_space.register
# def _(space: "gym.spaces.Graph") -> gymnasium.spaces.Graph:
# # Pycharm is throwing up a type warning, but as long as the base space is correct, this is valid
# return gymnasium.spaces.Graph(_convert_space(space.node_space), _convert_space(space.edge_space)) # type: ignore
4 changes: 4 additions & 0 deletions gymnasium/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ class InvalidAction(Error):
"""Raised when the user performs an action not contained within the action space."""


class MissingArgument(Error):
"""Raised when a required argument in the initializer is missing."""


# API errors


Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_version():
testing_group = set(extras.keys()) - {"accept-rom-license", "atari"}
extras["testing"] = list(
set(itertools.chain.from_iterable(map(lambda group: extras[group], testing_group)))
) + ["pytest==7.0.1"]
) + ["pytest==7.0.1", "gym==0.26.2"]

# All dependency groups - accept rom license as requires user to run
all_groups = set(extras.keys()) - {"accept-rom-license"}
Expand Down
1 change: 1 addition & 0 deletions test_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ imageio>=2.14.1
pygame==2.1.0
mujoco_py<2.2,>=2.1
pytest==7.0.1
gym==0.26.2
37 changes: 28 additions & 9 deletions tests/envs/test_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
from typing import Any, Dict, Optional, Tuple

import numpy as np
import pytest

import gymnasium as gym
import gymnasium
from gymnasium.spaces import Discrete
from gymnasium.utils.env_checker import check_env
from gymnasium.wrappers.compatibility import EnvCompatibility, LegacyEnv


class LegacyEnvExplicit(LegacyEnv, gym.Env):
class LegacyEnvExplicit(LegacyEnv, gymnasium.Env):
"""Legacy env that explicitly implements the old API."""

observation_space = Discrete(1)
Expand Down Expand Up @@ -37,7 +39,7 @@ def seed(self, seed=None):
pass


class LegacyEnvImplicit(gym.Env):
class LegacyEnvImplicit(gymnasium.Env):
"""Legacy env that implicitly implements the old API as a protocol."""

observation_space = Discrete(1)
Expand Down Expand Up @@ -95,12 +97,12 @@ def test_implicit():


def test_make_compatibility_in_spec():
gym.register(
gymnasium.register(
id="LegacyTestEnv-v0",
entry_point=LegacyEnvExplicit,
apply_api_compatibility=True,
)
env = gym.make("LegacyTestEnv-v0", render_mode="rgb_array")
env = gymnasium.make("LegacyTestEnv-v0", render_mode="rgb_array")
assert env.observation_space == Discrete(1)
assert env.action_space == Discrete(1)
assert env.reset() == (0, {})
Expand All @@ -110,12 +112,12 @@ def test_make_compatibility_in_spec():
assert isinstance(img, np.ndarray)
assert img.shape == (1, 1, 3) # type: ignore
env.close()
del gym.envs.registration.registry["LegacyTestEnv-v0"]
del gymnasium.envs.registration.registry["LegacyTestEnv-v0"]


def test_make_compatibility_in_make():
gym.register(id="LegacyTestEnv-v0", entry_point=LegacyEnvExplicit)
env = gym.make(
gymnasium.register(id="LegacyTestEnv-v0", entry_point=LegacyEnvExplicit)
env = gymnasium.make(
"LegacyTestEnv-v0", apply_api_compatibility=True, render_mode="rgb_array"
)
assert env.observation_space == Discrete(1)
Expand All @@ -127,4 +129,21 @@ def test_make_compatibility_in_make():
assert isinstance(img, np.ndarray)
assert img.shape == (1, 1, 3) # type: ignore
env.close()
del gym.envs.registration.registry["LegacyTestEnv-v0"]
del gymnasium.envs.registration.registry["LegacyTestEnv-v0"]


def test_gym_conversion_by_id():
gym = pytest.importorskip("gym")
env_list = gym.envs.registry.keys()
for env_id in env_list:
env = gymnasium.make("GymV26Environment-v0", env_id=env_id)
check_env(env)


def test_gym_conversion_instantiated():
gym = pytest.importorskip("gym")
env_list = gym.envs.registry.keys()
for env_id in env_list:
env = gym.make(env_id)
env = gymnasium.make("GymV26Environment-v0", env=env)
check_env(env)
6 changes: 5 additions & 1 deletion tests/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ def try_make_env(env_spec: EnvSpec) -> Optional[gym.Env]:
if "gymnasium.envs." in env_spec.entry_point:
try:
return env_spec.make(disable_env_checker=True).unwrapped
except (ImportError, gym.error.DependencyNotInstalled) as e:
except (
ImportError,
gym.error.DependencyNotInstalled,
gym.error.MissingArgument,
) as e:
logger.warn(f"Not testing {env_spec.id} due to error: {e}")
return None

Expand Down