Skip to content

Commit 96b771f

Browse files
ndormannaraffin
andauthored
Implement DQN (#28)
* Created DQN template according to the paper. Next steps: - Create Policy - Complete Training - Debug * Changed Base Class * refactor save, to be consistence with overriding the excluded_save_params function. Do not try to exclude the parameters twice. * Added simple DQN policy * Finished learn and train function - missing correct loss computation * changed collect_rollouts to work with discrete space * moved discrete space collect_rollouts to dqn * basic dqn working * deleted SDE related code * added gradient clipping and moved greedy policy to policy * changed policy to implement target network and added soft update(in fact standart tau is 1 so hard update) * fixed policy setup * rebase target_update_intervall on _n_updates * adapted all tests all tests passing * Move to stable-baseline3 * Fixes for DQN * Fix tests + add CNNPolicy * Allow any optimizer for DQN * added some util functions to create a arbitrary linear schedule, fixed pickle problem with old exploration schedule * more documentation * changed buffer dtype * refactor and document * Added Sphinx Documentation Updated changelog.rst * removed custom collect_rollouts as it is no longer necessary * Implemented suggestions to clean code and documentation. * extracted some functions on tests to reduce duplicated code * added support for exploration_fraction * Fixed exploration_fraction * Added documentation * Fixed get_linear_fn -> proper progress scaling * Merged master * Added nature reference * Changed default parameters to https://www.nature.com/articles/nature14236/tables/1 * Fixed n_updates to be incremented correctly * Correct train_freq * Doc update * added special parameter for DQN in tests * different fix for test_discrete * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <[email protected]> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <[email protected]> * Update docs/modules/dqn.rst Co-authored-by: Antonin RAFFIN <[email protected]> * Added RMSProp in optimizer_kwargs, as described in nature paper * Exploration fraction is inverse of 50.000.000 (total frames) / 1.000.000 (frames with linear schedule) according to nature paper * Changelog update for buffer dtype * standard exlude parameters should be always excluded to assure proper saving only if intentionally included by ``include`` parameter * slightly more iterations on test_discrete to pass the test * added param use_rms_prop instead of mutable default argument * forgot alpha * using huber loss, adam and learning rate 1e-4 * account for train_freq in update_target_network * Added memory check for both buffers * Doc updated for buffer allocation * Added psutil Requirement * Adapted test_identity.py * Fixes with new SB3 version * Fix for tensorboard name * Convert assert to warning and fix tests * Refactor off-policy algorithms * Fixes * test: remove next_obs in replay buffer * Update changelog * Fix tests and use tmp_path where possible * Fix sampling bug in buffer * Do not store next obs on episode termination * Fix replay buffer sampling * Update comment * moved epsilon from policy to model * Update predict method * Update atari wrappers to match SB2 * Minor edit in the buffers * Update changelog * Merge branch 'master' into dqn * Update DQN to new structure * Fix tests and remove hardcoded path * Fix for DQN * Disable memory efficient replay buffer by default * Fix docstring * Add tests for memory efficient buffer * Update changelog * Split collect rollout * Move target update outside `train()` for DQN * Update changelog * Update linear schedule doc * Cleanup DQN code * Minor edit * Update version and docker images Co-authored-by: Antonin RAFFIN <[email protected]>
1 parent e47da42 commit 96b771f

32 files changed

+1280
-275
lines changed

.gitlab-ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
image: stablebaselines/stable-baselines3-cpu:0.6.0
1+
image: stablebaselines/stable-baselines3-cpu:0.8.0a1
22

33
type-check:
44
script:

README.md

-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ These algorithms will make it easier for the research community and industry to
4040
Please look at the issue for more details.
4141
Planned features:
4242

43-
- [ ] DQN (almost ready, currently in testing phase)
4443
- [ ] DDPG (you can use its successor TD3 for now)
4544
- [ ] HER
4645

docs/guide/algos.rst

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ A2C ✔️ ✔️ ✔️ ✔️
1212
PPO ✔️ ✔️ ✔️ ✔️ ✔️
1313
SAC ✔️ ❌ ❌ ❌ ❌
1414
TD3 ✔️ ❌ ❌ ❌ ❌
15+
DQN ❌ ✔️ ❌ ❌ ❌
1516
============ =========== ============ ================= =============== ================
1617

1718

docs/guide/examples.rst

+6-6
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ notebooks:
3333
Basic Usage: Training, Saving, Loading
3434
--------------------------------------
3535

36-
In the following example, we will train, save and load a A2C model on the Lunar Lander environment.
36+
In the following example, we will train, save and load a DQN model on the Lunar Lander environment.
3737

3838
.. image:: ../_static/img/colab-badge.svg
3939
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/saving_loading_dqn.ipynb
@@ -57,31 +57,31 @@ In the following example, we will train, save and load a A2C model on the Lunar
5757
5858
import gym
5959
60-
from stable_baselines3 import A2C
60+
from stable_baselines3 import DQN
6161
from stable_baselines3.common.evaluation import evaluate_policy
6262
6363
6464
# Create environment
6565
env = gym.make('LunarLander-v2')
6666
6767
# Instantiate the agent
68-
model = A2C('MlpPolicy', env, verbose=1)
68+
model = DQN('MlpPolicy', env, verbose=1)
6969
# Train the agent
7070
model.learn(total_timesteps=int(2e5))
7171
# Save the agent
72-
model.save("a2c_lunar")
72+
model.save("dqn_lunar")
7373
del model # delete trained model to demonstrate loading
7474
7575
# Load the trained agent
76-
model = A2C.load("a2c_lunar")
76+
model = DQN.load("dqn_lunar")
7777
7878
# Evaluate the agent
7979
mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=10)
8080
8181
# Enjoy trained agent
8282
obs = env.reset()
8383
for i in range(1000):
84-
action, _states = model.predict(obs)
84+
action, _states = model.predict(obs, deterministic=True)
8585
obs, rewards, dones, info = env.step(action)
8686
env.render()
8787

