Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/source/rllib-algorithms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ Tuned examples: `PongNoFrameskip-v4 <https://github.com/ray-project/ray/blob/mas

**APPO-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):

.. warning::

Keras custom models are not compatible with multi-GPU (this includes PPO in single-GPU mode). This is because the multi-GPU implementation in RLlib relies on variable scopes to implement cross-GPU support.

.. literalinclude:: ../../python/ray/rllib/agents/ppo/appo.py
:language: python
:start-after: __sphinx_doc_begin__
Expand Down
4 changes: 4 additions & 0 deletions doc/source/rllib-models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ Custom Models (TensorFlow)

Custom TF models should subclass the common RLlib `model class <https://github.com/ray-project/ray/blob/master/python/ray/rllib/models/model.py>`__ and override the ``_build_layers_v2`` method. This method takes in a dict of tensor inputs (the observation ``obs``, ``prev_action``, and ``prev_reward``, ``is_training``), and returns a feature layer and float vector of the specified output size. You can also override the ``value_function`` method to implement a custom value branch. Additional supervised / self-supervised losses can be added via the ``custom_loss`` method. The model can then be registered and used in place of a built-in model:

.. warning::

Keras custom models are not compatible with multi-GPU (this includes PPO in single-GPU mode). This is because the multi-GPU implementation in RLlib relies on variable scopes to implement cross-GPU support.

.. code-block:: python

import ray
Expand Down
10 changes: 8 additions & 2 deletions python/ray/rllib/agents/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,20 @@
# Callbacks that will be run during various phases of training. These all
# take a single "info" dict as an argument. For episode callbacks, custom
# metrics can be attached to the episode by updating the episode object's
# custom metrics dict (see examples/custom_metrics_and_callbacks.py).
# custom metrics dict (see examples/custom_metrics_and_callbacks.py). You
# may also mutate the passed in batch data in your callback.
"callbacks": {
"on_episode_start": None, # arg: {"env": .., "episode": ...}
"on_episode_step": None, # arg: {"env": .., "episode": ...}
"on_episode_end": None, # arg: {"env": .., "episode": ...}
"on_sample_end": None, # arg: {"samples": .., "evaluator": ...}
"on_train_result": None, # arg: {"trainer": ..., "result": ...}
"on_postprocess_traj": None, # arg: {"batch": ..., "episode": ...}
"on_postprocess_traj": None, # arg: {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider eventually making this a class that you can pass in (like KerasCallbacks). Would enable type hinting and better documentation.

# "agent_id": ..., "episode": ...,
# "pre_batch": (before processing),
# "post_batch": (after processing),
# "all_pre_batches": (other agent ids),
# }
},
# Whether to attempt to continue training if a worker crashes.
"ignore_worker_failures": False,
Expand Down
8 changes: 7 additions & 1 deletion python/ray/rllib/evaluation/sample_batch_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,13 @@ def postprocess_batch_so_far(self, episode):
self.policy_builders[self.agent_to_policy[agent_id]].add_batch(
post_batch)
if self.postp_callback:
self.postp_callback({"episode": episode, "batch": post_batch})
self.postp_callback({
"episode": episode,
"agent_id": agent_id,
"pre_batch": pre_batches[agent_id],
"post_batch": post_batch,
"all_pre_batches": pre_batches,
})

self.agent_builders.clear()
self.agent_to_policy.clear()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def on_train_result(info):

def on_postprocess_traj(info):
episode = info["episode"]
batch = info["batch"]
batch = info["post_batch"]
print("postprocessed {} steps".format(batch.count))
if "num_batches" not in episode.custom_metrics:
episode.custom_metrics["num_batches"] = 0
Expand Down