TRFL (pronounced "truffle") is a library built on top of TensorFlow that exposes several useful building blocks for implementing Reinforcement Learning agents.
Common RL algorithms describe a particular update to either a Policy, a Value function, or an Action-Value (Q) function. In Deep-RL, a policy, value- or Q- function is typically represented by a neural network (the model, not to be confused with an environment model, which is used in model-based RL). We formulate common RL update rules for these neural networks as differentiable loss functions, as is common in (un-)supervised learning. Under automatic differentiation, the original update rule is recovered. We find that loss functions are more modular and composable than traditional RL updates, and more natural when combining with supervised or unsupervised objectives.
The loss functions and other operations provided here are implemented in pure TensorFlow. They are not complete algorithms, but implementations of RL-specific mathematical operations needed when building fully-functional RL agents. In particular, the updates are only valid if the input data are sampled in the correct manner. For example, the sequence-advantage-actor-critic loss (i.e. A2C) is only valid if the input trajectory is an unbiased sample from the current policy; i.e. the data are on-policy. This library cannot check or enforce such constraints.
TRFL can be installed from pip directly from github, with the following command:
pip install git+git://github.com/deepmind/trfl.git
TRFL will work with both the CPU and GPU version of tensorflow, but to allow for that it does not list Tensorflow as a requirement, so you need to install Tensorflow and Tensorflow-probability separately if you haven't already done so.
Import TensorFlow and TRFL.
import tensorflow as tf
import trfl
Define the relevant data associated to a transition
in the environment from
state s_tm1
to state s_t
. This typically includes action values (or other
characterization of the agent's policy) in both the source
and destination
states. The action a_tm1
is the one selected after observing s_tm1
, and
resulted in observing the immediate reward r_t
and the subsequent state s_t
.
pcont_t
represents a time dependent discount factor, or (equivalently) the
continuation probability from state s_t
. In most applications, its value will
be equal to a constant factor (e.g., 0.99), except if s_t
is the final state
in an episode, in which case it is set to zero.
# Q-values for the previous and next timesteps, shape [batch_size, num_actions].
q_tm1 = tf.get_variable(
"q_tm1", initializer=[[1., 1., 0.], [1., 2., 0.]], dtype=tf.float32)
q_t = tf.get_variable(
"q_t", initializer=[[0., 1., 0.], [1., 2., 0.]], dtype=tf.float32)
# Action indices, discounts and rewards, shape [batch_size].
a_tm1 = tf.constant([0, 1], dtype=tf.int32)
r_t = tf.constant([1, 1], dtype=tf.float32)
pcont_t = tf.constant([0, 1], dtype=tf.float32) # the discount factor
# Q-learning loss, and auxiliary data.
loss, q_learning = trfl.qlearning(q_tm1, a_tm1, r_t, pcont_t, q_t)
loss
is the tensor representing the loss. For Q-learning, it is half the
squared difference between the predicted Q-values and the TD targets, shape
[batch_size]
. Extra information is in the q_learning
namedtuple, including
q_learning.td_error
and q_learning.target
.
Most of the time, you may only be interested in the loss, in which case you can use any of the following expressions:
loss, _ = trfl.qlearning(q_tm1, a_tm1, r_t, pcont_t, q_t)
loss = trfl.qlearning(q_tm1, a_tm1, r_t, pcont_t, q_t).loss
The loss
tensor can be differentiated to derive the corresponding RL update.
Note that in Q-learning, as in other bootstrapped losses, the TD targets
are wrapped in a tf.stop_gradient
. Differentiating loss
therefore
results in gradients with respect to q_tm1
but not with respect to q_t
.
reduced_loss = tf.reduce_mean(loss)
optimizer = tf.train.AdamOptimizer(learning_rate=0.1)
train_op = optimizer.minimize(reduced_loss)
All loss functions in the package return both a loss tensor and a namedtuple
with extra information, using the above convention, but different functions
may have different extra
fields. Check the documentation of each function
below for more information.
Throughout the package, we use the following conventions:
-
Time indices and variable names:
q_tm1
: the action value in thesource
state of a transition.a_tm1
: the action that was selected in thesource
state.r_t
: the resulting rewards collected in thedestination
state.pcont_t
: the continuation probability /discount
for a transition.q_t
: the action values in thedestination
state.
-
Tensor shapes:
- All ops support minibatches only. We use
B
to denote the batch size. - Batches of rewards, continuation probabilities / discounts have shape [B].
- Batches of state-values have shape
[B]
. - Batches of action-values / q-values have shape
[B, num_actions]
. - All losses have shape [B], i.e. the loss is not reduced over the batch dimension. This allows the user to easily weight the loss for different elements of the batch (e.g., by their importance sampling weights).
- For ops that take batches of sequences of data,
T
denotes the sequence length. Tensors are time-major, and have shape[T, B, ...]
. Index0
of the time dimension is assumed to be the start of the sequence.
- All ops support minibatches only. We use
-
State Value learning:
-
Discrete-action Value learning:
-
Distributional Value learning:
-
Continuous-action Policy Gradient:
-
Deterministic Policy Gradient:
-
Discrete-action Policy Gradient:
- discrete_policy_entropy_loss
- sequence_advantage_actor_critic_loss: this is the commonly-used A2C/A3C loss function.
- discrete_policy_gradient
- discrete_policy_gradient_loss
-
Pixel control:
-
Retrace:
-
Target Network Updating:
-
V-trace:
-
Clipping ops
-
Distributions
-
Indexing ops
-
Periodic execution ops
-
Policy ops
-
Sequence ops