Skip to content

Commit

Permalink
Bug fix for QR-DQN (#21)
Browse files Browse the repository at this point in the history
* Bug fix for QR-DQN

* Upgrade SB3
  • Loading branch information
araffin authored Mar 6, 2021
1 parent 7c2eb83 commit 9824dac
Show file tree
Hide file tree
Showing 9 changed files with 48 additions and 16 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6, 3.7] # 3.8 not supported yet by pytype
python-version: [3.6, 3.7, 3.8]

steps:
- uses: actions/checkout@v2
Expand All @@ -28,7 +28,7 @@ jobs:
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==1.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch==1.8.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
# Install dependencies for docs and tests
pip install stable_baselines3[extra,tests,docs]
# Install master version
Expand Down
12 changes: 12 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,18 @@
Changelog
==========

Release 1.0rc1 (WIP)
-------------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Upgraded to Stable-Baselines3 >= 1.0rc1

Bug Fixes:
^^^^^^^^^^
- Fixed a bug with ``QR-DQN`` predict method when using ``deterministic=False`` with image space


Pre-Release 0.11.1 (2021-02-27)
-------------------------------

Expand Down
6 changes: 3 additions & 3 deletions sb3_contrib/qrdqn/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,10 @@ def __init__(
)

if net_arch is None:
if features_extractor_class == FlattenExtractor:
net_arch = [64, 64]
else:
if features_extractor_class == NatureCNN:
net_arch = []
else:
net_arch = [64, 64]

self.n_quantiles = n_quantiles
self.net_arch = net_arch
Expand Down
3 changes: 2 additions & 1 deletion sb3_contrib/qrdqn/qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch as th
from stable_baselines3.common import logger
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.preprocessing import maybe_transpose
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update

Expand Down Expand Up @@ -211,7 +212,7 @@ def predict(
(used in recurrent policies)
"""
if not deterministic and np.random.rand() < self.exploration_rate:
if is_vectorized_observation(observation, self.observation_space):
if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space):
n_batch = observation.shape[0]
action = np.array([self.action_space.sample() for _ in range(n_batch)])
else:
Expand Down
6 changes: 3 additions & 3 deletions sb3_contrib/tqc/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,10 @@ def __init__(
)

if net_arch is None:
if features_extractor_class == FlattenExtractor:
net_arch = [256, 256]
else:
if features_extractor_class == NatureCNN:
net_arch = []
else:
net_arch = [256, 256]

actor_arch, critic_arch = get_actor_critic_arch(net_arch)

Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.11.1
1.0rc1
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3[tests,docs]>=0.11.1",
"stable_baselines3[tests,docs]>=1.0rc1",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",
Expand Down
23 changes: 21 additions & 2 deletions tests/test_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch as th
from stable_baselines3.common.identity_env import FakeImageEnv
from stable_baselines3.common.utils import zip_strict
from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped

from sb3_contrib import QRDQN, TQC

Expand All @@ -16,19 +17,37 @@ def test_cnn(tmp_path, model_class):
# Fake grayscale with frameskip
# Atari after preprocessing: 84x84x1, here we are using lower resolution
# to check that the network handle it automatically
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {TQC})
env = FakeImageEnv(
screen_height=40,
screen_width=40,
n_channels=1,
discrete=model_class not in {TQC},
)
kwargs = {}
if model_class in {TQC, QRDQN}:
# Avoid memory error when using replay buffer
# Reduce the size of the features and the number of quantiles
kwargs = dict(
buffer_size=250,
policy_kwargs=dict(n_quantiles=25, features_extractor_kwargs=dict(features_dim=32)),
policy_kwargs=dict(
n_quantiles=25,
features_extractor_kwargs=dict(features_dim=32),
),
)
model = model_class("CnnPolicy", env, **kwargs).learn(250)

obs = env.reset()

# FakeImageEnv is channel last by default and should be wrapped
assert is_vecenv_wrapped(model.get_env(), VecTransposeImage)

# Test stochastic predict with channel last input
if model_class == QRDQN:
model.exploration_rate = 0.9

for _ in range(10):
model.predict(obs, deterministic=False)

action, _ = model.predict(obs, deterministic=True)

model.save(tmp_path / SAVE_NAME)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,17 +184,17 @@ def test_set_env(model_class):
# create model
model = model_class("MlpPolicy", env, **kwargs)
# learn
model.learn(total_timesteps=300)
model.learn(total_timesteps=150)

# change env
model.set_env(env2)
# learn again
model.learn(total_timesteps=300)
model.learn(total_timesteps=150)

# change env test wrapping
model.set_env(env3)
# learn again
model.learn(total_timesteps=300)
model.learn(total_timesteps=150)


@pytest.mark.parametrize("model_class", MODEL_LIST)
Expand Down

0 comments on commit 9824dac

Please sign in to comment.