Skip to content

[FSDP] Add Masked importance sampling#1122

Merged
zhuzilin merged 8 commits intoTHUDM:mainfrom
zijiexia:zijie_dev_branch
Dec 22, 2025
Merged

[FSDP] Add Masked importance sampling#1122
zhuzilin merged 8 commits intoTHUDM:mainfrom
zijiexia:zijie_dev_branch

Conversation

@zijiexia
Copy link
Contributor

@zijiexia zijiexia commented Dec 15, 2025

Add masked importance sampling for both token level and sequence level as #1063 .
Results from @GuanxingLu:

Summary:

  • Directly use compute_mis_weights func from megatron backend
  • Add a pytest file (tests/test_fsdp_mis.py)

run with 4xH200 GPUs (using examples/train_infer_mismatch_helper/run-qwen3-4b-fsdp-mis.sh):
526919906-f205e2da-99f5-4d10-bda4-daad21e88ce9

Unfortunately, the original mismatch of training and rollout engine is quite marginal, so the MIS just has no effect. Thus added a pytest script to test the functionality.

@zijiexia zijiexia marked this pull request as draft December 16, 2025 00:34
@zijiexia zijiexia changed the title [FSDP] Add Masked importance sampling #1063 [WIP][FSDP] Add Masked importance sampling #1063 Dec 16, 2025
@zijiexia zijiexia marked this pull request as ready for review December 16, 2025 02:03
@zijiexia zijiexia changed the title [WIP][FSDP] Add Masked importance sampling #1063 [FSDP] Add Masked importance sampling Dec 16, 2025
@GuanxingLu
Copy link
Contributor

GuanxingLu commented Dec 17, 2025

@PopSoda2002 Hi, could you please review this? Thank you!

calculates PPO-style clipped policy gradient loss. For GSPO, gathers
full sequences via context-parallel all-gather before computing per-sample
KL. Optionally applies TIS (Temporal Importance Sampling) correction and
KL. Optionally applies TIS (Truncated Importance Sampling) correction and
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch 😂

Copy link
Collaborator

@PopSoda2002 PopSoda2002 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my opinion, this PR should not be so large:

  1. Delete the script and test it locally
  2. We can add an argument like use-mis and implement the MIS func
  3. Do not change the code a big diff since it's just a small func, currently the code is harder to read and may introduce higher potential for bug

