diff --git a/ci/jenkins_tests/run_rllib_tests.sh b/ci/jenkins_tests/run_rllib_tests.sh index 1fbb46e8283a..9b83b2c9cc31 100644 --- a/ci/jenkins_tests/run_rllib_tests.sh +++ b/ci/jenkins_tests/run_rllib_tests.sh @@ -446,6 +446,9 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output python /ray/rllib/examples/twostep_game.py --stop=2000 --run=APEX_QMIX +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + /ray/ci/suppress_output python /ray/rllib/examples/autoregressive_action_dist.py --stop=150 + docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output /ray/rllib/train.py \ --env PongDeterministic-v4 \ diff --git a/doc/source/rllib-algorithms.rst b/doc/source/rllib-algorithms.rst index 244929b9c81f..7e144754e406 100644 --- a/doc/source/rllib-algorithms.rst +++ b/doc/source/rllib-algorithms.rst @@ -336,8 +336,8 @@ Tuned examples: `Two-step game `__ `[implementation] `__ MADDPG is a specialized multi-agent algorithm. Code here is adapted from https://github.com/openai/maddpg to integrate with RLlib multi-agent APIs. Please check `wsjeon/maddpg-rllib `__ for examples and more information. **MADDPG-specific configs** (see also `common configs `__): diff --git a/doc/source/rllib-components.svg b/doc/source/rllib-components.svg index dac6268736d9..b9f7bbb11549 100644 --- a/doc/source/rllib-components.svg +++ b/doc/source/rllib-components.svg @@ -1,4 +1 @@ - - - - + \ No newline at end of file diff --git a/doc/source/rllib-concepts.rst b/doc/source/rllib-concepts.rst index 32aed3851482..b58ceba8abc9 100644 --- a/doc/source/rllib-concepts.rst +++ b/doc/source/rllib-concepts.rst @@ -407,7 +407,7 @@ The action sampler is straightforward, it just takes the q_model, runs a forward config): # do max over Q values... ... - return action, action_prob + return action, action_logp The remainder of DQN is similar to other algorithms. Target updates are handled by a ``after_optimizer_step`` callback that periodically copies the weights of the Q network to the target. diff --git a/doc/source/rllib-env.rst b/doc/source/rllib-env.rst index 1aa9aba54aa3..e310e9e1ba56 100644 --- a/doc/source/rllib-env.rst +++ b/doc/source/rllib-env.rst @@ -5,27 +5,33 @@ RLlib works with several different types of environments, including `OpenAI Gym .. image:: rllib-envs.svg -**Compatibility matrix**: - -============= ======================= ================== =========== ================== -Algorithm Discrete Actions Continuous Actions Multi-Agent Recurrent Policies -============= ======================= ================== =========== ================== -A2C, A3C **Yes** `+parametric`_ **Yes** **Yes** **Yes** -PPO, APPO **Yes** `+parametric`_ **Yes** **Yes** **Yes** -PG **Yes** `+parametric`_ **Yes** **Yes** **Yes** -IMPALA **Yes** `+parametric`_ **Yes** **Yes** **Yes** -DQN, Rainbow **Yes** `+parametric`_ No **Yes** No -DDPG, TD3 No **Yes** **Yes** No -APEX-DQN **Yes** `+parametric`_ No **Yes** No -APEX-DDPG No **Yes** **Yes** No -SAC (todo) **Yes** **Yes** No -ES **Yes** **Yes** No No -ARS **Yes** **Yes** No No -QMIX **Yes** No **Yes** **Yes** -MARWIL **Yes** `+parametric`_ **Yes** **Yes** **Yes** -============= ======================= ================== =========== ================== +Feature Compatibility Matrix +---------------------------- + +============= ======================= ================== =========== =========================== +Algorithm Discrete Actions Continuous Multi-Agent Model Support +============= ======================= ================== =========== =========================== +A2C, A3C **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_ +PPO, APPO **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_ +PG **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_ +IMPALA **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+autoreg`_ +DQN, Rainbow **Yes** `+parametric`_ No **Yes** +DDPG, TD3 No **Yes** **Yes** +APEX-DQN **Yes** `+parametric`_ No **Yes** +APEX-DDPG No **Yes** **Yes** +SAC (todo) **Yes** **Yes** +ES **Yes** **Yes** No +ARS **Yes** **Yes** No +QMIX **Yes** No **Yes** `+RNN`_ +MARWIL **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_ +============= ======================= ================== =========== =========================== .. _`+parametric`: rllib-models.html#variable-length-parametric-action-spaces +.. _`+RNN`: rllib-models.html#recurrent-models +.. _`+autoreg`: rllib-models.html#autoregressive-action-distributions + +Configuring Environments +------------------------ You can pass either a string name or a Python class to specify an environment. By default, strings will be interpreted as a gym `environment name `__. Custom env classes passed directly to the trainer must take a single ``env_config`` parameter in their constructor: @@ -69,9 +75,6 @@ For a full runnable code example using the custom environment API, see `custom_e The gym registry is not compatible with Ray. Instead, always use the registration flows documented above to ensure Ray workers can access the environment. -Configuring Environments ------------------------- - In the above example, note that the ``env_creator`` function takes in an ``env_config`` object. This is a dict containing options passed in through your trainer. You can also access ``env_config.worker_index`` and ``env_config.vector_index`` to get the worker id and env id within the worker (if ``num_envs_per_worker > 0``). This can be useful if you want to train over an ensemble of different environments, for example: .. code-block:: python diff --git a/doc/source/rllib-models.rst b/doc/source/rllib-models.rst index a392dba5c358..d75798c95c62 100644 --- a/doc/source/rllib-models.rst +++ b/doc/source/rllib-models.rst @@ -1,5 +1,5 @@ -RLlib Models and Preprocessors -============================== +RLlib Models, Preprocessors, and Action Distributions +===================================================== The following diagram provides a conceptual overview of data flow between different components in RLlib. We start with an ``Environment``, which given an action produces an observation. The observation is preprocessed by a ``Preprocessor`` and ``Filter`` (e.g. for running mean normalization) before being sent to a neural network ``Model``. The model output is in turn interpreted by an ``ActionDistribution`` to determine the next action. @@ -145,6 +145,7 @@ Custom preprocessors should subclass the RLlib `preprocessor class `__. + +.. code-block:: python + + import ray + import ray.rllib.agents.ppo as ppo + from ray.rllib.models import ModelCatalog + from ray.rllib.models.preprocessors import Preprocessor + + class MyActionDist(ActionDistribution): + @staticmethod + def required_model_output_shape(action_space, model_config): + return 7 # controls model output feature vector size + + def __init__(self, inputs, model): + super(MyActionDist, self).__init__(inputs, model) + assert model.num_outputs == 7 + + def sample(self): ... + def logp(self, actions): ... + def entropy(self): ... + + ModelCatalog.register_custom_action_dist("my_dist", MyActionDist) + + ray.init() + trainer = ppo.PPOTrainer(env="CartPole-v0", config={ + "model": { + "custom_action_dist": "my_dist", + }, + }) + Supervised Model Losses ----------------------- @@ -231,26 +266,119 @@ Custom models can be used to work with environments where (1) the set of valid a return action_logits + inf_mask, state -Depending on your use case it may make sense to use just the masking, just action embeddings, or both. For a runnable example of this in code, check out `parametric_action_cartpole.py `__. Note that since masking introduces ``tf.float32.min`` values into the model output, this technique might not work with all algorithm options. For example, algorithms might crash if they incorrectly process the ``tf.float32.min`` values. The cartpole example has working configurations for DQN (must set ``hiddens=[]``), PPO (must disable running mean and set ``vf_share_layers=True``), and several other algorithms. +Depending on your use case it may make sense to use just the masking, just action embeddings, or both. For a runnable example of this in code, check out `parametric_action_cartpole.py `__. Note that since masking introduces ``tf.float32.min`` values into the model output, this technique might not work with all algorithm options. For example, algorithms might crash if they incorrectly process the ``tf.float32.min`` values. The cartpole example has working configurations for DQN (must set ``hiddens=[]``), PPO (must disable running mean and set ``vf_share_layers=True``), and several other algorithms. Not all algorithms support parametric actions; see the `feature compatibility matrix `__. -Model-Based Rollouts -~~~~~~~~~~~~~~~~~~~~ -With a custom policy, you can also perform model-based rollouts and optionally incorporate the results of those rollouts as training data. For example, suppose you wanted to extend PGPolicy for model-based rollouts. This involves overriding the ``compute_actions`` method of that policy: +Autoregressive Action Distributions +----------------------------------- + +In an action space with multiple components (e.g., ``Tuple(a1, a2)``), you might want ``a2`` to be conditioned on the sampled value of ``a1``, i.e., ``a2_sampled ~ P(a2 | a1_sampled, obs)``. Normally, ``a1`` and ``a2`` would be sampled independently, reducing the expressivity of the policy. + +To do this, you need both a custom model that implements the autoregressive pattern, and a custom action distribution class that leverages that model. The `autoregressive_action_dist.py `__ example shows how this can be implemented for a simple binary action space. For a more complex space, a more efficient architecture such as a `MADE `__ is recommended. Note that sampling a `N-part` action requires `N` forward passes through the model, however computing the log probability of an action can be done in one pass: .. code-block:: python - class ModelBasedPolicy(PGPolicy): - def compute_actions(self, - obs_batch, - state_batches, - prev_action_batch=None, - prev_reward_batch=None, - episodes=None): - # compute a batch of actions based on the current obs_batch - # and state of each episode (i.e., for multiagent). You can do - # whatever is needed here, e.g., MCTS rollouts. - return action_batch + class BinaryAutoregressiveOutput(ActionDistribution): + """Action distribution P(a1, a2) = P(a1) * P(a2 | a1)""" + + @staticmethod + def required_model_output_shape(self, model_config): + return 16 # controls model output feature vector size + + def sample(self): + # first, sample a1 + a1_dist = self._a1_distribution() + a1 = a1_dist.sample() + + # sample a2 conditioned on a1 + a2_dist = self._a2_distribution(a1) + a2 = a2_dist.sample() + + # return the action tuple + return TupleActions([a1, a2]) + + def logp(self, actions): + a1, a2 = actions[:, 0], actions[:, 1] + a1_vec = tf.expand_dims(tf.cast(a1, tf.float32), 1) + a1_logits, a2_logits = self.model.action_model([self.inputs, a1_vec]) + return (Categorical(a1_logits, None).logp(a1) + Categorical( + a2_logits, None).logp(a2)) + + def _a1_distribution(self): + BATCH = tf.shape(self.inputs)[0] + a1_logits, _ = self.model.action_model( + [self.inputs, tf.zeros((BATCH, 1))]) + a1_dist = Categorical(a1_logits, None) + return a1_dist + + def _a2_distribution(self, a1): + a1_vec = tf.expand_dims(tf.cast(a1, tf.float32), 1) + _, a2_logits = self.model.action_model([self.inputs, a1_vec]) + a2_dist = Categorical(a2_logits, None) + return a2_dist + + class AutoregressiveActionsModel(TFModelV2): + """Implements the `.action_model` branch required above.""" + + def __init__(self, obs_space, action_space, num_outputs, model_config, + name): + super(AutoregressiveActionsModel, self).__init__( + obs_space, action_space, num_outputs, model_config, name) + if action_space != Tuple([Discrete(2), Discrete(2)]): + raise ValueError( + "This model only supports the [2, 2] action space") + + # Inputs + obs_input = tf.keras.layers.Input( + shape=obs_space.shape, name="obs_input") + a1_input = tf.keras.layers.Input(shape=(1, ), name="a1_input") + ctx_input = tf.keras.layers.Input( + shape=(num_outputs, ), name="ctx_input") + + # Output of the model (normally 'logits', but for an autoregressive + # dist this is more like a context/feature layer encoding the obs) + context = tf.keras.layers.Dense( + num_outputs, + name="hidden", + activation=tf.nn.tanh, + kernel_initializer=normc_initializer(1.0))(obs_input) + + # P(a1 | obs) + a1_logits = tf.keras.layers.Dense( + 2, + name="a1_logits", + activation=None, + kernel_initializer=normc_initializer(0.01))(ctx_input) + + # P(a2 | a1) + # --note: typically you'd want to implement P(a2 | a1, obs) as follows: + # a2_context = tf.keras.layers.Concatenate(axis=1)( + # [ctx_input, a1_input]) + a2_context = a1_input + a2_hidden = tf.keras.layers.Dense( + 16, + name="a2_hidden", + activation=tf.nn.tanh, + kernel_initializer=normc_initializer(1.0))(a2_context) + a2_logits = tf.keras.layers.Dense( + 2, + name="a2_logits", + activation=None, + kernel_initializer=normc_initializer(0.01))(a2_hidden) + + # Base layers + self.base_model = tf.keras.Model(obs_input, context) + self.register_variables(self.base_model.variables) + self.base_model.summary() + + # Autoregressive action sampler + self.action_model = tf.keras.Model([ctx_input, a1_input], + [a1_logits, a2_logits]) + self.action_model.summary() + self.register_variables(self.action_model.variables) + + +.. note:: -If you want take this rollouts data and append it to the sample batch, use the ``add_extra_batch()`` method of the `episode objects `__ passed in. For an example of this, see the ``testReturningModelBasedRolloutsData`` `unit test `__. + Not all algorithms support autoregressive action distributions; see the `feature compatibility matrix `__. diff --git a/doc/source/rllib.rst b/doc/source/rllib.rst index aa16e02d3901..6bc2789404e9 100644 --- a/doc/source/rllib.rst +++ b/doc/source/rllib.rst @@ -35,20 +35,23 @@ Training APIs Environments ------------ * `RLlib Environments Overview `__ +* `Feature Compatibility Matrix `__ * `OpenAI Gym `__ * `Vectorized `__ * `Multi-Agent and Hierarchical `__ * `Interfacing with External Agents `__ * `Advanced Integrations `__ -Models and Preprocessors ------------------------- -* `RLlib Models and Preprocessors Overview `__ +Models, Preprocessors, and Action Distributions +----------------------------------------------- +* `RLlib Models, Preprocessors, and Action Distributions Overview `__ * `TensorFlow Models `__ * `PyTorch Models `__ * `Custom Preprocessors `__ +* `Custom Action Distributions `__ * `Supervised Model Losses `__ * `Variable-length / Parametric Action Spaces `__ +* `Autoregressive Action Distributions `__ Algorithms ---------- @@ -84,7 +87,7 @@ Algorithms * Multi-agent specific - `QMIX Monotonic Value Factorisation (QMIX, VDN, IQN) `__ - - `Multi-Agent Actor Critic (contrib/MADDPG) `__ + - `Multi-Agent Deep Deterministic Policy Gradient (contrib/MADDPG) `__ * Offline diff --git a/rllib/agents/a3c/a3c_torch_policy.py b/rllib/agents/a3c/a3c_torch_policy.py index f14f1f16d247..014c6a44ef80 100644 --- a/rllib/agents/a3c/a3c_torch_policy.py +++ b/rllib/agents/a3c/a3c_torch_policy.py @@ -18,7 +18,7 @@ def actor_critic_loss(policy, batch_tensors): SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS] }) # TODO(ekl) seq lens shouldn't be None values = policy.model.value_function() - dist = policy.dist_class(logits, policy.config["model"]) + dist = policy.dist_class(logits, policy.model) log_probs = dist.logp(batch_tensors[SampleBatch.ACTIONS]) policy.entropy = dist.entropy().mean() policy.pi_err = -batch_tensors[Postprocessing.ADVANTAGES].dot( diff --git a/rllib/agents/ars/policies.py b/rllib/agents/ars/policies.py index 6029241c9f97..ce3e837a3fb9 100644 --- a/rllib/agents/ars/policies.py +++ b/rllib/agents/ars/policies.py @@ -81,7 +81,7 @@ def __init__(self, model = ModelCatalog.get_model({ "obs": self.inputs }, obs_space, action_space, dist_dim, model_config) - dist = dist_class(model.outputs, model_config=model_config) + dist = dist_class(model.outputs, model) self.sampler = dist.sample() self.variables = ray.experimental.tf_utils.TensorFlowVariables( diff --git a/rllib/agents/dqn/dqn_policy.py b/rllib/agents/dqn/dqn_policy.py index 46c891f7d7be..168e453487cc 100644 --- a/rllib/agents/dqn/dqn_policy.py +++ b/rllib/agents/dqn/dqn_policy.py @@ -109,10 +109,9 @@ class QValuePolicy(object): def __init__(self, q_values, observations, num_actions, stochastic, eps, softmax, softmax_temp, model_config): if softmax: - action_dist = Categorical( - q_values / softmax_temp, model_config=model_config) + action_dist = Categorical(q_values / softmax_temp) self.action = action_dist.sample() - self.action_prob = action_dist.sampled_action_prob() + self.action_prob = tf.exp(action_dist.sampled_action_logp()) return deterministic_actions = tf.argmax(q_values, axis=1) @@ -260,7 +259,10 @@ def build_q_networks(policy, q_model, input_dict, obs_space, action_space, config["model"]) policy.output_actions, policy.action_prob = qvp.action, qvp.action_prob - return policy.output_actions, policy.action_prob + actions = policy.output_actions + action_prob = (tf.log(policy.action_prob) + if policy.action_prob is not None else None) + return actions, action_prob def _build_parameter_noise(policy, pnet_params): diff --git a/rllib/agents/dqn/simple_q_policy.py b/rllib/agents/dqn/simple_q_policy.py index 0212fdef6524..44fd188533b2 100644 --- a/rllib/agents/dqn/simple_q_policy.py +++ b/rllib/agents/dqn/simple_q_policy.py @@ -128,9 +128,9 @@ def build_action_sampler(policy, q_model, input_dict, obs_space, action_space, deterministic_actions) action = tf.cond(policy.stochastic, lambda: stochastic_actions, lambda: deterministic_actions) - action_prob = None + action_logp = None - return action, action_prob + return action, action_logp def build_q_losses(policy, batch_tensors): diff --git a/rllib/agents/es/policies.py b/rllib/agents/es/policies.py index 8b15cfca4a85..3ddb4dbeda9d 100644 --- a/rllib/agents/es/policies.py +++ b/rllib/agents/es/policies.py @@ -59,7 +59,7 @@ def __init__(self, sess, action_space, obs_space, preprocessor, model = ModelCatalog.get_model({ "obs": self.inputs }, obs_space, action_space, dist_dim, model_options) - dist = dist_class(model.outputs, model_config=model_options) + dist = dist_class(model.outputs, model) self.sampler = dist.sample() self.variables = ray.experimental.tf_utils.TensorFlowVariables( diff --git a/rllib/agents/impala/vtrace.py b/rllib/agents/impala/vtrace.py index 7062d5d6f7a3..6edc9f571b83 100644 --- a/rllib/agents/impala/vtrace.py +++ b/rllib/agents/impala/vtrace.py @@ -49,14 +49,14 @@ def log_probs_from_logits_and_actions(policy_logits, actions, - config, - dist_class=Categorical): + dist_class=Categorical, + model=None): return multi_log_probs_from_logits_and_actions([policy_logits], [actions], - dist_class, config)[0] + dist_class, model)[0] def multi_log_probs_from_logits_and_actions(policy_logits, actions, dist_class, - config): + model): """Computes action log-probs from policy logits and actions. In the notation used throughout documentation and comments, T refers to the @@ -78,7 +78,6 @@ def multi_log_probs_from_logits_and_actions(policy_logits, actions, dist_class, [T, B, ...] with actions. dist_class: Python class of the action distribution - config: Trainer config dict Returns: A list with length of ACTION_SPACE of float32 @@ -100,8 +99,7 @@ def multi_log_probs_from_logits_and_actions(policy_logits, actions, dist_class, tf.concat([[-1], a_shape[2:]], axis=0)) log_probs.append( tf.reshape( - dist_class(policy_logits_flat, - model_config=config["model"]).logp(actions_flat), + dist_class(policy_logits_flat, model).logp(actions_flat), a_shape[:2])) return log_probs @@ -114,8 +112,8 @@ def from_logits(behaviour_policy_logits, rewards, values, bootstrap_value, - config, dist_class=Categorical, + model=None, clip_rho_threshold=1.0, clip_pg_rho_threshold=1.0, name="vtrace_from_logits"): @@ -127,8 +125,8 @@ def from_logits(behaviour_policy_logits, rewards, values, bootstrap_value, - config, dist_class, + model, clip_rho_threshold=clip_rho_threshold, clip_pg_rho_threshold=clip_pg_rho_threshold, name=name) @@ -151,8 +149,9 @@ def multi_from_logits(behaviour_policy_logits, rewards, values, bootstrap_value, - config, dist_class, + model, + behaviour_action_log_probs=None, clip_rho_threshold=1.0, clip_pg_rho_threshold=1.0, name="vtrace_from_logits"): @@ -203,6 +202,8 @@ def multi_from_logits(behaviour_policy_logits, bootstrap_value: A float32 of shape [B] with the value function estimate at time T. dist_class: action distribution class for the logits. + model: backing ModelV2 instance + behaviour_action_log_probs: precalculated values of the behaviour actions clip_rho_threshold: A scalar float32 tensor with the clipping threshold for importance weights (rho) when calculating the baseline targets (vs). rho^bar in the paper. @@ -242,9 +243,16 @@ def multi_from_logits(behaviour_policy_logits, discounts, rewards, values, bootstrap_value ]): target_action_log_probs = multi_log_probs_from_logits_and_actions( - target_policy_logits, actions, dist_class, config) - behaviour_action_log_probs = multi_log_probs_from_logits_and_actions( - behaviour_policy_logits, actions, dist_class, config) + target_policy_logits, actions, dist_class, model) + + if (len(behaviour_policy_logits) > 1 + or behaviour_action_log_probs is None): + # can't use precalculated values, recompute them. Note that + # recomputing won't work well for autoregressive action dists + # which may have variables not captured by 'logits' + behaviour_action_log_probs = ( + multi_log_probs_from_logits_and_actions( + behaviour_policy_logits, actions, dist_class, model)) log_rhos = get_log_rhos(target_action_log_probs, behaviour_action_log_probs) diff --git a/rllib/agents/impala/vtrace_policy.py b/rllib/agents/impala/vtrace_policy.py index 45228ea7190e..d13bbea006a2 100644 --- a/rllib/agents/impala/vtrace_policy.py +++ b/rllib/agents/impala/vtrace_policy.py @@ -16,7 +16,7 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.policy.tf_policy import LearningRateSchedule, \ - EntropyCoeffSchedule + EntropyCoeffSchedule, ACTION_LOGP from ray.rllib.utils.explained_variance import explained_variance from ray.rllib.utils import try_import_tf @@ -33,6 +33,7 @@ def __init__(self, actions_logp, actions_entropy, dones, + behaviour_action_logp, behaviour_logits, target_logits, discount, @@ -40,6 +41,7 @@ def __init__(self, values, bootstrap_value, dist_class, + model, valid_mask, config, vf_loss_coeff=0.5, @@ -57,6 +59,7 @@ def __init__(self, actions_logp: A float32 tensor of shape [T, B]. actions_entropy: A float32 tensor of shape [T, B]. dones: A bool tensor of shape [T, B]. + behaviour_action_logp: Tensor of shape [T, B]. behaviour_logits: A list with length of ACTION_SPACE of float32 tensors of shapes [T, B, ACTION_SPACE[0]], @@ -79,6 +82,7 @@ def __init__(self, # Compute vtrace on the CPU for better perf. with tf.device("/cpu:0"): self.vtrace_returns = vtrace.multi_from_logits( + behaviour_action_log_probs=behaviour_action_logp, behaviour_policy_logits=behaviour_logits, target_policy_logits=target_logits, actions=tf.unstack(actions, axis=2), @@ -87,10 +91,10 @@ def __init__(self, values=values, bootstrap_value=bootstrap_value, dist_class=dist_class, + model=model, clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32), clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold, - tf.float32), - config=config) + tf.float32)) self.value_targets = self.vtrace_returns.vs # The policy gradients loss @@ -164,6 +168,7 @@ def make_time_major(*args, **kw): actions = batch_tensors[SampleBatch.ACTIONS] dones = batch_tensors[SampleBatch.DONES] rewards = batch_tensors[SampleBatch.REWARDS] + behaviour_action_logp = batch_tensors[ACTION_LOGP] behaviour_logits = batch_tensors[BEHAVIOUR_LOGITS] unpacked_behaviour_logits = tf.split( behaviour_logits, output_hidden_shape, axis=1) @@ -190,6 +195,8 @@ def make_time_major(*args, **kw): actions_entropy=make_time_major( action_dist.multi_entropy(), drop_last=True), dones=make_time_major(dones, drop_last=True), + behaviour_action_logp=make_time_major( + behaviour_action_logp, drop_last=True), behaviour_logits=make_time_major( unpacked_behaviour_logits, drop_last=True), target_logits=make_time_major(unpacked_outputs, drop_last=True), @@ -198,6 +205,7 @@ def make_time_major(*args, **kw): values=make_time_major(values, drop_last=True), bootstrap_value=make_time_major(values)[-1], dist_class=Categorical if is_multidiscrete else policy.dist_class, + model=policy.model, valid_mask=make_time_major(mask, drop_last=True), config=policy.config, vf_loss_coeff=policy.config["vf_loss_coeff"], diff --git a/rllib/agents/impala/vtrace_test.py b/rllib/agents/impala/vtrace_test.py index 9d88fefa96fc..e1f39991b097 100644 --- a/rllib/agents/impala/vtrace_test.py +++ b/rllib/agents/impala/vtrace_test.py @@ -98,7 +98,7 @@ def test_log_probs_from_logits_and_actions(self, batch_size): 0, num_actions - 1, size=(seq_len, batch_size), dtype=np.int32) action_log_probs_tensor = vtrace.log_probs_from_logits_and_actions( - policy_logits, actions, {"model": None}) # dummy config dict + policy_logits, actions) # Ground Truth # Using broadcasting to create a mask that indexes action logits @@ -159,8 +159,6 @@ def test_vtrace_from_logits(self, batch_size): clip_rho_threshold = None # No clipping. clip_pg_rho_threshold = None # No clipping. - dummy_config = {"model": None} - # Intentionally leaving shapes unspecified to test if V-trace can # deal with that. placeholders = { @@ -180,15 +178,12 @@ def test_vtrace_from_logits(self, batch_size): from_logits_output = vtrace.from_logits( clip_rho_threshold=clip_rho_threshold, clip_pg_rho_threshold=clip_pg_rho_threshold, - config=dummy_config, **placeholders) target_log_probs = vtrace.log_probs_from_logits_and_actions( - placeholders["target_policy_logits"], placeholders["actions"], - dummy_config) + placeholders["target_policy_logits"], placeholders["actions"]) behaviour_log_probs = vtrace.log_probs_from_logits_and_actions( - placeholders["behaviour_policy_logits"], placeholders["actions"], - dummy_config) + placeholders["behaviour_policy_logits"], placeholders["actions"]) log_rhos = target_log_probs - behaviour_log_probs ground_truth = (log_rhos, behaviour_log_probs, target_log_probs) diff --git a/rllib/agents/marwil/marwil_policy.py b/rllib/agents/marwil/marwil_policy.py index 3ee1abfcbe33..72b8a239383b 100644 --- a/rllib/agents/marwil/marwil_policy.py +++ b/rllib/agents/marwil/marwil_policy.py @@ -29,7 +29,7 @@ def __init__(self, state_values, cumulative_rewards): class ReweightedImitationLoss(object): def __init__(self, state_values, cumulative_rewards, logits, actions, - action_space, beta, model_config): + action_space, beta, model): ma_adv_norm = tf.get_variable( name="moving_average_of_advantage_norm", dtype=tf.float32, @@ -48,8 +48,8 @@ def __init__(self, state_values, cumulative_rewards, logits, actions, beta * tf.divide(adv, 1e-8 + tf.sqrt(ma_adv_norm))) # log\pi_\theta(a|s) - dist_cls, _ = ModelCatalog.get_action_dist(action_space, model_config) - action_dist = dist_cls(logits, model_config=model_config) + dist_class, _ = ModelCatalog.get_action_dist(action_space, {}) + action_dist = dist_class(logits, model) logprobs = action_dist.logp(actions) self.loss = -1.0 * tf.reduce_mean( @@ -84,7 +84,7 @@ def __init__(self, observation_space, action_space, config): config = dict(ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, **config) self.config = config - dist_cls, logit_dim = ModelCatalog.get_action_dist( + dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"]) # Action inputs @@ -106,7 +106,7 @@ def __init__(self, observation_space, action_space, config): self.p_func_vars = scope_vars(scope.name) # Action outputs - action_dist = dist_cls(logits, model_config=self.config["model"]) + action_dist = dist_class(logits, self.model) self.output_actions = action_dist.sample() # Training inputs @@ -141,7 +141,7 @@ def __init__(self, observation_space, action_space, config): self.sess, obs_input=self.obs_t, action_sampler=self.output_actions, - action_prob=action_dist.sampled_action_prob(), + action_logp=action_dist.sampled_action_logp(), loss=objective, model=self.model, loss_inputs=self.loss_inputs, @@ -165,7 +165,7 @@ def _build_policy_loss(self, state_values, cum_rwds, logits, actions, action_space): return ReweightedImitationLoss(state_values, cum_rwds, logits, actions, action_space, self.config["beta"], - self.config["model"]) + self.model) @override(TFPolicy) def extra_compute_grad_fetches(self): diff --git a/rllib/agents/pg/torch_pg_policy.py b/rllib/agents/pg/torch_pg_policy.py index 2dc4a280f5ac..1e1fca7c4057 100644 --- a/rllib/agents/pg/torch_pg_policy.py +++ b/rllib/agents/pg/torch_pg_policy.py @@ -13,8 +13,7 @@ def pg_torch_loss(policy, batch_tensors): logits, _ = policy.model({ SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS] }) - action_dist = policy.dist_class( - logits, model_config=policy.config["model"]) + action_dist = policy.dist_class(logits, policy.model) log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS]) # save the error in the policy object policy.pi_err = -batch_tensors[Postprocessing.ADVANTAGES].dot( diff --git a/rllib/agents/ppo/appo_policy.py b/rllib/agents/ppo/appo_policy.py index 604eeab96885..6ecc8189a814 100644 --- a/rllib/agents/ppo/appo_policy.py +++ b/rllib/agents/ppo/appo_policy.py @@ -112,8 +112,8 @@ def __init__(self, rewards, values, bootstrap_value, - config, dist_class, + model, valid_mask, vf_loss_coeff=0.5, entropy_coeff=0.01, @@ -144,8 +144,8 @@ def __init__(self, rewards: A float32 tensor of shape [T, B]. values: A float32 tensor of shape [T, B]. bootstrap_value: A float32 tensor of shape [B]. - config: Trainer config dict. dist_class: action distribution class for logits. + model: backing ModelV2 instance valid_mask: A bool tensor of valid RNN input elements (#2992). vf_loss_coeff (float): Coefficient of the value function loss. entropy_coeff (float): Coefficient of the entropy regularizer. @@ -167,8 +167,8 @@ def reduce_mean_valid(t): rewards=rewards, values=values, bootstrap_value=bootstrap_value, - config=config, dist_class=dist_class, + model=model, clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32), clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold, tf.float32)) @@ -254,10 +254,9 @@ def make_time_major(*args, **kw): old_policy_behaviour_logits, output_hidden_shape, axis=1) unpacked_outputs = tf.split(policy.model_out, output_hidden_shape, axis=1) action_dist = policy.action_dist - old_policy_action_dist = policy.dist_class( - old_policy_behaviour_logits, model_config=policy.config["model"]) - prev_action_dist = policy.dist_class( - behaviour_logits, model_config=policy.config["model"]) + old_policy_action_dist = policy.dist_class(old_policy_behaviour_logits, + policy.model) + prev_action_dist = policy.dist_class(behaviour_logits, policy.model) values = policy.value_function policy.model_vars = policy.model.variables() @@ -303,8 +302,8 @@ def make_time_major(*args, **kw): rewards=make_time_major(rewards, drop_last=True), values=make_time_major(values, drop_last=True), bootstrap_value=make_time_major(values)[-1], - config=policy.config, dist_class=Categorical if is_multidiscrete else policy.dist_class, + model=policy.model, valid_mask=make_time_major(mask, drop_last=True), vf_loss_coeff=policy.config["vf_loss_coeff"], entropy_coeff=policy.config["entropy_coeff"], diff --git a/rllib/agents/ppo/ppo_policy.py b/rllib/agents/ppo/ppo_policy.py index d41aeb90043b..60d05a5a6802 100644 --- a/rllib/agents/ppo/ppo_policy.py +++ b/rllib/agents/ppo/ppo_policy.py @@ -9,9 +9,8 @@ Postprocessing from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy import LearningRateSchedule, \ - EntropyCoeffSchedule + EntropyCoeffSchedule, ACTION_LOGP from ray.rllib.policy.tf_policy_template import build_tf_policy -from ray.rllib.models.catalog import ModelCatalog from ray.rllib.utils.explained_variance import explained_variance from ray.rllib.utils import try_import_tf @@ -26,10 +25,13 @@ class PPOLoss(object): def __init__(self, action_space, + dist_class, + model, value_targets, advantages, actions, - logits, + prev_logits, + prev_actions_logp, vf_preds, curr_action_dist, value_fn, @@ -45,13 +47,16 @@ def __init__(self, Arguments: action_space: Environment observation space specification. + dist_class: action distribution class for logits. value_targets (Placeholder): Placeholder for target values; used for GAE. actions (Placeholder): Placeholder for actions taken from previous model evaluation. advantages (Placeholder): Placeholder for calculated advantages from previous model evaluation. - logits (Placeholder): Placeholder for logits output from + prev_logits (Placeholder): Placeholder for logits output from + previous model evaluation. + prev_actions_logp (Placeholder): Placeholder for prob output from previous model evaluation. vf_preds (Placeholder): Placeholder for value function output from previous model evaluation. @@ -73,11 +78,9 @@ def __init__(self, def reduce_mean_valid(t): return tf.reduce_mean(tf.boolean_mask(t, valid_mask)) - dist_cls, _ = ModelCatalog.get_action_dist(action_space, model_config) - prev_dist = dist_cls(logits, model_config=model_config) + prev_dist = dist_class(prev_logits, model) # Make loss functions. - logp_ratio = tf.exp( - curr_action_dist.logp(actions) - prev_dist.logp(actions)) + logp_ratio = tf.exp(curr_action_dist.logp(actions) - prev_actions_logp) action_kl = prev_dist.kl(curr_action_dist) self.mean_kl = reduce_mean_valid(action_kl) @@ -119,10 +122,13 @@ def ppo_surrogate_loss(policy, batch_tensors): policy.loss_obj = PPOLoss( policy.action_space, + policy.dist_class, + policy.model, batch_tensors[Postprocessing.VALUE_TARGETS], batch_tensors[Postprocessing.ADVANTAGES], batch_tensors[SampleBatch.ACTIONS], batch_tensors[BEHAVIOUR_LOGITS], + batch_tensors[ACTION_LOGP], batch_tensors[SampleBatch.VF_PREDS], policy.action_dist, policy.value_function, diff --git a/rllib/examples/autoregressive_action_dist.py b/rllib/examples/autoregressive_action_dist.py new file mode 100644 index 000000000000..819341594f2c --- /dev/null +++ b/rllib/examples/autoregressive_action_dist.py @@ -0,0 +1,212 @@ +"""Example of specifying an autoregressive action distribution. + +In an action space with multiple components (e.g., Tuple(a1, a2)), you might +want a2 to be sampled based on the sampled value of a1, i.e., +a2_sampled ~ P(a2 | a1_sampled, obs). Normally, a1 and a2 would be sampled +independently. + +To do this, you need both a custom model that implements the autoregressive +pattern, and a custom action distribution class that leverages that model. +This examples shows both. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gym +from gym.spaces import Discrete, Tuple +import argparse +import random + +import ray +from ray import tune +from ray.rllib.models import ModelCatalog +from ray.rllib.models.tf.tf_action_dist import Categorical, ActionDistribution +from ray.rllib.models.tf.misc import normc_initializer +from ray.rllib.models.tf.tf_modelv2 import TFModelV2 +from ray.rllib.policy.policy import TupleActions +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() + +parser = argparse.ArgumentParser() +parser.add_argument("--run", type=str, default="PPO") # try PG, PPO, IMPALA +parser.add_argument("--stop", type=int, default=200) + + +class CorrelatedActionsEnv(gym.Env): + """Simple env in which the policy has to emit a tuple of equal actions. + + The best score would be ~200 reward.""" + + def __init__(self, _): + self.observation_space = Discrete(2) + self.action_space = Tuple([Discrete(2), Discrete(2)]) + + def reset(self): + self.t = 0 + self.last = random.choice([0, 1]) + return self.last + + def step(self, action): + self.t += 1 + a1, a2 = action + reward = 0 + if a1 == self.last: + reward += 5 + # encourage correlation between a1 and a2 + if a1 == a2: + reward += 5 + done = self.t > 20 + self.last = random.choice([0, 1]) + return self.last, reward, done, {} + + +class BinaryAutoregressiveOutput(ActionDistribution): + """Action distribution P(a1, a2) = P(a1) * P(a2 | a1)""" + + @staticmethod + def required_model_output_shape(self, model_config): + return 16 # controls model output feature vector size + + def sample(self): + # first, sample a1 + a1_dist = self._a1_distribution() + a1 = a1_dist.sample() + + # sample a2 conditioned on a1 + a2_dist = self._a2_distribution(a1) + a2 = a2_dist.sample() + self._action_logp = a1_dist.logp(a1) + a2_dist.logp(a2) + + # return the action tuple + return TupleActions([a1, a2]) + + def logp(self, actions): + a1, a2 = actions[:, 0], actions[:, 1] + a1_vec = tf.expand_dims(tf.cast(a1, tf.float32), 1) + a1_logits, a2_logits = self.model.action_model([self.inputs, a1_vec]) + return ( + Categorical(a1_logits).logp(a1) + Categorical(a2_logits).logp(a2)) + + def sampled_action_logp(self): + return tf.exp(self._action_logp) + + def entropy(self): + a1_dist = self._a1_distribution() + a2_dist = self._a2_distribution(a1_dist.sample()) + return a1_dist.entropy() + a2_dist.entropy() + + def kl(self, other): + # TODO: implement this properly + return tf.zeros_like(self.entropy()) + + def _a1_distribution(self): + BATCH = tf.shape(self.inputs)[0] + a1_logits, _ = self.model.action_model( + [self.inputs, tf.zeros((BATCH, 1))]) + a1_dist = Categorical(a1_logits) + return a1_dist + + def _a2_distribution(self, a1): + a1_vec = tf.expand_dims(tf.cast(a1, tf.float32), 1) + _, a2_logits = self.model.action_model([self.inputs, a1_vec]) + a2_dist = Categorical(a2_logits) + return a2_dist + + +class AutoregressiveActionsModel(TFModelV2): + """Implements the `.action_model` branch required above.""" + + def __init__(self, obs_space, action_space, num_outputs, model_config, + name): + super(AutoregressiveActionsModel, self).__init__( + obs_space, action_space, num_outputs, model_config, name) + if action_space != Tuple([Discrete(2), Discrete(2)]): + raise ValueError( + "This model only supports the [2, 2] action space") + + # Inputs + obs_input = tf.keras.layers.Input( + shape=obs_space.shape, name="obs_input") + a1_input = tf.keras.layers.Input(shape=(1, ), name="a1_input") + ctx_input = tf.keras.layers.Input( + shape=(num_outputs, ), name="ctx_input") + + # Output of the model (normally 'logits', but for an autoregressive + # dist this is more like a context/feature layer encoding the obs) + context = tf.keras.layers.Dense( + num_outputs, + name="hidden", + activation=tf.nn.tanh, + kernel_initializer=normc_initializer(1.0))(obs_input) + + # V(s) + value_out = tf.keras.layers.Dense( + 1, + name="value_out", + activation=None, + kernel_initializer=normc_initializer(0.01))(context) + + # P(a1 | obs) + a1_logits = tf.keras.layers.Dense( + 2, + name="a1_logits", + activation=None, + kernel_initializer=normc_initializer(0.01))(ctx_input) + + # P(a2 | a1) + # --note: typically you'd want to implement P(a2 | a1, obs) as follows: + # a2_context = tf.keras.layers.Concatenate(axis=1)( + # [ctx_input, a1_input]) + a2_context = a1_input + a2_hidden = tf.keras.layers.Dense( + 16, + name="a2_hidden", + activation=tf.nn.tanh, + kernel_initializer=normc_initializer(1.0))(a2_context) + a2_logits = tf.keras.layers.Dense( + 2, + name="a2_logits", + activation=None, + kernel_initializer=normc_initializer(0.01))(a2_hidden) + + # Base layers + self.base_model = tf.keras.Model(obs_input, [context, value_out]) + self.register_variables(self.base_model.variables) + self.base_model.summary() + + # Autoregressive action sampler + self.action_model = tf.keras.Model([ctx_input, a1_input], + [a1_logits, a2_logits]) + self.action_model.summary() + self.register_variables(self.action_model.variables) + + def forward(self, input_dict, state, seq_lens): + context, self._value_out = self.base_model(input_dict["obs"]) + return context, state + + def value_function(self): + return tf.reshape(self._value_out, [-1]) + + +if __name__ == "__main__": + ray.init() + args = parser.parse_args() + ModelCatalog.register_custom_model("autoregressive_model", + AutoregressiveActionsModel) + ModelCatalog.register_custom_action_dist("binary_autoreg_output", + BinaryAutoregressiveOutput) + tune.run( + args.run, + stop={"episode_reward_mean": args.stop}, + config={ + "env": CorrelatedActionsEnv, + "gamma": 0.5, + "num_gpus": 0, + "model": { + "custom_model": "autoregressive_model", + "custom_action_dist": "binary_autoreg_output", + }, + }) diff --git a/rllib/examples/centralized_critic.py b/rllib/examples/centralized_critic.py index 7e9495204465..9d07df5af6cf 100644 --- a/rllib/examples/centralized_critic.py +++ b/rllib/examples/centralized_critic.py @@ -28,7 +28,7 @@ from ray.rllib.models import ModelCatalog from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy import LearningRateSchedule, \ - EntropyCoeffSchedule + EntropyCoeffSchedule, ACTION_LOGP from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork from ray.rllib.utils.explained_variance import explained_variance @@ -141,10 +141,13 @@ def loss_with_central_critic(policy, batch_tensors): policy.loss_obj = PPOLoss( policy.action_space, + policy.dist_class, + policy.model, batch_tensors[Postprocessing.VALUE_TARGETS], batch_tensors[Postprocessing.ADVANTAGES], batch_tensors[SampleBatch.ACTIONS], batch_tensors[BEHAVIOUR_LOGITS], + batch_tensors[ACTION_LOGP], batch_tensors[SampleBatch.VF_PREDS], policy.action_dist, policy.central_value_function, diff --git a/rllib/examples/custom_torch_policy.py b/rllib/examples/custom_torch_policy.py index 8f6ef5444f8e..4fdb3a064c38 100644 --- a/rllib/examples/custom_torch_policy.py +++ b/rllib/examples/custom_torch_policy.py @@ -18,7 +18,7 @@ def policy_gradient_loss(policy, batch_tensors): logits, _ = policy.model({ SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS] }) - action_dist = policy.dist_class(logits, policy.config["model"]) + action_dist = policy.dist_class(logits, policy.model) log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS]) return -batch_tensors[SampleBatch.REWARDS].dot(log_probs) diff --git a/rllib/keras_policy.py b/rllib/keras_policy.py deleted file mode 100644 index 3008e133c1c6..000000000000 --- a/rllib/keras_policy.py +++ /dev/null @@ -1,65 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from ray.rllib.policy.policy import Policy - - -def _sample(probs): - return [np.random.choice(len(pr), p=pr) for pr in probs] - - -class KerasPolicy(Policy): - """Initialize the Keras Policy. - - This is a Policy used for models with actor and critics. - Note: This class is built for specific usage of Actor-Critic models, - and is less general compared to TFPolicy and TorchPolicies. - - Args: - observation_space (gym.Space): Observation space of the policy. - action_space (gym.Space): Action space of the policy. - config (dict): Policy-specific configuration data. - actor (Model): A model that holds the policy. - critic (Model): A model that holds the value function. - """ - - def __init__(self, - observation_space, - action_space, - config, - actor=None, - critic=None): - Policy.__init__(self, observation_space, action_space, config) - self.actor = actor - self.critic = critic - self.models = [self.actor, self.critic] - - def compute_actions(self, obs, *args, **kwargs): - state = np.array(obs) - policy = self.actor.predict(state) - value = self.critic.predict(state) - return _sample(policy), [], {"vf_preds": value.flatten()} - - def learn_on_batch(self, batch, *args): - self.actor.fit( - batch["obs"], - batch["adv_targets"], - epochs=1, - verbose=0, - steps_per_epoch=20) - self.critic.fit( - batch["obs"], - batch["value_targets"], - epochs=1, - verbose=0, - steps_per_epoch=20) - return {} - - def get_weights(self): - return [model.get_weights() for model in self.models] - - def set_weights(self, weights): - return [model.set_weights(w) for model, w in zip(self.models, weights)] diff --git a/rllib/models/action_dist.py b/rllib/models/action_dist.py index 9bfeb32fb120..f5a5f1e3c0cf 100644 --- a/rllib/models/action_dist.py +++ b/rllib/models/action_dist.py @@ -9,22 +9,35 @@ class ActionDistribution(object): """The policy action distribution of an agent. - Args: - inputs (Tensor): The input vector to compute samples from. - model_config (dict): Optional model config dict - (as defined in catalog.py) + Attributes: + inputs (Tensors): input vector to compute samples from. + model (ModelV2): reference to model producing the inputs. """ @DeveloperAPI - def __init__(self, inputs, model_config): + def __init__(self, inputs, model): + """Initialize the action dist. + + Arguments: + inputs (Tensors): input vector to compute samples from. + model (ModelV2): reference to model producing the inputs. This + is mainly useful if you want to use model variables to compute + action outputs (i.e., for auto-regressive action distributions, + see examples/autoregressive_action_dist.py). + """ self.inputs = inputs - self.model_config = model_config + self.model = model @DeveloperAPI def sample(self): """Draw a sample from the action distribution.""" raise NotImplementedError + @DeveloperAPI + def sampled_action_logp(self): + """Returns the log probability of the last sampled action.""" + raise NotImplementedError + @DeveloperAPI def logp(self, x): """The log-likelihood of the action distribution.""" diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index 243cfccc54ba..33b6b36cb573 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -97,10 +97,10 @@ class ModelCatalog(object): >>> prep = ModelCatalog.get_preprocessor(env) >>> observation = prep.transform(raw_observation) - >>> dist_cls, dist_dim = ModelCatalog.get_action_dist( + >>> dist_class, dist_dim = ModelCatalog.get_action_dist( env.action_space, {}) >>> model = ModelCatalog.get_model(inputs, dist_dim, options) - >>> dist = dist_cls(model.outputs) + >>> dist = dist_class(model.outputs, model) >>> action = dist.sample() """ diff --git a/rllib/models/tf/tf_action_dist.py b/rllib/models/tf/tf_action_dist.py index 48b5b40eb185..9474d278543c 100644 --- a/rllib/models/tf/tf_action_dist.py +++ b/rllib/models/tf/tf_action_dist.py @@ -17,9 +17,8 @@ class TFActionDistribution(ActionDistribution): """TF-specific extensions for building action distributions.""" @DeveloperAPI - def __init__(self, inputs, model_config): - super(TFActionDistribution, self).__init__( - inputs, model_config=model_config) + def __init__(self, inputs, model): + super(TFActionDistribution, self).__init__(inputs, model) self.sample_op = self._build_sample_op() @DeveloperAPI @@ -27,24 +26,28 @@ def _build_sample_op(self): """Implement this instead of sample(), to enable op reuse. This is needed since the sample op is non-deterministic and is shared - between sample() and sampled_action_prob(). + between sample() and sampled_action_logp(). """ raise NotImplementedError - @DeveloperAPI + @override(ActionDistribution) def sample(self): """Draw a sample from the action distribution.""" return self.sample_op - @DeveloperAPI - def sampled_action_prob(self): + @override(ActionDistribution) + def sampled_action_logp(self): """Returns the log probability of the sampled action.""" - return tf.exp(self.logp(self.sample_op)) + return self.logp(self.sample_op) class Categorical(TFActionDistribution): """Categorical distribution for discrete action spaces.""" + @DeveloperAPI + def __init__(self, inputs, model=None): + super(Categorical, self).__init__(inputs, model) + @override(ActionDistribution) def logp(self, x): return -tf.nn.sparse_softmax_cross_entropy_with_logits( @@ -86,13 +89,14 @@ def required_model_output_shape(action_space, model_config): class MultiCategorical(TFActionDistribution): """MultiCategorical distribution for MultiDiscrete action spaces.""" - def __init__(self, inputs, input_lens, model_config): + def __init__(self, inputs, model, input_lens): + # skip TFActionDistribution init + ActionDistribution.__init__(self, inputs, model) self.cats = [ - Categorical(input_, model_config=model_config) + Categorical(input_, model) for input_ in tf.split(inputs, input_lens, axis=1) ] self.sample_op = self._build_sample_op() - self.model_config = model_config @override(ActionDistribution) def logp(self, actions): @@ -136,12 +140,12 @@ class DiagGaussian(TFActionDistribution): second half the gaussian standard deviations. """ - def __init__(self, inputs, model_config): + def __init__(self, inputs, model): mean, log_std = tf.split(inputs, 2, axis=1) self.mean = mean self.log_std = log_std self.std = tf.exp(log_std) - super(DiagGaussian, self).__init__(inputs, model_config) + TFActionDistribution.__init__(self, inputs, model) @override(ActionDistribution) def logp(self, x): @@ -182,8 +186,8 @@ class Deterministic(TFActionDistribution): """ @override(TFActionDistribution) - def sampled_action_prob(self): - return 1.0 + def sampled_action_logp(self): + return 0.0 @override(TFActionDistribution) def _build_sample_op(self): @@ -202,14 +206,15 @@ class MultiActionDistribution(TFActionDistribution): inputs (Tensor list): A list of tensors from which to compute samples. """ - def __init__(self, inputs, action_space, child_distributions, input_lens, - model_config): + def __init__(self, inputs, model, action_space, child_distributions, + input_lens): + # skip TFActionDistribution init + ActionDistribution.__init__(self, inputs, model) self.input_lens = input_lens split_inputs = tf.split(inputs, self.input_lens, axis=1) child_list = [] for i, distribution in enumerate(child_distributions): - child_list.append( - distribution(split_inputs[i], model_config=model_config)) + child_list.append(distribution(split_inputs[i], model)) self.child_distributions = child_list @override(ActionDistribution) @@ -252,10 +257,10 @@ def sample(self): return TupleActions([s.sample() for s in self.child_distributions]) @override(TFActionDistribution) - def sampled_action_prob(self): - p = self.child_distributions[0].sampled_action_prob() + def sampled_action_logp(self): + p = self.child_distributions[0].sampled_action_logp() for c in self.child_distributions[1:]: - p *= c.sampled_action_prob() + p += c.sampled_action_logp() return p @@ -265,7 +270,7 @@ class Dirichlet(TFActionDistribution): e.g. actions that represent resource allocation.""" - def __init__(self, inputs, model_config): + def __init__(self, inputs, model): """Input is a tensor of logits. The exponential of logits is used to parametrize the Dirichlet distribution as all parameters need to be positive. An arbitrary small epsilon is added to the concentration @@ -280,8 +285,7 @@ def __init__(self, inputs, model_config): validate_args=True, allow_nan_stats=False, ) - super(Dirichlet, self).__init__( - concentration, model_config=model_config) + TFActionDistribution.__init__(self, concentration, model) @override(ActionDistribution) def logp(self, x): diff --git a/rllib/models/torch/torch_action_dist.py b/rllib/models/torch/torch_action_dist.py index b1a373f1582d..6c4eb72689fa 100644 --- a/rllib/models/torch/torch_action_dist.py +++ b/rllib/models/torch/torch_action_dist.py @@ -37,9 +37,8 @@ class TorchCategorical(TorchDistributionWrapper): """Wrapper class for PyTorch Categorical distribution.""" @override(ActionDistribution) - def __init__(self, inputs, model_config): + def __init__(self, inputs, model): self.dist = torch.distributions.categorical.Categorical(logits=inputs) - self.model_config = model_config @staticmethod @override(ActionDistribution) @@ -51,10 +50,9 @@ class TorchDiagGaussian(TorchDistributionWrapper): """Wrapper class for PyTorch Normal distribution.""" @override(ActionDistribution) - def __init__(self, inputs, model_config): + def __init__(self, inputs, model): mean, log_std = torch.chunk(inputs, 2, dim=1) self.dist = torch.distributions.normal.Normal(mean, torch.exp(log_std)) - self.model_config = model_config @override(TorchDistributionWrapper) def logp(self, actions): diff --git a/rllib/optimizers/aso_multi_gpu_learner.py b/rllib/optimizers/aso_multi_gpu_learner.py index cbc8f61988c8..80109edfff93 100644 --- a/rllib/optimizers/aso_multi_gpu_learner.py +++ b/rllib/optimizers/aso_multi_gpu_learner.py @@ -165,7 +165,7 @@ def _step(self): opt = s.idle_optimizers.get() with self.load_timer: - tuples = s.policy._get_loss_inputs_dict(batch) + tuples = s.policy._get_loss_inputs_dict(batch, shuffle=False) data_keys = [ph for _, ph in s.policy._loss_inputs] if s.policy._state_inputs: state_keys = s.policy._state_inputs + [s.policy._seq_lens] diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index be80ec6b3844..f47cfc1d0a2a 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -66,7 +66,7 @@ def __init__(self, All policy variables should be created in this function. If not specified, a default model will be created. action_sampler_fn (func): optional function that returns a - tuple of action and action prob tensors given + tuple of action and action logp tensors given (policy, model, input_dict, obs_space, action_space, config). If not specified, a default action distribution will be used. existing_inputs (OrderedDict): when copying a policy, this @@ -144,6 +144,7 @@ def __init__(self, logit_dim, self.config["model"], framework="tf") + if existing_inputs: self.state_in = [ v for k, v in existing_inputs.items() @@ -162,14 +163,13 @@ def __init__(self, # Setup action sampler if action_sampler_fn: self.action_dist = None - action_sampler, action_prob = action_sampler_fn( + action_sampler, action_logp = action_sampler_fn( self, self.model, self.input_dict, obs_space, action_space, config) else: - self.action_dist = self.dist_class( - self.model_out, model_config=self.config["model"]) + self.action_dist = self.dist_class(self.model_out, self.model) action_sampler = self.action_dist.sample() - action_prob = self.action_dist.sampled_action_prob() + action_logp = self.action_dist.sampled_action_logp() # Phase 1 init sess = tf.get_default_session() or tf.Session() @@ -184,7 +184,7 @@ def __init__(self, sess, obs_input=obs, action_sampler=action_sampler, - action_prob=action_prob, + action_logp=action_logp, loss=None, # dynamically initialized on run loss_inputs=[], model=self.model, diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index 5605044889bf..ad030ec732b1 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -22,6 +22,9 @@ tf = try_import_tf() logger = logging.getLogger(__name__) +ACTION_PROB = "action_prob" +ACTION_LOGP = "action_logp" + @DeveloperAPI class TFPolicy(Policy): @@ -59,7 +62,7 @@ def __init__(self, loss, loss_inputs, model=None, - action_prob=None, + action_logp=None, state_inputs=None, state_outputs=None, prev_action_input=None, @@ -87,7 +90,7 @@ def __init__(self, placeholders during loss computation. model (rllib.models.Model): used to integrate custom losses and stats from user-defined RLlib models. - action_prob (Tensor): probability of the sampled action. + action_logp (Tensor): log probability of the sampled action. state_inputs (list): list of RNN state input Tensors. state_outputs (list): list of RNN state output Tensors. prev_action_input (Tensor): placeholder for previous actions @@ -113,7 +116,9 @@ def __init__(self, self._prev_reward_input = prev_reward_input self._sampler = action_sampler self._is_training = self._get_is_training_placeholder() - self._action_prob = action_prob + self._action_logp = action_logp + self._action_prob = (tf.exp(self._action_logp) + if self._action_logp is not None else None) self._state_inputs = state_inputs or [] self._state_outputs = state_outputs or [] self._seq_lens = seq_lens @@ -297,8 +302,11 @@ def extra_compute_action_fetches(self): By default we only return action probability info (if present). """ - if self._action_prob is not None: - return {"action_prob": self._action_prob} + if self._action_logp is not None: + return { + ACTION_PROB: self._action_prob, + ACTION_LOGP: self._action_logp, + } else: return {} diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index 1a151625e2d7..aff405e332c5 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -30,7 +30,7 @@ class TorchPolicy(Policy): """ def __init__(self, observation_space, action_space, model, loss, - action_distribution_cls): + action_distribution_class): """Build a policy from policy and loss torch modules. Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES @@ -44,7 +44,7 @@ def __init__(self, observation_space, action_space, model, loss, first item is action logits, and the rest can be any value. loss (func): Function that takes (policy, batch_tensors) and returns a single scalar loss. - action_distribution_cls (ActionDistribution): Class for action + action_distribution_class (ActionDistribution): Class for action distribution. """ self.observation_space = observation_space @@ -56,7 +56,7 @@ def __init__(self, observation_space, action_space, model, loss, self._model = model.to(self.device) self._loss = loss self._optimizer = self.optimizer() - self._action_dist_cls = action_distribution_cls + self._action_dist_class = action_distribution_class @override(Policy) def compute_actions(self, @@ -78,8 +78,7 @@ def compute_actions(self, input_dict["prev_rewards"] = prev_reward_batch model_out = self._model(input_dict, state_batches, [1]) logits, state = model_out - action_dist = self._action_dist_cls( - logits, model_config=self.config["model"]) + action_dist = self._action_dist_class(logits, self._model) actions = action_dist.sample() return (actions.cpu().numpy(), [h.cpu().numpy() for h in state],