Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional r1-style thinking reward #551

Merged
merged 10 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions open_instruct/dataset_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,24 @@ def visualize_token(tokens: list[int], tokenizer: PreTrainedTokenizer):
"{% endif %}"
"{% endfor %}"
),
"r1_simple_chat_prepend_think": (
"A conversation between User and Assistant. "
"The user asks a question, and the Assistant solves it. "
"The assistant first thinks about the reasoning process in "
"the mind and then provides the user with the answer. "
"The reasoning process and answer are enclosed within <think> </think> "
"and <answer> </answer> tags, respectively, "
"i.e., <think> reasoning process here </think> "
"<answer> answer here </answer>."
"\n\n"
"{% for message in messages %}"
"{{ '\n\n' if not loop.first else '' }}"
"{{ message['role'].capitalize() + ': ' + message['content'] + '\n' }}"
"{% if loop.last and add_generation_prompt %}"
"{{ 'Assistant: <think>' }}"
"{% endif %}"
"{% endfor %}"
),
}
# flake8: noqa

Expand Down
108 changes: 1 addition & 107 deletions open_instruct/ground_truth_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def verify_flan_sample(model_output, ground_truth_answer):

def soft_format_reward_func(responses: list[str]) -> list[float]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would agree with nathan that it feels like reward weight should be a param (and maybe even what pattern you are looking for?), to help with tuning reward stuff in the future.

"""Reward function that checks if the completion has a specific format."""
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
pattern = r".*?</think>\s*<answer>.*?</answer>"
matches = [re.match(pattern, r, re.DOTALL) for r in responses]
return [0.5 if match else 0.0 for match in matches]

Expand All @@ -150,115 +150,9 @@ def strict_format_reward_func(responses: list[str]) -> list[float]:
matches = [re.match(pattern, r, re.DOTALL) for r in responses]
return [1.0 if match else 0.0 for match in matches]

def test_basic_valid():
response = """<think>
The sky is blue
</think>
<answer>
Blue
</answer>
"""
result = strict_format_reward_func([response])
assert result == [1.0], f"Basic valid case failed, got {result}"
print("✓ Basic valid case passed")

def test_multiline_think():
response = """<think>
The sky is blue because of Rayleigh scattering
This is a second line of think
And here's a third line
</think>
<answer>
Blue
</answer>
"""
result = strict_format_reward_func([response])
assert result == [1.0], f"Multiline think failed, got {result}"
print("✓ Multiline think passed")

def test_multiline_answer():
response = """<think>
The sky is blue
</think>
<answer>
Blue
And also sometimes lighter blue
And even white when cloudy
</answer>
"""
result = strict_format_reward_func([response])
assert result == [1.0], f"Multiline answer failed, got {result}"
print("✓ Multiline answer passed")

def test_no_newlines():
response = "<think>The sky is blue</think><answer>Blue</answer>"
result = strict_format_reward_func([response])
assert result == [0.0], f"No newlines case failed, got {result}"
print("✓ No newlines case passed")

def test_missing_final_newline():
response = """<think>
The sky is blue
</think>
<answer>
Blue
</answer>""" # No final newline
result = strict_format_reward_func([response])
assert result == [0.0], f"Missing final newline failed, got {result}"
print("✓ Missing final newline passed")

def test_extra_content():
response = """Extra content
<think>
The sky is blue
</think>
<answer>
Blue
</answer>
"""
result = strict_format_reward_func([response])
assert result == [0.0], f"Extra content case failed, got {result}"
print("✓ Extra content case passed")

def test_wrong_order():
response = """<answer>
Blue
</answer>
<think>
The sky is blue
</think>
"""
result = strict_format_reward_func([response])
assert result == [0.0], f"Wrong order case failed, got {result}"
print("✓ Wrong order case passed")

def test_multiple_responses():
valid = """<think>
First think
</think>
<answer>
First answer
</answer>
"""
invalid = "Invalid format"

result = strict_format_reward_func([valid, invalid])
assert result == [1.0, 0.0], f"Multiple responses failed, got {result}"
print("✓ Multiple responses passed")


# debug code
if __name__ == "__main__":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vwxyzjn lets add this to the tests check we have?
Also, let's make the scale of the reward set by a hyperparam / config? Could get tricky reward shaping issues.

print("Running tests...")
test_basic_valid()
test_multiline_think()
test_multiline_answer()
test_no_newlines()
test_missing_final_newline()
test_extra_content()
test_wrong_order()
test_multiple_responses()
print("\nAll tests passed!")
from datasets import load_dataset
ds = load_dataset("ai2-adapt-dev/prompts_with_constraints_for_ground_truth")
test_model_output = "<|assistant|>\nThe answer is $\\boxed{3.14}$"
Expand Down
Loading