Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions functioncall/code/local_verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def call_verify(problem, generation, debug, timeout=SINGLE_CASE_EXEC_TIMEOUT):
return result["result"], result["info"]


def code_verify(id2info, generateds, query_ids, debug=False):
def code_verify(id2info, generateds, query_ids, max_workers=None, debug=False):
assert len(generateds) == len(query_ids)
problems = [id2info[qid] for qid in query_ids]

Expand All @@ -106,8 +106,10 @@ def code_verify(id2info, generateds, query_ids, debug=False):
infer_args.append((problem, generated, debug, SINGLE_CASE_EXEC_TIMEOUT))

run_results = []
num_process = max(1, os.cpu_count() // 8)
with concurrent.futures.ProcessPoolExecutor(num_process) as executor:
if max_workers is None:
max_workers = max(1, os.cpu_count() // 8)

with concurrent.futures.ProcessPoolExecutor(max_workers) as executor:
run_results = executor.map(call_verify, *zip(*infer_args))

for run_result in run_results:
Expand Down
11 changes: 9 additions & 2 deletions realhf/impl/dataset/math_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import regex
from latex2sympy2 import latex2sympy
from pebble import ProcessPool
from pebble import ProcessExpired, ProcessPool
from sympy import N, simplify
from sympy.parsing.latex import parse_latex
from sympy.parsing.sympy_parser import parse_expr
Expand Down Expand Up @@ -289,6 +289,7 @@ def strip_string(string, skip_unit=False):

# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "")
string = string.replace("%", "")

# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
Expand Down Expand Up @@ -398,7 +399,7 @@ def extract_answer(pred_str, data_name, use_last_number=True):
pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip()
else: # use the last number
if use_last_number:
pattern = r"-?\d*\.?\d+"
pattern = "-?\d*\.?\d+"
pred = re.findall(pattern, pred_str.replace(",", ""))
if len(pred) >= 1:
pred = pred[-1]
Expand Down Expand Up @@ -836,6 +837,12 @@ def parse_lines_in_parallel(
# print("[debug: timeout]")
logger.warning(f"Timeout occurred while justifying the math answer.")
x = (0, "timeout", "timeout")
except ProcessExpired as e:
logger.warning(f"Process terminated abnormally: {e}")
x = (0, "error", "error")
except Exception as e:
logger.warning(f"Other error occurred: {e.__class__.__name__}, {e}")
x = (0, "error", "error")
label = label or x[0]
labels.append(label)
return labels
Expand Down
2 changes: 2 additions & 0 deletions realhf/impl/environment/math_code_single_step_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ async def step(self, action: Tuple[str, List[str]]):
self.id2info,
answers,
[qid for _ in range(group_size)],
max_workers=1,
)
elif cur_task == "code":
answers = [extract_code(x) for x in answers]
Expand All @@ -65,6 +66,7 @@ async def step(self, action: Tuple[str, List[str]]):
self.id2info,
answers,
[qid for _ in range(group_size)],
max_workers=1,
)
else:
raise NotImplementedError()
Expand Down