Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add testing for step api compatibility functions and wrapper #3028

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gym/utils/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def data_equivalence(data_1, data_2) -> bool:
return data_1.keys() == data_2.keys() and all(
data_equivalence(data_1[k], data_2[k]) for k in data_1.keys()
)
elif isinstance(data_1, tuple):
elif isinstance(data_1, (tuple, list)):
return len(data_1) == len(data_2) and all(
data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2)
)
Expand Down
172 changes: 74 additions & 98 deletions gym/utils/step_api_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,66 +36,41 @@ def step_to_new_api(
assert len(step_returns) == 4
observations, rewards, dones, infos = step_returns

terminateds = []
truncateds = []
if not is_vector_env:
dones = [dones]

for i in range(len(dones)):
# For every condition, handling - info single env / info vector env (list) / info vector env (dict)

# TimeLimit.truncated attribute not present - implies either terminated or episode still ongoing based on `done`
if (not is_vector_env and "TimeLimit.truncated" not in infos) or (
is_vector_env
and (
(
isinstance(infos, list)
and "TimeLimit.truncated" not in infos[i]
) # vector env, list info api
or (
"TimeLimit.truncated" not in infos
or (
"TimeLimit.truncated" in infos
and not infos["TimeLimit.truncated"][i]
)
)
# vector env, dict info api, for env i, vector mask `_TimeLimit.truncated` is not considered, to be compatible with envpool
# For env i, `TimeLimit.truncated` not being present is treated same as being present and set to False.
# therefore, terminated=True, truncated=True simultaneously is not allowed while using compatibility functions
# with vector info
)
):
terminateds.append(dones[i])
truncateds.append(False)

# This means info["TimeLimit.truncated"] exists and this elif checks if it is True, which means the truncation has occurred but termination has not.
elif (
infos["TimeLimit.truncated"]
if not is_vector_env
else (
infos["TimeLimit.truncated"][i]
if isinstance(infos, dict)
else infos[i]["TimeLimit.truncated"]
)
):
assert dones[i]
terminateds.append(False)
truncateds.append(True)
else:
# This means info["TimeLimit.truncated"] exists but is False, which means the core environment had already terminated,
# but it also exceeded maximum timesteps at the same step. However to be compatible with envpool, and to be backward compatible
# truncated is set to False here.
assert dones[i]
terminateds.append(True)
truncateds.append(False)

return (
observations,
rewards,
np.array(terminateds, dtype=np.bool_) if is_vector_env else terminateds[0],
np.array(truncateds, dtype=np.bool_) if is_vector_env else truncateds[0],
infos,
)
# Cases to handle - info single env / info vector env (list) / info vector env (dict)
if is_vector_env is False:
truncated = infos.pop("TimeLimit.truncated", False)
return (
observations,
rewards,
dones and not truncated,
dones and truncated,
infos,
)
elif isinstance(infos, list):
truncated = np.array(
[info.pop("TimeLimit.truncated", False) for info in infos]
)
return (
observations,
rewards,
np.logical_and(dones, np.logical_not(truncated)),
np.logical_and(dones, truncated),
infos,
)
elif isinstance(infos, dict):
num_envs = len(dones)
truncated = infos.pop("TimeLimit.truncated", np.zeros(num_envs, dtype=bool))
return (
observations,
rewards,
np.logical_and(dones, np.logical_not(truncated)),
np.logical_and(dones, truncated),
infos,
)
else:
raise TypeError(
f"Unexpected value of infos, as is_vector_envs=False, expects `info` to be a list or dict, actual type: {type(infos)}"
)


def step_to_old_api(
Expand All @@ -111,44 +86,45 @@ def step_to_old_api(
return step_returns
else:
assert len(step_returns) == 5
observations, rewards, terminateds, truncateds, infos = step_returns
dones = []
if not is_vector_env:
terminateds = [terminateds]
truncateds = [truncateds]

n_envs = len(terminateds)

for i in range(n_envs):
dones.append(terminateds[i] or truncateds[i])
if truncateds[i]:
if is_vector_env:
# handle vector infos for dict and list
if isinstance(infos, dict):
if "TimeLimit.truncated" not in infos:
# TODO: This should ideally not be done manually and should use vector_env's _add_info()
infos["TimeLimit.truncated"] = np.zeros(n_envs, dtype=bool)
infos["_TimeLimit.truncated"] = np.zeros(n_envs, dtype=bool)

infos["TimeLimit.truncated"][i] = (
not terminateds[i] or infos["TimeLimit.truncated"][i]
)
infos["_TimeLimit.truncated"][i] = True
else:
# if vector info is a list
infos[i]["TimeLimit.truncated"] = not terminateds[i] or infos[
i
].get("TimeLimit.truncated", False)
else:
infos["TimeLimit.truncated"] = not terminateds[i] or infos.get(
"TimeLimit.truncated", False
)
return (
observations,
rewards,
np.array(dones, dtype=np.bool_) if is_vector_env else dones[0],
infos,
)
observations, rewards, terminated, truncated, infos = step_returns

# Cases to handle - info single env / info vector env (list) / info vector env (dict)
if is_vector_env is False:
if truncated or terminated:
infos["TimeLimit.truncated"] = truncated and not terminated
return (
observations,
rewards,
terminated or truncated,
infos,
)
elif isinstance(infos, list):
for info, env_truncated, env_terminated in zip(
infos, truncated, terminated
):
if env_truncated or env_terminated:
info["TimeLimit.truncated"] = env_truncated and not env_terminated
return (
observations,
rewards,
np.logical_or(terminated, truncated),
infos,
)
elif isinstance(infos, dict):
if np.logical_or(np.any(truncated), np.any(terminated)):
infos["TimeLimit.truncated"] = np.logical_and(
truncated, np.logical_not(terminated)
)
return (
observations,
rewards,
np.logical_or(terminated, truncated),
infos,
)
else:
raise TypeError(
f"Unexpected value of infos, as is_vector_envs=False, expects `info` to be a list or dict, actual type: {type(infos)}"
)


def step_api_compatibility(
Expand Down
7 changes: 5 additions & 2 deletions gym/wrappers/time_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(

Args:
env: The environment to apply the wrapper
max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used)
max_episode_steps: An optional max episode steps (if ``None``, ``env.spec.max_episode_steps`` is used)
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
"""
super().__init__(env, new_step_api)
Expand Down Expand Up @@ -63,7 +63,10 @@ def step(self, action):
self._elapsed_steps += 1

if self._elapsed_steps >= self._max_episode_steps:
truncated = True
if self.new_step_api is True or terminated is False:
# As the old step api cannot encode both terminated and truncated, we favor terminated in the case of both.
# Therefore, if new step api (i.e. not old step api) or when terminated is False to prevent the overriding
truncated = True

return step_api_compatibility(
(observation, reward, terminated, truncated, info),
Expand Down
166 changes: 166 additions & 0 deletions tests/utils/test_step_api_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import numpy as np
import pytest

from gym.utils.env_checker import data_equivalence
from gym.utils.step_api_compatibility import step_to_new_api, step_to_old_api


@pytest.mark.parametrize(
"is_vector_env, done_returns, expected_terminated, expected_truncated",
(
# Test each of the permutations for single environments with and without the old info
(False, (0, 0, False, {"Test-info": True}), False, False),
(False, (0, 0, False, {"TimeLimit.truncated": False}), False, False),
(False, (0, 0, True, {}), True, False),
(False, (0, 0, True, {"TimeLimit.truncated": True}), False, True),
(False, (0, 0, True, {"Test-info": True}), True, False),
# Test vectorise versions with both list and dict infos testing each permutation for sub-environments
(
True,
(
0,
0,
np.array([False, True, True]),
[{}, {}, {"TimeLimit.truncated": True}],
),
np.array([False, True, False]),
np.array([False, False, True]),
),
(
True,
(
0,
0,
np.array([False, True, True]),
{"TimeLimit.truncated": np.array([False, False, True])},
),
np.array([False, True, False]),
np.array([False, False, True]),
),
# empty truncated info
(
True,
(
0,
0,
np.array([False, True]),
{},
),
np.array([False, True]),
np.array([False, False]),
),
),
)
def test_to_done_step_api(
is_vector_env, done_returns, expected_terminated, expected_truncated
):
_, _, terminated, truncated, info = step_to_new_api(
done_returns, is_vector_env=is_vector_env
)
assert np.all(terminated == expected_terminated)
assert np.all(truncated == expected_truncated)

if is_vector_env is False:
assert "TimeLimit.truncated" not in info
elif isinstance(info, list):
assert all("TimeLimit.truncated" not in sub_info for sub_info in info)
else: # isinstance(info, dict)
assert "TimeLimit.truncated" not in info

roundtripped_returns = step_to_old_api(
(0, 0, terminated, truncated, info), is_vector_env=is_vector_env
)
assert data_equivalence(done_returns, roundtripped_returns)


@pytest.mark.parametrize(
"is_vector_env, terminated_truncated_returns, expected_done, expected_truncated",
(
(False, (0, 0, False, False, {"Test-info": True}), False, False),
(False, (0, 0, True, False, {}), True, False),
(False, (0, 0, False, True, {}), True, True),
# (False, (), True, True), # Not possible to encode in the old step api
# Test vector dict info
(
True,
(0, 0, np.array([False, True, False]), np.array([False, False, True]), {}),
np.array([False, True, True]),
np.array([False, False, True]),
),
# Test vector dict info with no truncation
(
True,
(0, 0, np.array([False, True]), np.array([False, False]), {}),
np.array([False, True]),
np.array([False, False]),
),
# Test vector list info
(
True,
(
0,
0,
np.array([False, True, False]),
np.array([False, False, True]),
[{"Test-Info": True}, {}, {}],
),
np.array([False, True, True]),
np.array([False, False, True]),
),
),
)
def test_to_terminated_truncated_step_api(
is_vector_env, terminated_truncated_returns, expected_done, expected_truncated
):
_, _, done, info = step_to_old_api(
terminated_truncated_returns, is_vector_env=is_vector_env
)
assert np.all(done == expected_done)

if is_vector_env is False:
if expected_done:
assert info["TimeLimit.truncated"] == expected_truncated
else:
assert "TimeLimit.truncated" not in info
elif isinstance(info, list):
for sub_info, env_done, env_truncated in zip(
info, expected_done, expected_truncated
):
if env_done:
assert sub_info["TimeLimit.truncated"] == env_truncated
else:
assert "TimeLimit.truncated" not in sub_info
else: # isinstance(info, dict)
if np.any(expected_done):
assert np.all(info["TimeLimit.truncated"] == expected_truncated)
else:
assert "TimeLimit.truncated" not in info

roundtripped_returns = step_to_new_api(
(0, 0, done, info), is_vector_env=is_vector_env
)
assert data_equivalence(terminated_truncated_returns, roundtripped_returns)


def test_edge_case():
# When converting between the two-step APIs this is not possible in a single case
# terminated=True and truncated=True -> done=True and info={}
# We cannot test this in test_to_terminated_truncated_step_api as the roundtripping test will fail
_, _, done, info = step_to_old_api((0, 0, True, True, {}))
assert done is True
assert info == {"TimeLimit.truncated": False}

# Test with vector dict info
_, _, done, info = step_to_old_api(
(0, 0, np.array([True]), np.array([True]), {}), is_vector_env=True
)
assert np.all(done)
assert info == {"TimeLimit.truncated": np.array([False])}

# Test with vector list info
_, _, done, info = step_to_old_api(
(0, 0, np.array([True]), np.array([True]), [{"Test-Info": True}]),
is_vector_env=True,
)
assert np.all(done)
assert info == [{"Test-Info": True, "TimeLimit.truncated": False}]
Loading