-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom action distributions #5164
Changes from 11 commits
5176959
8c4d684
508ed4a
7d0ae68
289141c
33c3907
ce720a3
d0b8a64
c3ec408
f78f447
489c573
ff5076e
bfdfa90
11243cd
a9939d4
96bba6c
8158f24
f11fbca
bd378b0
2f39c88
a1321a8
1b2eb98
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -28,11 +28,14 @@ class ActionDistribution(object): | |||||
|
||||||
Args: | ||||||
inputs (Tensor): The input vector to compute samples from. | ||||||
model_config (dict): Optional model config dict | ||||||
(as defined in catalog.py) | ||||||
""" | ||||||
|
||||||
@DeveloperAPI | ||||||
def __init__(self, inputs): | ||||||
def __init__(self, inputs, model_config=None): | ||||||
self.inputs = inputs | ||||||
self.model_config = model_config | ||||||
self.sample_op = self._build_sample_op() | ||||||
|
||||||
@DeveloperAPI | ||||||
|
@@ -69,6 +72,25 @@ def sampled_action_prob(self): | |||||
"""Returns the log probability of the sampled action.""" | ||||||
return tf.exp(self.logp(self.sample_op)) | ||||||
|
||||||
@DeveloperAPI | ||||||
@staticmethod | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't really understand this suggestion. Why do you think this should be a class method? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm I guess staticmethod is fine, since you don't really need the class. |
||||||
def parameter_shape_for_action_space(action_space, model_config=None): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about |
||||||
"""Returns the required shape of an input parameter tensor for a | ||||||
particular action space and an optional dict of distribution-specific | ||||||
options. | ||||||
|
||||||
Args: | ||||||
action_space (gym.Space): The action space this distribution will | ||||||
be used for, whose shape attributes will be used to determine | ||||||
the required shape of the input parameter tensor. | ||||||
model_config (dict): Model's config dict (as defined in catalog.py) | ||||||
|
||||||
Returns: | ||||||
dist_dim (int or np.ndarray of ints): size of the required | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
input vector (minus leading batch dimension). | ||||||
""" | ||||||
raise NotImplementedError | ||||||
|
||||||
|
||||||
class Categorical(ActionDistribution): | ||||||
"""Categorical distribution for discrete action spaces.""" | ||||||
|
@@ -122,16 +144,22 @@ def kl(self, other): | |||||
def _build_sample_op(self): | ||||||
return tf.squeeze(tf.multinomial(self.inputs, 1), axis=1) | ||||||
|
||||||
@staticmethod | ||||||
@override(ActionDistribution) | ||||||
def parameter_shape_for_action_space(action_space, model_config=None): | ||||||
return action_space.n | ||||||
|
||||||
|
||||||
class MultiCategorical(ActionDistribution): | ||||||
"""Categorical distribution for discrete action spaces.""" | ||||||
|
||||||
def __init__(self, inputs, input_lens): | ||||||
def __init__(self, inputs, input_lens, model_config=None): | ||||||
self.cats = [ | ||||||
Categorical(input_) | ||||||
for input_ in tf.split(inputs, input_lens, axis=1) | ||||||
] | ||||||
self.sample_op = self._build_sample_op() | ||||||
self.model_config = model_config | ||||||
|
||||||
def logp(self, actions): | ||||||
# If tensor is provided, unstack it into list | ||||||
|
@@ -158,12 +186,12 @@ class DiagGaussian(ActionDistribution): | |||||
second half the gaussian standard deviations. | ||||||
""" | ||||||
|
||||||
def __init__(self, inputs): | ||||||
def __init__(self, inputs, model_config=None): | ||||||
mean, log_std = tf.split(inputs, 2, axis=1) | ||||||
self.mean = mean | ||||||
self.log_std = log_std | ||||||
self.std = tf.exp(log_std) | ||||||
ActionDistribution.__init__(self, inputs) | ||||||
ActionDistribution.__init__(self, inputs, model_config) | ||||||
|
||||||
@override(ActionDistribution) | ||||||
def logp(self, x): | ||||||
|
@@ -191,6 +219,11 @@ def entropy(self): | |||||
def _build_sample_op(self): | ||||||
return self.mean + self.std * tf.random_normal(tf.shape(self.mean)) | ||||||
|
||||||
@staticmethod | ||||||
@override(ActionDistribution) | ||||||
def parameter_shape_for_action_space(action_space, model_config=None): | ||||||
return action_space.shape[0] * 2 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. np.product(action_space.shape)? here and elsewhere? |
||||||
|
||||||
|
||||||
class Deterministic(ActionDistribution): | ||||||
"""Action distribution that returns the input values directly. | ||||||
|
@@ -206,21 +239,35 @@ def sampled_action_prob(self): | |||||
def _build_sample_op(self): | ||||||
return self.inputs | ||||||
|
||||||
@staticmethod | ||||||
@override(ActionDistribution) | ||||||
def parameter_shape_for_action_space(action_space, model_config=None): | ||||||
return action_space.shape[0] | ||||||
|
||||||
|
||||||
class MultiActionDistribution(ActionDistribution): | ||||||
"""Action distribution that operates for list of actions. | ||||||
|
||||||
Args: | ||||||
inputs (Tensor list): A list of tensors from which to compute samples. | ||||||
model_config (dict): Config dict for the model (as defined in | ||||||
catalog.py) | ||||||
""" | ||||||
|
||||||
def __init__(self, inputs, action_space, child_distributions, input_lens): | ||||||
def __init__(self, | ||||||
inputs, | ||||||
action_space, | ||||||
child_distributions, | ||||||
input_lens, | ||||||
model_config=None): | ||||||
self.input_lens = input_lens | ||||||
split_inputs = tf.split(inputs, self.input_lens, axis=1) | ||||||
child_list = [] | ||||||
for i, distribution in enumerate(child_distributions): | ||||||
child_list.append(distribution(split_inputs[i])) | ||||||
child_list.append( | ||||||
distribution(split_inputs[i], model_config=model_config)) | ||||||
self.child_distributions = child_list | ||||||
self.model_config = model_config | ||||||
|
||||||
@override(ActionDistribution) | ||||||
def logp(self, x): | ||||||
|
@@ -278,7 +325,7 @@ class Dirichlet(ActionDistribution): | |||||
|
||||||
e.g. actions that represent resource allocation.""" | ||||||
|
||||||
def __init__(self, inputs): | ||||||
def __init__(self, inputs, model_config=None): | ||||||
"""Input is a tensor of logits. The exponential of logits is used to | ||||||
parametrize the Dirichlet distribution as all parameters need to be | ||||||
positive. An arbitrary small epsilon is added to the concentration | ||||||
|
@@ -293,7 +340,7 @@ def __init__(self, inputs): | |||||
validate_args=True, | ||||||
allow_nan_stats=False, | ||||||
) | ||||||
ActionDistribution.__init__(self, concentration) | ||||||
ActionDistribution.__init__(self, concentration, model_config) | ||||||
|
||||||
@override(ActionDistribution) | ||||||
def logp(self, x): | ||||||
|
@@ -315,3 +362,8 @@ def kl(self, other): | |||||
@override(ActionDistribution) | ||||||
def _build_sample_op(self): | ||||||
return self.dist.sample() | ||||||
|
||||||
@staticmethod | ||||||
@override(ActionDistribution) | ||||||
def parameter_shape_for_action_space(action_space, model_config=None): | ||||||
return action_space.shape[0] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,7 @@ | |
from functools import partial | ||
|
||
from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \ | ||
_global_registry | ||
RLLIB_ACTION_DIST, _global_registry | ||
|
||
from ray.rllib.models.extra_spaces import Simplex | ||
from ray.rllib.models.action_dist import (Categorical, MultiCategorical, | ||
|
@@ -80,6 +80,8 @@ | |
"custom_preprocessor": None, | ||
# Name of a custom model to use | ||
"custom_model": None, | ||
# Name of a custom action distribution to use | ||
"custom_action_dist": None, | ||
# Extra options to pass to the custom classes | ||
"custom_options": {}, | ||
} | ||
|
@@ -119,46 +121,57 @@ def get_action_dist(action_space, config, dist_type=None, torch=False): | |
""" | ||
|
||
config = config or MODEL_DEFAULTS | ||
if isinstance(action_space, gym.spaces.Box): | ||
if config.get("custom_action_dist"): | ||
action_dist_name = config["custom_action_dist"] | ||
logger.debug( | ||
"Using custom action distribution {}".format(action_dist_name)) | ||
dist = _global_registry.get(RLLIB_ACTION_DIST, action_dist_name) | ||
|
||
elif isinstance(action_space, gym.spaces.Box): | ||
if len(action_space.shape) > 1: | ||
raise UnsupportedSpaceException( | ||
"Action space has multiple dimensions " | ||
"{}. ".format(action_space.shape) + | ||
"Consider reshaping this into a single dimension, " | ||
"using a custom action distribution, " | ||
"using a Tuple action space, or the multi-agent API.") | ||
if dist_type is None: | ||
dist = TorchDiagGaussian if torch else DiagGaussian | ||
return dist, action_space.shape[0] * 2 | ||
elif dist_type == "deterministic": | ||
return Deterministic, action_space.shape[0] | ||
dist = Deterministic | ||
elif isinstance(action_space, gym.spaces.Discrete): | ||
dist = TorchCategorical if torch else Categorical | ||
return dist, action_space.n | ||
elif isinstance(action_space, gym.spaces.Tuple): | ||
if torch: | ||
raise NotImplementedError("Tuple action spaces not supported " | ||
"for Pytorch.") | ||
child_dist = [] | ||
input_lens = [] | ||
for action in action_space.spaces: | ||
dist, action_size = ModelCatalog.get_action_dist( | ||
action, config) | ||
child_dist.append(dist) | ||
input_lens.append(action_size) | ||
if torch: | ||
raise NotImplementedError | ||
return partial( | ||
MultiActionDistribution, | ||
child_distributions=child_dist, | ||
action_space=action_space, | ||
input_lens=input_lens), sum(input_lens) | ||
elif isinstance(action_space, Simplex): | ||
if torch: | ||
raise NotImplementedError | ||
return Dirichlet, action_space.shape[0] | ||
elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete): | ||
raise NotImplementedError("Simplex action spaces not " | ||
"supported for Pytorch.") | ||
dist = Dirichlet | ||
elif isinstance(action_space, gym.spaces.MultiDiscrete): | ||
if torch: | ||
raise NotImplementedError | ||
raise NotImplementedError("MultiDiscrete action spaces not " | ||
"supported for Pytorch.") | ||
return partial(MultiCategorical, input_lens=action_space.nvec), \ | ||
int(sum(action_space.nvec)) | ||
|
||
return dist, dist.parameter_shape_for_action_space( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like now you could simplify this to return just |
||
action_space, config) | ||
|
||
raise NotImplementedError("Unsupported args: {} {}".format( | ||
action_space, dist_type)) | ||
|
||
|
@@ -173,11 +186,16 @@ def get_action_placeholder(action_space): | |
action_placeholder (Tensor): A placeholder for the actions | ||
""" | ||
|
||
if isinstance(action_space, gym.spaces.Box): | ||
return tf.placeholder( | ||
tf.float32, shape=(None, action_space.shape[0]), name="action") | ||
elif isinstance(action_space, gym.spaces.Discrete): | ||
if isinstance(action_space, gym.spaces.Discrete): | ||
return tf.placeholder(tf.int64, shape=(None, ), name="action") | ||
elif isinstance(action_space, (gym.spaces.Box, Simplex)): | ||
return tf.placeholder( | ||
tf.float32, shape=(None, ) + action_space.shape, name="action") | ||
elif isinstance(action_space, gym.spaces.MultiDiscrete): | ||
return tf.placeholder( | ||
tf.as_dtype(action_space.dtype), | ||
shape=(None, ) + action_space.shape, | ||
name="action") | ||
elif isinstance(action_space, gym.spaces.Tuple): | ||
size = 0 | ||
all_discrete = True | ||
|
@@ -191,14 +209,6 @@ def get_action_placeholder(action_space): | |
tf.int64 if all_discrete else tf.float32, | ||
shape=(None, size), | ||
name="action") | ||
elif isinstance(action_space, Simplex): | ||
return tf.placeholder( | ||
tf.float32, shape=(None, action_space.shape[0]), name="action") | ||
elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete): | ||
return tf.placeholder( | ||
tf.as_dtype(action_space.dtype), | ||
shape=(None, len(action_space.nvec)), | ||
name="action") | ||
else: | ||
raise NotImplementedError("action space {}" | ||
" not supported".format(action_space)) | ||
|
@@ -476,3 +486,18 @@ def register_custom_model(model_name, model_class): | |
model_class (type): Python class of the model. | ||
""" | ||
_global_registry.register(RLLIB_MODEL, model_name, model_class) | ||
|
||
@staticmethod | ||
@PublicAPI | ||
def register_custom_action_dist(action_dist_name, action_dist_class): | ||
"""Register a custom action distribution class by name. | ||
|
||
The model can be later used by specifying | ||
{"custom_action_dist": action_dist_name} in the model config. | ||
|
||
Args: | ||
model_name (str): Name to register the model under. | ||
model_class (type): Python class of the model. | ||
""" | ||
_global_registry.register(RLLIB_ACTION_DIST, action_dist_name, | ||
action_dist_class) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To avoid accidents where we forget to pass this, consider making it a required argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I lean towards making it required but this may break other users' code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is DeveloperAPI I would say it's ok to err on the side of avoiding bugs vs backwards compatibility.