Skip to content

Commit fa78e5d

Browse files
authored
[GRPO] add cosine reward (huggingface#206)
* add cosine reward * fix merge * fix typo * fix check
1 parent 4e42ee9 commit fa78e5d

File tree

5 files changed

+170
-43
lines changed

5 files changed

+170
-43
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
44
export PYTHONPATH = src
55

6-
check_dirs := src
6+
check_dirs := src tests
77

88
style:
99
ruff format --line-length 119 --target-version py310 $(check_dirs) setup.py

setup.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
"hf_transfer>=0.1.4",
5353
"huggingface-hub[cli]>=0.19.2,<1.0",
5454
"isort>=5.12.0",
55+
"latex2sympy2_extended>=1.0.6",
56+
"math-verify>=0.5.2",
5557
"liger_kernel==0.5.2",
5658
"lighteval @ git+https://github.com/huggingface/lighteval.git@86f62259f105ae164f655e0b91c92a823a742724#egg=lighteval[math]",
5759
"math-verify==0.5.2", # Used for math verification in grpo
@@ -96,6 +98,8 @@ def deps_list(*pkgs):
9698
deps["deepspeed"],
9799
deps["hf_transfer"],
98100
deps["huggingface-hub"],
101+
deps["latex2sympy2_extended"],
102+
deps["math-verify"],
99103
deps["liger_kernel"],
100104
deps["packaging"], # utilities from PyPA to e.g., compare versions
101105
deps["safetensors"],

src/open_r1/grpo.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
import sys
1818
from dataclasses import dataclass, field
19+
from functools import partial
1920

2021
import datasets
2122
import torch
@@ -25,7 +26,7 @@
2526
from transformers.trainer_utils import get_last_checkpoint
2627

2728
from open_r1.configs import GRPOConfig
28-
from open_r1.rewards import REWARD_FUNCS_REGISTRY
29+
from open_r1.rewards import accuracy_reward, cosine_scaled_reward, format_reward, reasoning_steps_reward
2930
from open_r1.utils.callbacks import get_callbacks
3031
from trl import GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
3132

@@ -40,15 +41,45 @@ class GRPOScriptArguments(ScriptArguments):
4041
4142
Args:
4243
reward_funcs (`list[str]`):
43-
List of reward functions. Possible values are dynamically populated from REWARD_FUNCS_REGISTRY.
44+
List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine'.
45+
cosine_min_value_wrong (`float`):
46+
Minimum reward for cosine scaling for wrong answers.
47+
cosine_max_value_wrong (`float`):
48+
Maximum reward for cosine scaling for wrong answers.
49+
cosine_min_value_correct (`float`):
50+
Minimum reward for cosine scaling for correct answers.
51+
cosine_max_value_correct (`float`):
52+
Maximum reward for cosine scaling for correct answers.
53+
cosine_max_len (`int`):
54+
Maximum length for cosine scaling.
4455
"""
4556

4657
reward_funcs: list[str] = field(
47-
default_factory=lambda: ["accuracy", "format"],
58+
default_factory=lambda: ["accuracy", "format", "reasoning_steps", "cosine"],
4859
metadata={
49-
"help": f"List of reward functions. Possible values: {', '.join(REWARD_FUNCS_REGISTRY.keys())}"
60+
"help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine'"
5061
},
5162
)
63+
cosine_min_value_wrong: float = field(
64+
default=0.0,
65+
metadata={"help": "Minimum reward for wrong answers"},
66+
)
67+
cosine_max_value_wrong: float = field(
68+
default=-0.5,
69+
metadata={"help": "Maximum reward for wrong answers"},
70+
)
71+
cosine_min_value_correct: float = field(
72+
default=0.5,
73+
metadata={"help": "Minimum reward for correct answers"},
74+
)
75+
cosine_max_value_correct: float = field(
76+
default=1.0,
77+
metadata={"help": "Maximum reward for correct answers"},
78+
)
79+
cosine_max_len: int = field(
80+
default=1000,
81+
metadata={"help": "Maximum length for scaling"},
82+
)
5283

5384

5485
SYSTEM_PROMPT = (
@@ -98,6 +129,19 @@ def main(script_args, training_args, model_args):
98129
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
99130

100131
# Get reward functions
132+
REWARD_FUNCS_REGISTRY = {
133+
"accuracy": accuracy_reward,
134+
"format": format_reward,
135+
"reasoning_steps": reasoning_steps_reward,
136+
"cosine": partial(
137+
cosine_scaled_reward,
138+
min_value_wrong=script_args.cosine_min_value_wrong,
139+
max_value_wrong=script_args.cosine_max_value_wrong,
140+
min_value_correct=script_args.cosine_min_value_correct,
141+
max_value_correct=script_args.cosine_max_value_correct,
142+
max_len=script_args.cosine_max_len,
143+
),
144+
}
101145
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
102146

103147
# Format into conversation

src/open_r1/rewards.py

Lines changed: 73 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Reward functions for GRPO training."""
22

3+
import math
34
import re
45

56
from latex2sympy2_extended import NormalizationConfig
@@ -73,8 +74,75 @@ def reasoning_steps_reward(completions, **kwargs):
7374
return [min(1.0, count / 3) for count in matches]
7475

7576

76-
REWARD_FUNCS_REGISTRY = {
77-
"accuracy": accuracy_reward,
78-
"format": format_reward,
79-
"reasoning_steps": reasoning_steps_reward,
80-
}
77+
def cosine_scaled_reward(
78+
completions,
79+
solution,
80+
min_value_wrong: float = -1.0,
81+
max_value_wrong: float = -0.5,
82+
min_value_correct: float = 0.5,
83+
max_value_correct: float = 1.0,
84+
max_len: int = 1000,
85+
**kwargs,
86+
):
87+
"""Reward function that scales based on completion length using a cosine schedule.
88+
89+
Shorter correct solutions are rewarded more than longer ones.
90+
Longer incorrect solutions are penalized less than shorter ones.
91+
92+
Args:
93+
completions: List of model completions
94+
solution: List of ground truth solutions
95+
min_value_wrong: Minimum reward for wrong answers
96+
max_value_wrong: Maximum reward for wrong answers
97+
min_value_correct: Minimum reward for correct answers
98+
max_value_correct: Maximum reward for correct answers
99+
max_len: Maximum length for scaling
100+
"""
101+
contents = [completion[0]["content"] for completion in completions]
102+
rewards = []
103+
104+
for content, sol in zip(contents, solution):
105+
gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
106+
if len(gold_parsed) == 0:
107+
rewards.append(1.0) # Skip unparseable examples
108+
print("Failed to parse gold solution: ", sol)
109+
continue
110+
111+
answer_parsed = parse(
112+
content,
113+
extraction_config=[
114+
LatexExtractionConfig(
115+
normalization_config=NormalizationConfig(
116+
nits=False,
117+
malformed_operators=False,
118+
basic_latex=True,
119+
equations=True,
120+
boxed=True,
121+
units=True,
122+
),
123+
boxed_match_priority=0,
124+
try_extract_without_anchor=False,
125+
)
126+
],
127+
extraction_mode="first_match",
128+
)
129+
130+
is_correct = verify(answer_parsed, gold_parsed)
131+
gen_len = len(content)
132+
133+
# Apply cosine scaling based on length
134+
progress = gen_len / max_len
135+
cosine = math.cos(progress * math.pi)
136+
137+
if is_correct:
138+
min_value = min_value_correct
139+
max_value = max_value_correct
140+
else:
141+
# Swap min/max for incorrect answers
142+
min_value = max_value_wrong
143+
max_value = min_value_wrong
144+
145+
reward = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine)
146+
rewards.append(float(reward))
147+
148+
return rewards

