Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ critic:
kl_ctrl:
type: fixed
kl_coef: 0.001
loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode}
checkpoint:
contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space

Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ critic:
shuffle: ${actor_rollout_ref.actor.shuffle}
grad_clip: 1.0
cliprange_value: 0.5
loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode}
checkpoint:
contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space

Expand Down
24 changes: 12 additions & 12 deletions verl/trainer/ppo/core_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def compute_policy_loss(
cliprange_low=None,
cliprange_high=None,
clip_ratio_c=3.0,
loss_agg_mode="token-mean",
loss_agg_mode: str = "token-mean",
):
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122
Args:
Expand All @@ -380,11 +380,7 @@ def compute_policy_loss(
The higher clip range used in PPO.
clip_ratio_c: (float) default: 3.0
The lower bound of the ratio for dual-clip PPO, See https://arxiv.org/pdf/1912.09729
loss_agg_mode: (str) choices: "token-mean" /
"seq-mean-token-sum" /
"seq-mean-token-mean" /
"seq-mean-token-sum-norm" /
"token-mean" is the default behavior
loss_agg_mode: (str) see `agg_loss`

Returns:
pg_loss: `a scalar torch.Tensor`
Expand Down Expand Up @@ -421,8 +417,8 @@ def compute_policy_loss(
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower


def compute_entropy_loss(logits, response_mask):
"""Compute Categorical entropy loss
def compute_entropy_loss(logits, response_mask, loss_agg_mode: str = "token-mean"):
"""Compute categorical entropy loss (For backward compatibility)

Args:
logits: `(torch.Tensor)`
Expand All @@ -435,12 +431,12 @@ def compute_entropy_loss(logits, response_mask):

"""
# compute entropy
entropy = verl_F.entropy_from_logits(logits) # (bs, response_len)
entropy_loss = verl_F.masked_mean(entropy, mask=response_mask)
token_entropy = verl_F.entropy_from_logits(logits) # (bs, response_len)
entropy_loss = agg_loss(loss_mat=token_entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return entropy_loss


def compute_value_loss(vpreds, returns, values, response_mask, cliprange_value):
def compute_value_loss(vpreds: torch.Tensor, returns: torch.Tensor, values: torch.Tensor, response_mask: torch.Tensor, cliprange_value: float, loss_agg_mode: str = "token-mean"):
"""Compute the value loss. Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151

Args:
Expand All @@ -450,6 +446,9 @@ def compute_value_loss(vpreds, returns, values, response_mask, cliprange_value):
Old values of value head, shape (`batch_size`, `response_length`)
returns: (`torch.FloatTensor`):
Ground truth returns, shape (`batch_size`, `response_length`)
response_mask: `(torch.Tensor)`
Mask for tokens to calculate value function losses. # TODO: Rename to `state_mask`.
loss_agg_mode: (str) see `agg_loss`

Returns:
vf_loss: a scalar (`torch.FloatTensor`):
Expand All @@ -461,7 +460,8 @@ def compute_value_loss(vpreds, returns, values, response_mask, cliprange_value):
vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)
vf_losses1 = (vpreds - returns) ** 2
vf_losses2 = (vpredclipped - returns) ** 2
vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), response_mask)
clipped_vf_losses = torch.max(vf_losses1, vf_losses2)
vf_loss = agg_loss(loss_mat=clipped_vf_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask)
return vf_loss, vf_clipfrac

Expand Down
2 changes: 1 addition & 1 deletion verl/workers/actor/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def update_policy(self, data: DataProto):
ref_log_prob = data["ref_log_prob"]
# compute kl loss
kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type)
kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode)
kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
metrics["actor/kl_loss"] = kl_loss.detach().item()
Expand Down
1 change: 1 addition & 0 deletions verl/workers/critic/dp_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def update_critic(self, data: DataProto):
returns=returns,
response_mask=response_mask,
cliprange_value=self.config.cliprange_value,
loss_agg_mode=self.config.loss_agg_mode,
)
if self.config.use_dynamic_bsz:
# relative to the dynamic bsz
Expand Down
1 change: 1 addition & 0 deletions verl/workers/critic/megatron_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def loss_func(output, data, meta_info):
returns=returns,
response_mask=response_mask,
cliprange_value=cliprange_value,
loss_agg_mode=self.config.loss_agg_mode,
)
stats = {
"critic/vf_loss": vf_loss.detach().item(),
Expand Down