Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 68 additions & 7 deletions slime/utils/train_infer_is.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@
from slime.backends.megatron_utils.cp_utils import all_gather_with_cp, slice_log_prob_with_cp


def masked_sum(x: torch.Tensor, loss_mask: torch.Tensor, expand: bool = False) -> torch.Tensor:
result = (x * loss_mask).sum()
return result.expand_as(x) if expand else result


def masked_mean(x: torch.Tensor, loss_mask: torch.Tensor, expand: bool = False) -> torch.Tensor:
result = masked_sum(x, loss_mask) / torch.clamp_min(loss_mask.sum(), 1)
return result.expand_as(x) if expand else result


def metrics_append(metrics: Dict[str, list[torch.Tensor]], key: str, value: torch.Tensor) -> None:
"""

Expand Down Expand Up @@ -138,33 +148,31 @@ def compute_train_infer_is_weights(

# handle each sequence independently
for train_log_prob, rollout_log_prob, loss_mask in zip(train_log_probs, rollout_log_probs, loss_masks):
raw_log_ratio = train_log_prob - rollout_log_prob
loss_mask = loss_mask.float()
add_ppl_metrics(train_log_prob, rollout_log_prob, loss_mask, metrics)
raw_log_ratio = train_log_prob - rollout_log_prob
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change the naming of raw_log_ratio, maybe to raw_log_ratio_diff


# level: The aggregation level for the importance sampling weights.
if level == "token":
# Per-token ratio (biased)
log_ratio_for_metrics = raw_log_ratio
elif level == "sequence":
# Product of ratios (unbiased but high variance)
agg_log_ratio = (raw_log_ratio * loss_mask).sum()
log_ratio_for_metrics = agg_log_ratio.expand_as(raw_log_ratio)
log_ratio_for_metrics = masked_sum(raw_log_ratio, loss_mask, expand=True)
elif level == "geometric":
# Geometric mean of ratios (biased but low variance)
agg_log_ratio = (raw_log_ratio * loss_mask).sum() / torch.clamp_min(loss_mask.sum(), 1)
log_ratio_for_metrics = agg_log_ratio.expand_as(raw_log_ratio)
log_ratio_for_metrics = masked_mean(raw_log_ratio, loss_mask, expand=True)
else:
raise ValueError(f"Invalid importance sampling level: {level}")

log_ratio_safe = torch.clamp(log_ratio_for_metrics, min=-SAFETY_BOUND, max=SAFETY_BOUND)
weights = torch.exp(log_ratio_safe)
metrics_append(metrics, "ratio_mean_before_tis", weights)
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mean_is_weight_before_clip


# mask out catastrophic tokens
if args.train_infer_is_veto_threshold is not None:
veto_mask = calculate_veto_mask(raw_log_ratio, loss_mask, args.train_infer_is_veto_threshold, metrics)

metrics_append(metrics, "raw_ratio_mean", weights)

# mode: how to handle the importance sampling weights exceeding the thresholds.
if args.train_infer_is_mode == "truncate":
# Cap the importance sampling weights at the upper threshold
Expand Down Expand Up @@ -261,3 +269,56 @@ def slice_cp_and_concat(
is_metrics[key] = values

return is_weights, is_metrics


def add_ppl_metrics(
train_log_prob: torch.Tensor,
rollout_log_prob: torch.Tensor,
loss_mask: torch.Tensor,
metrics: Dict[str, list[torch.Tensor]],
):
Comment on lines +274 to +279
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function add_ppl_metrics is missing a return type hint. According to PEP 484, functions should have type hints for arguments and return values. Since this function does not return a value, you should add -> None to its signature for better code clarity and static analysis.

Suggested change
def add_ppl_metrics(
train_log_prob: torch.Tensor,
rollout_log_prob: torch.Tensor,
loss_mask: torch.Tensor,
metrics: Dict[str, list[torch.Tensor]],
):
def add_ppl_metrics(
train_log_prob: torch.Tensor,
rollout_log_prob: torch.Tensor,
loss_mask: torch.Tensor,
metrics: Dict[str, list[torch.Tensor]],
) -> None:

loss_mask = loss_mask.float()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The loss_mask is already converted to a float tensor in the calling function compute_train_infer_is_weights at line 151. This conversion here is redundant. It's best to remove it to avoid unnecessary operations and rely on the caller to provide the correct data type. Consider adding a note to the function's docstring to clarify that loss_mask is expected to be a float tensor.


# 1. Training policy perplexity metrics
mean_log_prob_training = masked_mean(train_log_prob, loss_mask, expand=True)
training_log_ppl = -mean_log_prob_training
training_ppl = torch.exp(training_log_ppl)
metrics_append(metrics, "training_log_ppl", training_log_ppl)
metrics_append(metrics, "training_ppl", training_ppl)

# 2. Rollout policy perplexity metrics
mean_log_prob_rollout = masked_mean(rollout_log_prob, loss_mask, expand=True)
rollout_log_ppl = -mean_log_prob_rollout
rollout_ppl = torch.exp(rollout_log_ppl)
metrics_append(metrics, "rollout_log_ppl", rollout_log_ppl)
metrics_append(metrics, "rollout_ppl", rollout_ppl)

# 3a. kl: Direct estimator for KL(π_rollout || π_training)
# This is the standard KL divergence: E[log(π_rollout) - log(π_training)]
# Positive value means rollout policy is more confident than training policy
kl_per_token = rollout_log_prob - train_log_prob
metrics_append(metrics, "kl", kl_per_token)

# 3b. K3 KL estimator for improved stability
# More stable for small KL values using: E[exp(log_ratio) - log_ratio - 1]
# Formula: KL ≈ E[r - log(r) - 1] where r = π_training/π_rollout
log_ratio = train_log_prob - rollout_log_prob
k3_kl_matrix = torch.exp(log_ratio) - log_ratio - 1
metrics_append(metrics, "k3_kl", k3_kl_matrix)

# 3c. Log PPL difference (sequence-level perplexity difference)
# log_ppl_diff = mean_log_prob_rollout - mean_log_prob_training
# Since ppl = exp(-log_prob), we have:
# log(ppl_ratio) = log(training_ppl/rollout_ppl) = log_ppl_diff
Comment on lines +311 to +312
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The metrics log_ppl_diff_max and log_ppl_diff_min are incorrectly calculated. You are appending the same log_ppl_diff tensor for both. Since log_ppl_diff contains identical values for a single sequence (due to expand=True), this will not compute the maximum or minimum across sequences. The aggregation framework will then compute the mean, resulting in identical values for log_ppl_diff, log_ppl_diff_max, and log_ppl_diff_min, which is not the intended behavior.

I recommend removing these two lines until a proper mechanism for computing and logging max/min values across the batch is implemented.

# Positive value means training assigns lower probability (higher PPL) than rollout
log_ppl_diff = mean_log_prob_rollout - mean_log_prob_training
metrics_append(metrics, "log_ppl_diff", log_ppl_diff)
metrics_append(metrics, "log_ppl_abs_diff", log_ppl_diff.abs())
metrics_append(metrics, "log_ppl_diff_max", log_ppl_diff)
metrics_append(metrics, "log_ppl_diff_min", log_ppl_diff)

# 3d. PPL ratio (how much higher is training PPL vs rollout PPL)
# For numerical stability, compute in log space using log_ppl_diff
# Note: log_ppl_diff = log(ppl_ratio), so ppl_ratio = exp(log_ppl_diff)
ppl_ratio = torch.exp(log_ppl_diff)
metrics_append(metrics, "ppl_ratio", ppl_ratio)