docs/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ Main Features
5858
modules/ppo
5959
modules/sac
6060
modules/td3
61+
modules/dqn
6162

6263
.. toctree::
6364
:maxdepth: 1

docs/misc/changelog.rst

+11-1
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,20 @@
33
Changelog
44
==========
55

6-
Pre-Release 0.8.0a0 (WIP)
6+
Pre-Release 0.8.0a1 (WIP)
77
------------------------------
88

99
Breaking Changes:
1010
^^^^^^^^^^^^^^^^^
11+
- ``AtariWrapper`` and other Atari wrappers were updated to match SB2 ones
1112
- ``save_replay_buffer`` now receives as argument the file path instead of the folder path (@tirafesi)
1213

1314
New Features:
1415
^^^^^^^^^^^^^
16+
- Added ``DQN`` Algorithm (@Artemis-Skade)
17+
- Buffer dtype is now set according to action and observation spaces for ``ReplayBuffer``
18+
- Added warning when allocation of a buffer may exceed the available memory of the system
19+
when ``psutil`` is available
1520

1621
Bug Fixes:
1722
^^^^^^^^^^
@@ -22,13 +27,18 @@ Deprecations:
2227

2328
Others:
2429
^^^^^^^
30+
- Refactored off-policy algorithm to share the same ``.learn()`` method
31+
- Split the ``collect_rollout()`` method for off-policy algorithms
32+
- Added ``_on_step()`` for off-policy base class
33+
- Optimized replay buffer size by removing the need of ``next_observations`` numpy array
2534

2635
Documentation:
2736
^^^^^^^^^^^^^^
2837
- Updated notebook links
2938
- Fixed a typo in the section of Enjoy a Trained Agent, in RL Baselines3 Zoo README. (@blurLake)
3039

3140

41+
3242
Pre-Release 0.7.0 (2020-06-10)
3343
------------------------------
3444

docs/modules/dqn.rst

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
.. _dqn:
2+
3+
.. automodule:: stable_baselines3.dqn
4+
5+
6+
DQN
7+
===
8+
9+
`Deep Q Network (DQN) <https://arxiv.org/abs/1312.5602>`_
10+
11+
.. rubric:: Available Policies
12+
13+
.. autosummary::
14+
:nosignatures:
15+
16+
MlpPolicy
17+
CnnPolicy
18+
19+
20+
Notes
21+
-----
22+
23+
- Original paper: https://arxiv.org/abs/1312.5602
24+
- Further reference: https://www.nature.com/articles/nature14236
25+
26+
.. note::
27+
This implementation provides only vanilla Deep Q-Learning and has no extensions such as Double-DQN, Dueling-DQN and Prioritized Experience Replay.
28+
29+
30+
Can I use?
31+
----------
32+
33+
- Recurrent policies: ❌
34+
- Multi processing: ❌
35+
- Gym spaces:
36+
37+
38+
============= ====== ===========
39+
Space Action Observation
40+
============= ====== ===========
41+
Discrete ✔ ✔
42+
Box ❌ ✔
43+
MultiDiscrete ❌ ✔
44+
MultiBinary ❌ ✔
45+
============= ====== ===========
46+
47+
48+
Example
49+
-------
50+
51+
.. code-block:: python
52+
53+
import gym
54+
import numpy as np
55+
56+
from stable_baselines3 import DQN
57+
from stable_baselines3.dqn import MlpPolicy
58+
59+
env = gym.make('Pendulum-v0')
60+
61+
model = DQN(MlpPolicy, env, verbose=1)
62+
model.learn(total_timesteps=10000, log_interval=4)
63+
model.save("dqn_pendulum")
64+
65+
del model # remove to demonstrate saving and loading
66+
67+
model = DQN.load("dqn_pendulum")
68+
69+
obs = env.reset()
70+
while True:
71+
action, _states = model.predict(obs, deterministic=True)
72+
obs, reward, done, info = env.step(action)
73+
env.render()
74+
if done:
75+
obs = env.reset()
76+
77+
Parameters
78+
----------
79+
80+
.. autoclass:: DQN
81+
:members:
82+
:inherited-members:
83+
84+
.. _dqn_policies:
85+
86+
DQN Policies
87+
-------------
88+
89+
.. autoclass:: MlpPolicy
90+
:members:
91+
:inherited-members:
92+
93+
.. autoclass:: CnnPolicy
94+
:members:

setup.cfg

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ per-file-ignores =
2727
./stable_baselines3/__init__.py:F401
2828
./stable_baselines3/common/__init__.py:F401
2929
./stable_baselines3/a2c/__init__.py:F401
30+
./stable_baselines3/dqn/__init__.py:F401
3031
./stable_baselines3/ppo/__init__.py:F401
3132
./stable_baselines3/sac/__init__.py:F401
3233
./stable_baselines3/td3/__init__.py:F401

setup.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,9 @@
108108
# For atari games,
109109
'atari_py~=0.2.0', 'pillow',
110110
# Tensorboard support
111-
'tensorboard'
111+
'tensorboard',
112+
# Checking memory taken by replay buffer
113+
'psutil'
112114
]
113115
},
114116
description='Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.',

stable_baselines3/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from stable_baselines3.ppo import PPO
55
from stable_baselines3.sac import SAC
66
from stable_baselines3.td3 import TD3
7+
from stable_baselines3.dqn import DQN
78

89
# Read version from file
910
version_file = os.path.join(os.path.dirname(__file__), 'version.txt')

0 commit comments

Comments
 (0)