Skip to content

Commit

Permalink
[RLlib Contrib] APEX DQN. (#36591)
Browse files Browse the repository at this point in the history
  • Loading branch information
avnishn authored Oct 4, 2023
1 parent 6606cc8 commit 67593a9
Show file tree
Hide file tree
Showing 10 changed files with 1,077 additions and 24 deletions.
72 changes: 52 additions & 20 deletions .buildkite/pipeline.ml.yml
Original file line number Diff line number Diff line change
Expand Up @@ -463,23 +463,28 @@
doc/...


- label: ":exploding_death_star: RLlib Contrib: A3C Tests"
- label: ":exploding_death_star: RLlib Contrib: A2C Tests"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- (cd rllib_contrib/a3c && pip install -r requirements.txt && pip install -e .)
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/a2c && pip install -r requirements.txt && pip install -e ".[development"])
- ./ci/env/env_info.sh
- pytest rllib_contrib/a3c/tests/test_a3c.py
- pytest rllib_contrib/a2c/tests/
- python rllib_contrib/a2c/examples/a2c_cartpole_v1.py --run-as-test

- label: ":exploding_death_star: RLlib Contrib: MAML Tests"
- label: ":exploding_death_star: RLlib Contrib: A3C Tests"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT

- source /root/.bashrc
- (cd rllib_contrib/maml && pip install -r requirements.txt && pip install -e .)
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/a3c && pip install -r requirements.txt && pip install -e ".[development"])
- ./ci/env/env_info.sh
- pytest rllib_contrib/maml/tests/test_maml.py
- pytest rllib_contrib/a3c/tests/test_a3c.py

- label: ":exploding_death_star: RLlib Contrib: Alpha Star Tests"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
Expand All @@ -490,29 +495,56 @@
- pytest rllib_contrib/alpha_star/tests/
- python rllib_contrib/alpha_star/examples/multi-agent-cartpole-alpha-star.py --run-as-test

- label: ":exploding_death_star: RLlib Contrib: A2C Tests"
- label: ":exploding_death_star: RLlib Contrib: APEX DQN Tests"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- (cd rllib_contrib/a2c && pip install -r requirements.txt && pip install -e .)
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/apex_dqn && pip install -r requirements.txt && pip install -e ".[development"])
- ./ci/env/env_info.sh
- pytest rllib_contrib/a2c/tests/
- python rllib_contrib/a2c/examples/a2c_cartpole_v1.py --run-as-test
- pytest rllib_contrib/apex_dqn/tests/
- python rllib_contrib/apex_dqn/examples/apex_dqn_cartpole_v1.py --run-as-test

- label: ":exploding_death_star: RLlib Contrib: DDPG Tests"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/ddpg && pip install -r requirements.txt && pip install -e ".[development"])
- ./ci/env/env_info.sh
- pytest rllib_contrib/ddpg/tests/
- python rllib_contrib/ddpg/examples/ddpg_pendulum_v1.py --run-as-test

- label: ":exploding_death_star: RLlib Contrib: R2D2 Tests"
- label: ":exploding_death_star: RLlib Contrib: MAML Tests"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- (cd rllib_contrib/r2d2 && pip install -r requirements.txt && pip install -e .)

# Install mujoco necessary for the testing environments
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- sudo apt install libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf -y
- [ ! -d "/root/.mujoco" ] && mkdir -p /root/.mujoco && wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz \
&& mv mujoco210-linux-x86_64.tar.gz /root/.mujoco/. && \
(cd /root/.mujoco && tar -xf /root/.mujoco/mujoco210-linux-x86_64.tar.gz)
- export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/root/.mujoco/mujoco210/bin
- (cd rllib_contrib/maml && pip install -r requirements.txt && pip install -e ".[development"])
- ./ci/env/env_info.sh
- pytest rllib_contrib/r2d2/tests/
- python rllib_contrib/r2d2/examples/r2d2_stateless_cartpole.py --run-as-test
- pytest rllib_contrib/maml/tests/test_maml.py

- label: ":exploding_death_star: RLlib Contrib: DDPG Tests"
- label: ":exploding_death_star: RLlib Contrib: R2D2 Tests"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- (cd rllib_contrib/ddpg && pip install -r requirements.txt && pip install -e .)
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/r2d2 && pip install -r requirements.txt && pip install -e ".[development"])
- ./ci/env/env_info.sh
- pytest rllib_contrib/ddpg/tests/
- python rllib_contrib/ddpg/examples/ddpg_pendulum_v1.py --run-as-test
- pytest rllib_contrib/r2d2/tests/
- python rllib_contrib/r2d2/examples/r2d2_stateless_cartpole.py --run-as-test
6 changes: 3 additions & 3 deletions rllib_contrib/TOC.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@


