Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom action distributions #5164

Merged
merged 22 commits into from
Aug 6, 2019
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
5176959
custom action dist wip
mawright Jun 2, 2019
8c4d684
Test case for custom action dist
mawright Jun 2, 2019
508ed4a
ActionDistribution.get_parameter_shape_for_action_space pattern
mawright Jul 8, 2019
7d0ae68
Edit exception message to also suggest using a custom action distribu…
mawright Jul 8, 2019
289141c
Clean up ModelCatalog.get_action_dist
mawright Jul 9, 2019
33c3907
Pass model config to ActionDistribution constructors
mawright Jul 9, 2019
ce720a3
Update custom action distribution test case
mawright Jul 10, 2019
d0b8a64
Name fix
mawright Jul 10, 2019
c3ec408
Autoformatter
mawright Jul 10, 2019
f78f447
parameter shape static methods for torch distributions
mawright Jul 10, 2019
489c573
Fix docstring
mawright Jul 10, 2019
ff5076e
Generalize fake array for graph initialization
mawright Jul 30, 2019
bfdfa90
Merge branch 'master' into action_dist
mawright Aug 1, 2019
11243cd
Fix action dist constructors
mawright Aug 1, 2019
a9939d4
Correct parameter shape static methods for multicategorical and gaussian
mawright Aug 1, 2019
96bba6c
Make suggested changes to custom action dist's
mawright Aug 1, 2019
8158f24
Correct instances of not passing model config to action dist
mawright Aug 2, 2019
f11fbca
Autoformatter
mawright Aug 2, 2019
bd378b0
fix tuple distribution constructor
mawright Aug 2, 2019
2f39c88
bugfix
mawright Aug 3, 2019
a1321a8
Merge remote-tracking branch 'upstream/master' into action_dist
ericl Aug 5, 2019
1b2eb98
Merge remote-tracking branch 'upstream/master' into action_dist
ericl Aug 6, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/ray/rllib/agents/a3c/a3c_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def actor_critic_loss(policy, batch_tensors):
SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS]
}) # TODO(ekl) seq lens shouldn't be None
values = policy.model.value_function()
dist = policy.dist_class(logits)
dist = policy.dist_class(logits, policy.config["model"])
log_probs = dist.logp(batch_tensors[SampleBatch.ACTIONS])
policy.entropy = dist.entropy().mean()
policy.pi_err = -batch_tensors[Postprocessing.ADVANTAGES].dot(
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/agents/ars/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(self,
model = ModelCatalog.get_model({
"obs": self.inputs
}, obs_space, action_space, dist_dim, model_config)
dist = dist_class(model.outputs)
dist = dist_class(model.outputs, model_config=model_config)
self.sampler = dist.sample()

self.variables = ray.experimental.tf_utils.TensorFlowVariables(
Expand Down
8 changes: 5 additions & 3 deletions python/ray/rllib/agents/dqn/dqn_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,10 @@ def __init__(self,

class QValuePolicy(object):
def __init__(self, q_values, observations, num_actions, stochastic, eps,
softmax, softmax_temp):
softmax, softmax_temp, model_config):
if softmax:
action_dist = Categorical(q_values / softmax_temp)
action_dist = Categorical(
q_values / softmax_temp, model_config=model_config)
self.action = action_dist.sample()
self.action_prob = action_dist.sampled_action_prob()
return
Expand Down Expand Up @@ -255,7 +256,8 @@ def build_q_networks(policy, q_model, input_dict, obs_space, action_space,
# Action outputs
qvp = QValuePolicy(q_values, input_dict[SampleBatch.CUR_OBS],
action_space.n, policy.stochastic, policy.eps,
config["soft_q"], config["softmax_temp"])
config["soft_q"], config["softmax_temp"],
config["model"])
policy.output_actions, policy.action_prob = qvp.action, qvp.action_prob

return policy.output_actions, policy.action_prob
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/agents/es/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(self, sess, action_space, obs_space, preprocessor,
model = ModelCatalog.get_model({
"obs": self.inputs
}, obs_space, action_space, dist_dim, model_options)
dist = dist_class(model.outputs)
dist = dist_class(model.outputs, model_config=model_options)
self.sampler = dist.sample()

self.variables = ray.experimental.tf_utils.TensorFlowVariables(
Expand Down
19 changes: 13 additions & 6 deletions python/ray/rllib/agents/impala/vtrace.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,14 @@

def log_probs_from_logits_and_actions(policy_logits,
actions,
config,
dist_class=Categorical):
return multi_log_probs_from_logits_and_actions([policy_logits], [actions],
dist_class)[0]
dist_class, config)[0]


def multi_log_probs_from_logits_and_actions(policy_logits, actions,
dist_class):
def multi_log_probs_from_logits_and_actions(policy_logits, actions, dist_class,
config):
"""Computes action log-probs from policy logits and actions.

In the notation used throughout documentation and comments, T refers to the
Expand All @@ -76,6 +77,8 @@ def multi_log_probs_from_logits_and_actions(policy_logits, actions,
...,
[T, B, ...]
with actions.
dist_class: Python class of the action distribution
config: Trainer config dict

Returns:
A list with length of ACTION_SPACE of float32
Expand All @@ -97,7 +100,8 @@ def multi_log_probs_from_logits_and_actions(policy_logits, actions,
tf.concat([[-1], a_shape[2:]], axis=0))
log_probs.append(
tf.reshape(
dist_class(policy_logits_flat).logp(actions_flat),
dist_class(policy_logits_flat,
model_config=config["model"]).logp(actions_flat),
a_shape[:2]))

return log_probs
Expand All @@ -110,6 +114,7 @@ def from_logits(behaviour_policy_logits,
rewards,
values,
bootstrap_value,
config,
dist_class=Categorical,
clip_rho_threshold=1.0,
clip_pg_rho_threshold=1.0,
Expand All @@ -122,6 +127,7 @@ def from_logits(behaviour_policy_logits,
rewards,
values,
bootstrap_value,
config,
dist_class,
clip_rho_threshold=clip_rho_threshold,
clip_pg_rho_threshold=clip_pg_rho_threshold,
Expand All @@ -145,6 +151,7 @@ def multi_from_logits(behaviour_policy_logits,
rewards,
values,
bootstrap_value,
config,
dist_class,
clip_rho_threshold=1.0,
clip_pg_rho_threshold=1.0,
Expand Down Expand Up @@ -235,9 +242,9 @@ def multi_from_logits(behaviour_policy_logits,
discounts, rewards, values, bootstrap_value
]):
target_action_log_probs = multi_log_probs_from_logits_and_actions(
target_policy_logits, actions, dist_class)
target_policy_logits, actions, dist_class, config)
behaviour_action_log_probs = multi_log_probs_from_logits_and_actions(
behaviour_policy_logits, actions, dist_class)
behaviour_policy_logits, actions, dist_class, config)

log_rhos = get_log_rhos(target_action_log_probs,
behaviour_action_log_probs)
Expand Down
6 changes: 5 additions & 1 deletion python/ray/rllib/agents/impala/vtrace_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(self,
bootstrap_value,
dist_class,
valid_mask,
config,
vf_loss_coeff=0.5,
entropy_coeff=0.01,
clip_rho_threshold=1.0,
Expand Down Expand Up @@ -72,6 +73,7 @@ def __init__(self,
bootstrap_value: A float32 tensor of shape [B].
dist_class: action distribution class for logits.
valid_mask: A bool tensor of valid RNN input elements (#2992).
config: Trainer config dict.
"""

# Compute vtrace on the CPU for better perf.
Expand All @@ -87,7 +89,8 @@ def __init__(self,
dist_class=dist_class,
clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32),
clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold,
tf.float32))
tf.float32),
config=config)
self.value_targets = self.vtrace_returns.vs

