Skip to content

Commit

Permalink
vec_envs fix seed() causing a reset (#1486)
Browse files Browse the repository at this point in the history
* `dummy_vec_env` fix `seed()` causing a reset

* rename `seed`

* fixes

* bug fix

* fix seed return type

* Cleanup seeding, add test and remove compat wrapper

* Update env checker and tests

* Add deterministic test for make_vec_env

---------

Co-authored-by: Antonin Raffin <[email protected]>
  • Loading branch information
Kallinteris-Andreas and araffin authored May 20, 2023
1 parent fd0cd82 commit 9c338f9
Show file tree
Hide file tree
Showing 18 changed files with 94 additions and 76 deletions.
4 changes: 3 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.0.0a8 (WIP)
Release 2.0.0a9 (WIP)
--------------------------

**Gymnasium support**
Expand All @@ -22,6 +22,7 @@ Breaking Changes:
- Renamed environment output observations in ``evaluate_policy`` to prevent shadowing the input observations during callbacks (@npit)
- Upgraded wrappers and custom environment to Gymnasium
- Refined the ``HumanOutputFormat`` file check: now it verifies if the object is an instance of ``io.TextIOBase`` instead of only checking for the presence of a ``write`` method.
- Because of new Gym API (0.26+), the random seed passed to ``vec_env.seed(seed=seed)`` will only be effective after then ``env.reset()`` call.

New Features:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -55,6 +56,7 @@ Others:
- Fixed ``stable_baselines3/common/vec_env/base_vec_env.py`` type hints
- Fixed ``stable_baselines3/common/vec_env/vec_frame_stack.py`` type hints
- Fixed ``stable_baselines3/common/vec_env/dummy_vec_env.py`` type hints
- Fixed ``stable_baselines3/common/vec_env/subproc_vec_env.py`` type hints
- Upgraded docker images to use mamba/micromamba and CUDA 11.7
- Updated env checker to reflect what subset of Gymnasium is supported and improve GoalEnv checks
- Improve type annotation of wrappers
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ exclude = """(?x)(
| stable_baselines3/common/save_util.py$
| stable_baselines3/common/utils.py$
| stable_baselines3/common/vec_env/__init__.py$
| stable_baselines3/common/vec_env/subproc_vec_env.py$
| stable_baselines3/common/vec_env/vec_normalize.py$
| stable_baselines3/common/vec_env/vec_transpose.py$
| stable_baselines3/common/vec_env/vec_video_recorder.py$
Expand Down
5 changes: 5 additions & 0 deletions stable_baselines3/common/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,11 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -
observation_space = env.observation_space
action_space = env.action_space

try:
env.reset(seed=0)
except TypeError as e:
raise TypeError("The reset() method must accept a `seed` parameter") from e

# Warn the user if needed.
# A warning means that the environment may run but not work properly with Stable Baselines algorithms
if warn:
Expand Down
9 changes: 6 additions & 3 deletions stable_baselines3/common/env_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from stable_baselines3.common.atari_wrappers import AtariWrapper
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.utils import compat_gym_seed
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv
from stable_baselines3.common.vec_env.patch_gym import _patch_env

Expand Down Expand Up @@ -101,7 +100,8 @@ def _init() -> gym.Env:
env = _patch_env(env)

if seed is not None:
compat_gym_seed(env, seed=seed + rank)
# Note: here we only seed the action space
# We will seed the env at the next reset
env.action_space.seed(seed + rank)
# Wrap the env in a Monitor wrapper
# to have additional training information
Expand All @@ -122,7 +122,10 @@ def _init() -> gym.Env:
# Default: use a DummyVecEnv
vec_env_cls = DummyVecEnv

return vec_env_cls([make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs)
vec_env = vec_env_cls([make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs)
# Prepare the seeds for the first reset
vec_env.seed(seed)
return vec_env


def make_atari_env(
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def reset(self, indices: Optional[Iterable[int]] = None) -> None:
self.noises[index].reset()

def __repr__(self) -> str:
return f"VecNoise(BaseNoise={repr(self.base_noise)}), n_envs={len(self.noises)})"
return f"VecNoise(BaseNoise={self.base_noise!r}), n_envs={len(self.noises)})"

def __call__(self) -> np.ndarray:
"""
Expand Down
16 changes: 0 additions & 16 deletions stable_baselines3/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import random
import re
from collections import deque
from inspect import signature
from itertools import zip_longest
from typing import Dict, Iterable, List, Optional, Tuple, Union

Expand Down Expand Up @@ -549,18 +548,3 @@ def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]:
if print_info:
print(env_info_str)
return env_info, env_info_str


