-
Notifications
You must be signed in to change notification settings - Fork 63
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Incorporate distributed RL framework, Ape-X and Ape-X DQN (#246)
* Take context as init_communication input; all processes share the same context. * implement abstract classes for distributed and ApeX Learner wrapper * Implement params2numpy method that loads torch state_dict as array of np.ndarray. * add __init__ * implement worker as abstract class, not wrapper base class * Change apex_learner file name to learner. * Implement Ape-X worker and learner base classes * implement Ape-X DQN worker * Create base class for distributed architectures * Implement and test Ape-X DQN working on Pong * Accept current change (master) for PongNoFrameskip-v4 dqn config * Make env_info more explicit in run_pong script (accept incoming change) * Make learner return cpu state_dict (accept incoming change) * Fix minor errors * Implement ApeXWorker as a wrapper ApeXWorkerWrapper Implement Logger and test wandb functionality Add worker and logger render in argparse Implement load_param() method in logger and worker * Move num_workers to hyperparams, and add logger_interval to hyperparams. * Implement safe exit condition for all ray actors. * Change _init_communication -> init_communication and call outside of __init__ for all ApeX actors Implement test() in distributed architectures (load from checkpoint and run logger test()) * * Add documentation * Move collect_data from worker class to ApeX Wrapper * Change hyperparameters around * Add worker-verbose as argparse flag * * Move num_worker to hyper_param cfg * * Add author * Add separate integration test for ApeX * Add integration test flag to pong * argparse integration test flag store_false->store_true * Change default config to dqn. * * Log worker scores per update step on Wandb. * Modify integration test * Modify apex buffer config for integration test * Change distributed directory structure * Add documentation * Modify readme.md * Modify readme.md * Add Ape-X to README. * Add description about args flags for distributed training. Co-authored-by: khkim <[email protected]> Co-authored-by: Kyunghwan Kim <[email protected]>
- Loading branch information
1 parent
07743f6
commit 9e897ad
Showing
28 changed files
with
1,538 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,4 @@ | ||
# Our repository | ||
MIT License | ||
The MIT License (MIT) | ||
|
||
Copyright (c) 2019 Medipixel | ||
|
||
|
@@ -20,16 +19,3 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. | ||
|
||
# Mujoco models | ||
This work is derived from [MuJuCo models](http://www.mujoco.org/forum/index.php?resources/) used under the following license: | ||
``` | ||
This file is part of MuJoCo. | ||
Copyright 2009-2015 Roboti LLC. | ||
Mujoco :: Advanced physics simulation engine | ||
Source : www.roboti.us | ||
Version : 1.31 | ||
Released : 23Apr16 | ||
Author :: Vikash Kumar | ||
Contacts : [email protected] | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
"""Config for ApeX-DQN on Pong-No_FrameSkip-v4. | ||
- Author: Chris Yoon | ||
- Contact: [email protected] | ||
""" | ||
|
||
from rl_algorithms.common.helper_functions import identity | ||
|
||
agent = dict( | ||
type="ApeX", | ||
hyper_params=dict( | ||
gamma=0.99, | ||
tau=5e-3, | ||
buffer_size=int(2.5e5), # openai baselines: int(1e4) | ||
batch_size=512, # openai baselines: 32 | ||
update_starts_from=int(1e5), # openai baselines: int(1e4) | ||
multiple_update=1, # multiple learning updates | ||
train_freq=1, # in openai baselines, train_freq = 4 | ||
gradient_clip=10.0, # dueling: 10.0 | ||
n_step=5, | ||
w_n_step=1.0, | ||
w_q_reg=0.0, | ||
per_alpha=0.6, # openai baselines: 0.6 | ||
per_beta=0.4, | ||
per_eps=1e-6, | ||
loss_type=dict(type="DQNLoss"), | ||
# Epsilon Greedy | ||
max_epsilon=1.0, | ||
min_epsilon=0.1, # openai baselines: 0.01 | ||
epsilon_decay=1e-6, # openai baselines: 1e-7 / 1e-1 | ||
# grad_cam | ||
grad_cam_layer_list=[ | ||
"backbone.cnn.cnn_0.cnn", | ||
"backbone.cnn.cnn_1.cnn", | ||
"backbone.cnn.cnn_2.cnn", | ||
], | ||
num_workers=4, | ||
local_buffer_max_size=1000, | ||
worker_update_interval=50, | ||
logger_interval=2000, | ||
), | ||
learner_cfg=dict( | ||
type="DQNLearner", | ||
device="cuda", | ||
backbone=dict( | ||
type="CNN", | ||
configs=dict( | ||
input_sizes=[4, 32, 64], | ||
output_sizes=[32, 64, 64], | ||
kernel_sizes=[8, 4, 3], | ||
strides=[4, 2, 1], | ||
paddings=[1, 0, 0], | ||
), | ||
), | ||
head=dict( | ||
type="DuelingMLP", | ||
configs=dict( | ||
use_noisy_net=False, hidden_sizes=[512], output_activation=identity | ||
), | ||
), | ||
optim_cfg=dict( | ||
lr_dqn=0.0003, # dueling: 6.25e-5, openai baselines: 1e-4 | ||
weight_decay=0.0, # this makes saturation in cnn weights | ||
adam_eps=1e-8, # rainbow: 1.5e-4, openai baselines: 1e-8 | ||
), | ||
), | ||
worker_cfg=dict(type="DQNWorker", device="cpu",), | ||
logger_cfg=dict(type="DQNLogger",), | ||
comm_cfg=dict( | ||
learner_buffer_port=6554, | ||
learner_worker_port=6555, | ||
worker_buffer_port=6556, | ||
learner_logger_port=6557, | ||
send_batch_port=6558, | ||
priorities_port=6559, | ||
), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
"""Abstract class for distributed architectures. | ||
- Author: Chris Yoon | ||
- Contact: [email protected] | ||
""" | ||
|
||
from abc import ABC, abstractmethod | ||
|
||
|
||
class Architecture(ABC): | ||
"""Abstract class for distributed architectures""" | ||
|
||
@abstractmethod | ||
def _spawn(self): | ||
pass | ||
|
||
@abstractmethod | ||
def train(self): | ||
pass | ||
|
||
@abstractmethod | ||
def test(self): | ||
pass |
Oops, something went wrong.