-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Refactoring reward functions. Adding step by step reasoning reward. Adding test coverage for reward functions #144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactoring reward functions. Adding step by step reasoning reward. Adding test coverage for reward functions #144
Conversation
| """ | ||
|
|
||
| reward_funcs: list[str] = field( | ||
| default_factory=lambda: ["accuracy", "format"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've left the default same
| "accuracy": accuracy_reward, | ||
| "format": format_reward, | ||
| "reasoning_steps": reasoning_steps_reward, | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not exactly fan of this "registry", but list of reward functions is small, so I guess it works for now ¯_(ツ)_/¯
| reward = float(verify(answer_parsed, gold_parsed)) | ||
| else: | ||
| # If the gold solution is not parseable, we reward 1 to skip this example | ||
| reward = 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't follow: why do we give the same reward in the "non parseable" case and the correct case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function was taken as is: link
But, I speculate that:
- It's rare that the ground truth dataset is not parseable.
- The downstream GRPO trainer currently doesn't support ignoring the samples.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't we give it 0 weight instead then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's completely equivalent, the reward is normalized per group.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I was just thinking of it.
I guess it is guaranteed that group would contain only the completions for that particular unparseable solution, right? Because if the group would contain others, the normalization wouldnt work properly. Does it make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nevertheless, I can come back to it in the next PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd just made an issue about this (#159) and was reading this thread to see if it was related. A rare exception from Math-Verify stopped my training run, it needs to be caught and skipped as well (#158).
I thought it was odd because the rewards were like 0.19, 0.20, etc., so rewarding 1 for a pair where we couldn't find the gold solution, or we got an error from Math-Verify, felt off. But if we're sure it works out, I'll close those.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't we give it 0 weight instead then?
We wouldn't want to give it 0 weight either since that could be a penalty or even a net reward depending on the other rewards in the group (0 might be highest if they all have negative reward). If we don't know, we don't know, we can't put information there that isn't there, we are just diluting the signal.
We would actually want to omit them. It will take upstream changes to trl's GRPOTrainer but this is necessary IMO. I've looked at it a few times and I'm confident it's not right. We should be able to return None from a reward_func and we can just reshape the tensors to remove those examples.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We wouldn't want to give it 0 weight either since that could be a penalty or even a net reward depending on the other rewards in the group
That's not right. Gold is shared in the group. If it's not parsable for one sample, then it's not parsable for the whole group.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh OK, I see now.
| class TestRewards(unittest.TestCase): | ||
| def test_accuracy_reward_correct_answer(self): | ||
| """Test accuracy_reward with a correct answer.""" | ||
| completion = [[{"content": r"\boxed{\frac{63}{400}}"}]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't the test be without the \boxed ? nowhere in the system prompt do we tell the model to use \boxed so even if it gets the answer right with say <answer>\frac{63}{400}</answer>, shouldn't that be counted as correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Valid point, it's passed in as system prompt argument
|
@qgallouedec kindly tagging for the review |
57db2dc to
f81cfab
Compare
|
@kashif @lewtun @edbeeching tagging other maintainers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for this refactor.
|
Sorry, minor ruff error, fixed it. @edbeeching |
…dding test coverage for reward functions
5225fe4 to
c991fe0
Compare
|
Rebased because of the merge conflict |
…dding test coverage for reward functions (huggingface#144) * Refactoring reward functions. Adding step by step reasoning reward. Adding test coverage for reward functions * [Refactoring reward functions] - Ruff error fix * [Refactoring reward functions] - Linting error fix
I've placed the reward functions into their own file to accommodate for future expansion (code specific reward functions and etc.). Added tests to make sure the reward functions are working as intended.
Also, I've written an optional step by step thinking encouragement function