From ae80da939502025e7002cc184d95bb292fe771bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 13 Mar 2026 03:38:04 +0000 Subject: [PATCH 1/2] Fix `accuracy_reward` crash when called from non-main thread --- tests/test_rewards.py | 24 +++++++++++++++++++++ trl/rewards/accuracy_rewards.py | 37 +++++++++++++++++++++++++++++---- 2 files changed, 57 insertions(+), 4 deletions(-) diff --git a/tests/test_rewards.py b/tests/test_rewards.py index 2442b3bd2ce..fbf6b4f52c1 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import threading + from trl.rewards import accuracy_reward, get_soft_overlong_punishment, reasoning_accuracy_reward, think_format_reward from .testing_utils import TrlTestCase, require_math_latex @@ -128,6 +130,28 @@ def test_accuracy_reward_unparsable_gold(self): assert rewards[1] is None + @require_math_latex + def test_accuracy_reward_in_worker_thread(self): + """Test that accuracy_reward works when called from a non-main thread.""" + completions = [[{"content": r"\boxed{\frac{1}{3}}"}]] + solutions = [r"\frac{1}{3}"] + results = [] + exceptions = [] + + def target(): + try: + results.extend(accuracy_reward(completions, solutions)) + except Exception as e: + exceptions.append(e) + + t = threading.Thread(target=target) + t.start() + t.join() + + assert not exceptions, f"accuracy_reward raised in worker thread: {exceptions[0]}" + assert results == [1.0] + + class TestReasoningAccuracyReward: @require_math_latex def test_correct_answer_yields_unit_reward(self): diff --git a/trl/rewards/accuracy_rewards.py b/trl/rewards/accuracy_rewards.py index a06c84bc413..7970a9a4cae 100644 --- a/trl/rewards/accuracy_rewards.py +++ b/trl/rewards/accuracy_rewards.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +import threading + from ..import_utils import is_math_verify_available @@ -53,8 +56,20 @@ def accuracy_reward(completions: list[list[dict[str, str]]], solution: list[str] contents = [completion[0]["content"] for completion in completions] rewards = [] + + # math_verify uses signal.alarm() for timeouts, which only works in the main thread. + # Disable timeouts when running in a non-main thread to avoid ValueError. + is_main_thread = threading.current_thread() is threading.main_thread() + parsing_timeout = None if not is_main_thread else 10 + verify_timeout = None if not is_main_thread else 5 + + # Suppress the "Timeout is disabled" warnings from math_verify when we intentionally disable timeouts + if not is_main_thread: + logging.getLogger("math_verify.parser").setLevel(logging.ERROR) + logging.getLogger("math_verify.grader").setLevel(logging.ERROR) + for content, sol in zip(contents, solution, strict=True): - gold_parsed = parse(sol) + gold_parsed = parse(sol, parsing_timeout=parsing_timeout) if len(gold_parsed) != 0: # We require the answer to be provided in correct latex (no malformed operators) answer_parsed = parse( @@ -68,8 +83,9 @@ def accuracy_reward(completions: list[list[dict[str, str]]], solution: list[str] ) ], extraction_mode="first_match", + parsing_timeout=parsing_timeout, ) - reward = float(verify(gold_parsed, answer_parsed)) + reward = float(verify(gold_parsed, answer_parsed, timeout_seconds=verify_timeout)) else: # If the gold solution cannot be parsed, we assign `None` to skip this example reward = None @@ -140,6 +156,18 @@ def reasoning_accuracy_reward( rewards = [] contents = [completion[0]["content"] for completion in completions] + + # math_verify uses signal.alarm() for timeouts, which only works in the main thread. + # Disable timeouts when running in a non-main thread to avoid ValueError. + is_main_thread = threading.current_thread() is threading.main_thread() + parsing_timeout = None if not is_main_thread else 10 + verify_timeout = None if not is_main_thread else 5 + + # Suppress the "Timeout is disabled" warnings from math_verify when we intentionally disable timeouts + if not is_main_thread: + logging.getLogger("math_verify.parser").setLevel(logging.ERROR) + logging.getLogger("math_verify.grader").setLevel(logging.ERROR) + for content, sol in zip(contents, solution, strict=True): # Split final answer from reasoning content is_reasoning_complete = False @@ -153,7 +181,7 @@ def reasoning_accuracy_reward( rewards.append(0.0) continue - gold_parsed = parse(sol) + gold_parsed = parse(sol, parsing_timeout=parsing_timeout) if len(gold_parsed) != 0: # We require the answer to be provided in correct latex (no malformed operators) answer_parsed = parse( @@ -168,8 +196,9 @@ def reasoning_accuracy_reward( ) ], extraction_mode="first_match", + parsing_timeout=parsing_timeout, ) - reward = float(verify(gold_parsed, answer_parsed)) + reward = float(verify(gold_parsed, answer_parsed, timeout_seconds=verify_timeout)) else: # If the gold solution cannot be parsed, we assign `None` to skip this example reward = None From e0e6a766cbf09e2157c64edce2e4e6fc08b15897 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 13 Mar 2026 04:12:10 +0000 Subject: [PATCH 2/2] code style --- tests/test_rewards.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_rewards.py b/tests/test_rewards.py index fbf6b4f52c1..890ee95fdef 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -129,7 +129,6 @@ def test_accuracy_reward_unparsable_gold(self): assert rewards[0] is None assert rewards[1] is None - @require_math_latex def test_accuracy_reward_in_worker_thread(self): """Test that accuracy_reward works when called from a non-main thread."""