Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pin gym version #782

Merged
merged 3 commits into from
Feb 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@ Changelog
==========


Release 1.4.1a0 (WIP)
Release 1.4.1a1 (WIP)
---------------------------


Breaking Changes:
^^^^^^^^^^^^^^^^^

- Switched minimum Gym version to 0.21.0.

New Features:
^^^^^^^^^^^^^
- Makes the length of keys and values in `HumanOutputFormat` configurable,
- Makes the length of keys and values in ``HumanOutputFormat`` configurable,
depending on desired maximum width of output.

SB3-Contrib
Expand All @@ -30,10 +30,10 @@ Bug Fixes:

Deprecations:
^^^^^^^^^^^^^
- Switched minimum Gym version to 0.21.0.

Others:
^^^^^^^
- Fixed pytest warnings

Documentation:
^^^^^^^^^^^^^^
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ filterwarnings =
ignore:Parameters to load are deprecated.:DeprecationWarning
ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning
ignore::UserWarning:gym
ignore:SelectableGroups dict interface is deprecated.:DeprecationWarning
ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning

[pytype]
inputs = stable_baselines3
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
packages=[package for package in find_packages() if package.startswith("stable_baselines3")],
package_data={"stable_baselines3": ["py.typed", "version.txt"]},
install_requires=[
"gym>=0.21", # Remember to also update gym version in "extra" below when this changes
"gym==0.21", # Fixed version due to breaking changes in 0.22
"numpy",
"torch>=1.8.1",
# For saving models
Expand Down Expand Up @@ -116,7 +116,8 @@
# For render
"opencv-python",
# For atari games,
"gym[atari,accept-rom-license]>=0.21",
"ale-py~=0.7.4",
"autorom[accept-rom-license]~=0.4.2",
"pillow",
# Tensorboard support
"tensorboard>=2.2.0",
Expand Down
32 changes: 19 additions & 13 deletions stable_baselines3/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ def __init__(self, image: Union[th.Tensor, np.ndarray, str], dataformats: str):


class FormatUnsupportedError(NotImplementedError):
"""
Custom error to display informative message when
a value is not supported by some formats.

:param unsupported_formats: A sequence of unsupported formats,
for instance ``["stdout"]``.
:param value_description: Description of the value that cannot be logged by this format.
"""

def __init__(self, unsupported_formats: Sequence[str], value_description: str):
if len(unsupported_formats) > 1:
format_str = f"formats {', '.join(unsupported_formats)} are"
Expand Down Expand Up @@ -116,21 +125,18 @@ def write_sequence(self, sequence: List) -> None:
class HumanOutputFormat(KVWriter, SeqWriter):
"""A human-readable output format producing ASCII tables of key-value pairs.

Set attribute `max_length` to change the maximum length of keys and values
to write to output (or specify it when calling `__init__`).
Set attribute ``max_length`` to change the maximum length of keys and values
to write to output (or specify it when calling ``__init__``).

:param filename_or_file: the file to write the log to
:param max_length: the maximum length of keys and values to write to output.
Outputs longer than this will be truncated. An error will be raised
if multiple keys are truncated to the same value. The maximum output
width will be ``2*max_length + 7``. The default of 36 produces output
no longer than 79 characters wide.
"""

def __init__(self, filename_or_file: Union[str, TextIO], max_length: int = 36):
"""
log to a file, in a human readable format

:param filename_or_file: the file to write the log to
:param max_length: the maximum length of keys and values to write to output.
Outputs longer than this will be truncated. An error will be raised
if multiple keys are truncated to the same value. The maximum output
width will be ``2*max_length + 7``. The default of 36 produces output
no longer than 79 characters wide.
"""
self.max_length = max_length
if isinstance(filename_or_file, str):
self.file = open(filename_or_file, "wt")
Expand Down Expand Up @@ -174,7 +180,7 @@ def write(self, key_values: Dict, key_excluded: Dict, step: int = 0) -> None:
truncated_key = self._truncate(key)
if truncated_key in key2str:
raise ValueError(
f"Key '{key}' truncated to " f"'{truncated_key}' that already exists. Consider increasing `max_length`."
f"Key '{key}' truncated to '{truncated_key}' that already exists. Consider increasing `max_length`."
)
key2str[truncated_key] = self._truncate(value_str)

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.4.1a0
1.4.1a1
9 changes: 5 additions & 4 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import types
import warnings

import gym
import numpy as np
Expand Down Expand Up @@ -35,7 +36,7 @@ def test_env(env_id):
:param env_id: (str)
"""
env = gym.make(env_id)
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
check_env(env)

# Pendulum-v1 will produce a warning because the action space is
Expand All @@ -50,7 +51,7 @@ def test_env(env_id):
@pytest.mark.parametrize("env_class", ENV_CLASSES)
def test_custom_envs(env_class):
env = env_class()
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
check_env(env)
# No warnings for custom envs
assert len(record) == 0
Expand All @@ -68,7 +69,7 @@ def test_custom_envs(env_class):
def test_bit_flipping(kwargs):
# Additional tests for BitFlippingEnv
env = BitFlippingEnv(**kwargs)
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
check_env(env)

# No warnings for custom envs
Expand Down Expand Up @@ -147,7 +148,7 @@ def patched_step(_action):
def test_non_default_action_spaces(new_action_space):
env = FakeImageEnv(discrete=False)
# Default, should pass the test
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
check_env(env)

# No warnings for custom envs
Expand Down
12 changes: 6 additions & 6 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,43 +580,43 @@ def test_open_file_str_pathlib(tmp_path, pathtype):
with open_path(pathtype(f"{tmp_path}/t1"), "w") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
assert load_from_pkl(pathtype(f"{tmp_path}/t1")) == "foo"
assert not record

# test custom suffix
with open_path(pathtype(f"{tmp_path}/t1.custom_ext"), "w") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
assert load_from_pkl(pathtype(f"{tmp_path}/t1.custom_ext")) == "foo"
assert not record

# test without suffix
with open_path(pathtype(f"{tmp_path}/t1"), "w", suffix="pkl") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
assert load_from_pkl(pathtype(f"{tmp_path}/t1.pkl")) == "foo"
assert not record

# test that a warning is raised when the path doesn't exist
with open_path(pathtype(f"{tmp_path}/t2.pkl"), "w") as fp1:
save_to_pkl(fp1, "foo")
assert fp1.closed
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl")) == "foo"
assert len(record) == 0

with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl", verbose=2)) == "foo"
assert len(record) == 1

fp = pathlib.Path(f"{tmp_path}/t2").open("w")
fp.write("rubbish")
fp.close()
# test that a warning is only raised when verbose = 0
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=0).close()
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=1).close()
open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=2).close()
Expand Down
3 changes: 2 additions & 1 deletion tests/test_vec_normalize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import operator
import warnings

import gym
import numpy as np
Expand Down Expand Up @@ -120,7 +121,7 @@ def make_dict_env():
def test_deprecation():
venv = DummyVecEnv([lambda: gym.make("CartPole-v1")])
venv = VecNormalize(venv)
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
assert np.allclose(venv.ret, venv.returns)
# Deprecation warning when using .ret
assert len(record) == 1
Expand Down