diff --git a/recipe/dapo/dapo_ray_trainer.py b/recipe/dapo/dapo_ray_trainer.py index 9d15c74c681..cea58308228 100644 --- a/recipe/dapo/dapo_ray_trainer.py +++ b/recipe/dapo/dapo_ray_trainer.py @@ -26,6 +26,7 @@ from tqdm import tqdm from verl import DataProto +from verl.trainer.ppo.core_algos import agg_loss from verl.trainer.ppo.metric_utils import ( compute_data_metrics, compute_throughout_metrics, @@ -220,6 +221,13 @@ def fit(self): # recompute old_log_probs with _timer("old_log_prob", timing_raw): old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + entropys = old_log_prob.batch["entropys"] + response_masks = batch.batch["response_mask"] + loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode + entropy_loss = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) + old_log_prob_metrics = {"actor/entropy_loss": entropy_loss.detach().item()} + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") batch = batch.union(old_log_prob) if self.use_reference_policy: