-
Notifications
You must be signed in to change notification settings - Fork 8.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add testing for step api compatibility functions and wrapper (#3028)
* Initial commit * Fixed tests and forced TimeLimit.truncated to always exist when truncated or terminated * Fix CI issues * pre-commit * Revert back to old language * Revert changes to step api wrapper
- Loading branch information
1 parent
aa43d13
commit a8d4dd7
Showing
6 changed files
with
247 additions
and
192 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from gym.utils.env_checker import data_equivalence | ||
from gym.utils.step_api_compatibility import step_to_new_api, step_to_old_api | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"is_vector_env, done_returns, expected_terminated, expected_truncated", | ||
( | ||
# Test each of the permutations for single environments with and without the old info | ||
(False, (0, 0, False, {"Test-info": True}), False, False), | ||
(False, (0, 0, False, {"TimeLimit.truncated": False}), False, False), | ||
(False, (0, 0, True, {}), True, False), | ||
(False, (0, 0, True, {"TimeLimit.truncated": True}), False, True), | ||
(False, (0, 0, True, {"Test-info": True}), True, False), | ||
# Test vectorise versions with both list and dict infos testing each permutation for sub-environments | ||
( | ||
True, | ||
( | ||
0, | ||
0, | ||
np.array([False, True, True]), | ||
[{}, {}, {"TimeLimit.truncated": True}], | ||
), | ||
np.array([False, True, False]), | ||
np.array([False, False, True]), | ||
), | ||
( | ||
True, | ||
( | ||
0, | ||
0, | ||
np.array([False, True, True]), | ||
{"TimeLimit.truncated": np.array([False, False, True])}, | ||
), | ||
np.array([False, True, False]), | ||
np.array([False, False, True]), | ||
), | ||
# empty truncated info | ||
( | ||
True, | ||
( | ||
0, | ||
0, | ||
np.array([False, True]), | ||
{}, | ||
), | ||
np.array([False, True]), | ||
np.array([False, False]), | ||
), | ||
), | ||
) | ||
def test_to_done_step_api( | ||
is_vector_env, done_returns, expected_terminated, expected_truncated | ||
): | ||
_, _, terminated, truncated, info = step_to_new_api( | ||
done_returns, is_vector_env=is_vector_env | ||
) | ||
assert np.all(terminated == expected_terminated) | ||
assert np.all(truncated == expected_truncated) | ||
|
||
if is_vector_env is False: | ||
assert "TimeLimit.truncated" not in info | ||
elif isinstance(info, list): | ||
assert all("TimeLimit.truncated" not in sub_info for sub_info in info) | ||
else: # isinstance(info, dict) | ||
assert "TimeLimit.truncated" not in info | ||
|
||
roundtripped_returns = step_to_old_api( | ||
(0, 0, terminated, truncated, info), is_vector_env=is_vector_env | ||
) | ||
assert data_equivalence(done_returns, roundtripped_returns) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"is_vector_env, terminated_truncated_returns, expected_done, expected_truncated", | ||
( | ||
(False, (0, 0, False, False, {"Test-info": True}), False, False), | ||
(False, (0, 0, True, False, {}), True, False), | ||
(False, (0, 0, False, True, {}), True, True), | ||
# (False, (), True, True), # Not possible to encode in the old step api | ||
# Test vector dict info | ||
( | ||
True, | ||
(0, 0, np.array([False, True, False]), np.array([False, False, True]), {}), | ||
np.array([False, True, True]), | ||
np.array([False, False, True]), | ||
), | ||
# Test vector dict info with no truncation | ||
( | ||
True, | ||
(0, 0, np.array([False, True]), np.array([False, False]), {}), | ||
np.array([False, True]), | ||
np.array([False, False]), | ||
), | ||
# Test vector list info | ||
( | ||
True, | ||
( | ||
0, | ||
0, | ||
np.array([False, True, False]), | ||
np.array([False, False, True]), | ||
[{"Test-Info": True}, {}, {}], | ||
), | ||
np.array([False, True, True]), | ||
np.array([False, False, True]), | ||
), | ||
), | ||
) | ||
def test_to_terminated_truncated_step_api( | ||
is_vector_env, terminated_truncated_returns, expected_done, expected_truncated | ||
): | ||
_, _, done, info = step_to_old_api( | ||
terminated_truncated_returns, is_vector_env=is_vector_env | ||
) | ||
assert np.all(done == expected_done) | ||
|
||
if is_vector_env is False: | ||
if expected_done: | ||
assert info["TimeLimit.truncated"] == expected_truncated | ||
else: | ||
assert "TimeLimit.truncated" not in info | ||
elif isinstance(info, list): | ||
for sub_info, env_done, env_truncated in zip( | ||
info, expected_done, expected_truncated | ||
): | ||
if env_done: | ||
assert sub_info["TimeLimit.truncated"] == env_truncated | ||
else: | ||
assert "TimeLimit.truncated" not in sub_info | ||
else: # isinstance(info, dict) | ||
if np.any(expected_done): | ||
assert np.all(info["TimeLimit.truncated"] == expected_truncated) | ||
else: | ||
assert "TimeLimit.truncated" not in info | ||
|
||
roundtripped_returns = step_to_new_api( | ||
(0, 0, done, info), is_vector_env=is_vector_env | ||
) | ||
assert data_equivalence(terminated_truncated_returns, roundtripped_returns) | ||
|
||
|
||
def test_edge_case(): | ||
# When converting between the two-step APIs this is not possible in a single case | ||
# terminated=True and truncated=True -> done=True and info={} | ||
# We cannot test this in test_to_terminated_truncated_step_api as the roundtripping test will fail | ||
_, _, done, info = step_to_old_api((0, 0, True, True, {})) | ||
assert done is True | ||
assert info == {"TimeLimit.truncated": False} | ||
|
||
# Test with vector dict info | ||
_, _, done, info = step_to_old_api( | ||
(0, 0, np.array([True]), np.array([True]), {}), is_vector_env=True | ||
) | ||
assert np.all(done) | ||
assert info == {"TimeLimit.truncated": np.array([False])} | ||
|
||
# Test with vector list info | ||
_, _, done, info = step_to_old_api( | ||
(0, 0, np.array([True]), np.array([True]), [{"Test-Info": True}]), | ||
is_vector_env=True, | ||
) | ||
assert np.all(done) | ||
assert info == [{"Test-Info": True, "TimeLimit.truncated": False}] |
Oops, something went wrong.