diff --git a/examples/train_infer_mismatch_helper/mis.py b/examples/train_infer_mismatch_helper/mis.py index 5866741cca..f17b232a75 100644 --- a/examples/train_infer_mismatch_helper/mis.py +++ b/examples/train_infer_mismatch_helper/mis.py @@ -395,3 +395,45 @@ def add_ppl_metrics( rho_squared_seq = torch.exp(2.0 * log_ratio_sum_safe) # (Π ρ_t)² chi2_seq = rho_squared_seq - 1.0 metrics_append(metrics, "chi2_seq", chi2_seq) + + +def compute_mis_weights_fsdp( + args, + *, + pg_loss: torch.Tensor, + train_log_probs: list[torch.Tensor], + rollout_log_probs: list[torch.Tensor], + loss_masks: list[torch.Tensor], + **kwargs: Any, +) -> tuple[torch.Tensor, list[torch.Tensor], dict[str, torch.Tensor]]: + """Compute masked importance sampling weights for FSDP. No context parallelism. + + Args: + args: Arguments containing MIS settings (use_tis, tis_mode, etc.) + pg_loss: Policy gradient loss, flattened tensor [total_tokens] + train_log_probs: Training log probs, list of 1D tensors per sequence + rollout_log_probs: Rollout log probs, list of 1D tensors per sequence + loss_masks: Loss masks, list of 1D tensors per sequence + **kwargs: Additional arguments (cp_rank, cp_size, etc.) for compatibility + + Returns: + pg_loss: Policy gradient loss with IS weights applied + modified_masks: Modified loss masks after rejection sampling + mis_metrics: Metrics dict with flattened tensors + """ + is_weights, modified_masks, is_metrics = compute_mis_weights( + args=args, + train_log_probs=train_log_probs, + rollout_log_probs=rollout_log_probs, + loss_masks=loss_masks, + ) + + result_metrics = {} + if is_weights is not None: + is_weights_flat = torch.cat(is_weights, dim=0) + pg_loss = pg_loss * is_weights_flat + + for key, values in is_metrics.items(): + result_metrics[f"mis_{key}"] = torch.cat(values, dim=0) + + return pg_loss, modified_masks, result_metrics diff --git a/examples/train_infer_mismatch_helper/run-qwen3-4b-fsdp-mis.sh b/examples/train_infer_mismatch_helper/run-qwen3-4b-fsdp-mis.sh new file mode 100644 index 0000000000..08cbe6b3d0 --- /dev/null +++ b/examples/train_infer_mismatch_helper/run-qwen3-4b-fsdp-mis.sh @@ -0,0 +1,148 @@ +#!/bin/bash + +# for rerun the task +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + + + + +set -ex + +# will prevent ray from buffering stdout/stderr +export PYTHONBUFFERED=16 +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + + + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" + +RUN_ID=${RUN_ID:-"run_$(date +%Y%m%d_%H%M%S)"} +LOAD_SAVE_PATH="/root/shared_data/${RUN_ID}/checkpoints" + +CKPT_ARGS=( + --hf-checkpoint /root/Qwen3-4B + --load /root/Qwen3-4B + --ref-load /root/Qwen3-4B +) + +ROLLOUT_ARGS=( + --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + --balance-data + --rm-type deepscaler + --num-rollout 100 + --rollout-batch-size 8 + --n-samples-per-prompt 8 + --rollout-max-response-len 4096 + --rollout-temperature 0.8 + --global-batch-size 64 +) + +GRPO_ARGS=( + --use-kl-loss + --advantage-estimator grpo + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --kl-coef 0.00 + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 + --use-tis +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=( + --use-wandb + --wandb-project slime-dev-mcore-fsdp + --wandb-group qwen3-4B-fsdp-1130-ref + --wandb-key ${WANDB_API_KEY} +) + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 1 + --sglang-mem-fraction-static 0.75 + --sglang-decode-log-interval 1000 + --sglang-chunked-prefill-size 4096 + --sglang-attention-backend fa3 +) + +TRAIN_BACKEND_ARGS=( + --train-backend fsdp + --update-weight-buffer-size 536870912 + --gradient-checkpointing + --attn-implementation flash_attention_3 + --train-env-vars '{"PYTORCH_CUDA_ALLOC_CONF":"expandable_segments:True"}' +) + +PERF_ARGS=( + --use-dynamic-batch-size + --max-tokens-per-gpu 9216 +) + +MISC_ARGS=( + --actor-num-nodes 1 + --actor-num-gpus-per-node 8 + --colocate + --use-fault-tolerance + --dump-details /root/shared_data/qwen3-4B-fsdp-1116-noref/dump_details + # --fsdp-cpu-offload +) + +CUSTOM_ARGS=( + --custom-config-path examples/train_infer_mismatch_helper/mis.yaml + --custom-tis-function-path examples.train_infer_mismatch_helper.mis.compute_mis_weights_fsdp +) + +# launch the master node of ray in container - 8 GPUs for training +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats + + +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/:${SCRIPT_DIR}\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\" + } +}" + + +ray job submit --address="http://127.0.0.1:8265" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${TRAIN_BACKEND_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${MISC_ARGS[@]} \ + ${CUSTOM_ARGS[@]} + + diff --git a/slime/backends/fsdp_utils/actor.py b/slime/backends/fsdp_utils/actor.py index 223ad1672a..bc2c61e547 100644 --- a/slime/backends/fsdp_utils/actor.py +++ b/slime/backends/fsdp_utils/actor.py @@ -19,7 +19,14 @@ from slime.utils.distributed_utils import get_gloo_group from slime.utils.memory_utils import clear_memory, print_memory from slime.utils.metric_utils import compute_rollout_step -from slime.utils.ppo_utils import compute_approx_kl, compute_gspo_kl, compute_opsm_mask, compute_policy_loss +from slime.utils.misc import load_function +from slime.utils.ppo_utils import ( + compute_approx_kl, + compute_gspo_kl, + compute_opsm_mask, + compute_policy_loss, + vanilla_tis_function, +) from slime.utils.processing_utils import load_processor, load_tokenizer from slime.utils.ray_utils import Box from slime.utils.timer import Timer, inverse_timer, timer @@ -655,26 +662,41 @@ def _has_rollout_log_probs(batch) -> bool: else None ) - # Apply TIS before sample mean calculation + # Apply off-policy correction using importance sampling if enabled if self.args.use_tis: - # Apply TIS off-policy correction using importance sampling assert ( has_rollout_log_probs and rollout_log_probs is not None - ), "rollout_log_probs must be provided as non-empty torch.Tensor for TIS" + ), "rollout_log_probs must be provided as non-empty torch.Tensor for TIS/MIS" - tis = torch.exp(old_log_probs - rollout_log_probs) + train_log_probs_list = list(log_probs.split(response_lengths, dim=0)) + rollout_log_probs_list = list(rollout_log_probs.split(response_lengths, dim=0)) ois = (-ppo_kl).exp() - tis_clip = torch.clamp( - tis, min=getattr(self.args, "tis_clip_low", 0.1), max=getattr(self.args, "tis_clip", 2.0) - ) - tis_clipfrac = tis_clip != tis - - pg_loss = pg_loss * tis_clip - - assert not self.args.calculate_per_token_loss, "calculate_per_token_loss not yet implemented" - pg_loss = sum_of_sample_mean(pg_loss, response_lengths, loss_masks) - pg_clipfrac = sum_of_sample_mean(pg_clipfrac, response_lengths, loss_masks) - ppo_kl = sum_of_sample_mean(ppo_kl.abs(), response_lengths, loss_masks) + tis_kwargs = { + "args": self.args, + "pg_loss": pg_loss, + "train_log_probs": train_log_probs_list, + "rollout_log_probs": rollout_log_probs_list, + "loss_masks": loss_masks, + "response_lengths": response_lengths, + "cp_rank": self.cp_rank, + "cp_size": self.cp_size, + "cp_group": self.cp_group, + } + + if self.args.custom_tis_function_path is not None: + tis_func = load_function(self.args.custom_tis_function_path) + else: + tis_func = vanilla_tis_function + pg_loss, loss_masks, tis_metrics = tis_func(**tis_kwargs) + + if self.args.calculate_per_token_loss: + pg_loss = sum_of_token(pg_loss, response_lengths, loss_masks) + pg_clipfrac = sum_of_token(pg_clipfrac, response_lengths, loss_masks) + ppo_kl = sum_of_token(ppo_kl.abs(), response_lengths, loss_masks) + else: + pg_loss = sum_of_sample_mean(pg_loss, response_lengths, loss_masks) + pg_clipfrac = sum_of_sample_mean(pg_clipfrac, response_lengths, loss_masks) + ppo_kl = sum_of_sample_mean(ppo_kl.abs(), response_lengths, loss_masks) # Only compare rollout vs. train log probs when they originate from different stages. train_rollout_logprob_abs_diff = None @@ -721,10 +743,13 @@ def _has_rollout_log_probs(batch) -> bool: if self.args.use_opsm: reported["opsm_clipfrac"] = opsm_clipfrac - if self.args.use_tis and tis is not None: - reported["tis"] = sum_of_sample_mean(tis, response_lengths, loss_masks).detach() + if self.args.use_tis and tis_metrics: reported["ois"] = sum_of_sample_mean(ois, response_lengths, loss_masks).detach() - reported["tis_clipfrac"] = sum_of_sample_mean(tis_clipfrac.float(), response_lengths, loss_masks).detach() + for k, v in tis_metrics.items(): + if self.args.calculate_per_token_loss: + reported[k] = sum_of_token(v, response_lengths, loss_masks).detach() + else: + reported[k] = sum_of_sample_mean(v, response_lengths, loss_masks).detach() # Scale loss for gradient accumulation loss = loss * self.dp_size / self.args.global_batch_size @@ -1103,3 +1128,12 @@ def apply_fsdp2(model, mesh=None, cpu_offload=False, args=None): fully_shard(model, **fsdp_kwargs) return model + + +def sum_of_token(x: torch.Tensor, response_lengths: list[int], loss_masks: list[torch.Tensor]) -> torch.Tensor: + return sum( + [ + (x_i * loss_mask_i).sum() + for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks, strict=False) + ] + ) diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 39d0dae2db..cc3cdcf8de 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -420,7 +420,7 @@ def policy_loss_function( Computes current log-probabilities and entropy from model logits, then calculates PPO-style clipped policy gradient loss. For GSPO, gathers full sequences via context-parallel all-gather before computing per-sample - KL. Optionally applies TIS (Temporal Importance Sampling) correction and + KL. Optionally applies TIS (Truncated Importance Sampling) correction and adds KL loss term if configured. Args: diff --git a/slime/utils/ppo_utils.py b/slime/utils/ppo_utils.py index 3d90e8dfcd..ada5454160 100644 --- a/slime/utils/ppo_utils.py +++ b/slime/utils/ppo_utils.py @@ -662,3 +662,40 @@ def calculate_log_probs_and_entropy(logits, tokens, tp_group, with_entropy: bool else: entropy = None return log_prob, entropy + + +def vanilla_tis_function( + args, + *, + pg_loss: torch.Tensor, + train_log_probs: list[torch.Tensor], + rollout_log_probs: list[torch.Tensor], + loss_masks: list[torch.Tensor], + **kwargs, +) -> tuple[torch.Tensor, list[torch.Tensor], dict[str, torch.Tensor]]: + """Apply TIS off-policy correction using importance sampling. + + Parameters: + args: Arguments containing TIS settings. + pg_loss: Policy gradient loss tensor of shape [total_seq_len - 1]. + train_log_probs: List of tensors containing training log-probabilities + for each sequence. + rollout_log_probs: List of tensors containing rollout log-probabilities + for each sequence. + loss_masks: List of tensors containing loss masks for each sequence. + """ + rollout_log_probs = torch.cat(rollout_log_probs, dim=0) + old_log_probs = torch.cat(train_log_probs, dim=0) + tis = torch.exp(old_log_probs - rollout_log_probs) + tis_abs = (tis - 1).abs() + tis_clip_low = args.tis_clip_low if args.tis_clip_low is not None else 0.1 + tis_clip_high = args.tis_clip if args.tis_clip is not None else 2.0 + tis_weights = torch.clamp(tis, min=tis_clip_low, max=tis_clip_high) + tis_clipfrac = (tis_weights != tis).float() + metrics = { + "tis": tis.clone().detach(), + "tis_clipfrac": tis_clipfrac.clone().detach(), + "tis_abs": tis_abs.clone().detach(), + } + pg_loss = pg_loss * tis_weights + return pg_loss, loss_masks, metrics