5
5
from gym import spaces
6
6
7
7
from stable_baselines3 .common .buffers import DictReplayBuffer , DictRolloutBuffer , ReplayBuffer , RolloutBuffer
8
+ from stable_baselines3 .common .env_checker import check_env
8
9
from stable_baselines3 .common .env_util import make_vec_env
9
10
from stable_baselines3 .common .type_aliases import DictReplayBufferSamples , ReplayBufferSamples
10
11
from stable_baselines3 .common .utils import get_device
@@ -19,7 +20,7 @@ class DummyEnv(gym.Env):
19
20
def __init__ (self ):
20
21
self .action_space = spaces .Box (1 , 5 , (1 ,))
21
22
self .observation_space = spaces .Box (1 , 5 , (1 ,))
22
- self ._observations = [ 1 , 2 , 3 , 4 , 5 ]
23
+ self ._observations = np . array ([[ 1.0 ], [ 2.0 ], [ 3.0 ], [ 4.0 ], [ 5.0 ]], dtype = np . float32 )
23
24
self ._rewards = [1 , 2 , 3 , 4 , 5 ]
24
25
self ._t = 0
25
26
self ._ep_length = 100
@@ -47,7 +48,7 @@ def __init__(self):
47
48
self .action_space = spaces .Box (1 , 5 , (1 ,))
48
49
space = spaces .Box (1 , 5 , (1 ,))
49
50
self .observation_space = spaces .Dict ({"observation" : space , "achieved_goal" : space , "desired_goal" : space })
50
- self ._observations = [ 1 , 2 , 3 , 4 , 5 ]
51
+ self ._observations = np . array ([[ 1.0 ], [ 2.0 ], [ 3.0 ], [ 4.0 ], [ 5.0 ]], dtype = np . float32 )
51
52
self ._rewards = [1 , 2 , 3 , 4 , 5 ]
52
53
self ._t = 0
53
54
self ._ep_length = 100
@@ -66,6 +67,13 @@ def step(self, action):
66
67
return obs , reward , done , truncated , {}
67
68
68
69
70
+ @pytest .mark .parametrize ("env_cls" , [DummyEnv , DummyDictEnv ])
71
+ def test_env (env_cls ):
72
+ # Check the env used for testing
73
+ # Do not warn for assymetric space
74
+ check_env (env_cls (), warn = False , skip_render_check = True )
75
+
76
+
69
77
@pytest .mark .parametrize ("replay_buffer_cls" , [ReplayBuffer , DictReplayBuffer ])
70
78
def test_replay_buffer_normalization (replay_buffer_cls ):
71
79
env = {ReplayBuffer : DummyEnv , DictReplayBuffer : DummyDictEnv }[replay_buffer_cls ]
0 commit comments