def compat_gym_seed(env: GymEnv, seed: int) -> None:
"""
Compatibility helper to seed Gym envs.
:param env: The Gym environment.
:param seed: The seed for the pseudo random generator
"""
if "seed" in signature(env.unwrapped.reset).parameters:
# gym >= 0.23.1
env.reset(seed=seed)
else:
# VecEnv and backward compatibility
env.seed(seed)
19 changes: 17 additions & 2 deletions stable_baselines3/common/vec_env/base_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ def __init__(
self.render_mode = render_mode
# store info returned by the reset method
self.reset_infos: List[Dict[str, Any]] = [{} for _ in range(num_envs)]
# seeds to be used in the next call to env.reset()
self._seeds: List[Optional[int]] = [None for _ in range(num_envs)]

def _reset_seeds(self) -> None:
"""
Reset the seeds that are going to be used at the next reset.
"""
self._seeds = [None for _ in range(self.num_envs)]

@abstractmethod
def reset(self) -> VecEnvObs:
Expand Down Expand Up @@ -239,17 +247,24 @@ def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]:
self.env_method("render")
return None

@abstractmethod
def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]:
"""
Sets the random seeds for all environments, based on a given seed.
Each individual environment will still get its own seed, by incrementing the given seed.
WARNING: since gym 0.26, those seeds will only be passed to the environment
at the next reset.
:param seed: The random seed. May be None for completely random seeding.
:return: Returns a list containing the seeds for each individual env.
Note that all list elements may be None, if the env does not return anything when being seeded.
"""
pass
if seed is None:
# To ensure that subprocesses have different seeds,
# we still populate the seed variable when no argument is passed
seed = np.random.randint(0, 2**32 - 1)

self._seeds = [seed + idx for idx in range(self.num_envs)]
return self._seeds

@property
def unwrapped(self) -> "VecEnv":
Expand Down
17 changes: 4 additions & 13 deletions stable_baselines3/common/vec_env/dummy_vec_env.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from collections import OrderedDict
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Type

import gymnasium as gym
import numpy as np
Expand Down Expand Up @@ -71,21 +71,12 @@ def step_wait(self) -> VecEnvStepReturn:
self._save_obs(env_idx, obs)
return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos))

def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]:
# Avoid circular import
from stable_baselines3.common.utils import compat_gym_seed

if seed is None:
seed = np.random.randint(0, 2**32 - 1)
seeds = []
for idx, env in enumerate(self.envs):
seeds.append(compat_gym_seed(env, seed=seed + idx)) # type: ignore[func-returns-value]
return seeds

def reset(self) -> VecEnvObs:
for env_idx in range(self.num_envs):
obs, self.reset_infos[env_idx] = self.envs[env_idx].reset()
obs, self.reset_infos[env_idx] = self.envs[env_idx].reset(seed=self._seeds[env_idx])
self._save_obs(env_idx, obs)
# Seeds are only used once
self._reset_seeds()
return self._obs_from_buf()

def close(self) -> None:
Expand Down
32 changes: 13 additions & 19 deletions stable_baselines3/common/vec_env/subproc_vec_env.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import multiprocessing as mp
import warnings
from collections import OrderedDict
from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union

import gymnasium as gym
import numpy as np
Expand All @@ -24,11 +24,10 @@ def _worker(
) -> None:
# Import here to avoid a circular import
from stable_baselines3.common.env_util import is_wrapped
from stable_baselines3.common.utils import compat_gym_seed

parent_remote.close()
env = _patch_env(env_fn_wrapper.var())
reset_info = {}
reset_info: Optional[Dict[str, Any]] = {}
while True:
try:
cmd, data = remote.recv()
Expand All @@ -42,10 +41,8 @@ def _worker(
info["terminal_observation"] = observation
observation, reset_info = env.reset()
remote.send((observation, reward, done, info, reset_info))
elif cmd == "seed":
remote.send(compat_gym_seed(env, seed=data))
elif cmd == "reset":
observation, reset_info = env.reset()
observation, reset_info = env.reset(seed=data)
remote.send((observation, reset_info))
elif cmd == "render":
remote.send(env.render())
Expand All @@ -61,7 +58,7 @@ def _worker(
elif cmd == "get_attr":
remote.send(getattr(env, data))
elif cmd == "set_attr":
remote.send(setattr(env, data[0], data[1]))
remote.send(setattr(env, data[0], data[1])) # type: ignore[func-returns-value]
elif cmd == "is_wrapped":
remote.send(is_wrapped(env, data))
else:
Expand Down Expand Up @@ -112,7 +109,9 @@ def __init__(self, env_fns: List[Callable[[], gym.Env]], start_method: Optional[
for work_remote, remote, env_fn in zip(self.work_remotes, self.remotes, env_fns):
args = (work_remote, remote, CloudpickleWrapper(env_fn))
# daemon=True: if the main process crashes, we should not cause things to hang
process = ctx.Process(target=_worker, args=args, daemon=True) # pytype:disable=attribute-error
# pytype: disable=attribute-error
process = ctx.Process(target=_worker, args=args, daemon=True) # type: ignore[attr-defined]
# pytype: enable=attribute-error
process.start()
self.processes.append(process)
work_remote.close()
Expand All @@ -135,18 +134,13 @@ def step_wait(self) -> VecEnvStepReturn:
obs, rews, dones, infos, self.reset_infos = zip(*results)
return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos

def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]:
if seed is None:
seed = np.random.randint(0, 2**32 - 1)
for idx, remote in enumerate(self.remotes):
remote.send(("seed", seed + idx))
return [remote.recv() for remote in self.remotes]

def reset(self) -> VecEnvObs:
for remote in self.remotes:
remote.send(("reset", None))
for env_idx, remote in enumerate(self.remotes):
remote.send(("reset", self._seeds[env_idx]))
results = [remote.recv() for remote in self.remotes]
obs, self.reset_infos = zip(*results)
# Seeds are only used once
self._reset_seeds()
return _flatten_obs(obs, self.observation_space)

def close(self) -> None:
Expand Down Expand Up @@ -235,6 +229,6 @@ def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Sp
elif isinstance(space, spaces.Tuple):
assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space"
obs_len = len(space.spaces)
return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len))
return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len)) # type: ignore[index]
else:
return np.stack(obs)
return np.stack(obs) # type: ignore[arg-type]
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.0.0a8
2.0.0a9
4 changes: 2 additions & 2 deletions tests/test_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self):
self._t = 0
self._ep_length = 100