# The policy gradients loss
Expand Down Expand Up @@ -196,6 +199,7 @@ def make_time_major(*args, **kw):
bootstrap_value=make_time_major(values)[-1],
dist_class=Categorical if is_multidiscrete else policy.dist_class,
valid_mask=make_time_major(mask, drop_last=True),
config=policy.config,
vf_loss_coeff=policy.config["vf_loss_coeff"],
entropy_coeff=policy.entropy_coeff,
clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"],
Expand Down
11 changes: 8 additions & 3 deletions python/ray/rllib/agents/impala/vtrace_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_log_probs_from_logits_and_actions(self, batch_size):
0, num_actions - 1, size=(seq_len, batch_size), dtype=np.int32)

action_log_probs_tensor = vtrace.log_probs_from_logits_and_actions(
policy_logits, actions)
policy_logits, actions, {"model": None}) # dummy config dict

# Ground Truth
# Using broadcasting to create a mask that indexes action logits
Expand Down Expand Up @@ -159,6 +159,8 @@ def test_vtrace_from_logits(self, batch_size):
clip_rho_threshold = None # No clipping.
clip_pg_rho_threshold = None # No clipping.

dummy_config = {"model": None}

# Intentionally leaving shapes unspecified to test if V-trace can
# deal with that.
placeholders = {
Expand All @@ -178,12 +180,15 @@ def test_vtrace_from_logits(self, batch_size):
from_logits_output = vtrace.from_logits(
clip_rho_threshold=clip_rho_threshold,
clip_pg_rho_threshold=clip_pg_rho_threshold,
config=dummy_config,
**placeholders)

target_log_probs = vtrace.log_probs_from_logits_and_actions(
placeholders["target_policy_logits"], placeholders["actions"])
placeholders["target_policy_logits"], placeholders["actions"],
dummy_config)
behaviour_log_probs = vtrace.log_probs_from_logits_and_actions(
placeholders["behaviour_policy_logits"], placeholders["actions"])
placeholders["behaviour_policy_logits"], placeholders["actions"],
dummy_config)
log_rhos = target_log_probs - behaviour_log_probs
ground_truth = (log_rhos, behaviour_log_probs, target_log_probs)

