Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add testing for step api compatibility functions and wrapper #3028

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gym/utils/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def data_equivalence(data_1, data_2) -> bool:
return data_1.keys() == data_2.keys() and all(
data_equivalence(data_1[k], data_2[k]) for k in data_1.keys()
)
elif isinstance(data_1, tuple):
elif isinstance(data_1, (tuple, list)):
return len(data_1) == len(data_2) and all(
data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2)
)
Expand Down
31 changes: 17 additions & 14 deletions gym/wrappers/step_api_compatibility.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Implementation of StepAPICompatibility wrapper class for transforming envs between new and old step API."""
import gym
from gym.logger import deprecation
from gym.utils.step_api_compatibility import step_to_new_api, step_to_old_api
from gym.utils.step_api_compatibility import (
to_done_step_api,
to_terminated_truncated_step_api,
)


class StepAPICompatibility(gym.Wrapper):
Expand All @@ -15,33 +18,33 @@ class StepAPICompatibility(gym.Wrapper):

Args:
env (gym.Env): the env to wrap. Can be in old or new API
new_step_api (bool): True to use env with new step API, False to use env with old step API. (False by default)
to_termination_truncation_api (bool): True to use env with new step API, False to use env with old step API. (False by default)

Examples:
>>> env = gym.make("CartPole-v1")
>>> env # wrapper applied by default, set to old API
<TimeLimit<OrderEnforcing<StepAPICompatibility<CartPoleEnv<CartPole-v1>>>>>
>>> env = gym.make("CartPole-v1", new_step_api=True) # set to new API
>>> env = StepAPICompatibility(CustomEnv(), new_step_api=True) # manually using wrapper on unregistered envs
>>> env = gym.make("CartPole-v1", to_termination_truncation_api=True) # set to new API
>>> env = StepAPICompatibility(CustomEnv(), to_termination_truncation_api=True) # manually using wrapper on unregistered envs

"""

def __init__(self, env: gym.Env, new_step_api=False):
def __init__(self, env: gym.Env, to_termination_truncation_api: bool = False):
"""A wrapper which can transform an environment from new step API to old and vice-versa.

Args:
env (gym.Env): the env to wrap. Can be in old or new API
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
to_termination_truncation_api (bool): Whether the wrapper's step method outputs two booleans (termination and truncation) with True, or one boolean (done) with False
"""
super().__init__(env, new_step_api)
self.new_step_api = new_step_api
if not self.new_step_api:
super().__init__(env, to_termination_truncation_api)
self.to_termination_truncation_api = to_termination_truncation_api
if self.to_termination_truncation_api is False:
deprecation(
"Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future."
"Initializing environment in old done step API which returns one bool instead of two. It is recommended to set `to_termination_truncation_api=True` to use new step API. This will be the default behaviour in future."
)

def step(self, action):
"""Steps through the environment, returning 5 or 4 items depending on `new_step_api`.
"""Steps through the environment, returning 5 or 4 items depending on `to_termination_truncation_api`.

