Skip to content

Commit cd52d8b

Browse files
authored
[algo] feat: support GRPO algorithm (#124)
- Implement KL loss, GRPO outcome adv, and utilize bon rollouts - Provide scripts for deepseek and qwen on GSM8k. Can provide more for other datasets. - Support seq balance - Train using qwen2-7b, GSM8k score can reach 0.89
1 parent 5b90cd7 commit cd52d8b

14 files changed

+314
-39
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
set -x
2+
3+
python3 -m verl.trainer.main_ppo \
4+
algorithm.adv_estimator=grpo \
5+
data.train_files=$HOME/data/gsm8k/train.parquet \
6+
data.val_files=$HOME/data/gsm8k/test.parquet \
7+
data.train_batch_size=1024 \
8+
data.val_batch_size=1312 \
9+
data.max_prompt_length=512 \
10+
data.max_response_length=1024 \
11+
actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \
12+
actor_rollout_ref.actor.optim.lr=1e-6 \
13+
actor_rollout_ref.model.use_remove_padding=True \
14+
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
15+
actor_rollout_ref.actor.ppo_micro_batch_size=128 \
16+
actor_rollout_ref.actor.use_kl_loss=True \
17+
actor_rollout_ref.actor.kl_loss_coef=0.001 \
18+
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
19+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
20+
actor_rollout_ref.actor.fsdp_config.param_offload=False \
21+
actor_rollout_ref.actor.fsdp_config.grad_offload=False \
22+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
23+
actor_rollout_ref.rollout.log_prob_micro_batch_size=256 \
24+
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
25+
actor_rollout_ref.rollout.name=vllm \
26+
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
27+
actor_rollout_ref.rollout.n=5 \
28+
actor_rollout_ref.ref.log_prob_micro_batch_size=256 \
29+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
30+
algorithm.kl_ctrl.kl_coef=0.001 \
31+
trainer.critic_warmup=0 \
32+
trainer.logger=['console','wandb'] \
33+
trainer.project_name='verl_grpo_example_gsm8k' \
34+
trainer.experiment_name='deepseek_llm_7b_function_rm' \
35+
trainer.n_gpus_per_node=8 \
36+
trainer.nnodes=1 \
37+
trainer.save_freq=-1 \
38+
trainer.test_freq=5 \
39+
trainer.total_epochs=15 $@
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
set -x
2+
3+
python3 -m verl.trainer.main_ppo \
4+
algorithm.adv_estimator=grpo \
5+
data.train_files=$HOME/data/gsm8k/train.parquet \
6+
data.val_files=$HOME/data/gsm8k/test.parquet \
7+
data.train_batch_size=1024 \
8+
data.val_batch_size=1312 \
9+
data.max_prompt_length=512 \
10+
data.max_response_length=512 \
11+
actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \
12+
actor_rollout_ref.actor.optim.lr=1e-6 \
13+
actor_rollout_ref.model.use_remove_padding=True \
14+
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
15+
actor_rollout_ref.actor.use_dynamic_bsz=True \
16+
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \
17+
actor_rollout_ref.actor.use_kl_loss=True \
18+
actor_rollout_ref.actor.kl_loss_coef=0.001 \
19+
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
20+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
21+
actor_rollout_ref.actor.fsdp_config.param_offload=False \
22+
actor_rollout_ref.actor.fsdp_config.grad_offload=False \
23+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
24+
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
25+
actor_rollout_ref.rollout.name=vllm \
26+
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
27+
actor_rollout_ref.rollout.n=5 \
28+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
29+
algorithm.kl_ctrl.kl_coef=0.001 \
30+
trainer.critic_warmup=0 \
31+
trainer.logger=['console','wandb'] \
32+
trainer.project_name='verl_grpo_example_gsm8k' \
33+
trainer.experiment_name='deepseek_llm_7b_function_rm_seq_packing' \
34+
trainer.n_gpus_per_node=8 \
35+
trainer.nnodes=1 \
36+
trainer.save_freq=-1 \
37+
trainer.test_freq=5 \
38+
trainer.total_epochs=15 $@

examples/grpo_trainer/run_qwen2-7b.sh

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
set -x
2+
3+
export VLLM_ATTENTION_BACKEND=XFORMERS
4+
5+
python3 -m verl.trainer.main_ppo \
6+
algorithm.adv_estimator=grpo \
7+
data.train_files=$HOME/data/gsm8k/train.parquet \
8+
data.val_files=$HOME/data/gsm8k/test.parquet \
9+
data.train_batch_size=1024 \
10+
data.val_batch_size=1312 \
11+
data.max_prompt_length=512 \
12+
data.max_response_length=1024 \
13+
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
14+
actor_rollout_ref.actor.optim.lr=1e-6 \
15+
actor_rollout_ref.model.use_remove_padding=True \
16+
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
17+
actor_rollout_ref.actor.ppo_micro_batch_size=128 \
18+
actor_rollout_ref.actor.use_kl_loss=True \
19+
actor_rollout_ref.actor.kl_loss_coef=0.001 \
20+
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
21+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
22+
actor_rollout_ref.actor.fsdp_config.param_offload=False \
23+
actor_rollout_ref.actor.fsdp_config.grad_offload=False \
24+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
25+
actor_rollout_ref.rollout.log_prob_micro_batch_size=256 \
26+
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
27+
actor_rollout_ref.rollout.name=vllm \
28+
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
29+
actor_rollout_ref.rollout.n=5 \
30+
actor_rollout_ref.ref.log_prob_micro_batch_size=256 \
31+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
32+
algorithm.kl_ctrl.kl_coef=0.001 \
33+
trainer.critic_warmup=0 \
34+
trainer.logger=['console','wandb'] \
35+
trainer.project_name='verl_grpo_example_gsm8k' \
36+
trainer.experiment_name='qwen2_7b_function_rm' \
37+
trainer.n_gpus_per_node=8 \
38+
trainer.nnodes=1 \
39+
trainer.save_freq=-1 \
40+
trainer.test_freq=5 \
41+
trainer.total_epochs=15 $@
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
set -x
2+
3+
export VLLM_ATTENTION_BACKEND=XFORMERS
4+
5+
python3 -m verl.trainer.main_ppo \
6+
algorithm.adv_estimator=grpo \
7+
data.train_files=$HOME/data/gsm8k/train.parquet \
8+
data.val_files=$HOME/data/gsm8k/test.parquet \
9+
data.train_batch_size=1024 \
10+
data.val_batch_size=1312 \
11+
data.max_prompt_length=512 \
12+
data.max_response_length=1024 \
13+
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
14+
actor_rollout_ref.actor.optim.lr=1e-6 \
15+
actor_rollout_ref.model.use_remove_padding=True \
16+
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
17+
actor_rollout_ref.actor.use_dynamic_bsz=True \
18+
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \
19+
actor_rollout_ref.actor.use_kl_loss=True \
20+
actor_rollout_ref.actor.kl_loss_coef=0.001 \
21+
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
22+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
23+
actor_rollout_ref.actor.fsdp_config.param_offload=False \
24+
actor_rollout_ref.actor.fsdp_config.grad_offload=False \
25+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
26+
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
27+
actor_rollout_ref.rollout.name=vllm \
28+
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
29+
actor_rollout_ref.rollout.n=5 \
30+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
31+
algorithm.kl_ctrl.kl_coef=0.001 \
32+
trainer.critic_warmup=0 \
33+
trainer.logger=['console','wandb'] \
34+
trainer.project_name='verl_grpo_example_gsm8k' \
35+
trainer.experiment_name='qwen2_7b_function_rm_kl1e-3' \
36+
+trainer.val_before_train=False \
37+
trainer.n_gpus_per_node=8 \
38+
trainer.nnodes=1 \
39+
trainer.save_freq=-1 \
40+
trainer.test_freq=5 \
41+
trainer.total_epochs=15 $@

examples/ppo_trainer/run_qwen2-7b_rm_seq_balance.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,4 @@ python3 -m verl.trainer.main_ppo \
5858
trainer.nnodes=1 \
5959
trainer.save_freq=-1 \
6060
trainer.test_freq=5 \
61-
trainer.total_epochs=100 $@
61+
trainer.total_epochs=15 $@

examples/ppo_trainer/run_qwen2-7b_seq_balance.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,4 @@ python3 -m verl.trainer.main_ppo \
4949
trainer.nnodes=1 \
5050
trainer.save_freq=-1 \
5151
trainer.test_freq=5 \
52-
trainer.total_epochs=100 $@
52+
trainer.total_epochs=15 $@

tests/__init__.py

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

tests/e2e/arithmetic_sequence/rl/config/ray_trainer.yaml

+3
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ actor_rollout_ref:
2828
grad_clip: 1.0
2929
clip_ratio: 0.2
3030
entropy_coeff: 0.0
31+
use_kl_loss: False # True for GRPO
32+
kl_loss_coef: 0.001 # for grpo
33+
kl_loss_type: low_var_kl # for grpo
3134
ppo_epochs: 1
3235
shuffle: False
3336
ulysses_sequence_parallel_size: 1 # sp size

tests/e2e/check_results.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,5 @@ def extract_reward_from_line(line):
4848
best_reward = reward
4949

5050
print(f'Best reward is {best_reward}')
51-
assert best_reward > 0.2, f'Best reward must be greater than 0.3. best_reward: {best_reward}'
51+
assert best_reward > 0.2, f'Best reward must be greater than 0.2. best_reward: {best_reward}'
5252
print('Check passes')

verl/trainer/config/ppo_trainer.yaml

+3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ actor_rollout_ref:
2727
grad_clip: 1.0
2828
clip_ratio: 0.2
2929
entropy_coeff: 0.001
30+
use_kl_loss: False # True for GRPO
31+
kl_loss_coef: 0.001 # for grpo
32+
kl_loss_type: low_var_kl # for grpo
3033
ppo_epochs: 1
3134
shuffle: False
3235
ulysses_sequence_parallel_size: 1 # sp size

verl/trainer/ppo/core_algos.py

+57
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import numpy as np
2222
import torch
23+
from collections import defaultdict
2324

2425
import verl.utils.torch_functional as verl_F
2526

@@ -106,6 +107,54 @@ def compute_gae_advantage_return(token_level_rewards: torch.Tensor, values: torc
106107
return advantages, returns
107108

108109

110+
# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.
111+
def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor,
112+
eos_mask: torch.Tensor,
113+
index: torch.Tensor,
114+
epsilon: float = 1e-6):
115+
"""
116+
Compute advantage for GRPO, operating only on Outcome reward
117+
(with only one scalar reward for each response).
118+
Args:
119+
token_level_rewards: `(torch.Tensor)`
120+
shape: (bs, response_length)
121+
eos_mask: `(torch.Tensor)`
122+
shape: (bs, response_length)
123+
124+
Returns:
125+
advantages: `(torch.Tensor)`
126+
shape: (bs, response_length)
127+
Returns: `(torch.Tensor)`
128+
shape: (bs, response_length)
129+
"""
130+
response_length = token_level_rewards.shape[-1]
131+
non_zero_mask = (token_level_rewards != 0)
132+
scores = (token_level_rewards * non_zero_mask).sum(dim=-1)
133+
134+
id2score = defaultdict(list)
135+
id2mean = {}
136+
id2std = {}
137+
138+
with torch.no_grad():
139+
bsz = scores.shape[0]
140+
for i in range(bsz):
141+
id2score[index[i]].append(scores[i])
142+
for idx in id2score:
143+
if len(id2score[idx]) == 1:
144+
id2mean[idx] = torch.tensor(0.0)
145+
id2std[idx] = torch.tensor(1.0)
146+
elif len(id2score[idx]) > 1:
147+
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
148+
id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
149+
else:
150+
raise ValueError(f"no score in prompt index: {idx}")
151+
for i in range(bsz):
152+
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
153+
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
154+
155+
return scores, scores
156+
157+
109158
def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
110159
kl = old_log_prob - ref_log_prob
111160
return token_level_scores - kl * kl_ratio
@@ -210,6 +259,14 @@ def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_pe
210259
if kl_penalty == "mse":
211260
return 0.5 * (logprob - ref_logprob).square()
212261

262+
# J. Schulman. Approximating kl divergence, 2020.
263+
# # URL http://joschu.net/blog/kl-approx.html.
264+
if kl_penalty == 'low_var_kl':
265+
kl = ref_logprob - logprob
266+
ratio = torch.exp(kl)
267+
kld = (ratio - kl - 1).contiguous()
268+
return torch.clamp(kld, min=-10, max=10)
269+
213270
if kl_penalty == "full":
214271
# so, here logprob and ref_logprob should contain the logits for every token in vocabulary
215272
raise NotImplementedError

0 commit comments

Comments
 (0)