Skip to content

Commit

Permalink
Add warning when using non-zero start with Discrete (fixes DLR-RM#1197)
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Dec 7, 2022
1 parent b954703 commit 3f75a8a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
14 changes: 13 additions & 1 deletion stable_baselines3/common/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,15 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act

if isinstance(observation_space, spaces.Dict):
nested_dict = False
for space in observation_space.spaces.values():
for key, space in observation_space.spaces.items():
if isinstance(space, spaces.Dict):
nested_dict = True
if isinstance(space, spaces.Discrete) and space.start != 0:
warnings.warn(
f"Discrete observation space (key '{key}') with a non-zero start is not supported by Stable-Baselines3. "
"You can use a wrapper or update your observation space."
)

if nested_dict:
warnings.warn(
"Nested observation spaces are not supported by Stable Baselines3 "
Expand All @@ -77,6 +83,12 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act
"which is supported by SB3."
)

if isinstance(observation_space, spaces.Discrete) and observation_space.start != 0:
warnings.warn(
"Discrete observation space with a non-zero start is not supported by Stable-Baselines3. "
"You can use a wrapper or update your observation space."
)

if not _is_numpy_array_space(action_space):
warnings.warn(
"The action space is not based off a numpy array. Typically this means it's either a Dict or Tuple space. "
Expand Down
4 changes: 4 additions & 0 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ def patched_step(_action):
spaces.Dict({"position": spaces.Dict({"abs": spaces.Discrete(5), "rel": spaces.Discrete(2)})}),
# Small image inside a dict
spaces.Dict({"img": spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)}),
# Non zero start index
spaces.Discrete(3, start=-1),
# Non zero start index inside a Dict
spaces.Dict({"obs": spaces.Discrete(3, start=1)}),
],
)
def test_non_default_spaces(new_obs_space):
Expand Down

0 comments on commit 3f75a8a

Please sign in to comment.