Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rllib] Autoregressive action distributions #5304

Merged
merged 41 commits into from
Aug 10, 2019
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
556985a
wip
ericl Jul 29, 2019
515f7ce
wip
ericl Jul 29, 2019
dee364f
fix
ericl Jul 29, 2019
e2d4fcc
doc
ericl Jul 29, 2019
3a51e24
doc
ericl Jul 29, 2019
81c731f
Update dqn_policy.py
ericl Jul 29, 2019
a9e5e14
none
ericl Jul 29, 2019
05083c3
Merge branch 'autoregressive' of github.com:ericl/ray into autoregres…
ericl Jul 29, 2019
292d1ba
lint
ericl Jul 29, 2019
6e6059d
Update rllib-models.rst
ericl Jul 29, 2019
368188e
docs update
ericl Jul 29, 2019
c1980d7
Merge branch 'autoregressive' of github.com:ericl/ray into autoregres…
ericl Jul 29, 2019
ca4cbbc
doc update
ericl Jul 29, 2019
b469b47
move matrix
ericl Jul 29, 2019
a4e3069
model
ericl Jul 29, 2019
f5e5d0b
env
ericl Jul 29, 2019
2c34ebc
update
ericl Jul 29, 2019
67a2ae0
fix shuffle
ericl Aug 6, 2019
e223e85
Merge remote-tracking branch 'upstream/master' into autoregressive
ericl Aug 6, 2019
1c7c0b3
remove keras
ericl Aug 6, 2019
eca623f
Merge remote-tracking branch 'upstream/master' into autoregressive
ericl Aug 6, 2019
19b91da
update docs
ericl Aug 6, 2019
59a29f6
docs
ericl Aug 6, 2019
b1ed891
switch to logp for stability
ericl Aug 6, 2019
6e584a3
remove override
ericl Aug 6, 2019
ba1b531
fix op leak
ericl Aug 6, 2019
b551fcb
fix
ericl Aug 6, 2019
4c4786a
fix
ericl Aug 6, 2019
ba007f1
lint
ericl Aug 6, 2019
ef30c39
Merge remote-tracking branch 'upstream/master' into autoregressive
ericl Aug 6, 2019
afc5002
Merge remote-tracking branch 'upstream/master' into autoregressive
ericl Aug 6, 2019
1134d17
doc
ericl Aug 6, 2019
2d946f1
cateogrical
ericl Aug 7, 2019
da85071
fix
ericl Aug 8, 2019
7a66f09
fix vtrace
ericl Aug 8, 2019
b10b749
Merge remote-tracking branch 'upstream/master' into autoregressive
ericl Aug 8, 2019
406620d
fix appo
ericl Aug 8, 2019
fe8ecff
to note
ericl Aug 9, 2019
7e3f040
comments
ericl Aug 9, 2019
3ffb4b7
Merge remote-tracking branch 'upstream/master' into autoregressive
ericl Aug 10, 2019
979fd5a
fix merge
ericl Aug 10, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ci/jenkins_tests/run_rllib_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,9 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/rllib/examples/twostep_game.py --stop=2000 --run=APEX_QMIX

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/rllib/examples/autoregressive_action_dist.py --stop=150

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output /ray/rllib/train.py \
--env PongDeterministic-v4 \
Expand Down
4 changes: 2 additions & 2 deletions doc/source/rllib-algorithms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,8 @@ Tuned examples: `Two-step game <https://github.com/ray-project/ray/blob/master/r
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__

Multi-Agent Actor Critic (contrib/MADDPG)
-----------------------------------------
Multi-Agent Deep Deterministic Policy Gradient (contrib/MADDPG)
---------------------------------------------------------------
`[paper] <https://arxiv.org/abs/1706.02275>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/contrib/maddpg/maddpg.py>`__ MADDPG is a specialized multi-agent algorithm. Code here is adapted from https://github.com/openai/maddpg to integrate with RLlib multi-agent APIs. Please check `wsjeon/maddpg-rllib <https://github.com/wsjeon/maddpg-rllib>`__ for examples and more information.

