Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions recipe/prime/prime_core_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,18 @@
# limitations under the License.

import torch
from omegaconf import DictConfig

import verl
import verl.utils.torch_functional as verl_F
from verl.trainer.ppo.ray_trainer import compute_response_mask


def compute_rloo_advantage_return(data: verl.DataProto, response_mask: torch.Tensor, n_samples, config):
def compute_rloo_advantage_return(data: verl.DataProto, config: DictConfig):
# calculate rloo reward on different reward sources, and sum again
action_mask = compute_response_mask(data)
n_samples = config.actor_rollout_ref.rollout.n

def masked_rloo(reward_tensor_original, mask_tensor):
reward_tensor = reward_tensor_original.clone()
reward_tensor[~mask_tensor] = 0
Expand All @@ -39,13 +44,13 @@ def masked_rloo(reward_tensor_original, mask_tensor):
with torch.no_grad():
if "rm_scores" in data.batch.keys() and config.algorithm.reward_dpo_coef != 0.0:
reward_tensor = data.batch["rm_scores"]
reward_mask = response_mask.bool()
reward_mask = action_mask.bool()

reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_dpo_coef)

if "acc" in data.batch.keys() and config.algorithm.reward_gt_coef != 0.0:
reward_tensor = torch.zeros_like(response_mask, dtype=torch.float32)
reward_mask = torch.zeros_like(response_mask, dtype=torch.bool)
reward_tensor = torch.zeros_like(action_mask, dtype=torch.float32)
reward_mask = torch.zeros_like(action_mask, dtype=torch.bool)

prompt_ids = data.batch["prompts"]
prompt_length = prompt_ids.shape[-1]
Expand All @@ -64,10 +69,10 @@ def masked_rloo(reward_tensor_original, mask_tensor):

final_reward_tensor = sum(reward_tensors)

returns = (final_reward_tensor * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
returns = (final_reward_tensor * action_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])

advantages = returns.clone()
advantages = verl_F.masked_whiten(advantages, response_mask)
advantages = verl_F.masked_whiten(advantages, action_mask)

return advantages, returns

Expand Down
54 changes: 27 additions & 27 deletions recipe/prime/prime_dp_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import verl.utils.torch_functional as verl_F
from verl import DataProto
from verl.utils.py_functional import append_to_dict
from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
from verl.utils.seqlen_balancing import get_reverse_idx, get_uniform_data_chunks
from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs

from .prime_core_algos import compute_ce_dpo_loss_rm, compute_detach_dpo_loss_rm
Expand Down Expand Up @@ -170,22 +170,24 @@ def compute_rm_score(self, data: DataProto):
self.ref_module.eval()
micro_batch_size = data.meta_info["micro_batch_size"]
select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "acc"]
batch = data.select(batch_keys=select_keys).batch
selected_data = data.select(batch_keys=select_keys)
use_dynamic_bsz = data.meta_info["use_dynamic_bsz"]
prompt_length = data.batch["input_ids"].shape[-1] - data.batch["responses"].shape[-1]
batch = selected_data.batch
prompt_length = batch["input_ids"].shape[-1] - batch["responses"].shape[-1]

if use_dynamic_bsz:
# split using dynamic bsz
max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size
micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)
micro_data_chunks, indices = get_uniform_data_chunks(data=selected_data, max_token_len=max_token_len)
else:
micro_batches = batch.split(micro_batch_size)
num_micro_batches = len(selected_data) // micro_batch_size
micro_data_chunks = selected_data.chunk(num_micro_batches)

rm_scores_lst = []
q_lst = []
for micro_batch in micro_batches:
for micro_data_chunk in micro_data_chunks:
with torch.no_grad():
rm_score, q = self._forward_micro_batch(micro_batch, prompt_length)
rm_score, q = self._forward_micro_batch(micro_batch=micro_data_chunk.batch, prompt_length=prompt_length)
rm_scores_lst.append(rm_score)
q_lst.append(q)
rm_scores = torch.concat(rm_scores_lst, dim=0)
Expand Down Expand Up @@ -221,37 +223,35 @@ def update_rm(self, data: DataProto):
if key in data.batch.keys():
select_keys.append(key)

