Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions recipe/puffin/run_puffin_7b_test_ray.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ export TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/puffin_test.parquet"}
## Train
max_prompt_length=$((1024 * 1))
max_response_length=$((1024 * 3))
gen_prompt_bsz=512
train_prompt_bsz=512
train_prompt_mini_bsz=32
## Validation
val_top_k=-1 # 0 for HF rollout, -1 for vLLM rollout

Expand All @@ -38,16 +41,18 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
data.truncation='left' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.train_batch_size=512 \
data.gen_batch_size=${gen_prompt_bsz} \
data.train_batch_size=${train_prompt_bsz} \
data.truncation='left' \
actor_rollout_ref.rollout.n=16 \
actor_rollout_ref.actor.kl_loss_coef=0 \
actor_rollout_ref.actor.clip_ratio_low=0.2 \
actor_rollout_ref.actor.clip_ratio_high=0.25 \
algorithm.adv_estimator=grpo \
algorithm.kl_ctrl.kl_coef=0.0 \
algorithm.gamma=1.0 \
algorithm.lam=0.95 \
algorithm.filter_groups.enable=True \
algorithm.filter_groups.fill_train_batch=True \
algorithm.filter_groups.drop_last_mini_batch=True \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
Expand All @@ -63,7 +68,7 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=512 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
Expand Down
35 changes: 35 additions & 0 deletions verl/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,41 @@ def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=Non

return DataProto(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info)

def sel_idxs(self, idxs):
"""
Select specific indices from the DataProto.

Args:
idxs (torch.Tensor or numpy.ndarray or list): Indices to select

Returns:
DataProto: A new DataProto containing only the selected indices
"""
if isinstance(idxs, list):
idxs = torch.tensor(idxs, dtype=torch.int32)

if isinstance(idxs, np.ndarray):
idxs_np = idxs
idxs_torch = torch.from_numpy(idxs)
else: # torch.Tensor
idxs_torch = idxs
idxs_np = idxs.detach().cpu().numpy()

if self.batch is not None:
# Use TensorDict's built-in indexing capabilities
selected_batch = TensorDict(source={
key: tensor[idxs_torch] for key, tensor in self.batch.items()
},
batch_size=(idxs_torch.shape[0],))
else:
selected_batch = None

selected_non_tensor = {}
for key, val in self.non_tensor_batch.items():
selected_non_tensor[key] = val[idxs_np]

return DataProto(batch=selected_batch, non_tensor_batch=selected_non_tensor, meta_info=self.meta_info)