Expand Down
11 changes: 6 additions & 5 deletions python/ray/rllib/agents/marwil/marwil_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, state_values, cumulative_rewards):

class ReweightedImitationLoss(object):
def __init__(self, state_values, cumulative_rewards, logits, actions,
action_space, beta):
action_space, beta, model_config):
ma_adv_norm = tf.get_variable(
name="moving_average_of_advantage_norm",
dtype=tf.float32,
Expand All @@ -48,8 +48,8 @@ def __init__(self, state_values, cumulative_rewards, logits, actions,
beta * tf.divide(adv, 1e-8 + tf.sqrt(ma_adv_norm)))

# log\pi_\theta(a|s)
dist_cls, _ = ModelCatalog.get_action_dist(action_space, {})
action_dist = dist_cls(logits)
dist_cls, _ = ModelCatalog.get_action_dist(action_space, model_config)
action_dist = dist_cls(logits, model_config=model_config)
logprobs = action_dist.logp(actions)

self.loss = -1.0 * tf.reduce_mean(
Expand Down Expand Up @@ -106,7 +106,7 @@ def __init__(self, observation_space, action_space, config):
self.p_func_vars = scope_vars(scope.name)

# Action outputs
action_dist = dist_cls(logits)
action_dist = dist_cls(logits, model_config=self.config["model"])
self.output_actions = action_dist.sample()

# Training inputs
Expand Down Expand Up @@ -164,7 +164,8 @@ def _build_value_loss(self, state_values, cum_rwds):
def _build_policy_loss(self, state_values, cum_rwds, logits, actions,
action_space):
return ReweightedImitationLoss(state_values, cum_rwds, logits, actions,
action_space, self.config["beta"])
action_space, self.config["beta"],
self.config["model"])

@override(TFPolicy)
def extra_compute_grad_fetches(self):
Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/agents/pg/torch_pg_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def pg_torch_loss(policy, batch_tensors):
logits, _ = policy.model({
SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS]
})
action_dist = policy.dist_class(logits)
action_dist = policy.dist_class(
logits, model_config=policy.config["model"])
log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS])
# save the error in the policy object
policy.pi_err = -batch_tensors[Postprocessing.ADVANTAGES].dot(
Expand Down
10 changes: 8 additions & 2 deletions python/ray/rllib/agents/ppo/appo_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(self,
rewards,
values,
bootstrap_value,
config,
dist_class,
valid_mask,
vf_loss_coeff=0.5,
Expand Down Expand Up @@ -143,6 +144,7 @@ def __init__(self,
rewards: A float32 tensor of shape [T, B].
values: A float32 tensor of shape [T, B].
bootstrap_value: A float32 tensor of shape [B].
config: Trainer config dict.
dist_class: action distribution class for logits.
valid_mask: A bool tensor of valid RNN input elements (#2992).
vf_loss_coeff (float): Coefficient of the value function loss.
Expand All @@ -165,6 +167,7 @@ def reduce_mean_valid(t):
rewards=rewards,
values=values,
bootstrap_value=bootstrap_value,
config=config,
dist_class=dist_class,
clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32),
clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold,
Expand Down Expand Up @@ -251,8 +254,10 @@ def make_time_major(*args, **kw):
old_policy_behaviour_logits, output_hidden_shape, axis=1)
unpacked_outputs = tf.split(policy.model_out, output_hidden_shape, axis=1)
action_dist = policy.action_dist
old_policy_action_dist = policy.dist_class(old_policy_behaviour_logits)
prev_action_dist = policy.dist_class(behaviour_logits)
old_policy_action_dist = policy.dist_class(
old_policy_behaviour_logits, model_config=policy.config["model"])
prev_action_dist = policy.dist_class(
behaviour_logits, model_config=policy.config["model"])
values = policy.value_function

policy.model_vars = policy.model.variables()
Expand Down Expand Up @@ -298,6 +303,7 @@ def make_time_major(*args, **kw):
rewards=make_time_major(rewards, drop_last=True),
values=make_time_major(values, drop_last=True),
bootstrap_value=make_time_major(values)[-1],
config=policy.config,
dist_class=Categorical if is_multidiscrete else policy.dist_class,
valid_mask=make_time_major(mask, drop_last=True),
vf_loss_coeff=policy.config["vf_loss_coeff"],
Expand Down
12 changes: 8 additions & 4 deletions python/ray/rllib/agents/ppo/ppo_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def __init__(self,
clip_param=0.1,
vf_clip_param=0.1,
vf_loss_coeff=1.0,
use_gae=True):
use_gae=True,
model_config=None):
"""Constructs the loss for Proximal Policy Objective.

Arguments:
Expand All @@ -65,13 +66,15 @@ def __init__(self,
vf_clip_param (float): Clip parameter for the value function
vf_loss_coeff (float): Coefficient of the value function loss
use_gae (bool): If true, use the Generalized Advantage Estimator.
model_config (dict): (Optional) model config for use in specifying
action distributions.
"""

def reduce_mean_valid(t):
return tf.reduce_mean(tf.boolean_mask(t, valid_mask))

dist_cls, _ = ModelCatalog.get_action_dist(action_space, {})
prev_dist = dist_cls(logits)
dist_cls, _ = ModelCatalog.get_action_dist(action_space, model_config)
prev_dist = dist_cls(logits, model_config=model_config)
# Make loss functions.
logp_ratio = tf.exp(
curr_action_dist.logp(actions) - prev_dist.logp(actions))
Expand Down Expand Up @@ -129,7 +132,8 @@ def ppo_surrogate_loss(policy, batch_tensors):
clip_param=policy.config["clip_param"],
vf_clip_param=policy.config["vf_clip_param"],
vf_loss_coeff=policy.config["vf_loss_coeff"],
use_gae=policy.config["use_gae"])
use_gae=policy.config["use_gae"],
model_config=policy.config["model"])