batch = data.select(batch_keys=select_keys).batch
selected_data = data.select(batch_keys=select_keys)
# Split to make minibatch iterator for updating the actor
# See PPO paper for details. https://arxiv.org/abs/1707.06347
dataloader = batch.split(self.config.mini_batch_size)
mini_dataloader = selected_data.split(split_size=self.config.mini_batch_size)

rm_scores_lst = []
q_lst = []

for batch_idx, data in enumerate(dataloader):
# split batch into micro_batches
mini_batch = data
for mini_idx, mini_data_chunk in enumerate(mini_dataloader):
if self.config.use_dynamic_bsz:
max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
micro_data_chunks, _ = get_uniform_data_chunks(data=mini_data_chunk, max_token_len=max_token_len)
else:
micro_batches = mini_batch.split(self.config.micro_batch_size_per_gpu)
self.gradient_accumulation = self.config.mini_batch_size // self.config.micro_batch_size_per_gpu
micro_data_chunks = mini_data_chunk.split(split_size=self.config.micro_batch_size_per_gpu)

self.reward_optimizer.zero_grad()

for data in micro_batches:
data = data.cuda()
attention_mask = data["attention_mask"]
acc = data["acc"]
for micro_data_chunk in micro_data_chunks:
micro_batch = micro_data_chunk.batch
micro_batch = micro_batch.cuda()
attention_mask = micro_batch["attention_mask"]
acc = micro_batch["acc"]

prompt_ids = data["prompts"]
prompt_ids = micro_batch["prompts"]
prompt_length = prompt_ids.shape[-1]

response_mask = attention_mask[:, prompt_length:]

rm_score, q = self._forward_micro_batch(data, prompt_length)
rm_score, q = self._forward_micro_batch(micro_batch=micro_batch, prompt_length=prompt_length)

rm_scores_lst.append(rm_score)
q_lst.append(q.detach())
Expand Down Expand Up @@ -285,21 +285,21 @@ def update_rm(self, data: DataProto):
else:
raise NotImplementedError

data = {"reward_model/dpo_loss": dpo_loss.detach().item()}
micro_metric_data = {"reward_model/dpo_loss": dpo_loss.detach().item()}

if self.config.use_dynamic_bsz:
# relative to the dynamic bsz
loss = dpo_loss * (len(data) / self.config.ppo_mini_batch_size)
loss = dpo_loss * (len(micro_data_chunk) / self.config.ppo_mini_batch_size)
else:
loss = dpo_loss / self.gradient_accumulation
loss = dpo_loss / len(micro_data_chunks)

loss.backward()

append_to_dict(metrics, data)
append_to_dict(metrics, micro_metric_data)

grad_norm = self._optimizer_step()
data = {"reward_model/grad_norm": grad_norm.detach().item()}
append_to_dict(metrics, data)
metric_data = {"reward_model/grad_norm": grad_norm.detach().item()}
append_to_dict(metrics, metric_data)
self.reward_optimizer.zero_grad()

rm_scores = torch.cat(rm_scores_lst, dim=0)
Expand Down
15 changes: 2 additions & 13 deletions recipe/prime/prime_ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from verl.single_controller.ray import RayWorkerGroup
from verl.trainer.ppo.core_algos import agg_loss
from verl.trainer.ppo.metric_utils import _compute_response_info
from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role, WorkerType, _timer, reduce_metrics
from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role, WorkerType, _timer, compute_response_mask, reduce_metrics
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn

Expand All @@ -39,11 +39,7 @@

def compute_advantage(data: DataProto, adv_estimator, config):
if adv_estimator == "rloo":
responses = data.batch["responses"]
response_length = responses.size(-1)
attention_mask = data.batch["attention_mask"]
response_mask = attention_mask[:, -response_length:]
advantages, returns = prime_core_algos.compute_rloo_advantage_return(data, response_mask, config.actor_rollout_ref.rollout.n, config)
advantages, returns = prime_core_algos.compute_rloo_advantage_return(data, config=config)
data.batch["advantages"] = advantages
data.batch["returns"] = returns
else:
Expand Down Expand Up @@ -110,13 +106,6 @@ def compute_data_metrics(batch, use_critic=True):
return metrics


def compute_response_mask(data: DataProto):
responses = data.batch["responses"]
response_length = responses.size(1)
attention_mask = data.batch["attention_mask"]
return attention_mask[:, -response_length:]


