Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add comment on the use of categorical cross entropy in REINFORCE and a2c #54

Open
fredcallaway opened this issue Jun 29, 2017 · 9 comments

Comments

@fredcallaway
Copy link
Contributor

I was surprised to see this loss function because it is generally used when the target is a distribution (i.e. sums to 1). This is not the case for the advantage estimate. However, I worked out the math and it does appear to be doing the right thing which is neat!

I think this trick should be mentioned in the code.

@dnddnjs
Copy link
Contributor

dnddnjs commented Jul 3, 2017

Yes, categorical cross entropy can not be used for policy gradient generally. I have changed the cross entropy to one hot vector and multiplied with advantage function which is also one hot vector. So the categorical cross entropy became object function for the policy gradient.

I did this because I want to use model.fit() to policy gradient because it is simple! I will make some explanation for this. Thank you

@fredcallaway
Copy link
Contributor Author

Categorical cross entropy is defined H(p, q) = sum(p_i * log(q_i)). For the action taken, a, you set p_a = A. q_a is the probability of taking action a, i.e. π(a, s). All other p_i are zero, thus we have H(p, q) = A * π(a, s). It's a very clever and generalizable trick!

@dnddnjs
Copy link
Contributor

dnddnjs commented Jul 5, 2017

Thank you! Your explanation is clear enough to understand easily. I will add explanation about this soon

@dnddnjs dnddnjs closed this as completed Jul 5, 2017
@keon keon reopened this Jul 5, 2017
@fredcallaway
Copy link
Contributor Author

Shall I make a pull request?

@dnddnjs
Copy link
Contributor

dnddnjs commented Jul 12, 2017

@fredcallaway
We will appreciate it if you do that!

@nyck33
Copy link

nyck33 commented Jan 24, 2019

Hello, I am confused as to why "all other p_i are zero" if p's are advantages and softmax is applied. Is this the answer? "You correctly pointed out that the softmax numerator should never have zero-values due to the exponential. However, due to floating point precision, the numerator could be a very small value, say, exp(-50000), which essentially evaluates to zero." from https://stackoverflow.com/questions/39109674/the-output-of-a-softmax-isnt-supposed-to-have-zeros-right

@fredcallaway
Copy link
Contributor Author

Softmax is not applied to the advantages (the p_i). It is applied to action scores to get action probabilities. All other p_i are zero simply because they are initialized that way:cartpole_reinforce.py#L84

Note that p_i refers to the target (true) labels and q_i to the predictions in the standard classification case. So by feeding advantages as the "target" on line 90, we are defining p_i to be the one-hot advantages.

@nyck33
Copy link

nyck33 commented Jan 25, 2019

I printed out the advantages array:

 advantages final 
[[ 0.          1.54572815]
 [ 0.          1.21139962]
 [ 0.          0.87369404]
 [ 0.          0.53257729]
 [ 0.          0.18801491]
 [ 0.         -0.16002789]
 [ 0.         -0.51158628]
 [-0.86669576  0.        ]
 [ 0.         -1.22539221]
 [ 0.         -1.58771185]]


So the final line H(p, q) = -(A * log(policy(s,a)) means the cross-entropy loss is the negation of advantage * log probability of the policy? (I kept seeing the negation everywhere cross-entropy is explained so wanted to include it).

It's still magical to me how the network outputs probabilities but you can use the advantages like this.

Is this a hack? If so I think it is very ingenious but I also prefer to learn it the plain vanilla way. Is Karpathy's pong example the best for that?

@fredcallaway
Copy link
Contributor Author

Good point regarding the negation. Your correct that the loss has the negation. However, we are actually trying to maximize sum(p_i * log(q_i)), which corresponds to minimizing the negation.

I definitely agree that it's a bit confusing, but I think the idea of this library is not so much to teach you the math, but rather to give you some code to play around with to test the effects of hyperparameters and apply the algorithms to different environments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants