From 33f80122174ddc5f0d5b3d7b96dcd59bf3b4ce36 Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Sat, 15 Mar 2025 14:42:46 +0000 Subject: [PATCH 01/22] fix: config chunked prefilling --- recipe/puffin/run_puffin_qwen2.5_7b_test.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/recipe/puffin/run_puffin_qwen2.5_7b_test.sh b/recipe/puffin/run_puffin_qwen2.5_7b_test.sh index dc9e23c527c..df8d93dd0f2 100644 --- a/recipe/puffin/run_puffin_qwen2.5_7b_test.sh +++ b/recipe/puffin/run_puffin_qwen2.5_7b_test.sh @@ -67,6 +67,8 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \ actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ actor_rollout_ref.rollout.val_kwargs.top_k="${val_top_k}" \ actor_rollout_ref.rollout.val_kwargs.top_p=0.7\ actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \ From c87a749d9bb6af37b1eb28bd6fbe8c88187ca813 Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Sun, 16 Mar 2025 05:38:32 +0000 Subject: [PATCH 02/22] feat: more statistics for validation metrics --- verl/trainer/ppo/ray_trainer.py | 64 ++++++++++++++++++++------ verl/utils/reward_score/__init__.py | 4 +- verl/utils/reward_score/math_puffin.py | 8 +++- verl/workers/reward_manager/naive.py | 40 +++++++++++++--- 4 files changed, 93 insertions(+), 23 deletions(-) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index d8228b1514b..891c4c71f5b 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -24,7 +24,9 @@ from pprint import pprint from typing import Type, Dict from copy import deepcopy +from collections import defaultdict +import pandas as pd import ray import numpy as np from codetiming import Timer @@ -494,8 +496,8 @@ def _maybe_log_val_generations(self, inputs, outputs, scores): self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps) def _validate(self): - reward_tensor_lst = [] data_source_lst = [] + extra_infos_dict: dict[str, list] = defaultdict(list) # Lists to collect samples for the table sample_inputs = [] @@ -511,6 +513,7 @@ def _validate(self): # Store original inputs input_ids = test_batch.batch['input_ids'] + # TODO: Can we keep special tokens except for padding tokens? input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] sample_inputs.extend(input_texts) @@ -549,33 +552,66 @@ def _validate(self): test_batch = test_batch.union(test_output_gen_batch) # evaluate using reward_function - reward_tensor = self.val_reward_fn(test_batch) + result = self.val_reward_fn(test_batch, return_dict=True) + reward_tensor = result["reward_tensor"] + extra_info = result["extra_info"] + for key, lst in extra_info.items(): + extra_infos_dict[key].extend(lst) # Store scores scores = reward_tensor.sum(-1).cpu().tolist() sample_scores.extend(scores) - reward_tensor_lst.append(reward_tensor) data_source_lst.append(test_batch.non_tensor_batch.get('data_source', ['unknown'] * reward_tensor.shape[0])) self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) - reward_tensor = torch.cat(reward_tensor_lst, dim=0).sum(-1).cpu() # (batch_size,) data_sources = np.concatenate(data_source_lst, axis=0) - # evaluate test_score based on data source - data_source_reward = {} - for i in range(reward_tensor.shape[0]): - data_source = data_sources[i] - if data_source not in data_source_reward: - data_source_reward[data_source] = [] - data_source_reward[data_source].append(reward_tensor[i].item()) + sample_df = pd.DataFrame( + { + "data_source": data_sources, + "prompt": sample_inputs, + "response": sample_outputs, + "sum_reward": sample_scores, + **extra_infos_dict, + } + ) metric_dict = {} - for data_source, rewards in data_source_reward.items(): - metric_dict[f'val/test_score/{data_source}'] = np.mean(rewards) - return metric_dict + # Calculate metrics for each data source + # First, identify all numeric columns that might be metrics + num_cols = [col for col in extra_infos_dict.keys() + if isinstance(extra_infos_dict[col][0], (int, float, bool, np.number))] + + + # Group by data_source and calculate statistics + prompt_stats = sample_df.groupby(["data_source", "prompt"]).agg({ + **{col: ["mean", "std", "min", "max"] for col in num_cols}, + "response": "count" # Count responses per prompt + }) + + # This creates a multi-level column index, which we can flatten + prompt_stats.columns = [f"{col}_{stat}" for col, stat in prompt_stats.columns] + prompt_stats = prompt_stats.reset_index() + + # Calculate metrics for each data source + for data_source in sample_df["data_source"].unique(): + # Get stats for this data source + source_stats: pd.DataFrame = prompt_stats[prompt_stats["data_source"] == data_source] + + # Calculate mean of prompt means for each column + for stat_col in prompt_stats.columns: + # Assert each prompt has the same number of responses + uniq_resp_cnts = source_stats["response_count"].unique() + assert len(uniq_resp_cnts) == 1, f"Each prompt must have the same number of responses, but got {uniq_resp_cnts}" + resp_cnt = uniq_resp_cnts[0] + # Mean of means for this column across all prompts in this data source + metric_dict[f"{data_source}/{stat_col}@{resp_cnt}"] = source_stats[stat_col].mean() + + val_metric_dict = {f"val/{key}": value for key, value in metric_dict.items()} + return val_metric_dict def init_workers(self): """Init resource pool and worker group""" diff --git a/verl/utils/reward_score/__init__.py b/verl/utils/reward_score/__init__.py index d72005da500..32bf4b73dd1 100644 --- a/verl/utils/reward_score/__init__.py +++ b/verl/utils/reward_score/__init__.py @@ -43,7 +43,9 @@ def _default_compute_score(data_source, solution_str, ground_truth, extra_info=N else: raise NotImplementedError - if isinstance(res, (int, float, bool)): + if isinstance(res, dict): + return res + elif isinstance(res, (int, float, bool)): return float(res) else: return float(res[0]) diff --git a/verl/utils/reward_score/math_puffin.py b/verl/utils/reward_score/math_puffin.py index e1a4863872f..3f387bea414 100644 --- a/verl/utils/reward_score/math_puffin.py +++ b/verl/utils/reward_score/math_puffin.py @@ -272,5 +272,11 @@ def compute_score(solution_str: str, ground_truth: str, config, pause_tokens_ind # Verify the solution strict_box_verify = config.reward_model.strict_box_verify correct = verify(solution_str, ground_truth, strict_box_verify, pause_tokens_index) + + reward = 1.0 if correct else -1.0 + acc = correct - return 1.0 if correct else -1.0 + return { + "reward": reward, + "acc": acc, + } diff --git a/verl/workers/reward_manager/naive.py b/verl/workers/reward_manager/naive.py index 5206fb48d52..7091dd2f397 100644 --- a/verl/workers/reward_manager/naive.py +++ b/verl/workers/reward_manager/naive.py @@ -15,6 +15,7 @@ from verl import DataProto from verl.utils.reward_score import _default_compute_score import torch +from collections import defaultdict class NaiveRewardManager: @@ -27,6 +28,7 @@ def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key='da self.compute_score = compute_score or _default_compute_score self.reward_fn_key = reward_fn_key + # TODO: Is this still necessary in algorithms other than PRIME? def verify(self, data): scores = [] for i in range(len(data)): @@ -63,14 +65,20 @@ def verify(self, data): data.batch['acc'] = torch.tensor(scores, dtype=torch.float32, device=prompt_ids.device) return scores - def __call__(self, data: DataProto): + def __call__(self, data: DataProto, return_dict: bool = False): """We will expand this function gradually based on the available datasets""" # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn if 'rm_scores' in data.batch.keys(): - return data.batch['rm_scores'] + if return_dict: + return { + "reward": data.batch['rm_scores'] + } + else: + return data.batch['rm_scores'] reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32) + extra_info = defaultdict(list) already_print_data_sources = {} @@ -98,13 +106,24 @@ def __call__(self, data: DataProto): extra_info = data_item.non_tensor_batch.get('extra_info', None) - score = self.compute_score( + result = self.compute_score( data_source=data_source, solution_str=response_str, ground_truth=ground_truth, extra_info=extra_info, ) - reward_tensor[i, valid_response_length - 1] = score + + reward: float + if isinstance(result, dict): + assert "reward" in result + reward = result["reward"] + else: + reward = result + + reward_tensor[i, valid_response_length - 1] = reward + + for key, value in result.items(): + extra_info[key].append(value) if data_source not in already_print_data_sources: already_print_data_sources[data_source] = 0 @@ -114,6 +133,13 @@ def __call__(self, data: DataProto): print("[prompt]", prompt_str) print("[response]", response_str) print("[ground_truth]", ground_truth) - print("[score]", score) - - return reward_tensor + for key, value in result.items(): + print(f"[{key}]", value) + + if return_dict: + return { + "reward_tensor": reward_tensor, + "extra_info": extra_info, + } + else: + return reward_tensor From 49b4c1ee1f10fc6c7bc498db28dbefaff882aea9 Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Sun, 16 Mar 2025 05:41:28 +0000 Subject: [PATCH 03/22] fix: "entropy" as metric name --- verl/workers/actor/dp_actor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 1057c6d08a6..7116b6545b9 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -322,7 +322,7 @@ def update_policy(self, data: DataProto): loss.backward() data = { - 'actor/entropy_loss': entropy_loss.detach().item(), + 'actor/entropy': entropy_loss.detach().item(), 'actor/pg_loss': pg_loss.detach().item(), 'actor/pg_clipfrac': pg_clipfrac.detach().item(), 'actor/ppo_kl': ppo_kl.detach().item(), From c088d08a4993ae94ca753bacccf987ddf43fa97c Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Sun, 16 Mar 2025 06:53:14 +0000 Subject: [PATCH 04/22] fix: call math_puffin --- verl/utils/reward_score/__init__.py | 2 +- verl/utils/reward_score/math_puffin.py | 12 +++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/verl/utils/reward_score/__init__.py b/verl/utils/reward_score/__init__.py index 32bf4b73dd1..5946dbf2c41 100644 --- a/verl/utils/reward_score/__init__.py +++ b/verl/utils/reward_score/__init__.py @@ -27,7 +27,7 @@ def _default_compute_score(data_source, solution_str, ground_truth, extra_info=N res = math_verify.compute_score(solution_str, ground_truth) elif data_source == 'math_puffin': from . import math_puffin - res = math_puffin.compute_score + res = math_puffin.compute_score(solution_str, ground_truth) elif data_source in [ 'numina_aops_forum', 'numina_synthetic_math', 'numina_amc_aime', 'numina_synthetic_amc', 'numina_cn_k12', 'numina_olympiads' diff --git a/verl/utils/reward_score/math_puffin.py b/verl/utils/reward_score/math_puffin.py index 3f387bea414..55c85f62ac2 100644 --- a/verl/utils/reward_score/math_puffin.py +++ b/verl/utils/reward_score/math_puffin.py @@ -37,7 +37,7 @@ def last_boxed_only_string(string: str) -> Optional[str]: i = idx right_brace_idx = None num_left_braces_open = 0 - + while i < len(string): if string[i] == "{": num_left_braces_open += 1 @@ -254,7 +254,10 @@ def verify(solution_str: str, answer: str, strict_box_verify: bool = False, return corr -def compute_score(solution_str: str, ground_truth: str, config, pause_tokens_index: Optional[list[int]] = None) -> float: +def compute_score(solution_str: str, + ground_truth: str, + strict_box_verify: bool = False, + pause_tokens_index: Optional[list[int]] = None) -> float: """Compute the reward score for a solution. Args: @@ -268,14 +271,13 @@ def compute_score(solution_str: str, ground_truth: str, config, pause_tokens_ind """ # Limit solution length for efficiency solution_str = solution_str[-300:] # The longest answer in MATH-500 has 159 characters - + # Verify the solution - strict_box_verify = config.reward_model.strict_box_verify correct = verify(solution_str, ground_truth, strict_box_verify, pause_tokens_index) reward = 1.0 if correct else -1.0 acc = correct - + return { "reward": reward, "acc": acc, From ccca602b06d8949b6c3234384b9abff4201ab9d2 Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Sun, 16 Mar 2025 07:05:23 +0000 Subject: [PATCH 05/22] fix: test script --- recipe/puffin/run_puffin_qwen2.5_7b_test.sh | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/recipe/puffin/run_puffin_qwen2.5_7b_test.sh b/recipe/puffin/run_puffin_qwen2.5_7b_test.sh index df8d93dd0f2..6252477f9ed 100644 --- a/recipe/puffin/run_puffin_qwen2.5_7b_test.sh +++ b/recipe/puffin/run_puffin_qwen2.5_7b_test.sh @@ -1,6 +1,8 @@ #!/usr/bin/env bash set -euxo pipefail +pip install "antlr4-python3-runtime==4.11.*" + project_name='puffin' exp_name='Qwen2.5-7B-Puffin-Test' @@ -14,12 +16,9 @@ TEST_FILE=${TEST_FILE:-"${HOME}/verl/data/puffin_test.parquet"} ## Train max_prompt_length=$((1024 * 2)) max_response_length=$((1024 * 8)) -# TODO -# force_append_eos=True ## Validation val_top_k=-1 # 0 for HF rollout, -1 for vLLM rollout - # Mathematically equivalent use_dynamic_bsz=True infer_micro_batch_size=null From d761bf75ebe512346e5d7b6f8dda52cb8bd0764b Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Sun, 16 Mar 2025 07:46:59 +0000 Subject: [PATCH 06/22] fix: puffin verifier --- verl/utils/reward_score/math_puffin.py | 44 +------------------------- 1 file changed, 1 insertion(+), 43 deletions(-) diff --git a/verl/utils/reward_score/math_puffin.py b/verl/utils/reward_score/math_puffin.py index 55c85f62ac2..a3f1cea87eb 100644 --- a/verl/utils/reward_score/math_puffin.py +++ b/verl/utils/reward_score/math_puffin.py @@ -17,9 +17,6 @@ import signal from typing import Optional -import sympy -from sympy.parsing.latex import parse_latex - def last_boxed_only_string(string: str) -> Optional[str]: """Extract the last LaTeX boxed expression from a string. @@ -82,45 +79,6 @@ def __enter__(self): def __exit__(self, type, value, traceback): signal.alarm(0) - -def is_equiv(x1: str, x2: str) -> bool: - """ - Args: - x1, x2: normalized LaTeX string - """ - try: - with timeout(seconds=10): - try: - parsed_x1 = parse_latex(x1) - parsed_x2 = parse_latex(x2) - except ( - sympy.parsing.latex.errors.LaTeXParsingError, - sympy.SympifyError, - TypeError, - ): - return False - - try: - diff = parsed_x1 - parsed_x2 - except TypeError: - return False - - try: - if sympy.simplify(diff) == 0: - return True - else: - return False - except ValueError: - return False - - except TimeoutError: - return False - except ImportError as e: - raise - except Exception as e: - return False - - # Constants for normalization SUBSTITUTIONS = [ ("an ", ""), ("a ", ""), (".$", "$"), ("\\$", ""), @@ -205,7 +163,7 @@ def is_correct_minerva(solution_str: str, gt: str, gt_need_extract: bool = False else: gt = normalize_final_answer(gt) - return (pred == gt or is_equiv(pred, gt)), pred + return (pred == gt), pred def is_correct_strict_box(pred: str, gt: str, pause_tokens_index: Optional[list[int]] = None) -> tuple[int, Optional[str]]: From 85b0c8a4558ab8901e113498302359d4c5e8ad96 Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Sun, 16 Mar 2025 07:48:14 +0000 Subject: [PATCH 07/22] feat: test on ray --- ...t.sh => run_puffin_qwen2.5_7b_test_ray.sh} | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) rename recipe/puffin/{run_puffin_qwen2.5_7b_test.sh => run_puffin_qwen2.5_7b_test_ray.sh} (83%) diff --git a/recipe/puffin/run_puffin_qwen2.5_7b_test.sh b/recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh similarity index 83% rename from recipe/puffin/run_puffin_qwen2.5_7b_test.sh rename to recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh index 6252477f9ed..1b8d57988f9 100644 --- a/recipe/puffin/run_puffin_qwen2.5_7b_test.sh +++ b/recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh @@ -1,31 +1,39 @@ #!/usr/bin/env bash set -euxo pipefail -pip install "antlr4-python3-runtime==4.11.*" - project_name='puffin' exp_name='Qwen2.5-7B-Puffin-Test' +# Ray +export RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +export RUNTIME_ENV=${RUNTIME_ENV:-"./verl/trainer/runtime_env.yaml"} +export NNODES=${NNODES:-4} # Paths -MODEL_PATH=${MODEL_PATH:-"${HOME}/verl/models/Qwen2.5-7B"} -CKPTS_DIR=${CKPTS_DIR:-"${HOME}/verl/ckpts/${project_name}/${exp_name}"} -TRAIN_FILE=${TRAIN_FILE:-"${HOME}/verl/data/puffin_train.parquet"} -TEST_FILE=${TEST_FILE:-"${HOME}/verl/data/puffin_test.parquet"} +export RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +export MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-7B"} +export CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +export TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/puffin_train.parquet"} +export TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/puffin_test.parquet"} # Algorithm ## Train max_prompt_length=$((1024 * 2)) max_response_length=$((1024 * 8)) +# TODO +# force_append_eos=True ## Validation val_top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + # Mathematically equivalent use_dynamic_bsz=True infer_micro_batch_size=null train_micro_batch_size=null offload=True -python3 -m verl.trainer.main_ppo \ +ray job submit --runtime-env="${RUNTIME_ENV}" \ + --working-dir "${PWD}" \ + -- python3 -m verl.trainer.main_ppo \ data.train_files="${TRAIN_FILE}" \ data.val_files="${TEST_FILE}" \ data.prompt_key=prompt \ @@ -81,7 +89,7 @@ python3 -m verl.trainer.main_ppo \ trainer.project_name="${project_name}" \ trainer.experiment_name="${exp_name}" \ trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ + trainer.nnodes="${NNODES}" \ +trainer.val_before_train=True \ trainer.test_freq=2 \ trainer.save_freq=2 \ From 77cbbf269475ef067acd6311a8f127d99f304108 Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Sun, 16 Mar 2025 07:50:56 +0000 Subject: [PATCH 08/22] fix: reward_extra_info --- verl/trainer/ppo/ray_trainer.py | 30 ++++++++++++++-------------- verl/workers/reward_manager/naive.py | 6 +++--- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 891c4c71f5b..2460bb3c854 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -497,7 +497,7 @@ def _maybe_log_val_generations(self, inputs, outputs, scores): def _validate(self): data_source_lst = [] - extra_infos_dict: dict[str, list] = defaultdict(list) + reward_extra_infos_dict: dict[str, list] = defaultdict(list) # Lists to collect samples for the table sample_inputs = [] @@ -554,9 +554,9 @@ def _validate(self): # evaluate using reward_function result = self.val_reward_fn(test_batch, return_dict=True) reward_tensor = result["reward_tensor"] - extra_info = result["extra_info"] - for key, lst in extra_info.items(): - extra_infos_dict[key].extend(lst) + reward_extra_info = result["reward_extra_info"] + for key, lst in reward_extra_info.items(): + reward_extra_infos_dict[key].extend(lst) # Store scores scores = reward_tensor.sum(-1).cpu().tolist() @@ -568,22 +568,22 @@ def _validate(self): data_sources = np.concatenate(data_source_lst, axis=0) - sample_df = pd.DataFrame( - { - "data_source": data_sources, - "prompt": sample_inputs, - "response": sample_outputs, - "sum_reward": sample_scores, - **extra_infos_dict, - } - ) + sample_df = pd.DataFrame({ + "data_source": data_sources, + "prompt": sample_inputs, + "response": sample_outputs, + "sum_reward": sample_scores, + **reward_extra_infos_dict, + }) metric_dict = {} # Calculate metrics for each data source # First, identify all numeric columns that might be metrics - num_cols = [col for col in extra_infos_dict.keys() - if isinstance(extra_infos_dict[col][0], (int, float, bool, np.number))] + num_cols = [ + col for col in reward_extra_infos_dict.keys() + if isinstance(reward_extra_infos_dict[col][0], (int, float, bool, np.number)) + ] # Group by data_source and calculate statistics diff --git a/verl/workers/reward_manager/naive.py b/verl/workers/reward_manager/naive.py index 7091dd2f397..ebbd13ebd71 100644 --- a/verl/workers/reward_manager/naive.py +++ b/verl/workers/reward_manager/naive.py @@ -78,7 +78,7 @@ def __call__(self, data: DataProto, return_dict: bool = False): return data.batch['rm_scores'] reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32) - extra_info = defaultdict(list) + reward_extra_info = defaultdict(list) already_print_data_sources = {} @@ -123,7 +123,7 @@ def __call__(self, data: DataProto, return_dict: bool = False): reward_tensor[i, valid_response_length - 1] = reward for key, value in result.items(): - extra_info[key].append(value) + reward_extra_info[key].append(value) if data_source not in already_print_data_sources: already_print_data_sources[data_source] = 0 @@ -139,7 +139,7 @@ def __call__(self, data: DataProto, return_dict: bool = False): if return_dict: return { "reward_tensor": reward_tensor, - "extra_info": extra_info, + "reward_extra_info": reward_extra_info, } else: return reward_tensor From 2fefdcceb6240a5fcde7069c41c7e5b82ed7d621 Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Sun, 16 Mar 2025 08:51:15 +0000 Subject: [PATCH 09/22] feat: BoN / WoN --- verl/trainer/ppo/metric_utils.py | 25 ++++++++- verl/trainer/ppo/ray_trainer.py | 87 +++++++++++++++++--------------- 2 files changed, 70 insertions(+), 42 deletions(-) diff --git a/verl/trainer/ppo/metric_utils.py b/verl/trainer/ppo/metric_utils.py index 4f0859d0058..3fcc9e1bd23 100644 --- a/verl/trainer/ppo/metric_utils.py +++ b/verl/trainer/ppo/metric_utils.py @@ -16,7 +16,7 @@ """ import torch -from typing import Any, Dict, List +from typing import Any, Dict, List, Callable import numpy as np from verl import DataProto @@ -166,3 +166,26 @@ def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], n 'perf/time_per_step': time, 'perf/throughput': total_num_tokens / (time * n_gpus), } + +def bootstrap_metric(metric_vals: list[float], + subset_size: int, + reduce_fns: list[Callable[[np.ndarray], float]], + n_bootstrap: int = 1000) -> list[tuple[float, float]]: + """ + Bootstrap the metric to get the confidence interval + """ + metric_vals = np.array(metric_vals) + + # bootstrap_metrics = [] + # for _ in range(n_bootstrap): + # bootstrap_vals = np.random.choice(metric_vals, size=subset_size, replace=True) + + # bootstrap_metrics.append(reduce_fn(bootstrap_vals)) + # return np.mean(bootstrap_metrics), np.std(bootstrap_metrics) + + bootstrap_metric_lsts = [[] for _ in range(len(reduce_fns))] + for _ in range(n_bootstrap): + bootstrap_vals = np.random.choice(metric_vals, size=subset_size, replace=True) + for i, reduce_fn in enumerate(reduce_fns): + bootstrap_metric_lsts[i].append(reduce_fn(bootstrap_vals)) + return [(np.mean(lst), np.std(lst)) for lst in bootstrap_metric_lsts] diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 2460bb3c854..15fee15f01d 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -26,7 +26,6 @@ from copy import deepcopy from collections import defaultdict -import pandas as pd import ray import numpy as np from codetiming import Timer @@ -37,7 +36,7 @@ from verl.single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs from verl.single_controller.ray.base import create_colocated_worker_cls from verl.trainer.ppo import core_algos -from verl.trainer.ppo.metric_utils import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, reduce_metrics +from verl.trainer.ppo.metric_utils import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, reduce_metrics, bootstrap_metric from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn @@ -568,47 +567,53 @@ def _validate(self): data_sources = np.concatenate(data_source_lst, axis=0) - sample_df = pd.DataFrame({ - "data_source": data_sources, - "prompt": sample_inputs, - "response": sample_outputs, - "sum_reward": sample_scores, - **reward_extra_infos_dict, - }) + data_src2prompt2var2vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + for sample_idx, data_source in enumerate(data_sources): + prompt = sample_inputs[sample_idx] + + var2vals = {} + var2vals["reward_sum"].append(sample_scores[sample_idx]) + for metric_name, metric_vals in reward_extra_infos_dict.items(): + var2vals[metric_name].append(metric_vals[sample_idx]) + data_src2prompt2var2vals[data_source][prompt] = var2vals + + data_src2prompt2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) + for data_source, prompt2var2vals in data_src2prompt2var2vals.items(): + for prompt, var2vals in prompt2var2vals.items(): + n_resps = len(var2vals["reward_sum"]) + for var_name, var_vals in var2vals.items(): + metric = {} + + metric[f"mean@{n_resps}"] = np.mean(var_vals) + metric[f"std@{n_resps}"] = np.std(var_vals) + + ns = [] + n = 2 + while n < n_resps: + ns.append(n) + n *= 2 + ns.append(n_resps) + + for n in ns: + (bon_mean, bon_std), (won_mean, won_std) = bootstrap_metric(var_vals, n, [np.max, np.min]) + metric[f"best@{n}/mean"], metric[f"best@{n}/std"] = bon_mean, bon_std + metric[f"worst@{n}/mean"], metric[f"worst@{n}/std"] = won_mean, won_std + + data_src2prompt2var2metric[data_source][prompt][var_name] = metric + + data_src2var2metric2prompt_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + for data_source, prompt2var2metric in data_src2prompt2var2metric.items(): + for prompt, var2metric in prompt2var2metric.items(): + for metric_name, metric in var2metric.items(): + for metric_name, metric_val in metric.items(): + data_src2var2metric2prompt_vals[data_source][metric_name][metric_name].append(metric_val) metric_dict = {} - - # Calculate metrics for each data source - # First, identify all numeric columns that might be metrics - num_cols = [ - col for col in reward_extra_infos_dict.keys() - if isinstance(reward_extra_infos_dict[col][0], (int, float, bool, np.number)) - ] - - - # Group by data_source and calculate statistics - prompt_stats = sample_df.groupby(["data_source", "prompt"]).agg({ - **{col: ["mean", "std", "min", "max"] for col in num_cols}, - "response": "count" # Count responses per prompt - }) - - # This creates a multi-level column index, which we can flatten - prompt_stats.columns = [f"{col}_{stat}" for col, stat in prompt_stats.columns] - prompt_stats = prompt_stats.reset_index() - - # Calculate metrics for each data source - for data_source in sample_df["data_source"].unique(): - # Get stats for this data source - source_stats: pd.DataFrame = prompt_stats[prompt_stats["data_source"] == data_source] - - # Calculate mean of prompt means for each column - for stat_col in prompt_stats.columns: - # Assert each prompt has the same number of responses - uniq_resp_cnts = source_stats["response_count"].unique() - assert len(uniq_resp_cnts) == 1, f"Each prompt must have the same number of responses, but got {uniq_resp_cnts}" - resp_cnt = uniq_resp_cnts[0] - # Mean of means for this column across all prompts in this data source - metric_dict[f"{data_source}/{stat_col}@{resp_cnt}"] = source_stats[stat_col].mean() + for data_source, var2metric2prompt_vals in data_src2var2metric2prompt_vals.items(): + for metric_name, metric2prompt_vals in var2metric2prompt_vals.items(): + for metric_name, prompt_vals in metric2prompt_vals.items(): + pfx = f"{data_source}/{metric_name}/{metric_name}" + metric_dict[pfx] = np.mean(prompt_vals) val_metric_dict = {f"val/{key}": value for key, value in metric_dict.items()} return val_metric_dict From 08d63ce17fcb71ba259d26b3211fa8980a346803 Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Sun, 16 Mar 2025 08:52:59 +0000 Subject: [PATCH 10/22] feat: same top_p when validating --- recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh b/recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh index 1b8d57988f9..405f29293cf 100644 --- a/recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh +++ b/recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh @@ -77,7 +77,7 @@ ray job submit --runtime-env="${RUNTIME_ENV}" \ actor_rollout_ref.rollout.enable_chunked_prefill=True \ actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ actor_rollout_ref.rollout.val_kwargs.top_k="${val_top_k}" \ - actor_rollout_ref.rollout.val_kwargs.top_p=0.7\ + actor_rollout_ref.rollout.val_kwargs.top_p=1.0\ actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \ actor_rollout_ref.rollout.val_kwargs.n=1 \ actor_rollout_ref.rollout.val_kwargs.do_sample=True \ From 1b43855df8414f5baf236c29332ed598f94da2db Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Sun, 16 Mar 2025 08:54:17 +0000 Subject: [PATCH 11/22] feat: separate clip epsilon --- recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh b/recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh index 405f29293cf..a52d52109e9 100644 --- a/recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh +++ b/recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh @@ -44,7 +44,8 @@ ray job submit --runtime-env="${RUNTIME_ENV}" \ data.truncation='left' \ actor_rollout_ref.rollout.n=16 \ actor_rollout_ref.actor.kl_loss_coef=0 \ - actor_rollout_ref.actor.clip_ratio=0.2 \ + actor_rollout_ref.actor.clip_ratio_low=0.2 \ + actor_rollout_ref.actor.clip_ratio_high=0.25 \ algorithm.adv_estimator=grpo \ algorithm.kl_ctrl.kl_coef=0.0 \ algorithm.gamma=1.0 \ From 48a251cd1c5ad776529f76176c94aca5fa391b2f Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Sun, 16 Mar 2025 08:56:49 +0000 Subject: [PATCH 12/22] fix: remove force_append_eos --- recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh b/recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh index a52d52109e9..66de7012f0b 100644 --- a/recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh +++ b/recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh @@ -19,8 +19,6 @@ export TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/puffin_test.parquet"} ## Train max_prompt_length=$((1024 * 2)) max_response_length=$((1024 * 8)) -# TODO -# force_append_eos=True ## Validation val_top_k=-1 # 0 for HF rollout, -1 for vLLM rollout From adfaf25b08ce0d699512307a8e0a238d0090e9a3 Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Sun, 16 Mar 2025 09:02:31 +0000 Subject: [PATCH 13/22] chore: disable offload for 7B on 4 nodes --- recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh b/recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh index 66de7012f0b..65a37cfd36a 100644 --- a/recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh +++ b/recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh @@ -27,7 +27,7 @@ val_top_k=-1 # 0 for HF rollout, -1 for vLLM rollout use_dynamic_bsz=True infer_micro_batch_size=null train_micro_batch_size=null -offload=True +offload=False ray job submit --runtime-env="${RUNTIME_ENV}" \ --working-dir "${PWD}" \ From 71ddbe79bf38d9f2878f70ca1f14cf22c9524b0a Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Sun, 16 Mar 2025 09:18:49 +0000 Subject: [PATCH 14/22] chore: no-wait --- recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh b/recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh index 65a37cfd36a..bd490c61e5c 100644 --- a/recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh +++ b/recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh @@ -29,7 +29,7 @@ infer_micro_batch_size=null train_micro_batch_size=null offload=False -ray job submit --runtime-env="${RUNTIME_ENV}" \ +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ --working-dir "${PWD}" \ -- python3 -m verl.trainer.main_ppo \ data.train_files="${TRAIN_FILE}" \ From 4888a0b741ca067ae41e0dc9e0b85cb67ce4b7cf Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Sun, 16 Mar 2025 09:20:03 +0000 Subject: [PATCH 15/22] feat: use Qwen2.5-Math-7B for test --- ...uffin_qwen2.5_7b_test_ray.sh => run_puffin_7b_test_ray.sh} | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) rename recipe/puffin/{run_puffin_qwen2.5_7b_test_ray.sh => run_puffin_7b_test_ray.sh} (98%) diff --git a/recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh b/recipe/puffin/run_puffin_7b_test_ray.sh similarity index 98% rename from recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh rename to recipe/puffin/run_puffin_7b_test_ray.sh index bd490c61e5c..873155a7f85 100644 --- a/recipe/puffin/run_puffin_qwen2.5_7b_test_ray.sh +++ b/recipe/puffin/run_puffin_7b_test_ray.sh @@ -2,7 +2,7 @@ set -euxo pipefail project_name='puffin' -exp_name='Qwen2.5-7B-Puffin-Test' +exp_name='Qwen2.5-7B-Math-Puffin-Test' # Ray export RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} @@ -10,7 +10,7 @@ export RUNTIME_ENV=${RUNTIME_ENV:-"./verl/trainer/runtime_env.yaml"} export NNODES=${NNODES:-4} # Paths export RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} -export MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-7B"} +export MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} export CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} export TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/puffin_train.parquet"} export TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/puffin_test.parquet"} From b866eab1b12e6346522803207e3190ee663f292c Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Sun, 16 Mar 2025 09:24:56 +0000 Subject: [PATCH 16/22] feat: majority voting --- verl/trainer/ppo/metric_utils.py | 29 +++++++++++++++++--------- verl/trainer/ppo/ray_trainer.py | 18 ++++++++++++---- verl/utils/reward_score/math_puffin.py | 12 +++++------ 3 files changed, 39 insertions(+), 20 deletions(-) diff --git a/verl/trainer/ppo/metric_utils.py b/verl/trainer/ppo/metric_utils.py index 3fcc9e1bd23..9ceeef9e68a 100644 --- a/verl/trainer/ppo/metric_utils.py +++ b/verl/trainer/ppo/metric_utils.py @@ -19,6 +19,7 @@ from typing import Any, Dict, List, Callable import numpy as np from verl import DataProto +from collections import Counter def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]: @@ -167,25 +168,33 @@ def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], n 'perf/throughput': total_num_tokens / (time * n_gpus), } -def bootstrap_metric(metric_vals: list[float], +def bootstrap_metric(vals: list[float], subset_size: int, reduce_fns: list[Callable[[np.ndarray], float]], n_bootstrap: int = 1000) -> list[tuple[float, float]]: """ Bootstrap the metric to get the confidence interval """ - metric_vals = np.array(metric_vals) - - # bootstrap_metrics = [] - # for _ in range(n_bootstrap): - # bootstrap_vals = np.random.choice(metric_vals, size=subset_size, replace=True) - - # bootstrap_metrics.append(reduce_fn(bootstrap_vals)) - # return np.mean(bootstrap_metrics), np.std(bootstrap_metrics) + vals = np.array(vals) bootstrap_metric_lsts = [[] for _ in range(len(reduce_fns))] for _ in range(n_bootstrap): - bootstrap_vals = np.random.choice(metric_vals, size=subset_size, replace=True) + bootstrap_vals = np.random.choice(vals, size=subset_size, replace=True) for i, reduce_fn in enumerate(reduce_fns): bootstrap_metric_lsts[i].append(reduce_fn(bootstrap_vals)) return [(np.mean(lst), np.std(lst)) for lst in bootstrap_metric_lsts] + +def calc_maj_val(val_key_pairs: list[tuple[float, str]]) -> float: + """ + Calculate the majority voting metric + """ + keys = [pair[1] for pair in val_key_pairs] + # TODO: use clustering method beyond string + key_counter = Counter(keys) + maj_key = key_counter.most_common(1)[0][0] + + for val, key in val_key_pairs: + if key == maj_key: + return val + + raise ValueError("No majority voting metric found") diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 15fee15f01d..d42450d63ad 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -36,7 +36,7 @@ from verl.single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs from verl.single_controller.ray.base import create_colocated_worker_cls from verl.trainer.ppo import core_algos -from verl.trainer.ppo.metric_utils import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, reduce_metrics, bootstrap_metric +from verl.trainer.ppo.metric_utils import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, reduce_metrics, bootstrap_metric, calc_maj_val from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn @@ -571,17 +571,19 @@ def _validate(self): for sample_idx, data_source in enumerate(data_sources): prompt = sample_inputs[sample_idx] - var2vals = {} + var2vals = data_src2prompt2var2vals[data_source][prompt] var2vals["reward_sum"].append(sample_scores[sample_idx]) for metric_name, metric_vals in reward_extra_infos_dict.items(): var2vals[metric_name].append(metric_vals[sample_idx]) - data_src2prompt2var2vals[data_source][prompt] = var2vals data_src2prompt2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) for data_source, prompt2var2vals in data_src2prompt2var2vals.items(): for prompt, var2vals in prompt2var2vals.items(): n_resps = len(var2vals["reward_sum"]) + preds = var2vals["pred"] for var_name, var_vals in var2vals.items(): + if var_name in ["pred", "reward_sum"]: + continue metric = {} metric[f"mean@{n_resps}"] = np.mean(var_vals) @@ -594,10 +596,18 @@ def _validate(self): n *= 2 ns.append(n_resps) + val_pred_pairs = list(zip(var_vals, preds)) for n in ns: - (bon_mean, bon_std), (won_mean, won_std) = bootstrap_metric(var_vals, n, [np.max, np.min]) + (bon_mean, bon_std), (won_mean, won_std), (maj_n_mean, maj_n_std) = bootstrap_metric( + vals=val_pred_pairs, + subset_size=n, + reduce_fns=[ + lambda arr: np.max([x[0] for x in arr]), lambda arr: np.min([x[0] for x in arr]), + calc_maj_val + ]) metric[f"best@{n}/mean"], metric[f"best@{n}/std"] = bon_mean, bon_std metric[f"worst@{n}/mean"], metric[f"worst@{n}/std"] = won_mean, won_std + metric[f"maj@{n}/mean"], metric[f"maj@{n}/std"] = maj_n_mean, maj_n_std data_src2prompt2var2metric[data_source][prompt][var_name] = metric diff --git a/verl/utils/reward_score/math_puffin.py b/verl/utils/reward_score/math_puffin.py index a3f1cea87eb..752eafe201c 100644 --- a/verl/utils/reward_score/math_puffin.py +++ b/verl/utils/reward_score/math_puffin.py @@ -205,12 +205,11 @@ def verify(solution_str: str, answer: str, strict_box_verify: bool = False, True if the solution is correct, False otherwise """ if strict_box_verify: - corr, _ = is_correct_strict_box(solution_str, answer, pause_tokens_index) - return corr == 1 - - corr, _ = is_correct_minerva(solution_str, answer) - return corr + correct, pred = is_correct_strict_box(solution_str, answer, pause_tokens_index) + return correct == 1, pred + correct, pred = is_correct_minerva(solution_str, answer) + return correct, pred def compute_score(solution_str: str, ground_truth: str, @@ -231,7 +230,7 @@ def compute_score(solution_str: str, solution_str = solution_str[-300:] # The longest answer in MATH-500 has 159 characters # Verify the solution - correct = verify(solution_str, ground_truth, strict_box_verify, pause_tokens_index) + correct, pred = verify(solution_str, ground_truth, strict_box_verify, pause_tokens_index) reward = 1.0 if correct else -1.0 acc = correct @@ -239,4 +238,5 @@ def compute_score(solution_str: str, return { "reward": reward, "acc": acc, + "pred": pred, } From 0bde9e46a8ac049058d7a1b341e94132e9503aaa Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Sun, 16 Mar 2025 09:28:44 +0000 Subject: [PATCH 17/22] fix: context length for Qwen2.5-Math-7B --- recipe/puffin/run_puffin_7b_test_ray.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/recipe/puffin/run_puffin_7b_test_ray.sh b/recipe/puffin/run_puffin_7b_test_ray.sh index 873155a7f85..39a19453b82 100644 --- a/recipe/puffin/run_puffin_7b_test_ray.sh +++ b/recipe/puffin/run_puffin_7b_test_ray.sh @@ -17,8 +17,8 @@ export TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/puffin_test.parquet"} # Algorithm ## Train -max_prompt_length=$((1024 * 2)) -max_response_length=$((1024 * 8)) +max_prompt_length=$((1024 * 1)) +max_response_length=$((1024 * 3)) ## Validation val_top_k=-1 # 0 for HF rollout, -1 for vLLM rollout From 1fe5eadb09d68aaef2df24db9588747b3617337b Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Sun, 16 Mar 2025 09:50:17 +0000 Subject: [PATCH 18/22] fix: majority voting --- verl/trainer/ppo/metric_utils.py | 37 +++++++++++++++++--------------- verl/trainer/ppo/ray_trainer.py | 11 ++++++---- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/verl/trainer/ppo/metric_utils.py b/verl/trainer/ppo/metric_utils.py index 9ceeef9e68a..49f54d1a614 100644 --- a/verl/trainer/ppo/metric_utils.py +++ b/verl/trainer/ppo/metric_utils.py @@ -19,7 +19,7 @@ from typing import Any, Dict, List, Callable import numpy as np from verl import DataProto -from collections import Counter +from collections import Counter, defaultdict def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]: @@ -168,33 +168,36 @@ def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], n 'perf/throughput': total_num_tokens / (time * n_gpus), } -def bootstrap_metric(vals: list[float], +def bootstrap_metric(data: list[dict[str, Any]], subset_size: int, reduce_fns: list[Callable[[np.ndarray], float]], - n_bootstrap: int = 1000) -> list[tuple[float, float]]: + n_bootstrap: int = 1000, + seed: int = 42 + ) -> list[tuple[float, float]]: """ Bootstrap the metric to get the confidence interval """ - vals = np.array(vals) + np.random.seed(seed) bootstrap_metric_lsts = [[] for _ in range(len(reduce_fns))] for _ in range(n_bootstrap): - bootstrap_vals = np.random.choice(vals, size=subset_size, replace=True) + bootstrap_idxs = np.random.choice(len(data), size=subset_size, replace=True) + bootstrap_data = [data[i] for i in bootstrap_idxs] for i, reduce_fn in enumerate(reduce_fns): - bootstrap_metric_lsts[i].append(reduce_fn(bootstrap_vals)) + bootstrap_metric_lsts[i].append(reduce_fn(bootstrap_data)) return [(np.mean(lst), np.std(lst)) for lst in bootstrap_metric_lsts] -def calc_maj_val(val_key_pairs: list[tuple[float, str]]) -> float: +def calc_maj_val(data: list[dict[str, Any]], vote_key: str, val_key: str) -> float: """ Calculate the majority voting metric """ - keys = [pair[1] for pair in val_key_pairs] - # TODO: use clustering method beyond string - key_counter = Counter(keys) - maj_key = key_counter.most_common(1)[0][0] - - for val, key in val_key_pairs: - if key == maj_key: - return val - - raise ValueError("No majority voting metric found") + vote2vals = defaultdict(list) + for d in data: + vote2vals[d[vote_key]].append(d[val_key]) + + vote2cnt = {k: len(v) for k, v in vote2vals.items()} + maj_vote = max(vote2cnt, key=vote2cnt.get) + + maj_val = vote2vals[maj_vote][0] + + return maj_val diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index d42450d63ad..90239c49cdc 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -25,6 +25,7 @@ from typing import Type, Dict from copy import deepcopy from collections import defaultdict +from functools import partial import ray import numpy as np @@ -596,14 +597,16 @@ def _validate(self): n *= 2 ns.append(n_resps) - val_pred_pairs = list(zip(var_vals, preds)) + data = [{"val": val, "pred": pred} for val, pred in zip(var_vals, preds)] for n in ns: + (bon_mean, bon_std), (won_mean, won_std), (maj_n_mean, maj_n_std) = bootstrap_metric( - vals=val_pred_pairs, + data, subset_size=n, reduce_fns=[ - lambda arr: np.max([x[0] for x in arr]), lambda arr: np.min([x[0] for x in arr]), - calc_maj_val + lambda arr: np.max([d["val"] for d in arr]), + lambda arr: np.min([d["val"] for d in arr]), + partial(calc_maj_val, vote_key="pred", val_key="val") ]) metric[f"best@{n}/mean"], metric[f"best@{n}/std"] = bon_mean, bon_std metric[f"worst@{n}/mean"], metric[f"worst@{n}/std"] = won_mean, won_std From 118674b14c555038aac8f4fddedc8afabab0722d Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Sun, 16 Mar 2025 10:08:11 +0000 Subject: [PATCH 19/22] fix: var_name --- verl/trainer/ppo/ray_trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 90239c49cdc..0fbcc515e1f 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -619,13 +619,13 @@ def _validate(self): for prompt, var2metric in prompt2var2metric.items(): for metric_name, metric in var2metric.items(): for metric_name, metric_val in metric.items(): - data_src2var2metric2prompt_vals[data_source][metric_name][metric_name].append(metric_val) + data_src2var2metric2prompt_vals[data_source][var_name][metric_name].append(metric_val) metric_dict = {} for data_source, var2metric2prompt_vals in data_src2var2metric2prompt_vals.items(): - for metric_name, metric2prompt_vals in var2metric2prompt_vals.items(): + for var_name, metric2prompt_vals in var2metric2prompt_vals.items(): for metric_name, prompt_vals in metric2prompt_vals.items(): - pfx = f"{data_source}/{metric_name}/{metric_name}" + pfx = f"{data_source}/{var_name}/{metric_name}" metric_dict[pfx] = np.mean(prompt_vals) val_metric_dict = {f"val/{key}": value for key, value in metric_dict.items()} From 0df61dc51b07fa4146f733069c5d59ff2a57bb98 Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Sun, 16 Mar 2025 11:37:35 +0000 Subject: [PATCH 20/22] fix: eos_token --- verl/workers/reward_manager/naive.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/verl/workers/reward_manager/naive.py b/verl/workers/reward_manager/naive.py index ebbd13ebd71..177cea0c732 100644 --- a/verl/workers/reward_manager/naive.py +++ b/verl/workers/reward_manager/naive.py @@ -48,6 +48,9 @@ def verify(self, data): # decode prompt_str = self.tokenizer.decode(valid_prompt_ids) response_str = self.tokenizer.decode(valid_response_ids) + eos_token = self.tokenizer.eos_token + if response_str.endswith(eos_token): + response_str = response_str[:-len(eos_token)] ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth'] From fe6e2fd86aeb5a7c023b995224eb764bfb02514f Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Sun, 16 Mar 2025 11:41:00 +0000 Subject: [PATCH 21/22] fix: eos_token --- verl/workers/reward_manager/naive.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/verl/workers/reward_manager/naive.py b/verl/workers/reward_manager/naive.py index 177cea0c732..b67c176da4d 100644 --- a/verl/workers/reward_manager/naive.py +++ b/verl/workers/reward_manager/naive.py @@ -48,9 +48,6 @@ def verify(self, data): # decode prompt_str = self.tokenizer.decode(valid_prompt_ids) response_str = self.tokenizer.decode(valid_response_ids) - eos_token = self.tokenizer.eos_token - if response_str.endswith(eos_token): - response_str = response_str[:-len(eos_token)] ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth'] @@ -102,6 +99,9 @@ def __call__(self, data: DataProto, return_dict: bool = False): # decode prompt_str = self.tokenizer.decode(valid_prompt_ids) response_str = self.tokenizer.decode(valid_response_ids) + eos_token = self.tokenizer.eos_token + if response_str.endswith(eos_token): + response_str = response_str[:-len(eos_token)] ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth'] From dc23c23f53a0f8dd8b22f1dfa3a064ad823d1e4a Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Sun, 16 Mar 2025 20:59:56 +0800 Subject: [PATCH 22/22] [reward] feat: overlong buffer (#621) --- recipe/puffin/run_puffin_7b_test_ray.sh | 15 +++- verl/protocol.py | 35 ++++++++ verl/trainer/config/ppo_trainer.yaml | 11 ++- verl/trainer/main_ppo.py | 8 +- verl/trainer/ppo/ray_trainer.py | 106 ++++++++++++++++++------ verl/workers/reward_manager/naive.py | 25 +++++- 6 files changed, 164 insertions(+), 36 deletions(-) diff --git a/recipe/puffin/run_puffin_7b_test_ray.sh b/recipe/puffin/run_puffin_7b_test_ray.sh index 39a19453b82..77c9c8de894 100644 --- a/recipe/puffin/run_puffin_7b_test_ray.sh +++ b/recipe/puffin/run_puffin_7b_test_ray.sh @@ -19,6 +19,9 @@ export TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/puffin_test.parquet"} ## Train max_prompt_length=$((1024 * 1)) max_response_length=$((1024 * 3)) +gen_prompt_bsz=512 +train_prompt_bsz=512 +train_prompt_mini_bsz=32 ## Validation val_top_k=-1 # 0 for HF rollout, -1 for vLLM rollout @@ -38,7 +41,8 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ data.truncation='left' \ data.max_prompt_length=${max_prompt_length} \ data.max_response_length=${max_response_length} \ - data.train_batch_size=512 \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ data.truncation='left' \ actor_rollout_ref.rollout.n=16 \ actor_rollout_ref.actor.kl_loss_coef=0 \ @@ -46,8 +50,9 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ actor_rollout_ref.actor.clip_ratio_high=0.25 \ algorithm.adv_estimator=grpo \ algorithm.kl_ctrl.kl_coef=0.0 \ - algorithm.gamma=1.0 \ - algorithm.lam=0.95 \ + algorithm.filter_groups.enable=True \ + algorithm.filter_groups.fill_train_batch=True \ + algorithm.filter_groups.drop_last_mini_batch=True \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ @@ -63,7 +68,7 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ actor_rollout_ref.actor.optim.weight_decay=0.1 \ - actor_rollout_ref.actor.ppo_mini_batch_size=512 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ @@ -84,6 +89,8 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \ actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + custom_reward_function.overlong_buffer.len=512 \ + custom_reward_function.overlong_buffer.penalty_factor=1.0 \ trainer.logger=['console','wandb'] \ trainer.project_name="${project_name}" \ trainer.experiment_name="${exp_name}" \ diff --git a/verl/protocol.py b/verl/protocol.py index 94d7410f687..a525feb9bb2 100644 --- a/verl/protocol.py +++ b/verl/protocol.py @@ -370,6 +370,41 @@ def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=Non return DataProto(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info) + def sel_idxs(self, idxs): + """ + Select specific indices from the DataProto. + + Args: + idxs (torch.Tensor or numpy.ndarray or list): Indices to select + + Returns: + DataProto: A new DataProto containing only the selected indices + """ + if isinstance(idxs, list): + idxs = torch.tensor(idxs, dtype=torch.int32) + + if isinstance(idxs, np.ndarray): + idxs_np = idxs + idxs_torch = torch.from_numpy(idxs) + else: # torch.Tensor + idxs_torch = idxs + idxs_np = idxs.detach().cpu().numpy() + + if self.batch is not None: + # Use TensorDict's built-in indexing capabilities + selected_batch = TensorDict(source={ + key: tensor[idxs_torch] for key, tensor in self.batch.items() + }, + batch_size=(idxs_torch.shape[0],)) + else: + selected_batch = None + + selected_non_tensor = {} + for key, val in self.non_tensor_batch.items(): + selected_non_tensor[key] = val[idxs_np] + + return DataProto(batch=selected_batch, non_tensor_batch=selected_non_tensor, meta_info=self.meta_info) + def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> 'DataProto': """Pop a subset of the DataProto via `batch_keys` and `meta_info_keys` diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 40dd2e6fe94..54dfb619ca6 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -6,7 +6,8 @@ data: reward_fn_key: data_source max_prompt_length: 512 max_response_length: 512 - train_batch_size: 1024 + gen_batch_size: 1024 + train_batch_size: ${data.gen_batch_size} val_batch_size: null # DEPRECATED: Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs return_raw_chat: False @@ -167,6 +168,10 @@ reward_model: custom_reward_function: path: null name: compute_score + overlong_buffer: + len: 0 + penalty_factor: 1.0 + log: False algorithm: gamma: 1.0 @@ -176,6 +181,10 @@ algorithm: kl_ctrl: type: fixed kl_coef: 0.001 + filter_groups: + enable: False + fill_train_batch: True + drop_last_mini_batch: True trainer: balance_batch: True diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index 65a7865c92e..31623ad7b30 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -146,10 +146,14 @@ def main_task(config): raise NotImplementedError compute_score = get_custom_reward_fn(config) - reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score, reward_fn_key=config.data.reward_fn_key) + reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score, reward_fn_key=config.data.reward_fn_key, + max_resp_len=config.data.max_response_length, + overlong_buffer_cfg=config.custom_reward_function.overlong_buffer) # Note that we always use function-based RM for validation - val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score, reward_fn_key=config.data.reward_fn_key) + val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score, reward_fn_key=config.data.reward_fn_key, + max_resp_len=config.data.max_response_length, + overlong_buffer_cfg=config.custom_reward_function.overlong_buffer) resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 0fbcc515e1f..e091d23c45f 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -316,6 +316,10 @@ def _validate_config(self): # number of GPUs total n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes + if not config.algorithm.filter_groups.enable: + assert config.data.train_batch_size == config.data.gen_batch_size, \ + f"train_batch_size must be equal to gen_batch_size when filter_groups.enable is False, but got {config.data.train_batch_size =} and {config.data.gen_batch_size =}" + # 1. Check total batch size for data correctness real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n assert real_train_batch_size % n_gpus == 0, \ @@ -424,7 +428,7 @@ def _create_dataloader(self): sampler = SequentialSampler(data_source=self.train_dataset) self.train_dataloader = StatefulDataLoader(dataset=self.train_dataset, - batch_size=self.config.data.train_batch_size, + batch_size=self.config.data.gen_batch_size, num_workers=8, drop_last=True, collate_fn=collate_fn, @@ -573,17 +577,17 @@ def _validate(self): prompt = sample_inputs[sample_idx] var2vals = data_src2prompt2var2vals[data_source][prompt] - var2vals["reward_sum"].append(sample_scores[sample_idx]) + var2vals["final_reward"].append(sample_scores[sample_idx]) for metric_name, metric_vals in reward_extra_infos_dict.items(): var2vals[metric_name].append(metric_vals[sample_idx]) data_src2prompt2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) for data_source, prompt2var2vals in data_src2prompt2var2vals.items(): for prompt, var2vals in prompt2var2vals.items(): - n_resps = len(var2vals["reward_sum"]) + n_resps = len(var2vals["final_reward"]) preds = var2vals["pred"] for var_name, var_vals in var2vals.items(): - if var_name in ["pred", "reward_sum"]: + if var_name in ["pred", "final_reward"]: continue metric = {} @@ -617,7 +621,7 @@ def _validate(self): data_src2var2metric2prompt_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) for data_source, prompt2var2metric in data_src2prompt2var2metric.items(): for prompt, var2metric in prompt2var2metric.items(): - for metric_name, metric in var2metric.items(): + for var_name, metric in var2metric.items(): for metric_name, metric_val in metric.items(): data_src2var2metric2prompt_vals[data_source][var_name][metric_name].append(metric_val) @@ -881,6 +885,77 @@ def fit(self): batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) batch = batch.union(gen_batch_output) + with _timer('reward', timing_raw): + # compute scores. Support both model and function-based. + # We first compute the scores using reward model. Then, we call reward_fn to combine + # the results from reward model and rule-based results. + if self.use_rm: + # we first compute reward model score + reward_tensor = self.rm_wg.compute_rm_score(batch) + batch = batch.union(reward_tensor) + + # we combine with rule-based rm + reward_tensor = self.reward_fn(batch) + batch.batch['token_level_scores'] = reward_tensor + + # compute rewards. apply_kl_penalty if available + if not self.config.actor_rollout_ref.actor.get('use_kl_loss', False): + batch, kl_metrics = apply_kl_penalty(batch, + kl_ctrl=self.kl_ctrl, + kl_penalty=self.config.algorithm.kl_penalty) + metrics.update(kl_metrics) + else: + batch.batch['token_level_rewards'] = batch.batch['token_level_scores'] + + if self.config.algorithm.filter_groups.enable: + filter_metric_dict = {} + + uid2seq_rewards = defaultdict(list) + for uid, tok_rewards in zip(batch.non_tensor_batch['uid'], batch.batch['token_level_rewards']): + seq_reward = torch.sum(tok_rewards).item() + uid2seq_rewards[uid].append(seq_reward) + + uid2seq_reward_std = {} + for uid, seq_rewards in uid2seq_rewards.items(): + uid2seq_reward_std[uid] = np.std(seq_rewards) + + kept_uids = [uid for uid, std in uid2seq_reward_std.items() if std > 0] + filter_metric_dict["non_uni_rew_prompt_ratio"] = len(kept_uids) / len(uid2seq_rewards) + filter_metric_dict["non_uni_rew_prompt_bsz"] = len(kept_uids) + + kept_idxs = [] + + + train_prompt_bsz = len(batch.batch) + fill_train_batch = self.config.algorithm.filter_groups.fill_train_batch + if len(kept_uids) > train_prompt_bsz or not fill_train_batch: + kept_uids = kept_uids[:train_prompt_bsz] + else: + for uid in uid2seq_reward_std.keys(): + if uid not in kept_uids: + kept_uids.append(uid) + if len(kept_uids) == train_prompt_bsz: + break + + for idx, uid in enumerate(batch.non_tensor_batch['uid']): + if uid in kept_uids: + kept_idxs.append(idx) + filter_metric_dict["non_uni_rew_traj_bsz"] = len(kept_idxs) + + world_size = self.actor_rollout_wg.world_size + kept_idxs = kept_idxs[:len(kept_idxs) // world_size * world_size] + if self.config.algorithm.filter_groups.drop_last_mini_batch: + train_traj_mini_bsz = self.config.actor_rollout_ref.actor.ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n + if len(kept_idxs) > train_traj_mini_bsz: + kept_idxs = kept_idxs[:len(kept_idxs) // train_traj_mini_bsz * train_traj_mini_bsz] + else: + print(f'[WARNING] {len(kept_idxs)=} < {train_traj_mini_bsz=}') + + filter_metric_dict["final_traj_ratio"] = len(kept_idxs) / len(batch.batch) + filter_metric_dict["final_traj_bsz"] = len(kept_idxs) + + batch = batch.sel_idxs(kept_idxs) + # balance the number of valid tokens on each dp rank. # Note that this breaks the order of data inside the batch. # Please take care when you implement group based adv computation such as GRPO and rloo @@ -908,27 +983,6 @@ def fit(self): batch = batch.union(values) with _timer('adv', timing_raw): - # compute scores. Support both model and function-based. - # We first compute the scores using reward model. Then, we call reward_fn to combine - # the results from reward model and rule-based results. - if self.use_rm: - # we first compute reward model score - reward_tensor = self.rm_wg.compute_rm_score(batch) - batch = batch.union(reward_tensor) - - # we combine with rule-based rm - reward_tensor = self.reward_fn(batch) - batch.batch['token_level_scores'] = reward_tensor - - # compute rewards. apply_kl_penalty if available - if not self.config.actor_rollout_ref.actor.get('use_kl_loss', False): - batch, kl_metrics = apply_kl_penalty(batch, - kl_ctrl=self.kl_ctrl, - kl_penalty=self.config.algorithm.kl_penalty) - metrics.update(kl_metrics) - else: - batch.batch['token_level_rewards'] = batch.batch['token_level_scores'] - # compute advantages, executed on the driver process batch = compute_advantage(batch, adv_estimator=self.config.algorithm.adv_estimator, diff --git a/verl/workers/reward_manager/naive.py b/verl/workers/reward_manager/naive.py index b67c176da4d..23e42b79175 100644 --- a/verl/workers/reward_manager/naive.py +++ b/verl/workers/reward_manager/naive.py @@ -22,11 +22,16 @@ class NaiveRewardManager: """The reward manager. """ - def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key='data_source') -> None: + def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key='data_source', max_resp_len=None, overlong_buffer_cfg = None) -> None: self.tokenizer = tokenizer self.num_examine = num_examine # the number of batches of decoded responses to print to the console self.compute_score = compute_score or _default_compute_score self.reward_fn_key = reward_fn_key + self.overlong_buffer_cfg = overlong_buffer_cfg + self.max_resp_len = max_resp_len + + if self.overlong_buffer_cfg is not None: + assert self.max_resp_len is not None, f"max_resp_len must be provided if {overlong_buffer_cfg=}, but got None" # TODO: Is this still necessary in algorithms other than PRIME? def verify(self, data): @@ -116,6 +121,8 @@ def __call__(self, data: DataProto, return_dict: bool = False): extra_info=extra_info, ) + final_reward = 0 + reward: float if isinstance(result, dict): assert "reward" in result @@ -123,11 +130,23 @@ def __call__(self, data: DataProto, return_dict: bool = False): else: reward = result - reward_tensor[i, valid_response_length - 1] = reward - for key, value in result.items(): reward_extra_info[key].append(value) + final_reward += reward + + overlong_buffer_len = self.overlong_buffer_cfg.len + if overlong_buffer_len > 0: + overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor + exceed_len = valid_response_length - (self.max_resp_len - overlong_buffer_len) + overlong_reward = max(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0) + final_reward += overlong_reward + if self.overlong_buffer_cfg.log: + reward_extra_info["overlong_reward"].append(overlong_reward) + reward_extra_info["overlong"].append(overlong_reward < 0) + + reward_tensor[i, valid_response_length - 1] = final_reward + if data_source not in already_print_data_sources: already_print_data_sources[data_source] = 0