Comment on lines +824 to +868
def _compute_tis_weights(
self,
old_log_probs: torch.Tensor,
rollout_log_probs: torch.Tensor,
loss_masks: list[torch.Tensor],
response_lengths: list[int],
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""Compute Importance Sampling weights for TIS/MIS.

Supports both token-level and sequence-level aggregation, and truncate/mask modes.
"""
tis_mode = self.args.tis_mode if self.args.tis_mode is not None else "truncate"
tis_level = self.args.tis_level if self.args.tis_level is not None else "token"
tis_clip_low = self.args.tis_clip_low if self.args.tis_clip_low is not None else 0.1
tis_clip_high = self.args.tis_clip if self.args.tis_clip is not None else 2.0

log_ratio = old_log_probs - rollout_log_probs

# Calculate raw TIS weights based on level
if tis_level == "token":
tis = torch.exp(log_ratio)
elif tis_level == "sequence":
tis_list = []
for seq_log_ratio, mask in zip(log_ratio.split(response_lengths, dim=0), loss_masks, strict=False):
seq_mask = mask.to(seq_log_ratio.device)
sum_log_ratio = masked_sum(seq_log_ratio, seq_mask, expand=True)
seq_tis = torch.exp(sum_log_ratio)
tis_list.append(seq_tis)
tis = torch.cat(tis_list, dim=0)
else:
raise ValueError(f"Unsupported tis_level: {tis_level}")

# Apply mode (truncate or mask)
if tis_mode == "truncate":
tis_clip = torch.clamp(tis, min=tis_clip_low, max=tis_clip_high)
elif tis_mode == "mask":
mask = (tis >= tis_clip_low) & (tis <= tis_clip_high)
tis_clip = tis * mask.float()
else:
raise ValueError(f"Unsupported tis_mode: {tis_mode}")

tis_clipfrac = tis_clip != tis

return tis_clip, tis, tis_clipfrac

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure that this function should be put under fsdp_utils? You can refer to where we put similar function of Megatron.

Comment on lines +1185 to +1223
def vanilla_tis_function_fsdp(
args,
*,
pg_loss: torch.Tensor,
train_log_probs: list[torch.Tensor],
rollout_log_probs: list[torch.Tensor],
loss_masks: list[torch.Tensor],
**kwargs,
) -> tuple[torch.Tensor, list[torch.Tensor], dict[str, torch.Tensor]]:
"""Apply TIS off-policy correction using importance sampling.

Parameters:
args: Arguments containing TIS settings.
pg_loss: Policy gradient loss tensor of shape [total_seq_len - 1].
train_log_probs: List of tensors containing training log-probabilities
for each sequence.
rollout_log_probs: List of tensors containing rollout log-probabilities
for each sequence.
loss_masks: List of tensors containing loss masks for each sequence.
"""
rollout_log_probs_flat = torch.cat(rollout_log_probs, dim=0)
train_log_probs_flat = torch.cat(train_log_probs, dim=0)

tis = torch.exp(train_log_probs_flat - rollout_log_probs_flat)
tis_abs = (tis - 1).abs()

tis_clip_low = args.tis_clip_low if args.tis_clip_low is not None else 0.1
tis_clip_high = args.tis_clip if args.tis_clip is not None else 2.0
tis_clip = torch.clamp(tis, min=tis_clip_low, max=tis_clip_high)
tis_clipfrac = (tis_clip != tis).float()

metrics = {
"tis": tis.clone().detach(),
"tis_clipfrac": tis_clipfrac.clone().detach(),
"tis_abs": tis_abs.clone().detach(),
}
pg_loss = pg_loss * tis_clip

return pg_loss, loss_masks, metrics
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I mentioned, maybe these functions should not be put here. We want to keep actor.py as clean as possible.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @zijiexia, can we just remove vanilla_tis_function_fsdp (it should not be used as we can specify custom-tis-function-path to use compute_mis_weights in examples/train_infer_mismatch_helper/mis.py) and _compute_tis_weights (it is not used now)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @GuanxingLu , I've made the following changes:

  1. Cleanup the unused functions in actor.py.
  2. I move vanilla_tis_function_fsdp to ppo_utils.py as I think we do need it, following the same pattern as Megatron: referring to this
    tis_func = vanilla_tis_function
  3. I move the compute_mis_weights_fsdp to mis.py.

I didn't add a new use-mis arg as I'm trying to follow the same parameter system in mis.yaml. Could you take a look at it and let me know what you think before I mark it back to ready for review? Thanks!

Copy link
Contributor

@GuanxingLu GuanxingLu Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

@zijiexia zijiexia marked this pull request as draft December 18, 2025 06:42
@zijiexia
Copy link
Contributor Author

Hi @PopSoda2002 @zhaochenyang20 , thanks for the review, I've refactor the code accordingly:

  1. Delete the test script.
  2. Remove the unused functions (sorry!) and cleanup the actor.py
  3. I didn't add an additional use-mis args since I was trying to follow the same pattern as Refactoring training inference importance sampling with seqeunce/geometry level #429. Please let me know what you think, I'll further refactor accordingly.

Thank you!

@zijiexia zijiexia marked this pull request as ready for review December 18, 2025 17:13
@zhaochenyang20
Copy link
Collaborator

thanks! Sorry for the late reply. Let me and Huapeng review this @PopSoda2002

Copy link
Collaborator

@PopSoda2002 PopSoda2002 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks pretty nice now! Thanks for your great work!

@zhuzilin zhuzilin merged commit b74858a into THUDM:main Dec 22, 2025
7 checks passed
@zhaochenyang20
Copy link
Collaborator

Nice done Zijie! @zijiexia

@PopSoda2002
Copy link
Collaborator

Nice done Zijie! @zijiexia

I think @GuanxingLu does important contribute to this PR also cc @zhaochenyang20 😂

@zijiexia
Copy link
Contributor Author

Nice done Zijie! @zijiexia

@GuanxingLu start this before I join so most credit should goes to him

@GuanxingLu
Copy link
Contributor

Appreciate it, we all made a lot, happy to contribute!

Yangruipis pushed a commit to rednote-ai/slime that referenced this pull request Feb 28, 2026
Co-authored-by: Guanxing Lu <747398423@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants