Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib Contrib] APEX DQN #36591

Merged
merged 11 commits into from
Oct 4, 2023
72 changes: 52 additions & 20 deletions .buildkite/pipeline.ml.yml
Original file line number Diff line number Diff line change
Expand Up @@ -463,23 +463,28 @@
doc/...


- label: ":exploding_death_star: RLlib Contrib: A3C Tests"
- label: ":exploding_death_star: RLlib Contrib: A2C Tests"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- (cd rllib_contrib/a3c && pip install -r requirements.txt && pip install -e .)
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/a2c && pip install -r requirements.txt && pip install -e ".[development"])
- ./ci/env/env_info.sh
- pytest rllib_contrib/a3c/tests/test_a3c.py
- pytest rllib_contrib/a2c/tests/
- python rllib_contrib/a2c/examples/a2c_cartpole_v1.py --run-as-test

- label: ":exploding_death_star: RLlib Contrib: MAML Tests"
- label: ":exploding_death_star: RLlib Contrib: A3C Tests"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT

- source /root/.bashrc
- (cd rllib_contrib/maml && pip install -r requirements.txt && pip install -e .)
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/a3c && pip install -r requirements.txt && pip install -e ".[development"])
- ./ci/env/env_info.sh
- pytest rllib_contrib/maml/tests/test_maml.py
- pytest rllib_contrib/a3c/tests/test_a3c.py

- label: ":exploding_death_star: RLlib Contrib: Alpha Star Tests"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
Expand All @@ -490,29 +495,56 @@
- pytest rllib_contrib/alpha_star/tests/
- python rllib_contrib/alpha_star/examples/multi-agent-cartpole-alpha-star.py --run-as-test

- label: ":exploding_death_star: RLlib Contrib: A2C Tests"
- label: ":exploding_death_star: RLlib Contrib: APEX DQN Tests"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- (cd rllib_contrib/a2c && pip install -r requirements.txt && pip install -e .)
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/apex_dqn && pip install -r requirements.txt && pip install -e ".[development"])
- ./ci/env/env_info.sh
- pytest rllib_contrib/a2c/tests/
- python rllib_contrib/a2c/examples/a2c_cartpole_v1.py --run-as-test
- pytest rllib_contrib/apex_dqn/tests/
- python rllib_contrib/apex_dqn/examples/apex_dqn_cartpole_v1.py --run-as-test

- label: ":exploding_death_star: RLlib Contrib: DDPG Tests"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/ddpg && pip install -r requirements.txt && pip install -e ".[development"])
- ./ci/env/env_info.sh
- pytest rllib_contrib/ddpg/tests/
- python rllib_contrib/ddpg/examples/ddpg_pendulum_v1.py --run-as-test

- label: ":exploding_death_star: RLlib Contrib: R2D2 Tests"
- label: ":exploding_death_star: RLlib Contrib: MAML Tests"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- (cd rllib_contrib/r2d2 && pip install -r requirements.txt && pip install -e .)

# Install mujoco necessary for the testing environments
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- sudo apt install libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf -y
- [ ! -d "/root/.mujoco" ] && mkdir -p /root/.mujoco && wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz \
&& mv mujoco210-linux-x86_64.tar.gz /root/.mujoco/. && \
(cd /root/.mujoco && tar -xf /root/.mujoco/mujoco210-linux-x86_64.tar.gz)
- export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/root/.mujoco/mujoco210/bin
- (cd rllib_contrib/maml && pip install -r requirements.txt && pip install -e ".[development"])
- ./ci/env/env_info.sh
- pytest rllib_contrib/r2d2/tests/
- python rllib_contrib/r2d2/examples/r2d2_stateless_cartpole.py --run-as-test
- pytest rllib_contrib/maml/tests/test_maml.py

- label: ":exploding_death_star: RLlib Contrib: DDPG Tests"
- label: ":exploding_death_star: RLlib Contrib: R2D2 Tests"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- (cd rllib_contrib/ddpg && pip install -r requirements.txt && pip install -e .)
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/r2d2 && pip install -r requirements.txt && pip install -e ".[development"])
- ./ci/env/env_info.sh
- pytest rllib_contrib/ddpg/tests/
- python rllib_contrib/ddpg/examples/ddpg_pendulum_v1.py --run-as-test
- pytest rllib_contrib/r2d2/tests/
- python rllib_contrib/r2d2/examples/r2d2_stateless_cartpole.py --run-as-test
6 changes: 3 additions & 3 deletions rllib_contrib/TOC.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@


