Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ci/jenkins_tests/run_rllib_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,12 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_loss.py --iters=2

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_tf_policy.py --iters=2

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_torch_policy.py --iters=2

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/policy_evaluator_custom_workflow.py

Expand Down
33 changes: 29 additions & 4 deletions doc/source/rllib-concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ This page describes the internal concepts used to implement algorithms in RLlib.
Policies
--------

Policy classes encapsulate the core numerical components of RL algorithms. This typically includes the policy model that determines actions to take, a trajectory postprocessor for experiences, and a loss function to improve the policy given postprocessed experiences. For a simple example, see the policy gradients `graph definition <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/pg/pg_policy.py>`__.
Policy classes encapsulate the core numerical components of RL algorithms. This typically includes the policy model that determines actions to take, a trajectory postprocessor for experiences, and a loss function to improve the policy given postprocessed experiences. For a simple example, see the policy gradients `policy definition <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/pg/pg_policy.py>`__.

Most interaction with deep learning frameworks is isolated to the `Policy interface <https://github.com/ray-project/ray/blob/master/python/ray/rllib/policy/policy.py>`__, allowing RLlib to support multiple frameworks. To simplify the definition of policies, RLlib includes `Tensorflow <#building-policies-in-tensorflow>`__ and `PyTorch-specific <#building-policies-in-pytorch>`__ templates. You can also write your own from scratch. Here is an example:

Expand Down Expand Up @@ -148,7 +148,7 @@ We can create a `Trainer <#trainers>`__ and try running this policy on a toy env
tune.run(MyTrainer, config={"env": "CartPole-v0", "num_workers": 2})


If you run the above snippet, you'll probably notice that CartPole doesn't learn so well:
If you run the above snippet `(runnable file here) <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_tf_policy.py>`__, you'll probably notice that CartPole doesn't learn so well:

.. code-block:: bash

Expand Down Expand Up @@ -208,7 +208,7 @@ In the above section you saw how to compose a simple policy gradient algorithm w

Besides some boilerplate for defining the PPO configuration and some warnings, there are two important arguments to take note of here: ``make_policy_optimizer=choose_policy_optimizer``, and ``after_optimizer_step=update_kl``.

The ``choose_policy_optimizer`` function chooses which `Policy Optimizer <#policy-optimization>`__ to use for distributed training. You can think of these policy optimizers as coordinating the distributed workflow needed to improve the policy. Depending on the trainer config, PPO can switch between a simple synchronous optimizer (the default), or a multi-GPU optimizer that implements minibatch SGD:
The ``choose_policy_optimizer`` function chooses which `Policy Optimizer <#policy-optimization>`__ to use for distributed training. You can think of these policy optimizers as coordinating the distributed workflow needed to improve the policy. Depending on the trainer config, PPO can switch between a simple synchronous optimizer, or a multi-GPU optimizer that implements minibatch SGD (the default):

.. code-block:: python

Expand Down Expand Up @@ -349,7 +349,27 @@ Finally, note that you do not have to use ``build_tf_policy`` to define a Tensor
Building Policies in PyTorch
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Defining a policy in PyTorch is quite similar to that for TensorFlow (and the process of defining a trainer given a Torch policy is exactly the same). Building on the TF examples above, let's look at how the `A3C torch policy <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/a3c/a3c_torch_policy.py>`__ is defined:
Defining a policy in PyTorch is quite similar to that for TensorFlow (and the process of defining a trainer given a Torch policy is exactly the same). Here's a simple example of a trivial torch policy `(runnable file here) <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/custom_torch_policy.py>`__:

.. code-block:: python

from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy

def policy_gradient_loss(policy, batch_tensors):
logits, _, values, _ = policy.model({
SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS]
}, [])
action_dist = policy.dist_class(logits)
log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS])
return -batch_tensors[SampleBatch.REWARDS].dot(log_probs)

