Skip to content

Commit 0431c7a

Browse files
authored
Merge pull request #5 from DLR-RM/check_test_env
Check test env in tests
2 parents 669ef02 + 7460782 commit 0431c7a

File tree

6 files changed

+52
-5
lines changed

6 files changed

+52
-5
lines changed

tests/test_buffers.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from gym import spaces
66

77
from stable_baselines3.common.buffers import DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer
8+
from stable_baselines3.common.env_checker import check_env
89
from stable_baselines3.common.env_util import make_vec_env
910
from stable_baselines3.common.type_aliases import DictReplayBufferSamples, ReplayBufferSamples
1011
from stable_baselines3.common.utils import get_device
@@ -19,7 +20,7 @@ class DummyEnv(gym.Env):
1920
def __init__(self):
2021
self.action_space = spaces.Box(1, 5, (1,))
2122
self.observation_space = spaces.Box(1, 5, (1,))
22-
self._observations = [1, 2, 3, 4, 5]
23+
self._observations = np.array([[1.0], [2.0], [3.0], [4.0], [5.0]], dtype=np.float32)
2324
self._rewards = [1, 2, 3, 4, 5]
2425
self._t = 0
2526
self._ep_length = 100
@@ -47,7 +48,7 @@ def __init__(self):
4748
self.action_space = spaces.Box(1, 5, (1,))
4849
space = spaces.Box(1, 5, (1,))
4950
self.observation_space = spaces.Dict({"observation": space, "achieved_goal": space, "desired_goal": space})
50-
self._observations = [1, 2, 3, 4, 5]
51+
self._observations = np.array([[1.0], [2.0], [3.0], [4.0], [5.0]], dtype=np.float32)
5152
self._rewards = [1, 2, 3, 4, 5]
5253
self._t = 0
5354
self._ep_length = 100
@@ -66,6 +67,13 @@ def step(self, action):
6667
return obs, reward, done, truncated, {}
6768

6869

70+
@pytest.mark.parametrize("env_cls", [DummyEnv, DummyDictEnv])
71+
def test_env(env_cls):
72+
# Check the env used for testing
73+
# Do not warn for assymetric space
74+
check_env(env_cls(), warn=False, skip_render_check=True)
75+
76+
6977
@pytest.mark.parametrize("replay_buffer_cls", [ReplayBuffer, DictReplayBuffer])
7078
def test_replay_buffer_normalization(replay_buffer_cls):
7179
env = {ReplayBuffer: DummyEnv, DictReplayBuffer: DummyDictEnv}[replay_buffer_cls]

tests/test_dict_env.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from gym import spaces
77

88
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
9+
from stable_baselines3.common.env_checker import check_env
910
from stable_baselines3.common.env_util import make_vec_env
1011
from stable_baselines3.common.envs import BitFlippingEnv, SimpleMultiObsEnv
1112
from stable_baselines3.common.evaluation import evaluate_policy
@@ -71,9 +72,6 @@ def step(self, action):
7172
done = truncated = False
7273
return self.observation_space.sample(), reward, done, truncated, {}
7374

74-
def compute_reward(self, achieved_goal, desired_goal, info):
75-
return np.zeros((len(achieved_goal),))
76-
7775
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
7876
if seed is not None:
7977
self.observation_space.seed(seed)
@@ -83,6 +81,19 @@ def render(self):
8381
pass
8482

8583

84+
@pytest.mark.parametrize("use_discrete_actions", [True, False])
85+
@pytest.mark.parametrize("channel_last", [True, False])
86+
@pytest.mark.parametrize("nested_dict_obs", [True, False])
87+
@pytest.mark.parametrize("vec_only", [True, False])
88+
def test_env(use_discrete_actions, channel_last, nested_dict_obs, vec_only):
89+
# Check the env used for testing
90+
if nested_dict_obs:
91+
with pytest.warns(UserWarning, match="Nested observation spaces are not supported"):
92+
check_env(DummyDictEnv(use_discrete_actions, channel_last, nested_dict_obs, vec_only))
93+
else:
94+
check_env(DummyDictEnv(use_discrete_actions, channel_last, nested_dict_obs, vec_only))
95+
96+
8697
@pytest.mark.parametrize("policy", ["MlpPolicy", "CnnPolicy"])
8798
def test_policy_hint(policy):
8899
# Common mistake: using the wrong policy

