From fc1a70ff16ac95918196ae26a842112ce98f03bc Mon Sep 17 00:00:00 2001 From: StringTheory Date: Tue, 16 Aug 2022 03:20:35 +0100 Subject: [PATCH 1/6] Initial commit --- gym/utils/env_checker.py | 2 +- gym/wrappers/step_api_compatibility.py | 31 ++-- gym/wrappers/time_limit.py | 11 +- tests/utils/test_step_api_compatibility.py | 160 +++++++++++++++++++++ tests/utils/test_terminated_truncated.py | 91 ------------ 5 files changed, 185 insertions(+), 110 deletions(-) create mode 100644 tests/utils/test_step_api_compatibility.py delete mode 100644 tests/utils/test_terminated_truncated.py diff --git a/gym/utils/env_checker.py b/gym/utils/env_checker.py index 95355e2a6e5..5a5564b571b 100644 --- a/gym/utils/env_checker.py +++ b/gym/utils/env_checker.py @@ -45,7 +45,7 @@ def data_equivalence(data_1, data_2) -> bool: return data_1.keys() == data_2.keys() and all( data_equivalence(data_1[k], data_2[k]) for k in data_1.keys() ) - elif isinstance(data_1, tuple): + elif isinstance(data_1, (tuple, list)): return len(data_1) == len(data_2) and all( data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2) ) diff --git a/gym/wrappers/step_api_compatibility.py b/gym/wrappers/step_api_compatibility.py index 72d4c8a1e07..bddb3a0a847 100644 --- a/gym/wrappers/step_api_compatibility.py +++ b/gym/wrappers/step_api_compatibility.py @@ -1,7 +1,10 @@ """Implementation of StepAPICompatibility wrapper class for transforming envs between new and old step API.""" import gym from gym.logger import deprecation -from gym.utils.step_api_compatibility import step_to_new_api, step_to_old_api +from gym.utils.step_api_compatibility import ( + to_done_step_api, + to_terminated_truncated_step_api, +) class StepAPICompatibility(gym.Wrapper): @@ -15,33 +18,33 @@ class StepAPICompatibility(gym.Wrapper): Args: env (gym.Env): the env to wrap. Can be in old or new API - new_step_api (bool): True to use env with new step API, False to use env with old step API. (False by default) + to_termination_truncation_api (bool): True to use env with new step API, False to use env with old step API. (False by default) Examples: >>> env = gym.make("CartPole-v1") >>> env # wrapper applied by default, set to old API >>>> - >>> env = gym.make("CartPole-v1", new_step_api=True) # set to new API - >>> env = StepAPICompatibility(CustomEnv(), new_step_api=True) # manually using wrapper on unregistered envs + >>> env = gym.make("CartPole-v1", to_termination_truncation_api=True) # set to new API + >>> env = StepAPICompatibility(CustomEnv(), to_termination_truncation_api=True) # manually using wrapper on unregistered envs """ - def __init__(self, env: gym.Env, new_step_api=False): + def __init__(self, env: gym.Env, to_termination_truncation_api: bool = False): """A wrapper which can transform an environment from new step API to old and vice-versa. Args: env (gym.Env): the env to wrap. Can be in old or new API - new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) + to_termination_truncation_api (bool): Whether the wrapper's step method outputs two booleans (termination and truncation) with True, or one boolean (done) with False """ - super().__init__(env, new_step_api) - self.new_step_api = new_step_api - if not self.new_step_api: + super().__init__(env, to_termination_truncation_api) + self.to_termination_truncation_api = to_termination_truncation_api + if self.to_termination_truncation_api is False: deprecation( - "Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future." + "Initializing environment in old done step API which returns one bool instead of two. It is recommended to set `to_termination_truncation_api=True` to use new step API. This will be the default behaviour in future." ) def step(self, action): - """Steps through the environment, returning 5 or 4 items depending on `new_step_api`. + """Steps through the environment, returning 5 or 4 items depending on `to_termination_truncation_api`. Args: action: action to step through the environment with @@ -50,7 +53,7 @@ def step(self, action): (observation, reward, terminated, truncated, info) or (observation, reward, done, info) """ step_returns = self.env.step(action) - if self.new_step_api: - return step_to_new_api(step_returns) + if self.to_termination_truncation_api: + return to_terminated_truncated_step_api(step_returns) else: - return step_to_old_api(step_returns) + return to_done_step_api(step_returns) diff --git a/gym/wrappers/time_limit.py b/gym/wrappers/time_limit.py index 8e9f67f4ae9..df985f50b08 100644 --- a/gym/wrappers/time_limit.py +++ b/gym/wrappers/time_limit.py @@ -34,7 +34,7 @@ def __init__( Args: env: The environment to apply the wrapper - max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used) + max_episode_steps: An optional max episode steps (if ``None``, ``env.spec.max_episode_steps`` is used) new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ super().__init__(env, new_step_api) @@ -58,16 +58,19 @@ def step(self, action): """ observation, reward, terminated, truncated, info = step_api_compatibility( self.env.step(action), - True, + to_termination_truncation=True, ) self._elapsed_steps += 1 if self._elapsed_steps >= self._max_episode_steps: - truncated = True + if self.new_step_api is True or terminated is False: + # As the old step api cannot encode both terminated and truncated, we favor terminated in the case of both. + # Therefore, if new step api (i.e. not old step api) or when terminated is False to prevent the overriding + truncated = True return step_api_compatibility( (observation, reward, terminated, truncated, info), - self.new_step_api, + to_termination_truncation=self.new_step_api, ) def reset(self, **kwargs): diff --git a/tests/utils/test_step_api_compatibility.py b/tests/utils/test_step_api_compatibility.py new file mode 100644 index 00000000000..fa8e4d9b65a --- /dev/null +++ b/tests/utils/test_step_api_compatibility.py @@ -0,0 +1,160 @@ +import numpy as np +import pytest + +from gym.utils.env_checker import data_equivalence +from gym.utils.step_api_compatibility import ( + to_done_step_api, + to_terminated_truncated_step_api, +) + + +@pytest.mark.parametrize( + "is_vector_env, done_returns, expected_terminated, expected_truncated", + ( + # Test each of the permutations for single environments with and without the old info + (False, (0, 0, False, {"Test-info": True}), False, False), + (False, (0, 0, False, {"TimeLimit.truncated": False}), False, False), + (False, (0, 0, True, {}), True, False), + (False, (0, 0, True, {"TimeLimit.truncated": True}), False, True), + (False, (0, 0, True, {"Test-info": True}), True, False), + # Test vectorise versions with both list and dict infos testing each permutation for sub-environments + ( + True, + ( + 0, + 0, + np.array([False, True, True]), + [{}, {}, {"TimeLimit.truncated": True}], + ), + np.array([False, True, False]), + np.array([False, False, True]), + ), + ( + True, + ( + 0, + 0, + np.array([False, True, True]), + {"TimeLimit.truncated": np.array([False, False, True])}, + ), + np.array([False, True, False]), + np.array([False, False, True]), + ), + # empty truncated info + ( + True, + ( + 0, + 0, + np.array([False, True]), + {}, + ), + np.array([False, True]), + np.array([False, False]), + ), + ), +) +def test_to_done_step_api( + is_vector_env, done_returns, expected_terminated, expected_truncated +): + _, _, terminated, truncated, info = to_terminated_truncated_step_api( + done_returns, is_vector_env=is_vector_env + ) + assert np.all(terminated == expected_terminated) + assert np.all(truncated == expected_truncated) + if is_vector_env is False: + assert "TimeLimit.truncated" not in info + elif isinstance(info, list): + assert all("TimeLimit.truncated" not in sub_info for sub_info in info) + else: # isinstance(info, dict) + assert "TimeLimit.truncated" not in info + + roundtripped_returns = to_done_step_api( + (0, 0, terminated, truncated, info), is_vector_env=is_vector_env + ) + assert data_equivalence(done_returns, roundtripped_returns) + + +@pytest.mark.parametrize( + "is_vector_env, terminated_truncated_returns, expected_done, expected_truncated", + ( + (False, (0, 0, False, False, {"Test-info": True}), False, False), + (False, (0, 0, True, False, {}), True, False), + (False, (0, 0, False, True, {}), True, True), + # (False, (), True, True), # Not possible to encode in the old step api + # Test vector dict info + ( + True, + (0, 0, np.array([False, True, False]), np.array([False, False, True]), {}), + np.array([False, True, True]), + np.array([False, False, True]), + ), + # Test vector dict info with no truncation + ( + True, + (0, 0, np.array([False, True]), np.array([False, False]), {}), + np.array([False, True]), + np.array([False, False]), + ), + # Test vector list info + ( + True, + ( + 0, + 0, + np.array([False, True, False]), + np.array([False, False, True]), + [{}, {}, {}], + ), + np.array([False, True, True]), + np.array([False, False, True]), + ), + ), +) +def test_to_terminated_truncated_step_api( + is_vector_env, terminated_truncated_returns, expected_done, expected_truncated +): + _, _, done, info = to_done_step_api( + terminated_truncated_returns, is_vector_env=is_vector_env + ) + assert np.all(done == expected_done) + if is_vector_env is False: + if expected_truncated: + assert info["TimeLimit.truncated"] == expected_truncated + else: + assert "TimeLimit.truncated" not in info + elif isinstance(info, list): + for sub_info, trunc in zip(info, expected_truncated): + if trunc: + assert sub_info["TimeLimit.truncated"] == trunc + else: + assert "TimeLimit.truncated" not in sub_info + else: # isinstance(info, dict) + if np.any(expected_truncated): + assert np.all(info["TimeLimit.truncated"] == expected_truncated) + else: + assert "TimeLimit.truncated" not in info + + roundtripped_returns = to_terminated_truncated_step_api( + (0, 0, done, info), is_vector_env=is_vector_env + ) + assert data_equivalence(terminated_truncated_returns, roundtripped_returns) + + +def test_edge_case(): + # When converting between the two-step APIs this is not possible in a single case + # terminated=True and truncated=True -> done=True and info={} + # We cannot test this in test_to_terminated_truncated_step_api as the roundtripping test will fail + _, _, done, info = to_done_step_api((0, 0, True, True, {})) + assert done is True + assert info == {} + + _, _, done, info = to_done_step_api((0, 0, np.array([True]), np.array([True]), {})) + assert np.all(done) + assert info == {} + + _, _, done, info = to_done_step_api( + (0, 0, np.array([True]), np.array([True]), [{"Test-Info": True}]) + ) + assert np.all(done) + assert info == [{"Test-Info": True}] diff --git a/tests/utils/test_terminated_truncated.py b/tests/utils/test_terminated_truncated.py deleted file mode 100644 index e74fdc85378..00000000000 --- a/tests/utils/test_terminated_truncated.py +++ /dev/null @@ -1,91 +0,0 @@ -import pytest - -import gym -from gym.spaces import Discrete -from gym.vector import AsyncVectorEnv, SyncVectorEnv -from gym.wrappers import TimeLimit - - -# An environment where termination happens after 20 steps -class DummyEnv(gym.Env): - def __init__(self): - self.action_space = Discrete(2) - self.observation_space = Discrete(2) - self.terminal_timestep = 20 - - self.timestep = 0 - - def step(self, action): - self.timestep += 1 - terminated = True if self.timestep >= self.terminal_timestep else False - truncated = False - - return 0, 0, terminated, truncated, {} - - def reset(self): - self.timestep = 0 - return 0 - - -@pytest.mark.parametrize("time_limit", [10, 20, 30]) -def test_terminated_truncated(time_limit): - test_env = TimeLimit(DummyEnv(), time_limit, new_step_api=True) - - terminated = False - truncated = False - test_env.reset() - while not (terminated or truncated): - _, _, terminated, truncated, _ = test_env.step(0) - - if test_env.terminal_timestep < time_limit: - assert terminated - assert not truncated - elif test_env.terminal_timestep == time_limit: - assert ( - terminated - ), "`terminated` should be True even when termination and truncation happen at the same step" - assert ( - truncated - ), "`truncated` should be True even when termination and truncation occur at same step " - else: - assert not terminated - assert truncated - - -def test_terminated_truncated_vector(): - env0 = TimeLimit(DummyEnv(), 10, new_step_api=True) - env1 = TimeLimit(DummyEnv(), 20, new_step_api=True) - env2 = TimeLimit(DummyEnv(), 30, new_step_api=True) - - async_env = AsyncVectorEnv( - [lambda: env0, lambda: env1, lambda: env2], new_step_api=True - ) - async_env.reset() - terminateds = [False, False, False] - truncateds = [False, False, False] - counter = 0 - while not all([x or y for x, y in zip(terminateds, truncateds)]): - counter += 1 - _, _, terminateds, truncateds, _ = async_env.step( - async_env.action_space.sample() - ) - print(counter) - assert counter == 20 - assert all(terminateds == [False, True, True]) - assert all(truncateds == [True, True, False]) - - sync_env = SyncVectorEnv( - [lambda: env0, lambda: env1, lambda: env2], new_step_api=True - ) - sync_env.reset() - terminateds = [False, False, False] - truncateds = [False, False, False] - counter = 0 - while not all([x or y for x, y in zip(terminateds, truncateds)]): - counter += 1 - _, _, terminateds, truncateds, _ = sync_env.step( - async_env.action_space.sample() - ) - assert counter == 20 - assert all(terminateds == [False, True, True]) - assert all(truncateds == [True, True, False]) From 84f1829142eca5f88aeb73a1d24c591353d41ed6 Mon Sep 17 00:00:00 2001 From: StringTheory Date: Tue, 16 Aug 2022 11:36:42 +0100 Subject: [PATCH 2/6] Fixed tests and forced TimeLimit.truncated to always exist when truncated or terminated --- gym/utils/step_api_compatibility.py | 210 ++++++++++----------- gym/vector/sync_vector_env.py | 2 +- tests/utils/test_step_api_compatibility.py | 31 +-- tests/wrappers/test_step_compatibility.py | 2 +- 4 files changed, 118 insertions(+), 127 deletions(-) diff --git a/gym/utils/step_api_compatibility.py b/gym/utils/step_api_compatibility.py index 2be07dbe35c..58d794f460a 100644 --- a/gym/utils/step_api_compatibility.py +++ b/gym/utils/step_api_compatibility.py @@ -5,14 +5,14 @@ from gym.core import ObsType -OldStepType = Tuple[ +DoneStepType = Tuple[ Union[ObsType, np.ndarray], Union[float, np.ndarray], Union[bool, np.ndarray], Union[dict, list], ] -NewStepType = Tuple[ +TerminationTruncationStepType = Tuple[ Union[ObsType, np.ndarray], Union[float, np.ndarray], Union[bool, np.ndarray], @@ -21,9 +21,10 @@ ] -def step_to_new_api( - step_returns: Union[OldStepType, NewStepType], is_vector_env=False -) -> NewStepType: +def to_terminated_truncated_step_api( + step_returns: Union[DoneStepType, TerminationTruncationStepType], + is_vector_env=False, +) -> TerminationTruncationStepType: """Function to transform step returns to new step API irrespective of input API. Args: @@ -36,67 +37,47 @@ def step_to_new_api( assert len(step_returns) == 4 observations, rewards, dones, infos = step_returns - terminateds = [] - truncateds = [] - if not is_vector_env: - dones = [dones] - - for i in range(len(dones)): - # For every condition, handling - info single env / info vector env (list) / info vector env (dict) - - # TimeLimit.truncated attribute not present - implies either terminated or episode still ongoing based on `done` - if (not is_vector_env and "TimeLimit.truncated" not in infos) or ( - is_vector_env - and ( - ( - isinstance(infos, list) - and "TimeLimit.truncated" not in infos[i] - ) # vector env, list info api - or ( - "TimeLimit.truncated" not in infos - or ( - "TimeLimit.truncated" in infos - and not infos["_TimeLimit.truncated"][i] - ) - ) # vector env, dict info api, if mask is False, it's the same as TimeLimit.truncated attribute not being present for env 'i' - ) - ): - - terminateds.append(dones[i]) - truncateds.append(False) - - # This means info["TimeLimit.truncated"] exists and this elif checks if it is True, which means the truncation has occurred but termination has not. - elif ( - infos["TimeLimit.truncated"] - if not is_vector_env - else ( - infos["TimeLimit.truncated"][i] - if isinstance(infos, dict) - else infos[i]["TimeLimit.truncated"] - ) - ): - assert dones[i] - terminateds.append(False) - truncateds.append(True) - else: - # This means info["TimeLimit.truncated"] exists but is False, which means the core environment had already terminated, - # but it also exceeded maximum timesteps at the same step. - assert dones[i] - terminateds.append(True) - truncateds.append(True) - - return ( - observations, - rewards, - np.array(terminateds, dtype=np.bool_) if is_vector_env else terminateds[0], - np.array(truncateds, dtype=np.bool_) if is_vector_env else truncateds[0], - infos, - ) - - -def step_to_old_api( - step_returns: Union[NewStepType, OldStepType], is_vector_env: bool = False -) -> OldStepType: + # Cases to handle - info single env / info vector env (list) / info vector env (dict) + if is_vector_env is False: + truncated = infos.pop("TimeLimit.truncated", False) + return ( + observations, + rewards, + dones and not truncated, + dones and truncated, + infos, + ) + elif isinstance(infos, list): + truncated = np.array( + [info.pop("TimeLimit.truncated", False) for info in infos] + ) + return ( + observations, + rewards, + np.logical_and(dones, np.logical_not(truncated)), + np.logical_and(dones, truncated), + infos, + ) + elif isinstance(infos, dict): + num_envs = len(dones) + truncated = infos.pop("TimeLimit.truncated", np.zeros(num_envs, dtype=bool)) + return ( + observations, + rewards, + np.logical_and(dones, np.logical_not(truncated)), + np.logical_and(dones, truncated), + infos, + ) + else: + raise TypeError( + f"Unexpected value of infos, as is_vector_envs=False, expects `info` to be a list or dict, actual type: {type(infos)}" + ) + + +def to_done_step_api( + step_returns: Union[TerminationTruncationStepType, DoneStepType], + is_vector_env: bool = False, +) -> DoneStepType: """Function to transform step returns to old step API irrespective of input API. Args: @@ -107,60 +88,61 @@ def step_to_old_api( return step_returns else: assert len(step_returns) == 5 - observations, rewards, terminateds, truncateds, infos = step_returns - dones = [] - if not is_vector_env: - terminateds = [terminateds] - truncateds = [truncateds] - - n_envs = len(terminateds) - - for i in range(n_envs): - dones.append(terminateds[i] or truncateds[i]) - if truncateds[i]: - if is_vector_env: - # handle vector infos for dict and list - if isinstance(infos, dict): - if "TimeLimit.truncated" not in infos: - # TODO: This should ideally not be done manually and should use vector_env's _add_info() - infos["TimeLimit.truncated"] = np.zeros(n_envs, dtype=bool) - infos["_TimeLimit.truncated"] = np.zeros(n_envs, dtype=bool) - - infos["TimeLimit.truncated"][i] = ( - not terminateds[i] or infos["TimeLimit.truncated"][i] - ) - infos["_TimeLimit.truncated"][i] = True - else: - # if vector info is a list - infos[i]["TimeLimit.truncated"] = not terminateds[i] or infos[ - i - ].get("TimeLimit.truncated", False) - else: - infos["TimeLimit.truncated"] = not terminateds[i] or infos.get( - "TimeLimit.truncated", False - ) - return ( - observations, - rewards, - np.array(dones, dtype=np.bool_) if is_vector_env else dones[0], - infos, - ) + observations, rewards, terminated, truncated, infos = step_returns + + # Cases to handle - info single env / info vector env (list) / info vector env (dict) + if is_vector_env is False: + if truncated or terminated: + infos["TimeLimit.truncated"] = truncated and not terminated + return ( + observations, + rewards, + terminated or truncated, + infos, + ) + elif isinstance(infos, list): + for info, env_truncated, env_terminated in zip( + infos, truncated, terminated + ): + if env_truncated or env_terminated: + info["TimeLimit.truncated"] = env_truncated and not env_terminated + return ( + observations, + rewards, + np.logical_or(terminated, truncated), + infos, + ) + elif isinstance(infos, dict): + if np.logical_or(np.any(truncated), np.any(terminated)): + infos["TimeLimit.truncated"] = np.logical_and( + truncated, np.logical_not(terminated) + ) + return ( + observations, + rewards, + np.logical_or(terminated, truncated), + infos, + ) + else: + raise TypeError( + f"Unexpected value of infos, as is_vector_envs=False, expects `info` to be a list or dict, actual type: {type(infos)}" + ) def step_api_compatibility( - step_returns: Union[NewStepType, OldStepType], - new_step_api: bool = False, + step_returns: Union[TerminationTruncationStepType, DoneStepType], + to_termination_truncation: bool = False, is_vector_env: bool = False, -) -> Union[NewStepType, OldStepType]: +) -> Union[TerminationTruncationStepType, DoneStepType]: """Function to transform step returns to the API specified by `new_step_api` bool. - Old step API refers to step() method returning (observation, reward, done, info) - New step API refers to step() method returning (observation, reward, terminated, truncated, info) + Done step API refers to step() method returning (observation, reward, done, info) + Termination Truncation step API refers to step() method returning (observation, reward, terminated, truncated, info) (Refer to docs for details on the API change) Args: step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info) - new_step_api (bool): Whether the output should be in new step API or old (False by default) + to_termination_truncation (bool): Whether the output should be in new step API or old (False by default) is_vector_env (bool): Whether the step_returns are from a vector environment Returns: @@ -171,10 +153,10 @@ def step_api_compatibility( wrapper is written in new API, and the final step output is desired to be in old API. >>> obs, rew, done, info = step_api_compatibility(env.step(action)) - >>> obs, rew, terminated, truncated, info = step_api_compatibility(env.step(action), new_step_api=True) + >>> obs, rew, terminated, truncated, info = step_api_compatibility(env.step(action), to_termination_truncation=True) >>> observations, rewards, dones, infos = step_api_compatibility(vec_env.step(action), is_vector_env=True) """ - if new_step_api: - return step_to_new_api(step_returns, is_vector_env) + if to_termination_truncation: + return to_terminated_truncated_step_api(step_returns, is_vector_env) else: - return step_to_old_api(step_returns, is_vector_env) + return to_done_step_api(step_returns, is_vector_env) diff --git a/gym/vector/sync_vector_env.py b/gym/vector/sync_vector_env.py index cc3408e7adb..d6bdb793192 100644 --- a/gym/vector/sync_vector_env.py +++ b/gym/vector/sync_vector_env.py @@ -180,7 +180,7 @@ def step_wait(self): np.copy(self._truncateds), infos, ), - new_step_api=self.new_step_api, + to_termination_truncation=self.new_step_api, is_vector_env=True, ) diff --git a/tests/utils/test_step_api_compatibility.py b/tests/utils/test_step_api_compatibility.py index fa8e4d9b65a..c66157e2138 100644 --- a/tests/utils/test_step_api_compatibility.py +++ b/tests/utils/test_step_api_compatibility.py @@ -62,6 +62,7 @@ def test_to_done_step_api( ) assert np.all(terminated == expected_terminated) assert np.all(truncated == expected_truncated) + if is_vector_env is False: assert "TimeLimit.truncated" not in info elif isinstance(info, list): @@ -104,7 +105,7 @@ def test_to_done_step_api( 0, np.array([False, True, False]), np.array([False, False, True]), - [{}, {}, {}], + [{"Test-Info": True}, {}, {}], ), np.array([False, True, True]), np.array([False, False, True]), @@ -118,19 +119,22 @@ def test_to_terminated_truncated_step_api( terminated_truncated_returns, is_vector_env=is_vector_env ) assert np.all(done == expected_done) + if is_vector_env is False: - if expected_truncated: + if expected_done: assert info["TimeLimit.truncated"] == expected_truncated else: assert "TimeLimit.truncated" not in info elif isinstance(info, list): - for sub_info, trunc in zip(info, expected_truncated): - if trunc: - assert sub_info["TimeLimit.truncated"] == trunc + for sub_info, env_done, env_truncated in zip( + info, expected_done, expected_truncated + ): + if env_done: + assert sub_info["TimeLimit.truncated"] == env_truncated else: assert "TimeLimit.truncated" not in sub_info else: # isinstance(info, dict) - if np.any(expected_truncated): + if np.any(expected_done): assert np.all(info["TimeLimit.truncated"] == expected_truncated) else: assert "TimeLimit.truncated" not in info @@ -147,14 +151,19 @@ def test_edge_case(): # We cannot test this in test_to_terminated_truncated_step_api as the roundtripping test will fail _, _, done, info = to_done_step_api((0, 0, True, True, {})) assert done is True - assert info == {} + assert info == {"TimeLimit.truncated": False} - _, _, done, info = to_done_step_api((0, 0, np.array([True]), np.array([True]), {})) + # Test with vector dict info + _, _, done, info = to_done_step_api( + (0, 0, np.array([True]), np.array([True]), {}), is_vector_env=True + ) assert np.all(done) - assert info == {} + assert info == {"TimeLimit.truncated": np.array([False])} + # Test with vector list info _, _, done, info = to_done_step_api( - (0, 0, np.array([True]), np.array([True]), [{"Test-Info": True}]) + (0, 0, np.array([True]), np.array([True]), [{"Test-Info": True}]), + is_vector_env=True, ) assert np.all(done) - assert info == [{"Test-Info": True}] + assert info == [{"Test-Info": True, "TimeLimit.truncated": False}] diff --git a/tests/wrappers/test_step_compatibility.py b/tests/wrappers/test_step_compatibility.py index 83557f02db6..7f0f1d2c798 100644 --- a/tests/wrappers/test_step_compatibility.py +++ b/tests/wrappers/test_step_compatibility.py @@ -58,7 +58,7 @@ def test_step_compatibility_to_old_api(env, new_step_api): def test_step_compatibility_in_make(new_step_api): if new_step_api is None: with pytest.warns( - DeprecationWarning, match="Initializing environment in old step API" + DeprecationWarning, match="Initializing environment in old done step API" ): env = gym.make("CartPole-v1") else: From e34722e7f77da19d98120e2c4056cfa800b75094 Mon Sep 17 00:00:00 2001 From: StringTheory Date: Tue, 16 Aug 2022 13:04:31 +0100 Subject: [PATCH 3/6] Fix CI issues --- gym/wrappers/step_api_compatibility.py | 2 +- tests/wrappers/test_autoreset.py | 1 + tests/wrappers/test_step_compatibility.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/gym/wrappers/step_api_compatibility.py b/gym/wrappers/step_api_compatibility.py index bddb3a0a847..bb88a2652c0 100644 --- a/gym/wrappers/step_api_compatibility.py +++ b/gym/wrappers/step_api_compatibility.py @@ -40,7 +40,7 @@ def __init__(self, env: gym.Env, to_termination_truncation_api: bool = False): self.to_termination_truncation_api = to_termination_truncation_api if self.to_termination_truncation_api is False: deprecation( - "Initializing environment in old done step API which returns one bool instead of two. It is recommended to set `to_termination_truncation_api=True` to use new step API. This will be the default behaviour in future." + "Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future." ) def step(self, action): diff --git a/tests/wrappers/test_autoreset.py b/tests/wrappers/test_autoreset.py index e4ed3f9b593..5d50d6dfde7 100644 --- a/tests/wrappers/test_autoreset.py +++ b/tests/wrappers/test_autoreset.py @@ -138,6 +138,7 @@ def test_autoreset_wrapper_autoreset(): "count": 0, "final_observation": np.array([3]), "final_info": {"count": 3}, + "TimeLimit.truncated": False } obs, reward, done, info = env.step(action) diff --git a/tests/wrappers/test_step_compatibility.py b/tests/wrappers/test_step_compatibility.py index 7f0f1d2c798..83557f02db6 100644 --- a/tests/wrappers/test_step_compatibility.py +++ b/tests/wrappers/test_step_compatibility.py @@ -58,7 +58,7 @@ def test_step_compatibility_to_old_api(env, new_step_api): def test_step_compatibility_in_make(new_step_api): if new_step_api is None: with pytest.warns( - DeprecationWarning, match="Initializing environment in old done step API" + DeprecationWarning, match="Initializing environment in old step API" ): env = gym.make("CartPole-v1") else: From 53fdff779f3bc86a04b1bc0c62e346a75ec20136 Mon Sep 17 00:00:00 2001 From: StringTheory Date: Tue, 16 Aug 2022 13:14:21 +0100 Subject: [PATCH 4/6] pre-commit --- tests/wrappers/test_autoreset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/wrappers/test_autoreset.py b/tests/wrappers/test_autoreset.py index 5d50d6dfde7..6598d919c64 100644 --- a/tests/wrappers/test_autoreset.py +++ b/tests/wrappers/test_autoreset.py @@ -138,7 +138,7 @@ def test_autoreset_wrapper_autoreset(): "count": 0, "final_observation": np.array([3]), "final_info": {"count": 3}, - "TimeLimit.truncated": False + "TimeLimit.truncated": False, } obs, reward, done, info = env.step(action) From 67ab63331255178a32c99d7e85794c68da72da7f Mon Sep 17 00:00:00 2001 From: StringTheory Date: Tue, 16 Aug 2022 17:11:33 +0100 Subject: [PATCH 5/6] Revert back to old language --- gym/utils/step_api_compatibility.py | 38 ++++++++++------------ gym/vector/sync_vector_env.py | 2 +- gym/wrappers/time_limit.py | 4 +-- tests/utils/test_step_api_compatibility.py | 19 +++++------ 4 files changed, 29 insertions(+), 34 deletions(-) diff --git a/gym/utils/step_api_compatibility.py b/gym/utils/step_api_compatibility.py index 58d794f460a..85d9b031147 100644 --- a/gym/utils/step_api_compatibility.py +++ b/gym/utils/step_api_compatibility.py @@ -5,14 +5,14 @@ from gym.core import ObsType -DoneStepType = Tuple[ +OldStepType = Tuple[ Union[ObsType, np.ndarray], Union[float, np.ndarray], Union[bool, np.ndarray], Union[dict, list], ] -TerminationTruncationStepType = Tuple[ +NewStepType = Tuple[ Union[ObsType, np.ndarray], Union[float, np.ndarray], Union[bool, np.ndarray], @@ -21,10 +21,9 @@ ] -def to_terminated_truncated_step_api( - step_returns: Union[DoneStepType, TerminationTruncationStepType], - is_vector_env=False, -) -> TerminationTruncationStepType: +def step_to_new_api( + step_returns: Union[OldStepType, NewStepType], is_vector_env=False +) -> NewStepType: """Function to transform step returns to new step API irrespective of input API. Args: @@ -74,10 +73,9 @@ def to_terminated_truncated_step_api( ) -def to_done_step_api( - step_returns: Union[TerminationTruncationStepType, DoneStepType], - is_vector_env: bool = False, -) -> DoneStepType: +def step_to_old_api( + step_returns: Union[NewStepType, OldStepType], is_vector_env: bool = False +) -> OldStepType: """Function to transform step returns to old step API irrespective of input API. Args: @@ -130,19 +128,19 @@ def to_done_step_api( def step_api_compatibility( - step_returns: Union[TerminationTruncationStepType, DoneStepType], - to_termination_truncation: bool = False, + step_returns: Union[NewStepType, OldStepType], + new_step_api: bool = False, is_vector_env: bool = False, -) -> Union[TerminationTruncationStepType, DoneStepType]: +) -> Union[NewStepType, OldStepType]: """Function to transform step returns to the API specified by `new_step_api` bool. - Done step API refers to step() method returning (observation, reward, done, info) - Termination Truncation step API refers to step() method returning (observation, reward, terminated, truncated, info) + Old step API refers to step() method returning (observation, reward, done, info) + New step API refers to step() method returning (observation, reward, terminated, truncated, info) (Refer to docs for details on the API change) Args: step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info) - to_termination_truncation (bool): Whether the output should be in new step API or old (False by default) + new_step_api (bool): Whether the output should be in new step API or old (False by default) is_vector_env (bool): Whether the step_returns are from a vector environment Returns: @@ -153,10 +151,10 @@ def step_api_compatibility( wrapper is written in new API, and the final step output is desired to be in old API. >>> obs, rew, done, info = step_api_compatibility(env.step(action)) - >>> obs, rew, terminated, truncated, info = step_api_compatibility(env.step(action), to_termination_truncation=True) + >>> obs, rew, terminated, truncated, info = step_api_compatibility(env.step(action), new_step_api=True) >>> observations, rewards, dones, infos = step_api_compatibility(vec_env.step(action), is_vector_env=True) """ - if to_termination_truncation: - return to_terminated_truncated_step_api(step_returns, is_vector_env) + if new_step_api: + return step_to_new_api(step_returns, is_vector_env) else: - return to_done_step_api(step_returns, is_vector_env) + return step_to_old_api(step_returns, is_vector_env) diff --git a/gym/vector/sync_vector_env.py b/gym/vector/sync_vector_env.py index d6bdb793192..cc3408e7adb 100644 --- a/gym/vector/sync_vector_env.py +++ b/gym/vector/sync_vector_env.py @@ -180,7 +180,7 @@ def step_wait(self): np.copy(self._truncateds), infos, ), - to_termination_truncation=self.new_step_api, + new_step_api=self.new_step_api, is_vector_env=True, ) diff --git a/gym/wrappers/time_limit.py b/gym/wrappers/time_limit.py index df985f50b08..17481d68070 100644 --- a/gym/wrappers/time_limit.py +++ b/gym/wrappers/time_limit.py @@ -58,7 +58,7 @@ def step(self, action): """ observation, reward, terminated, truncated, info = step_api_compatibility( self.env.step(action), - to_termination_truncation=True, + True, ) self._elapsed_steps += 1 @@ -70,7 +70,7 @@ def step(self, action): return step_api_compatibility( (observation, reward, terminated, truncated, info), - to_termination_truncation=self.new_step_api, + self.new_step_api, ) def reset(self, **kwargs): diff --git a/tests/utils/test_step_api_compatibility.py b/tests/utils/test_step_api_compatibility.py index c66157e2138..ade45892072 100644 --- a/tests/utils/test_step_api_compatibility.py +++ b/tests/utils/test_step_api_compatibility.py @@ -2,10 +2,7 @@ import pytest from gym.utils.env_checker import data_equivalence -from gym.utils.step_api_compatibility import ( - to_done_step_api, - to_terminated_truncated_step_api, -) +from gym.utils.step_api_compatibility import step_to_new_api, step_to_old_api @pytest.mark.parametrize( @@ -57,7 +54,7 @@ def test_to_done_step_api( is_vector_env, done_returns, expected_terminated, expected_truncated ): - _, _, terminated, truncated, info = to_terminated_truncated_step_api( + _, _, terminated, truncated, info = step_to_new_api( done_returns, is_vector_env=is_vector_env ) assert np.all(terminated == expected_terminated) @@ -70,7 +67,7 @@ def test_to_done_step_api( else: # isinstance(info, dict) assert "TimeLimit.truncated" not in info - roundtripped_returns = to_done_step_api( + roundtripped_returns = step_to_old_api( (0, 0, terminated, truncated, info), is_vector_env=is_vector_env ) assert data_equivalence(done_returns, roundtripped_returns) @@ -115,7 +112,7 @@ def test_to_done_step_api( def test_to_terminated_truncated_step_api( is_vector_env, terminated_truncated_returns, expected_done, expected_truncated ): - _, _, done, info = to_done_step_api( + _, _, done, info = step_to_old_api( terminated_truncated_returns, is_vector_env=is_vector_env ) assert np.all(done == expected_done) @@ -139,7 +136,7 @@ def test_to_terminated_truncated_step_api( else: assert "TimeLimit.truncated" not in info - roundtripped_returns = to_terminated_truncated_step_api( + roundtripped_returns = step_to_new_api( (0, 0, done, info), is_vector_env=is_vector_env ) assert data_equivalence(terminated_truncated_returns, roundtripped_returns) @@ -149,19 +146,19 @@ def test_edge_case(): # When converting between the two-step APIs this is not possible in a single case # terminated=True and truncated=True -> done=True and info={} # We cannot test this in test_to_terminated_truncated_step_api as the roundtripping test will fail - _, _, done, info = to_done_step_api((0, 0, True, True, {})) + _, _, done, info = step_to_old_api((0, 0, True, True, {})) assert done is True assert info == {"TimeLimit.truncated": False} # Test with vector dict info - _, _, done, info = to_done_step_api( + _, _, done, info = step_to_old_api( (0, 0, np.array([True]), np.array([True]), {}), is_vector_env=True ) assert np.all(done) assert info == {"TimeLimit.truncated": np.array([False])} # Test with vector list info - _, _, done, info = to_done_step_api( + _, _, done, info = step_to_old_api( (0, 0, np.array([True]), np.array([True]), [{"Test-Info": True}]), is_vector_env=True, ) From 8984498eee547a9b7af01471914223034fafc1ae Mon Sep 17 00:00:00 2001 From: StringTheory Date: Tue, 16 Aug 2022 17:20:02 +0100 Subject: [PATCH 6/6] Revert changes to step api wrapper --- gym/wrappers/step_api_compatibility.py | 29 ++++++++++++-------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/gym/wrappers/step_api_compatibility.py b/gym/wrappers/step_api_compatibility.py index bb88a2652c0..72d4c8a1e07 100644 --- a/gym/wrappers/step_api_compatibility.py +++ b/gym/wrappers/step_api_compatibility.py @@ -1,10 +1,7 @@ """Implementation of StepAPICompatibility wrapper class for transforming envs between new and old step API.""" import gym from gym.logger import deprecation -from gym.utils.step_api_compatibility import ( - to_done_step_api, - to_terminated_truncated_step_api, -) +from gym.utils.step_api_compatibility import step_to_new_api, step_to_old_api class StepAPICompatibility(gym.Wrapper): @@ -18,33 +15,33 @@ class StepAPICompatibility(gym.Wrapper): Args: env (gym.Env): the env to wrap. Can be in old or new API - to_termination_truncation_api (bool): True to use env with new step API, False to use env with old step API. (False by default) + new_step_api (bool): True to use env with new step API, False to use env with old step API. (False by default) Examples: >>> env = gym.make("CartPole-v1") >>> env # wrapper applied by default, set to old API >>>> - >>> env = gym.make("CartPole-v1", to_termination_truncation_api=True) # set to new API - >>> env = StepAPICompatibility(CustomEnv(), to_termination_truncation_api=True) # manually using wrapper on unregistered envs + >>> env = gym.make("CartPole-v1", new_step_api=True) # set to new API + >>> env = StepAPICompatibility(CustomEnv(), new_step_api=True) # manually using wrapper on unregistered envs """ - def __init__(self, env: gym.Env, to_termination_truncation_api: bool = False): + def __init__(self, env: gym.Env, new_step_api=False): """A wrapper which can transform an environment from new step API to old and vice-versa. Args: env (gym.Env): the env to wrap. Can be in old or new API - to_termination_truncation_api (bool): Whether the wrapper's step method outputs two booleans (termination and truncation) with True, or one boolean (done) with False + new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ - super().__init__(env, to_termination_truncation_api) - self.to_termination_truncation_api = to_termination_truncation_api - if self.to_termination_truncation_api is False: + super().__init__(env, new_step_api) + self.new_step_api = new_step_api + if not self.new_step_api: deprecation( "Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future." ) def step(self, action): - """Steps through the environment, returning 5 or 4 items depending on `to_termination_truncation_api`. + """Steps through the environment, returning 5 or 4 items depending on `new_step_api`. Args: action: action to step through the environment with @@ -53,7 +50,7 @@ def step(self, action): (observation, reward, terminated, truncated, info) or (observation, reward, done, info) """ step_returns = self.env.step(action) - if self.to_termination_truncation_api: - return to_terminated_truncated_step_api(step_returns) + if self.new_step_api: + return step_to_new_api(step_returns) else: - return to_done_step_api(step_returns) + return step_to_old_api(step_returns)