diff --git a/slime/utils/train_infer_is.py b/slime/utils/train_infer_is.py index c0379257a8..bceeb39a87 100644 --- a/slime/utils/train_infer_is.py +++ b/slime/utils/train_infer_is.py @@ -5,6 +5,16 @@ from slime.backends.megatron_utils.cp_utils import all_gather_with_cp, slice_log_prob_with_cp +def masked_sum(x: torch.Tensor, loss_mask: torch.Tensor, expand: bool = False) -> torch.Tensor: + result = (x * loss_mask).sum() + return result.expand_as(x) if expand else result + + +def masked_mean(x: torch.Tensor, loss_mask: torch.Tensor, expand: bool = False) -> torch.Tensor: + result = masked_sum(x, loss_mask) / torch.clamp_min(loss_mask.sum(), 1) + return result.expand_as(x) if expand else result + + def metrics_append(metrics: Dict[str, list[torch.Tensor]], key: str, value: torch.Tensor) -> None: """ @@ -138,8 +148,9 @@ def compute_train_infer_is_weights( # handle each sequence independently for train_log_prob, rollout_log_prob, loss_mask in zip(train_log_probs, rollout_log_probs, loss_masks): - raw_log_ratio = train_log_prob - rollout_log_prob loss_mask = loss_mask.float() + add_ppl_metrics(train_log_prob, rollout_log_prob, loss_mask, metrics) + raw_log_ratio = train_log_prob - rollout_log_prob # level: The aggregation level for the importance sampling weights. if level == "token": @@ -147,24 +158,21 @@ def compute_train_infer_is_weights( log_ratio_for_metrics = raw_log_ratio elif level == "sequence": # Product of ratios (unbiased but high variance) - agg_log_ratio = (raw_log_ratio * loss_mask).sum() - log_ratio_for_metrics = agg_log_ratio.expand_as(raw_log_ratio) + log_ratio_for_metrics = masked_sum(raw_log_ratio, loss_mask, expand=True) elif level == "geometric": # Geometric mean of ratios (biased but low variance) - agg_log_ratio = (raw_log_ratio * loss_mask).sum() / torch.clamp_min(loss_mask.sum(), 1) - log_ratio_for_metrics = agg_log_ratio.expand_as(raw_log_ratio) + log_ratio_for_metrics = masked_mean(raw_log_ratio, loss_mask, expand=True) else: raise ValueError(f"Invalid importance sampling level: {level}") log_ratio_safe = torch.clamp(log_ratio_for_metrics, min=-SAFETY_BOUND, max=SAFETY_BOUND) weights = torch.exp(log_ratio_safe) + metrics_append(metrics, "ratio_mean_before_tis", weights) # mask out catastrophic tokens if args.train_infer_is_veto_threshold is not None: veto_mask = calculate_veto_mask(raw_log_ratio, loss_mask, args.train_infer_is_veto_threshold, metrics) - metrics_append(metrics, "raw_ratio_mean", weights) - # mode: how to handle the importance sampling weights exceeding the thresholds. if args.train_infer_is_mode == "truncate": # Cap the importance sampling weights at the upper threshold @@ -261,3 +269,56 @@ def slice_cp_and_concat( is_metrics[key] = values return is_weights, is_metrics + + +def add_ppl_metrics( + train_log_prob: torch.Tensor, + rollout_log_prob: torch.Tensor, + loss_mask: torch.Tensor, + metrics: Dict[str, list[torch.Tensor]], +): + loss_mask = loss_mask.float() + + # 1. Training policy perplexity metrics + mean_log_prob_training = masked_mean(train_log_prob, loss_mask, expand=True) + training_log_ppl = -mean_log_prob_training + training_ppl = torch.exp(training_log_ppl) + metrics_append(metrics, "training_log_ppl", training_log_ppl) + metrics_append(metrics, "training_ppl", training_ppl) + + # 2. Rollout policy perplexity metrics + mean_log_prob_rollout = masked_mean(rollout_log_prob, loss_mask, expand=True) + rollout_log_ppl = -mean_log_prob_rollout + rollout_ppl = torch.exp(rollout_log_ppl) + metrics_append(metrics, "rollout_log_ppl", rollout_log_ppl) + metrics_append(metrics, "rollout_ppl", rollout_ppl) + + # 3a. kl: Direct estimator for KL(π_rollout || π_training) + # This is the standard KL divergence: E[log(π_rollout) - log(π_training)] + # Positive value means rollout policy is more confident than training policy + kl_per_token = rollout_log_prob - train_log_prob + metrics_append(metrics, "kl", kl_per_token) + + # 3b. K3 KL estimator for improved stability + # More stable for small KL values using: E[exp(log_ratio) - log_ratio - 1] + # Formula: KL ≈ E[r - log(r) - 1] where r = π_training/π_rollout + log_ratio = train_log_prob - rollout_log_prob + k3_kl_matrix = torch.exp(log_ratio) - log_ratio - 1 + metrics_append(metrics, "k3_kl", k3_kl_matrix) + + # 3c. Log PPL difference (sequence-level perplexity difference) + # log_ppl_diff = mean_log_prob_rollout - mean_log_prob_training + # Since ppl = exp(-log_prob), we have: + # log(ppl_ratio) = log(training_ppl/rollout_ppl) = log_ppl_diff + # Positive value means training assigns lower probability (higher PPL) than rollout + log_ppl_diff = mean_log_prob_rollout - mean_log_prob_training + metrics_append(metrics, "log_ppl_diff", log_ppl_diff) + metrics_append(metrics, "log_ppl_abs_diff", log_ppl_diff.abs()) + metrics_append(metrics, "log_ppl_diff_max", log_ppl_diff) + metrics_append(metrics, "log_ppl_diff_min", log_ppl_diff) + + # 3d. PPL ratio (how much higher is training PPL vs rollout PPL) + # For numerical stability, compute in log space using log_ppl_diff + # Note: log_ppl_diff = log(ppl_ratio), so ppl_ratio = exp(log_ppl_diff) + ppl_ratio = torch.exp(log_ppl_diff) + metrics_append(metrics, "ppl_ratio", ppl_ratio)