Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 6 additions & 57 deletions src/open_r1/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import logging
import os
import re
import sys
from dataclasses import dataclass, field

Expand All @@ -25,9 +24,8 @@
from transformers import set_seed
from transformers.trainer_utils import get_last_checkpoint

from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from open_r1.configs import GRPOConfig
from open_r1.rewards import REWARD_FUNCS_REGISTRY
from open_r1.utils.callbacks import get_callbacks
from trl import GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config

Expand All @@ -42,66 +40,17 @@ class GRPOScriptArguments(ScriptArguments):

Args:
reward_funcs (`list[str]`):
List of reward functions. Possible values: 'accuracy', 'format'.
List of reward functions. Possible values are dynamically populated from REWARD_FUNCS_REGISTRY.
"""

reward_funcs: list[str] = field(
default_factory=lambda: ["accuracy", "format"],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've left the default same

metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
metadata={
"help": f"List of reward functions. Possible values: {', '.join(REWARD_FUNCS_REGISTRY.keys())}"
},
)


def accuracy_reward(completions, solution, **kwargs):
"""Reward function that checks if the completion is the same as the ground truth."""
contents = [completion[0]["content"] for completion in completions]
rewards = []
for content, sol in zip(contents, solution):
gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed="all",
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
# Reward 1 if the content is the same as the ground truth, 0 otherwise
reward = float(verify(answer_parsed, gold_parsed))
else:
# If the gold solution is not parseable, we reward 1 to skip this example
reward = 1.0
print("Failed to parse gold solution: ", sol)
rewards.append(reward)

return rewards


def format_reward(completions, **kwargs):
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<think>.*?</think><answer>.*?</answer>$"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]


reward_funcs_registry = {
"accuracy": accuracy_reward,
"format": format_reward,
}

SYSTEM_PROMPT = (
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
Expand Down Expand Up @@ -149,7 +98,7 @@ def main(script_args, training_args, model_args):
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

# Get reward functions
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]

# Format into conversation
def make_conversation(example):
Expand Down
80 changes: 80 additions & 0 deletions src/open_r1/rewards.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Reward functions for GRPO training."""

import re

from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify


