-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Refactoring reward functions. Adding step by step reasoning reward. Adding test coverage for reward functions #144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| """Reward functions for GRPO training.""" | ||
|
|
||
| import re | ||
|
|
||
| from latex2sympy2_extended import NormalizationConfig | ||
| from math_verify import LatexExtractionConfig, parse, verify | ||
|
|
||
|
|
||
| def accuracy_reward(completions, solution, **kwargs): | ||
| """Reward function that checks if the completion is the same as the ground truth.""" | ||
| contents = [completion[0]["content"] for completion in completions] | ||
| rewards = [] | ||
| for content, sol in zip(contents, solution): | ||
| gold_parsed = parse( | ||
| sol, | ||
| extraction_mode="first_match", | ||
| extraction_config=[LatexExtractionConfig()], | ||
| ) | ||
| if len(gold_parsed) != 0: | ||
| # We require the answer to be provided in correct latex (no malformed operators) | ||
| answer_parsed = parse( | ||
| content, | ||
| extraction_config=[ | ||
| LatexExtractionConfig( | ||
| normalization_config=NormalizationConfig( | ||
| nits=False, | ||
| malformed_operators=False, | ||
| basic_latex=True, | ||
| equations=True, | ||
| boxed="all", | ||
| units=True, | ||
| ), | ||
| # Ensures that boxed is tried first | ||
| boxed_match_priority=0, | ||
| try_extract_without_anchor=False, | ||
| ) | ||
| ], | ||
| extraction_mode="first_match", | ||
| ) | ||
| # Reward 1 if the content is the same as the ground truth, 0 otherwise | ||
| reward = float(verify(answer_parsed, gold_parsed)) | ||
| else: | ||
| # If the gold solution is not parseable, we reward 1 to skip this example | ||
| reward = 1.0 | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't follow: why do we give the same reward in the "non parseable" case and the correct case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function was taken as is: link But, I speculate that:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't we give it 0 weight instead then? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's completely equivalent, the reward is normalized per group. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I was just thinking of it. I guess it is guaranteed that group would contain only the completions for that particular unparseable solution, right? Because if the group would contain others, the normalization wouldnt work properly. Does it make sense? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nevertheless, I can come back to it in the next PR There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd just made an issue about this (#159) and was reading this thread to see if it was related. A rare exception from Math-Verify stopped my training run, it needs to be caught and skipped as well (#158). I thought it was odd because the rewards were like 0.19, 0.20, etc., so rewarding 1 for a pair where we couldn't find the gold solution, or we got an error from Math-Verify, felt off. But if we're sure it works out, I'll close those. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We wouldn't want to give it 0 weight either since that could be a penalty or even a net reward depending on the other rewards in the group (0 might be highest if they all have negative reward). If we don't know, we don't know, we can't put information there that isn't there, we are just diluting the signal. We would actually want to omit them. It will take upstream changes to trl's GRPOTrainer but this is necessary IMO. I've looked at it a few times and I'm confident it's not right. We should be able to return There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
That's not right. Gold is shared in the group. If it's not parsable for one sample, then it's not parsable for the whole group. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh OK, I see now. |
||
| print("Failed to parse gold solution: ", sol) | ||
| rewards.append(reward) | ||
|
|
||
| return rewards | ||
|
|
||
|
|
||
| def format_reward(completions, **kwargs): | ||
| """Reward function that checks if the completion has a specific format.""" | ||
| pattern = r"^<think>.*?</think><answer>.*?</answer>$" | ||
| completion_contents = [completion[0]["content"] for completion in completions] | ||
| matches = [re.match(pattern, content) for content in completion_contents] | ||
| return [1.0 if match else 0.0 for match in matches] | ||
|
|
||
|
|
||
| def reasoning_steps_reward(completions, **kwargs): | ||
| """Reward function that checks for clear step-by-step reasoning. | ||
| Regex pattern: | ||
| Step \d+: - matches "Step 1:", "Step 2:", etc. | ||
| ^\d+\. - matches numbered lists like "1.", "2.", etc. at start of line | ||
| \n- - matches bullet points with hyphens | ||
| \n\* - matches bullet points with asterisks | ||
| First,|Second,|Next,|Finally, - matches transition words | ||
| """ | ||
| pattern = r"(Step \d+:|^\d+\.|\n-|\n\*|First,|Second,|Next,|Finally,)" | ||
| completion_contents = [completion[0]["content"] for completion in completions] | ||
| matches = [len(re.findall(pattern, content)) for content in completion_contents] | ||
|
|
||
| # Magic nubmer 3 to encourage 3 steps and more, otherwise partial reward | ||
| return [min(1.0, count / 3) for count in matches] | ||
|
|
||
|
|
||
| REWARD_FUNCS_REGISTRY = { | ||
| "accuracy": accuracy_reward, | ||
| "format": format_reward, | ||
| "reasoning_steps": reasoning_steps_reward, | ||
| } | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not exactly fan of this "registry", but list of reward functions is small, so I guess it works for now ¯_(ツ)_/¯ |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| import unittest | ||
| from open_r1.rewards import accuracy_reward, format_reward, reasoning_steps_reward | ||
|
|
||
|
|
||
| class TestRewards(unittest.TestCase): | ||
| def test_accuracy_reward_correct_answer(self): | ||
| """Test accuracy_reward with a correct answer.""" | ||
| completion = [[{"content": r"\boxed{\frac{63}{400}}"}]] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't the test be without the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Valid point, it's passed in as system prompt argument |
||
| solution = [r"\frac{63}{400}"] | ||
|
|
||
| rewards = accuracy_reward(completion, solution) | ||
| self.assertEqual(rewards[0], 1.0) | ||
|
|
||
| def test_accuracy_reward_wrong_answer(self): | ||
| """Test accuracy_reward with an incorrect answer.""" | ||
| completion = [[{"content": r"\boxed{\frac{64}{400}}"}]] | ||
| solution = [r"\frac{63}{400}"] | ||
|
|
||
| rewards = accuracy_reward(completion, solution) | ||
| self.assertEqual(rewards[0], 0.0) | ||
|
|
||
| def test_format_reward_correct(self): | ||
| """Test format_reward with correct format.""" | ||
| completion = [[{"content": "<think>Some reasoning</think><answer>The answer</answer>"}]] | ||
| rewards = format_reward(completion) | ||
| self.assertEqual(rewards[0], 1.0) | ||
|
|
||
| def test_format_reward_incorrect(self): | ||
| """Test format_reward with incorrect format.""" | ||
| incorrect_formats = [ | ||
| "<think>Only thinking</think>", | ||
| "<answer>Only answer</answer>", | ||
| "No tags at all", | ||
| "<think>Missing closing</think><answer>Missing closing", | ||
| "<think>Wrong order</answer><answer>Wrong order</think>" | ||
| ] | ||
|
|
||
| for fmt in incorrect_formats: | ||
| completion = [[{"content": fmt}]] | ||
| rewards = format_reward(completion) | ||
| self.assertEqual(rewards[0], 0.0) | ||
|
|
||
| def test_reasoning_steps_reward(self): | ||
| """Test reasoning_steps_reward with various formats.""" | ||
| test_cases = [ | ||
| # Full credit cases (3 or more steps) | ||
| ( | ||
| "Step 1: First step\nStep 2: Second step\nStep 3: Third step", | ||
| 1.0 | ||
| ), | ||
| ( | ||
| "First, we do this.\nSecond, we do that.\nFinally, we conclude.", | ||
| 1.0 | ||
| ), | ||
| # Partial credit cases (less than 3 steps) | ||
| ( | ||
| "Step 1: Only step", | ||
| 1/3 | ||
| ), | ||
| ( | ||
| "First, we do this.\nFinally, we conclude.", | ||
| 2/3 | ||
| ), | ||
| # No credit case | ||
| ( | ||
| "Just plain text without any clear steps", | ||
| 0.0 | ||
| ) | ||
| ] | ||
|
|
||
| for content, expected_reward in test_cases: | ||
| completion = [[{"content": content}]] | ||
| rewards = reasoning_steps_reward(completion) | ||
| self.assertAlmostEqual(rewards[0], expected_reward) | ||
|
|
||
| def test_multiple_completions(self): | ||
| """Test handling multiple completions at once.""" | ||
| completions = [ | ||
| [{"content": r"\boxed{\frac{63}{400}}"}], | ||
| [{"content": r"\boxed{\frac{64}{400}}"}] | ||
| ] | ||
| solutions = [r"\frac{63}{400}", r"\frac{63}{400}"] | ||
|
|
||
| rewards = accuracy_reward(completions, solutions) | ||
| self.assertEqual(len(rewards), 2) | ||
| self.assertEqual(rewards[0], 1.0) | ||
| self.assertEqual(rewards[1], 0.0) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| unittest.main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've left the default same