Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions verl/workers/actor/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,8 @@ def update_policy(self, data: DataProto):
self.actor_optimizer.zero_grad()

for data in micro_batches:
micro_batch_metrics = {}

# Support all hardwares
if isinstance(data, DataProto):
data = {**data.batch.to(get_device_id()), **data.non_tensor_batch}
Expand Down Expand Up @@ -484,8 +486,8 @@ def update_policy(self, data: DataProto):
kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
metrics["actor/kl_loss"] = kl_loss.detach().item()
metrics["actor/kl_coef"] = self.config.kl_loss_coef
micro_batch_metrics["actor/kl_loss"] = kl_loss.detach().item()
micro_batch_metrics["actor/kl_coef"] = self.config.kl_loss_coef

if self.config.use_dynamic_bsz:
# relative to the dynamic bsz
Expand All @@ -494,16 +496,18 @@ def update_policy(self, data: DataProto):
loss = policy_loss / self.gradient_accumulation
loss.backward()

data = {
"actor/pg_loss": pg_loss.detach().item(),
"actor/pg_clipfrac": pg_clipfrac.detach().item(),
"actor/ppo_kl": ppo_kl.detach().item(),
"actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
}
append_to_dict(metrics, data)
micro_batch_metrics.update(
{
"actor/pg_loss": pg_loss.detach().item(),
"actor/pg_clipfrac": pg_clipfrac.detach().item(),
"actor/ppo_kl": ppo_kl.detach().item(),
"actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
}
)
append_to_dict(metrics, micro_batch_metrics)

grad_norm = self._optimizer_step()
data = {"actor/grad_norm": grad_norm.detach().item()}
append_to_dict(metrics, data)
mini_batch_metrics = {"actor/grad_norm": grad_norm.detach().item()}
append_to_dict(metrics, mini_batch_metrics)
self.actor_optimizer.zero_grad()
return metrics
20 changes: 12 additions & 8 deletions verl/workers/critic/dp_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ def update_critic(self, data: DataProto):
self.critic_optimizer.zero_grad()

for data in micro_batches:
micro_batch_metrics = {}

# Support all devices
if isinstance(data, DataProto):
data = {**data.batch.to(get_device_id()), **data.non_tensor_batch}
Expand Down Expand Up @@ -254,16 +256,18 @@ def update_critic(self, data: DataProto):

loss.backward()

data = {
"critic/vf_loss": vf_loss.detach().item(),
"critic/vf_clipfrac": vf_clipfrac.detach().item(),
"critic/vpred_mean": masked_mean(vpreds, response_mask).detach().item(),
}
micro_batch_metrics.update(
{
"critic/vf_loss": vf_loss.detach().item(),
"critic/vf_clipfrac": vf_clipfrac.detach().item(),
"critic/vpred_mean": masked_mean(vpreds, response_mask).detach().item(),
}
)

append_to_dict(metrics, data)
append_to_dict(metrics, micro_batch_metrics)

grad_norm = self._optimizer_step()
data = {"critic/grad_norm": grad_norm.detach().item()}
append_to_dict(metrics, data)
mini_batch_metrics = {"critic/grad_norm": grad_norm.detach().item()}
append_to_dict(metrics, mini_batch_metrics)
self.critic_optimizer.zero_grad()
return metrics
2 changes: 1 addition & 1 deletion verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,9 +1123,9 @@ def update_critic(self, data: DataProto):
estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
metrics["perf/mfu/critic"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size

self.critic_lr_scheduler.step()
lr = self.critic_lr_scheduler.get_last_lr()[0]
metrics["critic/lr"] = lr
self.critic_lr_scheduler.step()
Comment on lines -1126 to +1128
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also fix an issue where previous recorded lr is actually the next step's lr instead of the current step. as #1463


output = DataProto(batch=None, meta_info={"metrics": metrics})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
Expand Down
Loading