diff --git a/recipe/puffin/run_puffin_qwen2.5_7b_test.sh b/recipe/puffin/run_puffin_qwen2.5_7b_test.sh new file mode 100644 index 00000000000..b95c031e23b --- /dev/null +++ b/recipe/puffin/run_puffin_qwen2.5_7b_test.sh @@ -0,0 +1,89 @@ +#!/usr/bin/env bash +set -euxo pipefail + +project_name='puffin' +exp_name='Qwen2.5-7B-Puffin-Test' + +# Paths +MODEL_PATH=${MODEL_PATH:-"${HOME}/verl/models/Qwen2.5-7B"} +CKPTS_DIR=${CKPTS_DIR:-"${HOME}/verl/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${HOME}/verl/data/puffin_train.parquet"} +TEST_FILE=${TEST_FILE:-"${HOME}/verl/data/puffin_test.parquet"} + +# Algorithm +## Train +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +# TODO +# force_append_eos=True +## Validation +val_top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + + +# Mathematically equivalent +use_dynamic_bsz=True +infer_micro_batch_size=null +train_micro_batch_size=null +offload=True + +python3 -m verl.trainer.main_ppo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=512 \ + data.truncation='left' \ + actor_rollout_ref.rollout.n=16 \ + actor_rollout_ref.actor.kl_loss_coef=0 \ + actor_rollout_ref.actor.clip_ratio=0.2 \ + algorithm.adv_estimator=grpo \ + algorithm.kl_ctrl.kl_coef=0.0 \ + algorithm.gamma=1.0 \ + algorithm.lam=0.95 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + +actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=512 \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.val_kwargs.top_k="${val_top_k}" \ + actor_rollout_ref.rollout.val_kwargs.top_p=0.7\ + actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \ + actor_rollout_ref.rollout.val_kwargs.n=32 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + trainer.logger=['console','wandb'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + +trainer.val_before_train=True \ + trainer.test_freq=2 \ + trainer.save_freq=2 \ + trainer.total_epochs=5000 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=disable \ No newline at end of file diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index c903934b8bb..e1dd29d7b91 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -3,6 +3,7 @@ data: train_files: ~/data/rlhf/gsm8k/train.parquet val_files: ~/data/rlhf/gsm8k/test.parquet prompt_key: prompt + reward_fn_key: data_source max_prompt_length: 512 max_response_length: 512 train_batch_size: 1024 diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index eab987d3aad..ac121a02405 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -3,6 +3,7 @@ data: train_files: ~/data/rlhf/gsm8k/train.parquet val_files: ~/data/rlhf/gsm8k/test.parquet prompt_key: prompt + reward_fn_key: data_source max_prompt_length: 512 max_response_length: 512 train_batch_size: 1024 diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index 3d5837a6c16..65a7865c92e 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -146,10 +146,10 @@ def main_task(config): raise NotImplementedError compute_score = get_custom_reward_fn(config) - reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score) + reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score, reward_fn_key=config.data.reward_fn_key) # Note that we always use function-based RM for validation - val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score) + val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score, reward_fn_key=config.data.reward_fn_key) resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) diff --git a/verl/utils/reward_score/__init__.py b/verl/utils/reward_score/__init__.py index 12599556d0f..d72005da500 100644 --- a/verl/utils/reward_score/__init__.py +++ b/verl/utils/reward_score/__init__.py @@ -25,6 +25,9 @@ def _default_compute_score(data_source, solution_str, ground_truth, extra_info=N # Use Math-Verify (https://github.com/huggingface/Math-Verify) for better evaluation accuracy from . import math_verify res = math_verify.compute_score(solution_str, ground_truth) + elif data_source == 'math_puffin': + from . import math_puffin + res = math_puffin.compute_score elif data_source in [ 'numina_aops_forum', 'numina_synthetic_math', 'numina_amc_aime', 'numina_synthetic_amc', 'numina_cn_k12', 'numina_olympiads' diff --git a/verl/utils/reward_score/math_puffin.py b/verl/utils/reward_score/math_puffin.py new file mode 100644 index 00000000000..e1a4863872f --- /dev/null +++ b/verl/utils/reward_score/math_puffin.py @@ -0,0 +1,276 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py + +import re +import signal +from typing import Optional + +import sympy +from sympy.parsing.latex import parse_latex + + +def last_boxed_only_string(string: str) -> Optional[str]: + """Extract the last LaTeX boxed expression from a string. + + Args: + string: Input string containing LaTeX code + + Returns: + The last boxed expression or None if not found + """ + idx = string.rfind("\\boxed{") + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + return string[idx:right_brace_idx + 1] if right_brace_idx is not None else None + + +def remove_boxed(s: str) -> str: + """Remove the LaTeX boxed command from a string. + + Args: + s: String with format "\\boxed{content}" + + Returns: + The content inside the boxed command + """ + left = "\\boxed{" + assert s[:len(left)] == left, f"box error: {s}" + assert s[-1] == "}", f"box error: {s}" + return s[len(left):-1] + + +class timeout: + + def __init__(self, seconds=1, error_message="Timeout"): + self.seconds = seconds + self.error_message = error_message + + def handle_timeout(self, signum, frame): + raise TimeoutError(self.error_message) + + def __enter__(self): + signal.signal(signal.SIGALRM, self.handle_timeout) + signal.alarm(self.seconds) + + def __exit__(self, type, value, traceback): + signal.alarm(0) + + +def is_equiv(x1: str, x2: str) -> bool: + """ + Args: + x1, x2: normalized LaTeX string + """ + try: + with timeout(seconds=10): + try: + parsed_x1 = parse_latex(x1) + parsed_x2 = parse_latex(x2) + except ( + sympy.parsing.latex.errors.LaTeXParsingError, + sympy.SympifyError, + TypeError, + ): + return False + + try: + diff = parsed_x1 - parsed_x2 + except TypeError: + return False + + try: + if sympy.simplify(diff) == 0: + return True + else: + return False + except ValueError: + return False + + except TimeoutError: + return False + except ImportError as e: + raise + except Exception as e: + return False + + +# Constants for normalization +SUBSTITUTIONS = [ + ("an ", ""), ("a ", ""), (".$", "$"), ("\\$", ""), + (r"\ ", ""), (" ", ""), ("mbox", "text"), + (",\\text{and}", ","), ("\\text{and}", ","), ("\\text{m}", "\\text{}"), +] + +REMOVED_EXPRESSIONS = [ + "square", "ways", "integers", "dollars", "mph", "inches", + "hours", "km", "units", "\\ldots", "sue", "points", "feet", + "minutes", "digits", "cents", "degrees", "cm", "gm", "pounds", + "meters", "meals", "edges", "students", "childrentickets", + "multiples", "\\text{s}", "\\text{.}", "\\text{\ns}", + "\\text{}^2", "\\text{}^3", "\\text{\n}", "\\text{}", + r"\mathrm{th}", r"^\circ", r"^{\circ}", r"\;", r",\!", + "{,}", '"', "\\dots", +] + + +def normalize_final_answer(final_answer: str) -> str: + """Normalize a final answer to a quantitative reasoning question. + + Args: + final_answer: The answer string to normalize + + Returns: + Normalized answer string + """ + final_answer = final_answer.split("=")[-1] + + # Apply substitutions and removals + for before, after in SUBSTITUTIONS: + final_answer = final_answer.replace(before, after) + for expr in REMOVED_EXPRESSIONS: + final_answer = final_answer.replace(expr, "") + + # Extract and normalize LaTeX math + final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) + final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) + + # Normalize shorthand TeX: + # \fracab -> \frac{a}{b} + # \frac{abc}{bef} -> \frac{abc}{bef} + # \fracabc -> \frac{a}{b}c + # \sqrta -> \sqrt{a} + # \sqrtab -> sqrt{a}b + final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) + final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) + final_answer = final_answer.replace("$", "") + + # Normalize numbers + if final_answer.replace(",", "").isdigit(): + final_answer = final_answer.replace(",", "") + + return final_answer.strip() + + +def is_correct_minerva(solution_str: str, gt: str, gt_need_extract: bool = False, + answer_pattern: str = r"(?i)Answer\s*:\s*([^\n]+)") -> tuple[bool, str]: + """Check if the solution is correct according to Minerva criteria. + + Args: + solution_str: The solution string to check + gt: The ground truth answer + gt_need_extract: Whether the ground truth needs extraction + answer_pattern: Regex pattern to extract the answer + + Returns: + Tuple of (is_correct, normalized_prediction) + """ + # Extract answer from solution + match = re.findall(answer_pattern, solution_str) + extracted_answer = match[-1] if match else "[INVALID]" + pred = normalize_final_answer(extracted_answer) + + # Process ground truth + if gt_need_extract: + gt = normalize_final_answer(remove_boxed(last_boxed_only_string(gt))) + else: + gt = normalize_final_answer(gt) + + return (pred == gt or is_equiv(pred, gt)), pred + + +def is_correct_strict_box(pred: str, gt: str, pause_tokens_index: Optional[list[int]] = None) -> tuple[int, Optional[str]]: + """Check if the prediction is correct using strict boxed answer criteria. + + Args: + pred: The prediction string + gt: The ground truth answer + pause_tokens_index: Indices of pause tokens + + Returns: + Tuple of (score, extracted_prediction) + """ + # Extract the relevant part of the prediction + if pause_tokens_index is not None: + assert len(pause_tokens_index) == 4 + pred = pred[pause_tokens_index[-1] - 100:] + else: + pred = pred[-100:] + + # Extract and check the boxed answer + boxed_pred = last_boxed_only_string(pred) + extracted_pred = remove_boxed(boxed_pred) if boxed_pred is not None else None + + return 1 if (extracted_pred == gt) else -1, extracted_pred + + +def verify(solution_str: str, answer: str, strict_box_verify: bool = False, + pause_tokens_index: Optional[list[int]] = None) -> bool: + """Verify if the solution is correct. + + Args: + solution_str: The solution string to verify + answer: The ground truth answer + strict_box_verify: Whether to use strict box verification + pause_tokens_index: Indices of pause tokens + + Returns: + True if the solution is correct, False otherwise + """ + if strict_box_verify: + corr, _ = is_correct_strict_box(solution_str, answer, pause_tokens_index) + return corr == 1 + + corr, _ = is_correct_minerva(solution_str, answer) + return corr + + +def compute_score(solution_str: str, ground_truth: str, config, pause_tokens_index: Optional[list[int]] = None) -> float: + """Compute the reward score for a solution. + + Args: + solution_str: The solution string + ground_truth: The ground truth answer + config: Configuration object containing reward model settings + pause_tokens_index: Indices of pause tokens + + Returns: + Reward score (1.0 for correct, -1.0 for incorrect) + """ + # Limit solution length for efficiency + solution_str = solution_str[-300:] # The longest answer in MATH-500 has 159 characters + + # Verify the solution + strict_box_verify = config.reward_model.strict_box_verify + correct = verify(solution_str, ground_truth, strict_box_verify, pause_tokens_index) + + return 1.0 if correct else -1.0 diff --git a/verl/workers/reward_manager/naive.py b/verl/workers/reward_manager/naive.py index a65dd126bf7..5206fb48d52 100644 --- a/verl/workers/reward_manager/naive.py +++ b/verl/workers/reward_manager/naive.py @@ -21,10 +21,11 @@ class NaiveRewardManager: """The reward manager. """ - def __init__(self, tokenizer, num_examine, compute_score=None) -> None: + def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key='data_source') -> None: self.tokenizer = tokenizer self.num_examine = num_examine # the number of batches of decoded responses to print to the console self.compute_score = compute_score or _default_compute_score + self.reward_fn_key = reward_fn_key def verify(self, data): scores = [] @@ -48,7 +49,7 @@ def verify(self, data): ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth'] - data_source = data_item.non_tensor_batch['data_source'] + data_source = data_item.non_tensor_batch[self.reward_fn_key] extra_info = data_item.non_tensor_batch.get('extra_info', None) @@ -93,7 +94,7 @@ def __call__(self, data: DataProto): ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth'] - data_source = data_item.non_tensor_batch['data_source'] + data_source = data_item.non_tensor_batch[self.reward_fn_key] extra_info = data_item.non_tensor_batch.get('extra_info', None)