Skip to content

Commit

Permalink
add comment on use of categorical_crossentropy
Browse files Browse the repository at this point in the history
  • Loading branch information
fredcallaway committed Jul 13, 2017
1 parent 589719f commit a58645c
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 0 deletions.
6 changes: 6 additions & 0 deletions 2-cartpole/3-reinforce/cartpole_reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ def build_model(self):
model.add(Dense(self.hidden2, activation='relu', kernel_initializer='glorot_uniform'))
model.add(Dense(self.action_size, activation='softmax', kernel_initializer='glorot_uniform'))
model.summary()
# Using categorical crossentropy as a loss is a trick to easily
# implement the policy gradient. Categorical cross entropy is defined
# H(p, q) = sum(p_i * log(q_i)). For the action taken, a, you set
# p_a = advantage. q_a is the output of the policy network, which is
# the probability of taking the action a, i.e. policy(s, a).
# All other p_i are zero, thus we have H(p, q) = A * log(policy(s, a))
model.compile(loss="categorical_crossentropy", optimizer=Adam(lr=self.learning_rate))
return model

Expand Down
1 change: 1 addition & 0 deletions 2-cartpole/4-actor-critic/cartpole_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def build_actor(self):
actor.add(Dense(self.action_size, activation='softmax',
kernel_initializer='he_uniform'))
actor.summary()
# See note regarding crossentropy in cartpole_reinforce.py
actor.compile(loss='categorical_crossentropy',
optimizer=Adam(lr=self.actor_lr))
return actor
Expand Down
1 change: 1 addition & 0 deletions 3-atari/2-pong/pong_reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def _build_model(self):
model.add(Dense(32, activation='relu', init='he_uniform'))
model.add(Dense(self.action_size, activation='softmax'))
opt = Adam(lr=self.learning_rate)
# See note regarding crossentropy in cartpole_reinforce.py
model.compile(loss='categorical_crossentropy', optimizer=opt)
return model

Expand Down

1 comment on commit a58645c

@hoangcuong2011
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @fredcallaway for your explanation of why we should use categorical_crossentropy here. Very clever indeed!

Please sign in to comment.