**MADDPG-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
Expand Down
5 changes: 1 addition & 4 deletions doc/source/rllib-components.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion doc/source/rllib-concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ The action sampler is straightforward, it just takes the q_model, runs a forward
config):
# do max over Q values...
...
return action, action_prob
return action, action_logp

The remainder of DQN is similar to other algorithms. Target updates are handled by a ``after_optimizer_step`` callback that periodically copies the weights of the Q network to the target.

Expand Down
47 changes: 25 additions & 22 deletions doc/source/rllib-env.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,33 @@ RLlib works with several different types of environments, including `OpenAI Gym

.. image:: rllib-envs.svg

**Compatibility matrix**:

============= ======================= ================== =========== ==================
Algorithm Discrete Actions Continuous Actions Multi-Agent Recurrent Policies
============= ======================= ================== =========== ==================
A2C, A3C **Yes** `+parametric`_ **Yes** **Yes** **Yes**
PPO, APPO **Yes** `+parametric`_ **Yes** **Yes** **Yes**
PG **Yes** `+parametric`_ **Yes** **Yes** **Yes**
IMPALA **Yes** `+parametric`_ **Yes** **Yes** **Yes**
DQN, Rainbow **Yes** `+parametric`_ No **Yes** No
DDPG, TD3 No **Yes** **Yes** No
APEX-DQN **Yes** `+parametric`_ No **Yes** No
APEX-DDPG No **Yes** **Yes** No
SAC (todo) **Yes** **Yes** No
ES **Yes** **Yes** No No
ARS **Yes** **Yes** No No
QMIX **Yes** No **Yes** **Yes**
MARWIL **Yes** `+parametric`_ **Yes** **Yes** **Yes**
============= ======================= ================== =========== ==================
Feature Compatibility Matrix
----------------------------

============= ======================= ================== =========== ===========================
Algorithm Discrete Actions Continuous Multi-Agent Model Support
============= ======================= ================== =========== ===========================
A2C, A3C **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
PPO, APPO **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
PG **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
IMPALA **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_
DQN, Rainbow **Yes** `+parametric`_ No **Yes**
DDPG, TD3 No **Yes** **Yes**
APEX-DQN **Yes** `+parametric`_ No **Yes**
APEX-DDPG No **Yes** **Yes**
SAC (todo) **Yes** **Yes**
ES **Yes** **Yes** No
ARS **Yes** **Yes** No
QMIX **Yes** No **Yes** `+RNN`_
MARWIL **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_
============= ======================= ================== =========== ===========================

.. _`+parametric`: rllib-models.html#variable-length-parametric-action-spaces
.. _`+RNN`: rllib-models.html#recurrent-models
.. _`+autoreg`: rllib-models.html#autoregressive-action-distributions

Configuring Environments
------------------------

You can pass either a string name or a Python class to specify an environment. By default, strings will be interpreted as a gym `environment name <https://gym.openai.com/envs>`__. Custom env classes passed directly to the trainer must take a single ``env_config`` parameter in their constructor:

Expand Down Expand Up @@ -69,9 +75,6 @@ For a full runnable code example using the custom environment API, see `custom_e

The gym registry is not compatible with Ray. Instead, always use the registration flows documented above to ensure Ray workers can access the environment.

Configuring Environments
------------------------

In the above example, note that the ``env_creator`` function takes in an ``env_config`` object. This is a dict containing options passed in through your trainer. You can also access ``env_config.worker_index`` and ``env_config.vector_index`` to get the worker id and env id within the worker (if ``num_envs_per_worker > 0``). This can be useful if you want to train over an ensemble of different environments, for example:

.. code-block:: python
Expand Down
162 changes: 143 additions & 19 deletions doc/source/rllib-models.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
RLlib Models and Preprocessors
==============================
RLlib Models, Preprocessors, and Action Distributions
=====================================================

The following diagram provides a conceptual overview of data flow between different components in RLlib. We start with an ``Environment``, which given an action produces an observation. The observation is preprocessed by a ``Preprocessor`` and ``Filter`` (e.g. for running mean normalization) before being sent to a neural network ``Model``. The model output is in turn interpreted by an ``ActionDistribution`` to determine the next action.

Expand Down Expand Up @@ -145,6 +145,7 @@ Custom preprocessors should subclass the RLlib `preprocessor class <https://gith

