diff --git a/examples/train_infer_mismatch_helper/mis.py b/examples/train_infer_mismatch_helper/mis.py index fc8c80424f..26448f8e44 100644 --- a/examples/train_infer_mismatch_helper/mis.py +++ b/examples/train_infer_mismatch_helper/mis.py @@ -101,12 +101,15 @@ def mask( metrics: Dict[str, list[torch.Tensor]], lower_bound: float, upper_bound: float, -) -> torch.Tensor: +) -> Tuple[torch.Tensor, torch.Tensor]: assert lower_bound is not None and upper_bound is not None and lower_bound < upper_bound metrics_append(metrics, "mask_fraction_low", (weights < lower_bound).int()) metrics_append(metrics, "mask_fraction_high", (weights > upper_bound).int()) - mask = (weights >= lower_bound) & (weights <= upper_bound) - return weights * mask * loss_mask + in_range = (weights >= lower_bound) & (weights <= upper_bound) + modified_mask = loss_mask * in_range.float() + # Zero out padding in weights but preserve values at non-rejected positions + weights = weights * loss_mask + return weights, modified_mask def compute_mis_weights( @@ -115,7 +118,7 @@ def compute_mis_weights( train_log_probs: list[torch.Tensor], rollout_log_probs: list[torch.Tensor], loss_masks: list[torch.Tensor], -) -> Tuple[list[torch.Tensor], Dict[str, list[torch.Tensor]]]: +) -> Tuple[list[torch.Tensor], list[torch.Tensor], Dict[str, list[torch.Tensor]]]: """ Compute the importance sampling (IS) weights and metrics between the inference and training engine. Args: @@ -126,7 +129,8 @@ def compute_mis_weights( For multi-turn RL, the tool response will be marked as 0 in the loss_mask. Returns: - weights: List of importance sampling weights. 1D tensor each. + weights: List of importance sampling weights (safety-bounded; zeroed at padding only). 1D tensor each. + modified_response_masks: List of rejection masks to apply in aggregation (mask mode + veto). 1D tensor each. metrics: The metrics for the importance sampling weights, a dict of list[torch.Tensor]. 1D tensor each. """ @@ -148,6 +152,7 @@ def compute_mis_weights( SAFETY_BOUND = 20.0 # Add a safety bound to avoid exp overflow all_weights = [] + all_modified_masks = [] # handle each sequence independently for train_log_prob, rollout_log_prob, loss_mask in zip(train_log_probs, rollout_log_probs, loss_masks): @@ -172,9 +177,7 @@ def compute_mis_weights( weights = torch.exp(log_ratio_safe) metrics_append(metrics, "mean_is_weight_before_clip", weights) - # mask out catastrophic tokens - if args.mis_veto_threshold is not None: - veto_mask = calculate_veto_mask(raw_log_ratio_diff, loss_mask, args.mis_veto_threshold, metrics) + modified_mask = loss_mask.clone().float() # mode: how to handle the importance sampling weights exceeding the thresholds. if args.mis_mode == "truncate": @@ -182,9 +185,9 @@ def compute_mis_weights( # https://fengyao.notion.site/off-policy-rl#279721e3f6c48092bbe2fcfe0e9c6b33 weights = truncate(weights, loss_mask, metrics, args.mis_upper_bound) elif args.mis_mode == "mask": - # Zero the importance sampling weights outside the [lower, upper] range. + # Preserve safety-bounded weights; apply thresholds via modified_mask # https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda - weights = mask( + weights, modified_mask = mask( weights, loss_mask, metrics, @@ -204,15 +207,20 @@ def compute_mis_weights( else: raise ValueError(f"Unsupported mis_mode: {args.mis_mode}") - metrics_append(metrics, "ratio_mean_after_mis", weights) + # Veto on raw per-token ratios (sequence-wise rejection) + # Works independently of truncate/mask mode and does NOT modify IS weights if args.mis_veto_threshold is not None: - weights = weights * veto_mask - metrics_append(metrics, "ratio_mean_after_veto_mask", weights) + veto_mask = calculate_veto_mask(raw_log_ratio_diff, loss_mask, args.mis_veto_threshold, metrics) + modified_mask = modified_mask * veto_mask + + metrics_append(metrics, "ratio_mean_after_mis", weights) weights = weights.detach() + modified_mask = modified_mask.detach() all_weights.append(weights) + all_modified_masks.append(modified_mask) - return all_weights, metrics + return all_weights, all_modified_masks, metrics def compute_mis_weights_with_cp( @@ -225,7 +233,7 @@ def compute_mis_weights_with_cp( total_lengths: list[int], response_lengths: list[int], **kwargs: Any, -) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: +) -> Tuple[torch.Tensor, list[torch.Tensor], Dict[str, torch.Tensor]]: """ Compute the importance sampling (IS) weights and metrics with context parallel. Args: @@ -235,9 +243,9 @@ def compute_mis_weights_with_cp( total_lengths: List of total lengths. response_lengths: List of response lengths. Returns: - is_weights: Importance sampling weights on this CP rank and flattened along dim=0. - is_metrics: The metrics for the importance sampling weights, a dict of list[torch.Tensor]. 1D tensor each. - Also flattened along dim=0. + pg_loss: Policy gradient loss with IS weights applied (flattened along dim=0). + modified_masks: List of modified response masks with rejection applied (one per sequence). + is_metrics: The metrics for the importance sampling weights, a dict of flattened tensors. """ # Gather cp slice from other cp ranks full_rollout_log_probs = [ @@ -249,8 +257,8 @@ def compute_mis_weights_with_cp( for old_log_prob, total_length, response_length in zip(train_log_probs, total_lengths, response_lengths) ] - # Main logic for is - is_weights, is_metrics = compute_mis_weights( + # Main logic for is (decoupled) + is_weights, modified_masks, is_metrics = compute_mis_weights( args=args, train_log_probs=full_old_log_probs, rollout_log_probs=full_rollout_log_probs, @@ -270,6 +278,7 @@ def slice_cp_and_concat( result_metrics = {} is_weights = slice_cp_and_concat(is_weights, total_lengths, response_lengths) + for key, values in is_metrics.items(): key_name = f"mis_{key}" values = slice_cp_and_concat(values, total_lengths, response_lengths) @@ -277,7 +286,7 @@ def slice_cp_and_concat( pg_loss = pg_loss * is_weights - return pg_loss, result_metrics + return pg_loss, modified_masks, result_metrics def add_ppl_metrics( diff --git a/examples/train_infer_mismatch_helper/mis.yaml b/examples/train_infer_mismatch_helper/mis.yaml index 3a7fbbafc9..893f331277 100644 --- a/examples/train_infer_mismatch_helper/mis.yaml +++ b/examples/train_infer_mismatch_helper/mis.yaml @@ -1,5 +1,5 @@ # Enable importance sampling, details refer to the comments of compute_mis_weights in mis.py -use_mis: false +use_tis: true # Aggregation level for importance sampling weights: # token: per-token @@ -11,7 +11,7 @@ mis_level: "token" # truncate: cap to upper bound, TIS # mask: zero outside [lower, upper], MIS # clip: clip to [lower, upper], CIS -mis_mode: "truncate" +mis_mode: "mask" # For mask or clip mode, the lower bound of the IS weights. # For truncate mode, it will not be used. diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index a7c40d74ea..b80f887f4f 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -427,8 +427,9 @@ def vanilla_tis_function( pg_loss: torch.Tensor, train_log_probs: list[torch.Tensor], rollout_log_probs: list[torch.Tensor], + loss_masks: list[torch.Tensor], **kwargs: Any, - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + ) -> Tuple[torch.Tensor, list[torch.Tensor], Dict[str, torch.Tensor]]: rollout_log_probs = torch.cat(rollout_log_probs, dim=0) old_log_probs = torch.cat(train_log_probs, dim=0) tis = torch.exp(old_log_probs - rollout_log_probs) @@ -441,7 +442,7 @@ def vanilla_tis_function( "tis_abs": tis_abs.clone().detach(), } pg_loss = pg_loss * tis_weights - return pg_loss, metrics + return pg_loss, loss_masks, metrics assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS" @@ -460,7 +461,13 @@ def vanilla_tis_function( tis_func = load_function(args.custom_tis_function_path) else: tis_func = vanilla_tis_function - pg_loss, tis_metrics = tis_func(**tis_kwargs) + pg_loss, modified_response_masks, tis_metrics = tis_func(**tis_kwargs) + + # [decouple IS and rejection] Rebuild sum_of_sample_mean with modified_response_masks for denominator correction + # modified_response_masks will be sliced with cp in get_sum_of_sample_mean + sum_of_sample_mean = get_sum_of_sample_mean( + total_lengths, response_lengths, modified_response_masks, args.calculate_per_token_loss + ) pg_loss = sum_of_sample_mean(pg_loss) pg_clipfrac = sum_of_sample_mean(pg_clipfrac)