Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gym-Gymnasium compatibility converter #61

Merged
merged 16 commits into from
Oct 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions gymnasium/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,11 @@
entry_point="gymnasium.envs.mujoco.humanoidstandup_v4:HumanoidStandupEnv",
max_episode_steps=1000,
)


# Gym conversion
# ----------------------------------------
register(
id="GymV26Environment-v0",
entry_point="gymnasium.envs.external.gym_env:GymEnvironment",
)
Empty file.
159 changes: 159 additions & 0 deletions gymnasium/envs/external/gym_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from typing import Optional, Tuple

import gymnasium
from gymnasium import error
from gymnasium.core import ActType, ObsType

try:
import gym
import gym.wrappers
except ImportError as e:
GYM_IMPORT_ERROR = e
else:
GYM_IMPORT_ERROR = None


class GymEnvironment(gymnasium.Env):
"""
Converts a gym environment to a gymnasium environment.
"""

def __init__(
self,
env_id: Optional[str] = None,
make_kwargs: Optional[dict] = None,
env: Optional["gym.Env"] = None,
):
if GYM_IMPORT_ERROR is not None:
raise error.DependencyNotInstalled(
f"{GYM_IMPORT_ERROR} (Hint: You need to install gym with `pip install gym` to use gym environments"
)

if make_kwargs is None:
make_kwargs = {}

if env is not None:
self.gym_env = env
elif env_id is not None:
self.gym_env = gym.make(env_id, **make_kwargs)
else:
raise gymnasium.error.MissingArgument(
"Either env_id or env must be provided to create a legacy gym environment."
)
self.gym_env = _strip_default_wrappers(self.gym_env)

self.observation_space = _convert_space(self.gym_env.observation_space)
self.action_space = _convert_space(self.gym_env.action_space)

self.metadata = getattr(self.gym_env, "metadata", {"render_modes": []})
self.render_mode = self.gym_env.render_mode
self.reward_range = getattr(self.gym_env, "reward_range", None)
self.spec = getattr(self.gym_env, "spec", None)

def reset(
self, seed: Optional[int] = None, options: Optional[dict] = None
) -> Tuple[ObsType, dict]:
"""Resets the environment.

Args:
seed: the seed to reset the environment with
options: the options to reset the environment with

Returns:
(observation, info)
"""
super().reset(seed=seed)
# Options are ignored
return self.gym_env.reset(seed=seed, options=options)

def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
"""Steps through the environment.

Args:
action: action to step through the environment with

Returns:
(observation, reward, terminated, truncated, info)
"""
return self.gym_env.step(action)

def render(self):
"""Renders the environment.

Returns:
The rendering of the environment, depending on the render mode
"""
return self.gym_env.render()

def close(self):
"""Closes the environment."""
self.gym_env.close()

def __str__(self):
return f"GymEnvironment({self.gym_env})"

def __repr__(self):
return f"GymEnvironment({self.gym_env})"


def _strip_default_wrappers(env: "gym.Env") -> "gym.Env":
"""Strips builtin wrappers from the environment.

Args:
env: the environment to strip builtin wrappers from

Returns:
The environment without builtin wrappers
"""

default_wrappers = (
gym.wrappers.render_collection.RenderCollection,
gym.wrappers.human_rendering.HumanRendering,
)
while isinstance(env, default_wrappers):
env = env.env
return env


def _convert_space(space: "gym.Space") -> gymnasium.Space:
"""Converts a gym space to a gymnasium space.

Args:
space: the space to convert

Returns:
The converted space
"""
if isinstance(space, gym.spaces.Discrete):
return gymnasium.spaces.Discrete(n=space.n)
elif isinstance(space, gym.spaces.Box):
return gymnasium.spaces.Box(
low=space.low, high=space.high, shape=space.shape, dtype=space.dtype
)
elif isinstance(space, gym.spaces.MultiDiscrete):
return gymnasium.spaces.MultiDiscrete(nvec=space.nvec)
elif isinstance(space, gym.spaces.MultiBinary):
return gymnasium.spaces.MultiBinary(n=space.n)
elif isinstance(space, gym.spaces.Tuple):
return gymnasium.spaces.Tuple(spaces=tuple(map(_convert_space, space.spaces)))
elif isinstance(space, gym.spaces.Dict):
return gymnasium.spaces.Dict(
spaces={k: _convert_space(v) for k, v in space.spaces.items()}
)
elif isinstance(space, gym.spaces.Sequence):
return gymnasium.spaces.Sequence(space=_convert_space(space.feature_space))
elif isinstance(space, gym.spaces.Graph):
return gymnasium.spaces.Graph(
node_space=_convert_space(space.node_space), # type: ignore
edge_space=_convert_space(space.edge_space), # type: ignore
)
elif isinstance(space, gym.spaces.Text):
return gymnasium.spaces.Text(
max_length=space.max_length,
min_length=space.min_length,
charset=space._char_str,
)
else:
raise NotImplementedError(
f"Cannot convert space of type {space}. Please upgrade your code to gymnasium."
)
4 changes: 4 additions & 0 deletions gymnasium/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ class InvalidAction(Error):
"""Raised when the user performs an action not contained within the action space."""