import ray
import ray.rllib.agents.ppo as ppo
from ray.rllib.models import ModelCatalog
from ray.rllib.models.preprocessors import Preprocessor

class MyPreprocessorClass(Preprocessor):
Expand All @@ -164,6 +165,40 @@ Custom preprocessors should subclass the RLlib `preprocessor class <https://gith
},
})

Custom Action Distributions
---------------------------

Similar to custom models and preprocessors, you can also specify a custom action distribution class as follows. The action dist class is passed a reference to the ``model``, which you can use to access ``model.model_config`` or other attributes of the model. This is commonly used to implement `autoregressive action outputs <#autoregressive-action-distributions>`__. Not all algorithms support custom action distributions; see the `feature compatibility matrix <rllib-env.html#feature-compatibility-matrix>`__.

.. code-block:: python

import ray
import ray.rllib.agents.ppo as ppo
from ray.rllib.models import ModelCatalog
from ray.rllib.models.preprocessors import Preprocessor

class MyActionDist(ActionDistribution):
@staticmethod
def required_model_output_shape(action_space, model_config):
return 7 # controls model output feature vector size

def __init__(self, inputs, model):
super(MyActionDist, self).__init__(inputs, model)
assert model.num_outputs == 7

def sample(self): ...
def logp(self, actions): ...
def entropy(self): ...

ModelCatalog.register_custom_action_dist("my_dist", MyActionDist)

ray.init()
trainer = ppo.PPOTrainer(env="CartPole-v0", config={
"model": {
"custom_action_dist": "my_dist",
},
})

Supervised Model Losses
-----------------------

Expand Down Expand Up @@ -231,26 +266,115 @@ Custom models can be used to work with environments where (1) the set of valid a
return action_logits + inf_mask, state


Depending on your use case it may make sense to use just the masking, just action embeddings, or both. For a runnable example of this in code, check out `parametric_action_cartpole.py <https://github.com/ray-project/ray/blob/master/rllib/examples/parametric_action_cartpole.py>`__. Note that since masking introduces ``tf.float32.min`` values into the model output, this technique might not work with all algorithm options. For example, algorithms might crash if they incorrectly process the ``tf.float32.min`` values. The cartpole example has working configurations for DQN (must set ``hiddens=[]``), PPO (must disable running mean and set ``vf_share_layers=True``), and several other algorithms.
Depending on your use case it may make sense to use just the masking, just action embeddings, or both. For a runnable example of this in code, check out `parametric_action_cartpole.py <https://github.com/ray-project/ray/blob/master/rllib/examples/parametric_action_cartpole.py>`__. Note that since masking introduces ``tf.float32.min`` values into the model output, this technique might not work with all algorithm options. For example, algorithms might crash if they incorrectly process the ``tf.float32.min`` values. The cartpole example has working configurations for DQN (must set ``hiddens=[]``), PPO (must disable running mean and set ``vf_share_layers=True``), and several other algorithms. Not all algorithms support parametric actions; see the `feature compatibility matrix <rllib-env.html#feature-compatibility-matrix>`__.

Model-Based Rollouts
~~~~~~~~~~~~~~~~~~~~

With a custom policy, you can also perform model-based rollouts and optionally incorporate the results of those rollouts as training data. For example, suppose you wanted to extend PGPolicy for model-based rollouts. This involves overriding the ``compute_actions`` method of that policy:
Autoregressive Action Distributions
-----------------------------------

.. code-block:: python
In an action space with multiple components (e.g., ``Tuple(a1, a2)``), you might want ``a2`` to be conditioned on the sampled value of ``a1``, i.e., ``a2_sampled ~ P(a2 | a1_sampled, obs)``. Normally, ``a1`` and ``a2`` would be sampled independently, reducing the expressivity of the policy.

