diff --git a/examples/split_placement/split_monkey_patch.py b/examples/split_placement/split_monkey_patch.py index d7bf25380a6..f90ab2ba639 100644 --- a/examples/split_placement/split_monkey_patch.py +++ b/examples/split_placement/split_monkey_patch.py @@ -23,15 +23,7 @@ import torch from verl import DataProto -from verl.trainer.ppo.ray_trainer import ( - AdvantageEstimator, - _timer, - apply_kl_penalty, - compute_advantage, - compute_data_metrics, - compute_timing_metrics, - reduce_metrics, -) +from verl.trainer.ppo.ray_trainer import AdvantageEstimator, _timer, apply_kl_penalty, calc_mini_batch_loss_token_nums, compute_advantage, compute_data_metrics, compute_timing_metrics, reduce_metrics def fit(self): @@ -115,6 +107,8 @@ def fit(self): # compute global_valid tokens batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + batch.meta_info["mini_batch_loss_token_nums"] = calc_mini_batch_loss_token_nums(batch, traj_mini_bsz=self.config.actor_rollout_ref.actor.ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n, num_dp_ranks=self.actor_rollout_wg.world_size) + # recompute old_log_probs with _timer("old_log_prob", timing_raw): old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) diff --git a/recipe/dapo/src/dapo_ray_trainer.py b/recipe/dapo/src/dapo_ray_trainer.py index 881c3256e9e..faf017e166a 100644 --- a/recipe/dapo/src/dapo_ray_trainer.py +++ b/recipe/dapo/src/dapo_ray_trainer.py @@ -32,7 +32,7 @@ compute_timing_metrics, reduce_metrics, ) -from verl.trainer.ppo.ray_trainer import AdvantageEstimator, RayPPOTrainer, _timer, apply_kl_penalty, compute_advantage +from verl.trainer.ppo.ray_trainer import AdvantageEstimator, RayPPOTrainer, _timer, apply_kl_penalty, calc_mini_batch_loss_token_nums, compute_advantage class RayDAPOTrainer(RayPPOTrainer): @@ -216,6 +216,8 @@ def fit(self): # compute global_valid tokens batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + batch.meta_info["mini_batch_loss_token_nums"] = calc_mini_batch_loss_token_nums(batch, traj_mini_bsz=self.config.actor_rollout_ref.actor.ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n, num_dp_ranks=self.actor_rollout_wg.world_size) + # recompute old_log_probs with _timer("old_log_prob", timing_raw): old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) diff --git a/recipe/prime/prime_dp_rm.py b/recipe/prime/prime_dp_rm.py index 59998ae54b2..dccc31aefc1 100644 --- a/recipe/prime/prime_dp_rm.py +++ b/recipe/prime/prime_dp_rm.py @@ -26,7 +26,7 @@ import verl.utils.torch_functional as verl_F from verl import DataProto from verl.utils.py_functional import append_to_dict -from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches +from verl.utils.seqlen_balancing import get_reverse_idx, get_uniform_data_chunks from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs from .prime_core_algos import compute_ce_dpo_loss_rm, compute_detach_dpo_loss_rm @@ -170,22 +170,24 @@ def compute_rm_score(self, data: DataProto): self.ref_module.eval() micro_batch_size = data.meta_info["micro_batch_size"] select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "acc"] - batch = data.select(batch_keys=select_keys).batch + selected_data = data.select(batch_keys=select_keys) use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] - prompt_length = data.batch["input_ids"].shape[-1] - data.batch["responses"].shape[-1] + batch = selected_data.batch + prompt_length = batch["input_ids"].shape[-1] - batch["responses"].shape[-1] if use_dynamic_bsz: # split using dynamic bsz max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size - micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) + micro_data_chunks, indices = get_uniform_data_chunks(data=selected_data, max_token_len=max_token_len) else: - micro_batches = batch.split(micro_batch_size) + num_micro_batches = len(selected_data) // micro_batch_size + micro_data_chunks = selected_data.chunk(num_micro_batches) rm_scores_lst = [] q_lst = [] - for micro_batch in micro_batches: + for micro_data_chunk in micro_data_chunks: with torch.no_grad(): - rm_score, q = self._forward_micro_batch(micro_batch, prompt_length) + rm_score, q = self._forward_micro_batch(micro_batch=micro_data_chunk.batch, prompt_length=prompt_length) rm_scores_lst.append(rm_score) q_lst.append(q) rm_scores = torch.concat(rm_scores_lst, dim=0) @@ -221,37 +223,38 @@ def update_rm(self, data: DataProto): if key in data.batch.keys(): select_keys.append(key) - batch = data.select(batch_keys=select_keys).batch + selected_data = data.select(batch_keys=select_keys) # Split to make minibatch iterator for updating the actor # See PPO paper for details. https://arxiv.org/abs/1707.06347 - dataloader = batch.split(self.config.mini_batch_size) + num_mini_batches = len(selected_data) // self.config.mini_batch_size + mini_dataloader = selected_data.chunk(num_mini_batches) rm_scores_lst = [] q_lst = [] - for batch_idx, data in enumerate(dataloader): - # split batch into micro_batches - mini_batch = data + for mini_idx, mini_data_chunk in enumerate(mini_dataloader): if self.config.use_dynamic_bsz: max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size - micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) + micro_data_chunks, _ = get_uniform_data_chunks(data=mini_data_chunk, max_token_len=max_token_len) else: - micro_batches = mini_batch.split(self.config.micro_batch_size_per_gpu) self.gradient_accumulation = self.config.mini_batch_size // self.config.micro_batch_size_per_gpu + num_micro_batches = len(mini_data_chunk) // self.config.micro_batch_size_per_gpu + micro_data_chunks = mini_data_chunk.chunk(num_micro_batches) self.reward_optimizer.zero_grad() - for data in micro_batches: - data = data.cuda() - attention_mask = data["attention_mask"] - acc = data["acc"] + for micro_data_chunk in micro_data_chunks: + micro_batch = micro_data_chunk.batch + micro_batch = micro_batch.cuda() + attention_mask = micro_batch["attention_mask"] + acc = micro_batch["acc"] - prompt_ids = data["prompts"] + prompt_ids = micro_batch["prompts"] prompt_length = prompt_ids.shape[-1] response_mask = attention_mask[:, prompt_length:] - rm_score, q = self._forward_micro_batch(data, prompt_length) + rm_score, q = self._forward_micro_batch(micro_batch=micro_batch, prompt_length=prompt_length) rm_scores_lst.append(rm_score) q_lst.append(q.detach()) @@ -289,7 +292,7 @@ def update_rm(self, data: DataProto): if self.config.use_dynamic_bsz: # relative to the dynamic bsz - loss = dpo_loss * (len(data) / self.config.ppo_mini_batch_size) + loss = dpo_loss * (len(micro_data_chunk) / self.config.ppo_mini_batch_size) else: loss = dpo_loss / self.gradient_accumulation diff --git a/recipe/prime/prime_ray_trainer.py b/recipe/prime/prime_ray_trainer.py index 724e91cd6e9..bbe2e758b26 100644 --- a/recipe/prime/prime_ray_trainer.py +++ b/recipe/prime/prime_ray_trainer.py @@ -30,7 +30,7 @@ from verl.single_controller.ray import RayWorkerGroup from verl.trainer.ppo.core_algos import agg_loss from verl.trainer.ppo.metric_utils import _compute_response_info -from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role, WorkerType, _timer, reduce_metrics +from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role, WorkerType, _timer, calc_mini_batch_loss_token_nums, reduce_metrics from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn @@ -379,6 +379,8 @@ def fit(self): # compute global_valid tokens batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + batch.meta_info["mini_batch_loss_token_nums"] = calc_mini_batch_loss_token_nums(batch, traj_mini_bsz=self.config.actor_rollout_ref.actor.ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n, num_dp_ranks=self.actor_rollout_wg.world_size) + # verify with _timer("verify", timing_raw): scores = self.reward_fn.verify(batch) diff --git a/tests/e2e/grad_accum/grad_accum_test_fsdp_workers.py b/tests/e2e/grad_accum/grad_accum_test_fsdp_workers.py new file mode 100644 index 00000000000..a8d4bbf0fa5 --- /dev/null +++ b/tests/e2e/grad_accum/grad_accum_test_fsdp_workers.py @@ -0,0 +1,353 @@ +import logging +import os +from collections import defaultdict +from typing import Any, Optional + +import pandas as pd +import torch +from omegaconf import open_dict + +from verl import DataProto +from verl.single_controller.base.decorator import Dispatch, register +from verl.trainer.ppo import core_algos +from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager +from verl.utils.debug import log_gpu_memory_usage +from verl.utils.flops_counter import FlopsCounter +from verl.utils.fsdp_utils import offload_fsdp_model_to_cpu, offload_fsdp_optimizer +from verl.utils.import_utils import import_external_libs +from verl.utils.seqlen_balancing import get_uniform_data_chunks +from verl.utils.torch_functional import compute_response_mask +from verl.workers.actor.dp_actor import DataParallelPPOActor +from verl.workers.critic.dp_critic import DataParallelPPOCritic +from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker + +ALL_LOSS_AGG_MODES: list[str] = ["token-mean", "seq-mean-token-sum", "seq-mean-token-mean"] +GRAD_ACCUM_RTOL: float = 0.01 + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN")) + + +class GradAccumulationTestActorRolloutRefWorker(ActorRolloutRefWorker): + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): + # This is used to import external_lib into the huggingface systems + import_external_libs(self.config.model.get("external_lib", None)) + + from omegaconf import OmegaConf + + override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) + + use_remove_padding = self.config.model.get("use_remove_padding", False) + + if self._is_actor or self._is_rollout: + # we need the model for actor and rollout + if self._is_actor: + optim_config = self.config.actor.optim + fsdp_config = self.config.actor.fsdp_config + else: + optim_config = None + fsdp_config = OmegaConf.create() + self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = self._build_model_optimizer( + model_path=self.config.model.path, + fsdp_config=fsdp_config, + optim_config=optim_config, + override_model_config=override_model_config, + use_remove_padding=use_remove_padding, + enable_gradient_checkpointing=self.config.model.get("enable_gradient_checkpointing", False), + trust_remote_code=self.config.model.get("trust_remote_code", False), + use_liger=self.config.model.get("use_liger", False), + role="actor", + ) + + # get the original unwrapped module + self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module + + if self._is_offload_optimizer: + offload_fsdp_optimizer(optimizer=self.actor_optimizer) + log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) + # load from checkpoint + if self._is_actor: + OmegaConf.set_struct(self.config.actor, True) + with open_dict(self.config.actor): + self.config.actor.use_remove_padding = use_remove_padding + self.actor = GradAccumulationTestDPActor(config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer) + + if self._is_rollout: + self.rollout, self.rollout_sharding_manager = self._build_rollout(trust_remote_code=self.config.model.get("trust_remote_code", False)) + + if self._is_ref: + self.ref_module_fsdp = self._build_model_optimizer( + model_path=self.config.model.path, + fsdp_config=self.config.ref.fsdp_config, + optim_config=None, + override_model_config=override_model_config, + use_remove_padding=use_remove_padding, + trust_remote_code=self.config.model.get("trust_remote_code", False), + use_liger=self.config.model.get("use_liger", False), + role="ref", + )[0] + OmegaConf.set_struct(self.config.ref, True) + with open_dict(self.config.ref): + self.config.ref.use_remove_padding = use_remove_padding + self.ref_policy = GradAccumulationTestDPActor(config=self.config.ref, actor_module=self.ref_module_fsdp) + + if self._is_actor: + self.flops_counter = FlopsCounter(self.actor_model_config) + self.checkpoint_manager = FSDPCheckpointManager( + model=self.actor_module_fsdp, optimizer=self.actor.actor_optimizer, lr_scheduler=self.actor_lr_scheduler, processing_class=self.processor if self.processor is not None else self.tokenizer, checkpoint_contents=self.config.actor.checkpoint.contents + ) + + +class GradAccumulationTestCriticWorker(CriticWorker): + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): + # This is used to import external_lib into the huggingface systems + import_external_libs(self.config.model.get("external_lib", None)) + + from verl.workers.critic import DataParallelPPOCritic + + self.critic_module, self.critic_optimizer, self.critic_lr_scheduler = self._build_critic_model_optimizer(self.config) + + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.critic_module) + if self._is_offload_optimizer: + offload_fsdp_optimizer(optimizer=self.critic_optimizer) + + self.critic = DataParallelPPOCritic(config=self.config, critic_module=self.critic_module, critic_optimizer=self.critic_optimizer) + + self.flops_counter = FlopsCounter(self.critic_model_config) + self.checkpoint_manager = FSDPCheckpointManager(model=self.critic_module, optimizer=self.critic_optimizer, lr_scheduler=self.critic_lr_scheduler, processing_class=self.processor if self.processor is not None else self.tokenizer, checkpoint_contents=self.config.checkpoint.contents) + + +class GradAccumulationTestDPActor(DataParallelPPOActor): + def compute_batch_loss(self, data: DataProto, loss_agg_mode: str = "token-mean", mini_batch_loss_token_num: Optional[int] = None, disable_grad_accum: bool = False) -> tuple[torch.Tensor, int]: + accum_loss = None + temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid silent error + + if disable_grad_accum: + micro_data_chunks = [data] + else: + if self.config.use_dynamic_bsz: + max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size + micro_data_chunks, _ = get_uniform_data_chunks(data=data, max_token_len=max_token_len) + else: + num_micro_batches = len(data) // self.config.ppo_micro_batch_size_per_gpu + micro_data_chunks = data.chunk(num_micro_batches) + + assert len(micro_data_chunks) > 1, f"len(micro_data_chunks) must be greater than 1 to test grad accumulation, but got {len(micro_data_chunks)=}" + + micro_weights = [] + raw_micro_losses = [] + for micro_data_chunk in micro_data_chunks: + micro_batch = {**micro_data_chunk.batch, **micro_data_chunk.non_tensor_batch} + + response_mask = compute_response_mask(response_ids=micro_batch["responses"], attention_mask=micro_batch["attention_mask"]) + old_log_prob = micro_batch["old_log_probs"] + advantages = micro_batch["advantages"] + + clip_ratio = self.config.clip_ratio + clip_ratio_low = self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio + clip_ratio_high = self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio + clip_ratio_c = self.config.get("clip_ratio_c", 3.0) + entropy_coeff = self.config.entropy_coeff + + # all return: (bsz, response_length) + entropy, log_prob = self._forward_micro_batch(micro_batch=micro_batch, temperature=temperature) + + pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = core_algos.compute_policy_loss( + old_log_prob=old_log_prob, log_prob=log_prob, advantages=advantages, response_mask=response_mask, cliprange=clip_ratio, cliprange_low=clip_ratio_low, cliprange_high=clip_ratio_high, clip_ratio_c=clip_ratio_c + ) + loss = pg_loss + + # compute entropy loss from entropy + entropy_loss = core_algos.compute_entropy_loss(entropy=entropy, response_mask=response_mask, loss_agg_mode=loss_agg_mode) + loss += -entropy_loss * entropy_coeff + + if self.config.use_kl_loss: + ref_log_prob = micro_batch["ref_log_prob"] + # compute kl loss + kld = core_algos.kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type) + kl_loss = core_algos.compute_kl_loss(kld=kld, response_mask=response_mask, loss_agg_mode=loss_agg_mode) + + loss += kl_loss * self.config.kl_loss_coef + + # Rescale the final model loss together instead of separately in core_algos + if loss_agg_mode == "token-mean": + num_valid_toks = response_mask.sum() + micro_weight = num_valid_toks / mini_batch_loss_token_num + else: # seq-mean + micro_weight = len(micro_data_chunk) / self.config.ppo_mini_batch_size + + micro_loss = loss * micro_weight + if accum_loss is None: + accum_loss = micro_loss + else: + accum_loss += micro_loss + + micro_weights.append(micro_weight) + raw_micro_losses.append(loss.detach().cpu().item()) + + print(f"{raw_micro_losses=}") + print(f"{sum(micro_weights)=}") + print(f"{micro_weights=}") + + return accum_loss, len(micro_data_chunks) + + def update_policy(self, data: DataProto): + """ + Tests gradient accumulation by comparing loss computed with mini-batches vs single batch + """ + # make sure we are in training mode + self.actor_module.train() + metrics = defaultdict(list) + + select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "advantages"] + if self.config.use_kl_loss: + select_keys.append("ref_log_prob") + non_tensor_select_keys = ["multi_modal_inputs"] if "multi_modal_inputs" in data.non_tensor_batch.keys() else [] + + selected_data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) + + # Split to make minibatch iterator for updating the actor + # See PPO paper for details. https://arxiv.org/abs/1707.06347 + num_mini_batches = len(selected_data) // self.config.ppo_mini_batch_size + assert num_mini_batches > 1, f"num_mini_batches must be greater than 1 to test grad accumulation, but got {num_mini_batches=}" + + mini_dataloader = selected_data.chunk(num_mini_batches) + + test_infos: list[dict[str, Any]] = [] + + for mini_idx, mini_data_chunk in enumerate(mini_dataloader): + for loss_agg_mode in ALL_LOSS_AGG_MODES: + mini_loss_w_grad_accum, num_micro_batches = self.compute_batch_loss(data=mini_data_chunk, loss_agg_mode=loss_agg_mode, mini_batch_loss_token_num=data.meta_info["mini_batch_loss_token_nums"][mini_idx], disable_grad_accum=False) + mini_loss, _ = self.compute_batch_loss(data=mini_data_chunk, loss_agg_mode=loss_agg_mode, mini_batch_loss_token_num=data.meta_info["mini_batch_loss_token_nums"][mini_idx], disable_grad_accum=True) + if loss_agg_mode == self.config.loss_agg_mode: + mini_loss.backward() + mini_loss_w_grad_accum = mini_loss_w_grad_accum.detach().cpu() + mini_loss = mini_loss.detach().cpu() + + test_infos.append( + { + "mini_idx": mini_idx, + "loss_agg_mode": loss_agg_mode, + "num_micro_batches": num_micro_batches, + "mini_loss": mini_loss.item(), + "mini_loss_w_grad_accum": mini_loss_w_grad_accum.item(), + "rtol": GRAD_ACCUM_RTOL, + "isclose": torch.isclose(mini_loss_w_grad_accum, mini_loss, rtol=GRAD_ACCUM_RTOL), + } + ) + self._optimizer_step() + self.actor_optimizer.zero_grad() + + test_info_df = pd.DataFrame(test_infos) + print(test_info_df) + + return metrics + + +class GradAccumulationTestDPCritic(DataParallelPPOCritic): + def compute_batch_loss(self, data: DataProto, loss_agg_mode: str = "token-mean", mini_batch_loss_token_num: Optional[int] = None, disable_grad_accum: bool = False) -> tuple[torch.Tensor, int]: + accum_loss = None + + if disable_grad_accum: + micro_data_chunks = [data] + else: + if self.config.use_dynamic_bsz: + max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size + micro_data_chunks, _ = get_uniform_data_chunks(data=data, max_token_len=max_token_len) + else: + num_micro_batches = len(data) // self.config.ppo_micro_batch_size_per_gpu + micro_data_chunks = data.chunk(num_micro_batches) + + assert len(micro_data_chunks) > 1, f"len(micro_data_chunks) must be greater than 1 to test grad accumulation, but got {len(micro_data_chunks)=}" + + micro_weights = [] + raw_micro_losses = [] + for micro_data_chunk in micro_data_chunks: + micro_batch = {**micro_data_chunk.batch, **micro_data_chunk.non_tensor_batch} + + responses = micro_batch["responses"] + attention_mask = micro_batch["attention_mask"] + values = micro_batch["values"] + returns = micro_batch["returns"] + response_mask = compute_response_mask(response_ids=responses, attention_mask=attention_mask) + + vpreds = self._forward_micro_batch(micro_batch) + + # assert not torch.any(torch.isnan(vpreds)).item() + + vf_loss, vf_clipfrac = core_algos.compute_value_loss(vpreds=vpreds, values=values, returns=returns, response_mask=response_mask, cliprange_value=self.config.cliprange_value, loss_agg_mode=loss_agg_mode) + + loss = vf_loss + # Rescale the final model loss together instead of separately in core_algos + if loss_agg_mode == "token-mean": + num_valid_toks = response_mask.sum() + micro_weight = num_valid_toks / mini_batch_loss_token_num + else: # seq-mean + micro_weight = len(micro_data_chunk) / self.config.ppo_mini_batch_size + + micro_loss = loss * micro_weight + if accum_loss is None: + accum_loss = micro_loss + else: + accum_loss += micro_loss + + micro_weights.append(micro_weight) + raw_micro_losses.append(loss.detach().cpu().item()) + + print(f"{raw_micro_losses=}") + print(f"{sum(micro_weights)=}") + print(f"{micro_weights=}") + + assert accum_loss is not None, "accum_loss must not be None" + return accum_loss, len(micro_data_chunks) + + def update_critic(self, data: DataProto): + """ + TODO: Merge common part with update_actor as update + """ + # make sure we are in training mode + self.critic_module.train() + metrics = defaultdict(list) + + select_keys = ["input_ids", "responses", "attention_mask", "position_ids", "values", "returns"] + non_tensor_select_keys = ["multi_modal_inputs"] if "multi_modal_inputs" in data.non_tensor_batch.keys() else [] + + selected_data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) + + # Split to make minibatch iterator for updating the actor + # See PPO paper for details. https://arxiv.org/abs/1707.06347 + num_mini_batches = len(selected_data) // self.config.ppo_mini_batch_size + + mini_dataloader = data.chunk(num_mini_batches) # TODO: `make_minibatch_iterator`` as in megatron + + test_infos: list[dict[str, Any]] = [] + + for mini_idx, mini_data_chunk in enumerate(mini_dataloader): + for loss_agg_mode in ALL_LOSS_AGG_MODES: + mini_loss_w_grad_accum, num_micro_batches = self.compute_batch_loss(data=mini_data_chunk, loss_agg_mode=loss_agg_mode, mini_batch_loss_token_num=data.meta_info["mini_batch_loss_token_nums"][mini_idx], disable_grad_accum=False) + mini_loss, _ = self.compute_batch_loss(data=mini_data_chunk, loss_agg_mode=loss_agg_mode, mini_batch_loss_token_num=data.meta_info["mini_batch_loss_token_nums"][mini_idx], disable_grad_accum=True) + if loss_agg_mode == self.config.loss_agg_mode: + mini_loss.backward() + mini_loss_w_grad_accum = mini_loss_w_grad_accum.detach().cpu() + mini_loss = mini_loss.detach().cpu() + test_infos.append( + { + "mini_idx": mini_idx, + "loss_agg_mode": loss_agg_mode, + "num_micro_batches": num_micro_batches, + "mini_loss": mini_loss.item(), + "mini_loss_w_grad_accum": mini_loss_w_grad_accum.item(), + "rtol": GRAD_ACCUM_RTOL, + "isclose": torch.isclose(mini_loss_w_grad_accum, mini_loss, rtol=GRAD_ACCUM_RTOL), + } + ) + self._optimizer_step() + self.critic_optimizer.zero_grad() + + test_info_df = pd.DataFrame(test_infos) + print(test_info_df) + + return metrics diff --git a/tests/e2e/grad_accum/test_grad_accum.py b/tests/e2e/grad_accum/test_grad_accum.py new file mode 100644 index 00000000000..06f95f98fe1 --- /dev/null +++ b/tests/e2e/grad_accum/test_grad_accum.py @@ -0,0 +1,447 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. +""" + +import os +import uuid +from copy import deepcopy +from unittest.mock import patch + +import hydra +import numpy as np +import ray +import torch +import torch.distributed +from omegaconf import OmegaConf +from tqdm import tqdm + +from verl import DataProto +from verl.single_controller.base.worker import Worker +from verl.single_controller.ray.base import ( + RayClassWithInitArgs, + _bind_workers_method_to_parent, + _unwrap_ray_remote, +) +from verl.trainer.main_ppo import get_custom_reward_fn +from verl.trainer.ppo.metric_utils import ( + compute_data_metrics, + compute_throughout_metrics, + compute_timing_metrics, + reduce_metrics, +) +from verl.trainer.ppo.ray_trainer import ( + AdvantageEstimator, + RayPPOTrainer, + Role, + _timer, + apply_kl_penalty, + calc_mini_batch_loss_token_nums, + compute_advantage, +) + +from .grad_accum_test_fsdp_workers import ( + GradAccumulationTestActorRolloutRefWorker, + GradAccumulationTestCriticWorker, +) + + +def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]): + """ + This function should return a class instance that delegates the calls to every + cls in cls_dict + """ + # TODO: MegatronWorker + worker_cls = Worker + assert all(issubclass(cls.cls.__ray_actor_class__, worker_cls) for cls in class_dict.values()), f"all classes in class_dict must be subclass of {worker_cls=}" + + cls_dict = {} + init_args_dict = {} + + for key, cls in class_dict.items(): + cls_dict[key] = cls.cls + init_args_dict[key] = {"args": cls.args, "kwargs": cls.kwargs} + + assert cls_dict.keys() == init_args_dict.keys() + + # TODO: create a class with customizable name + class WorkerDict(worker_cls): + def __init__(self): + super().__init__() + self.worker_dict = {} + for key, user_defined_cls in cls_dict.items(): + user_defined_cls = _unwrap_ray_remote(user_defined_cls) + # directly instantiate the class without remote + # in worker class, e.g. when DISABLE_WORKER_INIT == 1 it will return immediately + with patch.dict(os.environ, {"DISABLE_WORKER_INIT": "1"}): + self.worker_dict[key] = user_defined_cls(*init_args_dict[key].get("args", ()), **init_args_dict[key].get("kwargs", {})) + + # now monkey-patch the methods from inner class to WorkerDict + for key, user_defined_cls in cls_dict.items(): + user_defined_cls = _unwrap_ray_remote(user_defined_cls) + _bind_workers_method_to_parent(WorkerDict, key, user_defined_cls) + + remote_cls = ray.remote(WorkerDict) + remote_cls = RayClassWithInitArgs(cls=remote_cls) + return remote_cls + + +class GradAccumulationTestTrainer(RayPPOTrainer): + """ + Note that this trainer runs on the driver process on a single CPU/GPU node. + """ + + def init_workers(self): + """Init resource pool and worker group""" + self.resource_pool_manager.create_resource_pool() + + self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} + + # create actor and rollout + if self.hybrid_engine: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) + actor_rollout_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.ActorRollout], config=self.config.actor_rollout_ref, role="actor_rollout") + self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls + else: + raise NotImplementedError + + # create critic + if self.use_critic: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) + critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic) + self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls + + # create reference policy if needed + if self.use_reference_policy: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) + ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy], config=self.config.actor_rollout_ref, role="ref") + self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls + + # create a reward model if reward_fn is None + if self.use_rm: + # we create a RM here + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) + rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model) + self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls + + # initialize WorkerGroup + # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, + # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups. + # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. + all_wg = {} + self.wg_dicts = [] + wg_kwargs = {} # Setting up kwargs for RayWorkerGroup + if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: + wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout + + for resource_pool, class_dict in self.resource_pool_to_cls.items(): + worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) + wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls, **wg_kwargs) + spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) + all_wg.update(spawn_wg) + # keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699 + self.wg_dicts.append(wg_dict) + + if self.use_critic: + self.critic_wg = all_wg["critic"] + self.critic_wg.init_model() + + if self.use_reference_policy: + self.ref_policy_wg = all_wg["ref"] + self.ref_policy_wg.init_model() + + if self.use_rm: + self.rm_wg = all_wg["rm"] + self.rm_wg.init_model() + + # we should create rollout at the end so that vllm can have a better estimation of kv cache memory + self.actor_rollout_wg = all_wg["actor_rollout"] + self.actor_rollout_wg.init_model() + + def fit(self) -> None: + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from omegaconf import OmegaConf + + from verl.utils.tracking import Tracking + + logger = Tracking(project_name=self.config.trainer.project_name, experiment_name=self.config.trainer.experiment_name, default_backend=self.config.trainer.logger, config=OmegaConf.to_container(self.config, resolve=True)) + + self.global_steps = 0 + + # load checkpoint before doing anything + self._load_checkpoint() + + # add tqdm + progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + + # we start from step 1 + self.global_steps += 1 + + for epoch in range(self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + metrics = {} + timing_raw = {} + + batch = DataProto.from_single_dict(batch_dict) + + # pop those keys for generation + if "multi_modal_inputs" in batch.non_tensor_batch.keys(): + gen_batch = batch.pop( + batch_keys=["input_ids", "attention_mask", "position_ids"], + non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs"], + ) + else: + gen_batch = batch.pop( + batch_keys=["input_ids", "attention_mask", "position_ids"], + non_tensor_batch_keys=["raw_prompt_ids"], + ) + + with _timer("step", timing_raw): + # generate a batch + with _timer("gen", timing_raw): + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + + if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: + with _timer("gen_max", timing_raw): + gen_baseline_batch = deepcopy(gen_batch) + gen_baseline_batch.meta_info["do_sample"] = False + gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) + + batch = batch.union(gen_baseline_output) + reward_baseline_tensor = self.reward_fn(batch) + reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) + + batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) + + batch.batch["reward_baselines"] = reward_baseline_tensor + + del gen_baseline_batch, gen_baseline_output + + batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object) + # repeat to align with repeated responses in rollout + batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.union(gen_batch_output) + + # balance the number of valid tokens on each dp rank. + # Note that this breaks the order of data inside the batch. + # Please take care when you implement group based adv computation such as GRPO and rloo + if self.config.trainer.balance_batch: + self._balance_batch(batch, metrics=metrics) + + # compute global_valid tokens + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + + batch.meta_info["mini_batch_loss_token_nums"] = calc_mini_batch_loss_token_nums(batch, traj_mini_bsz=self.config.actor_rollout_ref.actor.ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n, num_dp_ranks=self.actor_rollout_wg.world_size) + + # recompute old_log_probs + with _timer("old_log_prob", timing_raw): + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + batch = batch.union(old_log_prob) + + if self.use_reference_policy: + # compute reference log_prob + with _timer("ref", timing_raw): + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + # compute values + if self.use_critic: + with _timer("values", timing_raw): + values = self.critic_wg.compute_values(batch) + batch = batch.union(values) + + with _timer("adv", timing_raw): + # compute scores. Support both model and function-based. + # We first compute the scores using reward model. Then, we call reward_fn to combine + # the results from reward model and rule-based results. + if self.use_rm: + # we first compute reward model score + reward_tensor = self.rm_wg.compute_rm_score(batch) + batch = batch.union(reward_tensor) + + # we combine with rule-based rm + reward_extra_infos_dict: dict[str, list] + try: + reward_result = self.reward_fn(batch, return_dict=True) + reward_tensor = reward_result["reward_tensor"] + reward_extra_infos_dict = reward_result["reward_extra_info"] + except Exception as e: + print(f"Error in reward_fn: {e}") + reward_tensor = self.reward_fn(batch) + reward_extra_infos_dict = {} + + batch.batch["token_level_scores"] = reward_tensor + + print(f"{list(reward_extra_infos_dict.keys())=}") + if reward_extra_infos_dict: + batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) + + # compute rewards. apply_kl_penalty if available + if self.config.algorithm.use_kl_in_reward: + batch, kl_metrics = apply_kl_penalty(batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty) + metrics.update(kl_metrics) + else: + batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] + + # compute advantages, executed on the driver process + batch = compute_advantage(batch, adv_estimator=self.config.algorithm.adv_estimator, gamma=self.config.algorithm.gamma, lam=self.config.algorithm.lam, num_repeat=self.config.actor_rollout_ref.rollout.n) + + # update critic + if self.use_critic: + with _timer("update_critic", timing_raw): + critic_output = self.critic_wg.update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + metrics.update(critic_output_metrics) + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + # update actor + with _timer("update_actor", timing_raw): + actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + metrics.update(actor_output_metrics) + + # collect metrics + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + # TODO: implement actual tflpo and theoretical tflpo + n_gpus = self.resource_pool_manager.get_n_gpus() + metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + progress_bar.update(1) + self.global_steps += 1 + + +@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head +class TaskRunner: + def run(self, config): + # print initial config + from pprint import pprint + + from omegaconf import OmegaConf + + from verl.utils.fs import copy_to_local + + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + OmegaConf.resolve(config) + + # download the checkpoint from hdfs + local_path = copy_to_local(config.actor_rollout_ref.model.path) + + # instantiate tokenizer + from verl.utils import hf_processor, hf_tokenizer + + tokenizer = hf_tokenizer(local_path) + processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none + + # define worker classes + if config.actor_rollout_ref.actor.strategy == "fsdp": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.single_controller.ray import RayWorkerGroup + + actor_rollout_ref_worker_cls = GradAccumulationTestActorRolloutRefWorker + critic_worker_cls = GradAccumulationTestCriticWorker + ray_worker_group_cls = RayWorkerGroup + elif config.actor_rollout_ref.actor.strategy == "megatron": + raise NotImplementedError("TODO: Gradient accumulation test for Megatron is not supported yet.") + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + + role_worker_mapping = {Role.ActorRollout: ray.remote(actor_rollout_ref_worker_cls), Role.Critic: ray.remote(critic_worker_cls), Role.RefPolicy: ray.remote(actor_rollout_ref_worker_cls)} + + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + mapping = { + Role.ActorRollout: global_pool_id, + Role.Critic: global_pool_id, + Role.RefPolicy: global_pool_id, + } + + # we should adopt a multi-source reward function here + # - for rule-based rm, we directly call a reward score + # - for model-based rm, we call a model + # - for code related prompt, we send to a sandbox if there are test cases + # - finally, we combine all the rewards together + # - The reward type depends on the tag of the data + if config.reward_model.enable: + if config.reward_model.strategy == "fsdp": + from verl.workers.fsdp_workers import RewardModelWorker + elif config.reward_model.strategy == "megatron": + from verl.workers.megatron_workers import RewardModelWorker + else: + raise NotImplementedError + role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) + mapping[Role.RewardModel] = global_pool_id + + # reference model + if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + role_worker_mapping[Role.RefPolicy] = ray.remote(GradAccumulationTestActorRolloutRefWorker) + mapping[Role.RefPolicy] = global_pool_id + + reward_manager_name = config.reward_model.get("reward_manager", "naive") + if reward_manager_name == "naive": + from verl.workers.reward_manager import NaiveRewardManager + + reward_manager_cls = NaiveRewardManager + elif reward_manager_name == "prime": + from verl.workers.reward_manager import PrimeRewardManager + + reward_manager_cls = PrimeRewardManager + elif reward_manager_name == "dapo": + from verl.workers.reward_manager import DAPORewardManager + + reward_manager_cls = DAPORewardManager + else: + raise NotImplementedError + + compute_score = get_custom_reward_fn(config) + reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score, reward_fn_key=config.data.reward_fn_key) + + # Note that we always use function-based RM for validation + val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score, reward_fn_key=config.data.reward_fn_key) + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + trainer = GradAccumulationTestTrainer(config=config, tokenizer=tokenizer, processor=processor, role_worker_mapping=role_worker_mapping, resource_pool_manager=resource_pool_manager, ray_worker_group_cls=ray_worker_group_cls, reward_fn=reward_fn, val_reward_fn=val_reward_fn) + trainer.init_workers() + trainer.fit() + + +@hydra.main(config_path="../../../verl/trainer/config", config_name="ppo_trainer", version_base=None) +def run_ppo(config) -> None: + # TODO(linjunrong.ocss884): this ENV is left for resolving SGLang conflict with ray devices + # isolation, will solve in the future + os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "") + if not ray.is_initialized(): + # this is for local ray cluster + ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"}}) + + runner = TaskRunner.remote() + ray.get(runner.run.remote(config)) + + +if __name__ == "__main__": + # TODO: Simplify other mains + run_ppo() diff --git a/tests/e2e/grad_accum/test_grad_accum.sh b/tests/e2e/grad_accum/test_grad_accum.sh new file mode 100644 index 00000000000..21f56e0fc20 --- /dev/null +++ b/tests/e2e/grad_accum/test_grad_accum.sh @@ -0,0 +1,80 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B} +MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}} +huggingface-cli download "${MODEL_ID}" --local-dir "${MODEL_PATH}" + +TRAIN_FILES=${TRAIN_FILES:-$HOME/data/gsm8k/train.parquet} +VAL_FILES=${VAL_FILES:-$HOME/data/gsm8k/test.parquet} +MAX_PROMPT_LEN=${MAX_PROMPT_LEN:-512} +MAX_RESPONSE_LEN=${MAX_RESPONSE_LEN:-512} + +ENGINE=${ENGINE:-vllm} +RM_PAD=${RM_PAD:-True} +ADV_ESTIMATOR=${ADV_ESTIMATOR:-gae} +USE_KL=${USE_KL:-False} +CUSTOM_REWARD_FN=${CUSTOM_REWARD_FN:-False} +# Validation +VAL_BEFORE_TRAIN=${VAL_BEFORE_TRAIN:-False} +TEST_FREQ=${TEST_FREQ:--1} +# Save & Resume +RESUME_MODE=${RESUME_MODE:-disable} +SAVE_FREQ=${SAVE_FREQ:--1} +TOT_TRAIN_STEPS=${TOT_TRAIN_STEPS:-1} + +NUM_GPUS=${NUM_GPUS:-8} + +train_traj_micro_bsz_per_gpu=2 # b +n_resp_per_prompt=4 # g +num_micro_batches=2 + +train_traj_micro_bsz=$((train_traj_micro_bsz_per_gpu * NUM_GPUS)) # b * n +train_traj_mini_bsz=$((train_traj_micro_bsz * num_micro_batches)) # 2 * b * n +train_prompt_mini_bsz=$((train_traj_mini_bsz * n_resp_per_prompt)) # 2 * b * n / g +train_prompt_bsz=$((train_prompt_mini_bsz * 2)) # 4 * b * n / g + +exp_name="$(basename "${MODEL_ID,,}")-test-grad-accum" + +python3 -m tests.e2e.grad_accum.test_grad_accum \ + algorithm.adv_estimator="${ADV_ESTIMATOR}" \ + data.train_files="${TRAIN_FILES}" \ + data.val_files="${VAL_FILES}" \ + data.train_batch_size=${train_prompt_bsz} \ + data.max_prompt_length="${MAX_PROMPT_LEN}" \ + data.max_response_length="${MAX_RESPONSE_LEN}" \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding="${RM_PAD}" \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.actor.use_kl_loss="${USE_KL}" \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name="${ENGINE}" \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ + critic.optim.lr=1e-5 \ + critic.model.use_remove_padding="${RM_PAD}" \ + critic.model.path="${MODEL_PATH}" \ + critic.model.enable_gradient_checkpointing=False \ + critic.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ + critic.model.fsdp_config.param_offload=False \ + critic.model.fsdp_config.optimizer_offload=False \ + algorithm.use_kl_in_reward="${USE_KL}" \ + algorithm.kl_penalty=kl \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.critic_warmup=0 \ + trainer.logger=['console'] \ + trainer.project_name='verl-test' \ + trainer.experiment_name="${exp_name}" \ + trainer.nnodes=1 \ + trainer.n_gpus_per_node="${NUM_GPUS}" \ + trainer.val_before_train="${VAL_BEFORE_TRAIN}" \ + trainer.test_freq="${TEST_FREQ}" \ + trainer.save_freq="${SAVE_FREQ}" \ + trainer.resume_mode="${RESUME_MODE}" \ + trainer.total_epochs=2 \ + trainer.total_training_steps="${TOT_TRAIN_STEPS}" $@ diff --git a/tests/utility/test_tensor_dict_utilities.py b/tests/utility/test_tensor_dict_utilities.py index befa5bd1e3c..9308927457e 100644 --- a/tests/utility/test_tensor_dict_utilities.py +++ b/tests/utility/test_tensor_dict_utilities.py @@ -276,7 +276,7 @@ def test_len(): def test_seqlen_balancing(): - from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches + from verl.utils.seqlen_balancing import get_reverse_idx, get_uniform_data_chunks input_ids = torch.randint(low=0, high=10, size=(20, 100)) from verl.utils.model import create_random_mask @@ -284,7 +284,8 @@ def test_seqlen_balancing(): attention_mask = create_random_mask(input_ids=input_ids, max_ratio_of_left_padding=0.1, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.5) data = {"input_ids": input_ids, "attention_mask": attention_mask} dataproto = DataProto.from_single_dict(data) - micro_batches, micro_bsz_idx_lst = rearrange_micro_batches(dataproto.batch, max_token_len=300) + micro_data_chunks, micro_bsz_idx_lst = get_uniform_data_chunks(dataproto, max_token_len=300) + micro_batches = [micro_data_chunk.batch for micro_data_chunk in micro_data_chunks] batch = torch.cat(micro_batches) micro_bsz_idx = [] for idx in micro_bsz_idx_lst: diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 1b6668dce8a..cec8dc5b864 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -164,6 +164,7 @@ critic: shuffle: ${actor_rollout_ref.actor.shuffle} grad_clip: 1.0 cliprange_value: 0.5 + loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode} checkpoint: contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index a0e7cd8af48..70ed1eec370 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -356,11 +356,11 @@ def compute_policy_loss( log_prob, advantages, response_mask, + loss_agg_mode, cliprange=None, cliprange_low=None, cliprange_high=None, clip_ratio_c=3.0, - loss_agg_mode="token-mean", ): """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122 Args: @@ -372,6 +372,8 @@ def compute_policy_loss( shape: (bs, response_length) response_mask: `(torch.Tensor)` shape: (bs, response_length) + loss_agg_mode: (str) choices: "token-mean" / "seq-mean-token-sum" / "seq-mean-token-mean" / "seq-mean-token-sum-norm" + "token-mean" is the default behavior cliprange: (float) The clip range used in PPO. See https://arxiv.org/abs/1707.06347 cliprange_low: (float) @@ -380,11 +382,6 @@ def compute_policy_loss( The higher clip range used in PPO. clip_ratio_c: (float) default: 3.0 The lower bound of the ratio for dual-clip PPO, See https://arxiv.org/pdf/1912.09729 - loss_agg_mode: (str) choices: "token-mean" / - "seq-mean-token-sum" / - "seq-mean-token-mean" / - "seq-mean-token-sum-norm" / - "token-mean" is the default behavior Returns: pg_loss: `a scalar torch.Tensor` @@ -421,7 +418,7 @@ def compute_policy_loss( return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower -def compute_entropy_loss(logits, response_mask): +def compute_entropy_loss(logits: torch.Tensor, response_mask: torch.Tensor, loss_agg_mode: str): """Compute Categorical entropy loss Args: @@ -429,6 +426,11 @@ def compute_entropy_loss(logits, response_mask): shape: (bs, response_length, vocab_size) response_mask: `(torch.Tensor)` shape: (bs, response_length) + loss_agg_mode: (str) choices: "token-mean" / + "seq-mean-token-sum" / + "seq-mean-token-mean" / + "seq-mean-token-sum-norm" / + "token-mean" is the default behavior Returns: entropy: a scalar torch.Tensor @@ -436,11 +438,30 @@ def compute_entropy_loss(logits, response_mask): """ # compute entropy entropy = verl_F.entropy_from_logits(logits) # (bs, response_len) - entropy_loss = verl_F.masked_mean(entropy, mask=response_mask) + entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) return entropy_loss -def compute_value_loss(vpreds, returns, values, response_mask, cliprange_value): +def compute_kl_loss(kld: torch.Tensor, response_mask: torch.Tensor, loss_agg_mode: str): + """Compute KL divergence loss + + Args: + kld: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + loss_agg_mode: (str) choices: "token-mean" / "seq-mean-token-sum" / "seq-mean-token-mean" + "token-mean" is the default behavior + + Returns: + kl_loss: `torch.Tensor` + shape: () + """ + kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + return kl_loss + + +def compute_value_loss(vpreds: torch.Tensor, returns: torch.Tensor, values: torch.Tensor, response_mask: torch.Tensor, cliprange_value: float, loss_agg_mode: str): """Compute the value loss. Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151 Args: @@ -450,6 +471,8 @@ def compute_value_loss(vpreds, returns, values, response_mask, cliprange_value): Old values of value head, shape (`batch_size`, `response_length`) returns: (`torch.FloatTensor`): Ground truth returns, shape (`batch_size`, `response_length`) + loss_agg_mode: (str) choices: "token-mean" / "seq-mean-token-sum" / "seq-mean-token-mean" + "token-mean" is the default behavior Returns: vf_loss: a scalar (`torch.FloatTensor`): @@ -461,7 +484,7 @@ def compute_value_loss(vpreds, returns, values, response_mask, cliprange_value): vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value) vf_losses1 = (vpreds - returns) ** 2 vf_losses2 = (vpredclipped - returns) ** 2 - vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), response_mask) + vf_loss = agg_loss(loss_mat=0.5 * torch.max(vf_losses1, vf_losses2), loss_mask=response_mask, loss_agg_mode=loss_agg_mode) vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask) return vf_loss, vf_clipfrac diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index bbeacff4879..52504ab8a47 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -56,7 +56,7 @@ from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance -from verl.utils.torch_functional import masked_mean +from verl.utils.torch_functional import compute_response_mask, masked_mean from verl.utils.tracking import ValidationGenerationsLogger from verl.workers.rollout.async_server import AsyncLLMServerManager @@ -143,19 +143,11 @@ def _check_resource_available(self): raise ValueError(f"Resource pool {resource_pool_name}: {num_gpus}*{num_nodes}" + "cannot be satisfied in this ray cluster") -def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl", multi_turn=False): - responses = data.batch["responses"] - response_length = responses.size(1) +def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"): + response_mask = compute_response_mask(response_ids=data.batch["responses"], attention_mask=data.batch["attention_mask"]) token_level_scores = data.batch["token_level_scores"] batch_size = data.batch.batch_size[0] - if multi_turn: - loss_mask = data.batch["loss_mask"] - response_mask = loss_mask[:, -response_length:] - else: - attention_mask = data.batch["attention_mask"] - response_mask = attention_mask[:, -response_length:] - # compute kl between ref_policy and current policy # When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled. kld = core_algos.kl_penalty(data.batch["old_log_probs"], data.batch["ref_log_prob"], kl_penalty=kl_penalty) # (batch_size, response_length) @@ -176,24 +168,46 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, return data, metrics -def compute_response_mask(data: DataProto): - responses = data.batch["responses"] - response_length = responses.size(1) - attention_mask = data.batch["attention_mask"] - return attention_mask[:, -response_length:] +def calc_mini_batch_loss_token_nums(batch: DataProto, traj_mini_bsz: int, num_dp_ranks: int) -> list[int]: + """ + NOTE: Be compatible with + + 1. verl.workers.fsdp_workers.ActorRolloutRefWorker.update_actor + 2. verl.workers.fsdp_workers.CriticWorker.update_critic + + TODO: Calculate separate numbers if adopting different strategies for actor and critic + """ + response_mask = compute_response_mask(response_ids=batch.batch["responses"], attention_mask=batch.batch["attention_mask"]) + + traj_bsz = len(batch.batch) + num_mini_batches = (traj_bsz + traj_mini_bsz - 1) // traj_mini_bsz + traj_mini_bsz_per_rank = traj_mini_bsz // num_dp_ranks + + mini_batch_loss_token_nums = [] + for _ in range(num_mini_batches): + mini_batch_traj_idxs = [] + for dp_rank in range(num_dp_ranks): + start_traj_idx = int(traj_bsz / num_dp_ranks * dp_rank) + next_start_traj_idx = int(traj_bsz / num_dp_ranks * (dp_rank + 1)) + end_traj_idx = int(min(start_traj_idx + traj_mini_bsz_per_rank, next_start_traj_idx)) + mini_batch_traj_idxs.extend(list(range(start_traj_idx, end_traj_idx))) + mini_batch_resp_mask = response_mask[mini_batch_traj_idxs] + mini_batch_loss_token_num = mini_batch_resp_mask.sum() + mini_batch_loss_token_nums.append(mini_batch_loss_token_num) + + return mini_batch_loss_token_nums def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1, multi_turn=False, norm_adv_by_std_in_grpo=True): # Back-compatible with trainers that do not compute response mask in fit - if "response_mask" not in data.batch: - data.batch["response_mask"] = compute_response_mask(data) + response_mask = compute_response_mask(response_ids=data.batch["responses"], attention_mask=data.batch["attention_mask"]) # prepare response group # TODO: add other ways to estimate advantages if adv_estimator == AdvantageEstimator.GAE: advantages, returns = core_algos.compute_gae_advantage_return( token_level_rewards=data.batch["token_level_rewards"], values=data.batch["values"], - response_mask=data.batch["response_mask"], + response_mask=response_mask, gamma=gamma, lam=lam, ) @@ -207,12 +221,7 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re response_length = grpo_calculation_mask.size(1) # Get length from the initial response mask grpo_calculation_mask = data.batch["loss_mask"][:, -response_length:] # This mask is the one intended for GRPO # Call compute_grpo_outcome_advantage with parameters matching its definition - advantages, returns = core_algos.compute_grpo_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - response_mask=grpo_calculation_mask, - index=data.non_tensor_batch["uid"], - norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, - ) + advantages, returns = core_algos.compute_grpo_outcome_advantage(token_level_rewards=data.batch["token_level_rewards"], response_mask=response_mask, index=data.non_tensor_batch["uid"]) data.batch["advantages"] = advantages data.batch["returns"] = returns elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE: @@ -224,28 +233,16 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re data.batch["advantages"] = advantages data.batch["returns"] = returns elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS: - advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - response_mask=data.batch["response_mask"], - gamma=gamma, - ) + advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage(token_level_rewards=data.batch["token_level_rewards"], response_mask=response_mask, gamma=gamma) data.batch["advantages"] = advantages data.batch["returns"] = returns elif adv_estimator == AdvantageEstimator.REMAX: - advantages, returns = core_algos.compute_remax_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - reward_baselines=data.batch["reward_baselines"], - response_mask=data.batch["response_mask"], - ) + advantages, returns = core_algos.compute_remax_outcome_advantage(token_level_rewards=data.batch["token_level_rewards"], reward_baselines=data.batch["reward_baselines"], response_mask=response_mask) data.batch["advantages"] = advantages data.batch["returns"] = returns elif adv_estimator == AdvantageEstimator.RLOO: - advantages, returns = core_algos.compute_rloo_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - response_mask=data.batch["response_mask"], - index=data.non_tensor_batch["uid"], - ) + advantages, returns = core_algos.compute_rloo_outcome_advantage(token_level_rewards=data.batch["token_level_rewards"], response_mask=response_mask, index=data.non_tensor_batch["uid"]) data.batch["advantages"] = advantages data.batch["returns"] = returns else: @@ -951,7 +948,6 @@ def fit(self): batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) batch = batch.union(gen_batch_output) - batch.batch["response_mask"] = compute_response_mask(batch) # balance the number of valid tokens on each dp rank. # Note that this breaks the order of data inside the batch. # Please take care when you implement group based adv computation such as GRPO and rloo @@ -972,6 +968,8 @@ def fit(self): else: reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) + batch.meta_info["mini_batch_loss_token_nums"] = calc_mini_batch_loss_token_nums(batch, traj_mini_bsz=self.config.actor_rollout_ref.actor.ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n, num_dp_ranks=self.actor_rollout_wg.world_size) + # recompute old_log_probs with _timer("old_log_prob", timing_raw): old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) diff --git a/verl/utils/seqlen_balancing.py b/verl/utils/seqlen_balancing.py index 16dca31c961..25ece3b5bca 100644 --- a/verl/utils/seqlen_balancing.py +++ b/verl/utils/seqlen_balancing.py @@ -14,12 +14,13 @@ import copy import heapq -from typing import List, Tuple +from typing import List, Optional, Tuple import torch -from tensordict import TensorDict from torch import distributed as dist +from verl import DataProto + def karmarkar_karp(seqlen_list: List[int], k_partitions: int, equal_size: bool): # see: https://en.wikipedia.org/wiki/Largest_differencing_method @@ -213,15 +214,15 @@ def ceildiv(a, b): return -(a // -b) -def rearrange_micro_batches(batch: TensorDict, max_token_len, dp_group=None): +def get_uniform_data_chunks(data: DataProto, max_token_len: int, dp_group: Optional[dist.ProcessGroup] = None) -> tuple[list[DataProto], list[list[int]]]: """Split the batch into a list of micro_batches, where the max_token_len is smaller than max_token_len and the number of valid tokens in each micro batch is well balanced. """ # this is per local micro_bsz - max_seq_len = batch["attention_mask"].shape[-1] + max_seq_len = data.batch["attention_mask"].shape[-1] assert max_token_len >= max_seq_len, f"max_token_len must be greater than the sequence length. Got {max_token_len=} and {max_seq_len=}" - seq_len_effective: torch.Tensor = batch["attention_mask"].sum(dim=1) + seq_len_effective: torch.Tensor = data.batch["attention_mask"].sum(dim=1) total_seqlen = seq_len_effective.sum().item() num_micro_batches = ceildiv(total_seqlen, max_token_len) if dist.is_initialized(): @@ -232,19 +233,11 @@ def rearrange_micro_batches(batch: TensorDict, max_token_len, dp_group=None): seq_len_effective = seq_len_effective.tolist() assert num_micro_batches <= len(seq_len_effective) - micro_bsz_idx = get_seqlen_balanced_partitions(seq_len_effective, num_micro_batches, equal_size=False) - - micro_batches = [] - - for partition in micro_bsz_idx: - curr_micro_batch = [] - for idx in partition: - curr_micro_batch.append(batch[idx : idx + 1]) - curr_micro_batch = torch.cat(curr_micro_batch) + micro_idx_partitions = get_seqlen_balanced_partitions(seq_len_effective, num_micro_batches, equal_size=False) - micro_batches.append(curr_micro_batch) + uniform_data_chunks = [data.select_idxs(partition) for partition in micro_idx_partitions] - return micro_batches, micro_bsz_idx + return uniform_data_chunks, micro_idx_partitions def get_reverse_idx(idx_map): diff --git a/verl/utils/torch_functional.py b/verl/utils/torch_functional.py index 2d42b25d97a..00e24d4964d 100644 --- a/verl/utils/torch_functional.py +++ b/verl/utils/torch_functional.py @@ -500,6 +500,11 @@ def get_unpad_data(attention_mask): ) +def compute_response_mask(response_ids: torch.Tensor, attention_mask: torch.Tensor): + response_length = response_ids.size(-1) + return attention_mask[:, -response_length:] + + def get_wsd_schedule_with_warmup( optimizer: Optimizer, num_warmup_steps: int, diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 3b26d81fe6d..588bfa99bc0 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -20,7 +20,7 @@ import itertools import logging import os -from typing import Tuple +from typing import Optional, Tuple import torch from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input @@ -32,7 +32,7 @@ from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, kl_penalty from verl.utils.debug import GPUMemoryLogger from verl.utils.py_functional import append_to_dict -from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches +from verl.utils.seqlen_balancing import get_reverse_idx, get_uniform_data_chunks from verl.utils.torch_functional import logprobs_from_logits from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs from verl.workers.actor import BasePPOActor @@ -60,7 +60,7 @@ def __init__(self, config, actor_module: nn.Module, actor_optimizer: torch.optim else verl_F.entropy_from_logits ) - def _forward_micro_batch(self, micro_batch, temperature, calculate_entropy=False) -> Tuple[torch.Tensor, torch.Tensor]: + def _forward_micro_batch(self, micro_batch, temperature, calculate_entropy=False) -> Tuple[torch.Tensor, torch.FloatTensor]: """ Returns: entropy: # (bs, response_len) @@ -173,7 +173,7 @@ def _optimizer_step(self): return grad_norm @GPUMemoryLogger(role="dp actor", logger=logger) - def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Tensor: + def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Compute the log probability of the responses given input_ids, attention_mask and position_ids Args: @@ -189,35 +189,32 @@ def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Te ``responses``: tensor of shape [batch_size, response_length]. torch.int64. Returns: - torch.Tensor: the log_prob tensor + log_probs: tensor of shape [batch_size, response_length]. torch.float32. + entropys: tensor of shape [batch_size, response_length]. torch.float32. """ # set to eval self.actor_module.eval() - micro_batch_size = data.meta_info["micro_batch_size"] temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid slient error use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] - batch = data.select(batch_keys=select_keys).batch - has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() - - if has_multi_modal_inputs: - num_micro_batches = data.batch.batch_size[0] // micro_batch_size - non_tensor_select_keys = ["multi_modal_inputs"] - micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) - elif use_dynamic_bsz: - # split using dynamic bsz + non_tensor_select_keys = ["multi_modal_inputs"] if "multi_modal_inputs" in data.non_tensor_batch.keys() else [] + + selected_data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) + + if use_dynamic_bsz: max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size - micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) + micro_data_chunks, indices = get_uniform_data_chunks(data=selected_data, max_token_len=max_token_len) else: - micro_batches = batch.split(micro_batch_size) + micro_batch_size = data.meta_info["micro_batch_size"] + num_micro_batches = len(selected_data) // micro_batch_size + micro_data_chunks = selected_data.chunk(num_micro_batches) log_probs_lst = [] entropy_lst = [] - for micro_batch in micro_batches: - if isinstance(micro_batch, DataProto): - micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch} + for micro_data_chunk in micro_data_chunks: + micro_batch = {**micro_data_chunk.batch, **micro_data_chunk.non_tensor_batch} with torch.no_grad(): entropy, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature, calculate_entropy=calculate_entropy) log_probs_lst.append(log_probs) @@ -249,53 +246,38 @@ def update_policy(self, data: DataProto): select_keys.append("loss_mask") if self.config.use_kl_loss: select_keys.append("ref_log_prob") - batch = data.select(batch_keys=select_keys).batch - has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + non_tensor_select_keys = ["multi_modal_inputs"] if "multi_modal_inputs" in data.non_tensor_batch.keys() else [] + + selected_data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) # Split to make minibatch iterator for updating the actor # See PPO paper for details. https://arxiv.org/abs/1707.06347 - if has_multi_modal_inputs: - num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size - non_tensor_select_keys = ["multi_modal_inputs"] - dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches) - else: - dataloader = batch.split(self.config.ppo_mini_batch_size) + num_mini_batches = len(selected_data) // self.config.ppo_mini_batch_size + mini_dataloader = selected_data.chunk(num_mini_batches) # TODO: `make_minibatch_iterator`` as in megatron metrics = {} for epoch in range(self.config.ppo_epochs): - for batch_idx, data in enumerate(dataloader): + for mini_idx, mini_data_chunk in enumerate(mini_dataloader): # split batch into micro_batches - mini_batch = data - if has_multi_modal_inputs: - self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu - num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu - micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) - elif self.config.use_dynamic_bsz: + if self.config.use_dynamic_bsz: max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size - micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) + micro_data_chunks, _ = get_uniform_data_chunks(data=mini_data_chunk, max_token_len=max_token_len) else: - self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu - # split batch into micro_batches - micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) + num_micro_batches = len(mini_data_chunk) // self.config.ppo_micro_batch_size_per_gpu + micro_data_chunks = mini_data_chunk.chunk(num_micro_batches) self.actor_optimizer.zero_grad() - for data in micro_batches: - # Support all hardwares - if isinstance(data, DataProto): - data = {**data.batch.to(torch.cuda.current_device()), **data.non_tensor_batch} - else: - data = data.to(torch.cuda.current_device()) # actor device is cpu when using offload - responses = data["responses"] - response_length = responses.size(1) - attention_mask = data["attention_mask"] + for micro_data_chunk in micro_data_chunks: + micro_batch = {**micro_data_chunk.batch.to(torch.cuda.current_device()), **micro_data_chunk.non_tensor_batch} + response_length = micro_batch["responses"].size(-1) + attention_mask = micro_batch["attention_mask"] if multi_turn: response_mask = data["loss_mask"][:, -response_length:] else: response_mask = attention_mask[:, -response_length:] - - old_log_prob = data["old_log_probs"] - advantages = data["advantages"] + old_log_prob = micro_batch["old_log_probs"] + advantages = micro_batch["advantages"] clip_ratio = self.config.clip_ratio clip_ratio_low = self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio @@ -308,7 +290,7 @@ def update_policy(self, data: DataProto): calculate_entropy = False if entropy_coeff != 0: calculate_entropy = True - entropy, log_prob = self._forward_micro_batch(micro_batch=data, temperature=temperature, calculate_entropy=calculate_entropy) + entropy, log_prob = self._forward_micro_batch(micro_batch=micro_batch, temperature=temperature, calculate_entropy=calculate_entropy) pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss( old_log_prob=old_log_prob, @@ -331,7 +313,7 @@ def update_policy(self, data: DataProto): policy_loss = pg_loss if self.config.use_kl_loss: - ref_log_prob = data["ref_log_prob"] + ref_log_prob = micro_batch["ref_log_prob"] # compute kl loss kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type) kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode) @@ -340,23 +322,27 @@ def update_policy(self, data: DataProto): metrics["actor/kl_loss"] = kl_loss.detach().item() metrics["actor/kl_coef"] = self.config.kl_loss_coef - if self.config.use_dynamic_bsz: - # relative to the dynamic bsz - loss = policy_loss * (len(data) / self.config.ppo_mini_batch_size) - else: - loss = policy_loss / self.gradient_accumulation + if self.config.loss_agg_mode == "token-mean": + mini_batch_loss_token_nums = data.meta_info["mini_batch_loss_token_nums"] + mini_batch_loss_token_num = mini_batch_loss_token_nums[mini_idx] + num_valid_toks = response_mask.sum() + loss = policy_loss * num_valid_toks / mini_batch_loss_token_num + else: # seq-mean + loss = policy_loss * (len(micro_data_chunk) / self.config.ppo_mini_batch_size) + loss.backward() - data = { + mini_metric_data = { + "actor/entropy": entropy_loss.detach().item(), "actor/pg_loss": pg_loss.detach().item(), "actor/pg_clipfrac": pg_clipfrac.detach().item(), "actor/ppo_kl": ppo_kl.detach().item(), "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), } - append_to_dict(metrics, data) + append_to_dict(metrics, mini_metric_data) grad_norm = self._optimizer_step() - data = {"actor/grad_norm": grad_norm.detach().item()} - append_to_dict(metrics, data) + metric_data = {"actor/grad_norm": grad_norm.detach().item()} + append_to_dict(metrics, metric_data) self.actor_optimizer.zero_grad() return metrics diff --git a/verl/workers/critic/dp_critic.py b/verl/workers/critic/dp_critic.py index f87dc3a9cf4..4e99ec8c266 100644 --- a/verl/workers/critic/dp_critic.py +++ b/verl/workers/critic/dp_critic.py @@ -29,7 +29,7 @@ from verl.trainer.ppo import core_algos from verl.utils.debug import GPUMemoryLogger from verl.utils.py_functional import append_to_dict -from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches +from verl.utils.seqlen_balancing import get_reverse_idx, get_uniform_data_chunks from verl.utils.torch_functional import masked_mean from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs from verl.workers.critic import BasePPOCritic @@ -128,27 +128,25 @@ def _optimizer_step(self): @GPUMemoryLogger(role="dp critic", logger=logger) def compute_values(self, data: DataProto) -> torch.Tensor: self.critic_module.eval() - micro_batch_size = data.meta_info["micro_batch_size"] + select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] - batch = data.select(batch_keys=select_keys).batch + non_tensor_select_keys = ["multi_modal_inputs"] if "multi_modal_inputs" in data.non_tensor_batch.keys() else [] + + selected_data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) + use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] - has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() - - if has_multi_modal_inputs: - num_micro_batches = data.batch.batch_size[0] // micro_batch_size - non_tensor_select_keys = ["multi_modal_inputs"] - micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) - elif use_dynamic_bsz: - # split using dynamic bsz + + if use_dynamic_bsz: max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size - micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) + micro_data_chunks, indices = get_uniform_data_chunks(data=selected_data, max_token_len=max_token_len) else: - micro_batches = batch.split(micro_batch_size) + micro_batch_size = data.meta_info["micro_batch_size"] + num_micro_batches = len(selected_data) // micro_batch_size + micro_data_chunks = selected_data.chunk(num_micro_batches) values_lst = [] - for micro_batch in micro_batches: - if isinstance(micro_batch, DataProto): - micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch} + for micro_data_chunk in micro_data_chunks: + micro_batch = {**micro_data_chunk.batch, **micro_data_chunk.non_tensor_batch} with torch.no_grad(): values = self._forward_micro_batch(micro_batch) @@ -174,77 +172,62 @@ def update_critic(self, data: DataProto): metrics = {} select_keys = ["input_ids", "responses", "attention_mask", "position_ids", "values", "returns"] - batch = data.select(batch_keys=select_keys).batch - has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + non_tensor_select_keys = ["multi_modal_inputs"] if "multi_modal_inputs" in data.non_tensor_batch.keys() else [] + + selected_data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) # Split to make minibatch iterator for updating the actor # See PPO paper for details. https://arxiv.org/abs/1707.06347 - if has_multi_modal_inputs: - num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size - non_tensor_select_keys = ["multi_modal_inputs"] - dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches) - else: - dataloader = batch.split(self.config.ppo_mini_batch_size) + num_mini_batches = len(selected_data) // self.config.ppo_mini_batch_size + mini_dataloader = selected_data.chunk(num_mini_batches) # TODO: `make_minibatch_iterator`` as in megatron for epoch in range(self.config.ppo_epochs): - for batch_idx, data in enumerate(dataloader): - # split batch into micro_batches - mini_batch = data - if has_multi_modal_inputs: - num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu - micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) - elif self.config.use_dynamic_bsz: + for mini_idx, mini_data_chunk in enumerate(mini_dataloader): + if self.config.use_dynamic_bsz: max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size - micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) + micro_data_chunks, _ = get_uniform_data_chunks(data=mini_data_chunk, max_token_len=max_token_len) else: - micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) - self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + num_micro_batches = len(mini_data_chunk) // self.config.ppo_micro_batch_size_per_gpu + micro_data_chunks = mini_data_chunk.chunk(num_micro_batches) self.critic_optimizer.zero_grad() - for data in micro_batches: - # Support all devices - if isinstance(data, DataProto): - data = {**data.batch.to(torch.cuda.current_device()), **data.non_tensor_batch} - else: - data = data.to(torch.cuda.current_device()) # critic device is cpu when using offload - responses = data["responses"] - attention_mask = data["attention_mask"] - values = data["values"] - returns = data["returns"] - response_length = responses.size(1) + for micro_data_chunk in micro_data_chunks: + micro_batch = {**micro_data_chunk.batch, **micro_data_chunk.non_tensor_batch} - response_mask = attention_mask[:, -response_length - 1 : -1] + response_ids = micro_batch["responses"] + attention_mask = micro_batch["attention_mask"] + values = micro_batch["values"] + returns = micro_batch["returns"] + response_length = response_ids.size(1) - vpreds = self._forward_micro_batch(data) + vpreds = self._forward_micro_batch(micro_batch) # assert not torch.any(torch.isnan(vpreds)).item() - vf_loss, vf_clipfrac = core_algos.compute_value_loss( - vpreds=vpreds, - values=values, - returns=returns, - response_mask=response_mask, - cliprange_value=self.config.cliprange_value, - ) - if self.config.use_dynamic_bsz: - # relative to the dynamic bsz - loss = vf_loss * (len(data) / self.config.ppo_mini_batch_size) - else: - loss = vf_loss / self.gradient_accumulation + state_mask = attention_mask[:, -response_length - 1 : -1] + vf_loss, vf_clipfrac = core_algos.compute_value_loss(vpreds=vpreds, values=values, returns=returns, response_mask=state_mask, cliprange_value=self.config.cliprange_value, loss_agg_mode=self.config.loss_agg_mode) + + if self.config.loss_agg_mode == "token-mean": + mini_batch_loss_token_nums = data.meta_info["mini_batch_loss_token_nums"] + mini_batch_loss_token_num = mini_batch_loss_token_nums[mini_idx] + num_loss_toks = state_mask.sum() + loss = vf_loss * num_loss_toks / mini_batch_loss_token_num + else: # seq-mean + loss = vf_loss * (len(micro_data_chunk) / self.config.ppo_mini_batch_size) loss.backward() - data = { + mini_metric_data = { "critic/vf_loss": vf_loss.detach().item(), "critic/vf_clipfrac": vf_clipfrac.detach().item(), - "critic/vpred_mean": masked_mean(vpreds, response_mask).detach().item(), + "critic/vpred_mean": masked_mean(vpreds, state_mask).detach().item(), } - append_to_dict(metrics, data) + append_to_dict(metrics, mini_metric_data) grad_norm = self._optimizer_step() - data = {"critic/grad_norm": grad_norm.detach().item()} - append_to_dict(metrics, data) + metric_data = {"critic/grad_norm": grad_norm.detach().item()} + append_to_dict(metrics, metric_data) self.critic_optimizer.zero_grad() return metrics diff --git a/verl/workers/critic/megatron_critic.py b/verl/workers/critic/megatron_critic.py index 68c1d51e889..42419054741 100644 --- a/verl/workers/critic/megatron_critic.py +++ b/verl/workers/critic/megatron_critic.py @@ -163,6 +163,7 @@ def loss_func(output, data, meta_info): returns=returns, response_mask=response_mask, cliprange_value=cliprange_value, + loss_agg_mode=self.config.loss_agg_mode, ) stats = { "critic/vf_loss": vf_loss.detach().item(), diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index c1b501c0cb9..838efa55585 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -1195,7 +1195,7 @@ def _switch_chat_template(self, data: DataProto): def compute_rm_score(self, data: DataProto): import itertools - from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches + from verl.utils.seqlen_balancing import get_reverse_idx, get_uniform_data_chunks # Support all hardwares data = data.to(torch.cuda.current_device()) @@ -1223,12 +1223,13 @@ def compute_rm_score(self, data: DataProto): use_dynamic_bsz = self.config.use_dynamic_bsz if use_dynamic_bsz: max_token_len = self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size - micro_batches, indices = rearrange_micro_batches(batch=rm_data.batch, max_token_len=max_token_len) + micro_data_chunks, indices = get_uniform_data_chunks(data=rm_data, max_token_len=max_token_len) else: - micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu) + num_micro_batches = len(rm_data) // self.config.micro_batch_size_per_gpu + micro_data_chunks = rm_data.chunk(num_micro_batches) output = [] - for micro_batch in micro_batches: - rm_score = self._forward_micro_batch(micro_batch) + for micro_data_chunk in micro_data_chunks: + rm_score = self._forward_micro_batch(micro_batch=micro_data_chunk.batch) output.append(rm_score) scores = torch.cat(output, dim=0) # (batch_size)