diff --git a/python/ray/rllib/bc/policy.py b/python/ray/rllib/bc/policy.py index 7566422fa154..6ef4cb190309 100644 --- a/python/ray/rllib/bc/policy.py +++ b/python/ray/rllib/bc/policy.py @@ -38,7 +38,7 @@ def _setup_graph(self, ob_space, ac_space): tf.get_variable_scope().name) def setup_loss(self, action_space): - self.ac = tf.placeholder(tf.int64, [None], name="ac") + self.ac = tf.placeholder(tf.float32, [None] + list(action_space.shape), name="ac") log_prob = self.curr_dist.logp(self.ac) self.pi_loss = - tf.reduce_sum(log_prob) self.loss = self.pi_loss