From 0cfb13433da3d6826cb6a8c9dde6df6c18ebba30 Mon Sep 17 00:00:00 2001 From: Guangming Sheng Date: Sat, 15 Mar 2025 10:03:59 +0000 Subject: [PATCH 1/3] support multiple verifier and add reward_fn_key --- verl/trainer/config/ppo_megatron_trainer.yaml | 1 + verl/trainer/config/ppo_trainer.yaml | 1 + verl/trainer/main_ppo.py | 4 +- verl/utils/reward_score/__init__.py | 3 + verl/utils/reward_score/gpqa.py | 27 ++ verl/utils/reward_score/math_v2.py | 363 ++++++++++++++++++ verl/workers/reward_manager/naive.py | 7 +- 7 files changed, 401 insertions(+), 5 deletions(-) create mode 100644 verl/utils/reward_score/gpqa.py create mode 100644 verl/utils/reward_score/math_v2.py diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index c903934b8bb..e1dd29d7b91 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -3,6 +3,7 @@ data: train_files: ~/data/rlhf/gsm8k/train.parquet val_files: ~/data/rlhf/gsm8k/test.parquet prompt_key: prompt + reward_fn_key: data_source max_prompt_length: 512 max_response_length: 512 train_batch_size: 1024 diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index eab987d3aad..ac121a02405 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -3,6 +3,7 @@ data: train_files: ~/data/rlhf/gsm8k/train.parquet val_files: ~/data/rlhf/gsm8k/test.parquet prompt_key: prompt + reward_fn_key: data_source max_prompt_length: 512 max_response_length: 512 train_batch_size: 1024 diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index 3d5837a6c16..65a7865c92e 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -146,10 +146,10 @@ def main_task(config): raise NotImplementedError compute_score = get_custom_reward_fn(config) - reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score) + reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score, reward_fn_key=config.data.reward_fn_key) # Note that we always use function-based RM for validation - val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score) + val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score, reward_fn_key=config.data.reward_fn_key) resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) diff --git a/verl/utils/reward_score/__init__.py b/verl/utils/reward_score/__init__.py index 12599556d0f..82782ec65ab 100644 --- a/verl/utils/reward_score/__init__.py +++ b/verl/utils/reward_score/__init__.py @@ -25,6 +25,9 @@ def _default_compute_score(data_source, solution_str, ground_truth, extra_info=N # Use Math-Verify (https://github.com/huggingface/Math-Verify) for better evaluation accuracy from . import math_verify res = math_verify.compute_score(solution_str, ground_truth) + elif data_source == 'rule-lighteval/MATH_v2': + from . import math_v2 + res = math_v2.compute_score elif data_source in [ 'numina_aops_forum', 'numina_synthetic_math', 'numina_amc_aime', 'numina_synthetic_amc', 'numina_cn_k12', 'numina_olympiads' diff --git a/verl/utils/reward_score/gpqa.py b/verl/utils/reward_score/gpqa.py new file mode 100644 index 00000000000..f8971bf1dc1 --- /dev/null +++ b/verl/utils/reward_score/gpqa.py @@ -0,0 +1,27 @@ +from typing import Optional, Any + +from verl.utils.reward_score.math_v2 import last_boxed_only_string_v2, remove_boxed + + +def compute_score(solution_str, ground_truth, **argv) -> tuple[float, dict[str, Any]]: + pred = last_boxed_only_string_v2(solution_str[-100:]) + correct: bool = False + reward: float = 0 + if pred is None: + # TODO: return a dict for analysis + # extra_info = {"acc": correct, "format/answer": "answer:" in solution_str.lower(), "format/boxed": "boxed{" in solution_str, "format/option": False} + return reward + + pred = remove_boxed(pred) + if pred.upper() == ground_truth: + correct = True + reward = 1 + + # TODO: return a dict for analysis + # extra_info = {"acc": correct, "format/answer": "answer:" in solution_str.lower(), "format/boxed": "boxed{" in solution_str, "format/option": pred.upper() in ["A", "B", "C", "D"]} + + return reward + + +if __name__ == "__main__": + pass diff --git a/verl/utils/reward_score/math_v2.py b/verl/utils/reward_score/math_v2.py new file mode 100644 index 00000000000..fc1f3cb4399 --- /dev/null +++ b/verl/utils/reward_score/math_v2.py @@ -0,0 +1,363 @@ +import re +import signal +from typing import Optional, Any +import torch + +try: + import sympy + from sympy.parsing.latex import parse_latex +except ModuleNotFoundError: + raise ModuleNotFoundError( + "`sympy` is required for generating translation task prompt templates. \ +please install sympy via pip install lm-eval[math] or pip install -e .[math]" , + ) + + +def list_fewshot_samples() -> list[dict]: + return [ + { + "problem": + "Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.}", + "solution": + "The expressions inside each square root must be non-negative. Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$. Also, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$. Therefore, the domain of the expression is $\\boxed{[2,5)}$.\nFinal Answer: The final answer is $[2,5)$. I hope it is correct.", + "few_shot": + "1", + }, + { + "problem": + "If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12,$ then find $\\det (\\mathbf{A} \\mathbf{B}).$", + "solution": + "We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}.$\nFinal Answer: The final answer is $24$. I hope it is correct.", + "few_shot": + "1", + }, + { + "problem": + "Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?", + "solution": + "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight. If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$:\n\\begin{align*}\n30n&=480\\\n\\Rightarrow\\qquad n&=480/30=\\boxed{16}\n\\end{align*}\nFinal Answer: The final answer is $16$. I hope it is correct.", + "few_shot": + "1", + }, + { + "problem": + "If the system of equations\n\n\\begin{align*}\n6x-4y&=a,\\\n6y-9x &=b.\n\\end{align*}has a solution $(x, y)$ where $x$ and $y$ are both nonzero,\nfind $\\frac{a}{b},$ assuming $b$ is nonzero.", + "solution": + "If we multiply the first equation by $-\\frac{3}{2}$, we obtain\n\n$$6y-9x=-\\frac{3}{2}a.$$Since we also know that $6y-9x=b$, we have\n\n$$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}.$$\nFinal Answer: The final answer is $-\\frac{2}{3}$. I hope it is correct.", + "few_shot": + "1", + }, + ] + +def last_boxed_only_string_v2(string: str) -> Optional[str]: + """ + find last \\boxed{...} + """ + idx = string.rfind("\\boxed{") + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + if right_brace_idx is None: + retval = None + else: + retval = string[idx:right_brace_idx + 1] + + return retval + + +def remove_boxed(s: str) -> str: + left = "\\boxed{" + + assert s[:len(left)] == left, f"box error: {s}" + assert s[-1] == "}", f"box error: {s}" + + return s[len(left):-1] + + +class timeout: + + def __init__(self, seconds=1, error_message="Timeout"): + self.seconds = seconds + self.error_message = error_message + + def handle_timeout(self, signum, frame): + raise TimeoutError(self.error_message) + + def __enter__(self): + signal.signal(signal.SIGALRM, self.handle_timeout) + signal.alarm(self.seconds) + + def __exit__(self, type, value, traceback): + signal.alarm(0) + + +def is_equiv(x1: str, x2: str) -> bool: + """ + x1 and x2 are normalized latex string + """ + try: + with timeout(seconds=10): + try: + parsed_x1 = parse_latex(x1) + parsed_x2 = parse_latex(x2) + except ( + sympy.parsing.latex.errors.LaTeXParsingError, + sympy.SympifyError, + TypeError, + ): + # eval_logger.debug(f"couldn't parse one of {x1} or {x2}") + return False + + try: + diff = parsed_x1 - parsed_x2 + except TypeError: + # eval_logger.debug(f"couldn't subtract {x1} and {x2}") + return False + + try: + if sympy.simplify(diff) == 0: + return True + else: + return False + except ValueError: + # eval_logger.debug( + # f"Had some trouble simplifying when comparing {x1} and {x2}" + # ) + return False + + except TimeoutError: + # eval_logger.debug(f"Timed out comparing {x1} and {x2}") + return False + except ImportError as e: + # eval_logger.error(e) + raise + except Exception as e: + # eval_logger.debug(f"Failed comparing {x1} and {x2} with {e}") + return False + + +SUBSTITUTIONS = [ + ("an ", ""), + ("a ", ""), + (".$", "$"), + ("\\$", ""), + (r"\ ", ""), + (" ", ""), + ("mbox", "text"), + (",\\text{and}", ","), + ("\\text{and}", ","), + ("\\text{m}", "\\text{}"), +] +REMOVED_EXPRESSIONS = [ + "square", + "ways", + "integers", + "dollars", + "mph", + "inches", + # "ft", #this is dangerous, infty, left will be damaged! + "hours", + "km", + "units", + "\\ldots", + "sue", + "points", + "feet", + "minutes", + "digits", + "cents", + "degrees", + "cm", + "gm", + "pounds", + "meters", + "meals", + "edges", + "students", + "childrentickets", + "multiples", + "\\text{s}", + "\\text{.}", + "\\text{\ns}", + "\\text{}^2", + "\\text{}^3", + "\\text{\n}", + "\\text{}", + r"\mathrm{th}", + r"^\circ", + r"^{\circ}", + r"\;", + r",\!", + "{,}", + '"', + "\\dots", +] + + +def normalize_final_answer(final_answer: str) -> str: + """ + Normalize a final answer to a quantitative reasoning question. + + Copied character for character from appendix D of Lewkowycz et al. (2022) + """ + final_answer = final_answer.split("=")[-1] + + for before, after in SUBSTITUTIONS: + final_answer = final_answer.replace(before, after) + for expr in REMOVED_EXPRESSIONS: + final_answer = final_answer.replace(expr, "") + + # Extract answer that is in LaTeX math, is bold, + # is surrounded by a box, etc. + final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) + final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) + + # Normalize shorthand TeX: + # \fracab -> \frac{a}{b} + # \frac{abc}{bef} -> \frac{abc}{bef} + # \fracabc -> \frac{a}{b}c + # \sqrta -> \sqrt{a} + # \sqrtab -> sqrt{a}b + final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) + final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) + final_answer = final_answer.replace("$", "") + + # Normalize 100,000 -> 100000 + if final_answer.replace(",", "").isdigit(): + final_answer = final_answer.replace(",", "") + + return final_answer.strip() + + +INVALID_ANS_GSM8k = "[invalid]" +ANSWER_PATTERN = r"(?i)Answer\s*:\s*([^\n]+)" +STRICT_BOX_PATTERN = r"\\boxed\{([^}]*)\}" + + +def filter_ignores(st, regexes_to_ignore): + if regexes_to_ignore is not None: + for s in regexes_to_ignore: + st = re.sub(s, "", st) + return st + + +def is_correct_integer( + og_pred, + gt, +): + numbers = re.findall(r'-?\d+', og_pred[-100:]) + numbers = numbers[-1] if len(numbers) > 0 else "" # 很难通过枚举把最后一个搞成正确答案 + correctness = gt == numbers + return correctness, og_pred[-100:] + + +def is_correct_minerva(og_pred, gt, gt_need_extract=False): + og_pred = og_pred[-300:] #math500最长answer为159 + match = re.findall(ANSWER_PATTERN, og_pred) + extracted_answer = match[-1] if match else "[INVALID]" + pred = normalize_final_answer(extracted_answer) + if gt_need_extract: + gt = normalize_final_answer(remove_boxed(last_boxed_only_string_v2(gt))) + else: + gt = normalize_final_answer(gt) + # return (pred == gt or is_equiv(pred, gt)), pred + return (pred == gt), pred + + +def is_correct_strict_box(pred, gt, pause_tokens_index): + if pause_tokens_index is not None: + assert len(pause_tokens_index) == 4 + pred = pred[pause_tokens_index[-1] - 100:] + else: + pred = pred[-100:] + pred = last_boxed_only_string_v2(pred) + pred = remove_boxed(pred) if pred is not None else None + return 1 if (pred == gt) else -1, pred + + +def verify(pred, answer, resp_len, max_resp_len, add_int_verify=True, strict_box_verify=False, pause_tokens_index=None) -> bool: + + if strict_box_verify: + corr_strict_box, pred_strict_box = is_correct_strict_box(pred, answer, pause_tokens_index) + return corr_strict_box + + corr_minerva, pred_minerva = is_correct_minerva(pred, + answer) # To remove if math is also converted to interger format + if add_int_verify: + corr_integer, pred_integer = is_correct_integer(pred, answer) + pred = pred_minerva if corr_minerva else pred_integer + corr = corr_minerva or corr_integer + else: + pred = pred_minerva + corr = corr_minerva + + return corr + + +def compute_score(batch_info, solution_str, ground_truth, config, rm_name, pause_tokens_index, **argv) -> tuple[float, dict[str, Any]]: + """ + default行为:对给1,其余给-1 + punish_no_answer: + * v0: 0 + * v1: -0.1 + * v2: -0.2 + """ + prompt_length = batch_info['prompts'].shape[-1] + max_resp_len = batch_info['responses'].shape[-1] + resp_len = sum(batch_info['attention_mask'][prompt_length:].tolist()) + add_int_verify = config.reward_model.add_int_verify + strict_box_verify = config.reward_model.strict_box_verify + + correct = verify(solution_str, ground_truth, resp_len, max_resp_len, add_int_verify, strict_box_verify, + pause_tokens_index) + + final_reward = 1 if correct else -1 + + # extra_info = {"acc": correct, "format/answer": "answer:" in solution_str, "format/boxed": "boxed{" in solution_str, "format/option": False} + + return final_reward + + +if __name__ == "__main__": + + # 不含合法答案的长/短回复,启用不含answer惩罚 + # pred = r"""Another way is to look at the height of one corner of our shape considering the big square we are trying to find where this outer tangent line from the big circle to the adjacent small circles is touching the corners of those small circles at specific heights relative to the center of the big circle which may be related to the radius of the circle such as the bottom of the outer tangent line from the big excess outer circle to the adjacent small circles is some height relative to the center of the big outer circle which directly relates to the radius length of the small circles in relationship to the big outer shape in their outer area. We draw a line from the corner of one of the small circles to the outer edge of the big outer shape and from the center of the small circle perpendicular to the outer edge of the big outer shape so that it hits the outer edge of the big outer shape at a right angle which we call $ h $ which may be related to the radius of the small circle and the shape formed at that corner in relationship to the outer edges of the outer shapes. And the distance from the corner of one of the small circles to the outer edge of the big outer shape which includes the half of this figure that represents how high the outer tangent line from the big excess outer circle to the adjacent small circles is touching the corners of those small circles at specific heights relative to the center of the big circle which includes the radius part of the small circle which we call $ l $ . If we know the height of this line from the corner of the big outer main shape and the line perpendicular to the outer edge of the big outer shape which is what we call it as coming from the corner of the big outer main shape and is hitting the outer edge of the big outer shape with the half of this outer tangent line from the big excess outer circle to the adjacent small circles which we call it with the radius of the small circles which we call this value $ h $ which is what we are looking at in relationship to how that relates to the rest of the features of this outer shape. We might use some trigonometry involving the angle between this half part of the outer tangent line from the big excess outer circle to the adjacent small circles being some angle at one corner of one of the small circles in relationship to the other parts of the outer shape to figure out what the value of the radius of the small circle is. And if we have this value of the height of this line from the corner of the big outer main shape and the line perpendicular to the outer edge of the big outer shape which is what we have already called it as coming from the corner of the big outer main shape and is hitting the outer edge of the big outer shape with the half of this outer tangent line from the big excess outer circle to the adjacent small circles which we call it with the radius of the small circles which we call this value $ h $ which is what we have been using to think about in relationship to how that relates to the rest of the features of this outer shape and the angle at one corner of one of the small circles in relationship to the other parts of the outer shape which we call the "angle between relevant sides value" let's call it $ \theta_{rs} $ . Then using the height of this line from the corner of the big outer main shape and the line perpendicular to the outer edge of the big outer shape which is what we call it as coming from the corner of the big outer main shape and is hitting the outer edge of the big outer shape with the half of this outer tangent line""" + answer = "22+12\sqrt{2}" + pred = r"""Answer:22+12\sqrt{2}""" + assert verify(pred, answer, 0, 16384, False, "v0", strict_box_verify=True) == -1 + + pred = "So we have 235 as answer." + answer = "235" + assert verify(pred, answer, 0, 16384, False, "v0", strict_box_verify=True) == -1 + + pred = "So we have \\boxed{235} as answer." + answer = "235" + assert verify(pred, answer, 0, 16384, False, "v0", strict_box_verify=True) == 1 + + pred = "Answer: 11" + answer = "11" + assert verify(pred, answer, 0, 16384, False, "v0", strict_box_verify=True) == -1 + + pred = "Answer: \\boxed{11}" + answer = "-13" + assert verify(pred, answer, 0, 16384, False, "v0", strict_box_verify=True) == -1 + + pred = "Answer: \\boxed{11}" + answer = "11" + assert verify(pred, answer, 0, 16384, False, "v0", strict_box_verify=True) == 1 diff --git a/verl/workers/reward_manager/naive.py b/verl/workers/reward_manager/naive.py index a65dd126bf7..5206fb48d52 100644 --- a/verl/workers/reward_manager/naive.py +++ b/verl/workers/reward_manager/naive.py @@ -21,10 +21,11 @@ class NaiveRewardManager: """The reward manager. """ - def __init__(self, tokenizer, num_examine, compute_score=None) -> None: + def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key='data_source') -> None: self.tokenizer = tokenizer self.num_examine = num_examine # the number of batches of decoded responses to print to the console self.compute_score = compute_score or _default_compute_score + self.reward_fn_key = reward_fn_key def verify(self, data): scores = [] @@ -48,7 +49,7 @@ def verify(self, data): ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth'] - data_source = data_item.non_tensor_batch['data_source'] + data_source = data_item.non_tensor_batch[self.reward_fn_key] extra_info = data_item.non_tensor_batch.get('extra_info', None) @@ -93,7 +94,7 @@ def __call__(self, data: DataProto): ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth'] - data_source = data_item.non_tensor_batch['data_source'] + data_source = data_item.non_tensor_batch[self.reward_fn_key] extra_info = data_item.non_tensor_batch.get('extra_info', None) From 83a36e650937ebc0b1c876a9dd7da1e97ab31126 Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Sat, 15 Mar 2025 11:14:18 +0000 Subject: [PATCH 2/3] feat: verifier for math in Puffin --- verl/utils/reward_score/__init__.py | 6 +- verl/utils/reward_score/gpqa.py | 27 -- verl/utils/reward_score/math_puffin.py | 276 +++++++++++++++++++ verl/utils/reward_score/math_v2.py | 363 ------------------------- 4 files changed, 279 insertions(+), 393 deletions(-) delete mode 100644 verl/utils/reward_score/gpqa.py create mode 100644 verl/utils/reward_score/math_puffin.py delete mode 100644 verl/utils/reward_score/math_v2.py diff --git a/verl/utils/reward_score/__init__.py b/verl/utils/reward_score/__init__.py index 82782ec65ab..d72005da500 100644 --- a/verl/utils/reward_score/__init__.py +++ b/verl/utils/reward_score/__init__.py @@ -25,9 +25,9 @@ def _default_compute_score(data_source, solution_str, ground_truth, extra_info=N # Use Math-Verify (https://github.com/huggingface/Math-Verify) for better evaluation accuracy from . import math_verify res = math_verify.compute_score(solution_str, ground_truth) - elif data_source == 'rule-lighteval/MATH_v2': - from . import math_v2 - res = math_v2.compute_score + elif data_source == 'math_puffin': + from . import math_puffin + res = math_puffin.compute_score elif data_source in [ 'numina_aops_forum', 'numina_synthetic_math', 'numina_amc_aime', 'numina_synthetic_amc', 'numina_cn_k12', 'numina_olympiads' diff --git a/verl/utils/reward_score/gpqa.py b/verl/utils/reward_score/gpqa.py deleted file mode 100644 index f8971bf1dc1..00000000000 --- a/verl/utils/reward_score/gpqa.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Optional, Any - -from verl.utils.reward_score.math_v2 import last_boxed_only_string_v2, remove_boxed - - -def compute_score(solution_str, ground_truth, **argv) -> tuple[float, dict[str, Any]]: - pred = last_boxed_only_string_v2(solution_str[-100:]) - correct: bool = False - reward: float = 0 - if pred is None: - # TODO: return a dict for analysis - # extra_info = {"acc": correct, "format/answer": "answer:" in solution_str.lower(), "format/boxed": "boxed{" in solution_str, "format/option": False} - return reward - - pred = remove_boxed(pred) - if pred.upper() == ground_truth: - correct = True - reward = 1 - - # TODO: return a dict for analysis - # extra_info = {"acc": correct, "format/answer": "answer:" in solution_str.lower(), "format/boxed": "boxed{" in solution_str, "format/option": pred.upper() in ["A", "B", "C", "D"]} - - return reward - - -if __name__ == "__main__": - pass diff --git a/verl/utils/reward_score/math_puffin.py b/verl/utils/reward_score/math_puffin.py new file mode 100644 index 00000000000..e1a4863872f --- /dev/null +++ b/verl/utils/reward_score/math_puffin.py @@ -0,0 +1,276 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py + +import re +import signal +from typing import Optional + +import sympy +from sympy.parsing.latex import parse_latex + + +def last_boxed_only_string(string: str) -> Optional[str]: + """Extract the last LaTeX boxed expression from a string. + + Args: + string: Input string containing LaTeX code + + Returns: + The last boxed expression or None if not found + """ + idx = string.rfind("\\boxed{") + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + return string[idx:right_brace_idx + 1] if right_brace_idx is not None else None + + +def remove_boxed(s: str) -> str: + """Remove the LaTeX boxed command from a string. + + Args: + s: String with format "\\boxed{content}" + + Returns: + The content inside the boxed command + """ + left = "\\boxed{" + assert s[:len(left)] == left, f"box error: {s}" + assert s[-1] == "}", f"box error: {s}" + return s[len(left):-1] + + +class timeout: + + def __init__(self, seconds=1, error_message="Timeout"): + self.seconds = seconds + self.error_message = error_message + + def handle_timeout(self, signum, frame): + raise TimeoutError(self.error_message) + + def __enter__(self): + signal.signal(signal.SIGALRM, self.handle_timeout) + signal.alarm(self.seconds) + + def __exit__(self, type, value, traceback): + signal.alarm(0) + + +def is_equiv(x1: str, x2: str) -> bool: + """ + Args: + x1, x2: normalized LaTeX string + """ + try: + with timeout(seconds=10): + try: + parsed_x1 = parse_latex(x1) + parsed_x2 = parse_latex(x2) + except ( + sympy.parsing.latex.errors.LaTeXParsingError, + sympy.SympifyError, + TypeError, + ): + return False + + try: + diff = parsed_x1 - parsed_x2 + except TypeError: + return False + + try: + if sympy.simplify(diff) == 0: + return True + else: + return False + except ValueError: + return False + + except TimeoutError: + return False + except ImportError as e: + raise + except Exception as e: + return False + + +# Constants for normalization +SUBSTITUTIONS = [ + ("an ", ""), ("a ", ""), (".$", "$"), ("\\$", ""), + (r"\ ", ""), (" ", ""), ("mbox", "text"), + (",\\text{and}", ","), ("\\text{and}", ","), ("\\text{m}", "\\text{}"), +] + +REMOVED_EXPRESSIONS = [ + "square", "ways", "integers", "dollars", "mph", "inches", + "hours", "km", "units", "\\ldots", "sue", "points", "feet", + "minutes", "digits", "cents", "degrees", "cm", "gm", "pounds", + "meters", "meals", "edges", "students", "childrentickets", + "multiples", "\\text{s}", "\\text{.}", "\\text{\ns}", + "\\text{}^2", "\\text{}^3", "\\text{\n}", "\\text{}", + r"\mathrm{th}", r"^\circ", r"^{\circ}", r"\;", r",\!", + "{,}", '"', "\\dots", +] + + +def normalize_final_answer(final_answer: str) -> str: + """Normalize a final answer to a quantitative reasoning question. + + Args: + final_answer: The answer string to normalize + + Returns: + Normalized answer string + """ + final_answer = final_answer.split("=")[-1] + + # Apply substitutions and removals + for before, after in SUBSTITUTIONS: + final_answer = final_answer.replace(before, after) + for expr in REMOVED_EXPRESSIONS: + final_answer = final_answer.replace(expr, "") + + # Extract and normalize LaTeX math + final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) + final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) + + # Normalize shorthand TeX: + # \fracab -> \frac{a}{b} + # \frac{abc}{bef} -> \frac{abc}{bef} + # \fracabc -> \frac{a}{b}c + # \sqrta -> \sqrt{a} + # \sqrtab -> sqrt{a}b + final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) + final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) + final_answer = final_answer.replace("$", "") + + # Normalize numbers + if final_answer.replace(",", "").isdigit(): + final_answer = final_answer.replace(",", "") + + return final_answer.strip() + + +def is_correct_minerva(solution_str: str, gt: str, gt_need_extract: bool = False, + answer_pattern: str = r"(?i)Answer\s*:\s*([^\n]+)") -> tuple[bool, str]: + """Check if the solution is correct according to Minerva criteria. + + Args: + solution_str: The solution string to check + gt: The ground truth answer + gt_need_extract: Whether the ground truth needs extraction + answer_pattern: Regex pattern to extract the answer + + Returns: + Tuple of (is_correct, normalized_prediction) + """ + # Extract answer from solution + match = re.findall(answer_pattern, solution_str) + extracted_answer = match[-1] if match else "[INVALID]" + pred = normalize_final_answer(extracted_answer) + + # Process ground truth + if gt_need_extract: + gt = normalize_final_answer(remove_boxed(last_boxed_only_string(gt))) + else: + gt = normalize_final_answer(gt) + + return (pred == gt or is_equiv(pred, gt)), pred + + +def is_correct_strict_box(pred: str, gt: str, pause_tokens_index: Optional[list[int]] = None) -> tuple[int, Optional[str]]: + """Check if the prediction is correct using strict boxed answer criteria. + + Args: + pred: The prediction string + gt: The ground truth answer + pause_tokens_index: Indices of pause tokens + + Returns: + Tuple of (score, extracted_prediction) + """ + # Extract the relevant part of the prediction + if pause_tokens_index is not None: + assert len(pause_tokens_index) == 4 + pred = pred[pause_tokens_index[-1] - 100:] + else: + pred = pred[-100:] + + # Extract and check the boxed answer + boxed_pred = last_boxed_only_string(pred) + extracted_pred = remove_boxed(boxed_pred) if boxed_pred is not None else None + + return 1 if (extracted_pred == gt) else -1, extracted_pred + + +def verify(solution_str: str, answer: str, strict_box_verify: bool = False, + pause_tokens_index: Optional[list[int]] = None) -> bool: + """Verify if the solution is correct. + + Args: + solution_str: The solution string to verify + answer: The ground truth answer + strict_box_verify: Whether to use strict box verification + pause_tokens_index: Indices of pause tokens + + Returns: + True if the solution is correct, False otherwise + """ + if strict_box_verify: + corr, _ = is_correct_strict_box(solution_str, answer, pause_tokens_index) + return corr == 1 + + corr, _ = is_correct_minerva(solution_str, answer) + return corr + + +def compute_score(solution_str: str, ground_truth: str, config, pause_tokens_index: Optional[list[int]] = None) -> float: + """Compute the reward score for a solution. + + Args: + solution_str: The solution string + ground_truth: The ground truth answer + config: Configuration object containing reward model settings + pause_tokens_index: Indices of pause tokens + + Returns: + Reward score (1.0 for correct, -1.0 for incorrect) + """ + # Limit solution length for efficiency + solution_str = solution_str[-300:] # The longest answer in MATH-500 has 159 characters + + # Verify the solution + strict_box_verify = config.reward_model.strict_box_verify + correct = verify(solution_str, ground_truth, strict_box_verify, pause_tokens_index) + + return 1.0 if correct else -1.0 diff --git a/verl/utils/reward_score/math_v2.py b/verl/utils/reward_score/math_v2.py deleted file mode 100644 index fc1f3cb4399..00000000000 --- a/verl/utils/reward_score/math_v2.py +++ /dev/null @@ -1,363 +0,0 @@ -import re -import signal -from typing import Optional, Any -import torch - -try: - import sympy - from sympy.parsing.latex import parse_latex -except ModuleNotFoundError: - raise ModuleNotFoundError( - "`sympy` is required for generating translation task prompt templates. \ -please install sympy via pip install lm-eval[math] or pip install -e .[math]" , - ) - - -def list_fewshot_samples() -> list[dict]: - return [ - { - "problem": - "Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.}", - "solution": - "The expressions inside each square root must be non-negative. Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$. Also, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$. Therefore, the domain of the expression is $\\boxed{[2,5)}$.\nFinal Answer: The final answer is $[2,5)$. I hope it is correct.", - "few_shot": - "1", - }, - { - "problem": - "If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12,$ then find $\\det (\\mathbf{A} \\mathbf{B}).$", - "solution": - "We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}.$\nFinal Answer: The final answer is $24$. I hope it is correct.", - "few_shot": - "1", - }, - { - "problem": - "Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?", - "solution": - "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight. If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$:\n\\begin{align*}\n30n&=480\\\n\\Rightarrow\\qquad n&=480/30=\\boxed{16}\n\\end{align*}\nFinal Answer: The final answer is $16$. I hope it is correct.", - "few_shot": - "1", - }, - { - "problem": - "If the system of equations\n\n\\begin{align*}\n6x-4y&=a,\\\n6y-9x &=b.\n\\end{align*}has a solution $(x, y)$ where $x$ and $y$ are both nonzero,\nfind $\\frac{a}{b},$ assuming $b$ is nonzero.", - "solution": - "If we multiply the first equation by $-\\frac{3}{2}$, we obtain\n\n$$6y-9x=-\\frac{3}{2}a.$$Since we also know that $6y-9x=b$, we have\n\n$$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}.$$\nFinal Answer: The final answer is $-\\frac{2}{3}$. I hope it is correct.", - "few_shot": - "1", - }, - ] - -def last_boxed_only_string_v2(string: str) -> Optional[str]: - """ - find last \\boxed{...} - """ - idx = string.rfind("\\boxed{") - if idx < 0: - return None - - i = idx - right_brace_idx = None - num_left_braces_open = 0 - while i < len(string): - if string[i] == "{": - num_left_braces_open += 1 - if string[i] == "}": - num_left_braces_open -= 1 - if num_left_braces_open == 0: - right_brace_idx = i - break - i += 1 - - if right_brace_idx is None: - retval = None - else: - retval = string[idx:right_brace_idx + 1] - - return retval - - -def remove_boxed(s: str) -> str: - left = "\\boxed{" - - assert s[:len(left)] == left, f"box error: {s}" - assert s[-1] == "}", f"box error: {s}" - - return s[len(left):-1] - - -class timeout: - - def __init__(self, seconds=1, error_message="Timeout"): - self.seconds = seconds - self.error_message = error_message - - def handle_timeout(self, signum, frame): - raise TimeoutError(self.error_message) - - def __enter__(self): - signal.signal(signal.SIGALRM, self.handle_timeout) - signal.alarm(self.seconds) - - def __exit__(self, type, value, traceback): - signal.alarm(0) - - -def is_equiv(x1: str, x2: str) -> bool: - """ - x1 and x2 are normalized latex string - """ - try: - with timeout(seconds=10): - try: - parsed_x1 = parse_latex(x1) - parsed_x2 = parse_latex(x2) - except ( - sympy.parsing.latex.errors.LaTeXParsingError, - sympy.SympifyError, - TypeError, - ): - # eval_logger.debug(f"couldn't parse one of {x1} or {x2}") - return False - - try: - diff = parsed_x1 - parsed_x2 - except TypeError: - # eval_logger.debug(f"couldn't subtract {x1} and {x2}") - return False - - try: - if sympy.simplify(diff) == 0: - return True - else: - return False - except ValueError: - # eval_logger.debug( - # f"Had some trouble simplifying when comparing {x1} and {x2}" - # ) - return False - - except TimeoutError: - # eval_logger.debug(f"Timed out comparing {x1} and {x2}") - return False - except ImportError as e: - # eval_logger.error(e) - raise - except Exception as e: - # eval_logger.debug(f"Failed comparing {x1} and {x2} with {e}") - return False - - -SUBSTITUTIONS = [ - ("an ", ""), - ("a ", ""), - (".$", "$"), - ("\\$", ""), - (r"\ ", ""), - (" ", ""), - ("mbox", "text"), - (",\\text{and}", ","), - ("\\text{and}", ","), - ("\\text{m}", "\\text{}"), -] -REMOVED_EXPRESSIONS = [ - "square", - "ways", - "integers", - "dollars", - "mph", - "inches", - # "ft", #this is dangerous, infty, left will be damaged! - "hours", - "km", - "units", - "\\ldots", - "sue", - "points", - "feet", - "minutes", - "digits", - "cents", - "degrees", - "cm", - "gm", - "pounds", - "meters", - "meals", - "edges", - "students", - "childrentickets", - "multiples", - "\\text{s}", - "\\text{.}", - "\\text{\ns}", - "\\text{}^2", - "\\text{}^3", - "\\text{\n}", - "\\text{}", - r"\mathrm{th}", - r"^\circ", - r"^{\circ}", - r"\;", - r",\!", - "{,}", - '"', - "\\dots", -] - - -def normalize_final_answer(final_answer: str) -> str: - """ - Normalize a final answer to a quantitative reasoning question. - - Copied character for character from appendix D of Lewkowycz et al. (2022) - """ - final_answer = final_answer.split("=")[-1] - - for before, after in SUBSTITUTIONS: - final_answer = final_answer.replace(before, after) - for expr in REMOVED_EXPRESSIONS: - final_answer = final_answer.replace(expr, "") - - # Extract answer that is in LaTeX math, is bold, - # is surrounded by a box, etc. - final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) - final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) - final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) - final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) - final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) - - # Normalize shorthand TeX: - # \fracab -> \frac{a}{b} - # \frac{abc}{bef} -> \frac{abc}{bef} - # \fracabc -> \frac{a}{b}c - # \sqrta -> \sqrt{a} - # \sqrtab -> sqrt{a}b - final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) - final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) - final_answer = final_answer.replace("$", "") - - # Normalize 100,000 -> 100000 - if final_answer.replace(",", "").isdigit(): - final_answer = final_answer.replace(",", "") - - return final_answer.strip() - - -INVALID_ANS_GSM8k = "[invalid]" -ANSWER_PATTERN = r"(?i)Answer\s*:\s*([^\n]+)" -STRICT_BOX_PATTERN = r"\\boxed\{([^}]*)\}" - - -def filter_ignores(st, regexes_to_ignore): - if regexes_to_ignore is not None: - for s in regexes_to_ignore: - st = re.sub(s, "", st) - return st - - -def is_correct_integer( - og_pred, - gt, -): - numbers = re.findall(r'-?\d+', og_pred[-100:]) - numbers = numbers[-1] if len(numbers) > 0 else "" # 很难通过枚举把最后一个搞成正确答案 - correctness = gt == numbers - return correctness, og_pred[-100:] - - -def is_correct_minerva(og_pred, gt, gt_need_extract=False): - og_pred = og_pred[-300:] #math500最长answer为159 - match = re.findall(ANSWER_PATTERN, og_pred) - extracted_answer = match[-1] if match else "[INVALID]" - pred = normalize_final_answer(extracted_answer) - if gt_need_extract: - gt = normalize_final_answer(remove_boxed(last_boxed_only_string_v2(gt))) - else: - gt = normalize_final_answer(gt) - # return (pred == gt or is_equiv(pred, gt)), pred - return (pred == gt), pred - - -def is_correct_strict_box(pred, gt, pause_tokens_index): - if pause_tokens_index is not None: - assert len(pause_tokens_index) == 4 - pred = pred[pause_tokens_index[-1] - 100:] - else: - pred = pred[-100:] - pred = last_boxed_only_string_v2(pred) - pred = remove_boxed(pred) if pred is not None else None - return 1 if (pred == gt) else -1, pred - - -def verify(pred, answer, resp_len, max_resp_len, add_int_verify=True, strict_box_verify=False, pause_tokens_index=None) -> bool: - - if strict_box_verify: - corr_strict_box, pred_strict_box = is_correct_strict_box(pred, answer, pause_tokens_index) - return corr_strict_box - - corr_minerva, pred_minerva = is_correct_minerva(pred, - answer) # To remove if math is also converted to interger format - if add_int_verify: - corr_integer, pred_integer = is_correct_integer(pred, answer) - pred = pred_minerva if corr_minerva else pred_integer - corr = corr_minerva or corr_integer - else: - pred = pred_minerva - corr = corr_minerva - - return corr - - -def compute_score(batch_info, solution_str, ground_truth, config, rm_name, pause_tokens_index, **argv) -> tuple[float, dict[str, Any]]: - """ - default行为:对给1,其余给-1 - punish_no_answer: - * v0: 0 - * v1: -0.1 - * v2: -0.2 - """ - prompt_length = batch_info['prompts'].shape[-1] - max_resp_len = batch_info['responses'].shape[-1] - resp_len = sum(batch_info['attention_mask'][prompt_length:].tolist()) - add_int_verify = config.reward_model.add_int_verify - strict_box_verify = config.reward_model.strict_box_verify - - correct = verify(solution_str, ground_truth, resp_len, max_resp_len, add_int_verify, strict_box_verify, - pause_tokens_index) - - final_reward = 1 if correct else -1 - - # extra_info = {"acc": correct, "format/answer": "answer:" in solution_str, "format/boxed": "boxed{" in solution_str, "format/option": False} - - return final_reward - - -if __name__ == "__main__": - - # 不含合法答案的长/短回复,启用不含answer惩罚 - # pred = r"""Another way is to look at the height of one corner of our shape considering the big square we are trying to find where this outer tangent line from the big circle to the adjacent small circles is touching the corners of those small circles at specific heights relative to the center of the big circle which may be related to the radius of the circle such as the bottom of the outer tangent line from the big excess outer circle to the adjacent small circles is some height relative to the center of the big outer circle which directly relates to the radius length of the small circles in relationship to the big outer shape in their outer area. We draw a line from the corner of one of the small circles to the outer edge of the big outer shape and from the center of the small circle perpendicular to the outer edge of the big outer shape so that it hits the outer edge of the big outer shape at a right angle which we call $ h $ which may be related to the radius of the small circle and the shape formed at that corner in relationship to the outer edges of the outer shapes. And the distance from the corner of one of the small circles to the outer edge of the big outer shape which includes the half of this figure that represents how high the outer tangent line from the big excess outer circle to the adjacent small circles is touching the corners of those small circles at specific heights relative to the center of the big circle which includes the radius part of the small circle which we call $ l $ . If we know the height of this line from the corner of the big outer main shape and the line perpendicular to the outer edge of the big outer shape which is what we call it as coming from the corner of the big outer main shape and is hitting the outer edge of the big outer shape with the half of this outer tangent line from the big excess outer circle to the adjacent small circles which we call it with the radius of the small circles which we call this value $ h $ which is what we are looking at in relationship to how that relates to the rest of the features of this outer shape. We might use some trigonometry involving the angle between this half part of the outer tangent line from the big excess outer circle to the adjacent small circles being some angle at one corner of one of the small circles in relationship to the other parts of the outer shape to figure out what the value of the radius of the small circle is. And if we have this value of the height of this line from the corner of the big outer main shape and the line perpendicular to the outer edge of the big outer shape which is what we have already called it as coming from the corner of the big outer main shape and is hitting the outer edge of the big outer shape with the half of this outer tangent line from the big excess outer circle to the adjacent small circles which we call it with the radius of the small circles which we call this value $ h $ which is what we have been using to think about in relationship to how that relates to the rest of the features of this outer shape and the angle at one corner of one of the small circles in relationship to the other parts of the outer shape which we call the "angle between relevant sides value" let's call it $ \theta_{rs} $ . Then using the height of this line from the corner of the big outer main shape and the line perpendicular to the outer edge of the big outer shape which is what we call it as coming from the corner of the big outer main shape and is hitting the outer edge of the big outer shape with the half of this outer tangent line""" - answer = "22+12\sqrt{2}" - pred = r"""Answer:22+12\sqrt{2}""" - assert verify(pred, answer, 0, 16384, False, "v0", strict_box_verify=True) == -1 - - pred = "So we have 235 as answer." - answer = "235" - assert verify(pred, answer, 0, 16384, False, "v0", strict_box_verify=True) == -1 - - pred = "So we have \\boxed{235} as answer." - answer = "235" - assert verify(pred, answer, 0, 16384, False, "v0", strict_box_verify=True) == 1 - - pred = "Answer: 11" - answer = "11" - assert verify(pred, answer, 0, 16384, False, "v0", strict_box_verify=True) == -1 - - pred = "Answer: \\boxed{11}" - answer = "-13" - assert verify(pred, answer, 0, 16384, False, "v0", strict_box_verify=True) == -1 - - pred = "Answer: \\boxed{11}" - answer = "11" - assert verify(pred, answer, 0, 16384, False, "v0", strict_box_verify=True) == 1 From 220d89fecb1220e2d193bd5deaea1e7b8495d4d0 Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Sat, 15 Mar 2025 13:28:38 +0000 Subject: [PATCH 3/3] feat: test script --- recipe/puffin/run_puffin_qwen2.5_7b_test.sh | 89 +++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 recipe/puffin/run_puffin_qwen2.5_7b_test.sh diff --git a/recipe/puffin/run_puffin_qwen2.5_7b_test.sh b/recipe/puffin/run_puffin_qwen2.5_7b_test.sh new file mode 100644 index 00000000000..b95c031e23b --- /dev/null +++ b/recipe/puffin/run_puffin_qwen2.5_7b_test.sh @@ -0,0 +1,89 @@ +#!/usr/bin/env bash +set -euxo pipefail + +project_name='puffin' +exp_name='Qwen2.5-7B-Puffin-Test' + +# Paths +MODEL_PATH=${MODEL_PATH:-"${HOME}/verl/models/Qwen2.5-7B"} +CKPTS_DIR=${CKPTS_DIR:-"${HOME}/verl/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${HOME}/verl/data/puffin_train.parquet"} +TEST_FILE=${TEST_FILE:-"${HOME}/verl/data/puffin_test.parquet"} + +# Algorithm +## Train +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +# TODO +# force_append_eos=True +## Validation +val_top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + + +# Mathematically equivalent +use_dynamic_bsz=True +infer_micro_batch_size=null +train_micro_batch_size=null +offload=True + +python3 -m verl.trainer.main_ppo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=512 \ + data.truncation='left' \ + actor_rollout_ref.rollout.n=16 \ + actor_rollout_ref.actor.kl_loss_coef=0 \ + actor_rollout_ref.actor.clip_ratio=0.2 \ + algorithm.adv_estimator=grpo \ + algorithm.kl_ctrl.kl_coef=0.0 \ + algorithm.gamma=1.0 \ + algorithm.lam=0.95 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + +actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=512 \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.val_kwargs.top_k="${val_top_k}" \ + actor_rollout_ref.rollout.val_kwargs.top_p=0.7\ + actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \ + actor_rollout_ref.rollout.val_kwargs.n=32 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + trainer.logger=['console','wandb'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + +trainer.val_before_train=True \ + trainer.test_freq=2 \ + trainer.save_freq=2 \ + trainer.total_epochs=5000 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=disable \ No newline at end of file