class ModelBasedPolicy(PGPolicy):
def compute_actions(self,
obs_batch,
state_batches,
prev_action_batch=None,
prev_reward_batch=None,
episodes=None):
# compute a batch of actions based on the current obs_batch
# and state of each episode (i.e., for multiagent). You can do
# whatever is needed here, e.g., MCTS rollouts.
return action_batch
To do this, you need both a custom model that implements the autoregressive pattern, and a custom action distribution class that leverages that model. The `autoregressive_action_dist.py <https://github.com/ray-project/ray/blob/master/rllib/examples/autoregressive_action_dist.py>`__ example shows how this can be implemented for a simple binary action space. For a more complex space, a more efficient architecture such as a `MADE <https://arxiv.org/abs/1502.03509>`__ is recommended. Note that sampling a `N-part` action requires `N` forward passes through the model, however computing the log probability of an action can be done in one pass:

.. code-block:: python

If you want take this rollouts data and append it to the sample batch, use the ``add_extra_batch()`` method of the `episode objects <https://github.com/ray-project/ray/blob/master/rllib/evaluation/episode.py>`__ passed in. For an example of this, see the ``testReturningModelBasedRolloutsData`` `unit test <https://github.com/ray-project/ray/blob/master/rllib/tests/test_multi_agent_env.py>`__.
class BinaryAutoregressiveOutput(ActionDistribution):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use literalinclude so that this doesn't go out of sync?

"""Action distribution P(a1, a2) = P(a1) * P(a2 | a1)"""

@staticmethod
def required_model_output_shape(self, model_config):
return 16 # controls model output feature vector size

def sample(self):
# first, sample a1
a1_dist = self._a1_distribution()
a1 = a1_dist.sample()

# sample a2 conditioned on a1
a2_dist = self._a2_distribution(a1)
a2 = a2_dist.sample()

# return the action tuple
return TupleActions([a1, a2])

def logp(self, actions):
a1, a2 = actions[:, 0], actions[:, 1]
a1_vec = tf.expand_dims(tf.cast(a1, tf.float32), 1)
a1_logits, a2_logits = self.model.action_model([self.inputs, a1_vec])
return (Categorical(a1_logits, None).logp(a1) + Categorical(
a2_logits, None).logp(a2))

def _a1_distribution(self):
BATCH = tf.shape(self.inputs)[0]
a1_logits, _ = self.model.action_model(
[self.inputs, tf.zeros((BATCH, 1))])
a1_dist = Categorical(a1_logits, None)
return a1_dist

def _a2_distribution(self, a1):
a1_vec = tf.expand_dims(tf.cast(a1, tf.float32), 1)
_, a2_logits = self.model.action_model([self.inputs, a1_vec])
a2_dist = Categorical(a2_logits, None)
return a2_dist

class AutoregressiveActionsModel(TFModelV2):
"""Implements the `.action_model` branch required above."""

def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
super(AutoregressiveActionsModel, self).__init__(
obs_space, action_space, num_outputs, model_config, name)
if action_space != Tuple([Discrete(2), Discrete(2)]):
raise ValueError(
"This model only supports the [2, 2] action space")

# Inputs
obs_input = tf.keras.layers.Input(
shape=obs_space.shape, name="obs_input")
a1_input = tf.keras.layers.Input(shape=(1, ), name="a1_input")
ctx_input = tf.keras.layers.Input(
shape=(num_outputs, ), name="ctx_input")

# Output of the model (normally 'logits', but for an autoregressive
# dist this is more like a context/feature layer encoding the obs)
context = tf.keras.layers.Dense(
num_outputs,
name="hidden",
activation=tf.nn.tanh,
kernel_initializer=normc_initializer(1.0))(obs_input)

# P(a1 | obs)
a1_logits = tf.keras.layers.Dense(
2,
name="a1_logits",
activation=None,
kernel_initializer=normc_initializer(0.01))(ctx_input)

