From 30e01df5020af3a523fde00ca4b571827119dc2b Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Sun, 5 Oct 2025 23:35:28 +0000 Subject: [PATCH 01/58] fix quick start docs in zh/en --- docs/en/get_started/quick_start.md | 5 ++--- docs/zh/get_started/quick_start.md | 4 ++-- scripts/run-qwen3-30B-A3B.sh | 18 +++++++++--------- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/docs/en/get_started/quick_start.md b/docs/en/get_started/quick_start.md index ee5686d03e..65a2b1a027 100644 --- a/docs/en/get_started/quick_start.md +++ b/docs/en/get_started/quick_start.md @@ -272,11 +272,10 @@ GRPO_ARGS=( --eps-clip-high 0.28 ) ``` -Of course. Here is the English translation for your documentation: -- `--advantage-estimator`: In addition to `grpo` ([https://arxiv.org/abs/2402.03300](https://arxiv.org/abs/2402.03300)), slime also supports several other training algorithms, such as `gspo` ([https://arxiv.org/abs/2507.18071](https://arxiv.org/abs/2507.18071)), `reinforce_plus_plus` and `reinforce_plus_plus_baseline` ([https://arxiv.org/abs/2501.03262](https://arxiv.org/abs/2501.03262)), and `ppo` ([https://arxiv.org/abs/1707.06347](https://arxiv.org/abs/1707.06347)). +- `--advantage-estimator`: In addition to [GRPO](https://arxiv.org/abs/2402.03300), slime also supports several other training algorithms, such as [GSPO](https://arxiv.org/abs/2507.18071), [Reinforce++](https://arxiv.org/abs/2501.03262) and [Reinforce++ Baseline](https://arxiv.org/abs/2501.03262), and [PPO](https://arxiv.org/abs/1707.06347). - `--calculate-per-token-loss`: By default, slime calculates the loss on a per-sample basis, i.e., `mean(sum(sample_i) / len(sample_i))`. To calculate the loss on a per-token basis, i.e., `sum(sum(sample_i)) / sum(len(sample_i))`, you can enable this flag. -- `--use-tis`: Enable this setting to use TIS (Truncated Importance Sampling) ([https://fengyao.notion.site/off-policy-rl](https://fengyao.notion.site/off-policy-rl)). +- `--use-tis`: Enable this setting to use TIS (Truncated Importance Sampling), which is introduced by this [blog](https://fengyao.notion.site/off-policy-rl). ### OPTIMIZER_ARGS: Optimizer Parameters diff --git a/docs/zh/get_started/quick_start.md b/docs/zh/get_started/quick_start.md index 73d292bfb8..d4c24bdf57 100644 --- a/docs/zh/get_started/quick_start.md +++ b/docs/zh/get_started/quick_start.md @@ -272,9 +272,9 @@ GRPO_ARGS=( ) ``` -- `--advantage-estimator`: 除去 `grpo`(https://arxiv.org/abs/2402.03300),slime 还支持其他的一些训练算法,例如 `gspo`(https://arxiv.org/abs/2507.18071)、`reinforce_plus_plus` 与 `reinforce_plus_plus_baseline`(https://arxiv.org/abs/2501.03262)、以及 `ppo`(https://arxiv.org/abs/1707.06347); +- `--advantage-estimator`: 除去 [GRPO](https://arxiv.org/abs/2402.03300),slime 还支持丰富的其他训练算法,例如 [GSPO](https://arxiv.org/abs/2507.18071)、[Reinforce++](https://arxiv.org/abs/2501.03262) 与 [Reinforce++ Baseline](https://arxiv.org/abs/2501.03262)、以及 [PPO](https://arxiv.org/abs/1707.06347); - `--calculate-per-token-loss`:slime 中默认的方案是 per sample loss,即 `mean(sum(sample_i) / len(sample_i))`,如果需要计算 per token loss,即 `sum(sum(sample_i)) / sum(len(sample_i))`,可以开启 `--calculate-per-token-loss`; -- `--use-tis`:如果需要开启 tis(https://fengyao.notion.site/off-policy-rl),可以开启这一设置。 +- `--use-tis`:如果需要开启 TIS (Truncated Importance Sampling),可以开启这一设置。TIS 由此[博客](https://fengyao.notion.site/off-policy-rl)介绍。 ### OPTIMIZER_ARGS: 优化器参数 diff --git a/scripts/run-qwen3-30B-A3B.sh b/scripts/run-qwen3-30B-A3B.sh index fd8375551b..88f8b99fc7 100644 --- a/scripts/run-qwen3-30B-A3B.sh +++ b/scripts/run-qwen3-30B-A3B.sh @@ -42,18 +42,18 @@ ROLLOUT_ARGS=( --apply-chat-template --rollout-shuffle --rm-type deepscaler - --num-rollout 3000 - --rollout-batch-size 32 - --n-samples-per-prompt 8 + --num-rollout 100 + --rollout-batch-size 16 + --n-samples-per-prompt 4 --rollout-max-response-len 8192 --rollout-temperature 0.8 - --global-batch-size 256 + --global-batch-size 64 --balance-data ) EVAL_ARGS=( - --eval-interval 20 + # --eval-interval 20 --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 @@ -101,10 +101,10 @@ OPTIMIZER_ARGS=( ) WANDB_ARGS=( - #--use-wandb - # --wandb-project slime-dev - # --wandb-group qwen3-30B-A3B-test - # --wandb-key ${WANDB_KEY} + --use-wandb + --wandb-project slime-dev + --wandb-group qwen3-30B-A3B-test + --wandb-key ${WANDB_KEY} ) SGLANG_ARGS=( From 82695df68697f545041d793d58be705a44004ee1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E6=99=A8=E9=98=B3?= Date: Sun, 5 Oct 2025 16:37:12 -0700 Subject: [PATCH 02/58] Update run-qwen3-30B-A3B.sh --- scripts/run-qwen3-30B-A3B.sh | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/scripts/run-qwen3-30B-A3B.sh b/scripts/run-qwen3-30B-A3B.sh index 88f8b99fc7..fd8375551b 100644 --- a/scripts/run-qwen3-30B-A3B.sh +++ b/scripts/run-qwen3-30B-A3B.sh @@ -42,18 +42,18 @@ ROLLOUT_ARGS=( --apply-chat-template --rollout-shuffle --rm-type deepscaler - --num-rollout 100 - --rollout-batch-size 16 - --n-samples-per-prompt 4 + --num-rollout 3000 + --rollout-batch-size 32 + --n-samples-per-prompt 8 --rollout-max-response-len 8192 --rollout-temperature 0.8 - --global-batch-size 64 + --global-batch-size 256 --balance-data ) EVAL_ARGS=( - # --eval-interval 20 + --eval-interval 20 --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 @@ -101,10 +101,10 @@ OPTIMIZER_ARGS=( ) WANDB_ARGS=( - --use-wandb - --wandb-project slime-dev - --wandb-group qwen3-30B-A3B-test - --wandb-key ${WANDB_KEY} + #--use-wandb + # --wandb-project slime-dev + # --wandb-group qwen3-30B-A3B-test + # --wandb-key ${WANDB_KEY} ) SGLANG_ARGS=( From fc0ec271ecc818ea2da70d335b3227d0df7fbfaf Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Mon, 6 Oct 2025 04:20:03 +0000 Subject: [PATCH 03/58] [Importance sampling] seperate importance sampling as a function --- scripts/run-qwen3-30B-A3B.sh | 2 + slime/backends/fsdp_utils/actor.py | 100 +++++---- slime/backends/fsdp_utils/data_packing.py | 4 +- slime/backends/megatron_utils/data.py | 2 +- slime/backends/megatron_utils/loss.py | 50 ++++- slime/backends/megatron_utils/model.py | 2 +- slime/ray/rollout.py | 2 +- slime/utils/arguments.py | 49 +++++ slime/utils/distributed_utils.py | 2 +- slime/utils/tis.py | 204 ++++++++++++++++++ .../rollout_buffer/rollout_buffer_example.py | 2 +- 11 files changed, 369 insertions(+), 50 deletions(-) create mode 100644 slime/utils/tis.py diff --git a/scripts/run-qwen3-30B-A3B.sh b/scripts/run-qwen3-30B-A3B.sh index fd8375551b..3c442c3ebf 100644 --- a/scripts/run-qwen3-30B-A3B.sh +++ b/scripts/run-qwen3-30B-A3B.sh @@ -9,6 +9,7 @@ pkill -9 python sleep 3 pkill -9 ray pkill -9 python +pkill -9 redis set -ex @@ -85,6 +86,7 @@ GRPO_ARGS=( --entropy-coef 0.00 --eps-clip 0.2 --eps-clip-high 0.28 + --use-tis ) OPTIMIZER_ARGS=( diff --git a/slime/backends/fsdp_utils/actor.py b/slime/backends/fsdp_utils/actor.py index 69f8eb926b..4bd16b4a23 100644 --- a/slime/backends/fsdp_utils/actor.py +++ b/slime/backends/fsdp_utils/actor.py @@ -4,13 +4,10 @@ import ray import torch import torch.distributed as dist +from packaging import version +from torch.distributed.tensor import DTensor from torch_memory_saver import torch_memory_saver from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoTokenizer -import logging - - -from torch.distributed.tensor import DTensor -from packaging import version # Import FSDP v2 components based on PyTorch version if version.parse(torch.__version__) >= version.parse("2.6"): @@ -27,6 +24,7 @@ from slime.utils.memory_utils import clear_memory from slime.utils.ppo_utils import compute_approx_kl, compute_policy_loss from slime.utils.timer import Timer, timer +from slime.utils.tis import compute_tis_weights from slime.utils.wandb_utils import init_wandb_secondary from .data_packing import pack_sequences, unpack_sequences @@ -222,7 +220,9 @@ def packed_data(self, rollout_data): rollout_data["response_lengths"][start:end], rollout_data["advantages"][start:end], rollout_data["returns"][start:end], - rollout_log_probs=rollout_data["rollout_log_probs"][start:end] if "rollout_log_probs" in rollout_data else None, + rollout_log_probs=( + rollout_data["rollout_log_probs"][start:end] if "rollout_log_probs" in rollout_data else None + ), num_packs=mbs_size, ) ) @@ -319,31 +319,52 @@ def train(self, rollout_id, rollout_data_ref): ppo_kl = old_log_probs.to(device=log_probs.device) - log_probs advantages = advantages.to(device=ppo_kl.device) pg_loss, pg_clipfrac = compute_policy_loss(ppo_kl, advantages, self.args.eps_clip, self.args.eps_clip_high) - + # Apply TIS before sample mean calculation if self.args.use_tis: # Initialize TIS variables - tis = None - tis_clipfrac = None + tis_weights = None ois = None # Apply TIS off-policy correction using importance sampling assert all( - "rollout_log_probs" in batch and - isinstance(batch["rollout_log_probs"], torch.Tensor) and - batch["rollout_log_probs"].numel() > 0 + "rollout_log_probs" in batch + and isinstance(batch["rollout_log_probs"], torch.Tensor) + and batch["rollout_log_probs"].numel() > 0 for batch in unpacked_batches ), "rollout_log_probs must be provided as non-empty torch.Tensor for TIS" - - rollout_log_probs = torch.cat([batch["rollout_log_probs"] for batch in unpacked_batches], dim=0) - rollout_log_probs = rollout_log_probs.to(device=log_probs.device) - - tis = torch.exp(old_log_probs - rollout_log_probs) + + rollout_log_probs = torch.cat([batch["rollout_log_probs"] for batch in unpacked_batches], dim=0).to( + device=log_probs.device + ) + old_log_probs_flat = old_log_probs + + # Build eos mask from loss masks + eos_mask = torch.cat(loss_masks, dim=0).to(device=log_probs.device) + + upper = getattr(self.args, "tis_threshold", None) + if upper is None: + upper = getattr(self.args, "tis_clip", 2.0) + lower = getattr(self.args, "tis_threshold_lower", None) + if lower is None: + lower = getattr(self.args, "tis_clip_low", 0.0) + + tis_weights, tis_metrics = compute_tis_weights( + old_log_prob=old_log_probs_flat, + rollout_log_prob=rollout_log_probs, + eos_mask=eos_mask, + level=getattr(self.args, "tis_level", "token"), + mode=getattr(self.args, "tis_mode", "truncate"), + upper_threshold=upper, + lower_threshold=lower, + veto_threshold=getattr(self.args, "tis_veto_threshold", 1e-4), + safety_bound=getattr(self.args, "tis_safety_bound", 20.0), + ) + ois = (-ppo_kl).exp() - tis_clip = torch.clamp(tis, min=getattr(self.args, 'tis_clip_low', 0.1), max=getattr(self.args, 'tis_clip', 2.0)) - tis_clipfrac = tis_clip != tis - - pg_loss = pg_loss * tis_clip - + + if tis_weights is not None: + pg_loss = pg_loss * tis_weights + pg_loss = sum_of_sample_mean(pg_loss, response_lengths, loss_masks) pg_clipfrac = sum_of_sample_mean(pg_clipfrac, response_lengths, loss_masks) ppo_kl = sum_of_sample_mean(ppo_kl.abs(), response_lengths, loss_masks) @@ -375,11 +396,23 @@ def train(self, rollout_id, rollout_data_ref): if self.args.use_kl_loss: reported["kl_loss"] = kl_loss.detach() - - if self.args.use_tis and tis is not None: - reported["tis"] = sum_of_sample_mean(tis, response_lengths, loss_masks).detach() + + if self.args.use_tis and tis_weights is not None: reported["ois"] = sum_of_sample_mean(ois, response_lengths, loss_masks).detach() - reported["tis_clipfrac"] = sum_of_sample_mean(tis_clipfrac.float(), response_lengths, loss_masks).detach() + # Extended metrics + for k in [ + "tis_mean", + "tis_std", + "tis_ratio_fraction_high", + "tis_ratio_fraction_low", + "tis_seq_clipped_fraction", + "tis_veto_fraction", + ]: + if k in tis_metrics: + val = tis_metrics[k] + reported[k] = ( + val.detach() if torch.is_tensor(val) else torch.tensor(val, device=log_probs.device) + ) # Scale loss for gradient accumulation loss = loss * dist.get_world_size() / self.args.global_batch_size @@ -443,16 +476,12 @@ def update_weights(self): # type: ignore[override] self.weight_updator.connect_rollout_engines(rollout_engines, rollout_engine_lock) dist.barrier(group=get_gloo_group()) - # For colocated mode with sharded updates (full_params=False), + # For colocated mode with sharded updates (full_params=False), # we don't need to wake up the entire model # The bucket-based approach will load parameters selectively from CPU storage # TODO: Add bucket optimization for from distributed mode - use_bucket_optimization = ( - self.args.colocate and - not getattr(self.weight_updator, 'full_params', False) - ) - - + use_bucket_optimization = self.args.colocate and not getattr(self.weight_updator, "full_params", False) + if self.args.offload and not use_bucket_optimization: # Wake up for distributed mode or full_params mode self.wake_up(("model")) @@ -467,13 +496,13 @@ def update_weights(self): # type: ignore[override] @torch.no_grad() def update_cpu_params_dict(self, params_dict): """Copy model parameters from GPU to CPU storage dictionary""" - + state_dict = self.model.state_dict() for name, param in state_dict.items(): if isinstance(param, DTensor): param = param.full_tensor() - + if name not in params_dict: params_dict[name] = torch.empty_like(param, device=torch.device("cpu"), pin_memory=True) params_dict[name].copy_(param.detach(), non_blocking=True) @@ -492,7 +521,6 @@ def load_ref_model(self, ref_load_path): if ref_load_path is None: raise ValueError("ref_load_path must be provided when loading reference model") - current_weights = {} self.update_cpu_params_dict(current_weights) diff --git a/slime/backends/fsdp_utils/data_packing.py b/slime/backends/fsdp_utils/data_packing.py index e040306f46..0583f97218 100644 --- a/slime/backends/fsdp_utils/data_packing.py +++ b/slime/backends/fsdp_utils/data_packing.py @@ -67,7 +67,7 @@ def pack_sequences( flat_advantages = [] flat_returns = [] flat_rollout_log_probs = [] - + for i in indices: seq_tokens = tokens[i] seq_mask = loss_masks[i] @@ -92,7 +92,7 @@ def pack_sequences( "response_lengths": [response_lengths[i] for i in indices], "advantages": torch.tensor(flat_advantages, dtype=torch.float32), "returns": torch.tensor(flat_returns, dtype=torch.float32), - "rollout_log_probs": torch.tensor(flat_rollout_log_probs, dtype=torch.float32) + "rollout_log_probs": torch.tensor(flat_rollout_log_probs, dtype=torch.float32), } ) diff --git a/slime/backends/megatron_utils/data.py b/slime/backends/megatron_utils/data.py index 82f6398528..325e887bdf 100644 --- a/slime/backends/megatron_utils/data.py +++ b/slime/backends/megatron_utils/data.py @@ -5,10 +5,10 @@ import torch import torch.distributed as dist import torch.nn.functional as F -import wandb from megatron.core import mpu from megatron.core.packed_seq_params import PackedSeqParams +import wandb from slime.utils.data import get_minimum_num_micro_batch_size from slime.utils.flops_utils import calculate_fwd_flops from slime.utils.seqlen_balancing import get_seqlen_balanced_partitions diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 2290fda8dd..7deb87cdd3 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -14,6 +14,7 @@ get_reinforce_plus_plus_baseline_advantages, get_reinforce_plus_plus_returns, ) +from slime.utils.tis import compute_tis_weights from .cp_utils import all_gather_with_cp, get_logits_and_tokens_offset_with_cp, get_sum_of_sample_mean @@ -308,14 +309,36 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): if args.use_tis: assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS" rollout_log_probs = torch.cat(batch["rollout_log_probs"], dim=0) - old_log_probs = torch.cat(batch["log_probs"], dim=0) + old_log_probs_flat = torch.cat(batch["log_probs"], dim=0) + + # Build eos mask from loss masks (concatenated) to match flattened tensors + eos_mask = torch.cat(batch["loss_masks"], dim=0).to(device=log_probs.device) + + # Backward compatible thresholds + upper = args.tis_threshold if getattr(args, "tis_threshold", None) is not None else args.tis_clip + lower = ( + args.tis_threshold_lower + if getattr(args, "tis_threshold_lower", None) is not None + else getattr(args, "tis_clip_low", 0.0) + ) - tis = torch.exp(old_log_probs - rollout_log_probs) + tis_weights, tis_metrics = compute_tis_weights( + old_log_prob=old_log_probs_flat, + rollout_log_prob=rollout_log_probs, + eos_mask=eos_mask, + level=getattr(args, "tis_level", "token"), + mode=getattr(args, "tis_mode", "truncate"), + upper_threshold=upper, + lower_threshold=lower, + veto_threshold=getattr(args, "tis_veto_threshold", 1e-4), + safety_bound=getattr(args, "tis_safety_bound", 20.0), + ) + + # On-policy ratio for monitoring (π_new/π_old) ois = (-ppo_kl).exp() - tis_clip = torch.clamp(tis, min=args.tis_clip_low, max=args.tis_clip) - tis_clipfrac = tis_clip != tis - pg_loss = pg_loss * tis_clip + if tis_weights is not None: + pg_loss = pg_loss * tis_weights pg_loss = sum_of_sample_mean(pg_loss) pg_clipfrac = sum_of_sample_mean(pg_clipfrac) @@ -356,9 +379,22 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): reported_loss["kl_loss"] = kl_loss.clone().detach() if args.use_tis: - reported_loss["tis"] = sum_of_sample_mean(tis).clone().detach() + # Backward compatible basic logs reported_loss["ois"] = sum_of_sample_mean(ois).clone().detach() - reported_loss["tis_clipfrac"] = sum_of_sample_mean(tis_clipfrac).clone().detach() + # Extended metrics from generalized TIS + for k in [ + "tis_mean", + "tis_std", + "tis_ratio_fraction_high", + "tis_ratio_fraction_low", + "tis_seq_clipped_fraction", + "tis_veto_fraction", + ]: + if k in tis_metrics: + val = tis_metrics[k] + reported_loss[k] = ( + val.clone().detach() if torch.is_tensor(val) else torch.tensor(val, device=logits.device) + ) return loss, reported_loss diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index 5b9e31ab41..3b78290dda 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -6,7 +6,6 @@ from functools import partial import torch -import wandb from megatron.core import mpu from megatron.core.distributed import DistributedDataParallel as DDP from megatron.core.distributed import DistributedDataParallelConfig, finalize_model_grads @@ -19,6 +18,7 @@ from megatron.training.global_vars import get_args from megatron.training.training import get_model +import wandb from slime.utils.memory_utils import clear_memory from .checkpoint import load_checkpoint, save_checkpoint diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index bb3247896d..b21214e255 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -8,9 +8,9 @@ import ray import torch -import wandb from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +import wandb from slime.backends.sglang_utils.sglang_engine import SGLangEngine from slime.ray.rollout_data_source import RolloutDataSourceWithBuffer from slime.utils.http_utils import find_available_port, get_host_info, init_http_client diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 52e10a21aa..3a35c515a8 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -654,6 +654,55 @@ def add_algo_arguments(parser): help="Lower bound clipping threshold C for importance sampling ratios to control variance.", ) + # Extended TIS controls (levels/modes/thresholds) with backward compatibility + parser.add_argument( + "--tis-level", + type=str, + choices=["token", "sequence", "geometric"], + default="token", + help=( + "Aggregation level for importance sampling weights: token (per-token), " + "sequence (product over tokens), geometric (geometric mean)." + ), + ) + parser.add_argument( + "--tis-mode", + type=str, + choices=["truncate", "clip"], + default="truncate", + help=( + "Handling mode for IS weights: truncate (cap upper bound, TIS) or clip " + "(zero outside [lower, upper], CIS)." + ), + ) + parser.add_argument( + "--tis-threshold", + type=float, + default=None, + help=("Upper threshold for IS weights. If not set, falls back to --tis-clip (default 2.0)."), + ) + parser.add_argument( + "--tis-threshold-lower", + type=float, + default=None, + help=( + "Lower threshold for IS weights. If not set: for clip mode uses reciprocal of upper; " + "for truncate mode remains unused. If --tis-clip-low provided, that will be used when applicable." + ), + ) + parser.add_argument( + "--tis-veto-threshold", + type=float, + default=1e-4, + help=("Per-token veto threshold. If any token ratio < this, zero the entire sequence weight."), + ) + parser.add_argument( + "--tis-safety-bound", + type=float, + default=20.0, + help=("Safety clamp for log-space ratio to avoid numerical overflow (exp(20) ~= 4.85e8)."), + ) + parser.add_argument( "--use-routing-replay", action="store_true", diff --git a/slime/utils/distributed_utils.py b/slime/utils/distributed_utils.py index d4f34e9480..f358bd34ce 100644 --- a/slime/utils/distributed_utils.py +++ b/slime/utils/distributed_utils.py @@ -108,7 +108,7 @@ def distributed_masked_whiten( Args: values (torch.Tensor): The local tensor of values to whiten. mask (torch.Tensor): The local mask corresponding to the values. - process_group: The process group for all_reduce. + process_group: The process group for all_reduce. If None, uses the default world group. shift_mean (bool): If True, the output is zero-mean. Defaults to True. epsilon (float): A small value for numerical stability. diff --git a/slime/utils/tis.py b/slime/utils/tis.py new file mode 100644 index 0000000000..b558558cc2 --- /dev/null +++ b/slime/utils/tis.py @@ -0,0 +1,204 @@ +import torch +from typing import Any, Dict, Optional, Tuple + + +def masked_sum(x: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor: + if mask is None: + return x.sum(dim=dim) + return (x * mask).sum(dim=dim) + + +def masked_mean(x: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor: + if mask is None: + return x.mean(dim=dim) + denom = mask.sum(dim=dim).clamp_min(1) + return masked_sum(x, mask, dim=dim) / denom + + +def compute_is_metrics( + is_weights: torch.Tensor, + log_ratio_for_metrics: torch.Tensor, + eos_mask: Optional[torch.Tensor], + *, + level: str, + upper_threshold: float, + lower_threshold: float, + log_threshold_upper: torch.Tensor, + log_threshold_lower: torch.Tensor, + has_catastrophic: Optional[torch.Tensor], + catastrophic_tokens: Optional[torch.Tensor], + safety_bound: float, +) -> Dict[str, Any]: + metrics: Dict[str, Any] = {} + + if eos_mask is None: + eos_mask = torch.ones_like(is_weights, dtype=torch.bool) + + device = is_weights.device + + if has_catastrophic is not None: + metrics["tis_veto_fraction"] = has_catastrophic.float().mean() + if catastrophic_tokens is not None and eos_mask is not None: + metrics["tis_catastrophic_token_fraction"] = masked_mean(catastrophic_tokens.float(), eos_mask) + + if level in ["sequence", "geometric"]: + log_max = log_ratio_for_metrics.max() + log_min = log_ratio_for_metrics.min() + metrics["tis_max"] = torch.exp(torch.clamp(log_max, max=safety_bound)) + metrics["tis_min"] = torch.exp(log_min) + metrics["tis_mean"] = masked_mean(is_weights, eos_mask) + exceeds_upper = log_ratio_for_metrics > log_threshold_upper + below_lower = log_ratio_for_metrics < log_threshold_lower + if level == "sequence": + metrics["tis_ratio_fraction_high"] = exceeds_upper.float().mean() + metrics["tis_ratio_fraction_low"] = below_lower.float().mean() + else: + exceeds_upper_exp = exceeds_upper.expand_as(eos_mask) + below_lower_exp = below_lower.expand_as(eos_mask) + metrics["tis_ratio_fraction_high"] = masked_mean(exceeds_upper_exp.float(), eos_mask) + metrics["tis_ratio_fraction_low"] = masked_mean(below_lower_exp.float(), eos_mask) + else: + metrics["tis_mean"] = masked_mean(is_weights, eos_mask) + above = is_weights > upper_threshold + below = is_weights < lower_threshold + metrics["tis_ratio_fraction_high"] = masked_mean(above.float(), eos_mask) + metrics["tis_ratio_fraction_low"] = masked_mean(below.float(), eos_mask) + if eos_mask.any(): + mask_bool = eos_mask.bool() + metrics["tis_max"] = is_weights.masked_fill(~mask_bool, float("-inf")).max() + metrics["tis_min"] = is_weights.masked_fill(~mask_bool, float("inf")).min() + else: + metrics["tis_max"] = torch.tensor(0.0, device=device) + metrics["tis_min"] = torch.tensor(0.0, device=device) + + if eos_mask.any(): + weights_for_std = is_weights.clamp(min=lower_threshold, max=upper_threshold) + var = masked_mean(weights_for_std.square(), eos_mask) - metrics["tis_mean"].square() + metrics["tis_std"] = torch.sqrt(torch.clamp(var, min=0.0)) + weights_for_ess = weights_for_std / (metrics["tis_mean"] + 1e-8) + metrics["tis_eff_sample_size"] = 1.0 / masked_mean(weights_for_ess.square(), eos_mask) + else: + metrics["tis_std"] = torch.tensor(0.0, device=device) + metrics["tis_eff_sample_size"] = torch.tensor(1.0, device=device) + + if is_weights.dim() > 1 and eos_mask.any(): + seq_mean = masked_mean(is_weights, eos_mask, dim=-1) + metrics["tis_seq_mean"] = seq_mean.mean() + metrics["tis_seq_std"] = seq_mean.std() if seq_mean.numel() > 1 else torch.tensor(0.0, device=is_weights.device) + metrics["tis_seq_max"] = seq_mean.max() + metrics["tis_seq_min"] = seq_mean.min() + seq_dev = (seq_mean - 1.0).abs() + metrics["tis_seq_max_deviation"] = seq_dev.max() + metrics["tis_seq_fraction_high"] = (seq_mean > upper_threshold).float().mean() + metrics["tis_seq_fraction_low"] = (seq_mean < 1.0 / upper_threshold).float().mean() + + if eos_mask.any(): + flat = is_weights[eos_mask.bool()] + if flat.numel() > 0: + metrics["tis_p25"] = torch.quantile(flat, 0.25) + metrics["tis_p50"] = torch.quantile(flat, 0.50) + metrics["tis_p75"] = torch.quantile(flat, 0.75) + metrics["tis_p95"] = torch.quantile(flat, 0.95) + metrics["tis_p99"] = torch.quantile(flat, 0.99) + + return metrics + + +def compute_tis_weights( + *, + old_log_prob: torch.Tensor, + rollout_log_prob: torch.Tensor, + eos_mask: Optional[torch.Tensor], + level: str = "token", + mode: str = "truncate", + upper_threshold: Optional[float] = None, + lower_threshold: Optional[float] = None, + veto_threshold: float = 1e-4, + safety_bound: float = 20.0, +) -> Tuple[Optional[torch.Tensor], Dict[str, Any]]: + if upper_threshold is None: + return None, {} + + device = old_log_prob.device + if eos_mask is None: + eos_mask = torch.ones_like(old_log_prob, dtype=torch.bool, device=device) + + if lower_threshold is None: + lower_threshold = 1.0 / upper_threshold + + log_ratio = old_log_prob - rollout_log_prob + + log_threshold_upper = torch.log(torch.tensor(upper_threshold, device=device)) + log_threshold_lower = torch.log(torch.tensor(lower_threshold, device=device)) + + if level == "token": + log_ratio_for_metrics = log_ratio + log_ratio_safe = torch.clamp(log_ratio, min=-safety_bound, max=safety_bound) + weights = torch.exp(log_ratio_safe) + elif level == "sequence": + log_ratio_sum = masked_sum(log_ratio, eos_mask, dim=-1).unsqueeze(-1) + log_ratio_for_metrics = log_ratio_sum + log_ratio_sum_safe = torch.clamp(log_ratio_sum, min=-safety_bound, max=safety_bound) + weights = torch.exp(log_ratio_sum_safe).expand_as(old_log_prob) + elif level == "geometric": + log_ratio_mean = masked_mean(log_ratio, eos_mask, dim=-1).unsqueeze(-1) + log_ratio_for_metrics = log_ratio_mean + log_ratio_mean_safe = torch.clamp(log_ratio_mean, min=-safety_bound, max=safety_bound) + weights = torch.exp(log_ratio_mean_safe).expand_as(old_log_prob) + else: + raise ValueError(f"Invalid tis level: {level}") + + log_veto_threshold = torch.log(torch.tensor(veto_threshold, device=device)) + catastrophic_tokens = (log_ratio < log_veto_threshold) & eos_mask.bool() + has_catastrophic = catastrophic_tokens.any(dim=-1, keepdim=True) + veto_mask = (~has_catastrophic).float() + + metrics = compute_is_metrics( + is_weights=weights, + log_ratio_for_metrics=log_ratio_for_metrics, + eos_mask=eos_mask, + level=level, + upper_threshold=upper_threshold, + lower_threshold=lower_threshold, + log_threshold_upper=log_threshold_upper, + log_threshold_lower=log_threshold_lower, + has_catastrophic=has_catastrophic, + catastrophic_tokens=catastrophic_tokens, + safety_bound=safety_bound, + ) + + if mode == "truncate": + weights = weights.clamp(max=upper_threshold) + elif mode == "clip": + clip_mask = (weights >= lower_threshold) & (weights <= upper_threshold) + clip_mask_f = clip_mask.float() + metrics["tis_clipped_fraction"] = masked_mean(1 - clip_mask_f, eos_mask) + if level in ["sequence", "geometric"]: + seq_w = weights[:, 0] if weights.dim() > 1 else weights + seq_clipped = ((seq_w < lower_threshold) | (seq_w > upper_threshold)).float() + metrics["tis_seq_clipped_fraction"] = seq_clipped.mean() + else: + clipped_indicator = 1 - clip_mask_f + seq_has_clipped = masked_sum(clipped_indicator, eos_mask, dim=-1) > 0 + metrics["tis_seq_clipped_fraction"] = seq_has_clipped.float().mean() + weights = weights * clip_mask_f + else: + raise ValueError(f"Invalid tis mode: {mode}") + + weights = weights * veto_mask + weights = weights * eos_mask + weights = weights.detach() + + metrics.update( + { + "tis_threshold_upper": upper_threshold, + "tis_threshold_lower": lower_threshold, + "tis_level": level, + "tis_mode": mode, + "tis_veto_threshold": veto_threshold, + } + ) + + return weights, metrics + + diff --git a/slime_plugins/rollout_buffer/rollout_buffer_example.py b/slime_plugins/rollout_buffer/rollout_buffer_example.py index ad320bea66..a97d5b18aa 100644 --- a/slime_plugins/rollout_buffer/rollout_buffer_example.py +++ b/slime_plugins/rollout_buffer/rollout_buffer_example.py @@ -4,9 +4,9 @@ import aiohttp import requests -import wandb from transformers import AutoTokenizer +import wandb from slime.utils.async_utils import run from slime.utils.mask_utils import MultiTurnLossMaskGenerator from slime.utils.types import Sample From 35f414c404c3c3679fe1b961407789375dbc0660 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Mon, 6 Oct 2025 04:27:20 +0000 Subject: [PATCH 04/58] fix lint --- slime/backends/fsdp_utils/actor.py | 52 +++++++++---------- slime/backends/fsdp_utils/data_packing.py | 4 +- slime/backends/megatron_utils/data.py | 2 +- slime/backends/megatron_utils/model.py | 2 +- slime/ray/rollout.py | 2 +- slime/utils/distributed_utils.py | 2 +- .../rollout_buffer/rollout_buffer_example.py | 2 +- 7 files changed, 32 insertions(+), 34 deletions(-) diff --git a/slime/backends/fsdp_utils/actor.py b/slime/backends/fsdp_utils/actor.py index 69f8eb926b..31b8eb99f8 100644 --- a/slime/backends/fsdp_utils/actor.py +++ b/slime/backends/fsdp_utils/actor.py @@ -4,13 +4,10 @@ import ray import torch import torch.distributed as dist +from packaging import version +from torch.distributed.tensor import DTensor from torch_memory_saver import torch_memory_saver from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoTokenizer -import logging - - -from torch.distributed.tensor import DTensor -from packaging import version # Import FSDP v2 components based on PyTorch version if version.parse(torch.__version__) >= version.parse("2.6"): @@ -222,7 +219,9 @@ def packed_data(self, rollout_data): rollout_data["response_lengths"][start:end], rollout_data["advantages"][start:end], rollout_data["returns"][start:end], - rollout_log_probs=rollout_data["rollout_log_probs"][start:end] if "rollout_log_probs" in rollout_data else None, + rollout_log_probs=( + rollout_data["rollout_log_probs"][start:end] if "rollout_log_probs" in rollout_data else None + ), num_packs=mbs_size, ) ) @@ -319,7 +318,7 @@ def train(self, rollout_id, rollout_data_ref): ppo_kl = old_log_probs.to(device=log_probs.device) - log_probs advantages = advantages.to(device=ppo_kl.device) pg_loss, pg_clipfrac = compute_policy_loss(ppo_kl, advantages, self.args.eps_clip, self.args.eps_clip_high) - + # Apply TIS before sample mean calculation if self.args.use_tis: # Initialize TIS variables @@ -328,22 +327,24 @@ def train(self, rollout_id, rollout_data_ref): ois = None # Apply TIS off-policy correction using importance sampling assert all( - "rollout_log_probs" in batch and - isinstance(batch["rollout_log_probs"], torch.Tensor) and - batch["rollout_log_probs"].numel() > 0 + "rollout_log_probs" in batch + and isinstance(batch["rollout_log_probs"], torch.Tensor) + and batch["rollout_log_probs"].numel() > 0 for batch in unpacked_batches ), "rollout_log_probs must be provided as non-empty torch.Tensor for TIS" - + rollout_log_probs = torch.cat([batch["rollout_log_probs"] for batch in unpacked_batches], dim=0) rollout_log_probs = rollout_log_probs.to(device=log_probs.device) - + tis = torch.exp(old_log_probs - rollout_log_probs) ois = (-ppo_kl).exp() - tis_clip = torch.clamp(tis, min=getattr(self.args, 'tis_clip_low', 0.1), max=getattr(self.args, 'tis_clip', 2.0)) + tis_clip = torch.clamp( + tis, min=getattr(self.args, "tis_clip_low", 0.1), max=getattr(self.args, "tis_clip", 2.0) + ) tis_clipfrac = tis_clip != tis - + pg_loss = pg_loss * tis_clip - + pg_loss = sum_of_sample_mean(pg_loss, response_lengths, loss_masks) pg_clipfrac = sum_of_sample_mean(pg_clipfrac, response_lengths, loss_masks) ppo_kl = sum_of_sample_mean(ppo_kl.abs(), response_lengths, loss_masks) @@ -375,11 +376,13 @@ def train(self, rollout_id, rollout_data_ref): if self.args.use_kl_loss: reported["kl_loss"] = kl_loss.detach() - + if self.args.use_tis and tis is not None: reported["tis"] = sum_of_sample_mean(tis, response_lengths, loss_masks).detach() reported["ois"] = sum_of_sample_mean(ois, response_lengths, loss_masks).detach() - reported["tis_clipfrac"] = sum_of_sample_mean(tis_clipfrac.float(), response_lengths, loss_masks).detach() + reported["tis_clipfrac"] = sum_of_sample_mean( + tis_clipfrac.float(), response_lengths, loss_masks + ).detach() # Scale loss for gradient accumulation loss = loss * dist.get_world_size() / self.args.global_batch_size @@ -443,16 +446,12 @@ def update_weights(self): # type: ignore[override] self.weight_updator.connect_rollout_engines(rollout_engines, rollout_engine_lock) dist.barrier(group=get_gloo_group()) - # For colocated mode with sharded updates (full_params=False), + # For colocated mode with sharded updates (full_params=False), # we don't need to wake up the entire model # The bucket-based approach will load parameters selectively from CPU storage # TODO: Add bucket optimization for from distributed mode - use_bucket_optimization = ( - self.args.colocate and - not getattr(self.weight_updator, 'full_params', False) - ) - - + use_bucket_optimization = self.args.colocate and not getattr(self.weight_updator, "full_params", False) + if self.args.offload and not use_bucket_optimization: # Wake up for distributed mode or full_params mode self.wake_up(("model")) @@ -467,13 +466,13 @@ def update_weights(self): # type: ignore[override] @torch.no_grad() def update_cpu_params_dict(self, params_dict): """Copy model parameters from GPU to CPU storage dictionary""" - + state_dict = self.model.state_dict() for name, param in state_dict.items(): if isinstance(param, DTensor): param = param.full_tensor() - + if name not in params_dict: params_dict[name] = torch.empty_like(param, device=torch.device("cpu"), pin_memory=True) params_dict[name].copy_(param.detach(), non_blocking=True) @@ -492,7 +491,6 @@ def load_ref_model(self, ref_load_path): if ref_load_path is None: raise ValueError("ref_load_path must be provided when loading reference model") - current_weights = {} self.update_cpu_params_dict(current_weights) diff --git a/slime/backends/fsdp_utils/data_packing.py b/slime/backends/fsdp_utils/data_packing.py index e040306f46..0583f97218 100644 --- a/slime/backends/fsdp_utils/data_packing.py +++ b/slime/backends/fsdp_utils/data_packing.py @@ -67,7 +67,7 @@ def pack_sequences( flat_advantages = [] flat_returns = [] flat_rollout_log_probs = [] - + for i in indices: seq_tokens = tokens[i] seq_mask = loss_masks[i] @@ -92,7 +92,7 @@ def pack_sequences( "response_lengths": [response_lengths[i] for i in indices], "advantages": torch.tensor(flat_advantages, dtype=torch.float32), "returns": torch.tensor(flat_returns, dtype=torch.float32), - "rollout_log_probs": torch.tensor(flat_rollout_log_probs, dtype=torch.float32) + "rollout_log_probs": torch.tensor(flat_rollout_log_probs, dtype=torch.float32), } ) diff --git a/slime/backends/megatron_utils/data.py b/slime/backends/megatron_utils/data.py index 82f6398528..325e887bdf 100644 --- a/slime/backends/megatron_utils/data.py +++ b/slime/backends/megatron_utils/data.py @@ -5,10 +5,10 @@ import torch import torch.distributed as dist import torch.nn.functional as F -import wandb from megatron.core import mpu from megatron.core.packed_seq_params import PackedSeqParams +import wandb from slime.utils.data import get_minimum_num_micro_batch_size from slime.utils.flops_utils import calculate_fwd_flops from slime.utils.seqlen_balancing import get_seqlen_balanced_partitions diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index 5b9e31ab41..3b78290dda 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -6,7 +6,6 @@ from functools import partial import torch -import wandb from megatron.core import mpu from megatron.core.distributed import DistributedDataParallel as DDP from megatron.core.distributed import DistributedDataParallelConfig, finalize_model_grads @@ -19,6 +18,7 @@ from megatron.training.global_vars import get_args from megatron.training.training import get_model +import wandb from slime.utils.memory_utils import clear_memory from .checkpoint import load_checkpoint, save_checkpoint diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index bb3247896d..b21214e255 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -8,9 +8,9 @@ import ray import torch -import wandb from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +import wandb from slime.backends.sglang_utils.sglang_engine import SGLangEngine from slime.ray.rollout_data_source import RolloutDataSourceWithBuffer from slime.utils.http_utils import find_available_port, get_host_info, init_http_client diff --git a/slime/utils/distributed_utils.py b/slime/utils/distributed_utils.py index d4f34e9480..f358bd34ce 100644 --- a/slime/utils/distributed_utils.py +++ b/slime/utils/distributed_utils.py @@ -108,7 +108,7 @@ def distributed_masked_whiten( Args: values (torch.Tensor): The local tensor of values to whiten. mask (torch.Tensor): The local mask corresponding to the values. - process_group: The process group for all_reduce. + process_group: The process group for all_reduce. If None, uses the default world group. shift_mean (bool): If True, the output is zero-mean. Defaults to True. epsilon (float): A small value for numerical stability. diff --git a/slime_plugins/rollout_buffer/rollout_buffer_example.py b/slime_plugins/rollout_buffer/rollout_buffer_example.py index ad320bea66..a97d5b18aa 100644 --- a/slime_plugins/rollout_buffer/rollout_buffer_example.py +++ b/slime_plugins/rollout_buffer/rollout_buffer_example.py @@ -4,9 +4,9 @@ import aiohttp import requests -import wandb from transformers import AutoTokenizer +import wandb from slime.utils.async_utils import run from slime.utils.mask_utils import MultiTurnLossMaskGenerator from slime.utils.types import Sample From 19649f9d29c398d6c77298df20fcac547eca67f0 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Mon, 6 Oct 2025 04:27:50 +0000 Subject: [PATCH 05/58] fix lint of main --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 3a7547ce17..1a3a2a2cba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ pillow pylatexenc pyyaml ray[default] +tensorboard torch transformers wandb -tensorboard From 81ffa4742e39d27bd70c0c10ea7d0c105a5d2a33 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Mon, 6 Oct 2025 04:31:25 +0000 Subject: [PATCH 06/58] adding pre-commit as a CI flow --- .github/workflows/pre-commit.yml | 39 ++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 .github/workflows/pre-commit.yml diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml new file mode 100644 index 0000000000..449d4766d7 --- /dev/null +++ b/.github/workflows/pre-commit.yml @@ -0,0 +1,39 @@ +name: pre-commit + +on: + pull_request: + types: [opened, synchronize, reopened, ready_for_review] + +permissions: + contents: read + +jobs: + run-pre-commit: + name: Run pre-commit + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + cache: 'pip' + + - name: Install pre-commit + run: pip install --upgrade pip pre-commit + + - name: Cache pre-commit environments + uses: actions/cache@v4 + with: + path: ~/.cache/pre-commit + key: pre-commit-${{ runner.os }}-${{ hashFiles('.pre-commit-config.yaml') }} + restore-keys: | + pre-commit-${{ runner.os }}- + + - name: Run pre-commit on all files + run: pre-commit run --all-files --show-diff-on-failure --color=always + From 243cc446efdcbd7a4b3fb2489836690e89548c47 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Mon, 6 Oct 2025 04:52:31 +0000 Subject: [PATCH 07/58] only pre-commit with yml --- .github/workflows/pre-commit.yaml | 21 --------------------- .github/workflows/pre-commit.yml | 2 ++ 2 files changed, 2 insertions(+), 21 deletions(-) delete mode 100644 .github/workflows/pre-commit.yaml diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml deleted file mode 100644 index 63b26d3e38..0000000000 --- a/.github/workflows/pre-commit.yaml +++ /dev/null @@ -1,21 +0,0 @@ -name: pre-commit - -on: - push: - branches: [main] - pull_request: - -jobs: - pre-commit: - runs-on: ubuntu-latest - steps: - - name: Checkout code - uses: actions/checkout@v3 - - name: Setup Python - uses: actions/setup-python@v3 - with: - python-version: '3.x' - - name: Run pre-commit hooks - uses: pre-commit/action@v3.0.1 - with: - extra_args: --all-files diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 449d4766d7..d0e05b27cb 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -1,6 +1,8 @@ name: pre-commit on: + push: + branches: [main] pull_request: types: [opened, synchronize, reopened, ready_for_review] From 0b164f3c44e88dd53eb233c0ed9cd9d69093c73a Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Mon, 6 Oct 2025 04:56:45 +0000 Subject: [PATCH 08/58] fix up pre commit --- .github/workflows/pre-commit.yml | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index d0e05b27cb..7fbe2402b9 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -7,7 +7,7 @@ on: types: [opened, synchronize, reopened, ready_for_review] permissions: - contents: read + contents: write jobs: run-pre-commit: @@ -36,6 +36,26 @@ jobs: restore-keys: | pre-commit-${{ runner.os }}- - - name: Run pre-commit on all files + - name: Run pre-commit (autofix pass) + id: autofix + run: | + set -e + pre-commit run --all-files --show-diff-on-failure --color=always || true + if [[ -n "$(git status --porcelain)" ]]; then + echo "changes=true" >> $GITHUB_OUTPUT + else + echo "changes=false" >> $GITHUB_OUTPUT + fi + + - name: Commit and push fixes + if: steps.autofix.outputs.changes == 'true' && github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository + run: | + git config user.name "github-actions[bot]" + git config user.email "41898282+github-actions[bot]@users.noreply.github.com" + git add -A + git commit -m "chore(pre-commit): apply autofixes" + git push + + - name: Run pre-commit (verify pass) run: pre-commit run --all-files --show-diff-on-failure --color=always From 071717129804b26c17dc312fe96587395c32aac4 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Mon, 6 Oct 2025 05:05:14 +0000 Subject: [PATCH 09/58] unigy local pre-commit with third party --- .github/workflows/pre-commit.yml | 24 ++----------------- pyproject.toml | 3 +++ slime/backends/megatron_utils/data.py | 2 +- slime/backends/megatron_utils/model.py | 2 +- slime/ray/rollout.py | 2 +- slime_plugins/mbridge/glm4.py | 3 ++- .../rollout_buffer/rollout_buffer_example.py | 2 +- tools/convert_hf_to_torch_dist.py | 2 +- 8 files changed, 12 insertions(+), 28 deletions(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 7fbe2402b9..d0e05b27cb 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -7,7 +7,7 @@ on: types: [opened, synchronize, reopened, ready_for_review] permissions: - contents: write + contents: read jobs: run-pre-commit: @@ -36,26 +36,6 @@ jobs: restore-keys: | pre-commit-${{ runner.os }}- - - name: Run pre-commit (autofix pass) - id: autofix - run: | - set -e - pre-commit run --all-files --show-diff-on-failure --color=always || true - if [[ -n "$(git status --porcelain)" ]]; then - echo "changes=true" >> $GITHUB_OUTPUT - else - echo "changes=false" >> $GITHUB_OUTPUT - fi - - - name: Commit and push fixes - if: steps.autofix.outputs.changes == 'true' && github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository - run: | - git config user.name "github-actions[bot]" - git config user.email "41898282+github-actions[bot]@users.noreply.github.com" - git add -A - git commit -m "chore(pre-commit): apply autofixes" - git push - - - name: Run pre-commit (verify pass) + - name: Run pre-commit on all files run: pre-commit run --all-files --show-diff-on-failure --color=always diff --git a/pyproject.toml b/pyproject.toml index 3c2b7d3baf..6c998bb666 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,9 @@ py_version = 310 # python 3.10 as a target version sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"] default_section = "THIRDPARTY" extend_skip = ["setup.py", "docs/source/conf.py"] +known_first_party = ["slime", "slime_plugins"] +known_third_party = ["megatron", "wandb", "ray", "transformers"] +src_paths = ["slime", "slime_plugins"] [tool.black] diff --git a/slime/backends/megatron_utils/data.py b/slime/backends/megatron_utils/data.py index 325e887bdf..82f6398528 100644 --- a/slime/backends/megatron_utils/data.py +++ b/slime/backends/megatron_utils/data.py @@ -5,10 +5,10 @@ import torch import torch.distributed as dist import torch.nn.functional as F +import wandb from megatron.core import mpu from megatron.core.packed_seq_params import PackedSeqParams -import wandb from slime.utils.data import get_minimum_num_micro_batch_size from slime.utils.flops_utils import calculate_fwd_flops from slime.utils.seqlen_balancing import get_seqlen_balanced_partitions diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index 3b78290dda..5b9e31ab41 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -6,6 +6,7 @@ from functools import partial import torch +import wandb from megatron.core import mpu from megatron.core.distributed import DistributedDataParallel as DDP from megatron.core.distributed import DistributedDataParallelConfig, finalize_model_grads @@ -18,7 +19,6 @@ from megatron.training.global_vars import get_args from megatron.training.training import get_model -import wandb from slime.utils.memory_utils import clear_memory from .checkpoint import load_checkpoint, save_checkpoint diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index b21214e255..bb3247896d 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -8,9 +8,9 @@ import ray import torch +import wandb from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy -import wandb from slime.backends.sglang_utils.sglang_engine import SGLangEngine from slime.ray.rollout_data_source import RolloutDataSourceWithBuffer from slime.utils.http_utils import find_available_port, get_host_info, init_http_client diff --git a/slime_plugins/mbridge/glm4.py b/slime_plugins/mbridge/glm4.py index acea62fe85..ef9e6ea703 100644 --- a/slime_plugins/mbridge/glm4.py +++ b/slime_plugins/mbridge/glm4.py @@ -1,6 +1,7 @@ -from mbridge.core import LLMBridge, register_model from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from mbridge.core import LLMBridge, register_model + @register_model("glm4") class GLM4Bridge(LLMBridge): diff --git a/slime_plugins/rollout_buffer/rollout_buffer_example.py b/slime_plugins/rollout_buffer/rollout_buffer_example.py index a97d5b18aa..ad320bea66 100644 --- a/slime_plugins/rollout_buffer/rollout_buffer_example.py +++ b/slime_plugins/rollout_buffer/rollout_buffer_example.py @@ -4,9 +4,9 @@ import aiohttp import requests +import wandb from transformers import AutoTokenizer -import wandb from slime.utils.async_utils import run from slime.utils.mask_utils import MultiTurnLossMaskGenerator from slime.utils.types import Sample diff --git a/tools/convert_hf_to_torch_dist.py b/tools/convert_hf_to_torch_dist.py index d9e8597d14..e56aff8c82 100644 --- a/tools/convert_hf_to_torch_dist.py +++ b/tools/convert_hf_to_torch_dist.py @@ -3,13 +3,13 @@ import torch import torch.distributed as dist -from mbridge import AutoBridge from megatron.core.enums import ModelType from megatron.training.arguments import parse_args, validate_args from megatron.training.checkpointing import get_checkpoint_name, get_checkpoint_tracker_filename, save_checkpoint from megatron.training.training import get_model import slime_plugins.mbridge # noqa: F401 +from mbridge import AutoBridge from slime.backends.megatron_utils import set_default_megatron_args from slime.backends.megatron_utils.initialize import init from slime.backends.megatron_utils.model_provider import get_model_provider_func From 0695c73d853275903a770b80bd5585dc5da01d95 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Mon, 6 Oct 2025 18:35:05 +0000 Subject: [PATCH 10/58] adding kl metircs --- slime/backends/fsdp_utils/actor.py | 31 ++++++++--------- slime/backends/megatron_utils/loss.py | 31 ++++++++--------- slime/utils/tis.py | 49 +++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 34 deletions(-) diff --git a/slime/backends/fsdp_utils/actor.py b/slime/backends/fsdp_utils/actor.py index 4bd16b4a23..b02616af0b 100644 --- a/slime/backends/fsdp_utils/actor.py +++ b/slime/backends/fsdp_utils/actor.py @@ -18,13 +18,14 @@ raise ImportError("FSDP v2 not available") import wandb + from slime.ray.train_actor import TrainRayActor from slime.utils.data import get_minimum_num_micro_batch_size, process_rollout_data from slime.utils.distributed_utils import get_gloo_group from slime.utils.memory_utils import clear_memory from slime.utils.ppo_utils import compute_approx_kl, compute_policy_loss from slime.utils.timer import Timer, timer -from slime.utils.tis import compute_tis_weights +from slime.utils.tis import compute_kl_metrics, compute_tis_weights from slime.utils.wandb_utils import init_wandb_secondary from .data_packing import pack_sequences, unpack_sequences @@ -336,7 +337,6 @@ def train(self, rollout_id, rollout_data_ref): rollout_log_probs = torch.cat([batch["rollout_log_probs"] for batch in unpacked_batches], dim=0).to( device=log_probs.device ) - old_log_probs_flat = old_log_probs # Build eos mask from loss masks eos_mask = torch.cat(loss_masks, dim=0).to(device=log_probs.device) @@ -349,7 +349,7 @@ def train(self, rollout_id, rollout_data_ref): lower = getattr(self.args, "tis_clip_low", 0.0) tis_weights, tis_metrics = compute_tis_weights( - old_log_prob=old_log_probs_flat, + old_log_prob=old_log_probs, rollout_log_prob=rollout_log_probs, eos_mask=eos_mask, level=getattr(self.args, "tis_level", "token"), @@ -365,6 +365,14 @@ def train(self, rollout_id, rollout_data_ref): if tis_weights is not None: pg_loss = pg_loss * tis_weights + # KL metrics next to TIS metrics + kl_metrics = compute_kl_metrics( + old_log_prob=old_log_probs, + rollout_log_prob=rollout_log_probs, + eos_mask=eos_mask, + response_lengths=response_lengths, + ) + pg_loss = sum_of_sample_mean(pg_loss, response_lengths, loss_masks) pg_clipfrac = sum_of_sample_mean(pg_clipfrac, response_lengths, loss_masks) ppo_kl = sum_of_sample_mean(ppo_kl.abs(), response_lengths, loss_masks) @@ -399,20 +407,9 @@ def train(self, rollout_id, rollout_data_ref): if self.args.use_tis and tis_weights is not None: reported["ois"] = sum_of_sample_mean(ois, response_lengths, loss_masks).detach() - # Extended metrics - for k in [ - "tis_mean", - "tis_std", - "tis_ratio_fraction_high", - "tis_ratio_fraction_low", - "tis_seq_clipped_fraction", - "tis_veto_fraction", - ]: - if k in tis_metrics: - val = tis_metrics[k] - reported[k] = ( - val.detach() if torch.is_tensor(val) else torch.tensor(val, device=log_probs.device) - ) + # Report all TIS and KL metrics uniformly + for k, v in {**tis_metrics, **kl_metrics}.items(): + reported[k] = v.detach() if torch.is_tensor(v) else torch.tensor(v, device=log_probs.device) # Scale loss for gradient accumulation loss = loss * dist.get_world_size() / self.args.global_batch_size diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 7deb87cdd3..f2d2c5dca8 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -14,7 +14,7 @@ get_reinforce_plus_plus_baseline_advantages, get_reinforce_plus_plus_returns, ) -from slime.utils.tis import compute_tis_weights +from slime.utils.tis import compute_kl_metrics, compute_tis_weights from .cp_utils import all_gather_with_cp, get_logits_and_tokens_offset_with_cp, get_sum_of_sample_mean @@ -309,7 +309,7 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): if args.use_tis: assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS" rollout_log_probs = torch.cat(batch["rollout_log_probs"], dim=0) - old_log_probs_flat = torch.cat(batch["log_probs"], dim=0) + old_log_probs = torch.cat(batch["log_probs"], dim=0) # Build eos mask from loss masks (concatenated) to match flattened tensors eos_mask = torch.cat(batch["loss_masks"], dim=0).to(device=log_probs.device) @@ -323,7 +323,7 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): ) tis_weights, tis_metrics = compute_tis_weights( - old_log_prob=old_log_probs_flat, + old_log_prob=old_log_probs, rollout_log_prob=rollout_log_probs, eos_mask=eos_mask, level=getattr(args, "tis_level", "token"), @@ -340,6 +340,14 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): if tis_weights is not None: pg_loss = pg_loss * tis_weights + # KL metrics next to TIS metrics + kl_metrics = compute_kl_metrics( + old_log_prob=old_log_probs, + rollout_log_prob=rollout_log_probs, + eos_mask=eos_mask, + response_lengths=batch["response_lengths"], + ) + pg_loss = sum_of_sample_mean(pg_loss) pg_clipfrac = sum_of_sample_mean(pg_clipfrac) ppo_kl = sum_of_sample_mean(ppo_kl) @@ -381,20 +389,9 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): if args.use_tis: # Backward compatible basic logs reported_loss["ois"] = sum_of_sample_mean(ois).clone().detach() - # Extended metrics from generalized TIS - for k in [ - "tis_mean", - "tis_std", - "tis_ratio_fraction_high", - "tis_ratio_fraction_low", - "tis_seq_clipped_fraction", - "tis_veto_fraction", - ]: - if k in tis_metrics: - val = tis_metrics[k] - reported_loss[k] = ( - val.clone().detach() if torch.is_tensor(val) else torch.tensor(val, device=logits.device) - ) + # Report all TIS and KL metrics uniformly + for k, v in {**tis_metrics, **kl_metrics}.items(): + reported_loss[k] = v.clone().detach() if torch.is_tensor(v) else torch.tensor(v, device=logits.device) return loss, reported_loss diff --git a/slime/utils/tis.py b/slime/utils/tis.py index 38c9c9d29f..62d641c043 100644 --- a/slime/utils/tis.py +++ b/slime/utils/tis.py @@ -107,6 +107,55 @@ def compute_is_metrics( return metrics +def compute_kl_metrics( + *, + old_log_prob: torch.Tensor, + rollout_log_prob: torch.Tensor, + eos_mask: Optional[torch.Tensor], + response_lengths: Optional[list[int]] = None, +) -> Dict[str, Any]: + metrics: Dict[str, Any] = {} + + device = old_log_prob.device + if eos_mask is None: + eos_mask = torch.ones_like(old_log_prob, dtype=torch.bool, device=device) + + # Direct estimator for KL(pi_rollout || pi_old): E[log pi_rollout - log pi_old] + metrics["rollout_kl"] = masked_mean(rollout_log_prob - old_log_prob, eos_mask) + + # K3 estimator: E[exp(log(pi_old/pi_rollout)) - log(pi_old/pi_rollout) - 1] + log_ratio = old_log_prob - rollout_log_prob + k3_matrix = torch.exp(log_ratio) - log_ratio - 1 + metrics["rollout_k3_kl"] = masked_mean(k3_matrix, eos_mask) + + # Sequence-level perplexity difference metrics + if old_log_prob.dim() == 2: + mean_log_prob_rollout_per_seq = masked_mean(rollout_log_prob, eos_mask, dim=-1) + mean_log_prob_old_per_seq = masked_mean(old_log_prob, eos_mask, dim=-1) + elif response_lengths is not None and len(response_lengths) > 0 and old_log_prob.dim() == 1: + seq_rollout_means = [] + seq_old_means = [] + start = 0 + for length in response_lengths: + end = start + int(length) + mask_chunk = eos_mask[start:end] if eos_mask is not None else None + seq_rollout_means.append(masked_mean(rollout_log_prob[start:end], mask_chunk)) + seq_old_means.append(masked_mean(old_log_prob[start:end], mask_chunk)) + start = end + mean_log_prob_rollout_per_seq = torch.stack(seq_rollout_means) + mean_log_prob_old_per_seq = torch.stack(seq_old_means) + else: + # Fallback to global means if sequence boundaries are unavailable + mean_log_prob_rollout_per_seq = masked_mean(rollout_log_prob, eos_mask).unsqueeze(0) + mean_log_prob_old_per_seq = masked_mean(old_log_prob, eos_mask).unsqueeze(0) + + diff = mean_log_prob_rollout_per_seq - mean_log_prob_old_per_seq + metrics["log_ppl_diff"] = diff.mean() + metrics["log_ppl_abs_diff"] = diff.abs().mean() + + return metrics + + def compute_tis_weights( *, old_log_prob: torch.Tensor, From cac9fa923f1e69c4fed4e555223375b62cebc0c0 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Mon, 6 Oct 2025 22:21:26 +0000 Subject: [PATCH 11/58] fix type custing for metrics --- scripts/run-qwen3-30B-A3B.sh | 16 ++++++++-------- slime/backends/fsdp_utils/actor.py | 16 ++++++++-------- slime/backends/megatron_utils/loss.py | 18 +++++++++--------- slime/utils/arguments.py | 24 ++++++------------------ 4 files changed, 31 insertions(+), 43 deletions(-) diff --git a/scripts/run-qwen3-30B-A3B.sh b/scripts/run-qwen3-30B-A3B.sh index 3c442c3ebf..fd3e3c51e3 100644 --- a/scripts/run-qwen3-30B-A3B.sh +++ b/scripts/run-qwen3-30B-A3B.sh @@ -44,17 +44,17 @@ ROLLOUT_ARGS=( --rollout-shuffle --rm-type deepscaler --num-rollout 3000 - --rollout-batch-size 32 - --n-samples-per-prompt 8 + --rollout-batch-size 8 + --n-samples-per-prompt 4 --rollout-max-response-len 8192 --rollout-temperature 0.8 - --global-batch-size 256 + --global-batch-size 32 --balance-data ) EVAL_ARGS=( - --eval-interval 20 + # --eval-interval 20 --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 @@ -103,10 +103,10 @@ OPTIMIZER_ARGS=( ) WANDB_ARGS=( - #--use-wandb - # --wandb-project slime-dev - # --wandb-group qwen3-30B-A3B-test - # --wandb-key ${WANDB_KEY} + --use-wandb + --wandb-project slime-dev + --wandb-group qwen3-30B-A3B-TIS + --wandb-key ${WANDB_KEY} ) SGLANG_ARGS=( diff --git a/slime/backends/fsdp_utils/actor.py b/slime/backends/fsdp_utils/actor.py index b02616af0b..36ff2bf401 100644 --- a/slime/backends/fsdp_utils/actor.py +++ b/slime/backends/fsdp_utils/actor.py @@ -341,12 +341,8 @@ def train(self, rollout_id, rollout_data_ref): # Build eos mask from loss masks eos_mask = torch.cat(loss_masks, dim=0).to(device=log_probs.device) - upper = getattr(self.args, "tis_threshold", None) - if upper is None: - upper = getattr(self.args, "tis_clip", 2.0) - lower = getattr(self.args, "tis_threshold_lower", None) - if lower is None: - lower = getattr(self.args, "tis_clip_low", 0.0) + upper = self.args.tis_threshold_upper + lower = self.args.tis_threshold_lower tis_weights, tis_metrics = compute_tis_weights( old_log_prob=old_log_probs, @@ -407,9 +403,13 @@ def train(self, rollout_id, rollout_data_ref): if self.args.use_tis and tis_weights is not None: reported["ois"] = sum_of_sample_mean(ois, response_lengths, loss_masks).detach() - # Report all TIS and KL metrics uniformly + # Report all TIS and KL metrics uniformly, filtering out non-numeric values for k, v in {**tis_metrics, **kl_metrics}.items(): - reported[k] = v.detach() if torch.is_tensor(v) else torch.tensor(v, device=log_probs.device) + if torch.is_tensor(v): + reported[k] = v.detach() + elif isinstance(v, (int, float)): + reported[k] = torch.tensor(v, device=log_probs.device) + # Skip string and other non-numeric types # Scale loss for gradient accumulation loss = loss * dist.get_world_size() / self.args.global_batch_size diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index f2d2c5dca8..8fe0422371 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -314,13 +314,9 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): # Build eos mask from loss masks (concatenated) to match flattened tensors eos_mask = torch.cat(batch["loss_masks"], dim=0).to(device=log_probs.device) - # Backward compatible thresholds - upper = args.tis_threshold if getattr(args, "tis_threshold", None) is not None else args.tis_clip - lower = ( - args.tis_threshold_lower - if getattr(args, "tis_threshold_lower", None) is not None - else getattr(args, "tis_clip_low", 0.0) - ) + # Use the new threshold parameters + upper = args.tis_threshold_upper + lower = args.tis_threshold_lower tis_weights, tis_metrics = compute_tis_weights( old_log_prob=old_log_probs, @@ -389,9 +385,13 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): if args.use_tis: # Backward compatible basic logs reported_loss["ois"] = sum_of_sample_mean(ois).clone().detach() - # Report all TIS and KL metrics uniformly + # Report all TIS and KL metrics uniformly, filtering out non-numeric values for k, v in {**tis_metrics, **kl_metrics}.items(): - reported_loss[k] = v.clone().detach() if torch.is_tensor(v) else torch.tensor(v, device=logits.device) + if torch.is_tensor(v): + reported_loss[k] = v.clone().detach() + elif isinstance(v, (int, float)): + reported_loss[k] = torch.tensor(v, device=logits.device) + # Skip string and other non-numeric types return loss, reported_loss diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 3a35c515a8..97ba3f8610 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -641,18 +641,6 @@ def add_algo_arguments(parser): default=False, help="Enable TIS from https://fengyao.notion.site/off-policy-rl for off-policy importance sampling.", ) - parser.add_argument( - "--tis-clip", - type=float, - default=2.0, - help="Clipping threshold C for importance sampling ratios to control variance.", - ) - parser.add_argument( - "--tis-clip-low", - type=float, - default=0, - help="Lower bound clipping threshold C for importance sampling ratios to control variance.", - ) # Extended TIS controls (levels/modes/thresholds) with backward compatibility parser.add_argument( @@ -676,18 +664,18 @@ def add_algo_arguments(parser): ), ) parser.add_argument( - "--tis-threshold", + "--tis-threshold-upper", type=float, - default=None, - help=("Upper threshold for IS weights. If not set, falls back to --tis-clip (default 2.0)."), + default=2.0, + help=("Upper threshold for IS weights. Default is 2.0."), ) parser.add_argument( "--tis-threshold-lower", type=float, - default=None, + default=0.0, help=( - "Lower threshold for IS weights. If not set: for clip mode uses reciprocal of upper; " - "for truncate mode remains unused. If --tis-clip-low provided, that will be used when applicable." + "Lower threshold for IS weights. Default is 0.0. " + "For clip mode uses this value; for truncate mode remains unused." ), ) parser.add_argument( From e7cf0c26f79f771515521997f664150735979c84 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Tue, 7 Oct 2025 03:31:39 +0000 Subject: [PATCH 12/58] comments to compute_tis_weights --- scripts/run-qwen3-30B-A3B.sh | 6 +- slime/utils/tis.py | 234 ++++++++++++++++++++--------------- 2 files changed, 136 insertions(+), 104 deletions(-) diff --git a/scripts/run-qwen3-30B-A3B.sh b/scripts/run-qwen3-30B-A3B.sh index fd3e3c51e3..d9dabecece 100644 --- a/scripts/run-qwen3-30B-A3B.sh +++ b/scripts/run-qwen3-30B-A3B.sh @@ -31,9 +31,9 @@ CKPT_ARGS=( --hf-checkpoint /root/Qwen3-30B-A3B #--hf-checkpoint /root/Qwen3-30B-A3B-FP8 --ref-load /root/Qwen3-30B-A3B_torch_dist - --load /root/Qwen3-30B-A3B_slime/ - --save /root/Qwen3-30B-A3B_slime/ - --save-interval 20 + # --load /root/Qwen3-30B-A3B_slime/ + # --save /root/Qwen3-30B-A3B_slime/ + # --save-interval 20 ) ROLLOUT_ARGS=( diff --git a/slime/utils/tis.py b/slime/utils/tis.py index 62d641c043..85d2b7c1ab 100644 --- a/slime/utils/tis.py +++ b/slime/utils/tis.py @@ -16,6 +16,135 @@ def masked_mean(x: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Ten return masked_sum(x, mask, dim=dim) / denom +def compute_tis_weights( + *, + old_log_prob: torch.Tensor, + rollout_log_prob: torch.Tensor, + eos_mask: Optional[torch.Tensor], + level: str = "token", + mode: str = "truncate", + upper_threshold: Optional[float] = None, + lower_threshold: Optional[float] = None, + veto_threshold: float = 1e-4, + safety_bound: float = 20.0, +) -> Tuple[Optional[torch.Tensor], Dict[str, Any]]: + """ + Compute the truncated importance sampling (TIS) weights and metrics. + + Adapted from: + + https://fengyao.notion.site/off-policy-rl#279721e3f6c48092bbe2fcfe0e9c6b33 + https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda + + Args: + old_log_prob: The log probs from the policy model in the training backend. [batch_size, seq_len] + rollout_log_prob: The log probs from the policy model in the rollout backend. [batch_size, seq_len] + eos_mask: The mask of valid tokens. [batch_size, seq_len] + level: The aggregation level for the importance sampling weights. + - "token": per-token importance sampling weights, biased low variance. + - "sequence": product over tokens, unbiased but high variance. + - "geometric": geometric mean over tokens, biased, medium variance. + mode: how to handle the importance sampling weights exceeding the thresholds. + - "truncate": cap the importance sampling weights at the upper threshold, i.e., truncated importance sampling. + - "clip": zero the importance sampling weights outside the [lower, upper] range. + upper_threshold: The upper threshold for the importance sampling weights. + lower_threshold: The lower threshold for the importance sampling weights, only used in "clip" mode. + If not provided, it will be set to 1.0 / upper_threshold. + veto_threshold: If any token's importance sampling weight is less than this, zero the entire sequence weight. + safety_bound: The safety bound for the log-space ratio to avoid numerical overflow. + + Returns: + weights: The importance sampling weights. [batch_size, seq_len] + metrics: The metrics for the importance sampling weights. + """ + if upper_threshold is None: + return None, {} + if lower_threshold is None: + lower_threshold = 1.0 / upper_threshold + + device = old_log_prob.device + log_ratio = old_log_prob - rollout_log_prob + + log_upper_threshold = torch.log(torch.tensor(upper_threshold, device=device)) + log_lower_threshold = torch.log(torch.tensor(lower_threshold, device=device)) + + if level == "token": + # Token-level IS: π_training(a|s) / π_rollout(a|s) per token + # The truncation will be applied later. + log_ratio_for_metrics = log_ratio # [batch_size, seq_len] + log_ratio_safe = torch.clamp(log_ratio, min=-safety_bound, max=safety_bound) + weights = torch.exp(log_ratio_safe) + elif level == "sequence": + # Sequence-level IS: π_training(a|s) / π_rollout(a|s) across the entire sequence + log_ratio_sum = masked_sum(log_ratio, eos_mask, dim=-1).unsqueeze(-1) + log_ratio_for_metrics = log_ratio_sum # [batch_size, 1] + log_ratio_sum_safe = torch.clamp(log_ratio_sum, min=-safety_bound, max=safety_bound) + weights = torch.exp(log_ratio_sum_safe).expand_as(old_log_prob) + elif level == "geometric": + log_ratio_mean = masked_mean(log_ratio, eos_mask, dim=-1).unsqueeze(-1) + log_ratio_for_metrics = log_ratio_mean # [batch_size, 1] + log_ratio_mean_safe = torch.clamp(log_ratio_mean, min=-safety_bound, max=safety_bound) + weights = torch.exp(log_ratio_mean_safe).expand_as(old_log_prob) + else: + raise ValueError(f"Invalid importance sampling level: {level}") + + log_veto_threshold = torch.log(torch.tensor(veto_threshold, device=device)) + # Veto sequences with any token's log ratio below the threshold. + # log(π_training / π_rollout) < log(veto_threshold) ⟺ π_training / π_rollout < veto_threshold + catastrophic_tokens = (log_ratio < log_veto_threshold) & eos_mask.bool() + has_catastrophic = catastrophic_tokens.any(dim=-1, keepdim=True) # [batch_size, 1] + veto_mask = (~has_catastrophic).float() # [batch_size, 1] + + metrics = compute_is_metrics( + is_weights=weights, + log_ratio_for_metrics=log_ratio_for_metrics, + eos_mask=eos_mask, + level=level, + upper_threshold=upper_threshold, + lower_threshold=lower_threshold, + log_upper_threshold=log_upper_threshold, + log_lower_threshold=log_lower_threshold, + has_catastrophic=has_catastrophic, + catastrophic_tokens=catastrophic_tokens, + safety_bound=safety_bound, + ) + + if mode == "truncate": + weights = weights.clamp(max=upper_threshold) + elif mode == "clip": + clip_mask = (weights >= lower_threshold) & (weights <= upper_threshold) + clip_mask_f = clip_mask.float() + metrics["tis_clipped_fraction"] = masked_mean(1 - clip_mask_f, eos_mask) + if level in ["sequence", "geometric"]: + seq_w = weights[:, 0] if weights.dim() > 1 else weights + seq_clipped = ((seq_w < lower_threshold) | (seq_w > upper_threshold)).float() + metrics["tis_seq_clipped_fraction"] = seq_clipped.mean() + else: + clipped_indicator = 1 - clip_mask_f + seq_has_clipped = masked_sum(clipped_indicator, eos_mask, dim=-1) > 0 + metrics["tis_seq_clipped_fraction"] = seq_has_clipped.float().mean() + weights = weights * clip_mask_f + else: + raise ValueError(f"Invalid tis mode: {mode}") + + weights = weights * veto_mask + weights = weights * eos_mask + weights = weights.detach() + + metrics.update( + { + "tis_threshold_upper": upper_threshold, + "tis_threshold_lower": lower_threshold, + "tis_level": level, + "tis_mode": mode, + "tis_veto_threshold": veto_threshold, + } + ) + + return weights, metrics + + + def compute_is_metrics( is_weights: torch.Tensor, log_ratio_for_metrics: torch.Tensor, @@ -24,8 +153,8 @@ def compute_is_metrics( level: str, upper_threshold: float, lower_threshold: float, - log_threshold_upper: torch.Tensor, - log_threshold_lower: torch.Tensor, + log_upper_threshold: torch.Tensor, + log_lower_threshold: torch.Tensor, has_catastrophic: Optional[torch.Tensor], catastrophic_tokens: Optional[torch.Tensor], safety_bound: float, @@ -48,8 +177,8 @@ def compute_is_metrics( metrics["tis_max"] = torch.exp(torch.clamp(log_max, max=safety_bound)) metrics["tis_min"] = torch.exp(log_min) metrics["tis_mean"] = masked_mean(is_weights, eos_mask) - exceeds_upper = log_ratio_for_metrics > log_threshold_upper - below_lower = log_ratio_for_metrics < log_threshold_lower + exceeds_upper = log_ratio_for_metrics > log_upper_threshold + below_lower = log_ratio_for_metrics < log_lower_threshold if level == "sequence": metrics["tis_ratio_fraction_high"] = exceeds_upper.float().mean() metrics["tis_ratio_fraction_low"] = below_lower.float().mean() @@ -155,100 +284,3 @@ def compute_kl_metrics( return metrics - -def compute_tis_weights( - *, - old_log_prob: torch.Tensor, - rollout_log_prob: torch.Tensor, - eos_mask: Optional[torch.Tensor], - level: str = "token", - mode: str = "truncate", - upper_threshold: Optional[float] = None, - lower_threshold: Optional[float] = None, - veto_threshold: float = 1e-4, - safety_bound: float = 20.0, -) -> Tuple[Optional[torch.Tensor], Dict[str, Any]]: - if upper_threshold is None: - return None, {} - - device = old_log_prob.device - if eos_mask is None: - eos_mask = torch.ones_like(old_log_prob, dtype=torch.bool, device=device) - - if lower_threshold is None: - lower_threshold = 1.0 / upper_threshold - - log_ratio = old_log_prob - rollout_log_prob - - log_threshold_upper = torch.log(torch.tensor(upper_threshold, device=device)) - log_threshold_lower = torch.log(torch.tensor(lower_threshold, device=device)) - - if level == "token": - log_ratio_for_metrics = log_ratio - log_ratio_safe = torch.clamp(log_ratio, min=-safety_bound, max=safety_bound) - weights = torch.exp(log_ratio_safe) - elif level == "sequence": - log_ratio_sum = masked_sum(log_ratio, eos_mask, dim=-1).unsqueeze(-1) - log_ratio_for_metrics = log_ratio_sum - log_ratio_sum_safe = torch.clamp(log_ratio_sum, min=-safety_bound, max=safety_bound) - weights = torch.exp(log_ratio_sum_safe).expand_as(old_log_prob) - elif level == "geometric": - log_ratio_mean = masked_mean(log_ratio, eos_mask, dim=-1).unsqueeze(-1) - log_ratio_for_metrics = log_ratio_mean - log_ratio_mean_safe = torch.clamp(log_ratio_mean, min=-safety_bound, max=safety_bound) - weights = torch.exp(log_ratio_mean_safe).expand_as(old_log_prob) - else: - raise ValueError(f"Invalid tis level: {level}") - - log_veto_threshold = torch.log(torch.tensor(veto_threshold, device=device)) - catastrophic_tokens = (log_ratio < log_veto_threshold) & eos_mask.bool() - has_catastrophic = catastrophic_tokens.any(dim=-1, keepdim=True) - veto_mask = (~has_catastrophic).float() - - metrics = compute_is_metrics( - is_weights=weights, - log_ratio_for_metrics=log_ratio_for_metrics, - eos_mask=eos_mask, - level=level, - upper_threshold=upper_threshold, - lower_threshold=lower_threshold, - log_threshold_upper=log_threshold_upper, - log_threshold_lower=log_threshold_lower, - has_catastrophic=has_catastrophic, - catastrophic_tokens=catastrophic_tokens, - safety_bound=safety_bound, - ) - - if mode == "truncate": - weights = weights.clamp(max=upper_threshold) - elif mode == "clip": - clip_mask = (weights >= lower_threshold) & (weights <= upper_threshold) - clip_mask_f = clip_mask.float() - metrics["tis_clipped_fraction"] = masked_mean(1 - clip_mask_f, eos_mask) - if level in ["sequence", "geometric"]: - seq_w = weights[:, 0] if weights.dim() > 1 else weights - seq_clipped = ((seq_w < lower_threshold) | (seq_w > upper_threshold)).float() - metrics["tis_seq_clipped_fraction"] = seq_clipped.mean() - else: - clipped_indicator = 1 - clip_mask_f - seq_has_clipped = masked_sum(clipped_indicator, eos_mask, dim=-1) > 0 - metrics["tis_seq_clipped_fraction"] = seq_has_clipped.float().mean() - weights = weights * clip_mask_f - else: - raise ValueError(f"Invalid tis mode: {mode}") - - weights = weights * veto_mask - weights = weights * eos_mask - weights = weights.detach() - - metrics.update( - { - "tis_threshold_upper": upper_threshold, - "tis_threshold_lower": lower_threshold, - "tis_level": level, - "tis_mode": mode, - "tis_veto_threshold": veto_threshold, - } - ) - - return weights, metrics From 2924975adfc070fa41b96fd0f30c348507f3f1cf Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Tue, 7 Oct 2025 03:32:10 +0000 Subject: [PATCH 13/58] [lint] tis comment --- slime/utils/tis.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/slime/utils/tis.py b/slime/utils/tis.py index 85d2b7c1ab..d3aa2c09f8 100644 --- a/slime/utils/tis.py +++ b/slime/utils/tis.py @@ -71,18 +71,18 @@ def compute_tis_weights( if level == "token": # Token-level IS: π_training(a|s) / π_rollout(a|s) per token # The truncation will be applied later. - log_ratio_for_metrics = log_ratio # [batch_size, seq_len] + log_ratio_for_metrics = log_ratio # [batch_size, seq_len] log_ratio_safe = torch.clamp(log_ratio, min=-safety_bound, max=safety_bound) weights = torch.exp(log_ratio_safe) elif level == "sequence": # Sequence-level IS: π_training(a|s) / π_rollout(a|s) across the entire sequence log_ratio_sum = masked_sum(log_ratio, eos_mask, dim=-1).unsqueeze(-1) - log_ratio_for_metrics = log_ratio_sum # [batch_size, 1] + log_ratio_for_metrics = log_ratio_sum # [batch_size, 1] log_ratio_sum_safe = torch.clamp(log_ratio_sum, min=-safety_bound, max=safety_bound) weights = torch.exp(log_ratio_sum_safe).expand_as(old_log_prob) elif level == "geometric": log_ratio_mean = masked_mean(log_ratio, eos_mask, dim=-1).unsqueeze(-1) - log_ratio_for_metrics = log_ratio_mean # [batch_size, 1] + log_ratio_for_metrics = log_ratio_mean # [batch_size, 1] log_ratio_mean_safe = torch.clamp(log_ratio_mean, min=-safety_bound, max=safety_bound) weights = torch.exp(log_ratio_mean_safe).expand_as(old_log_prob) else: @@ -92,8 +92,8 @@ def compute_tis_weights( # Veto sequences with any token's log ratio below the threshold. # log(π_training / π_rollout) < log(veto_threshold) ⟺ π_training / π_rollout < veto_threshold catastrophic_tokens = (log_ratio < log_veto_threshold) & eos_mask.bool() - has_catastrophic = catastrophic_tokens.any(dim=-1, keepdim=True) # [batch_size, 1] - veto_mask = (~has_catastrophic).float() # [batch_size, 1] + has_catastrophic = catastrophic_tokens.any(dim=-1, keepdim=True) # [batch_size, 1] + veto_mask = (~has_catastrophic).float() # [batch_size, 1] metrics = compute_is_metrics( is_weights=weights, @@ -144,7 +144,6 @@ def compute_tis_weights( return weights, metrics - def compute_is_metrics( is_weights: torch.Tensor, log_ratio_for_metrics: torch.Tensor, @@ -283,4 +282,3 @@ def compute_kl_metrics( metrics["log_ppl_abs_diff"] = diff.abs().mean() return metrics - From 7963809c1b24e397251cd3c07aa41052a27b8433 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Tue, 7 Oct 2025 04:35:09 +0000 Subject: [PATCH 14/58] refactor clip mode in sequence level --- scripts/run-qwen3-30B-A3B.sh | 1 + slime/utils/tis.py | 22 ++++++++++++---------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/scripts/run-qwen3-30B-A3B.sh b/scripts/run-qwen3-30B-A3B.sh index d9dabecece..f31c84042e 100644 --- a/scripts/run-qwen3-30B-A3B.sh +++ b/scripts/run-qwen3-30B-A3B.sh @@ -87,6 +87,7 @@ GRPO_ARGS=( --eps-clip 0.2 --eps-clip-high 0.28 --use-tis + --tis-level sequence ) OPTIMIZER_ARGS=( diff --git a/slime/utils/tis.py b/slime/utils/tis.py index d3aa2c09f8..49cd5c0121 100644 --- a/slime/utils/tis.py +++ b/slime/utils/tis.py @@ -110,20 +110,22 @@ def compute_tis_weights( ) if mode == "truncate": + # only truncate the weights at the upper threshold weights = weights.clamp(max=upper_threshold) elif mode == "clip": - clip_mask = (weights >= lower_threshold) & (weights <= upper_threshold) - clip_mask_f = clip_mask.float() - metrics["tis_clipped_fraction"] = masked_mean(1 - clip_mask_f, eos_mask) + # zero the weights outside the [lower, upper] range if level in ["sequence", "geometric"]: - seq_w = weights[:, 0] if weights.dim() > 1 else weights - seq_clipped = ((seq_w < lower_threshold) | (seq_w > upper_threshold)).float() - metrics["tis_seq_clipped_fraction"] = seq_clipped.mean() + seq_weights = weights[:, 0] if weights.dim() > 1 else weights + sequence_clipped = ((seq_weights < lower_threshold) | (seq_weights > upper_threshold)).float() + metrics["tis_sequence_clipped_fraction"] = sequence_clipped.mean() else: - clipped_indicator = 1 - clip_mask_f - seq_has_clipped = masked_sum(clipped_indicator, eos_mask, dim=-1) > 0 - metrics["tis_seq_clipped_fraction"] = seq_has_clipped.float().mean() - weights = weights * clip_mask_f + clip_mask = (weights >= lower_threshold) & (weights <= upper_threshold) + clip_mask = clip_mask.float() + clipped_indicator = 1 - clip_mask + metrics["tis_token_clipped_fraction"] = masked_mean(clipped_indicator, eos_mask) + sequence_has_clipped = masked_sum(clipped_indicator, eos_mask, dim=-1) > 0 + metrics["tis_sequence_clipped_fraction"] = sequence_has_clipped.float().mean() + weights = weights * clip_mask else: raise ValueError(f"Invalid tis mode: {mode}") From 6f36eef8ed803f74038a0de1ea42787daee86086 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Tue, 7 Oct 2025 06:37:18 +0000 Subject: [PATCH 15/58] [test] geometric level --- scripts/run-qwen3-30B-A3B.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/run-qwen3-30B-A3B.sh b/scripts/run-qwen3-30B-A3B.sh index f31c84042e..14a6cddb17 100644 --- a/scripts/run-qwen3-30B-A3B.sh +++ b/scripts/run-qwen3-30B-A3B.sh @@ -87,7 +87,7 @@ GRPO_ARGS=( --eps-clip 0.2 --eps-clip-high 0.28 --use-tis - --tis-level sequence + --tis-level geometric ) OPTIMIZER_ARGS=( @@ -107,6 +107,7 @@ WANDB_ARGS=( --use-wandb --wandb-project slime-dev --wandb-group qwen3-30B-A3B-TIS + --wandb-run-id qwen3-30B-A3B-TIS-sequence --wandb-key ${WANDB_KEY} ) From 3cc498225edebf1ab1ecc654fac34d8e9af2fa56 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Tue, 7 Oct 2025 20:04:15 +0000 Subject: [PATCH 16/58] adding metrics to new tis --- slime/utils/tis.py | 169 +++++++++++++++++++++++---------------------- 1 file changed, 87 insertions(+), 82 deletions(-) diff --git a/slime/utils/tis.py b/slime/utils/tis.py index 49cd5c0121..252472de06 100644 --- a/slime/utils/tis.py +++ b/slime/utils/tis.py @@ -4,23 +4,37 @@ def masked_sum(x: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor: - if mask is None: - return x.sum(dim=dim) + """ + Computes the sum of the tensor x, masked by the mask. + + x = [[1, 2, 3], [4, 5, 6]] + mask = [[1, 1, 1], [1, 1, 0]] + masked_sum(x, mask, dim=-1) = [6, 9] + """ + valid_tokens = mask.sum(dim=dim) + assert valid_tokens.min() > 0, "any sequence must have at least one valid token" + assert x.shape == mask.shape, "x and mask must have the same shape" return (x * mask).sum(dim=dim) def masked_mean(x: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor: - if mask is None: - return x.mean(dim=dim) - denom = mask.sum(dim=dim).clamp_min(1) - return masked_sum(x, mask, dim=dim) / denom + """ + Computes the mean of the tensor x, masked by the mask. + + x = [[1, 2, 3], [4, 5, 6]] + mask = [[1, 1, 1], [1, 1, 0]] + masked_mean(x, mask, dim=-1) = [2, 4.5] + """ + valid_tokens = mask.sum(dim=dim) + assert valid_tokens.min() > 0, "any sequence must have at least one valid token" + return masked_sum(x, mask, dim=dim) / valid_tokens def compute_tis_weights( *, old_log_prob: torch.Tensor, rollout_log_prob: torch.Tensor, - eos_mask: Optional[torch.Tensor], + eos_mask: torch.Tensor, level: str = "token", mode: str = "truncate", upper_threshold: Optional[float] = None, @@ -57,6 +71,10 @@ def compute_tis_weights( weights: The importance sampling weights. [batch_size, seq_len] metrics: The metrics for the importance sampling weights. """ + assert ( + eos_mask.shape == old_log_prob.shape and eos_mask.shape == rollout_log_prob.shape + ), "eos_mask, old_log_prob, and rollout_log_prob must have the same shape" + if upper_threshold is None: return None, {} if lower_threshold is None: @@ -95,8 +113,8 @@ def compute_tis_weights( has_catastrophic = catastrophic_tokens.any(dim=-1, keepdim=True) # [batch_size, 1] veto_mask = (~has_catastrophic).float() # [batch_size, 1] - metrics = compute_is_metrics( - is_weights=weights, + metrics = compute_tis_metrics( + tis_weights=weights, log_ratio_for_metrics=log_ratio_for_metrics, eos_mask=eos_mask, level=level, @@ -133,106 +151,93 @@ def compute_tis_weights( weights = weights * eos_mask weights = weights.detach() - metrics.update( - { - "tis_threshold_upper": upper_threshold, - "tis_threshold_lower": lower_threshold, - "tis_level": level, - "tis_mode": mode, - "tis_veto_threshold": veto_threshold, - } - ) - + metrics["tis_threshold_upper"] = upper_threshold + metrics["tis_threshold_lower"] = lower_threshold + metrics["tis_level"] = level + metrics["tis_mode"] = mode + metrics["tis_veto_threshold"] = veto_threshold return weights, metrics -def compute_is_metrics( - is_weights: torch.Tensor, - log_ratio_for_metrics: torch.Tensor, - eos_mask: Optional[torch.Tensor], +def compute_tis_metrics( *, + tis_weights: torch.Tensor, + log_ratio_for_metrics: torch.Tensor, + eos_mask: torch.Tensor, level: str, upper_threshold: float, lower_threshold: float, log_upper_threshold: torch.Tensor, log_lower_threshold: torch.Tensor, - has_catastrophic: Optional[torch.Tensor], - catastrophic_tokens: Optional[torch.Tensor], + has_catastrophic: torch.Tensor, + catastrophic_tokens: torch.Tensor, safety_bound: float, ) -> Dict[str, Any]: + """ + Computes metrics that reflect the TRUE distribution (before clamping) + for the truncated importance sampling (TIS) weights. + """ metrics: Dict[str, Any] = {} - if eos_mask is None: - eos_mask = torch.ones_like(is_weights, dtype=torch.bool) - - device = is_weights.device + assert eos_mask.shape == tis_weights.shape, "eos_mask and tis_weights must have the same shape" - if has_catastrophic is not None: - metrics["tis_veto_fraction"] = has_catastrophic.float().mean() - if catastrophic_tokens is not None and eos_mask is not None: - metrics["tis_catastrophic_token_fraction"] = masked_mean(catastrophic_tokens.float(), eos_mask) + metrics["tis_veto_fraction"] = has_catastrophic.float().mean() + metrics["tis_catastrophic_token_fraction"] = masked_mean(catastrophic_tokens.float(), eos_mask) + metrics["tis_level"] = level + metrics["tis_upper_threshold"] = upper_threshold + metrics["tis_lower_threshold"] = lower_threshold + metrics["tis_log_upper_threshold"] = log_upper_threshold + metrics["tis_log_lower_threshold"] = log_lower_threshold + metrics["tis_safety_bound"] = safety_bound if level in ["sequence", "geometric"]: log_max = log_ratio_for_metrics.max() log_min = log_ratio_for_metrics.min() metrics["tis_max"] = torch.exp(torch.clamp(log_max, max=safety_bound)) metrics["tis_min"] = torch.exp(log_min) - metrics["tis_mean"] = masked_mean(is_weights, eos_mask) + metrics["tis_mean"] = masked_mean(tis_weights, eos_mask) exceeds_upper = log_ratio_for_metrics > log_upper_threshold below_lower = log_ratio_for_metrics < log_lower_threshold if level == "sequence": - metrics["tis_ratio_fraction_high"] = exceeds_upper.float().mean() - metrics["tis_ratio_fraction_low"] = below_lower.float().mean() + metrics["tis_ratio_fraction_exceeds_upper"] = exceeds_upper.float().mean() + metrics["tis_ratio_fraction_below_lower"] = below_lower.float().mean() else: exceeds_upper_exp = exceeds_upper.expand_as(eos_mask) below_lower_exp = below_lower.expand_as(eos_mask) - metrics["tis_ratio_fraction_high"] = masked_mean(exceeds_upper_exp.float(), eos_mask) - metrics["tis_ratio_fraction_low"] = masked_mean(below_lower_exp.float(), eos_mask) - else: - metrics["tis_mean"] = masked_mean(is_weights, eos_mask) - above = is_weights > upper_threshold - below = is_weights < lower_threshold - metrics["tis_ratio_fraction_high"] = masked_mean(above.float(), eos_mask) - metrics["tis_ratio_fraction_low"] = masked_mean(below.float(), eos_mask) - if eos_mask.any(): - mask_bool = eos_mask.bool() - metrics["tis_max"] = is_weights.masked_fill(~mask_bool, float("-inf")).max() - metrics["tis_min"] = is_weights.masked_fill(~mask_bool, float("inf")).min() - else: - metrics["tis_max"] = torch.tensor(0.0, device=device) - metrics["tis_min"] = torch.tensor(0.0, device=device) - - if eos_mask.any(): - weights_for_std = is_weights.clamp(min=lower_threshold, max=upper_threshold) - var = masked_mean(weights_for_std.square(), eos_mask) - metrics["tis_mean"].square() - metrics["tis_std"] = torch.sqrt(torch.clamp(var, min=0.0)) - weights_for_ess = weights_for_std / (metrics["tis_mean"] + 1e-8) - metrics["tis_eff_sample_size"] = 1.0 / masked_mean(weights_for_ess.square(), eos_mask) + metrics["tis_ratio_fraction_exceeds_upper"] = masked_mean(exceeds_upper_exp.float(), eos_mask) + metrics["tis_ratio_fraction_below_lower"] = masked_mean(below_lower_exp.float(), eos_mask) else: - metrics["tis_std"] = torch.tensor(0.0, device=device) - metrics["tis_eff_sample_size"] = torch.tensor(1.0, device=device) - - if is_weights.dim() > 1 and eos_mask.any(): - seq_mean = masked_mean(is_weights, eos_mask, dim=-1) - metrics["tis_seq_mean"] = seq_mean.mean() - metrics["tis_seq_std"] = ( - seq_mean.std() if seq_mean.numel() > 1 else torch.tensor(0.0, device=is_weights.device) - ) - metrics["tis_seq_max"] = seq_mean.max() - metrics["tis_seq_min"] = seq_mean.min() - seq_dev = (seq_mean - 1.0).abs() - metrics["tis_seq_max_deviation"] = seq_dev.max() - metrics["tis_seq_fraction_high"] = (seq_mean > upper_threshold).float().mean() - metrics["tis_seq_fraction_low"] = (seq_mean < 1.0 / upper_threshold).float().mean() - - if eos_mask.any(): - flat = is_weights[eos_mask.bool()] - if flat.numel() > 0: - metrics["tis_p25"] = torch.quantile(flat, 0.25) - metrics["tis_p50"] = torch.quantile(flat, 0.50) - metrics["tis_p75"] = torch.quantile(flat, 0.75) - metrics["tis_p95"] = torch.quantile(flat, 0.95) - metrics["tis_p99"] = torch.quantile(flat, 0.99) + metrics["tis_mean"] = masked_mean(tis_weights, eos_mask) + exceeds_upper = tis_weights > upper_threshold + below_lower = tis_weights < lower_threshold + metrics["tis_ratio_fraction_exceeds_upper"] = masked_mean(exceeds_upper.float(), eos_mask) + metrics["tis_ratio_fraction_below_lower"] = masked_mean(below_lower.float(), eos_mask) + valid = eos_mask.bool() + metrics["tis_max"] = tis_weights[valid].max() + metrics["tis_min"] = tis_weights[valid].min() + + weights_for_std = tis_weights.clamp(min=lower_threshold, max=upper_threshold) + var = masked_mean(weights_for_std.square(), eos_mask) - metrics["tis_mean"].square() + metrics["tis_std"] = torch.sqrt(torch.clamp(var, min=0.0)) + weights_for_ess = weights_for_std / (metrics["tis_mean"] + 1e-8) + metrics["tis_eff_sample_size"] = 1.0 / masked_mean(weights_for_ess.square(), eos_mask) + + seq_mean = masked_mean(tis_weights, eos_mask, dim=-1) + metrics["tis_seq_mean"] = seq_mean.mean() + metrics["tis_seq_std"] = seq_mean.std() + metrics["tis_seq_max"] = seq_mean.max() + metrics["tis_seq_min"] = seq_mean.min() + seq_dev = (seq_mean - 1.0).abs() + metrics["tis_seq_max_deviation"] = seq_dev.max() + metrics["tis_seq_fraction_exceeds_upper"] = (seq_mean > upper_threshold).float().mean() + metrics["tis_seq_fraction_below_lower"] = (seq_mean < lower_threshold).float().mean() + + flat = tis_weights[eos_mask.bool()] + metrics["tis_p25"] = torch.quantile(flat, 0.25) + metrics["tis_p50"] = torch.quantile(flat, 0.50) + metrics["tis_p75"] = torch.quantile(flat, 0.75) + metrics["tis_p95"] = torch.quantile(flat, 0.95) + metrics["tis_p99"] = torch.quantile(flat, 0.99) return metrics From d60a5958a2a6705d3ceeb1bfbd33e6ff43410bbc Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Thu, 9 Oct 2025 02:12:22 +0000 Subject: [PATCH 17/58] [log probs in 1D] --- slime/backends/fsdp_utils/actor.py | 1 + slime/backends/megatron_utils/loss.py | 32 +++- slime/utils/tis.py | 232 +++++++++++++++----------- 3 files changed, 168 insertions(+), 97 deletions(-) diff --git a/slime/backends/fsdp_utils/actor.py b/slime/backends/fsdp_utils/actor.py index 36ff2bf401..742d705951 100644 --- a/slime/backends/fsdp_utils/actor.py +++ b/slime/backends/fsdp_utils/actor.py @@ -354,6 +354,7 @@ def train(self, rollout_id, rollout_data_ref): lower_threshold=lower, veto_threshold=getattr(self.args, "tis_veto_threshold", 1e-4), safety_bound=getattr(self.args, "tis_safety_bound", 20.0), + response_lengths=response_lengths, ) ois = (-ppo_kl).exp() diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 8fe0422371..a7f1310cb9 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -16,7 +16,12 @@ ) from slime.utils.tis import compute_kl_metrics, compute_tis_weights -from .cp_utils import all_gather_with_cp, get_logits_and_tokens_offset_with_cp, get_sum_of_sample_mean +from .cp_utils import ( + all_gather_with_cp, + get_logits_and_tokens_offset_with_cp, + get_sum_of_sample_mean, + slice_log_prob_with_cp, +) def get_responses( @@ -311,13 +316,33 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): rollout_log_probs = torch.cat(batch["rollout_log_probs"], dim=0) old_log_probs = torch.cat(batch["log_probs"], dim=0) - # Build eos mask from loss masks (concatenated) to match flattened tensors - eos_mask = torch.cat(batch["loss_masks"], dim=0).to(device=log_probs.device) + # Build eos mask aligned with local (possibly CP-chunked) flattened tensors + cp_size = mpu.get_context_parallel_world_size() + if cp_size == 1: + eos_mask = torch.cat(batch["loss_masks"], dim=0).to(device=log_probs.device) + else: + # if CP enabled, use slice_log_prob_with_cp to slice loss_mask + mask_chunks = [ + slice_log_prob_with_cp(loss_mask, total_len, resp_len) + for loss_mask, total_len, resp_len in zip( + batch["loss_masks"], batch["total_lengths"], batch["response_lengths"] + ) + ] + eos_mask = torch.cat(mask_chunks).to(device=log_probs.device) + + # Ensure shape alignment with log-probs tensors for TIS + assert ( + eos_mask.shape == old_log_probs.shape + ), f"eos_mask {eos_mask.shape} vs old_log_probs {old_log_probs.shape}" + assert ( + eos_mask.shape == rollout_log_probs.shape + ), f"eos_mask {eos_mask.shape} vs rollout_log_probs {rollout_log_probs.shape}" # Use the new threshold parameters upper = args.tis_threshold_upper lower = args.tis_threshold_lower + assert upper == 2.0 tis_weights, tis_metrics = compute_tis_weights( old_log_prob=old_log_probs, rollout_log_prob=rollout_log_probs, @@ -328,6 +353,7 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): lower_threshold=lower, veto_threshold=getattr(args, "tis_veto_threshold", 1e-4), safety_bound=getattr(args, "tis_safety_bound", 20.0), + response_lengths=batch["response_lengths"], ) # On-policy ratio for monitoring (π_new/π_old) diff --git a/slime/utils/tis.py b/slime/utils/tis.py index 252472de06..869745fdc5 100644 --- a/slime/utils/tis.py +++ b/slime/utils/tis.py @@ -30,6 +30,24 @@ def masked_mean(x: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Ten return masked_sum(x, mask, dim=dim) / valid_tokens +def per_seq_masked_mean( + x: torch.Tensor, + mask: torch.Tensor, + response_lengths: Optional[list[int]] = None, +) -> torch.Tensor: + """ + 计算按样本的 masked mean 后再求和,返回一个可加性的标量(适配 DP 汇总)。 + 支持二维 [B, T] 与拍平后一维、并提供 response_lengths 的两种输入形态。 + """ + if response_lengths is not None and len(response_lengths) > 0: + xs = torch.split(x, [int(l) for l in response_lengths], dim=0) + ms = torch.split(mask, [int(l) for l in response_lengths], dim=0) + seq_means = [masked_mean(xi, mi) for xi, mi in zip(xs, ms)] + return torch.stack(seq_means).sum() + # fallback:视为单一样本 + return masked_mean(x, mask).unsqueeze(0).sum() + + def compute_tis_weights( *, old_log_prob: torch.Tensor, @@ -41,6 +59,7 @@ def compute_tis_weights( lower_threshold: Optional[float] = None, veto_threshold: float = 1e-4, safety_bound: float = 20.0, + response_lengths: Optional[list[int]] = None, ) -> Tuple[Optional[torch.Tensor], Dict[str, Any]]: """ Compute the truncated importance sampling (TIS) weights and metrics. @@ -51,9 +70,9 @@ def compute_tis_weights( https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda Args: - old_log_prob: The log probs from the policy model in the training backend. [batch_size, seq_len] - rollout_log_prob: The log probs from the policy model in the rollout backend. [batch_size, seq_len] - eos_mask: The mask of valid tokens. [batch_size, seq_len] + old_log_prob: Flattened log probs from training backend. Shape: [sum(response_lengths)] + rollout_log_prob: Flattened log probs from rollout backend. Shape: [sum(response_lengths)] + eos_mask: Flattened mask aligned with flattened tensors. Shape: [sum(response_lengths)] level: The aggregation level for the importance sampling weights. - "token": per-token importance sampling weights, biased low variance. - "sequence": product over tokens, unbiased but high variance. @@ -74,6 +93,7 @@ def compute_tis_weights( assert ( eos_mask.shape == old_log_prob.shape and eos_mask.shape == rollout_log_prob.shape ), "eos_mask, old_log_prob, and rollout_log_prob must have the same shape" + assert response_lengths is not None and len(response_lengths) > 0, "response_lengths must be provided" if upper_threshold is None: return None, {} @@ -89,20 +109,29 @@ def compute_tis_weights( if level == "token": # Token-level IS: π_training(a|s) / π_rollout(a|s) per token # The truncation will be applied later. - log_ratio_for_metrics = log_ratio # [batch_size, seq_len] + log_ratio_for_metrics = log_ratio # [sum(response_lengths)] log_ratio_safe = torch.clamp(log_ratio, min=-safety_bound, max=safety_bound) weights = torch.exp(log_ratio_safe) - elif level == "sequence": - # Sequence-level IS: π_training(a|s) / π_rollout(a|s) across the entire sequence - log_ratio_sum = masked_sum(log_ratio, eos_mask, dim=-1).unsqueeze(-1) - log_ratio_for_metrics = log_ratio_sum # [batch_size, 1] - log_ratio_sum_safe = torch.clamp(log_ratio_sum, min=-safety_bound, max=safety_bound) - weights = torch.exp(log_ratio_sum_safe).expand_as(old_log_prob) - elif level == "geometric": - log_ratio_mean = masked_mean(log_ratio, eos_mask, dim=-1).unsqueeze(-1) - log_ratio_for_metrics = log_ratio_mean # [batch_size, 1] - log_ratio_mean_safe = torch.clamp(log_ratio_mean, min=-safety_bound, max=safety_bound) - weights = torch.exp(log_ratio_mean_safe).expand_as(old_log_prob) + elif level in ["sequence", "geometric"]: + # Sequence-level/geometric: compute per-sequence aggregate in log-space, then expand to tokens + xs = torch.split(log_ratio, [int(l) for l in response_lengths], dim=0) + ms = torch.split(eos_mask, [int(l) for l in response_lengths], dim=0) + per_seq_vals = [] + for xi, mi in zip(xs, ms): + if level == "sequence": + val = (xi * mi).sum() + else: # geometric + val = masked_mean(xi, mi) + per_seq_vals.append(torch.clamp(val, min=-safety_bound, max=safety_bound)) + per_seq_vals = torch.stack(per_seq_vals) # [num_sequences] + per_seq_weights = torch.exp(per_seq_vals) + # Expand to per-token weights per sequence + expanded = [] + for w, xi in zip(per_seq_weights, xs): + expanded.append(torch.ones_like(xi) * w) + weights = torch.cat(expanded, dim=0) + # For metrics that need the aggregated log-ratio, keep per-seq values + log_ratio_for_metrics = per_seq_vals else: raise ValueError(f"Invalid importance sampling level: {level}") @@ -110,8 +139,16 @@ def compute_tis_weights( # Veto sequences with any token's log ratio below the threshold. # log(π_training / π_rollout) < log(veto_threshold) ⟺ π_training / π_rollout < veto_threshold catastrophic_tokens = (log_ratio < log_veto_threshold) & eos_mask.bool() - has_catastrophic = catastrophic_tokens.any(dim=-1, keepdim=True) # [batch_size, 1] - veto_mask = (~has_catastrophic).float() # [batch_size, 1] + # Build per-sequence veto and expand to tokens + cat_chunks = torch.split(catastrophic_tokens, [int(l) for l in response_lengths], dim=0) + has_catastrophic_per_seq = torch.tensor([chunk.any() for chunk in cat_chunks], device=device) + veto_mask = torch.cat( + [ + torch.zeros_like(chunk, dtype=torch.float32) if has_cat else torch.ones_like(chunk, dtype=torch.float32) + for has_cat, chunk in zip(has_catastrophic_per_seq, cat_chunks) + ], + dim=0, + ) metrics = compute_tis_metrics( tis_weights=weights, @@ -122,9 +159,10 @@ def compute_tis_weights( lower_threshold=lower_threshold, log_upper_threshold=log_upper_threshold, log_lower_threshold=log_lower_threshold, - has_catastrophic=has_catastrophic, + has_catastrophic=has_catastrophic_per_seq, catastrophic_tokens=catastrophic_tokens, safety_bound=safety_bound, + response_lengths=response_lengths, ) if mode == "truncate": @@ -172,6 +210,7 @@ def compute_tis_metrics( has_catastrophic: torch.Tensor, catastrophic_tokens: torch.Tensor, safety_bound: float, + response_lengths: Optional[list[int]] = None, ) -> Dict[str, Any]: """ Computes metrics that reflect the TRUE distribution (before clamping) @@ -181,63 +220,67 @@ def compute_tis_metrics( assert eos_mask.shape == tis_weights.shape, "eos_mask and tis_weights must have the same shape" - metrics["tis_veto_fraction"] = has_catastrophic.float().mean() - metrics["tis_catastrophic_token_fraction"] = masked_mean(catastrophic_tokens.float(), eos_mask) + # Counts/fractions reported as sum over sequences; external reducer divides by num_samples + metrics["tis_veto_fraction"] = has_catastrophic.float().sum() + metrics["tis_catastrophic_token_fraction"] = per_seq_masked_mean( + catastrophic_tokens.float(), eos_mask, response_lengths=response_lengths + ) metrics["tis_level"] = level - metrics["tis_upper_threshold"] = upper_threshold - metrics["tis_lower_threshold"] = lower_threshold - metrics["tis_log_upper_threshold"] = log_upper_threshold - metrics["tis_log_lower_threshold"] = log_lower_threshold - metrics["tis_safety_bound"] = safety_bound + assert upper_threshold == 2.0 + # Make numeric constants DP-safe by scaling with number of sequences in this batch + if tis_weights.dim() == 2: + num_sequences = tis_weights.size(0) + elif response_lengths is not None and len(response_lengths) > 0: + num_sequences = len(response_lengths) + else: + num_sequences = 1 + metrics["tis_upper_threshold"] = torch.tensor(2.0 * num_sequences, device=tis_weights.device) + metrics["tis_lower_threshold"] = torch.tensor(lower_threshold * num_sequences, device=tis_weights.device) + metrics["tis_log_upper_threshold"] = log_upper_threshold * num_sequences + metrics["tis_log_lower_threshold"] = log_lower_threshold * num_sequences + metrics["tis_safety_bound"] = torch.tensor(safety_bound * num_sequences, device=tis_weights.device) if level in ["sequence", "geometric"]: - log_max = log_ratio_for_metrics.max() - log_min = log_ratio_for_metrics.min() - metrics["tis_max"] = torch.exp(torch.clamp(log_max, max=safety_bound)) - metrics["tis_min"] = torch.exp(log_min) - metrics["tis_mean"] = masked_mean(tis_weights, eos_mask) - exceeds_upper = log_ratio_for_metrics > log_upper_threshold - below_lower = log_ratio_for_metrics < log_lower_threshold - if level == "sequence": - metrics["tis_ratio_fraction_exceeds_upper"] = exceeds_upper.float().mean() - metrics["tis_ratio_fraction_below_lower"] = below_lower.float().mean() - else: - exceeds_upper_exp = exceeds_upper.expand_as(eos_mask) - below_lower_exp = below_lower.expand_as(eos_mask) - metrics["tis_ratio_fraction_exceeds_upper"] = masked_mean(exceeds_upper_exp.float(), eos_mask) - metrics["tis_ratio_fraction_below_lower"] = masked_mean(below_lower_exp.float(), eos_mask) + # log_ratio_for_metrics is per-seq aggregated log-ratio: compare per-seq + exceeds_upper = (log_ratio_for_metrics > log_upper_threshold).float().sum() + below_lower = (log_ratio_for_metrics < log_lower_threshold).float().sum() + metrics["tis_ratio_fraction_exceeds_upper"] = exceeds_upper + metrics["tis_ratio_fraction_below_lower"] = below_lower + metrics["tis_mean"] = per_seq_masked_mean(tis_weights, eos_mask, response_lengths=response_lengths) else: - metrics["tis_mean"] = masked_mean(tis_weights, eos_mask) - exceeds_upper = tis_weights > upper_threshold - below_lower = tis_weights < lower_threshold - metrics["tis_ratio_fraction_exceeds_upper"] = masked_mean(exceeds_upper.float(), eos_mask) - metrics["tis_ratio_fraction_below_lower"] = masked_mean(below_lower.float(), eos_mask) - valid = eos_mask.bool() - metrics["tis_max"] = tis_weights[valid].max() - metrics["tis_min"] = tis_weights[valid].min() - + metrics["tis_mean"] = per_seq_masked_mean(tis_weights, eos_mask, response_lengths=response_lengths) + exceeds_upper = (tis_weights > upper_threshold).float() + below_lower = (tis_weights < lower_threshold).float() + metrics["tis_ratio_fraction_exceeds_upper"] = per_seq_masked_mean( + exceeds_upper, eos_mask, response_lengths=response_lengths + ) + metrics["tis_ratio_fraction_below_lower"] = per_seq_masked_mean( + below_lower, eos_mask, response_lengths=response_lengths + ) + + # Per-sequence std and ESS, reported as sum across sequences weights_for_std = tis_weights.clamp(min=lower_threshold, max=upper_threshold) - var = masked_mean(weights_for_std.square(), eos_mask) - metrics["tis_mean"].square() - metrics["tis_std"] = torch.sqrt(torch.clamp(var, min=0.0)) - weights_for_ess = weights_for_std / (metrics["tis_mean"] + 1e-8) - metrics["tis_eff_sample_size"] = 1.0 / masked_mean(weights_for_ess.square(), eos_mask) - - seq_mean = masked_mean(tis_weights, eos_mask, dim=-1) - metrics["tis_seq_mean"] = seq_mean.mean() - metrics["tis_seq_std"] = seq_mean.std() - metrics["tis_seq_max"] = seq_mean.max() - metrics["tis_seq_min"] = seq_mean.min() - seq_dev = (seq_mean - 1.0).abs() - metrics["tis_seq_max_deviation"] = seq_dev.max() - metrics["tis_seq_fraction_exceeds_upper"] = (seq_mean > upper_threshold).float().mean() - metrics["tis_seq_fraction_below_lower"] = (seq_mean < lower_threshold).float().mean() - - flat = tis_weights[eos_mask.bool()] - metrics["tis_p25"] = torch.quantile(flat, 0.25) - metrics["tis_p50"] = torch.quantile(flat, 0.50) - metrics["tis_p75"] = torch.quantile(flat, 0.75) - metrics["tis_p95"] = torch.quantile(flat, 0.95) - metrics["tis_p99"] = torch.quantile(flat, 0.99) + xs = torch.split(tis_weights, [int(l) for l in response_lengths], dim=0) + ms = torch.split(eos_mask, [int(l) for l in response_lengths], dim=0) + per_seq_mean = torch.stack([masked_mean(xi, mi) for xi, mi in zip(xs, ms)]) + per_seq_var = ( + torch.stack( + [masked_mean(xi.clamp(min=lower_threshold, max=upper_threshold).square(), mi) for xi, mi in zip(xs, ms)] + ) + - per_seq_mean.square() + ) + per_seq_std = torch.sqrt(torch.clamp(per_seq_var, min=0.0)) + metrics["tis_std"] = per_seq_std.sum() + # ESS per sequence using normalized weights + weights_for_ess_list = [xi / (pm + 1e-8) for xi, pm in zip(xs, per_seq_mean)] + per_seq_ess = torch.stack([1.0 / masked_mean(xi.square(), mi) for xi, mi in zip(weights_for_ess_list, ms)]) + metrics["tis_eff_sample_size"] = per_seq_ess.sum() + seq_mean = per_seq_mean + + # Sequence-level summaries (sum-style for DP-safe reduction) + metrics["tis_seq_mean"] = seq_mean.sum() + metrics["tis_seq_fraction_exceeds_upper"] = (seq_mean > upper_threshold).float().sum() + metrics["tis_seq_fraction_below_lower"] = (seq_mean < lower_threshold).float().sum() return metrics @@ -255,37 +298,38 @@ def compute_kl_metrics( if eos_mask is None: eos_mask = torch.ones_like(old_log_prob, dtype=torch.bool, device=device) - # Direct estimator for KL(pi_rollout || pi_old): E[log pi_rollout - log pi_old] - metrics["rollout_kl"] = masked_mean(rollout_log_prob - old_log_prob, eos_mask) + # Direct estimator for KL(pi_rollout || pi_old): per-seq mean then sum (1D inputs only) + assert response_lengths is not None and eos_mask is not None + xs = torch.split(rollout_log_prob - old_log_prob, [int(l) for l in response_lengths], dim=0) + ms = torch.split(eos_mask, [int(l) for l in response_lengths], dim=0) + per_seq = [masked_mean(xi, mi) for xi, mi in zip(xs, ms)] + metrics["rollout_kl"] = torch.stack(per_seq).sum() # K3 estimator: E[exp(log(pi_old/pi_rollout)) - log(pi_old/pi_rollout) - 1] log_ratio = old_log_prob - rollout_log_prob k3_matrix = torch.exp(log_ratio) - log_ratio - 1 - metrics["rollout_k3_kl"] = masked_mean(k3_matrix, eos_mask) + xs = torch.split(k3_matrix, [int(l) for l in response_lengths], dim=0) + ms = torch.split(eos_mask, [int(l) for l in response_lengths], dim=0) + per_seq = [masked_mean(xi, mi) for xi, mi in zip(xs, ms)] + metrics["rollout_k3_kl"] = torch.stack(per_seq).sum() # Sequence-level perplexity difference metrics - if old_log_prob.dim() == 2: - mean_log_prob_rollout_per_seq = masked_mean(rollout_log_prob, eos_mask, dim=-1) - mean_log_prob_old_per_seq = masked_mean(old_log_prob, eos_mask, dim=-1) - elif response_lengths is not None and len(response_lengths) > 0 and old_log_prob.dim() == 1: - seq_rollout_means = [] - seq_old_means = [] - start = 0 - for length in response_lengths: - end = start + int(length) - mask_chunk = eos_mask[start:end] if eos_mask is not None else None - seq_rollout_means.append(masked_mean(rollout_log_prob[start:end], mask_chunk)) - seq_old_means.append(masked_mean(old_log_prob[start:end], mask_chunk)) - start = end - mean_log_prob_rollout_per_seq = torch.stack(seq_rollout_means) - mean_log_prob_old_per_seq = torch.stack(seq_old_means) - else: - # Fallback to global means if sequence boundaries are unavailable - mean_log_prob_rollout_per_seq = masked_mean(rollout_log_prob, eos_mask).unsqueeze(0) - mean_log_prob_old_per_seq = masked_mean(old_log_prob, eos_mask).unsqueeze(0) + assert response_lengths is not None and len(response_lengths) > 0 + seq_rollout_means = [] + seq_old_means = [] + start = 0 + for length in response_lengths: + end = start + int(length) + mask_chunk = eos_mask[start:end] + seq_rollout_means.append(masked_mean(rollout_log_prob[start:end], mask_chunk)) + seq_old_means.append(masked_mean(old_log_prob[start:end], mask_chunk)) + start = end + mean_log_prob_rollout_per_seq = torch.stack(seq_rollout_means) + mean_log_prob_old_per_seq = torch.stack(seq_old_means) diff = mean_log_prob_rollout_per_seq - mean_log_prob_old_per_seq - metrics["log_ppl_diff"] = diff.mean() - metrics["log_ppl_abs_diff"] = diff.abs().mean() + # report sums; external reducer divides by num_samples + metrics["log_ppl_diff"] = diff.sum() + metrics["log_ppl_abs_diff"] = diff.abs().sum() return metrics From 5cac6e0e5333c0d904fe9290e7fd492d53422551 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Fri, 10 Oct 2025 00:53:15 +0000 Subject: [PATCH 18/58] stash with main --- slime/backends/fsdp_utils/actor.py | 6 +- slime/backends/megatron_utils/loss.py | 93 ++++++++++++--------- slime/utils/tis.py | 112 ++++++++++++++++---------- 3 files changed, 127 insertions(+), 84 deletions(-) diff --git a/slime/backends/fsdp_utils/actor.py b/slime/backends/fsdp_utils/actor.py index 742d705951..e314f0bbfc 100644 --- a/slime/backends/fsdp_utils/actor.py +++ b/slime/backends/fsdp_utils/actor.py @@ -339,7 +339,7 @@ def train(self, rollout_id, rollout_data_ref): ) # Build eos mask from loss masks - eos_mask = torch.cat(loss_masks, dim=0).to(device=log_probs.device) + loss_mask = torch.cat(loss_masks, dim=0).to(device=log_probs.device) upper = self.args.tis_threshold_upper lower = self.args.tis_threshold_lower @@ -347,7 +347,7 @@ def train(self, rollout_id, rollout_data_ref): tis_weights, tis_metrics = compute_tis_weights( old_log_prob=old_log_probs, rollout_log_prob=rollout_log_probs, - eos_mask=eos_mask, + loss_mask=loss_mask, level=getattr(self.args, "tis_level", "token"), mode=getattr(self.args, "tis_mode", "truncate"), upper_threshold=upper, @@ -366,7 +366,7 @@ def train(self, rollout_id, rollout_data_ref): kl_metrics = compute_kl_metrics( old_log_prob=old_log_probs, rollout_log_prob=rollout_log_probs, - eos_mask=eos_mask, + loss_mask=loss_mask, response_lengths=response_lengths, ) diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index a7f1310cb9..48aaf28a76 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -1,3 +1,4 @@ +import re from typing import Union import torch @@ -313,61 +314,77 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): # Apply TIS off-policy correction using importance sampling if enabled if args.use_tis: assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS" - rollout_log_probs = torch.cat(batch["rollout_log_probs"], dim=0) - old_log_probs = torch.cat(batch["log_probs"], dim=0) - - # Build eos mask aligned with local (possibly CP-chunked) flattened tensors cp_size = mpu.get_context_parallel_world_size() - if cp_size == 1: - eos_mask = torch.cat(batch["loss_masks"], dim=0).to(device=log_probs.device) - else: - # if CP enabled, use slice_log_prob_with_cp to slice loss_mask - mask_chunks = [ - slice_log_prob_with_cp(loss_mask, total_len, resp_len) - for loss_mask, total_len, resp_len in zip( - batch["loss_masks"], batch["total_lengths"], batch["response_lengths"] - ) - ] - eos_mask = torch.cat(mask_chunks).to(device=log_probs.device) - - # Ensure shape alignment with log-probs tensors for TIS - assert ( - eos_mask.shape == old_log_probs.shape - ), f"eos_mask {eos_mask.shape} vs old_log_probs {old_log_probs.shape}" - assert ( - eos_mask.shape == rollout_log_probs.shape - ), f"eos_mask {eos_mask.shape} vs rollout_log_probs {rollout_log_probs.shape}" - - # Use the new threshold parameters upper = args.tis_threshold_upper lower = args.tis_threshold_lower - assert upper == 2.0 - tis_weights, tis_metrics = compute_tis_weights( - old_log_prob=old_log_probs, - rollout_log_prob=rollout_log_probs, - eos_mask=eos_mask, + + total_lengths = batch["total_lengths"] + response_lengths = batch["response_lengths"] + + # 1) 组装全序列 old/rollout/mask(CP=1 直接拼接;CP>1 用 all_gather 重建) + if cp_size == 1: + full_old_list = batch["log_probs"] + full_rollout_list = batch["rollout_log_probs"] + full_mask_list = batch["loss_masks"] + else: + full_old_list = [ + all_gather_with_cp(lp, total_len, resp_len) + for lp, total_len, resp_len in zip(batch["log_probs"], total_lengths, response_lengths) + ] + full_rollout_list = [ + all_gather_with_cp(lp, total_len, resp_len) + for lp, total_len, resp_len in zip(batch["rollout_log_probs"], total_lengths, response_lengths) + ] + # loss_masks 已是每样本全序列 + full_mask_list = batch["loss_masks"] + + old_full_flat = torch.cat(full_old_list, dim=0) + rollout_full_flat = torch.cat(full_rollout_list, dim=0) + mask_full_flat = torch.cat(full_mask_list, dim=0).to(device=log_probs.device) + + # 2) 基本一致性与格式校验 + assert old_full_flat.shape == rollout_full_flat.shape == mask_full_flat.shape + loss_mask_str = "".join([str(int(x)) for x in mask_full_flat]) + pattern = r"^1+(0+1+)*0*1*$" + assert re.fullmatch(pattern, loss_mask_str), "loss_mask format is not expected!" + + # 3) 全序列上计算 TIS 权重和指标 + tis_weights_full_flat, tis_metrics = compute_tis_weights( + old_log_prob=old_full_flat, + rollout_log_prob=rollout_full_flat, + loss_mask=mask_full_flat, level=getattr(args, "tis_level", "token"), mode=getattr(args, "tis_mode", "truncate"), upper_threshold=upper, lower_threshold=lower, veto_threshold=getattr(args, "tis_veto_threshold", 1e-4), safety_bound=getattr(args, "tis_safety_bound", 20.0), - response_lengths=batch["response_lengths"], + response_lengths=response_lengths, ) # On-policy ratio for monitoring (π_new/π_old) ois = (-ppo_kl).exp() - if tis_weights is not None: - pg_loss = pg_loss * tis_weights + # 4) 应用权重(CP>1 时回切至本地切片) + if tis_weights_full_flat is not None: + if cp_size == 1: + pg_loss = pg_loss * tis_weights_full_flat + else: + per_seq_weights = list(torch.split(tis_weights_full_flat, [int(l) for l in response_lengths], dim=0)) + local_weight_chunks = [ + slice_log_prob_with_cp(w, total_len, resp_len) + for w, total_len, resp_len in zip(per_seq_weights, total_lengths, response_lengths) + ] + tis_weights_local_flat = torch.cat(local_weight_chunks, dim=0) + pg_loss = pg_loss * tis_weights_local_flat - # KL metrics next to TIS metrics + # 5) KL 指标统一基于全序列 kl_metrics = compute_kl_metrics( - old_log_prob=old_log_probs, - rollout_log_prob=rollout_log_probs, - eos_mask=eos_mask, - response_lengths=batch["response_lengths"], + old_log_prob=old_full_flat, + rollout_log_prob=rollout_full_flat, + loss_mask=mask_full_flat, + response_lengths=response_lengths, ) pg_loss = sum_of_sample_mean(pg_loss) diff --git a/slime/utils/tis.py b/slime/utils/tis.py index 869745fdc5..7835e5c53b 100644 --- a/slime/utils/tis.py +++ b/slime/utils/tis.py @@ -40,9 +40,12 @@ def per_seq_masked_mean( 支持二维 [B, T] 与拍平后一维、并提供 response_lengths 的两种输入形态。 """ if response_lengths is not None and len(response_lengths) > 0: - xs = torch.split(x, [int(l) for l in response_lengths], dim=0) - ms = torch.split(mask, [int(l) for l in response_lengths], dim=0) - seq_means = [masked_mean(xi, mi) for xi, mi in zip(xs, ms)] + sequence_log_ratios = torch.split(x, [int(l) for l in response_lengths], dim=0) + sequence_loss_masks = torch.split(mask, [int(l) for l in response_lengths], dim=0) + seq_means = [ + masked_mean(sequence_log_ratio, sequence_loss_mask) + for sequence_log_ratio, sequence_loss_mask in zip(sequence_log_ratios, sequence_loss_masks) + ] return torch.stack(seq_means).sum() # fallback:视为单一样本 return masked_mean(x, mask).unsqueeze(0).sum() @@ -52,7 +55,7 @@ def compute_tis_weights( *, old_log_prob: torch.Tensor, rollout_log_prob: torch.Tensor, - eos_mask: torch.Tensor, + loss_mask: torch.Tensor, level: str = "token", mode: str = "truncate", upper_threshold: Optional[float] = None, @@ -72,7 +75,7 @@ def compute_tis_weights( Args: old_log_prob: Flattened log probs from training backend. Shape: [sum(response_lengths)] rollout_log_prob: Flattened log probs from rollout backend. Shape: [sum(response_lengths)] - eos_mask: Flattened mask aligned with flattened tensors. Shape: [sum(response_lengths)] + loss_mask: Flattened mask aligned with flattened tensors. Shape: [sum(response_lengths)] level: The aggregation level for the importance sampling weights. - "token": per-token importance sampling weights, biased low variance. - "sequence": product over tokens, unbiased but high variance. @@ -91,8 +94,8 @@ def compute_tis_weights( metrics: The metrics for the importance sampling weights. """ assert ( - eos_mask.shape == old_log_prob.shape and eos_mask.shape == rollout_log_prob.shape - ), "eos_mask, old_log_prob, and rollout_log_prob must have the same shape" + loss_mask.shape == old_log_prob.shape and loss_mask.shape == rollout_log_prob.shape + ), "loss_mask, old_log_prob, and rollout_log_prob must have the same shape" assert response_lengths is not None and len(response_lengths) > 0, "response_lengths must be provided" if upper_threshold is None: @@ -114,21 +117,21 @@ def compute_tis_weights( weights = torch.exp(log_ratio_safe) elif level in ["sequence", "geometric"]: # Sequence-level/geometric: compute per-sequence aggregate in log-space, then expand to tokens - xs = torch.split(log_ratio, [int(l) for l in response_lengths], dim=0) - ms = torch.split(eos_mask, [int(l) for l in response_lengths], dim=0) + sequence_log_ratios = torch.split(log_ratio, [int(l) for l in response_lengths], dim=0) + sequence_loss_masks = torch.split(loss_mask, [int(l) for l in response_lengths], dim=0) per_seq_vals = [] - for xi, mi in zip(xs, ms): + for sequence_log_ratio, sequence_loss_mask in zip(sequence_log_ratios, sequence_loss_masks): if level == "sequence": - val = (xi * mi).sum() + val = (sequence_log_ratio * sequence_loss_mask).sum() else: # geometric - val = masked_mean(xi, mi) + val = masked_mean(sequence_log_ratio, sequence_loss_mask) per_seq_vals.append(torch.clamp(val, min=-safety_bound, max=safety_bound)) per_seq_vals = torch.stack(per_seq_vals) # [num_sequences] per_seq_weights = torch.exp(per_seq_vals) # Expand to per-token weights per sequence expanded = [] - for w, xi in zip(per_seq_weights, xs): - expanded.append(torch.ones_like(xi) * w) + for w, sequence_log_ratio in zip(per_seq_weights, sequence_log_ratios): + expanded.append(torch.ones_like(sequence_log_ratio) * w) weights = torch.cat(expanded, dim=0) # For metrics that need the aggregated log-ratio, keep per-seq values log_ratio_for_metrics = per_seq_vals @@ -138,7 +141,7 @@ def compute_tis_weights( log_veto_threshold = torch.log(torch.tensor(veto_threshold, device=device)) # Veto sequences with any token's log ratio below the threshold. # log(π_training / π_rollout) < log(veto_threshold) ⟺ π_training / π_rollout < veto_threshold - catastrophic_tokens = (log_ratio < log_veto_threshold) & eos_mask.bool() + catastrophic_tokens = (log_ratio < log_veto_threshold) & loss_mask.bool() # Build per-sequence veto and expand to tokens cat_chunks = torch.split(catastrophic_tokens, [int(l) for l in response_lengths], dim=0) has_catastrophic_per_seq = torch.tensor([chunk.any() for chunk in cat_chunks], device=device) @@ -153,7 +156,7 @@ def compute_tis_weights( metrics = compute_tis_metrics( tis_weights=weights, log_ratio_for_metrics=log_ratio_for_metrics, - eos_mask=eos_mask, + loss_mask=loss_mask, level=level, upper_threshold=upper_threshold, lower_threshold=lower_threshold, @@ -178,15 +181,15 @@ def compute_tis_weights( clip_mask = (weights >= lower_threshold) & (weights <= upper_threshold) clip_mask = clip_mask.float() clipped_indicator = 1 - clip_mask - metrics["tis_token_clipped_fraction"] = masked_mean(clipped_indicator, eos_mask) - sequence_has_clipped = masked_sum(clipped_indicator, eos_mask, dim=-1) > 0 + metrics["tis_token_clipped_fraction"] = masked_mean(clipped_indicator, loss_mask) + sequence_has_clipped = masked_sum(clipped_indicator, loss_mask, dim=-1) > 0 metrics["tis_sequence_clipped_fraction"] = sequence_has_clipped.float().mean() weights = weights * clip_mask else: raise ValueError(f"Invalid tis mode: {mode}") weights = weights * veto_mask - weights = weights * eos_mask + weights = weights * loss_mask weights = weights.detach() metrics["tis_threshold_upper"] = upper_threshold @@ -201,7 +204,7 @@ def compute_tis_metrics( *, tis_weights: torch.Tensor, log_ratio_for_metrics: torch.Tensor, - eos_mask: torch.Tensor, + loss_mask: torch.Tensor, level: str, upper_threshold: float, lower_threshold: float, @@ -218,12 +221,12 @@ def compute_tis_metrics( """ metrics: Dict[str, Any] = {} - assert eos_mask.shape == tis_weights.shape, "eos_mask and tis_weights must have the same shape" + assert loss_mask.shape == tis_weights.shape, "loss_mask and tis_weights must have the same shape" # Counts/fractions reported as sum over sequences; external reducer divides by num_samples metrics["tis_veto_fraction"] = has_catastrophic.float().sum() metrics["tis_catastrophic_token_fraction"] = per_seq_masked_mean( - catastrophic_tokens.float(), eos_mask, response_lengths=response_lengths + catastrophic_tokens.float(), loss_mask, response_lengths=response_lengths ) metrics["tis_level"] = level assert upper_threshold == 2.0 @@ -246,34 +249,51 @@ def compute_tis_metrics( below_lower = (log_ratio_for_metrics < log_lower_threshold).float().sum() metrics["tis_ratio_fraction_exceeds_upper"] = exceeds_upper metrics["tis_ratio_fraction_below_lower"] = below_lower - metrics["tis_mean"] = per_seq_masked_mean(tis_weights, eos_mask, response_lengths=response_lengths) + metrics["tis_mean"] = per_seq_masked_mean(tis_weights, loss_mask, response_lengths=response_lengths) else: - metrics["tis_mean"] = per_seq_masked_mean(tis_weights, eos_mask, response_lengths=response_lengths) + metrics["tis_mean"] = per_seq_masked_mean(tis_weights, loss_mask, response_lengths=response_lengths) exceeds_upper = (tis_weights > upper_threshold).float() below_lower = (tis_weights < lower_threshold).float() metrics["tis_ratio_fraction_exceeds_upper"] = per_seq_masked_mean( - exceeds_upper, eos_mask, response_lengths=response_lengths + exceeds_upper, loss_mask, response_lengths=response_lengths ) metrics["tis_ratio_fraction_below_lower"] = per_seq_masked_mean( - below_lower, eos_mask, response_lengths=response_lengths + below_lower, loss_mask, response_lengths=response_lengths ) # Per-sequence std and ESS, reported as sum across sequences weights_for_std = tis_weights.clamp(min=lower_threshold, max=upper_threshold) - xs = torch.split(tis_weights, [int(l) for l in response_lengths], dim=0) - ms = torch.split(eos_mask, [int(l) for l in response_lengths], dim=0) - per_seq_mean = torch.stack([masked_mean(xi, mi) for xi, mi in zip(xs, ms)]) + sequence_log_ratios = torch.split(tis_weights, [int(l) for l in response_lengths], dim=0) + sequence_loss_masks = torch.split(loss_mask, [int(l) for l in response_lengths], dim=0) + per_seq_mean = torch.stack( + [ + masked_mean(sequence_log_ratio, sequence_loss_mask) + for sequence_log_ratio, sequence_loss_mask in zip(sequence_log_ratios, sequence_loss_masks) + ] + ) per_seq_var = ( torch.stack( - [masked_mean(xi.clamp(min=lower_threshold, max=upper_threshold).square(), mi) for xi, mi in zip(xs, ms)] + [ + masked_mean( + sequence_log_ratio.clamp(min=lower_threshold, max=upper_threshold).square(), sequence_loss_mask + ) + for sequence_log_ratio, sequence_loss_mask in zip(sequence_log_ratios, sequence_loss_masks) + ] ) - per_seq_mean.square() ) per_seq_std = torch.sqrt(torch.clamp(per_seq_var, min=0.0)) metrics["tis_std"] = per_seq_std.sum() # ESS per sequence using normalized weights - weights_for_ess_list = [xi / (pm + 1e-8) for xi, pm in zip(xs, per_seq_mean)] - per_seq_ess = torch.stack([1.0 / masked_mean(xi.square(), mi) for xi, mi in zip(weights_for_ess_list, ms)]) + weights_for_ess_list = [ + sequence_log_ratio / (pm + 1e-8) for sequence_log_ratio, pm in zip(sequence_log_ratios, per_seq_mean) + ] + per_seq_ess = torch.stack( + [ + 1.0 / masked_mean(sequence_log_ratio.square(), sequence_loss_mask) + for sequence_log_ratio, sequence_loss_mask in zip(weights_for_ess_list, sequence_loss_masks) + ] + ) metrics["tis_eff_sample_size"] = per_seq_ess.sum() seq_mean = per_seq_mean @@ -289,28 +309,34 @@ def compute_kl_metrics( *, old_log_prob: torch.Tensor, rollout_log_prob: torch.Tensor, - eos_mask: Optional[torch.Tensor], + loss_mask: Optional[torch.Tensor], response_lengths: Optional[list[int]] = None, ) -> Dict[str, Any]: metrics: Dict[str, Any] = {} device = old_log_prob.device - if eos_mask is None: - eos_mask = torch.ones_like(old_log_prob, dtype=torch.bool, device=device) + if loss_mask is None: + loss_mask = torch.ones_like(old_log_prob, dtype=torch.bool, device=device) # Direct estimator for KL(pi_rollout || pi_old): per-seq mean then sum (1D inputs only) - assert response_lengths is not None and eos_mask is not None - xs = torch.split(rollout_log_prob - old_log_prob, [int(l) for l in response_lengths], dim=0) - ms = torch.split(eos_mask, [int(l) for l in response_lengths], dim=0) - per_seq = [masked_mean(xi, mi) for xi, mi in zip(xs, ms)] + assert response_lengths is not None and loss_mask is not None + sequence_log_ratios = torch.split(rollout_log_prob - old_log_prob, [int(l) for l in response_lengths], dim=0) + sequence_loss_masks = torch.split(loss_mask, [int(l) for l in response_lengths], dim=0) + per_seq = [ + masked_mean(sequence_log_ratio, sequence_loss_mask) + for sequence_log_ratio, sequence_loss_mask in zip(sequence_log_ratios, sequence_loss_masks) + ] metrics["rollout_kl"] = torch.stack(per_seq).sum() # K3 estimator: E[exp(log(pi_old/pi_rollout)) - log(pi_old/pi_rollout) - 1] log_ratio = old_log_prob - rollout_log_prob k3_matrix = torch.exp(log_ratio) - log_ratio - 1 - xs = torch.split(k3_matrix, [int(l) for l in response_lengths], dim=0) - ms = torch.split(eos_mask, [int(l) for l in response_lengths], dim=0) - per_seq = [masked_mean(xi, mi) for xi, mi in zip(xs, ms)] + sequence_log_ratios = torch.split(k3_matrix, [int(l) for l in response_lengths], dim=0) + sequence_loss_masks = torch.split(loss_mask, [int(l) for l in response_lengths], dim=0) + per_seq = [ + masked_mean(sequence_log_ratio, sequence_loss_mask) + for sequence_log_ratio, sequence_loss_mask in zip(sequence_log_ratios, sequence_loss_masks) + ] metrics["rollout_k3_kl"] = torch.stack(per_seq).sum() # Sequence-level perplexity difference metrics @@ -320,7 +346,7 @@ def compute_kl_metrics( start = 0 for length in response_lengths: end = start + int(length) - mask_chunk = eos_mask[start:end] + mask_chunk = loss_mask[start:end] seq_rollout_means.append(masked_mean(rollout_log_prob[start:end], mask_chunk)) seq_old_means.append(masked_mean(old_log_prob[start:end], mask_chunk)) start = end From 71194c39c80de6a9f83fad56ca243647b59f7bc3 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Fri, 10 Oct 2025 01:29:02 +0000 Subject: [PATCH 19/58] slice tis with slice_log_prob_with_cp --- slime/backends/megatron_utils/loss.py | 101 ++++++++++---------------- slime/utils/tis.py | 19 +++++ 2 files changed, 56 insertions(+), 64 deletions(-) diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 48aaf28a76..41dfbf8a16 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -1,4 +1,3 @@ -import re from typing import Union import torch @@ -15,7 +14,7 @@ get_reinforce_plus_plus_baseline_advantages, get_reinforce_plus_plus_returns, ) -from slime.utils.tis import compute_kl_metrics, compute_tis_weights +from slime.utils.tis import assert_tis_input_format, compute_tis_weights from .cp_utils import ( all_gather_with_cp, @@ -314,78 +313,52 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): # Apply TIS off-policy correction using importance sampling if enabled if args.use_tis: assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS" - cp_size = mpu.get_context_parallel_world_size() - upper = args.tis_threshold_upper - lower = args.tis_threshold_lower - assert upper == 2.0 - total_lengths = batch["total_lengths"] - response_lengths = batch["response_lengths"] + full_log_probs = [ + all_gather_with_cp(log_prob, total_length, response_length) + for log_prob, total_length, response_length in zip(log_probs, total_lengths, response_lengths) + ] + full_old_log_probs = [ + all_gather_with_cp(old_log_prob, total_length, response_length) + for old_log_prob, total_length, response_length in zip(old_log_probs, total_lengths, response_lengths) + ] - # 1) 组装全序列 old/rollout/mask(CP=1 直接拼接;CP>1 用 all_gather 重建) - if cp_size == 1: - full_old_list = batch["log_probs"] - full_rollout_list = batch["rollout_log_probs"] - full_mask_list = batch["loss_masks"] - else: - full_old_list = [ - all_gather_with_cp(lp, total_len, resp_len) - for lp, total_len, resp_len in zip(batch["log_probs"], total_lengths, response_lengths) - ] - full_rollout_list = [ - all_gather_with_cp(lp, total_len, resp_len) - for lp, total_len, resp_len in zip(batch["rollout_log_probs"], total_lengths, response_lengths) - ] - # loss_masks 已是每样本全序列 - full_mask_list = batch["loss_masks"] - - old_full_flat = torch.cat(full_old_list, dim=0) - rollout_full_flat = torch.cat(full_rollout_list, dim=0) - mask_full_flat = torch.cat(full_mask_list, dim=0).to(device=log_probs.device) - - # 2) 基本一致性与格式校验 - assert old_full_flat.shape == rollout_full_flat.shape == mask_full_flat.shape - loss_mask_str = "".join([str(int(x)) for x in mask_full_flat]) - pattern = r"^1+(0+1+)*0*1*$" - assert re.fullmatch(pattern, loss_mask_str), "loss_mask format is not expected!" - - # 3) 全序列上计算 TIS 权重和指标 - tis_weights_full_flat, tis_metrics = compute_tis_weights( - old_log_prob=old_full_flat, - rollout_log_prob=rollout_full_flat, - loss_mask=mask_full_flat, + # old_log_probs, log_probs, loss_masks are all concated into 1D tensor + full_old_log_probs = torch.cat(full_old_log_probs, dim=0) + full_log_probs = torch.cat(full_log_probs, dim=0) + # loss_mask is not sliced by cp, so no need to all_gather + full_loss_masks = torch.cat(batch["loss_masks"], dim=0) + + assert_tis_input_format(full_old_log_probs, full_log_probs, full_loss_masks) + + tis_weights, tis_metrics = compute_tis_weights( + old_log_prob=full_old_log_probs, + rollout_log_prob=full_log_probs, + loss_mask=full_loss_masks, level=getattr(args, "tis_level", "token"), mode=getattr(args, "tis_mode", "truncate"), - upper_threshold=upper, - lower_threshold=lower, + upper_threshold=getattr(args, "tis_threshold_upper", 2.0), + lower_threshold=getattr(args, "tis_threshold_lower", 1.0 / getattr(args, "tis_threshold_upper", 2.0)), veto_threshold=getattr(args, "tis_veto_threshold", 1e-4), safety_bound=getattr(args, "tis_safety_bound", 20.0), - response_lengths=response_lengths, + response_lengths=total_lengths, ) - # On-policy ratio for monitoring (π_new/π_old) ois = (-ppo_kl).exp() - # 4) 应用权重(CP>1 时回切至本地切片) - if tis_weights_full_flat is not None: - if cp_size == 1: - pg_loss = pg_loss * tis_weights_full_flat - else: - per_seq_weights = list(torch.split(tis_weights_full_flat, [int(l) for l in response_lengths], dim=0)) - local_weight_chunks = [ - slice_log_prob_with_cp(w, total_len, resp_len) - for w, total_len, resp_len in zip(per_seq_weights, total_lengths, response_lengths) - ] - tis_weights_local_flat = torch.cat(local_weight_chunks, dim=0) - pg_loss = pg_loss * tis_weights_local_flat - - # 5) KL 指标统一基于全序列 - kl_metrics = compute_kl_metrics( - old_log_prob=old_full_flat, - rollout_log_prob=rollout_full_flat, - loss_mask=mask_full_flat, - response_lengths=response_lengths, - ) + # tis_weights is a 1D tensor, should be sliced to the local cp rank + local_tis_chunks = [] + start = 0 + for total_len, response_len in zip(total_lengths, response_lengths): + end = start + int(response_len) + seq_weights = tis_weights[start:end] + # Slice to the two local chunks of this CP rank + local_chunk = slice_log_prob_with_cp(seq_weights, int(total_len), int(response_len)) + local_tis_chunks.append(local_chunk) + start = end + tis_weights = torch.cat(local_tis_chunks, dim=0) + + pg_loss = pg_loss * tis_weights pg_loss = sum_of_sample_mean(pg_loss) pg_clipfrac = sum_of_sample_mean(pg_clipfrac) diff --git a/slime/utils/tis.py b/slime/utils/tis.py index 7835e5c53b..706d8117a2 100644 --- a/slime/utils/tis.py +++ b/slime/utils/tis.py @@ -1,8 +1,27 @@ +import re from typing import Any, Dict, Optional, Tuple import torch +def assert_tis_input_format( + full_old_log_probs: torch.Tensor, + full_log_probs: torch.Tensor, + full_loss_masks: torch.Tensor, +) -> None: + assert all( + tensor.dim() == 1 for tensor in [full_old_log_probs, full_log_probs, full_loss_masks] + ), f"{full_old_log_probs.dim()} vs {full_log_probs.dim()} vs {full_loss_masks.dim()}" + + assert ( + full_old_log_probs.shape == full_log_probs.shape and full_old_log_probs.shape == full_loss_masks.shape + ), f"{full_old_log_probs.shape} vs {full_log_probs.shape} vs {full_loss_masks.shape}" + + loss_mask_str = "".join([str(int(x)) for x in full_loss_masks]) + pattern = r"^1+(0+1+)*0*1*$" + assert re.fullmatch(pattern, loss_mask_str), "loss_mask format is not expected!" + + def masked_sum(x: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor: """ Computes the sum of the tensor x, masked by the mask. From 92e6e976f9b825ca8745b1bb7762cb27e79585b7 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Fri, 10 Oct 2025 04:13:10 +0000 Subject: [PATCH 20/58] [todo] filter out catastrophic tokens --- slime/backends/megatron_utils/loss.py | 19 +++-- slime/utils/tis.py | 110 +++++++------------------- 2 files changed, 37 insertions(+), 92 deletions(-) diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 41dfbf8a16..f2968cf2f2 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -14,7 +14,7 @@ get_reinforce_plus_plus_baseline_advantages, get_reinforce_plus_plus_returns, ) -from slime.utils.tis import assert_tis_input_format, compute_tis_weights +from slime.utils.tis import compute_tis_weights from .cp_utils import ( all_gather_with_cp, @@ -324,24 +324,23 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): ] # old_log_probs, log_probs, loss_masks are all concated into 1D tensor - full_old_log_probs = torch.cat(full_old_log_probs, dim=0) - full_log_probs = torch.cat(full_log_probs, dim=0) + full_old_log_probs_flat = torch.cat(full_old_log_probs, dim=0) + full_log_probs_flat = torch.cat(full_log_probs, dim=0) # loss_mask is not sliced by cp, so no need to all_gather - full_loss_masks = torch.cat(batch["loss_masks"], dim=0) - - assert_tis_input_format(full_old_log_probs, full_log_probs, full_loss_masks) + full_loss_masks_flat = torch.cat(batch["loss_masks"], dim=0) tis_weights, tis_metrics = compute_tis_weights( - old_log_prob=full_old_log_probs, - rollout_log_prob=full_log_probs, - loss_mask=full_loss_masks, + old_log_prob_flat=full_old_log_probs_flat, + rollout_log_prob_flat=full_log_probs_flat, + loss_mask_flat=full_loss_masks_flat, level=getattr(args, "tis_level", "token"), mode=getattr(args, "tis_mode", "truncate"), upper_threshold=getattr(args, "tis_threshold_upper", 2.0), lower_threshold=getattr(args, "tis_threshold_lower", 1.0 / getattr(args, "tis_threshold_upper", 2.0)), veto_threshold=getattr(args, "tis_veto_threshold", 1e-4), safety_bound=getattr(args, "tis_safety_bound", 20.0), - response_lengths=total_lengths, + response_lengths=response_lengths, + total_lengths=total_lengths, ) ois = (-ppo_kl).exp() diff --git a/slime/utils/tis.py b/slime/utils/tis.py index 706d8117a2..6a89f2869e 100644 --- a/slime/utils/tis.py +++ b/slime/utils/tis.py @@ -1,80 +1,13 @@ -import re from typing import Any, Dict, Optional, Tuple import torch -def assert_tis_input_format( - full_old_log_probs: torch.Tensor, - full_log_probs: torch.Tensor, - full_loss_masks: torch.Tensor, -) -> None: - assert all( - tensor.dim() == 1 for tensor in [full_old_log_probs, full_log_probs, full_loss_masks] - ), f"{full_old_log_probs.dim()} vs {full_log_probs.dim()} vs {full_loss_masks.dim()}" - - assert ( - full_old_log_probs.shape == full_log_probs.shape and full_old_log_probs.shape == full_loss_masks.shape - ), f"{full_old_log_probs.shape} vs {full_log_probs.shape} vs {full_loss_masks.shape}" - - loss_mask_str = "".join([str(int(x)) for x in full_loss_masks]) - pattern = r"^1+(0+1+)*0*1*$" - assert re.fullmatch(pattern, loss_mask_str), "loss_mask format is not expected!" - - -def masked_sum(x: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor: - """ - Computes the sum of the tensor x, masked by the mask. - - x = [[1, 2, 3], [4, 5, 6]] - mask = [[1, 1, 1], [1, 1, 0]] - masked_sum(x, mask, dim=-1) = [6, 9] - """ - valid_tokens = mask.sum(dim=dim) - assert valid_tokens.min() > 0, "any sequence must have at least one valid token" - assert x.shape == mask.shape, "x and mask must have the same shape" - return (x * mask).sum(dim=dim) - - -def masked_mean(x: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor: - """ - Computes the mean of the tensor x, masked by the mask. - - x = [[1, 2, 3], [4, 5, 6]] - mask = [[1, 1, 1], [1, 1, 0]] - masked_mean(x, mask, dim=-1) = [2, 4.5] - """ - valid_tokens = mask.sum(dim=dim) - assert valid_tokens.min() > 0, "any sequence must have at least one valid token" - return masked_sum(x, mask, dim=dim) / valid_tokens - - -def per_seq_masked_mean( - x: torch.Tensor, - mask: torch.Tensor, - response_lengths: Optional[list[int]] = None, -) -> torch.Tensor: - """ - 计算按样本的 masked mean 后再求和,返回一个可加性的标量(适配 DP 汇总)。 - 支持二维 [B, T] 与拍平后一维、并提供 response_lengths 的两种输入形态。 - """ - if response_lengths is not None and len(response_lengths) > 0: - sequence_log_ratios = torch.split(x, [int(l) for l in response_lengths], dim=0) - sequence_loss_masks = torch.split(mask, [int(l) for l in response_lengths], dim=0) - seq_means = [ - masked_mean(sequence_log_ratio, sequence_loss_mask) - for sequence_log_ratio, sequence_loss_mask in zip(sequence_log_ratios, sequence_loss_masks) - ] - return torch.stack(seq_means).sum() - # fallback:视为单一样本 - return masked_mean(x, mask).unsqueeze(0).sum() - - def compute_tis_weights( *, - old_log_prob: torch.Tensor, - rollout_log_prob: torch.Tensor, - loss_mask: torch.Tensor, + old_log_prob_flat: torch.Tensor, + rollout_log_prob_flat: torch.Tensor, + loss_mask_flat: torch.Tensor, level: str = "token", mode: str = "truncate", upper_threshold: Optional[float] = None, @@ -82,6 +15,7 @@ def compute_tis_weights( veto_threshold: float = 1e-4, safety_bound: float = 20.0, response_lengths: Optional[list[int]] = None, + total_lengths: Optional[list[int]] = None, ) -> Tuple[Optional[torch.Tensor], Dict[str, Any]]: """ Compute the truncated importance sampling (TIS) weights and metrics. @@ -92,9 +26,11 @@ def compute_tis_weights( https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda Args: - old_log_prob: Flattened log probs from training backend. Shape: [sum(response_lengths)] - rollout_log_prob: Flattened log probs from rollout backend. Shape: [sum(response_lengths)] - loss_mask: Flattened mask aligned with flattened tensors. Shape: [sum(response_lengths)] + old_log_prob_flat: Flattened log probs from training backend. Shape: [sum(response_lengths)] + rollout_log_prob_flat: Flattened log probs from rollout backend. Shape: [sum(response_lengths)] + loss_mask_flat: Flattened mask aligned with flattened tensors. Shape: [sum(response_lengths)] + Note that for single turn RL, the loss_mask_flat is [1] * sum(response_lengths) + For multi turn RL, the tool response will be marked as 0 in the loss_mask_flat. level: The aggregation level for the importance sampling weights. - "token": per-token importance sampling weights, biased low variance. - "sequence": product over tokens, unbiased but high variance. @@ -107,29 +43,37 @@ def compute_tis_weights( If not provided, it will be set to 1.0 / upper_threshold. veto_threshold: If any token's importance sampling weight is less than this, zero the entire sequence weight. safety_bound: The safety bound for the log-space ratio to avoid numerical overflow. + response_lengths: The length of the response for each sequence. + total_lengths: The total length of the whole sequence for each sequence. Returns: weights: The importance sampling weights. [batch_size, seq_len] metrics: The metrics for the importance sampling weights. """ + + assert all( + tensor.dim() == 1 for tensor in [old_log_prob_flat, rollout_log_prob_flat, loss_mask_flat] + ), f"{old_log_prob_flat.dim()} vs {rollout_log_prob_flat.dim()} vs {loss_mask_flat.dim()}" + assert ( - loss_mask.shape == old_log_prob.shape and loss_mask.shape == rollout_log_prob.shape - ), "loss_mask, old_log_prob, and rollout_log_prob must have the same shape" - assert response_lengths is not None and len(response_lengths) > 0, "response_lengths must be provided" + old_log_prob_flat.shape == rollout_log_prob_flat.shape and old_log_prob_flat.shape == loss_mask_flat.shape + ), f"{old_log_prob_flat.shape} vs {rollout_log_prob_flat.shape} vs {loss_mask_flat.shape}" if upper_threshold is None: return None, {} if lower_threshold is None: lower_threshold = 1.0 / upper_threshold - device = old_log_prob.device - log_ratio = old_log_prob - rollout_log_prob + device = old_log_prob_flat.device + log_ratio = old_log_prob_flat - rollout_log_prob_flat log_upper_threshold = torch.log(torch.tensor(upper_threshold, device=device)) log_lower_threshold = torch.log(torch.tensor(lower_threshold, device=device)) + # compute TIS weights without truncation/clipping + if level == "token": - # Token-level IS: π_training(a|s) / π_rollout(a|s) per token + # Token-level IS: π_training(a|s) / π_rollout(a|s) per token # The truncation will be applied later. log_ratio_for_metrics = log_ratio # [sum(response_lengths)] log_ratio_safe = torch.clamp(log_ratio, min=-safety_bound, max=safety_bound) @@ -137,13 +81,13 @@ def compute_tis_weights( elif level in ["sequence", "geometric"]: # Sequence-level/geometric: compute per-sequence aggregate in log-space, then expand to tokens sequence_log_ratios = torch.split(log_ratio, [int(l) for l in response_lengths], dim=0) - sequence_loss_masks = torch.split(loss_mask, [int(l) for l in response_lengths], dim=0) + sequence_loss_masks = torch.split(loss_mask_flat, [int(l) for l in response_lengths], dim=0) per_seq_vals = [] for sequence_log_ratio, sequence_loss_mask in zip(sequence_log_ratios, sequence_loss_masks): if level == "sequence": val = (sequence_log_ratio * sequence_loss_mask).sum() else: # geometric - val = masked_mean(sequence_log_ratio, sequence_loss_mask) + val = (sequence_log_ratio * sequence_loss_mask).sum() / sequence_loss_mask.sum() per_seq_vals.append(torch.clamp(val, min=-safety_bound, max=safety_bound)) per_seq_vals = torch.stack(per_seq_vals) # [num_sequences] per_seq_weights = torch.exp(per_seq_vals) @@ -157,10 +101,12 @@ def compute_tis_weights( else: raise ValueError(f"Invalid importance sampling level: {level}") + # TODO:继续 filter out catastrophic tokens + log_veto_threshold = torch.log(torch.tensor(veto_threshold, device=device)) # Veto sequences with any token's log ratio below the threshold. # log(π_training / π_rollout) < log(veto_threshold) ⟺ π_training / π_rollout < veto_threshold - catastrophic_tokens = (log_ratio < log_veto_threshold) & loss_mask.bool() + catastrophic_tokens = (log_ratio < log_veto_threshold) & loss_mask_flat.bool() # Build per-sequence veto and expand to tokens cat_chunks = torch.split(catastrophic_tokens, [int(l) for l in response_lengths], dim=0) has_catastrophic_per_seq = torch.tensor([chunk.any() for chunk in cat_chunks], device=device) From dbd27d7f33659152c8cdbf979aa5a8b063215e7d Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Sat, 11 Oct 2025 00:48:02 +0000 Subject: [PATCH 21/58] before metrics --- slime/utils/tis.py | 183 +++++++++++++++++++++++++++++---------------- 1 file changed, 118 insertions(+), 65 deletions(-) diff --git a/slime/utils/tis.py b/slime/utils/tis.py index 6a89f2869e..34414704bd 100644 --- a/slime/utils/tis.py +++ b/slime/utils/tis.py @@ -3,6 +3,41 @@ import torch +def masked_sum( + tensor: torch.Tensor, mask: torch.Tensor, dim: Optional[int] = None, keepdim: bool = False +) -> torch.Tensor: + mask_f = mask if mask.dtype.is_floating_point else mask.float() + return (tensor * mask_f).sum(dim=dim, keepdim=keepdim) + + +def masked_mean( + tensor: torch.Tensor, + mask: torch.Tensor, + dim: Optional[int] = None, + keepdim: bool = False, + eps: float = 1e-8, +) -> torch.Tensor: + mask_f = mask if mask.dtype.is_floating_point else mask.float() + total = (tensor * mask_f).sum(dim=dim, keepdim=keepdim) + denom = mask_f.sum(dim=dim, keepdim=keepdim) + return total / (denom + eps) + + +def per_seq_masked_mean( + tensor: torch.Tensor, + mask: torch.Tensor, + *, + response_lengths: Optional[list[int]] = None, + eps: float = 1e-8, +) -> torch.Tensor: + assert response_lengths is not None and len(response_lengths) > 0 + lengths = [int(l) for l in response_lengths] + chunks = torch.split(tensor, lengths, dim=0) + mask_chunks = torch.split(mask, lengths, dim=0) + per_seq = [masked_mean(t, m, eps=eps) for t, m in zip(chunks, mask_chunks)] + return torch.stack(per_seq).sum() + + def compute_tis_weights( *, old_log_prob_flat: torch.Tensor, @@ -69,59 +104,106 @@ def compute_tis_weights( log_upper_threshold = torch.log(torch.tensor(upper_threshold, device=device)) log_lower_threshold = torch.log(torch.tensor(lower_threshold, device=device)) + log_veto_threshold = torch.log(torch.tensor(veto_threshold, device=device)) - # compute TIS weights without truncation/clipping + # compute TIS raw weights and catastrophe flags in a single big if/else if level == "token": - # Token-level IS: π_training(a|s) / π_rollout(a|s) per token - # The truncation will be applied later. + # Token-level IS: π_training(a|s) / π_rollout(a|s) per token (raw, before veto and truncate/clip) log_ratio_for_metrics = log_ratio # [sum(response_lengths)] log_ratio_safe = torch.clamp(log_ratio, min=-safety_bound, max=safety_bound) weights = torch.exp(log_ratio_safe) + + # Catastrophe detection and veto mask per sequence + + catastrophic_tokens = (log_ratio < log_veto_threshold) & loss_mask_flat.bool() + cat_chunks = torch.split(catastrophic_tokens, [int(l) for l in response_lengths], dim=0) + has_catastrophic_per_seq = torch.tensor([chunk.any() for chunk in cat_chunks], device=device) + veto_mask = torch.cat( + [ + ( + torch.zeros_like(chunk, dtype=torch.float32) + if has_cat + else torch.ones_like(chunk, dtype=torch.float32) + ) + for has_cat, chunk in zip(has_catastrophic_per_seq, cat_chunks) + ], + dim=0, + ) elif level in ["sequence", "geometric"]: - # Sequence-level/geometric: compute per-sequence aggregate in log-space, then expand to tokens - sequence_log_ratios = torch.split(log_ratio, [int(l) for l in response_lengths], dim=0) - sequence_loss_masks = torch.split(loss_mask_flat, [int(l) for l in response_lengths], dim=0) - per_seq_vals = [] - for sequence_log_ratio, sequence_loss_mask in zip(sequence_log_ratios, sequence_loss_masks): + # Sequence-level/geometric: single split, aggregate, then expand (raw, before veto and truncate/clip) + assert response_lengths is not None and len(response_lengths) > 0 + seq_lengths = [int(l) for l in response_lengths] + + seq_log_ratio_chunks = torch.split(log_ratio, seq_lengths, dim=0) + seq_loss_mask_chunks = torch.split(loss_mask_flat, seq_lengths, dim=0) + + per_seq_vals_list = [] + catastrophic_token_chunks = [] + + for log_ratio_chunk, loss_mask_chunk in zip(seq_log_ratio_chunks, seq_loss_mask_chunks): if level == "sequence": - val = (sequence_log_ratio * sequence_loss_mask).sum() + aggregated = (log_ratio_chunk * loss_mask_chunk).sum() # [1] else: # geometric - val = (sequence_log_ratio * sequence_loss_mask).sum() / sequence_loss_mask.sum() - per_seq_vals.append(torch.clamp(val, min=-safety_bound, max=safety_bound)) - per_seq_vals = torch.stack(per_seq_vals) # [num_sequences] - per_seq_weights = torch.exp(per_seq_vals) - # Expand to per-token weights per sequence - expanded = [] - for w, sequence_log_ratio in zip(per_seq_weights, sequence_log_ratios): - expanded.append(torch.ones_like(sequence_log_ratio) * w) - weights = torch.cat(expanded, dim=0) - # For metrics that need the aggregated log-ratio, keep per-seq values - log_ratio_for_metrics = per_seq_vals + aggregated = (log_ratio_chunk * loss_mask_chunk).sum() / loss_mask_chunk.sum() + per_seq_vals_list.append(torch.clamp(aggregated, min=-safety_bound, max=safety_bound)) + + # Catastrophic tokens per sequence + cat_tokens_chunk = (log_ratio_chunk < log_veto_threshold) & loss_mask_chunk.bool() # [response_lengths] + catastrophic_token_chunks.append(cat_tokens_chunk) + + # each sequence only has one value in per_seq_vals_list + per_seq_vals = torch.stack(per_seq_vals_list) # [num_sequences] + # Expand per_seq_vals to match token-level shape for consistent metrics calculation + log_ratio_for_metrics = torch.cat( + [ + torch.ones_like(log_ratio_chunk) * per_seq_val + for per_seq_val, log_ratio_chunk in zip(per_seq_vals, seq_log_ratio_chunks) + ], + dim=0, + ) # [sum(response_lengths)] + weights = torch.exp(log_ratio_for_metrics) + + # Compose catastrophe tensors + catastrophic_tokens = torch.cat(catastrophic_token_chunks, dim=0) # [sum(response_lengths)] + has_catastrophic_per_seq = torch.tensor( + [chunk.any() for chunk in catastrophic_token_chunks], device=device + ) # [num_sequences] + + # Build veto mask per token from per-seq flags + veto_mask = torch.cat( + [ + ( + torch.zeros_like(chunk, dtype=torch.float32) + if has_cat + else torch.ones_like(chunk, dtype=torch.float32) + ) + for has_cat, chunk in zip(has_catastrophic_per_seq, seq_log_ratio_chunks) + ], + dim=0, + ) else: raise ValueError(f"Invalid importance sampling level: {level}") - # TODO:继续 filter out catastrophic tokens + # Apply veto before mode as requested + weights = weights * veto_mask + # Then apply mode (truncate/clip) in a unified way + if mode == "truncate": + weights = weights.clamp(max=upper_threshold) + elif mode == "clip": + clip_mask = (weights >= lower_threshold) & (weights <= upper_threshold) + clip_mask = clip_mask.float() + weights = weights * clip_mask + else: + raise ValueError(f"Invalid tis mode: {mode}") - log_veto_threshold = torch.log(torch.tensor(veto_threshold, device=device)) - # Veto sequences with any token's log ratio below the threshold. - # log(π_training / π_rollout) < log(veto_threshold) ⟺ π_training / π_rollout < veto_threshold - catastrophic_tokens = (log_ratio < log_veto_threshold) & loss_mask_flat.bool() - # Build per-sequence veto and expand to tokens - cat_chunks = torch.split(catastrophic_tokens, [int(l) for l in response_lengths], dim=0) - has_catastrophic_per_seq = torch.tensor([chunk.any() for chunk in cat_chunks], device=device) - veto_mask = torch.cat( - [ - torch.zeros_like(chunk, dtype=torch.float32) if has_cat else torch.ones_like(chunk, dtype=torch.float32) - for has_cat, chunk in zip(has_catastrophic_per_seq, cat_chunks) - ], - dim=0, - ) + weights = weights * loss_mask_flat + weights = weights.detach() metrics = compute_tis_metrics( tis_weights=weights, log_ratio_for_metrics=log_ratio_for_metrics, - loss_mask=loss_mask, + loss_mask=loss_mask_flat, level=level, upper_threshold=upper_threshold, lower_threshold=lower_threshold, @@ -133,35 +215,6 @@ def compute_tis_weights( response_lengths=response_lengths, ) - if mode == "truncate": - # only truncate the weights at the upper threshold - weights = weights.clamp(max=upper_threshold) - elif mode == "clip": - # zero the weights outside the [lower, upper] range - if level in ["sequence", "geometric"]: - seq_weights = weights[:, 0] if weights.dim() > 1 else weights - sequence_clipped = ((seq_weights < lower_threshold) | (seq_weights > upper_threshold)).float() - metrics["tis_sequence_clipped_fraction"] = sequence_clipped.mean() - else: - clip_mask = (weights >= lower_threshold) & (weights <= upper_threshold) - clip_mask = clip_mask.float() - clipped_indicator = 1 - clip_mask - metrics["tis_token_clipped_fraction"] = masked_mean(clipped_indicator, loss_mask) - sequence_has_clipped = masked_sum(clipped_indicator, loss_mask, dim=-1) > 0 - metrics["tis_sequence_clipped_fraction"] = sequence_has_clipped.float().mean() - weights = weights * clip_mask - else: - raise ValueError(f"Invalid tis mode: {mode}") - - weights = weights * veto_mask - weights = weights * loss_mask - weights = weights.detach() - - metrics["tis_threshold_upper"] = upper_threshold - metrics["tis_threshold_lower"] = lower_threshold - metrics["tis_level"] = level - metrics["tis_mode"] = mode - metrics["tis_veto_threshold"] = veto_threshold return weights, metrics From 083c30040fedb6a2e57a3e9dcee1ebcd757dee3f Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Sat, 11 Oct 2025 03:47:21 +0000 Subject: [PATCH 22/58] fix with rollout log probs --- slime/backends/megatron_utils/loss.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index f2968cf2f2..85e797e76c 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -314,9 +314,12 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): if args.use_tis: assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS" - full_log_probs = [ + rollout_log_probs = batch["rollout_log_probs"] + old_log_probs = batch["log_probs"] + + full_rollout_log_probs = [ all_gather_with_cp(log_prob, total_length, response_length) - for log_prob, total_length, response_length in zip(log_probs, total_lengths, response_lengths) + for log_prob, total_length, response_length in zip(rollout_log_probs, total_lengths, response_lengths) ] full_old_log_probs = [ all_gather_with_cp(old_log_prob, total_length, response_length) @@ -325,13 +328,13 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): # old_log_probs, log_probs, loss_masks are all concated into 1D tensor full_old_log_probs_flat = torch.cat(full_old_log_probs, dim=0) - full_log_probs_flat = torch.cat(full_log_probs, dim=0) + full_rollout_log_probs = torch.cat(full_rollout_log_probs, dim=0) # loss_mask is not sliced by cp, so no need to all_gather full_loss_masks_flat = torch.cat(batch["loss_masks"], dim=0) tis_weights, tis_metrics = compute_tis_weights( old_log_prob_flat=full_old_log_probs_flat, - rollout_log_prob_flat=full_log_probs_flat, + rollout_log_prob_flat=full_rollout_log_probs, loss_mask_flat=full_loss_masks_flat, level=getattr(args, "tis_level", "token"), mode=getattr(args, "tis_mode", "truncate"), From ff68b322a53e442b7e84ec0177bff57ffc6cc4a5 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Sat, 11 Oct 2025 04:43:46 +0000 Subject: [PATCH 23/58] [wait for the metircs] --- slime/backends/megatron_utils/loss.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 85e797e76c..6157a56b38 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -403,13 +403,7 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): if args.use_tis: # Backward compatible basic logs reported_loss["ois"] = sum_of_sample_mean(ois).clone().detach() - # Report all TIS and KL metrics uniformly, filtering out non-numeric values - for k, v in {**tis_metrics, **kl_metrics}.items(): - if torch.is_tensor(v): - reported_loss[k] = v.clone().detach() - elif isinstance(v, (int, float)): - reported_loss[k] = torch.tensor(v, device=logits.device) - # Skip string and other non-numeric types + return loss, reported_loss From afc1cd7bf92abdda6fbd2a28629de9b7a9be5e41 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Thu, 9 Oct 2025 02:10:43 +0000 Subject: [PATCH 24/58] logging a whole sequence --- scripts/run-qwen3-30B-A3B.sh | 30 ++++++++++++++------------- slime/backends/megatron_utils/loss.py | 3 +++ 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/scripts/run-qwen3-30B-A3B.sh b/scripts/run-qwen3-30B-A3B.sh index 3cdb749d9d..c1f1cc43ea 100644 --- a/scripts/run-qwen3-30B-A3B.sh +++ b/scripts/run-qwen3-30B-A3B.sh @@ -33,7 +33,7 @@ CKPT_ARGS=( --ref-load /root/Qwen3-30B-A3B_torch_dist --load /root/Qwen3-30B-A3B_slime/ --save /root/Qwen3-30B-A3B_slime/ - --save-interval 20 + --save-interval 1000 ) ROLLOUT_ARGS=( @@ -44,29 +44,29 @@ ROLLOUT_ARGS=( --rollout-shuffle --rm-type deepscaler --num-rollout 3000 - --rollout-batch-size 32 - --n-samples-per-prompt 8 + --rollout-batch-size 8 + --n-samples-per-prompt 4 --rollout-max-response-len 8192 --rollout-temperature 0.8 - --global-batch-size 256 + --global-batch-size 32 --balance-data ) EVAL_ARGS=( - --eval-interval 20 - --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl - --n-samples-per-eval-prompt 16 - --eval-max-response-len 16384 - --eval-top-p 0.7 + # --eval-interval 20 + # --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl + # --n-samples-per-eval-prompt 16 + # --eval-max-response-len 16384 + # --eval-top-p 0.7 ) PERF_ARGS=( - --tensor-model-parallel-size 4 + --tensor-model-parallel-size 2 --sequence-parallel --pipeline-model-parallel-size 1 --context-parallel-size 1 - --expert-model-parallel-size 8 + --expert-model-parallel-size 4 --expert-tensor-parallel-size 1 --recompute-granularity full @@ -86,6 +86,8 @@ GRPO_ARGS=( --entropy-coef 0.00 --eps-clip 0.2 --eps-clip-high 0.28 + + --use-tis ) OPTIMIZER_ARGS=( @@ -109,7 +111,7 @@ WANDB_ARGS=( ) SGLANG_ARGS=( - --rollout-num-gpus-per-engine 8 + --rollout-num-gpus-per-engine 4 --sglang-mem-fraction-static 0.7 --sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256) ) @@ -127,7 +129,7 @@ MISC_ARGS=( # launch the master node of ray in container export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 4 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 # Build the runtime environment JSON with proper variable substitution RUNTIME_ENV_JSON="{ @@ -142,7 +144,7 @@ ray job submit --address="http://127.0.0.1:8265" \ --runtime-env-json="${RUNTIME_ENV_JSON}" \ -- python3 train.py \ --actor-num-nodes 1 \ - --actor-num-gpus-per-node 8 \ + --actor-num-gpus-per-node 4 \ --colocate \ ${MODEL_ARGS[@]} \ ${CKPT_ARGS[@]} \ diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 2290fda8dd..afe5af9863 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -309,6 +309,9 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS" rollout_log_probs = torch.cat(batch["rollout_log_probs"], dim=0) old_log_probs = torch.cat(batch["log_probs"], dim=0) + print("================================") + print(f"old_log_probs.shape: {old_log_probs.shape}, rollout_log_probs.shape: {rollout_log_probs.shape}") + print("================================") tis = torch.exp(old_log_probs - rollout_log_probs) ois = (-ppo_kl).exp() From f31c9cdf26c86b59bb93a1f71d50ca832f55cbe3 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Mon, 13 Oct 2025 02:08:29 +0000 Subject: [PATCH 25/58] tmp commit --- slime/backends/megatron_utils/loss.py | 31 +-- slime/utils/arguments.py | 115 ++++---- slime/utils/tis.py | 376 +++++++------------------- 3 files changed, 172 insertions(+), 350 deletions(-) diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 7b09e3eca2..ed94f9bf56 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -311,7 +311,7 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): pg_loss, pg_clipfrac = compute_policy_loss(ppo_kl, advantages, args.eps_clip, args.eps_clip_high) # Apply TIS off-policy correction using importance sampling if enabled - if args.use_tis: + if args.use_train_infer_tis: assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS" rollout_log_probs = batch["rollout_log_probs"] @@ -326,24 +326,13 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): for old_log_prob, total_length, response_length in zip(old_log_probs, total_lengths, response_lengths) ] - # old_log_probs, log_probs, loss_masks are all concated into 1D tensor - full_old_log_probs_flat = torch.cat(full_old_log_probs, dim=0) - full_rollout_log_probs = torch.cat(full_rollout_log_probs, dim=0) - # loss_mask is not sliced by cp, so no need to all_gather - full_loss_masks_flat = torch.cat(batch["loss_masks"], dim=0) - tis_weights, tis_metrics = compute_tis_weights( - old_log_prob_flat=full_old_log_probs_flat, - rollout_log_prob_flat=full_rollout_log_probs, - loss_mask_flat=full_loss_masks_flat, - level=getattr(args, "tis_level", "token"), - mode=getattr(args, "tis_mode", "truncate"), - upper_threshold=getattr(args, "tis_threshold_upper", 2.0), - lower_threshold=getattr(args, "tis_threshold_lower", 1.0 / getattr(args, "tis_threshold_upper", 2.0)), - veto_threshold=getattr(args, "tis_veto_threshold", 1e-4), - safety_bound=getattr(args, "tis_safety_bound", 20.0), + args=args, + new_log_probs=full_old_log_probs, + old_log_probs=full_rollout_log_probs, + loss_masks=batch["loss_masks"], response_lengths=response_lengths, - total_lengths=total_lengths, + prefix="train_infer", ) ois = (-ppo_kl).exp() @@ -400,9 +389,15 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): if args.use_kl_loss: reported_loss["kl_loss"] = kl_loss.clone().detach() - if args.use_tis: + if args.use_train_infer_tis: # Backward compatible basic logs reported_loss["ois"] = sum_of_sample_mean(ois).clone().detach() + for metric_key, metric_value in tis_metrics.items(): + key_name = f"train_infer_{metric_key}" + if torch.is_tensor(metric_value): + reported_loss[key_name] = metric_value.clone().detach() + elif isinstance(metric_value, (int, float)): + reported_loss[key_name] = torch.tensor(metric_value, device=log_probs.device) return loss, reported_loss diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 4e3f585772..c2bfcc6571 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -25,6 +25,61 @@ def reset_arg(parser, name, **kwargs): parser.add_argument(name, **kwargs) +def add_tis_arguments(parser: argparse.ArgumentParser): + # Off-Policy Correction using Importance Sampling: https://fengyao.notion.site/off-policy-rl + parser.add_argument( + "--use-train-infer-tis", + action="store_true", + default=False, + help="Enable TIS from https://fengyao.notion.site/off-policy-rl for off-policy importance sampling.", + ) + + # Extended TIS controls (levels/modes/thresholds) with backward compatibility + parser.add_argument( + "--train-infer-tis-level", + type=str, + choices=["token", "sequence", "geometric"], + default="token", + help=( + "Aggregation level for importance sampling weights: token (per-token), " + "sequence (product over tokens), geometric (geometric mean)." + ), + ) + parser.add_argument( + "--train-infer-tis-mode", + type=str, + choices=["truncate", "clip"], + default="truncate", + help=( + "Handling mode for IS weights: truncate (cap upper bound, TIS) or clip " + "(zero outside [lower, upper], CIS)." + ), + ) + parser.add_argument( + "--train-infer-tis-eps-clip", + type=float, + default=0.2, + help=( + "Lower clip threshold for IS weights. Default is 0.2. " + "For clip mode uses this value; for truncate mode remains unused." + ), + ) + parser.add_argument( + "--train-infer-tis-eps-clip-high", + type=float, + default=None, + help=("Upper clip threshold for IS weights. Default is None."), + ) + parser.add_argument( + "--train-infer-tis-veto-threshold", + type=float, + default=1e-4, + help=("Per-token veto threshold. If any token ratio < this, zero the entire sequence weight."), + ) + + return parser + + def get_slime_extra_args_provider(add_custom_arguments=None): def add_slime_arguments(parser): # Ray @@ -659,62 +714,6 @@ def add_algo_arguments(parser): "This is useful for doing special loss mask." ), ) - # Off-Policy Correction using Importance Sampling: https://fengyao.notion.site/off-policy-rl - parser.add_argument( - "--use-tis", - action="store_true", - default=False, - help="Enable TIS from https://fengyao.notion.site/off-policy-rl for off-policy importance sampling.", - ) - - # Extended TIS controls (levels/modes/thresholds) with backward compatibility - parser.add_argument( - "--tis-level", - type=str, - choices=["token", "sequence", "geometric"], - default="token", - help=( - "Aggregation level for importance sampling weights: token (per-token), " - "sequence (product over tokens), geometric (geometric mean)." - ), - ) - parser.add_argument( - "--tis-mode", - type=str, - choices=["truncate", "clip"], - default="truncate", - help=( - "Handling mode for IS weights: truncate (cap upper bound, TIS) or clip " - "(zero outside [lower, upper], CIS)." - ), - ) - parser.add_argument( - "--tis-threshold-upper", - type=float, - default=2.0, - help=("Upper threshold for IS weights. Default is 2.0."), - ) - parser.add_argument( - "--tis-threshold-lower", - type=float, - default=0.0, - help=( - "Lower threshold for IS weights. Default is 0.0. " - "For clip mode uses this value; for truncate mode remains unused." - ), - ) - parser.add_argument( - "--tis-veto-threshold", - type=float, - default=1e-4, - help=("Per-token veto threshold. If any token ratio < this, zero the entire sequence weight."), - ) - parser.add_argument( - "--tis-safety-bound", - type=float, - default=20.0, - help=("Safety clamp for log-space ratio to avoid numerical overflow (exp(20) ~= 4.85e8)."), - ) parser.add_argument( "--use-routing-replay", @@ -1018,6 +1017,7 @@ def add_ci_arguments(parser): parser = add_reward_model_arguments(parser) parser = add_rollout_buffer_arguments(parser) parser = add_ci_arguments(parser) + parser = add_tis_arguments(parser) # For megatron parser = add_custom_megatron_plugins_arguments(parser) @@ -1151,6 +1151,9 @@ def slime_validate_args(args): if args.eps_clip_high is None: args.eps_clip_high = args.eps_clip + if args.train_infer_tis_eps_clip_high is None: + args.train_infer_tis_eps_clip_high = args.train_infer_tis_eps_clip + if args.eval_reward_key is None: args.eval_reward_key = args.reward_key diff --git a/slime/utils/tis.py b/slime/utils/tis.py index 34414704bd..6d3ca3a328 100644 --- a/slime/utils/tis.py +++ b/slime/utils/tis.py @@ -6,8 +6,7 @@ def masked_sum( tensor: torch.Tensor, mask: torch.Tensor, dim: Optional[int] = None, keepdim: bool = False ) -> torch.Tensor: - mask_f = mask if mask.dtype.is_floating_point else mask.float() - return (tensor * mask_f).sum(dim=dim, keepdim=keepdim) + return (tensor * mask.float()).sum(dim=dim, keepdim=keepdim) def masked_mean( @@ -17,41 +16,23 @@ def masked_mean( keepdim: bool = False, eps: float = 1e-8, ) -> torch.Tensor: - mask_f = mask if mask.dtype.is_floating_point else mask.float() - total = (tensor * mask_f).sum(dim=dim, keepdim=keepdim) - denom = mask_f.sum(dim=dim, keepdim=keepdim) + total = (tensor * mask.float()).sum(dim=dim, keepdim=keepdim) + denom = mask.float().sum(dim=dim, keepdim=keepdim) return total / (denom + eps) -def per_seq_masked_mean( - tensor: torch.Tensor, - mask: torch.Tensor, - *, - response_lengths: Optional[list[int]] = None, - eps: float = 1e-8, -) -> torch.Tensor: - assert response_lengths is not None and len(response_lengths) > 0 - lengths = [int(l) for l in response_lengths] - chunks = torch.split(tensor, lengths, dim=0) - mask_chunks = torch.split(mask, lengths, dim=0) - per_seq = [masked_mean(t, m, eps=eps) for t, m in zip(chunks, mask_chunks)] - return torch.stack(per_seq).sum() +SAFETY_BOUND = 20.0 def compute_tis_weights( + args, *, - old_log_prob_flat: torch.Tensor, - rollout_log_prob_flat: torch.Tensor, - loss_mask_flat: torch.Tensor, - level: str = "token", - mode: str = "truncate", - upper_threshold: Optional[float] = None, - lower_threshold: Optional[float] = None, - veto_threshold: float = 1e-4, - safety_bound: float = 20.0, + new_log_probs: list[torch.Tensor], + old_log_probs: list[torch.Tensor], + loss_masks: list[torch.Tensor], response_lengths: Optional[list[int]] = None, - total_lengths: Optional[list[int]] = None, -) -> Tuple[Optional[torch.Tensor], Dict[str, Any]]: + prefix: str = "", +) -> Tuple[list[torch.Tensor], Dict[str, Any]]: """ Compute the truncated importance sampling (TIS) weights and metrics. @@ -61,268 +42,111 @@ def compute_tis_weights( https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda Args: - old_log_prob_flat: Flattened log probs from training backend. Shape: [sum(response_lengths)] - rollout_log_prob_flat: Flattened log probs from rollout backend. Shape: [sum(response_lengths)] - loss_mask_flat: Flattened mask aligned with flattened tensors. Shape: [sum(response_lengths)] - Note that for single turn RL, the loss_mask_flat is [1] * sum(response_lengths) - For multi turn RL, the tool response will be marked as 0 in the loss_mask_flat. - level: The aggregation level for the importance sampling weights. - - "token": per-token importance sampling weights, biased low variance. - - "sequence": product over tokens, unbiased but high variance. - - "geometric": geometric mean over tokens, biased, medium variance. - mode: how to handle the importance sampling weights exceeding the thresholds. - - "truncate": cap the importance sampling weights at the upper threshold, i.e., truncated importance sampling. - - "clip": zero the importance sampling weights outside the [lower, upper] range. - upper_threshold: The upper threshold for the importance sampling weights. - lower_threshold: The lower threshold for the importance sampling weights, only used in "clip" mode. - If not provided, it will be set to 1.0 / upper_threshold. - veto_threshold: If any token's importance sampling weight is less than this, zero the entire sequence weight. - safety_bound: The safety bound for the log-space ratio to avoid numerical overflow. + new_log_probs: List of log probs from new policy, one tensor per sequence. + old_log_probs: List of log probs from old policy, one tensor per sequence. + - under training/inference tis + - new_log_probs = training backend + - old_log_probs = rollout backend + - under mini batch tis + - new_log_probs = new batch + - old_log_probs = old batch + loss_masks: List of loss masks, one tensor per sequence. + Note that for single turn RL, the loss_mask is [1] * response_length for each sequence + For multi turn RL, the tool response will be marked as 0 in the loss_mask. response_lengths: The length of the response for each sequence. - total_lengths: The total length of the whole sequence for each sequence. + prefix: The prefix for the parameters, indicating which tis is used. Returns: weights: The importance sampling weights. [batch_size, seq_len] metrics: The metrics for the importance sampling weights. """ - assert all( - tensor.dim() == 1 for tensor in [old_log_prob_flat, rollout_log_prob_flat, loss_mask_flat] - ), f"{old_log_prob_flat.dim()} vs {rollout_log_prob_flat.dim()} vs {loss_mask_flat.dim()}" - - assert ( - old_log_prob_flat.shape == rollout_log_prob_flat.shape and old_log_prob_flat.shape == loss_mask_flat.shape - ), f"{old_log_prob_flat.shape} vs {rollout_log_prob_flat.shape} vs {loss_mask_flat.shape}" - - if upper_threshold is None: - return None, {} - if lower_threshold is None: - lower_threshold = 1.0 / upper_threshold - - device = old_log_prob_flat.device - log_ratio = old_log_prob_flat - rollout_log_prob_flat - - log_upper_threshold = torch.log(torch.tensor(upper_threshold, device=device)) - log_lower_threshold = torch.log(torch.tensor(lower_threshold, device=device)) - log_veto_threshold = torch.log(torch.tensor(veto_threshold, device=device)) + """ + level: The aggregation level for the importance sampling weights. + - "token": per-token importance sampling weights, biased low variance. + - "sequence": product over tokens, unbiased but high variance. + - "geometric": geometric mean over tokens, biased, medium variance. + mode: how to handle the importance sampling weights exceeding the thresholds. + - "truncate": cap the importance sampling weights at the upper threshold, i.e., truncated importance sampling. + - "clip": zero the importance sampling weights outside the [lower, upper] range. + eps_clip: The lower clip threshold for the importance sampling weights. + eps_clip_high: The upper clip threshold for the importance sampling weights. + Clip ratio between [1 - eps_clip, 1 + eps_clip_high] + When token ratio is out of [veto_threshold, 1 / veto_threshold], the sequence will be masked. + """ + level: str = getattr(args, prefix + "_tis_level", "token") + mode: str = getattr(args, prefix + "_tis_mode", "truncate") + eps_clip: Optional[float] = getattr(args, prefix + "_tis_eps_clip", None) + eps_clip_high: Optional[float] = getattr(args, prefix + "_tis_eps_clip_high", None) + veto_threshold: Optional[float] = getattr(args, prefix + "_tis_veto_threshold", None) - # compute TIS raw weights and catastrophe flags in a single big if/else + assert eps_clip is not None and eps_clip_high is not None, "eps_clip and eps_clip_high must be provided" - if level == "token": - # Token-level IS: π_training(a|s) / π_rollout(a|s) per token (raw, before veto and truncate/clip) - log_ratio_for_metrics = log_ratio # [sum(response_lengths)] - log_ratio_safe = torch.clamp(log_ratio, min=-safety_bound, max=safety_bound) + # Validate input lists have same length and each sequence has matching shapes + assert ( + len(old_log_probs) == len(new_log_probs) == len(loss_masks) + ), f"Input lists must have same length: {len(old_log_probs)} vs {len(new_log_probs)} vs {len(loss_masks)}" + + for i, (old, new, mask) in enumerate(zip(old_log_probs, new_log_probs, loss_masks)): + assert ( + old.shape == new.shape == mask.shape + ), f"Sequence {i}: shapes must match - old: {old.shape}, new: {new.shape}, mask: {mask.shape}" + + # Get device from first tensor + device = old_log_probs[0].device + + all_weights = [] + + # Process each sequence individually + for old_log_prob, new_log_prob, loss_mask in zip(old_log_probs, new_log_probs, loss_masks): + raw_log_ratio = old_log_prob - new_log_prob + + if level == "token": + # Token-level IS + log_ratio_for_metrics = raw_log_ratio + elif level == "sequence": + # Sequence-level IS + agg_log_ratio = masked_sum(raw_log_ratio, loss_mask) + log_ratio_for_metrics = torch.full_like(raw_log_ratio, agg_log_ratio) + elif level == "geometric": + # Geometric mean IS + agg_log_ratio = masked_mean(raw_log_ratio, loss_mask) + log_ratio_for_metrics = torch.full_like(raw_log_ratio, agg_log_ratio) + else: + raise ValueError(f"Invalid importance sampling level: {level}") + + log_ratio_safe = torch.clamp(log_ratio_for_metrics, min=-SAFETY_BOUND, max=SAFETY_BOUND) weights = torch.exp(log_ratio_safe) - # Catastrophe detection and veto mask per sequence - - catastrophic_tokens = (log_ratio < log_veto_threshold) & loss_mask_flat.bool() - cat_chunks = torch.split(catastrophic_tokens, [int(l) for l in response_lengths], dim=0) - has_catastrophic_per_seq = torch.tensor([chunk.any() for chunk in cat_chunks], device=device) - veto_mask = torch.cat( - [ - ( - torch.zeros_like(chunk, dtype=torch.float32) - if has_cat - else torch.ones_like(chunk, dtype=torch.float32) - ) - for has_cat, chunk in zip(has_catastrophic_per_seq, cat_chunks) - ], - dim=0, - ) - elif level in ["sequence", "geometric"]: - # Sequence-level/geometric: single split, aggregate, then expand (raw, before veto and truncate/clip) - assert response_lengths is not None and len(response_lengths) > 0 - seq_lengths = [int(l) for l in response_lengths] - - seq_log_ratio_chunks = torch.split(log_ratio, seq_lengths, dim=0) - seq_loss_mask_chunks = torch.split(loss_mask_flat, seq_lengths, dim=0) - - per_seq_vals_list = [] - catastrophic_token_chunks = [] - - for log_ratio_chunk, loss_mask_chunk in zip(seq_log_ratio_chunks, seq_loss_mask_chunks): - if level == "sequence": - aggregated = (log_ratio_chunk * loss_mask_chunk).sum() # [1] - else: # geometric - aggregated = (log_ratio_chunk * loss_mask_chunk).sum() / loss_mask_chunk.sum() - per_seq_vals_list.append(torch.clamp(aggregated, min=-safety_bound, max=safety_bound)) - - # Catastrophic tokens per sequence - cat_tokens_chunk = (log_ratio_chunk < log_veto_threshold) & loss_mask_chunk.bool() # [response_lengths] - catastrophic_token_chunks.append(cat_tokens_chunk) - - # each sequence only has one value in per_seq_vals_list - per_seq_vals = torch.stack(per_seq_vals_list) # [num_sequences] - # Expand per_seq_vals to match token-level shape for consistent metrics calculation - log_ratio_for_metrics = torch.cat( - [ - torch.ones_like(log_ratio_chunk) * per_seq_val - for per_seq_val, log_ratio_chunk in zip(per_seq_vals, seq_log_ratio_chunks) - ], - dim=0, - ) # [sum(response_lengths)] - weights = torch.exp(log_ratio_for_metrics) - - # Compose catastrophe tensors - catastrophic_tokens = torch.cat(catastrophic_token_chunks, dim=0) # [sum(response_lengths)] - has_catastrophic_per_seq = torch.tensor( - [chunk.any() for chunk in catastrophic_token_chunks], device=device - ) # [num_sequences] - - # Build veto mask per token from per-seq flags - veto_mask = torch.cat( - [ - ( - torch.zeros_like(chunk, dtype=torch.float32) - if has_cat - else torch.ones_like(chunk, dtype=torch.float32) - ) - for has_cat, chunk in zip(has_catastrophic_per_seq, seq_log_ratio_chunks) - ], - dim=0, - ) - else: - raise ValueError(f"Invalid importance sampling level: {level}") - - # Apply veto before mode as requested - weights = weights * veto_mask - # Then apply mode (truncate/clip) in a unified way - if mode == "truncate": - weights = weights.clamp(max=upper_threshold) - elif mode == "clip": - clip_mask = (weights >= lower_threshold) & (weights <= upper_threshold) - clip_mask = clip_mask.float() - weights = weights * clip_mask - else: - raise ValueError(f"Invalid tis mode: {mode}") - - weights = weights * loss_mask_flat - weights = weights.detach() - - metrics = compute_tis_metrics( - tis_weights=weights, - log_ratio_for_metrics=log_ratio_for_metrics, - loss_mask=loss_mask_flat, - level=level, - upper_threshold=upper_threshold, - lower_threshold=lower_threshold, - log_upper_threshold=log_upper_threshold, - log_lower_threshold=log_lower_threshold, - has_catastrophic=has_catastrophic_per_seq, - catastrophic_tokens=catastrophic_tokens, - safety_bound=safety_bound, - response_lengths=response_lengths, - ) + if veto_threshold is not None: + log_veto_threshold = torch.log(torch.tensor(veto_threshold, device=device)) + # For each sequence, check if it has any catastrophic tokens + catastrophic_tokens = ( + (log_ratio_for_metrics < log_veto_threshold) | (log_ratio_for_metrics > 1 / log_veto_threshold) + ) & loss_mask.bool() + has_catastrophic = catastrophic_tokens.any() + # Create veto mask: 0 if sequence has catastrophic tokens, 1 otherwise + veto_mask = (~has_catastrophic).float() + else: + catastrophic_tokens = torch.zeros_like(log_ratio_for_metrics, dtype=torch.bool) + has_catastrophic = False + veto_mask = torch.ones_like(weights) + + if mode == "clip": + weights = weights.clamp(1 - eps_clip, 1 + eps_clip_high) + # TODO: which mode should be masked importance sampling? + else: + raise ValueError(f"Invalid tis mode: {mode}") + + weights = weights * veto_mask + weights = weights * loss_mask + weights = weights.detach() + + all_weights.append(weights) return weights, metrics -def compute_tis_metrics( - *, - tis_weights: torch.Tensor, - log_ratio_for_metrics: torch.Tensor, - loss_mask: torch.Tensor, - level: str, - upper_threshold: float, - lower_threshold: float, - log_upper_threshold: torch.Tensor, - log_lower_threshold: torch.Tensor, - has_catastrophic: torch.Tensor, - catastrophic_tokens: torch.Tensor, - safety_bound: float, - response_lengths: Optional[list[int]] = None, -) -> Dict[str, Any]: - """ - Computes metrics that reflect the TRUE distribution (before clamping) - for the truncated importance sampling (TIS) weights. - """ - metrics: Dict[str, Any] = {} - - assert loss_mask.shape == tis_weights.shape, "loss_mask and tis_weights must have the same shape" - - # Counts/fractions reported as sum over sequences; external reducer divides by num_samples - metrics["tis_veto_fraction"] = has_catastrophic.float().sum() - metrics["tis_catastrophic_token_fraction"] = per_seq_masked_mean( - catastrophic_tokens.float(), loss_mask, response_lengths=response_lengths - ) - metrics["tis_level"] = level - assert upper_threshold == 2.0 - # Make numeric constants DP-safe by scaling with number of sequences in this batch - if tis_weights.dim() == 2: - num_sequences = tis_weights.size(0) - elif response_lengths is not None and len(response_lengths) > 0: - num_sequences = len(response_lengths) - else: - num_sequences = 1 - metrics["tis_upper_threshold"] = torch.tensor(2.0 * num_sequences, device=tis_weights.device) - metrics["tis_lower_threshold"] = torch.tensor(lower_threshold * num_sequences, device=tis_weights.device) - metrics["tis_log_upper_threshold"] = log_upper_threshold * num_sequences - metrics["tis_log_lower_threshold"] = log_lower_threshold * num_sequences - metrics["tis_safety_bound"] = torch.tensor(safety_bound * num_sequences, device=tis_weights.device) - - if level in ["sequence", "geometric"]: - # log_ratio_for_metrics is per-seq aggregated log-ratio: compare per-seq - exceeds_upper = (log_ratio_for_metrics > log_upper_threshold).float().sum() - below_lower = (log_ratio_for_metrics < log_lower_threshold).float().sum() - metrics["tis_ratio_fraction_exceeds_upper"] = exceeds_upper - metrics["tis_ratio_fraction_below_lower"] = below_lower - metrics["tis_mean"] = per_seq_masked_mean(tis_weights, loss_mask, response_lengths=response_lengths) - else: - metrics["tis_mean"] = per_seq_masked_mean(tis_weights, loss_mask, response_lengths=response_lengths) - exceeds_upper = (tis_weights > upper_threshold).float() - below_lower = (tis_weights < lower_threshold).float() - metrics["tis_ratio_fraction_exceeds_upper"] = per_seq_masked_mean( - exceeds_upper, loss_mask, response_lengths=response_lengths - ) - metrics["tis_ratio_fraction_below_lower"] = per_seq_masked_mean( - below_lower, loss_mask, response_lengths=response_lengths - ) - - # Per-sequence std and ESS, reported as sum across sequences - weights_for_std = tis_weights.clamp(min=lower_threshold, max=upper_threshold) - sequence_log_ratios = torch.split(tis_weights, [int(l) for l in response_lengths], dim=0) - sequence_loss_masks = torch.split(loss_mask, [int(l) for l in response_lengths], dim=0) - per_seq_mean = torch.stack( - [ - masked_mean(sequence_log_ratio, sequence_loss_mask) - for sequence_log_ratio, sequence_loss_mask in zip(sequence_log_ratios, sequence_loss_masks) - ] - ) - per_seq_var = ( - torch.stack( - [ - masked_mean( - sequence_log_ratio.clamp(min=lower_threshold, max=upper_threshold).square(), sequence_loss_mask - ) - for sequence_log_ratio, sequence_loss_mask in zip(sequence_log_ratios, sequence_loss_masks) - ] - ) - - per_seq_mean.square() - ) - per_seq_std = torch.sqrt(torch.clamp(per_seq_var, min=0.0)) - metrics["tis_std"] = per_seq_std.sum() - # ESS per sequence using normalized weights - weights_for_ess_list = [ - sequence_log_ratio / (pm + 1e-8) for sequence_log_ratio, pm in zip(sequence_log_ratios, per_seq_mean) - ] - per_seq_ess = torch.stack( - [ - 1.0 / masked_mean(sequence_log_ratio.square(), sequence_loss_mask) - for sequence_log_ratio, sequence_loss_mask in zip(weights_for_ess_list, sequence_loss_masks) - ] - ) - metrics["tis_eff_sample_size"] = per_seq_ess.sum() - seq_mean = per_seq_mean - - # Sequence-level summaries (sum-style for DP-safe reduction) - metrics["tis_seq_mean"] = seq_mean.sum() - metrics["tis_seq_fraction_exceeds_upper"] = (seq_mean > upper_threshold).float().sum() - metrics["tis_seq_fraction_below_lower"] = (seq_mean < lower_threshold).float().sum() - - return metrics - - def compute_kl_metrics( *, old_log_prob: torch.Tensor, From 46adaeff2b615d61bc88a04c8052412958588d69 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Mon, 13 Oct 2025 04:55:53 +0000 Subject: [PATCH 26/58] make good abstraction to metrics --- slime/backends/megatron_utils/loss.py | 36 +++++++++- slime/utils/tis.py | 96 ++++++++++++++++----------- 2 files changed, 89 insertions(+), 43 deletions(-) diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index ed94f9bf56..c41477e77b 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Union import torch @@ -14,7 +15,7 @@ get_reinforce_plus_plus_baseline_advantages, get_reinforce_plus_plus_returns, ) -from slime.utils.tis import compute_tis_weights +from slime.utils.tis import clip, clip_to_zero, compute_train_infer_tis_weights, truncate from .cp_utils import ( all_gather_with_cp, @@ -326,13 +327,42 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): for old_log_prob, total_length, response_length in zip(old_log_probs, total_lengths, response_lengths) ] - tis_weights, tis_metrics = compute_tis_weights( + """ + mode: how to handle the importance sampling weights exceeding the thresholds. + - "truncated": cap the importance sampling weights at the upper threshold + https://fengyao.notion.site/off-policy-rl#279721e3f6c48092bbe2fcfe0e9c6b33 + - "clip_to_zero": zero the importance sampling weights outside the [lower, upper] range. + https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda + - "clip": clip the importance sampling weights to the [lower, upper] range. + """ + + if args.train_infer_tis_mode == "clip_to_zero": + assert ( + args.train_infer_tis_eps_clip is not None and args.train_infer_tis_eps_clip_high is not None + ), "eps_clip and eps_clip_high must be provided" + tis_function = partial( + clip_to_zero, eps_clip=args.train_infer_tis_eps_clip, eps_clip_high=args.train_infer_tis_eps_clip_high + ) + elif args.train_infer_tis_mode == "clip": + assert ( + args.train_infer_tis_eps_clip is not None and args.train_infer_tis_eps_clip_high is not None + ), "eps_clip and eps_clip_high must be provided" + tis_function = partial( + clip, eps_clip=args.train_infer_tis_eps_clip, eps_clip_high=args.train_infer_tis_eps_clip_high + ) + elif args.train_infer_tis_mode == "truncate": + assert args.train_infer_tis_eps_clip is not None, "eps_clip must be provided" + tis_function = partial(truncate, eps=args.train_infer_tis_eps_clip) + else: + raise ValueError(f"Unsupported train_infer_tis_mode: {args.train_infer_tis_mode}") + + tis_weights, tis_metrics = compute_train_infer_tis_weights( args=args, new_log_probs=full_old_log_probs, old_log_probs=full_rollout_log_probs, loss_masks=batch["loss_masks"], response_lengths=response_lengths, - prefix="train_infer", + tis_function=tis_function, ) ois = (-ppo_kl).exp() diff --git a/slime/utils/tis.py b/slime/utils/tis.py index 6d3ca3a328..8e59f5b925 100644 --- a/slime/utils/tis.py +++ b/slime/utils/tis.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Tuple +from typing import Any, Callable, Dict, Optional, Tuple import torch @@ -21,10 +21,54 @@ def masked_mean( return total / (denom + eps) -SAFETY_BOUND = 20.0 +def calculate_veto_mask( + log_ratio_for_metrics: torch.Tensor, + loss_mask: torch.Tensor, + veto_threshold: Optional[float], + metrics: Dict[str, Any], +) -> torch.Tensor: + if veto_threshold is None: + return torch.ones_like(log_ratio_for_metrics) + log_veto_threshold = torch.log(torch.tensor(veto_threshold, device=log_ratio_for_metrics.device)) + # For each sequence, check if it has any catastrophic tokens + catastrophic_tokens = ( + (log_ratio_for_metrics < log_veto_threshold) | (log_ratio_for_metrics > 1 / log_veto_threshold) + ) & loss_mask.bool() + has_catastrophic = catastrophic_tokens.any() + # Create veto mask: 0 if sequence has catastrophic tokens, 1 otherwise + veto_mask = (~has_catastrophic).float().expand_as(log_ratio_for_metrics) + + # Update metrics + metrics["catastrophic_ratio"] += masked_mean(has_catastrophic.int(), loss_mask) + return veto_mask + + +def truncate(weights: torch.Tensor, loss_mask: torch.Tensor, metrics: Dict[str, Any], *, eps: float) -> torch.Tensor: + metrics["mean"] += masked_mean(weights, loss_mask) + metrics["truncate_fraction"] += masked_mean((weights > eps).int(), loss_mask) + return weights.clamp(0, eps) * loss_mask + + +def clip( + weights: torch.Tensor, loss_mask: torch.Tensor, metrics: Dict[str, Any], *, eps_clip: float, eps_clip_high: float +) -> torch.Tensor: + metrics["mean"] += masked_mean(weights, loss_mask) + metrics["clip_fraction_low"] += masked_mean((weights < 1 - eps_clip).int(), loss_mask) + metrics["clip_fraction_high"] += masked_mean((weights > 1 + eps_clip_high).int(), loss_mask) + return weights.clamp(1 - eps_clip, 1 + eps_clip_high) + + +def clip_to_zero( + weights: torch.Tensor, loss_mask: torch.Tensor, metrics: Dict[str, Any], *, eps_clip: float, eps_clip_high: float +) -> torch.Tensor: + metrics["mean"] += masked_mean(weights, loss_mask) + metrics["clip_fraction_low"] += masked_mean((weights < 1 - eps_clip).int(), loss_mask) + metrics["clip_fraction_high"] += masked_mean((weights > 1 + eps_clip_high).int(), loss_mask) + clip_mask = (weights >= 1 - eps_clip) & (weights <= 1 + eps_clip_high) + return weights * clip_mask -def compute_tis_weights( +def compute_train_infer_tis_weights( args, *, new_log_probs: list[torch.Tensor], @@ -32,6 +76,7 @@ def compute_tis_weights( loss_masks: list[torch.Tensor], response_lengths: Optional[list[int]] = None, prefix: str = "", + tis_function: Callable[[torch.Tensor, torch.Tensor, Dict[str, Any]], torch.Tensor] = None, ) -> Tuple[list[torch.Tensor], Dict[str, Any]]: """ Compute the truncated importance sampling (TIS) weights and metrics. @@ -66,21 +111,9 @@ def compute_tis_weights( - "token": per-token importance sampling weights, biased low variance. - "sequence": product over tokens, unbiased but high variance. - "geometric": geometric mean over tokens, biased, medium variance. - mode: how to handle the importance sampling weights exceeding the thresholds. - - "truncate": cap the importance sampling weights at the upper threshold, i.e., truncated importance sampling. - - "clip": zero the importance sampling weights outside the [lower, upper] range. - eps_clip: The lower clip threshold for the importance sampling weights. - eps_clip_high: The upper clip threshold for the importance sampling weights. - Clip ratio between [1 - eps_clip, 1 + eps_clip_high] - When token ratio is out of [veto_threshold, 1 / veto_threshold], the sequence will be masked. """ - level: str = getattr(args, prefix + "_tis_level", "token") - mode: str = getattr(args, prefix + "_tis_mode", "truncate") - eps_clip: Optional[float] = getattr(args, prefix + "_tis_eps_clip", None) - eps_clip_high: Optional[float] = getattr(args, prefix + "_tis_eps_clip_high", None) - veto_threshold: Optional[float] = getattr(args, prefix + "_tis_veto_threshold", None) - - assert eps_clip is not None and eps_clip_high is not None, "eps_clip and eps_clip_high must be provided" + level: str = args.train_infer_tis_level + metrics: Dict[str, Any] = {} # Validate input lists have same length and each sequence has matching shapes assert ( @@ -92,12 +125,11 @@ def compute_tis_weights( old.shape == new.shape == mask.shape ), f"Sequence {i}: shapes must match - old: {old.shape}, new: {new.shape}, mask: {mask.shape}" - # Get device from first tensor - device = old_log_probs[0].device - + # TODO: Get device from first tensor and apply to tensors + # device = old_log_probs[0].device + SAFETY_BOUND = 20.0 all_weights = [] - # Process each sequence individually for old_log_prob, new_log_prob, loss_mask in zip(old_log_probs, new_log_probs, loss_masks): raw_log_ratio = old_log_prob - new_log_prob @@ -118,27 +150,11 @@ def compute_tis_weights( log_ratio_safe = torch.clamp(log_ratio_for_metrics, min=-SAFETY_BOUND, max=SAFETY_BOUND) weights = torch.exp(log_ratio_safe) - if veto_threshold is not None: - log_veto_threshold = torch.log(torch.tensor(veto_threshold, device=device)) - # For each sequence, check if it has any catastrophic tokens - catastrophic_tokens = ( - (log_ratio_for_metrics < log_veto_threshold) | (log_ratio_for_metrics > 1 / log_veto_threshold) - ) & loss_mask.bool() - has_catastrophic = catastrophic_tokens.any() - # Create veto mask: 0 if sequence has catastrophic tokens, 1 otherwise - veto_mask = (~has_catastrophic).float() - else: - catastrophic_tokens = torch.zeros_like(log_ratio_for_metrics, dtype=torch.bool) - has_catastrophic = False - veto_mask = torch.ones_like(weights) + veto_mask = calculate_veto_mask(log_ratio_for_metrics, loss_mask, args.train_infer_tis_veto_threshold, metrics) + loss_mask = loss_mask & veto_mask # mask out catastrophic tokens - if mode == "clip": - weights = weights.clamp(1 - eps_clip, 1 + eps_clip_high) - # TODO: which mode should be masked importance sampling? - else: - raise ValueError(f"Invalid tis mode: {mode}") + weights = tis_function(weights, loss_mask, metrics) - weights = weights * veto_mask weights = weights * loss_mask weights = weights.detach() From 2d9ffd893305ca18a65ddb07db82c307e1a2c720 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Mon, 13 Oct 2025 05:08:02 +0000 Subject: [PATCH 27/58] fix seq concat problem --- slime/backends/megatron_utils/loss.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index c41477e77b..4db73fcc9e 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -17,12 +17,7 @@ ) from slime.utils.tis import clip, clip_to_zero, compute_train_infer_tis_weights, truncate -from .cp_utils import ( - all_gather_with_cp, - get_logits_and_tokens_offset_with_cp, - get_sum_of_sample_mean, - slice_log_prob_with_cp, -) +from .cp_utils import all_gather_with_cp, get_logits_and_tokens_offset_with_cp, get_sum_of_sample_mean def get_responses( @@ -367,17 +362,7 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): ois = (-ppo_kl).exp() - # tis_weights is a 1D tensor, should be sliced to the local cp rank - local_tis_chunks = [] - start = 0 - for total_len, response_len in zip(total_lengths, response_lengths): - end = start + int(response_len) - seq_weights = tis_weights[start:end] - # Slice to the two local chunks of this CP rank - local_chunk = slice_log_prob_with_cp(seq_weights, int(total_len), int(response_len)) - local_tis_chunks.append(local_chunk) - start = end - tis_weights = torch.cat(local_tis_chunks, dim=0) + tis_weights = torch.cat(tis_weights, dim=0) pg_loss = pg_loss * tis_weights From f44a7d1ffff742d131c8a7cf1bae7bce908386fb Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Mon, 13 Oct 2025 05:41:51 +0000 Subject: [PATCH 28/58] fix key not exist --- slime/utils/tis.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/slime/utils/tis.py b/slime/utils/tis.py index 8e59f5b925..ae9e4dd9d3 100644 --- a/slime/utils/tis.py +++ b/slime/utils/tis.py @@ -21,6 +21,12 @@ def masked_mean( return total / (denom + eps) +def metrics_add(metrics: Dict[str, Any], key: str, value: float) -> None: + if key not in metrics: + metrics[key] = 0 + metrics[key] += value + + def calculate_veto_mask( log_ratio_for_metrics: torch.Tensor, loss_mask: torch.Tensor, @@ -39,31 +45,31 @@ def calculate_veto_mask( veto_mask = (~has_catastrophic).float().expand_as(log_ratio_for_metrics) # Update metrics - metrics["catastrophic_ratio"] += masked_mean(has_catastrophic.int(), loss_mask) + metrics_add(metrics, "catastrophic_ratio", masked_mean(has_catastrophic.int(), loss_mask)) return veto_mask def truncate(weights: torch.Tensor, loss_mask: torch.Tensor, metrics: Dict[str, Any], *, eps: float) -> torch.Tensor: - metrics["mean"] += masked_mean(weights, loss_mask) - metrics["truncate_fraction"] += masked_mean((weights > eps).int(), loss_mask) + metrics_add(metrics, "mean", masked_mean(weights, loss_mask)) + metrics_add(metrics, "truncate_fraction", masked_mean((weights > eps).int(), loss_mask)) return weights.clamp(0, eps) * loss_mask def clip( weights: torch.Tensor, loss_mask: torch.Tensor, metrics: Dict[str, Any], *, eps_clip: float, eps_clip_high: float ) -> torch.Tensor: - metrics["mean"] += masked_mean(weights, loss_mask) - metrics["clip_fraction_low"] += masked_mean((weights < 1 - eps_clip).int(), loss_mask) - metrics["clip_fraction_high"] += masked_mean((weights > 1 + eps_clip_high).int(), loss_mask) + metrics_add(metrics, "mean", masked_mean(weights, loss_mask)) + metrics_add(metrics, "clip_fraction_low", masked_mean((weights < 1 - eps_clip).int(), loss_mask)) + metrics_add(metrics, "clip_fraction_high", masked_mean((weights > 1 + eps_clip_high).int(), loss_mask)) return weights.clamp(1 - eps_clip, 1 + eps_clip_high) def clip_to_zero( weights: torch.Tensor, loss_mask: torch.Tensor, metrics: Dict[str, Any], *, eps_clip: float, eps_clip_high: float ) -> torch.Tensor: - metrics["mean"] += masked_mean(weights, loss_mask) - metrics["clip_fraction_low"] += masked_mean((weights < 1 - eps_clip).int(), loss_mask) - metrics["clip_fraction_high"] += masked_mean((weights > 1 + eps_clip_high).int(), loss_mask) + metrics_add(metrics, "mean", masked_mean(weights, loss_mask)) + metrics_add(metrics, "clip_fraction_low", masked_mean((weights < 1 - eps_clip).int(), loss_mask)) + metrics_add(metrics, "clip_fraction_high", masked_mean((weights > 1 + eps_clip_high).int(), loss_mask)) clip_mask = (weights >= 1 - eps_clip) & (weights <= 1 + eps_clip_high) return weights * clip_mask From 8d77db6da9b7a87d8eddaf32fc8f2e70a59d9a97 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Mon, 13 Oct 2025 18:11:43 +0000 Subject: [PATCH 29/58] update metrics calculation --- slime/backends/megatron_utils/loss.py | 14 +- slime/utils/tis.py | 209 ++++++++++++++------------ 2 files changed, 117 insertions(+), 106 deletions(-) diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 4db73fcc9e..12a89fac9b 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -353,17 +353,15 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): tis_weights, tis_metrics = compute_train_infer_tis_weights( args=args, - new_log_probs=full_old_log_probs, - old_log_probs=full_rollout_log_probs, + train_log_probs=full_old_log_probs, + rollout_log_probs=full_rollout_log_probs, loss_masks=batch["loss_masks"], response_lengths=response_lengths, tis_function=tis_function, ) ois = (-ppo_kl).exp() - - tis_weights = torch.cat(tis_weights, dim=0) - + tis = tis_weights pg_loss = pg_loss * tis_weights pg_loss = sum_of_sample_mean(pg_loss) @@ -407,12 +405,10 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): if args.use_train_infer_tis: # Backward compatible basic logs reported_loss["ois"] = sum_of_sample_mean(ois).clone().detach() + reported_loss["tis"] = sum_of_sample_mean(tis).clone().detach() for metric_key, metric_value in tis_metrics.items(): key_name = f"train_infer_{metric_key}" - if torch.is_tensor(metric_value): - reported_loss[key_name] = metric_value.clone().detach() - elif isinstance(metric_value, (int, float)): - reported_loss[key_name] = torch.tensor(metric_value, device=log_probs.device) + reported_loss[key_name] = sum_of_sample_mean(metric_value).clone().detach() return loss, reported_loss diff --git a/slime/utils/tis.py b/slime/utils/tis.py index ae9e4dd9d3..cbbed9dad5 100644 --- a/slime/utils/tis.py +++ b/slime/utils/tis.py @@ -1,12 +1,13 @@ from typing import Any, Callable, Dict, Optional, Tuple import torch +from slime.backends.megatron_utils.cp_utils import slice_with_cp def masked_sum( tensor: torch.Tensor, mask: torch.Tensor, dim: Optional[int] = None, keepdim: bool = False ) -> torch.Tensor: - return (tensor * mask.float()).sum(dim=dim, keepdim=keepdim) + return (tensor * mask).sum(dim=dim, keepdim=keepdim) def masked_mean( @@ -16,8 +17,8 @@ def masked_mean( keepdim: bool = False, eps: float = 1e-8, ) -> torch.Tensor: - total = (tensor * mask.float()).sum(dim=dim, keepdim=keepdim) - denom = mask.float().sum(dim=dim, keepdim=keepdim) + total = (tensor * mask).sum(dim=dim, keepdim=keepdim) + denom = mask.sum(dim=dim, keepdim=keepdim) return total / (denom + eps) @@ -27,6 +28,18 @@ def metrics_add(metrics: Dict[str, Any], key: str, value: float) -> None: metrics[key] += value +def metrics_append(metrics: Dict[str, Any], key: str, value: torch.Tensor) -> None: + if key not in metrics: + metrics[key] = [] + metrics[key].append(slice_with_cp(value.clone().detach(), 0)) + + +def metrics_concat(metrics: Dict[str, Any]) -> None: + for key, values in metrics.items(): + assert isinstance(values, list), f"Metric {key} is not a list" + metrics[key] = torch.cat(values, dim=0) + + def calculate_veto_mask( log_ratio_for_metrics: torch.Tensor, loss_mask: torch.Tensor, @@ -36,40 +49,36 @@ def calculate_veto_mask( if veto_threshold is None: return torch.ones_like(log_ratio_for_metrics) log_veto_threshold = torch.log(torch.tensor(veto_threshold, device=log_ratio_for_metrics.device)) - # For each sequence, check if it has any catastrophic tokens + # For each sequence, if it has any catastrophic tokens, return 0 for the sequence catastrophic_tokens = ( - (log_ratio_for_metrics < log_veto_threshold) | (log_ratio_for_metrics > 1 / log_veto_threshold) + (log_ratio_for_metrics < log_veto_threshold) | (log_ratio_for_metrics > -log_veto_threshold) ) & loss_mask.bool() has_catastrophic = catastrophic_tokens.any() - # Create veto mask: 0 if sequence has catastrophic tokens, 1 otherwise veto_mask = (~has_catastrophic).float().expand_as(log_ratio_for_metrics) # Update metrics - metrics_add(metrics, "catastrophic_ratio", masked_mean(has_catastrophic.int(), loss_mask)) + metrics_append(metrics, "catastrophic_fraction", has_catastrophic.int().expand_as(loss_mask)) return veto_mask def truncate(weights: torch.Tensor, loss_mask: torch.Tensor, metrics: Dict[str, Any], *, eps: float) -> torch.Tensor: - metrics_add(metrics, "mean", masked_mean(weights, loss_mask)) - metrics_add(metrics, "truncate_fraction", masked_mean((weights > eps).int(), loss_mask)) - return weights.clamp(0, eps) * loss_mask + metrics_append(metrics, "truncate_fraction", (weights > eps).int()) + return weights.clamp(0, eps) def clip( weights: torch.Tensor, loss_mask: torch.Tensor, metrics: Dict[str, Any], *, eps_clip: float, eps_clip_high: float ) -> torch.Tensor: - metrics_add(metrics, "mean", masked_mean(weights, loss_mask)) - metrics_add(metrics, "clip_fraction_low", masked_mean((weights < 1 - eps_clip).int(), loss_mask)) - metrics_add(metrics, "clip_fraction_high", masked_mean((weights > 1 + eps_clip_high).int(), loss_mask)) + metrics_append(metrics, "clip_fraction_low", (weights < 1 - eps_clip).int()) + metrics_append(metrics, "clip_fraction_high", (weights > 1 + eps_clip_high).int()) return weights.clamp(1 - eps_clip, 1 + eps_clip_high) def clip_to_zero( weights: torch.Tensor, loss_mask: torch.Tensor, metrics: Dict[str, Any], *, eps_clip: float, eps_clip_high: float ) -> torch.Tensor: - metrics_add(metrics, "mean", masked_mean(weights, loss_mask)) - metrics_add(metrics, "clip_fraction_low", masked_mean((weights < 1 - eps_clip).int(), loss_mask)) - metrics_add(metrics, "clip_fraction_high", masked_mean((weights > 1 + eps_clip_high).int(), loss_mask)) + metrics_append(metrics, "clip_fraction_low", (weights < 1 - eps_clip).int()) + metrics_append(metrics, "clip_fraction_high", (weights > 1 + eps_clip_high).int()) clip_mask = (weights >= 1 - eps_clip) & (weights <= 1 + eps_clip_high) return weights * clip_mask @@ -77,12 +86,12 @@ def clip_to_zero( def compute_train_infer_tis_weights( args, *, - new_log_probs: list[torch.Tensor], - old_log_probs: list[torch.Tensor], + train_log_probs: list[torch.Tensor], + rollout_log_probs: list[torch.Tensor], loss_masks: list[torch.Tensor], response_lengths: Optional[list[int]] = None, prefix: str = "", - tis_function: Callable[[torch.Tensor, torch.Tensor, Dict[str, Any]], torch.Tensor] = None, + tis_function: Callable[[torch.Tensor, torch.Tensor, Dict[str, Any]], torch.Tensor], ) -> Tuple[list[torch.Tensor], Dict[str, Any]]: """ Compute the truncated importance sampling (TIS) weights and metrics. @@ -93,14 +102,8 @@ def compute_train_infer_tis_weights( https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda Args: - new_log_probs: List of log probs from new policy, one tensor per sequence. - old_log_probs: List of log probs from old policy, one tensor per sequence. - - under training/inference tis - - new_log_probs = training backend - - old_log_probs = rollout backend - - under mini batch tis - - new_log_probs = new batch - - old_log_probs = old batch + train_log_probs: List of log probs from training backend, one tensor per sequence. + rollout_log_probs: List of log probs from inference backend, one tensor per sequence. loss_masks: List of loss masks, one tensor per sequence. Note that for single turn RL, the loss_mask is [1] * response_length for each sequence For multi turn RL, the tool response will be marked as 0 in the loss_mask. @@ -123,21 +126,22 @@ def compute_train_infer_tis_weights( # Validate input lists have same length and each sequence has matching shapes assert ( - len(old_log_probs) == len(new_log_probs) == len(loss_masks) - ), f"Input lists must have same length: {len(old_log_probs)} vs {len(new_log_probs)} vs {len(loss_masks)}" + len(train_log_probs) == len(rollout_log_probs) == len(loss_masks) + ), f"Input lists must have same length: {len(train_log_probs)} vs {len(rollout_log_probs)} vs {len(loss_masks)}" - for i, (old, new, mask) in enumerate(zip(old_log_probs, new_log_probs, loss_masks)): + for i, (train, rollout, mask) in enumerate(zip(train_log_probs, rollout_log_probs, loss_masks)): assert ( - old.shape == new.shape == mask.shape - ), f"Sequence {i}: shapes must match - old: {old.shape}, new: {new.shape}, mask: {mask.shape}" + train.shape == rollout.shape == mask.shape + ), f"Sequence {i}: shapes must match - train: {train.shape}, rollout: {rollout.shape}, mask: {mask.shape}" # TODO: Get device from first tensor and apply to tensors - # device = old_log_probs[0].device + # device = train_log_probs[0].device SAFETY_BOUND = 20.0 all_weights = [] - for old_log_prob, new_log_prob, loss_mask in zip(old_log_probs, new_log_probs, loss_masks): - raw_log_ratio = old_log_prob - new_log_prob + for train_log_prob, rollout_log_prob, loss_mask in zip(train_log_probs, rollout_log_probs, loss_masks): + raw_log_ratio = train_log_prob - rollout_log_prob + loss_mask = loss_mask.float() if level == "token": # Token-level IS @@ -156,70 +160,81 @@ def compute_train_infer_tis_weights( log_ratio_safe = torch.clamp(log_ratio_for_metrics, min=-SAFETY_BOUND, max=SAFETY_BOUND) weights = torch.exp(log_ratio_safe) - veto_mask = calculate_veto_mask(log_ratio_for_metrics, loss_mask, args.train_infer_tis_veto_threshold, metrics) - loss_mask = loss_mask & veto_mask # mask out catastrophic tokens + # mask out catastrophic tokens + if args.train_infer_tis_veto_threshold is not None: + veto_mask = calculate_veto_mask( + log_ratio_for_metrics, loss_mask, args.train_infer_tis_veto_threshold, metrics + ) + metrics_append(metrics, "raw_ratio_mean", weights) weights = tis_function(weights, loss_mask, metrics) - weights = weights * loss_mask - weights = weights.detach() + metrics_append(metrics, "ratio_mean_after_tis", weights) + if args.train_infer_tis_veto_threshold is not None: + weights = weights * veto_mask + metrics_append(metrics, "ratio_mean_after_veto_mask", weights) - all_weights.append(weights) - - return weights, metrics - - -def compute_kl_metrics( - *, - old_log_prob: torch.Tensor, - rollout_log_prob: torch.Tensor, - loss_mask: Optional[torch.Tensor], - response_lengths: Optional[list[int]] = None, -) -> Dict[str, Any]: - metrics: Dict[str, Any] = {} + weights = weights.detach() - device = old_log_prob.device - if loss_mask is None: - loss_mask = torch.ones_like(old_log_prob, dtype=torch.bool, device=device) - - # Direct estimator for KL(pi_rollout || pi_old): per-seq mean then sum (1D inputs only) - assert response_lengths is not None and loss_mask is not None - sequence_log_ratios = torch.split(rollout_log_prob - old_log_prob, [int(l) for l in response_lengths], dim=0) - sequence_loss_masks = torch.split(loss_mask, [int(l) for l in response_lengths], dim=0) - per_seq = [ - masked_mean(sequence_log_ratio, sequence_loss_mask) - for sequence_log_ratio, sequence_loss_mask in zip(sequence_log_ratios, sequence_loss_masks) - ] - metrics["rollout_kl"] = torch.stack(per_seq).sum() - - # K3 estimator: E[exp(log(pi_old/pi_rollout)) - log(pi_old/pi_rollout) - 1] - log_ratio = old_log_prob - rollout_log_prob - k3_matrix = torch.exp(log_ratio) - log_ratio - 1 - sequence_log_ratios = torch.split(k3_matrix, [int(l) for l in response_lengths], dim=0) - sequence_loss_masks = torch.split(loss_mask, [int(l) for l in response_lengths], dim=0) - per_seq = [ - masked_mean(sequence_log_ratio, sequence_loss_mask) - for sequence_log_ratio, sequence_loss_mask in zip(sequence_log_ratios, sequence_loss_masks) - ] - metrics["rollout_k3_kl"] = torch.stack(per_seq).sum() - - # Sequence-level perplexity difference metrics - assert response_lengths is not None and len(response_lengths) > 0 - seq_rollout_means = [] - seq_old_means = [] - start = 0 - for length in response_lengths: - end = start + int(length) - mask_chunk = loss_mask[start:end] - seq_rollout_means.append(masked_mean(rollout_log_prob[start:end], mask_chunk)) - seq_old_means.append(masked_mean(old_log_prob[start:end], mask_chunk)) - start = end - mean_log_prob_rollout_per_seq = torch.stack(seq_rollout_means) - mean_log_prob_old_per_seq = torch.stack(seq_old_means) - - diff = mean_log_prob_rollout_per_seq - mean_log_prob_old_per_seq - # report sums; external reducer divides by num_samples - metrics["log_ppl_diff"] = diff.sum() - metrics["log_ppl_abs_diff"] = diff.abs().sum() - - return metrics + all_weights.append(slice_with_cp(weights, 0)) + + all_weights = torch.cat(all_weights, dim=0) + metrics_concat(metrics) + + return all_weights, metrics + + +# def compute_kl_metrics( +# *, +# train_log_prob: torch.Tensor, +# rollout_log_prob: torch.Tensor, +# loss_mask: Optional[torch.Tensor], +# response_lengths: Optional[list[int]] = None, +# ) -> Dict[str, Any]: +# metrics: Dict[str, Any] = {} + +# device = train_log_prob.device +# if loss_mask is None: +# loss_mask = torch.ones_like(train_log_prob, dtype=torch.bool, device=device) + +# # Direct estimator for KL(pi_rollout || pi_old): per-seq mean then sum (1D inputs only) +# assert response_lengths is not None and loss_mask is not None +# sequence_log_ratios = torch.split(rollout_log_prob - train_log_prob, [int(l) for l in response_lengths], dim=0) +# sequence_loss_masks = torch.split(loss_mask, [int(l) for l in response_lengths], dim=0) +# per_seq = [ +# masked_mean(sequence_log_ratio, sequence_loss_mask) +# for sequence_log_ratio, sequence_loss_mask in zip(sequence_log_ratios, sequence_loss_masks) +# ] +# metrics["rollout_kl"] = torch.stack(per_seq).sum() + +# # K3 estimator: E[exp(log(pi_old/pi_rollout)) - log(pi_old/pi_rollout) - 1] +# log_ratio = train_log_prob - rollout_log_prob +# k3_matrix = torch.exp(log_ratio) - log_ratio - 1 +# sequence_log_ratios = torch.split(k3_matrix, [int(l) for l in response_lengths], dim=0) +# sequence_loss_masks = torch.split(loss_mask, [int(l) for l in response_lengths], dim=0) +# per_seq = [ +# masked_mean(sequence_log_ratio, sequence_loss_mask) +# for sequence_log_ratio, sequence_loss_mask in zip(sequence_log_ratios, sequence_loss_masks) +# ] +# metrics["rollout_k3_kl"] = torch.stack(per_seq).sum() + +# # Sequence-level perplexity difference metrics +# assert response_lengths is not None and len(response_lengths) > 0 +# seq_rollout_means = [] +# seq_train_means = [] +# start = 0 +# for length in response_lengths: +# end = start + int(length) +# mask_chunk = loss_mask[start:end] +# seq_rollout_means.append(masked_mean(rollout_log_prob[start:end], mask_chunk)) +# seq_train_means.append(masked_mean(train_log_prob[start:end], mask_chunk)) +# start = end +# mean_log_prob_rollout_per_seq = torch.stack(seq_rollout_means) +# mean_log_prob_train_per_seq = torch.stack(seq_train_means) + +# diff = mean_log_prob_rollout_per_seq - mean_log_prob_train_per_seq +# # report sums; external reducer divides by num_samples +# metrics["log_ppl_diff"] = diff.sum() +# metrics["log_ppl_abs_diff"] = diff.abs().sum() + +# return metrics From 88f24dd0a2796d088a08a43ffb1c5638c10542ca Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Mon, 13 Oct 2025 23:50:07 +0000 Subject: [PATCH 30/58] fix cp>1 with scatter_with_cp func and update argument --- slime/backends/megatron_utils/cp_utils.py | 26 +++++++ slime/backends/megatron_utils/loss.py | 3 +- slime/utils/arguments.py | 2 +- slime/utils/tis.py | 95 ++++++----------------- 4 files changed, 50 insertions(+), 76 deletions(-) diff --git a/slime/backends/megatron_utils/cp_utils.py b/slime/backends/megatron_utils/cp_utils.py index 3ba6adfdf2..8d44cd52ce 100644 --- a/slime/backends/megatron_utils/cp_utils.py +++ b/slime/backends/megatron_utils/cp_utils.py @@ -155,6 +155,32 @@ def zero(len): return full_tensor +def scatter_with_cp(full_tensor: torch.Tensor, total_length: int, response_length: int) -> torch.Tensor: + """ + Inverse of `all_gather_with_cp`. Returns the chunk for the current CP rank. + """ + cp_size = mpu.get_context_parallel_world_size() + if cp_size == 1: + return full_tensor + + prompt_length = total_length - response_length + _, _, logits_offset, _ = get_logits_and_tokens_offset_with_cp(total_length, response_length) + + def _slice(offset: tuple[int, int]) -> torch.Tensor: + start, end = offset + if start >= end: + return full_tensor.new_empty((0,) + full_tensor.shape[1:]) + start -= prompt_length - 1 + end -= prompt_length - 1 + if start >= end: + return full_tensor.new_empty((0,) + full_tensor.shape[1:]) + return full_tensor[start:end] + + chunk_0 = _slice(logits_offset[0]) + chunk_1 = _slice(logits_offset[1]) + return torch.cat([chunk_0, chunk_1], dim=0) + + def slice_with_cp(tokens: torch.Tensor, pad_value): cp_rank = mpu.get_context_parallel_rank() cp_size = mpu.get_context_parallel_world_size() diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 12a89fac9b..3cfd533e98 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -356,12 +356,12 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): train_log_probs=full_old_log_probs, rollout_log_probs=full_rollout_log_probs, loss_masks=batch["loss_masks"], + total_lengths=total_lengths, response_lengths=response_lengths, tis_function=tis_function, ) ois = (-ppo_kl).exp() - tis = tis_weights pg_loss = pg_loss * tis_weights pg_loss = sum_of_sample_mean(pg_loss) @@ -405,7 +405,6 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): if args.use_train_infer_tis: # Backward compatible basic logs reported_loss["ois"] = sum_of_sample_mean(ois).clone().detach() - reported_loss["tis"] = sum_of_sample_mean(tis).clone().detach() for metric_key, metric_value in tis_metrics.items(): key_name = f"train_infer_{metric_key}" reported_loss[key_name] = sum_of_sample_mean(metric_value).clone().detach() diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index c2bfcc6571..3ed945e39f 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -73,7 +73,7 @@ def add_tis_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--train-infer-tis-veto-threshold", type=float, - default=1e-4, + default=None, help=("Per-token veto threshold. If any token ratio < this, zero the entire sequence weight."), ) diff --git a/slime/utils/tis.py b/slime/utils/tis.py index cbbed9dad5..b044305679 100644 --- a/slime/utils/tis.py +++ b/slime/utils/tis.py @@ -1,7 +1,7 @@ from typing import Any, Callable, Dict, Optional, Tuple import torch -from slime.backends.megatron_utils.cp_utils import slice_with_cp +from slime.backends.megatron_utils.cp_utils import scatter_with_cp def masked_sum( @@ -31,13 +31,14 @@ def metrics_add(metrics: Dict[str, Any], key: str, value: float) -> None: def metrics_append(metrics: Dict[str, Any], key: str, value: torch.Tensor) -> None: if key not in metrics: metrics[key] = [] - metrics[key].append(slice_with_cp(value.clone().detach(), 0)) + metrics[key].append(value.clone().detach()) -def metrics_concat(metrics: Dict[str, Any]) -> None: - for key, values in metrics.items(): - assert isinstance(values, list), f"Metric {key} is not a list" - metrics[key] = torch.cat(values, dim=0) +def scatter_cp_and_concat( + values: list[torch.Tensor], total_lengths: list[int], response_lengths: list[int] +) -> list[torch.Tensor]: + values = [scatter_with_cp(values[i], total_lengths[i], response_lengths[i]) for i in range(len(values))] + return torch.cat(values, dim=0) def calculate_veto_mask( @@ -56,14 +57,13 @@ def calculate_veto_mask( has_catastrophic = catastrophic_tokens.any() veto_mask = (~has_catastrophic).float().expand_as(log_ratio_for_metrics) - # Update metrics metrics_append(metrics, "catastrophic_fraction", has_catastrophic.int().expand_as(loss_mask)) return veto_mask def truncate(weights: torch.Tensor, loss_mask: torch.Tensor, metrics: Dict[str, Any], *, eps: float) -> torch.Tensor: metrics_append(metrics, "truncate_fraction", (weights > eps).int()) - return weights.clamp(0, eps) + return weights.clamp(0, eps) * loss_mask def clip( @@ -71,7 +71,7 @@ def clip( ) -> torch.Tensor: metrics_append(metrics, "clip_fraction_low", (weights < 1 - eps_clip).int()) metrics_append(metrics, "clip_fraction_high", (weights > 1 + eps_clip_high).int()) - return weights.clamp(1 - eps_clip, 1 + eps_clip_high) + return weights.clamp(1 - eps_clip, 1 + eps_clip_high) * loss_mask def clip_to_zero( @@ -80,7 +80,7 @@ def clip_to_zero( metrics_append(metrics, "clip_fraction_low", (weights < 1 - eps_clip).int()) metrics_append(metrics, "clip_fraction_high", (weights > 1 + eps_clip_high).int()) clip_mask = (weights >= 1 - eps_clip) & (weights <= 1 + eps_clip_high) - return weights * clip_mask + return weights * clip_mask * loss_mask def compute_train_infer_tis_weights( @@ -89,8 +89,8 @@ def compute_train_infer_tis_weights( train_log_probs: list[torch.Tensor], rollout_log_probs: list[torch.Tensor], loss_masks: list[torch.Tensor], + total_lengths: Optional[list[int]] = None, response_lengths: Optional[list[int]] = None, - prefix: str = "", tis_function: Callable[[torch.Tensor, torch.Tensor, Dict[str, Any]], torch.Tensor], ) -> Tuple[list[torch.Tensor], Dict[str, Any]]: """ @@ -108,7 +108,6 @@ def compute_train_infer_tis_weights( Note that for single turn RL, the loss_mask is [1] * response_length for each sequence For multi turn RL, the tool response will be marked as 0 in the loss_mask. response_lengths: The length of the response for each sequence. - prefix: The prefix for the parameters, indicating which tis is used. Returns: weights: The importance sampling weights. [batch_size, seq_len] @@ -129,6 +128,12 @@ def compute_train_infer_tis_weights( len(train_log_probs) == len(rollout_log_probs) == len(loss_masks) ), f"Input lists must have same length: {len(train_log_probs)} vs {len(rollout_log_probs)} vs {len(loss_masks)}" + if total_lengths is not None: + assert response_lengths is not None, "response_lengths must be provided when total_lengths is set" + assert len(total_lengths) == len( + train_log_probs + ), f"total_lengths must match number of sequences, got {len(total_lengths)} vs {len(train_log_probs)}" + for i, (train, rollout, mask) in enumerate(zip(train_log_probs, rollout_log_probs, loss_masks)): assert ( train.shape == rollout.shape == mask.shape @@ -168,73 +173,17 @@ def compute_train_infer_tis_weights( metrics_append(metrics, "raw_ratio_mean", weights) weights = tis_function(weights, loss_mask, metrics) - weights = weights * loss_mask metrics_append(metrics, "ratio_mean_after_tis", weights) if args.train_infer_tis_veto_threshold is not None: weights = weights * veto_mask metrics_append(metrics, "ratio_mean_after_veto_mask", weights) weights = weights.detach() + all_weights.append(weights) - all_weights.append(slice_with_cp(weights, 0)) - - all_weights = torch.cat(all_weights, dim=0) - metrics_concat(metrics) + all_weights = scatter_cp_and_concat(all_weights, total_lengths, response_lengths) + for key, values in metrics.items(): + values = scatter_cp_and_concat(values, total_lengths, response_lengths) + metrics[key] = values return all_weights, metrics - - -# def compute_kl_metrics( -# *, -# train_log_prob: torch.Tensor, -# rollout_log_prob: torch.Tensor, -# loss_mask: Optional[torch.Tensor], -# response_lengths: Optional[list[int]] = None, -# ) -> Dict[str, Any]: -# metrics: Dict[str, Any] = {} - -# device = train_log_prob.device -# if loss_mask is None: -# loss_mask = torch.ones_like(train_log_prob, dtype=torch.bool, device=device) - -# # Direct estimator for KL(pi_rollout || pi_old): per-seq mean then sum (1D inputs only) -# assert response_lengths is not None and loss_mask is not None -# sequence_log_ratios = torch.split(rollout_log_prob - train_log_prob, [int(l) for l in response_lengths], dim=0) -# sequence_loss_masks = torch.split(loss_mask, [int(l) for l in response_lengths], dim=0) -# per_seq = [ -# masked_mean(sequence_log_ratio, sequence_loss_mask) -# for sequence_log_ratio, sequence_loss_mask in zip(sequence_log_ratios, sequence_loss_masks) -# ] -# metrics["rollout_kl"] = torch.stack(per_seq).sum() - -# # K3 estimator: E[exp(log(pi_old/pi_rollout)) - log(pi_old/pi_rollout) - 1] -# log_ratio = train_log_prob - rollout_log_prob -# k3_matrix = torch.exp(log_ratio) - log_ratio - 1 -# sequence_log_ratios = torch.split(k3_matrix, [int(l) for l in response_lengths], dim=0) -# sequence_loss_masks = torch.split(loss_mask, [int(l) for l in response_lengths], dim=0) -# per_seq = [ -# masked_mean(sequence_log_ratio, sequence_loss_mask) -# for sequence_log_ratio, sequence_loss_mask in zip(sequence_log_ratios, sequence_loss_masks) -# ] -# metrics["rollout_k3_kl"] = torch.stack(per_seq).sum() - -# # Sequence-level perplexity difference metrics -# assert response_lengths is not None and len(response_lengths) > 0 -# seq_rollout_means = [] -# seq_train_means = [] -# start = 0 -# for length in response_lengths: -# end = start + int(length) -# mask_chunk = loss_mask[start:end] -# seq_rollout_means.append(masked_mean(rollout_log_prob[start:end], mask_chunk)) -# seq_train_means.append(masked_mean(train_log_prob[start:end], mask_chunk)) -# start = end -# mean_log_prob_rollout_per_seq = torch.stack(seq_rollout_means) -# mean_log_prob_train_per_seq = torch.stack(seq_train_means) - -# diff = mean_log_prob_rollout_per_seq - mean_log_prob_train_per_seq -# # report sums; external reducer divides by num_samples -# metrics["log_ppl_diff"] = diff.sum() -# metrics["log_ppl_abs_diff"] = diff.abs().sum() - -# return metrics From 3b687082f27aaad0e2192fabcd369860f4c48ae7 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Tue, 14 Oct 2025 01:09:59 +0000 Subject: [PATCH 31/58] move cp scatter to loss.py --- slime/backends/megatron_utils/loss.py | 18 +++++++--- slime/utils/arguments.py | 9 ++--- slime/utils/tis.py | 52 +++++++++------------------ 3 files changed, 33 insertions(+), 46 deletions(-) diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 3cfd533e98..8a122e4458 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -17,7 +17,7 @@ ) from slime.utils.tis import clip, clip_to_zero, compute_train_infer_tis_weights, truncate -from .cp_utils import all_gather_with_cp, get_logits_and_tokens_offset_with_cp, get_sum_of_sample_mean +from .cp_utils import all_gather_with_cp, get_logits_and_tokens_offset_with_cp, get_sum_of_sample_mean, scatter_with_cp def get_responses( @@ -356,11 +356,21 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): train_log_probs=full_old_log_probs, rollout_log_probs=full_rollout_log_probs, loss_masks=batch["loss_masks"], - total_lengths=total_lengths, - response_lengths=response_lengths, tis_function=tis_function, ) + def scatter_cp_and_concat( + values: list[torch.Tensor], total_lengths: list[int], response_lengths: list[int] + ) -> list[torch.Tensor]: + # reshape value to the sequence size of the cp rank. + values = [scatter_with_cp(values[i], total_lengths[i], response_lengths[i]) for i in range(len(values))] + return torch.cat(values, dim=0) + + tis_weights = scatter_cp_and_concat(tis_weights, total_lengths, response_lengths) + for key, values in tis_metrics.items(): + values = scatter_cp_and_concat(values, total_lengths, response_lengths) + tis_metrics[key] = values + ois = (-ppo_kl).exp() pg_loss = pg_loss * tis_weights @@ -407,7 +417,7 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): reported_loss["ois"] = sum_of_sample_mean(ois).clone().detach() for metric_key, metric_value in tis_metrics.items(): key_name = f"train_infer_{metric_key}" - reported_loss[key_name] = sum_of_sample_mean(metric_value).clone().detach() + reported_loss[key_name] = sum_of_sample_mean(metric_value) return loss, reported_loss diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 3ed945e39f..b8fbe7cfae 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -59,22 +59,19 @@ def add_tis_arguments(parser: argparse.ArgumentParser): "--train-infer-tis-eps-clip", type=float, default=0.2, - help=( - "Lower clip threshold for IS weights. Default is 0.2. " - "For clip mode uses this value; for truncate mode remains unused." - ), + help=("Lower clip threshold for IS weights. Default is 0.2. "), ) parser.add_argument( "--train-infer-tis-eps-clip-high", type=float, default=None, - help=("Upper clip threshold for IS weights. Default is None."), + help=("Upper clip threshold for IS weights. Default is None. Truncate mode will not use this value."), ) parser.add_argument( "--train-infer-tis-veto-threshold", type=float, default=None, - help=("Per-token veto threshold. If any token ratio < this, zero the entire sequence weight."), + help=("Per-token veto threshold. If any token ratio < this or > 1/this, zero the entire sequence weight."), ) return parser diff --git a/slime/utils/tis.py b/slime/utils/tis.py index b044305679..adc89b35e9 100644 --- a/slime/utils/tis.py +++ b/slime/utils/tis.py @@ -1,7 +1,6 @@ from typing import Any, Callable, Dict, Optional, Tuple import torch -from slime.backends.megatron_utils.cp_utils import scatter_with_cp def masked_sum( @@ -22,25 +21,22 @@ def masked_mean( return total / (denom + eps) -def metrics_add(metrics: Dict[str, Any], key: str, value: float) -> None: - if key not in metrics: - metrics[key] = 0 - metrics[key] += value - - -def metrics_append(metrics: Dict[str, Any], key: str, value: torch.Tensor) -> None: +def metrics_append(metrics: Dict[str, list[torch.Tensor]], key: str, value: torch.Tensor) -> None: + """ + Any metric should be list[torch.Tensor] with size [num_seq, response_length] + All tis metrics will be aggregated and averaged by `sum_of_sample_mean` and megatron automatically + The result will be sequence-level average or token-level if `calculate_per_token_loss` is set. + You have no need to worry about loss_mask — the sum_of_sample_mean automatically ignores statistics where loss_mask = 0. + + e.g. + To calculate a token-level metric like the ratio of catastrophic tokens, just append the orignal ratio tensor to the metrics dict. + To calculate a sequence-level metric like the ratio of vetoed sequences, you should set every value in the tensor to be 0 or 1. + """ if key not in metrics: metrics[key] = [] metrics[key].append(value.clone().detach()) -def scatter_cp_and_concat( - values: list[torch.Tensor], total_lengths: list[int], response_lengths: list[int] -) -> list[torch.Tensor]: - values = [scatter_with_cp(values[i], total_lengths[i], response_lengths[i]) for i in range(len(values))] - return torch.cat(values, dim=0) - - def calculate_veto_mask( log_ratio_for_metrics: torch.Tensor, loss_mask: torch.Tensor, @@ -57,7 +53,11 @@ def calculate_veto_mask( has_catastrophic = catastrophic_tokens.any() veto_mask = (~has_catastrophic).float().expand_as(log_ratio_for_metrics) - metrics_append(metrics, "catastrophic_fraction", has_catastrophic.int().expand_as(loss_mask)) + # TODO(jiajun): A single catastrophic token may not be enough to veto the entire sequence? + # May be we can set a threshold for the ratio of catastrophic tokens? + # If exceeds, veto the entire sequence. If not, only mask the catastrophic tokens. + metrics_append(metrics, "catastrophic_token_fraction", catastrophic_tokens.int()) + metrics_append(metrics, "catastrophic_seq_fraction", has_catastrophic.int().expand_as(loss_mask)) return veto_mask @@ -89,25 +89,16 @@ def compute_train_infer_tis_weights( train_log_probs: list[torch.Tensor], rollout_log_probs: list[torch.Tensor], loss_masks: list[torch.Tensor], - total_lengths: Optional[list[int]] = None, - response_lengths: Optional[list[int]] = None, tis_function: Callable[[torch.Tensor, torch.Tensor, Dict[str, Any]], torch.Tensor], ) -> Tuple[list[torch.Tensor], Dict[str, Any]]: """ Compute the truncated importance sampling (TIS) weights and metrics. - - Adapted from: - - https://fengyao.notion.site/off-policy-rl#279721e3f6c48092bbe2fcfe0e9c6b33 - https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda - Args: train_log_probs: List of log probs from training backend, one tensor per sequence. rollout_log_probs: List of log probs from inference backend, one tensor per sequence. loss_masks: List of loss masks, one tensor per sequence. Note that for single turn RL, the loss_mask is [1] * response_length for each sequence For multi turn RL, the tool response will be marked as 0 in the loss_mask. - response_lengths: The length of the response for each sequence. Returns: weights: The importance sampling weights. [batch_size, seq_len] @@ -128,12 +119,6 @@ def compute_train_infer_tis_weights( len(train_log_probs) == len(rollout_log_probs) == len(loss_masks) ), f"Input lists must have same length: {len(train_log_probs)} vs {len(rollout_log_probs)} vs {len(loss_masks)}" - if total_lengths is not None: - assert response_lengths is not None, "response_lengths must be provided when total_lengths is set" - assert len(total_lengths) == len( - train_log_probs - ), f"total_lengths must match number of sequences, got {len(total_lengths)} vs {len(train_log_probs)}" - for i, (train, rollout, mask) in enumerate(zip(train_log_probs, rollout_log_probs, loss_masks)): assert ( train.shape == rollout.shape == mask.shape @@ -181,9 +166,4 @@ def compute_train_infer_tis_weights( weights = weights.detach() all_weights.append(weights) - all_weights = scatter_cp_and_concat(all_weights, total_lengths, response_lengths) - for key, values in metrics.items(): - values = scatter_cp_and_concat(values, total_lengths, response_lengths) - metrics[key] = values - return all_weights, metrics From 74ea79195df64283a1562e351a8b537dfd532d4f Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Tue, 14 Oct 2025 01:44:43 +0000 Subject: [PATCH 32/58] revert fsdp tis --- scripts/run-qwen3-30B-A3B.sh | 10 ++++-- slime/backends/fsdp_utils/actor.py | 58 ++++++++---------------------- 2 files changed, 23 insertions(+), 45 deletions(-) diff --git a/scripts/run-qwen3-30B-A3B.sh b/scripts/run-qwen3-30B-A3B.sh index 146b45dc22..3511b9d8dd 100644 --- a/scripts/run-qwen3-30B-A3B.sh +++ b/scripts/run-qwen3-30B-A3B.sh @@ -86,8 +86,14 @@ GRPO_ARGS=( --entropy-coef 0.00 --eps-clip 0.2 --eps-clip-high 0.28 - --use-tis - --tis-level geometric +) + +TIS_ARGS=( + --use-train-infer-tis + --train-infer-tis-level token + --train-infer-tis-mode clip + --train-infer-tis-eps-clip 0.2 + --train-infer-tis-veto-threshold 1e-4 ) OPTIMIZER_ARGS=( diff --git a/slime/backends/fsdp_utils/actor.py b/slime/backends/fsdp_utils/actor.py index 79af768f1d..f0c954c10e 100644 --- a/slime/backends/fsdp_utils/actor.py +++ b/slime/backends/fsdp_utils/actor.py @@ -18,13 +18,11 @@ raise ImportError("FSDP v2 not available") import wandb - from slime.ray.train_actor import TrainRayActor from slime.utils.data import get_minimum_num_micro_batch_size, process_rollout_data from slime.utils.distributed_utils import get_gloo_group from slime.utils.ppo_utils import compute_approx_kl, compute_policy_loss from slime.utils.timer import Timer, timer -from slime.utils.tis import compute_kl_metrics, compute_tis_weights from slime.utils.wandb_utils import init_wandb_secondary from .data_packing import pack_sequences, unpack_sequences @@ -323,7 +321,8 @@ def train(self, rollout_id, rollout_data_ref): # Apply TIS before sample mean calculation if self.args.use_tis: # Initialize TIS variables - tis_weights = None + tis = None + tis_clipfrac = None ois = None # Apply TIS off-policy correction using importance sampling assert all( @@ -333,41 +332,17 @@ def train(self, rollout_id, rollout_data_ref): for batch in unpacked_batches ), "rollout_log_probs must be provided as non-empty torch.Tensor for TIS" - rollout_log_probs = torch.cat([batch["rollout_log_probs"] for batch in unpacked_batches], dim=0).to( - device=log_probs.device - ) - - # Build eos mask from loss masks - loss_mask = torch.cat(loss_masks, dim=0).to(device=log_probs.device) - - upper = self.args.tis_threshold_upper - lower = self.args.tis_threshold_lower - - tis_weights, tis_metrics = compute_tis_weights( - old_log_prob=old_log_probs, - rollout_log_prob=rollout_log_probs, - loss_mask=loss_mask, - level=getattr(self.args, "tis_level", "token"), - mode=getattr(self.args, "tis_mode", "truncate"), - upper_threshold=upper, - lower_threshold=lower, - veto_threshold=getattr(self.args, "tis_veto_threshold", 1e-4), - safety_bound=getattr(self.args, "tis_safety_bound", 20.0), - response_lengths=response_lengths, - ) + rollout_log_probs = torch.cat([batch["rollout_log_probs"] for batch in unpacked_batches], dim=0) + rollout_log_probs = rollout_log_probs.to(device=log_probs.device) + tis = torch.exp(old_log_probs - rollout_log_probs) ois = (-ppo_kl).exp() - - if tis_weights is not None: - pg_loss = pg_loss * tis_weights - - # KL metrics next to TIS metrics - kl_metrics = compute_kl_metrics( - old_log_prob=old_log_probs, - rollout_log_prob=rollout_log_probs, - loss_mask=loss_mask, - response_lengths=response_lengths, + tis_clip = torch.clamp( + tis, min=getattr(self.args, "tis_clip_low", 0.1), max=getattr(self.args, "tis_clip", 2.0) ) + tis_clipfrac = tis_clip != tis + + pg_loss = pg_loss * tis_clip pg_loss = sum_of_sample_mean(pg_loss, response_lengths, loss_masks) pg_clipfrac = sum_of_sample_mean(pg_clipfrac, response_lengths, loss_masks) @@ -401,15 +376,12 @@ def train(self, rollout_id, rollout_data_ref): if self.args.use_kl_loss: reported["kl_loss"] = kl_loss.detach() - if self.args.use_tis and tis_weights is not None: + if self.args.use_tis and tis is not None: + reported["tis"] = sum_of_sample_mean(tis, response_lengths, loss_masks).detach() reported["ois"] = sum_of_sample_mean(ois, response_lengths, loss_masks).detach() - # Report all TIS and KL metrics uniformly, filtering out non-numeric values - for k, v in {**tis_metrics, **kl_metrics}.items(): - if torch.is_tensor(v): - reported[k] = v.detach() - elif isinstance(v, (int, float)): - reported[k] = torch.tensor(v, device=log_probs.device) - # Skip string and other non-numeric types + reported["tis_clipfrac"] = sum_of_sample_mean( + tis_clipfrac.float(), response_lengths, loss_masks + ).detach() # Scale loss for gradient accumulation loss = loss * dist.get_world_size() / self.args.global_batch_size From 7e4ee128a8bdc6ddb2def80ec1e5a5ff7f12531a Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Tue, 14 Oct 2025 03:49:59 +0000 Subject: [PATCH 33/58] fix argument names from tis to is --- scripts/run-qwen3-30B-A3B.sh | 14 ++++---- slime/backends/megatron_utils/loss.py | 46 +++++++++++++-------------- slime/utils/arguments.py | 23 +++++++------- slime/utils/tis.py | 16 +++++----- 4 files changed, 50 insertions(+), 49 deletions(-) diff --git a/scripts/run-qwen3-30B-A3B.sh b/scripts/run-qwen3-30B-A3B.sh index 3511b9d8dd..37bad293c2 100644 --- a/scripts/run-qwen3-30B-A3B.sh +++ b/scripts/run-qwen3-30B-A3B.sh @@ -65,7 +65,7 @@ PERF_ARGS=( --tensor-model-parallel-size 2 --sequence-parallel --pipeline-model-parallel-size 1 - --context-parallel-size 1 + --context-parallel-size 2 --expert-model-parallel-size 4 --expert-tensor-parallel-size 1 @@ -88,12 +88,12 @@ GRPO_ARGS=( --eps-clip-high 0.28 ) -TIS_ARGS=( - --use-train-infer-tis - --train-infer-tis-level token - --train-infer-tis-mode clip - --train-infer-tis-eps-clip 0.2 - --train-infer-tis-veto-threshold 1e-4 +IS_ARGS=( + --use-train-infer-is + --train-infer-is-level token + --train-infer-is-mode clip + --train-infer-is-eps-clip 0.2 + --train-infer-is-veto-threshold 1e-4 ) OPTIMIZER_ARGS=( diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 8a122e4458..c40b84d57a 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -15,7 +15,7 @@ get_reinforce_plus_plus_baseline_advantages, get_reinforce_plus_plus_returns, ) -from slime.utils.tis import clip, clip_to_zero, compute_train_infer_tis_weights, truncate +from slime.utils.tis import clip, clip_mask, compute_train_infer_is_weights, truncate from .cp_utils import all_gather_with_cp, get_logits_and_tokens_offset_with_cp, get_sum_of_sample_mean, scatter_with_cp @@ -307,7 +307,7 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): pg_loss, pg_clipfrac = compute_policy_loss(ppo_kl, advantages, args.eps_clip, args.eps_clip_high) # Apply TIS off-policy correction using importance sampling if enabled - if args.use_train_infer_tis: + if args.use_train_infer_is: assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS" rollout_log_probs = batch["rollout_log_probs"] @@ -326,37 +326,37 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): mode: how to handle the importance sampling weights exceeding the thresholds. - "truncated": cap the importance sampling weights at the upper threshold https://fengyao.notion.site/off-policy-rl#279721e3f6c48092bbe2fcfe0e9c6b33 - - "clip_to_zero": zero the importance sampling weights outside the [lower, upper] range. + - "clip_mask": zero the importance sampling weights outside the [lower, upper] range. https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda - "clip": clip the importance sampling weights to the [lower, upper] range. """ - if args.train_infer_tis_mode == "clip_to_zero": + if args.train_infer_is_mode == "clip_mask": assert ( - args.train_infer_tis_eps_clip is not None and args.train_infer_tis_eps_clip_high is not None + args.train_infer_is_eps_clip is not None and args.train_infer_is_eps_clip_high is not None ), "eps_clip and eps_clip_high must be provided" - tis_function = partial( - clip_to_zero, eps_clip=args.train_infer_tis_eps_clip, eps_clip_high=args.train_infer_tis_eps_clip_high + is_function = partial( + clip_mask, eps_clip=args.train_infer_is_eps_clip, eps_clip_high=args.train_infer_is_eps_clip_high ) - elif args.train_infer_tis_mode == "clip": + elif args.train_infer_is_mode == "clip": assert ( - args.train_infer_tis_eps_clip is not None and args.train_infer_tis_eps_clip_high is not None + args.train_infer_is_eps_clip is not None and args.train_infer_is_eps_clip_high is not None ), "eps_clip and eps_clip_high must be provided" - tis_function = partial( - clip, eps_clip=args.train_infer_tis_eps_clip, eps_clip_high=args.train_infer_tis_eps_clip_high + is_function = partial( + clip, eps_clip=args.train_infer_is_eps_clip, eps_clip_high=args.train_infer_is_eps_clip_high ) - elif args.train_infer_tis_mode == "truncate": - assert args.train_infer_tis_eps_clip is not None, "eps_clip must be provided" - tis_function = partial(truncate, eps=args.train_infer_tis_eps_clip) + elif args.train_infer_is_mode == "truncate": + assert args.train_infer_is_eps_clip is not None, "eps_clip must be provided" + is_function = partial(truncate, eps=args.train_infer_is_eps_clip) else: - raise ValueError(f"Unsupported train_infer_tis_mode: {args.train_infer_tis_mode}") + raise ValueError(f"Unsupported train_infer_is_mode: {args.train_infer_is_mode}") - tis_weights, tis_metrics = compute_train_infer_tis_weights( + is_weights, is_metrics = compute_train_infer_is_weights( args=args, train_log_probs=full_old_log_probs, rollout_log_probs=full_rollout_log_probs, loss_masks=batch["loss_masks"], - tis_function=tis_function, + is_function=is_function, ) def scatter_cp_and_concat( @@ -366,13 +366,13 @@ def scatter_cp_and_concat( values = [scatter_with_cp(values[i], total_lengths[i], response_lengths[i]) for i in range(len(values))] return torch.cat(values, dim=0) - tis_weights = scatter_cp_and_concat(tis_weights, total_lengths, response_lengths) - for key, values in tis_metrics.items(): + is_weights = scatter_cp_and_concat(is_weights, total_lengths, response_lengths) + for key, values in is_metrics.items(): values = scatter_cp_and_concat(values, total_lengths, response_lengths) - tis_metrics[key] = values + is_metrics[key] = values ois = (-ppo_kl).exp() - pg_loss = pg_loss * tis_weights + pg_loss = pg_loss * is_weights pg_loss = sum_of_sample_mean(pg_loss) pg_clipfrac = sum_of_sample_mean(pg_clipfrac) @@ -412,10 +412,10 @@ def scatter_cp_and_concat( if args.use_kl_loss: reported_loss["kl_loss"] = kl_loss.clone().detach() - if args.use_train_infer_tis: + if args.use_train_infer_is: # Backward compatible basic logs reported_loss["ois"] = sum_of_sample_mean(ois).clone().detach() - for metric_key, metric_value in tis_metrics.items(): + for metric_key, metric_value in is_metrics.items(): key_name = f"train_infer_{metric_key}" reported_loss[key_name] = sum_of_sample_mean(metric_value) diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index b8fbe7cfae..bc2ebf253d 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -25,10 +25,11 @@ def reset_arg(parser, name, **kwargs): parser.add_argument(name, **kwargs) -def add_tis_arguments(parser: argparse.ArgumentParser): - # Off-Policy Correction using Importance Sampling: https://fengyao.notion.site/off-policy-rl +def add_is_arguments(parser: argparse.ArgumentParser): + # Off-Policy Correction arguments for importance sampling + # training/inference importance sampling parser.add_argument( - "--use-train-infer-tis", + "--use-train-infer-is", action="store_true", default=False, help="Enable TIS from https://fengyao.notion.site/off-policy-rl for off-policy importance sampling.", @@ -36,7 +37,7 @@ def add_tis_arguments(parser: argparse.ArgumentParser): # Extended TIS controls (levels/modes/thresholds) with backward compatibility parser.add_argument( - "--train-infer-tis-level", + "--train-infer-is-level", type=str, choices=["token", "sequence", "geometric"], default="token", @@ -46,7 +47,7 @@ def add_tis_arguments(parser: argparse.ArgumentParser): ), ) parser.add_argument( - "--train-infer-tis-mode", + "--train-infer-is-mode", type=str, choices=["truncate", "clip"], default="truncate", @@ -56,19 +57,19 @@ def add_tis_arguments(parser: argparse.ArgumentParser): ), ) parser.add_argument( - "--train-infer-tis-eps-clip", + "--train-infer-is-eps-clip", type=float, default=0.2, help=("Lower clip threshold for IS weights. Default is 0.2. "), ) parser.add_argument( - "--train-infer-tis-eps-clip-high", + "--train-infer-is-eps-clip-high", type=float, default=None, help=("Upper clip threshold for IS weights. Default is None. Truncate mode will not use this value."), ) parser.add_argument( - "--train-infer-tis-veto-threshold", + "--train-infer-is-veto-threshold", type=float, default=None, help=("Per-token veto threshold. If any token ratio < this or > 1/this, zero the entire sequence weight."), @@ -1014,7 +1015,7 @@ def add_ci_arguments(parser): parser = add_reward_model_arguments(parser) parser = add_rollout_buffer_arguments(parser) parser = add_ci_arguments(parser) - parser = add_tis_arguments(parser) + parser = add_is_arguments(parser) # For megatron parser = add_custom_megatron_plugins_arguments(parser) @@ -1148,8 +1149,8 @@ def slime_validate_args(args): if args.eps_clip_high is None: args.eps_clip_high = args.eps_clip - if args.train_infer_tis_eps_clip_high is None: - args.train_infer_tis_eps_clip_high = args.train_infer_tis_eps_clip + if args.train_infer_is_eps_clip_high is None: + args.train_infer_is_eps_clip_high = args.train_infer_is_eps_clip if args.eval_reward_key is None: args.eval_reward_key = args.reward_key diff --git a/slime/utils/tis.py b/slime/utils/tis.py index adc89b35e9..1761f7c2e0 100644 --- a/slime/utils/tis.py +++ b/slime/utils/tis.py @@ -74,7 +74,7 @@ def clip( return weights.clamp(1 - eps_clip, 1 + eps_clip_high) * loss_mask -def clip_to_zero( +def clip_mask( weights: torch.Tensor, loss_mask: torch.Tensor, metrics: Dict[str, Any], *, eps_clip: float, eps_clip_high: float ) -> torch.Tensor: metrics_append(metrics, "clip_fraction_low", (weights < 1 - eps_clip).int()) @@ -83,13 +83,13 @@ def clip_to_zero( return weights * clip_mask * loss_mask -def compute_train_infer_tis_weights( +def compute_train_infer_is_weights( args, *, train_log_probs: list[torch.Tensor], rollout_log_probs: list[torch.Tensor], loss_masks: list[torch.Tensor], - tis_function: Callable[[torch.Tensor, torch.Tensor, Dict[str, Any]], torch.Tensor], + is_function: Callable[[torch.Tensor, torch.Tensor, Dict[str, Any]], torch.Tensor], ) -> Tuple[list[torch.Tensor], Dict[str, Any]]: """ Compute the truncated importance sampling (TIS) weights and metrics. @@ -111,7 +111,7 @@ def compute_train_infer_tis_weights( - "sequence": product over tokens, unbiased but high variance. - "geometric": geometric mean over tokens, biased, medium variance. """ - level: str = args.train_infer_tis_level + level: str = args.train_infer_is_level metrics: Dict[str, Any] = {} # Validate input lists have same length and each sequence has matching shapes @@ -151,15 +151,15 @@ def compute_train_infer_tis_weights( weights = torch.exp(log_ratio_safe) # mask out catastrophic tokens - if args.train_infer_tis_veto_threshold is not None: + if args.train_infer_is_veto_threshold is not None: veto_mask = calculate_veto_mask( - log_ratio_for_metrics, loss_mask, args.train_infer_tis_veto_threshold, metrics + log_ratio_for_metrics, loss_mask, args.train_infer_is_veto_threshold, metrics ) metrics_append(metrics, "raw_ratio_mean", weights) - weights = tis_function(weights, loss_mask, metrics) + weights = is_function(weights, loss_mask, metrics) metrics_append(metrics, "ratio_mean_after_tis", weights) - if args.train_infer_tis_veto_threshold is not None: + if args.train_infer_is_veto_threshold is not None: weights = weights * veto_mask metrics_append(metrics, "ratio_mean_after_veto_mask", weights) From 3fd863b769f4cf0f0d5e1e5065c5ce888df3a5c3 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Tue, 14 Oct 2025 05:55:03 +0000 Subject: [PATCH 34/58] Update clip bound name and veto impl to fully follow the paper --- slime/backends/megatron_utils/cp_utils.py | 2 +- slime/backends/megatron_utils/loss.py | 33 +------------ slime/utils/arguments.py | 26 +++++----- slime/utils/tis.py | 59 +++++++++++++++++------ 4 files changed, 61 insertions(+), 59 deletions(-) diff --git a/slime/backends/megatron_utils/cp_utils.py b/slime/backends/megatron_utils/cp_utils.py index 8d44cd52ce..e0f06847c0 100644 --- a/slime/backends/megatron_utils/cp_utils.py +++ b/slime/backends/megatron_utils/cp_utils.py @@ -157,7 +157,7 @@ def zero(len): def scatter_with_cp(full_tensor: torch.Tensor, total_length: int, response_length: int) -> torch.Tensor: """ - Inverse of `all_gather_with_cp`. Returns the chunk for the current CP rank. + Inverse of `all_gather_with_cp`. Slice the full tensor and return the corresponding chunk on each CP rank. """ cp_size = mpu.get_context_parallel_world_size() if cp_size == 1: diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index c40b84d57a..eb28230e26 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -1,4 +1,3 @@ -from functools import partial from typing import Union import torch @@ -15,7 +14,7 @@ get_reinforce_plus_plus_baseline_advantages, get_reinforce_plus_plus_returns, ) -from slime.utils.tis import clip, clip_mask, compute_train_infer_is_weights, truncate +from slime.utils.tis import compute_train_infer_is_weights from .cp_utils import all_gather_with_cp, get_logits_and_tokens_offset_with_cp, get_sum_of_sample_mean, scatter_with_cp @@ -322,41 +321,11 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): for old_log_prob, total_length, response_length in zip(old_log_probs, total_lengths, response_lengths) ] - """ - mode: how to handle the importance sampling weights exceeding the thresholds. - - "truncated": cap the importance sampling weights at the upper threshold - https://fengyao.notion.site/off-policy-rl#279721e3f6c48092bbe2fcfe0e9c6b33 - - "clip_mask": zero the importance sampling weights outside the [lower, upper] range. - https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda - - "clip": clip the importance sampling weights to the [lower, upper] range. - """ - - if args.train_infer_is_mode == "clip_mask": - assert ( - args.train_infer_is_eps_clip is not None and args.train_infer_is_eps_clip_high is not None - ), "eps_clip and eps_clip_high must be provided" - is_function = partial( - clip_mask, eps_clip=args.train_infer_is_eps_clip, eps_clip_high=args.train_infer_is_eps_clip_high - ) - elif args.train_infer_is_mode == "clip": - assert ( - args.train_infer_is_eps_clip is not None and args.train_infer_is_eps_clip_high is not None - ), "eps_clip and eps_clip_high must be provided" - is_function = partial( - clip, eps_clip=args.train_infer_is_eps_clip, eps_clip_high=args.train_infer_is_eps_clip_high - ) - elif args.train_infer_is_mode == "truncate": - assert args.train_infer_is_eps_clip is not None, "eps_clip must be provided" - is_function = partial(truncate, eps=args.train_infer_is_eps_clip) - else: - raise ValueError(f"Unsupported train_infer_is_mode: {args.train_infer_is_mode}") - is_weights, is_metrics = compute_train_infer_is_weights( args=args, train_log_probs=full_old_log_probs, rollout_log_probs=full_rollout_log_probs, loss_masks=batch["loss_masks"], - is_function=is_function, ) def scatter_cp_and_concat( diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index bc2ebf253d..4ad3c70506 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -49,30 +49,34 @@ def add_is_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--train-infer-is-mode", type=str, - choices=["truncate", "clip"], + choices=["truncate", "clip_mask", "clip"], default="truncate", help=( - "Handling mode for IS weights: truncate (cap upper bound, TIS) or clip " - "(zero outside [lower, upper], CIS)." + "Handling mode for IS weights:" + "truncate (cap upper bound, TIS)," + "clip_mask (zero outside [lower, upper], CIS)," + "clip (clip to [lower, upper], CIS)." ), ) parser.add_argument( - "--train-infer-is-eps-clip", + "--train-infer-is-lower-bound", type=float, - default=0.2, - help=("Lower clip threshold for IS weights. Default is 0.2. "), + default=0.5, + help=( + "For clip or clip_mask mode, the lower bound of the IS weights. For truncate mode, it will not be used." + ), ) parser.add_argument( - "--train-infer-is-eps-clip-high", + "--train-infer-is-upper-bound", type=float, - default=None, - help=("Upper clip threshold for IS weights. Default is None. Truncate mode will not use this value."), + default=2.0, + help=("For truncate, clip, and clip_mask mode, the upper bound of the IS weights."), ) parser.add_argument( "--train-infer-is-veto-threshold", type=float, - default=None, - help=("Per-token veto threshold. If any token ratio < this or > 1/this, zero the entire sequence weight."), + default=1e-4, + help=("Per-token veto threshold. If any token ratio < this, zero the entire sequence weight."), ) return parser diff --git a/slime/utils/tis.py b/slime/utils/tis.py index 1761f7c2e0..569ad79dea 100644 --- a/slime/utils/tis.py +++ b/slime/utils/tis.py @@ -47,9 +47,7 @@ def calculate_veto_mask( return torch.ones_like(log_ratio_for_metrics) log_veto_threshold = torch.log(torch.tensor(veto_threshold, device=log_ratio_for_metrics.device)) # For each sequence, if it has any catastrophic tokens, return 0 for the sequence - catastrophic_tokens = ( - (log_ratio_for_metrics < log_veto_threshold) | (log_ratio_for_metrics > -log_veto_threshold) - ) & loss_mask.bool() + catastrophic_tokens = ((log_ratio_for_metrics < log_veto_threshold)) & loss_mask.bool() has_catastrophic = catastrophic_tokens.any() veto_mask = (~has_catastrophic).float().expand_as(log_ratio_for_metrics) @@ -61,25 +59,27 @@ def calculate_veto_mask( return veto_mask -def truncate(weights: torch.Tensor, loss_mask: torch.Tensor, metrics: Dict[str, Any], *, eps: float) -> torch.Tensor: - metrics_append(metrics, "truncate_fraction", (weights > eps).int()) - return weights.clamp(0, eps) * loss_mask +def truncate( + weights: torch.Tensor, loss_mask: torch.Tensor, metrics: Dict[str, Any], *, upper_bound: float +) -> torch.Tensor: + metrics_append(metrics, "truncate_fraction", (weights > upper_bound).int()) + return weights.clamp(0, upper_bound) * loss_mask def clip( - weights: torch.Tensor, loss_mask: torch.Tensor, metrics: Dict[str, Any], *, eps_clip: float, eps_clip_high: float + weights: torch.Tensor, loss_mask: torch.Tensor, metrics: Dict[str, Any], *, lower_bound: float, upper_bound: float ) -> torch.Tensor: - metrics_append(metrics, "clip_fraction_low", (weights < 1 - eps_clip).int()) - metrics_append(metrics, "clip_fraction_high", (weights > 1 + eps_clip_high).int()) - return weights.clamp(1 - eps_clip, 1 + eps_clip_high) * loss_mask + metrics_append(metrics, "clip_fraction_low", (weights < lower_bound).int()) + metrics_append(metrics, "clip_fraction_high", (weights > upper_bound).int()) + return weights.clamp(lower_bound, upper_bound) * loss_mask def clip_mask( - weights: torch.Tensor, loss_mask: torch.Tensor, metrics: Dict[str, Any], *, eps_clip: float, eps_clip_high: float + weights: torch.Tensor, loss_mask: torch.Tensor, metrics: Dict[str, Any], *, lower_bound: float, upper_bound: float ) -> torch.Tensor: - metrics_append(metrics, "clip_fraction_low", (weights < 1 - eps_clip).int()) - metrics_append(metrics, "clip_fraction_high", (weights > 1 + eps_clip_high).int()) - clip_mask = (weights >= 1 - eps_clip) & (weights <= 1 + eps_clip_high) + metrics_append(metrics, "clip_fraction_low", (weights < lower_bound).int()) + metrics_append(metrics, "clip_fraction_high", (weights > upper_bound).int()) + clip_mask = (weights >= lower_bound) & (weights <= upper_bound) return weights * clip_mask * loss_mask @@ -157,7 +157,36 @@ def compute_train_infer_is_weights( ) metrics_append(metrics, "raw_ratio_mean", weights) - weights = is_function(weights, loss_mask, metrics) + + """ + mode: how to handle the importance sampling weights exceeding the thresholds. + - "truncated": cap the importance sampling weights at the upper threshold + https://fengyao.notion.site/off-policy-rl#279721e3f6c48092bbe2fcfe0e9c6b33 + - "clip_mask": zero the importance sampling weights outside the [lower, upper] range. + https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda + - "clip": clip the importance sampling weights to the [lower, upper] range. + """ + if args.train_infer_is_mode == "clip_mask": + weights = clip_mask( + weights, + loss_mask, + metrics, + lower_bound=args.train_infer_is_lower_bound, + upper_bound=args.train_infer_is_upper_bound, + ) + elif args.train_infer_is_mode == "clip": + weights = clip( + weights, + loss_mask, + metrics, + lower_bound=args.train_infer_is_lower_bound, + upper_bound=args.train_infer_is_upper_bound, + ) + elif args.train_infer_is_mode == "truncate": + weights = truncate(weights, loss_mask, metrics, upper_bound=args.train_infer_is_upper_bound) + else: + raise ValueError(f"Unsupported train_infer_is_mode: {args.train_infer_is_mode}") + metrics_append(metrics, "ratio_mean_after_tis", weights) if args.train_infer_is_veto_threshold is not None: weights = weights * veto_mask From c047b2c00548ab809f30a7c2de1cca16a50d7882 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Tue, 14 Oct 2025 05:58:17 +0000 Subject: [PATCH 35/58] upd clip_mask to mask --- slime/utils/arguments.py | 10 ++++------ slime/utils/tis.py | 22 +++++++++++----------- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 4ad3c70506..ea60881b8b 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -49,12 +49,12 @@ def add_is_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--train-infer-is-mode", type=str, - choices=["truncate", "clip_mask", "clip"], + choices=["truncate", "mask", "clip"], default="truncate", help=( "Handling mode for IS weights:" "truncate (cap upper bound, TIS)," - "clip_mask (zero outside [lower, upper], CIS)," + "mask (zero outside [lower, upper], CIS)," "clip (clip to [lower, upper], CIS)." ), ) @@ -62,15 +62,13 @@ def add_is_arguments(parser: argparse.ArgumentParser): "--train-infer-is-lower-bound", type=float, default=0.5, - help=( - "For clip or clip_mask mode, the lower bound of the IS weights. For truncate mode, it will not be used." - ), + help=("For clip or mask mode, the lower bound of the IS weights. For truncate mode, it will not be used."), ) parser.add_argument( "--train-infer-is-upper-bound", type=float, default=2.0, - help=("For truncate, clip, and clip_mask mode, the upper bound of the IS weights."), + help=("For truncate, clip, and mask mode, the upper bound of the IS weights."), ) parser.add_argument( "--train-infer-is-veto-threshold", diff --git a/slime/utils/tis.py b/slime/utils/tis.py index 569ad79dea..6060834adb 100644 --- a/slime/utils/tis.py +++ b/slime/utils/tis.py @@ -74,13 +74,13 @@ def clip( return weights.clamp(lower_bound, upper_bound) * loss_mask -def clip_mask( +def mask( weights: torch.Tensor, loss_mask: torch.Tensor, metrics: Dict[str, Any], *, lower_bound: float, upper_bound: float ) -> torch.Tensor: - metrics_append(metrics, "clip_fraction_low", (weights < lower_bound).int()) - metrics_append(metrics, "clip_fraction_high", (weights > upper_bound).int()) - clip_mask = (weights >= lower_bound) & (weights <= upper_bound) - return weights * clip_mask * loss_mask + metrics_append(metrics, "mask_fraction_low", (weights < lower_bound).int()) + metrics_append(metrics, "mask_fraction_high", (weights > upper_bound).int()) + mask = (weights >= lower_bound) & (weights <= upper_bound) + return weights * mask * loss_mask def compute_train_infer_is_weights( @@ -119,10 +119,10 @@ def compute_train_infer_is_weights( len(train_log_probs) == len(rollout_log_probs) == len(loss_masks) ), f"Input lists must have same length: {len(train_log_probs)} vs {len(rollout_log_probs)} vs {len(loss_masks)}" - for i, (train, rollout, mask) in enumerate(zip(train_log_probs, rollout_log_probs, loss_masks)): + for i, (train, rollout, loss_mask) in enumerate(zip(train_log_probs, rollout_log_probs, loss_masks)): assert ( - train.shape == rollout.shape == mask.shape - ), f"Sequence {i}: shapes must match - train: {train.shape}, rollout: {rollout.shape}, mask: {mask.shape}" + train.shape == rollout.shape == loss_mask.shape + ), f"Sequence {i}: shapes must match - train: {train.shape}, rollout: {rollout.shape}, loss_mask: {loss_mask.shape}" # TODO: Get device from first tensor and apply to tensors # device = train_log_probs[0].device @@ -162,12 +162,12 @@ def compute_train_infer_is_weights( mode: how to handle the importance sampling weights exceeding the thresholds. - "truncated": cap the importance sampling weights at the upper threshold https://fengyao.notion.site/off-policy-rl#279721e3f6c48092bbe2fcfe0e9c6b33 - - "clip_mask": zero the importance sampling weights outside the [lower, upper] range. + - "mask": zero the importance sampling weights outside the [lower, upper] range. https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda - "clip": clip the importance sampling weights to the [lower, upper] range. """ - if args.train_infer_is_mode == "clip_mask": - weights = clip_mask( + if args.train_infer_is_mode == "mask": + weights = mask( weights, loss_mask, metrics, From ca6e90e67f162e39260dbf14c0a4d89819b9cb29 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Tue, 14 Oct 2025 06:46:50 +0000 Subject: [PATCH 36/58] upd arg name, fix small bug, delete scatter_with_cp --- slime/backends/megatron_utils/cp_utils.py | 26 ----------------------- slime/utils/arguments.py | 4 ++-- 2 files changed, 2 insertions(+), 28 deletions(-) diff --git a/slime/backends/megatron_utils/cp_utils.py b/slime/backends/megatron_utils/cp_utils.py index e0f06847c0..3ba6adfdf2 100644 --- a/slime/backends/megatron_utils/cp_utils.py +++ b/slime/backends/megatron_utils/cp_utils.py @@ -155,32 +155,6 @@ def zero(len): return full_tensor -def scatter_with_cp(full_tensor: torch.Tensor, total_length: int, response_length: int) -> torch.Tensor: - """ - Inverse of `all_gather_with_cp`. Slice the full tensor and return the corresponding chunk on each CP rank. - """ - cp_size = mpu.get_context_parallel_world_size() - if cp_size == 1: - return full_tensor - - prompt_length = total_length - response_length - _, _, logits_offset, _ = get_logits_and_tokens_offset_with_cp(total_length, response_length) - - def _slice(offset: tuple[int, int]) -> torch.Tensor: - start, end = offset - if start >= end: - return full_tensor.new_empty((0,) + full_tensor.shape[1:]) - start -= prompt_length - 1 - end -= prompt_length - 1 - if start >= end: - return full_tensor.new_empty((0,) + full_tensor.shape[1:]) - return full_tensor[start:end] - - chunk_0 = _slice(logits_offset[0]) - chunk_1 = _slice(logits_offset[1]) - return torch.cat([chunk_0, chunk_1], dim=0) - - def slice_with_cp(tokens: torch.Tensor, pad_value): cp_rank = mpu.get_context_parallel_rank() cp_size = mpu.get_context_parallel_world_size() diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index ea60881b8b..16a1c6e8e3 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -1151,8 +1151,8 @@ def slime_validate_args(args): if args.eps_clip_high is None: args.eps_clip_high = args.eps_clip - if args.train_infer_is_eps_clip_high is None: - args.train_infer_is_eps_clip_high = args.train_infer_is_eps_clip + if args.train_infer_is_lower_bound is None: + args.train_infer_is_lower_bound = args.train_infer_is_upper_bound if args.eval_reward_key is None: args.eval_reward_key = args.reward_key From c0aeb32f3cfbf2f3b32b7ca650fa063beffd8748 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Tue, 14 Oct 2025 06:49:43 +0000 Subject: [PATCH 37/58] move cp logic to tis.py --- slime/backends/megatron_utils/loss.py | 33 +++--------- slime/utils/tis.py | 73 +++++++++++++++++++++++++-- 2 files changed, 75 insertions(+), 31 deletions(-) diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index eb28230e26..82d8b88e00 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -14,9 +14,9 @@ get_reinforce_plus_plus_baseline_advantages, get_reinforce_plus_plus_returns, ) -from slime.utils.tis import compute_train_infer_is_weights +from slime.utils.tis import compute_train_infer_is_weights_with_cp -from .cp_utils import all_gather_with_cp, get_logits_and_tokens_offset_with_cp, get_sum_of_sample_mean, scatter_with_cp +from .cp_utils import all_gather_with_cp, get_logits_and_tokens_offset_with_cp, get_sum_of_sample_mean def get_responses( @@ -312,34 +312,15 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): rollout_log_probs = batch["rollout_log_probs"] old_log_probs = batch["log_probs"] - full_rollout_log_probs = [ - all_gather_with_cp(log_prob, total_length, response_length) - for log_prob, total_length, response_length in zip(rollout_log_probs, total_lengths, response_lengths) - ] - full_old_log_probs = [ - all_gather_with_cp(old_log_prob, total_length, response_length) - for old_log_prob, total_length, response_length in zip(old_log_probs, total_lengths, response_lengths) - ] - - is_weights, is_metrics = compute_train_infer_is_weights( + is_weights, is_metrics = compute_train_infer_is_weights_with_cp( args=args, - train_log_probs=full_old_log_probs, - rollout_log_probs=full_rollout_log_probs, + train_log_probs=old_log_probs, + rollout_log_probs=rollout_log_probs, loss_masks=batch["loss_masks"], + total_lengths=total_lengths, + response_lengths=response_lengths, ) - def scatter_cp_and_concat( - values: list[torch.Tensor], total_lengths: list[int], response_lengths: list[int] - ) -> list[torch.Tensor]: - # reshape value to the sequence size of the cp rank. - values = [scatter_with_cp(values[i], total_lengths[i], response_lengths[i]) for i in range(len(values))] - return torch.cat(values, dim=0) - - is_weights = scatter_cp_and_concat(is_weights, total_lengths, response_lengths) - for key, values in is_metrics.items(): - values = scatter_cp_and_concat(values, total_lengths, response_lengths) - is_metrics[key] = values - ois = (-ppo_kl).exp() pg_loss = pg_loss * is_weights diff --git a/slime/utils/tis.py b/slime/utils/tis.py index 6060834adb..368144f904 100644 --- a/slime/utils/tis.py +++ b/slime/utils/tis.py @@ -1,7 +1,9 @@ -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple import torch +from slime.backends.megatron_utils.cp_utils import all_gather_with_cp, slice_log_prob_with_cp + def masked_sum( tensor: torch.Tensor, mask: torch.Tensor, dim: Optional[int] = None, keepdim: bool = False @@ -62,6 +64,7 @@ def calculate_veto_mask( def truncate( weights: torch.Tensor, loss_mask: torch.Tensor, metrics: Dict[str, Any], *, upper_bound: float ) -> torch.Tensor: + assert upper_bound is not None metrics_append(metrics, "truncate_fraction", (weights > upper_bound).int()) return weights.clamp(0, upper_bound) * loss_mask @@ -69,6 +72,7 @@ def truncate( def clip( weights: torch.Tensor, loss_mask: torch.Tensor, metrics: Dict[str, Any], *, lower_bound: float, upper_bound: float ) -> torch.Tensor: + assert lower_bound is not None and upper_bound is not None and lower_bound < upper_bound metrics_append(metrics, "clip_fraction_low", (weights < lower_bound).int()) metrics_append(metrics, "clip_fraction_high", (weights > upper_bound).int()) return weights.clamp(lower_bound, upper_bound) * loss_mask @@ -77,6 +81,7 @@ def clip( def mask( weights: torch.Tensor, loss_mask: torch.Tensor, metrics: Dict[str, Any], *, lower_bound: float, upper_bound: float ) -> torch.Tensor: + assert lower_bound is not None and upper_bound is not None and lower_bound < upper_bound metrics_append(metrics, "mask_fraction_low", (weights < lower_bound).int()) metrics_append(metrics, "mask_fraction_high", (weights > upper_bound).int()) mask = (weights >= lower_bound) & (weights <= upper_bound) @@ -89,14 +94,13 @@ def compute_train_infer_is_weights( train_log_probs: list[torch.Tensor], rollout_log_probs: list[torch.Tensor], loss_masks: list[torch.Tensor], - is_function: Callable[[torch.Tensor, torch.Tensor, Dict[str, Any]], torch.Tensor], ) -> Tuple[list[torch.Tensor], Dict[str, Any]]: """ Compute the truncated importance sampling (TIS) weights and metrics. Args: - train_log_probs: List of log probs from training backend, one tensor per sequence. - rollout_log_probs: List of log probs from inference backend, one tensor per sequence. - loss_masks: List of loss masks, one tensor per sequence. + train_log_probs: List of log probs from training backend. + rollout_log_probs: List of log probs from inference backend. + loss_masks: List of loss masks. Note that for single turn RL, the loss_mask is [1] * response_length for each sequence For multi turn RL, the tool response will be marked as 0 in the loss_mask. @@ -196,3 +200,62 @@ def compute_train_infer_is_weights( all_weights.append(weights) return all_weights, metrics + + +def compute_train_infer_is_weights_with_cp( + args, + *, + train_log_probs: list[torch.Tensor], + rollout_log_probs: list[torch.Tensor], + loss_masks: list[torch.Tensor], + total_lengths: list[int], + response_lengths: list[int], +) -> Tuple[list[torch.Tensor], Dict[str, Any]]: + """ + Compute the truncated importance sampling (TIS) weights and metrics with context parallel. + Args: + train_log_probs: List of log probs from training backend on this cp rank. + rollout_log_probs: List of log probs from inference backend on this cp rank. + loss_masks: List of loss masks. + total_lengths: List of total lengths. + response_lengths: List of response lengths. + Returns: + is_weights: The importance sampling weights. [batch_size, seq_len] + is_metrics: The metrics for the importance sampling weights. + """ + # Gather cp slice from other cp ranks + full_rollout_log_probs = [ + all_gather_with_cp(log_prob, total_length, response_length) + for log_prob, total_length, response_length in zip(rollout_log_probs, total_lengths, response_lengths) + ] + full_old_log_probs = [ + all_gather_with_cp(old_log_prob, total_length, response_length) + for old_log_prob, total_length, response_length in zip(train_log_probs, total_lengths, response_lengths) + ] + + # Main logic for is + is_weights, is_metrics = compute_train_infer_is_weights( + args=args, + train_log_probs=full_old_log_probs, + rollout_log_probs=full_rollout_log_probs, + loss_masks=loss_masks, + ) + + # Slice cp slice and concat to the full response tensor + def slice_cp_and_concat( + values: list[torch.Tensor], total_lengths: list[int], response_lengths: list[int] + ) -> list[torch.Tensor]: + # reshape value to the sequence size of the cp rank. + values = [ + # TODO: A rename of this function ? + slice_log_prob_with_cp(values[i], total_lengths[i], response_lengths[i]) + for i in range(len(values)) + ] + return torch.cat(values, dim=0) + + is_weights = slice_cp_and_concat(is_weights, total_lengths, response_lengths) + for key, values in is_metrics.items(): + values = slice_cp_and_concat(values, total_lengths, response_lengths) + is_metrics[key] = values + + return is_weights, is_metrics From 44446342abad7a4aea5fc747f1a99b7e7f8d6866 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Tue, 14 Oct 2025 06:57:40 +0000 Subject: [PATCH 38/58] add a qwen3-4b tis sample --- scripts/run-qwen3-4B-tis-example.sh | 161 ++++++++++++++++++++++++++++ 1 file changed, 161 insertions(+) create mode 100644 scripts/run-qwen3-4B-tis-example.sh diff --git a/scripts/run-qwen3-4B-tis-example.sh b/scripts/run-qwen3-4B-tis-example.sh new file mode 100644 index 0000000000..0c11a48646 --- /dev/null +++ b/scripts/run-qwen3-4B-tis-example.sh @@ -0,0 +1,161 @@ +#!/bin/bash + +# for rerun the task +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +source "${SCRIPT_DIR}/models/qwen3-4B.sh" + +CKPT_ARGS=( + --hf-checkpoint /root/Qwen3-4B + #--hf-checkpoint /root/Qwen3-4B-FP8 + --ref-load /root/Qwen3-4B_torch_dist + # --load /root/Qwen3-4B_slime/ + --save /root/Qwen3-4B_slime/ + --save-interval 2000 +) + +ROLLOUT_ARGS=( + --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + --rm-type deepscaler + --num-rollout 3000 + --rollout-batch-size 8 + --n-samples-per-prompt 4 + --rollout-max-response-len 8192 + --rollout-temperature 0.8 + + --global-batch-size 32 + --balance-data +) + +EVAL_ARGS=( + --eval-interval 20 + --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl + --n-samples-per-eval-prompt 1 + --eval-max-response-len 16384 + --eval-top-p 0.7 +) + +PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 2 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + # --micro-batch-size 1 + --use-dynamic-batch-size + --max-tokens-per-gpu 9216 +) + +GRPO_ARGS=( + --advantage-estimator grpo + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=( + --use-wandb + --wandb-project slime-is + --wandb-group qwen3-4B-is-cp2 + --wandb-key ${WANDB_KEY} +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + --sglang-mem-fraction-static 0.7 +) + +MISC_ARGS=( + # default dropout in megatron is 0.1 + --attention-dropout 0.0 + --hidden-dropout 0.0 + # should be good for model performance + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + # need to comment this when using model with MLA + --attention-backend flash +) + +TIS_ARGS=( + --use-train-infer-is + --train-infer-is-level token + --train-infer-is-mode clip + --train-infer-is-upper-bound 2.0 + --train-infer-is-lower-bound 0.5 + --train-infer-is-veto-threshold 1e-4 +) + +export CUDA_VISIBLE_DEVICES=4,5,6,7 +# launch the master node of ray in container +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 4 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +# Build the runtime environment JSON with proper variable substitution +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 4 \ + --colocate \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${TIS_ARGS[@]} \ No newline at end of file From c4441ad6ec640659f5881abf92c9fce10121933b Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Tue, 14 Oct 2025 21:22:38 +0000 Subject: [PATCH 39/58] fix code comments, file name --- slime/backends/megatron_utils/loss.py | 7 +- slime/utils/arguments.py | 12 ++- slime/utils/{tis.py => train_infer_is.py} | 122 ++++++++++++---------- 3 files changed, 75 insertions(+), 66 deletions(-) rename slime/utils/{tis.py => train_infer_is.py} (69%) diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 82d8b88e00..d579b0785d 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -309,13 +309,10 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): if args.use_train_infer_is: assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS" - rollout_log_probs = batch["rollout_log_probs"] - old_log_probs = batch["log_probs"] - is_weights, is_metrics = compute_train_infer_is_weights_with_cp( args=args, - train_log_probs=old_log_probs, - rollout_log_probs=rollout_log_probs, + train_log_probs=batch["log_probs"], + rollout_log_probs=batch["rollout_log_probs"], loss_masks=batch["loss_masks"], total_lengths=total_lengths, response_lengths=response_lengths, diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 16a1c6e8e3..251c437418 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -32,7 +32,7 @@ def add_is_arguments(parser: argparse.ArgumentParser): "--use-train-infer-is", action="store_true", default=False, - help="Enable TIS from https://fengyao.notion.site/off-policy-rl for off-policy importance sampling.", + help="Enable importance sampling, details refer to the comments of compute_train_infer_is_weights in train_infer_is.py", ) # Extended TIS controls (levels/modes/thresholds) with backward compatibility @@ -54,7 +54,7 @@ def add_is_arguments(parser: argparse.ArgumentParser): help=( "Handling mode for IS weights:" "truncate (cap upper bound, TIS)," - "mask (zero outside [lower, upper], CIS)," + "mask (zero outside [lower, upper], MIS)," "clip (clip to [lower, upper], CIS)." ), ) @@ -62,19 +62,21 @@ def add_is_arguments(parser: argparse.ArgumentParser): "--train-infer-is-lower-bound", type=float, default=0.5, - help=("For clip or mask mode, the lower bound of the IS weights. For truncate mode, it will not be used."), + help=("For mask or clip mode, the lower bound of the IS weights. For truncate mode, it will not be used."), ) parser.add_argument( "--train-infer-is-upper-bound", type=float, default=2.0, - help=("For truncate, clip, and mask mode, the upper bound of the IS weights."), + help=("For truncate, mask, or clip mode, the upper bound of the IS weights."), ) parser.add_argument( "--train-infer-is-veto-threshold", type=float, default=1e-4, - help=("Per-token veto threshold. If any token ratio < this, zero the entire sequence weight."), + help=( + "Per-token veto threshold. If any token ratio < this, zero the entire sequence weight, the sequences won't have gradient." + ), ) return parser diff --git a/slime/utils/tis.py b/slime/utils/train_infer_is.py similarity index 69% rename from slime/utils/tis.py rename to slime/utils/train_infer_is.py index 368144f904..6b2b0eef81 100644 --- a/slime/utils/tis.py +++ b/slime/utils/train_infer_is.py @@ -5,34 +5,43 @@ from slime.backends.megatron_utils.cp_utils import all_gather_with_cp, slice_log_prob_with_cp -def masked_sum( - tensor: torch.Tensor, mask: torch.Tensor, dim: Optional[int] = None, keepdim: bool = False -) -> torch.Tensor: - return (tensor * mask).sum(dim=dim, keepdim=keepdim) +def metrics_append(metrics: Dict[str, list[torch.Tensor]], key: str, value: torch.Tensor) -> None: + """ + Every metrics-dict value is a list[torch.Tensor] (variable-length tensors) with shapes exactly the same + as train_log_probs, rollout_log_probs, and loss_masks. -def masked_mean( - tensor: torch.Tensor, - mask: torch.Tensor, - dim: Optional[int] = None, - keepdim: bool = False, - eps: float = 1e-8, -) -> torch.Tensor: - total = (tensor * mask).sum(dim=dim, keepdim=keepdim) - denom = mask.sum(dim=dim, keepdim=keepdim) - return total / (denom + eps) + All metrics will be aggregated and averaged by `sum_of_sample_mean` and divided by DP size automatically + - If calculate_per_token_loss=False (default), the final results will first be averaged in each sequence, + then across all the sequences in the global batch. + - If calculate_per_token_loss=True, the final results will be the mean of all the tokens in the global batch. - -def metrics_append(metrics: Dict[str, list[torch.Tensor]], key: str, value: torch.Tensor) -> None: - """ - Any metric should be list[torch.Tensor] with size [num_seq, response_length] - All tis metrics will be aggregated and averaged by `sum_of_sample_mean` and megatron automatically - The result will be sequence-level average or token-level if `calculate_per_token_loss` is set. - You have no need to worry about loss_mask — the sum_of_sample_mean automatically ignores statistics where loss_mask = 0. + No need to specifically handle loss_mask, sum_of_sample_mean automatically ignores statistics where loss_mask = 0. e.g. - To calculate a token-level metric like the ratio of catastrophic tokens, just append the orignal ratio tensor to the metrics dict. - To calculate a sequence-level metric like the ratio of vetoed sequences, you should set every value in the tensor to be 0 or 1. + For token-level metric: + value = [ + [0.1, 0.2], + [0.1, 0.2, 0.3, 0.4, 0.5], + [0.6] + ] + When calculate_per_token_loss = False (default): + result = (0.1 + 0.2) / 2 + (0.1 + 0.2 + 0.3 + 0.4 + 0.5) / 5 + (0.6) / 1 = 0.15 + 0.3 + 0.6 = 1.05 / 3 = 0.35 + When calculate_per_token_loss = True: + result = (0.1 + 0.2 + 0.1 + 0.2 + 0.3 + 0.4 + 0.5 + 0.6) / 8 = 2.4 / 8 = 0.3 + For sequence-level metric: + original sequence lengths = [2, 5, 1] + We should expand the metrics to the length of each sequence + value = [ + [2, 2], + [5, 5, 5, 5, 5], + [1, 1] + ] + When calculate_per_token_loss = False (default): + result = (2 + 2) / 2 + (5 + 5 + 5 + 5 + 5) / 5 + (1 + 1) / 2 = 2 + 5 + 1 = 8 / 3 = 2.6665 + When calculate_per_token_loss = True: + result = (2 + 2 + 5 + 5 + 5 + 5 + 5 + 1 + 1) / 8 = 31 / 8 = 3.875 ??? + This is not what we expect. So, when calculate_per_token_loss = True, sequence-level metrics becomes invalid !! """ if key not in metrics: metrics[key] = [] @@ -43,7 +52,7 @@ def calculate_veto_mask( log_ratio_for_metrics: torch.Tensor, loss_mask: torch.Tensor, veto_threshold: Optional[float], - metrics: Dict[str, Any], + metrics: Dict[str, list[torch.Tensor]], ) -> torch.Tensor: if veto_threshold is None: return torch.ones_like(log_ratio_for_metrics) @@ -62,7 +71,7 @@ def calculate_veto_mask( def truncate( - weights: torch.Tensor, loss_mask: torch.Tensor, metrics: Dict[str, Any], *, upper_bound: float + weights: torch.Tensor, loss_mask: torch.Tensor, metrics: Dict[str, list[torch.Tensor]], upper_bound: float ) -> torch.Tensor: assert upper_bound is not None metrics_append(metrics, "truncate_fraction", (weights > upper_bound).int()) @@ -70,7 +79,11 @@ def truncate( def clip( - weights: torch.Tensor, loss_mask: torch.Tensor, metrics: Dict[str, Any], *, lower_bound: float, upper_bound: float + weights: torch.Tensor, + loss_mask: torch.Tensor, + metrics: Dict[str, list[torch.Tensor]], + lower_bound: float, + upper_bound: float, ) -> torch.Tensor: assert lower_bound is not None and upper_bound is not None and lower_bound < upper_bound metrics_append(metrics, "clip_fraction_low", (weights < lower_bound).int()) @@ -79,7 +92,11 @@ def clip( def mask( - weights: torch.Tensor, loss_mask: torch.Tensor, metrics: Dict[str, Any], *, lower_bound: float, upper_bound: float + weights: torch.Tensor, + loss_mask: torch.Tensor, + metrics: Dict[str, list[torch.Tensor]], + lower_bound: float, + upper_bound: float, ) -> torch.Tensor: assert lower_bound is not None and upper_bound is not None and lower_bound < upper_bound metrics_append(metrics, "mask_fraction_low", (weights < lower_bound).int()) @@ -109,14 +126,8 @@ def compute_train_infer_is_weights( metrics: The metrics for the importance sampling weights. """ - """ - level: The aggregation level for the importance sampling weights. - - "token": per-token importance sampling weights, biased low variance. - - "sequence": product over tokens, unbiased but high variance. - - "geometric": geometric mean over tokens, biased, medium variance. - """ level: str = args.train_infer_is_level - metrics: Dict[str, Any] = {} + metrics: Dict[str, list[torch.Tensor]] = {} # Validate input lists have same length and each sequence has matching shapes assert ( @@ -128,25 +139,25 @@ def compute_train_infer_is_weights( train.shape == rollout.shape == loss_mask.shape ), f"Sequence {i}: shapes must match - train: {train.shape}, rollout: {rollout.shape}, loss_mask: {loss_mask.shape}" - # TODO: Get device from first tensor and apply to tensors - # device = train_log_probs[0].device SAFETY_BOUND = 20.0 all_weights = [] + # handle each sequence independently for train_log_prob, rollout_log_prob, loss_mask in zip(train_log_probs, rollout_log_probs, loss_masks): raw_log_ratio = train_log_prob - rollout_log_prob loss_mask = loss_mask.float() + # level: The aggregation level for the importance sampling weights. if level == "token": - # Token-level IS + # Per-token ratio (biased) log_ratio_for_metrics = raw_log_ratio elif level == "sequence": - # Sequence-level IS - agg_log_ratio = masked_sum(raw_log_ratio, loss_mask) + # Product of ratios (unbiased) + agg_log_ratio = (raw_log_ratio * loss_mask).sum() log_ratio_for_metrics = torch.full_like(raw_log_ratio, agg_log_ratio) elif level == "geometric": - # Geometric mean IS - agg_log_ratio = masked_mean(raw_log_ratio, loss_mask) + # Geometric mean of ratios (experimental) + agg_log_ratio = (raw_log_ratio * loss_mask).sum() / torch.clamp_min(loss_mask.sum(), 1) log_ratio_for_metrics = torch.full_like(raw_log_ratio, agg_log_ratio) else: raise ValueError(f"Invalid importance sampling level: {level}") @@ -162,32 +173,31 @@ def compute_train_infer_is_weights( metrics_append(metrics, "raw_ratio_mean", weights) - """ - mode: how to handle the importance sampling weights exceeding the thresholds. - - "truncated": cap the importance sampling weights at the upper threshold - https://fengyao.notion.site/off-policy-rl#279721e3f6c48092bbe2fcfe0e9c6b33 - - "mask": zero the importance sampling weights outside the [lower, upper] range. - https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda - - "clip": clip the importance sampling weights to the [lower, upper] range. - """ - if args.train_infer_is_mode == "mask": + # mode: how to handle the importance sampling weights exceeding the thresholds. + if args.train_infer_is_mode == "truncate": + # Cap the importance sampling weights at the upper threshold + # https://fengyao.notion.site/off-policy-rl#279721e3f6c48092bbe2fcfe0e9c6b33 + weights = truncate(weights, loss_mask, metrics, args.train_infer_is_upper_bound) + elif args.train_infer_is_mode == "mask": + # Zero the importance sampling weights outside the [lower, upper] range. + # https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda weights = mask( weights, loss_mask, metrics, - lower_bound=args.train_infer_is_lower_bound, - upper_bound=args.train_infer_is_upper_bound, + args.train_infer_is_lower_bound, + args.train_infer_is_upper_bound, ) elif args.train_infer_is_mode == "clip": + # Clip the importance sampling weights to the [lower, upper] range. + # Original behavior in slime. weights = clip( weights, loss_mask, metrics, - lower_bound=args.train_infer_is_lower_bound, - upper_bound=args.train_infer_is_upper_bound, + args.train_infer_is_lower_bound, + args.train_infer_is_upper_bound, ) - elif args.train_infer_is_mode == "truncate": - weights = truncate(weights, loss_mask, metrics, upper_bound=args.train_infer_is_upper_bound) else: raise ValueError(f"Unsupported train_infer_is_mode: {args.train_infer_is_mode}") From 293cbce6dcb6198e83e998a644489127b4b32368 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Tue, 14 Oct 2025 21:56:26 +0000 Subject: [PATCH 40/58] fix small bug and arg setting --- slime/backends/megatron_utils/loss.py | 2 +- slime/utils/arguments.py | 21 +++++++++++++-------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index d579b0785d..c17ea7d607 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -14,7 +14,7 @@ get_reinforce_plus_plus_baseline_advantages, get_reinforce_plus_plus_returns, ) -from slime.utils.tis import compute_train_infer_is_weights_with_cp +from slime.utils.train_infer_is import compute_train_infer_is_weights_with_cp from .cp_utils import all_gather_with_cp, get_logits_and_tokens_offset_with_cp, get_sum_of_sample_mean diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 251c437418..ff503c2841 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -32,10 +32,11 @@ def add_is_arguments(parser: argparse.ArgumentParser): "--use-train-infer-is", action="store_true", default=False, - help="Enable importance sampling, details refer to the comments of compute_train_infer_is_weights in train_infer_is.py", + help=( + "Enable importance sampling, details refer to the comments of compute_train_infer_is_weights " + "in train_infer_is.py" + ), ) - - # Extended TIS controls (levels/modes/thresholds) with backward compatibility parser.add_argument( "--train-infer-is-level", type=str, @@ -61,8 +62,11 @@ def add_is_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--train-infer-is-lower-bound", type=float, - default=0.5, - help=("For mask or clip mode, the lower bound of the IS weights. For truncate mode, it will not be used."), + default=None, + help=( + "For mask or clip mode, the lower bound of the IS weights. For truncate mode, it will not be used. " + "If not set, it will be set to 1.0 / train_infer_is_upper_bound." + ), ) parser.add_argument( "--train-infer-is-upper-bound", @@ -73,7 +77,7 @@ def add_is_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--train-infer-is-veto-threshold", type=float, - default=1e-4, + default=None, help=( "Per-token veto threshold. If any token ratio < this, zero the entire sequence weight, the sequences won't have gradient." ), @@ -1153,8 +1157,9 @@ def slime_validate_args(args): if args.eps_clip_high is None: args.eps_clip_high = args.eps_clip - if args.train_infer_is_lower_bound is None: - args.train_infer_is_lower_bound = args.train_infer_is_upper_bound + if args.use_train_infer_is: + if args.train_infer_is_lower_bound is None: + args.train_infer_is_lower_bound = 1.0 / args.train_infer_is_upper_bound if args.eval_reward_key is None: args.eval_reward_key = args.reward_key From d75012557d233049978aa43442d72e338a0e5da4 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Tue, 14 Oct 2025 22:10:13 +0000 Subject: [PATCH 41/58] fix small bug in veto, add mask, and add a script --- scripts/4b-cp2.sh | 162 ++++++++++++++++++++++++++++++++++ slime/utils/train_infer_is.py | 27 +++--- 2 files changed, 173 insertions(+), 16 deletions(-) create mode 100644 scripts/4b-cp2.sh diff --git a/scripts/4b-cp2.sh b/scripts/4b-cp2.sh new file mode 100644 index 0000000000..0af58252bc --- /dev/null +++ b/scripts/4b-cp2.sh @@ -0,0 +1,162 @@ +#!/bin/bash + +# for rerun the task +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python +pkill -9 redis + +set -ex + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +source "${SCRIPT_DIR}/models/qwen3-4B.sh" + +CKPT_ARGS=( + --hf-checkpoint /root/Qwen3-4B + #--hf-checkpoint /root/Qwen3-4B-FP8 + --ref-load /root/Qwen3-4B_torch_dist + # --load /root/Qwen3-4B_slime/ + --save /root/Qwen3-4B_slime/ + --save-interval 20 +) + +ROLLOUT_ARGS=( + --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + --rm-type deepscaler + --num-rollout 3000 + --rollout-batch-size 8 + --n-samples-per-prompt 4 + --rollout-max-response-len 8192 + --rollout-temperature 0.8 + + --global-batch-size 32 + --balance-data +) + +EVAL_ARGS=( + # --eval-interval 20 + --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl + --n-samples-per-eval-prompt 1 + --eval-max-response-len 16384 + --eval-top-p 0.7 +) + +PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 2 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + # --micro-batch-size 1 + --use-dynamic-batch-size + --max-tokens-per-gpu 9216 +) + +GRPO_ARGS=( + --advantage-estimator grpo + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +IS_ARGS=( + --use-train-infer-is + --train-infer-is-level sequence + --train-infer-is-mode clip + --train-infer-is-upper-bound 2.0 + # --train-infer-is-lower-bound 0.5 + --train-infer-is-veto-threshold 1e-4 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=( + --use-wandb + --wandb-project slime-tis + --wandb-group qwen3-4B-tis-seq + --wandb-key ${WANDB_KEY} +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + --sglang-mem-fraction-static 0.7 +) + +MISC_ARGS=( + # default dropout in megatron is 0.1 + --attention-dropout 0.0 + --hidden-dropout 0.0 + # should be good for model performance + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + # need to comment this when using model with MLA + --attention-backend flash +) + +# launch the master node of ray in container +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +export CUDA_VISIBLE_DEVICES=4,5,6,7 +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 4 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +# Build the runtime environment JSON with proper variable substitution +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 4 \ + --colocate \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${IS_ARGS[@]} \ No newline at end of file diff --git a/slime/utils/train_infer_is.py b/slime/utils/train_infer_is.py index 6b2b0eef81..16a4531c41 100644 --- a/slime/utils/train_infer_is.py +++ b/slime/utils/train_infer_is.py @@ -62,9 +62,6 @@ def calculate_veto_mask( has_catastrophic = catastrophic_tokens.any() veto_mask = (~has_catastrophic).float().expand_as(log_ratio_for_metrics) - # TODO(jiajun): A single catastrophic token may not be enough to veto the entire sequence? - # May be we can set a threshold for the ratio of catastrophic tokens? - # If exceeds, veto the entire sequence. If not, only mask the catastrophic tokens. metrics_append(metrics, "catastrophic_token_fraction", catastrophic_tokens.int()) metrics_append(metrics, "catastrophic_seq_fraction", has_catastrophic.int().expand_as(loss_mask)) return veto_mask @@ -115,14 +112,14 @@ def compute_train_infer_is_weights( """ Compute the truncated importance sampling (TIS) weights and metrics. Args: - train_log_probs: List of log probs from training backend. - rollout_log_probs: List of log probs from inference backend. - loss_masks: List of loss masks. - Note that for single turn RL, the loss_mask is [1] * response_length for each sequence + train_log_probs: List of log probs from training backend (1D tensor) + rollout_log_probs: List of log probs from inference backend (1D tensor) + loss_masks: List of loss masks (1D tensor) + Note that for single turn RL, the loss_mask is [1] * response_length for each sequence (1D tensor) For multi turn RL, the tool response will be marked as 0 in the loss_mask. Returns: - weights: The importance sampling weights. [batch_size, seq_len] + weights: List of importance sampling weights (1D tensor) metrics: The metrics for the importance sampling weights. """ @@ -154,11 +151,11 @@ def compute_train_infer_is_weights( elif level == "sequence": # Product of ratios (unbiased) agg_log_ratio = (raw_log_ratio * loss_mask).sum() - log_ratio_for_metrics = torch.full_like(raw_log_ratio, agg_log_ratio) + log_ratio_for_metrics = agg_log_ratio.expand_as(raw_log_ratio) elif level == "geometric": # Geometric mean of ratios (experimental) agg_log_ratio = (raw_log_ratio * loss_mask).sum() / torch.clamp_min(loss_mask.sum(), 1) - log_ratio_for_metrics = torch.full_like(raw_log_ratio, agg_log_ratio) + log_ratio_for_metrics = agg_log_ratio.expand_as(raw_log_ratio) else: raise ValueError(f"Invalid importance sampling level: {level}") @@ -167,9 +164,7 @@ def compute_train_infer_is_weights( # mask out catastrophic tokens if args.train_infer_is_veto_threshold is not None: - veto_mask = calculate_veto_mask( - log_ratio_for_metrics, loss_mask, args.train_infer_is_veto_threshold, metrics - ) + veto_mask = calculate_veto_mask(raw_log_ratio, loss_mask, args.train_infer_is_veto_threshold, metrics) metrics_append(metrics, "raw_ratio_mean", weights) @@ -224,9 +219,9 @@ def compute_train_infer_is_weights_with_cp( """ Compute the truncated importance sampling (TIS) weights and metrics with context parallel. Args: - train_log_probs: List of log probs from training backend on this cp rank. - rollout_log_probs: List of log probs from inference backend on this cp rank. - loss_masks: List of loss masks. + train_log_probs: List of log probs from training backend on this cp rank (1D tensor) + rollout_log_probs: List of log probs from inference backend on this cp rank (1D tensor) + loss_masks: List of loss masks (1D tensor) total_lengths: List of total lengths. response_lengths: List of response lengths. Returns: From ce01008a995253a576f849bca1ded7a47d0792c4 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Wed, 15 Oct 2025 03:05:50 +0000 Subject: [PATCH 42/58] small change fsdp --- slime/backends/fsdp_utils/actor.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/slime/backends/fsdp_utils/actor.py b/slime/backends/fsdp_utils/actor.py index f0c954c10e..f63da958b9 100644 --- a/slime/backends/fsdp_utils/actor.py +++ b/slime/backends/fsdp_utils/actor.py @@ -319,7 +319,7 @@ def train(self, rollout_id, rollout_data_ref): pg_loss, pg_clipfrac = compute_policy_loss(ppo_kl, advantages, self.args.eps_clip, self.args.eps_clip_high) # Apply TIS before sample mean calculation - if self.args.use_tis: + if self.args.use_train_infer_is: # Initialize TIS variables tis = None tis_clipfrac = None @@ -338,7 +338,9 @@ def train(self, rollout_id, rollout_data_ref): tis = torch.exp(old_log_probs - rollout_log_probs) ois = (-ppo_kl).exp() tis_clip = torch.clamp( - tis, min=getattr(self.args, "tis_clip_low", 0.1), max=getattr(self.args, "tis_clip", 2.0) + tis, + min=getattr(self.args, "train_infer_is_lower_bound", 0.1), + max=getattr(self.args, "train_infer_is_upper_bound", 2.0), ) tis_clipfrac = tis_clip != tis From 43d04bda19c3d0ad01d68218aa3a9b1a5ebbff85 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Wed, 15 Oct 2025 05:36:57 +0000 Subject: [PATCH 43/58] fix parameter type in some functions and update comments --- slime/utils/arguments.py | 2 +- slime/utils/train_infer_is.py | 55 +++++++++++++++++------------------ 2 files changed, 28 insertions(+), 29 deletions(-) diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index ff503c2841..03325482bc 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -54,7 +54,7 @@ def add_is_arguments(parser: argparse.ArgumentParser): default="truncate", help=( "Handling mode for IS weights:" - "truncate (cap upper bound, TIS)," + "truncate (cap to upper bound, TIS)," "mask (zero outside [lower, upper], MIS)," "clip (clip to [lower, upper], CIS)." ), diff --git a/slime/utils/train_infer_is.py b/slime/utils/train_infer_is.py index 16a4531c41..53c7eb05a5 100644 --- a/slime/utils/train_infer_is.py +++ b/slime/utils/train_infer_is.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Tuple +from typing import Dict, Optional, Tuple import torch @@ -8,8 +8,7 @@ def metrics_append(metrics: Dict[str, list[torch.Tensor]], key: str, value: torch.Tensor) -> None: """ - Every metrics-dict value is a list[torch.Tensor] (variable-length tensors) with shapes exactly the same - as train_log_probs, rollout_log_probs, and loss_masks. + Every metrics-dict value is a list[torch.Tensor] with shapes exactly the same as log_probs. All metrics will be aggregated and averaged by `sum_of_sample_mean` and divided by DP size automatically - If calculate_per_token_loss=False (default), the final results will first be averaged in each sequence, @@ -49,18 +48,18 @@ def metrics_append(metrics: Dict[str, list[torch.Tensor]], key: str, value: torc def calculate_veto_mask( - log_ratio_for_metrics: torch.Tensor, + log_ratio: torch.Tensor, loss_mask: torch.Tensor, veto_threshold: Optional[float], metrics: Dict[str, list[torch.Tensor]], ) -> torch.Tensor: if veto_threshold is None: - return torch.ones_like(log_ratio_for_metrics) - log_veto_threshold = torch.log(torch.tensor(veto_threshold, device=log_ratio_for_metrics.device)) + return torch.ones_like(log_ratio) + log_veto_threshold = torch.log(torch.tensor(veto_threshold, device=log_ratio.device)) # For each sequence, if it has any catastrophic tokens, return 0 for the sequence - catastrophic_tokens = ((log_ratio_for_metrics < log_veto_threshold)) & loss_mask.bool() + catastrophic_tokens = ((log_ratio < log_veto_threshold)) & loss_mask.bool() has_catastrophic = catastrophic_tokens.any() - veto_mask = (~has_catastrophic).float().expand_as(log_ratio_for_metrics) + veto_mask = (~has_catastrophic).float().expand_as(log_ratio) metrics_append(metrics, "catastrophic_token_fraction", catastrophic_tokens.int()) metrics_append(metrics, "catastrophic_seq_fraction", has_catastrophic.int().expand_as(loss_mask)) @@ -108,19 +107,19 @@ def compute_train_infer_is_weights( train_log_probs: list[torch.Tensor], rollout_log_probs: list[torch.Tensor], loss_masks: list[torch.Tensor], -) -> Tuple[list[torch.Tensor], Dict[str, Any]]: +) -> Tuple[list[torch.Tensor], Dict[str, list[torch.Tensor]]]: """ Compute the truncated importance sampling (TIS) weights and metrics. Args: - train_log_probs: List of log probs from training backend (1D tensor) - rollout_log_probs: List of log probs from inference backend (1D tensor) - loss_masks: List of loss masks (1D tensor) - Note that for single turn RL, the loss_mask is [1] * response_length for each sequence (1D tensor) + train_log_probs: List of log probs from training backend. 1D tensor each. + rollout_log_probs: List of log probs from inference backend. 1D tensor each. + loss_masks: List of loss masks. 1D tensor each. + Note that for single turn RL, the loss_mask is [1] * response_length tensor for each sequence For multi turn RL, the tool response will be marked as 0 in the loss_mask. Returns: - weights: List of importance sampling weights (1D tensor) - metrics: The metrics for the importance sampling weights. + weights: List of importance sampling weights. 1D tensor each. + metrics: The metrics for the importance sampling weights, a dict of list[torch.Tensor]. 1D tensor each. """ level: str = args.train_infer_is_level @@ -129,14 +128,14 @@ def compute_train_infer_is_weights( # Validate input lists have same length and each sequence has matching shapes assert ( len(train_log_probs) == len(rollout_log_probs) == len(loss_masks) - ), f"Input lists must have same length: {len(train_log_probs)} vs {len(rollout_log_probs)} vs {len(loss_masks)}" + ), f"Input lists must have the same number of sequences: {len(train_log_probs)} vs {len(rollout_log_probs)} vs {len(loss_masks)}" for i, (train, rollout, loss_mask) in enumerate(zip(train_log_probs, rollout_log_probs, loss_masks)): assert ( train.shape == rollout.shape == loss_mask.shape ), f"Sequence {i}: shapes must match - train: {train.shape}, rollout: {rollout.shape}, loss_mask: {loss_mask.shape}" - SAFETY_BOUND = 20.0 + SAFETY_BOUND = 20.0 # Add a safety bound to avoid exp overflow all_weights = [] # handle each sequence independently @@ -149,11 +148,11 @@ def compute_train_infer_is_weights( # Per-token ratio (biased) log_ratio_for_metrics = raw_log_ratio elif level == "sequence": - # Product of ratios (unbiased) + # Product of ratios (unbiased but high variance) agg_log_ratio = (raw_log_ratio * loss_mask).sum() log_ratio_for_metrics = agg_log_ratio.expand_as(raw_log_ratio) elif level == "geometric": - # Geometric mean of ratios (experimental) + # Geometric mean of ratios (experimental but low variance) agg_log_ratio = (raw_log_ratio * loss_mask).sum() / torch.clamp_min(loss_mask.sum(), 1) log_ratio_for_metrics = agg_log_ratio.expand_as(raw_log_ratio) else: @@ -215,18 +214,19 @@ def compute_train_infer_is_weights_with_cp( loss_masks: list[torch.Tensor], total_lengths: list[int], response_lengths: list[int], -) -> Tuple[list[torch.Tensor], Dict[str, Any]]: +) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Compute the truncated importance sampling (TIS) weights and metrics with context parallel. Args: - train_log_probs: List of log probs from training backend on this cp rank (1D tensor) - rollout_log_probs: List of log probs from inference backend on this cp rank (1D tensor) - loss_masks: List of loss masks (1D tensor) + train_log_probs: List of log probs from training backend on this cp rank. 1D tensor each. + rollout_log_probs: List of log probs from inference backend on this cp rank. 1D tensor each. + loss_masks: List of loss masks. 1D tensor each. total_lengths: List of total lengths. response_lengths: List of response lengths. Returns: - is_weights: The importance sampling weights. [batch_size, seq_len] - is_metrics: The metrics for the importance sampling weights. + is_weights: Importance sampling weights on this CP rank and flattened along dim=0. + is_metrics: The metrics for the importance sampling weights, a dict of list[torch.Tensor]. 1D tensor each. + Also flattened along dim=0. """ # Gather cp slice from other cp ranks full_rollout_log_probs = [ @@ -246,11 +246,10 @@ def compute_train_infer_is_weights_with_cp( loss_masks=loss_masks, ) - # Slice cp slice and concat to the full response tensor + # Slice out the value shards for this CP rank and concat them into a 1D tensor along dim=0 for loss.py computation. def slice_cp_and_concat( values: list[torch.Tensor], total_lengths: list[int], response_lengths: list[int] - ) -> list[torch.Tensor]: - # reshape value to the sequence size of the cp rank. + ) -> torch.Tensor: values = [ # TODO: A rename of this function ? slice_log_prob_with_cp(values[i], total_lengths[i], response_lengths[i]) From 1074d7a69dab4d69b3a3fe84e46b10160692ca26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E6=99=A8=E9=98=B3?= Date: Tue, 14 Oct 2025 23:01:04 -0700 Subject: [PATCH 44/58] Update train_infer_is.py --- slime/utils/train_infer_is.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/slime/utils/train_infer_is.py b/slime/utils/train_infer_is.py index 53c7eb05a5..c0379257a8 100644 --- a/slime/utils/train_infer_is.py +++ b/slime/utils/train_infer_is.py @@ -8,7 +8,7 @@ def metrics_append(metrics: Dict[str, list[torch.Tensor]], key: str, value: torch.Tensor) -> None: """ - Every metrics-dict value is a list[torch.Tensor] with shapes exactly the same as log_probs. + Every metrics-dict value is a list of 1D tensor, i.e., [torch.Tensor] with shapes exactly the same as log_probs. All metrics will be aggregated and averaged by `sum_of_sample_mean` and divided by DP size automatically - If calculate_per_token_loss=False (default), the final results will first be averaged in each sequence, @@ -29,8 +29,8 @@ def metrics_append(metrics: Dict[str, list[torch.Tensor]], key: str, value: torc When calculate_per_token_loss = True: result = (0.1 + 0.2 + 0.1 + 0.2 + 0.3 + 0.4 + 0.5 + 0.6) / 8 = 2.4 / 8 = 0.3 For sequence-level metric: - original sequence lengths = [2, 5, 1] - We should expand the metrics to the length of each sequence + original sequence lengths = [2, 5, 2] + We should expand the metrics to the length of each sequence: value = [ [2, 2], [5, 5, 5, 5, 5], @@ -38,9 +38,7 @@ def metrics_append(metrics: Dict[str, list[torch.Tensor]], key: str, value: torc ] When calculate_per_token_loss = False (default): result = (2 + 2) / 2 + (5 + 5 + 5 + 5 + 5) / 5 + (1 + 1) / 2 = 2 + 5 + 1 = 8 / 3 = 2.6665 - When calculate_per_token_loss = True: - result = (2 + 2 + 5 + 5 + 5 + 5 + 5 + 1 + 1) / 8 = 31 / 8 = 3.875 ??? - This is not what we expect. So, when calculate_per_token_loss = True, sequence-level metrics becomes invalid !! + Note that for sequence-level, calculating token-level loss is invalid; thus, calculate_per_token_loss should always be False. """ if key not in metrics: metrics[key] = [] @@ -109,13 +107,13 @@ def compute_train_infer_is_weights( loss_masks: list[torch.Tensor], ) -> Tuple[list[torch.Tensor], Dict[str, list[torch.Tensor]]]: """ - Compute the truncated importance sampling (TIS) weights and metrics. + Compute the importance sampling (IS) weights and metrics between the inference and training engine. Args: - train_log_probs: List of log probs from training backend. 1D tensor each. + train_log_probs: List of log probs from training backend. 1D tensor each. Lengths can be different. rollout_log_probs: List of log probs from inference backend. 1D tensor each. loss_masks: List of loss masks. 1D tensor each. Note that for single turn RL, the loss_mask is [1] * response_length tensor for each sequence - For multi turn RL, the tool response will be marked as 0 in the loss_mask. + For multi-turn RL, the tool response will be marked as 0 in the loss_mask. Returns: weights: List of importance sampling weights. 1D tensor each. @@ -152,7 +150,7 @@ def compute_train_infer_is_weights( agg_log_ratio = (raw_log_ratio * loss_mask).sum() log_ratio_for_metrics = agg_log_ratio.expand_as(raw_log_ratio) elif level == "geometric": - # Geometric mean of ratios (experimental but low variance) + # Geometric mean of ratios (biased but low variance) agg_log_ratio = (raw_log_ratio * loss_mask).sum() / torch.clamp_min(loss_mask.sum(), 1) log_ratio_for_metrics = agg_log_ratio.expand_as(raw_log_ratio) else: @@ -218,7 +216,7 @@ def compute_train_infer_is_weights_with_cp( """ Compute the truncated importance sampling (TIS) weights and metrics with context parallel. Args: - train_log_probs: List of log probs from training backend on this cp rank. 1D tensor each. + train_log_probs: List of log probs from training backend on this cp rank. 1D tensor each. Lengths can be different. rollout_log_probs: List of log probs from inference backend on this cp rank. 1D tensor each. loss_masks: List of loss masks. 1D tensor each. total_lengths: List of total lengths. @@ -251,7 +249,7 @@ def slice_cp_and_concat( values: list[torch.Tensor], total_lengths: list[int], response_lengths: list[int] ) -> torch.Tensor: values = [ - # TODO: A rename of this function ? + # TODO: A rename of this function? slice_log_prob_with_cp(values[i], total_lengths[i], response_lengths[i]) for i in range(len(values)) ] From 99224fab5e7220bd01e24ed6250adb995a057c44 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Wed, 15 Oct 2025 06:17:32 +0000 Subject: [PATCH 45/58] remove sh --- docs/en/get_started/quick_start.md | 2 +- docs/en/get_started/usage.md | 2 +- docs/zh/get_started/quick_start.md | 2 +- docs/zh/get_started/usage.md | 2 +- .../fully_async/run-qwen3-4b-fully_async.sh | 2 +- scripts/4b-cp2.sh | 162 ------------------ scripts/run-glm4.5-355B-A32B.sh | 2 +- scripts/run-qwen3-30B-A3B.sh | 43 ++--- scripts/run-qwen3-4B-tis-example.sh | 161 ----------------- tests/test_quick_start_glm4-9B.sh | 2 +- tests/test_qwen3-30B-A3B.sh | 2 +- 11 files changed, 25 insertions(+), 357 deletions(-) delete mode 100644 scripts/4b-cp2.sh delete mode 100644 scripts/run-qwen3-4B-tis-example.sh diff --git a/docs/en/get_started/quick_start.md b/docs/en/get_started/quick_start.md index 65a2b1a027..56e3092968 100644 --- a/docs/en/get_started/quick_start.md +++ b/docs/en/get_started/quick_start.md @@ -275,7 +275,7 @@ GRPO_ARGS=( - `--advantage-estimator`: In addition to [GRPO](https://arxiv.org/abs/2402.03300), slime also supports several other training algorithms, such as [GSPO](https://arxiv.org/abs/2507.18071), [Reinforce++](https://arxiv.org/abs/2501.03262) and [Reinforce++ Baseline](https://arxiv.org/abs/2501.03262), and [PPO](https://arxiv.org/abs/1707.06347). - `--calculate-per-token-loss`: By default, slime calculates the loss on a per-sample basis, i.e., `mean(sum(sample_i) / len(sample_i))`. To calculate the loss on a per-token basis, i.e., `sum(sum(sample_i)) / sum(len(sample_i))`, you can enable this flag. -- `--use-tis`: Enable this setting to use TIS (Truncated Importance Sampling), which is introduced by this [blog](https://fengyao.notion.site/off-policy-rl). +- `--use-train-infer-is`: Enable this setting to use TIS (Truncated Importance Sampling), which is introduced by this [blog](https://fengyao.notion.site/off-policy-rl). ### OPTIMIZER_ARGS: Optimizer Parameters diff --git a/docs/en/get_started/usage.md b/docs/en/get_started/usage.md index 26310cb9f0..827e5fd62f 100644 --- a/docs/en/get_started/usage.md +++ b/docs/en/get_started/usage.md @@ -177,7 +177,7 @@ Additionally, we provide a `metadata_key`, which defaults to `"metadata"`. When - `reinforce_plus_plus` and `reinforce_plus_plus_baseline` ([https://arxiv.org/abs/2501.03262](https://arxiv.org/abs/2501.03262)) - `ppo` ([https://arxiv.org/abs/1707.06347](https://arxiv.org/abs/1707.06347)) - `--calculate-per-token-loss`: By default, Slime calculates loss on a per-sample basis, i.e., `mean(sum(sample_i) / len(sample_i))`. Enable this flag to calculate loss on a per-token basis, i.e., `sum(sum(sample_i)) / sum(len(sample_i))`. -- `--use-tis`: Enable this setting to use TIS (Truncated Importance Sampling) (https://fengyao.notion.site/off-policy-rl). +- `--use-train-infer-is`: Enable this setting to use TIS (Truncated Importance Sampling) (https://fengyao.notion.site/off-policy-rl). ## Custom Rollout Function diff --git a/docs/zh/get_started/quick_start.md b/docs/zh/get_started/quick_start.md index d4c24bdf57..09d17d144a 100644 --- a/docs/zh/get_started/quick_start.md +++ b/docs/zh/get_started/quick_start.md @@ -274,7 +274,7 @@ GRPO_ARGS=( - `--advantage-estimator`: 除去 [GRPO](https://arxiv.org/abs/2402.03300),slime 还支持丰富的其他训练算法,例如 [GSPO](https://arxiv.org/abs/2507.18071)、[Reinforce++](https://arxiv.org/abs/2501.03262) 与 [Reinforce++ Baseline](https://arxiv.org/abs/2501.03262)、以及 [PPO](https://arxiv.org/abs/1707.06347); - `--calculate-per-token-loss`:slime 中默认的方案是 per sample loss,即 `mean(sum(sample_i) / len(sample_i))`,如果需要计算 per token loss,即 `sum(sum(sample_i)) / sum(len(sample_i))`,可以开启 `--calculate-per-token-loss`; -- `--use-tis`:如果需要开启 TIS (Truncated Importance Sampling),可以开启这一设置。TIS 由此[博客](https://fengyao.notion.site/off-policy-rl)介绍。 +- `--use-train-infer-is`:如果需要开启 TIS (Truncated Importance Sampling),可以开启这一设置。TIS 由此[博客](https://fengyao.notion.site/off-policy-rl)介绍。 ### OPTIMIZER_ARGS: 优化器参数 diff --git a/docs/zh/get_started/usage.md b/docs/zh/get_started/usage.md index 332be9d417..47e1b09b85 100644 --- a/docs/zh/get_started/usage.md +++ b/docs/zh/get_started/usage.md @@ -181,7 +181,7 @@ sglang 的加载非常简单,只需要: - `reinforce_plus_plus` 与 `reinforce_plus_plus_baseline`(https://arxiv.org/abs/2501.03262); - `ppo`(https://arxiv.org/abs/1707.06347)。 - `--calculate-per-token-loss`:slime 中默认的方案是 per sample loss,即 `mean(sum(sample_i) / len(sample_i))`,如果需要计算 per token loss,即 `sum(sum(sample_i)) / sum(len(sample_i))`,可以开启 `--calculate-per-token-loss`; -- `--use-tis`:如果需要开启 tis(https://fengyao.notion.site/off-policy-rl),可以开启这一设置。 +- `--use-train-infer-is`:如果需要开启 tis(https://fengyao.notion.site/off-policy-rl),可以开启这一设置。 ## 自定义 rollout 函数 diff --git a/examples/fully_async/run-qwen3-4b-fully_async.sh b/examples/fully_async/run-qwen3-4b-fully_async.sh index 2ef6162ab0..3665c6fb62 100644 --- a/examples/fully_async/run-qwen3-4b-fully_async.sh +++ b/examples/fully_async/run-qwen3-4b-fully_async.sh @@ -79,7 +79,7 @@ GRPO_ARGS=( --eps-clip 0.2 --eps-clip-high 0.28 - --use-tis + --use-train-infer-is ) OPTIMIZER_ARGS=( diff --git a/scripts/4b-cp2.sh b/scripts/4b-cp2.sh deleted file mode 100644 index 0af58252bc..0000000000 --- a/scripts/4b-cp2.sh +++ /dev/null @@ -1,162 +0,0 @@ -#!/bin/bash - -# for rerun the task -pkill -9 sglang -sleep 3 -ray stop --force -pkill -9 ray -pkill -9 python -sleep 3 -pkill -9 ray -pkill -9 python -pkill -9 redis - -set -ex - -# will prevent ray from buffering stdout/stderr -export PYTHONBUFFERED=16 - -NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) -if [ "$NVLINK_COUNT" -gt 0 ]; then - HAS_NVLINK=1 -else - HAS_NVLINK=0 -fi -echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" - -SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" -source "${SCRIPT_DIR}/models/qwen3-4B.sh" - -CKPT_ARGS=( - --hf-checkpoint /root/Qwen3-4B - #--hf-checkpoint /root/Qwen3-4B-FP8 - --ref-load /root/Qwen3-4B_torch_dist - # --load /root/Qwen3-4B_slime/ - --save /root/Qwen3-4B_slime/ - --save-interval 20 -) - -ROLLOUT_ARGS=( - --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl - --input-key prompt - --label-key label - --apply-chat-template - --rollout-shuffle - --rm-type deepscaler - --num-rollout 3000 - --rollout-batch-size 8 - --n-samples-per-prompt 4 - --rollout-max-response-len 8192 - --rollout-temperature 0.8 - - --global-batch-size 32 - --balance-data -) - -EVAL_ARGS=( - # --eval-interval 20 - --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl - --n-samples-per-eval-prompt 1 - --eval-max-response-len 16384 - --eval-top-p 0.7 -) - -PERF_ARGS=( - --tensor-model-parallel-size 2 - --sequence-parallel - --pipeline-model-parallel-size 1 - --context-parallel-size 2 - --expert-model-parallel-size 1 - --expert-tensor-parallel-size 1 - - --recompute-granularity full - --recompute-method uniform - --recompute-num-layers 1 - - # --micro-batch-size 1 - --use-dynamic-batch-size - --max-tokens-per-gpu 9216 -) - -GRPO_ARGS=( - --advantage-estimator grpo - --use-kl-loss - --kl-loss-coef 0.00 - --kl-loss-type low_var_kl - --entropy-coef 0.00 - --eps-clip 0.2 - --eps-clip-high 0.28 -) - -IS_ARGS=( - --use-train-infer-is - --train-infer-is-level sequence - --train-infer-is-mode clip - --train-infer-is-upper-bound 2.0 - # --train-infer-is-lower-bound 0.5 - --train-infer-is-veto-threshold 1e-4 -) - -OPTIMIZER_ARGS=( - --optimizer adam - --lr 1e-6 - --lr-decay-style constant - --weight-decay 0.1 - --adam-beta1 0.9 - --adam-beta2 0.98 -) - -WANDB_ARGS=( - --use-wandb - --wandb-project slime-tis - --wandb-group qwen3-4B-tis-seq - --wandb-key ${WANDB_KEY} -) - -SGLANG_ARGS=( - --rollout-num-gpus-per-engine 1 - --sglang-mem-fraction-static 0.7 -) - -MISC_ARGS=( - # default dropout in megatron is 0.1 - --attention-dropout 0.0 - --hidden-dropout 0.0 - # should be good for model performance - --accumulate-allreduce-grads-in-fp32 - --attention-softmax-in-fp32 - # need to comment this when using model with MLA - --attention-backend flash -) - -# launch the master node of ray in container -export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} -export CUDA_VISIBLE_DEVICES=4,5,6,7 -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 4 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 - -# Build the runtime environment JSON with proper variable substitution -RUNTIME_ENV_JSON="{ - \"env_vars\": { - \"PYTHONPATH\": \"/root/Megatron-LM/\", - \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", - \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" - } -}" - -ray job submit --address="http://127.0.0.1:8265" \ - --runtime-env-json="${RUNTIME_ENV_JSON}" \ - -- python3 train.py \ - --actor-num-nodes 1 \ - --actor-num-gpus-per-node 4 \ - --colocate \ - ${MODEL_ARGS[@]} \ - ${CKPT_ARGS[@]} \ - ${ROLLOUT_ARGS[@]} \ - ${OPTIMIZER_ARGS[@]} \ - ${GRPO_ARGS[@]} \ - ${WANDB_ARGS[@]} \ - ${PERF_ARGS[@]} \ - ${EVAL_ARGS[@]} \ - ${SGLANG_ARGS[@]} \ - ${MISC_ARGS[@]} \ - ${IS_ARGS[@]} \ No newline at end of file diff --git a/scripts/run-glm4.5-355B-A32B.sh b/scripts/run-glm4.5-355B-A32B.sh index 4a9792df0c..3d39ad334c 100644 --- a/scripts/run-glm4.5-355B-A32B.sh +++ b/scripts/run-glm4.5-355B-A32B.sh @@ -86,7 +86,7 @@ GRPO_ARGS=( --eps-clip 1e-4 --eps-clip-high 2e-4 - --use-tis + --use-train-infer-is ) OPTIMIZER_ARGS=( diff --git a/scripts/run-qwen3-30B-A3B.sh b/scripts/run-qwen3-30B-A3B.sh index 37bad293c2..3cdb749d9d 100644 --- a/scripts/run-qwen3-30B-A3B.sh +++ b/scripts/run-qwen3-30B-A3B.sh @@ -31,9 +31,9 @@ CKPT_ARGS=( --hf-checkpoint /root/Qwen3-30B-A3B #--hf-checkpoint /root/Qwen3-30B-A3B-FP8 --ref-load /root/Qwen3-30B-A3B_torch_dist - # --load /root/Qwen3-30B-A3B_slime/ - # --save /root/Qwen3-30B-A3B_slime/ - # --save-interval 20 + --load /root/Qwen3-30B-A3B_slime/ + --save /root/Qwen3-30B-A3B_slime/ + --save-interval 20 ) ROLLOUT_ARGS=( @@ -44,17 +44,17 @@ ROLLOUT_ARGS=( --rollout-shuffle --rm-type deepscaler --num-rollout 3000 - --rollout-batch-size 8 - --n-samples-per-prompt 4 + --rollout-batch-size 32 + --n-samples-per-prompt 8 --rollout-max-response-len 8192 --rollout-temperature 0.8 - --global-batch-size 32 + --global-batch-size 256 --balance-data ) EVAL_ARGS=( - # --eval-interval 20 + --eval-interval 20 --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 @@ -62,11 +62,11 @@ EVAL_ARGS=( ) PERF_ARGS=( - --tensor-model-parallel-size 2 + --tensor-model-parallel-size 4 --sequence-parallel --pipeline-model-parallel-size 1 - --context-parallel-size 2 - --expert-model-parallel-size 4 + --context-parallel-size 1 + --expert-model-parallel-size 8 --expert-tensor-parallel-size 1 --recompute-granularity full @@ -88,14 +88,6 @@ GRPO_ARGS=( --eps-clip-high 0.28 ) -IS_ARGS=( - --use-train-infer-is - --train-infer-is-level token - --train-infer-is-mode clip - --train-infer-is-eps-clip 0.2 - --train-infer-is-veto-threshold 1e-4 -) - OPTIMIZER_ARGS=( --optimizer adam --lr 1e-6 @@ -110,15 +102,14 @@ OPTIMIZER_ARGS=( ) WANDB_ARGS=( - --use-wandb - --wandb-project slime-dev - --wandb-group qwen3-30B-A3B-TIS - --wandb-run-id qwen3-30B-A3B-TIS-sequence - --wandb-key ${WANDB_KEY} + #--use-wandb + # --wandb-project slime-dev + # --wandb-group qwen3-30B-A3B-test + # --wandb-key ${WANDB_KEY} ) SGLANG_ARGS=( - --rollout-num-gpus-per-engine 4 + --rollout-num-gpus-per-engine 8 --sglang-mem-fraction-static 0.7 --sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256) ) @@ -136,7 +127,7 @@ MISC_ARGS=( # launch the master node of ray in container export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 4 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 # Build the runtime environment JSON with proper variable substitution RUNTIME_ENV_JSON="{ @@ -151,7 +142,7 @@ ray job submit --address="http://127.0.0.1:8265" \ --runtime-env-json="${RUNTIME_ENV_JSON}" \ -- python3 train.py \ --actor-num-nodes 1 \ - --actor-num-gpus-per-node 4 \ + --actor-num-gpus-per-node 8 \ --colocate \ ${MODEL_ARGS[@]} \ ${CKPT_ARGS[@]} \ diff --git a/scripts/run-qwen3-4B-tis-example.sh b/scripts/run-qwen3-4B-tis-example.sh deleted file mode 100644 index 0c11a48646..0000000000 --- a/scripts/run-qwen3-4B-tis-example.sh +++ /dev/null @@ -1,161 +0,0 @@ -#!/bin/bash - -# for rerun the task -pkill -9 sglang -sleep 3 -ray stop --force -pkill -9 ray -pkill -9 python -sleep 3 -pkill -9 ray -pkill -9 python - -set -ex - -# will prevent ray from buffering stdout/stderr -export PYTHONBUFFERED=16 - -NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) -if [ "$NVLINK_COUNT" -gt 0 ]; then - HAS_NVLINK=1 -else - HAS_NVLINK=0 -fi -echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" - -SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" -source "${SCRIPT_DIR}/models/qwen3-4B.sh" - -CKPT_ARGS=( - --hf-checkpoint /root/Qwen3-4B - #--hf-checkpoint /root/Qwen3-4B-FP8 - --ref-load /root/Qwen3-4B_torch_dist - # --load /root/Qwen3-4B_slime/ - --save /root/Qwen3-4B_slime/ - --save-interval 2000 -) - -ROLLOUT_ARGS=( - --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl - --input-key prompt - --label-key label - --apply-chat-template - --rollout-shuffle - --rm-type deepscaler - --num-rollout 3000 - --rollout-batch-size 8 - --n-samples-per-prompt 4 - --rollout-max-response-len 8192 - --rollout-temperature 0.8 - - --global-batch-size 32 - --balance-data -) - -EVAL_ARGS=( - --eval-interval 20 - --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl - --n-samples-per-eval-prompt 1 - --eval-max-response-len 16384 - --eval-top-p 0.7 -) - -PERF_ARGS=( - --tensor-model-parallel-size 2 - --sequence-parallel - --pipeline-model-parallel-size 1 - --context-parallel-size 2 - --expert-model-parallel-size 1 - --expert-tensor-parallel-size 1 - - --recompute-granularity full - --recompute-method uniform - --recompute-num-layers 1 - - # --micro-batch-size 1 - --use-dynamic-batch-size - --max-tokens-per-gpu 9216 -) - -GRPO_ARGS=( - --advantage-estimator grpo - --use-kl-loss - --kl-loss-coef 0.00 - --kl-loss-type low_var_kl - --entropy-coef 0.00 - --eps-clip 0.2 - --eps-clip-high 0.28 -) - -OPTIMIZER_ARGS=( - --optimizer adam - --lr 1e-6 - --lr-decay-style constant - --weight-decay 0.1 - --adam-beta1 0.9 - --adam-beta2 0.98 -) - -WANDB_ARGS=( - --use-wandb - --wandb-project slime-is - --wandb-group qwen3-4B-is-cp2 - --wandb-key ${WANDB_KEY} -) - -SGLANG_ARGS=( - --rollout-num-gpus-per-engine 1 - --sglang-mem-fraction-static 0.7 -) - -MISC_ARGS=( - # default dropout in megatron is 0.1 - --attention-dropout 0.0 - --hidden-dropout 0.0 - # should be good for model performance - --accumulate-allreduce-grads-in-fp32 - --attention-softmax-in-fp32 - # need to comment this when using model with MLA - --attention-backend flash -) - -TIS_ARGS=( - --use-train-infer-is - --train-infer-is-level token - --train-infer-is-mode clip - --train-infer-is-upper-bound 2.0 - --train-infer-is-lower-bound 0.5 - --train-infer-is-veto-threshold 1e-4 -) - -export CUDA_VISIBLE_DEVICES=4,5,6,7 -# launch the master node of ray in container -export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 4 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 - -# Build the runtime environment JSON with proper variable substitution -RUNTIME_ENV_JSON="{ - \"env_vars\": { - \"PYTHONPATH\": \"/root/Megatron-LM/\", - \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", - \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" - } -}" - -ray job submit --address="http://127.0.0.1:8265" \ - --runtime-env-json="${RUNTIME_ENV_JSON}" \ - -- python3 train.py \ - --actor-num-nodes 1 \ - --actor-num-gpus-per-node 4 \ - --colocate \ - ${MODEL_ARGS[@]} \ - ${CKPT_ARGS[@]} \ - ${ROLLOUT_ARGS[@]} \ - ${OPTIMIZER_ARGS[@]} \ - ${GRPO_ARGS[@]} \ - ${WANDB_ARGS[@]} \ - ${PERF_ARGS[@]} \ - ${EVAL_ARGS[@]} \ - ${SGLANG_ARGS[@]} \ - ${MISC_ARGS[@]} \ - ${TIS_ARGS[@]} \ No newline at end of file diff --git a/tests/test_quick_start_glm4-9B.sh b/tests/test_quick_start_glm4-9B.sh index 99832b5e43..3d1aab9e66 100644 --- a/tests/test_quick_start_glm4-9B.sh +++ b/tests/test_quick_start_glm4-9B.sh @@ -76,7 +76,7 @@ GRPO_ARGS=( --eps-clip 0.2 --eps-clip-high 0.28 - --use-tis + --use-train-infer-is --calculate-per-token-loss ) diff --git a/tests/test_qwen3-30B-A3B.sh b/tests/test_qwen3-30B-A3B.sh index 531eb613ba..98efc3069a 100644 --- a/tests/test_qwen3-30B-A3B.sh +++ b/tests/test_qwen3-30B-A3B.sh @@ -73,7 +73,7 @@ GRPO_ARGS=( --entropy-coef 0.00 --eps-clip 4e-4 - --use-tis + --use-train-infer-is --use-routing-replay ) From 8a8c44c50690261ca1b5e32b5039fb35ddb53a77 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Thu, 9 Oct 2025 02:10:43 +0000 Subject: [PATCH 46/58] logging a whole sequence --- scripts/run-qwen3-30B-A3B.sh | 30 ++++++++++++++------------- slime/backends/megatron_utils/loss.py | 3 +++ 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/scripts/run-qwen3-30B-A3B.sh b/scripts/run-qwen3-30B-A3B.sh index 3cdb749d9d..c1f1cc43ea 100644 --- a/scripts/run-qwen3-30B-A3B.sh +++ b/scripts/run-qwen3-30B-A3B.sh @@ -33,7 +33,7 @@ CKPT_ARGS=( --ref-load /root/Qwen3-30B-A3B_torch_dist --load /root/Qwen3-30B-A3B_slime/ --save /root/Qwen3-30B-A3B_slime/ - --save-interval 20 + --save-interval 1000 ) ROLLOUT_ARGS=( @@ -44,29 +44,29 @@ ROLLOUT_ARGS=( --rollout-shuffle --rm-type deepscaler --num-rollout 3000 - --rollout-batch-size 32 - --n-samples-per-prompt 8 + --rollout-batch-size 8 + --n-samples-per-prompt 4 --rollout-max-response-len 8192 --rollout-temperature 0.8 - --global-batch-size 256 + --global-batch-size 32 --balance-data ) EVAL_ARGS=( - --eval-interval 20 - --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl - --n-samples-per-eval-prompt 16 - --eval-max-response-len 16384 - --eval-top-p 0.7 + # --eval-interval 20 + # --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl + # --n-samples-per-eval-prompt 16 + # --eval-max-response-len 16384 + # --eval-top-p 0.7 ) PERF_ARGS=( - --tensor-model-parallel-size 4 + --tensor-model-parallel-size 2 --sequence-parallel --pipeline-model-parallel-size 1 --context-parallel-size 1 - --expert-model-parallel-size 8 + --expert-model-parallel-size 4 --expert-tensor-parallel-size 1 --recompute-granularity full @@ -86,6 +86,8 @@ GRPO_ARGS=( --entropy-coef 0.00 --eps-clip 0.2 --eps-clip-high 0.28 + + --use-tis ) OPTIMIZER_ARGS=( @@ -109,7 +111,7 @@ WANDB_ARGS=( ) SGLANG_ARGS=( - --rollout-num-gpus-per-engine 8 + --rollout-num-gpus-per-engine 4 --sglang-mem-fraction-static 0.7 --sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256) ) @@ -127,7 +129,7 @@ MISC_ARGS=( # launch the master node of ray in container export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 4 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 # Build the runtime environment JSON with proper variable substitution RUNTIME_ENV_JSON="{ @@ -142,7 +144,7 @@ ray job submit --address="http://127.0.0.1:8265" \ --runtime-env-json="${RUNTIME_ENV_JSON}" \ -- python3 train.py \ --actor-num-nodes 1 \ - --actor-num-gpus-per-node 8 \ + --actor-num-gpus-per-node 4 \ --colocate \ ${MODEL_ARGS[@]} \ ${CKPT_ARGS[@]} \ diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 2290fda8dd..afe5af9863 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -309,6 +309,9 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean): assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS" rollout_log_probs = torch.cat(batch["rollout_log_probs"], dim=0) old_log_probs = torch.cat(batch["log_probs"], dim=0) + print("================================") + print(f"old_log_probs.shape: {old_log_probs.shape}, rollout_log_probs.shape: {rollout_log_probs.shape}") + print("================================") tis = torch.exp(old_log_probs - rollout_log_probs) ois = (-ppo_kl).exp() From 52a401e4251a1c96f3d1bbb3d66769cd958746ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E6=99=A8=E9=98=B3?= Date: Tue, 14 Oct 2025 23:37:14 -0700 Subject: [PATCH 47/58] Update run-qwen3-30B-A3B.sh --- scripts/run-qwen3-30B-A3B.sh | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/scripts/run-qwen3-30B-A3B.sh b/scripts/run-qwen3-30B-A3B.sh index c1f1cc43ea..eb35f12d66 100644 --- a/scripts/run-qwen3-30B-A3B.sh +++ b/scripts/run-qwen3-30B-A3B.sh @@ -86,8 +86,15 @@ GRPO_ARGS=( --entropy-coef 0.00 --eps-clip 0.2 --eps-clip-high 0.28 +) - --use-tis +IS_ARGS=( + --use-train-infer-is + --train-infer-is-level geometric + --train-infer-is-mode mask + --train-infer-is-lower-bound 0.5 + --train-infer-is-upper-bound 2.0 + --train-infer-is-veto-threshold 1e-3 ) OPTIMIZER_ARGS=( From ac4e63ad3fd607c601e2ecd036c54059bff73b2b Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Wed, 15 Oct 2025 06:37:53 +0000 Subject: [PATCH 48/58] create test scripts --- ...3-30B-A3B.sh => run-qwen3-30B-A3B-test.sh} | 25 ++++++++++--------- tests/test_quick_start_glm4_9B.py | 2 +- tests/test_qwen3_30B_A3B.py | 2 +- 3 files changed, 15 insertions(+), 14 deletions(-) rename scripts/{run-qwen3-30B-A3B.sh => run-qwen3-30B-A3B-test.sh} (89%) diff --git a/scripts/run-qwen3-30B-A3B.sh b/scripts/run-qwen3-30B-A3B-test.sh similarity index 89% rename from scripts/run-qwen3-30B-A3B.sh rename to scripts/run-qwen3-30B-A3B-test.sh index eb35f12d66..2bfbfc6395 100644 --- a/scripts/run-qwen3-30B-A3B.sh +++ b/scripts/run-qwen3-30B-A3B-test.sh @@ -31,9 +31,9 @@ CKPT_ARGS=( --hf-checkpoint /root/Qwen3-30B-A3B #--hf-checkpoint /root/Qwen3-30B-A3B-FP8 --ref-load /root/Qwen3-30B-A3B_torch_dist - --load /root/Qwen3-30B-A3B_slime/ - --save /root/Qwen3-30B-A3B_slime/ - --save-interval 1000 + # --load /root/Qwen3-30B-A3B_slime/ + # --save /root/Qwen3-30B-A3B_slime/ + # --save-interval 20 ) ROLLOUT_ARGS=( @@ -55,17 +55,17 @@ ROLLOUT_ARGS=( EVAL_ARGS=( # --eval-interval 20 - # --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl - # --n-samples-per-eval-prompt 16 - # --eval-max-response-len 16384 - # --eval-top-p 0.7 + --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl + --n-samples-per-eval-prompt 16 + --eval-max-response-len 16384 + --eval-top-p 0.7 ) PERF_ARGS=( --tensor-model-parallel-size 2 --sequence-parallel --pipeline-model-parallel-size 1 - --context-parallel-size 1 + --context-parallel-size 2 --expert-model-parallel-size 4 --expert-tensor-parallel-size 1 @@ -111,10 +111,11 @@ OPTIMIZER_ARGS=( ) WANDB_ARGS=( - #--use-wandb - # --wandb-project slime-dev - # --wandb-group qwen3-30B-A3B-test - # --wandb-key ${WANDB_KEY} + --use-wandb + --wandb-project slime-dev + --wandb-group qwen3-30B-A3B-TIS + --wandb-run-id qwen3-30B-A3B-TIS-sequence + --wandb-key ${WANDB_KEY} ) SGLANG_ARGS=( diff --git a/tests/test_quick_start_glm4_9B.py b/tests/test_quick_start_glm4_9B.py index c491be37c4..63b648dfb9 100644 --- a/tests/test_quick_start_glm4_9B.py +++ b/tests/test_quick_start_glm4_9B.py @@ -67,7 +67,7 @@ def execute(): "--entropy-coef 0.00 " "--eps-clip 0.2 " "--eps-clip-high 0.28 " - "--use-tis " + "--use-train-infer-is " "--calculate-per-token-loss " ) diff --git a/tests/test_qwen3_30B_A3B.py b/tests/test_qwen3_30B_A3B.py index 6db31bbb0e..e7c4ff07d8 100644 --- a/tests/test_qwen3_30B_A3B.py +++ b/tests/test_qwen3_30B_A3B.py @@ -69,7 +69,7 @@ def execute(): "--kl-coef 0.00 " "--entropy-coef 0.00 " "--eps-clip 4e-4 " - "--use-tis " + "--use-train-infer-is " "--use-routing-replay " ) From 4df0724961f678f0751e4e9e3a17e5f0c5b77981 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Wed, 15 Oct 2025 06:40:24 +0000 Subject: [PATCH 49/58] revert change in qwen3 30B sh --- scripts/run-qwen3-30B-A3B.sh | 156 +++++++++++++++++++++++++++++++++++ 1 file changed, 156 insertions(+) create mode 100644 scripts/run-qwen3-30B-A3B.sh diff --git a/scripts/run-qwen3-30B-A3B.sh b/scripts/run-qwen3-30B-A3B.sh new file mode 100644 index 0000000000..3cdb749d9d --- /dev/null +++ b/scripts/run-qwen3-30B-A3B.sh @@ -0,0 +1,156 @@ +#!/bin/bash + +# for rerun the task +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python +pkill -9 redis + +set -ex + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +source "${SCRIPT_DIR}/models/qwen3-30B-A3B.sh" + +CKPT_ARGS=( + --hf-checkpoint /root/Qwen3-30B-A3B + #--hf-checkpoint /root/Qwen3-30B-A3B-FP8 + --ref-load /root/Qwen3-30B-A3B_torch_dist + --load /root/Qwen3-30B-A3B_slime/ + --save /root/Qwen3-30B-A3B_slime/ + --save-interval 20 +) + +ROLLOUT_ARGS=( + --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + --rm-type deepscaler + --num-rollout 3000 + --rollout-batch-size 32 + --n-samples-per-prompt 8 + --rollout-max-response-len 8192 + --rollout-temperature 0.8 + + --global-batch-size 256 + --balance-data +) + +EVAL_ARGS=( + --eval-interval 20 + --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl + --n-samples-per-eval-prompt 16 + --eval-max-response-len 16384 + --eval-top-p 0.7 +) + +PERF_ARGS=( + --tensor-model-parallel-size 4 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 8 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + # --micro-batch-size 1 + --use-dynamic-batch-size + --max-tokens-per-gpu 20480 +) + +GRPO_ARGS=( + --advantage-estimator grpo + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 + + --optimizer-cpu-offload + --overlap-cpu-optimizer-d2h-h2d + --use-precision-aware-optimizer +) + +WANDB_ARGS=( + #--use-wandb + # --wandb-project slime-dev + # --wandb-group qwen3-30B-A3B-test + # --wandb-key ${WANDB_KEY} +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 8 + --sglang-mem-fraction-static 0.7 + --sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256) +) + +MISC_ARGS=( + # default dropout in megatron is 0.1 + --attention-dropout 0.0 + --hidden-dropout 0.0 + # should be good for model performance + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + # need to comment this when using model with MLA + --attention-backend flash +) + +# launch the master node of ray in container +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +# Build the runtime environment JSON with proper variable substitution +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 8 \ + --colocate \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} From 7b013692c942f9d8a62dc81045f9e3f9170b53dc Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Wed, 15 Oct 2025 06:42:18 +0000 Subject: [PATCH 50/58] remove two tests sh --- tests/test_quick_start_glm4-9B.sh | 140 --------------------------- tests/test_qwen3-30B-A3B.sh | 151 ------------------------------ 2 files changed, 291 deletions(-) delete mode 100644 tests/test_quick_start_glm4-9B.sh delete mode 100644 tests/test_qwen3-30B-A3B.sh diff --git a/tests/test_quick_start_glm4-9B.sh b/tests/test_quick_start_glm4-9B.sh deleted file mode 100644 index 3d1aab9e66..0000000000 --- a/tests/test_quick_start_glm4-9B.sh +++ /dev/null @@ -1,140 +0,0 @@ -#!/bin/bash - -set -e - -# will prevent ray from buffering stdout/stderr -export PYTHONBUFFERED=16 - -NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) -if [ "$NVLINK_COUNT" -gt 0 ]; then - HAS_NVLINK=1 -else - HAS_NVLINK=0 -fi -echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" - -SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" -source "${SCRIPT_DIR}/../scripts/models/glm4-9B.sh" - -CKPT_ARGS=( - --hf-checkpoint /root/models/GLM-Z1-9B-0414/ - --ref-load /root/GLM-Z1-9B-0414_torch_dist - - --fp8-format e4m3 - --fp8-recipe blockwise -) - -ROLLOUT_ARGS=( - --prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl - --input-key prompt - --label-key label - --apply-chat-template - --rollout-shuffle - - --rm-type deepscaler - - --num-rollout 3 - --rollout-batch-size 8 - --n-samples-per-prompt 8 - --rollout-max-response-len 8192 - --rollout-temperature 0.8 - - --global-batch-size 32 - --balance-data -) - -EVAL_ARGS=( - --eval-interval 20 - --eval-prompt-data aime24 /root/datasets/aime-2024/aime-2024.jsonl - --n-samples-per-eval-prompt 1 - --eval-max-response-len 16384 - --eval-top-k 1 -) - -PERF_ARGS=( - --tensor-model-parallel-size 2 - --sequence-parallel - --pipeline-model-parallel-size 1 - --context-parallel-size 2 - --expert-model-parallel-size 1 - --expert-tensor-parallel-size 1 - - --recompute-granularity full - --recompute-method uniform - --recompute-num-layers 1 - - --use-dynamic-batch-size - --max-tokens-per-gpu 4608 -) - -GRPO_ARGS=( - --advantage-estimator grpo - --use-kl-loss - --kl-loss-coef 0.00 - --kl-loss-type low_var_kl - --entropy-coef 0.00 - --eps-clip 0.2 - --eps-clip-high 0.28 - - --use-train-infer-is - --calculate-per-token-loss -) - -OPTIMIZER_ARGS=( - --optimizer adam - --lr 1e-6 - --lr-decay-style constant - --weight-decay 0.1 - --adam-beta1 0.9 - --adam-beta2 0.98 -) - -SGLANG_ARGS=( - --rollout-num-gpus-per-engine 2 - - --use-slime-router -) - -MISC_ARGS=( - # default dropout in megatron is 0.1 - --attention-dropout 0.0 - --hidden-dropout 0.0 - # should be good for model performance - --accumulate-allreduce-grads-in-fp32 - --attention-softmax-in-fp32 - # need to comment this when using model with MLA - --attention-backend flash -) - -# launch the master node of ray in container -export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} -export no_proxy="127.0.0.1,${MASTER_ADDR}" -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats - -# Build the runtime environment JSON with proper variable substitution -RUNTIME_ENV_JSON="{ - \"env_vars\": { - \"PYTHONPATH\": \"/root/Megatron-LM/\", - \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", - \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\", - \"no_proxy\": \"${no_proxy}\" - } -}" - -ray job submit --address="http://127.0.0.1:8265" \ - --runtime-env-json="${RUNTIME_ENV_JSON}" \ - -- python3 train.py \ - --actor-num-nodes 1 \ - --actor-num-gpus-per-node 4 \ - --rollout-num-gpus 4 \ - ${MODEL_ARGS[@]} \ - ${CKPT_ARGS[@]} \ - ${ROLLOUT_ARGS[@]} \ - ${OPTIMIZER_ARGS[@]} \ - ${GRPO_ARGS[@]} \ - ${WANDB_ARGS[@]} \ - ${PERF_ARGS[@]} \ - ${EVAL_ARGS[@]} \ - ${SGLANG_ARGS[@]} \ - ${MISC_ARGS[@]} \ - --ci-test diff --git a/tests/test_qwen3-30B-A3B.sh b/tests/test_qwen3-30B-A3B.sh deleted file mode 100644 index 98efc3069a..0000000000 --- a/tests/test_qwen3-30B-A3B.sh +++ /dev/null @@ -1,151 +0,0 @@ -#!/bin/bash - -set -e - -# will prevent ray from buffering stdout/stderr -export PYTHONBUFFERED=16 - -NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) -if [ "$NVLINK_COUNT" -gt 0 ]; then - HAS_NVLINK=1 -else - HAS_NVLINK=0 -fi -echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" - -SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" -source "${SCRIPT_DIR}/../scripts/models/qwen3-30B-A3B.sh" - -CKPT_ARGS=( - --hf-checkpoint /root/models/Qwen3-30B-A3B-FP8 - --ref-load /root/Qwen3-30B-A3B_torch_dist -) - -ROLLOUT_ARGS=( - --prompt-data /root/datasets/dapo-math-17k/dapo-math-17k.jsonl - --input-key prompt - --label-key label - --apply-chat-template - --rollout-shuffle - - --rm-type deepscaler - - --num-rollout 3 - --rollout-batch-size 8 - --n-samples-per-prompt 8 - --rollout-max-response-len 8192 - --rollout-temperature 0.8 - - --global-batch-size 32 - --balance-data -) - -EVAL_ARGS=( - --eval-interval 20 - --eval-prompt-data aime24 /root/datasets/aime-2024/aime-2024.jsonl - --n-samples-per-eval-prompt 1 - --eval-max-response-len 16384 - --eval-top-k 1 -) - -PERF_ARGS=( - --tensor-model-parallel-size 4 - --sequence-parallel - --pipeline-model-parallel-size 1 - --context-parallel-size 2 - --expert-model-parallel-size 8 - --expert-tensor-parallel-size 1 - - --recompute-granularity full - --recompute-method uniform - --recompute-num-layers 1 - - --use-dynamic-batch-size - --max-tokens-per-gpu 16384 -) - -GRPO_ARGS=( - --advantage-estimator gspo - --use-kl-loss - --kl-loss-coef 0.00 - --kl-loss-type low_var_kl - --kl-coef 0.00 - --entropy-coef 0.00 - --eps-clip 4e-4 - - --use-train-infer-is - --use-routing-replay -) - -OPTIMIZER_ARGS=( - --optimizer adam - --lr 1e-6 - --lr-decay-style constant - --weight-decay 0.1 - --adam-beta1 0.9 - --adam-beta2 0.98 - - --optimizer-cpu-offload - --overlap-cpu-optimizer-d2h-h2d - --use-precision-aware-optimizer -) - -SGLANG_ARGS=( - --rollout-num-gpus-per-engine 8 - --sglang-mem-fraction-static 0.8 - - --sglang-moe-a2a-backend deepep - --sglang-deepep-mode auto - - --sglang-max-running-requests 512 - --sglang-disable-radix-cache -) - -MISC_ARGS=( - # default dropout in megatron is 0.1 - --attention-dropout 0.0 - --hidden-dropout 0.0 - # should be good for model performance - --accumulate-allreduce-grads-in-fp32 - --attention-softmax-in-fp32 - # need to comment this when using model with MLA - --attention-backend flash - - --moe-token-dispatcher-type flex - --moe-enable-deepep -) - -# launch the master node of ray in container -export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} -export no_proxy="127.0.0.1,${MASTER_ADDR}" -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats - -export CUDA_HOME=${CUDA_HOME:-"/usr/local/cuda"} -# Build the runtime environment JSON with proper variable substitution -RUNTIME_ENV_JSON="{ - \"env_vars\": { - \"PYTHONPATH\": \"/root/Megatron-LM/\", - \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", - \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\", - \"no_proxy\": \"${no_proxy}\", - \"CUDA_HOME\": \"${CUDA_HOME}\" - } -}" - -ray job submit --address="http://127.0.0.1:8265" \ - --runtime-env-json="${RUNTIME_ENV_JSON}" \ - -- python3 train.py \ - --actor-num-nodes 1 \ - --actor-num-gpus-per-node 8 \ - --colocate \ - ${MODEL_ARGS[@]} \ - ${CKPT_ARGS[@]} \ - ${ROLLOUT_ARGS[@]} \ - ${OPTIMIZER_ARGS[@]} \ - ${GRPO_ARGS[@]} \ - ${WANDB_ARGS[@]} \ - ${PERF_ARGS[@]} \ - ${EVAL_ARGS[@]} \ - ${SGLANG_ARGS[@]} \ - ${MISC_ARGS[@]} \ - --ci-test From 4455479fa1220c5685387a6286e71bf28cfe8814 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Wed, 15 Oct 2025 21:10:09 +0000 Subject: [PATCH 51/58] add kl metrics --- slime/utils/train_infer_is.py | 67 +++++++++++++++++++++++++++++++---- 1 file changed, 60 insertions(+), 7 deletions(-) diff --git a/slime/utils/train_infer_is.py b/slime/utils/train_infer_is.py index c0379257a8..43dce53e8b 100644 --- a/slime/utils/train_infer_is.py +++ b/slime/utils/train_infer_is.py @@ -5,6 +5,16 @@ from slime.backends.megatron_utils.cp_utils import all_gather_with_cp, slice_log_prob_with_cp +def masked_sum(x: torch.Tensor, loss_mask: torch.Tensor, expand: bool = False) -> torch.Tensor: + result = (x * loss_mask).sum() + return result.expand_as(x) if expand else result + + +def masked_mean(x: torch.Tensor, loss_mask: torch.Tensor, expand: bool = False) -> torch.Tensor: + result = masked_sum(x, loss_mask) / torch.clamp_min(loss_mask.sum(), 1) + return result.expand_as(x) if expand else result + + def metrics_append(metrics: Dict[str, list[torch.Tensor]], key: str, value: torch.Tensor) -> None: """ @@ -138,8 +148,9 @@ def compute_train_infer_is_weights( # handle each sequence independently for train_log_prob, rollout_log_prob, loss_mask in zip(train_log_probs, rollout_log_probs, loss_masks): - raw_log_ratio = train_log_prob - rollout_log_prob loss_mask = loss_mask.float() + add_ppl_metrics(train_log_prob, rollout_log_prob, loss_mask, metrics) + raw_log_ratio = train_log_prob - rollout_log_prob # level: The aggregation level for the importance sampling weights. if level == "token": @@ -147,24 +158,21 @@ def compute_train_infer_is_weights( log_ratio_for_metrics = raw_log_ratio elif level == "sequence": # Product of ratios (unbiased but high variance) - agg_log_ratio = (raw_log_ratio * loss_mask).sum() - log_ratio_for_metrics = agg_log_ratio.expand_as(raw_log_ratio) + log_ratio_for_metrics = masked_sum(raw_log_ratio, loss_mask, expand=True) elif level == "geometric": # Geometric mean of ratios (biased but low variance) - agg_log_ratio = (raw_log_ratio * loss_mask).sum() / torch.clamp_min(loss_mask.sum(), 1) - log_ratio_for_metrics = agg_log_ratio.expand_as(raw_log_ratio) + log_ratio_for_metrics = masked_mean(raw_log_ratio, loss_mask, expand=True) else: raise ValueError(f"Invalid importance sampling level: {level}") log_ratio_safe = torch.clamp(log_ratio_for_metrics, min=-SAFETY_BOUND, max=SAFETY_BOUND) weights = torch.exp(log_ratio_safe) + metrics_append(metrics, "ratio_mean_before_tis", weights) # mask out catastrophic tokens if args.train_infer_is_veto_threshold is not None: veto_mask = calculate_veto_mask(raw_log_ratio, loss_mask, args.train_infer_is_veto_threshold, metrics) - metrics_append(metrics, "raw_ratio_mean", weights) - # mode: how to handle the importance sampling weights exceeding the thresholds. if args.train_infer_is_mode == "truncate": # Cap the importance sampling weights at the upper threshold @@ -261,3 +269,48 @@ def slice_cp_and_concat( is_metrics[key] = values return is_weights, is_metrics + + +def add_ppl_metrics( + train_log_prob: torch.Tensor, + rollout_log_prob: torch.Tensor, + loss_mask: torch.Tensor, + metrics: Dict[str, list[torch.Tensor]], +): + loss_mask = loss_mask.float() + + # 1. Training policy perplexity metrics + mean_log_prob_training = masked_mean(train_log_prob, loss_mask, expand=True) + training_log_ppl = -mean_log_prob_training + training_ppl = torch.exp(training_log_ppl) + metrics_append(metrics, "training_log_ppl", training_log_ppl) + metrics_append(metrics, "training_ppl", training_ppl) + + # 2. Rollout policy perplexity metrics + mean_log_prob_rollout = masked_mean(rollout_log_prob, loss_mask, expand=True) + rollout_log_ppl = -mean_log_prob_rollout + rollout_ppl = torch.exp(rollout_log_ppl) + metrics_append(metrics, "rollout_log_ppl", rollout_log_ppl) + metrics_append(metrics, "rollout_ppl", rollout_ppl) + + # 2a. kl: Direct estimator for KL(π_rollout || π_training) + # This is the standard KL divergence: E[log(π_rollout) - log(π_training)] + # Positive value means rollout policy is more confident than training policy + kl_per_token = rollout_log_prob - train_log_prob + metrics_append(metrics, "kl", kl_per_token) + + # 2b. K3 KL estimator for improved stability + log_ratio = train_log_prob - rollout_log_prob + k3_kl_matrix = torch.exp(log_ratio) - log_ratio - 1 + metrics_append(metrics, "k3_kl", k3_kl_matrix) + + # 2d. Log perplexity differences and related statistics + log_ppl_diff = mean_log_prob_rollout - mean_log_prob_training + metrics_append(metrics, "log_ppl_diff", log_ppl_diff) + metrics_append(metrics, "log_ppl_abs_diff", log_ppl_diff.abs()) + metrics_append(metrics, "log_ppl_diff_max", log_ppl_diff) + metrics_append(metrics, "log_ppl_diff_min", log_ppl_diff) + + # 2e. Perplexity ratio between training and rollout policies + ppl_ratio = torch.exp(log_ppl_diff) + metrics_append(metrics, "ppl_ratio", ppl_ratio) From 3319f382cea12a3588a4250baaa03d94c4ef8d84 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Wed, 15 Oct 2025 21:16:51 +0000 Subject: [PATCH 52/58] fix comment --- slime/utils/train_infer_is.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/slime/utils/train_infer_is.py b/slime/utils/train_infer_is.py index 43dce53e8b..bceeb39a87 100644 --- a/slime/utils/train_infer_is.py +++ b/slime/utils/train_infer_is.py @@ -293,24 +293,32 @@ def add_ppl_metrics( metrics_append(metrics, "rollout_log_ppl", rollout_log_ppl) metrics_append(metrics, "rollout_ppl", rollout_ppl) - # 2a. kl: Direct estimator for KL(π_rollout || π_training) + # 3a. kl: Direct estimator for KL(π_rollout || π_training) # This is the standard KL divergence: E[log(π_rollout) - log(π_training)] # Positive value means rollout policy is more confident than training policy kl_per_token = rollout_log_prob - train_log_prob metrics_append(metrics, "kl", kl_per_token) - # 2b. K3 KL estimator for improved stability + # 3b. K3 KL estimator for improved stability + # More stable for small KL values using: E[exp(log_ratio) - log_ratio - 1] + # Formula: KL ≈ E[r - log(r) - 1] where r = π_training/π_rollout log_ratio = train_log_prob - rollout_log_prob k3_kl_matrix = torch.exp(log_ratio) - log_ratio - 1 metrics_append(metrics, "k3_kl", k3_kl_matrix) - # 2d. Log perplexity differences and related statistics + # 3c. Log PPL difference (sequence-level perplexity difference) + # log_ppl_diff = mean_log_prob_rollout - mean_log_prob_training + # Since ppl = exp(-log_prob), we have: + # log(ppl_ratio) = log(training_ppl/rollout_ppl) = log_ppl_diff + # Positive value means training assigns lower probability (higher PPL) than rollout log_ppl_diff = mean_log_prob_rollout - mean_log_prob_training metrics_append(metrics, "log_ppl_diff", log_ppl_diff) metrics_append(metrics, "log_ppl_abs_diff", log_ppl_diff.abs()) metrics_append(metrics, "log_ppl_diff_max", log_ppl_diff) metrics_append(metrics, "log_ppl_diff_min", log_ppl_diff) - # 2e. Perplexity ratio between training and rollout policies + # 3d. PPL ratio (how much higher is training PPL vs rollout PPL) + # For numerical stability, compute in log space using log_ppl_diff + # Note: log_ppl_diff = log(ppl_ratio), so ppl_ratio = exp(log_ppl_diff) ppl_ratio = torch.exp(log_ppl_diff) metrics_append(metrics, "ppl_ratio", ppl_ratio) From bf8df317f42296893935aa482185cee6f9ca650c Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Wed, 15 Oct 2025 22:40:21 +0000 Subject: [PATCH 53/58] adding kl metrics --- slime/backends/megatron_utils/loss.py | 2 +- slime/utils/train_infer_is.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 3d842f3ea9..33f985a868 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -1,6 +1,6 @@ from argparse import Namespace from collections.abc import Callable, Iterator -from typing import Optional, Union +from typing import Union import torch from megatron.core import mpu diff --git a/slime/utils/train_infer_is.py b/slime/utils/train_infer_is.py index bceeb39a87..d455a0b1b4 100644 --- a/slime/utils/train_infer_is.py +++ b/slime/utils/train_infer_is.py @@ -150,28 +150,28 @@ def compute_train_infer_is_weights( for train_log_prob, rollout_log_prob, loss_mask in zip(train_log_probs, rollout_log_probs, loss_masks): loss_mask = loss_mask.float() add_ppl_metrics(train_log_prob, rollout_log_prob, loss_mask, metrics) - raw_log_ratio = train_log_prob - rollout_log_prob + raw_log_ratio_diff = train_log_prob - rollout_log_prob # level: The aggregation level for the importance sampling weights. if level == "token": # Per-token ratio (biased) - log_ratio_for_metrics = raw_log_ratio + log_ratio_for_metrics = raw_log_ratio_diff elif level == "sequence": # Product of ratios (unbiased but high variance) - log_ratio_for_metrics = masked_sum(raw_log_ratio, loss_mask, expand=True) + log_ratio_for_metrics = masked_sum(raw_log_ratio_diff, loss_mask, expand=True) elif level == "geometric": # Geometric mean of ratios (biased but low variance) - log_ratio_for_metrics = masked_mean(raw_log_ratio, loss_mask, expand=True) + log_ratio_for_metrics = masked_mean(raw_log_ratio_diff, loss_mask, expand=True) else: raise ValueError(f"Invalid importance sampling level: {level}") log_ratio_safe = torch.clamp(log_ratio_for_metrics, min=-SAFETY_BOUND, max=SAFETY_BOUND) weights = torch.exp(log_ratio_safe) - metrics_append(metrics, "ratio_mean_before_tis", weights) + metrics_append(metrics, "mean_is_weight_before_clip", weights) # mask out catastrophic tokens if args.train_infer_is_veto_threshold is not None: - veto_mask = calculate_veto_mask(raw_log_ratio, loss_mask, args.train_infer_is_veto_threshold, metrics) + veto_mask = calculate_veto_mask(raw_log_ratio_diff, loss_mask, args.train_infer_is_veto_threshold, metrics) # mode: how to handle the importance sampling weights exceeding the thresholds. if args.train_infer_is_mode == "truncate": From fa4606dbfe8f53c8e6a21515d6b7417013e340c6 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Fri, 17 Oct 2025 04:05:52 +0000 Subject: [PATCH 54/58] revert changes to use_tis --- docs/en/get_started/quick_start.md | 2 +- docs/en/get_started/usage.md | 2 +- docs/zh/get_started/quick_start.md | 2 +- docs/zh/get_started/usage.md | 2 +- .../fully_async/run-qwen3-4b-fully_async.sh | 2 +- scripts/run-glm4.5-355B-A32B.sh | 2 +- scripts/run-qwen3-30B-A3B-test.sh | 166 ------------------ slime/backends/fsdp_utils/actor.py | 6 +- tests/test_quick_start_glm4_9B.py | 2 +- tests/test_qwen3_30B_A3B.py | 2 +- 10 files changed, 10 insertions(+), 178 deletions(-) delete mode 100644 scripts/run-qwen3-30B-A3B-test.sh diff --git a/docs/en/get_started/quick_start.md b/docs/en/get_started/quick_start.md index c0d91ed7b7..8142699443 100644 --- a/docs/en/get_started/quick_start.md +++ b/docs/en/get_started/quick_start.md @@ -275,7 +275,7 @@ GRPO_ARGS=( - `--advantage-estimator`: In addition to [GRPO](https://arxiv.org/abs/2402.03300), slime also supports several other training algorithms, such as [GSPO](https://arxiv.org/abs/2507.18071), [Reinforce++](https://arxiv.org/abs/2501.03262) and [Reinforce++ Baseline](https://arxiv.org/abs/2501.03262), and [PPO](https://arxiv.org/abs/1707.06347). - `--calculate-per-token-loss`: By default, slime calculates the loss on a per-sample basis, i.e., `mean(sum(sample_i) / len(sample_i))`. To calculate the loss on a per-token basis, i.e., `sum(sum(sample_i)) / sum(len(sample_i))`, you can enable this flag. -- `--use-train-infer-is`: Enable this setting to use TIS (Truncated Importance Sampling), which is introduced by this [blog](https://fengyao.notion.site/off-policy-rl). +- `--use-tis`: Enable this setting to use TIS (Truncated Importance Sampling), which is introduced by this [blog](https://fengyao.notion.site/off-policy-rl). ### OPTIMIZER_ARGS: Optimizer Parameters diff --git a/docs/en/get_started/usage.md b/docs/en/get_started/usage.md index 827e5fd62f..26310cb9f0 100644 --- a/docs/en/get_started/usage.md +++ b/docs/en/get_started/usage.md @@ -177,7 +177,7 @@ Additionally, we provide a `metadata_key`, which defaults to `"metadata"`. When - `reinforce_plus_plus` and `reinforce_plus_plus_baseline` ([https://arxiv.org/abs/2501.03262](https://arxiv.org/abs/2501.03262)) - `ppo` ([https://arxiv.org/abs/1707.06347](https://arxiv.org/abs/1707.06347)) - `--calculate-per-token-loss`: By default, Slime calculates loss on a per-sample basis, i.e., `mean(sum(sample_i) / len(sample_i))`. Enable this flag to calculate loss on a per-token basis, i.e., `sum(sum(sample_i)) / sum(len(sample_i))`. -- `--use-train-infer-is`: Enable this setting to use TIS (Truncated Importance Sampling) (https://fengyao.notion.site/off-policy-rl). +- `--use-tis`: Enable this setting to use TIS (Truncated Importance Sampling) (https://fengyao.notion.site/off-policy-rl). ## Custom Rollout Function diff --git a/docs/zh/get_started/quick_start.md b/docs/zh/get_started/quick_start.md index e8db493495..a0ad9730e7 100644 --- a/docs/zh/get_started/quick_start.md +++ b/docs/zh/get_started/quick_start.md @@ -274,7 +274,7 @@ GRPO_ARGS=( - `--advantage-estimator`: 除去 [GRPO](https://arxiv.org/abs/2402.03300),slime 还支持丰富的其他训练算法,例如 [GSPO](https://arxiv.org/abs/2507.18071)、[Reinforce++](https://arxiv.org/abs/2501.03262) 与 [Reinforce++ Baseline](https://arxiv.org/abs/2501.03262)、以及 [PPO](https://arxiv.org/abs/1707.06347); - `--calculate-per-token-loss`:slime 中默认的方案是 per sample loss,即 `mean(sum(sample_i) / len(sample_i))`,如果需要计算 per token loss,即 `sum(sum(sample_i)) / sum(len(sample_i))`,可以开启 `--calculate-per-token-loss`; -- `--use-train-infer-is`:如果需要开启 TIS (Truncated Importance Sampling),可以开启这一设置。TIS 由此[博客](https://fengyao.notion.site/off-policy-rl)介绍。 +- `--use-tis`:如果需要开启 TIS (Truncated Importance Sampling),可以开启这一设置。TIS 由此[博客](https://fengyao.notion.site/off-policy-rl)介绍。 ### OPTIMIZER_ARGS: 优化器参数 diff --git a/docs/zh/get_started/usage.md b/docs/zh/get_started/usage.md index 47e1b09b85..332be9d417 100644 --- a/docs/zh/get_started/usage.md +++ b/docs/zh/get_started/usage.md @@ -181,7 +181,7 @@ sglang 的加载非常简单,只需要: - `reinforce_plus_plus` 与 `reinforce_plus_plus_baseline`(https://arxiv.org/abs/2501.03262); - `ppo`(https://arxiv.org/abs/1707.06347)。 - `--calculate-per-token-loss`:slime 中默认的方案是 per sample loss,即 `mean(sum(sample_i) / len(sample_i))`,如果需要计算 per token loss,即 `sum(sum(sample_i)) / sum(len(sample_i))`,可以开启 `--calculate-per-token-loss`; -- `--use-train-infer-is`:如果需要开启 tis(https://fengyao.notion.site/off-policy-rl),可以开启这一设置。 +- `--use-tis`:如果需要开启 tis(https://fengyao.notion.site/off-policy-rl),可以开启这一设置。 ## 自定义 rollout 函数 diff --git a/examples/fully_async/run-qwen3-4b-fully_async.sh b/examples/fully_async/run-qwen3-4b-fully_async.sh index 3665c6fb62..2ef6162ab0 100644 --- a/examples/fully_async/run-qwen3-4b-fully_async.sh +++ b/examples/fully_async/run-qwen3-4b-fully_async.sh @@ -79,7 +79,7 @@ GRPO_ARGS=( --eps-clip 0.2 --eps-clip-high 0.28 - --use-train-infer-is + --use-tis ) OPTIMIZER_ARGS=( diff --git a/scripts/run-glm4.5-355B-A32B.sh b/scripts/run-glm4.5-355B-A32B.sh index 3d39ad334c..4a9792df0c 100644 --- a/scripts/run-glm4.5-355B-A32B.sh +++ b/scripts/run-glm4.5-355B-A32B.sh @@ -86,7 +86,7 @@ GRPO_ARGS=( --eps-clip 1e-4 --eps-clip-high 2e-4 - --use-train-infer-is + --use-tis ) OPTIMIZER_ARGS=( diff --git a/scripts/run-qwen3-30B-A3B-test.sh b/scripts/run-qwen3-30B-A3B-test.sh deleted file mode 100644 index 2bfbfc6395..0000000000 --- a/scripts/run-qwen3-30B-A3B-test.sh +++ /dev/null @@ -1,166 +0,0 @@ -#!/bin/bash - -# for rerun the task -pkill -9 sglang -sleep 3 -ray stop --force -pkill -9 ray -pkill -9 python -sleep 3 -pkill -9 ray -pkill -9 python -pkill -9 redis - -set -ex - -# will prevent ray from buffering stdout/stderr -export PYTHONBUFFERED=16 - -NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) -if [ "$NVLINK_COUNT" -gt 0 ]; then - HAS_NVLINK=1 -else - HAS_NVLINK=0 -fi -echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" - -SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" -source "${SCRIPT_DIR}/models/qwen3-30B-A3B.sh" - -CKPT_ARGS=( - --hf-checkpoint /root/Qwen3-30B-A3B - #--hf-checkpoint /root/Qwen3-30B-A3B-FP8 - --ref-load /root/Qwen3-30B-A3B_torch_dist - # --load /root/Qwen3-30B-A3B_slime/ - # --save /root/Qwen3-30B-A3B_slime/ - # --save-interval 20 -) - -ROLLOUT_ARGS=( - --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl - --input-key prompt - --label-key label - --apply-chat-template - --rollout-shuffle - --rm-type deepscaler - --num-rollout 3000 - --rollout-batch-size 8 - --n-samples-per-prompt 4 - --rollout-max-response-len 8192 - --rollout-temperature 0.8 - - --global-batch-size 32 - --balance-data -) - -EVAL_ARGS=( - # --eval-interval 20 - --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl - --n-samples-per-eval-prompt 16 - --eval-max-response-len 16384 - --eval-top-p 0.7 -) - -PERF_ARGS=( - --tensor-model-parallel-size 2 - --sequence-parallel - --pipeline-model-parallel-size 1 - --context-parallel-size 2 - --expert-model-parallel-size 4 - --expert-tensor-parallel-size 1 - - --recompute-granularity full - --recompute-method uniform - --recompute-num-layers 1 - - # --micro-batch-size 1 - --use-dynamic-batch-size - --max-tokens-per-gpu 20480 -) - -GRPO_ARGS=( - --advantage-estimator grpo - --use-kl-loss - --kl-loss-coef 0.00 - --kl-loss-type low_var_kl - --entropy-coef 0.00 - --eps-clip 0.2 - --eps-clip-high 0.28 -) - -IS_ARGS=( - --use-train-infer-is - --train-infer-is-level geometric - --train-infer-is-mode mask - --train-infer-is-lower-bound 0.5 - --train-infer-is-upper-bound 2.0 - --train-infer-is-veto-threshold 1e-3 -) - -OPTIMIZER_ARGS=( - --optimizer adam - --lr 1e-6 - --lr-decay-style constant - --weight-decay 0.1 - --adam-beta1 0.9 - --adam-beta2 0.98 - - --optimizer-cpu-offload - --overlap-cpu-optimizer-d2h-h2d - --use-precision-aware-optimizer -) - -WANDB_ARGS=( - --use-wandb - --wandb-project slime-dev - --wandb-group qwen3-30B-A3B-TIS - --wandb-run-id qwen3-30B-A3B-TIS-sequence - --wandb-key ${WANDB_KEY} -) - -SGLANG_ARGS=( - --rollout-num-gpus-per-engine 4 - --sglang-mem-fraction-static 0.7 - --sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256) -) - -MISC_ARGS=( - # default dropout in megatron is 0.1 - --attention-dropout 0.0 - --hidden-dropout 0.0 - # should be good for model performance - --accumulate-allreduce-grads-in-fp32 - --attention-softmax-in-fp32 - # need to comment this when using model with MLA - --attention-backend flash -) - -# launch the master node of ray in container -export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 4 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 - -# Build the runtime environment JSON with proper variable substitution -RUNTIME_ENV_JSON="{ - \"env_vars\": { - \"PYTHONPATH\": \"/root/Megatron-LM/\", - \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", - \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" - } -}" - -ray job submit --address="http://127.0.0.1:8265" \ - --runtime-env-json="${RUNTIME_ENV_JSON}" \ - -- python3 train.py \ - --actor-num-nodes 1 \ - --actor-num-gpus-per-node 4 \ - --colocate \ - ${MODEL_ARGS[@]} \ - ${CKPT_ARGS[@]} \ - ${ROLLOUT_ARGS[@]} \ - ${OPTIMIZER_ARGS[@]} \ - ${GRPO_ARGS[@]} \ - ${WANDB_ARGS[@]} \ - ${PERF_ARGS[@]} \ - ${EVAL_ARGS[@]} \ - ${SGLANG_ARGS[@]} \ - ${MISC_ARGS[@]} diff --git a/slime/backends/fsdp_utils/actor.py b/slime/backends/fsdp_utils/actor.py index 3f5f7997fa..ec65b77954 100644 --- a/slime/backends/fsdp_utils/actor.py +++ b/slime/backends/fsdp_utils/actor.py @@ -401,7 +401,7 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None: pg_loss, pg_clipfrac = compute_policy_loss(ppo_kl, advantages, self.args.eps_clip, self.args.eps_clip_high) # Apply TIS before sample mean calculation - if self.args.use_train_infer_is: + if self.args.use_tis: # Initialize TIS variables tis = None tis_clipfrac = None @@ -420,9 +420,7 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None: tis = torch.exp(old_log_probs - rollout_log_probs) ois = (-ppo_kl).exp() tis_clip = torch.clamp( - tis, - min=getattr(self.args, "train_infer_is_lower_bound", 0.1), - max=getattr(self.args, "train_infer_is_upper_bound", 2.0), + tis, min=getattr(self.args, "tis_clip_low", 0.1), max=getattr(self.args, "tis_clip", 2.0) ) tis_clipfrac = tis_clip != tis diff --git a/tests/test_quick_start_glm4_9B.py b/tests/test_quick_start_glm4_9B.py index 63b648dfb9..c491be37c4 100644 --- a/tests/test_quick_start_glm4_9B.py +++ b/tests/test_quick_start_glm4_9B.py @@ -67,7 +67,7 @@ def execute(): "--entropy-coef 0.00 " "--eps-clip 0.2 " "--eps-clip-high 0.28 " - "--use-train-infer-is " + "--use-tis " "--calculate-per-token-loss " ) diff --git a/tests/test_qwen3_30B_A3B.py b/tests/test_qwen3_30B_A3B.py index e7c4ff07d8..6db31bbb0e 100644 --- a/tests/test_qwen3_30B_A3B.py +++ b/tests/test_qwen3_30B_A3B.py @@ -69,7 +69,7 @@ def execute(): "--kl-coef 0.00 " "--entropy-coef 0.00 " "--eps-clip 4e-4 " - "--use-train-infer-is " + "--use-tis " "--use-routing-replay " ) From a66753f88b76ab09101b0ee5736da9736204cd83 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Fri, 17 Oct 2025 04:07:06 +0000 Subject: [PATCH 55/58] move to examples and use yaml for custom args parsing --- .../train_infer_mismatch_helper/mis.py | 59 ++++--- examples/train_infer_mismatch_helper/mis.yaml | 26 +++ .../run-qwen3-4b-mis.sh | 157 ++++++++++++++++++ slime/backends/megatron_utils/loss.py | 38 +++-- slime/utils/arguments.py | 91 +++------- slime/utils/tis.py | 24 +++ 6 files changed, 289 insertions(+), 106 deletions(-) rename slime/utils/train_infer_is.py => examples/train_infer_mismatch_helper/mis.py (89%) create mode 100644 examples/train_infer_mismatch_helper/mis.yaml create mode 100644 examples/train_infer_mismatch_helper/run-qwen3-4b-mis.sh create mode 100644 slime/utils/tis.py diff --git a/slime/utils/train_infer_is.py b/examples/train_infer_mismatch_helper/mis.py similarity index 89% rename from slime/utils/train_infer_is.py rename to examples/train_infer_mismatch_helper/mis.py index d455a0b1b4..bd3703bc2d 100644 --- a/slime/utils/train_infer_is.py +++ b/examples/train_infer_mismatch_helper/mis.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple import torch @@ -109,7 +109,7 @@ def mask( return weights * mask * loss_mask -def compute_train_infer_is_weights( +def compute_mis_weights( args, *, train_log_probs: list[torch.Tensor], @@ -130,9 +130,12 @@ def compute_train_infer_is_weights( metrics: The metrics for the importance sampling weights, a dict of list[torch.Tensor]. 1D tensor each. """ - level: str = args.train_infer_is_level + level: str = args.mis_level metrics: Dict[str, list[torch.Tensor]] = {} + if args.mis_lower_bound is None: + return 1.0 / args.mis_upper_bound + # Validate input lists have same length and each sequence has matching shapes assert ( len(train_log_probs) == len(rollout_log_probs) == len(loss_masks) @@ -151,6 +154,15 @@ def compute_train_infer_is_weights( loss_mask = loss_mask.float() add_ppl_metrics(train_log_prob, rollout_log_prob, loss_mask, metrics) raw_log_ratio_diff = train_log_prob - rollout_log_prob + print("train_log_prob: ", train_log_prob) + print("max_train_log_prob: ", train_log_prob.max()) + print("min_train_log_prob: ", train_log_prob.min()) + print("rollout_log_prob: ", rollout_log_prob) + print("max_rollout_log_prob: ", rollout_log_prob.max()) + print("min_rollout_log_prob: ", rollout_log_prob.min()) + print("raw_log_ratio_diff: ", raw_log_ratio_diff) + print("max_raw_log_ratio_diff: ", raw_log_ratio_diff.max()) + print("min_raw_log_ratio_diff: ", raw_log_ratio_diff.min()) # level: The aggregation level for the importance sampling weights. if level == "token": @@ -170,39 +182,39 @@ def compute_train_infer_is_weights( metrics_append(metrics, "mean_is_weight_before_clip", weights) # mask out catastrophic tokens - if args.train_infer_is_veto_threshold is not None: - veto_mask = calculate_veto_mask(raw_log_ratio_diff, loss_mask, args.train_infer_is_veto_threshold, metrics) + if args.mis_veto_threshold is not None: + veto_mask = calculate_veto_mask(raw_log_ratio_diff, loss_mask, args.mis_veto_threshold, metrics) # mode: how to handle the importance sampling weights exceeding the thresholds. - if args.train_infer_is_mode == "truncate": + if args.mis_mode == "truncate": # Cap the importance sampling weights at the upper threshold # https://fengyao.notion.site/off-policy-rl#279721e3f6c48092bbe2fcfe0e9c6b33 - weights = truncate(weights, loss_mask, metrics, args.train_infer_is_upper_bound) - elif args.train_infer_is_mode == "mask": + weights = truncate(weights, loss_mask, metrics, args.mis_upper_bound) + elif args.mis_mode == "mask": # Zero the importance sampling weights outside the [lower, upper] range. # https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda weights = mask( weights, loss_mask, metrics, - args.train_infer_is_lower_bound, - args.train_infer_is_upper_bound, + args.mis_lower_bound, + args.mis_upper_bound, ) - elif args.train_infer_is_mode == "clip": + elif args.mis_mode == "clip": # Clip the importance sampling weights to the [lower, upper] range. # Original behavior in slime. weights = clip( weights, loss_mask, metrics, - args.train_infer_is_lower_bound, - args.train_infer_is_upper_bound, + args.mis_lower_bound, + args.mis_upper_bound, ) else: - raise ValueError(f"Unsupported train_infer_is_mode: {args.train_infer_is_mode}") + raise ValueError(f"Unsupported mis_mode: {args.mis_mode}") - metrics_append(metrics, "ratio_mean_after_tis", weights) - if args.train_infer_is_veto_threshold is not None: + metrics_append(metrics, "ratio_mean_after_mis", weights) + if args.mis_veto_threshold is not None: weights = weights * veto_mask metrics_append(metrics, "ratio_mean_after_veto_mask", weights) @@ -212,7 +224,7 @@ def compute_train_infer_is_weights( return all_weights, metrics -def compute_train_infer_is_weights_with_cp( +def compute_mis_weights_with_cp( args, *, train_log_probs: list[torch.Tensor], @@ -220,9 +232,10 @@ def compute_train_infer_is_weights_with_cp( loss_masks: list[torch.Tensor], total_lengths: list[int], response_lengths: list[int], + **kwargs: Any, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ - Compute the truncated importance sampling (TIS) weights and metrics with context parallel. + Compute the importance sampling (IS) weights and metrics with context parallel. Args: train_log_probs: List of log probs from training backend on this cp rank. 1D tensor each. Lengths can be different. rollout_log_probs: List of log probs from inference backend on this cp rank. 1D tensor each. @@ -245,7 +258,7 @@ def compute_train_infer_is_weights_with_cp( ] # Main logic for is - is_weights, is_metrics = compute_train_infer_is_weights( + is_weights, is_metrics = compute_mis_weights( args=args, train_log_probs=full_old_log_probs, rollout_log_probs=full_rollout_log_probs, @@ -263,12 +276,14 @@ def slice_cp_and_concat( ] return torch.cat(values, dim=0) + result_metrics = {} is_weights = slice_cp_and_concat(is_weights, total_lengths, response_lengths) for key, values in is_metrics.items(): + key_name = f"mis_{key}" values = slice_cp_and_concat(values, total_lengths, response_lengths) - is_metrics[key] = values + result_metrics[key_name] = values - return is_weights, is_metrics + return is_weights, result_metrics def add_ppl_metrics( @@ -314,8 +329,6 @@ def add_ppl_metrics( log_ppl_diff = mean_log_prob_rollout - mean_log_prob_training metrics_append(metrics, "log_ppl_diff", log_ppl_diff) metrics_append(metrics, "log_ppl_abs_diff", log_ppl_diff.abs()) - metrics_append(metrics, "log_ppl_diff_max", log_ppl_diff) - metrics_append(metrics, "log_ppl_diff_min", log_ppl_diff) # 3d. PPL ratio (how much higher is training PPL vs rollout PPL) # For numerical stability, compute in log space using log_ppl_diff diff --git a/examples/train_infer_mismatch_helper/mis.yaml b/examples/train_infer_mismatch_helper/mis.yaml new file mode 100644 index 0000000000..3a7fbbafc9 --- /dev/null +++ b/examples/train_infer_mismatch_helper/mis.yaml @@ -0,0 +1,26 @@ +# Enable importance sampling, details refer to the comments of compute_mis_weights in mis.py +use_mis: false + +# Aggregation level for importance sampling weights: +# token: per-token +# sequence: product over tokens +# geometric: geometric mean +mis_level: "token" + +# Handling mode for IS weights: +# truncate: cap to upper bound, TIS +# mask: zero outside [lower, upper], MIS +# clip: clip to [lower, upper], CIS +mis_mode: "truncate" + +# For mask or clip mode, the lower bound of the IS weights. +# For truncate mode, it will not be used. +# If not set, it will be set to 1.0 / mis_upper_bound +mis_lower_bound: 0.5 + +# For truncate, mask, or clip mode, the upper bound of the IS weights +mis_upper_bound: 2.0 + +# Per-token veto threshold. If any token ratio < this, zero the entire sequence weight, the sequences won't have gradient +# Note: float number must be written with dot e.g. 1.0e-4, not 1e-4 +mis_veto_threshold: 1.0e-4 diff --git a/examples/train_infer_mismatch_helper/run-qwen3-4b-mis.sh b/examples/train_infer_mismatch_helper/run-qwen3-4b-mis.sh new file mode 100644 index 0000000000..f19f5e87fe --- /dev/null +++ b/examples/train_infer_mismatch_helper/run-qwen3-4b-mis.sh @@ -0,0 +1,157 @@ +#!/bin/bash + +# for rerun the task +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +source "/root/slime/scripts/models/qwen3-4B.sh" + +CKPT_ARGS=( + --hf-checkpoint /root/Qwen3-4B + #--hf-checkpoint /root/Qwen3-4B-FP8 + --ref-load /root/Qwen3-4B_torch_dist + # --load /root/Qwen3-4B_slime/ + --save /root/Qwen3-4B_slime/ + --save-interval 200 +) + +ROLLOUT_ARGS=( + --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + --rm-type deepscaler + --num-rollout 3000 + --rollout-batch-size 32 + --n-samples-per-prompt 8 + --rollout-max-response-len 8192 + --rollout-temperature 0.8 + + --global-batch-size 256 + --balance-data +) + +EVAL_ARGS=( + # --eval-interval 20 + --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl + --n-samples-per-eval-prompt 1 + --eval-max-response-len 16384 + --eval-top-p 0.7 +) + +PERF_ARGS=( + --tensor-model-parallel-size 2 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 2 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + # --micro-batch-size 1 + --use-dynamic-batch-size + --max-tokens-per-gpu 9216 +) + +GRPO_ARGS=( + --advantage-estimator grpo + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 + --use-tis +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=( + --use-wandb + --wandb-project slime-mis + --wandb-group qwen3-4B-mis + --wandb-key ${WANDB_KEY} +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + --sglang-mem-fraction-static 0.7 +) + +MISC_ARGS=( + # default dropout in megatron is 0.1 + --attention-dropout 0.0 + --hidden-dropout 0.0 + # should be good for model performance + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + # need to comment this when using model with MLA + --attention-backend flash +) + +CUSTOM_ARGS=( + --custom-config-path examples/train_infer_mismatch_helper/mis.yaml + --custom-tis-function-path examples.train_infer_mismatch_helper.mis.compute_mis_weights_with_cp +) + +# launch the master node of ray in container +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 + +# Build the runtime environment JSON with proper variable substitution +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" + } +}" + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 8 \ + --colocate \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${CUSTOM_ARGS[@]} \ No newline at end of file diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 33f985a868..62d851efd6 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -16,7 +16,7 @@ get_reinforce_plus_plus_baseline_advantages, get_reinforce_plus_plus_returns, ) -from slime.utils.train_infer_is import compute_train_infer_is_weights_with_cp +from slime.utils.tis import compute_tis_weights from slime.utils.types import RolloutBatch from .cp_utils import all_gather_with_cp, get_logits_and_tokens_offset_with_cp, get_sum_of_sample_mean @@ -419,21 +419,26 @@ def policy_loss_function( pg_loss, pg_clipfrac = compute_policy_loss(ppo_kl, advantages, args.eps_clip, args.eps_clip_high) - # Apply TIS off-policy correction using importance sampling if enabled - if args.use_train_infer_is: + # Apply off-policy correction using importance sampling if enabled + if args.use_tis: assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS" + ois = (-ppo_kl).exp() - is_weights, is_metrics = compute_train_infer_is_weights_with_cp( - args=args, - train_log_probs=batch["log_probs"], - rollout_log_probs=batch["rollout_log_probs"], - loss_masks=batch["loss_masks"], - total_lengths=total_lengths, - response_lengths=response_lengths, - ) + tis_kwargs = { + "args": args, + "train_log_probs": batch["log_probs"], + "rollout_log_probs": batch["rollout_log_probs"], + "loss_masks": batch["loss_masks"], + "total_lengths": total_lengths, + "response_lengths": response_lengths, + } + if args.custom_tis_function_path is not None: + tis_func = load_function(args.custom_tis_function_path) + tis_weights, tis_metrics = tis_func(**tis_kwargs) + else: + tis_weights, tis_metrics = compute_tis_weights(**tis_kwargs) - ois = (-ppo_kl).exp() - pg_loss = pg_loss * is_weights + pg_loss = pg_loss * tis_weights pg_loss = sum_of_sample_mean(pg_loss) pg_clipfrac = sum_of_sample_mean(pg_clipfrac) @@ -473,11 +478,10 @@ def policy_loss_function( if args.use_kl_loss: reported_loss["kl_loss"] = kl_loss.clone().detach() - if args.use_train_infer_is: - # Backward compatible basic logs + if args.use_tis: reported_loss["ois"] = sum_of_sample_mean(ois).clone().detach() - for metric_key, metric_value in is_metrics.items(): - key_name = f"train_infer_{metric_key}" + for metric_key, metric_value in tis_metrics.items(): + key_name = f"{metric_key}" reported_loss[key_name] = sum_of_sample_mean(metric_value) return loss, reported_loss diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 8799c7ef8f..d6341d3821 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -26,67 +26,6 @@ def reset_arg(parser, name, **kwargs): parser.add_argument(name, **kwargs) -def add_is_arguments(parser: argparse.ArgumentParser): - # Off-Policy Correction arguments for importance sampling - # training/inference importance sampling - parser.add_argument( - "--use-train-infer-is", - action="store_true", - default=False, - help=( - "Enable importance sampling, details refer to the comments of compute_train_infer_is_weights " - "in train_infer_is.py" - ), - ) - parser.add_argument( - "--train-infer-is-level", - type=str, - choices=["token", "sequence", "geometric"], - default="token", - help=( - "Aggregation level for importance sampling weights: token (per-token), " - "sequence (product over tokens), geometric (geometric mean)." - ), - ) - parser.add_argument( - "--train-infer-is-mode", - type=str, - choices=["truncate", "mask", "clip"], - default="truncate", - help=( - "Handling mode for IS weights:" - "truncate (cap to upper bound, TIS)," - "mask (zero outside [lower, upper], MIS)," - "clip (clip to [lower, upper], CIS)." - ), - ) - parser.add_argument( - "--train-infer-is-lower-bound", - type=float, - default=None, - help=( - "For mask or clip mode, the lower bound of the IS weights. For truncate mode, it will not be used. " - "If not set, it will be set to 1.0 / train_infer_is_upper_bound." - ), - ) - parser.add_argument( - "--train-infer-is-upper-bound", - type=float, - default=2.0, - help=("For truncate, mask, or clip mode, the upper bound of the IS weights."), - ) - parser.add_argument( - "--train-infer-is-veto-threshold", - type=float, - default=None, - help=( - "Per-token veto threshold. If any token ratio < this, zero the entire sequence weight, the sequences won't have gradient." - ), - ) - - return parser - - def get_slime_extra_args_provider(add_custom_arguments=None): def add_slime_arguments(parser): # Ray @@ -736,6 +675,31 @@ def add_algo_arguments(parser): "This is useful for doing special loss mask." ), ) + # Off-Policy Correction using Importance Sampling: https://fengyao.notion.site/off-policy-rl + parser.add_argument( + "--use-tis", + action="store_true", + default=False, + help="Enable TIS from https://fengyao.notion.site/off-policy-rl for off-policy importance sampling.", + ) + parser.add_argument( + "--tis-clip", + type=float, + default=2.0, + help="Clipping threshold C for importance sampling ratios to control variance.", + ) + parser.add_argument( + "--tis-clip-low", + type=float, + default=0, + help="Lower bound clipping threshold C for importance sampling ratios to control variance.", + ) + parser.add_argument( + "--custom-tis-function-path", + type=str, + default=None, + help="Path to the custom TIS function.", + ) parser.add_argument( "--use-routing-replay", @@ -1047,7 +1011,6 @@ def add_sglang_tp_size(): parser = add_reward_model_arguments(parser) parser = add_rollout_buffer_arguments(parser) parser = add_ci_arguments(parser) - parser = add_is_arguments(parser) parser.set_defaults(sglang_tensor_parallel_size=add_sglang_tp_size()) # For megatron @@ -1188,10 +1151,6 @@ def slime_validate_args(args): if args.eps_clip_high is None: args.eps_clip_high = args.eps_clip - if args.use_train_infer_is: - if args.train_infer_is_lower_bound is None: - args.train_infer_is_lower_bound = 1.0 / args.train_infer_is_upper_bound - if args.eval_reward_key is None: args.eval_reward_key = args.reward_key diff --git a/slime/utils/tis.py b/slime/utils/tis.py new file mode 100644 index 0000000000..11f834ac1a --- /dev/null +++ b/slime/utils/tis.py @@ -0,0 +1,24 @@ +from typing import Any, Dict, Tuple + +import torch + + +def compute_tis_weights( + args, + *, + train_log_probs: list[torch.Tensor], + rollout_log_probs: list[torch.Tensor], + **kwargs: Any, +) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + rollout_log_probs = torch.cat(rollout_log_probs, dim=0) + old_log_probs = torch.cat(train_log_probs, dim=0) + tis = torch.exp(old_log_probs - rollout_log_probs) + tis_clip = torch.clamp(tis, min=args.tis_clip_low, max=args.tis_clip) + tis_clipfrac = (tis_clip != tis).float() + + metrics = { + "tis": tis.clone().detach(), + "tis_clipfrac": tis_clipfrac.clone().detach(), + } + + return tis_clip, metrics From fffeab9880797ef5f71cb01e0b9a0601010d92fd Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Fri, 17 Oct 2025 04:31:00 +0000 Subject: [PATCH 56/58] fix small bug --- examples/train_infer_mismatch_helper/mis.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/examples/train_infer_mismatch_helper/mis.py b/examples/train_infer_mismatch_helper/mis.py index bd3703bc2d..6514e8a09c 100644 --- a/examples/train_infer_mismatch_helper/mis.py +++ b/examples/train_infer_mismatch_helper/mis.py @@ -154,15 +154,6 @@ def compute_mis_weights( loss_mask = loss_mask.float() add_ppl_metrics(train_log_prob, rollout_log_prob, loss_mask, metrics) raw_log_ratio_diff = train_log_prob - rollout_log_prob - print("train_log_prob: ", train_log_prob) - print("max_train_log_prob: ", train_log_prob.max()) - print("min_train_log_prob: ", train_log_prob.min()) - print("rollout_log_prob: ", rollout_log_prob) - print("max_rollout_log_prob: ", rollout_log_prob.max()) - print("min_rollout_log_prob: ", rollout_log_prob.min()) - print("raw_log_ratio_diff: ", raw_log_ratio_diff) - print("max_raw_log_ratio_diff: ", raw_log_ratio_diff.max()) - print("min_raw_log_ratio_diff: ", raw_log_ratio_diff.min()) # level: The aggregation level for the importance sampling weights. if level == "token": From 9e4bf7c9e2f2fbc69f07a07bd49214834fe6cbc3 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Fri, 17 Oct 2025 06:50:52 +0000 Subject: [PATCH 57/58] remove tis file --- slime/backends/megatron_utils/loss.py | 31 +++++++++++++++++---------- slime/utils/tis.py | 24 --------------------- 2 files changed, 20 insertions(+), 35 deletions(-) delete mode 100644 slime/utils/tis.py diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 62d851efd6..32765c9c81 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -16,7 +16,6 @@ get_reinforce_plus_plus_baseline_advantages, get_reinforce_plus_plus_returns, ) -from slime.utils.tis import compute_tis_weights from slime.utils.types import RolloutBatch from .cp_utils import all_gather_with_cp, get_logits_and_tokens_offset_with_cp, get_sum_of_sample_mean @@ -422,21 +421,30 @@ def policy_loss_function( # Apply off-policy correction using importance sampling if enabled if args.use_tis: assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS" - ois = (-ppo_kl).exp() - tis_kwargs = { - "args": args, - "train_log_probs": batch["log_probs"], - "rollout_log_probs": batch["rollout_log_probs"], - "loss_masks": batch["loss_masks"], - "total_lengths": total_lengths, - "response_lengths": response_lengths, - } + ois = (-ppo_kl).exp() if args.custom_tis_function_path is not None: + tis_kwargs = { + "args": args, + "train_log_probs": batch["log_probs"], + "rollout_log_probs": batch["rollout_log_probs"], + "loss_masks": batch["loss_masks"], + "total_lengths": total_lengths, + "response_lengths": response_lengths, + } tis_func = load_function(args.custom_tis_function_path) tis_weights, tis_metrics = tis_func(**tis_kwargs) else: - tis_weights, tis_metrics = compute_tis_weights(**tis_kwargs) + rollout_log_probs = torch.cat(batch["rollout_log_probs"], dim=0) + old_log_probs = torch.cat(batch["log_probs"], dim=0) + tis = torch.exp(old_log_probs - rollout_log_probs) + tis_weights = torch.clamp(tis, min=args.tis_clip_low, max=args.tis_clip) + tis_clipfrac = (tis_weights != tis).float() + + tis_metrics = { + "tis": tis.clone().detach(), + "tis_clipfrac": tis_clipfrac.clone().detach(), + } pg_loss = pg_loss * tis_weights @@ -480,6 +488,7 @@ def policy_loss_function( if args.use_tis: reported_loss["ois"] = sum_of_sample_mean(ois).clone().detach() + # Assume all metrics are already cloned and detached for metric_key, metric_value in tis_metrics.items(): key_name = f"{metric_key}" reported_loss[key_name] = sum_of_sample_mean(metric_value) diff --git a/slime/utils/tis.py b/slime/utils/tis.py deleted file mode 100644 index 11f834ac1a..0000000000 --- a/slime/utils/tis.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import Any, Dict, Tuple - -import torch - - -def compute_tis_weights( - args, - *, - train_log_probs: list[torch.Tensor], - rollout_log_probs: list[torch.Tensor], - **kwargs: Any, -) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: - rollout_log_probs = torch.cat(rollout_log_probs, dim=0) - old_log_probs = torch.cat(train_log_probs, dim=0) - tis = torch.exp(old_log_probs - rollout_log_probs) - tis_clip = torch.clamp(tis, min=args.tis_clip_low, max=args.tis_clip) - tis_clipfrac = (tis_clip != tis).float() - - metrics = { - "tis": tis.clone().detach(), - "tis_clipfrac": tis_clipfrac.clone().detach(), - } - - return tis_clip, metrics From eb7711c8767c0792d01bc29f86307ebca0184671 Mon Sep 17 00:00:00 2001 From: Jiajun Li Date: Sun, 19 Oct 2025 17:20:07 +0000 Subject: [PATCH 58/58] give vanilla tis func --- slime/backends/megatron_utils/loss.py | 49 ++++++++++++++++----------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 32765c9c81..83d135ef44 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -1,6 +1,6 @@ from argparse import Namespace from collections.abc import Callable, Iterator -from typing import Union +from typing import Any, Dict, Tuple, Union import torch from megatron.core import mpu @@ -420,31 +420,42 @@ def policy_loss_function( # Apply off-policy correction using importance sampling if enabled if args.use_tis: - assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS" - ois = (-ppo_kl).exp() - if args.custom_tis_function_path is not None: - tis_kwargs = { - "args": args, - "train_log_probs": batch["log_probs"], - "rollout_log_probs": batch["rollout_log_probs"], - "loss_masks": batch["loss_masks"], - "total_lengths": total_lengths, - "response_lengths": response_lengths, - } - tis_func = load_function(args.custom_tis_function_path) - tis_weights, tis_metrics = tis_func(**tis_kwargs) - else: - rollout_log_probs = torch.cat(batch["rollout_log_probs"], dim=0) - old_log_probs = torch.cat(batch["log_probs"], dim=0) + def vanilla_tis_function( + args, + *, + train_log_probs: list[torch.Tensor], + rollout_log_probs: list[torch.Tensor], + **kwargs: Any, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + rollout_log_probs = torch.cat(rollout_log_probs, dim=0) + old_log_probs = torch.cat(train_log_probs, dim=0) tis = torch.exp(old_log_probs - rollout_log_probs) tis_weights = torch.clamp(tis, min=args.tis_clip_low, max=args.tis_clip) tis_clipfrac = (tis_weights != tis).float() - - tis_metrics = { + metrics = { "tis": tis.clone().detach(), "tis_clipfrac": tis_clipfrac.clone().detach(), } + return tis_weights, metrics + + assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS" + + ois = (-ppo_kl).exp() + tis_kwargs = { + "args": args, + "train_log_probs": batch["log_probs"], + "rollout_log_probs": batch["rollout_log_probs"], + "loss_masks": batch["loss_masks"], + "total_lengths": total_lengths, + "response_lengths": response_lengths, + } + + if args.custom_tis_function_path is not None: + tis_func = load_function(args.custom_tis_function_path) + else: + tis_func = vanilla_tis_function + tis_weights, tis_metrics = tis_func(**tis_kwargs) pg_loss = pg_loss * tis_weights