# <class 'ray.rllib.policy.torch_policy_template.MyTorchPolicy'>
MyTorchPolicy = build_torch_policy(
name="MyTorchPolicy",
loss_fn=policy_gradient_loss)

Now, building on the TF examples above, let's look at how the `A3C torch policy <https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/a3c/a3c_torch_policy.py>`__ is defined:

.. code-block:: python

Expand Down Expand Up @@ -423,6 +443,11 @@ You can find the full policy definition in `a3c_torch_policy.py <https://github.

In summary, the main differences between the PyTorch and TensorFlow policy builder functions is that the TF loss and stats functions are built symbolically when the policy is initialized, whereas for PyTorch these functions are called imperatively each time they are used.

Extending Existing Policies
~~~~~~~~~~~~~~~~~~~~~~~~~~~

(todo)

Policy Evaluation
-----------------

Expand Down
2 changes: 2 additions & 0 deletions doc/source/rllib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ Concepts and Building Custom Algorithms

- `Building Policies in PyTorch <rllib-concepts.html#building-policies-in-pytorch>`__

- `Extending Existing Policies <rllib-concepts.html#extending-existing-policies>`__

* `Policy Evaluation <rllib-concepts.html#policy-evaluation>`__
* `Policy Optimization <rllib-concepts.html#policy-optimization>`__
* `Trainers <rllib-concepts.html#trainers>`__
Expand Down
2 changes: 2 additions & 0 deletions python/ray/rllib/agents/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@
"clip_actions": True,
# Whether to use rllib or deepmind preprocessors by default
"preprocessor_pref": "deepmind",
# The default learning rate
"lr": 0.0001,

# === Evaluation ===
# Evaluate with every `evaluation_interval` training iterations.
Expand Down
47 changes: 47 additions & 0 deletions python/ray/rllib/examples/custom_tf_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse

import ray
from ray import tune
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils import try_import_tf

tf = try_import_tf()

parser = argparse.ArgumentParser()
parser.add_argument("--iters", type=int, default=200)


def policy_gradient_loss(policy, batch_tensors):
actions = batch_tensors[SampleBatch.ACTIONS]
rewards = batch_tensors[SampleBatch.REWARDS]
return -tf.reduce_mean(policy.action_dist.logp(actions) * rewards)


# <class 'ray.rllib.policy.tf_policy_template.MyTFPolicy'>
MyTFPolicy = build_tf_policy(
name="MyTFPolicy",
loss_fn=policy_gradient_loss,
)

# <class 'ray.rllib.agents.trainer_template.MyCustomTrainer'>
MyTrainer = build_trainer(
name="MyCustomTrainer",
default_policy=MyTFPolicy,
)

if __name__ == "__main__":
ray.init()
args = parser.parse_args()
tune.run(
MyTrainer,
stop={"training_iteration": args.iters},
config={
"env": "CartPole-v0",
"num_workers": 2,
})
45 changes: 45 additions & 0 deletions python/ray/rllib/examples/custom_torch_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse

import ray
from ray import tune
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy

parser = argparse.ArgumentParser()
parser.add_argument("--iters", type=int, default=200)


def policy_gradient_loss(policy, batch_tensors):
logits, _, values, _ = policy.model({
SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS]
}, [])
action_dist = policy.dist_class(logits)
log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS])
return -batch_tensors[SampleBatch.REWARDS].dot(log_probs)


# <class 'ray.rllib.policy.torch_policy_template.MyTorchPolicy'>
MyTorchPolicy = build_torch_policy(
name="MyTorchPolicy", loss_fn=policy_gradient_loss)

# <class 'ray.rllib.agents.trainer_template.MyCustomTrainer'>
MyTrainer = build_trainer(
name="MyCustomTrainer",
default_policy=MyTorchPolicy,
)

if __name__ == "__main__":
ray.init()
args = parser.parse_args()
tune.run(
MyTrainer,
stop={"training_iteration": args.iters},
config={
"env": "CartPole-v0",
"num_workers": 2,
})