Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions tests/test_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import threading

from trl.rewards import accuracy_reward, get_soft_overlong_punishment, reasoning_accuracy_reward, think_format_reward

from .testing_utils import TrlTestCase, require_math_latex
Expand Down Expand Up @@ -127,6 +129,27 @@ def test_accuracy_reward_unparsable_gold(self):
assert rewards[0] is None
assert rewards[1] is None

@require_math_latex
def test_accuracy_reward_in_worker_thread(self):
"""Test that accuracy_reward works when called from a non-main thread."""
completions = [[{"content": r"\boxed{\frac{1}{3}}"}]]
solutions = [r"\frac{1}{3}"]
results = []
exceptions = []

def target():
try:
results.extend(accuracy_reward(completions, solutions))
except Exception as e:
exceptions.append(e)

t = threading.Thread(target=target)
t.start()
t.join()

assert not exceptions, f"accuracy_reward raised in worker thread: {exceptions[0]}"
assert results == [1.0]


class TestReasoningAccuracyReward:
@require_math_latex
Expand Down
37 changes: 33 additions & 4 deletions trl/rewards/accuracy_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import threading

from ..import_utils import is_math_verify_available


Expand Down Expand Up @@ -53,8 +56,20 @@ def accuracy_reward(completions: list[list[dict[str, str]]], solution: list[str]

contents = [completion[0]["content"] for completion in completions]
rewards = []

# math_verify uses signal.alarm() for timeouts, which only works in the main thread.
# Disable timeouts when running in a non-main thread to avoid ValueError.
is_main_thread = threading.current_thread() is threading.main_thread()
parsing_timeout = None if not is_main_thread else 10
verify_timeout = None if not is_main_thread else 5

# Suppress the "Timeout is disabled" warnings from math_verify when we intentionally disable timeouts
if not is_main_thread:
logging.getLogger("math_verify.parser").setLevel(logging.ERROR)
logging.getLogger("math_verify.grader").setLevel(logging.ERROR)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logger level permanently modified as global side effect

Medium Severity

logging.getLogger() returns a process-wide singleton, so calling setLevel(logging.ERROR) permanently suppresses warnings from math_verify.parser and math_verify.grader for the entire process. After any single call from a non-main thread, all subsequent calls — including from the main thread — will have these warnings silenced. The log levels are never restored after the function returns. A scoped approach (e.g., saving and restoring the original level in a try/finally) would avoid this persistent global side effect.

Additional Locations (1)
Fix in Cursor Fix in Web


for content, sol in zip(contents, solution, strict=True):
gold_parsed = parse(sol)
gold_parsed = parse(sol, parsing_timeout=parsing_timeout)
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
answer_parsed = parse(
Expand All @@ -68,8 +83,9 @@ def accuracy_reward(completions: list[list[dict[str, str]]], solution: list[str]
)
],
extraction_mode="first_match",
parsing_timeout=parsing_timeout,
)
reward = float(verify(gold_parsed, answer_parsed))
reward = float(verify(gold_parsed, answer_parsed, timeout_seconds=verify_timeout))
else:
# If the gold solution cannot be parsed, we assign `None` to skip this example
reward = None
Expand Down Expand Up @@ -140,6 +156,18 @@ def reasoning_accuracy_reward(

rewards = []
contents = [completion[0]["content"] for completion in completions]

# math_verify uses signal.alarm() for timeouts, which only works in the main thread.
# Disable timeouts when running in a non-main thread to avoid ValueError.
is_main_thread = threading.current_thread() is threading.main_thread()
parsing_timeout = None if not is_main_thread else 10
verify_timeout = None if not is_main_thread else 5

# Suppress the "Timeout is disabled" warnings from math_verify when we intentionally disable timeouts
if not is_main_thread:
logging.getLogger("math_verify.parser").setLevel(logging.ERROR)
logging.getLogger("math_verify.grader").setLevel(logging.ERROR)

for content, sol in zip(contents, solution, strict=True):
# Split final answer from reasoning content
is_reasoning_complete = False
Expand All @@ -153,7 +181,7 @@ def reasoning_accuracy_reward(
rewards.append(0.0)
continue

gold_parsed = parse(sol)
gold_parsed = parse(sol, parsing_timeout=parsing_timeout)
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
answer_parsed = parse(
Expand All @@ -168,8 +196,9 @@ def reasoning_accuracy_reward(
)
],
extraction_mode="first_match",
parsing_timeout=parsing_timeout,
)
reward = float(verify(gold_parsed, answer_parsed))
reward = float(verify(gold_parsed, answer_parsed, timeout_seconds=verify_timeout))
else:
# If the gold solution cannot be parsed, we assign `None` to skip this example
reward = None
Expand Down
Loading