[recipe, algo] feat: Representation-based Exploration (RepExp)#4278
[recipe, algo] feat: Representation-based Exploration (RepExp)#4278jens321 wants to merge 10 commits intoverl-project:mainfrom
Conversation
|
|
There was a problem hiding this comment.
Code Review
This PR introduces the RepExp algorithm. The changes are extensive, adding a new recipe with multiple components. The implementation looks mostly correct, but I've found several issues that need to be addressed. There are critical bugs in the data preprocessing scripts that will cause them to crash. Some scripts use ~ in default file paths, which is not robust. There are brittle path manipulations in the logging utilities that can lead to IndexError. The documentation contains copy-paste errors in example commands. Finally, there's a performance issue in the EllipticalRewardModelWorker where a tensor is repeatedly moved to the GPU.
| dataset = datasets.load_dataset( | ||
| "parquet", | ||
| ) |
docs/algo/repexp.md
Outdated
| 2. Evaluate the merged model. | ||
|
|
||
| ```bash | ||
| sh recipe/rep_exp/eval.sh TASK /path/to/global_step_X/actor/hf #where X is the global step of the checkpoint with the best pass@1 on dev |
There was a problem hiding this comment.
The example evaluation command is incorrect. TASK is used as a literal string, but it should be the variable $TASK to match the shell command format. This will cause the command to fail if a user copies and pastes it.
| sh recipe/rep_exp/eval.sh TASK /path/to/global_step_X/actor/hf #where X is the global step of the checkpoint with the best pass@1 on dev | |
| sh recipe/rep_exp/eval.sh $TASK /path/to/global_step_X/actor/hf #where X is the global step of the checkpoint with the best pass@1 on dev |
recipe/rep_exp/README.md
Outdated
| 2. Evaluate the merged model. | ||
|
|
||
| ```bash | ||
| sh recipe/rep_exp/eval.sh TASK /path/to/global_step_X/actor/hf #where X is the global step of the checkpoint with the best pass@1 on dev |
There was a problem hiding this comment.
The example evaluation command is incorrect. TASK is used as a literal string, but it should be the variable $TASK to match the shell command format. This will cause the command to fail if a user copies and pastes it.
| sh recipe/rep_exp/eval.sh TASK /path/to/global_step_X/actor/hf #where X is the global step of the checkpoint with the best pass@1 on dev | |
| sh recipe/rep_exp/eval.sh $TASK /path/to/global_step_X/actor/hf #where X is the global step of the checkpoint with the best pass@1 on dev |
| train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) | ||
| dev_dataset = dev_dataset.map(function=make_map_fn("test"), with_indices=True) | ||
|
|
||
| local_dir = args.local_dir |
There was a problem hiding this comment.
The default value for --local_dir contains a tilde (~), which is not automatically expanded by argparse. This will result in creating a directory named ~ in the current working directory, instead of using the user's home directory. To fix this, you should expand the user's home directory path.
| local_dir = args.local_dir | |
| local_dir = os.path.expanduser(args.local_dir) |
|
|
||
| test_dataset = test_dataset.filter(filter_dev_indices, with_indices=True) | ||
|
|
||
| local_dir = args.local_dir |
There was a problem hiding this comment.
The default value for --local_dir contains a tilde (~), which is not automatically expanded. This will lead to incorrect path resolution. You should use os.path.expanduser to correctly resolve the home directory.
| local_dir = args.local_dir | |
| local_dir = os.path.expanduser(args.local_dir) |
| if save_path is not None and save_path != "": | ||
| self.experiment_name = save_path.split("/")[-2] | ||
| self.checkpoint_type = save_path.split("/")[-1] |
There was a problem hiding this comment.
This path splitting logic is not robust. If save_path does not have the expected directory structure (e.g., it's a top-level directory or a file), save_path.split("/")[-2] will raise an IndexError. You should add checks to handle different path structures gracefully.
| if save_path is not None and save_path != "": | |
| self.experiment_name = save_path.split("/")[-2] | |
| self.checkpoint_type = save_path.split("/")[-1] | |
| if save_path is not None and save_path != "": | |
| parts = save_path.rstrip("/").split("/") | |
| self.experiment_name = parts[-2] if len(parts) > 1 else f"{task}_untrained" | |
| self.checkpoint_type = parts[-1] |
| if model_path.endswith("actor/hf"): | ||
| # Case where the model path is a saved checkpoint | ||
| save_path = model_path.split("/")[-4:-2] | ||
| save_path = "/".join(save_path) |
There was a problem hiding this comment.
This path splitting logic is brittle and assumes a fixed directory structure. If model_path does not match this structure, model_path.split("/")[-4:-2] could lead to an IndexError or an incorrect save_path. Consider using pathlib or more robust path parsing to avoid potential crashes.
| if model_path.endswith("actor/hf"): | |
| # Case where the model path is a saved checkpoint | |
| save_path = model_path.split("/")[-4:-2] | |
| save_path = "/".join(save_path) | |
| path_parts = model_path.split("/") | |
| if len(path_parts) >= 4: | |
| save_path = "/".join(path_parts[-4:-2]) | |
| else: | |
| save_path = "" |
|
|
||
| # Parse task from config | ||
| train_file = config["data"]["train_files"][0] | ||
| task = train_file.split("/")[-2] |
There was a problem hiding this comment.
Hardcoding the index [-2] to extract the task name makes the code fragile. If the train_file path structure changes, this will either fail with an IndexError or extract the wrong directory name. Using pathlib would make this more robust. You would need to add from pathlib import Path at the top of the file.
| task = train_file.split("/")[-2] | |
| task = Path(train_file).parent.name |
| if self.sparse_matrix is None: | ||
| d = data.batch["mean_hidden_states"].shape[-1] | ||
| sparse_matrix = self._construct_sparse_matrix(torch.randn(1, d), self.sparse_dim) | ||
| if not self.randomize_sparse_matrix: | ||
| self.sparse_matrix = sparse_matrix | ||
| else: | ||
| sparse_matrix = self.sparse_matrix | ||
|
|
||
| mean_hidden_states = data.batch["mean_hidden_states"].to(get_device_id()).float() | ||
|
|
||
| # sparse project | ||
| mean_hidden_states = mean_hidden_states @ sparse_matrix.to(get_device_id()) |
There was a problem hiding this comment.
There is a performance issue here. The sparse_matrix is moved to the GPU on every call to compute_rm_score inside the mean_hidden_states @ sparse_matrix.to(get_device_id()) operation. When randomize_sparse_matrix is False, the matrix is cached but remains on the CPU. To improve efficiency, the matrix should be moved to the GPU only once when it is created.
| if self.sparse_matrix is None: | |
| d = data.batch["mean_hidden_states"].shape[-1] | |
| sparse_matrix = self._construct_sparse_matrix(torch.randn(1, d), self.sparse_dim) | |
| if not self.randomize_sparse_matrix: | |
| self.sparse_matrix = sparse_matrix | |
| else: | |
| sparse_matrix = self.sparse_matrix | |
| mean_hidden_states = data.batch["mean_hidden_states"].to(get_device_id()).float() | |
| # sparse project | |
| mean_hidden_states = mean_hidden_states @ sparse_matrix.to(get_device_id()) | |
| if self.sparse_matrix is None: | |
| d = data.batch["mean_hidden_states"].shape[-1] | |
| sparse_matrix = self._construct_sparse_matrix(torch.randn(1, d), self.sparse_dim).to(get_device_id()) | |
| if not self.randomize_sparse_matrix: | |
| self.sparse_matrix = sparse_matrix | |
| else: | |
| sparse_matrix = self.sparse_matrix | |
| mean_hidden_states = data.batch["mean_hidden_states"].to(get_device_id()).float() | |
| # sparse project | |
| mean_hidden_states = mean_hidden_states @ sparse_matrix |
|
@jens321 Hi, thanks for your contribution. We're moving recipe to a separate project verl-project/verl-recipe, could you submit a PR to this project? #4283 |
|
@wuxibin89 Thanks for the quick reply! Sounds good, will go ahead and close this one then. |
What does this PR do?
Add support for the training and evaluation of the RepExp method introduced in section 5 of the paper Representation-Based Exploration for Language Models: From Test-Time to Post-Training.
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
Please refer to Figure 2 in https://arxiv.org/abs/2510.11686 for evaluation results.
API and Usage Example
The general format for training with our method is as follows
where
$TASKis the task name,$SPARSE_DIMis the sparse dimension,$BETAis the beta parameter, and$SEEDis the seed.For example for training on MATH with the original parameters from the paper, one would do
Once done training, one can evaluate the model on the test set by following two steps.
This is necessary because the model checkpoint is saved in multiple shards (depending on the nubmer of GPUs), and we need to merge them into a single checkpoint.
sh recipe/rep_exp/model_merge.sh /path/to/global_step_X/actor # where X is the global step of the checkpoint with the best pass@1 on devDesign & Code Changes
All changes are contained in the
recipe/rep_expfolder and summarized below:rep_exp/main_rep_exp.py: copy ofverl/trainer/main_ppo.pybut importsEllipticalRewardModelWorkerinstead of the standardRewardModelWorkerrep_exp/rep_exp_trainer.py: copy ofverl/trainer/ppo/ray_trainer.pybut adds computing of hidden states and elliptical reward scores as follows:In addition, there are a few lines of code that help keep track of the best pass@1 seen so far on validation.
rep_exp/metric_utils.py: adds and extends some utility functions that provide additional metrics for our method that could be helpful for debugging.rep_exp/workers/elliptical_reward_model_worker.py: adds the EllipticalRewardModelWorker class which provides functionality for (1) computing hidden states and (2) computing elliptical reward scores based on the hidden states.rep_exp/reward_manager: adds the EllipticalRewardManager class that handles combining the external reward (from a verifier) and the elliptical reward.rep_exp/utils/tracking.pyandrep_exp/utils/aggregate_logger.py: adds the JsonEvalLogger which can be used to log final evaluation results to a json file.rep_exp/data_preprocess.py: contains a script for each of MATH, GSM8K, and AIME 2024 that provide the logic for getting the dataset splits (train, dev, test).rep_exp/reward_score/__init__.py: copy ofverl/utils/reward_score/__init__.pybut usesmath_verifyfor dapo trainingrep_exp/plot_pass_at_k.py: sample script that provides basic plotting code that plots a pass@k curve based on the logged json files that are saved after running the evaluation script.rep_exp/config/rep_exp_trainer.yaml: overwrites and adds any RepExp specific configuration parameters.rep_exp/train_elliptical.sh: training scriptrep_exp/model_merge.sh: script to merge model checkpointsrep_exp/eval.sh: evaluation scriptChecklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)