# P(a2 | a1)
# --note: typically you'd want to implement P(a2 | a1, obs) as follows:
# a2_context = tf.keras.layers.Concatenate(axis=1)(
# [ctx_input, a1_input])
a2_context = a1_input
a2_hidden = tf.keras.layers.Dense(
16,
name="a2_hidden",
activation=tf.nn.tanh,
kernel_initializer=normc_initializer(1.0))(a2_context)
a2_logits = tf.keras.layers.Dense(
2,
name="a2_logits",
activation=None,
kernel_initializer=normc_initializer(0.01))(a2_hidden)

# Base layers
self.base_model = tf.keras.Model(obs_input, context)
self.register_variables(self.base_model.variables)
self.base_model.summary()

# Autoregressive action sampler
self.action_model = tf.keras.Model([ctx_input, a1_input],
[a1_logits, a2_logits])
self.action_model.summary()
self.register_variables(self.action_model.variables)

Not all algorithms support custom action distributions; see the `feature compatibility matrix <rllib-env.html#feature-compatibility-matrix>`__.
11 changes: 7 additions & 4 deletions doc/source/rllib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,23 @@ Training APIs
Environments
------------
* `RLlib Environments Overview <rllib-env.html>`__
* `Feature Compatibility Matrix <rllib-env.html#feature-compatibility-matrix>`__
* `OpenAI Gym <rllib-env.html#openai-gym>`__
* `Vectorized <rllib-env.html#vectorized>`__
* `Multi-Agent and Hierarchical <rllib-env.html#multi-agent-and-hierarchical>`__
* `Interfacing with External Agents <rllib-env.html#interfacing-with-external-agents>`__
* `Advanced Integrations <rllib-env.html#advanced-integrations>`__

Models and Preprocessors
------------------------
* `RLlib Models and Preprocessors Overview <rllib-models.html>`__
Models, Preprocessors, and Action Distributions
-----------------------------------------------
* `RLlib Models, Preprocessors, and Action Distributions Overview <rllib-models.html>`__
* `TensorFlow Models <rllib-models.html#tensorflow-models>`__
* `PyTorch Models <rllib-models.html#pytorch-models>`__
* `Custom Preprocessors <rllib-models.html#custom-preprocessors>`__
* `Custom Action Distributions <rllib-models.html#custom-action-distributions>`__
* `Supervised Model Losses <rllib-models.html#supervised-model-losses>`__
* `Variable-length / Parametric Action Spaces <rllib-models.html#variable-length-parametric-action-spaces>`__
* `Autoregressive Action Distributions <rllib-models.html#autoregressive-action-distributions>`__

Algorithms
----------
Expand Down Expand Up @@ -84,7 +87,7 @@ Algorithms
* Multi-agent specific

- `QMIX Monotonic Value Factorisation (QMIX, VDN, IQN) <rllib-algorithms.html#qmix-monotonic-value-factorisation-qmix-vdn-iqn>`__
- `Multi-Agent Actor Critic (contrib/MADDPG) <rllib-algorithms.html#multi-agent-actor-critic-contrib-maddpg>`__
- `Multi-Agent Deep Deterministic Policy Gradient (contrib/MADDPG) <rllib-algorithms.html#multi-agent-deep-deterministic-policy-gradient-contrib-maddpg>`__

* Offline

Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/a3c/a3c_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def actor_critic_loss(policy, batch_tensors):
SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS]
}) # TODO(ekl) seq lens shouldn't be None
values = policy.model.value_function()
dist = policy.dist_class(logits, policy.config["model"])
dist = policy.dist_class(logits, policy.model)
log_probs = dist.logp(batch_tensors[SampleBatch.ACTIONS])
policy.entropy = dist.entropy().mean()
policy.pi_err = -batch_tensors[Postprocessing.ADVANTAGES].dot(
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/ars/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(self,
model = ModelCatalog.get_model({
"obs": self.inputs
}, obs_space, action_space, dist_dim, model_config)
dist = dist_class(model.outputs, model_config=model_config)
dist = dist_class(model.outputs, model)
self.sampler = dist.sample()

self.variables = ray.experimental.tf_utils.TensorFlowVariables(
Expand Down
Loading