Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions examples/train_infer_mismatch_helper/mis.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,3 +395,45 @@ def add_ppl_metrics(
rho_squared_seq = torch.exp(2.0 * log_ratio_sum_safe) # (Π ρ_t)²
chi2_seq = rho_squared_seq - 1.0
metrics_append(metrics, "chi2_seq", chi2_seq)


def compute_mis_weights_fsdp(
args,
*,
pg_loss: torch.Tensor,
train_log_probs: list[torch.Tensor],
rollout_log_probs: list[torch.Tensor],
loss_masks: list[torch.Tensor],
**kwargs: Any,
) -> tuple[torch.Tensor, list[torch.Tensor], dict[str, torch.Tensor]]:
"""Compute masked importance sampling weights for FSDP. No context parallelism.

Args:
args: Arguments containing MIS settings (use_tis, tis_mode, etc.)
pg_loss: Policy gradient loss, flattened tensor [total_tokens]
train_log_probs: Training log probs, list of 1D tensors per sequence
rollout_log_probs: Rollout log probs, list of 1D tensors per sequence
loss_masks: Loss masks, list of 1D tensors per sequence
**kwargs: Additional arguments (cp_rank, cp_size, etc.) for compatibility

Returns:
pg_loss: Policy gradient loss with IS weights applied
modified_masks: Modified loss masks after rejection sampling
mis_metrics: Metrics dict with flattened tensors
"""
is_weights, modified_masks, is_metrics = compute_mis_weights(
args=args,
train_log_probs=train_log_probs,
rollout_log_probs=rollout_log_probs,
loss_masks=loss_masks,
)

result_metrics = {}
if is_weights is not None:
is_weights_flat = torch.cat(is_weights, dim=0)
pg_loss = pg_loss * is_weights_flat

for key, values in is_metrics.items():
result_metrics[f"mis_{key}"] = torch.cat(values, dim=0)

return pg_loss, modified_masks, result_metrics
148 changes: 148 additions & 0 deletions examples/train_infer_mismatch_helper/run-qwen3-4b-fsdp-mis.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#!/bin/bash

# for rerun the task
pkill -9 sglang
sleep 3
ray stop --force
pkill -9 ray
pkill -9 python
sleep 3
pkill -9 ray
pkill -9 python




set -ex

# will prevent ray from buffering stdout/stderr
export PYTHONBUFFERED=16
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l)
if [ "$NVLINK_COUNT" -gt 0 ]; then
HAS_NVLINK=1
else
HAS_NVLINK=0
fi
echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)"



SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"

RUN_ID=${RUN_ID:-"run_$(date +%Y%m%d_%H%M%S)"}
LOAD_SAVE_PATH="/root/shared_data/${RUN_ID}/checkpoints"

CKPT_ARGS=(
--hf-checkpoint /root/Qwen3-4B
--load /root/Qwen3-4B
--ref-load /root/Qwen3-4B
)

ROLLOUT_ARGS=(
--prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl
--input-key prompt
--label-key label
--apply-chat-template
--rollout-shuffle
--balance-data
--rm-type deepscaler
--num-rollout 100
--rollout-batch-size 8
--n-samples-per-prompt 8
--rollout-max-response-len 4096
--rollout-temperature 0.8
--global-batch-size 64
)

GRPO_ARGS=(
--use-kl-loss
--advantage-estimator grpo
--kl-loss-coef 0.00
--kl-loss-type low_var_kl
--kl-coef 0.00
--entropy-coef 0.00
--eps-clip 0.2
--eps-clip-high 0.28
--use-tis
)

OPTIMIZER_ARGS=(
--optimizer adam
--lr 1e-6
--lr-decay-style constant
--weight-decay 0.1
--adam-beta1 0.9
--adam-beta2 0.98
)

WANDB_ARGS=(
--use-wandb
--wandb-project slime-dev-mcore-fsdp
--wandb-group qwen3-4B-fsdp-1130-ref
--wandb-key ${WANDB_API_KEY}
)

SGLANG_ARGS=(
--rollout-num-gpus-per-engine 1
--sglang-mem-fraction-static 0.75
--sglang-decode-log-interval 1000
--sglang-chunked-prefill-size 4096
--sglang-attention-backend fa3
)

TRAIN_BACKEND_ARGS=(
--train-backend fsdp
--update-weight-buffer-size 536870912
--gradient-checkpointing
--attn-implementation flash_attention_3
--train-env-vars '{"PYTORCH_CUDA_ALLOC_CONF":"expandable_segments:True"}'
)