def compute_timing_metrics(batch, timing_raw):
response_info = _compute_response_info(batch)
num_prompt_tokens = torch.sum(response_info["prompt_length"]).item()
Expand Down
5 changes: 3 additions & 2 deletions tests/utility/test_tensor_dict_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,15 +276,16 @@ def test_len():


def test_seqlen_balancing():
from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
from verl.utils.seqlen_balancing import get_reverse_idx, get_uniform_data_chunks

input_ids = torch.randint(low=0, high=10, size=(20, 100))
from verl.utils.model import create_random_mask

attention_mask = create_random_mask(input_ids=input_ids, max_ratio_of_left_padding=0.1, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.5)
data = {"input_ids": input_ids, "attention_mask": attention_mask}
dataproto = DataProto.from_single_dict(data)
micro_batches, micro_bsz_idx_lst = rearrange_micro_batches(dataproto.batch, max_token_len=300)
micro_data_chunks, micro_bsz_idx_lst = get_uniform_data_chunks(dataproto, max_token_len=300)
micro_batches = [micro_data_chunk.batch for micro_data_chunk in micro_data_chunks]
batch = torch.cat(micro_batches)
micro_bsz_idx = []
for idx in micro_bsz_idx_lst:
Expand Down
5 changes: 5 additions & 0 deletions verl/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,11 @@ def chunk(self, chunks: int) -> List["DataProto"]:

return output

def split(self, split_size: int) -> List["DataProto"]:
"""Split the DataProto into a list of DataProto."""
num_splits = -(-len(self) // split_size) # Ceiling
return self.chunk(chunks=num_splits)

@staticmethod
def concat(data: List["DataProto"]) -> "DataProto":
"""Concat a list of DataProto. The batch is concatenated among dim=0.
Expand Down
25 changes: 9 additions & 16 deletions verl/utils/seqlen_balancing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@

import copy
import heapq
from typing import List, Tuple
from typing import List, Optional, Tuple

import torch
from tensordict import TensorDict
from torch import distributed as dist

from verl import DataProto


def karmarkar_karp(seqlen_list: List[int], k_partitions: int, equal_size: bool):
# see: https://en.wikipedia.org/wiki/Largest_differencing_method
Expand Down Expand Up @@ -213,15 +214,15 @@ def ceildiv(a, b):
return -(a // -b)


def rearrange_micro_batches(batch: TensorDict, max_token_len, dp_group=None):
def get_uniform_data_chunks(data: DataProto, max_token_len: int, dp_group: Optional[dist.ProcessGroup] = None) -> tuple[list[DataProto], list[list[int]]]:
"""Split the batch into a list of micro_batches, where the max_token_len is smaller than max_token_len
and the number of valid tokens in each micro batch is well balanced.
"""
# this is per local micro_bsz
max_seq_len = batch["attention_mask"].shape[-1]
max_seq_len = data.batch["attention_mask"].shape[-1]
assert max_token_len >= max_seq_len, f"max_token_len must be greater than the sequence length. Got {max_token_len=} and {max_seq_len=}"

seq_len_effective: torch.Tensor = batch["attention_mask"].sum(dim=1)
seq_len_effective: torch.Tensor = data.batch["attention_mask"].sum(dim=1)
total_seqlen = seq_len_effective.sum().item()
num_micro_batches = ceildiv(total_seqlen, max_token_len)
if dist.is_initialized():
Expand All @@ -232,19 +233,11 @@ def rearrange_micro_batches(batch: TensorDict, max_token_len, dp_group=None):
seq_len_effective = seq_len_effective.tolist()
assert num_micro_batches <= len(seq_len_effective)

micro_bsz_idx = get_seqlen_balanced_partitions(seq_len_effective, num_micro_batches, equal_size=False)

micro_batches = []

for partition in micro_bsz_idx:
curr_micro_batch = []
for idx in partition:
curr_micro_batch.append(batch[idx : idx + 1])
curr_micro_batch = torch.cat(curr_micro_batch)
micro_idx_partitions = get_seqlen_balanced_partitions(seq_len_effective, num_micro_batches, equal_size=False)

micro_batches.append(curr_micro_batch)
uniform_data_chunks = [data.select_idxs(partition) for partition in micro_idx_partitions]

return micro_batches, micro_bsz_idx
return uniform_data_chunks, micro_idx_partitions


def get_reverse_idx(idx_map):
Expand Down
Loading