Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions recipe/puffin/run_puffin_7b_test_ray.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ export TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/puffin_test.parquet"}
## Train
max_prompt_length=$((1024 * 1))
max_response_length=$((1024 * 3))
gen_prompt_bsz=512
train_prompt_bsz=512
train_prompt_mini_bsz=32
## Validation
val_top_k=-1 # 0 for HF rollout, -1 for vLLM rollout

Expand All @@ -38,16 +41,18 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
data.truncation='left' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.train_batch_size=512 \
data.gen_batch_size=${gen_prompt_bsz} \
data.train_batch_size=${train_prompt_bsz} \
data.truncation='left' \
actor_rollout_ref.rollout.n=16 \
actor_rollout_ref.actor.kl_loss_coef=0 \
actor_rollout_ref.actor.clip_ratio_low=0.2 \
actor_rollout_ref.actor.clip_ratio_high=0.25 \
algorithm.adv_estimator=grpo \
algorithm.kl_ctrl.kl_coef=0.0 \
algorithm.gamma=1.0 \
algorithm.lam=0.95 \
algorithm.filter_groups.enable=True \
algorithm.filter_groups.fill_train_batch=True \
algorithm.filter_groups.drop_last_mini_batch=True \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
Expand All @@ -63,7 +68,7 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=512 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
Expand All @@ -84,6 +89,8 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \
actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
custom_reward_function.overlong_buffer.len=512 \
custom_reward_function.overlong_buffer.penalty_factor=1.0 \
trainer.logger=['console','wandb'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
Expand Down
35 changes: 35 additions & 0 deletions verl/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,41 @@ def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=Non

return DataProto(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info)

def sel_idxs(self, idxs):
"""
Select specific indices from the DataProto.

Args:
idxs (torch.Tensor or numpy.ndarray or list): Indices to select

Returns:
DataProto: A new DataProto containing only the selected indices
"""
if isinstance(idxs, list):
idxs = torch.tensor(idxs, dtype=torch.int32)

if isinstance(idxs, np.ndarray):
idxs_np = idxs
idxs_torch = torch.from_numpy(idxs)
else: # torch.Tensor
idxs_torch = idxs
idxs_np = idxs.detach().cpu().numpy()

if self.batch is not None:
# Use TensorDict's built-in indexing capabilities
selected_batch = TensorDict(source={
key: tensor[idxs_torch] for key, tensor in self.batch.items()
},
batch_size=(idxs_torch.shape[0],))
else:
selected_batch = None

selected_non_tensor = {}
for key, val in self.non_tensor_batch.items():
selected_non_tensor[key] = val[idxs_np]

return DataProto(batch=selected_batch, non_tensor_batch=selected_non_tensor, meta_info=self.meta_info)

def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> 'DataProto':
"""Pop a subset of the DataProto via `batch_keys` and `meta_info_keys`

Expand Down
11 changes: 10 additions & 1 deletion verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ data:
reward_fn_key: data_source
max_prompt_length: 512
max_response_length: 512
train_batch_size: 1024
gen_batch_size: 1024
train_batch_size: ${data.gen_batch_size}
val_batch_size: null # DEPRECATED: Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
return_raw_chat: False
Expand Down Expand Up @@ -167,6 +168,10 @@ reward_model:
custom_reward_function:
path: null
name: compute_score
overlong_buffer:
len: 0
penalty_factor: 1.0
log: False

algorithm:
gamma: 1.0
Expand All @@ -176,6 +181,10 @@ algorithm:
kl_ctrl:
type: fixed
kl_coef: 0.001
filter_groups:
enable: False
fill_train_batch: True
drop_last_mini_batch: True

trainer:
balance_batch: True
Expand Down
8 changes: 6 additions & 2 deletions verl/trainer/main_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,14 @@ def main_task(config):
raise NotImplementedError

compute_score = get_custom_reward_fn(config)
reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score, reward_fn_key=config.data.reward_fn_key)
reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score, reward_fn_key=config.data.reward_fn_key,
max_resp_len=config.data.max_response_length,
overlong_buffer_cfg=config.custom_reward_function.overlong_buffer)

# Note that we always use function-based RM for validation
val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score, reward_fn_key=config.data.reward_fn_key)
val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score, reward_fn_key=config.data.reward_fn_key,
max_resp_len=config.data.max_response_length,
overlong_buffer_cfg=config.custom_reward_function.overlong_buffer)

resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)

Expand Down
106 changes: 80 additions & 26 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,10 @@ def _validate_config(self):
# number of GPUs total
n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes

if not config.algorithm.filter_groups.enable:
assert config.data.train_batch_size == config.data.gen_batch_size, \
f"train_batch_size must be equal to gen_batch_size when filter_groups.enable is False, but got {config.data.train_batch_size =} and {config.data.gen_batch_size =}"

# 1. Check total batch size for data correctness
real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n
assert real_train_batch_size % n_gpus == 0, \
Expand Down Expand Up @@ -424,7 +428,7 @@ def _create_dataloader(self):
sampler = SequentialSampler(data_source=self.train_dataset)

self.train_dataloader = StatefulDataLoader(dataset=self.train_dataset,
batch_size=self.config.data.train_batch_size,
batch_size=self.config.data.gen_batch_size,
num_workers=8,
drop_last=True,
collate_fn=collate_fn,
Expand Down Expand Up @@ -573,17 +577,17 @@ def _validate(self):
prompt = sample_inputs[sample_idx]

var2vals = data_src2prompt2var2vals[data_source][prompt]
var2vals["reward_sum"].append(sample_scores[sample_idx])
var2vals["final_reward"].append(sample_scores[sample_idx])
for metric_name, metric_vals in reward_extra_infos_dict.items():
var2vals[metric_name].append(metric_vals[sample_idx])

data_src2prompt2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
for data_source, prompt2var2vals in data_src2prompt2var2vals.items():
for prompt, var2vals in prompt2var2vals.items():
n_resps = len(var2vals["reward_sum"])
n_resps = len(var2vals["final_reward"])
preds = var2vals["pred"]
for var_name, var_vals in var2vals.items():
if var_name in ["pred", "reward_sum"]:
if var_name in ["pred", "final_reward"]:
continue
metric = {}

Expand Down Expand Up @@ -617,7 +621,7 @@ def _validate(self):
data_src2var2metric2prompt_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
for data_source, prompt2var2metric in data_src2prompt2var2metric.items():
for prompt, var2metric in prompt2var2metric.items():
for metric_name, metric in var2metric.items():
for var_name, metric in var2metric.items():
for metric_name, metric_val in metric.items():
data_src2var2metric2prompt_vals[data_source][var_name][metric_name].append(metric_val)

Expand Down Expand Up @@ -881,6 +885,77 @@ def fit(self):
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)

with _timer('reward', timing_raw):
# compute scores. Support both model and function-based.
# We first compute the scores using reward model. Then, we call reward_fn to combine
# the results from reward model and rule-based results.
if self.use_rm:
# we first compute reward model score
reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)

# we combine with rule-based rm
reward_tensor = self.reward_fn(batch)
batch.batch['token_level_scores'] = reward_tensor

# compute rewards. apply_kl_penalty if available
if not self.config.actor_rollout_ref.actor.get('use_kl_loss', False):
batch, kl_metrics = apply_kl_penalty(batch,
kl_ctrl=self.kl_ctrl,
kl_penalty=self.config.algorithm.kl_penalty)
metrics.update(kl_metrics)
else:
batch.batch['token_level_rewards'] = batch.batch['token_level_scores']

if self.config.algorithm.filter_groups.enable:
filter_metric_dict = {}

uid2seq_rewards = defaultdict(list)
for uid, tok_rewards in zip(batch.non_tensor_batch['uid'], batch.batch['token_level_rewards']):
seq_reward = torch.sum(tok_rewards).item()
uid2seq_rewards[uid].append(seq_reward)

uid2seq_reward_std = {}
for uid, seq_rewards in uid2seq_rewards.items():
uid2seq_reward_std[uid] = np.std(seq_rewards)

kept_uids = [uid for uid, std in uid2seq_reward_std.items() if std > 0]
filter_metric_dict["non_uni_rew_prompt_ratio"] = len(kept_uids) / len(uid2seq_rewards)
filter_metric_dict["non_uni_rew_prompt_bsz"] = len(kept_uids)

kept_idxs = []


train_prompt_bsz = len(batch.batch)
fill_train_batch = self.config.algorithm.filter_groups.fill_train_batch
if len(kept_uids) > train_prompt_bsz or not fill_train_batch:
kept_uids = kept_uids[:train_prompt_bsz]
else:
for uid in uid2seq_reward_std.keys():
if uid not in kept_uids:
kept_uids.append(uid)
if len(kept_uids) == train_prompt_bsz:
break

for idx, uid in enumerate(batch.non_tensor_batch['uid']):
if uid in kept_uids:
kept_idxs.append(idx)
filter_metric_dict["non_uni_rew_traj_bsz"] = len(kept_idxs)

world_size = self.actor_rollout_wg.world_size
kept_idxs = kept_idxs[:len(kept_idxs) // world_size * world_size]
if self.config.algorithm.filter_groups.drop_last_mini_batch:
train_traj_mini_bsz = self.config.actor_rollout_ref.actor.ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n
if len(kept_idxs) > train_traj_mini_bsz:
kept_idxs = kept_idxs[:len(kept_idxs) // train_traj_mini_bsz * train_traj_mini_bsz]
else:
print(f'[WARNING] {len(kept_idxs)=} < {train_traj_mini_bsz=}')

filter_metric_dict["final_traj_ratio"] = len(kept_idxs) / len(batch.batch)
filter_metric_dict["final_traj_bsz"] = len(kept_idxs)

batch = batch.sel_idxs(kept_idxs)

# balance the number of valid tokens on each dp rank.
# Note that this breaks the order of data inside the batch.
# Please take care when you implement group based adv computation such as GRPO and rloo
Expand Down Expand Up @@ -908,27 +983,6 @@ def fit(self):
batch = batch.union(values)

with _timer('adv', timing_raw):
# compute scores. Support both model and function-based.
# We first compute the scores using reward model. Then, we call reward_fn to combine
# the results from reward model and rule-based results.
if self.use_rm:
# we first compute reward model score
reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)

# we combine with rule-based rm
reward_tensor = self.reward_fn(batch)
batch.batch['token_level_scores'] = reward_tensor

# compute rewards. apply_kl_penalty if available
if not self.config.actor_rollout_ref.actor.get('use_kl_loss', False):
batch, kl_metrics = apply_kl_penalty(batch,
kl_ctrl=self.kl_ctrl,
kl_penalty=self.config.algorithm.kl_penalty)
metrics.update(kl_metrics)
else:
batch.batch['token_level_rewards'] = batch.batch['token_level_scores']

# compute advantages, executed on the driver process
batch = compute_advantage(batch,
adv_estimator=self.config.algorithm.adv_estimator,
Expand Down
25 changes: 22 additions & 3 deletions verl/workers/reward_manager/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,16 @@ class NaiveRewardManager:
"""The reward manager.
"""

def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key='data_source') -> None:
def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key='data_source', max_resp_len=None, overlong_buffer_cfg = None) -> None:
self.tokenizer = tokenizer
self.num_examine = num_examine # the number of batches of decoded responses to print to the console
self.compute_score = compute_score or _default_compute_score
self.reward_fn_key = reward_fn_key
self.overlong_buffer_cfg = overlong_buffer_cfg
self.max_resp_len = max_resp_len

if self.overlong_buffer_cfg is not None:
assert self.max_resp_len is not None, f"max_resp_len must be provided if {overlong_buffer_cfg=}, but got None"

# TODO: Is this still necessary in algorithms other than PRIME?
def verify(self, data):
Expand Down Expand Up @@ -116,18 +121,32 @@ def __call__(self, data: DataProto, return_dict: bool = False):
extra_info=extra_info,
)

final_reward = 0

reward: float
if isinstance(result, dict):
assert "reward" in result
reward = result["reward"]
else:
reward = result

reward_tensor[i, valid_response_length - 1] = reward

for key, value in result.items():
reward_extra_info[key].append(value)

final_reward += reward

overlong_buffer_len = self.overlong_buffer_cfg.len
if overlong_buffer_len > 0:
overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor
exceed_len = valid_response_length - (self.max_resp_len - overlong_buffer_len)
overlong_reward = max(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0)
final_reward += overlong_reward
if self.overlong_buffer_cfg.log:
reward_extra_info["overlong_reward"].append(overlong_reward)
reward_extra_info["overlong"].append(overlong_reward < 0)

reward_tensor[i, valid_response_length - 1] = final_reward

if data_source not in already_print_data_sources:
already_print_data_sources[data_source] = 0

Expand Down