From 68bd27a85106cff04f3bb9b3eaffffa16ec2e5f3 Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Thu, 1 May 2025 04:12:50 +0800 Subject: [PATCH 1/6] feat: separate data, batch and metric with clear variable names --- recipe/prime/prime_dp_rm.py | 52 ++++++++++--------- verl/protocol.py | 5 ++ verl/utils/seqlen_balancing.py | 25 ++++----- verl/workers/actor/dp_actor.py | 87 ++++++++++++-------------------- verl/workers/critic/dp_critic.py | 85 +++++++++++++------------------ 5 files changed, 108 insertions(+), 146 deletions(-) diff --git a/recipe/prime/prime_dp_rm.py b/recipe/prime/prime_dp_rm.py index 59998ae54b2..cc2b39c6179 100644 --- a/recipe/prime/prime_dp_rm.py +++ b/recipe/prime/prime_dp_rm.py @@ -26,7 +26,7 @@ import verl.utils.torch_functional as verl_F from verl import DataProto from verl.utils.py_functional import append_to_dict -from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches +from verl.utils.seqlen_balancing import get_reverse_idx, get_uniform_data_chunks from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs from .prime_core_algos import compute_ce_dpo_loss_rm, compute_detach_dpo_loss_rm @@ -170,22 +170,24 @@ def compute_rm_score(self, data: DataProto): self.ref_module.eval() micro_batch_size = data.meta_info["micro_batch_size"] select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "acc"] - batch = data.select(batch_keys=select_keys).batch + selected_data = data.select(batch_keys=select_keys) use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] - prompt_length = data.batch["input_ids"].shape[-1] - data.batch["responses"].shape[-1] + batch = selected_data.batch + prompt_length = batch["input_ids"].shape[-1] - batch["responses"].shape[-1] if use_dynamic_bsz: # split using dynamic bsz max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size - micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) + micro_data_chunks, indices = get_uniform_data_chunks(data=selected_data, max_token_len=max_token_len) else: - micro_batches = batch.split(micro_batch_size) + num_micro_batches = len(selected_data) // micro_batch_size + micro_data_chunks = selected_data.chunk(num_micro_batches) rm_scores_lst = [] q_lst = [] - for micro_batch in micro_batches: + for micro_data_chunk in micro_data_chunks: with torch.no_grad(): - rm_score, q = self._forward_micro_batch(micro_batch, prompt_length) + rm_score, q = self._forward_micro_batch(micro_batch=micro_data_chunk.batch, prompt_length=prompt_length) rm_scores_lst.append(rm_score) q_lst.append(q) rm_scores = torch.concat(rm_scores_lst, dim=0) @@ -221,37 +223,37 @@ def update_rm(self, data: DataProto): if key in data.batch.keys(): select_keys.append(key) - batch = data.select(batch_keys=select_keys).batch + selected_data = data.select(batch_keys=select_keys) # Split to make minibatch iterator for updating the actor # See PPO paper for details. https://arxiv.org/abs/1707.06347 - dataloader = batch.split(self.config.mini_batch_size) + mini_dataloader = selected_data.split(split_size=self.config.mini_batch_size) rm_scores_lst = [] q_lst = [] - for batch_idx, data in enumerate(dataloader): - # split batch into micro_batches - mini_batch = data + for mini_idx, mini_data_chunk in enumerate(mini_dataloader): if self.config.use_dynamic_bsz: max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size - micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) + micro_data_chunks, _ = get_uniform_data_chunks(data=mini_data_chunk, max_token_len=max_token_len) else: - micro_batches = mini_batch.split(self.config.micro_batch_size_per_gpu) self.gradient_accumulation = self.config.mini_batch_size // self.config.micro_batch_size_per_gpu + num_micro_batches = len(mini_data_chunk) // self.config.micro_batch_size_per_gpu + micro_data_chunks = mini_data_chunk.chunk(num_micro_batches) self.reward_optimizer.zero_grad() - for data in micro_batches: - data = data.cuda() - attention_mask = data["attention_mask"] - acc = data["acc"] + for micro_data_chunk in micro_data_chunks: + micro_batch = micro_data_chunk.batch + micro_batch = micro_batch.cuda() + attention_mask = micro_batch["attention_mask"] + acc = micro_batch["acc"] - prompt_ids = data["prompts"] + prompt_ids = micro_batch["prompts"] prompt_length = prompt_ids.shape[-1] response_mask = attention_mask[:, prompt_length:] - rm_score, q = self._forward_micro_batch(data, prompt_length) + rm_score, q = self._forward_micro_batch(micro_batch=micro_batch, prompt_length=prompt_length) rm_scores_lst.append(rm_score) q_lst.append(q.detach()) @@ -285,21 +287,21 @@ def update_rm(self, data: DataProto): else: raise NotImplementedError - data = {"reward_model/dpo_loss": dpo_loss.detach().item()} + micro_metric_data = {"reward_model/dpo_loss": dpo_loss.detach().item()} if self.config.use_dynamic_bsz: # relative to the dynamic bsz - loss = dpo_loss * (len(data) / self.config.ppo_mini_batch_size) + loss = dpo_loss * (len(micro_data_chunk) / self.config.ppo_mini_batch_size) else: loss = dpo_loss / self.gradient_accumulation loss.backward() - append_to_dict(metrics, data) + append_to_dict(metrics, micro_metric_data) grad_norm = self._optimizer_step() - data = {"reward_model/grad_norm": grad_norm.detach().item()} - append_to_dict(metrics, data) + metric_data = {"reward_model/grad_norm": grad_norm.detach().item()} + append_to_dict(metrics, metric_data) self.reward_optimizer.zero_grad() rm_scores = torch.cat(rm_scores_lst, dim=0) diff --git a/verl/protocol.py b/verl/protocol.py index 57f494b7e1e..10beed04910 100644 --- a/verl/protocol.py +++ b/verl/protocol.py @@ -616,6 +616,11 @@ def chunk(self, chunks: int) -> List["DataProto"]: return output + def split(self, split_size: int) -> List["DataProto"]: + """Split the DataProto into a list of DataProto.""" + num_splits = -(-len(self) // split_size) # Ceiling + return self.chunk(chunks=num_splits) + @staticmethod def concat(data: List["DataProto"]) -> "DataProto": """Concat a list of DataProto. The batch is concatenated among dim=0. diff --git a/verl/utils/seqlen_balancing.py b/verl/utils/seqlen_balancing.py index 16dca31c961..25ece3b5bca 100644 --- a/verl/utils/seqlen_balancing.py +++ b/verl/utils/seqlen_balancing.py @@ -14,12 +14,13 @@ import copy import heapq -from typing import List, Tuple +from typing import List, Optional, Tuple import torch -from tensordict import TensorDict from torch import distributed as dist +from verl import DataProto + def karmarkar_karp(seqlen_list: List[int], k_partitions: int, equal_size: bool): # see: https://en.wikipedia.org/wiki/Largest_differencing_method @@ -213,15 +214,15 @@ def ceildiv(a, b): return -(a // -b) -def rearrange_micro_batches(batch: TensorDict, max_token_len, dp_group=None): +def get_uniform_data_chunks(data: DataProto, max_token_len: int, dp_group: Optional[dist.ProcessGroup] = None) -> tuple[list[DataProto], list[list[int]]]: """Split the batch into a list of micro_batches, where the max_token_len is smaller than max_token_len and the number of valid tokens in each micro batch is well balanced. """ # this is per local micro_bsz - max_seq_len = batch["attention_mask"].shape[-1] + max_seq_len = data.batch["attention_mask"].shape[-1] assert max_token_len >= max_seq_len, f"max_token_len must be greater than the sequence length. Got {max_token_len=} and {max_seq_len=}" - seq_len_effective: torch.Tensor = batch["attention_mask"].sum(dim=1) + seq_len_effective: torch.Tensor = data.batch["attention_mask"].sum(dim=1) total_seqlen = seq_len_effective.sum().item() num_micro_batches = ceildiv(total_seqlen, max_token_len) if dist.is_initialized(): @@ -232,19 +233,11 @@ def rearrange_micro_batches(batch: TensorDict, max_token_len, dp_group=None): seq_len_effective = seq_len_effective.tolist() assert num_micro_batches <= len(seq_len_effective) - micro_bsz_idx = get_seqlen_balanced_partitions(seq_len_effective, num_micro_batches, equal_size=False) - - micro_batches = [] - - for partition in micro_bsz_idx: - curr_micro_batch = [] - for idx in partition: - curr_micro_batch.append(batch[idx : idx + 1]) - curr_micro_batch = torch.cat(curr_micro_batch) + micro_idx_partitions = get_seqlen_balanced_partitions(seq_len_effective, num_micro_batches, equal_size=False) - micro_batches.append(curr_micro_batch) + uniform_data_chunks = [data.select_idxs(partition) for partition in micro_idx_partitions] - return micro_batches, micro_bsz_idx + return uniform_data_chunks, micro_idx_partitions def get_reverse_idx(idx_map): diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 3b26d81fe6d..a48bac0ef4c 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -32,7 +32,7 @@ from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, kl_penalty from verl.utils.debug import GPUMemoryLogger from verl.utils.py_functional import append_to_dict -from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches +from verl.utils.seqlen_balancing import get_reverse_idx, get_uniform_data_chunks from verl.utils.torch_functional import logprobs_from_logits from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs from verl.workers.actor import BasePPOActor @@ -194,30 +194,25 @@ def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Te # set to eval self.actor_module.eval() - micro_batch_size = data.meta_info["micro_batch_size"] temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid slient error use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] - batch = data.select(batch_keys=select_keys).batch - has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() - - if has_multi_modal_inputs: - num_micro_batches = data.batch.batch_size[0] // micro_batch_size - non_tensor_select_keys = ["multi_modal_inputs"] - micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) - elif use_dynamic_bsz: - # split using dynamic bsz + non_tensor_select_keys = ["multi_modal_inputs"] if "multi_modal_inputs" in data.non_tensor_batch.keys() else [] + + selected_data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) + + if use_dynamic_bsz: max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size - micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) + micro_data_chunks, indices = get_uniform_data_chunks(data=selected_data, max_token_len=max_token_len) else: - micro_batches = batch.split(micro_batch_size) + micro_batch_size = data.meta_info["micro_batch_size"] + micro_data_chunks = selected_data.split(split_size=micro_batch_size) log_probs_lst = [] entropy_lst = [] - for micro_batch in micro_batches: - if isinstance(micro_batch, DataProto): - micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch} + for micro_data_chunk in micro_data_chunks: + micro_batch = {**micro_data_chunk.batch, **micro_data_chunk.non_tensor_batch} with torch.no_grad(): entropy, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature, calculate_entropy=calculate_entropy) log_probs_lst.append(log_probs) @@ -249,53 +244,36 @@ def update_policy(self, data: DataProto): select_keys.append("loss_mask") if self.config.use_kl_loss: select_keys.append("ref_log_prob") - batch = data.select(batch_keys=select_keys).batch - has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + non_tensor_select_keys = ["multi_modal_inputs"] if "multi_modal_inputs" in data.non_tensor_batch.keys() else [] + + selected_data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) # Split to make minibatch iterator for updating the actor # See PPO paper for details. https://arxiv.org/abs/1707.06347 - if has_multi_modal_inputs: - num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size - non_tensor_select_keys = ["multi_modal_inputs"] - dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches) - else: - dataloader = batch.split(self.config.ppo_mini_batch_size) + mini_dataloader = selected_data.split(split_size=self.config.ppo_mini_batch_size) # TODO: `make_minibatch_iterator`` as in megatron metrics = {} for epoch in range(self.config.ppo_epochs): - for batch_idx, data in enumerate(dataloader): + for mini_idx, mini_data_chunk in enumerate(mini_dataloader): # split batch into micro_batches - mini_batch = data - if has_multi_modal_inputs: - self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu - num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu - micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) - elif self.config.use_dynamic_bsz: + if self.config.use_dynamic_bsz: max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size - micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) + micro_data_chunks, _ = get_uniform_data_chunks(data=mini_data_chunk, max_token_len=max_token_len) else: - self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu - # split batch into micro_batches - micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) + micro_data_chunks = mini_data_chunk.split(split_size=self.config.ppo_micro_batch_size_per_gpu) self.actor_optimizer.zero_grad() - for data in micro_batches: - # Support all hardwares - if isinstance(data, DataProto): - data = {**data.batch.to(torch.cuda.current_device()), **data.non_tensor_batch} - else: - data = data.to(torch.cuda.current_device()) # actor device is cpu when using offload - responses = data["responses"] - response_length = responses.size(1) - attention_mask = data["attention_mask"] + for micro_data_chunk in micro_data_chunks: + micro_batch = {**micro_data_chunk.batch.to(torch.cuda.current_device()), **micro_data_chunk.non_tensor_batch} + response_length = micro_batch["responses"].size(-1) + attention_mask = micro_batch["attention_mask"] if multi_turn: response_mask = data["loss_mask"][:, -response_length:] else: response_mask = attention_mask[:, -response_length:] - - old_log_prob = data["old_log_probs"] - advantages = data["advantages"] + old_log_prob = micro_batch["old_log_probs"] + advantages = micro_batch["advantages"] clip_ratio = self.config.clip_ratio clip_ratio_low = self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio @@ -308,7 +286,7 @@ def update_policy(self, data: DataProto): calculate_entropy = False if entropy_coeff != 0: calculate_entropy = True - entropy, log_prob = self._forward_micro_batch(micro_batch=data, temperature=temperature, calculate_entropy=calculate_entropy) + entropy, log_prob = self._forward_micro_batch(micro_batch=micro_batch, temperature=temperature, calculate_entropy=calculate_entropy) pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss( old_log_prob=old_log_prob, @@ -331,7 +309,7 @@ def update_policy(self, data: DataProto): policy_loss = pg_loss if self.config.use_kl_loss: - ref_log_prob = data["ref_log_prob"] + ref_log_prob = micro_batch["ref_log_prob"] # compute kl loss kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type) kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode) @@ -344,19 +322,20 @@ def update_policy(self, data: DataProto): # relative to the dynamic bsz loss = policy_loss * (len(data) / self.config.ppo_mini_batch_size) else: - loss = policy_loss / self.gradient_accumulation + loss = policy_loss / len(micro_data_chunks) loss.backward() - data = { + mini_metric_data = { + "actor/entropy": entropy_loss.detach().item(), "actor/pg_loss": pg_loss.detach().item(), "actor/pg_clipfrac": pg_clipfrac.detach().item(), "actor/ppo_kl": ppo_kl.detach().item(), "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), } - append_to_dict(metrics, data) + append_to_dict(metrics, mini_metric_data) grad_norm = self._optimizer_step() - data = {"actor/grad_norm": grad_norm.detach().item()} - append_to_dict(metrics, data) + metric_data = {"actor/grad_norm": grad_norm.detach().item()} + append_to_dict(metrics, metric_data) self.actor_optimizer.zero_grad() return metrics diff --git a/verl/workers/critic/dp_critic.py b/verl/workers/critic/dp_critic.py index f87dc3a9cf4..dbe018eeb44 100644 --- a/verl/workers/critic/dp_critic.py +++ b/verl/workers/critic/dp_critic.py @@ -29,7 +29,7 @@ from verl.trainer.ppo import core_algos from verl.utils.debug import GPUMemoryLogger from verl.utils.py_functional import append_to_dict -from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches +from verl.utils.seqlen_balancing import get_reverse_idx, get_uniform_data_chunks from verl.utils.torch_functional import masked_mean from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs from verl.workers.critic import BasePPOCritic @@ -128,27 +128,24 @@ def _optimizer_step(self): @GPUMemoryLogger(role="dp critic", logger=logger) def compute_values(self, data: DataProto) -> torch.Tensor: self.critic_module.eval() - micro_batch_size = data.meta_info["micro_batch_size"] + select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] - batch = data.select(batch_keys=select_keys).batch + non_tensor_select_keys = ["multi_modal_inputs"] if "multi_modal_inputs" in data.non_tensor_batch.keys() else [] + + selected_data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) + use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] - has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() - - if has_multi_modal_inputs: - num_micro_batches = data.batch.batch_size[0] // micro_batch_size - non_tensor_select_keys = ["multi_modal_inputs"] - micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) - elif use_dynamic_bsz: - # split using dynamic bsz + + if use_dynamic_bsz: max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size - micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) + micro_data_chunks, indices = get_uniform_data_chunks(data=selected_data, max_token_len=max_token_len) else: - micro_batches = batch.split(micro_batch_size) + micro_batch_size = data.meta_info["micro_batch_size"] + micro_data_chunks = selected_data.split(split_size=micro_batch_size) values_lst = [] - for micro_batch in micro_batches: - if isinstance(micro_batch, DataProto): - micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch} + for micro_data_chunk in micro_data_chunks: + micro_batch = {**micro_data_chunk.batch, **micro_data_chunk.non_tensor_batch} with torch.no_grad(): values = self._forward_micro_batch(micro_batch) @@ -174,49 +171,35 @@ def update_critic(self, data: DataProto): metrics = {} select_keys = ["input_ids", "responses", "attention_mask", "position_ids", "values", "returns"] - batch = data.select(batch_keys=select_keys).batch - has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + non_tensor_select_keys = ["multi_modal_inputs"] if "multi_modal_inputs" in data.non_tensor_batch.keys() else [] + + selected_data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) # Split to make minibatch iterator for updating the actor # See PPO paper for details. https://arxiv.org/abs/1707.06347 - if has_multi_modal_inputs: - num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size - non_tensor_select_keys = ["multi_modal_inputs"] - dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches) - else: - dataloader = batch.split(self.config.ppo_mini_batch_size) + mini_dataloader = selected_data.split(split_size=self.config.ppo_mini_batch_size) # TODO: `make_minibatch_iterator`` as in megatron for epoch in range(self.config.ppo_epochs): - for batch_idx, data in enumerate(dataloader): - # split batch into micro_batches - mini_batch = data - if has_multi_modal_inputs: - num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu - micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) - elif self.config.use_dynamic_bsz: + for mini_idx, mini_data_chunk in enumerate(mini_dataloader): + if self.config.use_dynamic_bsz: max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size - micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) + micro_data_chunks, _ = get_uniform_data_chunks(data=mini_data_chunk, max_token_len=max_token_len) else: - micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) - self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + micro_data_chunks = mini_data_chunk.split(split_size=self.config.ppo_micro_batch_size_per_gpu) self.critic_optimizer.zero_grad() - for data in micro_batches: - # Support all devices - if isinstance(data, DataProto): - data = {**data.batch.to(torch.cuda.current_device()), **data.non_tensor_batch} - else: - data = data.to(torch.cuda.current_device()) # critic device is cpu when using offload - responses = data["responses"] - attention_mask = data["attention_mask"] - values = data["values"] - returns = data["returns"] - response_length = responses.size(1) + for micro_data_chunk in micro_data_chunks: + micro_batch = {**micro_data_chunk.batch, **micro_data_chunk.non_tensor_batch} - response_mask = attention_mask[:, -response_length - 1 : -1] + values = micro_batch["values"] + returns = micro_batch["returns"] - vpreds = self._forward_micro_batch(data) + responses = micro_batch["responses"] + response_length = responses.size(1) + attention_mask = micro_batch["attention_mask"] + response_mask = attention_mask[:, -response_length - 1 : -1] + vpreds = self._forward_micro_batch(micro_batch) # assert not torch.any(torch.isnan(vpreds)).item() @@ -235,16 +218,16 @@ def update_critic(self, data: DataProto): loss.backward() - data = { + mini_metric_data = { "critic/vf_loss": vf_loss.detach().item(), "critic/vf_clipfrac": vf_clipfrac.detach().item(), "critic/vpred_mean": masked_mean(vpreds, response_mask).detach().item(), } - append_to_dict(metrics, data) + append_to_dict(metrics, mini_metric_data) grad_norm = self._optimizer_step() - data = {"critic/grad_norm": grad_norm.detach().item()} - append_to_dict(metrics, data) + metric_data = {"critic/grad_norm": grad_norm.detach().item()} + append_to_dict(metrics, metric_data) self.critic_optimizer.zero_grad() return metrics From 4b3ec194b7592b0de199315a51c886a334845b01 Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Thu, 1 May 2025 04:17:09 +0800 Subject: [PATCH 2/6] fix: unify to split --- recipe/prime/prime_dp_rm.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/recipe/prime/prime_dp_rm.py b/recipe/prime/prime_dp_rm.py index cc2b39c6179..1e48e989c24 100644 --- a/recipe/prime/prime_dp_rm.py +++ b/recipe/prime/prime_dp_rm.py @@ -236,9 +236,7 @@ def update_rm(self, data: DataProto): max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size micro_data_chunks, _ = get_uniform_data_chunks(data=mini_data_chunk, max_token_len=max_token_len) else: - self.gradient_accumulation = self.config.mini_batch_size // self.config.micro_batch_size_per_gpu - num_micro_batches = len(mini_data_chunk) // self.config.micro_batch_size_per_gpu - micro_data_chunks = mini_data_chunk.chunk(num_micro_batches) + micro_data_chunks = mini_data_chunk.split(split_size=self.config.micro_batch_size_per_gpu) self.reward_optimizer.zero_grad() @@ -293,7 +291,7 @@ def update_rm(self, data: DataProto): # relative to the dynamic bsz loss = dpo_loss * (len(micro_data_chunk) / self.config.ppo_mini_batch_size) else: - loss = dpo_loss / self.gradient_accumulation + loss = dpo_loss / len(micro_data_chunks) loss.backward() From 66caa467eb9c583336ffe1f116e75ac48fff7db5 Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Thu, 1 May 2025 04:26:50 +0800 Subject: [PATCH 3/6] fix: fsdp_workers and tests --- tests/utility/test_tensor_dict_utilities.py | 5 +++-- verl/workers/fsdp_workers.py | 10 +++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/utility/test_tensor_dict_utilities.py b/tests/utility/test_tensor_dict_utilities.py index befa5bd1e3c..9308927457e 100644 --- a/tests/utility/test_tensor_dict_utilities.py +++ b/tests/utility/test_tensor_dict_utilities.py @@ -276,7 +276,7 @@ def test_len(): def test_seqlen_balancing(): - from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches + from verl.utils.seqlen_balancing import get_reverse_idx, get_uniform_data_chunks input_ids = torch.randint(low=0, high=10, size=(20, 100)) from verl.utils.model import create_random_mask @@ -284,7 +284,8 @@ def test_seqlen_balancing(): attention_mask = create_random_mask(input_ids=input_ids, max_ratio_of_left_padding=0.1, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.5) data = {"input_ids": input_ids, "attention_mask": attention_mask} dataproto = DataProto.from_single_dict(data) - micro_batches, micro_bsz_idx_lst = rearrange_micro_batches(dataproto.batch, max_token_len=300) + micro_data_chunks, micro_bsz_idx_lst = get_uniform_data_chunks(dataproto, max_token_len=300) + micro_batches = [micro_data_chunk.batch for micro_data_chunk in micro_data_chunks] batch = torch.cat(micro_batches) micro_bsz_idx = [] for idx in micro_bsz_idx_lst: diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index c1b501c0cb9..c028db57147 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -1195,7 +1195,7 @@ def _switch_chat_template(self, data: DataProto): def compute_rm_score(self, data: DataProto): import itertools - from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches + from verl.utils.seqlen_balancing import get_reverse_idx, get_uniform_data_chunks # Support all hardwares data = data.to(torch.cuda.current_device()) @@ -1223,12 +1223,12 @@ def compute_rm_score(self, data: DataProto): use_dynamic_bsz = self.config.use_dynamic_bsz if use_dynamic_bsz: max_token_len = self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size - micro_batches, indices = rearrange_micro_batches(batch=rm_data.batch, max_token_len=max_token_len) + micro_data_chunks, indices = get_uniform_data_chunks(data=rm_data, max_token_len=max_token_len) else: - micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu) + micro_data_chunks = rm_data.split(split_size=self.config.micro_batch_size_per_gpu) output = [] - for micro_batch in micro_batches: - rm_score = self._forward_micro_batch(micro_batch) + for micro_data_chunk in micro_data_chunks: + rm_score = self._forward_micro_batch(micro_batch=micro_data_chunk.batch) output.append(rm_score) scores = torch.cat(output, dim=0) # (batch_size) From 35870818ee72adbbe7badfd3ad760244e77d0177 Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Thu, 1 May 2025 05:17:02 +0800 Subject: [PATCH 4/6] fix: no logging entropy_loss --- recipe/prime/prime_core_algos.py | 17 +++++++++++------ recipe/prime/prime_ray_trainer.py | 15 ++------------- verl/workers/actor/dp_actor.py | 2 -- 3 files changed, 13 insertions(+), 21 deletions(-) diff --git a/recipe/prime/prime_core_algos.py b/recipe/prime/prime_core_algos.py index c17c668a723..8269ef0ad0d 100644 --- a/recipe/prime/prime_core_algos.py +++ b/recipe/prime/prime_core_algos.py @@ -13,13 +13,18 @@ # limitations under the License. import torch +from omegaconf import DictConfig import verl import verl.utils.torch_functional as verl_F +from verl.trainer.ppo.ray_trainer import compute_response_mask -def compute_rloo_advantage_return(data: verl.DataProto, response_mask: torch.Tensor, n_samples, config): +def compute_rloo_advantage_return(data: verl.DataProto, config: DictConfig): # calculate rloo reward on different reward sources, and sum again + action_mask = compute_response_mask(data) + n_samples = config.actor_rollout_ref.rollout.n + def masked_rloo(reward_tensor_original, mask_tensor): reward_tensor = reward_tensor_original.clone() reward_tensor[~mask_tensor] = 0 @@ -39,13 +44,13 @@ def masked_rloo(reward_tensor_original, mask_tensor): with torch.no_grad(): if "rm_scores" in data.batch.keys() and config.algorithm.reward_dpo_coef != 0.0: reward_tensor = data.batch["rm_scores"] - reward_mask = response_mask.bool() + reward_mask = action_mask.bool() reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_dpo_coef) if "acc" in data.batch.keys() and config.algorithm.reward_gt_coef != 0.0: - reward_tensor = torch.zeros_like(response_mask, dtype=torch.float32) - reward_mask = torch.zeros_like(response_mask, dtype=torch.bool) + reward_tensor = torch.zeros_like(action_mask, dtype=torch.float32) + reward_mask = torch.zeros_like(action_mask, dtype=torch.bool) prompt_ids = data.batch["prompts"] prompt_length = prompt_ids.shape[-1] @@ -64,10 +69,10 @@ def masked_rloo(reward_tensor_original, mask_tensor): final_reward_tensor = sum(reward_tensors) - returns = (final_reward_tensor * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) + returns = (final_reward_tensor * action_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) advantages = returns.clone() - advantages = verl_F.masked_whiten(advantages, response_mask) + advantages = verl_F.masked_whiten(advantages, action_mask) return advantages, returns diff --git a/recipe/prime/prime_ray_trainer.py b/recipe/prime/prime_ray_trainer.py index 724e91cd6e9..ae16fa0c0e9 100644 --- a/recipe/prime/prime_ray_trainer.py +++ b/recipe/prime/prime_ray_trainer.py @@ -30,7 +30,7 @@ from verl.single_controller.ray import RayWorkerGroup from verl.trainer.ppo.core_algos import agg_loss from verl.trainer.ppo.metric_utils import _compute_response_info -from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role, WorkerType, _timer, reduce_metrics +from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role, WorkerType, _timer, compute_response_mask, reduce_metrics from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn @@ -39,11 +39,7 @@ def compute_advantage(data: DataProto, adv_estimator, config): if adv_estimator == "rloo": - responses = data.batch["responses"] - response_length = responses.size(-1) - attention_mask = data.batch["attention_mask"] - response_mask = attention_mask[:, -response_length:] - advantages, returns = prime_core_algos.compute_rloo_advantage_return(data, response_mask, config.actor_rollout_ref.rollout.n, config) + advantages, returns = prime_core_algos.compute_rloo_advantage_return(data, config=config) data.batch["advantages"] = advantages data.batch["returns"] = returns else: @@ -110,13 +106,6 @@ def compute_data_metrics(batch, use_critic=True): return metrics -def compute_response_mask(data: DataProto): - responses = data.batch["responses"] - response_length = responses.size(1) - attention_mask = data.batch["attention_mask"] - return attention_mask[:, -response_length:] - - def compute_timing_metrics(batch, timing_raw): response_info = _compute_response_info(batch) num_prompt_tokens = torch.sum(response_info["prompt_length"]).item() diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index a48bac0ef4c..b99991c2154 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -302,7 +302,6 @@ def update_policy(self, data: DataProto): if entropy_coeff != 0: entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) - # compute policy loss policy_loss = pg_loss - entropy_loss * entropy_coeff else: @@ -326,7 +325,6 @@ def update_policy(self, data: DataProto): loss.backward() mini_metric_data = { - "actor/entropy": entropy_loss.detach().item(), "actor/pg_loss": pg_loss.detach().item(), "actor/pg_clipfrac": pg_clipfrac.detach().item(), "actor/ppo_kl": ppo_kl.detach().item(), From e4d7dfb6c6114fc6a7aee927ca1d13c3f3aa0154 Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Thu, 1 May 2025 05:17:30 +0800 Subject: [PATCH 5/6] fix: self.gradient_accumulation --- verl/workers/critic/dp_critic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verl/workers/critic/dp_critic.py b/verl/workers/critic/dp_critic.py index dbe018eeb44..8595739b20e 100644 --- a/verl/workers/critic/dp_critic.py +++ b/verl/workers/critic/dp_critic.py @@ -214,7 +214,7 @@ def update_critic(self, data: DataProto): # relative to the dynamic bsz loss = vf_loss * (len(data) / self.config.ppo_mini_batch_size) else: - loss = vf_loss / self.gradient_accumulation + loss = vf_loss / len(micro_data_chunks) loss.backward() From 730d120a16cb01f31852ca8bef15f77f5f712d27 Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Thu, 1 May 2025 07:34:51 +0800 Subject: [PATCH 6/6] fix: loss_mask --- verl/workers/actor/dp_actor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index b99991c2154..a3097cb32db 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -269,7 +269,7 @@ def update_policy(self, data: DataProto): response_length = micro_batch["responses"].size(-1) attention_mask = micro_batch["attention_mask"] if multi_turn: - response_mask = data["loss_mask"][:, -response_length:] + response_mask = micro_batch["loss_mask"][:, -response_length:] else: response_mask = attention_mask[:, -response_length:] old_log_prob = micro_batch["old_log_probs"]