diff --git a/src/open_r1/grpo.py b/src/open_r1/grpo.py index 4bdc335f7..31006ef16 100644 --- a/src/open_r1/grpo.py +++ b/src/open_r1/grpo.py @@ -57,6 +57,8 @@ def accuracy_reward(completions, solution, **kwargs): rewards = [] for content, sol in zip(contents, solution): gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()]) + + skip = False if len(gold_parsed) != 0: # We require the answer to be provided in correct latex (no malformed operators) answer_parsed = parse( @@ -78,12 +80,22 @@ def accuracy_reward(completions, solution, **kwargs): ], extraction_mode="first_match", ) - # Reward 1 if the content is the same as the ground truth, 0 otherwise - reward = float(verify(answer_parsed, gold_parsed)) + + try: + # Reward 1 if the content is the same as the ground truth, 0 otherwise + reward = float(verify(answer_parsed, gold_parsed)) + except Exception as e: + print("Failed to verify answer: ", content) + print(e) + skip = True else: - # If the gold solution is not parseable, we reward 1 to skip this example - reward = 1.0 print("Failed to parse gold solution: ", sol) + skip = True + + if skip: + # We reward 1 to skip this example + reward = 1.0 + rewards.append(reward) return rewards