def reset(self):
def reset(self, *, seed=None, options=None):
self._t = 0
obs = self._observations[0]
return obs, {}
Expand Down Expand Up @@ -55,7 +55,7 @@ def __init__(self):
self._t = 0
self._ep_length = 100

def reset(self):
def reset(self, seed=None, options=None):
self._t = 0
obs = {key: self._observations[0] for key in self.observation_space.spaces.keys()}
return obs, {}
Expand Down
2 changes: 1 addition & 1 deletion tests/test_env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def step(self, action):
info = {}
return observation, reward, terminated, truncated, info

def reset(self):
def reset(self, seed=None):
return np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype), {}

def render(self):
Expand Down
23 changes: 19 additions & 4 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,12 @@ def patched_step(_action):
def test_non_default_spaces(new_obs_space):
env = FakeImageEnv()
env.observation_space = new_obs_space

# Patch methods to avoid errors
env.reset = lambda: (new_obs_space.sample(), {})
def patched_reset(seed=None):
return new_obs_space.sample(), {}

env.reset = patched_reset

def patched_step(_action):
return new_obs_space.sample(), 0.0, False, False, {}
Expand Down Expand Up @@ -204,7 +208,7 @@ def check_reset_assert_error(env, new_reset_return):
:param new_reset_return: (Any)
"""

def wrong_reset():
def wrong_reset(seed=None):
return new_reset_return, {}

# Patch the reset method with a wrong one
Expand All @@ -224,10 +228,21 @@ def test_common_failures_reset():
check_reset_assert_error(env, 1)

# Return only obs (gym < 0.26)
env.reset = env.observation_space.sample
def wrong_reset(self, seed=None):
return env.observation_space.sample()

env.reset = types.MethodType(wrong_reset, env)
with pytest.raises(AssertionError):
check_env(env)

# No seed parameter (gym < 0.26)
def wrong_reset(self):
return env.observation_space.sample(), {}

env.reset = types.MethodType(wrong_reset, env)
with pytest.raises(TypeError):
check_env(env)

# Return not only the observation
check_reset_assert_error(env, (env.observation_space.sample(), False))

Expand All @@ -242,7 +257,7 @@ def test_common_failures_reset():

obs, _ = env.reset()

def wrong_reset(self):
def wrong_reset(self, seed=None):
return {"img": obs["img"], "vec": obs["img"]}, {}

env.reset = types.MethodType(wrong_reset, env)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def read_fn(_format):
tb_values_logged = []
for reservoir in [acc.scalars, acc.tensors, acc.images, acc.histograms, acc.compressed_histograms]:
for k in reservoir.Keys():
tb_values_logged.append(f"{k}: {str(reservoir.Items(k))}")
tb_values_logged.append(f"{k}: {reservoir.Items(k)!s}")

content = LogContent(_format, tb_values_logged)
return content
Expand Down Expand Up @@ -353,7 +353,7 @@ def __init__(self, delay: float = 0.01):
self.observation_space = spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32)
self.action_space = spaces.Discrete(2)

def reset(self):
def reset(self, seed=None):
return self.observation_space.sample(), {}

def step(self, action):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self):
self.observation_space = SubClassedBox(-1, 1, shape=(2,), dtype=np.float32)
self.action_space = SubClassedBox(-1, 1, shape=(2,), dtype=np.float32)

def reset(self):
def reset(self, seed=None):
return self.observation_space.sample(), {}

def step(self, action):
Expand Down
Loading

0 comments on commit 9c338f9

Please sign in to comment.