[FSDP] Add Masked importance sampling#1122
Conversation
rewrite masked sum
68bf817 to
51954ed
Compare
Add masked importance sampling for FSDP backend (THUDM#1063).
|
@PopSoda2002 Hi, could you please review this? Thank you! |
| calculates PPO-style clipped policy gradient loss. For GSPO, gathers | ||
| full sequences via context-parallel all-gather before computing per-sample | ||
| KL. Optionally applies TIS (Temporal Importance Sampling) correction and | ||
| KL. Optionally applies TIS (Truncated Importance Sampling) correction and |
PopSoda2002
left a comment
There was a problem hiding this comment.
In my opinion, this PR should not be so large:
- Delete the script and test it locally
- We can add an argument like
use-misand implement the MIS func - Do not change the code a big diff since it's just a small func, currently the code is harder to read and may introduce higher potential for bug
slime/backends/fsdp_utils/actor.py
Outdated
| def _compute_tis_weights( | ||
| self, | ||
| old_log_probs: torch.Tensor, | ||
| rollout_log_probs: torch.Tensor, | ||
| loss_masks: list[torch.Tensor], | ||
| response_lengths: list[int], | ||
| ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: | ||
| """Compute Importance Sampling weights for TIS/MIS. | ||
|
|
||
| Supports both token-level and sequence-level aggregation, and truncate/mask modes. | ||
| """ | ||
| tis_mode = self.args.tis_mode if self.args.tis_mode is not None else "truncate" | ||
| tis_level = self.args.tis_level if self.args.tis_level is not None else "token" | ||
| tis_clip_low = self.args.tis_clip_low if self.args.tis_clip_low is not None else 0.1 | ||
| tis_clip_high = self.args.tis_clip if self.args.tis_clip is not None else 2.0 | ||
|
|
||
| log_ratio = old_log_probs - rollout_log_probs | ||
|
|
||
| # Calculate raw TIS weights based on level | ||
| if tis_level == "token": | ||
| tis = torch.exp(log_ratio) | ||
| elif tis_level == "sequence": | ||
| tis_list = [] | ||
| for seq_log_ratio, mask in zip(log_ratio.split(response_lengths, dim=0), loss_masks, strict=False): | ||
| seq_mask = mask.to(seq_log_ratio.device) | ||
| sum_log_ratio = masked_sum(seq_log_ratio, seq_mask, expand=True) | ||
| seq_tis = torch.exp(sum_log_ratio) | ||
| tis_list.append(seq_tis) | ||
| tis = torch.cat(tis_list, dim=0) | ||
| else: | ||
| raise ValueError(f"Unsupported tis_level: {tis_level}") | ||
|
|
||
| # Apply mode (truncate or mask) | ||
| if tis_mode == "truncate": | ||
| tis_clip = torch.clamp(tis, min=tis_clip_low, max=tis_clip_high) | ||
| elif tis_mode == "mask": | ||
| mask = (tis >= tis_clip_low) & (tis <= tis_clip_high) | ||
| tis_clip = tis * mask.float() | ||
| else: | ||
| raise ValueError(f"Unsupported tis_mode: {tis_mode}") | ||
|
|
||
| tis_clipfrac = tis_clip != tis | ||
|
|
||
| return tis_clip, tis, tis_clipfrac | ||
|
|
There was a problem hiding this comment.
Are you sure that this function should be put under fsdp_utils? You can refer to where we put similar function of Megatron.
slime/backends/fsdp_utils/actor.py
Outdated
| def vanilla_tis_function_fsdp( | ||
| args, | ||
| *, | ||
| pg_loss: torch.Tensor, | ||
| train_log_probs: list[torch.Tensor], | ||
| rollout_log_probs: list[torch.Tensor], | ||
| loss_masks: list[torch.Tensor], | ||
| **kwargs, | ||
| ) -> tuple[torch.Tensor, list[torch.Tensor], dict[str, torch.Tensor]]: | ||
| """Apply TIS off-policy correction using importance sampling. | ||
|
|
||
| Parameters: | ||
| args: Arguments containing TIS settings. | ||
| pg_loss: Policy gradient loss tensor of shape [total_seq_len - 1]. | ||
| train_log_probs: List of tensors containing training log-probabilities | ||
| for each sequence. | ||
| rollout_log_probs: List of tensors containing rollout log-probabilities | ||
| for each sequence. | ||
| loss_masks: List of tensors containing loss masks for each sequence. | ||
| """ | ||
| rollout_log_probs_flat = torch.cat(rollout_log_probs, dim=0) | ||
| train_log_probs_flat = torch.cat(train_log_probs, dim=0) | ||
|
|
||
| tis = torch.exp(train_log_probs_flat - rollout_log_probs_flat) | ||
| tis_abs = (tis - 1).abs() | ||
|
|
||
| tis_clip_low = args.tis_clip_low if args.tis_clip_low is not None else 0.1 | ||
| tis_clip_high = args.tis_clip if args.tis_clip is not None else 2.0 | ||
| tis_clip = torch.clamp(tis, min=tis_clip_low, max=tis_clip_high) | ||
| tis_clipfrac = (tis_clip != tis).float() | ||
|
|
||
| metrics = { | ||
| "tis": tis.clone().detach(), | ||
| "tis_clipfrac": tis_clipfrac.clone().detach(), | ||
| "tis_abs": tis_abs.clone().detach(), | ||
| } | ||
| pg_loss = pg_loss * tis_clip | ||
|
|
||
| return pg_loss, loss_masks, metrics |
There was a problem hiding this comment.
As I mentioned, maybe these functions should not be put here. We want to keep actor.py as clean as possible.
There was a problem hiding this comment.
Hi @zijiexia, can we just remove vanilla_tis_function_fsdp (it should not be used as we can specify custom-tis-function-path to use compute_mis_weights in examples/train_infer_mismatch_helper/mis.py) and _compute_tis_weights (it is not used now)?
There was a problem hiding this comment.
Hi @GuanxingLu , I've made the following changes:
- Cleanup the unused functions in
actor.py. - I move
vanilla_tis_function_fsdptoppo_utils.pyas I think we do need it, following the same pattern as Megatron: referring to thisslime/slime/backends/megatron_utils/loss.py
Line 526 in 461fc8a
- I move the
compute_mis_weights_fsdptomis.py.
I didn't add a new use-mis arg as I'm trying to follow the same parameter system in mis.yaml. Could you take a look at it and let me know what you think before I mark it back to ready for review? Thanks!
|
Hi @PopSoda2002 @zhaochenyang20 , thanks for the review, I've refactor the code accordingly:
Thank you! |
|
thanks! Sorry for the late reply. Let me and Huapeng review this @PopSoda2002 |
PopSoda2002
left a comment
There was a problem hiding this comment.
It looks pretty nice now! Thanks for your great work!
|
Nice done Zijie! @zijiexia |
I think @GuanxingLu does important contribute to this PR also cc @zhaochenyang20 😂 |
@GuanxingLu start this before I join so most credit should goes to him |
|
Appreciate it, we all made a lot, happy to contribute! |
Co-authored-by: Guanxing Lu <747398423@qq.com>
Add masked importance sampling for both token level and sequence level as #1063 .
Results from @GuanxingLu:
Summary:
run with 4xH200 GPUs (using examples/train_infer_mismatch_helper/run-qwen3-4b-fsdp-mis.sh):

Unfortunately, the original mismatch of training and rollout engine is quite marginal, so the MIS just has no effect. Thus added a pytest script to test the functionality.