PERF_ARGS=(
--use-dynamic-batch-size
--max-tokens-per-gpu 9216
)

MISC_ARGS=(
--actor-num-nodes 1
--actor-num-gpus-per-node 8
--colocate
--use-fault-tolerance
--dump-details /root/shared_data/qwen3-4B-fsdp-1116-noref/dump_details
# --fsdp-cpu-offload
)

CUSTOM_ARGS=(
--custom-config-path examples/train_infer_mismatch_helper/mis.yaml
--custom-tis-function-path examples.train_infer_mismatch_helper.mis.compute_mis_weights_fsdp
)

# launch the master node of ray in container - 8 GPUs for training
export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats


RUNTIME_ENV_JSON="{
\"env_vars\": {
\"PYTHONPATH\": \"/root/Megatron-LM/:${SCRIPT_DIR}\",
\"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\"
}
}"


ray job submit --address="http://127.0.0.1:8265" \
--runtime-env-json="${RUNTIME_ENV_JSON}" \
-- python3 train.py \
${CKPT_ARGS[@]} \
${ROLLOUT_ARGS[@]} \
${OPTIMIZER_ARGS[@]} \
${GRPO_ARGS[@]} \
${WANDB_ARGS[@]} \
${SGLANG_ARGS[@]} \
${TRAIN_BACKEND_ARGS[@]} \
${PERF_ARGS[@]} \
${MISC_ARGS[@]} \
${CUSTOM_ARGS[@]}


72 changes: 53 additions & 19 deletions slime/backends/fsdp_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@
from slime.utils.distributed_utils import get_gloo_group
from slime.utils.memory_utils import clear_memory, print_memory
from slime.utils.metric_utils import compute_rollout_step
from slime.utils.ppo_utils import compute_approx_kl, compute_gspo_kl, compute_opsm_mask, compute_policy_loss
from slime.utils.misc import load_function
from slime.utils.ppo_utils import (
compute_approx_kl,
compute_gspo_kl,
compute_opsm_mask,
compute_policy_loss,
vanilla_tis_function,
)
from slime.utils.processing_utils import load_processor, load_tokenizer
from slime.utils.ray_utils import Box
from slime.utils.timer import Timer, inverse_timer, timer
Expand Down Expand Up @@ -655,26 +662,41 @@ def _has_rollout_log_probs(batch) -> bool:
else None
)

# Apply TIS before sample mean calculation
# Apply off-policy correction using importance sampling if enabled
if self.args.use_tis:
# Apply TIS off-policy correction using importance sampling
assert (
has_rollout_log_probs and rollout_log_probs is not None
), "rollout_log_probs must be provided as non-empty torch.Tensor for TIS"
), "rollout_log_probs must be provided as non-empty torch.Tensor for TIS/MIS"

tis = torch.exp(old_log_probs - rollout_log_probs)
train_log_probs_list = list(log_probs.split(response_lengths, dim=0))
rollout_log_probs_list = list(rollout_log_probs.split(response_lengths, dim=0))
ois = (-ppo_kl).exp()
tis_clip = torch.clamp(
tis, min=getattr(self.args, "tis_clip_low", 0.1), max=getattr(self.args, "tis_clip", 2.0)
)
tis_clipfrac = tis_clip != tis

pg_loss = pg_loss * tis_clip

assert not self.args.calculate_per_token_loss, "calculate_per_token_loss not yet implemented"
pg_loss = sum_of_sample_mean(pg_loss, response_lengths, loss_masks)
pg_clipfrac = sum_of_sample_mean(pg_clipfrac, response_lengths, loss_masks)
ppo_kl = sum_of_sample_mean(ppo_kl.abs(), response_lengths, loss_masks)
tis_kwargs = {
"args": self.args,
"pg_loss": pg_loss,
"train_log_probs": train_log_probs_list,
"rollout_log_probs": rollout_log_probs_list,
"loss_masks": loss_masks,
"response_lengths": response_lengths,
"cp_rank": self.cp_rank,
"cp_size": self.cp_size,
"cp_group": self.cp_group,
}

if self.args.custom_tis_function_path is not None:
tis_func = load_function(self.args.custom_tis_function_path)
else:
tis_func = vanilla_tis_function
pg_loss, loss_masks, tis_metrics = tis_func(**tis_kwargs)