class MissingArgument(Error):
"""Raised when a required argument in the initializer is missing."""


# API errors


Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_version():
testing_group = set(extras.keys()) - {"accept-rom-license", "atari"}
extras["testing"] = list(
set(itertools.chain.from_iterable(map(lambda group: extras[group], testing_group)))
) + ["pytest==7.0.1"]
) + ["pytest==7.0.1", "gym==0.26.2"]

# All dependency groups - accept rom license as requires user to run
all_groups = set(extras.keys()) - {"accept-rom-license"}
Expand Down
1 change: 1 addition & 0 deletions test_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ imageio>=2.14.1
pygame==2.1.0
mujoco_py<2.2,>=2.1
pytest==7.0.1
gym==0.26.2
18 changes: 9 additions & 9 deletions tests/envs/test_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

import numpy as np

import gymnasium as gym
import gymnasium
from gymnasium.spaces import Discrete
from gymnasium.wrappers.compatibility import EnvCompatibility, LegacyEnv


class LegacyEnvExplicit(LegacyEnv, gym.Env):
class LegacyEnvExplicit(LegacyEnv, gymnasium.Env):
"""Legacy env that explicitly implements the old API."""

observation_space = Discrete(1)
Expand Down Expand Up @@ -37,7 +37,7 @@ def seed(self, seed=None):
pass


class LegacyEnvImplicit(gym.Env):
class LegacyEnvImplicit(gymnasium.Env):
"""Legacy env that implicitly implements the old API as a protocol."""

observation_space = Discrete(1)
Expand Down Expand Up @@ -95,12 +95,12 @@ def test_implicit():


def test_make_compatibility_in_spec():
gym.register(
gymnasium.register(
id="LegacyTestEnv-v0",
entry_point=LegacyEnvExplicit,
apply_api_compatibility=True,
)
env = gym.make("LegacyTestEnv-v0", render_mode="rgb_array")
env = gymnasium.make("LegacyTestEnv-v0", render_mode="rgb_array")
assert env.observation_space == Discrete(1)
assert env.action_space == Discrete(1)
assert env.reset() == (0, {})
Expand All @@ -110,12 +110,12 @@ def test_make_compatibility_in_spec():
assert isinstance(img, np.ndarray)
assert img.shape == (1, 1, 3) # type: ignore
env.close()
del gym.envs.registration.registry["LegacyTestEnv-v0"]
del gymnasium.envs.registration.registry["LegacyTestEnv-v0"]


def test_make_compatibility_in_make():
gym.register(id="LegacyTestEnv-v0", entry_point=LegacyEnvExplicit)
env = gym.make(
gymnasium.register(id="LegacyTestEnv-v0", entry_point=LegacyEnvExplicit)
env = gymnasium.make(
"LegacyTestEnv-v0", apply_api_compatibility=True, render_mode="rgb_array"
)
assert env.observation_space == Discrete(1)
Expand All @@ -127,4 +127,4 @@ def test_make_compatibility_in_make():
assert isinstance(img, np.ndarray)
assert img.shape == (1, 1, 3) # type: ignore
env.close()
del gym.envs.registration.registry["LegacyTestEnv-v0"]
del gymnasium.envs.registration.registry["LegacyTestEnv-v0"]
27 changes: 27 additions & 0 deletions tests/envs/test_gym_conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest

import gymnasium
from gymnasium.utils.env_checker import check_env

pytest.importorskip("gym")

import gym # noqa: E402, isort: skip

ALL_GYM_ENVS = gym.envs.registry.keys()


@pytest.mark.parametrize(
"env_id", ALL_GYM_ENVS, ids=[env_id for env_id in ALL_GYM_ENVS]
)
def test_gym_conversion_by_id(env_id):
env = gymnasium.make("GymV26Environment-v0", env_id=env_id)
check_env(env)


@pytest.mark.parametrize(
"env_id", ALL_GYM_ENVS, ids=[env_id for env_id in ALL_GYM_ENVS]
)
def test_gym_conversion_instantiated(env_id):
env = gym.make(env_id)
env = gymnasium.make("GymV26Environment-v0", env=env)
check_env(env)
6 changes: 5 additions & 1 deletion tests/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ def try_make_env(env_spec: EnvSpec) -> Optional[gym.Env]:
if "gymnasium.envs." in env_spec.entry_point:
try:
return env_spec.make(disable_env_checker=True).unwrapped
except (ImportError, gym.error.DependencyNotInstalled) as e:
except (
ImportError,
gym.error.DependencyNotInstalled,
gym.error.MissingArgument,
) as e:
logger.warn(f"Not testing {env_spec.id} due to error: {e}")
return None

Expand Down