Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions examples/split_placement/split_monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,7 @@
import torch

from verl import DataProto
from verl.trainer.ppo.ray_trainer import (
AdvantageEstimator,
_timer,
apply_kl_penalty,
compute_advantage,
compute_data_metrics,
compute_timing_metrics,
reduce_metrics,
)
from verl.trainer.ppo.ray_trainer import AdvantageEstimator, _timer, apply_kl_penalty, calc_mini_batch_loss_token_nums, compute_advantage, compute_data_metrics, compute_timing_metrics, reduce_metrics


def fit(self):
Expand Down Expand Up @@ -115,6 +107,8 @@ def fit(self):
# compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()

batch.meta_info["mini_batch_loss_token_nums"] = calc_mini_batch_loss_token_nums(batch, traj_mini_bsz=self.config.actor_rollout_ref.actor.ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n, num_dp_ranks=self.actor_rollout_wg.world_size)

# recompute old_log_probs
with _timer("old_log_prob", timing_raw):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
Expand Down
4 changes: 3 additions & 1 deletion recipe/dapo/src/dapo_ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
compute_timing_metrics,
reduce_metrics,
)
from verl.trainer.ppo.ray_trainer import AdvantageEstimator, RayPPOTrainer, _timer, apply_kl_penalty, compute_advantage
from verl.trainer.ppo.ray_trainer import AdvantageEstimator, RayPPOTrainer, _timer, apply_kl_penalty, calc_mini_batch_loss_token_nums, compute_advantage


class RayDAPOTrainer(RayPPOTrainer):
Expand Down Expand Up @@ -216,6 +216,8 @@ def fit(self):
# compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()

batch.meta_info["mini_batch_loss_token_nums"] = calc_mini_batch_loss_token_nums(batch, traj_mini_bsz=self.config.actor_rollout_ref.actor.ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n, num_dp_ranks=self.actor_rollout_wg.world_size)

# recompute old_log_probs
with _timer("old_log_prob", timing_raw):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
Expand Down
45 changes: 24 additions & 21 deletions recipe/prime/prime_dp_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import verl.utils.torch_functional as verl_F
from verl import DataProto
from verl.utils.py_functional import append_to_dict
from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
from verl.utils.seqlen_balancing import get_reverse_idx, get_uniform_data_chunks
from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs

from .prime_core_algos import compute_ce_dpo_loss_rm, compute_detach_dpo_loss_rm
Expand Down Expand Up @@ -170,22 +170,24 @@ def compute_rm_score(self, data: DataProto):
self.ref_module.eval()
micro_batch_size = data.meta_info["micro_batch_size"]
select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "acc"]
batch = data.select(batch_keys=select_keys).batch
selected_data = data.select(batch_keys=select_keys)
use_dynamic_bsz = data.meta_info["use_dynamic_bsz"]
prompt_length = data.batch["input_ids"].shape[-1] - data.batch["responses"].shape[-1]
batch = selected_data.batch
prompt_length = batch["input_ids"].shape[-1] - batch["responses"].shape[-1]

if use_dynamic_bsz:
# split using dynamic bsz
max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size
micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)
micro_data_chunks, indices = get_uniform_data_chunks(data=selected_data, max_token_len=max_token_len)
else:
micro_batches = batch.split(micro_batch_size)
num_micro_batches = len(selected_data) // micro_batch_size
micro_data_chunks = selected_data.chunk(num_micro_batches)

rm_scores_lst = []
q_lst = []
for micro_batch in micro_batches:
for micro_data_chunk in micro_data_chunks:
with torch.no_grad():
rm_score, q = self._forward_micro_batch(micro_batch, prompt_length)
rm_score, q = self._forward_micro_batch(micro_batch=micro_data_chunk.batch, prompt_length=prompt_length)
rm_scores_lst.append(rm_score)
q_lst.append(q)
rm_scores = torch.concat(rm_scores_lst, dim=0)
Expand Down Expand Up @@ -221,37 +223,38 @@ def update_rm(self, data: DataProto):
if key in data.batch.keys():
select_keys.append(key)

batch = data.select(batch_keys=select_keys).batch
selected_data = data.select(batch_keys=select_keys)
# Split to make minibatch iterator for updating the actor
# See PPO paper for details. https://arxiv.org/abs/1707.06347
dataloader = batch.split(self.config.mini_batch_size)
num_mini_batches = len(selected_data) // self.config.mini_batch_size
mini_dataloader = selected_data.chunk(num_mini_batches)

rm_scores_lst = []
q_lst = []

for batch_idx, data in enumerate(dataloader):
# split batch into micro_batches
mini_batch = data
for mini_idx, mini_data_chunk in enumerate(mini_dataloader):
if self.config.use_dynamic_bsz:
max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
micro_data_chunks, _ = get_uniform_data_chunks(data=mini_data_chunk, max_token_len=max_token_len)
else:
micro_batches = mini_batch.split(self.config.micro_batch_size_per_gpu)
self.gradient_accumulation = self.config.mini_batch_size // self.config.micro_batch_size_per_gpu
num_micro_batches = len(mini_data_chunk) // self.config.micro_batch_size_per_gpu
micro_data_chunks = mini_data_chunk.chunk(num_micro_batches)

self.reward_optimizer.zero_grad()

for data in micro_batches:
data = data.cuda()
attention_mask = data["attention_mask"]
acc = data["acc"]
for micro_data_chunk in micro_data_chunks:
micro_batch = micro_data_chunk.batch
micro_batch = micro_batch.cuda()
attention_mask = micro_batch["attention_mask"]
acc = micro_batch["acc"]

prompt_ids = data["prompts"]
prompt_ids = micro_batch["prompts"]
prompt_length = prompt_ids.shape[-1]

response_mask = attention_mask[:, prompt_length:]

rm_score, q = self._forward_micro_batch(data, prompt_length)
rm_score, q = self._forward_micro_batch(micro_batch=micro_batch, prompt_length=prompt_length)

rm_scores_lst.append(rm_score)
q_lst.append(q.detach())
Expand Down Expand Up @@ -289,7 +292,7 @@ def update_rm(self, data: DataProto):

if self.config.use_dynamic_bsz:
# relative to the dynamic bsz
loss = dpo_loss * (len(data) / self.config.ppo_mini_batch_size)
loss = dpo_loss * (len(micro_data_chunk) / self.config.ppo_mini_batch_size)
else:
loss = dpo_loss / self.gradient_accumulation

Expand Down
4 changes: 3 additions & 1 deletion recipe/prime/prime_ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from verl.single_controller.ray import RayWorkerGroup
from verl.trainer.ppo.core_algos import agg_loss
from verl.trainer.ppo.metric_utils import _compute_response_info
from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role, WorkerType, _timer, reduce_metrics
from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role, WorkerType, _timer, calc_mini_batch_loss_token_nums, reduce_metrics
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn

Expand Down Expand Up @@ -379,6 +379,8 @@ def fit(self):
# compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()

batch.meta_info["mini_batch_loss_token_nums"] = calc_mini_batch_loss_token_nums(batch, traj_mini_bsz=self.config.actor_rollout_ref.actor.ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n, num_dp_ranks=self.actor_rollout_wg.world_size)

# verify
with _timer("verify", timing_raw):
scores = self.reward_fn.verify(batch)
Expand Down
Loading
Loading