Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
export PYTHONPATH = src

check_dirs := src
check_dirs := src tests

style:
ruff format --line-length 119 --target-version py310 $(check_dirs) setup.py
Expand Down
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
"hf_transfer>=0.1.4",
"huggingface-hub[cli]>=0.19.2,<1.0",
"isort>=5.12.0",
"latex2sympy2_extended>=1.0.6",
"math-verify>=0.5.2",
"liger_kernel==0.5.2",
"lighteval @ git+https://github.com/huggingface/lighteval.git@86f62259f105ae164f655e0b91c92a823a742724#egg=lighteval[math]",
"math-verify==0.5.2", # Used for math verification in grpo
Expand Down Expand Up @@ -96,6 +98,8 @@ def deps_list(*pkgs):
deps["deepspeed"],
deps["hf_transfer"],
deps["huggingface-hub"],
deps["latex2sympy2_extended"],
deps["math-verify"],
deps["liger_kernel"],
deps["packaging"], # utilities from PyPA to e.g., compare versions
deps["safetensors"],
Expand Down
52 changes: 48 additions & 4 deletions src/open_r1/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
import sys
from dataclasses import dataclass, field
from functools import partial

import datasets
import torch
Expand All @@ -25,7 +26,7 @@
from transformers.trainer_utils import get_last_checkpoint

from open_r1.configs import GRPOConfig
from open_r1.rewards import REWARD_FUNCS_REGISTRY
from open_r1.rewards import accuracy_reward, cosine_scaled_reward, format_reward, reasoning_steps_reward
from open_r1.utils.callbacks import get_callbacks
from trl import GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config

Expand All @@ -40,15 +41,45 @@ class GRPOScriptArguments(ScriptArguments):

Args:
reward_funcs (`list[str]`):
List of reward functions. Possible values are dynamically populated from REWARD_FUNCS_REGISTRY.
List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine'.
cosine_min_value_wrong (`float`):
Minimum reward for cosine scaling for wrong answers.
cosine_max_value_wrong (`float`):
Maximum reward for cosine scaling for wrong answers.
cosine_min_value_correct (`float`):
Minimum reward for cosine scaling for correct answers.
cosine_max_value_correct (`float`):
Maximum reward for cosine scaling for correct answers.
cosine_max_len (`int`):
Maximum length for cosine scaling.
"""

reward_funcs: list[str] = field(
default_factory=lambda: ["accuracy", "format"],
default_factory=lambda: ["accuracy", "format", "reasoning_steps", "cosine"],
metadata={
"help": f"List of reward functions. Possible values: {', '.join(REWARD_FUNCS_REGISTRY.keys())}"
"help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine'"
},
)
cosine_min_value_wrong: float = field(
default=0.0,
metadata={"help": "Minimum reward for wrong answers"},
)
cosine_max_value_wrong: float = field(
default=-0.5,
metadata={"help": "Maximum reward for wrong answers"},
)
cosine_min_value_correct: float = field(
default=0.5,
metadata={"help": "Minimum reward for correct answers"},
)
cosine_max_value_correct: float = field(
default=1.0,
metadata={"help": "Maximum reward for correct answers"},
)
cosine_max_len: int = field(
default=1000,
metadata={"help": "Maximum length for scaling"},
)


SYSTEM_PROMPT = (
Expand Down Expand Up @@ -98,6 +129,19 @@ def main(script_args, training_args, model_args):
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

# Get reward functions
REWARD_FUNCS_REGISTRY = {
"accuracy": accuracy_reward,
"format": format_reward,
"reasoning_steps": reasoning_steps_reward,
"cosine": partial(
cosine_scaled_reward,
min_value_wrong=script_args.cosine_min_value_wrong,
max_value_wrong=script_args.cosine_max_value_wrong,
min_value_correct=script_args.cosine_min_value_correct,
max_value_correct=script_args.cosine_max_value_correct,
max_len=script_args.cosine_max_len,
),
}
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]

# Format into conversation
Expand Down
78 changes: 73 additions & 5 deletions src/open_r1/rewards.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Reward functions for GRPO training."""

import math
import re

from latex2sympy2_extended import NormalizationConfig
Expand Down Expand Up @@ -73,8 +74,75 @@ def reasoning_steps_reward(completions, **kwargs):
return [min(1.0, count / 3) for count in matches]


REWARD_FUNCS_REGISTRY = {
"accuracy": accuracy_reward,
"format": format_reward,
"reasoning_steps": reasoning_steps_reward,
}
def cosine_scaled_reward(
completions,
solution,
min_value_wrong: float = -1.0,
max_value_wrong: float = -0.5,
min_value_correct: float = 0.5,
max_value_correct: float = 1.0,
max_len: int = 1000,
**kwargs,
):
"""Reward function that scales based on completion length using a cosine schedule.

Shorter correct solutions are rewarded more than longer ones.
Longer incorrect solutions are penalized less than shorter ones.

Args:
completions: List of model completions
solution: List of ground truth solutions
min_value_wrong: Minimum reward for wrong answers
max_value_wrong: Maximum reward for wrong answers
min_value_correct: Minimum reward for correct answers
max_value_correct: Maximum reward for correct answers
max_len: Maximum length for scaling
"""
contents = [completion[0]["content"] for completion in completions]
rewards = []

