diff --git a/python/ray/rllib/a3c/a3c_torch_policy.py b/python/ray/rllib/a3c/a3c_torch_policy.py index 5a654fa5732c..1da89db4c4be 100644 --- a/python/ray/rllib/a3c/a3c_torch_policy.py +++ b/python/ray/rllib/a3c/a3c_torch_policy.py @@ -27,7 +27,8 @@ def __init__(self, obs_space, action_space, registry, config): self.lock = Lock() def setup_graph(self, obs_space, action_space): - _, self.logit_dim = ModelCatalog.get_action_dist(action_space) + _, self.logit_dim = ModelCatalog.get_action_dist( + action_space, self.config["model"]) self._model = ModelCatalog.get_torch_model( self.registry, obs_space.shape, self.logit_dim, self.config["model"]) diff --git a/python/ray/rllib/a3c/shared_model.py b/python/ray/rllib/a3c/shared_model.py index 3a093fa906f8..f01d8fbeddc9 100644 --- a/python/ray/rllib/a3c/shared_model.py +++ b/python/ray/rllib/a3c/shared_model.py @@ -16,7 +16,8 @@ def __init__(self, ob_space, ac_space, registry, config, **kwargs): def _setup_graph(self, ob_space, ac_space): self.x = tf.placeholder(tf.float32, [None] + list(ob_space.shape)) - dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space) + dist_class, self.logit_dim = ModelCatalog.get_action_dist( + ac_space, self.config["model"]) self._model = ModelCatalog.get_model( self.registry, self.x, self.logit_dim, self.config["model"]) self.logits = self._model.outputs diff --git a/python/ray/rllib/a3c/shared_model_lstm.py b/python/ray/rllib/a3c/shared_model_lstm.py index 7cb64e684aa6..7950a24925bc 100644 --- a/python/ray/rllib/a3c/shared_model_lstm.py +++ b/python/ray/rllib/a3c/shared_model_lstm.py @@ -17,7 +17,8 @@ def __init__(self, ob_space, ac_space, registry, config, **kwargs): def _setup_graph(self, ob_space, ac_space): self.x = tf.placeholder(tf.float32, [None] + list(ob_space.shape)) - dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space) + dist_class, self.logit_dim = ModelCatalog.get_action_dist( + ac_space, self.config["model"]) self._model = LSTM(self.x, self.logit_dim, {}) self.state_in = self._model.state_in diff --git a/python/ray/rllib/bc/policy.py b/python/ray/rllib/bc/policy.py index 2c4210a57cf5..bc972bba1cfa 100644 --- a/python/ray/rllib/bc/policy.py +++ b/python/ray/rllib/bc/policy.py @@ -22,7 +22,8 @@ def __init__(self, registry, obs_space, action_space, config): def _setup_graph(self, obs_space, ac_space): self.x = tf.placeholder(tf.float32, [None] + list(obs_space.shape)) - dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space) + dist_class, self.logit_dim = ModelCatalog.get_action_dist( + ac_space, self.config["model"]) self._model = ModelCatalog.get_model( self.registry, self.x, self.logit_dim, self.config["model"]) self.logits = self._model.outputs diff --git a/python/ray/rllib/models/action_dist.py b/python/ray/rllib/models/action_dist.py index cf5fecb1968e..c4de85004000 100644 --- a/python/ray/rllib/models/action_dist.py +++ b/python/ray/rllib/models/action_dist.py @@ -73,10 +73,19 @@ class DiagGaussian(ActionDistribution): second half the gaussian standard deviations. """ - def __init__(self, inputs): + def __init__(self, inputs, low=None, high=None): ActionDistribution.__init__(self, inputs) mean, log_std = tf.split(inputs, 2, axis=1) self.mean = mean + self.low = low + self.high = high + + # Squash to range if specified. + # TODO(ekl) might make sense to use a beta distribution instead: + # http://proceedings.mlr.press/v70/chou17a/chou17a.pdf + if low is not None: + self.mean = low + tf.sigmoid(self.mean) * (high - low) + self.log_std = log_std self.std = tf.exp(log_std) @@ -99,7 +108,10 @@ def entropy(self): reduction_indices=[1]) def sample(self): - return self.mean + self.std * tf.random_normal(tf.shape(self.mean)) + out = self.mean + self.std * tf.random_normal(tf.shape(self.mean)) + if self.low is not None: + out = tf.clip_by_value(out, self.low, self.high) + return out class Deterministic(ActionDistribution): @@ -112,6 +124,34 @@ def sample(self): return self.inputs +def squash_to_range(dist_cls, low, high): + """Squashes an action distribution to a range in (low, high). + + Arguments: + dist_cls (class): ActionDistribution class to wrap. + low (float|array): Scalar value or array of values. + high (float|array): Scalar value or array of values. + """ + + class SquashToRangeWrapper(dist_cls): + def __init__(self, inputs): + dist_cls.__init__(self, inputs, low=low, high=high) + + def logp(self, x): + return dist_cls.logp(self, x) + + def kl(self, other): + return dist_cls.kl(self, other) + + def entropy(self): + return dist_cls.entropy(self) + + def sample(self): + return dist_cls.sample(self) + + return SquashToRangeWrapper + + class MultiActionDistribution(ActionDistribution): """Action distribution that operates for list of actions. diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index 603073dbfebc..f5684f0aa220 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -11,7 +11,8 @@ _default_registry from ray.rllib.models.action_dist import ( - Categorical, Deterministic, DiagGaussian, MultiActionDistribution) + Categorical, Deterministic, DiagGaussian, MultiActionDistribution, + squash_to_range) from ray.rllib.models.preprocessors import get_preprocessor from ray.rllib.models.fcnet import FullyConnectedNetwork from ray.rllib.models.visionnet import VisionNetwork @@ -29,6 +30,7 @@ "fcnet_hiddens", # Number of hidden layers for fully connected net "free_log_std", # Documented in ray.rllib.models.Model "channel_major", # Pytorch conv requires images to be channel-major + "squash_to_range", # Whether to squash the action output to space range # === Options for custom models === "custom_preprocessor", # Name of a custom preprocessor to use @@ -51,11 +53,12 @@ class ModelCatalog(object): """ @staticmethod - def get_action_dist(action_space, dist_type=None): + def get_action_dist(action_space, config=None, dist_type=None): """Returns action distribution class and size for the given action space. Args: action_space (Space): Action space of the target gym env. + config (dict): Optional model config. dist_type (str): Optional identifier of the action distribution. Returns: @@ -66,10 +69,14 @@ def get_action_dist(action_space, dist_type=None): # TODO(ekl) are list spaces valid? if isinstance(action_space, list): action_space = gym.spaces.Tuple(action_space) - + config = config or {} if isinstance(action_space, gym.spaces.Box): if dist_type is None: - return DiagGaussian, action_space.shape[0] * 2 + dist = DiagGaussian + if config.get("squash_to_range"): + dist = squash_to_range( + dist, action_space.low, action_space.high) + return dist, action_space.shape[0] * 2 elif dist_type == 'deterministic': return Deterministic, action_space.shape[0] elif isinstance(action_space, gym.spaces.Discrete): diff --git a/python/ray/rllib/pg/pg_policy_graph.py b/python/ray/rllib/pg/pg_policy_graph.py index b605a513f39c..55919be046c7 100644 --- a/python/ray/rllib/pg/pg_policy_graph.py +++ b/python/ray/rllib/pg/pg_policy_graph.py @@ -16,7 +16,8 @@ def __init__(self, obs_space, action_space, registry, config): # setup policy self.x = tf.placeholder(tf.float32, shape=[None]+list(obs_space.shape)) - dist_class, self.logit_dim = ModelCatalog.get_action_dist(action_space) + dist_class, self.logit_dim = ModelCatalog.get_action_dist( + action_space, self.config["model"]) self.model = ModelCatalog.get_model( registry, self.x, self.logit_dim, options=self.config["model"]) self.dist = dist_class(self.model.outputs) # logit for each action diff --git a/python/ray/rllib/ppo/ppo_evaluator.py b/python/ray/rllib/ppo/ppo_evaluator.py index a8ca6e54ca92..b78297661042 100644 --- a/python/ray/rllib/ppo/ppo_evaluator.py +++ b/python/ray/rllib/ppo/ppo_evaluator.py @@ -69,7 +69,7 @@ def __init__(self, registry, env_creator, config, logdir, is_remote): action_space = self.env.action_space self.actions = ModelCatalog.get_action_placeholder(action_space) self.distribution_class, self.logit_dim = ModelCatalog.get_action_dist( - action_space) + action_space, config["model"]) # Log probabilities from the policy before the policy update. self.prev_logits = tf.placeholder( tf.float32, shape=(None, self.logit_dim)) diff --git a/python/ray/rllib/tuned_examples/pendulum-ppo.yaml b/python/ray/rllib/tuned_examples/pendulum-ppo.yaml index 089bb946de4c..841bbfd6f88f 100644 --- a/python/ray/rllib/tuned_examples/pendulum-ppo.yaml +++ b/python/ray/rllib/tuned_examples/pendulum-ppo.yaml @@ -12,3 +12,4 @@ pendulum-ppo: num_sgd_iter: 10 model: fcnet_hiddens: [64, 64] + squash_to_range: True