* [A3C](./a3c)
* [MAML](./maml)
* [Alpha Star](./alpha_star)
* [A2C](./a2c)
* [Alpha Star](./alpha_star)
* [APEX DQN](./apex_dqn/)
* [DDPG](./ddpg)
* [MAML](./maml)
* [R2D2](./r2d2)




# Example Use-cases

* [Using TF-GNN for encoding graph spaces in RLlib using Tensorflow](https://github.com/kk-55/tf-gnn-example-for-rllib)
2 changes: 1 addition & 1 deletion rllib_contrib/a3c/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ version = "0.1.0"
description = ""
readme = "README.md"
requires-python = ">=3.7, <3.11"
dependencies = ["gym[accept-rom-license]", "gymnasium[mujoco]==0.26.3", "higher", "ray[rllib]==2.3.1"]
dependencies = ["gym[accept-rom-license]", "gymnasium[mujoco]==0.26.3", "ray[rllib]==2.3.1"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to our requirements, higher is used for MAML on PyTorch. Not sure why it went here in the first place.
I have no idea what the library does, just wanna point out this irregularity.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It slipped through


[project.optional-dependencies]
development = ["pytest>=7.2.2", "pre-commit==2.21.0", "tensorflow==2.11.0", "torch==1.12.0"]
19 changes: 19 additions & 0 deletions rllib_contrib/apex_dqn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# APEX DQN (Distributed Prioritized Experience Replay)

[APEX DQN](https://arxiv.org/pdf/1803.00933.pdf) Distributed Prioritized Experience Replay is an algorithm that decouples
active learning from sampling. Actors interact with their own instances of the environment by selecting actions according
to a shared neural network, and accumulate the resulting experience in a shared experience replay memory; the learner replays samples of experience and updates the neural network. The architecture relies on prioritized experience replay to
focus only on the most significant data generated by the actors.

## Installation

```
conda create -n rllib-apex-dqn python=3.10
conda activate rllib-apex-dqn
pip install -r requirements.txt
pip install -e '.[development]'
```

## Usage

[APEX-DQN Example]()
58 changes: 58 additions & 0 deletions rllib_contrib/apex_dqn/examples/apex_dqn_cartpole_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import argparse
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why "v1"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are using the cartpole v1 environment


from rllib_apex_dqn.apex_dqn import ApexDQN, ApexDQNConfig

import ray
from ray import air, tune
from ray.rllib.utils.test_utils import check_learning_achieved


def get_cli_args():
"""Create CLI parser and return parsed arguments"""
parser = argparse.ArgumentParser()
parser.add_argument("--run-as-test", action="store_true", default=False)
args = parser.parse_args()
print(f"Running with following CLI args: {args}")
return args


if __name__ == "__main__":
args = get_cli_args()

ray.init()

config = (
ApexDQNConfig()
.rollouts(num_rollout_workers=3)
.environment("CartPole-v1")
.training(
replay_buffer_config={
"type": "MultiAgentPrioritizedReplayBuffer",
"capacity": 20000,
},
num_steps_sampled_before_learning_starts=1000,
optimizer={"num_replay_buffer_shards": 2},
target_network_update_freq=500,
training_intensity=4,
)
.resources(num_gpus=0)
.reporting(min_sample_timesteps_per_iteration=1000, min_time_s_per_iteration=5)
)

stop_reward = 150.0

tuner = tune.Tuner(
ApexDQN,
param_space=config.to_dict(),
run_config=air.RunConfig(
stop={
"sampler_results/episode_reward_mean": stop_reward,
"timesteps_total": 250000,
},
failure_config=air.FailureConfig(fail_fast="raise"),
),
)
results = tuner.fit()

if args.run_as_test:
check_learning_achieved(results, stop_reward)
18 changes: 18 additions & 0 deletions rllib_contrib/apex_dqn/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"

[tool.setuptools.packages.find]
where = ["src"]

[project]
name = "rllib-apex-dqn"
authors = [{name = "Anyscale Inc."}]
version = "0.1.0"
description = ""
readme = "README.md"
requires-python = ">=3.7, <3.11"
dependencies = ["gymnasium[atari]", "ray[rllib]==2.5.0"]

[project.optional-dependencies]
development = ["pytest>=7.2.2", "pre-commit==2.21.0", "tensorflow==2.11.0", "torch==1.12.0"]
2 changes: 2 additions & 0 deletions rllib_contrib/apex_dqn/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
tensorflow==2.11.0
torch==1.12.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from rllib_apex_dqn.apex_dqn.apex_dqn import ApexDQN, ApexDQNConfig

from ray.tune.registry import register_trainable

__all__ = ["ApexDQNConfig", "ApexDQN"]

register_trainable("rllib-contrib-apex-dqn", ApexDQN)
Loading
Loading