-
Notifications
You must be signed in to change notification settings - Fork 0
Further modify tis #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -5,6 +5,16 @@ | |||||||||||||||||||||||||
| from slime.backends.megatron_utils.cp_utils import all_gather_with_cp, slice_log_prob_with_cp | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def masked_sum(x: torch.Tensor, loss_mask: torch.Tensor, expand: bool = False) -> torch.Tensor: | ||||||||||||||||||||||||||
| result = (x * loss_mask).sum() | ||||||||||||||||||||||||||
| return result.expand_as(x) if expand else result | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def masked_mean(x: torch.Tensor, loss_mask: torch.Tensor, expand: bool = False) -> torch.Tensor: | ||||||||||||||||||||||||||
| result = masked_sum(x, loss_mask) / torch.clamp_min(loss_mask.sum(), 1) | ||||||||||||||||||||||||||
| return result.expand_as(x) if expand else result | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def metrics_append(metrics: Dict[str, list[torch.Tensor]], key: str, value: torch.Tensor) -> None: | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
@@ -138,33 +148,31 @@ def compute_train_infer_is_weights( | |||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # handle each sequence independently | ||||||||||||||||||||||||||
| for train_log_prob, rollout_log_prob, loss_mask in zip(train_log_probs, rollout_log_probs, loss_masks): | ||||||||||||||||||||||||||
| raw_log_ratio = train_log_prob - rollout_log_prob | ||||||||||||||||||||||||||
| loss_mask = loss_mask.float() | ||||||||||||||||||||||||||
| add_ppl_metrics(train_log_prob, rollout_log_prob, loss_mask, metrics) | ||||||||||||||||||||||||||
| raw_log_ratio = train_log_prob - rollout_log_prob | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # level: The aggregation level for the importance sampling weights. | ||||||||||||||||||||||||||
| if level == "token": | ||||||||||||||||||||||||||
| # Per-token ratio (biased) | ||||||||||||||||||||||||||
| log_ratio_for_metrics = raw_log_ratio | ||||||||||||||||||||||||||
| elif level == "sequence": | ||||||||||||||||||||||||||
| # Product of ratios (unbiased but high variance) | ||||||||||||||||||||||||||
| agg_log_ratio = (raw_log_ratio * loss_mask).sum() | ||||||||||||||||||||||||||
| log_ratio_for_metrics = agg_log_ratio.expand_as(raw_log_ratio) | ||||||||||||||||||||||||||
| log_ratio_for_metrics = masked_sum(raw_log_ratio, loss_mask, expand=True) | ||||||||||||||||||||||||||
| elif level == "geometric": | ||||||||||||||||||||||||||
| # Geometric mean of ratios (biased but low variance) | ||||||||||||||||||||||||||
| agg_log_ratio = (raw_log_ratio * loss_mask).sum() / torch.clamp_min(loss_mask.sum(), 1) | ||||||||||||||||||||||||||
| log_ratio_for_metrics = agg_log_ratio.expand_as(raw_log_ratio) | ||||||||||||||||||||||||||
| log_ratio_for_metrics = masked_mean(raw_log_ratio, loss_mask, expand=True) | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| raise ValueError(f"Invalid importance sampling level: {level}") | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| log_ratio_safe = torch.clamp(log_ratio_for_metrics, min=-SAFETY_BOUND, max=SAFETY_BOUND) | ||||||||||||||||||||||||||
| weights = torch.exp(log_ratio_safe) | ||||||||||||||||||||||||||
| metrics_append(metrics, "ratio_mean_before_tis", weights) | ||||||||||||||||||||||||||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # mask out catastrophic tokens | ||||||||||||||||||||||||||
| if args.train_infer_is_veto_threshold is not None: | ||||||||||||||||||||||||||
| veto_mask = calculate_veto_mask(raw_log_ratio, loss_mask, args.train_infer_is_veto_threshold, metrics) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| metrics_append(metrics, "raw_ratio_mean", weights) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # mode: how to handle the importance sampling weights exceeding the thresholds. | ||||||||||||||||||||||||||
| if args.train_infer_is_mode == "truncate": | ||||||||||||||||||||||||||
| # Cap the importance sampling weights at the upper threshold | ||||||||||||||||||||||||||
|
|
@@ -261,3 +269,56 @@ def slice_cp_and_concat( | |||||||||||||||||||||||||
| is_metrics[key] = values | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| return is_weights, is_metrics | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def add_ppl_metrics( | ||||||||||||||||||||||||||
| train_log_prob: torch.Tensor, | ||||||||||||||||||||||||||
| rollout_log_prob: torch.Tensor, | ||||||||||||||||||||||||||
| loss_mask: torch.Tensor, | ||||||||||||||||||||||||||
| metrics: Dict[str, list[torch.Tensor]], | ||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||
|
Comment on lines
+274
to
+279
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function
Suggested change
|
||||||||||||||||||||||||||
| loss_mask = loss_mask.float() | ||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # 1. Training policy perplexity metrics | ||||||||||||||||||||||||||
| mean_log_prob_training = masked_mean(train_log_prob, loss_mask, expand=True) | ||||||||||||||||||||||||||
| training_log_ppl = -mean_log_prob_training | ||||||||||||||||||||||||||
| training_ppl = torch.exp(training_log_ppl) | ||||||||||||||||||||||||||
| metrics_append(metrics, "training_log_ppl", training_log_ppl) | ||||||||||||||||||||||||||
| metrics_append(metrics, "training_ppl", training_ppl) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # 2. Rollout policy perplexity metrics | ||||||||||||||||||||||||||
| mean_log_prob_rollout = masked_mean(rollout_log_prob, loss_mask, expand=True) | ||||||||||||||||||||||||||
| rollout_log_ppl = -mean_log_prob_rollout | ||||||||||||||||||||||||||
| rollout_ppl = torch.exp(rollout_log_ppl) | ||||||||||||||||||||||||||
| metrics_append(metrics, "rollout_log_ppl", rollout_log_ppl) | ||||||||||||||||||||||||||
| metrics_append(metrics, "rollout_ppl", rollout_ppl) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # 3a. kl: Direct estimator for KL(π_rollout || π_training) | ||||||||||||||||||||||||||
| # This is the standard KL divergence: E[log(π_rollout) - log(π_training)] | ||||||||||||||||||||||||||
| # Positive value means rollout policy is more confident than training policy | ||||||||||||||||||||||||||
| kl_per_token = rollout_log_prob - train_log_prob | ||||||||||||||||||||||||||
| metrics_append(metrics, "kl", kl_per_token) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # 3b. K3 KL estimator for improved stability | ||||||||||||||||||||||||||
| # More stable for small KL values using: E[exp(log_ratio) - log_ratio - 1] | ||||||||||||||||||||||||||
| # Formula: KL ≈ E[r - log(r) - 1] where r = π_training/π_rollout | ||||||||||||||||||||||||||
| log_ratio = train_log_prob - rollout_log_prob | ||||||||||||||||||||||||||
| k3_kl_matrix = torch.exp(log_ratio) - log_ratio - 1 | ||||||||||||||||||||||||||
| metrics_append(metrics, "k3_kl", k3_kl_matrix) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # 3c. Log PPL difference (sequence-level perplexity difference) | ||||||||||||||||||||||||||
| # log_ppl_diff = mean_log_prob_rollout - mean_log_prob_training | ||||||||||||||||||||||||||
| # Since ppl = exp(-log_prob), we have: | ||||||||||||||||||||||||||
| # log(ppl_ratio) = log(training_ppl/rollout_ppl) = log_ppl_diff | ||||||||||||||||||||||||||
|
Comment on lines
+311
to
+312
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The metrics I recommend removing these two lines until a proper mechanism for computing and logging max/min values across the batch is implemented. |
||||||||||||||||||||||||||
| # Positive value means training assigns lower probability (higher PPL) than rollout | ||||||||||||||||||||||||||
| log_ppl_diff = mean_log_prob_rollout - mean_log_prob_training | ||||||||||||||||||||||||||
| metrics_append(metrics, "log_ppl_diff", log_ppl_diff) | ||||||||||||||||||||||||||
| metrics_append(metrics, "log_ppl_abs_diff", log_ppl_diff.abs()) | ||||||||||||||||||||||||||
| metrics_append(metrics, "log_ppl_diff_max", log_ppl_diff) | ||||||||||||||||||||||||||
| metrics_append(metrics, "log_ppl_diff_min", log_ppl_diff) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # 3d. PPL ratio (how much higher is training PPL vs rollout PPL) | ||||||||||||||||||||||||||
| # For numerical stability, compute in log space using log_ppl_diff | ||||||||||||||||||||||||||
| # Note: log_ppl_diff = log(ppl_ratio), so ppl_ratio = exp(log_ppl_diff) | ||||||||||||||||||||||||||
| ppl_ratio = torch.exp(log_ppl_diff) | ||||||||||||||||||||||||||
| metrics_append(metrics, "ppl_ratio", ppl_ratio) | ||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change the naming of
raw_log_ratio, maybe toraw_log_ratio_diff