Args:
action: action to step through the environment with
Expand All @@ -50,7 +53,7 @@ def step(self, action):
(observation, reward, terminated, truncated, info) or (observation, reward, done, info)
"""
step_returns = self.env.step(action)
if self.new_step_api:
return step_to_new_api(step_returns)
if self.to_termination_truncation_api:
return to_terminated_truncated_step_api(step_returns)
else:
return step_to_old_api(step_returns)
return to_done_step_api(step_returns)
11 changes: 7 additions & 4 deletions gym/wrappers/time_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(

Args:
env: The environment to apply the wrapper
max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used)
max_episode_steps: An optional max episode steps (if ``None``, ``env.spec.max_episode_steps`` is used)
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
"""
super().__init__(env, new_step_api)
Expand All @@ -58,16 +58,19 @@ def step(self, action):
"""
observation, reward, terminated, truncated, info = step_api_compatibility(
self.env.step(action),
True,
to_termination_truncation=True,
)
self._elapsed_steps += 1

if self._elapsed_steps >= self._max_episode_steps:
truncated = True
if self.new_step_api is True or terminated is False:
# As the old step api cannot encode both terminated and truncated, we favor terminated in the case of both.
# Therefore, if new step api (i.e. not old step api) or when terminated is False to prevent the overriding
truncated = True

return step_api_compatibility(
(observation, reward, terminated, truncated, info),
self.new_step_api,
to_termination_truncation=self.new_step_api,
)

def reset(self, **kwargs):
Expand Down
160 changes: 160 additions & 0 deletions tests/utils/test_step_api_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import numpy as np
import pytest

from gym.utils.env_checker import data_equivalence
from gym.utils.step_api_compatibility import (
to_done_step_api,
to_terminated_truncated_step_api,
)


@pytest.mark.parametrize(
"is_vector_env, done_returns, expected_terminated, expected_truncated",
(
# Test each of the permutations for single environments with and without the old info
(False, (0, 0, False, {"Test-info": True}), False, False),
(False, (0, 0, False, {"TimeLimit.truncated": False}), False, False),
(False, (0, 0, True, {}), True, False),
(False, (0, 0, True, {"TimeLimit.truncated": True}), False, True),
(False, (0, 0, True, {"Test-info": True}), True, False),
# Test vectorise versions with both list and dict infos testing each permutation for sub-environments
(
True,
(
0,
0,
np.array([False, True, True]),
[{}, {}, {"TimeLimit.truncated": True}],
),
np.array([False, True, False]),
np.array([False, False, True]),
),
(
True,
(
0,
0,
np.array([False, True, True]),
{"TimeLimit.truncated": np.array([False, False, True])},
),
np.array([False, True, False]),
np.array([False, False, True]),
),
# empty truncated info
(
True,
(
0,
0,
np.array([False, True]),
{},
),
np.array([False, True]),
np.array([False, False]),
),
),
)
def test_to_done_step_api(
is_vector_env, done_returns, expected_terminated, expected_truncated
):
_, _, terminated, truncated, info = to_terminated_truncated_step_api(
done_returns, is_vector_env=is_vector_env
)
assert np.all(terminated == expected_terminated)
assert np.all(truncated == expected_truncated)
if is_vector_env is False:
assert "TimeLimit.truncated" not in info
elif isinstance(info, list):
assert all("TimeLimit.truncated" not in sub_info for sub_info in info)
else: # isinstance(info, dict)
assert "TimeLimit.truncated" not in info

roundtripped_returns = to_done_step_api(
(0, 0, terminated, truncated, info), is_vector_env=is_vector_env
)
assert data_equivalence(done_returns, roundtripped_returns)


@pytest.mark.parametrize(
"is_vector_env, terminated_truncated_returns, expected_done, expected_truncated",
(
(False, (0, 0, False, False, {"Test-info": True}), False, False),
(False, (0, 0, True, False, {}), True, False),
(False, (0, 0, False, True, {}), True, True),
# (False, (), True, True), # Not possible to encode in the old step api
# Test vector dict info
(
True,
(0, 0, np.array([False, True, False]), np.array([False, False, True]), {}),
np.array([False, True, True]),
np.array([False, False, True]),
),
# Test vector dict info with no truncation
(
True,
(0, 0, np.array([False, True]), np.array([False, False]), {}),
np.array([False, True]),
np.array([False, False]),
),
# Test vector list info
(
True,
(
0,
0,
np.array([False, True, False]),
np.array([False, False, True]),
[{}, {}, {}],
),
np.array([False, True, True]),
np.array([False, False, True]),
),
),
)
def test_to_terminated_truncated_step_api(
is_vector_env, terminated_truncated_returns, expected_done, expected_truncated
):
_, _, done, info = to_done_step_api(
terminated_truncated_returns, is_vector_env=is_vector_env
)
assert np.all(done == expected_done)
if is_vector_env is False:
if expected_truncated:
assert info["TimeLimit.truncated"] == expected_truncated
else:
assert "TimeLimit.truncated" not in info
elif isinstance(info, list):
for sub_info, trunc in zip(info, expected_truncated):
if trunc:
assert sub_info["TimeLimit.truncated"] == trunc
else:
assert "TimeLimit.truncated" not in sub_info
else: # isinstance(info, dict)
if np.any(expected_truncated):
assert np.all(info["TimeLimit.truncated"] == expected_truncated)
else:
assert "TimeLimit.truncated" not in info

roundtripped_returns = to_terminated_truncated_step_api(
(0, 0, done, info), is_vector_env=is_vector_env
)
assert data_equivalence(terminated_truncated_returns, roundtripped_returns)


def test_edge_case():
# When converting between the two-step APIs this is not possible in a single case
# terminated=True and truncated=True -> done=True and info={}
# We cannot test this in test_to_terminated_truncated_step_api as the roundtripping test will fail
_, _, done, info = to_done_step_api((0, 0, True, True, {}))
assert done is True
assert info == {}

_, _, done, info = to_done_step_api((0, 0, np.array([True]), np.array([True]), {}))
assert np.all(done)
assert info == {}

_, _, done, info = to_done_step_api(
(0, 0, np.array([True]), np.array([True]), [{"Test-Info": True}])
)
assert np.all(done)
assert info == [{"Test-Info": True}]
91 changes: 0 additions & 91 deletions tests/utils/test_terminated_truncated.py

This file was deleted.