-
Notifications
You must be signed in to change notification settings - Fork 306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GRPO implementation update #534
Comments
That's a good catch. While kl1 estimator is a "decent" estimator of kl divergence, its gradient is not the "correct" estimator of the gradient of kl divergence. I suspect that you are using the kl1 estimator. The gradient of the kl divergence between the policy If the samples This is what kl1 should have been. BTW, the GRPO paper recommends using kl3 estimator (Equation 4 in https://arxiv.org/pdf/2402.03300) |
Yeah, so I thought about the kl3. The issue is that kl3 blows up: see but this only happens when using KL as part of the reward. I am not sure what happens when we put it in the loss directly... |
Interesting. I agree that it might work if added in the loss directly. I have modified my local fork with the following changes:
Basically, it gives a kl1 estimator whose gradient is unbiased. I haven't created a new PR since it has conflicting changes with your PR. |
I can help test it out. I modified it to the following and launched a scan. Will report back with results.
|
Great. Thanks |
Since RLVR looks only at the final answer completely ignoring the the text generated, the model can be motivated to generate text that deviates from the reference policy as long as it reaches the correct final answer. Perhaps, a stronger KL term (higher beta value) is desirable in RLVR, thus ensuring that the text generated is sensible. |
FYI on KL3 estimator. https://x.com/vwxyzjn/status/1885329398821187633 |
Meanwhile
@gauravpandeyamu why did you multiply the |
Ahh, yes. Now, the graphs of kl1, kl2, kl3 and kl4 make perfect sense. As for why kl1 estimator is a bad estimator, and how multiplying by the ratio fixes it (mainly the bias is fixed), here is ChatGPT's response. https://chatgpt.com/share/679dfb75-cdec-800f-9078-f838d3925f9e |
If I am not mistaken, I believe that the gradient of KL3 is also not an unbiased estimator of the gradient of the KL divergence. Given this, I don't fully understand how KL3 can work for optimization (maybe the lower variance alone is good enough, despite this potential bias). It should be a good estimator to monitor the KL divergence (notably because it remains non-negative, is unbiased and has low variance), but using it for optimization is a bit mysterious to me. I'm sorry to highjack this issue, but I haven't seen a lot of discussions on the use of KL3 in GRPO. |
You are right. It is biased but with a low variance. Intuitively, another reason why it works is that each term of the kl3 estimator (-log p(x)/q(x) + p(x)/q(x) -1 ) is lower bounded by 0 (same as kl2) with the equality achieved only when p(x)=q(x). So, if you try to estimate kl with just one sample and optimize the kl3 estimator for that sample, you will end up making the two probs equal for that sample. On the other hand, if you try to minimize the kl1 estimator for a single sample (-log p(x)/q(x)), you will just end up maximizing q(x). This is why optimizing kl1 leads to blow-up in KL if beta is higher. |
Also worth noting that |
That makes sense. In that case, I would find KL2 a more "principled" choice for optimization, since its gradient wrt. But your comment on it being a valid divergence also makes a lot of sense, and maybe what the success of KL3 shows is that we should go beyond just the KL divergence. |
I agree with KL2 being a more principled choice for optimization. |
Let's use this issue to share the latest GRPO development updates. CC @gauravpandeyamu thanks for your fix.
The command below (without @gauravpandeyamu's fix) yields the charts below. Overall the training score and sequence length goes up, but the downstream eval in MATH seems to suffer. I am gonna try out @gauravpandeyamu's fix on KL regularization to see if it helps.
The text was updated successfully, but these errors were encountered: