Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions src/open_r1/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def accuracy_reward(completions, solution, **kwargs):
rewards = []
for content, sol in zip(contents, solution):
gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])

skip = False
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
answer_parsed = parse(
Expand All @@ -78,12 +80,22 @@ def accuracy_reward(completions, solution, **kwargs):
],
extraction_mode="first_match",
)
# Reward 1 if the content is the same as the ground truth, 0 otherwise
reward = float(verify(answer_parsed, gold_parsed))

try:
# Reward 1 if the content is the same as the ground truth, 0 otherwise
reward = float(verify(answer_parsed, gold_parsed))
except Exception as e:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the exception type raised here? I'd go for a more specific catching

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not frankly, since any error with math-verify is enough to cause a skip. If we wanted to catch specific exceptions and print something different each time, that's one thing, but if math-verify ever throws, the training run dies.

print("Failed to verify answer: ", content)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
print("Failed to verify answer: ", content)
logger.debug(f"Failed to verify answer {content} due to {e}")

print(e)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print(e)

skip = True
else:
# If the gold solution is not parseable, we reward 1 to skip this example
reward = 1.0
print("Failed to parse gold solution: ", sol)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print("Failed to parse gold solution: ", sol)
logger.debug("Failed to parse gold solution: ", sol)

skip = True

if skip:
# We reward 1 to skip this example
reward = 1.0

rewards.append(reward)

return rewards
Expand Down