tests/test_rewards.py

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
11
import unittest
2-
from open_r1.rewards import accuracy_reward, format_reward, reasoning_steps_reward
2+
3+
from open_r1.rewards import accuracy_reward, cosine_scaled_reward, format_reward, reasoning_steps_reward
34

45

56
class TestRewards(unittest.TestCase):
67
def test_accuracy_reward_correct_answer(self):
78
"""Test accuracy_reward with a correct answer."""
89
completion = [[{"content": r"\boxed{\frac{63}{400}}"}]]
910
solution = [r"\frac{63}{400}"]
10-
11+
1112
rewards = accuracy_reward(completion, solution)
1213
self.assertEqual(rewards[0], 1.0)
1314

1415
def test_accuracy_reward_wrong_answer(self):
1516
"""Test accuracy_reward with an incorrect answer."""
1617
completion = [[{"content": r"\boxed{\frac{64}{400}}"}]]
1718
solution = [r"\frac{63}{400}"]
18-
19+
1920
rewards = accuracy_reward(completion, solution)
2021
self.assertEqual(rewards[0], 0.0)
2122

@@ -32,9 +33,9 @@ def test_format_reward_incorrect(self):
3233
"<answer>Only answer</answer>",
3334
"No tags at all",
3435
"<think>Missing closing</think><answer>Missing closing",
35-
"<think>Wrong order</answer><answer>Wrong order</think>"
36+
"<think>Wrong order</answer><answer>Wrong order</think>",
3637
]
37-
38+
3839
for fmt in incorrect_formats:
3940
completion = [[{"content": fmt}]]
4041
rewards = format_reward(completion)
@@ -44,48 +45,58 @@ def test_reasoning_steps_reward(self):
4445
"""Test reasoning_steps_reward with various formats."""
4546
test_cases = [
4647
# Full credit cases (3 or more steps)
47-
(
48-
"Step 1: First step\nStep 2: Second step\nStep 3: Third step",
49-
1.0
50-
),
51-
(
52-
"First, we do this.\nSecond, we do that.\nFinally, we conclude.",
53-
1.0
54-
),
48+
("Step 1: First step\nStep 2: Second step\nStep 3: Third step", 1.0),
49+
("First, we do this.\nSecond, we do that.\nFinally, we conclude.", 1.0),
5550
# Partial credit cases (less than 3 steps)
56-
(
57-
"Step 1: Only step",
58-
1/3
59-
),
60-
(
61-
"First, we do this.\nFinally, we conclude.",
62-
2/3
63-
),
51+
("Step 1: Only step", 1 / 3),
52+
("First, we do this.\nFinally, we conclude.", 2 / 3),
6453
# No credit case
65-
(
66-
"Just plain text without any clear steps",
67-
0.0
68-
)
54+
("Just plain text without any clear steps", 0.0),
6955
]
70-
56+
7157
for content, expected_reward in test_cases:
7258
completion = [[{"content": content}]]
7359
rewards = reasoning_steps_reward(completion)
7460
self.assertAlmostEqual(rewards[0], expected_reward)
7561

