Skip to content

Latest commit

 

History

History
219 lines (157 loc) · 11.2 KB

index.md

File metadata and controls

219 lines (157 loc) · 11.2 KB

TRFL: Reinforcement Learning Building Blocks

TRFL (pronounced "truffle") is a library built on top of TensorFlow that exposes several useful building blocks for implementing Reinforcement Learning agents.

Background

Common RL algorithms describe a particular update to either a Policy, a Value function, or an Action-Value (Q) function. In Deep-RL, a policy, value- or Q- function is typically represented by a neural network (the model, not to be confused with an environment model, which is used in model-based RL). We formulate common RL update rules for these neural networks as differentiable loss functions, as is common in (un-)supervised learning. Under automatic differentiation, the original update rule is recovered. We find that loss functions are more modular and composable than traditional RL updates, and more natural when combining with supervised or unsupervised objectives.

The loss functions and other operations provided here are implemented in pure TensorFlow. They are not complete algorithms, but implementations of RL-specific mathematical operations needed when building fully-functional RL agents. In particular, the updates are only valid if the input data are sampled in the correct manner. For example, the sequence-advantage-actor-critic loss (i.e. A2C) is only valid if the input trajectory is an unbiased sample from the current policy; i.e. the data are on-policy. This library cannot check or enforce such constraints.

Installation

TRFL can be installed from pip directly from github, with the following command: pip install git+git://github.com/deepmind/trfl.git

TRFL will work with both the CPU and GPU version of tensorflow, but to allow for that it does not list Tensorflow as a requirement, so you need to install Tensorflow and Tensorflow-probability separately if you haven't already done so.

Example usage

Import TensorFlow and TRFL.

import tensorflow as tf
import trfl

Define the relevant data associated to a transition in the environment from state s_tm1 to state s_t. This typically includes action values (or other characterization of the agent's policy) in both the source and destination states. The action a_tm1 is the one selected after observing s_tm1, and resulted in observing the immediate reward r_t and the subsequent state s_t. pcont_t represents a time dependent discount factor, or (equivalently) the continuation probability from state s_t. In most applications, its value will be equal to a constant factor (e.g., 0.99), except if s_t is the final state in an episode, in which case it is set to zero.

# Q-values for the previous and next timesteps, shape [batch_size, num_actions].
q_tm1 = tf.get_variable(
    "q_tm1", initializer=[[1., 1., 0.], [1., 2., 0.]], dtype=tf.float32)
q_t = tf.get_variable(
    "q_t", initializer=[[0., 1., 0.], [1., 2., 0.]], dtype=tf.float32)

# Action indices, discounts and rewards, shape [batch_size].
a_tm1 = tf.constant([0, 1], dtype=tf.int32)
r_t = tf.constant([1, 1], dtype=tf.float32)
pcont_t = tf.constant([0, 1], dtype=tf.float32)  # the discount factor

# Q-learning loss, and auxiliary data.
loss, q_learning = trfl.qlearning(q_tm1, a_tm1, r_t, pcont_t, q_t)

loss is the tensor representing the loss. For Q-learning, it is half the squared difference between the predicted Q-values and the TD targets, shape [batch_size]. Extra information is in the q_learning namedtuple, including q_learning.td_error and q_learning.target.

Most of the time, you may only be interested in the loss, in which case you can use any of the following expressions:

loss, _ = trfl.qlearning(q_tm1, a_tm1, r_t, pcont_t, q_t)
loss = trfl.qlearning(q_tm1, a_tm1, r_t, pcont_t, q_t).loss

The loss tensor can be differentiated to derive the corresponding RL update. Note that in Q-learning, as in other bootstrapped losses, the TD targets are wrapped in a tf.stop_gradient. Differentiating loss therefore results in gradients with respect to q_tm1 but not with respect to q_t.

reduced_loss = tf.reduce_mean(loss)
optimizer = tf.train.AdamOptimizer(learning_rate=0.1)
train_op = optimizer.minimize(reduced_loss)

All loss functions in the package return both a loss tensor and a namedtuple with extra information, using the above convention, but different functions may have different extra fields. Check the documentation of each function below for more information.

Naming Conventions and Developer Guidelines

Throughout the package, we use the following conventions:

  • Time indices and variable names:

    • q_tm1: the action value in the source state of a transition.
    • a_tm1: the action that was selected in the source state.
    • r_t: the resulting rewards collected in the destination state.
    • pcont_t: the continuation probability / discount for a transition.
    • q_t: the action values in the destination state.
  • Tensor shapes:

    • All ops support minibatches only. We use B to denote the batch size.
    • Batches of rewards, continuation probabilities / discounts have shape [B].
    • Batches of state-values have shape [B].
    • Batches of action-values / q-values have shape [B, num_actions].
    • All losses have shape [B], i.e. the loss is not reduced over the batch dimension. This allows the user to easily weight the loss for different elements of the batch (e.g., by their importance sampling weights).
    • For ops that take batches of sequences of data, T denotes the sequence length. Tensors are time-major, and have shape [T, B, ...]. Index 0 of the time dimension is assumed to be the start of the sequence.

Learning updates

Other

More information