diff --git a/docs/algo/repexp.md b/docs/algo/repexp.md new file mode 100644 index 00000000000..456b70760a4 --- /dev/null +++ b/docs/algo/repexp.md @@ -0,0 +1,68 @@ +# Recipe: Representation-Based Exploration (RepExp) + +Last updated: 11/14/2025. + +
+ +Representation-Based Exploration for Language Models:
From Test-Time to Post-Training + +[📄 arXiv](https://arxiv.org/abs/2510.11686)     [🌐 Website](https://rep-exp.github.io)     [🐦 Twitter / X ](https://x.com/JensTuyls/status/1978244454617128993) + +
+ + +## Installation 🔌 + +Our algorithm doesn't require anything beyond the base verl installation, which you can find [here](https://verl.readthedocs.io/en/latest/start/install.html). + +## Running the Experiments 🚀 + +You can reproduce or extend our experiments by running the following commands: + +```bash +# General format +sh recipe/rep_exp/train_elliptical.sh $TASK $SPARSE_DIM $BETA $SEED + +# MATH +sh recipe/rep_exp/train_elliptical.sh math 32 0.01 42 + +# GSM8K +sh recipe/rep_exp/train_elliptical.sh gsm8k 32 0.01 42 + +# DAPO-WITH-AIME +sh recipe/rep_exp/train_elliptical.sh dapo-with-aime24 128 0.01 42 +``` +where `$TASK` is the task name, `$SPARSE_DIM` is the sparse dimension, `$BETA` is the beta parameter, and `$SEED` is the seed. + +## Evaluation 📊 +Once done training, you can evaluate the model on the test set by following two steps. +1. Merge the model checkpoint. + +This is necessary because the model checkpoint is saved in multiple shards (depending on the nubmer of GPUs), and we need to merge them into a single checkpoint. + +```bash +sh recipe/rep_exp/model_merge.sh /path/to/global_step_X/actor # where X is the global step of the checkpoint with the best pass@1 on dev +``` + +2. Evaluate the merged model. + +```bash +sh recipe/rep_exp/eval.sh $TASK /path/to/global_step_X/actor/hf #where X is the global step of the checkpoint with the best pass@1 on dev +``` + +The results should be in a folder named `eval` and saved as a JSON file. + +## Citation 📝 + +```bibtex +@article{tuyls2025representation, + title={Representation-Based Exploration for Language Models: From Test-Time to Post-Training}, + author={Tuyls, Jens and Foster, Dylan J and Krishnamurthy, Akshay and Ash, Jordan T}, + journal={arXiv preprint arXiv:2510.11686}, + year={2025} +} +``` + +## Contact 📬 + +If you have any questions or suggestions, feel free to reach out at [jtuyls@princeton.edu](mailto:jtuyls@princeton.edu). \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 7ddb244ac02..e91e748ac0e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -75,6 +75,7 @@ verl is fast with: algo/spin.md algo/sppo.md algo/entropy.md + algo/repexp.md algo/opo.md algo/baseline.md algo/gpg.md diff --git a/recipe/rep_exp/README.md b/recipe/rep_exp/README.md new file mode 100644 index 00000000000..0fcfacdc355 --- /dev/null +++ b/recipe/rep_exp/README.md @@ -0,0 +1,66 @@ +
+ +# Representation-Based Exploration for Language Models:
From Test-Time to Post-Training + +[📄 arXiv](https://arxiv.org/abs/2510.11686)     [🌐 Website](https://rep-exp.github.io)     [🐦 Twitter / X ](https://x.com/JensTuyls/status/1978244454617128993) + +
+ +## Installation 🔌 + +Besides the base verl installation, which you can find [here](https://verl.readthedocs.io/en/latest/start/install.html), the only package to install is scikit-learn. +```bash +pip install scikit-learn +``` + +## Running the Experiments 🚀 + +You can reproduce or extend our experiments by running the following commands: + +```bash +# General format +sh recipe/rep_exp/train_elliptical.sh $TASK $SPARSE_DIM $BETA $SEED + +# MATH +sh recipe/rep_exp/train_elliptical.sh math 32 0.01 42 + +# GSM8K +sh recipe/rep_exp/train_elliptical.sh gsm8k 32 0.01 42 + +# DAPO-WITH-AIME +sh recipe/rep_exp/train_elliptical.sh dapo-with-aime24 128 0.01 42 +``` +where `$TASK` is the task name, `$SPARSE_DIM` is the sparse dimension, `$BETA` is the beta parameter, and `$SEED` is the seed. + +## Evaluation 📊 +Once done training, you can evaluate the model on the test set by following two steps. +1. Merge the model checkpoint. + +This is necessary because the model checkpoint is saved in multiple shards (depending on the nubmer of GPUs), and we need to merge them into a single checkpoint. + +```bash +sh recipe/rep_exp/model_merge.sh /path/to/global_step_X/actor # where X is the global step of the checkpoint with the best pass@1 on dev +``` + +2. Evaluate the merged model. + +```bash +sh recipe/rep_exp/eval.sh $TASK /path/to/global_step_X/actor/hf #where X is the global step of the checkpoint with the best pass@1 on dev +``` + +The results should be in a folder named `eval` and saved as a JSON file. + +## Citation 📝 + +```bibtex +@article{tuyls2025representation, + title={Representation-Based Exploration for Language Models: From Test-Time to Post-Training}, + author={Tuyls, Jens and Foster, Dylan J and Krishnamurthy, Akshay and Ash, Jordan T}, + journal={arXiv preprint arXiv:2510.11686}, + year={2025} +} +``` + +## Contact 📬 + +If you have any questions or suggestions, feel free to reach out at [jtuyls@princeton.edu](mailto:jtuyls@princeton.edu). \ No newline at end of file diff --git a/recipe/rep_exp/config/rep_exp_trainer.yaml b/recipe/rep_exp/config/rep_exp_trainer.yaml new file mode 100644 index 00000000000..30fe1c84a31 --- /dev/null +++ b/recipe/rep_exp/config/rep_exp_trainer.yaml @@ -0,0 +1,33 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + +reward_model: + elliptical: + enable: True + lamb: 0.01 + normalization: none # none, rnd, z_score + reward_type: leave_one_out # leave_one_out, leverage + sparse_dim: 512 + randomize_sparse_matrix: True + persist_covariance: False + + reward_kwargs: + elliptical: + alpha: 1.0 + beta: 1.0 + turn_off_elliptical_if_none_correct: True + turn_off_elliptical_if_some_correct: False + turn_off_elliptical_if_all_correct: False + turn_off_elliptical_if_rollout_incorrect: False + +actor_rollout_ref: + rollout: + val_kwargs: + temperature: 1.0 + n: 128 + do_sample: True diff --git a/recipe/rep_exp/data_preprocess/dapo_with_aime.py b/recipe/rep_exp/data_preprocess/dapo_with_aime.py new file mode 100644 index 00000000000..db3dd03d42f --- /dev/null +++ b/recipe/rep_exp/data_preprocess/dapo_with_aime.py @@ -0,0 +1,104 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Preprocess DAPO dataset to parquet format +""" + +import argparse +import os + +import datasets +import numpy as np + +from verl.utils.hdfs_io import copy, makedirs + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--local_dir", default="~/data/dapo-with-aime24") + parser.add_argument("--hdfs_dir", default=None) + parser.add_argument("--dapo_dataset_path", type=str, default="ftajwar/deduplicated_dapo_dataset") + parser.add_argument("--aime24_part_1_dataset_path", type=str, default="MathArena/aime_2024_I") + parser.add_argument("--aime24_part_2_dataset_path", type=str, default="MathArena/aime_2024_II") + parser.add_argument("--train_size", type=int, default=4096) + + args = parser.parse_args() + + data_source = "math_dapo" + + # Load DAPO dataset for training + dapo_dataset_path = args.dapo_dataset_path + dapo_dataset = datasets.load_dataset(dapo_dataset_path, trust_remote_code=True) + + # Load AIME 2024 part 1 dataset for testing + aime24_dataset_path_part_1 = args.aime24_part_1_dataset_path + aime24_dataset_part_1 = datasets.load_dataset(aime24_dataset_path_part_1, trust_remote_code=True) + + # Load AIME 2024 part 2 dataset for testing + aime24_dataset_path_part_2 = args.aime24_part_2_dataset_path + aime24_dataset_part_2 = datasets.load_dataset(aime24_dataset_path_part_2, trust_remote_code=True) + + train_dataset = dapo_dataset["train"] + train_dataset = train_dataset.select(np.random.choice(len(train_dataset), size=args.train_size, replace=False)) + + dev_dataset_aime24_part_1 = aime24_dataset_part_1["train"] + dev_dataset_aime24_part_2 = aime24_dataset_part_2["train"] + dev_dataset = datasets.concatenate_datasets([dev_dataset_aime24_part_1, dev_dataset_aime24_part_2]) + + instruction_following = "Let's think step by step and output the final answer within \\boxed{}." + + # add a row to each data item that represents a unique id + def make_map_fn(split): + def process_fn(example, idx): + if "prompt" in example: + question = example.pop("prompt") + elif "problem" in example: + question = example.pop("problem") + else: + raise ValueError(f"Unknown question type: {example}") + + question = question + " " + instruction_following + + if "answer" in example: + solution = example.pop("answer") + else: + raise ValueError(f"Unknown answer type: {example}") + solution = str(solution) + + data = { + "data_source": data_source, + "prompt": [{"role": "user", "content": question}], + "ability": "math", + "reward_model": { + "style": "rule", + "ground_truth": solution, + }, + "extra_info": {"split": split, "index": idx}, + } + return data + + return process_fn + + train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) + dev_dataset = dev_dataset.map(function=make_map_fn("test"), with_indices=True) + + local_dir = args.local_dir + hdfs_dir = args.hdfs_dir + + train_dataset.to_parquet(os.path.join(local_dir, "train.parquet")) + dev_dataset.to_parquet(os.path.join(local_dir, "dev.parquet")) + + if hdfs_dir is not None: + makedirs(hdfs_dir) + + copy(src=local_dir, dst=hdfs_dir) diff --git a/recipe/rep_exp/data_preprocess/gsm8k.py b/recipe/rep_exp/data_preprocess/gsm8k.py new file mode 100644 index 00000000000..e4d8cf4fc85 --- /dev/null +++ b/recipe/rep_exp/data_preprocess/gsm8k.py @@ -0,0 +1,112 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Preprocess the GSM8k dataset to parquet format +""" + +import argparse +import os +import re + +import datasets +import numpy as np + +from verl.utils.hdfs_io import copy, makedirs + + +def extract_solution(solution_str): + solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str) + assert solution is not None + final_solution = solution.group(0) + final_solution = final_solution.split("#### ")[1].replace(",", "") + return final_solution + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--local_dir", default=None, help="The save directory for the preprocessed dataset.") + parser.add_argument("--hdfs_dir", default=None) + parser.add_argument("--local_dataset_path", default=None, help="The local path to the raw dataset, if it exists.") + parser.add_argument( + "--local_save_dir", default="~/data/gsm8k", help="The save directory for the preprocessed dataset." + ) + + args = parser.parse_args() + local_dataset_path = args.local_dataset_path + + data_source = "openai/gsm8k" + + if local_dataset_path is not None: + dataset = datasets.load_dataset(local_dataset_path, "main") + else: + dataset = datasets.load_dataset(data_source, "main") + + train_dataset = dataset["train"] + test_dataset = dataset["test"] + + instruction_following = 'Let\'s think step by step and output the final answer after "####".' + + # add a row to each data item that represents a unique id + def make_map_fn(split): + def process_fn(example, idx): + question_raw = example.pop("question") + + question = question_raw + " " + instruction_following + + answer_raw = example.pop("answer") + solution = extract_solution(answer_raw) + data = { + "data_source": data_source, + "prompt": [ + { + "role": "user", + "content": question, + } + ], + "ability": "math", + "reward_model": {"style": "rule", "ground_truth": solution}, + "extra_info": { + "split": split, + "index": idx, + "answer": answer_raw, + "question": question_raw, + }, + } + return data + + return process_fn + + train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) + test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True) + # split test into dev and test by picking random subset of 512 examples + all_test_indices = range(len(test_dataset)) + all_test_indices = list(all_test_indices) + np.random.shuffle(all_test_indices) + dev_dataset = test_dataset.select(all_test_indices[:512]) + test_dataset = test_dataset.select(all_test_indices[512:]) + + hdfs_dir = args.hdfs_dir + local_save_dir = args.local_dir + if local_save_dir is not None: + print("Warning: Argument 'local_dir' is deprecated. Please use 'local_save_dir' instead.") + else: + local_save_dir = args.local_save_dir + + train_dataset.to_parquet(os.path.join(local_save_dir, "train.parquet")) + test_dataset.to_parquet(os.path.join(local_save_dir, "test.parquet")) + + if hdfs_dir is not None: + makedirs(hdfs_dir) + + copy(src=local_save_dir, dst=hdfs_dir) diff --git a/recipe/rep_exp/data_preprocess/math_dataset.py b/recipe/rep_exp/data_preprocess/math_dataset.py new file mode 100644 index 00000000000..1ae35ea93f4 --- /dev/null +++ b/recipe/rep_exp/data_preprocess/math_dataset.py @@ -0,0 +1,595 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Preprocess the MATH-lighteval dataset to parquet format +""" + +import argparse +import os + +import datasets + +from verl.utils.hdfs_io import copy, makedirs +from verl.utils.reward_score.math import last_boxed_only_string, remove_boxed + +# These are the MATH-500 indices +DEV_INDICES = [ + 4, + 6, + 15, + 18, + 34, + 36, + 37, + 41, + 45, + 64, + 66, + 85, + 92, + 100, + 120, + 127, + 133, + 136, + 149, + 160, + 161, + 162, + 166, + 168, + 202, + 215, + 243, + 247, + 256, + 260, + 270, + 320, + 361, + 367, + 381, + 392, + 396, + 411, + 450, + 451, + 452, + 460, + 496, + 501, + 503, + 505, + 511, + 513, + 520, + 534, + 563, + 564, + 571, + 576, + 579, + 587, + 596, + 601, + 607, + 609, + 612, + 615, + 622, + 666, + 673, + 683, + 684, + 695, + 700, + 703, + 709, + 718, + 722, + 738, + 748, + 757, + 761, + 762, + 782, + 805, + 817, + 834, + 840, + 849, + 853, + 854, + 859, + 882, + 885, + 888, + 906, + 909, + 933, + 941, + 962, + 978, + 985, + 988, + 991, + 1008, + 1033, + 1037, + 1046, + 1048, + 1054, + 1058, + 1067, + 1073, + 1085, + 1088, + 1095, + 1111, + 1119, + 1123, + 1127, + 1128, + 1131, + 1136, + 1144, + 1145, + 1150, + 1172, + 1173, + 1180, + 1188, + 1190, + 1194, + 1196, + 1215, + 1243, + 1250, + 1251, + 1258, + 1262, + 1271, + 1281, + 1285, + 1287, + 1290, + 1302, + 1308, + 1311, + 1312, + 1322, + 1339, + 1359, + 1374, + 1380, + 1402, + 1441, + 1442, + 1449, + 1513, + 1531, + 1540, + 1543, + 1552, + 1555, + 1576, + 1603, + 1612, + 1620, + 1690, + 1710, + 1715, + 1730, + 1764, + 1767, + 1769, + 1788, + 1790, + 1791, + 1801, + 1806, + 1820, + 1842, + 1843, + 1880, + 1890, + 1897, + 1901, + 1905, + 1908, + 1932, + 1935, + 1940, + 1963, + 1967, + 1981, + 1996, + 2001, + 2006, + 2011, + 2041, + 2047, + 2053, + 2057, + 2062, + 2063, + 2078, + 2110, + 2119, + 2120, + 2143, + 2148, + 2150, + 2151, + 2170, + 2186, + 2191, + 2196, + 2199, + 2210, + 2214, + 2215, + 2217, + 2231, + 2236, + 2237, + 2238, + 2246, + 2253, + 2263, + 2264, + 2275, + 2289, + 2294, + 2297, + 2303, + 2311, + 2323, + 2324, + 2325, + 2327, + 2328, + 2334, + 2352, + 2359, + 2360, + 2371, + 2382, + 2384, + 2397, + 2404, + 2409, + 2413, + 2416, + 2473, + 2505, + 2512, + 2515, + 2522, + 2536, + 2539, + 2546, + 2569, + 2571, + 2579, + 2602, + 2607, + 2609, + 2611, + 2622, + 2628, + 2637, + 2647, + 2681, + 2682, + 2700, + 2707, + 2731, + 2752, + 2758, + 2767, + 2799, + 2802, + 2808, + 2816, + 2838, + 2851, + 2863, + 2868, + 2876, + 2883, + 2896, + 2907, + 2937, + 2938, + 2946, + 2966, + 2977, + 2991, + 2994, + 3018, + 3019, + 3020, + 3022, + 3024, + 3035, + 3037, + 3046, + 3047, + 3058, + 3067, + 3072, + 3079, + 3080, + 3105, + 3126, + 3134, + 3141, + 3165, + 3181, + 3186, + 3187, + 3196, + 3200, + 3210, + 3220, + 3226, + 3236, + 3240, + 3246, + 3287, + 3295, + 3299, + 3317, + 3320, + 3323, + 3334, + 3341, + 3342, + 3344, + 3350, + 3352, + 3365, + 3366, + 3369, + 3375, + 3392, + 3404, + 3411, + 3417, + 3419, + 3420, + 3440, + 3444, + 3447, + 3460, + 3467, + 3474, + 3480, + 3498, + 3507, + 3511, + 3519, + 3529, + 3539, + 3541, + 3548, + 3549, + 3569, + 3586, + 3604, + 3607, + 3646, + 3647, + 3658, + 3669, + 3700, + 3711, + 3725, + 3730, + 3732, + 3738, + 3740, + 3741, + 3752, + 3768, + 3769, + 3773, + 3779, + 3802, + 3805, + 3824, + 3849, + 3856, + 3878, + 3913, + 3923, + 3941, + 3942, + 3951, + 3982, + 3990, + 3994, + 3999, + 4011, + 4034, + 4036, + 4042, + 4043, + 4046, + 4055, + 4071, + 4074, + 4088, + 4090, + 4104, + 4108, + 4127, + 4149, + 4150, + 4155, + 4157, + 4158, + 4160, + 4177, + 4181, + 4190, + 4193, + 4210, + 4222, + 4235, + 4242, + 4253, + 4265, + 4272, + 4279, + 4297, + 4303, + 4315, + 4326, + 4333, + 4352, + 4368, + 4384, + 4404, + 4413, + 4423, + 4425, + 4441, + 4449, + 4451, + 4479, + 4487, + 4500, + 4515, + 4523, + 4533, + 4535, + 4547, + 4549, + 4550, + 4569, + 4584, + 4590, + 4591, + 4597, + 4600, + 4603, + 4610, + 4626, + 4657, + 4666, + 4678, + 4697, + 4706, + 4713, + 4731, + 4744, + 4751, + 4753, + 4758, + 4765, + 4776, + 4796, + 4812, + 4834, + 4850, + 4857, + 4861, + 4866, + 4868, + 4871, + 4885, + 4896, + 4900, + 4909, + 4914, + 4924, + 4926, + 4947, + 4955, + 4964, + 4969, + 4978, + 4990, + 4992, + 4993, +] + + +def extract_solution(solution_str): + return remove_boxed(last_boxed_only_string(solution_str)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--local_dir", default="~/data/math") + parser.add_argument("--hdfs_dir", default=None) + + args = parser.parse_args() + + # 'lighteval/MATH' is no longer available on huggingface. + # Use mirror repo: DigitalLearningGmbH/MATH-lighteval + data_source = "DigitalLearningGmbH/MATH-lighteval" + print(f"Loading the {data_source} dataset from huggingface...", flush=True) + dataset = datasets.load_dataset(data_source, trust_remote_code=True) + + train_dataset = dataset["train"] + test_dataset = dataset["test"] + + instruction_following = "Let's think step by step and output the final answer within \\boxed{}." + + # add a row to each data item that represents a unique id + def make_map_fn(split): + def process_fn(example, idx): + question = example.pop("problem") + + question = question + " " + instruction_following + + answer = example.pop("solution") + solution = extract_solution(answer) + data = { + "data_source": data_source, + "prompt": [{"role": "user", "content": question}], + "ability": "math", + "reward_model": {"style": "rule", "ground_truth": solution}, + "extra_info": {"split": split, "index": idx}, + } + return data + + return process_fn + + train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) + test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True) + + # Split test into dev and test + dev_indices_set = set(DEV_INDICES) + dev_dataset = test_dataset.select(DEV_INDICES) + + def filter_dev_indices(example, idx): + return idx not in dev_indices_set + + test_dataset = test_dataset.filter(filter_dev_indices, with_indices=True) + + local_dir = args.local_dir + hdfs_dir = args.hdfs_dir + + train_dataset.to_parquet(os.path.join(local_dir, "train.parquet")) + dev_dataset.to_parquet(os.path.join(local_dir, "dev.parquet")) + test_dataset.to_parquet(os.path.join(local_dir, "test.parquet")) + + if hdfs_dir is not None: + makedirs(hdfs_dir) + + copy(src=local_dir, dst=hdfs_dir) diff --git a/recipe/rep_exp/eval.sh b/recipe/rep_exp/eval.sh new file mode 100644 index 00000000000..cb667bcd698 --- /dev/null +++ b/recipe/rep_exp/eval.sh @@ -0,0 +1,83 @@ +TASK=${1} # math, gsm8k, dapo-with-aime24 + +# Custom model path for evaluation after training +MODEL_PATH=${2} # /path/to/global_step_X/actor/hf, where X is the global step of the checkpoint with the best pass@1 on dev + +# If you want to evaluate the base model before training +# MODEL_PATH=Qwen/Qwen2.5-7B-Instruct + +train_path=$HOME/data/${TASK}/train.parquet +train_files="['$train_path']" +CHECKPOINT_SAVE_CONTENTS='["model"]' + +if [ ${TASK} == "dapo-with-aime24" ]; then + MAX_PROMPT_LENGTH=$((1024 * 2)) + MAX_RESPONSE_LENGTH=$((1024 * 8)) + MAX_NUM_BATCHED_TOKENS=$((MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH)) + test_path=$HOME/data/${TASK}/dev.parquet +else + MAX_PROMPT_LENGTH=1024 + MAX_RESPONSE_LENGTH=1024 + MAX_NUM_BATCHED_TOKENS=8192 + test_path=$HOME/data/${TASK}/test.parquet +fi + +test_files="['$test_path']" + +# If you're on a cluster with no internet access, set to OFFLINE=True +OFFLINE=False + +PYTHONUNBUFFERED=1 WANDB_MODE=disabled TRANSFORMERS_OFFLINE=${OFFLINE} python3 -u -m recipe.rep_exp.main_rep_exp \ + algorithm.adv_estimator=grpo \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=1024 \ + data.max_prompt_length=$MAX_PROMPT_LENGTH \ + data.max_response_length=$MAX_RESPONSE_LENGTH \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.val_batch_size=128 \ + actor_rollout_ref.model.path="$MODEL_PATH" \ + actor_rollout_ref.actor.checkpoint.save_contents=$CHECKPOINT_SAVE_CONTENTS \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.actor.kl_loss_coef=0.0 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.ppo_epochs=1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.mode=sync \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.max_num_batched_tokens=$MAX_NUM_BATCHED_TOKENS \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.95 \ + actor_rollout_ref.rollout.val_kwargs.n=256 \ + actor_rollout_ref.rollout.n=8 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + reward_model.model.path="$MODEL_PATH" \ + reward_model.model.use_remove_padding=False \ + reward_model.model.fsdp_config.param_offload=True \ + reward_model.micro_batch_size_per_gpu=32 \ + reward_model.model.input_tokenizer=null \ + actor_rollout_ref.actor.use_kl_loss=False \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","json_eval"]' \ + trainer.project_name='rep-exp' \ + trainer.experiment_name="${TASK}_eval" \ + trainer.n_gpus_per_node=1 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=1 \ + trainer.total_epochs=100 \ + trainer.val_only=True \ + trainer.resume_mode=disable \ + trainer.resume_from_path='' + +exit 0 diff --git a/recipe/rep_exp/main_rep_exp.py b/recipe/rep_exp/main_rep_exp.py new file mode 100644 index 00000000000..ad7068f12c1 --- /dev/null +++ b/recipe/rep_exp/main_rep_exp.py @@ -0,0 +1,483 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. +""" + +import os +import socket +import warnings + +import hydra +import ray +from omegaconf import OmegaConf + +from verl.experimental.dataset.sampler import AbstractSampler +from verl.trainer.constants_ppo import get_ppo_ray_runtime_env +from verl.trainer.ppo.reward import load_reward_manager +from verl.trainer.ppo.utils import need_critic, need_reference_policy +from verl.utils.config import validate_config +from verl.utils.device import is_cuda_available +from verl.utils.import_utils import load_extern_type + +from .rep_exp_trainer import RayRepExpTrainer + + +@hydra.main(config_path="config", config_name="rep_exp_trainer", version_base=None) +def main(config): + """Main entry point for PPO training with Hydra configuration management. + + Args: + config_dict: Hydra configuration dictionary containing training parameters. + """ + run_ppo(config) + + +# Define a function to run the PPO-like training process +def run_ppo(config, task_runner_class=None) -> None: + """Initialize Ray cluster and run distributed PPO training process. + + Args: + config: Training configuration object containing all necessary parameters + for distributed PPO training including Ray initialization settings, + model paths, and training hyperparameters. + task_runner_class: For recipe to change TaskRunner. + """ + # Check if Ray is not initialized + if not ray.is_initialized(): + # Initialize Ray with a local cluster configuration + # Set environment variables in the runtime environment to control tokenizer parallelism, + # NCCL debug level, VLLM logging level, and allow runtime LoRA updating + # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration + default_runtime_env = get_ppo_ray_runtime_env() + ray_init_kwargs = config.ray_kwargs.get("ray_init", {}) + runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {}) + + if config.transfer_queue.enable: + # Add runtime environment variables for transfer queue + runtime_env_vars = runtime_env_kwargs.get("env_vars", {}) + runtime_env_vars["TRANSFER_QUEUE_ENABLE"] = "1" + runtime_env_kwargs["env_vars"] = runtime_env_vars + + runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs) + ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env}) + print(f"ray init kwargs: {ray_init_kwargs}") + ray.init(**OmegaConf.to_container(ray_init_kwargs)) + + if task_runner_class is None: + task_runner_class = ray.remote(num_cpus=1)(TaskRunner) # please make sure main_task is not scheduled on head + + # Create a remote instance of the TaskRunner class, and + # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete + if ( + is_cuda_available + and config.global_profiler.tool == "nsys" + and config.global_profiler.get("steps") is not None + and len(config.global_profiler.get("steps", [])) > 0 + ): + from verl.utils.import_utils import is_nvtx_available + + assert is_nvtx_available(), "nvtx is not available in CUDA platform. Please 'pip3 install nvtx'" + nsight_options = OmegaConf.to_container( + config.global_profiler.global_tool_config.nsys.controller_nsight_options + ) + runner = task_runner_class.options(runtime_env={"nsight": nsight_options}).remote() + else: + runner = task_runner_class.remote() + ray.get(runner.run.remote(config)) + + # [Optional] get the path of the timeline trace file from the configuration, default to None + # This file is used for performance analysis + timeline_json_file = config.ray_kwargs.get("timeline_json_file", None) + if timeline_json_file: + ray.timeline(filename=timeline_json_file) + + +class TaskRunner: + """Ray remote class for executing distributed PPO training tasks. + + This class encapsulates the main training logic and runs as a Ray remote actor + to enable distributed execution across multiple nodes and GPUs. + + Attributes: + role_worker_mapping: Dictionary mapping Role enums to Ray remote worker classes + mapping: Dictionary mapping Role enums to resource pool IDs for GPU allocation + """ + + def __init__(self): + self.role_worker_mapping = {} + self.mapping = {} + + def add_actor_rollout_worker(self, config): + """Add actor rollout worker based on the actor strategy.""" + from verl.single_controller.ray import RayWorkerGroup + from verl.trainer.ppo.ray_trainer import Role + + use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") + + # use new model engine implementation + if use_legacy_worker_impl == "disable": + from verl.workers.engine_workers import ActorRolloutRefWorker + + actor_rollout_cls = ActorRolloutRefWorker + ray_worker_group_cls = RayWorkerGroup + # NOTE: In new model engine, ref policy and actor rollout are in same ActorRolloutRefWorker, + # while in legacy model engine, ref policy is in a separate ActorRolloutRefWorker. + if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + role = Role.ActorRolloutRef + else: + role = Role.ActorRollout + self.role_worker_mapping[role] = ray.remote(actor_rollout_cls) + self.mapping[role] = "global_pool" + return actor_rollout_cls, ray_worker_group_cls + + if config.actor_rollout_ref.rollout.mode == "sync": + warnings.warn("spmd rollout mode is deprecated and will be removed in v0.6.2", stacklevel=2) + + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: + from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker + + actor_rollout_cls = ( + AsyncActorRolloutRefWorker + if config.actor_rollout_ref.rollout.mode == "async" + else ActorRolloutRefWorker + ) + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == "megatron": + from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker + + actor_rollout_cls = ( + AsyncActorRolloutRefWorker + if config.actor_rollout_ref.rollout.mode == "async" + else ActorRolloutRefWorker + ) + ray_worker_group_cls = RayWorkerGroup + + else: + raise NotImplementedError + + self.role_worker_mapping[Role.ActorRollout] = ray.remote(actor_rollout_cls) + self.mapping[Role.ActorRollout] = "global_pool" + return actor_rollout_cls, ray_worker_group_cls + + def add_critic_worker(self, config): + """Add critic worker to role mapping.""" + if config.critic.strategy in {"fsdp", "fsdp2"}: + use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") + if use_legacy_worker_impl in ["auto", "enable"]: + from verl.workers.fsdp_workers import CriticWorker + elif use_legacy_worker_impl == "disable": + from verl.workers.roles import CriticWorker + + print("Using new worker implementation") + else: + raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}") + + elif config.critic.strategy == "megatron": + from verl.workers.megatron_workers import CriticWorker + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import Role + + self.role_worker_mapping[Role.Critic] = ray.remote(CriticWorker) + self.mapping[Role.Critic] = "global_pool" + + def init_resource_pool_mgr(self, config): + """Initialize resource pool manager.""" + + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + # TODO Here you can use the new registration method to support dynamic registration of roles + if config.reward_model.enable_resource_pool: + if config.reward_model.n_gpus_per_node <= 0: + raise ValueError("config.reward_model.n_gpus_per_node must be greater than 0") + if config.reward_model.nnodes <= 0: + raise ValueError("config.reward_model.nnodes must be greater than 0") + + reward_pool = [config.reward_model.n_gpus_per_node] * config.reward_model.nnodes + resource_pool_spec["reward_pool"] = reward_pool + + from verl.trainer.ppo.ray_trainer import ResourcePoolManager + + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=self.mapping) + return resource_pool_manager + + def add_reward_model_worker(self, config): + """Add reward model worker if enabled.""" + from verl.trainer.ppo.ray_trainer import Role + + if config.reward_model.enable: + use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") + if use_legacy_worker_impl in ["auto", "enable"]: + if config.reward_model.strategy in {"fsdp", "fsdp2"}: + if config.reward_model.elliptical: + from .workers.elliptical_reward_model_worker import ( + EllipticalRewardModelWorker as RewardModelWorker, + ) + else: + from verl.workers.fsdp_workers import RewardModelWorker + elif config.reward_model.strategy == "megatron": + from verl.workers.megatron_workers import RewardModelWorker + else: + raise NotImplementedError + elif use_legacy_worker_impl == "disable": + from verl.workers.roles import RewardModelWorker + + print("Using new worker implementation") + else: + raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}") + + self.role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) + if config.reward_model.enable_resource_pool: + self.mapping[Role.RewardModel] = "reward_pool" + else: + self.mapping[Role.RewardModel] = "global_pool" + + def add_ref_policy_worker(self, config, ref_policy_cls): + """Add reference policy worker if KL loss or KL reward is used.""" + from verl.trainer.ppo.ray_trainer import Role + + # Ref policy has been fused into ActorRolloutRefWorker in new model engine, + # we don't need to add a separate ref policy worker goup. + use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") + if use_legacy_worker_impl == "disable": + return + + if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + self.role_worker_mapping[Role.RefPolicy] = ray.remote(ref_policy_cls) + self.mapping[Role.RefPolicy] = "global_pool" + + def run(self, config): + """Execute the main PPO training workflow. + + This method sets up the distributed training environment, initializes + workers, datasets, and reward functions, then starts the training process. + + Args: + config: Training configuration object containing all parameters needed + for setting up and running the PPO training process. + """ + # Print the initial configuration. `resolve=True` will evaluate symbolic values. + from pprint import pprint + + from omegaconf import OmegaConf + + from verl.utils.fs import copy_to_local + + print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") + pprint(OmegaConf.to_container(config, resolve=True)) + OmegaConf.resolve(config) + + actor_rollout_cls, ray_worker_group_cls = self.add_actor_rollout_worker(config) + self.add_critic_worker(config) + + # We should adopt a multi-source reward function here: + # - for rule-based rm, we directly call a reward score + # - for model-based rm, we call a model + # - for code related prompt, we send to a sandbox if there are test cases + # finally, we combine all the rewards together + # The reward type depends on the tag of the data + self.add_reward_model_worker(config) + + # Add a reference policy worker if KL loss or KL reward is used. + self.add_ref_policy_worker(config, actor_rollout_cls) + + # validate config + validate_config( + config=config, + use_reference_policy=need_reference_policy(self.role_worker_mapping), + use_critic=need_critic(config), + ) + + # Download the checkpoint from HDFS to the local machine. + # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on + local_path = copy_to_local( + config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False) + ) + + # Instantiate the tokenizer and processor. + from verl.utils import hf_processor, hf_tokenizer + + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + # Used for multimodal LLM, could be None + processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) + + # Make sure the elliptical reward manager is registered + from .reward_manager.elliptical_reward_manager import EllipticalRewardManager # noqa: F401 + + # Load the reward manager for training and validation. + reward_manager_name = config.reward_model.get("reward_manager", "naive") + reward_fn = load_reward_manager( + config, + tokenizer, + num_examine=0, + **config.reward_model.get("reward_kwargs", {}).get(reward_manager_name, {}), + ) + val_reward_fn = load_reward_manager( + config, + tokenizer, + num_examine=1, + **config.reward_model.get("reward_kwargs", {}).get(reward_manager_name, {}), + ) + + resource_pool_manager = self.init_resource_pool_mgr(config) + + from verl.utils.dataset.rl_dataset import collate_fn + + # Create training and validation datasets. + train_dataset = create_rl_dataset( + config.data.train_files, + config.data, + tokenizer, + processor, + is_train=True, + max_samples=config.data.get("train_max_samples", -1), + ) + val_dataset = create_rl_dataset( + config.data.val_files, + config.data, + tokenizer, + processor, + is_train=False, + max_samples=config.data.get("val_max_samples", -1), + ) + train_sampler = create_rl_sampler(config.data, train_dataset) + + # Initialize the PPO trainer. + trainer = RayRepExpTrainer( + config=config, + tokenizer=tokenizer, + processor=processor, + role_worker_mapping=self.role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + train_dataset=train_dataset, + val_dataset=val_dataset, + collate_fn=collate_fn, + train_sampler=train_sampler, + ) + # Initialize the workers of the trainer. + trainer.init_workers() + + # Start the training process. + trainer.fit() + + +def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=True, max_samples: int = -1): + """Create a dataset. + + Arguments: + data_paths: List of paths to data files. + data_config: The data config. + tokenizer (Tokenizer): The tokenizer. + processor (Processor): The processor. + + Returns: + dataset (Dataset): The dataset. + """ + from torch.utils.data import Dataset + + from verl.utils.dataset.rl_dataset import RLHFDataset + + # Check if a custom dataset class is specified in the data configuration + # and if the path to the custom class is provided + if "custom_cls" in data_config and data_config.custom_cls.get("path", None) is not None: + # Dynamically load the custom dataset class + dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name) + # Verify that the custom dataset class inherits from torch.utils.data.Dataset + if not issubclass(dataset_cls, Dataset): + raise TypeError( + f"The custom dataset class '{data_config.custom_cls.name}' from " + f"'{data_config.custom_cls.path}' must inherit from torch.utils.data.Dataset" + ) + elif "datagen" in data_config and data_config.datagen.get("path", None) is not None and is_train: + # If a data generation strategy is specified, use the DynamicGenDataset class + from verl.utils.dataset.dynamicgen_dataset import DynamicGenDataset + + dataset_cls = DynamicGenDataset + print("Using DynamicGenDataset for data generation.") + else: + # Use the default RLHFDataset class if no custom class is specified + dataset_cls = RLHFDataset + print(f"Using dataset class: {dataset_cls.__name__}") + + # Instantiate the dataset using the determined dataset class + dataset = dataset_cls( + data_files=data_paths, + tokenizer=tokenizer, + processor=processor, + config=data_config, + max_samples=max_samples, + ) + + return dataset + + +def create_rl_sampler(data_config, dataset): + """Create a sampler for the dataset. + + Arguments: + data_config: The data config. + dataset (Dataset): The dataset. + + Returns: + sampler (Sampler): The sampler. + """ + import torch + from torch.utils.data import SequentialSampler + + # torch.utils.data.RandomSampler could not recover properly + from torchdata.stateful_dataloader.sampler import RandomSampler + + if data_config.sampler is not None and data_config.sampler.get("class_path", None) is not None: + curriculum_class = load_extern_type( + data_config.sampler.class_path, + data_config.sampler.class_name, + ) + sampler = curriculum_class( + data_source=dataset, + data_config=data_config, + ) + assert isinstance(sampler, AbstractSampler) + assert data_config.get("dataloader_num_workers", 8) == 0, ( + "If using curriculum, num_workers must be 0 to prevent data caching. " + "If the dataloader caches data before the batch is done the " + "curriculum sampler won't have the opportunity to reorder it. " + ) + + # Use a sampler to facilitate checkpoint resumption. + # If shuffling is enabled in the data configuration, create a random sampler. + elif data_config.shuffle: + train_dataloader_generator = torch.Generator() + seed = data_config.get("seed") + if seed is not None: + train_dataloader_generator.manual_seed(seed) + sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator) + else: + # If shuffling is disabled, use a sequential sampler to iterate through the dataset in order. + sampler = SequentialSampler(data_source=dataset) + + return sampler + + +if __name__ == "__main__": + main() diff --git a/recipe/rep_exp/metric_utils.py b/recipe/rep_exp/metric_utils.py new file mode 100644 index 00000000000..519a5b6c0f8 --- /dev/null +++ b/recipe/rep_exp/metric_utils.py @@ -0,0 +1,382 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Metrics related to the RepExp trainer. +""" + +from collections import defaultdict +from functools import partial +from typing import Any + +import numpy as np +import torch + +from verl import DataProto +from verl.trainer.ppo.metric_utils import _compute_response_info, bootstrap_metric, calc_maj_val + + +def _compute_three_case_stats(data: DataProto, extrinsic_reward_tensor: torch.Tensor) -> dict: + """ + Compute the fraction of samples that have no rollouts correct, some rollouts correct, and all rollouts correct. + + Args: + data (DataProto): The data proto containing the batch data. + extrinsic_reward_tensor (torch.Tensor): The extrinsic reward tensor. + + Returns: + dict[str, float]: A dictionary containing the fraction of samples that have no rollouts correct, + some rollouts correct, and all rollouts correct. + """ + no_rollouts_correct = 0 + some_rollouts_correct = 0 + all_rollouts_correct = 0 + + visited_uids = set() + for uid in data.non_tensor_batch["uid"]: + if uid in visited_uids: + continue + + visited_uids.add(uid) + mask = torch.from_numpy(data.non_tensor_batch["uid"] == uid) + + # Split into three cases + if extrinsic_reward_tensor[mask].sum() == 0: + no_rollouts_correct += 1 + elif extrinsic_reward_tensor[mask].sum() == mask.sum(): + all_rollouts_correct += 1 + elif extrinsic_reward_tensor[mask].sum() > 0 and extrinsic_reward_tensor[mask].sum() < mask.sum(): + some_rollouts_correct += 1 + else: + raise ValueError(f"Invalid extrinsic reward tensor: {extrinsic_reward_tensor[mask].sum()}") + + # Sanity checks + assert len(visited_uids) == no_rollouts_correct + some_rollouts_correct + all_rollouts_correct + + return { + "no_rollouts_correct_frac": no_rollouts_correct / len(visited_uids), + "some_rollouts_correct_frac": some_rollouts_correct / len(visited_uids), + "all_rollouts_correct_frac": all_rollouts_correct / len(visited_uids), + } + + +def compute_data_metrics(batch: DataProto, use_critic: bool = True, elliptical: bool = False) -> dict[str, Any]: + """ + Computes various metrics from a batch of data for PPO training. + + This function calculates metrics related to scores, rewards, advantages, returns, values, + and sequence lengths from a batch of data. It provides statistical information (mean, max, min) + for each metric category. + + Args: + batch: A DataProto object containing batch data with token-level scores, rewards, advantages, etc. + use_critic: Whether to include critic-specific metrics. Defaults to True. + elliptical: Whether to include elliptical-specific metrics. Defaults to False. + + Returns: + A dictionary of metrics including: + - critic/score/mean, max, min: Statistics about sequence scores + - critic/rewards/mean, max, min: Statistics about sequence rewards + - critic/advantages/mean, max, min: Statistics about advantages + - critic/returns/mean, max, min: Statistics about returns + - critic/values/mean, max, min: Statistics about critic values (if use_critic=True) + - critic/vf_explained_var: Explained variance of the value function (if use_critic=True) + - response_length/mean, max, min, clip_ratio: Statistics about response lengths + - prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths + - num_turns/mean, max, min: Statistics about the number of multi-turn conversations + """ + sequence_score = batch.batch["token_level_scores"].sum(-1) + sequence_reward = batch.batch["token_level_rewards"].sum(-1) + + if elliptical: + sequence_intrinsic_reward = batch.non_tensor_batch["intrinsic_reward"].sum(-1) + sequence_beta_scaled_intrinsic_reward = batch.non_tensor_batch["beta_scaled_intrinsic_reward"].sum(-1) + sequence_extrinsic_reward = batch.non_tensor_batch["extrinsic_reward"].sum(-1) + sequence_total_reward = batch.non_tensor_batch["total_reward"].sum(-1) + sequence_raw_bonuses = batch.non_tensor_batch["raw_bonuses"].sum(-1) + + three_case_stats = _compute_three_case_stats(batch, batch.non_tensor_batch["extrinsic_reward"]) + + advantages = batch.batch["advantages"] + returns = batch.batch["returns"] + + max_response_length = batch.batch["responses"].shape[-1] + + prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool() + response_mask = batch.batch["response_mask"].bool() + + max_prompt_length = prompt_mask.size(-1) + + response_info = _compute_response_info(batch) + prompt_length = response_info["prompt_length"] + response_length = response_info["response_length"] + + aborted_mask = (response_length == 0).bool() + non_aborted_mask = ~aborted_mask + + non_aborted_sequence_score = sequence_score[non_aborted_mask] + non_aborted_sequence_reward = sequence_reward[non_aborted_mask] + + score_mean = torch.mean(non_aborted_sequence_score).detach().item() + score_max = torch.max(non_aborted_sequence_score).detach().item() + score_min = torch.min(non_aborted_sequence_score).detach().item() + + reward_mean = torch.mean(non_aborted_sequence_reward).detach().item() + reward_max = torch.max(non_aborted_sequence_reward).detach().item() + reward_min = torch.min(non_aborted_sequence_reward).detach().item() + + valid_adv = torch.masked_select(advantages, response_mask) + valid_returns = torch.masked_select(returns, response_mask) + + if use_critic: + values = batch.batch["values"] + valid_values = torch.masked_select(values, response_mask) + return_diff_var = torch.var(valid_returns - valid_values) + return_var = torch.var(valid_returns) + + # Aborted samples and non-aborted response length statistics + # response_length_non_aborted/*: statistics computed on non-aborted samples only + aborted_ratio = torch.mean(aborted_mask.float()).detach().item() + + non_aborted_response_length = response_length[non_aborted_mask] + if non_aborted_response_length.numel() > 0: + non_aborted_response_length_mean = torch.mean(non_aborted_response_length).detach().item() + non_aborted_response_length_max = torch.max(non_aborted_response_length).detach().item() + non_aborted_response_length_min = torch.min(non_aborted_response_length).detach().item() + non_aborted_response_length_clip_ratio = ( + torch.mean(torch.eq(non_aborted_response_length, max_response_length).float()).detach().item() + ) + else: + raise ValueError("All samples are aborted, this should not happen.") + + metrics = { + # score + "critic/score/mean": score_mean, + "critic/score/max": score_max, + "critic/score/min": score_min, + # reward + "critic/rewards/mean": reward_mean, + "critic/rewards/max": reward_max, + "critic/rewards/min": reward_min, + # adv + "critic/advantages/mean": torch.mean(valid_adv).detach().item(), + "critic/advantages/max": torch.max(valid_adv).detach().item(), + "critic/advantages/min": torch.min(valid_adv).detach().item(), + # returns + "critic/returns/mean": torch.mean(valid_returns).detach().item(), + "critic/returns/max": torch.max(valid_returns).detach().item(), + "critic/returns/min": torch.min(valid_returns).detach().item(), + **( + { + # values + "critic/values/mean": torch.mean(valid_values).detach().item(), + "critic/values/max": torch.max(valid_values).detach().item(), + "critic/values/min": torch.min(valid_values).detach().item(), + # vf explained var + "critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), + } + if use_critic + else {} + ), + **( + { + # raw bonuses + "critic/raw_bonuses/mean": np.mean(sequence_raw_bonuses).item(), + "critic/raw_bonuses/max": np.max(sequence_raw_bonuses).item(), + "critic/raw_bonuses/min": np.min(sequence_raw_bonuses).item(), + "critic/raw_bonuses/std": np.std(sequence_raw_bonuses).item(), + # intrinsic_reward + "critic/intrinsic_reward/mean": np.mean(sequence_intrinsic_reward).item(), + "critic/intrinsic_reward/max": np.max(sequence_intrinsic_reward).item(), + "critic/intrinsic_reward/min": np.min(sequence_intrinsic_reward).item(), + "critic/intrinsic_reward/std": np.std(sequence_intrinsic_reward).item(), + # beta_scaled_intrinsic_reward + "critic/beta_scaled_intrinsic_reward/mean": np.mean(sequence_beta_scaled_intrinsic_reward).item(), + "critic/beta_scaled_intrinsic_reward/max": np.max(sequence_beta_scaled_intrinsic_reward).item(), + "critic/beta_scaled_intrinsic_reward/min": np.min(sequence_beta_scaled_intrinsic_reward).item(), + "critic/beta_scaled_intrinsic_reward/std": np.std(sequence_beta_scaled_intrinsic_reward).item(), + # extrinsic_reward + "critic/extrinsic_reward/mean": np.mean(sequence_extrinsic_reward).item(), + "critic/extrinsic_reward/max": np.max(sequence_extrinsic_reward).item(), + "critic/extrinsic_reward/min": np.min(sequence_extrinsic_reward).item(), + "critic/extrinsic_reward/std": np.std(sequence_extrinsic_reward).item(), + # three_case_stats + "critic/extrinsic_reward/no_rollouts_correct_frac": three_case_stats["no_rollouts_correct_frac"], + "critic/extrinsic_reward/some_rollouts_correct_frac": three_case_stats["some_rollouts_correct_frac"], + "critic/extrinsic_reward/all_rollouts_correct_frac": three_case_stats["all_rollouts_correct_frac"], + # total_reward + "critic/total_reward/mean": np.mean(sequence_total_reward).item(), + "critic/total_reward/max": np.max(sequence_total_reward).item(), + "critic/total_reward/min": np.min(sequence_total_reward).item(), + "critic/total_reward/std": np.std(sequence_total_reward).item(), + } + if elliptical + else {} + ), + # response length + "response_length/mean": torch.mean(response_length).detach().item(), + "response_length/max": torch.max(response_length).detach().item(), + "response_length/min": torch.min(response_length).detach().item(), + "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()) + .detach() + .item(), + # response length (non-aborted only) + # These statistics exclude aborted samples to avoid skew from zeros + "response_length_non_aborted/mean": non_aborted_response_length_mean, + "response_length_non_aborted/max": non_aborted_response_length_max, + "response_length_non_aborted/min": non_aborted_response_length_min, + "response_length_non_aborted/clip_ratio": non_aborted_response_length_clip_ratio, + # aborted ratio + # Fraction of samples whose response length is zero + "response/aborted_ratio": aborted_ratio, + # prompt length + "prompt_length/mean": torch.mean(prompt_length).detach().item(), + "prompt_length/max": torch.max(prompt_length).detach().item(), + "prompt_length/min": torch.min(prompt_length).detach().item(), + "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), + } + + # multi-turn conversation + if "__num_turns__" in batch.non_tensor_batch: + num_turns = batch.non_tensor_batch["__num_turns__"] + metrics["num_turns/min"] = num_turns.min() + metrics["num_turns/max"] = num_turns.max() + metrics["num_turns/mean"] = num_turns.mean() + + if "tool_call_counts" in batch.non_tensor_batch: + tool_call_counts = batch.non_tensor_batch["tool_call_counts"] + metrics["tool_call_counts/min"] = tool_call_counts.min() + metrics["tool_call_counts/max"] = tool_call_counts.max() + metrics["tool_call_counts/mean"] = tool_call_counts.mean() + + return metrics + + +def comb_estimator(n: int, c: int, k: int) -> float: + """Calculates 1 - comb(n - c, k) / comb(n, k).""" + if n - c < k: + return 1.0 + return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) + + +def process_validation_metrics( + data_sources: list[str], sample_uids: list[str], infos_dict: dict[str, list[Any]], seed: int = 42 +) -> dict[str, dict[str, dict[str, float]]]: + """ + Process validation metrics into a structured format with statistical analysis. + + This function organizes validation metrics by data source and prompt, then computes + various statistical measures including means, standard deviations, best/worst values, + and majority voting results. It also performs bootstrap sampling to estimate statistics + for different sample sizes. + + Args: + data_sources: List of data source identifiers for each sample. + sample_uids: List of sample uids corresponding to each sample. + infos_dict: Dictionary mapping variable names to lists of values for each sample. + seed: Random seed for bootstrap sampling. Defaults to 42. + + Returns: + A nested dictionary with the structure: + { + data_source: { + variable_name: { + metric_name: value + } + } + } + + Where metric_name includes: + - "mean@N": Mean value across N samples + - "std@N": Standard deviation across N samples + - "best@N/mean": Mean of the best values in bootstrap samples of size N + - "best@N/std": Standard deviation of the best values in bootstrap samples + - "worst@N/mean": Mean of the worst values in bootstrap samples + - "worst@N/std": Standard deviation of the worst values in bootstrap samples + - "maj@N/mean": Mean of majority voting results in bootstrap samples (if "pred" exists) + - "maj@N/std": Standard deviation of majority voting results (if "pred" exists) + + Example: + >>> data_sources = ["source1", "source1", "source2"] + >>> sample_uids = ["uid1", "uid1", "uid2"] + >>> infos_dict = {"score": [0.8, 0.9, 0.7], "pred": ["A", "A", "B"]} + >>> result = process_validation_metrics(data_sources, sample_uids, infos_dict) + >>> # result will contain statistics for each data source and variable + """ + # Group metrics by data source, prompt and variable + data_src2uid2var2vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + for sample_idx, data_source in enumerate(data_sources): + uid = sample_uids[sample_idx] + var2vals = data_src2uid2var2vals[data_source][uid] + for var_name, var_vals in infos_dict.items(): + var2vals[var_name].append(var_vals[sample_idx]) + + # Calculate metrics for each group + data_src2uid2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) + for data_source, uid2var2vals in data_src2uid2var2vals.items(): + for uid, var2vals in uid2var2vals.items(): + for var_name, var_vals in var2vals.items(): + if isinstance(var_vals[0], str): + continue + + metric = {} + n_resps = len(var_vals) + metric[f"mean@{n_resps}"] = np.mean(var_vals) + metric["pass@1/mean"] = comb_estimator(n_resps, np.sum(var_vals), 1) + + if n_resps > 1: + metric[f"std@{n_resps}"] = np.std(var_vals) + + ns = [] + n = 2 + while n < n_resps: + ns.append(n) + n *= 2 + ns.append(n_resps) + + for n in ns: + # [(bon_mean, bon_std), (won_mean, won_std)] = bootstrap_metric( + # data=var_vals, subset_size=n, reduce_fns=[np.max, np.min], seed=seed + # ) + # metric[f"best@{n}/mean"], metric[f"best@{n}/std"] = bon_mean, bon_std + # metric[f"worst@{n}/mean"], metric[f"worst@{n}/std"] = won_mean, won_std + metric[f"pass@{n}/mean"] = comb_estimator(n_resps, np.sum(var_vals), n) + if var2vals.get("pred", None) is not None: + vote_data = [ + {"val": val, "pred": pred} for val, pred in zip(var_vals, var2vals["pred"], strict=True) + ] + [(maj_n_mean, maj_n_std)] = bootstrap_metric( + data=vote_data, + subset_size=n, + reduce_fns=[partial(calc_maj_val, vote_key="pred", val_key="val")], + seed=seed, + ) + metric[f"maj@{n}/mean"], metric[f"maj@{n}/std"] = maj_n_mean, maj_n_std + + data_src2uid2var2metric[data_source][uid][var_name] = metric + + # Aggregate metrics across uids + data_src2var2metric2uid_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + for data_source, uid2var2metric in data_src2uid2var2metric.items(): + for uid, var2metric in uid2var2metric.items(): + for var_name, metric in var2metric.items(): + for metric_name, metric_val in metric.items(): + data_src2var2metric2uid_vals[data_source][var_name][metric_name].append(metric_val) + + data_src2var2metric2val = defaultdict(lambda: defaultdict(lambda: defaultdict(float))) + for data_source, var2metric2uid_vals in data_src2var2metric2uid_vals.items(): + for var_name, metric2uid_vals in var2metric2uid_vals.items(): + for metric_name, uid_vals in metric2uid_vals.items(): + data_src2var2metric2val[data_source][var_name][metric_name] = np.mean(uid_vals) + + return data_src2var2metric2val diff --git a/recipe/rep_exp/model_merge.sh b/recipe/rep_exp/model_merge.sh new file mode 100644 index 00000000000..b7c673ebe8c --- /dev/null +++ b/recipe/rep_exp/model_merge.sh @@ -0,0 +1,6 @@ +CHECKPOINT_PATH=${1} # /path/to/global_step_X/actor, where X is the global step of the checkpoint with the best pass@1 on dev + +python3 -m verl.model_merger merge \ + --backend fsdp \ + --local_dir $CHECKPOINT_PATH \ + --target_dir $CHECKPOINT_PATH/hf \ No newline at end of file diff --git a/recipe/rep_exp/plot_pass_at_k.py b/recipe/rep_exp/plot_pass_at_k.py new file mode 100644 index 00000000000..eea2011bab0 --- /dev/null +++ b/recipe/rep_exp/plot_pass_at_k.py @@ -0,0 +1,241 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Code to plot the pass@k results for the RepExp RL training results. +""" + +import json +import os +from collections import defaultdict + +import matplotlib.pyplot as plt +import numpy as np +import scipy.stats as stats +import seaborn as sns +from matplotlib.lines import Line2D + +# Content configuration +EVAL_FOLDER = "./eval" +TASKS = ["math"] # ["math", "gsm8k", "dapo-with-aime24"] +SEEDS = [41, 42, 43] +ALGORITHMS = ["elliptical"] # ["grpo", "elliptical", "untrained", "unlikely"] +LOG_AXES = True + +# Plot configuration +FACE_COLOR = "#F7F7FF" +MARKER = "o" +LINEWIDTH = 1.275 +MARKERSIZE = 6 +MARKEREDGEWIDTH = 0.9 +LABEL_FONT_SIZE = 10 +TITLE_FONT_SIZE = 11 +TICK_LABEL_FONT_SIZE = 8 +LEGEND_FONT_SIZE = 8 + +TASK_TO_NICE_NAME = { + "math": "MATH", + "gsm8k": "GSM8K", + "dapo-with-aime24": "AIME 2024", + "countdown-4": "Countdown", +} + +ALGO_TO_COLOR = { + "grpo": sns.color_palette("deep")[-1], + "untrained": sns.color_palette("deep")[7], + "elliptical": sns.color_palette("colorblind")[2], + "unlikely": sns.color_palette("deep")[1], +} + +ALGO_TO_NICE_NAME = { + "grpo": "GRPO", + "untrained": "Base Model", + "elliptical": r"RepExp (ours)", + "unlikely": "Unlikeliness", +} + + +def process_data(data: list[dict[str, float]], algorithm: str) -> tuple[dict[int, float], dict[int, float]]: + """ + Process the pass@k data generated by a given algorithm. + + Args: + data (List[Dict]): The data to process. + algorithm (str): Algorithm that generated the data. + + Returns: + Tuple[Dict[int, float], Dict[int, float]]: + pass_at_k - The mean pass@k values. + pass_at_k_sem - The standard error of the pass@k values. + """ + pass_at_k = defaultdict(list) + for d in data: + for key, v in d.items(): + for k in [1, 2, 4, 8, 16, 32, 64, 128, 256]: + if key.endswith(f"reward/pass@{k}/mean"): + pass_at_k[k].append(v) + + # NOTE: we only use a single seed for untrained since there is only one checkpoint for it + if algorithm != "untrained": + for k in pass_at_k.keys(): + assert len(pass_at_k[k]) == len(SEEDS) + + pass_at_k_sem = {k: stats.sem(v) for k, v in pass_at_k.items()} if algorithm != "untrained" else None + pass_at_k = {k: np.mean(v) for k, v in pass_at_k.items()} + + return pass_at_k, pass_at_k_sem + + +def main(): + # Get all top-level folders in EVAL_FOLDER + eval_folders = os.listdir(EVAL_FOLDER) + + # Figure setup + sns.set_style("whitegrid") + fig, axs = plt.subplots(1, len(TASKS), figsize=(3 * len(TASKS), 3)) + + for i, task in enumerate(TASKS): + ax = axs[i] if len(TASKS) > 1 else axs + algo_to_xs = {} + algo_to_ys = {} + + for algorithm in ALGORITHMS: + # Get all eval folders for the current task and algorithm + folders = [f for f in eval_folders if f.startswith(f"{task}_{algorithm}")] + if len(folders) == 0: + continue + + data = [] + for folder in folders: + if algorithm == "untrained": + with open(os.path.join(EVAL_FOLDER, folder, "eval.json")) as f: + data.append(json.load(f)) + else: + # walk all files recursively in folder + for root, dirs, files in os.walk(os.path.join(EVAL_FOLDER, folder)): + for file in files: + if file.endswith("eval.json"): + with open(os.path.join(root, file)) as f: + data.append(json.load(f)) + break + + pass_at_k, pass_at_k_sem = process_data(data, algorithm) + + xs = np.array(list(pass_at_k.keys())) + ys = np.array([pass_at_k[k] for k in xs]) + algo_to_xs[algorithm] = xs + algo_to_ys[algorithm] = ys + + # Plot the current task - algorithm data + ax.plot( + xs, + ys, + color=ALGO_TO_COLOR[algorithm], + label=algorithm, + markeredgecolor=FACE_COLOR, + marker=MARKER, + linewidth=LINEWIDTH, + markersize=MARKERSIZE, + markeredgewidth=MARKEREDGEWIDTH, + alpha=1.0 if algorithm != "untrained" else 0.8, + ) + + # Plot the standard error in shaded bands + if algorithm != "untrained": + sems = np.array([pass_at_k_sem[k] for k in xs]) + ax.fill_between(xs, ys - sems, ys + sems, alpha=0.2, color=ALGO_TO_COLOR[algorithm]) + + # Set y-axis limits + if task == "math": + y_min = 0.7 + ax.set_ylim(top=0.95, bottom=y_min) + elif task == "gsm8k": + y_min = 0.925 + ax.set_ylim(top=0.995, bottom=y_min) + elif task == "dapo-with-aime24": + y_min = 0.1 + ax.set_ylim(bottom=y_min, top=0.63) + + # Set x-axis limits + if LOG_AXES: + ax.set_xlim(left=2 ** (-0.2), right=2 ** (8.2)) + else: + ax.set_xlim(left=-10, right=266) + + # Set x-axis scale and ticks + if LOG_AXES: + ax.set_xscale("log", base=2) + x_ticks = [2**i for i in range(int(np.log2(max(xs))) + 1)] + x_tick_labels = [f"$2^{{{i}}}$" for i in range(int(np.log2(max(xs))) + 1)] + else: + # set every 64 + x_ticks = [1, 32, 64, 96, 128, 160, 192, 224, 256] + x_tick_labels = ["1", "32", "64", "96", "128", "160", "192", "224", "256"] + ax.set_xticks(x_ticks, x_tick_labels) + + # Set axes labels + ax.set_xlabel("k", fontsize=LABEL_FONT_SIZE) + if i == 0: + ax.set_ylabel("Pass@k", fontsize=LABEL_FONT_SIZE) + + # Set title + ax.set_title(f"{TASK_TO_NICE_NAME[task]}", fontsize=TITLE_FONT_SIZE) + + # Set font size for tick labels + for _label in ax.get_xticklabels(): + _label.set_fontsize(TICK_LABEL_FONT_SIZE) + for _label in ax.get_yticklabels(): + _label.set_fontsize(TICK_LABEL_FONT_SIZE) + + # Create legend handles + legend_handles = [ + Line2D( + [0], + [0], + color=ALGO_TO_COLOR[algo], + marker=MARKER, + linestyle="-", + linewidth=LINEWIDTH, + markersize=MARKERSIZE, + markeredgewidth=MARKEREDGEWIDTH, + markeredgecolor=FACE_COLOR, + label=ALGO_TO_NICE_NAME[algo], + ) + for algo in ALGORITHMS + ] + + # Create legend + legend = fig.legend( + handles=legend_handles, + loc="lower center", + ncol=len(ALGORITHMS), + bbox_to_anchor=(0.5, -0.07), + fontsize=LEGEND_FONT_SIZE, + ) + + plt.tight_layout() + + os.makedirs("figures", exist_ok=True) + # Save figure + plt.savefig( + os.path.join("figures", f"rl_pass_at_k_{TASKS}_{'' if LOG_AXES else '_linear_axes'}.pdf"), + bbox_extra_artists=(legend,), + bbox_inches="tight", + ) + + # Close figure + plt.close() + + +if __name__ == "__main__": + main() diff --git a/recipe/rep_exp/rep_exp_trainer.py b/recipe/rep_exp/rep_exp_trainer.py new file mode 100644 index 00000000000..c7c23b848d7 --- /dev/null +++ b/recipe/rep_exp/rep_exp_trainer.py @@ -0,0 +1,739 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +PPO Trainer with Ray-based single controller. +This trainer supports model-agonistic model initialization with huggingface +""" + +import json +import os +import uuid +from collections import defaultdict +from copy import deepcopy +from pprint import pprint + +import numpy as np +import ray +import torch +from omegaconf import OmegaConf +from tqdm import tqdm + +from verl import DataProto +from verl.experimental.dataset.sampler import AbstractCurriculumSampler +from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto +from verl.single_controller.ray import RayClassWithInitArgs +from verl.single_controller.ray.base import create_colocated_worker_cls +from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss +from verl.trainer.ppo.metric_utils import ( + compute_throughout_metrics, + compute_timing_metrics, +) +from verl.trainer.ppo.ray_trainer import RayPPOTrainer, apply_kl_penalty, compute_advantage, compute_response_mask +from verl.trainer.ppo.reward import compute_reward, compute_reward_async +from verl.trainer.ppo.utils import Role +from verl.utils.checkpoint.checkpoint_manager import should_save_ckpt_esi +from verl.utils.config import omega_conf_to_dataclass +from verl.utils.debug import marked_timer +from verl.utils.metric import reduce_metrics +from verl.utils.rollout_skip import RolloutSkip + +from .metric_utils import compute_data_metrics, process_validation_metrics + + +class RayRepExpTrainer(RayPPOTrainer): + """Distributed RepExp trainer using Ray for scalable reinforcement learning. + + See RayPPOTrainer parent class for more details. + """ + + def _save_checkpoint(self): + super()._save_checkpoint() + + # Write best metric to global steps + local_best_metric_to_global_step = os.path.join( + self.config.trainer.default_local_dir, "best_metric_to_global_step.json" + ) + with open(local_best_metric_to_global_step, "w") as f: + json.dump(self.best_dev_pass_at_k_to_global_step, f) + + def _update_best_pass_at(self, val_metrics: dict[str, float], pass_at_k: int) -> bool: + """ + Save checkpoint if the validation metrics are the best. + + Args: + val_metrics: The validation metrics. + pass_at_k: The pass@k to use for determining whether to save the checkpoint. + """ + for k in val_metrics.keys(): + if k.endswith(f"reward/pass@{pass_at_k}/mean"): + if val_metrics[k] > self.best_dev_pass_at_k[pass_at_k]: + self.best_dev_pass_at_k[pass_at_k] = val_metrics[k] + self.best_dev_pass_at_k_to_global_step[pass_at_k] = self.global_steps + return True + + return False + + def _validate(self): + data_source_lst = [] + reward_extra_infos_dict: dict[str, list] = defaultdict(list) + + # Lists to collect samples for the table + sample_inputs = [] + sample_outputs = [] + sample_gts = [] + sample_scores = [] + sample_turns = [] + sample_uids = [] + + for test_data in tqdm(self.val_dataloader, desc="Validating ..."): + test_batch = DataProto.from_single_dict(test_data) + + if "uid" not in test_batch.non_tensor_batch: + test_batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(test_batch.batch))], dtype=object + ) + + # repeat test batch + test_batch = test_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True + ) + + # we only do validation on rule-based rm + if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model": + return {} + + # Store original inputs + input_ids = test_batch.batch["input_ids"] + # TODO: Can we keep special tokens except for padding tokens? + input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] + sample_inputs.extend(input_texts) + sample_uids.extend(test_batch.non_tensor_batch["uid"]) + + ground_truths = [ + item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in test_batch + ] + sample_gts.extend(ground_truths) + + test_gen_batch = self._get_gen_batch(test_batch) + test_gen_batch.meta_info = { + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, + "recompute_log_prob": False, + "do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample, + "validate": True, + "global_steps": self.global_steps, + } + print(f"test_gen_batch meta info: {test_gen_batch.meta_info}") + + # pad to be divisible by dp_size + size_divisor = ( + self.actor_rollout_wg.world_size + if not self.async_rollout_mode + else self.config.actor_rollout_ref.rollout.agent.num_workers + ) + test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor) + if not self.async_rollout_mode: + test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) + else: + test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded) + + # unpad + test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) + + print("validation generation end") + + # Store generated outputs + output_ids = test_output_gen_batch.batch["responses"] + output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] + sample_outputs.extend(output_texts) + + test_batch = test_batch.union(test_output_gen_batch) + test_batch.meta_info["validate"] = True + + # evaluate using reward_function + if self.val_reward_fn is None: + raise ValueError("val_reward_fn must be provided for validation.") + result = self.val_reward_fn(test_batch, return_dict=True) + reward_tensor = result["reward_tensor"] + scores = reward_tensor.sum(-1).cpu().tolist() + sample_scores.extend(scores) + + reward_extra_infos_dict["reward"].extend(scores) + if "reward_extra_info" in result: + for key, lst in result["reward_extra_info"].items(): + reward_extra_infos_dict[key].extend(lst) + + # collect num_turns of each prompt + if "__num_turns__" in test_batch.non_tensor_batch: + sample_turns.append(test_batch.non_tensor_batch["__num_turns__"]) + + data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0])) + + self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) + + # dump generations + val_data_dir = self.config.trainer.get("validation_data_dir", None) + if val_data_dir: + self._dump_generations( + inputs=sample_inputs, + outputs=sample_outputs, + gts=sample_gts, + scores=sample_scores, + reward_extra_infos_dict=reward_extra_infos_dict, + dump_path=val_data_dir, + ) + + for key_info, lst in reward_extra_infos_dict.items(): + assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}" + + data_sources = np.concatenate(data_source_lst, axis=0) + + data_src2var2metric2val = process_validation_metrics(data_sources, sample_uids, reward_extra_infos_dict) + metric_dict = {} + for data_source, var2metric2val in data_src2var2metric2val.items(): + core_var = "acc" if "acc" in var2metric2val else "reward" + for var_name, metric2val in var2metric2val.items(): + n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()]) + for metric_name, metric_val in metric2val.items(): + if ( + (var_name == core_var) + and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"]) + and (f"@{n_max}" in metric_name) + ): + metric_sec = "val-core" + else: + metric_sec = "val-aux" + pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}" + metric_dict[pfx] = metric_val + + if len(sample_turns) > 0: + sample_turns = np.concatenate(sample_turns) + metric_dict["val-aux/num_turns/min"] = sample_turns.min() + metric_dict["val-aux/num_turns/max"] = sample_turns.max() + metric_dict["val-aux/num_turns/mean"] = sample_turns.mean() + + return metric_dict + + def init_workers(self): + """Initialize distributed training workers using Ray backend. + + Creates: + 1. Ray resource pools from configuration + 2. Worker groups for each role (actor, critic, etc.) + """ + self.resource_pool_manager.create_resource_pool() + + self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} + val_only = self.config.trainer.get("val_only", False) + + # create actor and rollout + actor_role = Role.ActorRolloutRef if Role.ActorRolloutRef in self.role_worker_mapping else Role.ActorRollout + if self.hybrid_engine: + resource_pool = self.resource_pool_manager.get_resource_pool(actor_role) + actor_rollout_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[actor_role], + config=self.config.actor_rollout_ref, + role=str(actor_role), + ) + self.resource_pool_to_cls[resource_pool][str(actor_role)] = actor_rollout_cls + else: + raise NotImplementedError + + # create critic + if self.use_critic and not val_only: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) + critic_cfg = omega_conf_to_dataclass(self.config.critic) + critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg) + self.resource_pool_to_cls[resource_pool][str(Role.Critic)] = critic_cls + + # create reference policy if needed + if self.use_reference_policy and not val_only: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) + ref_policy_cls = RayClassWithInitArgs( + self.role_worker_mapping[Role.RefPolicy], + config=self.config.actor_rollout_ref, + role=str(Role.RefPolicy), + ) + self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls + + # create a reward model if reward_fn is None + if self.use_rm and not val_only: + # we create a RM here + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) + rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model) + self.resource_pool_to_cls[resource_pool][str(Role.RewardModel)] = rm_cls + + # initialize WorkerGroup + # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, + # you should not use `create_colocated_worker_cls`. + # Instead, directly pass different resource pool to different worker groups. + # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. + all_wg = {} + wg_kwargs = {} # Setting up kwargs for RayWorkerGroup + if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: + wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout + if OmegaConf.select(self.config.global_profiler, "steps") is not None: + wg_kwargs["profile_steps"] = OmegaConf.select(self.config.global_profiler, "steps") + # Only require nsight worker options when tool is nsys + if OmegaConf.select(self.config.global_profiler, "tool") == "nsys": + assert ( + OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") + is not None + ), "worker_nsight_options must be set when using nsys with profile_steps" + wg_kwargs["worker_nsight_options"] = OmegaConf.to_container( + OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") + ) + wg_kwargs["device_name"] = self.device_name + + for resource_pool, class_dict in self.resource_pool_to_cls.items(): + worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) + wg_dict = self.ray_worker_group_cls( + resource_pool=resource_pool, + ray_cls_with_init=worker_dict_cls, + **wg_kwargs, + ) + spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) + all_wg.update(spawn_wg) + + if self.use_critic: + self.critic_wg = all_wg[str(Role.Critic)] + self.critic_wg.init_model() + + if self.use_reference_policy and not self.ref_in_actor: + if str(Role.RefPolicy) in all_wg: + self.ref_policy_wg = all_wg[str(Role.RefPolicy)] + self.ref_policy_wg.init_model() + else: + # Model engine: ActorRolloutRefWorker + assert str(Role.ActorRolloutRef) in all_wg, f"{all_wg.keys()=}" + self.ref_policy_wg = all_wg[str(Role.ActorRolloutRef)] + + self.rm_wg = None + # initalization of rm_wg will be deprecated in the future + if self.use_rm: + self.rm_wg = all_wg[str(Role.RewardModel)] + self.rm_wg.init_model() + + # we should create rollout at the end so that vllm can have a better estimation of kv cache memory + self.actor_rollout_wg = all_wg[str(actor_role)] + self.actor_rollout_wg.init_model() + + # create async rollout manager and request scheduler + self.async_rollout_mode = False + if self.config.actor_rollout_ref.rollout.mode == "async": + from verl.experimental.agent_loop import AgentLoopManager + + self.async_rollout_mode = True + self.async_rollout_manager = AgentLoopManager( + config=self.config, worker_group=self.actor_rollout_wg, rm_wg=self.rm_wg + ) + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC + to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from omegaconf import OmegaConf + + from .utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + # global vars to track during training + self.global_steps = 0 + + self.best_dev_pass_at_k = { + 1: 0, + } + self.best_dev_pass_at_k_to_global_step = { + 1: 0, + } + + # load checkpoint before doing anything + self._load_checkpoint() + + current_epoch = self.global_steps // len(self.train_dataloader) + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + val_metrics = self._validate() + assert val_metrics, f"{val_metrics=}" + + # Initialize the best validation metrics for pass@k before training + self._update_best_pass_at(val_metrics, 1) + val_metrics["best/pass@1"] = self.best_dev_pass_at_k[1] + + pprint(f"Initial validation metrics: {val_metrics}") + logger.log(data=val_metrics, step=self.global_steps) + + if self.config.trainer.get("val_only", False): + return + + if self.config.actor_rollout_ref.rollout.get("skip_rollout", False): + rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg) + rollout_skip.wrap_generate_sequences() + + # add tqdm + progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + + # we start from step 1 + self.global_steps += 1 + last_val_metrics = None + self.max_steps_duration = 0 + + prev_step_profile = False + curr_step_profile = ( + self.global_steps in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + next_step_profile = False + + for epoch in range(current_epoch, self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + metrics = {} + timing_raw = {} + + with marked_timer("start_profile", timing_raw): + self._start_profiling( + not prev_step_profile and curr_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + batch: DataProto = DataProto.from_single_dict(batch_dict) + batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature + + # add uid to batch + batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object + ) + + gen_batch = self._get_gen_batch(batch) + + # pass global_steps to trace + gen_batch.meta_info["global_steps"] = self.global_steps + gen_batch_output = gen_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True + ) + + is_last_step = self.global_steps >= self.total_training_steps + with marked_timer("step", timing_raw): + # generate a batch + with marked_timer("gen", timing_raw, color="red"): + if not self.async_rollout_mode: + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output) + else: + gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output) + + timing_raw.update(gen_batch_output.meta_info["timing"]) + gen_batch_output.meta_info.pop("timing", None) + + if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: + if self.reward_fn is None: + raise ValueError("A reward_fn is required for REMAX advantage estimation.") + + with marked_timer("gen_max", timing_raw, color="purple"): + gen_baseline_batch = deepcopy(gen_batch) + gen_baseline_batch.meta_info["do_sample"] = False + if not self.async_rollout_mode: + gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) + else: + gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch) + batch = batch.union(gen_baseline_output) + # compute reward model score on batch + rm_scores = None + if self.use_rm and "rm_scores" not in batch.batch.keys(): + rm_scores = self.rm_wg.compute_rm_score(batch) + batch = batch.union(rm_scores) + reward_baseline_tensor, _ = compute_reward(batch, self.reward_fn) + reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) + + keys_to_pop = set(gen_baseline_output.batch.keys()) + if rm_scores is not None: + keys_to_pop.update(rm_scores.batch.keys()) + batch.pop(batch_keys=list(keys_to_pop)) + + batch.batch["reward_baselines"] = reward_baseline_tensor + + del rm_scores, gen_baseline_batch, gen_baseline_output + # repeat to align with repeated responses in rollout + batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.union(gen_batch_output) + + if "response_mask" not in batch.batch.keys(): + batch.batch["response_mask"] = compute_response_mask(batch) + # Balance the number of valid tokens across DP ranks. + # NOTE: This usually changes the order of data in the `batch`, + # which won't affect the advantage calculation (since it's based on uid), + # but might affect the loss calculation (due to the change of mini-batching). + if self.config.trainer.balance_batch: + self._balance_batch(batch, metrics=metrics) + + # compute global_valid tokens + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + + with marked_timer("reward", timing_raw, color="yellow"): + # compute reward model score + if self.use_rm and "rm_scores" not in batch.batch.keys(): + if self.config.reward_model.elliptical.enable: + hidden_states = self.rm_wg.compute_hidden_states(batch) + batch = batch.union(hidden_states) + reward_tensor = self.rm_wg.compute_rm_score(batch) + else: + reward_tensor = self.rm_wg.compute_rm_score(batch) + batch = batch.union(reward_tensor) + + if self.config.reward_model.launch_reward_fn_async: + future_reward = compute_reward_async.remote( + data=batch, config=self.config, tokenizer=self.tokenizer + ) + else: + reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) + + # Operating Mode Selection: + # - Bypass mode: Sets old_log_probs = rollout_log_probs (2 policies: π_rollout, π_θ) + # - Decoupled mode: Recomputes old_log_probs as proximal anchor (3 policies: π_rollout, π_old, π_θ) + # Note: π_old computed once per data batch, serves as stable reference during mini-batch updates + rollout_corr_config = self.config.algorithm.get("rollout_correction", None) + bypass_recomputing_logprobs = rollout_corr_config and rollout_corr_config.get("bypass_mode", False) + if bypass_recomputing_logprobs: # Use `rollout_log_probs` + from verl.trainer.ppo.rollout_corr_helper import apply_rollout_correction + + apply_rollout_correction( + batch=batch, + rollout_corr_config=rollout_corr_config, + policy_loss_config=self.config.actor_rollout_ref.actor.policy_loss, + ) + else: # Recompute old_log_probs + with marked_timer("old_log_prob", timing_raw, color="blue"): + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + entropys = old_log_prob.batch["entropys"] + response_masks = batch.batch["response_mask"] + loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode + entropy_agg = agg_loss( + loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode + ) + old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") + batch = batch.union(old_log_prob) + if "rollout_log_probs" in batch.batch.keys(): + # TODO: we may want to add diff of probs too. + from verl.utils.debug.metrics import calculate_debug_metrics + + metrics.update(calculate_debug_metrics(batch)) + + assert "old_log_probs" in batch.batch, f'"old_log_prob" not in {batch.batch.keys()=}' + + if self.use_reference_policy: + # compute reference log_prob + with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"): + if not self.ref_in_actor: + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + else: + ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + # compute values + if self.use_critic: + with marked_timer("values", timing_raw, color="cyan"): + values = self.critic_wg.compute_values(batch) + batch = batch.union(values) + + with marked_timer("adv", timing_raw, color="brown"): + # we combine with rule-based rm + reward_extra_infos_dict: dict[str, list] + if self.config.reward_model.launch_reward_fn_async: + reward_tensor, reward_extra_infos_dict = ray.get(future_reward) + batch.batch["token_level_scores"] = reward_tensor + + if reward_extra_infos_dict: + batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) + + # compute rewards. apply_kl_penalty if available + if self.config.algorithm.use_kl_in_reward: + batch, kl_metrics = apply_kl_penalty( + batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + ) + metrics.update(kl_metrics) + else: + batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] + + # Compute rollout correction: IS weights, rejection sampling, and metrics + # Only runs in decoupled mode (computes once per batch using stable π_old) + # In bypass mode, this is skipped - actor computes metrics from evolving π_θ vs π_rollout + if ( + rollout_corr_config is not None + and "rollout_log_probs" in batch.batch + and not bypass_recomputing_logprobs # Only in decoupled mode + ): + from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch + + # Compute IS weights, apply rejection sampling, compute metrics + batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config) + # IS and off-policy metrics already have rollout_corr/ prefix + metrics.update(is_metrics) + + # compute advantages, executed on the driver process + norm_adv_by_std_in_grpo = self.config.algorithm.get( + "norm_adv_by_std_in_grpo", True + ) # GRPO adv normalization factor + + batch = compute_advantage( + batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n, + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + config=self.config.algorithm, + ) + + # update critic + if self.use_critic: + with marked_timer("update_critic", timing_raw, color="pink"): + critic_output = self.critic_wg.update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + metrics.update(critic_output_metrics) + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + # update actor + with marked_timer("update_actor", timing_raw, color="red"): + rollout_config = self.config.actor_rollout_ref.rollout + batch.meta_info["multi_turn"] = rollout_config.multi_turn.enable + # TODO: Make "temperature" single source of truth from generation. + batch.meta_info["temperature"] = rollout_config.temperature + actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + metrics.update(actor_output_metrics) + + # Log rollout generations if enabled + rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) + if rollout_data_dir: + self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir) + + # validate + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with marked_timer("testing", timing_raw, color="green"): + val_metrics: dict = self._validate() + + # Initialize the best validation metrics for pass@k before training + self._update_best_pass_at(val_metrics, 1) + val_metrics["best/pass@1"] = self.best_dev_pass_at_k[1] + + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + + # Check if the ESI (Elastic Server Instance)/training plan is close to expiration. + esi_close_to_expiration = should_save_ckpt_esi( + max_steps_duration=self.max_steps_duration, + redundant_time=self.config.trainer.esi_redundant_time, + ) + # Check if the conditions for saving a checkpoint are met. + # The conditions include a mandatory condition (1) and + # one of the following optional conditions (2/3/4): + # 1. The save frequency is set to a positive value. + # 2. It's the last training step. + # 3. The current step number is a multiple of the save frequency. + # 4. The ESI(Elastic Server Instance)/training plan is close to expiration. + if self.config.trainer.save_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration + ): + if esi_close_to_expiration: + print("Force saving checkpoint: ESI instance expiration approaching.") + with marked_timer("save_checkpoint", timing_raw, color="green"): + self._save_checkpoint() + + with marked_timer("stop_profile", timing_raw): + next_step_profile = ( + self.global_steps + 1 in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + self._stop_profiling( + curr_step_profile and not next_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + prev_step_profile = curr_step_profile + curr_step_profile = next_step_profile + + steps_duration = timing_raw["step"] + self.max_steps_duration = max(self.max_steps_duration, steps_duration) + + # training metrics + metrics.update( + { + "training/global_step": self.global_steps, + "training/epoch": epoch, + } + ) + # collect metrics + metrics.update( + compute_data_metrics( + batch=batch, + use_critic=self.use_critic, + elliptical=self.config.reward_model.elliptical.enable, + ) + ) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + # TODO: implement actual tflpo and theoretical tflpo + n_gpus = self.resource_pool_manager.get_n_gpus() + metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + # Note: mismatch metrics (KL, PPL, etc.) are collected at line 1179 after advantage computation + + # this is experimental and may be changed/removed in the future in favor of a general-purpose one + if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler): + self.train_dataloader.sampler.update(batch=batch) + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + progress_bar.update(1) + self.global_steps += 1 + + if ( + hasattr(self.config.actor_rollout_ref.actor, "profiler") + and self.config.actor_rollout_ref.actor.profiler.tool == "torch_memory" + ): + self.actor_rollout_wg.dump_memory_snapshot( + tag=f"post_update_step{self.global_steps}", sub_dir=f"step{self.global_steps}" + ) + + if is_last_step: + pprint(f"Final validation metrics: {last_val_metrics}") + progress_bar.close() + return + + # this is experimental and may be changed/removed in the future + # in favor of a general-purpose data buffer pool + if hasattr(self.train_dataset, "on_batch_end"): + # The dataset may be changed after each training batch + self.train_dataset.on_batch_end(batch=batch) diff --git a/recipe/rep_exp/reward_manager/elliptical_reward_manager.py b/recipe/rep_exp/reward_manager/elliptical_reward_manager.py new file mode 100644 index 00000000000..83040ac1c02 --- /dev/null +++ b/recipe/rep_exp/reward_manager/elliptical_reward_manager.py @@ -0,0 +1,138 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict + +import torch + +from verl import DataProto +from verl.workers.reward_manager import NaiveRewardManager, register + +from ..reward_score import default_compute_score + + +@register("elliptical") +class EllipticalRewardManager(NaiveRewardManager): + """The reward manager.""" + + def __init__( + self, + tokenizer, + num_examine, + compute_score=None, + reward_fn_key="data_source", + beta: int = 1.0, + turn_off_elliptical_if_none_correct: bool = False, + turn_off_elliptical_if_some_correct: bool = False, + turn_off_elliptical_if_all_correct: bool = False, + turn_off_elliptical_if_rollout_incorrect: bool = False, + alpha: float = 1.0, + ) -> None: + """ + Initialize the NaiveRewardManager instance. + + Args: + tokenizer: The tokenizer used to decode token IDs into text. + num_examine: The number of batches of decoded responses to print to the console for debugging purpose. + compute_score: A function to compute the reward score. If None, `default_compute_score` will be used. + reward_fn_key: The key used to access the data source in the non-tensor batch data. Defaults to + "data_source". + """ + super().__init__(tokenizer, num_examine, default_compute_score, reward_fn_key) + self.beta = beta + self.turn_off_elliptical_if_none_correct = turn_off_elliptical_if_none_correct + self.turn_off_elliptical_if_some_correct = turn_off_elliptical_if_some_correct + self.turn_off_elliptical_if_all_correct = turn_off_elliptical_if_all_correct + self.turn_off_elliptical_if_rollout_incorrect = turn_off_elliptical_if_rollout_incorrect + self.alpha = alpha + + def __call__(self, data: DataProto, return_dict=False): + if "rm_scores" not in data.batch: + # this means we're doing validation, so we don't need to compute the elliptical reward + return super().__call__(data, return_dict=return_dict) + + reward_extra_info = defaultdict(list) + + intrinsic_reward_tensor = data.batch["rm_scores"] + data.pop(batch_keys=["rm_scores"]) + + extrinsic_reward_result = super().__call__(data, return_dict=True) + extrinsic_reward_tensor = extrinsic_reward_result["reward_tensor"] + extrinsic_reward_extra_info = extrinsic_reward_result["reward_extra_info"] + + self._maybe_turn_off_elliptical(data, extrinsic_reward_tensor, intrinsic_reward_tensor) + + reward_tensor = self.alpha * extrinsic_reward_tensor + self.beta * intrinsic_reward_tensor + + # Intrinsic reward extra info + reward_extra_info["intrinsic_reward"] = intrinsic_reward_tensor.numpy() + reward_extra_info["beta_scaled_intrinsic_reward"] = self.beta * intrinsic_reward_tensor.numpy() + reward_extra_info["extrinsic_reward"] = extrinsic_reward_tensor.numpy() + reward_extra_info["alpha_scaled_extrinsic_reward"] = self.alpha * extrinsic_reward_tensor.numpy() + reward_extra_info["total_reward"] = reward_tensor.numpy() + + # Update with extrinsic reward extra info + reward_extra_info.update(extrinsic_reward_extra_info) + + if return_dict: + return { + "reward_tensor": reward_tensor, + "reward_extra_info": reward_extra_info, + } + else: + return reward_tensor + + def _maybe_turn_off_elliptical( + self, data: DataProto, extrinsic_reward_tensor: torch.Tensor, intrinsic_reward_tensor: torch.Tensor + ) -> None: + """ + Potentially turn off the elliptical reward for samples that have one of the following properties: + (1) any of the rollouts have the correct answer + (2) all of the rollouts have the correct answer + + Args: + data (DataProto): The data proto containing the batch data. + extrinsic_reward_tensor (torch.Tensor): The extrinsic reward tensor. + intrinsic_reward_tensor (torch.Tensor): The intrinsic reward tensor. + + Returns: + None + """ + if self.turn_off_elliptical_if_rollout_incorrect: + mask = extrinsic_reward_tensor.sum(dim=-1) == 0 + intrinsic_reward_tensor[mask] = 0.0 + + visited_uids = set() + for uid in data.non_tensor_batch["uid"]: + if uid in visited_uids: + continue + + visited_uids.add(uid) + mask = torch.from_numpy(data.non_tensor_batch["uid"] == uid) + + # Potentially turn off elliptical if **no** rollout has the correct answer + if self.turn_off_elliptical_if_none_correct and extrinsic_reward_tensor[mask].sum() == 0: + intrinsic_reward_tensor[mask] = 0.0 + + # Potentially turn off elliptical if **some** rollouts have the correct answer + if ( + self.turn_off_elliptical_if_some_correct + and extrinsic_reward_tensor[mask].sum() > 0 + and extrinsic_reward_tensor[mask].sum() < mask.sum() + ): + intrinsic_reward_tensor[mask] = 0.0 + + # Potentially turn off elliptical if **all** rollouts have the correct answer + if self.turn_off_elliptical_if_all_correct and extrinsic_reward_tensor[mask].sum() == mask.sum(): + intrinsic_reward_tensor[mask] = 0.0 diff --git a/recipe/rep_exp/reward_score/__init__.py b/recipe/rep_exp/reward_score/__init__.py new file mode 100644 index 00000000000..124189fa228 --- /dev/null +++ b/recipe/rep_exp/reward_score/__init__.py @@ -0,0 +1,136 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# from . import gsm8k, math, prime_math, prime_code + +from verl.utils.import_utils import deprecated + + +def default_compute_score( + data_source, + solution_str, + ground_truth, + extra_info=None, + sandbox_fusion_url=None, + concurrent_semaphore=None, + memory_limit_mb=None, + **kwargs, +): + """Compute the score for a given solution based on the data source. + + Args: + data_source (str): The source dataset identifier which determines the scoring method. + solution_str (str): The solution string to be evaluated. + ground_truth (str): The ground truth answer for comparison. + extra_info (dict, optional): Additional information that might be needed for scoring. Defaults to None. + + Returns: + float: The computed score as a floating point number. If the result is a dictionary, + it returns the dictionary instead. + + Raises: + NotImplementedError: If the reward function is not implemented for the given data source. + """ + if data_source == "openai/gsm8k": + from verl.utils.reward_score import gsm8k + + res = gsm8k.compute_score(solution_str, ground_truth) + elif data_source in ["lighteval/MATH", "DigitalLearningGmbH/MATH-lighteval", "HuggingFaceH4/MATH-500"]: + from verl.utils.reward_score import math_reward + + res = math_reward.compute_score(solution_str, ground_truth) + # [Optional] Math-Verify Integration + # For enhanced accuracy, consider utilizing Math-Verify (https://github.com/huggingface/Math-Verify). + # Note: Math-Verify needs to be manually installed via pip: `pip install math-verify`. + # To use it, override the `compute_score` function with the following implementation: + + # from . import math_verify + # res = math_verify.compute_score(solution_str, ground_truth) + elif data_source in ["math_dapo", "math", "math_dapo_reasoning"] or data_source.startswith("aime"): + # res = math_dapo.compute_score(solution_str, ground_truth) + from verl.utils.reward_score import math_verify + + res = math_verify.compute_score(solution_str, ground_truth) + elif data_source in [ + "numina_aops_forum", + "numina_synthetic_math", + "numina_amc_aime", + "numina_synthetic_amc", + "numina_cn_k12", + "numina_olympiads", + ]: + from verl.utils.reward_score import prime_math + + res = prime_math.compute_score(solution_str, ground_truth) + elif data_source in ["codecontests", "apps", "codeforces", "taco"]: + # Use the passed sandbox_fusion_url if available + if sandbox_fusion_url: + from verl.utils.reward_score import sandbox_fusion + + # Pass the URL directly, ground_truth likely contains test cases here + res = sandbox_fusion.compute_score( + sandbox_fusion_url, concurrent_semaphore, memory_limit_mb, solution_str, ground_truth, continuous=True + ) + else: + # If no sandbox URL is provided, fall back to prime_code or raise error + from verl.utils.reward_score import prime_code + + # Assuming prime_code doesn't need the URL + res = prime_code.compute_score(solution_str, ground_truth, continuous=True) + elif data_source in ["hiyouga/geometry3k"]: + from verl.utils.reward_score import geo3k + + res = geo3k.compute_score(solution_str, ground_truth) + elif data_source in [ + "searchR1_nq", + "searchR1_triviaqa", + "searchR1_popqa", + "searchR1_hotpotqa", + "searchR1_2wikimultihopqa", + "searchR1_musique", + "searchR1_bamboogle", + ]: + from verl.utils.reward_score import search_r1_like_qa_em + + res = search_r1_like_qa_em.compute_score(solution_str, ground_truth) + + else: + raise NotImplementedError(f"Reward function is not implemented for {data_source=}") + + if isinstance(res, dict): + return res + elif isinstance(res, int | float | bool): + return float(res) + else: + return float(res[0]) + + +@deprecated("verl.utils.reward_score.default_compute_score") +def _default_compute_score( + data_source, + solution_str, + ground_truth, + extra_info=None, + sandbox_fusion_url=None, + concurrent_semaphore=None, + memory_limit_mb=None, +): + """ + Legacy function API to be deprecated. Please use `default_compute_score` instead. + """ + return default_compute_score( + data_source, solution_str, ground_truth, extra_info, sandbox_fusion_url, concurrent_semaphore, memory_limit_mb + ) + + +__all__ = ["default_compute_score"] diff --git a/recipe/rep_exp/train_elliptical.sh b/recipe/rep_exp/train_elliptical.sh new file mode 100644 index 00000000000..70c56afec37 --- /dev/null +++ b/recipe/rep_exp/train_elliptical.sh @@ -0,0 +1,104 @@ +TASK=${1} # math, gsm8k, dapo-with-aime24 +SPARSE_DIM=${2} # the original paper used 32 for math/gsm8k, 128 for dapo-with-aime24 +BETA=${3} # 0.01 +SEED=${4} + +train_path=$HOME/data/${TASK}/train.parquet +dev_path=$HOME/data/${TASK}/dev.parquet + +train_files="['$train_path']" +dev_files="['$dev_path']" + +# Adjust things a bit for dapo-aime training since it has longer generations +# and hence is slower and consumes more memory +if [ ${TASK} == "dapo-with-aime24" ]; then + TEST_FREQ=10 + SAVE_FREQ=10 + TRAIN_BATCH_SIZE=512 + PPO_MINI_BATCH_SIZE=128 + + MAX_PROMPT_LENGTH=$((1024 * 2)) + MAX_RESPONSE_LENGTH=$((1024 * 8)) + MAX_NUM_BATCHED_TOKENS=$((MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH)) + GPU_MEMORY_UTILIZATION=0.5 + PPO_MICRO_BATCH_SIZE_PER_GPU=8 + REWARD_MODEL_MICRO_BATCH_SIZE_PER_GPU=16 +else + TEST_FREQ=20 + SAVE_FREQ=20 + TRAIN_BATCH_SIZE=1024 + PPO_MINI_BATCH_SIZE=256 + + MAX_PROMPT_LENGTH=1024 + MAX_RESPONSE_LENGTH=1024 + MAX_NUM_BATCHED_TOKENS=8192 + GPU_MEMORY_UTILIZATION=0.6 + PPO_MICRO_BATCH_SIZE_PER_GPU=16 + REWARD_MODEL_MICRO_BATCH_SIZE_PER_GPU=32 +fi + +OFFLINE=True + +PYTHONUNBUFFERED=1 TRANSFORMERS_OFFLINE=${OFFLINE} python3 -u -m recipe.rep_exp.main_rep_exp \ + algorithm.adv_estimator=grpo \ + data.train_files="$train_files" \ + data.val_files="$dev_files" \ + data.train_batch_size=$TRAIN_BATCH_SIZE \ + data.max_prompt_length=$MAX_PROMPT_LENGTH \ + data.max_response_length=$MAX_RESPONSE_LENGTH \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=$PPO_MINI_BATCH_SIZE \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$PPO_MICRO_BATCH_SIZE_PER_GPU \ + actor_rollout_ref.actor.kl_loss_coef=0.0 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.ppo_epochs=1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.mode=sync \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.max_num_batched_tokens=$MAX_NUM_BATCHED_TOKENS \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=$GPU_MEMORY_UTILIZATION \ + actor_rollout_ref.rollout.n=8 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + reward_model.enable=True \ + reward_model.model.path=Qwen/Qwen2.5-7B-Instruct \ + reward_model.model.use_remove_padding=False \ + reward_model.model.fsdp_config.param_offload=True \ + reward_model.micro_batch_size_per_gpu=$REWARD_MODEL_MICRO_BATCH_SIZE_PER_GPU \ + reward_model.model.input_tokenizer=null \ + reward_model.elliptical.enable=True \ + reward_model.elliptical.sparse_dim=$SPARSE_DIM \ + reward_model.elliptical.reward_type=leverage \ + reward_model.elliptical.randomize_sparse_matrix=True \ + reward_model.elliptical.normalization=none \ + reward_model.elliptical.persist_covariance=False \ + reward_model.reward_manager=elliptical \ + reward_model.reward_kwargs.elliptical.beta=$BETA \ + reward_model.reward_kwargs.elliptical.turn_off_elliptical_if_none_correct=True \ + reward_model.reward_kwargs.elliptical.turn_off_elliptical_if_some_correct=False \ + reward_model.reward_kwargs.elliptical.turn_off_elliptical_if_all_correct=False \ + reward_model.reward_kwargs.elliptical.turn_off_elliptical_if_rollout_incorrect=False \ + actor_rollout_ref.actor.loss_agg_mode=token-mean \ + actor_rollout_ref.actor.use_kl_loss=True \ + algorithm.norm_adv_by_std_in_grpo=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='rep-exp' \ + trainer.experiment_name="${TASK}_elliptical_seed_${SEED}_beta_${BETA}_sparse_dim_${SPARSE_DIM}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=$SAVE_FREQ \ + trainer.test_freq=$TEST_FREQ \ + trainer.total_epochs=1000 \ + trainer.resume_mode=disable \ + trainer.resume_from_path='' \ No newline at end of file diff --git a/recipe/rep_exp/utils/aggregate_logger.py b/recipe/rep_exp/utils/aggregate_logger.py new file mode 100644 index 00000000000..54a70272d50 --- /dev/null +++ b/recipe/rep_exp/utils/aggregate_logger.py @@ -0,0 +1,49 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A Ray logger will receive logging info from different processes. +""" + +import json +import os + + +class JsonEvalLogger: + """ + A logger that logs to a json file. + Args: + save_path: The path to the checkpoint to resume from. + task: The task name, used to name the experiment. + """ + + def __init__(self, save_path: str, task: str): + self.root = "eval" + if save_path is not None and save_path != "": + self.experiment_name = save_path.split("/")[-2] + self.checkpoint_type = save_path.split("/")[-1] + else: + self.experiment_name = f"{task}_untrained" + self.checkpoint_type = "" + + def flush(self): + pass + + def log(self, data, step): + # Create eval folder + save_folder = os.path.join(self.root, self.experiment_name, self.checkpoint_type) + os.makedirs(save_folder, exist_ok=True) + + # Save to json + with open(os.path.join(save_folder, "eval.json"), "w") as f: + json.dump(data, f) diff --git a/recipe/rep_exp/utils/tracking.py b/recipe/rep_exp/utils/tracking.py new file mode 100644 index 00000000000..898fc0f1aae --- /dev/null +++ b/recipe/rep_exp/utils/tracking.py @@ -0,0 +1,517 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A unified tracking interface that supports logging data to different backend +""" + +import dataclasses +import json +import os +from enum import Enum +from functools import partial +from pathlib import Path +from typing import Any + + +class Tracking: + """A unified tracking interface for logging experiment data to multiple backends. + + This class provides a centralized way to log experiment metrics, parameters, and artifacts + to various tracking backends including WandB, MLflow, SwanLab, TensorBoard, and console. + + Attributes: + supported_backend: List of supported tracking backends. + logger: Dictionary of initialized logger instances for each backend. + """ + + supported_backend = [ + "wandb", + "mlflow", + "swanlab", + "vemlp_wandb", + "tensorboard", + "console", + "clearml", + "trackio", + "file", + "json_eval", + ] + + def __init__(self, project_name, experiment_name, default_backend: str | list[str] = "console", config=None): + if isinstance(default_backend, str): + default_backend = [default_backend] + for backend in default_backend: + if backend == "tracking": + import warnings + + warnings.warn("`tracking` logger is deprecated. use `wandb` instead.", DeprecationWarning, stacklevel=2) + else: + assert backend in self.supported_backend, f"{backend} is not supported" + + self.logger = {} + + if "tracking" in default_backend or "wandb" in default_backend: + import os + + import wandb + + settings = None + if config and config["trainer"].get("wandb_proxy", None): + settings = wandb.Settings(https_proxy=config["trainer"]["wandb_proxy"]) + entity = os.environ.get("WANDB_ENTITY", None) + wandb.init(project=project_name, name=experiment_name, entity=entity, config=config, settings=settings) + self.logger["wandb"] = wandb + + if "trackio" in default_backend: + import trackio + + trackio.init(project=project_name, name=experiment_name, config=config) + self.logger["trackio"] = trackio + + if "mlflow" in default_backend: + import os + + import mlflow + + MLFLOW_TRACKING_URI = os.environ.get("MLFLOW_TRACKING_URI", "sqlite:////tmp/mlruns.db") + mlflow.set_tracking_uri(MLFLOW_TRACKING_URI) + + # Project_name is actually experiment_name in MLFlow + # If experiment does not exist, will create a new experiment + experiment = mlflow.set_experiment(project_name) + mlflow.start_run(experiment_id=experiment.experiment_id, run_name=experiment_name) + mlflow.log_params(_compute_mlflow_params_from_objects(config)) + self.logger["mlflow"] = _MlflowLoggingAdapter() + + if "swanlab" in default_backend: + import os + + import swanlab + + SWANLAB_API_KEY = os.environ.get("SWANLAB_API_KEY", None) + SWANLAB_LOG_DIR = os.environ.get("SWANLAB_LOG_DIR", "swanlog") + SWANLAB_MODE = os.environ.get("SWANLAB_MODE", "cloud") + if SWANLAB_API_KEY: + swanlab.login(SWANLAB_API_KEY) # NOTE: previous login information will be overwritten + + if config is None: + config = {} # make sure config is not None, otherwise **config will raise error + swanlab.init( + project=project_name, + experiment_name=experiment_name, + config={"FRAMEWORK": "verl", **config}, + logdir=SWANLAB_LOG_DIR, + mode=SWANLAB_MODE, + ) + self.logger["swanlab"] = swanlab + + if "vemlp_wandb" in default_backend: + import os + + import volcengine_ml_platform + from volcengine_ml_platform import wandb as vemlp_wandb + + volcengine_ml_platform.init( + ak=os.environ["VOLC_ACCESS_KEY_ID"], + sk=os.environ["VOLC_SECRET_ACCESS_KEY"], + region=os.environ["MLP_TRACKING_REGION"], + ) + + vemlp_wandb.init( + project=project_name, + name=experiment_name, + config=config, + sync_tensorboard=True, + ) + self.logger["vemlp_wandb"] = vemlp_wandb + + if "tensorboard" in default_backend: + self.logger["tensorboard"] = _TensorboardAdapter(project_name, experiment_name) + + if "console" in default_backend: + from verl.utils.logger import LocalLogger + + self.console_logger = LocalLogger(print_to_console=True) + self.logger["console"] = self.console_logger + + if "json_eval" in default_backend: + from .aggregate_logger import JsonEvalLogger + + model_path = config["actor_rollout_ref"]["model"]["path"] + if model_path.endswith("actor/hf"): + # Case where the model path is a saved checkpoint + save_path = model_path.split("/")[-4:-2] + save_path = "/".join(save_path) + else: + # Case where the model is pretrained model from huggingface + save_path = "" + + # Parse task from config + train_file = config["data"]["train_files"][0] + task = train_file.split("/")[-2] + + self.json_eval_logger = JsonEvalLogger(save_path=save_path, task=task) + self.logger["json_eval"] = self.json_eval_logger + + if "clearml" in default_backend: + self.logger["clearml"] = ClearMLLogger(project_name, experiment_name, config) + + if "file" in default_backend: + self.logger["file"] = FileLogger(project_name, experiment_name) + + def log(self, data, step, backend=None): + for default_backend, logger_instance in self.logger.items(): + if backend is None or default_backend in backend: + logger_instance.log(data=data, step=step) + + def __del__(self): + if "wandb" in self.logger: + self.logger["wandb"].finish(exit_code=0) + if "swanlab" in self.logger: + self.logger["swanlab"].finish() + if "vemlp_wandb" in self.logger: + self.logger["vemlp_wandb"].finish(exit_code=0) + if "tensorboard" in self.logger: + self.logger["tensorboard"].finish() + if "clearml" in self.logger: + self.logger["clearml"].finish() + if "trackio" in self.logger: + self.logger["trackio"].finish() + if "file" in self.logger: + self.logger["file"].finish() + + +class ClearMLLogger: + def __init__(self, project_name: str, experiment_name: str, config): + self.project_name = project_name + self.experiment_name = experiment_name + + import clearml + + self._task: clearml.Task = clearml.Task.init( + task_name=experiment_name, + project_name=project_name, + continue_last_task=True, + output_uri=False, + ) + + self._task.connect_configuration(config, name="Hyperparameters") + + def _get_logger(self): + return self._task.get_logger() + + def log(self, data, step): + import numpy as np + import pandas as pd + + # logs = self._rewrite_logs(data) + logger = self._get_logger() + for k, v in data.items(): + title, series = k.split("/", 1) + + if isinstance(v, int | float | np.floating | np.integer): + logger.report_scalar( + title=title, + series=series, + value=v, + iteration=step, + ) + elif isinstance(v, pd.DataFrame): + logger.report_table( + title=title, + series=series, + table_plot=v, + iteration=step, + ) + else: + logger.warning( + f'Trainer is attempting to log a value of "{v}" of type {type(v)} for key "{k}". This ' + f"invocation of ClearML logger's function is incorrect so this attribute was dropped. " + ) + + def finish(self): + self._task.close() + + +class FileLogger: + def __init__(self, project_name: str, experiment_name: str): + self.project_name = project_name + self.experiment_name = experiment_name + + self.filepath = os.getenv("VERL_FILE_LOGGER_PATH", None) + if self.filepath is None: + root_path = os.path.expanduser(os.getenv("VERL_FILE_LOGGER_ROOT", ".")) + directory = os.path.join(root_path, self.project_name) + os.makedirs(directory, exist_ok=True) + self.filepath = os.path.join(directory, f"{self.experiment_name}.jsonl") + print(f"Creating file logger at {self.filepath}") + self.fp = open(self.filepath, "w") + + def log(self, data, step): + data = {"step": step, "data": data} + self.fp.write(json.dumps(data) + "\n") + + def finish(self): + self.fp.close() + + +class _TensorboardAdapter: + def __init__(self, project_name, experiment_name): + import os + + from torch.utils.tensorboard import SummaryWriter + + tensorboard_dir = os.environ.get("TENSORBOARD_DIR", f"tensorboard_log/{project_name}/{experiment_name}") + os.makedirs(tensorboard_dir, exist_ok=True) + print(f"Saving tensorboard log to {tensorboard_dir}.") + self.writer = SummaryWriter(tensorboard_dir) + + def log(self, data, step): + for key in data: + self.writer.add_scalar(key, data[key], step) + + def finish(self): + self.writer.close() + + +class _MlflowLoggingAdapter: + def __init__(self): + import logging + import re + + self.logger = logging.getLogger(__name__) + # MLflow metric key validation logic: + # https://github.com/mlflow/mlflow/blob/master/mlflow/utils/validation.py#L157C12-L157C44 + # Only characters allowed: slashes, alphanumerics, underscores, periods, dashes, colons, + # and spaces. + self._invalid_chars_pattern = re.compile( + r"[^/\w.\- :]" + ) # Allowed: slashes, alphanumerics, underscores, periods, dashes, colons, and spaces. + + def log(self, data, step): + import mlflow + + def sanitize_key(key): + # First replace @ with _at_ for backward compatibility + sanitized = key.replace("@", "_at_") + # Then replace any other invalid characters with _ + sanitized = self._invalid_chars_pattern.sub("_", sanitized) + if sanitized != key: + self.logger.warning( + "[MLflow] Metric key '%s' sanitized to '%s' due to invalid characters.", key, sanitized + ) + return sanitized + + results = {sanitize_key(k): v for k, v in data.items()} + mlflow.log_metrics(metrics=results, step=step) + + +def _compute_mlflow_params_from_objects(params) -> dict[str, Any]: + if params is None: + return {} + + return _flatten_dict(_transform_params_to_json_serializable(params, convert_list_to_dict=True), sep="/") + + +def _transform_params_to_json_serializable(x, convert_list_to_dict: bool): + _transform = partial(_transform_params_to_json_serializable, convert_list_to_dict=convert_list_to_dict) + + if dataclasses.is_dataclass(x): + return _transform(dataclasses.asdict(x)) + if isinstance(x, dict): + return {k: _transform(v) for k, v in x.items()} + if isinstance(x, list): + if convert_list_to_dict: + return {"list_len": len(x)} | {f"{i}": _transform(v) for i, v in enumerate(x)} + else: + return [_transform(v) for v in x] + if isinstance(x, Path): + return str(x) + if isinstance(x, Enum): + return x.value + + return x + + +def _flatten_dict(raw: dict[str, Any], *, sep: str) -> dict[str, Any]: + import pandas as pd + + ans = pd.json_normalize(raw, sep=sep).to_dict(orient="records")[0] + assert isinstance(ans, dict) + return ans + + +@dataclasses.dataclass +class ValidationGenerationsLogger: + project_name: str = None + experiment_name: str = None + + def log(self, loggers, samples, step): + if "wandb" in loggers: + self.log_generations_to_wandb(samples, step) + if "swanlab" in loggers: + self.log_generations_to_swanlab(samples, step) + if "mlflow" in loggers: + self.log_generations_to_mlflow(samples, step) + + if "clearml" in loggers: + self.log_generations_to_clearml(samples, step) + if "tensorboard" in loggers: + self.log_generations_to_tensorboard(samples, step) + + if "vemlp_wandb" in loggers: + self.log_generations_to_vemlp_wandb(samples, step) + + def log_generations_to_vemlp_wandb(self, samples, step): + from volcengine_ml_platform import wandb as vemlp_wandb + + self._log_generations_to_wandb(samples, step, vemlp_wandb) + + def log_generations_to_wandb(self, samples, step): + import wandb + + self._log_generations_to_wandb(samples, step, wandb) + + def _log_generations_to_wandb(self, samples, step, wandb): + """Log samples to wandb as a table""" + + # Create column names for all samples + columns = ["step"] + sum( + [[f"input_{i + 1}", f"output_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))], [] + ) + + if not hasattr(self, "validation_table"): + # Initialize the table on first call + self.validation_table = wandb.Table(columns=columns) + + # Create a new table with same columns and existing data + # Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737 + new_table = wandb.Table(columns=columns, data=self.validation_table.data) + + # Add new row with all data + row_data = [] + row_data.append(step) + for sample in samples: + row_data.extend(sample) + + new_table.add_data(*row_data) + + # Update reference and log + wandb.log({"val/generations": new_table}, step=step) + self.validation_table = new_table + + def log_generations_to_swanlab(self, samples, step): + """Log samples to swanlab as text""" + import swanlab + + swanlab_table = swanlab.echarts.Table() + + # Create column names + headers = ["step", "input", "output", "score"] + + swanlab_row_list = [[step, *sample] for sample in samples] + swanlab_table.add(headers=headers, rows=swanlab_row_list) + + # Log to swanlab + swanlab.log({"val/generations": swanlab_table}, step=step) + + def log_generations_to_mlflow(self, samples, step): + """Log validation generation to mlflow as artifacts""" + # https://mlflow.org/docs/latest/api_reference/python_api/mlflow.html?highlight=log_artifact#mlflow.log_artifact + + import json + import tempfile + + import mlflow + + try: + with tempfile.TemporaryDirectory() as tmp_dir: + validation_gen_step_file = Path(tmp_dir, f"val_step{step}.json") + row_data = [] + for sample in samples: + data = {"input": sample[0], "output": sample[1], "score": sample[2]} + row_data.append(data) + with open(validation_gen_step_file, "w") as file: + json.dump(row_data, file) + mlflow.log_artifact(validation_gen_step_file) + except Exception as e: + print(f"WARNING: save validation generation file to mlflow failed with error {e}") + + def log_generations_to_clearml(self, samples, step): + """Log validation generation to clearml as table""" + + import clearml + import pandas as pd + + task: clearml.Task | None = clearml.Task.current_task() + if task is None: + return + + table = [ + { + "step": step, + "input": sample[0], + "output": sample[1], + "score": sample[2], + } + for sample in samples + ] + + logger = task.get_logger() + logger.report_table( + series="Validation generations", + title="Validation", + table_plot=pd.DataFrame.from_records(table), + iteration=step, + ) + + def log_generations_to_tensorboard(self, samples, step): + """Log samples to tensorboard as text""" + # Initialize tensorboard writer if not exists + if not hasattr(self, "writer"): + from torch.utils.tensorboard import SummaryWriter + + # Use the same directory structure as _TensorboardAdapter + if self.project_name and self.experiment_name: + default_dir = os.path.join("tensorboard_log", self.project_name, self.experiment_name) + else: + default_dir = "tensorboard_log" + + tensorboard_dir = os.environ.get("TENSORBOARD_DIR", default_dir) + os.makedirs(tensorboard_dir, exist_ok=True) + self.writer = SummaryWriter(log_dir=tensorboard_dir) + + # Format the samples data into readable text + text_content = f"**Generation Results - Step {step}**\n\n" + + for i, sample in enumerate(samples): + text_content += f"### Sample {i + 1}\n" + + # Assuming sample contains [input, output, score] + if len(sample) >= 3: + input_text, output_text, score = sample[0], sample[1], sample[2] + + text_content += f"**Input:** {input_text}\n\n" + text_content += f"**Output:** {output_text}\n\n" + text_content += f"**Score:** {score}\n\n" + else: + # Handle cases where sample format might be different + text_content += f"**Data:** {sample}\n\n" + + text_content += "---\n\n" + + # Log to tensorboard as text + self.writer.add_text("val/generations", text_content, step) + # Flush to ensure data is written + self.writer.flush() diff --git a/recipe/rep_exp/workers/elliptical_reward_model_worker.py b/recipe/rep_exp/workers/elliptical_reward_model_worker.py new file mode 100644 index 00000000000..931779bf8c9 --- /dev/null +++ b/recipe/rep_exp/workers/elliptical_reward_model_worker.py @@ -0,0 +1,389 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The main entry point to run the PPO algorithm +""" + +import logging +import os +import warnings + +import numpy as np +import torch +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from verl import DataProto +from verl.models.transformers.monkey_patch import apply_monkey_patch +from verl.single_controller.base.decorator import Dispatch, Execute, register +from verl.utils import hf_tokenizer +from verl.utils.device import ( + get_device_id, + get_device_name, +) +from verl.utils.fs import copy_to_local +from verl.utils.fsdp_utils import ( + CPUOffloadPolicy, + apply_fsdp2, + fsdp2_load_full_state_dict, + fsdp_version, + get_fsdp_wrap_policy, + get_init_weight_context_manager, + get_shard_placement_fn, + init_fn, +) +from verl.utils.profiler import DistProfiler +from verl.workers.fsdp_workers import RewardModelWorker, get_sharding_strategy + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + +device_name = get_device_name() + + +class EllipticalRewardModelWorker(RewardModelWorker): + def __init__(self, config): + super().__init__(config) + self.lamb = config.elliptical.lamb + self.normalization = config.elliptical.normalization + self.sparse_dim = config.elliptical.sparse_dim + self.sparse_matrix = None + self.randomize_sparse_matrix = config.elliptical.randomize_sparse_matrix + self.persist_covariance = config.elliptical.persist_covariance + self.cov_inv_dict = {} + self.mean_hidden_states_mu_dict = {} + self.hidden_mean_counter_dict = {} + + @staticmethod + def _construct_sparse_matrix(features: torch.Tensor, sparse_dim: int) -> torch.Tensor: + from sklearn.random_projection import SparseRandomProjection + + sparse_proj = SparseRandomProjection(sparse_dim, density="auto") + sparse_proj.fit(features) + sparse_matrix = sparse_proj.components_ + sparse_matrix_coo = sparse_matrix.tocoo() + + # Convert the row and col lists to numpy arrays and then to a LongTensor (speed up) + indices = torch.LongTensor(np.array([sparse_matrix_coo.row, sparse_matrix_coo.col])) + values = torch.FloatTensor(sparse_matrix_coo.data) + + sparse_mat = torch.sparse_coo_tensor(indices, values, [sparse_dim, features.shape[1]]).t() + + return sparse_mat + + def _build_model(self, config): + # the following line is necessary + from torch.distributed.fsdp import CPUOffload + from transformers import AutoConfig, AutoModel + + use_shm = config.model.get("use_shm", False) + # download the checkpoint from hdfs + local_path = copy_to_local(config.model.path, use_shm=use_shm) + + if self.config.model.input_tokenizer is None: + self._do_switch_chat_template = False + else: + self._do_switch_chat_template = True + input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer, use_shm=use_shm) + self.input_tokenizer = hf_tokenizer( + input_tokenizer_local_path, trust_remote_code=config.model.get("trust_remote_code", False) + ) + self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get("trust_remote_code", False)) + + trust_remote_code = config.model.get("trust_remote_code", False) + model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) + model_config.num_labels = 1 + + # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect + init_context = get_init_weight_context_manager( + use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh + ) + + with init_context(), warnings.catch_warnings(): + warnings.simplefilter("ignore") + model_config.classifier_dropout = 0.0 + reward_module = AutoModel.from_pretrained( + pretrained_model_name_or_path=local_path, + config=model_config, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + trust_remote_code=trust_remote_code, + ) + + apply_monkey_patch( + model=reward_module, + use_remove_padding=config.model.get("use_remove_padding", False), + ulysses_sp_size=self.ulysses_sequence_parallel_size, + ) + + reward_module.to(torch.bfloat16) + + auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config) + + fsdp_mesh = self.device_mesh + sharding_strategy = get_sharding_strategy(fsdp_mesh) + + if config.strategy == "fsdp": + reward_module = FSDP( + reward_module, + param_init_fn=init_fn, + use_orig_params=False, + auto_wrap_policy=auto_wrap_policy, + device_id=get_device_id(), + sharding_strategy=sharding_strategy, # zero3 + sync_module_states=True, + cpu_offload=CPUOffload(offload_params=True), + forward_prefetch=self.config.model.fsdp_config.forward_prefetch, + device_mesh=self.device_mesh, + ) + elif config.strategy == "fsdp2": + assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" + cpu_offload = CPUOffloadPolicy(pin_memory=True) + fsdp_kwargs = { + "mesh": fsdp_mesh, + "offload_policy": cpu_offload, + "reshard_after_forward": config.model.fsdp_config.reshard_after_forward, + "shard_placement_fn": get_shard_placement_fn(fsdp_size=self.device_mesh.shape[-1]), + } + full_state = reward_module.state_dict() + apply_fsdp2(reward_module, fsdp_kwargs, config.model.fsdp_config) + fsdp2_load_full_state_dict(reward_module, full_state, fsdp_mesh, cpu_offload) + else: + raise NotImplementedError(f"Unknown strategy: {config.strategy}") + return reward_module + + def _forward_micro_batch(self, micro_batch, start_of_response: int): + with torch.no_grad(), torch.autocast(device_type=device_name, dtype=torch.bfloat16): + input_ids = micro_batch["input_ids"] + batch_size, seqlen = input_ids.shape + attention_mask = micro_batch["attention_mask"] + position_ids = micro_batch["position_ids"] + if position_ids.dim() == 3: # qwen2vl mrope + position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen) + + if self.use_remove_padding: + raise NotImplementedError("Remove padding is not implemented for elliptical reward model") + else: + output = self.reward_module( + input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False + ) + + sequence_lengths = attention_mask[:, start_of_response:].sum(dim=1) + mean_hidden_states = [] + for i, seq_len in enumerate(sequence_lengths): + mean_hidden_states.append( + output.last_hidden_state[i, start_of_response : start_of_response + seq_len].mean(dim=0) + ) + mean_hidden_states = torch.stack(mean_hidden_states) + + return mean_hidden_states + + @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + @DistProfiler.annotate(color="brown") + def compute_hidden_states(self, data: DataProto): + import itertools + + from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches + + # Support all hardwares + data = data.to(get_device_id()) + if self._do_switch_chat_template: + rm_data = self._switch_chat_template(data) + else: + rm_input_ids = data.batch["input_ids"] + rm_attention_mask = data.batch["attention_mask"] + rm_position_ids = data.batch["position_ids"] + rm_inputs = { + "input_ids": rm_input_ids, + "attention_mask": rm_attention_mask, + "position_ids": rm_position_ids, + } + rm_data = DataProto.from_dict(rm_inputs) + + # Support all hardwares + rm_data = rm_data.to(get_device_id()) + + # perform forward computation + with self.ulysses_sharding_manager: + use_dynamic_bsz = self.config.use_dynamic_bsz + if use_dynamic_bsz: + max_token_len = self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size + micro_batches, indices = rearrange_micro_batches(batch=rm_data.batch, max_token_len=max_token_len) + else: + micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu) + output = [] + for micro_batch in micro_batches: + mean_hidden_states = self._forward_micro_batch( + micro_batch, start_of_response=data.batch["prompts"].shape[-1] + ) + output.append(mean_hidden_states) + mean_hidden_states = torch.cat(output, dim=0) # (batch_size) + + # NOTE(Jens): this has not been thoroughly checked + if use_dynamic_bsz: + indices = list(itertools.chain.from_iterable(indices)) + assert len(indices) == mean_hidden_states.size(0), f"{len(indices)} vs. {mean_hidden_states.size()}" + revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) + mean_hidden_states = mean_hidden_states[revert_indices] + + # Note that this is only the scores, may not be the final rewards used to train RL + output = DataProto.from_dict(tensors={"mean_hidden_states": mean_hidden_states}) + + # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes + # unshard the root FSDP module + if self.world_size > 1 and fsdp_version(self.reward_module) == 1: + self.reward_module._handle.reshard(True) + + output = output.to("cpu") + return output + + def _compute_bonuses(self, hidden_states, cov_inv, prompt_index: int): + if self.config.elliptical.reward_type == "leave_one_out": + if self.persist_covariance: + raise NotImplementedError("Leave-one-out with persistence is not implemented") + else: + bonuses = [] + for i, hidden_state in enumerate(hidden_states): + chosen_samp = hidden_state.unsqueeze(1) + middle_part = torch.inverse(1 - chosen_samp.t() @ cov_inv @ chosen_samp) + leave_one_out_cov_inv = cov_inv + cov_inv @ chosen_samp @ middle_part @ chosen_samp.t() @ cov_inv + bonus = (chosen_samp.t() @ leave_one_out_cov_inv @ chosen_samp).flatten().float() + bonuses.append(bonus) + + bonuses = torch.concat(bonuses) + + elif self.config.elliptical.reward_type == "leverage": + if self.persist_covariance: + hidden_mean = self.mean_hidden_states_mu_dict[prompt_index] + hidden_mean_counter = self.hidden_mean_counter_dict[prompt_index] + + hidden_states = hidden_states - hidden_mean + + numerator = cov_inv @ hidden_mean.unsqueeze(1) @ hidden_mean.unsqueeze(0) @ cov_inv + denominator = -1 / hidden_mean_counter + hidden_mean.t() @ cov_inv @ hidden_mean + cov_inv_mean_adjusted = cov_inv - numerator / denominator + batch_cov_inv = cov_inv_mean_adjusted.unsqueeze(0).expand(hidden_states.shape[0], -1, -1) + else: + batch_cov_inv = cov_inv.unsqueeze(0).expand(hidden_states.shape[0], -1, -1) + + bonuses = (hidden_states.unsqueeze(1) @ batch_cov_inv @ hidden_states.unsqueeze(2)).flatten().float() + + return bonuses + + def _normalize_bonuses(self, bonuses): + if self.normalization == "none": + pass + elif self.normalization == "rnd": + std = torch.std(bonuses) + if std > 0: + bonuses = bonuses / std + elif self.normalization == "z_score": + mean = torch.mean(bonuses) + std = torch.std(bonuses) + if std > 0: + bonuses = (bonuses - mean) / std + else: + bonuses = bonuses - mean + else: + raise ValueError(f"Unknown normalization: {self.normalization}") + + return bonuses + + @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO) + @DistProfiler.annotate(color="brown") + def compute_rm_score(self, data: DataProto): + if self.sparse_matrix is None: + d = data.batch["mean_hidden_states"].shape[-1] + sparse_matrix = self._construct_sparse_matrix(torch.randn(1, d), self.sparse_dim) + if not self.randomize_sparse_matrix: + self.sparse_matrix = sparse_matrix + else: + sparse_matrix = self.sparse_matrix + + mean_hidden_states = data.batch["mean_hidden_states"].to(get_device_id()).float() + + # sparse project + mean_hidden_states = mean_hidden_states @ sparse_matrix.to(get_device_id()) + + # upgrade to float64 + mean_hidden_states = mean_hidden_states.to(torch.float64) + + seen_uids = set() + reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32).to(get_device_id()) + raw_bonuses_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32).to(get_device_id()) + for i in range(len(data)): + data_item = data[i] + uid = data_item.non_tensor_batch["uid"] + if uid in seen_uids: + continue + + seen_uids.add(uid) + mask = data.non_tensor_batch["uid"] == uid + filtered_mean_hidden_states = mean_hidden_states[mask] + + prompt_index = data_item.non_tensor_batch["extra_info"]["index"] + + if self.persist_covariance: + # first update the mean hidden states mu + if prompt_index not in self.mean_hidden_states_mu_dict: + self.mean_hidden_states_mu_dict[prompt_index] = filtered_mean_hidden_states.mean(dim=0) + self.hidden_mean_counter_dict[prompt_index] = mask.sum() + else: + total_count = self.hidden_mean_counter_dict[prompt_index] + mask.sum() + old_mu = self.mean_hidden_states_mu_dict[prompt_index] + new_mu = ( + old_mu * self.hidden_mean_counter_dict[prompt_index] + + filtered_mean_hidden_states.mean(dim=0) * mask.sum() + ) / total_count + self.mean_hidden_states_mu_dict[prompt_index] = new_mu + self.hidden_mean_counter_dict[prompt_index] = total_count + + # NOTE: we don't center here since otherwise the covariance will accumulate stale means + final_mean_hidden_states = filtered_mean_hidden_states + + if prompt_index not in self.cov_inv_dict: + d = final_mean_hidden_states.shape[-1] + self.cov_inv_dict[prompt_index] = ( + torch.eye(d, dtype=torch.float64).to(get_device_id()) * self.lamb**-1 + ) + cov_inv = self.cov_inv_dict[prompt_index] + else: + centered_mean_hidden_states = filtered_mean_hidden_states - filtered_mean_hidden_states.mean(dim=0) + final_mean_hidden_states = centered_mean_hidden_states + + d = final_mean_hidden_states.shape[-1] + cov_inv = torch.eye(d, dtype=torch.float64).to(get_device_id()) * self.lamb**-1 + + # update inverse covariance matrix with rank-1 updates + for hidden_state in final_mean_hidden_states: + chosen_samp = hidden_state.unsqueeze(1) + middle_part = torch.inverse(1 + chosen_samp.t() @ cov_inv @ chosen_samp) + cov_inv = cov_inv - cov_inv @ chosen_samp @ middle_part @ chosen_samp.t() @ cov_inv + + if self.persist_covariance: + self.cov_inv_dict[prompt_index] = cov_inv + + raw_bonuses = self._compute_bonuses(final_mean_hidden_states, cov_inv, prompt_index) + normalized_bonuses = self._normalize_bonuses(raw_bonuses) + + prompt_ids = data.batch["prompts"][mask] + prompt_length = prompt_ids.shape[-1] + valid_response_lengths = data.batch["attention_mask"][mask, prompt_length:].sum(-1) + + raw_bonuses_tensor[mask, valid_response_lengths - 1] = raw_bonuses + reward_tensor[mask, valid_response_lengths - 1] = normalized_bonuses + + output = DataProto.from_dict( + tensors={"rm_scores": reward_tensor}, non_tensors={"raw_bonuses": raw_bonuses_tensor.cpu().numpy()} + ) + return output.to("cpu")