diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 5fec2a8d014..750797f442b 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -199,6 +199,7 @@ critic: kl_ctrl: type: fixed kl_coef: 0.001 + loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode} checkpoint: contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 1b6668dce8a..cec8dc5b864 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -164,6 +164,7 @@ critic: shuffle: ${actor_rollout_ref.actor.shuffle} grad_clip: 1.0 cliprange_value: 0.5 + loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode} checkpoint: contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index a0e7cd8af48..d449a031816 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -360,7 +360,7 @@ def compute_policy_loss( cliprange_low=None, cliprange_high=None, clip_ratio_c=3.0, - loss_agg_mode="token-mean", + loss_agg_mode: str = "token-mean", ): """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122 Args: @@ -380,11 +380,7 @@ def compute_policy_loss( The higher clip range used in PPO. clip_ratio_c: (float) default: 3.0 The lower bound of the ratio for dual-clip PPO, See https://arxiv.org/pdf/1912.09729 - loss_agg_mode: (str) choices: "token-mean" / - "seq-mean-token-sum" / - "seq-mean-token-mean" / - "seq-mean-token-sum-norm" / - "token-mean" is the default behavior + loss_agg_mode: (str) see `agg_loss` Returns: pg_loss: `a scalar torch.Tensor` @@ -421,8 +417,8 @@ def compute_policy_loss( return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower -def compute_entropy_loss(logits, response_mask): - """Compute Categorical entropy loss +def compute_entropy_loss(logits, response_mask, loss_agg_mode: str = "token-mean"): + """Compute categorical entropy loss (For backward compatibility) Args: logits: `(torch.Tensor)` @@ -435,12 +431,12 @@ def compute_entropy_loss(logits, response_mask): """ # compute entropy - entropy = verl_F.entropy_from_logits(logits) # (bs, response_len) - entropy_loss = verl_F.masked_mean(entropy, mask=response_mask) + token_entropy = verl_F.entropy_from_logits(logits) # (bs, response_len) + entropy_loss = agg_loss(loss_mat=token_entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) return entropy_loss -def compute_value_loss(vpreds, returns, values, response_mask, cliprange_value): +def compute_value_loss(vpreds: torch.Tensor, returns: torch.Tensor, values: torch.Tensor, response_mask: torch.Tensor, cliprange_value: float, loss_agg_mode: str = "token-mean"): """Compute the value loss. Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151 Args: @@ -450,6 +446,9 @@ def compute_value_loss(vpreds, returns, values, response_mask, cliprange_value): Old values of value head, shape (`batch_size`, `response_length`) returns: (`torch.FloatTensor`): Ground truth returns, shape (`batch_size`, `response_length`) + response_mask: `(torch.Tensor)` + Mask for tokens to calculate value function losses. # TODO: Rename to `state_mask`. + loss_agg_mode: (str) see `agg_loss` Returns: vf_loss: a scalar (`torch.FloatTensor`): @@ -461,7 +460,8 @@ def compute_value_loss(vpreds, returns, values, response_mask, cliprange_value): vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value) vf_losses1 = (vpreds - returns) ** 2 vf_losses2 = (vpredclipped - returns) ** 2 - vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), response_mask) + clipped_vf_losses = torch.max(vf_losses1, vf_losses2) + vf_loss = agg_loss(loss_mat=clipped_vf_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask) return vf_loss, vf_clipfrac diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 3b26d81fe6d..f8c6d11c7d6 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -334,7 +334,7 @@ def update_policy(self, data: DataProto): ref_log_prob = data["ref_log_prob"] # compute kl loss kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type) - kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode) + kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef metrics["actor/kl_loss"] = kl_loss.detach().item() diff --git a/verl/workers/critic/dp_critic.py b/verl/workers/critic/dp_critic.py index f87dc3a9cf4..0107ff34d60 100644 --- a/verl/workers/critic/dp_critic.py +++ b/verl/workers/critic/dp_critic.py @@ -226,6 +226,7 @@ def update_critic(self, data: DataProto): returns=returns, response_mask=response_mask, cliprange_value=self.config.cliprange_value, + loss_agg_mode=self.config.loss_agg_mode, ) if self.config.use_dynamic_bsz: # relative to the dynamic bsz diff --git a/verl/workers/critic/megatron_critic.py b/verl/workers/critic/megatron_critic.py index 68c1d51e889..42419054741 100644 --- a/verl/workers/critic/megatron_critic.py +++ b/verl/workers/critic/megatron_critic.py @@ -163,6 +163,7 @@ def loss_func(output, data, meta_info): returns=returns, response_mask=response_mask, cliprange_value=cliprange_value, + loss_agg_mode=self.config.loss_agg_mode, ) stats = { "critic/vf_loss": vf_loss.detach().item(),