* [A3C](./a3c)
* [MAML](./maml)
* [Alpha Star](./alpha_star)
* [A2C](./a2c)
* [Alpha Star](./alpha_star)
* [APEX DQN](./apex_dqn/)
* [DDPG](./ddpg)
* [MAML](./maml)
* [R2D2](./r2d2)




# Example Use-cases

* [Using TF-GNN for encoding graph spaces in RLlib using Tensorflow](https://github.com/kk-55/tf-gnn-example-for-rllib)
2 changes: 1 addition & 1 deletion rllib_contrib/a3c/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ version = "0.1.0"
description = ""
readme = "README.md"
requires-python = ">=3.7, <3.11"
dependencies = ["gym[accept-rom-license]", "gymnasium[mujoco]==0.26.3", "higher", "ray[rllib]==2.3.1"]
dependencies = ["gym[accept-rom-license]", "gymnasium[mujoco]==0.26.3", "ray[rllib]==2.3.1"]

[project.optional-dependencies]
development = ["pytest>=7.2.2", "pre-commit==2.21.0", "tensorflow==2.11.0", "torch==1.12.0"]
19 changes: 19 additions & 0 deletions rllib_contrib/apex_dqn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# APEX DQN (Distributed Prioritized Experience Replay)

[APEX DQN](https://arxiv.org/pdf/1803.00933.pdf) Distributed Prioritized Experience Replay is an algorithm that decouples
active learning from sampling. Actors interact with their own instances of the environment by selecting actions according
to a shared neural network, and accumulate the resulting experience in a shared experience replay memory; the learner replays samples of experience and updates the neural network. The architecture relies on prioritized experience replay to
focus only on the most significant data generated by the actors.

## Installation

```
conda create -n rllib-apex-dqn python=3.10
conda activate rllib-apex-dqn
pip install -r requirements.txt
pip install -e '.[development]'
```

## Usage

[APEX-DQN Example]()
58 changes: 58 additions & 0 deletions rllib_contrib/apex_dqn/examples/apex_dqn_cartpole_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import argparse

from rllib_apex_dqn.apex_dqn import ApexDQN, ApexDQNConfig

import ray
from ray import air, tune
from ray.rllib.utils.test_utils import check_learning_achieved


def get_cli_args():
"""Create CLI parser and return parsed arguments"""
parser = argparse.ArgumentParser()
parser.add_argument("--run-as-test", action="store_true", default=False)
args = parser.parse_args()
print(f"Running with following CLI args: {args}")
return args


if __name__ == "__main__":
args = get_cli_args()

ray.init()

config = (
ApexDQNConfig()
.rollouts(num_rollout_workers=3)
.environment("CartPole-v1")
.training(
replay_buffer_config={
"type": "MultiAgentPrioritizedReplayBuffer",
"capacity": 20000,
},
num_steps_sampled_before_learning_starts=1000,
optimizer={"num_replay_buffer_shards": 2},
target_network_update_freq=500,
training_intensity=4,
)
.resources(num_gpus=0)
.reporting(min_sample_timesteps_per_iteration=1000, min_time_s_per_iteration=5)
)

stop_reward = 150.0

tuner = tune.Tuner(
ApexDQN,
param_space=config.to_dict(),
run_config=air.RunConfig(
stop={
"sampler_results/episode_reward_mean": stop_reward,
"timesteps_total": 250000,
},
failure_config=air.FailureConfig(fail_fast="raise"),
),
)
results = tuner.fit()

if args.run_as_test:
check_learning_achieved(results, stop_reward)
18 changes: 18 additions & 0 deletions rllib_contrib/apex_dqn/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"

[tool.setuptools.packages.find]
where = ["src"]

[project]
name = "rllib-apex-dqn"
authors = [{name = "Anyscale Inc."}]
version = "0.1.0"
description = ""
readme = "README.md"
requires-python = ">=3.7, <3.11"
dependencies = ["gymnasium[atari]", "ray[rllib]==2.5.0"]

[project.optional-dependencies]
development = ["pytest>=7.2.2", "pre-commit==2.21.0", "tensorflow==2.11.0", "torch==1.12.0"]
2 changes: 2 additions & 0 deletions rllib_contrib/apex_dqn/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
tensorflow==2.11.0
torch==1.12.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from rllib_apex_dqn.apex_dqn.apex_dqn import ApexDQN, ApexDQNConfig

from ray.tune.registry import register_trainable

__all__ = ["ApexDQNConfig", "ApexDQN"]

register_trainable("rllib-contrib-apex-dqn", ApexDQN)
Loading

0 comments on commit 67593a9

Please sign in to comment.