7662
def test_multiple_completions(self):
7763
"""Test handling multiple completions at once."""
78-
completions = [
79-
[{"content": r"\boxed{\frac{63}{400}}"}],
80-
[{"content": r"\boxed{\frac{64}{400}}"}]
81-
]
64+
completions = [[{"content": r"\boxed{\frac{63}{400}}"}], [{"content": r"\boxed{\frac{64}{400}}"}]]
8265
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
83-
66+
8467
rewards = accuracy_reward(completions, solutions)
8568
self.assertEqual(len(rewards), 2)
8669
self.assertEqual(rewards[0], 1.0)
8770
self.assertEqual(rewards[1], 0.0)
8871

72+
def test_cosine_scaled_reward(self):
73+
"""Test cosine_scaled_reward with various cases."""
74+
# Test parameters
75+
test_params = {
76+
"min_value_wrong": -1.0,
77+
"max_value_wrong": -0.5,
78+
"min_value_correct": 0.5,
79+
"max_value_correct": 1.0,
80+
"max_len": 100,
81+
}
82+
83+
test_cases = [
84+
# Correct answers with different lengths
85+
(r"\boxed{\frac{63}{400}}", r"\frac{63}{400}", 20, 0.943), # Short correct answer
86+
(r"\boxed{\frac{63}{400}}", r"\frac{63}{400}", 80, 0.547), # Long correct answer
87+
# Wrong answers with different lengths
88+
(r"\boxed{\frac{64}{400}}", r"\frac{63}{400}", 20, -0.942), # Short wrong answer
89+
(r"\boxed{\frac{64}{400}}", r"\frac{63}{400}", 80, -0.547), # Long wrong answer
90+
]
91+
92+
for content, solution, content_len, expected_reward in test_cases:
93+
# Pad content to desired length
94+
padded_content = content + " " * (content_len - len(content))
95+
completion = [[{"content": padded_content}]]
96+
97+
rewards = cosine_scaled_reward(completion, [solution], **test_params)
98+
self.assertAlmostEqual(rewards[0], expected_reward, places=2)
99+
89100

90-
if __name__ == '__main__':
91-
unittest.main()
101+
if __name__ == "__main__":
102+
unittest.main()

0 commit comments

Comments
 (0)