Skip to content

Commit

Permalink
Add testing for step api compatibility functions and wrapper (#3028)
Browse files Browse the repository at this point in the history
* Initial commit

* Fixed tests and forced TimeLimit.truncated to always exist when truncated or terminated

* Fix CI issues

* pre-commit

* Revert back to old language

* Revert changes to step api wrapper
  • Loading branch information
pseudo-rnd-thoughts authored Aug 18, 2022
1 parent aa43d13 commit a8d4dd7
Show file tree
Hide file tree
Showing 6 changed files with 247 additions and 192 deletions.
2 changes: 1 addition & 1 deletion gym/utils/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def data_equivalence(data_1, data_2) -> bool:
return data_1.keys() == data_2.keys() and all(
data_equivalence(data_1[k], data_2[k]) for k in data_1.keys()
)
elif isinstance(data_1, tuple):
elif isinstance(data_1, (tuple, list)):
return len(data_1) == len(data_2) and all(
data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2)
)
Expand Down
172 changes: 74 additions & 98 deletions gym/utils/step_api_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,66 +36,41 @@ def step_to_new_api(
assert len(step_returns) == 4
observations, rewards, dones, infos = step_returns

terminateds = []
truncateds = []
if not is_vector_env:
dones = [dones]

for i in range(len(dones)):
# For every condition, handling - info single env / info vector env (list) / info vector env (dict)

# TimeLimit.truncated attribute not present - implies either terminated or episode still ongoing based on `done`
if (not is_vector_env and "TimeLimit.truncated" not in infos) or (
is_vector_env
and (
(
isinstance(infos, list)
and "TimeLimit.truncated" not in infos[i]
) # vector env, list info api
or (
"TimeLimit.truncated" not in infos
or (
"TimeLimit.truncated" in infos
and not infos["TimeLimit.truncated"][i]
)
)
# vector env, dict info api, for env i, vector mask `_TimeLimit.truncated` is not considered, to be compatible with envpool
# For env i, `TimeLimit.truncated` not being present is treated same as being present and set to False.
# therefore, terminated=True, truncated=True simultaneously is not allowed while using compatibility functions
# with vector info
)
):
terminateds.append(dones[i])
truncateds.append(False)

# This means info["TimeLimit.truncated"] exists and this elif checks if it is True, which means the truncation has occurred but termination has not.
elif (
infos["TimeLimit.truncated"]
if not is_vector_env
else (
infos["TimeLimit.truncated"][i]
if isinstance(infos, dict)
else infos[i]["TimeLimit.truncated"]
)
):
assert dones[i]
terminateds.append(False)
truncateds.append(True)
else:
# This means info["TimeLimit.truncated"] exists but is False, which means the core environment had already terminated,
# but it also exceeded maximum timesteps at the same step. However to be compatible with envpool, and to be backward compatible
# truncated is set to False here.
assert dones[i]
terminateds.append(True)
truncateds.append(False)

return (
observations,
rewards,
np.array(terminateds, dtype=np.bool_) if is_vector_env else terminateds[0],
np.array(truncateds, dtype=np.bool_) if is_vector_env else truncateds[0],
infos,
)
# Cases to handle - info single env / info vector env (list) / info vector env (dict)
if is_vector_env is False:
truncated = infos.pop("TimeLimit.truncated", False)
return (
observations,
rewards,
dones and not truncated,
dones and truncated,
infos,
)
elif isinstance(infos, list):
truncated = np.array(
[info.pop("TimeLimit.truncated", False) for info in infos]
)
return (
observations,
rewards,
np.logical_and(dones, np.logical_not(truncated)),
np.logical_and(dones, truncated),
infos,
)
elif isinstance(infos, dict):
num_envs = len(dones)
truncated = infos.pop("TimeLimit.truncated", np.zeros(num_envs, dtype=bool))
return (
observations,
rewards,
np.logical_and(dones, np.logical_not(truncated)),
np.logical_and(dones, truncated),
infos,
)
else:
raise TypeError(
f"Unexpected value of infos, as is_vector_envs=False, expects `info` to be a list or dict, actual type: {type(infos)}"
)


def step_to_old_api(
Expand All @@ -111,44 +86,45 @@ def step_to_old_api(
return step_returns
else:
assert len(step_returns) == 5
observations, rewards, terminateds, truncateds, infos = step_returns
dones = []
if not is_vector_env:
terminateds = [terminateds]
truncateds = [truncateds]

n_envs = len(terminateds)

for i in range(n_envs):
dones.append(terminateds[i] or truncateds[i])
if truncateds[i]:
if is_vector_env:
# handle vector infos for dict and list
if isinstance(infos, dict):
if "TimeLimit.truncated" not in infos:
# TODO: This should ideally not be done manually and should use vector_env's _add_info()
infos["TimeLimit.truncated"] = np.zeros(n_envs, dtype=bool)
infos["_TimeLimit.truncated"] = np.zeros(n_envs, dtype=bool)

infos["TimeLimit.truncated"][i] = (
not terminateds[i] or infos["TimeLimit.truncated"][i]
)
infos["_TimeLimit.truncated"][i] = True
else:
# if vector info is a list
infos[i]["TimeLimit.truncated"] = not terminateds[i] or infos[
i
].get("TimeLimit.truncated", False)
else:
infos["TimeLimit.truncated"] = not terminateds[i] or infos.get(
"TimeLimit.truncated", False
)
return (
observations,
rewards,
np.array(dones, dtype=np.bool_) if is_vector_env else dones[0],
infos,
)
observations, rewards, terminated, truncated, infos = step_returns

# Cases to handle - info single env / info vector env (list) / info vector env (dict)
if is_vector_env is False:
if truncated or terminated:
infos["TimeLimit.truncated"] = truncated and not terminated
return (
observations,
rewards,
terminated or truncated,
infos,
)
elif isinstance(infos, list):
for info, env_truncated, env_terminated in zip(
infos, truncated, terminated
):
if env_truncated or env_terminated:
info["TimeLimit.truncated"] = env_truncated and not env_terminated
return (
observations,
rewards,
np.logical_or(terminated, truncated),
infos,
)
elif isinstance(infos, dict):
if np.logical_or(np.any(truncated), np.any(terminated)):
infos["TimeLimit.truncated"] = np.logical_and(
truncated, np.logical_not(terminated)
)
return (
observations,
rewards,
np.logical_or(terminated, truncated),
infos,
)
else:
raise TypeError(
f"Unexpected value of infos, as is_vector_envs=False, expects `info` to be a list or dict, actual type: {type(infos)}"
)


def step_api_compatibility(
Expand Down
7 changes: 5 additions & 2 deletions gym/wrappers/time_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
Args:
env: The environment to apply the wrapper
max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used)
max_episode_steps: An optional max episode steps (if ``None``, ``env.spec.max_episode_steps`` is used)
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
"""
super().__init__(env, new_step_api)
Expand Down Expand Up @@ -63,7 +63,10 @@ def step(self, action):
self._elapsed_steps += 1

if self._elapsed_steps >= self._max_episode_steps:
truncated = True
if self.new_step_api is True or terminated is False:
# As the old step api cannot encode both terminated and truncated, we favor terminated in the case of both.
# Therefore, if new step api (i.e. not old step api) or when terminated is False to prevent the overriding
truncated = True

return step_api_compatibility(
(observation, reward, terminated, truncated, info),
Expand Down
166 changes: 166 additions & 0 deletions tests/utils/test_step_api_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import numpy as np
import pytest

from gym.utils.env_checker import data_equivalence
from gym.utils.step_api_compatibility import step_to_new_api, step_to_old_api


@pytest.mark.parametrize(
"is_vector_env, done_returns, expected_terminated, expected_truncated",
(
# Test each of the permutations for single environments with and without the old info
(False, (0, 0, False, {"Test-info": True}), False, False),
(False, (0, 0, False, {"TimeLimit.truncated": False}), False, False),
(False, (0, 0, True, {}), True, False),
(False, (0, 0, True, {"TimeLimit.truncated": True}), False, True),
(False, (0, 0, True, {"Test-info": True}), True, False),
# Test vectorise versions with both list and dict infos testing each permutation for sub-environments
(
True,
(
0,
0,
np.array([False, True, True]),
[{}, {}, {"TimeLimit.truncated": True}],
),
np.array([False, True, False]),
np.array([False, False, True]),
),
(
True,
(
0,
0,
np.array([False, True, True]),
{"TimeLimit.truncated": np.array([False, False, True])},
),
np.array([False, True, False]),
np.array([False, False, True]),
),
# empty truncated info
(
True,
(
0,
0,
np.array([False, True]),
{},
),
np.array([False, True]),
np.array([False, False]),
),
),
)
def test_to_done_step_api(
is_vector_env, done_returns, expected_terminated, expected_truncated
):
_, _, terminated, truncated, info = step_to_new_api(
done_returns, is_vector_env=is_vector_env
)
assert np.all(terminated == expected_terminated)
assert np.all(truncated == expected_truncated)

if is_vector_env is False:
assert "TimeLimit.truncated" not in info
elif isinstance(info, list):
assert all("TimeLimit.truncated" not in sub_info for sub_info in info)
else: # isinstance(info, dict)
assert "TimeLimit.truncated" not in info

roundtripped_returns = step_to_old_api(
(0, 0, terminated, truncated, info), is_vector_env=is_vector_env
)
assert data_equivalence(done_returns, roundtripped_returns)


@pytest.mark.parametrize(
"is_vector_env, terminated_truncated_returns, expected_done, expected_truncated",
(
(False, (0, 0, False, False, {"Test-info": True}), False, False),
(False, (0, 0, True, False, {}), True, False),
(False, (0, 0, False, True, {}), True, True),
# (False, (), True, True), # Not possible to encode in the old step api
# Test vector dict info
(
True,
(0, 0, np.array([False, True, False]), np.array([False, False, True]), {}),
np.array([False, True, True]),
np.array([False, False, True]),
),
# Test vector dict info with no truncation
(
True,
(0, 0, np.array([False, True]), np.array([False, False]), {}),
np.array([False, True]),
np.array([False, False]),
),
# Test vector list info
(
True,
(
0,
0,
np.array([False, True, False]),
np.array([False, False, True]),
[{"Test-Info": True}, {}, {}],
),
np.array([False, True, True]),
np.array([False, False, True]),
),
),
)
def test_to_terminated_truncated_step_api(
is_vector_env, terminated_truncated_returns, expected_done, expected_truncated
):
_, _, done, info = step_to_old_api(
terminated_truncated_returns, is_vector_env=is_vector_env
)
assert np.all(done == expected_done)

if is_vector_env is False:
if expected_done:
assert info["TimeLimit.truncated"] == expected_truncated
else:
assert "TimeLimit.truncated" not in info
elif isinstance(info, list):
for sub_info, env_done, env_truncated in zip(
info, expected_done, expected_truncated
):
if env_done:
assert sub_info["TimeLimit.truncated"] == env_truncated
else:
assert "TimeLimit.truncated" not in sub_info
else: # isinstance(info, dict)
if np.any(expected_done):
assert np.all(info["TimeLimit.truncated"] == expected_truncated)
else:
assert "TimeLimit.truncated" not in info

roundtripped_returns = step_to_new_api(
(0, 0, done, info), is_vector_env=is_vector_env
)
assert data_equivalence(terminated_truncated_returns, roundtripped_returns)


def test_edge_case():
# When converting between the two-step APIs this is not possible in a single case
# terminated=True and truncated=True -> done=True and info={}
# We cannot test this in test_to_terminated_truncated_step_api as the roundtripping test will fail
_, _, done, info = step_to_old_api((0, 0, True, True, {}))
assert done is True
assert info == {"TimeLimit.truncated": False}

# Test with vector dict info
_, _, done, info = step_to_old_api(
(0, 0, np.array([True]), np.array([True]), {}), is_vector_env=True
)
assert np.all(done)
assert info == {"TimeLimit.truncated": np.array([False])}

# Test with vector list info
_, _, done, info = step_to_old_api(
(0, 0, np.array([True]), np.array([True]), [{"Test-Info": True}]),
is_vector_env=True,
)
assert np.all(done)
assert info == [{"Test-Info": True, "TimeLimit.truncated": False}]
Loading

0 comments on commit a8d4dd7

Please sign in to comment.