Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/ray/rllib/a3c/a3c_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def __init__(self, obs_space, action_space, registry, config):
self.lock = Lock()

def setup_graph(self, obs_space, action_space):
_, self.logit_dim = ModelCatalog.get_action_dist(action_space)
_, self.logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
self._model = ModelCatalog.get_torch_model(
self.registry, obs_space.shape, self.logit_dim,
self.config["model"])
Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/a3c/shared_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def __init__(self, ob_space, ac_space, registry, config, **kwargs):

def _setup_graph(self, ob_space, ac_space):
self.x = tf.placeholder(tf.float32, [None] + list(ob_space.shape))
dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
ac_space, self.config["model"])
self._model = ModelCatalog.get_model(
self.registry, self.x, self.logit_dim, self.config["model"])
self.logits = self._model.outputs
Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/a3c/shared_model_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def __init__(self, ob_space, ac_space, registry, config, **kwargs):

def _setup_graph(self, ob_space, ac_space):
self.x = tf.placeholder(tf.float32, [None] + list(ob_space.shape))
dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
ac_space, self.config["model"])
self._model = LSTM(self.x, self.logit_dim, {})

self.state_in = self._model.state_in
Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/bc/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def __init__(self, registry, obs_space, action_space, config):

def _setup_graph(self, obs_space, ac_space):
self.x = tf.placeholder(tf.float32, [None] + list(obs_space.shape))
dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
ac_space, self.config["model"])
self._model = ModelCatalog.get_model(
self.registry, self.x, self.logit_dim, self.config["model"])
self.logits = self._model.outputs
Expand Down
44 changes: 42 additions & 2 deletions python/ray/rllib/models/action_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,19 @@ class DiagGaussian(ActionDistribution):
second half the gaussian standard deviations.
"""

def __init__(self, inputs):
def __init__(self, inputs, low=None, high=None):
ActionDistribution.__init__(self, inputs)
mean, log_std = tf.split(inputs, 2, axis=1)
self.mean = mean
self.low = low
self.high = high

# Squash to range if specified.
# TODO(ekl) might make sense to use a beta distribution instead:
# http://proceedings.mlr.press/v70/chou17a/chou17a.pdf
if low is not None:
self.mean = low + tf.sigmoid(self.mean) * (high - low)

self.log_std = log_std
self.std = tf.exp(log_std)

Expand All @@ -99,7 +108,10 @@ def entropy(self):
reduction_indices=[1])

def sample(self):
return self.mean + self.std * tf.random_normal(tf.shape(self.mean))
out = self.mean + self.std * tf.random_normal(tf.shape(self.mean))
if self.low is not None:
out = tf.clip_by_value(out, self.low, self.high)
return out


class Deterministic(ActionDistribution):
Expand All @@ -112,6 +124,34 @@ def sample(self):
return self.inputs


def squash_to_range(dist_cls, low, high):
"""Squashes an action distribution to a range in (low, high).

Arguments:
dist_cls (class): ActionDistribution class to wrap.
low (float|array): Scalar value or array of values.
high (float|array): Scalar value or array of values.
"""

class SquashToRangeWrapper(dist_cls):
def __init__(self, inputs):
dist_cls.__init__(self, inputs, low=low, high=high)

def logp(self, x):
return dist_cls.logp(self, x)

def kl(self, other):
return dist_cls.kl(self, other)

def entropy(self):
return dist_cls.entropy(self)

def sample(self):
return dist_cls.sample(self)

return SquashToRangeWrapper


class MultiActionDistribution(ActionDistribution):
"""Action distribution that operates for list of actions.

Expand Down
15 changes: 11 additions & 4 deletions python/ray/rllib/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
_default_registry

from ray.rllib.models.action_dist import (
Categorical, Deterministic, DiagGaussian, MultiActionDistribution)
Categorical, Deterministic, DiagGaussian, MultiActionDistribution,
squash_to_range)
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.fcnet import FullyConnectedNetwork
from ray.rllib.models.visionnet import VisionNetwork
Expand All @@ -29,6 +30,7 @@
"fcnet_hiddens", # Number of hidden layers for fully connected net
"free_log_std", # Documented in ray.rllib.models.Model
"channel_major", # Pytorch conv requires images to be channel-major
"squash_to_range", # Whether to squash the action output to space range

# === Options for custom models ===
"custom_preprocessor", # Name of a custom preprocessor to use
Expand All @@ -51,11 +53,12 @@ class ModelCatalog(object):
"""

@staticmethod
def get_action_dist(action_space, dist_type=None):
def get_action_dist(action_space, config=None, dist_type=None):
"""Returns action distribution class and size for the given action space.

Args:
action_space (Space): Action space of the target gym env.
config (dict): Optional model config.
dist_type (str): Optional identifier of the action distribution.

Returns:
Expand All @@ -66,10 +69,14 @@ def get_action_dist(action_space, dist_type=None):
# TODO(ekl) are list spaces valid?
if isinstance(action_space, list):
action_space = gym.spaces.Tuple(action_space)

config = config or {}
if isinstance(action_space, gym.spaces.Box):
if dist_type is None:
return DiagGaussian, action_space.shape[0] * 2
dist = DiagGaussian
if config.get("squash_to_range"):
dist = squash_to_range(
dist, action_space.low, action_space.high)
return dist, action_space.shape[0] * 2
elif dist_type == 'deterministic':
return Deterministic, action_space.shape[0]
elif isinstance(action_space, gym.spaces.Discrete):
Expand Down
3 changes: 2 additions & 1 deletion python/ray/rllib/pg/pg_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def __init__(self, obs_space, action_space, registry, config):

# setup policy
self.x = tf.placeholder(tf.float32, shape=[None]+list(obs_space.shape))
dist_class, self.logit_dim = ModelCatalog.get_action_dist(action_space)
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
self.model = ModelCatalog.get_model(
registry, self.x, self.logit_dim, options=self.config["model"])
self.dist = dist_class(self.model.outputs) # logit for each action
Expand Down
2 changes: 1 addition & 1 deletion python/ray/rllib/ppo/ppo_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(self, registry, env_creator, config, logdir, is_remote):
action_space = self.env.action_space
self.actions = ModelCatalog.get_action_placeholder(action_space)
self.distribution_class, self.logit_dim = ModelCatalog.get_action_dist(
action_space)
action_space, config["model"])
# Log probabilities from the policy before the policy update.
self.prev_logits = tf.placeholder(
tf.float32, shape=(None, self.logit_dim))
Expand Down
1 change: 1 addition & 0 deletions python/ray/rllib/tuned_examples/pendulum-ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ pendulum-ppo:
num_sgd_iter: 10
model:
fcnet_hiddens: [64, 64]
squash_to_range: True