diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index d0a694cfed9..f245587e715 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -414,6 +414,8 @@ def update_policy(self, data: DataProto): self.actor_optimizer.zero_grad() for data in micro_batches: + micro_batch_metrics = {} + # Support all hardwares if isinstance(data, DataProto): data = {**data.batch.to(get_device_id()), **data.non_tensor_batch} @@ -484,8 +486,8 @@ def update_policy(self, data: DataProto): kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef - metrics["actor/kl_loss"] = kl_loss.detach().item() - metrics["actor/kl_coef"] = self.config.kl_loss_coef + micro_batch_metrics["actor/kl_loss"] = kl_loss.detach().item() + micro_batch_metrics["actor/kl_coef"] = self.config.kl_loss_coef if self.config.use_dynamic_bsz: # relative to the dynamic bsz @@ -494,16 +496,18 @@ def update_policy(self, data: DataProto): loss = policy_loss / self.gradient_accumulation loss.backward() - data = { - "actor/pg_loss": pg_loss.detach().item(), - "actor/pg_clipfrac": pg_clipfrac.detach().item(), - "actor/ppo_kl": ppo_kl.detach().item(), - "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), - } - append_to_dict(metrics, data) + micro_batch_metrics.update( + { + "actor/pg_loss": pg_loss.detach().item(), + "actor/pg_clipfrac": pg_clipfrac.detach().item(), + "actor/ppo_kl": ppo_kl.detach().item(), + "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), + } + ) + append_to_dict(metrics, micro_batch_metrics) grad_norm = self._optimizer_step() - data = {"actor/grad_norm": grad_norm.detach().item()} - append_to_dict(metrics, data) + mini_batch_metrics = {"actor/grad_norm": grad_norm.detach().item()} + append_to_dict(metrics, mini_batch_metrics) self.actor_optimizer.zero_grad() return metrics diff --git a/verl/workers/critic/dp_critic.py b/verl/workers/critic/dp_critic.py index 9a6991db294..3a866928dcc 100644 --- a/verl/workers/critic/dp_critic.py +++ b/verl/workers/critic/dp_critic.py @@ -221,6 +221,8 @@ def update_critic(self, data: DataProto): self.critic_optimizer.zero_grad() for data in micro_batches: + micro_batch_metrics = {} + # Support all devices if isinstance(data, DataProto): data = {**data.batch.to(get_device_id()), **data.non_tensor_batch} @@ -254,16 +256,18 @@ def update_critic(self, data: DataProto): loss.backward() - data = { - "critic/vf_loss": vf_loss.detach().item(), - "critic/vf_clipfrac": vf_clipfrac.detach().item(), - "critic/vpred_mean": masked_mean(vpreds, response_mask).detach().item(), - } + micro_batch_metrics.update( + { + "critic/vf_loss": vf_loss.detach().item(), + "critic/vf_clipfrac": vf_clipfrac.detach().item(), + "critic/vpred_mean": masked_mean(vpreds, response_mask).detach().item(), + } + ) - append_to_dict(metrics, data) + append_to_dict(metrics, micro_batch_metrics) grad_norm = self._optimizer_step() - data = {"critic/grad_norm": grad_norm.detach().item()} - append_to_dict(metrics, data) + mini_batch_metrics = {"critic/grad_norm": grad_norm.detach().item()} + append_to_dict(metrics, mini_batch_metrics) self.critic_optimizer.zero_grad() return metrics diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index d46b3d02605..dcf964ec202 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -1123,9 +1123,9 @@ def update_critic(self, data: DataProto): estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) metrics["perf/mfu/critic"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size - self.critic_lr_scheduler.step() lr = self.critic_lr_scheduler.get_last_lr()[0] metrics["critic/lr"] = lr + self.critic_lr_scheduler.step() output = DataProto(batch=None, meta_info={"metrics": metrics}) output = self.ulysses_sharding_manager.postprocess_data(data=output)