11import unittest
2- from open_r1 .rewards import accuracy_reward , format_reward , reasoning_steps_reward
2+
3+ from open_r1 .rewards import accuracy_reward , cosine_scaled_reward , format_reward , reasoning_steps_reward
34
45
56class TestRewards (unittest .TestCase ):
67 def test_accuracy_reward_correct_answer (self ):
78 """Test accuracy_reward with a correct answer."""
89 completion = [[{"content" : r"\boxed{\frac{63}{400}}" }]]
910 solution = [r"\frac{63}{400}" ]
10-
11+
1112 rewards = accuracy_reward (completion , solution )
1213 self .assertEqual (rewards [0 ], 1.0 )
1314
1415 def test_accuracy_reward_wrong_answer (self ):
1516 """Test accuracy_reward with an incorrect answer."""
1617 completion = [[{"content" : r"\boxed{\frac{64}{400}}" }]]
1718 solution = [r"\frac{63}{400}" ]
18-
19+
1920 rewards = accuracy_reward (completion , solution )
2021 self .assertEqual (rewards [0 ], 0.0 )
2122
@@ -32,9 +33,9 @@ def test_format_reward_incorrect(self):
3233 "<answer>Only answer</answer>" ,
3334 "No tags at all" ,
3435 "<think>Missing closing</think><answer>Missing closing" ,
35- "<think>Wrong order</answer><answer>Wrong order</think>"
36+ "<think>Wrong order</answer><answer>Wrong order</think>" ,
3637 ]
37-
38+
3839 for fmt in incorrect_formats :
3940 completion = [[{"content" : fmt }]]
4041 rewards = format_reward (completion )
@@ -44,48 +45,58 @@ def test_reasoning_steps_reward(self):
4445 """Test reasoning_steps_reward with various formats."""
4546 test_cases = [
4647 # Full credit cases (3 or more steps)
47- (
48- "Step 1: First step\n Step 2: Second step\n Step 3: Third step" ,
49- 1.0
50- ),
51- (
52- "First, we do this.\n Second, we do that.\n Finally, we conclude." ,
53- 1.0
54- ),
48+ ("Step 1: First step\n Step 2: Second step\n Step 3: Third step" , 1.0 ),
49+ ("First, we do this.\n Second, we do that.\n Finally, we conclude." , 1.0 ),
5550 # Partial credit cases (less than 3 steps)
56- (
57- "Step 1: Only step" ,
58- 1 / 3
59- ),
60- (
61- "First, we do this.\n Finally, we conclude." ,
62- 2 / 3
63- ),
51+ ("Step 1: Only step" , 1 / 3 ),
52+ ("First, we do this.\n Finally, we conclude." , 2 / 3 ),
6453 # No credit case
65- (
66- "Just plain text without any clear steps" ,
67- 0.0
68- )
54+ ("Just plain text without any clear steps" , 0.0 ),
6955 ]
70-
56+
7157 for content , expected_reward in test_cases :
7258 completion = [[{"content" : content }]]
7359 rewards = reasoning_steps_reward (completion )
7460 self .assertAlmostEqual (rewards [0 ], expected_reward )
7561
7662 def test_multiple_completions (self ):
7763 """Test handling multiple completions at once."""
78- completions = [
79- [{"content" : r"\boxed{\frac{63}{400}}" }],
80- [{"content" : r"\boxed{\frac{64}{400}}" }]
81- ]
64+ completions = [[{"content" : r"\boxed{\frac{63}{400}}" }], [{"content" : r"\boxed{\frac{64}{400}}" }]]
8265 solutions = [r"\frac{63}{400}" , r"\frac{63}{400}" ]
83-
66+
8467 rewards = accuracy_reward (completions , solutions )
8568 self .assertEqual (len (rewards ), 2 )
8669 self .assertEqual (rewards [0 ], 1.0 )
8770 self .assertEqual (rewards [1 ], 0.0 )
8871
72+ def test_cosine_scaled_reward (self ):
73+ """Test cosine_scaled_reward with various cases."""
74+ # Test parameters
75+ test_params = {
76+ "min_value_wrong" : - 1.0 ,
77+ "max_value_wrong" : - 0.5 ,
78+ "min_value_correct" : 0.5 ,
79+ "max_value_correct" : 1.0 ,
80+ "max_len" : 100 ,
81+ }
82+
83+ test_cases = [
84+ # Correct answers with different lengths
85+ (r"\boxed{\frac{63}{400}}" , r"\frac{63}{400}" , 20 , 0.943 ), # Short correct answer
86+ (r"\boxed{\frac{63}{400}}" , r"\frac{63}{400}" , 80 , 0.547 ), # Long correct answer
87+ # Wrong answers with different lengths
88+ (r"\boxed{\frac{64}{400}}" , r"\frac{63}{400}" , 20 , - 0.942 ), # Short wrong answer
89+ (r"\boxed{\frac{64}{400}}" , r"\frac{63}{400}" , 80 , - 0.547 ), # Long wrong answer
90+ ]
91+
92+ for content , solution , content_len , expected_reward in test_cases :
93+ # Pad content to desired length
94+ padded_content = content + " " * (content_len - len (content ))
95+ completion = [[{"content" : padded_content }]]
96+
97+ rewards = cosine_scaled_reward (completion , [solution ], ** test_params )
98+ self .assertAlmostEqual (rewards [0 ], expected_reward , places = 2 )
99+
89100
90- if __name__ == ' __main__' :
91- unittest .main ()
101+ if __name__ == " __main__" :
102+ unittest .main ()
0 commit comments