diff --git a/.github/workflows/e2e_gsm8k_dapo.yml b/.github/workflows/e2e_gsm8k_dapo.yml new file mode 100644 index 00000000000..633d0975155 --- /dev/null +++ b/.github/workflows/e2e_gsm8k_dapo.yml @@ -0,0 +1,54 @@ +name: e2e_gsm8k_dapo + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + - v0.2.x + paths: + - "**/*.py" + - .github/workflows/e2e_gsm8k_dapo.yml + pull_request: + branches: + - main + - v0.2.x + paths: + - "**/*.py" + - "verl/trainer/config/*.yaml" + - .github/workflows/e2e_gsm8k_dapo.yml + - "tests/e2e/*.sh" + +# Declare permissions just read content. +permissions: + contents: read + +jobs: + e2e_gsm8k_dapo: + runs-on: [self-hosted, l20-1] + timeout-minutes: 40 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1" + HF_HUB_ENABLE_HF_TRANSFER: 1 + container: + image: verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3 + options: --gpus all --shm-size=10g + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install hf_transfer + pip3 install -e .[test,gpu] + - name: Prepare gsm8k dataset + run: | + ray stop --force + python3 examples/data_preprocess/gsm8k.py + - name: Running gsm8k e2e with dapo alg + run: | + ray stop --force + bash tests/e2e/run_qwen_gsm8k_dapo.sh \ No newline at end of file diff --git a/.github/workflows/e2e_gsm8k_prime.yml b/.github/workflows/e2e_gsm8k_prime.yml index e14097cb84f..15e373c8ed2 100644 --- a/.github/workflows/e2e_gsm8k_prime.yml +++ b/.github/workflows/e2e_gsm8k_prime.yml @@ -25,7 +25,7 @@ permissions: contents: read jobs: - e2e_gsm8k: + e2e_gsm8k_prime: runs-on: [self-hosted, l20-1] timeout-minutes: 40 # Increase this timeout value as needed env: diff --git a/examples/split_placement/main_ppo_split.py b/examples/split_placement/main_ppo_split.py index 311b35a8ec9..85347be0929 100644 --- a/examples/split_placement/main_ppo_split.py +++ b/examples/split_placement/main_ppo_split.py @@ -36,7 +36,7 @@ def __init__(self, tokenizer, num_examine) -> None: self.tokenizer = tokenizer self.num_examine = num_examine # the number of batches of decoded responses to print to the console - def __call__(self, data: DataProto): + def __call__(self, data: DataProto, return_dict: bool = False): """We will expand this function gradually based on the available datasets""" # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn @@ -81,7 +81,10 @@ def __call__(self, data: DataProto): already_print_data_sources[data_source] += 1 print(sequences_str) - return reward_tensor + if return_dict: + return {"reward_tensor": reward_tensor} + else: + return reward_tensor import ray diff --git a/recipe/dapo/README.md b/recipe/dapo/README.md new file mode 100644 index 00000000000..5c5fdbb6412 --- /dev/null +++ b/recipe/dapo/README.md @@ -0,0 +1,163 @@ +# DAPO Open-Source Implementation + +> Open-Source Algorithm Implementation & Expriement Running: [Yuxuan Tong](https://tongyx361.github.io/), [Guangming Sheng](https://hk.linkedin.com/in/guangming-sheng-b50640211) + +> [!IMPORTANT] +> **🔥 News!!!** +> - [2025/03] We published the training record of [an early version of DAPO (w/o Token-level PG Loss & Dynamic Sampling)](./run_dapo_early_qwen2.5_32b.sh), achieving 44% on AIME 2024, in [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl). + +🏠 [Homepage](https://dapo-sia.github.io/) | 📝 [Paper](https://dapo-sia.github.io/static/pdf/dapo_paper.pdf) | 🤗 [Datasets&Models@HF](https://huggingface.co/collections/BytedTsinghua-SIA/dapo-67d7f1517ee33c8aed059da0) | 🐱 [Code@GitHub](https://github.com/volcengine/verl/tree/gm-tyx/puffin/main/recipe/dapo) | 🐱 [Repo@GitHub](https://github.com/BytedTsinghua-SIA/DAPO) + +> We propose the **D**ecoupled Clip and Dynamic s**A**mpling **P**olicy **O**ptimization (DAPO) algorithm. By making our work publicly available, we provide the broader research community and society with practical access to scalable reinforcement learning, enabling all to benefit from these advancements. Our system is based on the awesome [verl](https://github.com/volcengine/verl) framework. Thanks for their great work! Applying DAPO training to Qwen2.5-32B base model proves to outperform the previous state-of-the-art DeepSeek-R1-Zero-Qwen-32B on AIME 2024, achieving **50%** accuracy with **50%** less training steps. +> +> ![dapo-main-result](https://dapo-sia.github.io/static/images/score.png) + +## Quickstart + +1. Prepare the datasets **on the Ray cluster**: + +```bash +bash prepare_dapo_data.sh # This downloads the datasets to ${HOME}/verl/data by default +``` + +2. Submit the job to the Ray cluster **from any machine**: + +```bash +cd verl # Repo root +export RAY_ADDRESS="http://${RAY_IP:-localhost}:8265" # The Ray cluster address to connect to +export WORKING_DIR="${PWD}" # The local directory to package to the Ray cluster +# Set the runtime environment like env vars and pip packages for the Ray cluster in yaml +export RUNTIME_ENV="./verl/trainer/runtime_env.yaml" +bash recipe/dapo/run_dapo_qwen2.5_32b.sh +``` + +## Reproduction Runs + +| Setup | AIME 2024 Acc. | Training Script | Training Record | +|-------|----------------------|-----------------|-----------------| +| DAPO w/o Token-level PG Loss & Dynamic Sampling | 44% | [run_dapo_early_qwen2.5_32b.sh](./run_dapo_early_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl) | +| DAPO | 50% | [run_dapo_qwen2.5_32b.sh](./run_dapo_qwen2.5_32b.sh) | W&B (Coming soon) | + +## Configuration + +> [!NOTE] +> Most experiments in the paper, including the best-performant one, are run without Overlong Filtering because it's somehow overlapping with Overlong Reward Shaping in terms of properly learning from the longest outputs. So we don't implement it here. + +### Separated Clip Epsilons (-> Clip-Higher) + +An example configuration: + +```yaml +actor_rollout_ref: + actor: + clip_ratio_low: 0.2 + clip_ratio_high: 0.28 +``` + +`clip_ratio_low` and `clip_ratio_high` specify the $\varepsilon_{\text {low }}$ and $\varepsilon_{\text {high }}$ in the DAPO objective. + +Core relevant code: + +```python +pg_losses1 = -advantages * ratio +pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) +pg_losses = torch.maximum(pg_losses1, pg_losses2) +``` + +### Dynamic Sampling (with Group Filtering) + +An example configuration: + +```yaml +data: + gen_batch_size: 1536 + train_batch_size: 512 +algorithm: + filter_groups: + enable: True + metric: acc # score / seq_reward / seq_final_reward / ... + max_num_gen_batches: 10 # Non-positive values mean no upper limit +``` + +Setting `filter_groups.enable` to `True` will filter out groups whose outputs' `metric` are all the same, e.g., for `acc`, groups whose outputs' accuracies are all 1 or 0. + +The trainer will repeat sampling with `gen_batch_size` until there are enough qualified groups for `train_batch_size` or reaching the upper limit specified by `max_num_gen_batches`. + +Core relevant code: + +```python +prompt_bsz = self.config.data.train_batch_size +if num_prompt_in_batch < prompt_bsz: + print(f'{num_prompt_in_batch=} < {prompt_bsz=}') + num_gen_batches += 1 + max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches + if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches: + print(f'{num_gen_batches=} < {max_num_gen_batches=}. Keep generating...') + continue + else: + raise ValueError( + f'{num_gen_batches=} >= {max_num_gen_batches=}. Generated too many. Please check your data.' + ) +else: + # Align the batch + traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n + batch = batch[:traj_bsz] +``` + +### Flexible Loss Aggregation Mode (-> Token-level Policy Gradient Loss) + +An example configuration: + +```yaml +actor_rollout_ref: + actor: + loss_agg_mode: "token-mean" # / "seq-mean-token-sum" / "seq-mean-token-mean" + # NOTE: "token-mean" is the default behavior +``` + +Setting `loss_agg_mode` to `token-mean` will mean the (policy gradient) loss across all the tokens in all the sequences in a mini-batch. + +Core relevant code: + +```python +if loss_agg_mode == "token-mean": + pg_loss = verl_F.masked_mean(pg_losses, eos_mask) +elif loss_agg_mode == "seq-mean-token-sum": + pg_loss = torch.sum(pg_losses * eos_mask, dim=-1) / torch.sum(eos_mask, dim=-1) + pg_loss = torch.mean(pg_loss) +elif loss_agg_mode == "seq-mean-token-mean": + pg_loss = torch.sum(pg_losses * eos_mask, dim=-1) / torch.sum(eos_mask, dim=-1) + pg_loss = torch.mean(pg_loss) +else: + raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}") +``` + +### Overlong Reward Shaping + +An example configuration: + +```yaml +data: + max_response_length: 20480 # 16384 + 4096 +reward_model: + overlong_buffer: + enable: True + len: 4096 + penalty_factor: 1.0 +``` + +Setting `overlong_buffer.enable` to `True` will penalize the outputs whose lengths are overlong but still within the hard context limit. + +Specifically, the penalty increases linearly from `0` to `overlong_buffer.penalty_factor` when the length of the output exceeds the `max_response_length` by `0` to `overlong_buffer.len` tokens. + +Core relevant code: + +```python +if self.overlong_buffer_cfg.enable: + overlong_buffer_len = self.overlong_buffer_cfg.len + expected_len = self.max_resp_len - overlong_buffer_len + exceed_len = valid_response_length - expected_len + overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor + overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0) + reward += overlong_reward +``` diff --git a/recipe/dapo/prepare_dapo_data.sh b/recipe/dapo/prepare_dapo_data.sh new file mode 100644 index 00000000000..c6c60bf2020 --- /dev/null +++ b/recipe/dapo/prepare_dapo_data.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash +set -uxo pipefail + +export VERL_HOME=${VERL_HOME:-"${HOME}/verl"} +export TRAIN_FILE=${TRAIN_FILE:-"${VERL_HOME}/data/dapo-math-17k.parquet"} +export TEST_FILE=${TEST_FILE:-"${VERL_HOME}/data/aime-2024.parquet"} + +mkdir -p "${VERL_HOME}/data" + +wget -O "${TRAIN_FILE}" "https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k/resolve/main/data/dapo-math-17k.parquet?download=true" + +wget -O "${TEST_FILE}" "https://huggingface.co/datasets/BytedTsinghua-SIA/AIME-2024/resolve/main/data/aime-2024.parquet?download=true" \ No newline at end of file diff --git a/recipe/dapo/run_dapo_early_qwen2.5_32b.sh b/recipe/dapo/run_dapo_early_qwen2.5_32b.sh new file mode 100644 index 00000000000..3b1ebfce6ee --- /dev/null +++ b/recipe/dapo/run_dapo_early_qwen2.5_32b.sh @@ -0,0 +1,131 @@ +#!/usr/bin/env bash +set -euxo pipefail + +project_name='DAPO' +exp_name='DAPO-Early-Qwen2.5-32B' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 20)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +# An early version for DAPO +loss_agg_mode="seq-mean-token-sum" + +enable_filter_groups=False +gen_prompt_bsz=512 # NOTE: no filtering here +train_prompt_bsz=512 +train_prompt_mini_bsz=32 +n_resp_per_prompt=16 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-16} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + + +# Performance Related Parameter +sp_size=8 +use_dynamic_bsz=True +actor_ppo_max_token_len=$((max_prompt_length + max_response_length)) +infer_ppo_max_token_len=$((max_prompt_length + max_response_length)) +offload=True +gen_tp=4 + +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + --working-dir "${WORKING_DIR}" \ + -- python3 -m recipe.dapo.src.main_dapo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=5 \ + trainer.save_freq=5 \ + trainer.total_epochs=1 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ No newline at end of file diff --git a/recipe/dapo/run_dapo_qwen2.5_32b.sh b/recipe/dapo/run_dapo_qwen2.5_32b.sh new file mode 100644 index 00000000000..3943a2e89b8 --- /dev/null +++ b/recipe/dapo/run_dapo_qwen2.5_32b.sh @@ -0,0 +1,132 @@ +#!/usr/bin/env bash +set -euxo pipefail + +project_name='DAPO' +exp_name='DAPO-Qwen2.5-32B' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 20)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=512 +gen_prompt_bsz=$((train_prompt_bsz * 3)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-16} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Performance Related Parameter +sp_size=8 +use_dynamic_bsz=True +actor_ppo_max_token_len=$((max_prompt_length + max_response_length)) +infer_ppo_max_token_len=$((max_prompt_length + max_response_length)) +offload=True +gen_tp=4 + +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + --working-dir "${WORKING_DIR}" \ + -- python3 -m recipe.dapo.src.main_dapo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=5 \ + trainer.save_freq=5 \ + trainer.total_epochs=1 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ No newline at end of file diff --git a/recipe/dapo/src/config/dapo_trainer.yaml b/recipe/dapo/src/config/dapo_trainer.yaml new file mode 100644 index 00000000000..bcafe90f691 --- /dev/null +++ b/recipe/dapo/src/config/dapo_trainer.yaml @@ -0,0 +1,223 @@ +data: + tokenizer: null + train_files: ~/data/rlhf/gsm8k/train.parquet + val_files: ~/data/rlhf/gsm8k/test.parquet + prompt_key: prompt + reward_fn_key: data_source + max_prompt_length: 512 + max_response_length: 512 + gen_batch_size: ${data.train_batch_size} + train_batch_size: 1024 + val_batch_size: null # DEPRECATED: Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves + return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs + return_raw_chat: False + shuffle: True + filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You should disable this and set `truncation='left' + truncation: error + image_key: images + +actor_rollout_ref: + hybrid_engine: True + model: + path: ~/models/deepseek-llm-7b-chat + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_mini_batch_size: 256 + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: null + use_dynamic_bsz: False + ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + # pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) + clip_ratio: 0.2 # default value if clip_ratio_low and clip_ratio_high are not specified + clip_ratio_low: 0.2 + clip_ratio_high: 0.2 + clip_ratio_c: 3.0 # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729 + loss_agg_mode: "token-mean" # / "seq-mean-token-sum" / "seq-mean-token-mean" + # NOTE: "token-mean" is the default behavior + entropy_coeff: 0.001 + use_kl_loss: False # True for GRPO + use_torch_compile: True # False to disable torch compile + kl_loss_coef: 0.001 # for grpo + kl_loss_type: low_var_kl # for grpo + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + checkpoint: + contents: ['model', 'hf_model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + optim: + lr: 1e-6 + lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio. + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + weight_decay: 0.01 + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + rollout: + name: vllm + temperature: 1.0 + top_k: -1 # 0 for hf rollout, -1 for vllm rollout + top_p: 1 + use_fire_sampling: False # https://arxiv.org/abs/2410.21236 + prompt_length: ${data.max_prompt_length} # not use for opensource + response_length: ${data.max_response_length} + # for vllm rollout + dtype: bfloat16 # should align with FSDP + gpu_memory_utilization: 0.5 + ignore_eos: False + enforce_eager: True + free_cache_engine: True + load_format: dummy_dtensor + tensor_model_parallel_size: 2 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + disable_log_stats: True + enable_chunked_prefill: True # may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len. + # for hf rollout + do_sample: True + # number of responses (i.e. num sample times) + n: 1 # > 1 for grpo + val_kwargs: + # sampling parameters for validation + top_k: -1 # 0 for hf rollout, -1 for vllm rollout + top_p: 1.0 + temperature: 0 + n: 1 + do_sample: False # default eager for validation + +critic: + rollout_n: ${actor_rollout_ref.rollout.n} + strategy: fsdp + optim: + lr: 1e-5 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + weight_decay: 0.01 + model: + path: ~/models/deepseek-llm-7b-chat + tokenizer_path: ${actor_rollout_ref.model.path} + override_config: { } + external_lib: ${actor_rollout_ref.model.external_lib} + enable_gradient_checkpointing: True + use_remove_padding: False + fsdp_config: + param_offload: False + optimizer_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + fsdp_size: -1 + ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: null + forward_micro_batch_size: ${critic.ppo_micro_batch_size} + forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} + use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2 + forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: 1 # sp size + ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} + shuffle: ${actor_rollout_ref.actor.shuffle} + grad_clip: 1.0 + cliprange_value: 0.5 + checkpoint: + contents: ['model', 'hf_model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + +reward_model: + enable: False + strategy: fsdp + model: + input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + use_remove_padding: False + fsdp_config: + wrap_policy: + min_num_params: 0 + param_offload: False + fsdp_size: -1 + micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu + micro_batch_size_per_gpu: null # set a number + max_length: null + ulysses_sequence_parallel_size: 1 # sp size + use_dynamic_bsz: ${critic.use_dynamic_bsz} + forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} + reward_manager: naive + overlong_buffer: + enable: False # We try to avoid forgetting to set enable + len: 0 + penalty_factor: 0.0 + log: False + +custom_reward_function: + path: null + name: compute_score + +algorithm: + gamma: 1.0 + lam: 1.0 + adv_estimator: gae + use_kl_in_reward: False + kl_penalty: kl # how to estimate kl divergence + kl_ctrl: + type: fixed + kl_coef: 0.001 + horizon: 10000 + target_kl: 0.1 + filter_groups: + enable: False # We try to avoid forgetting to set enable + metric: null # acc / score / seq_reward / seq_final_reward / ... + max_num_gen_batches: 0 # Non-positive values mean no upper limit + +trainer: + balance_batch: True + total_epochs: 30 + total_training_steps: null + project_name: verl_examples + experiment_name: gsm8k + logger: [ 'console', 'wandb' ] + log_val_generations: 0 + nnodes: 1 + n_gpus_per_node: 8 + save_freq: -1 + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or disable or resume_path if resume_from_path is set + resume_from_path: null + val_before_train: True + test_freq: -1 + critic_warmup: 0 + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + max_actor_ckpt_to_keep: null + max_critic_ckpt_to_keep: null diff --git a/recipe/dapo/src/dapo_ray_trainer.py b/recipe/dapo/src/dapo_ray_trainer.py new file mode 100644 index 00000000000..7fa161af336 --- /dev/null +++ b/recipe/dapo/src/dapo_ray_trainer.py @@ -0,0 +1,302 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +FSDP PPO Trainer with Ray-based single controller. +This trainer supports model-agonistic model initialization with huggingface +""" + +import uuid +from pprint import pprint +from copy import deepcopy +from collections import defaultdict +from tqdm import tqdm +import numpy as np +import torch + +from verl import DataProto +from verl.trainer.ppo.ray_trainer import RayPPOTrainer, _timer, apply_kl_penalty, compute_advantage, AdvantageEstimator +from verl.trainer.ppo.metric_utils import (compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, + reduce_metrics) + + +class RayDAPOTrainer(RayPPOTrainer): + """ + Note that this trainer runs on the driver process on a single CPU/GPU node. + """ + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from verl.utils.tracking import Tracking + from omegaconf import OmegaConf + + logger = Tracking(project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True)) + + self.global_steps = 0 + + # load checkpoint before doing anything + self._load_checkpoint() + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True): + val_metrics = self._validate() + pprint(f'Initial validation metrics: {val_metrics}') + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get('val_only', False): + return + + # add tqdm + progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + + # we start from step 1 + self.global_steps += 1 + last_val_metrics = None + + timing_raw = defaultdict(float) + batch = None + num_prompt_in_batch = 0 + num_gen_batches = 0 + for epoch in range(self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + metrics = {} + + new_batch: DataProto = DataProto.from_single_dict(batch_dict) + num_gen_batches += 1 + # pop those keys for generation + if 'multi_modal_inputs' in new_batch.non_tensor_batch.keys(): + gen_batch = new_batch.pop( + batch_keys=['input_ids', 'attention_mask', 'position_ids'], + non_tensor_batch_keys=['raw_prompt_ids', 'multi_modal_data', 'multi_modal_inputs'], + ) + else: + gen_batch = new_batch.pop( + batch_keys=['input_ids', 'attention_mask', 'position_ids'], + non_tensor_batch_keys=['raw_prompt_ids'], + ) + + is_last_step = self.global_steps >= self.total_training_steps + + with _timer('step', timing_raw): + # generate a batch + with _timer('gen', timing_raw): + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + + if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: + with _timer('gen_max', timing_raw): + gen_baseline_batch = deepcopy(gen_batch) + gen_baseline_batch.meta_info['do_sample'] = False + gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) + + new_batch = new_batch.union(gen_baseline_output) + reward_baseline_tensor = self.reward_fn(new_batch) + reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) + + new_batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) + + new_batch.batch['reward_baselines'] = reward_baseline_tensor + + del gen_baseline_batch, gen_baseline_output + + new_batch.non_tensor_batch['uid'] = np.array( + [str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object) + # repeat to align with repeated responses in rollout + new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + new_batch = new_batch.union(gen_batch_output) + + with _timer('reward', timing_raw): + # compute scores. Support both model and function-based. + # We first compute the scores using reward model. Then, we call reward_fn to combine + # the results from reward model and rule-based results. + if self.use_rm: + # we first compute reward model score + reward_tensor = self.rm_wg.compute_rm_score(new_batch) + new_batch = new_batch.union(reward_tensor) + + # we combine with rule-based rm + reward_extra_infos_dict: dict[str, list] + try: + reward_result = self.reward_fn(new_batch, return_dict=True) + reward_tensor = reward_result['reward_tensor'] + reward_extra_infos_dict = reward_result['reward_extra_info'] + except Exception as e: + print(f'Error in reward_fn: {e}') + reward_tensor = self.reward_fn(new_batch) + reward_extra_infos_dict = {} + + new_batch.batch['token_level_scores'] = reward_tensor + + print(f'{list(reward_extra_infos_dict.keys())=}') + if reward_extra_infos_dict: + new_batch.non_tensor_batch.update({ + k: np.array(v) for k, v in reward_extra_infos_dict.items() + }) + + # compute rewards. apply_kl_penalty if available + if self.config.algorithm.use_kl_in_reward: + new_batch, kl_metrics = apply_kl_penalty(new_batch, + kl_ctrl=self.kl_ctrl_in_reward, + kl_penalty=self.config.algorithm.kl_penalty) + metrics.update( + kl_metrics) # TODO: This will be cleared if we use multiple genenration batches + else: + new_batch.batch['token_level_rewards'] = new_batch.batch['token_level_scores'] + + if not self.config.algorithm.filter_groups.enable: + batch = new_batch + else: # NOTE: When prompts after filtering is less than train batch size, we skip to the next generation batch + metric_name = self.config.algorithm.filter_groups.metric + if metric_name == "seq_final_reward": + # Turn to numpy for easier filtering + new_batch.non_tensor_batch["seq_final_reward"] = new_batch.batch['token_level_scores'].sum( + dim=-1).numpy() + elif metric_name == "seq_reward": + new_batch.non_tensor_batch["seq_reward"] = new_batch.batch['token_level_scores'].sum( + dim=-1).numpy() + + # Collect the sequence reward for each trajectory + prompt_uid2metric_vals = defaultdict(list) + for uid, metric_val in zip(new_batch.non_tensor_batch['uid'], + new_batch.non_tensor_batch[metric_name]): + prompt_uid2metric_vals[uid].append(metric_val) + + prompt_uid2metric_std = {} + for prompt_uid, metric_vals in prompt_uid2metric_vals.items(): + prompt_uid2metric_std[prompt_uid] = np.std(metric_vals) + + kept_prompt_uids = [ + uid for uid, std in prompt_uid2metric_std.items() + if std > 0 or len(prompt_uid2metric_vals[uid]) == 1 + ] + num_prompt_in_batch += len(kept_prompt_uids) + + kept_traj_idxs = [] + for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch['uid']): + if traj_from_prompt_uid in kept_prompt_uids: + kept_traj_idxs.append(idx) + + new_batch = new_batch[kept_traj_idxs] + if batch is None: + batch = new_batch + else: + batch = DataProto.concat([batch, new_batch]) + + prompt_bsz = self.config.data.train_batch_size + if num_prompt_in_batch < prompt_bsz: + print(f'{num_prompt_in_batch=} < {prompt_bsz=}') + max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches + if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches: + print(f'{num_gen_batches=}. Keep generating...') + continue + else: + raise ValueError( + f'{num_gen_batches=} >= {max_num_gen_batches=}. Generated too many. Please check your data.' + ) + else: + # Align the batch + traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n + batch = batch[:traj_bsz] + + # balance the number of valid tokens on each dp rank. + # Note that this breaks the order of data inside the batch. + # Please take care when you implement group based adv computation such as GRPO and rloo + if self.config.trainer.balance_batch: + self._balance_batch(batch, metrics=metrics) + + # compute global_valid tokens + batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist() + + # recompute old_log_probs + with _timer('old_log_prob', timing_raw): + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + batch = batch.union(old_log_prob) + + if self.use_reference_policy: + # compute reference log_prob + with _timer('ref', timing_raw): + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + # compute values + if self.use_critic: + with _timer('values', timing_raw): + values = self.critic_wg.compute_values(batch) + batch = batch.union(values) + + with _timer('adv', timing_raw): + # compute advantages, executed on the driver process + batch = compute_advantage(batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n) + + # update critic + if self.use_critic: + with _timer('update_critic', timing_raw): + critic_output = self.critic_wg.update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics']) + metrics.update(critic_output_metrics) + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + # update actor + with _timer('update_actor', timing_raw): + actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) + metrics.update(actor_output_metrics) + + # validate + if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \ + (is_last_step or self.global_steps % self.config.trainer.test_freq == 0): + with _timer('testing', timing_raw): + val_metrics: dict = self._validate() + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + + if self.config.trainer.save_freq > 0 and (is_last_step or + self.global_steps % self.config.trainer.save_freq == 0): + with _timer('save_checkpoint', timing_raw): + self._save_checkpoint() + + # collect metrics + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + # TODO: implement actual tflpo and theoretical tflpo + n_gpus = self.resource_pool_manager.get_n_gpus() + metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + timing_raw = defaultdict(float) # clear timing + + metrics["train/num_gen_batches"] = num_gen_batches + batch = None + num_prompt_in_batch = 0 + num_gen_batches = 0 + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + if is_last_step: + pprint(f'Final validation metrics: {last_val_metrics}') + progress_bar.close() + return + + progress_bar.update(1) + self.global_steps += 1 diff --git a/recipe/dapo/src/main_dapo.py b/recipe/dapo/src/main_dapo.py new file mode 100644 index 00000000000..df96d4bbe14 --- /dev/null +++ b/recipe/dapo/src/main_dapo.py @@ -0,0 +1,193 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. +""" +from .dapo_ray_trainer import RayDAPOTrainer + +import os +import ray +import hydra + + +def get_custom_reward_fn(config): + import importlib.util, os + + reward_fn_config = config.get("custom_reward_function") or {} + file_path = reward_fn_config.get("path") + if not file_path: + return None + + if not os.path.exists(file_path): + raise FileNotFoundError(f"Reward function file '{file_path}' not found.") + + spec = importlib.util.spec_from_file_location("custom_module", file_path) + module = importlib.util.module_from_spec(spec) + try: + spec.loader.exec_module(module) + except Exception as e: + raise RuntimeError(f"Error loading module from '{file_path}': {e}") + + function_name = reward_fn_config.get("name") + + if not hasattr(module, function_name): + raise AttributeError(f"Reward function '{function_name}' not found in '{file_path}'.") + + print(f"using customized reward function '{function_name}' from '{file_path}'") + + return getattr(module, function_name) + + +@hydra.main(config_path='config', config_name='dapo_trainer', version_base=None) +def main(config): + run_ppo(config) + + +def run_ppo(config) -> None: + # TODO(linjunrong.ocss884): this ENV is left for resolving SGLang conflict with ray devices + # isolation, will solve in the future + os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get('CUDA_VISIBLE_DEVICES', '') + if not ray.is_initialized(): + # this is for local ray cluster + ray.init(runtime_env={ + 'env_vars': { + 'TOKENIZERS_PARALLELISM': 'true', + 'NCCL_DEBUG': 'WARN', + 'VLLM_LOGGING_LEVEL': 'WARN' + } + }) + + runner = TaskRunner.remote() + ray.get(runner.run.remote(config)) + + +@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head +class TaskRunner: + + def run(self, config): + from verl.utils.fs import copy_to_local + # print initial config + from pprint import pprint + from omegaconf import OmegaConf + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + OmegaConf.resolve(config) + + # download the checkpoint from hdfs + local_path = copy_to_local(config.actor_rollout_ref.model.path) + + # instantiate tokenizer + from verl.utils import hf_tokenizer, hf_processor + tokenizer = hf_tokenizer(local_path) + processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none + + # define worker classes + if config.actor_rollout_ref.actor.strategy == 'fsdp': + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker + from verl.single_controller.ray import RayWorkerGroup + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == 'megatron': + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup + ray_worker_group_cls = NVMegatronRayWorkerGroup + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + + role_worker_mapping = { + Role.ActorRollout: ray.remote(ActorRolloutRefWorker), + Role.Critic: ray.remote(CriticWorker), + Role.RefPolicy: ray.remote(ActorRolloutRefWorker) + } + + global_pool_id = 'global_pool' + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + mapping = { + Role.ActorRollout: global_pool_id, + Role.Critic: global_pool_id, + Role.RefPolicy: global_pool_id, + } + + # we should adopt a multi-source reward function here + # - for rule-based rm, we directly call a reward score + # - for model-based rm, we call a model + # - for code related prompt, we send to a sandbox if there are test cases + # - finally, we combine all the rewards together + # - The reward type depends on the tag of the data + if config.reward_model.enable: + if config.reward_model.strategy == 'fsdp': + from verl.workers.fsdp_workers import RewardModelWorker + elif config.reward_model.strategy == 'megatron': + from verl.workers.megatron_workers import RewardModelWorker + else: + raise NotImplementedError + role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) + mapping[Role.RewardModel] = global_pool_id + + # reference model + if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) + mapping[Role.RefPolicy] = global_pool_id + + reward_manager_name = config.reward_model.get("reward_manager", "naive") + if reward_manager_name == 'naive': + from verl.workers.reward_manager import NaiveRewardManager + reward_manager_cls = NaiveRewardManager + elif reward_manager_name == 'prime': + from verl.workers.reward_manager import PrimeRewardManager + reward_manager_cls = PrimeRewardManager + elif reward_manager_name == 'dapo': + from verl.workers.reward_manager import DAPORewardManager + reward_manager_cls = DAPORewardManager + else: + + raise NotImplementedError + + compute_score = get_custom_reward_fn(config) + reward_fn = reward_manager_cls(tokenizer=tokenizer, + num_examine=0, + compute_score=compute_score, + reward_fn_key=config.data.reward_fn_key, + max_resp_len=config.data.max_response_length, + overlong_buffer_cfg=config.reward_model.overlong_buffer) + + # Note that we always use function-based RM for validation + val_reward_fn = reward_manager_cls(tokenizer=tokenizer, + num_examine=1, + compute_score=compute_score, + reward_fn_key=config.data.reward_fn_key, + max_resp_len=config.data.max_response_length, + overlong_buffer_cfg=config.reward_model.overlong_buffer) + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + trainer = RayDAPOTrainer(config=config, + tokenizer=tokenizer, + processor=processor, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn) + trainer.init_workers() + trainer.fit() + + +if __name__ == '__main__': + main() diff --git a/recipe/dapo/test_dapo_7b.sh b/recipe/dapo/test_dapo_7b.sh new file mode 100644 index 00000000000..f370fffb41a --- /dev/null +++ b/recipe/dapo/test_dapo_7b.sh @@ -0,0 +1,135 @@ +#!/usr/bin/env bash +set -euxo pipefail + +project_name='DAPO' +exp_name='DAPO-Qwen2.5-7B-Math-Test' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 2)) +enable_overlong_buffer=True +overlong_buffer_len=512 +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=512 +gen_prompt_bsz=$((train_prompt_bsz * 3)) +train_prompt_mini_bsz=32 +n_resp_per_prompt=16 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-4} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Mathematically equivalent +use_dynamic_bsz=True +infer_micro_batch_size=null +train_micro_batch_size=null +offload=False + +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + --working-dir "${WORKING_DIR}" \ + -- python3 -m recipe.dapo.src.main_dapo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + data.truncation='left' \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.loss_agg_mode=True \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=2 \ + trainer.save_freq=2 \ + trainer.total_epochs=1 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=disable \ No newline at end of file diff --git a/tests/e2e/arithmetic_sequence/rl/main_trainer.py b/tests/e2e/arithmetic_sequence/rl/main_trainer.py index 9376eeb292a..c1861b98792 100644 --- a/tests/e2e/arithmetic_sequence/rl/main_trainer.py +++ b/tests/e2e/arithmetic_sequence/rl/main_trainer.py @@ -28,7 +28,7 @@ def make_reward_function(tokenizer, num_examine): - def arithmetic_sequence_reward_function(data: DataProto): + def arithmetic_sequence_reward_function(data: DataProto, return_dict: bool = False): from tests.e2e.envs.digit_completion.task import compute_reward reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32) @@ -77,7 +77,10 @@ def arithmetic_sequence_reward_function(data: DataProto): dense_reward = torch.as_tensor(dense_reward, dtype=torch.float32, device=reward_tensor.device) reward_tensor[i] = dense_reward * response_mask - return reward_tensor + if return_dict: + return {"reward_tensor": reward_tensor} + else: + return reward_tensor return arithmetic_sequence_reward_function diff --git a/tests/e2e/run_qwen_gsm8k_dapo.sh b/tests/e2e/run_qwen_gsm8k_dapo.sh new file mode 100644 index 00000000000..463b4d693d6 --- /dev/null +++ b/tests/e2e/run_qwen_gsm8k_dapo.sh @@ -0,0 +1,74 @@ +#!/usr/bin/env bash +set -x + +export VLLM_ATTENTION_BACKEND=XFORMERS + +adv_estimator=grpo + +kl_coef=0.0 +use_kl_in_reward=False +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=512 +max_response_length=512 +enable_overlong_buffer=True +overlong_buffer_len=128 +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=True +filter_groups_metric=seq_reward +max_num_gen_batches=10 +train_prompt_bsz=32 +train_prompt_mini_bsz=$((train_prompt_bsz / 2)) +gen_prompt_bsz=$((train_prompt_bsz * 3)) +n_resp_per_prompt=4 + +python3 -m recipe.dapo.src.main_dapo \ + data.train_files="$HOME/data/gsm8k/train.parquet" \ + data.val_files="$HOME/data/gsm8k/test.parquet" \ + reward_model.reward_manager=dapo \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + trainer.logger=['console'] \ + trainer.project_name='verl_example_gsm8k' \ + trainer.experiment_name='qwen2.5_0.5b_e2e_ci_dapo' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.total_training_steps=1 $@ diff --git a/tests/e2e/run_qwen_gsm8k_prime.sh b/tests/e2e/run_qwen_gsm8k_prime.sh index af9397441eb..4fea00fa96d 100644 --- a/tests/e2e/run_qwen_gsm8k_prime.sh +++ b/tests/e2e/run_qwen_gsm8k_prime.sh @@ -41,7 +41,7 @@ python3 -m recipe.prime.main_prime \ reward_model.model.optim.grad_clip=10.0 \ reward_model.model.input_tokenizer=null \ reward_model.mini_batch_size=32 \ - reward_model.reward_manager=naive \ + reward_model.reward_manager=prime \ trainer.val_before_train=False \ trainer.logger=['console'] \ trainer.project_name='verl_example' \ diff --git a/verl/protocol.py b/verl/protocol.py index a7c843ec677..847bc92a786 100644 --- a/verl/protocol.py +++ b/verl/protocol.py @@ -195,10 +195,39 @@ def __len__(self): return 0 def __getitem__(self, item): - tensor_data = self.batch[item] - non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()} - return_type = DataProto if isinstance(item, slice) else DataProtoItem - return return_type(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info) + """ + Enhanced indexing for DataProto objects. + + Args: + item: Can be one of: + - int: A single index + - slice: A slice object (start:stop:step) + - list: A list of indices + - numpy.ndarray: An array of indices + - torch.Tensor: A tensor of indices + + Returns: + DataProto: For all indexing types except single integers + DataProtoItem: Only for single integer indices + """ + # Case 1: Slice object - use the slice method + if isinstance(item, slice): + return self.slice(item.start, item.stop, item.step) + + # Case 2: List, numpy array, or torch tensor - use sel_idxs + elif isinstance(item, (list, np.ndarray, torch.Tensor)): + return self.select_idxs(item) + + # Case 3: Single integer - return DataProtoItem for backward compatibility + elif isinstance(item, (int, np.integer)): + tensor_data = self.batch[item] + non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()} + return_type = DataProto if isinstance(item, slice) else DataProtoItem + return return_type(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info) + + # Case 4: Unsupported type + else: + raise TypeError(f"Indexing with {type(item)} is not supported") def __getstate__(self): import io @@ -267,7 +296,7 @@ def check_consistency(self): for key, val in self.non_tensor_batch.items(): assert isinstance( val, np.ndarray - ) and val.dtype == object, 'data in the non_tensor_batch must be a numpy.array with dtype=object' + ), f'data in the non_tensor_batch must be a numpy.array with dtype=object, but for {key=}, got {type(val)=}' assert val.shape[ 0] == batch_size, f'key {key} length {len(val)} is not equal to batch size {batch_size}' @@ -371,6 +400,87 @@ def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=Non return DataProto(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info) + def select_idxs(self, idxs): + """ + Select specific indices from the DataProto. + + Args: + idxs (torch.Tensor or numpy.ndarray or list): Indices to select + + Returns: + DataProto: A new DataProto containing only the selected indices + """ + if isinstance(idxs, list): + idxs = torch.tensor(idxs, dtype=torch.int32) + + if isinstance(idxs, np.ndarray): + idxs_np = idxs + idxs_torch = torch.from_numpy(idxs) + else: # torch.Tensor + idxs_torch = idxs + idxs_np = idxs.detach().cpu().numpy() + + if self.batch is not None: + # Use TensorDict's built-in indexing capabilities + selected_batch = TensorDict(source={ + key: tensor[idxs_torch] for key, tensor in self.batch.items() + }, + batch_size=(idxs_torch.shape[0],)) + else: + selected_batch = None + + selected_non_tensor = {} + for key, val in self.non_tensor_batch.items(): + selected_non_tensor[key] = val[idxs_np] + + return DataProto(batch=selected_batch, non_tensor_batch=selected_non_tensor, meta_info=self.meta_info) + + def slice(self, start=None, end=None, step=None): + """ + Slice the DataProto and return a new DataProto object. + This is an improved version of direct slicing which returns a DataProtoItem. + + Args: + start (int, optional): Start index. Defaults to None (start from beginning). + end (int, optional): End index (exclusive). Defaults to None (go to end). + step (int, optional): Step size. Defaults to None (step=1). + + Returns: + DataProto: A new DataProto containing the sliced data + + Examples: + # Using the slice method directly + sliced_data = data_proto.slice(10, 20) + + # Using enhanced indexing (returns DataProto) + sliced_data = data_proto[10:20] + sliced_data = data_proto[::2] # Every other element + + # Using list indexing (returns DataProto) + indices = [1, 5, 10] + selected_data = data_proto[indices] + + # Single index still returns DataProtoItem + single_item = data_proto[5] + """ + # Create a slice object + slice_obj = slice(start, end, step) + + # Handle the batch data + if self.batch is not None: + # Use TensorDict's built-in slicing capabilities + sliced_batch = self.batch[slice_obj] + else: + sliced_batch = None + + # Handle the non-tensor batch data + sliced_non_tensor = {} + for key, val in self.non_tensor_batch.items(): + sliced_non_tensor[key] = val[slice_obj] + + # Return a new DataProto object + return DataProto(batch=sliced_batch, non_tensor_batch=sliced_non_tensor, meta_info=self.meta_info) + def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> 'DataProto': """Pop a subset of the DataProto via `batch_keys` and `meta_info_keys` diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 4d493942886..81a44207fd0 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -3,9 +3,11 @@ data: train_files: ~/data/rlhf/gsm8k/train.parquet val_files: ~/data/rlhf/gsm8k/test.parquet prompt_key: prompt + reward_fn_key: data_source max_prompt_length: 512 max_response_length: 512 train_batch_size: 1024 + gen_batch_size: ${data.train_batch_size} val_batch_size: null # DEPRECATED: Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs return_raw_chat: False @@ -27,8 +29,13 @@ actor_rollout_ref: ppo_micro_batch_size_per_gpu: null use_dynamic_bsz: False use_torch_compile: True # False to disable torch compile - clip_ratio: 0.2 + # pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) + clip_ratio: 0.2 # default value if clip_ratio_low and clip_ratio_high are not specified + clip_ratio_low: 0.2 + clip_ratio_high: 0.2 clip_ratio_c: 3.0 # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729 + loss_agg_mode: "token-mean" # / "seq-mean-token-sum" / "seq-mean-token-mean" + # NOTE: "token-mean" is the default behavior entropy_coeff: 0.001 use_kl_loss: False # True for GRPO kl_loss_coef: 0.001 # for grpo @@ -43,6 +50,7 @@ actor_rollout_ref: min_lr_ratio: null # only useful for warmup with cosine warmup_style: constant # select from constant/cosine total_training_steps: -1 # must be override by program + weight_decay: 0.01 megatron: tensor_model_parallel_size: 4 pipeline_model_parallel_size: 1 @@ -112,6 +120,7 @@ critic: min_lr_ratio: null # only useful for warmup with cosine warmup_style: constant # select from constant/cosine total_training_steps: -1 # must be override by program + weight_decay: 0.01 model: path: ~/models/deepseek-llm-7b-chat tokenizer_path: ${actor_rollout_ref.model.path} diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 7dd270bbaa4..d22e4dc6cd4 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -3,6 +3,7 @@ data: train_files: ~/data/rlhf/gsm8k/train.parquet val_files: ~/data/rlhf/gsm8k/test.parquet prompt_key: prompt + reward_fn_key: data_source max_prompt_length: 512 max_response_length: 512 train_batch_size: 1024 @@ -30,8 +31,12 @@ actor_rollout_ref: use_dynamic_bsz: False ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 - clip_ratio: 0.2 + # pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) + clip_ratio: 0.2 # default value if clip_ratio_low and clip_ratio_high are not specified + clip_ratio_low: 0.2 + clip_ratio_high: 0.2 clip_ratio_c: 3.0 # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729 + loss_agg_mode: "token-mean" # / "seq-mean-token-sum" / "seq-mean-token-mean" entropy_coeff: 0.001 use_kl_loss: False # True for GRPO use_torch_compile: True # False to disable torch compile @@ -49,6 +54,7 @@ actor_rollout_ref: min_lr_ratio: null # only useful for warmup with cosine warmup_style: constant # select from constant/cosine total_training_steps: -1 # must be override by program + weight_decay: 0.01 fsdp_config: wrap_policy: # transformer_layer_cls_to_wrap: None @@ -113,6 +119,7 @@ critic: min_lr_ratio: null # only useful for warmup with cosine warmup_style: constant # select from constant/cosine total_training_steps: -1 # must be override by program + weight_decay: 0.01 model: path: ~/models/deepseek-llm-7b-chat tokenizer_path: ${actor_rollout_ref.model.path} diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index 29345999775..ae4abffab8e 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -152,15 +152,24 @@ def run(self, config): elif reward_manager_name == 'prime': from verl.workers.reward_manager import PrimeRewardManager reward_manager_cls = PrimeRewardManager + elif reward_manager_name == 'dapo': + from verl.workers.reward_manager import DAPORewardManager + reward_manager_cls = DAPORewardManager else: + raise NotImplementedError compute_score = get_custom_reward_fn(config) - reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score) + reward_fn = reward_manager_cls(tokenizer=tokenizer, + num_examine=0, + compute_score=compute_score, + reward_fn_key=config.data.reward_fn_key) # Note that we always use function-based RM for validation - val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score) - + val_reward_fn = reward_manager_cls(tokenizer=tokenizer, + num_examine=1, + compute_score=compute_score, + reward_fn_key=config.data.reward_fn_key) resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) trainer = RayPPOTrainer(config=config, diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index 16b647e568f..bdb0a99510c 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -265,9 +265,44 @@ def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio): return token_level_scores - kl * kl_ratio -def compute_policy_loss(old_log_prob, log_prob, advantages, response_mask, cliprange, clip_ratio_c=3.0): - """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122 +def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str): + """ + Aggregate the loss matrix into a scalar. + Args: + loss_mat: `(torch.Tensor)` + shape: (bs, response_length) + loss_mask: `(torch.Tensor)` + shape: (bs, response_length) + loss_agg_mode: (str) choices: "token-mean" / "seq-mean-token-sum" / "seq-mean-token-mean" + "token-mean" is the default behavior + Returns: + loss: `a scalar torch.Tensor` + aggregated loss + """ + if loss_agg_mode == "token-mean": + loss = verl_F.masked_mean(loss_mat, loss_mask) + elif loss_agg_mode == "seq-mean-token-sum": + seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1) + loss = torch.mean(seq_losses) + elif loss_agg_mode == "seq-mean-token-mean": + seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1) + loss = torch.mean(seq_losses) + else: + raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}") + + return loss + +def compute_policy_loss(old_log_prob, + log_prob, + advantages, + response_mask, + cliprange=None, + cliprange_low=None, + cliprange_high=None, + clip_ratio_c=3.0, + loss_agg_mode="token-mean"): + """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122 Args: old_log_prob: `(torch.Tensor)` shape: (bs, response_length) @@ -279,16 +314,24 @@ def compute_policy_loss(old_log_prob, log_prob, advantages, response_mask, clipr shape: (bs, response_length) cliprange: (float) The clip range used in PPO. See https://arxiv.org/abs/1707.06347 - clip_ratio_c: (float) - THe lower bound of the ratio for dual-clip PPO, defalut 3. See https://arxiv.org/pdf/1912.09729 + cliprange_low: (float) + The lower clip range used in PPO. + cliprange_high: (float) + The higher clip range used in PPO. + clip_ratio_c: (float) default: 3.0 + The lower bound of the ratio for dual-clip PPO, See https://arxiv.org/pdf/1912.09729 + loss_agg_mode: (str) choices: "token-mean" / "seq-mean-token-sum" / "seq-mean-token-mean" + "token-mean" is the default behavior Returns: pg_loss: `a scalar torch.Tensor` policy gradient loss computed via PPO pg_clipfrac: (float) - a float number indicating the fraction of policy gradient loss being clipped + the fraction of policy gradient loss being clipped ppo_kl: (float) the estimated KL divergence between the latest updating policy and the old sampling policy + pg_clipfrac_lower: (float) + the fraction of policy gradient loss being clipped when the advantage is negative """ assert clip_ratio_c > 1.0, f"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0, but get the value: {clip_ratio_c}." @@ -296,20 +339,24 @@ def compute_policy_loss(old_log_prob, log_prob, advantages, response_mask, clipr ratio = torch.exp(negative_approx_kl) ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) - pg_losses = -advantages * ratio - pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange) - - clip_pg_losses1 = torch.max(pg_losses, pg_losses2) - pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses).float(), response_mask) + pg_losses1 = -advantages * ratio + if cliprange_low is None: + cliprange_low = cliprange + if cliprange_high is None: + cliprange_high = cliprange + pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, + 1 + cliprange_high) # - clip(ratio, 1-cliprange, 1+cliprange) * A + clip_pg_losses1 = torch.maximum(pg_losses1, + pg_losses2) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A) + pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask) pg_losses3 = -advantages * clip_ratio_c clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) pg_clipfrac_lower = verl_F.masked_mean( torch.gt(clip_pg_losses2, pg_losses3) * (advantages < 0).float(), response_mask) - # We only apply the dual-clip when the advantage is negative. - pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) - pg_loss = verl_F.masked_mean(pg_losses, response_mask) + pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) + pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower diff --git a/verl/trainer/ppo/metric_utils.py b/verl/trainer/ppo/metric_utils.py index 4f0859d0058..a5b8e4e63b4 100644 --- a/verl/trainer/ppo/metric_utils.py +++ b/verl/trainer/ppo/metric_utils.py @@ -16,9 +16,11 @@ """ import torch -from typing import Any, Dict, List +from typing import Any, Dict, List, Callable import numpy as np from verl import DataProto +from collections import Counter, defaultdict +from functools import partial def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]: @@ -166,3 +168,108 @@ def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], n 'perf/time_per_step': time, 'perf/throughput': total_num_tokens / (time * n_gpus), } + + +def bootstrap_metric(data: list[dict[str, Any]], + subset_size: int, + reduce_fns: list[Callable[[np.ndarray], float]], + n_bootstrap: int = 1000, + seed: int = 42) -> list[tuple[float, float]]: + np.random.seed(seed) + + bootstrap_metric_lsts = [[] for _ in range(len(reduce_fns))] + for _ in range(n_bootstrap): + bootstrap_idxs = np.random.choice(len(data), size=subset_size, replace=True) + bootstrap_data = [data[i] for i in bootstrap_idxs] + for i, reduce_fn in enumerate(reduce_fns): + bootstrap_metric_lsts[i].append(reduce_fn(bootstrap_data)) + return [(np.mean(lst), np.std(lst)) for lst in bootstrap_metric_lsts] + + +def calc_maj_val(data: list[dict[str, Any]], vote_key: str, val_key: str) -> float: + """ + Calculate the majority voting metric + """ + vote2vals = defaultdict(list) + for d in data: + vote2vals[d[vote_key]].append(d[val_key]) + + vote2cnt = {k: len(v) for k, v in vote2vals.items()} + maj_vote = max(vote2cnt, key=vote2cnt.get) + + maj_val = vote2vals[maj_vote][0] + + return maj_val + + +def process_validation_metrics(data_sources: list[str], sample_inputs: list[str], + infos_dict: dict[str, list[Any]]) -> dict[str, dict[str, dict[str, float]]]: + """Process validation metrics into a structured format. + + Args: + data_sources: Array of data source identifiers for each sample + sample_inputs: List of input prompts + infos_dict: variable name -> list of values for each sample + + Returns: + dict[str, dict[str, dict[str, float]]]: data source -> variable name -> metric value + """ + # Group metrics by data source, prompt and variable + data_src2prompt2var2vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + for sample_idx, data_source in enumerate(data_sources): + prompt = sample_inputs[sample_idx] + var2vals = data_src2prompt2var2vals[data_source][prompt] + for var_name, var_vals in infos_dict.items(): + var2vals[var_name].append(var_vals[sample_idx]) + + # Calculate metrics for each group + data_src2prompt2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) + for data_source, prompt2var2vals in data_src2prompt2var2vals.items(): + for prompt, var2vals in prompt2var2vals.items(): + n_resps = len(var2vals["final_reward"]) + preds = var2vals["pred"] + for var_name, var_vals in var2vals.items(): + if var_name in ["pred", "final_reward"]: + continue + metric = {} + + metric[f"mean@{n_resps}"] = np.mean(var_vals) + metric[f"std@{n_resps}"] = np.std(var_vals) + + ns = [] + n = 2 + while n < n_resps: + ns.append(n) + n *= 2 + ns.append(n_resps) + + data = [{"val": val, "pred": pred} for val, pred in zip(var_vals, preds)] + for n in ns: + (bon_mean, bon_std), (won_mean, won_std), (maj_n_mean, maj_n_std) = bootstrap_metric( + data, + subset_size=n, + reduce_fns=[ + lambda arr: np.max([d["val"] for d in arr]), lambda arr: np.min([d["val"] for d in arr]), + partial(calc_maj_val, vote_key="pred", val_key="val") + ]) + metric[f"best@{n}/mean"], metric[f"best@{n}/std"] = bon_mean, bon_std + metric[f"worst@{n}/mean"], metric[f"worst@{n}/std"] = won_mean, won_std + metric[f"maj@{n}/mean"], metric[f"maj@{n}/std"] = maj_n_mean, maj_n_std + + data_src2prompt2var2metric[data_source][prompt][var_name] = metric + + # Aggregate metrics across prompts + data_src2var2metric2prompt_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + for data_source, prompt2var2metric in data_src2prompt2var2metric.items(): + for prompt, var2metric in prompt2var2metric.items(): + for var_name, metric in var2metric.items(): + for metric_name, metric_val in metric.items(): + data_src2var2metric2prompt_vals[data_source][var_name][metric_name].append(metric_val) + + data_src2var2metric2val = defaultdict(lambda: defaultdict(lambda: defaultdict(float))) + for data_source, var2metric2prompt_vals in data_src2var2metric2prompt_vals.items(): + for var_name, metric2prompt_vals in var2metric2prompt_vals.items(): + for metric_name, prompt_vals in metric2prompt_vals.items(): + data_src2var2metric2val[data_source][var_name][metric_name] = np.mean(prompt_vals) + + return data_src2var2metric2val diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 115035a66e1..119f3f03ee8 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -24,6 +24,8 @@ from pprint import pprint from typing import Type, Dict from copy import deepcopy +from collections import defaultdict +from functools import partial from tqdm import tqdm import ray @@ -36,7 +38,7 @@ from verl.single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs from verl.single_controller.ray.base import create_colocated_worker_cls from verl.trainer.ppo import core_algos -from verl.trainer.ppo.metric_utils import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, reduce_metrics +from verl.trainer.ppo.metric_utils import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, reduce_metrics, bootstrap_metric, calc_maj_val, process_validation_metrics from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn @@ -357,6 +359,10 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): assert config.actor_rollout_ref.actor.ppo_mini_batch_size % config.actor_rollout_ref.actor.ppo_micro_batch_size == 0 assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus + assert config.actor_rollout_ref.actor.loss_agg_mode in [ + "token-mean", "seq-mean-token-sum", "seq-mean-token-mean" + ], f"Invalid loss_agg_mode: {config.actor_rollout_ref.actor.loss_agg_mode}" + if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss: print(f"NOTICE: You have both enabled in-reward kl and kl loss.") @@ -491,8 +497,8 @@ def _maybe_log_val_generations(self, inputs, outputs, scores): self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps) def _validate(self): - reward_tensor_lst = [] data_source_lst = [] + reward_extra_infos_dict: dict[str, list] = defaultdict(list) # Lists to collect samples for the table sample_inputs = [] @@ -512,6 +518,7 @@ def _validate(self): # Store original inputs input_ids = test_batch.batch['input_ids'] + # TODO: Can we keep special tokens except for padding tokens? input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] sample_inputs.extend(input_texts) @@ -551,31 +558,40 @@ def _validate(self): test_batch = test_batch.union(test_output_gen_batch) # evaluate using reward_function - reward_tensor = self.val_reward_fn(test_batch) - - # Store scores + result = self.val_reward_fn(test_batch, return_dict=True) + reward_tensor = result["reward_tensor"] scores = reward_tensor.sum(-1).cpu().tolist() sample_scores.extend(scores) + if "reward_extra_info" in result: + reward_extra_infos_dict["final_reward"].extend(scores) + for key, lst in result["reward_extra_info"].items(): + reward_extra_infos_dict[key].extend(lst) - reward_tensor_lst.append(reward_tensor) data_source_lst.append(test_batch.non_tensor_batch.get('data_source', ['unknown'] * reward_tensor.shape[0])) self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) - reward_tensor = torch.cat(reward_tensor_lst, dim=0).sum(-1).cpu() # (batch_size,) + for lst in reward_extra_infos_dict.values(): + assert len(lst) == 0 or len(lst) == len(sample_scores) + data_sources = np.concatenate(data_source_lst, axis=0) - # evaluate test_score based on data source - data_source_reward = {} - for i in range(reward_tensor.shape[0]): - data_source = data_sources[i] - if data_source not in data_source_reward: - data_source_reward[data_source] = [] - data_source_reward[data_source].append(reward_tensor[i].item()) + data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict) metric_dict = {} - for data_source, rewards in data_source_reward.items(): - metric_dict[f'val/test_score/{data_source}'] = np.mean(rewards) + for data_source, var2metric2val in data_src2var2metric2val.items(): + core_var = "acc" if "acc" in var2metric2val else "final_reward" + for var_name, metric2val in var2metric2val.items(): + n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()]) + for metric_name, metric_val in metric2val.items(): + if var_name == core_var and any( + metric_name.startswith(pfx) + for pfx in ["mean", "std", "maj", "best"]) and f"@{n_max}/" in metric_name: + metric_sec = "val-core" + else: + metric_sec = "val-aux" + pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}" + metric_dict[pfx] = metric_val return metric_dict @@ -882,9 +898,22 @@ def fit(self): batch = batch.union(reward_tensor) # we combine with rule-based rm - reward_tensor = self.reward_fn(batch) + reward_extra_infos_dict: dict[str, list] + try: + reward_result = self.reward_fn(batch, return_dict=True) + reward_tensor = reward_result['reward_tensor'] + reward_extra_infos_dict = reward_result['reward_extra_info'] + except Exception as e: + print(f'Error in reward_fn: {e}') + reward_tensor = self.reward_fn(batch) + reward_extra_infos_dict = {} + batch.batch['token_level_scores'] = reward_tensor + print(f'{list(reward_extra_infos_dict.keys())=}') + if reward_extra_infos_dict: + batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) + # compute rewards. apply_kl_penalty if available if self.config.algorithm.use_kl_in_reward: batch, kl_metrics = apply_kl_penalty(batch, diff --git a/verl/utils/megatron/tensor_parallel.py b/verl/utils/megatron/tensor_parallel.py index 28bf64eccbc..522d5008f49 100644 --- a/verl/utils/megatron/tensor_parallel.py +++ b/verl/utils/megatron/tensor_parallel.py @@ -168,22 +168,3 @@ def vocab_parallel_log_probs_from_logits_response_rmpad(input_ids, attention_mas seqlen=seqlen) output = full_output.squeeze(-1)[:, -response_length - 1:-1] # [batch_size, response_length] return output - - -def vocab_parallel_compute_entropy_loss(logits, response_mask): - """Compute Categorical entropy loss - - Args: - logits: `(torch.Tensor)` - shape: (bs, response_length, vocab_size) - response_mask: `(torch.Tensor)` - shape: (bs, response_length) - - Returns: - entropy: a scalar torch.Tensor - - """ - # compute entropy - entropy = vocab_parallel_entropy(logits) - entropy_loss = verl_F.masked_mean(entropy, mask=response_mask) - return entropy_loss diff --git a/verl/utils/reward_score/__init__.py b/verl/utils/reward_score/__init__.py index 1818bc1447a..70346be7fa6 100644 --- a/verl/utils/reward_score/__init__.py +++ b/verl/utils/reward_score/__init__.py @@ -21,7 +21,6 @@ def _default_compute_score(data_source, solution_str, ground_truth, extra_info=N elif data_source in ['lighteval/MATH', 'DigitalLearningGmbH/MATH-lighteval']: from . import math res = math.compute_score(solution_str, ground_truth) - # [Optional] Math-Verify Integration # For enhanced accuracy, consider utilizing Math-Verify (https://github.com/huggingface/Math-Verify). # Note: Math-Verify needs to be manually installed via pip: `pip install math-verify`. @@ -29,6 +28,9 @@ def _default_compute_score(data_source, solution_str, ground_truth, extra_info=N # from . import math_verify # res = math_verify.compute_score(solution_str, ground_truth) + elif data_source == 'math_dapo' or data_source.startswith("aime"): + from . import math_dapo + res = math_dapo.compute_score(solution_str, ground_truth) elif data_source in [ 'numina_aops_forum', 'numina_synthetic_math', 'numina_amc_aime', 'numina_synthetic_amc', 'numina_cn_k12', 'numina_olympiads' @@ -42,9 +44,11 @@ def _default_compute_score(data_source, solution_str, ground_truth, extra_info=N from . import geo3k res = geo3k.compute_score(solution_str, ground_truth) else: - raise NotImplementedError + raise NotImplementedError(f"Reward function is not implemented for {data_source=}") - if isinstance(res, (int, float, bool)): + if isinstance(res, dict): + return res + elif isinstance(res, (int, float, bool)): return float(res) else: return float(res[0]) diff --git a/verl/utils/reward_score/math_dapo.py b/verl/utils/reward_score/math_dapo.py new file mode 100644 index 00000000000..624651ad0df --- /dev/null +++ b/verl/utils/reward_score/math_dapo.py @@ -0,0 +1,290 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py + +import re +import signal +from typing import Optional + + +def last_boxed_only_string(string: str) -> Optional[str]: + """Extract the last LaTeX boxed expression from a string. + + Args: + string: Input string containing LaTeX code + + Returns: + The last boxed expression or None if not found + """ + idx = string.rfind("\\boxed{") + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + return string[idx:right_brace_idx + 1] if right_brace_idx is not None else None + + +def remove_boxed(s: str) -> str: + """Remove the LaTeX boxed command from a string. + + Args: + s: String with format "\\boxed{content}" + + Returns: + The content inside the boxed command + """ + left = "\\boxed{" + assert s[:len(left)] == left, f"box error: {s}" + assert s[-1] == "}", f"box error: {s}" + return s[len(left):-1] + + +class timeout: + + def __init__(self, seconds=1, error_message="Timeout"): + self.seconds = seconds + self.error_message = error_message + + def handle_timeout(self, signum, frame): + raise TimeoutError(self.error_message) + + def __enter__(self): + signal.signal(signal.SIGALRM, self.handle_timeout) + signal.alarm(self.seconds) + + def __exit__(self, type, value, traceback): + signal.alarm(0) + + +# Constants for normalization +SUBSTITUTIONS = [ + ("an ", ""), + ("a ", ""), + (".$", "$"), + ("\\$", ""), + (r"\ ", ""), + (" ", ""), + ("mbox", "text"), + (",\\text{and}", ","), + ("\\text{and}", ","), + ("\\text{m}", "\\text{}"), +] + +REMOVED_EXPRESSIONS = [ + "square", + "ways", + "integers", + "dollars", + "mph", + "inches", + "hours", + "km", + "units", + "\\ldots", + "sue", + "points", + "feet", + "minutes", + "digits", + "cents", + "degrees", + "cm", + "gm", + "pounds", + "meters", + "meals", + "edges", + "students", + "childrentickets", + "multiples", + "\\text{s}", + "\\text{.}", + "\\text{\ns}", + "\\text{}^2", + "\\text{}^3", + "\\text{\n}", + "\\text{}", + r"\mathrm{th}", + r"^\circ", + r"^{\circ}", + r"\;", + r",\!", + "{,}", + '"', + "\\dots", +] + + +def normalize_final_answer(final_answer: str) -> str: + """Normalize a final answer to a quantitative reasoning question. + + Args: + final_answer: The answer string to normalize + + Returns: + Normalized answer string + """ + final_answer = final_answer.split("=")[-1] + + # Apply substitutions and removals + for before, after in SUBSTITUTIONS: + final_answer = final_answer.replace(before, after) + for expr in REMOVED_EXPRESSIONS: + final_answer = final_answer.replace(expr, "") + + # Extract and normalize LaTeX math + final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) + final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) + + # Normalize shorthand TeX: + # \fracab -> \frac{a}{b} + # \frac{abc}{bef} -> \frac{abc}{bef} + # \fracabc -> \frac{a}{b}c + # \sqrta -> \sqrt{a} + # \sqrtab -> sqrt{a}b + final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) + final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) + final_answer = final_answer.replace("$", "") + + # Normalize numbers + if final_answer.replace(",", "").isdigit(): + final_answer = final_answer.replace(",", "") + + return final_answer.strip() + + +def is_correct_minerva(solution_str: str, + gt: str, + gt_need_extract: bool = False, + answer_pattern: str = r"(?i)Answer\s*:\s*([^\n]+)") -> tuple[bool, str]: + """Check if the solution is correct according to Minerva criteria. + + Args: + solution_str: The solution string to check + gt: The ground truth answer + gt_need_extract: Whether the ground truth needs extraction + answer_pattern: Regex pattern to extract the answer + + Returns: + Tuple of (is_correct, normalized_prediction) + """ + # Extract answer from solution + match = re.findall(answer_pattern, solution_str) + extracted_answer = match[-1] if match else "[INVALID]" + pred = normalize_final_answer(extracted_answer) + + # Process ground truth + if gt_need_extract: + gt = normalize_final_answer(remove_boxed(last_boxed_only_string(gt))) + else: + gt = normalize_final_answer(gt) + + return (pred == gt), pred + + +def is_correct_strict_box(pred: str, + gt: str, + pause_tokens_index: Optional[list[int]] = None) -> tuple[int, Optional[str]]: + """Check if the prediction is correct using strict boxed answer criteria. + + Args: + pred: The prediction string + gt: The ground truth answer + pause_tokens_index: Indices of pause tokens + + Returns: + Tuple of (score, extracted_prediction) + """ + # Extract the relevant part of the prediction + if pause_tokens_index is not None: + assert len(pause_tokens_index) == 4 + pred = pred[pause_tokens_index[-1] - 100:] + else: + pred = pred[-100:] + + # Extract and check the boxed answer + boxed_pred = last_boxed_only_string(pred) + extracted_pred = remove_boxed(boxed_pred) if boxed_pred is not None else None + + return 1 if (extracted_pred == gt) else -1, extracted_pred + + +def verify(solution_str: str, + answer: str, + strict_box_verify: bool = False, + pause_tokens_index: Optional[list[int]] = None) -> bool: + """Verify if the solution is correct. + + Args: + solution_str: The solution string to verify + answer: The ground truth answer + strict_box_verify: Whether to use strict box verification + pause_tokens_index: Indices of pause tokens + + Returns: + True if the solution is correct, False otherwise + """ + if strict_box_verify: + correct, pred = is_correct_strict_box(solution_str, answer, pause_tokens_index) + return correct == 1, pred + + correct, pred = is_correct_minerva(solution_str, answer) + return correct, pred + + +def compute_score(solution_str: str, + ground_truth: str, + strict_box_verify: bool = False, + pause_tokens_index: Optional[list[int]] = None) -> float: + """Compute the reward score for a solution. + + Args: + solution_str: The solution string + ground_truth: The ground truth answer + config: Configuration object containing reward model settings + pause_tokens_index: Indices of pause tokens + + Returns: + Reward score (1.0 for correct, -1.0 for incorrect) + """ + # Limit solution length for efficiency + solution_str = solution_str[-300:] # The longest answer in MATH-500 has 159 characters + + # Verify the solution + correct, pred = verify(solution_str, ground_truth, strict_box_verify, pause_tokens_index) + + reward = 1.0 if correct else -1.0 + acc = correct + + return { + "score": reward, + "acc": acc, + "pred": pred, + } diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 0e9c7dbb8ce..a3ae5e6f005 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -23,7 +23,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from verl import DataProto -from verl.trainer.ppo import core_algos +from verl.trainer.ppo.core_algos import compute_policy_loss, kl_penalty, agg_loss from verl.workers.actor import BasePPOActor from verl.utils.py_functional import append_to_dict from verl.utils.torch_functional import logprobs_from_logits, masked_mean @@ -286,21 +286,26 @@ def update_policy(self, data: DataProto): advantages = data['advantages'] clip_ratio = self.config.clip_ratio - entropy_coeff = self.config.entropy_coeff + clip_ratio_low = self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio + clip_ratio_high = self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio clip_ratio_c = self.config.get('clip_ratio_c', 3.0) + entropy_coeff = self.config.entropy_coeff + loss_agg_mode = self.config.loss_agg_mode # all return: (bsz, response_length) entropy, log_prob = self._forward_micro_batch(micro_batch=data, temperature=temperature) - pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = core_algos.compute_policy_loss( + pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss( old_log_prob=old_log_prob, log_prob=log_prob, advantages=advantages, response_mask=response_mask, cliprange=clip_ratio, + cliprange_low=clip_ratio_low, + cliprange_high=clip_ratio_high, clip_ratio_c=clip_ratio_c) # compute entropy loss from entropy - entropy_loss = verl_F.masked_mean(entropy, response_mask) + entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) # compute policy loss policy_loss = pg_loss - entropy_loss * entropy_coeff @@ -308,10 +313,12 @@ def update_policy(self, data: DataProto): if self.config.use_kl_loss: ref_log_prob = data['ref_log_prob'] # compute kl loss - kld = core_algos.kl_penalty(logprob=log_prob, - ref_logprob=ref_log_prob, - kl_penalty=self.config.kl_loss_type) - kl_loss = masked_mean(kld, response_mask) + kld = kl_penalty(logprob=log_prob, + ref_logprob=ref_log_prob, + kl_penalty=self.config.kl_loss_type) + kl_loss = agg_loss(loss_mat=kld, + loss_mask=response_mask, + loss_agg_mode=self.config.loss_agg_mode) policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef metrics['actor/kl_loss'] = kl_loss.detach().item() @@ -325,7 +332,7 @@ def update_policy(self, data: DataProto): loss.backward() data = { - 'actor/entropy_loss': entropy_loss.detach().item(), + 'actor/entropy': entropy_loss.detach().item(), 'actor/pg_loss': pg_loss.detach().item(), 'actor/pg_clipfrac': pg_clipfrac.detach().item(), 'actor/ppo_kl': ppo_kl.detach().item(), diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index 4de30b3f6bf..62c831b2979 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -39,10 +39,10 @@ from megatron.core.optimizer import DistributedOptimizer from omegaconf import OmegaConf -from verl.utils.megatron.tensor_parallel import vocab_parallel_compute_entropy_loss, vocab_parallel_log_probs_from_logits +from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits from verl.utils.megatron.pipeline_parallel import (compute_transformers_input_shapes, make_batch_generator) from verl import DataProto -from verl.trainer.ppo import core_algos +from verl.trainer.ppo.core_algos import compute_policy_loss, kl_penalty, agg_loss from verl.workers.actor import BasePPOActor from verl.utils.py_functional import append_to_dict from verl.utils.torch_functional import logprobs_from_logits, masked_mean, broadcast_dict_tensor, split_dict_tensor_into_batches @@ -276,8 +276,11 @@ def loss_func(output, data, meta_info): advantages = data['advantages'] clip_ratio = meta_info['clip_ratio'] - entropy_coeff = meta_info['entropy_coeff'] + clip_ratio_low = self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio + clip_ratio_high = self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio clip_ratio_c = meta_info['clip_ratio_c'] + entropy_coeff = meta_info['entropy_coeff'] + loss_agg_mode = self.config.loss_agg_mode # compute policy loss logits = output.logits @@ -285,24 +288,25 @@ def loss_func(output, data, meta_info): logits_back = logits.clone() log_prob = vocab_parallel_log_probs_from_logits(logits, responses) logits = logits_back - pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = core_algos.compute_policy_loss( - old_log_prob=old_log_prob, - log_prob=log_prob, - advantages=advantages, - response_mask=response_mask, - cliprange=clip_ratio, - clip_ratio_c=clip_ratio_c) - entropy_loss = vocab_parallel_compute_entropy_loss(logits, response_mask=response_mask) + pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(old_log_prob=old_log_prob, + log_prob=log_prob, + advantages=advantages, + response_mask=response_mask, + cliprange=clip_ratio, + cliprange_low=clip_ratio_low, + cliprange_high=clip_ratio_high, + clip_ratio_c=clip_ratio_c, + loss_agg_mode=loss_agg_mode) + entropy = vocab_parallel_entropy(logits) + entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) policy_loss = pg_loss - entropy_loss * entropy_coeff metrics = {} if self.config.use_kl_loss: ref_log_prob = data['ref_log_prob'] # compute kl loss - kld = core_algos.kl_penalty(logprob=log_prob, - ref_logprob=ref_log_prob, - kl_penalty=self.config.kl_loss_type) - kl_loss = masked_mean(kld, response_mask) + kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type) + kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode) policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef metrics['actor/kl_loss'] = kl_loss.detach().item() diff --git a/verl/workers/reward_manager/__init__.py b/verl/workers/reward_manager/__init__.py index 3de46a7717e..6e85a388cbe 100644 --- a/verl/workers/reward_manager/__init__.py +++ b/verl/workers/reward_manager/__init__.py @@ -13,4 +13,5 @@ # limitations under the License. from .naive import NaiveRewardManager -from .prime import PrimeRewardManager \ No newline at end of file +from .prime import PrimeRewardManager +from .dapo import DAPORewardManager diff --git a/verl/workers/reward_manager/dapo.py b/verl/workers/reward_manager/dapo.py new file mode 100644 index 00000000000..f98558d998a --- /dev/null +++ b/verl/workers/reward_manager/dapo.py @@ -0,0 +1,135 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from verl import DataProto +from verl.utils.reward_score import _default_compute_score +import torch +from collections import defaultdict + + +class DAPORewardManager: + """The reward manager. + """ + + def __init__(self, + tokenizer, + num_examine, + compute_score=None, + reward_fn_key='data_source', + max_resp_len=None, + overlong_buffer_cfg=None) -> None: + self.tokenizer = tokenizer + self.num_examine = num_examine # the number of batches of decoded responses to print to the console + self.compute_score = compute_score or _default_compute_score + self.reward_fn_key = reward_fn_key + self.overlong_buffer_cfg = overlong_buffer_cfg + self.max_resp_len = max_resp_len + + if self.overlong_buffer_cfg is not None: + assert self.max_resp_len is not None, f"max_resp_len must be provided if {overlong_buffer_cfg=}, but got None" + + def __call__(self, data: DataProto, return_dict: bool = False): + """We will expand this function gradually based on the available datasets""" + + # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn + if 'rm_scores' in data.batch.keys(): + if return_dict: + return {"reward": data.batch['rm_scores']} + else: + return data.batch['rm_scores'] + + reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32) + reward_extra_info = defaultdict(list) + + already_print_data_sources = {} + + for i in range(len(data)): + data_item = data[i] # DataProtoItem + + prompt_ids = data_item.batch['prompts'] + + prompt_length = prompt_ids.shape[-1] + + valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum() + valid_prompt_ids = prompt_ids[-valid_prompt_length:] + + response_ids = data_item.batch['responses'] + valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum() + valid_response_ids = response_ids[:valid_response_length] + + # decode + prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True) + response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True) + eos_token = self.tokenizer.eos_token + if response_str.endswith(eos_token): + response_str = response_str[:-len(eos_token)] + + ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth'] + + data_source = data_item.non_tensor_batch[self.reward_fn_key] + + extra_info = data_item.non_tensor_batch.get('extra_info', None) + + result = self.compute_score( + data_source=data_source, + solution_str=response_str, + ground_truth=ground_truth, + extra_info=extra_info, + ) + + score: float + if isinstance(result, dict): + score = result["score"] + # Store the information including original reward + for key, value in result.items(): + reward_extra_info[key].append(value) + else: + score = result + + reward = score + + if self.overlong_buffer_cfg.enable: + overlong_buffer_len = self.overlong_buffer_cfg.len + expected_len = self.max_resp_len - overlong_buffer_len + exceed_len = valid_response_length - expected_len + overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor + overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0) + reward += overlong_reward + if self.overlong_buffer_cfg.log: + reward_extra_info["overlong_reward"].append(overlong_reward) + reward_extra_info["overlong"].append(overlong_reward < 0) + + reward_tensor[i, valid_response_length - 1] = reward + + if data_source not in already_print_data_sources: + already_print_data_sources[data_source] = 0 + + if already_print_data_sources[data_source] < self.num_examine: + already_print_data_sources[data_source] += 1 + print("[prompt]", prompt_str) + print("[response]", response_str) + print("[ground_truth]", ground_truth) + if isinstance(result, dict): + for key, value in result.items(): + print(f"[{key}]", value) + else: + print(f"[score]", score) + + if return_dict: + return { + "reward_tensor": reward_tensor, + "reward_extra_info": reward_extra_info, + } + else: + return reward_tensor diff --git a/verl/workers/reward_manager/naive.py b/verl/workers/reward_manager/naive.py index 8e34228b83b..8dd1684775a 100644 --- a/verl/workers/reward_manager/naive.py +++ b/verl/workers/reward_manager/naive.py @@ -15,61 +15,31 @@ from verl import DataProto from verl.utils.reward_score import _default_compute_score import torch +from collections import defaultdict class NaiveRewardManager: """The reward manager. """ - def __init__(self, tokenizer, num_examine, compute_score=None) -> None: + def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key='data_source') -> None: self.tokenizer = tokenizer self.num_examine = num_examine # the number of batches of decoded responses to print to the console self.compute_score = compute_score or _default_compute_score + self.reward_fn_key = reward_fn_key - def verify(self, data): - scores = [] - for i in range(len(data)): - data_item = data[i] # DataProtoItem - - prompt_ids = data_item.batch['prompts'] - - prompt_length = prompt_ids.shape[-1] - - valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum() - valid_prompt_ids = prompt_ids[-valid_prompt_length:] - - response_ids = data_item.batch['responses'] - valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum() - valid_response_ids = response_ids[:valid_response_length] - - # decode - prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True) - response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True) - - ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth'] - - data_source = data_item.non_tensor_batch['data_source'] - - extra_info = data_item.non_tensor_batch.get('extra_info', None) - - score = self.compute_score( - data_source=data_source, - solution_str=response_str, - ground_truth=ground_truth, - extra_info=extra_info, - ) - scores.append(score) - data.batch['acc'] = torch.tensor(scores, dtype=torch.float32, device=prompt_ids.device) - return scores - - def __call__(self, data: DataProto): + def __call__(self, data: DataProto, return_dict=False): """We will expand this function gradually based on the available datasets""" # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn if 'rm_scores' in data.batch.keys(): - return data.batch['rm_scores'] + if return_dict: + return {"reward": data.batch['rm_scores']} + else: + return data.batch['rm_scores'] reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32) + reward_extra_info = defaultdict(list) already_print_data_sources = {} @@ -93,7 +63,7 @@ def __call__(self, data: DataProto): ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth'] - data_source = data_item.non_tensor_batch['data_source'] + data_source = data_item.non_tensor_batch[self.reward_fn_key] extra_info = data_item.non_tensor_batch.get('extra_info', None) @@ -103,7 +73,16 @@ def __call__(self, data: DataProto): ground_truth=ground_truth, extra_info=extra_info, ) - reward_tensor[i, valid_response_length - 1] = score + + if isinstance(score, dict): + reward = score["score"] + # Store the information including original reward + for key, value in score.items(): + reward_extra_info[key].append(value) + else: + reward = score + + reward_tensor[i, valid_response_length - 1] = reward if data_source not in already_print_data_sources: already_print_data_sources[data_source] = 0 @@ -113,6 +92,16 @@ def __call__(self, data: DataProto): print("[prompt]", prompt_str) print("[response]", response_str) print("[ground_truth]", ground_truth) - print("[score]", score) - - return reward_tensor + if isinstance(score, dict): + for key, value in score.items(): + print(f"[{key}]", value) + else: + print(f"[score]", score) + + if return_dict: + return { + "reward_tensor": reward_tensor, + "reward_extra_info": reward_extra_info, + } + else: + return reward_tensor diff --git a/verl/workers/reward_manager/prime.py b/verl/workers/reward_manager/prime.py index f7e3e5add7c..11e86860226 100644 --- a/verl/workers/reward_manager/prime.py +++ b/verl/workers/reward_manager/prime.py @@ -122,7 +122,7 @@ def verify(self, data): data.batch['acc'] = torch.tensor(scores, dtype=torch.float32, device=prompt_ids.device) return scores - def __call__(self, data: DataProto): + def __call__(self, data: DataProto, return_dict: bool = False): """We will expand this function gradually based on the available datasets""" # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn @@ -155,4 +155,7 @@ def __call__(self, data: DataProto): already_print_data_sources[data_source] += 1 print(sequences_str) - return reward_tensor + if return_dict: + return {"reward_tensor": reward_tensor} + else: + return reward_tensor