Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 30 additions & 21 deletions examples/train_infer_mismatch_helper/mis.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,15 @@ def mask(
metrics: Dict[str, list[torch.Tensor]],
lower_bound: float,
upper_bound: float,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
assert lower_bound is not None and upper_bound is not None and lower_bound < upper_bound
metrics_append(metrics, "mask_fraction_low", (weights < lower_bound).int())
metrics_append(metrics, "mask_fraction_high", (weights > upper_bound).int())
mask = (weights >= lower_bound) & (weights <= upper_bound)
return weights * mask * loss_mask
in_range = (weights >= lower_bound) & (weights <= upper_bound)
modified_mask = loss_mask * in_range.float()
# Zero out padding in weights but preserve values at non-rejected positions
weights = weights * loss_mask
return weights, modified_mask


def compute_mis_weights(
Expand All @@ -115,7 +118,7 @@ def compute_mis_weights(
train_log_probs: list[torch.Tensor],
rollout_log_probs: list[torch.Tensor],
loss_masks: list[torch.Tensor],
) -> Tuple[list[torch.Tensor], Dict[str, list[torch.Tensor]]]:
) -> Tuple[list[torch.Tensor], list[torch.Tensor], Dict[str, list[torch.Tensor]]]:
"""
Compute the importance sampling (IS) weights and metrics between the inference and training engine.
Args:
Expand All @@ -126,7 +129,8 @@ def compute_mis_weights(
For multi-turn RL, the tool response will be marked as 0 in the loss_mask.

Returns:
weights: List of importance sampling weights. 1D tensor each.
weights: List of importance sampling weights (safety-bounded; zeroed at padding only). 1D tensor each.
modified_response_masks: List of rejection masks to apply in aggregation (mask mode + veto). 1D tensor each.
metrics: The metrics for the importance sampling weights, a dict of list[torch.Tensor]. 1D tensor each.
"""

Expand All @@ -148,6 +152,7 @@ def compute_mis_weights(

SAFETY_BOUND = 20.0 # Add a safety bound to avoid exp overflow
all_weights = []
all_modified_masks = []

# handle each sequence independently
for train_log_prob, rollout_log_prob, loss_mask in zip(train_log_probs, rollout_log_probs, loss_masks):
Expand All @@ -172,19 +177,17 @@ def compute_mis_weights(
weights = torch.exp(log_ratio_safe)
metrics_append(metrics, "mean_is_weight_before_clip", weights)

# mask out catastrophic tokens
if args.mis_veto_threshold is not None:
veto_mask = calculate_veto_mask(raw_log_ratio_diff, loss_mask, args.mis_veto_threshold, metrics)
modified_mask = loss_mask.clone().float()

# mode: how to handle the importance sampling weights exceeding the thresholds.
if args.mis_mode == "truncate":
# Cap the importance sampling weights at the upper threshold
# https://fengyao.notion.site/off-policy-rl#279721e3f6c48092bbe2fcfe0e9c6b33
weights = truncate(weights, loss_mask, metrics, args.mis_upper_bound)
elif args.mis_mode == "mask":
# Zero the importance sampling weights outside the [lower, upper] range.
# Preserve safety-bounded weights; apply thresholds via modified_mask
# https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda
weights = mask(
weights, modified_mask = mask(
weights,
loss_mask,
metrics,
Expand All @@ -204,15 +207,20 @@ def compute_mis_weights(
else:
raise ValueError(f"Unsupported mis_mode: {args.mis_mode}")

metrics_append(metrics, "ratio_mean_after_mis", weights)
# Veto on raw per-token ratios (sequence-wise rejection)
# Works independently of truncate/mask mode and does NOT modify IS weights
if args.mis_veto_threshold is not None:
weights = weights * veto_mask
metrics_append(metrics, "ratio_mean_after_veto_mask", weights)
veto_mask = calculate_veto_mask(raw_log_ratio_diff, loss_mask, args.mis_veto_threshold, metrics)
modified_mask = modified_mask * veto_mask

metrics_append(metrics, "ratio_mean_after_mis", weights)

weights = weights.detach()
modified_mask = modified_mask.detach()
all_weights.append(weights)
all_modified_masks.append(modified_mask)

return all_weights, metrics
return all_weights, all_modified_masks, metrics


def compute_mis_weights_with_cp(
Expand All @@ -225,7 +233,7 @@ def compute_mis_weights_with_cp(
total_lengths: list[int],
response_lengths: list[int],
**kwargs: Any,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
) -> Tuple[torch.Tensor, list[torch.Tensor], Dict[str, torch.Tensor]]:
"""
Compute the importance sampling (IS) weights and metrics with context parallel.
Args:
Expand All @@ -235,9 +243,9 @@ def compute_mis_weights_with_cp(
total_lengths: List of total lengths.
response_lengths: List of response lengths.
Returns:
is_weights: Importance sampling weights on this CP rank and flattened along dim=0.
is_metrics: The metrics for the importance sampling weights, a dict of list[torch.Tensor]. 1D tensor each.
Also flattened along dim=0.
pg_loss: Policy gradient loss with IS weights applied (flattened along dim=0).
modified_masks: List of modified response masks with rejection applied (one per sequence).
is_metrics: The metrics for the importance sampling weights, a dict of flattened tensors.
"""
# Gather cp slice from other cp ranks
full_rollout_log_probs = [
Expand All @@ -249,8 +257,8 @@ def compute_mis_weights_with_cp(
for old_log_prob, total_length, response_length in zip(train_log_probs, total_lengths, response_lengths)
]

# Main logic for is
is_weights, is_metrics = compute_mis_weights(
# Main logic for is (decoupled)
is_weights, modified_masks, is_metrics = compute_mis_weights(
args=args,
train_log_probs=full_old_log_probs,
rollout_log_probs=full_rollout_log_probs,
Expand All @@ -270,14 +278,15 @@ def slice_cp_and_concat(

result_metrics = {}
is_weights = slice_cp_and_concat(is_weights, total_lengths, response_lengths)

for key, values in is_metrics.items():
key_name = f"mis_{key}"
values = slice_cp_and_concat(values, total_lengths, response_lengths)
result_metrics[key_name] = values

pg_loss = pg_loss * is_weights

return pg_loss, result_metrics
return pg_loss, modified_masks, result_metrics


def add_ppl_metrics(
Expand Down
4 changes: 2 additions & 2 deletions examples/train_infer_mismatch_helper/mis.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Enable importance sampling, details refer to the comments of compute_mis_weights in mis.py
use_mis: false
use_tis: true
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we change the mis -> tis here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It just seems "use_mis" is not used anywhere - we may delete it?


# Aggregation level for importance sampling weights:
# token: per-token
Expand All @@ -11,7 +11,7 @@ mis_level: "token"
# truncate: cap to upper bound, TIS
# mask: zero outside [lower, upper], MIS
# clip: clip to [lower, upper], CIS
mis_mode: "truncate"
mis_mode: "mask"

# For mask or clip mode, the lower bound of the IS weights.
# For truncate mode, it will not be used.
Expand Down
13 changes: 10 additions & 3 deletions slime/backends/megatron_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,9 @@ def vanilla_tis_function(
pg_loss: torch.Tensor,
train_log_probs: list[torch.Tensor],
rollout_log_probs: list[torch.Tensor],
loss_masks: list[torch.Tensor],
**kwargs: Any,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
) -> Tuple[torch.Tensor, list[torch.Tensor], Dict[str, torch.Tensor]]:
rollout_log_probs = torch.cat(rollout_log_probs, dim=0)
old_log_probs = torch.cat(train_log_probs, dim=0)
tis = torch.exp(old_log_probs - rollout_log_probs)
Expand All @@ -441,7 +442,7 @@ def vanilla_tis_function(
"tis_abs": tis_abs.clone().detach(),
}
pg_loss = pg_loss * tis_weights
return pg_loss, metrics
return pg_loss, loss_masks, metrics

assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS"

Expand All @@ -460,7 +461,13 @@ def vanilla_tis_function(
tis_func = load_function(args.custom_tis_function_path)
else:
tis_func = vanilla_tis_function
pg_loss, tis_metrics = tis_func(**tis_kwargs)
pg_loss, modified_response_masks, tis_metrics = tis_func(**tis_kwargs)

# [decouple IS and rejection] Rebuild sum_of_sample_mean with modified_response_masks for denominator correction
# modified_response_masks will be sliced with cp in get_sum_of_sample_mean
sum_of_sample_mean = get_sum_of_sample_mean(
total_lengths, response_lengths, modified_response_masks, args.calculate_per_token_loss
)

pg_loss = sum_of_sample_mean(pg_loss)
pg_clipfrac = sum_of_sample_mean(pg_clipfrac)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not use_tis, then pg_loss would rely on the passed in sum_of_sample_mean. If using tis, the we will create a new sum_of_sample_mean with modified_response_masks by:

        sum_of_sample_mean = get_sum_of_sample_mean(
            total_lengths, response_lengths, modified_response_masks, args.calculate_per_token_loss
        )

Right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, if we don't use TIS then we do not update this sum_of_sample_mean function, which was originally created from loss_mask

Expand Down