def accuracy_reward(completions, solution, **kwargs):
"""Reward function that checks if the completion is the same as the ground truth."""
contents = [completion[0]["content"] for completion in completions]
rewards = []
for content, sol in zip(contents, solution):
gold_parsed = parse(
sol,
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed="all",
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
# Reward 1 if the content is the same as the ground truth, 0 otherwise
reward = float(verify(answer_parsed, gold_parsed))
else:
# If the gold solution is not parseable, we reward 1 to skip this example
reward = 1.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow: why do we give the same reward in the "non parseable" case and the correct case?

Copy link
Contributor Author

@zeenolife zeenolife Feb 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function was taken as is: link

But, I speculate that:

  1. It's rare that the ground truth dataset is not parseable.
  2. The downstream GRPO trainer currently doesn't support ignoring the samples.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we give it 0 weight instead then?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's completely equivalent, the reward is normalized per group.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was just thinking of it.

I guess it is guaranteed that group would contain only the completions for that particular unparseable solution, right? Because if the group would contain others, the normalization wouldnt work properly. Does it make sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevertheless, I can come back to it in the next PR

Copy link
Contributor

@ctjlewis ctjlewis Feb 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd just made an issue about this (#159) and was reading this thread to see if it was related. A rare exception from Math-Verify stopped my training run, it needs to be caught and skipped as well (#158).

I thought it was odd because the rewards were like 0.19, 0.20, etc., so rewarding 1 for a pair where we couldn't find the gold solution, or we got an error from Math-Verify, felt off. But if we're sure it works out, I'll close those.

Copy link
Contributor

@ctjlewis ctjlewis Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we give it 0 weight instead then?

We wouldn't want to give it 0 weight either since that could be a penalty or even a net reward depending on the other rewards in the group (0 might be highest if they all have negative reward). If we don't know, we don't know, we can't put information there that isn't there, we are just diluting the signal.

We would actually want to omit them. It will take upstream changes to trl's GRPOTrainer but this is necessary IMO. I've looked at it a few times and I'm confident it's not right. We should be able to return None from a reward_func and we can just reshape the tensors to remove those examples.

Copy link
Member

@qgallouedec qgallouedec Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We wouldn't want to give it 0 weight either since that could be a penalty or even a net reward depending on the other rewards in the group

That's not right. Gold is shared in the group. If it's not parsable for one sample, then it's not parsable for the whole group.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh OK, I see now.

print("Failed to parse gold solution: ", sol)
rewards.append(reward)

return rewards


def format_reward(completions, **kwargs):
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<think>.*?</think><answer>.*?</answer>$"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]


def reasoning_steps_reward(completions, **kwargs):
"""Reward function that checks for clear step-by-step reasoning.
Regex pattern:
Step \d+: - matches "Step 1:", "Step 2:", etc.
^\d+\. - matches numbered lists like "1.", "2.", etc. at start of line
\n- - matches bullet points with hyphens
\n\* - matches bullet points with asterisks
First,|Second,|Next,|Finally, - matches transition words
"""
pattern = r"(Step \d+:|^\d+\.|\n-|\n\*|First,|Second,|Next,|Finally,)"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [len(re.findall(pattern, content)) for content in completion_contents]

# Magic nubmer 3 to encourage 3 steps and more, otherwise partial reward
return [min(1.0, count / 3) for count in matches]


REWARD_FUNCS_REGISTRY = {
"accuracy": accuracy_reward,
"format": format_reward,
"reasoning_steps": reasoning_steps_reward,
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly fan of this "registry", but list of reward functions is small, so I guess it works for now ¯_(ツ)_/¯

Empty file added tests/__init__.py
Empty file.
91 changes: 91 additions & 0 deletions tests/test_rewards.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import unittest
from open_r1.rewards import accuracy_reward, format_reward, reasoning_steps_reward


class TestRewards(unittest.TestCase):
def test_accuracy_reward_correct_answer(self):
"""Test accuracy_reward with a correct answer."""
completion = [[{"content": r"\boxed{\frac{63}{400}}"}]]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't the test be without the \boxed ? nowhere in the system prompt do we tell the model to use \boxed so even if it gets the answer right with say <answer>\frac{63}{400}</answer>, shouldn't that be counted as correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Valid point, it's passed in as system prompt argument

solution = [r"\frac{63}{400}"]

rewards = accuracy_reward(completion, solution)
self.assertEqual(rewards[0], 1.0)

def test_accuracy_reward_wrong_answer(self):
"""Test accuracy_reward with an incorrect answer."""
completion = [[{"content": r"\boxed{\frac{64}{400}}"}]]
solution = [r"\frac{63}{400}"]

rewards = accuracy_reward(completion, solution)
self.assertEqual(rewards[0], 0.0)

def test_format_reward_correct(self):
"""Test format_reward with correct format."""
completion = [[{"content": "<think>Some reasoning</think><answer>The answer</answer>"}]]
rewards = format_reward(completion)
self.assertEqual(rewards[0], 1.0)

def test_format_reward_incorrect(self):
"""Test format_reward with incorrect format."""
incorrect_formats = [
"<think>Only thinking</think>",
"<answer>Only answer</answer>",
"No tags at all",
"<think>Missing closing</think><answer>Missing closing",
"<think>Wrong order</answer><answer>Wrong order</think>"
]

for fmt in incorrect_formats:
completion = [[{"content": fmt}]]
rewards = format_reward(completion)
self.assertEqual(rewards[0], 0.0)

def test_reasoning_steps_reward(self):
"""Test reasoning_steps_reward with various formats."""
test_cases = [
# Full credit cases (3 or more steps)
(
"Step 1: First step\nStep 2: Second step\nStep 3: Third step",
1.0
),
(
"First, we do this.\nSecond, we do that.\nFinally, we conclude.",
1.0
),
# Partial credit cases (less than 3 steps)
(
"Step 1: Only step",
1/3
),
(
"First, we do this.\nFinally, we conclude.",
2/3
),
# No credit case
(
"Just plain text without any clear steps",
0.0
)
]

for content, expected_reward in test_cases:
completion = [[{"content": content}]]
rewards = reasoning_steps_reward(completion)
self.assertAlmostEqual(rewards[0], expected_reward)

def test_multiple_completions(self):
"""Test handling multiple completions at once."""
completions = [
[{"content": r"\boxed{\frac{63}{400}}"}],
[{"content": r"\boxed{\frac{64}{400}}"}]
]
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]

rewards = accuracy_reward(completions, solutions)
self.assertEqual(len(rewards), 2)
self.assertEqual(rewards[0], 1.0)
self.assertEqual(rewards[1], 0.0)


if __name__ == '__main__':
unittest.main()