def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> 'DataProto':
"""Pop a subset of the DataProto via `batch_keys` and `meta_info_keys`

Expand Down
7 changes: 6 additions & 1 deletion verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ data:
reward_fn_key: data_source
max_prompt_length: 512
max_response_length: 512
train_batch_size: 1024
gen_batch_size: 1024
train_batch_size: ${data.gen_batch_size}
val_batch_size: null # DEPRECATED: Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
return_raw_chat: False
Expand Down Expand Up @@ -179,6 +180,10 @@ algorithm:
kl_ctrl:
type: fixed
kl_coef: 0.001
filter_groups:
enable: False
fill_train_batch: True
drop_last_mini_batch: True

trainer:
balance_batch: True
Expand Down
98 changes: 76 additions & 22 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,10 @@ def _validate_config(self):
# number of GPUs total
n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes

if not config.algorithm.filter_groups.enable:
assert config.data.train_batch_size == config.data.gen_batch_size, \
f"train_batch_size must be equal to gen_batch_size when filter_groups.enable is False, but got {config.data.train_batch_size =} and {config.data.gen_batch_size =}"

# 1. Check total batch size for data correctness
real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n
assert real_train_batch_size % n_gpus == 0, \
Expand Down Expand Up @@ -424,7 +428,7 @@ def _create_dataloader(self):
sampler = SequentialSampler(data_source=self.train_dataset)

self.train_dataloader = StatefulDataLoader(dataset=self.train_dataset,
batch_size=self.config.data.train_batch_size,
batch_size=self.config.data.gen_batch_size,
num_workers=8,
drop_last=True,
collate_fn=collate_fn,
Expand Down Expand Up @@ -881,6 +885,77 @@ def fit(self):
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)

with _timer('reward', timing_raw):
# compute scores. Support both model and function-based.
# We first compute the scores using reward model. Then, we call reward_fn to combine
# the results from reward model and rule-based results.
if self.use_rm:
# we first compute reward model score
reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)

# we combine with rule-based rm
reward_tensor = self.reward_fn(batch)
batch.batch['token_level_scores'] = reward_tensor

# compute rewards. apply_kl_penalty if available
if not self.config.actor_rollout_ref.actor.get('use_kl_loss', False):
batch, kl_metrics = apply_kl_penalty(batch,
kl_ctrl=self.kl_ctrl,
kl_penalty=self.config.algorithm.kl_penalty)
metrics.update(kl_metrics)
else:
batch.batch['token_level_rewards'] = batch.batch['token_level_scores']

if self.config.algorithm.filter_groups.enable:
filter_metric_dict = {}

uid2seq_rewards = defaultdict(list)
for uid, tok_rewards in zip(batch.non_tensor_batch['uid'], batch.batch['token_level_rewards']):
seq_reward = torch.sum(tok_rewards).item()
uid2seq_rewards[uid].append(seq_reward)

uid2seq_reward_std = {}
for uid, seq_rewards in uid2seq_rewards.items():
uid2seq_reward_std[uid] = np.std(seq_rewards)

kept_uids = [uid for uid, std in uid2seq_reward_std.items() if std > 0]
filter_metric_dict["non_uni_rew_prompt_ratio"] = len(kept_uids) / len(uid2seq_rewards)
filter_metric_dict["non_uni_rew_prompt_bsz"] = len(kept_uids)

kept_idxs = []


train_prompt_bsz = len(batch.batch)
fill_train_batch = self.config.algorithm.filter_groups.fill_train_batch
if len(kept_uids) > train_prompt_bsz or not fill_train_batch:
kept_uids = kept_uids[:train_prompt_bsz]
else:
for uid in uid2seq_reward_std.keys():
if uid not in kept_uids:
kept_uids.append(uid)
if len(kept_uids) == train_prompt_bsz:
break

for idx, uid in enumerate(batch.non_tensor_batch['uid']):
if uid in kept_uids:
kept_idxs.append(idx)
filter_metric_dict["non_uni_rew_traj_bsz"] = len(kept_idxs)

world_size = self.actor_rollout_wg.world_size
kept_idxs = kept_idxs[:len(kept_idxs) // world_size * world_size]
if self.config.algorithm.filter_groups.drop_last_mini_batch:
train_traj_mini_bsz = self.config.actor_rollout_ref.actor.ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n
if len(kept_idxs) > train_traj_mini_bsz:
kept_idxs = kept_idxs[:len(kept_idxs) // train_traj_mini_bsz * train_traj_mini_bsz]
else:
print(f'[WARNING] {len(kept_idxs)=} < {train_traj_mini_bsz=}')

filter_metric_dict["final_traj_ratio"] = len(kept_idxs) / len(batch.batch)
filter_metric_dict["final_traj_bsz"] = len(kept_idxs)

batch = batch.sel_idxs(kept_idxs)

# balance the number of valid tokens on each dp rank.
# Note that this breaks the order of data inside the batch.
# Please take care when you implement group based adv computation such as GRPO and rloo
Expand Down Expand Up @@ -908,27 +983,6 @@ def fit(self):
batch = batch.union(values)

with _timer('adv', timing_raw):
# compute scores. Support both model and function-based.
# We first compute the scores using reward model. Then, we call reward_fn to combine
# the results from reward model and rule-based results.
if self.use_rm:
# we first compute reward model score
reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)

# we combine with rule-based rm
reward_tensor = self.reward_fn(batch)
batch.batch['token_level_scores'] = reward_tensor

# compute rewards. apply_kl_penalty if available
if not self.config.actor_rollout_ref.actor.get('use_kl_loss', False):
batch, kl_metrics = apply_kl_penalty(batch,
kl_ctrl=self.kl_ctrl,
kl_penalty=self.config.algorithm.kl_penalty)
metrics.update(kl_metrics)
else:
batch.batch['token_level_rewards'] = batch.batch['token_level_scores']

# compute advantages, executed on the driver process
batch = compute_advantage(batch,
adv_estimator=self.config.algorithm.adv_estimator,
Expand Down