Skip to content

Refactoring training inference importance sampling with seqeunce/geometry level#429

Merged
zhuzilin merged 70 commits intoTHUDM:mainfrom
zhaochenyang20:importance_sampling
Oct 20, 2025
Merged

Refactoring training inference importance sampling with seqeunce/geometry level#429
zhuzilin merged 70 commits intoTHUDM:mainfrom
zhaochenyang20:importance_sampling

Conversation

@zhaochenyang20
Copy link
Copy Markdown
Collaborator

@zhaochenyang20 zhaochenyang20 commented Oct 6, 2025

Thanks so much to the contribution of this paper When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch

This PR refactors the Importance Sampling (IS) functionality, replacing the legacy --use-tis parameter with a more flexible --use-train-infer-is parameter system. We introduce multiple aggregation levels and processing modes to handle training-inference mismatch problems better.

  1. Parameter System Refactoring
  • Removed legacy parameters: Deleted --use-tis, --tis-clip, --tis-clip-low parameters
  • New parameters:
    • --use-train-infer-is: Enable training-inference importance sampling
    • --train-infer-is-level: Aggregation level (token/sequence/geometric)
    • --train-infer-is-mode: Processing mode (truncate/mask/clip)
    • --train-infer-is-lower-bound/--train-infer-is-upper-bound: Weight bounds
    • --train-infer-is-veto-threshold: Catastrophic token threshold
  1. Aggregation Levels

Token Level (default):

  • Computes importance weights independently for each token
  • Formula: w_i = exp(log π_train(x_i) - log π_rollout(x_i))
  • Characteristics: Biased but computationally simple, suitable for most scenarios

Sequence Level:

  • Uses the product of all token weights as the sequence weight
  • Formula: w_seq = exp(Σ(log π_train(x_i) - log π_rollout(x_i)))
  • Characteristics: Unbiased but high variance, suitable for sequence-level optimization

Geometric Level:

  • Uses geometric mean to compute sequence weights
  • Formula: w_seq = exp(mean(log π_train(x_i) - log π_rollout(x_i)))
  • Characteristics: Biased but low variance, balances bias and variance
  1. Processing Modes

Truncate Mode (TIS):

  • Clips weights exceeding the upper bound to the upper bound
  • Maintains original TIS behavior, suitable for variance control

Mask Mode (MIS):

  • Sets weights outside [lower, upper] range to zero
  • More aggressive filtering strategy, suitable for handling extreme mismatches

Clip Mode (CIS):

  • Constrains weights within [lower, upper] range
  • Balanced truncation strategy
  1. Others
  • Catastrophic token detection: Detects and filters sequences containing catastrophic tokens via --train-infer-is-veto-threshold
  • monitoring metrics: Added training/inference perplexity, KL divergence, K3 KL estimator, and more
  1. Usage
# Using geometric mean + mask mode
--use-train-infer-is \
--train-infer-is-level geometric \
--train-infer-is-mode mask \
--train-infer-is-lower-bound 0.5 \
--train-infer-is-upper-bound 2.0 \
--train-infer-is-veto-threshold 1e-3

@zhaochenyang20
Copy link
Copy Markdown
Collaborator Author

/gemini review

Comment on lines +225 to +229
seq_mean = masked_mean(tis_weights, eos_mask, dim=-1)
metrics["tis_seq_mean"] = seq_mean.mean()
metrics["tis_seq_std"] = seq_mean.std()
metrics["tis_seq_max"] = seq_mean.max()
metrics["tis_seq_min"] = seq_mean.min()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might have a problem when cp>1.

tis_clip = torch.clamp(
tis, min=getattr(self.args, "tis_clip_low", 0.1), max=getattr(self.args, "tis_clip", 2.0)
# Build eos mask from loss masks
eos_mask = torch.cat(loss_masks, dim=0).to(device=log_probs.device)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might have a problem when cp>1. Because the loss mask is not split by cp_size, while logp is split by cp_size. You can reuse the implementation of sum_of_sample_mean in cp_utils.

@zhaochenyang20 zhaochenyang20 changed the title [WIP] Importance sampling Refactoring training inference importance sampling with seqeunce/geometry level Oct 15, 2025
@zhuzilin zhuzilin merged commit 46e2cd4 into THUDM:main Oct 20, 2025
4 checks passed
nanjiangwill pushed a commit to nanjiangwill/slime that referenced this pull request Oct 22, 2025
…etry level (THUDM#429)

Co-authored-by: Jiajun Li <guapisolo@gmail.com>
llltttwww pushed a commit to llltttwww/slime that referenced this pull request Nov 30, 2025
…etry level (THUDM#429)

Co-authored-by: Jiajun Li <guapisolo@gmail.com>
Yangruipis pushed a commit to rednote-ai/slime that referenced this pull request Feb 28, 2026
…etry level (THUDM#429)

Co-authored-by: Jiajun Li <guapisolo@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants