diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 228ee935..1021c188 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,7 +16,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.6, 3.7] # 3.8 not supported yet by pytype + python-version: [3.6, 3.7, 3.8] steps: - uses: actions/checkout@v2 @@ -28,7 +28,7 @@ jobs: run: | python -m pip install --upgrade pip # cpu version of pytorch - pip install torch==1.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + pip install torch==1.8.0+cpu -f https://download.pytorch.org/whl/torch_stable.html # Install dependencies for docs and tests pip install stable_baselines3[extra,tests,docs] # Install master version diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index f33d7c52..35b29fbd 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,18 @@ Changelog ========== +Release 1.0rc1 (WIP) +------------------------------- + +Breaking Changes: +^^^^^^^^^^^^^^^^^ +- Upgraded to Stable-Baselines3 >= 1.0rc1 + +Bug Fixes: +^^^^^^^^^^ +- Fixed a bug with ``QR-DQN`` predict method when using ``deterministic=False`` with image space + + Pre-Release 0.11.1 (2021-02-27) ------------------------------- diff --git a/sb3_contrib/qrdqn/policies.py b/sb3_contrib/qrdqn/policies.py index 48e94b8c..30ea6432 100644 --- a/sb3_contrib/qrdqn/policies.py +++ b/sb3_contrib/qrdqn/policies.py @@ -129,10 +129,10 @@ def __init__( ) if net_arch is None: - if features_extractor_class == FlattenExtractor: - net_arch = [64, 64] - else: + if features_extractor_class == NatureCNN: net_arch = [] + else: + net_arch = [64, 64] self.n_quantiles = n_quantiles self.net_arch = net_arch diff --git a/sb3_contrib/qrdqn/qrdqn.py b/sb3_contrib/qrdqn/qrdqn.py index 046f3a3a..b2c22db4 100644 --- a/sb3_contrib/qrdqn/qrdqn.py +++ b/sb3_contrib/qrdqn/qrdqn.py @@ -5,6 +5,7 @@ import torch as th from stable_baselines3.common import logger from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm +from stable_baselines3.common.preprocessing import maybe_transpose from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update @@ -211,7 +212,7 @@ def predict( (used in recurrent policies) """ if not deterministic and np.random.rand() < self.exploration_rate: - if is_vectorized_observation(observation, self.observation_space): + if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space): n_batch = observation.shape[0] action = np.array([self.action_space.sample() for _ in range(n_batch)]) else: diff --git a/sb3_contrib/tqc/policies.py b/sb3_contrib/tqc/policies.py index 07519b82..776a57d5 100644 --- a/sb3_contrib/tqc/policies.py +++ b/sb3_contrib/tqc/policies.py @@ -314,10 +314,10 @@ def __init__( ) if net_arch is None: - if features_extractor_class == FlattenExtractor: - net_arch = [256, 256] - else: + if features_extractor_class == NatureCNN: net_arch = [] + else: + net_arch = [256, 256] actor_arch, critic_arch = get_actor_critic_arch(net_arch) diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index af88ba82..0f82de4c 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -0.11.1 +1.0rc1 diff --git a/setup.py b/setup.py index d662214e..88957161 100644 --- a/setup.py +++ b/setup.py @@ -62,7 +62,7 @@ packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3[tests,docs]>=0.11.1", + "stable_baselines3[tests,docs]>=1.0rc1", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin", diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 91b6c80f..72488c31 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -6,6 +6,7 @@ import torch as th from stable_baselines3.common.identity_env import FakeImageEnv from stable_baselines3.common.utils import zip_strict +from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped from sb3_contrib import QRDQN, TQC @@ -16,19 +17,37 @@ def test_cnn(tmp_path, model_class): # Fake grayscale with frameskip # Atari after preprocessing: 84x84x1, here we are using lower resolution # to check that the network handle it automatically - env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {TQC}) + env = FakeImageEnv( + screen_height=40, + screen_width=40, + n_channels=1, + discrete=model_class not in {TQC}, + ) kwargs = {} if model_class in {TQC, QRDQN}: # Avoid memory error when using replay buffer # Reduce the size of the features and the number of quantiles kwargs = dict( buffer_size=250, - policy_kwargs=dict(n_quantiles=25, features_extractor_kwargs=dict(features_dim=32)), + policy_kwargs=dict( + n_quantiles=25, + features_extractor_kwargs=dict(features_dim=32), + ), ) model = model_class("CnnPolicy", env, **kwargs).learn(250) obs = env.reset() + # FakeImageEnv is channel last by default and should be wrapped + assert is_vecenv_wrapped(model.get_env(), VecTransposeImage) + + # Test stochastic predict with channel last input + if model_class == QRDQN: + model.exploration_rate = 0.9 + + for _ in range(10): + model.predict(obs, deterministic=False) + action, _ = model.predict(obs, deterministic=True) model.save(tmp_path / SAVE_NAME) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index bc16f16a..c4047e08 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -184,17 +184,17 @@ def test_set_env(model_class): # create model model = model_class("MlpPolicy", env, **kwargs) # learn - model.learn(total_timesteps=300) + model.learn(total_timesteps=150) # change env model.set_env(env2) # learn again - model.learn(total_timesteps=300) + model.learn(total_timesteps=150) # change env test wrapping model.set_env(env3) # learn again - model.learn(total_timesteps=300) + model.learn(total_timesteps=150) @pytest.mark.parametrize("model_class", MODEL_LIST)