diff --git a/python/ray/rllib/bc/policy.py b/python/ray/rllib/bc/policy.py index 7566422fa154..11178a50d23a 100644 --- a/python/ray/rllib/bc/policy.py +++ b/python/ray/rllib/bc/policy.py @@ -2,8 +2,10 @@ from __future__ import division from __future__ import print_function -import ray import tensorflow as tf +import gym + +import ray from ray.rllib.a3c.policy import Policy from ray.rllib.models.catalog import ModelCatalog @@ -38,7 +40,16 @@ def _setup_graph(self, ob_space, ac_space): tf.get_variable_scope().name) def setup_loss(self, action_space): - self.ac = tf.placeholder(tf.int64, [None], name="ac") + if isinstance(action_space, gym.spaces.Box): + self.ac = tf.placeholder(tf.float32, + [None] + list(action_space.shape), + name="ac") + elif isinstance(action_space, gym.spaces.Discrete): + self.ac = tf.placeholder(tf.int64, [None], name="ac") + else: + raise NotImplementedError( + "action space" + str(type(action_space)) + + "currently not supported") log_prob = self.curr_dist.logp(self.ac) self.pi_loss = - tf.reduce_sum(log_prob) self.loss = self.pi_loss