Skip to content

Commit a8d4dd7

Browse files
Add testing for step api compatibility functions and wrapper (#3028)
* Initial commit * Fixed tests and forced TimeLimit.truncated to always exist when truncated or terminated * Fix CI issues * pre-commit * Revert back to old language * Revert changes to step api wrapper
1 parent aa43d13 commit a8d4dd7

File tree

6 files changed

+247
-192
lines changed

6 files changed

+247
-192
lines changed

gym/utils/env_checker.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def data_equivalence(data_1, data_2) -> bool:
4545
return data_1.keys() == data_2.keys() and all(
4646
data_equivalence(data_1[k], data_2[k]) for k in data_1.keys()
4747
)
48-
elif isinstance(data_1, tuple):
48+
elif isinstance(data_1, (tuple, list)):
4949
return len(data_1) == len(data_2) and all(
5050
data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2)
5151
)

gym/utils/step_api_compatibility.py

+74-98
Original file line numberDiff line numberDiff line change
@@ -36,66 +36,41 @@ def step_to_new_api(
3636
assert len(step_returns) == 4
3737
observations, rewards, dones, infos = step_returns
3838

39-
terminateds = []
40-
truncateds = []
41-
if not is_vector_env:
42-
dones = [dones]
43-
44-
for i in range(len(dones)):
45-
# For every condition, handling - info single env / info vector env (list) / info vector env (dict)
46-
47-
# TimeLimit.truncated attribute not present - implies either terminated or episode still ongoing based on `done`
48-
if (not is_vector_env and "TimeLimit.truncated" not in infos) or (
49-
is_vector_env
50-
and (
51-
(
52-
isinstance(infos, list)
53-
and "TimeLimit.truncated" not in infos[i]
54-
) # vector env, list info api
55-
or (
56-
"TimeLimit.truncated" not in infos
57-
or (
58-
"TimeLimit.truncated" in infos
59-
and not infos["TimeLimit.truncated"][i]
60-
)
61-
)
62-
# vector env, dict info api, for env i, vector mask `_TimeLimit.truncated` is not considered, to be compatible with envpool
63-
# For env i, `TimeLimit.truncated` not being present is treated same as being present and set to False.
64-
# therefore, terminated=True, truncated=True simultaneously is not allowed while using compatibility functions
65-
# with vector info
66-
)
67-
):
68-
terminateds.append(dones[i])
69-
truncateds.append(False)
70-
71-
# This means info["TimeLimit.truncated"] exists and this elif checks if it is True, which means the truncation has occurred but termination has not.
72-
elif (
73-
infos["TimeLimit.truncated"]
74-
if not is_vector_env
75-
else (
76-
infos["TimeLimit.truncated"][i]
77-
if isinstance(infos, dict)
78-
else infos[i]["TimeLimit.truncated"]
79-
)
80-
):
81-
assert dones[i]
82-
terminateds.append(False)
83-
truncateds.append(True)
84-
else:
85-
# This means info["TimeLimit.truncated"] exists but is False, which means the core environment had already terminated,
86-
# but it also exceeded maximum timesteps at the same step. However to be compatible with envpool, and to be backward compatible
87-
# truncated is set to False here.
88-
assert dones[i]
89-
terminateds.append(True)
90-
truncateds.append(False)
91-
92-
return (
93-
observations,
94-
rewards,
95-
np.array(terminateds, dtype=np.bool_) if is_vector_env else terminateds[0],
96-
np.array(truncateds, dtype=np.bool_) if is_vector_env else truncateds[0],
97-
infos,
98-
)
39+
# Cases to handle - info single env / info vector env (list) / info vector env (dict)
40+
if is_vector_env is False:
41+
truncated = infos.pop("TimeLimit.truncated", False)
42+
return (
43+
observations,
44+
rewards,
45+
dones and not truncated,
46+
dones and truncated,
47+
infos,
48+
)
49+
elif isinstance(infos, list):
50+
truncated = np.array(
51+
[info.pop("TimeLimit.truncated", False) for info in infos]
52+
)
53+
return (
54+
observations,
55+
rewards,
56+
np.logical_and(dones, np.logical_not(truncated)),
57+
np.logical_and(dones, truncated),
58+
infos,
59+
)
60+
elif isinstance(infos, dict):
61+
num_envs = len(dones)
62+
truncated = infos.pop("TimeLimit.truncated", np.zeros(num_envs, dtype=bool))
63+
return (
64+
observations,
65+
rewards,
66+
np.logical_and(dones, np.logical_not(truncated)),
67+
np.logical_and(dones, truncated),
68+
infos,
69+
)
70+
else:
71+
raise TypeError(
72+
f"Unexpected value of infos, as is_vector_envs=False, expects `info` to be a list or dict, actual type: {type(infos)}"
73+
)
9974

10075

10176
def step_to_old_api(
@@ -111,44 +86,45 @@ def step_to_old_api(
11186
return step_returns
11287
else:
11388
assert len(step_returns) == 5
114-
observations, rewards, terminateds, truncateds, infos = step_returns
115-
dones = []
116-
if not is_vector_env:
117-
terminateds = [terminateds]
118-
truncateds = [truncateds]
119-
120-
n_envs = len(terminateds)
121-
122-
for i in range(n_envs):
123-
dones.append(terminateds[i] or truncateds[i])
124-
if truncateds[i]:
125-
if is_vector_env:
126-
# handle vector infos for dict and list
127-
if isinstance(infos, dict):
128-
if "TimeLimit.truncated" not in infos:
129-
# TODO: This should ideally not be done manually and should use vector_env's _add_info()
130-
infos["TimeLimit.truncated"] = np.zeros(n_envs, dtype=bool)
131-
infos["_TimeLimit.truncated"] = np.zeros(n_envs, dtype=bool)
132-
133-
infos["TimeLimit.truncated"][i] = (
134-
not terminateds[i] or infos["TimeLimit.truncated"][i]
135-
)
136-
infos["_TimeLimit.truncated"][i] = True
137-
else:
138-
# if vector info is a list
139-
infos[i]["TimeLimit.truncated"] = not terminateds[i] or infos[
140-
i
141-
].get("TimeLimit.truncated", False)
142-
else:
143-
infos["TimeLimit.truncated"] = not terminateds[i] or infos.get(
144-
"TimeLimit.truncated", False
145-
)
146-
return (
147-
observations,
148-
rewards,
149-
np.array(dones, dtype=np.bool_) if is_vector_env else dones[0],
150-
infos,
151-
)
89+
observations, rewards, terminated, truncated, infos = step_returns
90+
91+
# Cases to handle - info single env / info vector env (list) / info vector env (dict)
92+
if is_vector_env is False:
93+
if truncated or terminated:
94+
infos["TimeLimit.truncated"] = truncated and not terminated
95+
return (
96+
observations,
97+
rewards,
98+
terminated or truncated,
99+
infos,
100+
)
101+
elif isinstance(infos, list):
102+
for info, env_truncated, env_terminated in zip(
103+
infos, truncated, terminated
104+
):
105+
if env_truncated or env_terminated:
106+
info["TimeLimit.truncated"] = env_truncated and not env_terminated
107+
return (
108+
observations,
109+
rewards,
110+
np.logical_or(terminated, truncated),
111+
infos,
112+
)
113+
elif isinstance(infos, dict):
114+
if np.logical_or(np.any(truncated), np.any(terminated)):
115+
infos["TimeLimit.truncated"] = np.logical_and(
116+
truncated, np.logical_not(terminated)
117+
)
118+
return (
119+
observations,
120+
rewards,
121+
np.logical_or(terminated, truncated),
122+
infos,
123+
)
124+
else:
125+
raise TypeError(
126+
f"Unexpected value of infos, as is_vector_envs=False, expects `info` to be a list or dict, actual type: {type(infos)}"
127+
)
152128

153129

154130
def step_api_compatibility(

gym/wrappers/time_limit.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(
3434
3535
Args:
3636
env: The environment to apply the wrapper
37-
max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used)
37+
max_episode_steps: An optional max episode steps (if ``None``, ``env.spec.max_episode_steps`` is used)
3838
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
3939
"""
4040
super().__init__(env, new_step_api)
@@ -63,7 +63,10 @@ def step(self, action):
6363
self._elapsed_steps += 1
6464

6565
if self._elapsed_steps >= self._max_episode_steps:
66-
truncated = True
66+
if self.new_step_api is True or terminated is False:
67+
# As the old step api cannot encode both terminated and truncated, we favor terminated in the case of both.
68+
# Therefore, if new step api (i.e. not old step api) or when terminated is False to prevent the overriding
69+
truncated = True
6770

6871
return step_api_compatibility(
6972
(observation, reward, terminated, truncated, info),
+166
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import numpy as np
2+
import pytest
3+
4+
from gym.utils.env_checker import data_equivalence
5+
from gym.utils.step_api_compatibility import step_to_new_api, step_to_old_api
6+
7+
8+
@pytest.mark.parametrize(
9+
"is_vector_env, done_returns, expected_terminated, expected_truncated",
10+
(
11+
# Test each of the permutations for single environments with and without the old info
12+
(False, (0, 0, False, {"Test-info": True}), False, False),
13+
(False, (0, 0, False, {"TimeLimit.truncated": False}), False, False),
14+
(False, (0, 0, True, {}), True, False),
15+
(False, (0, 0, True, {"TimeLimit.truncated": True}), False, True),
16+
(False, (0, 0, True, {"Test-info": True}), True, False),
17+
# Test vectorise versions with both list and dict infos testing each permutation for sub-environments
18+
(
19+
True,
20+
(
21+
0,
22+
0,
23+
np.array([False, True, True]),
24+
[{}, {}, {"TimeLimit.truncated": True}],
25+
),
26+
np.array([False, True, False]),
27+
np.array([False, False, True]),
28+
),
29+
(
30+
True,
31+
(
32+
0,
33+
0,
34+
np.array([False, True, True]),
35+
{"TimeLimit.truncated": np.array([False, False, True])},
36+
),
37+
np.array([False, True, False]),
38+
np.array([False, False, True]),
39+
),
40+
# empty truncated info
41+
(
42+
True,
43+
(
44+
0,
45+
0,
46+
np.array([False, True]),
47+
{},
48+
),
49+
np.array([False, True]),
50+
np.array([False, False]),
51+
),
52+
),
53+
)
54+
def test_to_done_step_api(
55+
is_vector_env, done_returns, expected_terminated, expected_truncated
56+
):
57+
_, _, terminated, truncated, info = step_to_new_api(
58+
done_returns, is_vector_env=is_vector_env
59+
)
60+
assert np.all(terminated == expected_terminated)
61+
assert np.all(truncated == expected_truncated)
62+
63+
if is_vector_env is False:
64+
assert "TimeLimit.truncated" not in info
65+
elif isinstance(info, list):
66+
assert all("TimeLimit.truncated" not in sub_info for sub_info in info)
67+
else: # isinstance(info, dict)
68+
assert "TimeLimit.truncated" not in info
69+
70+
roundtripped_returns = step_to_old_api(
71+
(0, 0, terminated, truncated, info), is_vector_env=is_vector_env
72+
)
73+
assert data_equivalence(done_returns, roundtripped_returns)
74+
75+
76+
@pytest.mark.parametrize(
77+
"is_vector_env, terminated_truncated_returns, expected_done, expected_truncated",
78+
(
79+
(False, (0, 0, False, False, {"Test-info": True}), False, False),
80+
(False, (0, 0, True, False, {}), True, False),
81+
(False, (0, 0, False, True, {}), True, True),
82+
# (False, (), True, True), # Not possible to encode in the old step api
83+
# Test vector dict info
84+
(
85+
True,
86+
(0, 0, np.array([False, True, False]), np.array([False, False, True]), {}),
87+
np.array([False, True, True]),
88+
np.array([False, False, True]),
89+
),
90+
# Test vector dict info with no truncation
91+
(
92+
True,
93+
(0, 0, np.array([False, True]), np.array([False, False]), {}),
94+
np.array([False, True]),
95+
np.array([False, False]),
96+
),
97+
# Test vector list info
98+
(
99+
True,
100+
(
101+
0,
102+
0,
103+
np.array([False, True, False]),
104+
np.array([False, False, True]),
105+
[{"Test-Info": True}, {}, {}],
106+
),
107+
np.array([False, True, True]),
108+
np.array([False, False, True]),
109+
),
110+
),
111+
)
112+
def test_to_terminated_truncated_step_api(
113+
is_vector_env, terminated_truncated_returns, expected_done, expected_truncated
114+
):
115+
_, _, done, info = step_to_old_api(
116+
terminated_truncated_returns, is_vector_env=is_vector_env
117+
)
118+
assert np.all(done == expected_done)
119+
120+
if is_vector_env is False:
121+
if expected_done:
122+
assert info["TimeLimit.truncated"] == expected_truncated
123+
else:
124+
assert "TimeLimit.truncated" not in info
125+
elif isinstance(info, list):
126+
for sub_info, env_done, env_truncated in zip(
127+
info, expected_done, expected_truncated
128+
):
129+
if env_done:
130+
assert sub_info["TimeLimit.truncated"] == env_truncated
131+
else:
132+
assert "TimeLimit.truncated" not in sub_info
133+
else: # isinstance(info, dict)
134+
if np.any(expected_done):
135+
assert np.all(info["TimeLimit.truncated"] == expected_truncated)
136+
else:
137+
assert "TimeLimit.truncated" not in info
138+
139+
roundtripped_returns = step_to_new_api(
140+
(0, 0, done, info), is_vector_env=is_vector_env
141+
)
142+
assert data_equivalence(terminated_truncated_returns, roundtripped_returns)
143+
144+
145+
def test_edge_case():
146+
# When converting between the two-step APIs this is not possible in a single case
147+
# terminated=True and truncated=True -> done=True and info={}
148+
# We cannot test this in test_to_terminated_truncated_step_api as the roundtripping test will fail
149+
_, _, done, info = step_to_old_api((0, 0, True, True, {}))
150+
assert done is True
151+
assert info == {"TimeLimit.truncated": False}
152+
153+
# Test with vector dict info
154+
_, _, done, info = step_to_old_api(
155+
(0, 0, np.array([True]), np.array([True]), {}), is_vector_env=True
156+
)
157+
assert np.all(done)
158+
assert info == {"TimeLimit.truncated": np.array([False])}
159+
160+
# Test with vector list info
161+
_, _, done, info = step_to_old_api(
162+
(0, 0, np.array([True]), np.array([True]), [{"Test-Info": True}]),
163+
is_vector_env=True,
164+
)
165+
assert np.all(done)
166+
assert info == [{"Test-Info": True, "TimeLimit.truncated": False}]

0 commit comments

Comments
 (0)