for content, sol in zip(contents, solution):
gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
if len(gold_parsed) == 0:
rewards.append(1.0) # Skip unparseable examples
print("Failed to parse gold solution: ", sol)
continue

answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed=True,
units=True,
),
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)

is_correct = verify(answer_parsed, gold_parsed)
gen_len = len(content)

# Apply cosine scaling based on length
progress = gen_len / max_len
cosine = math.cos(progress * math.pi)

if is_correct:
min_value = min_value_correct
max_value = max_value_correct
else:
# Swap min/max for incorrect answers
min_value = max_value_wrong
max_value = min_value_wrong

reward = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine)
rewards.append(float(reward))

return rewards
77 changes: 44 additions & 33 deletions tests/test_rewards.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import unittest
from open_r1.rewards import accuracy_reward, format_reward, reasoning_steps_reward

from open_r1.rewards import accuracy_reward, cosine_scaled_reward, format_reward, reasoning_steps_reward


class TestRewards(unittest.TestCase):
def test_accuracy_reward_correct_answer(self):
"""Test accuracy_reward with a correct answer."""
completion = [[{"content": r"\boxed{\frac{63}{400}}"}]]
solution = [r"\frac{63}{400}"]

rewards = accuracy_reward(completion, solution)
self.assertEqual(rewards[0], 1.0)

def test_accuracy_reward_wrong_answer(self):
"""Test accuracy_reward with an incorrect answer."""
completion = [[{"content": r"\boxed{\frac{64}{400}}"}]]
solution = [r"\frac{63}{400}"]

rewards = accuracy_reward(completion, solution)
self.assertEqual(rewards[0], 0.0)

Expand All @@ -32,9 +33,9 @@ def test_format_reward_incorrect(self):
"<answer>Only answer</answer>",
"No tags at all",
"<think>Missing closing</think><answer>Missing closing",
"<think>Wrong order</answer><answer>Wrong order</think>"
"<think>Wrong order</answer><answer>Wrong order</think>",
]

for fmt in incorrect_formats:
completion = [[{"content": fmt}]]
rewards = format_reward(completion)
Expand All @@ -44,48 +45,58 @@ def test_reasoning_steps_reward(self):
"""Test reasoning_steps_reward with various formats."""
test_cases = [
# Full credit cases (3 or more steps)
(
"Step 1: First step\nStep 2: Second step\nStep 3: Third step",
1.0
),
(
"First, we do this.\nSecond, we do that.\nFinally, we conclude.",
1.0
),
("Step 1: First step\nStep 2: Second step\nStep 3: Third step", 1.0),
("First, we do this.\nSecond, we do that.\nFinally, we conclude.", 1.0),
# Partial credit cases (less than 3 steps)
(
"Step 1: Only step",
1/3
),
(
"First, we do this.\nFinally, we conclude.",
2/3
),
("Step 1: Only step", 1 / 3),
("First, we do this.\nFinally, we conclude.", 2 / 3),
# No credit case
(
"Just plain text without any clear steps",
0.0
)
("Just plain text without any clear steps", 0.0),
]

for content, expected_reward in test_cases:
completion = [[{"content": content}]]
rewards = reasoning_steps_reward(completion)
self.assertAlmostEqual(rewards[0], expected_reward)

def test_multiple_completions(self):
"""Test handling multiple completions at once."""
completions = [
[{"content": r"\boxed{\frac{63}{400}}"}],
[{"content": r"\boxed{\frac{64}{400}}"}]
]
completions = [[{"content": r"\boxed{\frac{63}{400}}"}], [{"content": r"\boxed{\frac{64}{400}}"}]]
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]

rewards = accuracy_reward(completions, solutions)
self.assertEqual(len(rewards), 2)
self.assertEqual(rewards[0], 1.0)
self.assertEqual(rewards[1], 0.0)

def test_cosine_scaled_reward(self):
"""Test cosine_scaled_reward with various cases."""
# Test parameters
test_params = {
"min_value_wrong": -1.0,
"max_value_wrong": -0.5,
"min_value_correct": 0.5,
"max_value_correct": 1.0,
"max_len": 100,
}

test_cases = [
# Correct answers with different lengths
(r"\boxed{\frac{63}{400}}", r"\frac{63}{400}", 20, 0.943), # Short correct answer
(r"\boxed{\frac{63}{400}}", r"\frac{63}{400}", 80, 0.547), # Long correct answer
# Wrong answers with different lengths
(r"\boxed{\frac{64}{400}}", r"\frac{63}{400}", 20, -0.942), # Short wrong answer
(r"\boxed{\frac{64}{400}}", r"\frac{63}{400}", 80, -0.547), # Long wrong answer
]

for content, solution, content_len, expected_reward in test_cases:
# Pad content to desired length
padded_content = content + " " * (content_len - len(content))
completion = [[{"content": padded_content}]]

rewards = cosine_scaled_reward(completion, [solution], **test_params)
self.assertAlmostEqual(rewards[0], expected_reward, places=2)


if __name__ == '__main__':
unittest.main()
if __name__ == "__main__":
unittest.main()