From acadcf3c41b7a733c5027cf0458211f9d1dea9bf Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 21 Feb 2022 22:04:30 +0100 Subject: [PATCH 1/3] Pin gym version --- docs/misc/changelog.rst | 7 +++---- setup.py | 5 +++-- stable_baselines3/version.txt | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 1f5052d53..ffb824ddd 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,17 +4,17 @@ Changelog ========== -Release 1.4.1a0 (WIP) +Release 1.4.1a1 (WIP) --------------------------- Breaking Changes: ^^^^^^^^^^^^^^^^^ - +- Switched minimum Gym version to 0.21.0. New Features: ^^^^^^^^^^^^^ -- Makes the length of keys and values in `HumanOutputFormat` configurable, +- Makes the length of keys and values in ``HumanOutputFormat`` configurable, depending on desired maximum width of output. SB3-Contrib @@ -30,7 +30,6 @@ Bug Fixes: Deprecations: ^^^^^^^^^^^^^ -- Switched minimum Gym version to 0.21.0. Others: ^^^^^^^ diff --git a/setup.py b/setup.py index eabf30c66..de615a7ff 100644 --- a/setup.py +++ b/setup.py @@ -73,7 +73,7 @@ packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ - "gym>=0.21", # Remember to also update gym version in "extra" below when this changes + "gym==0.21", # Fixed version due to breaking changes in 0.22 "numpy", "torch>=1.8.1", # For saving models @@ -116,7 +116,8 @@ # For render "opencv-python", # For atari games, - "gym[atari,accept-rom-license]>=0.21", + "ale-py~=0.7.4", + "autorom[accept-rom-license]~=0.4.2", "pillow", # Tensorboard support "tensorboard>=2.2.0", diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 97ec5cca0..d012e1c67 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.4.1a0 +1.4.1a1 From d29b871b4a6c4a2f02793c722cb429fcbd1759fd Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 21 Feb 2022 22:46:24 +0100 Subject: [PATCH 2/3] Cleanup warnings --- docs/misc/changelog.rst | 1 + setup.cfg | 2 ++ stable_baselines3/common/logger.py | 32 ++++++++++++++++++------------ tests/test_envs.py | 9 +++++---- tests/test_save_load.py | 12 +++++------ tests/test_vec_normalize.py | 3 ++- 6 files changed, 35 insertions(+), 24 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index ffb824ddd..9a875bfa8 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -33,6 +33,7 @@ Deprecations: Others: ^^^^^^^ +- Fixed pytest warnings Documentation: ^^^^^^^^^^^^^^ diff --git a/setup.cfg b/setup.cfg index 73ae3dbcc..e23ad4513 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,6 +13,8 @@ filterwarnings = ignore:Parameters to load are deprecated.:DeprecationWarning ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning ignore::UserWarning:gym + ignore:SelectableGroups dict interface is deprecated.:DeprecationWarning + ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning [pytype] inputs = stable_baselines3 diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py index 26e5a6e47..6493a3e0d 100644 --- a/stable_baselines3/common/logger.py +++ b/stable_baselines3/common/logger.py @@ -66,6 +66,15 @@ def __init__(self, image: Union[th.Tensor, np.ndarray, str], dataformats: str): class FormatUnsupportedError(NotImplementedError): + """ + Custom error to display informative message when + a value is not supported by some formats. + + :param unsupported_formats: A sequence of unsupported formats, + for instance ``["stdout"]``. + :param value_description: Description of the value that cannot be logged by this format. + """ + def __init__(self, unsupported_formats: Sequence[str], value_description: str): if len(unsupported_formats) > 1: format_str = f"formats {', '.join(unsupported_formats)} are" @@ -116,21 +125,18 @@ def write_sequence(self, sequence: List) -> None: class HumanOutputFormat(KVWriter, SeqWriter): """A human-readable output format producing ASCII tables of key-value pairs. - Set attribute `max_length` to change the maximum length of keys and values - to write to output (or specify it when calling `__init__`). + Set attribute ``max_length`` to change the maximum length of keys and values + to write to output (or specify it when calling ``__init__``). + + :param filename_or_file: the file to write the log to + :param max_length: the maximum length of keys and values to write to output. + Outputs longer than this will be truncated. An error will be raised + if multiple keys are truncated to the same value. The maximum output + width will be ``2*max_length + 7``. The default of 36 produces output + no longer than 79 characters wide. """ def __init__(self, filename_or_file: Union[str, TextIO], max_length: int = 36): - """ - log to a file, in a human readable format - - :param filename_or_file: the file to write the log to - :param max_length: the maximum length of keys and values to write to output. - Outputs longer than this will be truncated. An error will be raised - if multiple keys are truncated to the same value. The maximum output - width will be ``2*max_length + 7``. The default of 36 produces output - no longer than 79 characters wide. - """ self.max_length = max_length if isinstance(filename_or_file, str): self.file = open(filename_or_file, "wt") @@ -174,7 +180,7 @@ def write(self, key_values: Dict, key_excluded: Dict, step: int = 0) -> None: truncated_key = self._truncate(key) if truncated_key in key2str: raise ValueError( - f"Key '{key}' truncated to " f"'{truncated_key}' that already exists. Consider increasing `max_length`." + f"Key '{key}' truncated to '{truncated_key}' that already exists. Consider increasing `max_length`." ) key2str[truncated_key] = self._truncate(value_str) diff --git a/tests/test_envs.py b/tests/test_envs.py index b859ed703..34e3dfdf4 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -4,6 +4,7 @@ import numpy as np import pytest from gym import spaces +import warnings from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.envs import ( @@ -35,7 +36,7 @@ def test_env(env_id): :param env_id: (str) """ env = gym.make(env_id) - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: check_env(env) # Pendulum-v1 will produce a warning because the action space is @@ -50,7 +51,7 @@ def test_env(env_id): @pytest.mark.parametrize("env_class", ENV_CLASSES) def test_custom_envs(env_class): env = env_class() - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: check_env(env) # No warnings for custom envs assert len(record) == 0 @@ -68,7 +69,7 @@ def test_custom_envs(env_class): def test_bit_flipping(kwargs): # Additional tests for BitFlippingEnv env = BitFlippingEnv(**kwargs) - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: check_env(env) # No warnings for custom envs @@ -147,7 +148,7 @@ def patched_step(_action): def test_non_default_action_spaces(new_action_space): env = FakeImageEnv(discrete=False) # Default, should pass the test - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: check_env(env) # No warnings for custom envs diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 7d810c70e..19b2c90bd 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -580,7 +580,7 @@ def test_open_file_str_pathlib(tmp_path, pathtype): with open_path(pathtype(f"{tmp_path}/t1"), "w") as fp1: save_to_pkl(fp1, "foo") assert fp1.closed - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: assert load_from_pkl(pathtype(f"{tmp_path}/t1")) == "foo" assert not record @@ -588,7 +588,7 @@ def test_open_file_str_pathlib(tmp_path, pathtype): with open_path(pathtype(f"{tmp_path}/t1.custom_ext"), "w") as fp1: save_to_pkl(fp1, "foo") assert fp1.closed - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: assert load_from_pkl(pathtype(f"{tmp_path}/t1.custom_ext")) == "foo" assert not record @@ -596,7 +596,7 @@ def test_open_file_str_pathlib(tmp_path, pathtype): with open_path(pathtype(f"{tmp_path}/t1"), "w", suffix="pkl") as fp1: save_to_pkl(fp1, "foo") assert fp1.closed - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: assert load_from_pkl(pathtype(f"{tmp_path}/t1.pkl")) == "foo" assert not record @@ -604,11 +604,11 @@ def test_open_file_str_pathlib(tmp_path, pathtype): with open_path(pathtype(f"{tmp_path}/t2.pkl"), "w") as fp1: save_to_pkl(fp1, "foo") assert fp1.closed - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl")) == "foo" assert len(record) == 0 - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl", verbose=2)) == "foo" assert len(record) == 1 @@ -616,7 +616,7 @@ def test_open_file_str_pathlib(tmp_path, pathtype): fp.write("rubbish") fp.close() # test that a warning is only raised when verbose = 0 - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=0).close() open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=1).close() open_path(pathtype(f"{tmp_path}/t2"), "w", suffix="pkl", verbose=2).close() diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index c3d1d3065..e25a96cb5 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -4,6 +4,7 @@ import numpy as np import pytest from gym import spaces +import warnings from stable_baselines3 import SAC, TD3, HerReplayBuffer from stable_baselines3.common.monitor import Monitor @@ -120,7 +121,7 @@ def make_dict_env(): def test_deprecation(): venv = DummyVecEnv([lambda: gym.make("CartPole-v1")]) venv = VecNormalize(venv) - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: assert np.allclose(venv.ret, venv.returns) # Deprecation warning when using .ret assert len(record) == 1 From 7943e9f11707d1b632e9e16599b9c26b069e3e68 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 21 Feb 2022 22:48:30 +0100 Subject: [PATCH 3/3] Reformat --- tests/test_envs.py | 2 +- tests/test_vec_normalize.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_envs.py b/tests/test_envs.py index 34e3dfdf4..671e2a5e6 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -1,10 +1,10 @@ import types +import warnings import gym import numpy as np import pytest from gym import spaces -import warnings from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.envs import ( diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index e25a96cb5..8134340d9 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -1,10 +1,10 @@ import operator +import warnings import gym import numpy as np import pytest from gym import spaces -import warnings from stable_baselines3 import SAC, TD3, HerReplayBuffer from stable_baselines3.common.monitor import Monitor