diff --git a/docs/algo/repexp.md b/docs/algo/repexp.md
new file mode 100644
index 00000000000..456b70760a4
--- /dev/null
+++ b/docs/algo/repexp.md
@@ -0,0 +1,68 @@
+# Recipe: Representation-Based Exploration (RepExp)
+
+Last updated: 11/14/2025.
+
+
+
+Representation-Based Exploration for Language Models:
From Test-Time to Post-Training
+
+[📄 arXiv](https://arxiv.org/abs/2510.11686) [🌐 Website](https://rep-exp.github.io) [🐦 Twitter / X ](https://x.com/JensTuyls/status/1978244454617128993)
+
+
+
+
+## Installation 🔌
+
+Our algorithm doesn't require anything beyond the base verl installation, which you can find [here](https://verl.readthedocs.io/en/latest/start/install.html).
+
+## Running the Experiments 🚀
+
+You can reproduce or extend our experiments by running the following commands:
+
+```bash
+# General format
+sh recipe/rep_exp/train_elliptical.sh $TASK $SPARSE_DIM $BETA $SEED
+
+# MATH
+sh recipe/rep_exp/train_elliptical.sh math 32 0.01 42
+
+# GSM8K
+sh recipe/rep_exp/train_elliptical.sh gsm8k 32 0.01 42
+
+# DAPO-WITH-AIME
+sh recipe/rep_exp/train_elliptical.sh dapo-with-aime24 128 0.01 42
+```
+where `$TASK` is the task name, `$SPARSE_DIM` is the sparse dimension, `$BETA` is the beta parameter, and `$SEED` is the seed.
+
+## Evaluation 📊
+Once done training, you can evaluate the model on the test set by following two steps.
+1. Merge the model checkpoint.
+
+This is necessary because the model checkpoint is saved in multiple shards (depending on the nubmer of GPUs), and we need to merge them into a single checkpoint.
+
+```bash
+sh recipe/rep_exp/model_merge.sh /path/to/global_step_X/actor # where X is the global step of the checkpoint with the best pass@1 on dev
+```
+
+2. Evaluate the merged model.
+
+```bash
+sh recipe/rep_exp/eval.sh $TASK /path/to/global_step_X/actor/hf #where X is the global step of the checkpoint with the best pass@1 on dev
+```
+
+The results should be in a folder named `eval` and saved as a JSON file.
+
+## Citation 📝
+
+```bibtex
+@article{tuyls2025representation,
+ title={Representation-Based Exploration for Language Models: From Test-Time to Post-Training},
+ author={Tuyls, Jens and Foster, Dylan J and Krishnamurthy, Akshay and Ash, Jordan T},
+ journal={arXiv preprint arXiv:2510.11686},
+ year={2025}
+}
+```
+
+## Contact 📬
+
+If you have any questions or suggestions, feel free to reach out at [jtuyls@princeton.edu](mailto:jtuyls@princeton.edu).
\ No newline at end of file
diff --git a/docs/index.rst b/docs/index.rst
index 7ddb244ac02..e91e748ac0e 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -75,6 +75,7 @@ verl is fast with:
algo/spin.md
algo/sppo.md
algo/entropy.md
+ algo/repexp.md
algo/opo.md
algo/baseline.md
algo/gpg.md
diff --git a/recipe/rep_exp/README.md b/recipe/rep_exp/README.md
new file mode 100644
index 00000000000..0fcfacdc355
--- /dev/null
+++ b/recipe/rep_exp/README.md
@@ -0,0 +1,66 @@
+
+
+# Representation-Based Exploration for Language Models:
From Test-Time to Post-Training
+
+[📄 arXiv](https://arxiv.org/abs/2510.11686) [🌐 Website](https://rep-exp.github.io) [🐦 Twitter / X ](https://x.com/JensTuyls/status/1978244454617128993)
+
+
+
+## Installation 🔌
+
+Besides the base verl installation, which you can find [here](https://verl.readthedocs.io/en/latest/start/install.html), the only package to install is scikit-learn.
+```bash
+pip install scikit-learn
+```
+
+## Running the Experiments 🚀
+
+You can reproduce or extend our experiments by running the following commands:
+
+```bash
+# General format
+sh recipe/rep_exp/train_elliptical.sh $TASK $SPARSE_DIM $BETA $SEED
+
+# MATH
+sh recipe/rep_exp/train_elliptical.sh math 32 0.01 42
+
+# GSM8K
+sh recipe/rep_exp/train_elliptical.sh gsm8k 32 0.01 42
+
+# DAPO-WITH-AIME
+sh recipe/rep_exp/train_elliptical.sh dapo-with-aime24 128 0.01 42
+```
+where `$TASK` is the task name, `$SPARSE_DIM` is the sparse dimension, `$BETA` is the beta parameter, and `$SEED` is the seed.
+
+## Evaluation 📊
+Once done training, you can evaluate the model on the test set by following two steps.
+1. Merge the model checkpoint.
+
+This is necessary because the model checkpoint is saved in multiple shards (depending on the nubmer of GPUs), and we need to merge them into a single checkpoint.
+
+```bash
+sh recipe/rep_exp/model_merge.sh /path/to/global_step_X/actor # where X is the global step of the checkpoint with the best pass@1 on dev
+```
+
+2. Evaluate the merged model.
+
+```bash
+sh recipe/rep_exp/eval.sh $TASK /path/to/global_step_X/actor/hf #where X is the global step of the checkpoint with the best pass@1 on dev
+```
+
+The results should be in a folder named `eval` and saved as a JSON file.
+
+## Citation 📝
+
+```bibtex
+@article{tuyls2025representation,
+ title={Representation-Based Exploration for Language Models: From Test-Time to Post-Training},
+ author={Tuyls, Jens and Foster, Dylan J and Krishnamurthy, Akshay and Ash, Jordan T},
+ journal={arXiv preprint arXiv:2510.11686},
+ year={2025}
+}
+```
+
+## Contact 📬
+
+If you have any questions or suggestions, feel free to reach out at [jtuyls@princeton.edu](mailto:jtuyls@princeton.edu).
\ No newline at end of file
diff --git a/recipe/rep_exp/config/rep_exp_trainer.yaml b/recipe/rep_exp/config/rep_exp_trainer.yaml
new file mode 100644
index 00000000000..30fe1c84a31
--- /dev/null
+++ b/recipe/rep_exp/config/rep_exp_trainer.yaml
@@ -0,0 +1,33 @@
+hydra:
+ searchpath:
+ - file://verl/trainer/config
+
+defaults:
+ - ppo_trainer
+ - _self_
+
+reward_model:
+ elliptical:
+ enable: True
+ lamb: 0.01
+ normalization: none # none, rnd, z_score
+ reward_type: leave_one_out # leave_one_out, leverage
+ sparse_dim: 512
+ randomize_sparse_matrix: True
+ persist_covariance: False
+
+ reward_kwargs:
+ elliptical:
+ alpha: 1.0
+ beta: 1.0
+ turn_off_elliptical_if_none_correct: True
+ turn_off_elliptical_if_some_correct: False
+ turn_off_elliptical_if_all_correct: False
+ turn_off_elliptical_if_rollout_incorrect: False
+
+actor_rollout_ref:
+ rollout:
+ val_kwargs:
+ temperature: 1.0
+ n: 128
+ do_sample: True
diff --git a/recipe/rep_exp/data_preprocess/dapo_with_aime.py b/recipe/rep_exp/data_preprocess/dapo_with_aime.py
new file mode 100644
index 00000000000..db3dd03d42f
--- /dev/null
+++ b/recipe/rep_exp/data_preprocess/dapo_with_aime.py
@@ -0,0 +1,104 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Preprocess DAPO dataset to parquet format
+"""
+
+import argparse
+import os
+
+import datasets
+import numpy as np
+
+from verl.utils.hdfs_io import copy, makedirs
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--local_dir", default="~/data/dapo-with-aime24")
+ parser.add_argument("--hdfs_dir", default=None)
+ parser.add_argument("--dapo_dataset_path", type=str, default="ftajwar/deduplicated_dapo_dataset")
+ parser.add_argument("--aime24_part_1_dataset_path", type=str, default="MathArena/aime_2024_I")
+ parser.add_argument("--aime24_part_2_dataset_path", type=str, default="MathArena/aime_2024_II")
+ parser.add_argument("--train_size", type=int, default=4096)
+
+ args = parser.parse_args()
+
+ data_source = "math_dapo"
+
+ # Load DAPO dataset for training
+ dapo_dataset_path = args.dapo_dataset_path
+ dapo_dataset = datasets.load_dataset(dapo_dataset_path, trust_remote_code=True)
+
+ # Load AIME 2024 part 1 dataset for testing
+ aime24_dataset_path_part_1 = args.aime24_part_1_dataset_path
+ aime24_dataset_part_1 = datasets.load_dataset(aime24_dataset_path_part_1, trust_remote_code=True)
+
+ # Load AIME 2024 part 2 dataset for testing
+ aime24_dataset_path_part_2 = args.aime24_part_2_dataset_path
+ aime24_dataset_part_2 = datasets.load_dataset(aime24_dataset_path_part_2, trust_remote_code=True)
+
+ train_dataset = dapo_dataset["train"]
+ train_dataset = train_dataset.select(np.random.choice(len(train_dataset), size=args.train_size, replace=False))
+
+ dev_dataset_aime24_part_1 = aime24_dataset_part_1["train"]
+ dev_dataset_aime24_part_2 = aime24_dataset_part_2["train"]
+ dev_dataset = datasets.concatenate_datasets([dev_dataset_aime24_part_1, dev_dataset_aime24_part_2])
+
+ instruction_following = "Let's think step by step and output the final answer within \\boxed{}."
+
+ # add a row to each data item that represents a unique id
+ def make_map_fn(split):
+ def process_fn(example, idx):
+ if "prompt" in example:
+ question = example.pop("prompt")
+ elif "problem" in example:
+ question = example.pop("problem")
+ else:
+ raise ValueError(f"Unknown question type: {example}")
+
+ question = question + " " + instruction_following
+
+ if "answer" in example:
+ solution = example.pop("answer")
+ else:
+ raise ValueError(f"Unknown answer type: {example}")
+ solution = str(solution)
+
+ data = {
+ "data_source": data_source,
+ "prompt": [{"role": "user", "content": question}],
+ "ability": "math",
+ "reward_model": {
+ "style": "rule",
+ "ground_truth": solution,
+ },
+ "extra_info": {"split": split, "index": idx},
+ }
+ return data
+
+ return process_fn
+
+ train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
+ dev_dataset = dev_dataset.map(function=make_map_fn("test"), with_indices=True)
+
+ local_dir = args.local_dir
+ hdfs_dir = args.hdfs_dir
+
+ train_dataset.to_parquet(os.path.join(local_dir, "train.parquet"))
+ dev_dataset.to_parquet(os.path.join(local_dir, "dev.parquet"))
+
+ if hdfs_dir is not None:
+ makedirs(hdfs_dir)
+
+ copy(src=local_dir, dst=hdfs_dir)
diff --git a/recipe/rep_exp/data_preprocess/gsm8k.py b/recipe/rep_exp/data_preprocess/gsm8k.py
new file mode 100644
index 00000000000..e4d8cf4fc85
--- /dev/null
+++ b/recipe/rep_exp/data_preprocess/gsm8k.py
@@ -0,0 +1,112 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Preprocess the GSM8k dataset to parquet format
+"""
+
+import argparse
+import os
+import re
+
+import datasets
+import numpy as np
+
+from verl.utils.hdfs_io import copy, makedirs
+
+
+def extract_solution(solution_str):
+ solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
+ assert solution is not None
+ final_solution = solution.group(0)
+ final_solution = final_solution.split("#### ")[1].replace(",", "")
+ return final_solution
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--local_dir", default=None, help="The save directory for the preprocessed dataset.")
+ parser.add_argument("--hdfs_dir", default=None)
+ parser.add_argument("--local_dataset_path", default=None, help="The local path to the raw dataset, if it exists.")
+ parser.add_argument(
+ "--local_save_dir", default="~/data/gsm8k", help="The save directory for the preprocessed dataset."
+ )
+
+ args = parser.parse_args()
+ local_dataset_path = args.local_dataset_path
+
+ data_source = "openai/gsm8k"
+
+ if local_dataset_path is not None:
+ dataset = datasets.load_dataset(local_dataset_path, "main")
+ else:
+ dataset = datasets.load_dataset(data_source, "main")
+
+ train_dataset = dataset["train"]
+ test_dataset = dataset["test"]
+
+ instruction_following = 'Let\'s think step by step and output the final answer after "####".'
+
+ # add a row to each data item that represents a unique id
+ def make_map_fn(split):
+ def process_fn(example, idx):
+ question_raw = example.pop("question")
+
+ question = question_raw + " " + instruction_following
+
+ answer_raw = example.pop("answer")
+ solution = extract_solution(answer_raw)
+ data = {
+ "data_source": data_source,
+ "prompt": [
+ {
+ "role": "user",
+ "content": question,
+ }
+ ],
+ "ability": "math",
+ "reward_model": {"style": "rule", "ground_truth": solution},
+ "extra_info": {
+ "split": split,
+ "index": idx,
+ "answer": answer_raw,
+ "question": question_raw,
+ },
+ }
+ return data
+
+ return process_fn
+
+ train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
+ test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True)
+ # split test into dev and test by picking random subset of 512 examples
+ all_test_indices = range(len(test_dataset))
+ all_test_indices = list(all_test_indices)
+ np.random.shuffle(all_test_indices)
+ dev_dataset = test_dataset.select(all_test_indices[:512])
+ test_dataset = test_dataset.select(all_test_indices[512:])
+
+ hdfs_dir = args.hdfs_dir
+ local_save_dir = args.local_dir
+ if local_save_dir is not None:
+ print("Warning: Argument 'local_dir' is deprecated. Please use 'local_save_dir' instead.")
+ else:
+ local_save_dir = args.local_save_dir
+
+ train_dataset.to_parquet(os.path.join(local_save_dir, "train.parquet"))
+ test_dataset.to_parquet(os.path.join(local_save_dir, "test.parquet"))
+
+ if hdfs_dir is not None:
+ makedirs(hdfs_dir)
+
+ copy(src=local_save_dir, dst=hdfs_dir)
diff --git a/recipe/rep_exp/data_preprocess/math_dataset.py b/recipe/rep_exp/data_preprocess/math_dataset.py
new file mode 100644
index 00000000000..1ae35ea93f4
--- /dev/null
+++ b/recipe/rep_exp/data_preprocess/math_dataset.py
@@ -0,0 +1,595 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Preprocess the MATH-lighteval dataset to parquet format
+"""
+
+import argparse
+import os
+
+import datasets
+
+from verl.utils.hdfs_io import copy, makedirs
+from verl.utils.reward_score.math import last_boxed_only_string, remove_boxed
+
+# These are the MATH-500 indices
+DEV_INDICES = [
+ 4,
+ 6,
+ 15,
+ 18,
+ 34,
+ 36,
+ 37,
+ 41,
+ 45,
+ 64,
+ 66,
+ 85,
+ 92,
+ 100,
+ 120,
+ 127,
+ 133,
+ 136,
+ 149,
+ 160,
+ 161,
+ 162,
+ 166,
+ 168,
+ 202,
+ 215,
+ 243,
+ 247,
+ 256,
+ 260,
+ 270,
+ 320,
+ 361,
+ 367,
+ 381,
+ 392,
+ 396,
+ 411,
+ 450,
+ 451,
+ 452,
+ 460,
+ 496,
+ 501,
+ 503,
+ 505,
+ 511,
+ 513,
+ 520,
+ 534,
+ 563,
+ 564,
+ 571,
+ 576,
+ 579,
+ 587,
+ 596,
+ 601,
+ 607,
+ 609,
+ 612,
+ 615,
+ 622,
+ 666,
+ 673,
+ 683,
+ 684,
+ 695,
+ 700,
+ 703,
+ 709,
+ 718,
+ 722,
+ 738,
+ 748,
+ 757,
+ 761,
+ 762,
+ 782,
+ 805,
+ 817,
+ 834,
+ 840,
+ 849,
+ 853,
+ 854,
+ 859,
+ 882,
+ 885,
+ 888,
+ 906,
+ 909,
+ 933,
+ 941,
+ 962,
+ 978,
+ 985,
+ 988,
+ 991,
+ 1008,
+ 1033,
+ 1037,
+ 1046,
+ 1048,
+ 1054,
+ 1058,
+ 1067,
+ 1073,
+ 1085,
+ 1088,
+ 1095,
+ 1111,
+ 1119,
+ 1123,
+ 1127,
+ 1128,
+ 1131,
+ 1136,
+ 1144,
+ 1145,
+ 1150,
+ 1172,
+ 1173,
+ 1180,
+ 1188,
+ 1190,
+ 1194,
+ 1196,
+ 1215,
+ 1243,
+ 1250,
+ 1251,
+ 1258,
+ 1262,
+ 1271,
+ 1281,
+ 1285,
+ 1287,
+ 1290,
+ 1302,
+ 1308,
+ 1311,
+ 1312,
+ 1322,
+ 1339,
+ 1359,
+ 1374,
+ 1380,
+ 1402,
+ 1441,
+ 1442,
+ 1449,
+ 1513,
+ 1531,
+ 1540,
+ 1543,
+ 1552,
+ 1555,
+ 1576,
+ 1603,
+ 1612,
+ 1620,
+ 1690,
+ 1710,
+ 1715,
+ 1730,
+ 1764,
+ 1767,
+ 1769,
+ 1788,
+ 1790,
+ 1791,
+ 1801,
+ 1806,
+ 1820,
+ 1842,
+ 1843,
+ 1880,
+ 1890,
+ 1897,
+ 1901,
+ 1905,
+ 1908,
+ 1932,
+ 1935,
+ 1940,
+ 1963,
+ 1967,
+ 1981,
+ 1996,
+ 2001,
+ 2006,
+ 2011,
+ 2041,
+ 2047,
+ 2053,
+ 2057,
+ 2062,
+ 2063,
+ 2078,
+ 2110,
+ 2119,
+ 2120,
+ 2143,
+ 2148,
+ 2150,
+ 2151,
+ 2170,
+ 2186,
+ 2191,
+ 2196,
+ 2199,
+ 2210,
+ 2214,
+ 2215,
+ 2217,
+ 2231,
+ 2236,
+ 2237,
+ 2238,
+ 2246,
+ 2253,
+ 2263,
+ 2264,
+ 2275,
+ 2289,
+ 2294,
+ 2297,
+ 2303,
+ 2311,
+ 2323,
+ 2324,
+ 2325,
+ 2327,
+ 2328,
+ 2334,
+ 2352,
+ 2359,
+ 2360,
+ 2371,
+ 2382,
+ 2384,
+ 2397,
+ 2404,
+ 2409,
+ 2413,
+ 2416,
+ 2473,
+ 2505,
+ 2512,
+ 2515,
+ 2522,
+ 2536,
+ 2539,
+ 2546,
+ 2569,
+ 2571,
+ 2579,
+ 2602,
+ 2607,
+ 2609,
+ 2611,
+ 2622,
+ 2628,
+ 2637,
+ 2647,
+ 2681,
+ 2682,
+ 2700,
+ 2707,
+ 2731,
+ 2752,
+ 2758,
+ 2767,
+ 2799,
+ 2802,
+ 2808,
+ 2816,
+ 2838,
+ 2851,
+ 2863,
+ 2868,
+ 2876,
+ 2883,
+ 2896,
+ 2907,
+ 2937,
+ 2938,
+ 2946,
+ 2966,
+ 2977,
+ 2991,
+ 2994,
+ 3018,
+ 3019,
+ 3020,
+ 3022,
+ 3024,
+ 3035,
+ 3037,
+ 3046,
+ 3047,
+ 3058,
+ 3067,
+ 3072,
+ 3079,
+ 3080,
+ 3105,
+ 3126,
+ 3134,
+ 3141,
+ 3165,
+ 3181,
+ 3186,
+ 3187,
+ 3196,
+ 3200,
+ 3210,
+ 3220,
+ 3226,
+ 3236,
+ 3240,
+ 3246,
+ 3287,
+ 3295,
+ 3299,
+ 3317,
+ 3320,
+ 3323,
+ 3334,
+ 3341,
+ 3342,
+ 3344,
+ 3350,
+ 3352,
+ 3365,
+ 3366,
+ 3369,
+ 3375,
+ 3392,
+ 3404,
+ 3411,
+ 3417,
+ 3419,
+ 3420,
+ 3440,
+ 3444,
+ 3447,
+ 3460,
+ 3467,
+ 3474,
+ 3480,
+ 3498,
+ 3507,
+ 3511,
+ 3519,
+ 3529,
+ 3539,
+ 3541,
+ 3548,
+ 3549,
+ 3569,
+ 3586,
+ 3604,
+ 3607,
+ 3646,
+ 3647,
+ 3658,
+ 3669,
+ 3700,
+ 3711,
+ 3725,
+ 3730,
+ 3732,
+ 3738,
+ 3740,
+ 3741,
+ 3752,
+ 3768,
+ 3769,
+ 3773,
+ 3779,
+ 3802,
+ 3805,
+ 3824,
+ 3849,
+ 3856,
+ 3878,
+ 3913,
+ 3923,
+ 3941,
+ 3942,
+ 3951,
+ 3982,
+ 3990,
+ 3994,
+ 3999,
+ 4011,
+ 4034,
+ 4036,
+ 4042,
+ 4043,
+ 4046,
+ 4055,
+ 4071,
+ 4074,
+ 4088,
+ 4090,
+ 4104,
+ 4108,
+ 4127,
+ 4149,
+ 4150,
+ 4155,
+ 4157,
+ 4158,
+ 4160,
+ 4177,
+ 4181,
+ 4190,
+ 4193,
+ 4210,
+ 4222,
+ 4235,
+ 4242,
+ 4253,
+ 4265,
+ 4272,
+ 4279,
+ 4297,
+ 4303,
+ 4315,
+ 4326,
+ 4333,
+ 4352,
+ 4368,
+ 4384,
+ 4404,
+ 4413,
+ 4423,
+ 4425,
+ 4441,
+ 4449,
+ 4451,
+ 4479,
+ 4487,
+ 4500,
+ 4515,
+ 4523,
+ 4533,
+ 4535,
+ 4547,
+ 4549,
+ 4550,
+ 4569,
+ 4584,
+ 4590,
+ 4591,
+ 4597,
+ 4600,
+ 4603,
+ 4610,
+ 4626,
+ 4657,
+ 4666,
+ 4678,
+ 4697,
+ 4706,
+ 4713,
+ 4731,
+ 4744,
+ 4751,
+ 4753,
+ 4758,
+ 4765,
+ 4776,
+ 4796,
+ 4812,
+ 4834,
+ 4850,
+ 4857,
+ 4861,
+ 4866,
+ 4868,
+ 4871,
+ 4885,
+ 4896,
+ 4900,
+ 4909,
+ 4914,
+ 4924,
+ 4926,
+ 4947,
+ 4955,
+ 4964,
+ 4969,
+ 4978,
+ 4990,
+ 4992,
+ 4993,
+]
+
+
+def extract_solution(solution_str):
+ return remove_boxed(last_boxed_only_string(solution_str))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--local_dir", default="~/data/math")
+ parser.add_argument("--hdfs_dir", default=None)
+
+ args = parser.parse_args()
+
+ # 'lighteval/MATH' is no longer available on huggingface.
+ # Use mirror repo: DigitalLearningGmbH/MATH-lighteval
+ data_source = "DigitalLearningGmbH/MATH-lighteval"
+ print(f"Loading the {data_source} dataset from huggingface...", flush=True)
+ dataset = datasets.load_dataset(data_source, trust_remote_code=True)
+
+ train_dataset = dataset["train"]
+ test_dataset = dataset["test"]
+
+ instruction_following = "Let's think step by step and output the final answer within \\boxed{}."
+
+ # add a row to each data item that represents a unique id
+ def make_map_fn(split):
+ def process_fn(example, idx):
+ question = example.pop("problem")
+
+ question = question + " " + instruction_following
+
+ answer = example.pop("solution")
+ solution = extract_solution(answer)
+ data = {
+ "data_source": data_source,
+ "prompt": [{"role": "user", "content": question}],
+ "ability": "math",
+ "reward_model": {"style": "rule", "ground_truth": solution},
+ "extra_info": {"split": split, "index": idx},
+ }
+ return data
+
+ return process_fn
+
+ train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
+ test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True)
+
+ # Split test into dev and test
+ dev_indices_set = set(DEV_INDICES)
+ dev_dataset = test_dataset.select(DEV_INDICES)
+
+ def filter_dev_indices(example, idx):
+ return idx not in dev_indices_set
+
+ test_dataset = test_dataset.filter(filter_dev_indices, with_indices=True)
+
+ local_dir = args.local_dir
+ hdfs_dir = args.hdfs_dir
+
+ train_dataset.to_parquet(os.path.join(local_dir, "train.parquet"))
+ dev_dataset.to_parquet(os.path.join(local_dir, "dev.parquet"))
+ test_dataset.to_parquet(os.path.join(local_dir, "test.parquet"))
+
+ if hdfs_dir is not None:
+ makedirs(hdfs_dir)
+
+ copy(src=local_dir, dst=hdfs_dir)
diff --git a/recipe/rep_exp/eval.sh b/recipe/rep_exp/eval.sh
new file mode 100644
index 00000000000..cb667bcd698
--- /dev/null
+++ b/recipe/rep_exp/eval.sh
@@ -0,0 +1,83 @@
+TASK=${1} # math, gsm8k, dapo-with-aime24
+
+# Custom model path for evaluation after training
+MODEL_PATH=${2} # /path/to/global_step_X/actor/hf, where X is the global step of the checkpoint with the best pass@1 on dev
+
+# If you want to evaluate the base model before training
+# MODEL_PATH=Qwen/Qwen2.5-7B-Instruct
+
+train_path=$HOME/data/${TASK}/train.parquet
+train_files="['$train_path']"
+CHECKPOINT_SAVE_CONTENTS='["model"]'
+
+if [ ${TASK} == "dapo-with-aime24" ]; then
+ MAX_PROMPT_LENGTH=$((1024 * 2))
+ MAX_RESPONSE_LENGTH=$((1024 * 8))
+ MAX_NUM_BATCHED_TOKENS=$((MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH))
+ test_path=$HOME/data/${TASK}/dev.parquet
+else
+ MAX_PROMPT_LENGTH=1024
+ MAX_RESPONSE_LENGTH=1024
+ MAX_NUM_BATCHED_TOKENS=8192
+ test_path=$HOME/data/${TASK}/test.parquet
+fi
+
+test_files="['$test_path']"
+
+# If you're on a cluster with no internet access, set to OFFLINE=True
+OFFLINE=False
+
+PYTHONUNBUFFERED=1 WANDB_MODE=disabled TRANSFORMERS_OFFLINE=${OFFLINE} python3 -u -m recipe.rep_exp.main_rep_exp \
+ algorithm.adv_estimator=grpo \
+ data.train_files="$train_files" \
+ data.val_files="$test_files" \
+ data.train_batch_size=1024 \
+ data.max_prompt_length=$MAX_PROMPT_LENGTH \
+ data.max_response_length=$MAX_RESPONSE_LENGTH \
+ data.filter_overlong_prompts=True \
+ data.truncation='error' \
+ data.val_batch_size=128 \
+ actor_rollout_ref.model.path="$MODEL_PATH" \
+ actor_rollout_ref.actor.checkpoint.save_contents=$CHECKPOINT_SAVE_CONTENTS \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.actor.ppo_mini_batch_size=256 \
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \
+ actor_rollout_ref.actor.kl_loss_coef=0.0 \
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.actor.ppo_epochs=1 \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.fsdp_config.param_offload=False \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
+ actor_rollout_ref.rollout.mode=sync \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
+ actor_rollout_ref.rollout.max_num_batched_tokens=$MAX_NUM_BATCHED_TOKENS \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.95 \
+ actor_rollout_ref.rollout.val_kwargs.n=256 \
+ actor_rollout_ref.rollout.n=8 \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
+ actor_rollout_ref.ref.fsdp_config.param_offload=True \
+ reward_model.model.path="$MODEL_PATH" \
+ reward_model.model.use_remove_padding=False \
+ reward_model.model.fsdp_config.param_offload=True \
+ reward_model.micro_batch_size_per_gpu=32 \
+ reward_model.model.input_tokenizer=null \
+ actor_rollout_ref.actor.use_kl_loss=False \
+ algorithm.use_kl_in_reward=False \
+ trainer.critic_warmup=0 \
+ trainer.logger='["console","json_eval"]' \
+ trainer.project_name='rep-exp' \
+ trainer.experiment_name="${TASK}_eval" \
+ trainer.n_gpus_per_node=1 \
+ trainer.nnodes=1 \
+ trainer.save_freq=-1 \
+ trainer.test_freq=1 \
+ trainer.total_epochs=100 \
+ trainer.val_only=True \
+ trainer.resume_mode=disable \
+ trainer.resume_from_path=''
+
+exit 0
diff --git a/recipe/rep_exp/main_rep_exp.py b/recipe/rep_exp/main_rep_exp.py
new file mode 100644
index 00000000000..ad7068f12c1
--- /dev/null
+++ b/recipe/rep_exp/main_rep_exp.py
@@ -0,0 +1,483 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
+"""
+
+import os
+import socket
+import warnings
+
+import hydra
+import ray
+from omegaconf import OmegaConf
+
+from verl.experimental.dataset.sampler import AbstractSampler
+from verl.trainer.constants_ppo import get_ppo_ray_runtime_env
+from verl.trainer.ppo.reward import load_reward_manager
+from verl.trainer.ppo.utils import need_critic, need_reference_policy
+from verl.utils.config import validate_config
+from verl.utils.device import is_cuda_available
+from verl.utils.import_utils import load_extern_type
+
+from .rep_exp_trainer import RayRepExpTrainer
+
+
+@hydra.main(config_path="config", config_name="rep_exp_trainer", version_base=None)
+def main(config):
+ """Main entry point for PPO training with Hydra configuration management.
+
+ Args:
+ config_dict: Hydra configuration dictionary containing training parameters.
+ """
+ run_ppo(config)
+
+
+# Define a function to run the PPO-like training process
+def run_ppo(config, task_runner_class=None) -> None:
+ """Initialize Ray cluster and run distributed PPO training process.
+
+ Args:
+ config: Training configuration object containing all necessary parameters
+ for distributed PPO training including Ray initialization settings,
+ model paths, and training hyperparameters.
+ task_runner_class: For recipe to change TaskRunner.
+ """
+ # Check if Ray is not initialized
+ if not ray.is_initialized():
+ # Initialize Ray with a local cluster configuration
+ # Set environment variables in the runtime environment to control tokenizer parallelism,
+ # NCCL debug level, VLLM logging level, and allow runtime LoRA updating
+ # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration
+ default_runtime_env = get_ppo_ray_runtime_env()
+ ray_init_kwargs = config.ray_kwargs.get("ray_init", {})
+ runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
+
+ if config.transfer_queue.enable:
+ # Add runtime environment variables for transfer queue
+ runtime_env_vars = runtime_env_kwargs.get("env_vars", {})
+ runtime_env_vars["TRANSFER_QUEUE_ENABLE"] = "1"
+ runtime_env_kwargs["env_vars"] = runtime_env_vars
+
+ runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
+ ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env})
+ print(f"ray init kwargs: {ray_init_kwargs}")
+ ray.init(**OmegaConf.to_container(ray_init_kwargs))
+
+ if task_runner_class is None:
+ task_runner_class = ray.remote(num_cpus=1)(TaskRunner) # please make sure main_task is not scheduled on head
+
+ # Create a remote instance of the TaskRunner class, and
+ # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete
+ if (
+ is_cuda_available
+ and config.global_profiler.tool == "nsys"
+ and config.global_profiler.get("steps") is not None
+ and len(config.global_profiler.get("steps", [])) > 0
+ ):
+ from verl.utils.import_utils import is_nvtx_available
+
+ assert is_nvtx_available(), "nvtx is not available in CUDA platform. Please 'pip3 install nvtx'"
+ nsight_options = OmegaConf.to_container(
+ config.global_profiler.global_tool_config.nsys.controller_nsight_options
+ )
+ runner = task_runner_class.options(runtime_env={"nsight": nsight_options}).remote()
+ else:
+ runner = task_runner_class.remote()
+ ray.get(runner.run.remote(config))
+
+ # [Optional] get the path of the timeline trace file from the configuration, default to None
+ # This file is used for performance analysis
+ timeline_json_file = config.ray_kwargs.get("timeline_json_file", None)
+ if timeline_json_file:
+ ray.timeline(filename=timeline_json_file)
+
+
+class TaskRunner:
+ """Ray remote class for executing distributed PPO training tasks.
+
+ This class encapsulates the main training logic and runs as a Ray remote actor
+ to enable distributed execution across multiple nodes and GPUs.
+
+ Attributes:
+ role_worker_mapping: Dictionary mapping Role enums to Ray remote worker classes
+ mapping: Dictionary mapping Role enums to resource pool IDs for GPU allocation
+ """
+
+ def __init__(self):
+ self.role_worker_mapping = {}
+ self.mapping = {}
+
+ def add_actor_rollout_worker(self, config):
+ """Add actor rollout worker based on the actor strategy."""
+ from verl.single_controller.ray import RayWorkerGroup
+ from verl.trainer.ppo.ray_trainer import Role
+
+ use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
+
+ # use new model engine implementation
+ if use_legacy_worker_impl == "disable":
+ from verl.workers.engine_workers import ActorRolloutRefWorker
+
+ actor_rollout_cls = ActorRolloutRefWorker
+ ray_worker_group_cls = RayWorkerGroup
+ # NOTE: In new model engine, ref policy and actor rollout are in same ActorRolloutRefWorker,
+ # while in legacy model engine, ref policy is in a separate ActorRolloutRefWorker.
+ if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
+ role = Role.ActorRolloutRef
+ else:
+ role = Role.ActorRollout
+ self.role_worker_mapping[role] = ray.remote(actor_rollout_cls)
+ self.mapping[role] = "global_pool"
+ return actor_rollout_cls, ray_worker_group_cls
+
+ if config.actor_rollout_ref.rollout.mode == "sync":
+ warnings.warn("spmd rollout mode is deprecated and will be removed in v0.6.2", stacklevel=2)
+
+ if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
+ from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
+
+ actor_rollout_cls = (
+ AsyncActorRolloutRefWorker
+ if config.actor_rollout_ref.rollout.mode == "async"
+ else ActorRolloutRefWorker
+ )
+ ray_worker_group_cls = RayWorkerGroup
+
+ elif config.actor_rollout_ref.actor.strategy == "megatron":
+ from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
+
+ actor_rollout_cls = (
+ AsyncActorRolloutRefWorker
+ if config.actor_rollout_ref.rollout.mode == "async"
+ else ActorRolloutRefWorker
+ )
+ ray_worker_group_cls = RayWorkerGroup
+
+ else:
+ raise NotImplementedError
+
+ self.role_worker_mapping[Role.ActorRollout] = ray.remote(actor_rollout_cls)
+ self.mapping[Role.ActorRollout] = "global_pool"
+ return actor_rollout_cls, ray_worker_group_cls
+
+ def add_critic_worker(self, config):
+ """Add critic worker to role mapping."""
+ if config.critic.strategy in {"fsdp", "fsdp2"}:
+ use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
+ if use_legacy_worker_impl in ["auto", "enable"]:
+ from verl.workers.fsdp_workers import CriticWorker
+ elif use_legacy_worker_impl == "disable":
+ from verl.workers.roles import CriticWorker
+
+ print("Using new worker implementation")
+ else:
+ raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}")
+
+ elif config.critic.strategy == "megatron":
+ from verl.workers.megatron_workers import CriticWorker
+
+ else:
+ raise NotImplementedError
+
+ from verl.trainer.ppo.ray_trainer import Role
+
+ self.role_worker_mapping[Role.Critic] = ray.remote(CriticWorker)
+ self.mapping[Role.Critic] = "global_pool"
+
+ def init_resource_pool_mgr(self, config):
+ """Initialize resource pool manager."""
+
+ global_pool_id = "global_pool"
+ resource_pool_spec = {
+ global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
+ }
+ # TODO Here you can use the new registration method to support dynamic registration of roles
+ if config.reward_model.enable_resource_pool:
+ if config.reward_model.n_gpus_per_node <= 0:
+ raise ValueError("config.reward_model.n_gpus_per_node must be greater than 0")
+ if config.reward_model.nnodes <= 0:
+ raise ValueError("config.reward_model.nnodes must be greater than 0")
+
+ reward_pool = [config.reward_model.n_gpus_per_node] * config.reward_model.nnodes
+ resource_pool_spec["reward_pool"] = reward_pool
+
+ from verl.trainer.ppo.ray_trainer import ResourcePoolManager
+
+ resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=self.mapping)
+ return resource_pool_manager
+
+ def add_reward_model_worker(self, config):
+ """Add reward model worker if enabled."""
+ from verl.trainer.ppo.ray_trainer import Role
+
+ if config.reward_model.enable:
+ use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
+ if use_legacy_worker_impl in ["auto", "enable"]:
+ if config.reward_model.strategy in {"fsdp", "fsdp2"}:
+ if config.reward_model.elliptical:
+ from .workers.elliptical_reward_model_worker import (
+ EllipticalRewardModelWorker as RewardModelWorker,
+ )
+ else:
+ from verl.workers.fsdp_workers import RewardModelWorker
+ elif config.reward_model.strategy == "megatron":
+ from verl.workers.megatron_workers import RewardModelWorker
+ else:
+ raise NotImplementedError
+ elif use_legacy_worker_impl == "disable":
+ from verl.workers.roles import RewardModelWorker
+
+ print("Using new worker implementation")
+ else:
+ raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}")
+
+ self.role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
+ if config.reward_model.enable_resource_pool:
+ self.mapping[Role.RewardModel] = "reward_pool"
+ else:
+ self.mapping[Role.RewardModel] = "global_pool"
+
+ def add_ref_policy_worker(self, config, ref_policy_cls):
+ """Add reference policy worker if KL loss or KL reward is used."""
+ from verl.trainer.ppo.ray_trainer import Role
+
+ # Ref policy has been fused into ActorRolloutRefWorker in new model engine,
+ # we don't need to add a separate ref policy worker goup.
+ use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
+ if use_legacy_worker_impl == "disable":
+ return
+
+ if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
+ self.role_worker_mapping[Role.RefPolicy] = ray.remote(ref_policy_cls)
+ self.mapping[Role.RefPolicy] = "global_pool"
+
+ def run(self, config):
+ """Execute the main PPO training workflow.
+
+ This method sets up the distributed training environment, initializes
+ workers, datasets, and reward functions, then starts the training process.
+
+ Args:
+ config: Training configuration object containing all parameters needed
+ for setting up and running the PPO training process.
+ """
+ # Print the initial configuration. `resolve=True` will evaluate symbolic values.
+ from pprint import pprint
+
+ from omegaconf import OmegaConf
+
+ from verl.utils.fs import copy_to_local
+
+ print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}")
+ pprint(OmegaConf.to_container(config, resolve=True))
+ OmegaConf.resolve(config)
+
+ actor_rollout_cls, ray_worker_group_cls = self.add_actor_rollout_worker(config)
+ self.add_critic_worker(config)
+
+ # We should adopt a multi-source reward function here:
+ # - for rule-based rm, we directly call a reward score
+ # - for model-based rm, we call a model
+ # - for code related prompt, we send to a sandbox if there are test cases
+ # finally, we combine all the rewards together
+ # The reward type depends on the tag of the data
+ self.add_reward_model_worker(config)
+
+ # Add a reference policy worker if KL loss or KL reward is used.
+ self.add_ref_policy_worker(config, actor_rollout_cls)
+
+ # validate config
+ validate_config(
+ config=config,
+ use_reference_policy=need_reference_policy(self.role_worker_mapping),
+ use_critic=need_critic(config),
+ )
+
+ # Download the checkpoint from HDFS to the local machine.
+ # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on
+ local_path = copy_to_local(
+ config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False)
+ )
+
+ # Instantiate the tokenizer and processor.
+ from verl.utils import hf_processor, hf_tokenizer
+
+ trust_remote_code = config.data.get("trust_remote_code", False)
+ tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
+ # Used for multimodal LLM, could be None
+ processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
+
+ # Make sure the elliptical reward manager is registered
+ from .reward_manager.elliptical_reward_manager import EllipticalRewardManager # noqa: F401
+
+ # Load the reward manager for training and validation.
+ reward_manager_name = config.reward_model.get("reward_manager", "naive")
+ reward_fn = load_reward_manager(
+ config,
+ tokenizer,
+ num_examine=0,
+ **config.reward_model.get("reward_kwargs", {}).get(reward_manager_name, {}),
+ )
+ val_reward_fn = load_reward_manager(
+ config,
+ tokenizer,
+ num_examine=1,
+ **config.reward_model.get("reward_kwargs", {}).get(reward_manager_name, {}),
+ )
+
+ resource_pool_manager = self.init_resource_pool_mgr(config)
+
+ from verl.utils.dataset.rl_dataset import collate_fn
+
+ # Create training and validation datasets.
+ train_dataset = create_rl_dataset(
+ config.data.train_files,
+ config.data,
+ tokenizer,
+ processor,
+ is_train=True,
+ max_samples=config.data.get("train_max_samples", -1),
+ )
+ val_dataset = create_rl_dataset(
+ config.data.val_files,
+ config.data,
+ tokenizer,
+ processor,
+ is_train=False,
+ max_samples=config.data.get("val_max_samples", -1),
+ )
+ train_sampler = create_rl_sampler(config.data, train_dataset)
+
+ # Initialize the PPO trainer.
+ trainer = RayRepExpTrainer(
+ config=config,
+ tokenizer=tokenizer,
+ processor=processor,
+ role_worker_mapping=self.role_worker_mapping,
+ resource_pool_manager=resource_pool_manager,
+ ray_worker_group_cls=ray_worker_group_cls,
+ reward_fn=reward_fn,
+ val_reward_fn=val_reward_fn,
+ train_dataset=train_dataset,
+ val_dataset=val_dataset,
+ collate_fn=collate_fn,
+ train_sampler=train_sampler,
+ )
+ # Initialize the workers of the trainer.
+ trainer.init_workers()
+
+ # Start the training process.
+ trainer.fit()
+
+
+def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=True, max_samples: int = -1):
+ """Create a dataset.
+
+ Arguments:
+ data_paths: List of paths to data files.
+ data_config: The data config.
+ tokenizer (Tokenizer): The tokenizer.
+ processor (Processor): The processor.
+
+ Returns:
+ dataset (Dataset): The dataset.
+ """
+ from torch.utils.data import Dataset
+
+ from verl.utils.dataset.rl_dataset import RLHFDataset
+
+ # Check if a custom dataset class is specified in the data configuration
+ # and if the path to the custom class is provided
+ if "custom_cls" in data_config and data_config.custom_cls.get("path", None) is not None:
+ # Dynamically load the custom dataset class
+ dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name)
+ # Verify that the custom dataset class inherits from torch.utils.data.Dataset
+ if not issubclass(dataset_cls, Dataset):
+ raise TypeError(
+ f"The custom dataset class '{data_config.custom_cls.name}' from "
+ f"'{data_config.custom_cls.path}' must inherit from torch.utils.data.Dataset"
+ )
+ elif "datagen" in data_config and data_config.datagen.get("path", None) is not None and is_train:
+ # If a data generation strategy is specified, use the DynamicGenDataset class
+ from verl.utils.dataset.dynamicgen_dataset import DynamicGenDataset
+
+ dataset_cls = DynamicGenDataset
+ print("Using DynamicGenDataset for data generation.")
+ else:
+ # Use the default RLHFDataset class if no custom class is specified
+ dataset_cls = RLHFDataset
+ print(f"Using dataset class: {dataset_cls.__name__}")
+
+ # Instantiate the dataset using the determined dataset class
+ dataset = dataset_cls(
+ data_files=data_paths,
+ tokenizer=tokenizer,
+ processor=processor,
+ config=data_config,
+ max_samples=max_samples,
+ )
+
+ return dataset
+
+
+def create_rl_sampler(data_config, dataset):
+ """Create a sampler for the dataset.
+
+ Arguments:
+ data_config: The data config.
+ dataset (Dataset): The dataset.
+
+ Returns:
+ sampler (Sampler): The sampler.
+ """
+ import torch
+ from torch.utils.data import SequentialSampler
+
+ # torch.utils.data.RandomSampler could not recover properly
+ from torchdata.stateful_dataloader.sampler import RandomSampler
+
+ if data_config.sampler is not None and data_config.sampler.get("class_path", None) is not None:
+ curriculum_class = load_extern_type(
+ data_config.sampler.class_path,
+ data_config.sampler.class_name,
+ )
+ sampler = curriculum_class(
+ data_source=dataset,
+ data_config=data_config,
+ )
+ assert isinstance(sampler, AbstractSampler)
+ assert data_config.get("dataloader_num_workers", 8) == 0, (
+ "If using curriculum, num_workers must be 0 to prevent data caching. "
+ "If the dataloader caches data before the batch is done the "
+ "curriculum sampler won't have the opportunity to reorder it. "
+ )
+
+ # Use a sampler to facilitate checkpoint resumption.
+ # If shuffling is enabled in the data configuration, create a random sampler.
+ elif data_config.shuffle:
+ train_dataloader_generator = torch.Generator()
+ seed = data_config.get("seed")
+ if seed is not None:
+ train_dataloader_generator.manual_seed(seed)
+ sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator)
+ else:
+ # If shuffling is disabled, use a sequential sampler to iterate through the dataset in order.
+ sampler = SequentialSampler(data_source=dataset)
+
+ return sampler
+
+
+if __name__ == "__main__":
+ main()
diff --git a/recipe/rep_exp/metric_utils.py b/recipe/rep_exp/metric_utils.py
new file mode 100644
index 00000000000..519a5b6c0f8
--- /dev/null
+++ b/recipe/rep_exp/metric_utils.py
@@ -0,0 +1,382 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Metrics related to the RepExp trainer.
+"""
+
+from collections import defaultdict
+from functools import partial
+from typing import Any
+
+import numpy as np
+import torch
+
+from verl import DataProto
+from verl.trainer.ppo.metric_utils import _compute_response_info, bootstrap_metric, calc_maj_val
+
+
+def _compute_three_case_stats(data: DataProto, extrinsic_reward_tensor: torch.Tensor) -> dict:
+ """
+ Compute the fraction of samples that have no rollouts correct, some rollouts correct, and all rollouts correct.
+
+ Args:
+ data (DataProto): The data proto containing the batch data.
+ extrinsic_reward_tensor (torch.Tensor): The extrinsic reward tensor.
+
+ Returns:
+ dict[str, float]: A dictionary containing the fraction of samples that have no rollouts correct,
+ some rollouts correct, and all rollouts correct.
+ """
+ no_rollouts_correct = 0
+ some_rollouts_correct = 0
+ all_rollouts_correct = 0
+
+ visited_uids = set()
+ for uid in data.non_tensor_batch["uid"]:
+ if uid in visited_uids:
+ continue
+
+ visited_uids.add(uid)
+ mask = torch.from_numpy(data.non_tensor_batch["uid"] == uid)
+
+ # Split into three cases
+ if extrinsic_reward_tensor[mask].sum() == 0:
+ no_rollouts_correct += 1
+ elif extrinsic_reward_tensor[mask].sum() == mask.sum():
+ all_rollouts_correct += 1
+ elif extrinsic_reward_tensor[mask].sum() > 0 and extrinsic_reward_tensor[mask].sum() < mask.sum():
+ some_rollouts_correct += 1
+ else:
+ raise ValueError(f"Invalid extrinsic reward tensor: {extrinsic_reward_tensor[mask].sum()}")
+
+ # Sanity checks
+ assert len(visited_uids) == no_rollouts_correct + some_rollouts_correct + all_rollouts_correct
+
+ return {
+ "no_rollouts_correct_frac": no_rollouts_correct / len(visited_uids),
+ "some_rollouts_correct_frac": some_rollouts_correct / len(visited_uids),
+ "all_rollouts_correct_frac": all_rollouts_correct / len(visited_uids),
+ }
+
+
+def compute_data_metrics(batch: DataProto, use_critic: bool = True, elliptical: bool = False) -> dict[str, Any]:
+ """
+ Computes various metrics from a batch of data for PPO training.
+
+ This function calculates metrics related to scores, rewards, advantages, returns, values,
+ and sequence lengths from a batch of data. It provides statistical information (mean, max, min)
+ for each metric category.
+
+ Args:
+ batch: A DataProto object containing batch data with token-level scores, rewards, advantages, etc.
+ use_critic: Whether to include critic-specific metrics. Defaults to True.
+ elliptical: Whether to include elliptical-specific metrics. Defaults to False.
+
+ Returns:
+ A dictionary of metrics including:
+ - critic/score/mean, max, min: Statistics about sequence scores
+ - critic/rewards/mean, max, min: Statistics about sequence rewards
+ - critic/advantages/mean, max, min: Statistics about advantages
+ - critic/returns/mean, max, min: Statistics about returns
+ - critic/values/mean, max, min: Statistics about critic values (if use_critic=True)
+ - critic/vf_explained_var: Explained variance of the value function (if use_critic=True)
+ - response_length/mean, max, min, clip_ratio: Statistics about response lengths
+ - prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths
+ - num_turns/mean, max, min: Statistics about the number of multi-turn conversations
+ """
+ sequence_score = batch.batch["token_level_scores"].sum(-1)
+ sequence_reward = batch.batch["token_level_rewards"].sum(-1)
+
+ if elliptical:
+ sequence_intrinsic_reward = batch.non_tensor_batch["intrinsic_reward"].sum(-1)
+ sequence_beta_scaled_intrinsic_reward = batch.non_tensor_batch["beta_scaled_intrinsic_reward"].sum(-1)
+ sequence_extrinsic_reward = batch.non_tensor_batch["extrinsic_reward"].sum(-1)
+ sequence_total_reward = batch.non_tensor_batch["total_reward"].sum(-1)
+ sequence_raw_bonuses = batch.non_tensor_batch["raw_bonuses"].sum(-1)
+
+ three_case_stats = _compute_three_case_stats(batch, batch.non_tensor_batch["extrinsic_reward"])
+
+ advantages = batch.batch["advantages"]
+ returns = batch.batch["returns"]
+
+ max_response_length = batch.batch["responses"].shape[-1]
+
+ prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool()
+ response_mask = batch.batch["response_mask"].bool()
+
+ max_prompt_length = prompt_mask.size(-1)
+
+ response_info = _compute_response_info(batch)
+ prompt_length = response_info["prompt_length"]
+ response_length = response_info["response_length"]
+
+ aborted_mask = (response_length == 0).bool()
+ non_aborted_mask = ~aborted_mask
+
+ non_aborted_sequence_score = sequence_score[non_aborted_mask]
+ non_aborted_sequence_reward = sequence_reward[non_aborted_mask]
+
+ score_mean = torch.mean(non_aborted_sequence_score).detach().item()
+ score_max = torch.max(non_aborted_sequence_score).detach().item()
+ score_min = torch.min(non_aborted_sequence_score).detach().item()
+
+ reward_mean = torch.mean(non_aborted_sequence_reward).detach().item()
+ reward_max = torch.max(non_aborted_sequence_reward).detach().item()
+ reward_min = torch.min(non_aborted_sequence_reward).detach().item()
+
+ valid_adv = torch.masked_select(advantages, response_mask)
+ valid_returns = torch.masked_select(returns, response_mask)
+
+ if use_critic:
+ values = batch.batch["values"]
+ valid_values = torch.masked_select(values, response_mask)
+ return_diff_var = torch.var(valid_returns - valid_values)
+ return_var = torch.var(valid_returns)
+
+ # Aborted samples and non-aborted response length statistics
+ # response_length_non_aborted/*: statistics computed on non-aborted samples only
+ aborted_ratio = torch.mean(aborted_mask.float()).detach().item()
+
+ non_aborted_response_length = response_length[non_aborted_mask]
+ if non_aborted_response_length.numel() > 0:
+ non_aborted_response_length_mean = torch.mean(non_aborted_response_length).detach().item()
+ non_aborted_response_length_max = torch.max(non_aborted_response_length).detach().item()
+ non_aborted_response_length_min = torch.min(non_aborted_response_length).detach().item()
+ non_aborted_response_length_clip_ratio = (
+ torch.mean(torch.eq(non_aborted_response_length, max_response_length).float()).detach().item()
+ )
+ else:
+ raise ValueError("All samples are aborted, this should not happen.")
+
+ metrics = {
+ # score
+ "critic/score/mean": score_mean,
+ "critic/score/max": score_max,
+ "critic/score/min": score_min,
+ # reward
+ "critic/rewards/mean": reward_mean,
+ "critic/rewards/max": reward_max,
+ "critic/rewards/min": reward_min,
+ # adv
+ "critic/advantages/mean": torch.mean(valid_adv).detach().item(),
+ "critic/advantages/max": torch.max(valid_adv).detach().item(),
+ "critic/advantages/min": torch.min(valid_adv).detach().item(),
+ # returns
+ "critic/returns/mean": torch.mean(valid_returns).detach().item(),
+ "critic/returns/max": torch.max(valid_returns).detach().item(),
+ "critic/returns/min": torch.min(valid_returns).detach().item(),
+ **(
+ {
+ # values
+ "critic/values/mean": torch.mean(valid_values).detach().item(),
+ "critic/values/max": torch.max(valid_values).detach().item(),
+ "critic/values/min": torch.min(valid_values).detach().item(),
+ # vf explained var
+ "critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
+ }
+ if use_critic
+ else {}
+ ),
+ **(
+ {
+ # raw bonuses
+ "critic/raw_bonuses/mean": np.mean(sequence_raw_bonuses).item(),
+ "critic/raw_bonuses/max": np.max(sequence_raw_bonuses).item(),
+ "critic/raw_bonuses/min": np.min(sequence_raw_bonuses).item(),
+ "critic/raw_bonuses/std": np.std(sequence_raw_bonuses).item(),
+ # intrinsic_reward
+ "critic/intrinsic_reward/mean": np.mean(sequence_intrinsic_reward).item(),
+ "critic/intrinsic_reward/max": np.max(sequence_intrinsic_reward).item(),
+ "critic/intrinsic_reward/min": np.min(sequence_intrinsic_reward).item(),
+ "critic/intrinsic_reward/std": np.std(sequence_intrinsic_reward).item(),
+ # beta_scaled_intrinsic_reward
+ "critic/beta_scaled_intrinsic_reward/mean": np.mean(sequence_beta_scaled_intrinsic_reward).item(),
+ "critic/beta_scaled_intrinsic_reward/max": np.max(sequence_beta_scaled_intrinsic_reward).item(),
+ "critic/beta_scaled_intrinsic_reward/min": np.min(sequence_beta_scaled_intrinsic_reward).item(),
+ "critic/beta_scaled_intrinsic_reward/std": np.std(sequence_beta_scaled_intrinsic_reward).item(),
+ # extrinsic_reward
+ "critic/extrinsic_reward/mean": np.mean(sequence_extrinsic_reward).item(),
+ "critic/extrinsic_reward/max": np.max(sequence_extrinsic_reward).item(),
+ "critic/extrinsic_reward/min": np.min(sequence_extrinsic_reward).item(),
+ "critic/extrinsic_reward/std": np.std(sequence_extrinsic_reward).item(),
+ # three_case_stats
+ "critic/extrinsic_reward/no_rollouts_correct_frac": three_case_stats["no_rollouts_correct_frac"],
+ "critic/extrinsic_reward/some_rollouts_correct_frac": three_case_stats["some_rollouts_correct_frac"],
+ "critic/extrinsic_reward/all_rollouts_correct_frac": three_case_stats["all_rollouts_correct_frac"],
+ # total_reward
+ "critic/total_reward/mean": np.mean(sequence_total_reward).item(),
+ "critic/total_reward/max": np.max(sequence_total_reward).item(),
+ "critic/total_reward/min": np.min(sequence_total_reward).item(),
+ "critic/total_reward/std": np.std(sequence_total_reward).item(),
+ }
+ if elliptical
+ else {}
+ ),
+ # response length
+ "response_length/mean": torch.mean(response_length).detach().item(),
+ "response_length/max": torch.max(response_length).detach().item(),
+ "response_length/min": torch.min(response_length).detach().item(),
+ "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float())
+ .detach()
+ .item(),
+ # response length (non-aborted only)
+ # These statistics exclude aborted samples to avoid skew from zeros
+ "response_length_non_aborted/mean": non_aborted_response_length_mean,
+ "response_length_non_aborted/max": non_aborted_response_length_max,
+ "response_length_non_aborted/min": non_aborted_response_length_min,
+ "response_length_non_aborted/clip_ratio": non_aborted_response_length_clip_ratio,
+ # aborted ratio
+ # Fraction of samples whose response length is zero
+ "response/aborted_ratio": aborted_ratio,
+ # prompt length
+ "prompt_length/mean": torch.mean(prompt_length).detach().item(),
+ "prompt_length/max": torch.max(prompt_length).detach().item(),
+ "prompt_length/min": torch.min(prompt_length).detach().item(),
+ "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
+ }
+
+ # multi-turn conversation
+ if "__num_turns__" in batch.non_tensor_batch:
+ num_turns = batch.non_tensor_batch["__num_turns__"]
+ metrics["num_turns/min"] = num_turns.min()
+ metrics["num_turns/max"] = num_turns.max()
+ metrics["num_turns/mean"] = num_turns.mean()
+
+ if "tool_call_counts" in batch.non_tensor_batch:
+ tool_call_counts = batch.non_tensor_batch["tool_call_counts"]
+ metrics["tool_call_counts/min"] = tool_call_counts.min()
+ metrics["tool_call_counts/max"] = tool_call_counts.max()
+ metrics["tool_call_counts/mean"] = tool_call_counts.mean()
+
+ return metrics
+
+
+def comb_estimator(n: int, c: int, k: int) -> float:
+ """Calculates 1 - comb(n - c, k) / comb(n, k)."""
+ if n - c < k:
+ return 1.0
+ return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
+
+
+def process_validation_metrics(
+ data_sources: list[str], sample_uids: list[str], infos_dict: dict[str, list[Any]], seed: int = 42
+) -> dict[str, dict[str, dict[str, float]]]:
+ """
+ Process validation metrics into a structured format with statistical analysis.
+
+ This function organizes validation metrics by data source and prompt, then computes
+ various statistical measures including means, standard deviations, best/worst values,
+ and majority voting results. It also performs bootstrap sampling to estimate statistics
+ for different sample sizes.
+
+ Args:
+ data_sources: List of data source identifiers for each sample.
+ sample_uids: List of sample uids corresponding to each sample.
+ infos_dict: Dictionary mapping variable names to lists of values for each sample.
+ seed: Random seed for bootstrap sampling. Defaults to 42.
+
+ Returns:
+ A nested dictionary with the structure:
+ {
+ data_source: {
+ variable_name: {
+ metric_name: value
+ }
+ }
+ }
+
+ Where metric_name includes:
+ - "mean@N": Mean value across N samples
+ - "std@N": Standard deviation across N samples
+ - "best@N/mean": Mean of the best values in bootstrap samples of size N
+ - "best@N/std": Standard deviation of the best values in bootstrap samples
+ - "worst@N/mean": Mean of the worst values in bootstrap samples
+ - "worst@N/std": Standard deviation of the worst values in bootstrap samples
+ - "maj@N/mean": Mean of majority voting results in bootstrap samples (if "pred" exists)
+ - "maj@N/std": Standard deviation of majority voting results (if "pred" exists)
+
+ Example:
+ >>> data_sources = ["source1", "source1", "source2"]
+ >>> sample_uids = ["uid1", "uid1", "uid2"]
+ >>> infos_dict = {"score": [0.8, 0.9, 0.7], "pred": ["A", "A", "B"]}
+ >>> result = process_validation_metrics(data_sources, sample_uids, infos_dict)
+ >>> # result will contain statistics for each data source and variable
+ """
+ # Group metrics by data source, prompt and variable
+ data_src2uid2var2vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
+ for sample_idx, data_source in enumerate(data_sources):
+ uid = sample_uids[sample_idx]
+ var2vals = data_src2uid2var2vals[data_source][uid]
+ for var_name, var_vals in infos_dict.items():
+ var2vals[var_name].append(var_vals[sample_idx])
+
+ # Calculate metrics for each group
+ data_src2uid2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
+ for data_source, uid2var2vals in data_src2uid2var2vals.items():
+ for uid, var2vals in uid2var2vals.items():
+ for var_name, var_vals in var2vals.items():
+ if isinstance(var_vals[0], str):
+ continue
+
+ metric = {}
+ n_resps = len(var_vals)
+ metric[f"mean@{n_resps}"] = np.mean(var_vals)
+ metric["pass@1/mean"] = comb_estimator(n_resps, np.sum(var_vals), 1)
+
+ if n_resps > 1:
+ metric[f"std@{n_resps}"] = np.std(var_vals)
+
+ ns = []
+ n = 2
+ while n < n_resps:
+ ns.append(n)
+ n *= 2
+ ns.append(n_resps)
+
+ for n in ns:
+ # [(bon_mean, bon_std), (won_mean, won_std)] = bootstrap_metric(
+ # data=var_vals, subset_size=n, reduce_fns=[np.max, np.min], seed=seed
+ # )
+ # metric[f"best@{n}/mean"], metric[f"best@{n}/std"] = bon_mean, bon_std
+ # metric[f"worst@{n}/mean"], metric[f"worst@{n}/std"] = won_mean, won_std
+ metric[f"pass@{n}/mean"] = comb_estimator(n_resps, np.sum(var_vals), n)
+ if var2vals.get("pred", None) is not None:
+ vote_data = [
+ {"val": val, "pred": pred} for val, pred in zip(var_vals, var2vals["pred"], strict=True)
+ ]
+ [(maj_n_mean, maj_n_std)] = bootstrap_metric(
+ data=vote_data,
+ subset_size=n,
+ reduce_fns=[partial(calc_maj_val, vote_key="pred", val_key="val")],
+ seed=seed,
+ )
+ metric[f"maj@{n}/mean"], metric[f"maj@{n}/std"] = maj_n_mean, maj_n_std
+
+ data_src2uid2var2metric[data_source][uid][var_name] = metric
+
+ # Aggregate metrics across uids
+ data_src2var2metric2uid_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
+ for data_source, uid2var2metric in data_src2uid2var2metric.items():
+ for uid, var2metric in uid2var2metric.items():
+ for var_name, metric in var2metric.items():
+ for metric_name, metric_val in metric.items():
+ data_src2var2metric2uid_vals[data_source][var_name][metric_name].append(metric_val)
+
+ data_src2var2metric2val = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
+ for data_source, var2metric2uid_vals in data_src2var2metric2uid_vals.items():
+ for var_name, metric2uid_vals in var2metric2uid_vals.items():
+ for metric_name, uid_vals in metric2uid_vals.items():
+ data_src2var2metric2val[data_source][var_name][metric_name] = np.mean(uid_vals)
+
+ return data_src2var2metric2val
diff --git a/recipe/rep_exp/model_merge.sh b/recipe/rep_exp/model_merge.sh
new file mode 100644
index 00000000000..b7c673ebe8c
--- /dev/null
+++ b/recipe/rep_exp/model_merge.sh
@@ -0,0 +1,6 @@
+CHECKPOINT_PATH=${1} # /path/to/global_step_X/actor, where X is the global step of the checkpoint with the best pass@1 on dev
+
+python3 -m verl.model_merger merge \
+ --backend fsdp \
+ --local_dir $CHECKPOINT_PATH \
+ --target_dir $CHECKPOINT_PATH/hf
\ No newline at end of file
diff --git a/recipe/rep_exp/plot_pass_at_k.py b/recipe/rep_exp/plot_pass_at_k.py
new file mode 100644
index 00000000000..eea2011bab0
--- /dev/null
+++ b/recipe/rep_exp/plot_pass_at_k.py
@@ -0,0 +1,241 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Code to plot the pass@k results for the RepExp RL training results.
+"""
+
+import json
+import os
+from collections import defaultdict
+
+import matplotlib.pyplot as plt
+import numpy as np
+import scipy.stats as stats
+import seaborn as sns
+from matplotlib.lines import Line2D
+
+# Content configuration
+EVAL_FOLDER = "./eval"
+TASKS = ["math"] # ["math", "gsm8k", "dapo-with-aime24"]
+SEEDS = [41, 42, 43]
+ALGORITHMS = ["elliptical"] # ["grpo", "elliptical", "untrained", "unlikely"]
+LOG_AXES = True
+
+# Plot configuration
+FACE_COLOR = "#F7F7FF"
+MARKER = "o"
+LINEWIDTH = 1.275
+MARKERSIZE = 6
+MARKEREDGEWIDTH = 0.9
+LABEL_FONT_SIZE = 10
+TITLE_FONT_SIZE = 11
+TICK_LABEL_FONT_SIZE = 8
+LEGEND_FONT_SIZE = 8
+
+TASK_TO_NICE_NAME = {
+ "math": "MATH",
+ "gsm8k": "GSM8K",
+ "dapo-with-aime24": "AIME 2024",
+ "countdown-4": "Countdown",
+}
+
+ALGO_TO_COLOR = {
+ "grpo": sns.color_palette("deep")[-1],
+ "untrained": sns.color_palette("deep")[7],
+ "elliptical": sns.color_palette("colorblind")[2],
+ "unlikely": sns.color_palette("deep")[1],
+}
+
+ALGO_TO_NICE_NAME = {
+ "grpo": "GRPO",
+ "untrained": "Base Model",
+ "elliptical": r"RepExp (ours)",
+ "unlikely": "Unlikeliness",
+}
+
+
+def process_data(data: list[dict[str, float]], algorithm: str) -> tuple[dict[int, float], dict[int, float]]:
+ """
+ Process the pass@k data generated by a given algorithm.
+
+ Args:
+ data (List[Dict]): The data to process.
+ algorithm (str): Algorithm that generated the data.
+
+ Returns:
+ Tuple[Dict[int, float], Dict[int, float]]:
+ pass_at_k - The mean pass@k values.
+ pass_at_k_sem - The standard error of the pass@k values.
+ """
+ pass_at_k = defaultdict(list)
+ for d in data:
+ for key, v in d.items():
+ for k in [1, 2, 4, 8, 16, 32, 64, 128, 256]:
+ if key.endswith(f"reward/pass@{k}/mean"):
+ pass_at_k[k].append(v)
+
+ # NOTE: we only use a single seed for untrained since there is only one checkpoint for it
+ if algorithm != "untrained":
+ for k in pass_at_k.keys():
+ assert len(pass_at_k[k]) == len(SEEDS)
+
+ pass_at_k_sem = {k: stats.sem(v) for k, v in pass_at_k.items()} if algorithm != "untrained" else None
+ pass_at_k = {k: np.mean(v) for k, v in pass_at_k.items()}
+
+ return pass_at_k, pass_at_k_sem
+
+
+def main():
+ # Get all top-level folders in EVAL_FOLDER
+ eval_folders = os.listdir(EVAL_FOLDER)
+
+ # Figure setup
+ sns.set_style("whitegrid")
+ fig, axs = plt.subplots(1, len(TASKS), figsize=(3 * len(TASKS), 3))
+
+ for i, task in enumerate(TASKS):
+ ax = axs[i] if len(TASKS) > 1 else axs
+ algo_to_xs = {}
+ algo_to_ys = {}
+
+ for algorithm in ALGORITHMS:
+ # Get all eval folders for the current task and algorithm
+ folders = [f for f in eval_folders if f.startswith(f"{task}_{algorithm}")]
+ if len(folders) == 0:
+ continue
+
+ data = []
+ for folder in folders:
+ if algorithm == "untrained":
+ with open(os.path.join(EVAL_FOLDER, folder, "eval.json")) as f:
+ data.append(json.load(f))
+ else:
+ # walk all files recursively in folder
+ for root, dirs, files in os.walk(os.path.join(EVAL_FOLDER, folder)):
+ for file in files:
+ if file.endswith("eval.json"):
+ with open(os.path.join(root, file)) as f:
+ data.append(json.load(f))
+ break
+
+ pass_at_k, pass_at_k_sem = process_data(data, algorithm)
+
+ xs = np.array(list(pass_at_k.keys()))
+ ys = np.array([pass_at_k[k] for k in xs])
+ algo_to_xs[algorithm] = xs
+ algo_to_ys[algorithm] = ys
+
+ # Plot the current task - algorithm data
+ ax.plot(
+ xs,
+ ys,
+ color=ALGO_TO_COLOR[algorithm],
+ label=algorithm,
+ markeredgecolor=FACE_COLOR,
+ marker=MARKER,
+ linewidth=LINEWIDTH,
+ markersize=MARKERSIZE,
+ markeredgewidth=MARKEREDGEWIDTH,
+ alpha=1.0 if algorithm != "untrained" else 0.8,
+ )
+
+ # Plot the standard error in shaded bands
+ if algorithm != "untrained":
+ sems = np.array([pass_at_k_sem[k] for k in xs])
+ ax.fill_between(xs, ys - sems, ys + sems, alpha=0.2, color=ALGO_TO_COLOR[algorithm])
+
+ # Set y-axis limits
+ if task == "math":
+ y_min = 0.7
+ ax.set_ylim(top=0.95, bottom=y_min)
+ elif task == "gsm8k":
+ y_min = 0.925
+ ax.set_ylim(top=0.995, bottom=y_min)
+ elif task == "dapo-with-aime24":
+ y_min = 0.1
+ ax.set_ylim(bottom=y_min, top=0.63)
+
+ # Set x-axis limits
+ if LOG_AXES:
+ ax.set_xlim(left=2 ** (-0.2), right=2 ** (8.2))
+ else:
+ ax.set_xlim(left=-10, right=266)
+
+ # Set x-axis scale and ticks
+ if LOG_AXES:
+ ax.set_xscale("log", base=2)
+ x_ticks = [2**i for i in range(int(np.log2(max(xs))) + 1)]
+ x_tick_labels = [f"$2^{{{i}}}$" for i in range(int(np.log2(max(xs))) + 1)]
+ else:
+ # set every 64
+ x_ticks = [1, 32, 64, 96, 128, 160, 192, 224, 256]
+ x_tick_labels = ["1", "32", "64", "96", "128", "160", "192", "224", "256"]
+ ax.set_xticks(x_ticks, x_tick_labels)
+
+ # Set axes labels
+ ax.set_xlabel("k", fontsize=LABEL_FONT_SIZE)
+ if i == 0:
+ ax.set_ylabel("Pass@k", fontsize=LABEL_FONT_SIZE)
+
+ # Set title
+ ax.set_title(f"{TASK_TO_NICE_NAME[task]}", fontsize=TITLE_FONT_SIZE)
+
+ # Set font size for tick labels
+ for _label in ax.get_xticklabels():
+ _label.set_fontsize(TICK_LABEL_FONT_SIZE)
+ for _label in ax.get_yticklabels():
+ _label.set_fontsize(TICK_LABEL_FONT_SIZE)
+
+ # Create legend handles
+ legend_handles = [
+ Line2D(
+ [0],
+ [0],
+ color=ALGO_TO_COLOR[algo],
+ marker=MARKER,
+ linestyle="-",
+ linewidth=LINEWIDTH,
+ markersize=MARKERSIZE,
+ markeredgewidth=MARKEREDGEWIDTH,
+ markeredgecolor=FACE_COLOR,
+ label=ALGO_TO_NICE_NAME[algo],
+ )
+ for algo in ALGORITHMS
+ ]
+
+ # Create legend
+ legend = fig.legend(
+ handles=legend_handles,
+ loc="lower center",
+ ncol=len(ALGORITHMS),
+ bbox_to_anchor=(0.5, -0.07),
+ fontsize=LEGEND_FONT_SIZE,
+ )
+
+ plt.tight_layout()
+
+ os.makedirs("figures", exist_ok=True)
+ # Save figure
+ plt.savefig(
+ os.path.join("figures", f"rl_pass_at_k_{TASKS}_{'' if LOG_AXES else '_linear_axes'}.pdf"),
+ bbox_extra_artists=(legend,),
+ bbox_inches="tight",
+ )
+
+ # Close figure
+ plt.close()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/recipe/rep_exp/rep_exp_trainer.py b/recipe/rep_exp/rep_exp_trainer.py
new file mode 100644
index 00000000000..c7c23b848d7
--- /dev/null
+++ b/recipe/rep_exp/rep_exp_trainer.py
@@ -0,0 +1,739 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+# Copyright 2023-2024 SGLang Team
+# Copyright 2025 ModelBest Inc. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+PPO Trainer with Ray-based single controller.
+This trainer supports model-agonistic model initialization with huggingface
+"""
+
+import json
+import os
+import uuid
+from collections import defaultdict
+from copy import deepcopy
+from pprint import pprint
+
+import numpy as np
+import ray
+import torch
+from omegaconf import OmegaConf
+from tqdm import tqdm
+
+from verl import DataProto
+from verl.experimental.dataset.sampler import AbstractCurriculumSampler
+from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
+from verl.single_controller.ray import RayClassWithInitArgs
+from verl.single_controller.ray.base import create_colocated_worker_cls
+from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss
+from verl.trainer.ppo.metric_utils import (
+ compute_throughout_metrics,
+ compute_timing_metrics,
+)
+from verl.trainer.ppo.ray_trainer import RayPPOTrainer, apply_kl_penalty, compute_advantage, compute_response_mask
+from verl.trainer.ppo.reward import compute_reward, compute_reward_async
+from verl.trainer.ppo.utils import Role
+from verl.utils.checkpoint.checkpoint_manager import should_save_ckpt_esi
+from verl.utils.config import omega_conf_to_dataclass
+from verl.utils.debug import marked_timer
+from verl.utils.metric import reduce_metrics
+from verl.utils.rollout_skip import RolloutSkip
+
+from .metric_utils import compute_data_metrics, process_validation_metrics
+
+
+class RayRepExpTrainer(RayPPOTrainer):
+ """Distributed RepExp trainer using Ray for scalable reinforcement learning.
+
+ See RayPPOTrainer parent class for more details.
+ """
+
+ def _save_checkpoint(self):
+ super()._save_checkpoint()
+
+ # Write best metric to global steps
+ local_best_metric_to_global_step = os.path.join(
+ self.config.trainer.default_local_dir, "best_metric_to_global_step.json"
+ )
+ with open(local_best_metric_to_global_step, "w") as f:
+ json.dump(self.best_dev_pass_at_k_to_global_step, f)
+
+ def _update_best_pass_at(self, val_metrics: dict[str, float], pass_at_k: int) -> bool:
+ """
+ Save checkpoint if the validation metrics are the best.
+
+ Args:
+ val_metrics: The validation metrics.
+ pass_at_k: The pass@k to use for determining whether to save the checkpoint.
+ """
+ for k in val_metrics.keys():
+ if k.endswith(f"reward/pass@{pass_at_k}/mean"):
+ if val_metrics[k] > self.best_dev_pass_at_k[pass_at_k]:
+ self.best_dev_pass_at_k[pass_at_k] = val_metrics[k]
+ self.best_dev_pass_at_k_to_global_step[pass_at_k] = self.global_steps
+ return True
+
+ return False
+
+ def _validate(self):
+ data_source_lst = []
+ reward_extra_infos_dict: dict[str, list] = defaultdict(list)
+
+ # Lists to collect samples for the table
+ sample_inputs = []
+ sample_outputs = []
+ sample_gts = []
+ sample_scores = []
+ sample_turns = []
+ sample_uids = []
+
+ for test_data in tqdm(self.val_dataloader, desc="Validating ..."):
+ test_batch = DataProto.from_single_dict(test_data)
+
+ if "uid" not in test_batch.non_tensor_batch:
+ test_batch.non_tensor_batch["uid"] = np.array(
+ [str(uuid.uuid4()) for _ in range(len(test_batch.batch))], dtype=object
+ )
+
+ # repeat test batch
+ test_batch = test_batch.repeat(
+ repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True
+ )
+
+ # we only do validation on rule-based rm
+ if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model":
+ return {}
+
+ # Store original inputs
+ input_ids = test_batch.batch["input_ids"]
+ # TODO: Can we keep special tokens except for padding tokens?
+ input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
+ sample_inputs.extend(input_texts)
+ sample_uids.extend(test_batch.non_tensor_batch["uid"])
+
+ ground_truths = [
+ item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in test_batch
+ ]
+ sample_gts.extend(ground_truths)
+
+ test_gen_batch = self._get_gen_batch(test_batch)
+ test_gen_batch.meta_info = {
+ "eos_token_id": self.tokenizer.eos_token_id,
+ "pad_token_id": self.tokenizer.pad_token_id,
+ "recompute_log_prob": False,
+ "do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample,
+ "validate": True,
+ "global_steps": self.global_steps,
+ }
+ print(f"test_gen_batch meta info: {test_gen_batch.meta_info}")
+
+ # pad to be divisible by dp_size
+ size_divisor = (
+ self.actor_rollout_wg.world_size
+ if not self.async_rollout_mode
+ else self.config.actor_rollout_ref.rollout.agent.num_workers
+ )
+ test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor)
+ if not self.async_rollout_mode:
+ test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)
+ else:
+ test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded)
+
+ # unpad
+ test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
+
+ print("validation generation end")
+
+ # Store generated outputs
+ output_ids = test_output_gen_batch.batch["responses"]
+ output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
+ sample_outputs.extend(output_texts)
+
+ test_batch = test_batch.union(test_output_gen_batch)
+ test_batch.meta_info["validate"] = True
+
+ # evaluate using reward_function
+ if self.val_reward_fn is None:
+ raise ValueError("val_reward_fn must be provided for validation.")
+ result = self.val_reward_fn(test_batch, return_dict=True)
+ reward_tensor = result["reward_tensor"]
+ scores = reward_tensor.sum(-1).cpu().tolist()
+ sample_scores.extend(scores)
+
+ reward_extra_infos_dict["reward"].extend(scores)
+ if "reward_extra_info" in result:
+ for key, lst in result["reward_extra_info"].items():
+ reward_extra_infos_dict[key].extend(lst)
+
+ # collect num_turns of each prompt
+ if "__num_turns__" in test_batch.non_tensor_batch:
+ sample_turns.append(test_batch.non_tensor_batch["__num_turns__"])
+
+ data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0]))
+
+ self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)
+
+ # dump generations
+ val_data_dir = self.config.trainer.get("validation_data_dir", None)
+ if val_data_dir:
+ self._dump_generations(
+ inputs=sample_inputs,
+ outputs=sample_outputs,
+ gts=sample_gts,
+ scores=sample_scores,
+ reward_extra_infos_dict=reward_extra_infos_dict,
+ dump_path=val_data_dir,
+ )
+
+ for key_info, lst in reward_extra_infos_dict.items():
+ assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}"
+
+ data_sources = np.concatenate(data_source_lst, axis=0)
+
+ data_src2var2metric2val = process_validation_metrics(data_sources, sample_uids, reward_extra_infos_dict)
+ metric_dict = {}
+ for data_source, var2metric2val in data_src2var2metric2val.items():
+ core_var = "acc" if "acc" in var2metric2val else "reward"
+ for var_name, metric2val in var2metric2val.items():
+ n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()])
+ for metric_name, metric_val in metric2val.items():
+ if (
+ (var_name == core_var)
+ and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"])
+ and (f"@{n_max}" in metric_name)
+ ):
+ metric_sec = "val-core"
+ else:
+ metric_sec = "val-aux"
+ pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}"
+ metric_dict[pfx] = metric_val
+
+ if len(sample_turns) > 0:
+ sample_turns = np.concatenate(sample_turns)
+ metric_dict["val-aux/num_turns/min"] = sample_turns.min()
+ metric_dict["val-aux/num_turns/max"] = sample_turns.max()
+ metric_dict["val-aux/num_turns/mean"] = sample_turns.mean()
+
+ return metric_dict
+
+ def init_workers(self):
+ """Initialize distributed training workers using Ray backend.
+
+ Creates:
+ 1. Ray resource pools from configuration
+ 2. Worker groups for each role (actor, critic, etc.)
+ """
+ self.resource_pool_manager.create_resource_pool()
+
+ self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}
+ val_only = self.config.trainer.get("val_only", False)
+
+ # create actor and rollout
+ actor_role = Role.ActorRolloutRef if Role.ActorRolloutRef in self.role_worker_mapping else Role.ActorRollout
+ if self.hybrid_engine:
+ resource_pool = self.resource_pool_manager.get_resource_pool(actor_role)
+ actor_rollout_cls = RayClassWithInitArgs(
+ cls=self.role_worker_mapping[actor_role],
+ config=self.config.actor_rollout_ref,
+ role=str(actor_role),
+ )
+ self.resource_pool_to_cls[resource_pool][str(actor_role)] = actor_rollout_cls
+ else:
+ raise NotImplementedError
+
+ # create critic
+ if self.use_critic and not val_only:
+ resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
+ critic_cfg = omega_conf_to_dataclass(self.config.critic)
+ critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg)
+ self.resource_pool_to_cls[resource_pool][str(Role.Critic)] = critic_cls
+
+ # create reference policy if needed
+ if self.use_reference_policy and not val_only:
+ resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
+ ref_policy_cls = RayClassWithInitArgs(
+ self.role_worker_mapping[Role.RefPolicy],
+ config=self.config.actor_rollout_ref,
+ role=str(Role.RefPolicy),
+ )
+ self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls
+
+ # create a reward model if reward_fn is None
+ if self.use_rm and not val_only:
+ # we create a RM here
+ resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
+ rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)
+ self.resource_pool_to_cls[resource_pool][str(Role.RewardModel)] = rm_cls
+
+ # initialize WorkerGroup
+ # NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
+ # you should not use `create_colocated_worker_cls`.
+ # Instead, directly pass different resource pool to different worker groups.
+ # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
+ all_wg = {}
+ wg_kwargs = {} # Setting up kwargs for RayWorkerGroup
+ if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None:
+ wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout
+ if OmegaConf.select(self.config.global_profiler, "steps") is not None:
+ wg_kwargs["profile_steps"] = OmegaConf.select(self.config.global_profiler, "steps")
+ # Only require nsight worker options when tool is nsys
+ if OmegaConf.select(self.config.global_profiler, "tool") == "nsys":
+ assert (
+ OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options")
+ is not None
+ ), "worker_nsight_options must be set when using nsys with profile_steps"
+ wg_kwargs["worker_nsight_options"] = OmegaConf.to_container(
+ OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options")
+ )
+ wg_kwargs["device_name"] = self.device_name
+
+ for resource_pool, class_dict in self.resource_pool_to_cls.items():
+ worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
+ wg_dict = self.ray_worker_group_cls(
+ resource_pool=resource_pool,
+ ray_cls_with_init=worker_dict_cls,
+ **wg_kwargs,
+ )
+ spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
+ all_wg.update(spawn_wg)
+
+ if self.use_critic:
+ self.critic_wg = all_wg[str(Role.Critic)]
+ self.critic_wg.init_model()
+
+ if self.use_reference_policy and not self.ref_in_actor:
+ if str(Role.RefPolicy) in all_wg:
+ self.ref_policy_wg = all_wg[str(Role.RefPolicy)]
+ self.ref_policy_wg.init_model()
+ else:
+ # Model engine: ActorRolloutRefWorker
+ assert str(Role.ActorRolloutRef) in all_wg, f"{all_wg.keys()=}"
+ self.ref_policy_wg = all_wg[str(Role.ActorRolloutRef)]
+
+ self.rm_wg = None
+ # initalization of rm_wg will be deprecated in the future
+ if self.use_rm:
+ self.rm_wg = all_wg[str(Role.RewardModel)]
+ self.rm_wg.init_model()
+
+ # we should create rollout at the end so that vllm can have a better estimation of kv cache memory
+ self.actor_rollout_wg = all_wg[str(actor_role)]
+ self.actor_rollout_wg.init_model()
+
+ # create async rollout manager and request scheduler
+ self.async_rollout_mode = False
+ if self.config.actor_rollout_ref.rollout.mode == "async":
+ from verl.experimental.agent_loop import AgentLoopManager
+
+ self.async_rollout_mode = True
+ self.async_rollout_manager = AgentLoopManager(
+ config=self.config, worker_group=self.actor_rollout_wg, rm_wg=self.rm_wg
+ )
+
+ def fit(self):
+ """
+ The training loop of PPO.
+ The driver process only need to call the compute functions of the worker group through RPC
+ to construct the PPO dataflow.
+ The light-weight advantage computation is done on the driver process.
+ """
+ from omegaconf import OmegaConf
+
+ from .utils.tracking import Tracking
+
+ logger = Tracking(
+ project_name=self.config.trainer.project_name,
+ experiment_name=self.config.trainer.experiment_name,
+ default_backend=self.config.trainer.logger,
+ config=OmegaConf.to_container(self.config, resolve=True),
+ )
+
+ # global vars to track during training
+ self.global_steps = 0
+
+ self.best_dev_pass_at_k = {
+ 1: 0,
+ }
+ self.best_dev_pass_at_k_to_global_step = {
+ 1: 0,
+ }
+
+ # load checkpoint before doing anything
+ self._load_checkpoint()
+
+ current_epoch = self.global_steps // len(self.train_dataloader)
+
+ # perform validation before training
+ # currently, we only support validation using the reward_function.
+ if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
+ val_metrics = self._validate()
+ assert val_metrics, f"{val_metrics=}"
+
+ # Initialize the best validation metrics for pass@k before training
+ self._update_best_pass_at(val_metrics, 1)
+ val_metrics["best/pass@1"] = self.best_dev_pass_at_k[1]
+
+ pprint(f"Initial validation metrics: {val_metrics}")
+ logger.log(data=val_metrics, step=self.global_steps)
+
+ if self.config.trainer.get("val_only", False):
+ return
+
+ if self.config.actor_rollout_ref.rollout.get("skip_rollout", False):
+ rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg)
+ rollout_skip.wrap_generate_sequences()
+
+ # add tqdm
+ progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
+
+ # we start from step 1
+ self.global_steps += 1
+ last_val_metrics = None
+ self.max_steps_duration = 0
+
+ prev_step_profile = False
+ curr_step_profile = (
+ self.global_steps in self.config.global_profiler.steps
+ if self.config.global_profiler.steps is not None
+ else False
+ )
+ next_step_profile = False
+
+ for epoch in range(current_epoch, self.config.trainer.total_epochs):
+ for batch_dict in self.train_dataloader:
+ metrics = {}
+ timing_raw = {}
+
+ with marked_timer("start_profile", timing_raw):
+ self._start_profiling(
+ not prev_step_profile and curr_step_profile
+ if self.config.global_profiler.profile_continuous_steps
+ else curr_step_profile
+ )
+ batch: DataProto = DataProto.from_single_dict(batch_dict)
+ batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature
+
+ # add uid to batch
+ batch.non_tensor_batch["uid"] = np.array(
+ [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object
+ )
+
+ gen_batch = self._get_gen_batch(batch)
+
+ # pass global_steps to trace
+ gen_batch.meta_info["global_steps"] = self.global_steps
+ gen_batch_output = gen_batch.repeat(
+ repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
+ )
+
+ is_last_step = self.global_steps >= self.total_training_steps
+ with marked_timer("step", timing_raw):
+ # generate a batch
+ with marked_timer("gen", timing_raw, color="red"):
+ if not self.async_rollout_mode:
+ gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)
+ else:
+ gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output)
+
+ timing_raw.update(gen_batch_output.meta_info["timing"])
+ gen_batch_output.meta_info.pop("timing", None)
+
+ if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
+ if self.reward_fn is None:
+ raise ValueError("A reward_fn is required for REMAX advantage estimation.")
+
+ with marked_timer("gen_max", timing_raw, color="purple"):
+ gen_baseline_batch = deepcopy(gen_batch)
+ gen_baseline_batch.meta_info["do_sample"] = False
+ if not self.async_rollout_mode:
+ gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
+ else:
+ gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch)
+ batch = batch.union(gen_baseline_output)
+ # compute reward model score on batch
+ rm_scores = None
+ if self.use_rm and "rm_scores" not in batch.batch.keys():
+ rm_scores = self.rm_wg.compute_rm_score(batch)
+ batch = batch.union(rm_scores)
+ reward_baseline_tensor, _ = compute_reward(batch, self.reward_fn)
+ reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
+
+ keys_to_pop = set(gen_baseline_output.batch.keys())
+ if rm_scores is not None:
+ keys_to_pop.update(rm_scores.batch.keys())
+ batch.pop(batch_keys=list(keys_to_pop))
+
+ batch.batch["reward_baselines"] = reward_baseline_tensor
+
+ del rm_scores, gen_baseline_batch, gen_baseline_output
+ # repeat to align with repeated responses in rollout
+ batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
+ batch = batch.union(gen_batch_output)
+
+ if "response_mask" not in batch.batch.keys():
+ batch.batch["response_mask"] = compute_response_mask(batch)
+ # Balance the number of valid tokens across DP ranks.
+ # NOTE: This usually changes the order of data in the `batch`,
+ # which won't affect the advantage calculation (since it's based on uid),
+ # but might affect the loss calculation (due to the change of mini-batching).
+ if self.config.trainer.balance_batch:
+ self._balance_batch(batch, metrics=metrics)
+
+ # compute global_valid tokens
+ batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
+
+ with marked_timer("reward", timing_raw, color="yellow"):
+ # compute reward model score
+ if self.use_rm and "rm_scores" not in batch.batch.keys():
+ if self.config.reward_model.elliptical.enable:
+ hidden_states = self.rm_wg.compute_hidden_states(batch)
+ batch = batch.union(hidden_states)
+ reward_tensor = self.rm_wg.compute_rm_score(batch)
+ else:
+ reward_tensor = self.rm_wg.compute_rm_score(batch)
+ batch = batch.union(reward_tensor)
+
+ if self.config.reward_model.launch_reward_fn_async:
+ future_reward = compute_reward_async.remote(
+ data=batch, config=self.config, tokenizer=self.tokenizer
+ )
+ else:
+ reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)
+
+ # Operating Mode Selection:
+ # - Bypass mode: Sets old_log_probs = rollout_log_probs (2 policies: π_rollout, π_θ)
+ # - Decoupled mode: Recomputes old_log_probs as proximal anchor (3 policies: π_rollout, π_old, π_θ)
+ # Note: π_old computed once per data batch, serves as stable reference during mini-batch updates
+ rollout_corr_config = self.config.algorithm.get("rollout_correction", None)
+ bypass_recomputing_logprobs = rollout_corr_config and rollout_corr_config.get("bypass_mode", False)
+ if bypass_recomputing_logprobs: # Use `rollout_log_probs`
+ from verl.trainer.ppo.rollout_corr_helper import apply_rollout_correction
+
+ apply_rollout_correction(
+ batch=batch,
+ rollout_corr_config=rollout_corr_config,
+ policy_loss_config=self.config.actor_rollout_ref.actor.policy_loss,
+ )
+ else: # Recompute old_log_probs
+ with marked_timer("old_log_prob", timing_raw, color="blue"):
+ old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
+ entropys = old_log_prob.batch["entropys"]
+ response_masks = batch.batch["response_mask"]
+ loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
+ entropy_agg = agg_loss(
+ loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode
+ )
+ old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
+ metrics.update(old_log_prob_metrics)
+ old_log_prob.batch.pop("entropys")
+ batch = batch.union(old_log_prob)
+ if "rollout_log_probs" in batch.batch.keys():
+ # TODO: we may want to add diff of probs too.
+ from verl.utils.debug.metrics import calculate_debug_metrics
+
+ metrics.update(calculate_debug_metrics(batch))
+
+ assert "old_log_probs" in batch.batch, f'"old_log_prob" not in {batch.batch.keys()=}'
+
+ if self.use_reference_policy:
+ # compute reference log_prob
+ with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"):
+ if not self.ref_in_actor:
+ ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
+ else:
+ ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)
+ batch = batch.union(ref_log_prob)
+
+ # compute values
+ if self.use_critic:
+ with marked_timer("values", timing_raw, color="cyan"):
+ values = self.critic_wg.compute_values(batch)
+ batch = batch.union(values)
+
+ with marked_timer("adv", timing_raw, color="brown"):
+ # we combine with rule-based rm
+ reward_extra_infos_dict: dict[str, list]
+ if self.config.reward_model.launch_reward_fn_async:
+ reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
+ batch.batch["token_level_scores"] = reward_tensor
+
+ if reward_extra_infos_dict:
+ batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})
+
+ # compute rewards. apply_kl_penalty if available
+ if self.config.algorithm.use_kl_in_reward:
+ batch, kl_metrics = apply_kl_penalty(
+ batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
+ )
+ metrics.update(kl_metrics)
+ else:
+ batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
+
+ # Compute rollout correction: IS weights, rejection sampling, and metrics
+ # Only runs in decoupled mode (computes once per batch using stable π_old)
+ # In bypass mode, this is skipped - actor computes metrics from evolving π_θ vs π_rollout
+ if (
+ rollout_corr_config is not None
+ and "rollout_log_probs" in batch.batch
+ and not bypass_recomputing_logprobs # Only in decoupled mode
+ ):
+ from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch
+
+ # Compute IS weights, apply rejection sampling, compute metrics
+ batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config)
+ # IS and off-policy metrics already have rollout_corr/ prefix
+ metrics.update(is_metrics)
+
+ # compute advantages, executed on the driver process
+ norm_adv_by_std_in_grpo = self.config.algorithm.get(
+ "norm_adv_by_std_in_grpo", True
+ ) # GRPO adv normalization factor
+
+ batch = compute_advantage(
+ batch,
+ adv_estimator=self.config.algorithm.adv_estimator,
+ gamma=self.config.algorithm.gamma,
+ lam=self.config.algorithm.lam,
+ num_repeat=self.config.actor_rollout_ref.rollout.n,
+ norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
+ config=self.config.algorithm,
+ )
+
+ # update critic
+ if self.use_critic:
+ with marked_timer("update_critic", timing_raw, color="pink"):
+ critic_output = self.critic_wg.update_critic(batch)
+ critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
+ metrics.update(critic_output_metrics)
+
+ # implement critic warmup
+ if self.config.trainer.critic_warmup <= self.global_steps:
+ # update actor
+ with marked_timer("update_actor", timing_raw, color="red"):
+ rollout_config = self.config.actor_rollout_ref.rollout
+ batch.meta_info["multi_turn"] = rollout_config.multi_turn.enable
+ # TODO: Make "temperature" single source of truth from generation.
+ batch.meta_info["temperature"] = rollout_config.temperature
+ actor_output = self.actor_rollout_wg.update_actor(batch)
+ actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
+ metrics.update(actor_output_metrics)
+
+ # Log rollout generations if enabled
+ rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
+ if rollout_data_dir:
+ self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir)
+
+ # validate
+ if (
+ self.val_reward_fn is not None
+ and self.config.trainer.test_freq > 0
+ and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
+ ):
+ with marked_timer("testing", timing_raw, color="green"):
+ val_metrics: dict = self._validate()
+
+ # Initialize the best validation metrics for pass@k before training
+ self._update_best_pass_at(val_metrics, 1)
+ val_metrics["best/pass@1"] = self.best_dev_pass_at_k[1]
+
+ if is_last_step:
+ last_val_metrics = val_metrics
+ metrics.update(val_metrics)
+
+ # Check if the ESI (Elastic Server Instance)/training plan is close to expiration.
+ esi_close_to_expiration = should_save_ckpt_esi(
+ max_steps_duration=self.max_steps_duration,
+ redundant_time=self.config.trainer.esi_redundant_time,
+ )
+ # Check if the conditions for saving a checkpoint are met.
+ # The conditions include a mandatory condition (1) and
+ # one of the following optional conditions (2/3/4):
+ # 1. The save frequency is set to a positive value.
+ # 2. It's the last training step.
+ # 3. The current step number is a multiple of the save frequency.
+ # 4. The ESI(Elastic Server Instance)/training plan is close to expiration.
+ if self.config.trainer.save_freq > 0 and (
+ is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration
+ ):
+ if esi_close_to_expiration:
+ print("Force saving checkpoint: ESI instance expiration approaching.")
+ with marked_timer("save_checkpoint", timing_raw, color="green"):
+ self._save_checkpoint()
+
+ with marked_timer("stop_profile", timing_raw):
+ next_step_profile = (
+ self.global_steps + 1 in self.config.global_profiler.steps
+ if self.config.global_profiler.steps is not None
+ else False
+ )
+ self._stop_profiling(
+ curr_step_profile and not next_step_profile
+ if self.config.global_profiler.profile_continuous_steps
+ else curr_step_profile
+ )
+ prev_step_profile = curr_step_profile
+ curr_step_profile = next_step_profile
+
+ steps_duration = timing_raw["step"]
+ self.max_steps_duration = max(self.max_steps_duration, steps_duration)
+
+ # training metrics
+ metrics.update(
+ {
+ "training/global_step": self.global_steps,
+ "training/epoch": epoch,
+ }
+ )
+ # collect metrics
+ metrics.update(
+ compute_data_metrics(
+ batch=batch,
+ use_critic=self.use_critic,
+ elliptical=self.config.reward_model.elliptical.enable,
+ )
+ )
+ metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
+ # TODO: implement actual tflpo and theoretical tflpo
+ n_gpus = self.resource_pool_manager.get_n_gpus()
+ metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
+ # Note: mismatch metrics (KL, PPL, etc.) are collected at line 1179 after advantage computation
+
+ # this is experimental and may be changed/removed in the future in favor of a general-purpose one
+ if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler):
+ self.train_dataloader.sampler.update(batch=batch)
+
+ # TODO: make a canonical logger that supports various backend
+ logger.log(data=metrics, step=self.global_steps)
+
+ progress_bar.update(1)
+ self.global_steps += 1
+
+ if (
+ hasattr(self.config.actor_rollout_ref.actor, "profiler")
+ and self.config.actor_rollout_ref.actor.profiler.tool == "torch_memory"
+ ):
+ self.actor_rollout_wg.dump_memory_snapshot(
+ tag=f"post_update_step{self.global_steps}", sub_dir=f"step{self.global_steps}"
+ )
+
+ if is_last_step:
+ pprint(f"Final validation metrics: {last_val_metrics}")
+ progress_bar.close()
+ return
+
+ # this is experimental and may be changed/removed in the future
+ # in favor of a general-purpose data buffer pool
+ if hasattr(self.train_dataset, "on_batch_end"):
+ # The dataset may be changed after each training batch
+ self.train_dataset.on_batch_end(batch=batch)
diff --git a/recipe/rep_exp/reward_manager/elliptical_reward_manager.py b/recipe/rep_exp/reward_manager/elliptical_reward_manager.py
new file mode 100644
index 00000000000..83040ac1c02
--- /dev/null
+++ b/recipe/rep_exp/reward_manager/elliptical_reward_manager.py
@@ -0,0 +1,138 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import defaultdict
+
+import torch
+
+from verl import DataProto
+from verl.workers.reward_manager import NaiveRewardManager, register
+
+from ..reward_score import default_compute_score
+
+
+@register("elliptical")
+class EllipticalRewardManager(NaiveRewardManager):
+ """The reward manager."""
+
+ def __init__(
+ self,
+ tokenizer,
+ num_examine,
+ compute_score=None,
+ reward_fn_key="data_source",
+ beta: int = 1.0,
+ turn_off_elliptical_if_none_correct: bool = False,
+ turn_off_elliptical_if_some_correct: bool = False,
+ turn_off_elliptical_if_all_correct: bool = False,
+ turn_off_elliptical_if_rollout_incorrect: bool = False,
+ alpha: float = 1.0,
+ ) -> None:
+ """
+ Initialize the NaiveRewardManager instance.
+
+ Args:
+ tokenizer: The tokenizer used to decode token IDs into text.
+ num_examine: The number of batches of decoded responses to print to the console for debugging purpose.
+ compute_score: A function to compute the reward score. If None, `default_compute_score` will be used.
+ reward_fn_key: The key used to access the data source in the non-tensor batch data. Defaults to
+ "data_source".
+ """
+ super().__init__(tokenizer, num_examine, default_compute_score, reward_fn_key)
+ self.beta = beta
+ self.turn_off_elliptical_if_none_correct = turn_off_elliptical_if_none_correct
+ self.turn_off_elliptical_if_some_correct = turn_off_elliptical_if_some_correct
+ self.turn_off_elliptical_if_all_correct = turn_off_elliptical_if_all_correct
+ self.turn_off_elliptical_if_rollout_incorrect = turn_off_elliptical_if_rollout_incorrect
+ self.alpha = alpha
+
+ def __call__(self, data: DataProto, return_dict=False):
+ if "rm_scores" not in data.batch:
+ # this means we're doing validation, so we don't need to compute the elliptical reward
+ return super().__call__(data, return_dict=return_dict)
+
+ reward_extra_info = defaultdict(list)
+
+ intrinsic_reward_tensor = data.batch["rm_scores"]
+ data.pop(batch_keys=["rm_scores"])
+
+ extrinsic_reward_result = super().__call__(data, return_dict=True)
+ extrinsic_reward_tensor = extrinsic_reward_result["reward_tensor"]
+ extrinsic_reward_extra_info = extrinsic_reward_result["reward_extra_info"]
+
+ self._maybe_turn_off_elliptical(data, extrinsic_reward_tensor, intrinsic_reward_tensor)
+
+ reward_tensor = self.alpha * extrinsic_reward_tensor + self.beta * intrinsic_reward_tensor
+
+ # Intrinsic reward extra info
+ reward_extra_info["intrinsic_reward"] = intrinsic_reward_tensor.numpy()
+ reward_extra_info["beta_scaled_intrinsic_reward"] = self.beta * intrinsic_reward_tensor.numpy()
+ reward_extra_info["extrinsic_reward"] = extrinsic_reward_tensor.numpy()
+ reward_extra_info["alpha_scaled_extrinsic_reward"] = self.alpha * extrinsic_reward_tensor.numpy()
+ reward_extra_info["total_reward"] = reward_tensor.numpy()
+
+ # Update with extrinsic reward extra info
+ reward_extra_info.update(extrinsic_reward_extra_info)
+
+ if return_dict:
+ return {
+ "reward_tensor": reward_tensor,
+ "reward_extra_info": reward_extra_info,
+ }
+ else:
+ return reward_tensor
+
+ def _maybe_turn_off_elliptical(
+ self, data: DataProto, extrinsic_reward_tensor: torch.Tensor, intrinsic_reward_tensor: torch.Tensor
+ ) -> None:
+ """
+ Potentially turn off the elliptical reward for samples that have one of the following properties:
+ (1) any of the rollouts have the correct answer
+ (2) all of the rollouts have the correct answer
+
+ Args:
+ data (DataProto): The data proto containing the batch data.
+ extrinsic_reward_tensor (torch.Tensor): The extrinsic reward tensor.
+ intrinsic_reward_tensor (torch.Tensor): The intrinsic reward tensor.
+
+ Returns:
+ None
+ """
+ if self.turn_off_elliptical_if_rollout_incorrect:
+ mask = extrinsic_reward_tensor.sum(dim=-1) == 0
+ intrinsic_reward_tensor[mask] = 0.0
+
+ visited_uids = set()
+ for uid in data.non_tensor_batch["uid"]:
+ if uid in visited_uids:
+ continue
+
+ visited_uids.add(uid)
+ mask = torch.from_numpy(data.non_tensor_batch["uid"] == uid)
+
+ # Potentially turn off elliptical if **no** rollout has the correct answer
+ if self.turn_off_elliptical_if_none_correct and extrinsic_reward_tensor[mask].sum() == 0:
+ intrinsic_reward_tensor[mask] = 0.0
+
+ # Potentially turn off elliptical if **some** rollouts have the correct answer
+ if (
+ self.turn_off_elliptical_if_some_correct
+ and extrinsic_reward_tensor[mask].sum() > 0
+ and extrinsic_reward_tensor[mask].sum() < mask.sum()
+ ):
+ intrinsic_reward_tensor[mask] = 0.0
+
+ # Potentially turn off elliptical if **all** rollouts have the correct answer
+ if self.turn_off_elliptical_if_all_correct and extrinsic_reward_tensor[mask].sum() == mask.sum():
+ intrinsic_reward_tensor[mask] = 0.0
diff --git a/recipe/rep_exp/reward_score/__init__.py b/recipe/rep_exp/reward_score/__init__.py
new file mode 100644
index 00000000000..124189fa228
--- /dev/null
+++ b/recipe/rep_exp/reward_score/__init__.py
@@ -0,0 +1,136 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# from . import gsm8k, math, prime_math, prime_code
+
+from verl.utils.import_utils import deprecated
+
+
+def default_compute_score(
+ data_source,
+ solution_str,
+ ground_truth,
+ extra_info=None,
+ sandbox_fusion_url=None,
+ concurrent_semaphore=None,
+ memory_limit_mb=None,
+ **kwargs,
+):
+ """Compute the score for a given solution based on the data source.
+
+ Args:
+ data_source (str): The source dataset identifier which determines the scoring method.
+ solution_str (str): The solution string to be evaluated.
+ ground_truth (str): The ground truth answer for comparison.
+ extra_info (dict, optional): Additional information that might be needed for scoring. Defaults to None.
+
+ Returns:
+ float: The computed score as a floating point number. If the result is a dictionary,
+ it returns the dictionary instead.
+
+ Raises:
+ NotImplementedError: If the reward function is not implemented for the given data source.
+ """
+ if data_source == "openai/gsm8k":
+ from verl.utils.reward_score import gsm8k
+
+ res = gsm8k.compute_score(solution_str, ground_truth)
+ elif data_source in ["lighteval/MATH", "DigitalLearningGmbH/MATH-lighteval", "HuggingFaceH4/MATH-500"]:
+ from verl.utils.reward_score import math_reward
+
+ res = math_reward.compute_score(solution_str, ground_truth)
+ # [Optional] Math-Verify Integration
+ # For enhanced accuracy, consider utilizing Math-Verify (https://github.com/huggingface/Math-Verify).
+ # Note: Math-Verify needs to be manually installed via pip: `pip install math-verify`.
+ # To use it, override the `compute_score` function with the following implementation:
+
+ # from . import math_verify
+ # res = math_verify.compute_score(solution_str, ground_truth)
+ elif data_source in ["math_dapo", "math", "math_dapo_reasoning"] or data_source.startswith("aime"):
+ # res = math_dapo.compute_score(solution_str, ground_truth)
+ from verl.utils.reward_score import math_verify
+
+ res = math_verify.compute_score(solution_str, ground_truth)
+ elif data_source in [
+ "numina_aops_forum",
+ "numina_synthetic_math",
+ "numina_amc_aime",
+ "numina_synthetic_amc",
+ "numina_cn_k12",
+ "numina_olympiads",
+ ]:
+ from verl.utils.reward_score import prime_math
+
+ res = prime_math.compute_score(solution_str, ground_truth)
+ elif data_source in ["codecontests", "apps", "codeforces", "taco"]:
+ # Use the passed sandbox_fusion_url if available
+ if sandbox_fusion_url:
+ from verl.utils.reward_score import sandbox_fusion
+
+ # Pass the URL directly, ground_truth likely contains test cases here
+ res = sandbox_fusion.compute_score(
+ sandbox_fusion_url, concurrent_semaphore, memory_limit_mb, solution_str, ground_truth, continuous=True
+ )
+ else:
+ # If no sandbox URL is provided, fall back to prime_code or raise error
+ from verl.utils.reward_score import prime_code
+
+ # Assuming prime_code doesn't need the URL
+ res = prime_code.compute_score(solution_str, ground_truth, continuous=True)
+ elif data_source in ["hiyouga/geometry3k"]:
+ from verl.utils.reward_score import geo3k
+
+ res = geo3k.compute_score(solution_str, ground_truth)
+ elif data_source in [
+ "searchR1_nq",
+ "searchR1_triviaqa",
+ "searchR1_popqa",
+ "searchR1_hotpotqa",
+ "searchR1_2wikimultihopqa",
+ "searchR1_musique",
+ "searchR1_bamboogle",
+ ]:
+ from verl.utils.reward_score import search_r1_like_qa_em
+
+ res = search_r1_like_qa_em.compute_score(solution_str, ground_truth)
+
+ else:
+ raise NotImplementedError(f"Reward function is not implemented for {data_source=}")
+
+ if isinstance(res, dict):
+ return res
+ elif isinstance(res, int | float | bool):
+ return float(res)
+ else:
+ return float(res[0])
+
+
+@deprecated("verl.utils.reward_score.default_compute_score")
+def _default_compute_score(
+ data_source,
+ solution_str,
+ ground_truth,
+ extra_info=None,
+ sandbox_fusion_url=None,
+ concurrent_semaphore=None,
+ memory_limit_mb=None,
+):
+ """
+ Legacy function API to be deprecated. Please use `default_compute_score` instead.
+ """
+ return default_compute_score(
+ data_source, solution_str, ground_truth, extra_info, sandbox_fusion_url, concurrent_semaphore, memory_limit_mb
+ )
+
+
+__all__ = ["default_compute_score"]
diff --git a/recipe/rep_exp/train_elliptical.sh b/recipe/rep_exp/train_elliptical.sh
new file mode 100644
index 00000000000..70c56afec37
--- /dev/null
+++ b/recipe/rep_exp/train_elliptical.sh
@@ -0,0 +1,104 @@
+TASK=${1} # math, gsm8k, dapo-with-aime24
+SPARSE_DIM=${2} # the original paper used 32 for math/gsm8k, 128 for dapo-with-aime24
+BETA=${3} # 0.01
+SEED=${4}
+
+train_path=$HOME/data/${TASK}/train.parquet
+dev_path=$HOME/data/${TASK}/dev.parquet
+
+train_files="['$train_path']"
+dev_files="['$dev_path']"
+
+# Adjust things a bit for dapo-aime training since it has longer generations
+# and hence is slower and consumes more memory
+if [ ${TASK} == "dapo-with-aime24" ]; then
+ TEST_FREQ=10
+ SAVE_FREQ=10
+ TRAIN_BATCH_SIZE=512
+ PPO_MINI_BATCH_SIZE=128
+
+ MAX_PROMPT_LENGTH=$((1024 * 2))
+ MAX_RESPONSE_LENGTH=$((1024 * 8))
+ MAX_NUM_BATCHED_TOKENS=$((MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH))
+ GPU_MEMORY_UTILIZATION=0.5
+ PPO_MICRO_BATCH_SIZE_PER_GPU=8
+ REWARD_MODEL_MICRO_BATCH_SIZE_PER_GPU=16
+else
+ TEST_FREQ=20
+ SAVE_FREQ=20
+ TRAIN_BATCH_SIZE=1024
+ PPO_MINI_BATCH_SIZE=256
+
+ MAX_PROMPT_LENGTH=1024
+ MAX_RESPONSE_LENGTH=1024
+ MAX_NUM_BATCHED_TOKENS=8192
+ GPU_MEMORY_UTILIZATION=0.6
+ PPO_MICRO_BATCH_SIZE_PER_GPU=16
+ REWARD_MODEL_MICRO_BATCH_SIZE_PER_GPU=32
+fi
+
+OFFLINE=True
+
+PYTHONUNBUFFERED=1 TRANSFORMERS_OFFLINE=${OFFLINE} python3 -u -m recipe.rep_exp.main_rep_exp \
+ algorithm.adv_estimator=grpo \
+ data.train_files="$train_files" \
+ data.val_files="$dev_files" \
+ data.train_batch_size=$TRAIN_BATCH_SIZE \
+ data.max_prompt_length=$MAX_PROMPT_LENGTH \
+ data.max_response_length=$MAX_RESPONSE_LENGTH \
+ data.filter_overlong_prompts=True \
+ data.truncation='error' \
+ actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \
+ actor_rollout_ref.actor.optim.lr=1e-6 \
+ actor_rollout_ref.model.use_remove_padding=True \
+ actor_rollout_ref.actor.ppo_mini_batch_size=$PPO_MINI_BATCH_SIZE \
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$PPO_MICRO_BATCH_SIZE_PER_GPU \
+ actor_rollout_ref.actor.kl_loss_coef=0.0 \
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
+ actor_rollout_ref.actor.entropy_coeff=0 \
+ actor_rollout_ref.actor.ppo_epochs=1 \
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
+ actor_rollout_ref.actor.fsdp_config.param_offload=False \
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
+ actor_rollout_ref.rollout.mode=sync \
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
+ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
+ actor_rollout_ref.rollout.max_num_batched_tokens=$MAX_NUM_BATCHED_TOKENS \
+ actor_rollout_ref.rollout.name=vllm \
+ actor_rollout_ref.rollout.gpu_memory_utilization=$GPU_MEMORY_UTILIZATION \
+ actor_rollout_ref.rollout.n=8 \
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
+ actor_rollout_ref.ref.fsdp_config.param_offload=True \
+ reward_model.enable=True \
+ reward_model.model.path=Qwen/Qwen2.5-7B-Instruct \
+ reward_model.model.use_remove_padding=False \
+ reward_model.model.fsdp_config.param_offload=True \
+ reward_model.micro_batch_size_per_gpu=$REWARD_MODEL_MICRO_BATCH_SIZE_PER_GPU \
+ reward_model.model.input_tokenizer=null \
+ reward_model.elliptical.enable=True \
+ reward_model.elliptical.sparse_dim=$SPARSE_DIM \
+ reward_model.elliptical.reward_type=leverage \
+ reward_model.elliptical.randomize_sparse_matrix=True \
+ reward_model.elliptical.normalization=none \
+ reward_model.elliptical.persist_covariance=False \
+ reward_model.reward_manager=elliptical \
+ reward_model.reward_kwargs.elliptical.beta=$BETA \
+ reward_model.reward_kwargs.elliptical.turn_off_elliptical_if_none_correct=True \
+ reward_model.reward_kwargs.elliptical.turn_off_elliptical_if_some_correct=False \
+ reward_model.reward_kwargs.elliptical.turn_off_elliptical_if_all_correct=False \
+ reward_model.reward_kwargs.elliptical.turn_off_elliptical_if_rollout_incorrect=False \
+ actor_rollout_ref.actor.loss_agg_mode=token-mean \
+ actor_rollout_ref.actor.use_kl_loss=True \
+ algorithm.norm_adv_by_std_in_grpo=True \
+ algorithm.use_kl_in_reward=False \
+ trainer.critic_warmup=0 \
+ trainer.logger='["console","wandb"]' \
+ trainer.project_name='rep-exp' \
+ trainer.experiment_name="${TASK}_elliptical_seed_${SEED}_beta_${BETA}_sparse_dim_${SPARSE_DIM}" \
+ trainer.n_gpus_per_node=8 \
+ trainer.nnodes=1 \
+ trainer.save_freq=$SAVE_FREQ \
+ trainer.test_freq=$TEST_FREQ \
+ trainer.total_epochs=1000 \
+ trainer.resume_mode=disable \
+ trainer.resume_from_path=''
\ No newline at end of file
diff --git a/recipe/rep_exp/utils/aggregate_logger.py b/recipe/rep_exp/utils/aggregate_logger.py
new file mode 100644
index 00000000000..54a70272d50
--- /dev/null
+++ b/recipe/rep_exp/utils/aggregate_logger.py
@@ -0,0 +1,49 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+A Ray logger will receive logging info from different processes.
+"""
+
+import json
+import os
+
+
+class JsonEvalLogger:
+ """
+ A logger that logs to a json file.
+ Args:
+ save_path: The path to the checkpoint to resume from.
+ task: The task name, used to name the experiment.
+ """
+
+ def __init__(self, save_path: str, task: str):
+ self.root = "eval"
+ if save_path is not None and save_path != "":
+ self.experiment_name = save_path.split("/")[-2]
+ self.checkpoint_type = save_path.split("/")[-1]
+ else:
+ self.experiment_name = f"{task}_untrained"
+ self.checkpoint_type = ""
+
+ def flush(self):
+ pass
+
+ def log(self, data, step):
+ # Create eval folder
+ save_folder = os.path.join(self.root, self.experiment_name, self.checkpoint_type)
+ os.makedirs(save_folder, exist_ok=True)
+
+ # Save to json
+ with open(os.path.join(save_folder, "eval.json"), "w") as f:
+ json.dump(data, f)
diff --git a/recipe/rep_exp/utils/tracking.py b/recipe/rep_exp/utils/tracking.py
new file mode 100644
index 00000000000..898fc0f1aae
--- /dev/null
+++ b/recipe/rep_exp/utils/tracking.py
@@ -0,0 +1,517 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+A unified tracking interface that supports logging data to different backend
+"""
+
+import dataclasses
+import json
+import os
+from enum import Enum
+from functools import partial
+from pathlib import Path
+from typing import Any
+
+
+class Tracking:
+ """A unified tracking interface for logging experiment data to multiple backends.
+
+ This class provides a centralized way to log experiment metrics, parameters, and artifacts
+ to various tracking backends including WandB, MLflow, SwanLab, TensorBoard, and console.
+
+ Attributes:
+ supported_backend: List of supported tracking backends.
+ logger: Dictionary of initialized logger instances for each backend.
+ """
+
+ supported_backend = [
+ "wandb",
+ "mlflow",
+ "swanlab",
+ "vemlp_wandb",
+ "tensorboard",
+ "console",
+ "clearml",
+ "trackio",
+ "file",
+ "json_eval",
+ ]
+
+ def __init__(self, project_name, experiment_name, default_backend: str | list[str] = "console", config=None):
+ if isinstance(default_backend, str):
+ default_backend = [default_backend]
+ for backend in default_backend:
+ if backend == "tracking":
+ import warnings
+
+ warnings.warn("`tracking` logger is deprecated. use `wandb` instead.", DeprecationWarning, stacklevel=2)
+ else:
+ assert backend in self.supported_backend, f"{backend} is not supported"
+
+ self.logger = {}
+
+ if "tracking" in default_backend or "wandb" in default_backend:
+ import os
+
+ import wandb
+
+ settings = None
+ if config and config["trainer"].get("wandb_proxy", None):
+ settings = wandb.Settings(https_proxy=config["trainer"]["wandb_proxy"])
+ entity = os.environ.get("WANDB_ENTITY", None)
+ wandb.init(project=project_name, name=experiment_name, entity=entity, config=config, settings=settings)
+ self.logger["wandb"] = wandb
+
+ if "trackio" in default_backend:
+ import trackio
+
+ trackio.init(project=project_name, name=experiment_name, config=config)
+ self.logger["trackio"] = trackio
+
+ if "mlflow" in default_backend:
+ import os
+
+ import mlflow
+
+ MLFLOW_TRACKING_URI = os.environ.get("MLFLOW_TRACKING_URI", "sqlite:////tmp/mlruns.db")
+ mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
+
+ # Project_name is actually experiment_name in MLFlow
+ # If experiment does not exist, will create a new experiment
+ experiment = mlflow.set_experiment(project_name)
+ mlflow.start_run(experiment_id=experiment.experiment_id, run_name=experiment_name)
+ mlflow.log_params(_compute_mlflow_params_from_objects(config))
+ self.logger["mlflow"] = _MlflowLoggingAdapter()
+
+ if "swanlab" in default_backend:
+ import os
+
+ import swanlab
+
+ SWANLAB_API_KEY = os.environ.get("SWANLAB_API_KEY", None)
+ SWANLAB_LOG_DIR = os.environ.get("SWANLAB_LOG_DIR", "swanlog")
+ SWANLAB_MODE = os.environ.get("SWANLAB_MODE", "cloud")
+ if SWANLAB_API_KEY:
+ swanlab.login(SWANLAB_API_KEY) # NOTE: previous login information will be overwritten
+
+ if config is None:
+ config = {} # make sure config is not None, otherwise **config will raise error
+ swanlab.init(
+ project=project_name,
+ experiment_name=experiment_name,
+ config={"FRAMEWORK": "verl", **config},
+ logdir=SWANLAB_LOG_DIR,
+ mode=SWANLAB_MODE,
+ )
+ self.logger["swanlab"] = swanlab
+
+ if "vemlp_wandb" in default_backend:
+ import os
+
+ import volcengine_ml_platform
+ from volcengine_ml_platform import wandb as vemlp_wandb
+
+ volcengine_ml_platform.init(
+ ak=os.environ["VOLC_ACCESS_KEY_ID"],
+ sk=os.environ["VOLC_SECRET_ACCESS_KEY"],
+ region=os.environ["MLP_TRACKING_REGION"],
+ )
+
+ vemlp_wandb.init(
+ project=project_name,
+ name=experiment_name,
+ config=config,
+ sync_tensorboard=True,
+ )
+ self.logger["vemlp_wandb"] = vemlp_wandb
+
+ if "tensorboard" in default_backend:
+ self.logger["tensorboard"] = _TensorboardAdapter(project_name, experiment_name)
+
+ if "console" in default_backend:
+ from verl.utils.logger import LocalLogger
+
+ self.console_logger = LocalLogger(print_to_console=True)
+ self.logger["console"] = self.console_logger
+
+ if "json_eval" in default_backend:
+ from .aggregate_logger import JsonEvalLogger
+
+ model_path = config["actor_rollout_ref"]["model"]["path"]
+ if model_path.endswith("actor/hf"):
+ # Case where the model path is a saved checkpoint
+ save_path = model_path.split("/")[-4:-2]
+ save_path = "/".join(save_path)
+ else:
+ # Case where the model is pretrained model from huggingface
+ save_path = ""
+
+ # Parse task from config
+ train_file = config["data"]["train_files"][0]
+ task = train_file.split("/")[-2]
+
+ self.json_eval_logger = JsonEvalLogger(save_path=save_path, task=task)
+ self.logger["json_eval"] = self.json_eval_logger
+
+ if "clearml" in default_backend:
+ self.logger["clearml"] = ClearMLLogger(project_name, experiment_name, config)
+
+ if "file" in default_backend:
+ self.logger["file"] = FileLogger(project_name, experiment_name)
+
+ def log(self, data, step, backend=None):
+ for default_backend, logger_instance in self.logger.items():
+ if backend is None or default_backend in backend:
+ logger_instance.log(data=data, step=step)
+
+ def __del__(self):
+ if "wandb" in self.logger:
+ self.logger["wandb"].finish(exit_code=0)
+ if "swanlab" in self.logger:
+ self.logger["swanlab"].finish()
+ if "vemlp_wandb" in self.logger:
+ self.logger["vemlp_wandb"].finish(exit_code=0)
+ if "tensorboard" in self.logger:
+ self.logger["tensorboard"].finish()
+ if "clearml" in self.logger:
+ self.logger["clearml"].finish()
+ if "trackio" in self.logger:
+ self.logger["trackio"].finish()
+ if "file" in self.logger:
+ self.logger["file"].finish()
+
+
+class ClearMLLogger:
+ def __init__(self, project_name: str, experiment_name: str, config):
+ self.project_name = project_name
+ self.experiment_name = experiment_name
+
+ import clearml
+
+ self._task: clearml.Task = clearml.Task.init(
+ task_name=experiment_name,
+ project_name=project_name,
+ continue_last_task=True,
+ output_uri=False,
+ )
+
+ self._task.connect_configuration(config, name="Hyperparameters")
+
+ def _get_logger(self):
+ return self._task.get_logger()
+
+ def log(self, data, step):
+ import numpy as np
+ import pandas as pd
+
+ # logs = self._rewrite_logs(data)
+ logger = self._get_logger()
+ for k, v in data.items():
+ title, series = k.split("/", 1)
+
+ if isinstance(v, int | float | np.floating | np.integer):
+ logger.report_scalar(
+ title=title,
+ series=series,
+ value=v,
+ iteration=step,
+ )
+ elif isinstance(v, pd.DataFrame):
+ logger.report_table(
+ title=title,
+ series=series,
+ table_plot=v,
+ iteration=step,
+ )
+ else:
+ logger.warning(
+ f'Trainer is attempting to log a value of "{v}" of type {type(v)} for key "{k}". This '
+ f"invocation of ClearML logger's function is incorrect so this attribute was dropped. "
+ )
+
+ def finish(self):
+ self._task.close()
+
+
+class FileLogger:
+ def __init__(self, project_name: str, experiment_name: str):
+ self.project_name = project_name
+ self.experiment_name = experiment_name
+
+ self.filepath = os.getenv("VERL_FILE_LOGGER_PATH", None)
+ if self.filepath is None:
+ root_path = os.path.expanduser(os.getenv("VERL_FILE_LOGGER_ROOT", "."))
+ directory = os.path.join(root_path, self.project_name)
+ os.makedirs(directory, exist_ok=True)
+ self.filepath = os.path.join(directory, f"{self.experiment_name}.jsonl")
+ print(f"Creating file logger at {self.filepath}")
+ self.fp = open(self.filepath, "w")
+
+ def log(self, data, step):
+ data = {"step": step, "data": data}
+ self.fp.write(json.dumps(data) + "\n")
+
+ def finish(self):
+ self.fp.close()
+
+
+class _TensorboardAdapter:
+ def __init__(self, project_name, experiment_name):
+ import os
+
+ from torch.utils.tensorboard import SummaryWriter
+
+ tensorboard_dir = os.environ.get("TENSORBOARD_DIR", f"tensorboard_log/{project_name}/{experiment_name}")
+ os.makedirs(tensorboard_dir, exist_ok=True)
+ print(f"Saving tensorboard log to {tensorboard_dir}.")
+ self.writer = SummaryWriter(tensorboard_dir)
+
+ def log(self, data, step):
+ for key in data:
+ self.writer.add_scalar(key, data[key], step)
+
+ def finish(self):
+ self.writer.close()
+
+
+class _MlflowLoggingAdapter:
+ def __init__(self):
+ import logging
+ import re
+
+ self.logger = logging.getLogger(__name__)
+ # MLflow metric key validation logic:
+ # https://github.com/mlflow/mlflow/blob/master/mlflow/utils/validation.py#L157C12-L157C44
+ # Only characters allowed: slashes, alphanumerics, underscores, periods, dashes, colons,
+ # and spaces.
+ self._invalid_chars_pattern = re.compile(
+ r"[^/\w.\- :]"
+ ) # Allowed: slashes, alphanumerics, underscores, periods, dashes, colons, and spaces.
+
+ def log(self, data, step):
+ import mlflow
+
+ def sanitize_key(key):
+ # First replace @ with _at_ for backward compatibility
+ sanitized = key.replace("@", "_at_")
+ # Then replace any other invalid characters with _
+ sanitized = self._invalid_chars_pattern.sub("_", sanitized)
+ if sanitized != key:
+ self.logger.warning(
+ "[MLflow] Metric key '%s' sanitized to '%s' due to invalid characters.", key, sanitized
+ )
+ return sanitized
+
+ results = {sanitize_key(k): v for k, v in data.items()}
+ mlflow.log_metrics(metrics=results, step=step)
+
+
+def _compute_mlflow_params_from_objects(params) -> dict[str, Any]:
+ if params is None:
+ return {}
+
+ return _flatten_dict(_transform_params_to_json_serializable(params, convert_list_to_dict=True), sep="/")
+
+
+def _transform_params_to_json_serializable(x, convert_list_to_dict: bool):
+ _transform = partial(_transform_params_to_json_serializable, convert_list_to_dict=convert_list_to_dict)
+
+ if dataclasses.is_dataclass(x):
+ return _transform(dataclasses.asdict(x))
+ if isinstance(x, dict):
+ return {k: _transform(v) for k, v in x.items()}
+ if isinstance(x, list):
+ if convert_list_to_dict:
+ return {"list_len": len(x)} | {f"{i}": _transform(v) for i, v in enumerate(x)}
+ else:
+ return [_transform(v) for v in x]
+ if isinstance(x, Path):
+ return str(x)
+ if isinstance(x, Enum):
+ return x.value
+
+ return x
+
+
+def _flatten_dict(raw: dict[str, Any], *, sep: str) -> dict[str, Any]:
+ import pandas as pd
+
+ ans = pd.json_normalize(raw, sep=sep).to_dict(orient="records")[0]
+ assert isinstance(ans, dict)
+ return ans
+
+
+@dataclasses.dataclass
+class ValidationGenerationsLogger:
+ project_name: str = None
+ experiment_name: str = None
+
+ def log(self, loggers, samples, step):
+ if "wandb" in loggers:
+ self.log_generations_to_wandb(samples, step)
+ if "swanlab" in loggers:
+ self.log_generations_to_swanlab(samples, step)
+ if "mlflow" in loggers:
+ self.log_generations_to_mlflow(samples, step)
+
+ if "clearml" in loggers:
+ self.log_generations_to_clearml(samples, step)
+ if "tensorboard" in loggers:
+ self.log_generations_to_tensorboard(samples, step)
+
+ if "vemlp_wandb" in loggers:
+ self.log_generations_to_vemlp_wandb(samples, step)
+
+ def log_generations_to_vemlp_wandb(self, samples, step):
+ from volcengine_ml_platform import wandb as vemlp_wandb
+
+ self._log_generations_to_wandb(samples, step, vemlp_wandb)
+
+ def log_generations_to_wandb(self, samples, step):
+ import wandb
+
+ self._log_generations_to_wandb(samples, step, wandb)
+
+ def _log_generations_to_wandb(self, samples, step, wandb):
+ """Log samples to wandb as a table"""
+
+ # Create column names for all samples
+ columns = ["step"] + sum(
+ [[f"input_{i + 1}", f"output_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))], []
+ )
+
+ if not hasattr(self, "validation_table"):
+ # Initialize the table on first call
+ self.validation_table = wandb.Table(columns=columns)
+
+ # Create a new table with same columns and existing data
+ # Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737
+ new_table = wandb.Table(columns=columns, data=self.validation_table.data)
+
+ # Add new row with all data
+ row_data = []
+ row_data.append(step)
+ for sample in samples:
+ row_data.extend(sample)
+
+ new_table.add_data(*row_data)
+
+ # Update reference and log
+ wandb.log({"val/generations": new_table}, step=step)
+ self.validation_table = new_table
+
+ def log_generations_to_swanlab(self, samples, step):
+ """Log samples to swanlab as text"""
+ import swanlab
+
+ swanlab_table = swanlab.echarts.Table()
+
+ # Create column names
+ headers = ["step", "input", "output", "score"]
+
+ swanlab_row_list = [[step, *sample] for sample in samples]
+ swanlab_table.add(headers=headers, rows=swanlab_row_list)
+
+ # Log to swanlab
+ swanlab.log({"val/generations": swanlab_table}, step=step)
+
+ def log_generations_to_mlflow(self, samples, step):
+ """Log validation generation to mlflow as artifacts"""
+ # https://mlflow.org/docs/latest/api_reference/python_api/mlflow.html?highlight=log_artifact#mlflow.log_artifact
+
+ import json
+ import tempfile
+
+ import mlflow
+
+ try:
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ validation_gen_step_file = Path(tmp_dir, f"val_step{step}.json")
+ row_data = []
+ for sample in samples:
+ data = {"input": sample[0], "output": sample[1], "score": sample[2]}
+ row_data.append(data)
+ with open(validation_gen_step_file, "w") as file:
+ json.dump(row_data, file)
+ mlflow.log_artifact(validation_gen_step_file)
+ except Exception as e:
+ print(f"WARNING: save validation generation file to mlflow failed with error {e}")
+
+ def log_generations_to_clearml(self, samples, step):
+ """Log validation generation to clearml as table"""
+
+ import clearml
+ import pandas as pd
+
+ task: clearml.Task | None = clearml.Task.current_task()
+ if task is None:
+ return
+
+ table = [
+ {
+ "step": step,
+ "input": sample[0],
+ "output": sample[1],
+ "score": sample[2],
+ }
+ for sample in samples
+ ]
+
+ logger = task.get_logger()
+ logger.report_table(
+ series="Validation generations",
+ title="Validation",
+ table_plot=pd.DataFrame.from_records(table),
+ iteration=step,
+ )
+
+ def log_generations_to_tensorboard(self, samples, step):
+ """Log samples to tensorboard as text"""
+ # Initialize tensorboard writer if not exists
+ if not hasattr(self, "writer"):
+ from torch.utils.tensorboard import SummaryWriter
+
+ # Use the same directory structure as _TensorboardAdapter
+ if self.project_name and self.experiment_name:
+ default_dir = os.path.join("tensorboard_log", self.project_name, self.experiment_name)
+ else:
+ default_dir = "tensorboard_log"
+
+ tensorboard_dir = os.environ.get("TENSORBOARD_DIR", default_dir)
+ os.makedirs(tensorboard_dir, exist_ok=True)
+ self.writer = SummaryWriter(log_dir=tensorboard_dir)
+
+ # Format the samples data into readable text
+ text_content = f"**Generation Results - Step {step}**\n\n"
+
+ for i, sample in enumerate(samples):
+ text_content += f"### Sample {i + 1}\n"
+
+ # Assuming sample contains [input, output, score]
+ if len(sample) >= 3:
+ input_text, output_text, score = sample[0], sample[1], sample[2]
+
+ text_content += f"**Input:** {input_text}\n\n"
+ text_content += f"**Output:** {output_text}\n\n"
+ text_content += f"**Score:** {score}\n\n"
+ else:
+ # Handle cases where sample format might be different
+ text_content += f"**Data:** {sample}\n\n"
+
+ text_content += "---\n\n"
+
+ # Log to tensorboard as text
+ self.writer.add_text("val/generations", text_content, step)
+ # Flush to ensure data is written
+ self.writer.flush()
diff --git a/recipe/rep_exp/workers/elliptical_reward_model_worker.py b/recipe/rep_exp/workers/elliptical_reward_model_worker.py
new file mode 100644
index 00000000000..931779bf8c9
--- /dev/null
+++ b/recipe/rep_exp/workers/elliptical_reward_model_worker.py
@@ -0,0 +1,389 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+The main entry point to run the PPO algorithm
+"""
+
+import logging
+import os
+import warnings
+
+import numpy as np
+import torch
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+
+from verl import DataProto
+from verl.models.transformers.monkey_patch import apply_monkey_patch
+from verl.single_controller.base.decorator import Dispatch, Execute, register
+from verl.utils import hf_tokenizer
+from verl.utils.device import (
+ get_device_id,
+ get_device_name,
+)
+from verl.utils.fs import copy_to_local
+from verl.utils.fsdp_utils import (
+ CPUOffloadPolicy,
+ apply_fsdp2,
+ fsdp2_load_full_state_dict,
+ fsdp_version,
+ get_fsdp_wrap_policy,
+ get_init_weight_context_manager,
+ get_shard_placement_fn,
+ init_fn,
+)
+from verl.utils.profiler import DistProfiler
+from verl.workers.fsdp_workers import RewardModelWorker, get_sharding_strategy
+
+logger = logging.getLogger(__file__)
+logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
+
+device_name = get_device_name()
+
+
+class EllipticalRewardModelWorker(RewardModelWorker):
+ def __init__(self, config):
+ super().__init__(config)
+ self.lamb = config.elliptical.lamb
+ self.normalization = config.elliptical.normalization
+ self.sparse_dim = config.elliptical.sparse_dim
+ self.sparse_matrix = None
+ self.randomize_sparse_matrix = config.elliptical.randomize_sparse_matrix
+ self.persist_covariance = config.elliptical.persist_covariance
+ self.cov_inv_dict = {}
+ self.mean_hidden_states_mu_dict = {}
+ self.hidden_mean_counter_dict = {}
+
+ @staticmethod
+ def _construct_sparse_matrix(features: torch.Tensor, sparse_dim: int) -> torch.Tensor:
+ from sklearn.random_projection import SparseRandomProjection
+
+ sparse_proj = SparseRandomProjection(sparse_dim, density="auto")
+ sparse_proj.fit(features)
+ sparse_matrix = sparse_proj.components_
+ sparse_matrix_coo = sparse_matrix.tocoo()
+
+ # Convert the row and col lists to numpy arrays and then to a LongTensor (speed up)
+ indices = torch.LongTensor(np.array([sparse_matrix_coo.row, sparse_matrix_coo.col]))
+ values = torch.FloatTensor(sparse_matrix_coo.data)
+
+ sparse_mat = torch.sparse_coo_tensor(indices, values, [sparse_dim, features.shape[1]]).t()
+
+ return sparse_mat
+
+ def _build_model(self, config):
+ # the following line is necessary
+ from torch.distributed.fsdp import CPUOffload
+ from transformers import AutoConfig, AutoModel
+
+ use_shm = config.model.get("use_shm", False)
+ # download the checkpoint from hdfs
+ local_path = copy_to_local(config.model.path, use_shm=use_shm)
+
+ if self.config.model.input_tokenizer is None:
+ self._do_switch_chat_template = False
+ else:
+ self._do_switch_chat_template = True
+ input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer, use_shm=use_shm)
+ self.input_tokenizer = hf_tokenizer(
+ input_tokenizer_local_path, trust_remote_code=config.model.get("trust_remote_code", False)
+ )
+ self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get("trust_remote_code", False))
+
+ trust_remote_code = config.model.get("trust_remote_code", False)
+ model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
+ model_config.num_labels = 1
+
+ # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
+ init_context = get_init_weight_context_manager(
+ use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh
+ )
+
+ with init_context(), warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ model_config.classifier_dropout = 0.0
+ reward_module = AutoModel.from_pretrained(
+ pretrained_model_name_or_path=local_path,
+ config=model_config,
+ torch_dtype=torch.bfloat16,
+ attn_implementation="flash_attention_2",
+ trust_remote_code=trust_remote_code,
+ )
+
+ apply_monkey_patch(
+ model=reward_module,
+ use_remove_padding=config.model.get("use_remove_padding", False),
+ ulysses_sp_size=self.ulysses_sequence_parallel_size,
+ )
+
+ reward_module.to(torch.bfloat16)
+
+ auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config)
+
+ fsdp_mesh = self.device_mesh
+ sharding_strategy = get_sharding_strategy(fsdp_mesh)
+
+ if config.strategy == "fsdp":
+ reward_module = FSDP(
+ reward_module,
+ param_init_fn=init_fn,
+ use_orig_params=False,
+ auto_wrap_policy=auto_wrap_policy,
+ device_id=get_device_id(),
+ sharding_strategy=sharding_strategy, # zero3
+ sync_module_states=True,
+ cpu_offload=CPUOffload(offload_params=True),
+ forward_prefetch=self.config.model.fsdp_config.forward_prefetch,
+ device_mesh=self.device_mesh,
+ )
+ elif config.strategy == "fsdp2":
+ assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
+ cpu_offload = CPUOffloadPolicy(pin_memory=True)
+ fsdp_kwargs = {
+ "mesh": fsdp_mesh,
+ "offload_policy": cpu_offload,
+ "reshard_after_forward": config.model.fsdp_config.reshard_after_forward,
+ "shard_placement_fn": get_shard_placement_fn(fsdp_size=self.device_mesh.shape[-1]),
+ }
+ full_state = reward_module.state_dict()
+ apply_fsdp2(reward_module, fsdp_kwargs, config.model.fsdp_config)
+ fsdp2_load_full_state_dict(reward_module, full_state, fsdp_mesh, cpu_offload)
+ else:
+ raise NotImplementedError(f"Unknown strategy: {config.strategy}")
+ return reward_module
+
+ def _forward_micro_batch(self, micro_batch, start_of_response: int):
+ with torch.no_grad(), torch.autocast(device_type=device_name, dtype=torch.bfloat16):
+ input_ids = micro_batch["input_ids"]
+ batch_size, seqlen = input_ids.shape
+ attention_mask = micro_batch["attention_mask"]
+ position_ids = micro_batch["position_ids"]
+ if position_ids.dim() == 3: # qwen2vl mrope
+ position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen)
+
+ if self.use_remove_padding:
+ raise NotImplementedError("Remove padding is not implemented for elliptical reward model")
+ else:
+ output = self.reward_module(
+ input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False
+ )
+
+ sequence_lengths = attention_mask[:, start_of_response:].sum(dim=1)
+ mean_hidden_states = []
+ for i, seq_len in enumerate(sequence_lengths):
+ mean_hidden_states.append(
+ output.last_hidden_state[i, start_of_response : start_of_response + seq_len].mean(dim=0)
+ )
+ mean_hidden_states = torch.stack(mean_hidden_states)
+
+ return mean_hidden_states
+
+ @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
+ @DistProfiler.annotate(color="brown")
+ def compute_hidden_states(self, data: DataProto):
+ import itertools
+
+ from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
+
+ # Support all hardwares
+ data = data.to(get_device_id())
+ if self._do_switch_chat_template:
+ rm_data = self._switch_chat_template(data)
+ else:
+ rm_input_ids = data.batch["input_ids"]
+ rm_attention_mask = data.batch["attention_mask"]
+ rm_position_ids = data.batch["position_ids"]
+ rm_inputs = {
+ "input_ids": rm_input_ids,
+ "attention_mask": rm_attention_mask,
+ "position_ids": rm_position_ids,
+ }
+ rm_data = DataProto.from_dict(rm_inputs)
+
+ # Support all hardwares
+ rm_data = rm_data.to(get_device_id())
+
+ # perform forward computation
+ with self.ulysses_sharding_manager:
+ use_dynamic_bsz = self.config.use_dynamic_bsz
+ if use_dynamic_bsz:
+ max_token_len = self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
+ micro_batches, indices = rearrange_micro_batches(batch=rm_data.batch, max_token_len=max_token_len)
+ else:
+ micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu)
+ output = []
+ for micro_batch in micro_batches:
+ mean_hidden_states = self._forward_micro_batch(
+ micro_batch, start_of_response=data.batch["prompts"].shape[-1]
+ )
+ output.append(mean_hidden_states)
+ mean_hidden_states = torch.cat(output, dim=0) # (batch_size)
+
+ # NOTE(Jens): this has not been thoroughly checked
+ if use_dynamic_bsz:
+ indices = list(itertools.chain.from_iterable(indices))
+ assert len(indices) == mean_hidden_states.size(0), f"{len(indices)} vs. {mean_hidden_states.size()}"
+ revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
+ mean_hidden_states = mean_hidden_states[revert_indices]
+
+ # Note that this is only the scores, may not be the final rewards used to train RL
+ output = DataProto.from_dict(tensors={"mean_hidden_states": mean_hidden_states})
+
+ # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
+ # unshard the root FSDP module
+ if self.world_size > 1 and fsdp_version(self.reward_module) == 1:
+ self.reward_module._handle.reshard(True)
+
+ output = output.to("cpu")
+ return output
+
+ def _compute_bonuses(self, hidden_states, cov_inv, prompt_index: int):
+ if self.config.elliptical.reward_type == "leave_one_out":
+ if self.persist_covariance:
+ raise NotImplementedError("Leave-one-out with persistence is not implemented")
+ else:
+ bonuses = []
+ for i, hidden_state in enumerate(hidden_states):
+ chosen_samp = hidden_state.unsqueeze(1)
+ middle_part = torch.inverse(1 - chosen_samp.t() @ cov_inv @ chosen_samp)
+ leave_one_out_cov_inv = cov_inv + cov_inv @ chosen_samp @ middle_part @ chosen_samp.t() @ cov_inv
+ bonus = (chosen_samp.t() @ leave_one_out_cov_inv @ chosen_samp).flatten().float()
+ bonuses.append(bonus)
+
+ bonuses = torch.concat(bonuses)
+
+ elif self.config.elliptical.reward_type == "leverage":
+ if self.persist_covariance:
+ hidden_mean = self.mean_hidden_states_mu_dict[prompt_index]
+ hidden_mean_counter = self.hidden_mean_counter_dict[prompt_index]
+
+ hidden_states = hidden_states - hidden_mean
+
+ numerator = cov_inv @ hidden_mean.unsqueeze(1) @ hidden_mean.unsqueeze(0) @ cov_inv
+ denominator = -1 / hidden_mean_counter + hidden_mean.t() @ cov_inv @ hidden_mean
+ cov_inv_mean_adjusted = cov_inv - numerator / denominator
+ batch_cov_inv = cov_inv_mean_adjusted.unsqueeze(0).expand(hidden_states.shape[0], -1, -1)
+ else:
+ batch_cov_inv = cov_inv.unsqueeze(0).expand(hidden_states.shape[0], -1, -1)
+
+ bonuses = (hidden_states.unsqueeze(1) @ batch_cov_inv @ hidden_states.unsqueeze(2)).flatten().float()
+
+ return bonuses
+
+ def _normalize_bonuses(self, bonuses):
+ if self.normalization == "none":
+ pass
+ elif self.normalization == "rnd":
+ std = torch.std(bonuses)
+ if std > 0:
+ bonuses = bonuses / std
+ elif self.normalization == "z_score":
+ mean = torch.mean(bonuses)
+ std = torch.std(bonuses)
+ if std > 0:
+ bonuses = (bonuses - mean) / std
+ else:
+ bonuses = bonuses - mean
+ else:
+ raise ValueError(f"Unknown normalization: {self.normalization}")
+
+ return bonuses
+
+ @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO)
+ @DistProfiler.annotate(color="brown")
+ def compute_rm_score(self, data: DataProto):
+ if self.sparse_matrix is None:
+ d = data.batch["mean_hidden_states"].shape[-1]
+ sparse_matrix = self._construct_sparse_matrix(torch.randn(1, d), self.sparse_dim)
+ if not self.randomize_sparse_matrix:
+ self.sparse_matrix = sparse_matrix
+ else:
+ sparse_matrix = self.sparse_matrix
+
+ mean_hidden_states = data.batch["mean_hidden_states"].to(get_device_id()).float()
+
+ # sparse project
+ mean_hidden_states = mean_hidden_states @ sparse_matrix.to(get_device_id())
+
+ # upgrade to float64
+ mean_hidden_states = mean_hidden_states.to(torch.float64)
+
+ seen_uids = set()
+ reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32).to(get_device_id())
+ raw_bonuses_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32).to(get_device_id())
+ for i in range(len(data)):
+ data_item = data[i]
+ uid = data_item.non_tensor_batch["uid"]
+ if uid in seen_uids:
+ continue
+
+ seen_uids.add(uid)
+ mask = data.non_tensor_batch["uid"] == uid
+ filtered_mean_hidden_states = mean_hidden_states[mask]
+
+ prompt_index = data_item.non_tensor_batch["extra_info"]["index"]
+
+ if self.persist_covariance:
+ # first update the mean hidden states mu
+ if prompt_index not in self.mean_hidden_states_mu_dict:
+ self.mean_hidden_states_mu_dict[prompt_index] = filtered_mean_hidden_states.mean(dim=0)
+ self.hidden_mean_counter_dict[prompt_index] = mask.sum()
+ else:
+ total_count = self.hidden_mean_counter_dict[prompt_index] + mask.sum()
+ old_mu = self.mean_hidden_states_mu_dict[prompt_index]
+ new_mu = (
+ old_mu * self.hidden_mean_counter_dict[prompt_index]
+ + filtered_mean_hidden_states.mean(dim=0) * mask.sum()
+ ) / total_count
+ self.mean_hidden_states_mu_dict[prompt_index] = new_mu
+ self.hidden_mean_counter_dict[prompt_index] = total_count
+
+ # NOTE: we don't center here since otherwise the covariance will accumulate stale means
+ final_mean_hidden_states = filtered_mean_hidden_states
+
+ if prompt_index not in self.cov_inv_dict:
+ d = final_mean_hidden_states.shape[-1]
+ self.cov_inv_dict[prompt_index] = (
+ torch.eye(d, dtype=torch.float64).to(get_device_id()) * self.lamb**-1
+ )
+ cov_inv = self.cov_inv_dict[prompt_index]
+ else:
+ centered_mean_hidden_states = filtered_mean_hidden_states - filtered_mean_hidden_states.mean(dim=0)
+ final_mean_hidden_states = centered_mean_hidden_states
+
+ d = final_mean_hidden_states.shape[-1]
+ cov_inv = torch.eye(d, dtype=torch.float64).to(get_device_id()) * self.lamb**-1
+
+ # update inverse covariance matrix with rank-1 updates
+ for hidden_state in final_mean_hidden_states:
+ chosen_samp = hidden_state.unsqueeze(1)
+ middle_part = torch.inverse(1 + chosen_samp.t() @ cov_inv @ chosen_samp)
+ cov_inv = cov_inv - cov_inv @ chosen_samp @ middle_part @ chosen_samp.t() @ cov_inv
+
+ if self.persist_covariance:
+ self.cov_inv_dict[prompt_index] = cov_inv
+
+ raw_bonuses = self._compute_bonuses(final_mean_hidden_states, cov_inv, prompt_index)
+ normalized_bonuses = self._normalize_bonuses(raw_bonuses)
+
+ prompt_ids = data.batch["prompts"][mask]
+ prompt_length = prompt_ids.shape[-1]
+ valid_response_lengths = data.batch["attention_mask"][mask, prompt_length:].sum(-1)
+
+ raw_bonuses_tensor[mask, valid_response_lengths - 1] = raw_bonuses
+ reward_tensor[mask, valid_response_lengths - 1] = normalized_bonuses
+
+ output = DataProto.from_dict(
+ tensors={"rm_scores": reward_tensor}, non_tensors={"raw_bonuses": raw_bonuses_tensor.cpu().numpy()}
+ )
+ return output.to("cpu")