-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
1,077 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# APEX DQN (Distributed Prioritized Experience Replay) | ||
|
||
[APEX DQN](https://arxiv.org/pdf/1803.00933.pdf) Distributed Prioritized Experience Replay is an algorithm that decouples | ||
active learning from sampling. Actors interact with their own instances of the environment by selecting actions according | ||
to a shared neural network, and accumulate the resulting experience in a shared experience replay memory; the learner replays samples of experience and updates the neural network. The architecture relies on prioritized experience replay to | ||
focus only on the most significant data generated by the actors. | ||
|
||
## Installation | ||
|
||
``` | ||
conda create -n rllib-apex-dqn python=3.10 | ||
conda activate rllib-apex-dqn | ||
pip install -r requirements.txt | ||
pip install -e '.[development]' | ||
``` | ||
|
||
## Usage | ||
|
||
[APEX-DQN Example]() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import argparse | ||
|
||
from rllib_apex_dqn.apex_dqn import ApexDQN, ApexDQNConfig | ||
|
||
import ray | ||
from ray import air, tune | ||
from ray.rllib.utils.test_utils import check_learning_achieved | ||
|
||
|
||
def get_cli_args(): | ||
"""Create CLI parser and return parsed arguments""" | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--run-as-test", action="store_true", default=False) | ||
args = parser.parse_args() | ||
print(f"Running with following CLI args: {args}") | ||
return args | ||
|
||
|
||
if __name__ == "__main__": | ||
args = get_cli_args() | ||
|
||
ray.init() | ||
|
||
config = ( | ||
ApexDQNConfig() | ||
.rollouts(num_rollout_workers=3) | ||
.environment("CartPole-v1") | ||
.training( | ||
replay_buffer_config={ | ||
"type": "MultiAgentPrioritizedReplayBuffer", | ||
"capacity": 20000, | ||
}, | ||
num_steps_sampled_before_learning_starts=1000, | ||
optimizer={"num_replay_buffer_shards": 2}, | ||
target_network_update_freq=500, | ||
training_intensity=4, | ||
) | ||
.resources(num_gpus=0) | ||
.reporting(min_sample_timesteps_per_iteration=1000, min_time_s_per_iteration=5) | ||
) | ||
|
||
stop_reward = 150.0 | ||
|
||
tuner = tune.Tuner( | ||
ApexDQN, | ||
param_space=config.to_dict(), | ||
run_config=air.RunConfig( | ||
stop={ | ||
"sampler_results/episode_reward_mean": stop_reward, | ||
"timesteps_total": 250000, | ||
}, | ||
failure_config=air.FailureConfig(fail_fast="raise"), | ||
), | ||
) | ||
results = tuner.fit() | ||
|
||
if args.run_as_test: | ||
check_learning_achieved(results, stop_reward) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
[build-system] | ||
requires = ["setuptools>=61.0"] | ||
build-backend = "setuptools.build_meta" | ||
|
||
[tool.setuptools.packages.find] | ||
where = ["src"] | ||
|
||
[project] | ||
name = "rllib-apex-dqn" | ||
authors = [{name = "Anyscale Inc."}] | ||
version = "0.1.0" | ||
description = "" | ||
readme = "README.md" | ||
requires-python = ">=3.7, <3.11" | ||
dependencies = ["gymnasium[atari]", "ray[rllib]==2.5.0"] | ||
|
||
[project.optional-dependencies] | ||
development = ["pytest>=7.2.2", "pre-commit==2.21.0", "tensorflow==2.11.0", "torch==1.12.0"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
tensorflow==2.11.0 | ||
torch==1.12.0 |
7 changes: 7 additions & 0 deletions
7
rllib_contrib/apex_dqn/src/rllib_apex_dqn/apex_dqn/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from rllib_apex_dqn.apex_dqn.apex_dqn import ApexDQN, ApexDQNConfig | ||
|
||
from ray.tune.registry import register_trainable | ||
|
||
__all__ = ["ApexDQNConfig", "ApexDQN"] | ||
|
||
register_trainable("rllib-contrib-apex-dqn", ApexDQN) |
Oops, something went wrong.