Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions ci/jenkins_tests/run_rllib_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,16 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/rollout_worker_custom_workflow.py

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/eager_execution.py --iters=2

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output /ray/python/ray/rllib/train.py \
--env CartPole-v0 \
--run PPO \
--stop '{"training_iteration": 1}' \
--config '{"use_eager": true, "simple_optimizer": true}'

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_tf_policy.py --iters=2

Expand Down
31 changes: 31 additions & 0 deletions doc/source/rllib-concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,37 @@ In PPO we run ``setup_mixins`` before the loss function is called (i.e., ``befor

Finally, note that you do not have to use ``build_tf_policy`` to define a TensorFlow policy. You can alternatively subclass ``Policy``, ``TFPolicy``, or ``DynamicTFPolicy`` as convenient.

Building Policies in TensorFlow Eager
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

While RLlib runs all TF operations in graph mode, you can still leverage TensorFlow eager using `tf.py_function <https://www.tensorflow.org/api_docs/python/tf/py_function>`__. However, note that eager and non-eager tensors cannot be mixed within the ``py_function``. Here's an example of embedding eager execution within a policy loss function:

.. code-block:: python

def eager_loss(policy, batch_tensors):
"""Example of using embedded eager execution in a custom loss.

Here `compute_penalty` prints the actions and rewards for debugging, and
also computes a (dummy) penalty term to add to the loss.
"""

def compute_penalty(actions, rewards):
penalty = tf.reduce_mean(tf.cast(actions, tf.float32))
if random.random() > 0.9:
print("The eagerly computed penalty is", penalty, actions, rewards)
return penalty

actions = batch_tensors[SampleBatch.ACTIONS]
rewards = batch_tensors[SampleBatch.REWARDS]
penalty = tf.py_function(
compute_penalty, [actions, rewards], Tout=tf.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So is the trick mainly to run py_functions in the static mode?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. Right now this is limited to loss_fn, but with ModelV2 this can also include the model forward pass.


return penalty - tf.reduce_mean(policy.action_dist.logp(actions) * rewards)

You can find a runnable file for the above eager execution example `here <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/eager_execution.py>`__.

There is also experimental support for running the entire loss function in eager mode. This can be enabled with ``use_eager: True``, e.g., ``rllib train --env=CartPole-v0 --run=PPO --config='{"use_eager": true, "simple_optimizer": true}'``. However this currently only works for a couple algorithms.

Building Policies in PyTorch
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
2 changes: 2 additions & 0 deletions doc/source/rllib-examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ Custom Envs and Models
Example of adding batch norm layers to a custom model.
- `Parametric actions <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/parametric_action_cartpole.py>`__:
Example of how to handle variable-length or parametric action spaces.
- `Eager execution <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/eager_execution.py>`__:
Example of how to leverage TensorFlow eager to simplify debugging and design of custom models and policies.

Serving and Offline
-------------------
Expand Down
7 changes: 7 additions & 0 deletions doc/source/rllib-training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,13 @@ The ``"monitor": true`` config can be used to save Gym episode videos to the res
openaigym.video.0.31403.video000000.meta.json
openaigym.video.0.31403.video000000.mp4

TensorFlow Eager
~~~~~~~~~~~~~~~~

While RLlib uses TF graph mode for all computations, you can still leverage TF eager to inspect the intermediate state of computations using `tf.py_function <https://www.tensorflow.org/api_docs/python/tf/py_function>`__. Here's an example of using eager mode in `a custom RLlib model and loss <https://github.com/ray-project/ray/blob/master/python/ray/rllib/examples/eager_execution.py>`__.

There is also experimental support for running the entire loss function in eager mode. This can be enabled with ``use_eager: True``, e.g., ``rllib train --env=CartPole-v0 --run=PPO --config='{"use_eager": true, "simple_optimizer": true}'``. However this currently only works for a couple algorithms.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you specify "a couple"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Going to fill this out in ModelV2 -- hopefully by that time we will support most algos.


Episode Traces
~~~~~~~~~~~~~~

Expand Down
2 changes: 2 additions & 0 deletions doc/source/rllib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ Concepts and Building Custom Algorithms

- `Building Policies in TensorFlow <rllib-concepts.html#building-policies-in-tensorflow>`__

- `Building Policies in TensorFlow Eager <rllib-concepts.html#building-policies-in-tensorflow-eager>`__

- `Building Policies in PyTorch <rllib-concepts.html#building-policies-in-pytorch>`__

- `Extending Existing Policies <rllib-concepts.html#extending-existing-policies>`__
Expand Down
5 changes: 3 additions & 2 deletions python/ray/rllib/agents/a3c/a3c_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ def actor_critic_loss(policy, batch_tensors):
policy.loss = A3CLoss(
policy.action_dist, batch_tensors[SampleBatch.ACTIONS],
batch_tensors[Postprocessing.ADVANTAGES],
batch_tensors[Postprocessing.VALUE_TARGETS], policy.vf,
policy.config["vf_loss_coeff"], policy.config["entropy_coeff"])
batch_tensors[Postprocessing.VALUE_TARGETS],
policy.convert_to_eager(policy.vf), policy.config["vf_loss_coeff"],
policy.config["entropy_coeff"])
return policy.loss.total_loss


Expand Down
10 changes: 6 additions & 4 deletions python/ray/rllib/agents/ppo/ppo_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,10 @@ def reduce_mean_valid(t):

def ppo_surrogate_loss(policy, batch_tensors):
if policy.model.state_in:
max_seq_len = tf.reduce_max(policy.model.seq_lens)
mask = tf.sequence_mask(policy.model.seq_lens, max_seq_len)
max_seq_len = tf.reduce_max(
policy.convert_to_eager(policy.model.seq_lens))
mask = tf.sequence_mask(
policy.convert_to_eager(policy.model.seq_lens), max_seq_len)
mask = tf.reshape(mask, [-1])
else:
mask = tf.ones_like(
Expand All @@ -121,8 +123,8 @@ def ppo_surrogate_loss(policy, batch_tensors):
batch_tensors[BEHAVIOUR_LOGITS],
batch_tensors[SampleBatch.VF_PREDS],
policy.action_dist,
policy.value_function,
policy.kl_coeff,
policy.convert_to_eager(policy.value_function),
policy.convert_to_eager(policy.kl_coeff),
mask,
entropy_coeff=policy.config["entropy_coeff"],
clip_param=policy.config["clip_param"],
Expand Down
3 changes: 3 additions & 0 deletions python/ray/rllib/agents/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@
},
# Whether to attempt to continue training if a worker crashes.
"ignore_worker_failures": False,
# Execute TF loss functions in eager mode. This is currently experimental
# and only really works with the basic PG algorithm.
"use_eager": False,

# === Policy ===
# Arguments to pass to model. See models/catalog.py for a full list of the
Expand Down
101 changes: 101 additions & 0 deletions python/ray/rllib/examples/eager_execution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import random

import ray
from ray import tune
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.models import FullyConnectedNetwork, Model, ModelCatalog
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils import try_import_tf

tf = try_import_tf()

parser = argparse.ArgumentParser()
parser.add_argument("--iters", type=int, default=200)


class EagerModel(Model):
"""Example of using embedded eager execution in a custom model.

This shows how to use tf.py_function() to execute a snippet of TF code
in eager mode. Here the `self.forward_eager` method just prints out
the intermediate tensor for debug purposes, but you can in general
perform any TF eager operation in tf.py_function().
"""

def _build_layers_v2(self, input_dict, num_outputs, options):
self.fcnet = FullyConnectedNetwork(input_dict, self.obs_space,
self.action_space, num_outputs,
options)
feature_out = tf.py_function(self.forward_eager,
[self.fcnet.last_layer], tf.float32)

with tf.control_dependencies([feature_out]):
return tf.identity(self.fcnet.outputs), feature_out

def forward_eager(self, feature_layer):
assert tf.executing_eagerly()
if random.random() > 0.99:
print("Eagerly printing the feature layer mean value",
tf.reduce_mean(feature_layer))
return feature_layer


def policy_gradient_loss(policy, batch_tensors):
"""Example of using embedded eager execution in a custom loss.

Here `compute_penalty` prints the actions and rewards for debugging, and
also computes a (dummy) penalty term to add to the loss.

Alternatively, you can set config["use_eager"] = True, which will try to
automatically eagerify the entire loss function. However, this only works
if your loss doesn't reference any non-eager tensors. It also won't work
with the multi-GPU optimizer used by PPO.
"""

def compute_penalty(actions, rewards):
assert tf.executing_eagerly()
penalty = tf.reduce_mean(tf.cast(actions, tf.float32))
if random.random() > 0.9:
print("The eagerly computed penalty is", penalty, actions, rewards)
return penalty

actions = batch_tensors[SampleBatch.ACTIONS]
rewards = batch_tensors[SampleBatch.REWARDS]
penalty = tf.py_function(
compute_penalty, [actions, rewards], Tout=tf.float32)

return penalty - tf.reduce_mean(policy.action_dist.logp(actions) * rewards)


# <class 'ray.rllib.policy.tf_policy_template.MyTFPolicy'>
MyTFPolicy = build_tf_policy(
name="MyTFPolicy",
loss_fn=policy_gradient_loss,
)

# <class 'ray.rllib.agents.trainer_template.MyCustomTrainer'>
MyTrainer = build_trainer(
name="MyCustomTrainer",
default_policy=MyTFPolicy,
)

if __name__ == "__main__":
ray.init()
args = parser.parse_args()
ModelCatalog.register_custom_model("eager_model", EagerModel)
tune.run(
MyTrainer,
stop={"training_iteration": args.iters},
config={
"env": "CartPole-v0",
"num_workers": 0,
"model": {
"custom_model": "eager_model"
},
})
50 changes: 50 additions & 0 deletions python/ray/rllib/policy/dynamic_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ def __init__(self,
batch_divisibility_req=batch_divisibility_req)

# Phase 2 init
self._needs_eager_conversion = set()
self._eager_tensors = {}
before_loss_init(self, obs_space, action_space, config)
if not existing_inputs:
self._initialize_loss()
Expand All @@ -178,10 +180,26 @@ def get_obs_input_dict(self):
"""
return self.input_dict

def convert_to_eager(self, tensor):
"""Convert a graph tensor accessed in the loss to an eager tensor.

Experimental.
"""
if tf.executing_eagerly():
return self._eager_tensors[tensor]
else:
self._needs_eager_conversion.add(tensor)
return tensor

@override(TFPolicy)
def copy(self, existing_inputs):
"""Creates a copy of self using existing input placeholders."""

if self.config["use_eager"]:
raise ValueError(
"eager not implemented for multi-GPU, try setting "
"`simple_optimizer: true`")

# Note that there might be RNN state inputs at the end of the list
if self._state_inputs:
num_state_inputs = len(self._state_inputs) + 1
Expand Down Expand Up @@ -297,6 +315,38 @@ def fake_array(tensor):
loss = self._do_loss_init(batch_tensors)
for k in sorted(batch_tensors.accessed_keys):
loss_inputs.append((k, batch_tensors[k]))

# XXX experimental support for automatically eagerifying the loss.
# The main limitation right now is that TF doesn't support mixing eager
# and non-eager tensors, so losses that read non-eager tensors through
# `policy` need to use `policy.convert_to_eager(tensor)`.
if self.config["use_eager"]:
if not self.model:
raise ValueError("eager not implemented in this case")
graph_tensors = list(self._needs_eager_conversion)

def gen_loss(model_outputs, *args):
# fill in the batch tensor dict with eager ensors
eager_inputs = dict(
zip([k for (k, v) in loss_inputs],
args[:len(loss_inputs)]))
# fill in the eager versions of all accessed graph tensors
self._eager_tensors = dict(
zip(graph_tensors, args[len(loss_inputs):]))
# patch the action dist to use eager mode tensors
self.action_dist.inputs = model_outputs
return self._loss_fn(self, eager_inputs)

# TODO(ekl) also handle the stats funcs
loss = tf.py_function(
gen_loss,
# cast works around TypeError: Cannot convert provided value
# to EagerTensor. Provided value: 0.0 Requested dtype: int64
[self.model.outputs] + [
tf.cast(v, tf.float32) for (k, v) in loss_inputs
] + [tf.cast(t, tf.float32) for t in graph_tensors],
tf.float32)

TFPolicy._initialize_loss(self, loss, loss_inputs)
if self._grad_stats_fn:
self._stats_fetches.update(self._grad_stats_fn(self, self._grads))
Expand Down