Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
f39fae6
[misc] feat: add weight decay option in config (#611)
PeterSH6 Mar 15, 2025
25b643a
[misc] feat: verifier for Puffin (#612)
tongyx361 Mar 15, 2025
6e594fe
fix: n = 1 with duplicated data
tongyx361 Mar 15, 2025
6aefcdc
fix: config actor weight decay
tongyx361 Mar 15, 2025
3a7a0e8
[misc] feat: support token_level_loss and support different clip rang…
PeterSH6 Mar 16, 2025
198af7c
[log] feat: more statistics in validation (#620)
tongyx361 Mar 16, 2025
3ca12f9
chore: format
tongyx361 Mar 16, 2025
e7e888e
fix: use train_batch_size by default
tongyx361 Mar 16, 2025
2a6aa9e
fix: return_dict
tongyx361 Mar 16, 2025
fc1d1fc
feat: script for Puffin-Zero-Qwen2.5-32B
tongyx361 Mar 16, 2025
54acfd0
fix: extra_reward_info
tongyx361 Mar 16, 2025
c181511
chore: select_idxs
tongyx361 Mar 16, 2025
f3c513e
chore: fill_to_train_bsz
tongyx361 Mar 16, 2025
7ff4243
fix: naive.py
tongyx361 Mar 16, 2025
c16e82f
chore: rename 32B script
tongyx361 Mar 16, 2025
92fd507
fix: reward dict result
tongyx361 Mar 16, 2025
6ddea1c
chore: non_uniform_reward
tongyx361 Mar 16, 2025
4d077f5
fix: train_prompt_bsz
tongyx361 Mar 16, 2025
eff7f61
chore: filter prefix
tongyx361 Mar 16, 2025
a0050c9
chore: format
tongyx361 Mar 16, 2025
6075b73
fix 32b no filter script
PeterSH6 Mar 16, 2025
38a9fbe
feat: config overlong_buffer
tongyx361 Mar 16, 2025
8e64972
chore: comments
tongyx361 Mar 16, 2025
50109f1
fix: megatron config
tongyx361 Mar 16, 2025
3f2f815
fix: 32B script
tongyx361 Mar 16, 2025
a32aac0
fix: scripts
tongyx361 Mar 17, 2025
d8e0a73
feat: rename [skip ci]
tongyx361 Mar 17, 2025
a324ed4
fix: scripts [skip ci]
tongyx361 Mar 17, 2025
317a6c7
fix: scripts [skip ci]
tongyx361 Mar 17, 2025
f5486e4
fix: log filtering
tongyx361 Mar 17, 2025
1156f86
fix: default use_token_level_loss to True [skip ci]
tongyx361 Mar 18, 2025
04f41b7
feat: docs & metric for filtering
tongyx361 Mar 18, 2025
4536d99
fix: remove formula [skip ci]
tongyx361 Mar 18, 2025
8c3f765
fix: repro doc [skip ci]
tongyx361 Mar 18, 2025
818c048
fix: scripts [skip ci]
tongyx361 Mar 18, 2025
c2aab03
fix: filter by metric
tongyx361 Mar 18, 2025
52d8bc1
feat: allow to accumulate time
tongyx361 Mar 18, 2025
1124c78
fix: reward_extra_info
tongyx361 Mar 18, 2025
97205e4
feat: filter until max gen batches
tongyx361 Mar 18, 2025
c0d5f5b
fix: example config
tongyx361 Mar 18, 2025
aa2d8c0
fix: num_gen_batches [skip ci]
tongyx361 Mar 18, 2025
a5b8cc2
chore: tweak README [skip ci]
tongyx361 Mar 19, 2025
8889e0f
fix: typo [skip ci]
tongyx361 Mar 19, 2025
78de7f0
feat: improve README [skip ci]
tongyx361 Mar 20, 2025
e18d0ba
fix: typo [skip ci]
tongyx361 Mar 20, 2025
c6a2a6a
fix: typo [skip ci]
tongyx361 Mar 20, 2025
01ef718
chore: note about overlong filtering [skip ci]
tongyx361 Mar 20, 2025
88cf46d
resolve conflict by merging main
PeterSH6 Mar 25, 2025
66686b4
chore: wandb run of an early version
tongyx361 Mar 27, 2025
2045493
[recipe] refactor: decouple DAPO (#790)
tongyx361 Mar 27, 2025
a3c0f9e
fix: config
tongyx361 Mar 27, 2025
bd2059c
fix: algo job name
tongyx361 Mar 27, 2025
3b5ef9c
fix: config
tongyx361 Mar 27, 2025
2758a30
feat: reward_fn_key
tongyx361 Mar 27, 2025
74b7edb
fix: remove uncessary verify from naive
tongyx361 Mar 27, 2025
c01b097
fix: validation top p as paper
tongyx361 Mar 27, 2025
992d1fc
fix: top p for full version
tongyx361 Mar 27, 2025
7177a42
fix: train_batch_size
tongyx361 Mar 27, 2025
13e8b3d
chore: improve news
tongyx361 Mar 27, 2025
2ec29d9
feat: loss_agg_mode
tongyx361 Mar 27, 2025
69d57f3
feat: return_dict
tongyx361 Mar 27, 2025
ccf178e
fix: timing_raw
tongyx361 Mar 27, 2025
a4bd7a8
fix: verify for PRIME CI
tongyx361 Mar 27, 2025
7c42d47
fix: CI for DAPO
tongyx361 Mar 27, 2025
11411cc
fix: import AdvantageEstimator
tongyx361 Mar 28, 2025
066675d
fix: reward manager in CI
tongyx361 Mar 28, 2025
194aa3a
fix: CI for DAPO
tongyx361 Mar 28, 2025
9ee7ccc
fix: CI for DAPO
tongyx361 Mar 28, 2025
338f5de
fix: no filtering for single-item group
tongyx361 Mar 29, 2025
89da9db
Merge branch 'main' into gm-tyx/puffin/main
tongyx361 Mar 31, 2025
17f7b4b
fix: config
tongyx361 Mar 31, 2025
da7f9b9
feat: better metric sectioning
tongyx361 Mar 31, 2025
db456cc
refactor: extract process_validation_metrics
tongyx361 Mar 31, 2025
4bbc0ab
fix: reward_tensor
tongyx361 Mar 31, 2025
50b3fe0
Merge branch 'main' into gm-tyx/puffin/main
tongyx361 Apr 2, 2025
7bd84a7
fix: DAPO config
tongyx361 Apr 2, 2025
404a38b
fix: new features
tongyx361 Apr 2, 2025
6ba61e1
Merge branch 'main' into gm-tyx/puffin/main
tongyx361 Apr 3, 2025
a11ced5
Merge branch 'main' into gm-tyx/puffin/main
tongyx361 Apr 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions .github/workflows/e2e_gsm8k_dapo.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
name: e2e_gsm8k_dapo

on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
- v0.2.x
paths:
- "**/*.py"
- .github/workflows/e2e_gsm8k_dapo.yml
pull_request:
branches:
- main
- v0.2.x
paths:
- "**/*.py"
- "verl/trainer/config/*.yaml"
- .github/workflows/e2e_gsm8k_dapo.yml
- "tests/e2e/*.sh"

# Declare permissions just read content.
permissions:
contents: read

jobs:
e2e_gsm8k_dapo:
runs-on: [self-hosted, l20-1]
timeout-minutes: 40 # Increase this timeout value as needed
env:
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
NO_PROXY: "localhost,127.0.0.1"
HF_HUB_ENABLE_HF_TRANSFER: 1
container:
image: verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3
options: --gpus all --shm-size=10g
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
- name: Install the current repository
run: |
pip3 install hf_transfer
pip3 install -e .[test,gpu]
- name: Prepare gsm8k dataset
run: |
ray stop --force
python3 examples/data_preprocess/gsm8k.py
- name: Running gsm8k e2e with dapo alg
run: |
ray stop --force
bash tests/e2e/run_qwen_gsm8k_dapo.sh
2 changes: 1 addition & 1 deletion .github/workflows/e2e_gsm8k_prime.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ permissions:
contents: read

jobs:
e2e_gsm8k:
e2e_gsm8k_prime:
runs-on: [self-hosted, l20-1]
timeout-minutes: 40 # Increase this timeout value as needed
env:
Expand Down
7 changes: 5 additions & 2 deletions examples/split_placement/main_ppo_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, tokenizer, num_examine) -> None:
self.tokenizer = tokenizer
self.num_examine = num_examine # the number of batches of decoded responses to print to the console

def __call__(self, data: DataProto):
def __call__(self, data: DataProto, return_dict: bool = False):
"""We will expand this function gradually based on the available datasets"""

# If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
Expand Down Expand Up @@ -81,7 +81,10 @@ def __call__(self, data: DataProto):
already_print_data_sources[data_source] += 1
print(sequences_str)

return reward_tensor
if return_dict:
return {"reward_tensor": reward_tensor}
else:
return reward_tensor


import ray
Expand Down
163 changes: 163 additions & 0 deletions recipe/dapo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# DAPO Open-Source Implementation

> Open-Source Algorithm Implementation & Expriement Running: [Yuxuan Tong](https://tongyx361.github.io/), [Guangming Sheng](https://hk.linkedin.com/in/guangming-sheng-b50640211)

> [!IMPORTANT]
> **🔥 News!!!**
> - [2025/03] We published the training record of [an early version of DAPO (w/o Token-level PG Loss & Dynamic Sampling)](./run_dapo_early_qwen2.5_32b.sh), achieving 44% on AIME 2024, in [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl).

🏠 [Homepage](https://dapo-sia.github.io/) | 📝 [Paper](https://dapo-sia.github.io/static/pdf/dapo_paper.pdf) | 🤗 [Datasets&Models@HF](https://huggingface.co/collections/BytedTsinghua-SIA/dapo-67d7f1517ee33c8aed059da0) | 🐱 [Code@GitHub](https://github.com/volcengine/verl/tree/gm-tyx/puffin/main/recipe/dapo) | 🐱 [Repo@GitHub](https://github.com/BytedTsinghua-SIA/DAPO)

> We propose the **D**ecoupled Clip and Dynamic s**A**mpling **P**olicy **O**ptimization (DAPO) algorithm. By making our work publicly available, we provide the broader research community and society with practical access to scalable reinforcement learning, enabling all to benefit from these advancements. Our system is based on the awesome [verl](https://github.com/volcengine/verl) framework. Thanks for their great work! Applying DAPO training to Qwen2.5-32B base model proves to outperform the previous state-of-the-art DeepSeek-R1-Zero-Qwen-32B on AIME 2024, achieving **50%** accuracy with **50%** less training steps.
>
> ![dapo-main-result](https://dapo-sia.github.io/static/images/score.png)

## Quickstart

1. Prepare the datasets **on the Ray cluster**:

```bash
bash prepare_dapo_data.sh # This downloads the datasets to ${HOME}/verl/data by default
```

2. Submit the job to the Ray cluster **from any machine**:

```bash
cd verl # Repo root
export RAY_ADDRESS="http://${RAY_IP:-localhost}:8265" # The Ray cluster address to connect to
export WORKING_DIR="${PWD}" # The local directory to package to the Ray cluster
# Set the runtime environment like env vars and pip packages for the Ray cluster in yaml
export RUNTIME_ENV="./verl/trainer/runtime_env.yaml"
bash recipe/dapo/run_dapo_qwen2.5_32b.sh
```

## Reproduction Runs

| Setup | AIME 2024 Acc. | Training Script | Training Record |
|-------|----------------------|-----------------|-----------------|
| DAPO w/o Token-level PG Loss & Dynamic Sampling | 44% | [run_dapo_early_qwen2.5_32b.sh](./run_dapo_early_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl) |
| DAPO | 50% | [run_dapo_qwen2.5_32b.sh](./run_dapo_qwen2.5_32b.sh) | W&B (Coming soon) |

## Configuration

> [!NOTE]
> Most experiments in the paper, including the best-performant one, are run without Overlong Filtering because it's somehow overlapping with Overlong Reward Shaping in terms of properly learning from the longest outputs. So we don't implement it here.

### Separated Clip Epsilons (-> Clip-Higher)

An example configuration:

```yaml
actor_rollout_ref:
actor:
clip_ratio_low: 0.2
clip_ratio_high: 0.28
```

`clip_ratio_low` and `clip_ratio_high` specify the $\varepsilon_{\text {low }}$ and $\varepsilon_{\text {high }}$ in the DAPO objective.

Core relevant code:

```python
pg_losses1 = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)
pg_losses = torch.maximum(pg_losses1, pg_losses2)
```

### Dynamic Sampling (with Group Filtering)

An example configuration:

```yaml
data:
gen_batch_size: 1536
train_batch_size: 512
algorithm:
filter_groups:
enable: True
metric: acc # score / seq_reward / seq_final_reward / ...
max_num_gen_batches: 10 # Non-positive values mean no upper limit
```

Setting `filter_groups.enable` to `True` will filter out groups whose outputs' `metric` are all the same, e.g., for `acc`, groups whose outputs' accuracies are all 1 or 0.

The trainer will repeat sampling with `gen_batch_size` until there are enough qualified groups for `train_batch_size` or reaching the upper limit specified by `max_num_gen_batches`.

Core relevant code:

```python
prompt_bsz = self.config.data.train_batch_size
if num_prompt_in_batch < prompt_bsz:
print(f'{num_prompt_in_batch=} < {prompt_bsz=}')
num_gen_batches += 1
max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches
if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:
print(f'{num_gen_batches=} < {max_num_gen_batches=}. Keep generating...')
continue
else:
raise ValueError(
f'{num_gen_batches=} >= {max_num_gen_batches=}. Generated too many. Please check your data.'
)
else:
# Align the batch
traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n
batch = batch[:traj_bsz]
```

### Flexible Loss Aggregation Mode (-> Token-level Policy Gradient Loss)

An example configuration:

```yaml
actor_rollout_ref:
actor:
loss_agg_mode: "token-mean" # / "seq-mean-token-sum" / "seq-mean-token-mean"
# NOTE: "token-mean" is the default behavior
```

Setting `loss_agg_mode` to `token-mean` will mean the (policy gradient) loss across all the tokens in all the sequences in a mini-batch.

Core relevant code:

```python
if loss_agg_mode == "token-mean":
pg_loss = verl_F.masked_mean(pg_losses, eos_mask)
elif loss_agg_mode == "seq-mean-token-sum":
pg_loss = torch.sum(pg_losses * eos_mask, dim=-1) / torch.sum(eos_mask, dim=-1)
pg_loss = torch.mean(pg_loss)
elif loss_agg_mode == "seq-mean-token-mean":
pg_loss = torch.sum(pg_losses * eos_mask, dim=-1) / torch.sum(eos_mask, dim=-1)
pg_loss = torch.mean(pg_loss)
else:
raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}")
```

### Overlong Reward Shaping

An example configuration:

```yaml
data:
max_response_length: 20480 # 16384 + 4096
reward_model:
overlong_buffer:
enable: True
len: 4096
penalty_factor: 1.0
```

Setting `overlong_buffer.enable` to `True` will penalize the outputs whose lengths are overlong but still within the hard context limit.

Specifically, the penalty increases linearly from `0` to `overlong_buffer.penalty_factor` when the length of the output exceeds the `max_response_length` by `0` to `overlong_buffer.len` tokens.

Core relevant code:

```python
if self.overlong_buffer_cfg.enable:
overlong_buffer_len = self.overlong_buffer_cfg.len
expected_len = self.max_resp_len - overlong_buffer_len
exceed_len = valid_response_length - expected_len
overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor
overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0)
reward += overlong_reward
```
12 changes: 12 additions & 0 deletions recipe/dapo/prepare_dapo_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/usr/bin/env bash
set -uxo pipefail

export VERL_HOME=${VERL_HOME:-"${HOME}/verl"}
export TRAIN_FILE=${TRAIN_FILE:-"${VERL_HOME}/data/dapo-math-17k.parquet"}
export TEST_FILE=${TEST_FILE:-"${VERL_HOME}/data/aime-2024.parquet"}

mkdir -p "${VERL_HOME}/data"

wget -O "${TRAIN_FILE}" "https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k/resolve/main/data/dapo-math-17k.parquet?download=true"

wget -O "${TEST_FILE}" "https://huggingface.co/datasets/BytedTsinghua-SIA/AIME-2024/resolve/main/data/aime-2024.parquet?download=true"
131 changes: 131 additions & 0 deletions recipe/dapo/run_dapo_early_qwen2.5_32b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
#!/usr/bin/env bash
set -euxo pipefail

project_name='DAPO'
exp_name='DAPO-Early-Qwen2.5-32B'

adv_estimator=grpo

use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0

clip_ratio_low=0.2
clip_ratio_high=0.28

max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 20))
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4))
overlong_penalty_factor=1.0

# An early version for DAPO
loss_agg_mode="seq-mean-token-sum"

enable_filter_groups=False
gen_prompt_bsz=512 # NOTE: no filtering here
train_prompt_bsz=512
train_prompt_mini_bsz=32
n_resp_per_prompt=16

# Ray
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
NNODES=${NNODES:-16}
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}

# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
val_top_p=0.7


# Performance Related Parameter
sp_size=8
use_dynamic_bsz=True
actor_ppo_max_token_len=$((max_prompt_length + max_response_length))
infer_ppo_max_token_len=$((max_prompt_length + max_response_length))
offload=True
gen_tp=4

ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
--working-dir "${WORKING_DIR}" \
-- python3 -m recipe.dapo.src.main_dapo \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.gen_batch_size=${gen_prompt_bsz} \
data.train_batch_size=${train_prompt_bsz} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
algorithm.filter_groups.enable=${enable_filter_groups} \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
+actor_rollout_ref.model.override_config.attention_dropout=0. \
+actor_rollout_ref.model.override_config.embd_pdrop=0. \
+actor_rollout_ref.model.override_config.resid_pdrop=0. \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k="${top_k}" \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
reward_model.reward_manager=dapo \
reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
reward_model.overlong_buffer.len=${overlong_buffer_len} \
reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
trainer.logger=['console','wandb'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.n_gpus_per_node=8 \
trainer.nnodes="${NNODES}" \
trainer.val_before_train=True \
trainer.test_freq=5 \
trainer.save_freq=5 \
trainer.total_epochs=1 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=auto
Loading