Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions recipe/puffin/run_puffin_qwen2.5_7b_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#!/usr/bin/env bash
set -euxo pipefail

project_name='puffin'
exp_name='Qwen2.5-7B-Puffin-Test'

# Paths
MODEL_PATH=${MODEL_PATH:-"${HOME}/verl/models/Qwen2.5-7B"}
CKPTS_DIR=${CKPTS_DIR:-"${HOME}/verl/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${HOME}/verl/data/puffin_train.parquet"}
TEST_FILE=${TEST_FILE:-"${HOME}/verl/data/puffin_test.parquet"}

# Algorithm
## Train
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 8))
# TODO
# force_append_eos=True
## Validation
val_top_k=-1 # 0 for HF rollout, -1 for vLLM rollout


# Mathematically equivalent
use_dynamic_bsz=True
infer_micro_batch_size=null
train_micro_batch_size=null
offload=True

python3 -m verl.trainer.main_ppo \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.train_batch_size=512 \
data.truncation='left' \
actor_rollout_ref.rollout.n=16 \
actor_rollout_ref.actor.kl_loss_coef=0 \
actor_rollout_ref.actor.clip_ratio=0.2 \
algorithm.adv_estimator=grpo \
algorithm.kl_ctrl.kl_coef=0.0 \
algorithm.gamma=1.0 \
algorithm.lam=0.95 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.model.path="${MODEL_PATH}" \
+actor_rollout_ref.model.override_config.attention_dropout=0. \
+actor_rollout_ref.model.override_config.embd_pdrop=0. \
+actor_rollout_ref.model.override_config.resid_pdrop=0. \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
+actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=512 \
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.val_kwargs.top_k="${val_top_k}" \
actor_rollout_ref.rollout.val_kwargs.top_p=0.7\
actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \
actor_rollout_ref.rollout.val_kwargs.n=32 \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \
actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
trainer.logger=['console','wandb'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
+trainer.val_before_train=True \
trainer.test_freq=2 \
trainer.save_freq=2 \
trainer.total_epochs=5000 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=disable
1 change: 1 addition & 0 deletions verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ data:
train_files: ~/data/rlhf/gsm8k/train.parquet
val_files: ~/data/rlhf/gsm8k/test.parquet
prompt_key: prompt
reward_fn_key: data_source
max_prompt_length: 512
max_response_length: 512
train_batch_size: 1024
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ data:
train_files: ~/data/rlhf/gsm8k/train.parquet
val_files: ~/data/rlhf/gsm8k/test.parquet
prompt_key: prompt
reward_fn_key: data_source
max_prompt_length: 512
max_response_length: 512
train_batch_size: 1024
Expand Down
4 changes: 2 additions & 2 deletions verl/trainer/main_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,10 @@ def main_task(config):
raise NotImplementedError

compute_score = get_custom_reward_fn(config)
reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score)
reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score, reward_fn_key=config.data.reward_fn_key)

# Note that we always use function-based RM for validation
val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score)
val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score, reward_fn_key=config.data.reward_fn_key)

resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)

Expand Down
3 changes: 3 additions & 0 deletions verl/utils/reward_score/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def _default_compute_score(data_source, solution_str, ground_truth, extra_info=N
# Use Math-Verify (https://github.com/huggingface/Math-Verify) for better evaluation accuracy
from . import math_verify
res = math_verify.compute_score(solution_str, ground_truth)
elif data_source == 'math_puffin':
from . import math_puffin
res = math_puffin.compute_score
elif data_source in [
'numina_aops_forum', 'numina_synthetic_math', 'numina_amc_aime', 'numina_synthetic_amc', 'numina_cn_k12',
'numina_olympiads'
Expand Down
Loading