Skip to content

Commit b954703

Browse files
committed
Fix type annotations
1 parent c0a6a18 commit b954703

File tree

3 files changed

+20
-10
lines changed

3 files changed

+20
-10
lines changed

stable_baselines3/common/env_checker.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def _check_goal_env_compute_reward(
126126
env: gym.Env,
127127
reward: float,
128128
info: Dict[str, Any],
129-
):
129+
) -> None:
130130
"""
131131
Check that reward is computed with `compute_reward`
132132
and that the implementation is vectorized.

stable_baselines3/common/env_util.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def unwrap_wrapper(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> Optional[g
2525
return None
2626

2727

28-
def is_wrapped(env: Type[gym.Env], wrapper_class: Type[gym.Wrapper]) -> bool:
28+
def is_wrapped(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> bool:
2929
"""
3030
Check if a given environment has been wrapped with a given wrapper.
3131
@@ -73,13 +73,19 @@ def make_vec_env(
7373
:param wrapper_kwargs: Keyword arguments to pass to the ``Wrapper`` class constructor.
7474
:return: The wrapped environment
7575
"""
76-
env_kwargs = {} if env_kwargs is None else env_kwargs
77-
vec_env_kwargs = {} if vec_env_kwargs is None else vec_env_kwargs
78-
monitor_kwargs = {} if monitor_kwargs is None else monitor_kwargs
79-
wrapper_kwargs = {} if wrapper_kwargs is None else wrapper_kwargs
76+
env_kwargs = env_kwargs or {}
77+
vec_env_kwargs = vec_env_kwargs or {}
78+
monitor_kwargs = monitor_kwargs or {}
79+
wrapper_kwargs = wrapper_kwargs or {}
80+
assert vec_env_kwargs is not None # for mypy
81+
82+
def make_env(rank: int) -> Callable[[], gym.Env]:
83+
def _init() -> gym.Env:
84+
# For type checker:
85+
assert monitor_kwargs is not None
86+
assert wrapper_kwargs is not None
87+
assert env_kwargs is not None
8088

81-
def make_env(rank):
82-
def _init():
8389
if isinstance(env_id, str):
8490
env = gym.make(env_id, **env_kwargs)
8591
else:
@@ -91,7 +97,7 @@ def _init():
9197
# to have additional training information
9298
monitor_path = os.path.join(monitor_dir, str(rank)) if monitor_dir is not None else None
9399
# Create the monitor folder if needed
94-
if monitor_path is not None:
100+
if monitor_path is not None and monitor_dir is not None:
95101
os.makedirs(monitor_dir, exist_ok=True)
96102
env = Monitor(env, filename=monitor_path, **monitor_kwargs)
97103
# Optionally, wrap the environment with the provided wrapper

stable_baselines3/common/torch_layers.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,11 @@ class NatureCNN(BaseFeaturesExtractor):
5757
This corresponds to the number of unit for the last layer.
5858
"""
5959

60-
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512):
60+
def __init__(self, observation_space: gym.Space, features_dim: int = 512):
61+
assert isinstance(observation_space, gym.spaces.Box), (
62+
"NatureCNN must be used with a gym.spaces.Box ",
63+
f"observation space, not {observation_space}",
64+
)
6165
super().__init__(observation_space, features_dim)
6266
# We assume CxHxW images (channels first)
6367
# Re-ordering will be done by pre-preprocessing or wrapper

0 commit comments

Comments
 (0)