if self.args.calculate_per_token_loss:
pg_loss = sum_of_token(pg_loss, response_lengths, loss_masks)
pg_clipfrac = sum_of_token(pg_clipfrac, response_lengths, loss_masks)
ppo_kl = sum_of_token(ppo_kl.abs(), response_lengths, loss_masks)
else:
pg_loss = sum_of_sample_mean(pg_loss, response_lengths, loss_masks)
pg_clipfrac = sum_of_sample_mean(pg_clipfrac, response_lengths, loss_masks)
ppo_kl = sum_of_sample_mean(ppo_kl.abs(), response_lengths, loss_masks)

# Only compare rollout vs. train log probs when they originate from different stages.
train_rollout_logprob_abs_diff = None
Expand Down Expand Up @@ -721,10 +743,13 @@ def _has_rollout_log_probs(batch) -> bool:
if self.args.use_opsm:
reported["opsm_clipfrac"] = opsm_clipfrac

if self.args.use_tis and tis is not None:
reported["tis"] = sum_of_sample_mean(tis, response_lengths, loss_masks).detach()
if self.args.use_tis and tis_metrics:
reported["ois"] = sum_of_sample_mean(ois, response_lengths, loss_masks).detach()
reported["tis_clipfrac"] = sum_of_sample_mean(tis_clipfrac.float(), response_lengths, loss_masks).detach()
for k, v in tis_metrics.items():
if self.args.calculate_per_token_loss:
reported[k] = sum_of_token(v, response_lengths, loss_masks).detach()
else:
reported[k] = sum_of_sample_mean(v, response_lengths, loss_masks).detach()

# Scale loss for gradient accumulation
loss = loss * self.dp_size / self.args.global_batch_size
Expand Down Expand Up @@ -1103,3 +1128,12 @@ def apply_fsdp2(model, mesh=None, cpu_offload=False, args=None):
fully_shard(model, **fsdp_kwargs)

return model


def sum_of_token(x: torch.Tensor, response_lengths: list[int], loss_masks: list[torch.Tensor]) -> torch.Tensor:
return sum(
[
(x_i * loss_mask_i).sum()
for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks, strict=False)
]
)
2 changes: 1 addition & 1 deletion slime/backends/megatron_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def policy_loss_function(
Computes current log-probabilities and entropy from model logits, then
calculates PPO-style clipped policy gradient loss. For GSPO, gathers
full sequences via context-parallel all-gather before computing per-sample
KL. Optionally applies TIS (Temporal Importance Sampling) correction and
KL. Optionally applies TIS (Truncated Importance Sampling) correction and

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch 😂

adds KL loss term if configured.

Args:
Expand Down
37 changes: 37 additions & 0 deletions slime/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,3 +662,40 @@ def calculate_log_probs_and_entropy(logits, tokens, tp_group, with_entropy: bool
else:
entropy = None
return log_prob, entropy


def vanilla_tis_function(
args,
*,
pg_loss: torch.Tensor,
train_log_probs: list[torch.Tensor],
rollout_log_probs: list[torch.Tensor],
loss_masks: list[torch.Tensor],
**kwargs,
) -> tuple[torch.Tensor, list[torch.Tensor], dict[str, torch.Tensor]]:
"""Apply TIS off-policy correction using importance sampling.

Parameters:
args: Arguments containing TIS settings.
pg_loss: Policy gradient loss tensor of shape [total_seq_len - 1].
train_log_probs: List of tensors containing training log-probabilities
for each sequence.
rollout_log_probs: List of tensors containing rollout log-probabilities
for each sequence.
loss_masks: List of tensors containing loss masks for each sequence.
"""
rollout_log_probs = torch.cat(rollout_log_probs, dim=0)
old_log_probs = torch.cat(train_log_probs, dim=0)
tis = torch.exp(old_log_probs - rollout_log_probs)
tis_abs = (tis - 1).abs()
tis_clip_low = args.tis_clip_low if args.tis_clip_low is not None else 0.1
tis_clip_high = args.tis_clip if args.tis_clip is not None else 2.0
tis_weights = torch.clamp(tis, min=tis_clip_low, max=tis_clip_high)
tis_clipfrac = (tis_weights != tis).float()
metrics = {
"tis": tis.clone().detach(),
"tis_clipfrac": tis_clipfrac.clone().detach(),
"tis_abs": tis_abs.clone().detach(),
}
pg_loss = pg_loss * tis_weights
return pg_loss, loss_masks, metrics