return policy.loss_obj.loss

Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/agents/ppo/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def testCategorical(self):
logits = tf.placeholder(tf.float32, shape=(None, 10))
z = 8 * (np.random.rand(10) - 0.5)
data = np.tile(z, (num_samples, 1))
c = Categorical(logits)
c = Categorical(logits, {}) # dummy config dict
sample_op = c.sample()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/examples/custom_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def custom_loss(self, policy_loss, loss_inputs):
print("FYI: You can also use these tensors: {}, ".format(loss_inputs))

# compute the IL loss
action_dist = Categorical(logits)
action_dist = Categorical(logits, self.options)
self.policy_loss = policy_loss
self.imitation_loss = tf.reduce_mean(
-action_dist.logp(input_ops["actions"]))
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/examples/custom_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def policy_gradient_loss(policy, batch_tensors):
logits, _ = policy.model({
SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS]
})
action_dist = policy.dist_class(logits)
action_dist = policy.dist_class(logits, policy.config["model"])
log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS])
return -batch_tensors[SampleBatch.REWARDS].dot(log_probs)

Expand Down
24 changes: 23 additions & 1 deletion python/ray/rllib/models/action_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@ class ActionDistribution(object):

Args:
inputs (Tensor): The input vector to compute samples from.
model_config (dict): Optional model config dict
(as defined in catalog.py)
"""

@DeveloperAPI
def __init__(self, inputs):
def __init__(self, inputs, model_config):
self.inputs = inputs
self.model_config = model_config

@DeveloperAPI
def sample(self):
Expand Down Expand Up @@ -52,3 +55,22 @@ def multi_entropy(self):
MultiDiscrete. TODO(ekl) consider removing this.
"""
return self.entropy()

@DeveloperAPI
@staticmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@staticmethod
@classmethod

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really understand this suggestion. Why do you think this should be a class method?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm I guess staticmethod is fine, since you don't really need the class.

def required_model_output_shape(action_space, model_config):
"""Returns the required shape of an input parameter tensor for a
particular action space and an optional dict of distribution-specific
options.

Args:
action_space (gym.Space): The action space this distribution will
be used for, whose shape attributes will be used to determine
the required shape of the input parameter tensor.
model_config (dict): Model's config dict (as defined in catalog.py)

Returns:
model_output_shape (int or np.ndarray of ints): size of the
required input vector (minus leading batch dimension).
"""
raise NotImplementedError
Loading