tests/test_gae.py

+7
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from stable_baselines3 import A2C, PPO, SAC
1010
from stable_baselines3.common.callbacks import BaseCallback
11+
from stable_baselines3.common.env_checker import check_env
1112
from stable_baselines3.common.policies import ActorCriticPolicy
1213

1314

@@ -121,6 +122,12 @@ def forward(self, obs, deterministic=False):
121122
return actions, values, log_prob
122123

123124

125+
@pytest.mark.parametrize("env_cls", [CustomEnv, InfiniteHorizonEnv])
126+
def test_env(env_cls):
127+
# Check the env used for testing
128+
check_env(env_cls(), skip_render_check=True)
129+
130+
124131
@pytest.mark.parametrize("model_class", [A2C, PPO])
125132
@pytest.mark.parametrize("gae_lambda", [1.0, 0.9])
126133
@pytest.mark.parametrize("gamma", [1.0, 0.99])

tests/test_logger.py

+7
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pandas.errors import EmptyDataError
1414

1515
from stable_baselines3 import A2C, DQN
16+
from stable_baselines3.common.env_checker import check_env
1617
from stable_baselines3.common.logger import (
1718
DEBUG,
1819
INFO,
@@ -363,6 +364,12 @@ def step(self, action):
363364
return obs, 0.0, True, False, {}
364365

365366

367+
@pytest.mark.parametrize("env_cls", [TimeDelayEnv])
368+
def test_env(env_cls):
369+
# Check the env used for testing
370+
check_env(env_cls(), skip_render_check=True)
371+
372+
366373
class InMemoryLogger(Logger):
367374
"""
368375
Logger that keeps key/value pairs in memory without any writers.

tests/test_predict.py

+7
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from gym import spaces
66

77
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
8+
from stable_baselines3.common.env_checker import check_env
89
from stable_baselines3.common.envs import IdentityEnv
910
from stable_baselines3.common.utils import get_device
1011
from stable_baselines3.common.vec_env import DummyVecEnv
@@ -36,6 +37,12 @@ def step(self, action):
3637
return self.observation_space.sample(), 0.0, np.random.rand() > 0.5, False, {}
3738

3839

40+
@pytest.mark.parametrize("env_cls", [CustomSubClassedSpaceEnv])
41+
def test_env(env_cls):
42+
# Check the env used for testing
43+
check_env(env_cls(), skip_render_check=True)
44+
45+
3946
@pytest.mark.parametrize("model_class", MODEL_LIST)
4047
def test_auto_wrap(model_class):
4148
"""Test auto wrapping of env into a VecEnv."""

tests/test_spaces.py

+7
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from gym import spaces
77

88
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
9+
from stable_baselines3.common.env_checker import check_env
910
from stable_baselines3.common.env_util import make_vec_env
1011
from stable_baselines3.common.evaluation import evaluate_policy
1112

@@ -53,6 +54,12 @@ def step(self, action):
5354
return self.observation_space.sample(), 0.0, False, False, {}
5455

5556

57+
@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8), DummyMultiBinary((3, 2))])
58+
def test_env(env):
59+
# Check the env used for testing
60+
check_env(env, skip_render_check=True)
61+
62+
5663
@pytest.mark.parametrize("model_class", [SAC, TD3, DQN])
5764
@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8), DummyMultiBinary((3, 2))])
5865
def test_identity_spaces(model_class, env):

0 commit comments

Comments
 (0)