Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion python/ray/rllib/bc/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,14 @@ def _setup_graph(self, ob_space, ac_space):
tf.get_variable_scope().name)

def setup_loss(self, action_space):
self.ac = tf.placeholder(tf.int64, [None], name="ac")
if isinstance(action_space, gym.spaces.Box):
self.ac = tf.placeholder(tf.float32, [None] + list(action_space.shape), name="ac")
elif isinstance(action_space, gym.spaces.Discrete):
self.ac = tf.placeholder(tf.int64, [None], name="ac")
else:
raise NotImplementedError(
"action space" + str(type(action_space)) +
"currently not supported")
log_prob = self.curr_dist.logp(self.ac)
self.pi_loss = - tf.reduce_sum(log_prob)
self.loss = self.pi_loss
Expand Down