Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions examples/train_infer_mismatch_helper/mis_fsdp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Any

import torch

from .mis import compute_mis_weights


def compute_mis_weights_fsdp(
args,
*,
pg_loss: torch.Tensor,
train_log_probs: list[torch.Tensor],
rollout_log_probs: list[torch.Tensor],
loss_masks: list[torch.Tensor],
**kwargs: Any,
) -> tuple[torch.Tensor, list[torch.Tensor], dict[str, torch.Tensor]]:
"""Compute masked importance sampling weights for FSDP. No context parallelism.

Args:
args: Arguments containing MIS settings (use_tis, tis_mode, etc.)
pg_loss: Policy gradient loss, flattened tensor [total_tokens]
train_log_probs: Training log probs, list of 1D tensors per sequence
rollout_log_probs: Rollout log probs, list of 1D tensors per sequence
loss_masks: Loss masks, list of 1D tensors per sequence
**kwargs: Additional arguments (cp_rank, cp_size, etc.) for compatibility

Returns:
pg_loss: Policy gradient loss with IS weights applied
modified_masks: Modified loss masks after rejection sampling
mis_metrics: Metrics dict with flattened tensors
"""
is_weights, modified_masks, is_metrics = compute_mis_weights(
args=args,
train_log_probs=train_log_probs,
rollout_log_probs=rollout_log_probs,
loss_masks=loss_masks,
)

result_metrics = {}
if is_weights is not None:
is_weights_flat = torch.cat(is_weights, dim=0)
pg_loss = pg_loss * is_weights_flat

for key, values in is_metrics.items():
result_metrics[f"mis_{key}"] = torch.cat(values, dim=0)

return pg_loss, modified_masks, result_metrics
148 changes: 148 additions & 0 deletions examples/train_infer_mismatch_helper/run-qwen3-4b-fsdp-mis.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#!/bin/bash

# for rerun the task
pkill -9 sglang
sleep 3
ray stop --force
pkill -9 ray
pkill -9 python
sleep 3
pkill -9 ray
pkill -9 python




set -ex

# will prevent ray from buffering stdout/stderr
export PYTHONBUFFERED=16
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l)
if [ "$NVLINK_COUNT" -gt 0 ]; then
HAS_NVLINK=1
else
HAS_NVLINK=0
fi
echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)"



SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"

RUN_ID=${RUN_ID:-"run_$(date +%Y%m%d_%H%M%S)"}
LOAD_SAVE_PATH="/root/shared_data/${RUN_ID}/checkpoints"

CKPT_ARGS=(
--hf-checkpoint /root/Qwen3-4B
--load /root/Qwen3-4B
--ref-load /root/Qwen3-4B
)

ROLLOUT_ARGS=(
--prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl
--input-key prompt
--label-key label
--apply-chat-template
--rollout-shuffle
--balance-data
--rm-type deepscaler
--num-rollout 100
--rollout-batch-size 8
--n-samples-per-prompt 8
--rollout-max-response-len 4096
--rollout-temperature 0.8
--global-batch-size 64
)

GRPO_ARGS=(
--use-kl-loss
--advantage-estimator grpo
--kl-loss-coef 0.00
--kl-loss-type low_var_kl
--kl-coef 0.00
--entropy-coef 0.00
--eps-clip 0.2
--eps-clip-high 0.28
--use-tis
)

OPTIMIZER_ARGS=(
--optimizer adam
--lr 1e-6
--lr-decay-style constant
--weight-decay 0.1
--adam-beta1 0.9
--adam-beta2 0.98
)

WANDB_ARGS=(
--use-wandb
--wandb-project slime-dev-mcore-fsdp
--wandb-group qwen3-4B-fsdp-1130-ref
--wandb-key ${WANDB_API_KEY}
)

SGLANG_ARGS=(
--rollout-num-gpus-per-engine 1
--sglang-mem-fraction-static 0.75
--sglang-decode-log-interval 1000
--sglang-chunked-prefill-size 4096
--sglang-attention-backend fa3
)

TRAIN_BACKEND_ARGS=(
--train-backend fsdp
--update-weight-buffer-size 536870912
--gradient-checkpointing
--attn-implementation flash_attention_3
--train-env-vars '{"PYTORCH_CUDA_ALLOC_CONF":"expandable_segments:True"}'
)

PERF_ARGS=(
--use-dynamic-batch-size
--max-tokens-per-gpu 9216
)

MISC_ARGS=(
--actor-num-nodes 1
--actor-num-gpus-per-node 8
--colocate
--use-fault-tolerance
--dump-details /root/shared_data/qwen3-4B-fsdp-1116-noref/dump_details
# --fsdp-cpu-offload
)

CUSTOM_ARGS=(
--custom-config-path examples/train_infer_mismatch_helper/mis.yaml
--custom-tis-function-path examples.train_infer_mismatch_helper.mis_fsdp.compute_mis_weights_fsdp
)

# launch the master node of ray in container - 8 GPUs for training
export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats


RUNTIME_ENV_JSON="{
\"env_vars\": {
\"PYTHONPATH\": \"/root/Megatron-LM/:${SCRIPT_DIR}\",
\"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\"
}
}"


ray job submit --address="http://127.0.0.1:8265" \
--runtime-env-json="${RUNTIME_ENV_JSON}" \
-- python3 train.py \
${CKPT_ARGS[@]} \
${ROLLOUT_ARGS[@]} \
${OPTIMIZER_ARGS[@]} \
${GRPO_ARGS[@]} \
${WANDB_ARGS[@]} \
${SGLANG_ARGS[@]} \
${TRAIN_BACKEND_ARGS[@]} \
${PERF_ARGS[@]} \
${MISC_ARGS[@]} \
${CUSTOM_ARGS[@]}


158 changes: 137 additions & 21 deletions slime/backends/fsdp_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from slime.utils.distributed_utils import get_gloo_group
from slime.utils.memory_utils import clear_memory, print_memory
from slime.utils.metric_utils import compute_rollout_step
from slime.utils.misc import load_function
from slime.utils.ppo_utils import compute_approx_kl, compute_gspo_kl, compute_opsm_mask, compute_policy_loss
from slime.utils.processing_utils import load_processor, load_tokenizer
from slime.utils.ray_utils import Box
Expand Down Expand Up @@ -655,26 +656,41 @@ def _has_rollout_log_probs(batch) -> bool:
else None
)

# Apply TIS before sample mean calculation
tis_metrics = {}
if self.args.use_tis:
# Apply TIS off-policy correction using importance sampling
assert (
has_rollout_log_probs and rollout_log_probs is not None
), "rollout_log_probs must be provided as non-empty torch.Tensor for TIS"

tis = torch.exp(old_log_probs - rollout_log_probs)
ois = (-ppo_kl).exp()
tis_clip = torch.clamp(
tis, min=getattr(self.args, "tis_clip_low", 0.1), max=getattr(self.args, "tis_clip", 2.0)
)
tis_clipfrac = tis_clip != tis

pg_loss = pg_loss * tis_clip

assert not self.args.calculate_per_token_loss, "calculate_per_token_loss not yet implemented"
pg_loss = sum_of_sample_mean(pg_loss, response_lengths, loss_masks)
pg_clipfrac = sum_of_sample_mean(pg_clipfrac, response_lengths, loss_masks)
ppo_kl = sum_of_sample_mean(ppo_kl.abs(), response_lengths, loss_masks)
), "rollout_log_probs must be provided as non-empty torch.Tensor for TIS/MIS"

train_log_probs_list = list(log_probs.split(response_lengths, dim=0))
rollout_log_probs_list = list(rollout_log_probs.split(response_lengths, dim=0))

tis_kwargs = {
"args": self.args,
"pg_loss": pg_loss,
"train_log_probs": train_log_probs_list,
"rollout_log_probs": rollout_log_probs_list,
"loss_masks": loss_masks,
"response_lengths": response_lengths,
"cp_rank": self.cp_rank,
"cp_size": self.cp_size,
"cp_group": self.cp_group,
}

if self.args.custom_tis_function_path is not None:
tis_func = load_function(self.args.custom_tis_function_path)
else:
tis_func = vanilla_tis_function_fsdp
pg_loss, loss_masks, tis_metrics = tis_func(**tis_kwargs)

if getattr(self.args, "calculate_per_token_loss", False):
Comment thread
zijiexia marked this conversation as resolved.
Outdated
pg_loss = sum_of_token(pg_loss, response_lengths, loss_masks)
pg_clipfrac = sum_of_token(pg_clipfrac, response_lengths, loss_masks)
ppo_kl = sum_of_token(ppo_kl.abs(), response_lengths, loss_masks)
else:
pg_loss = sum_of_sample_mean(pg_loss, response_lengths, loss_masks)
pg_clipfrac = sum_of_sample_mean(pg_clipfrac, response_lengths, loss_masks)
ppo_kl = sum_of_sample_mean(ppo_kl.abs(), response_lengths, loss_masks)

# Only compare rollout vs. train log probs when they originate from different stages.
train_rollout_logprob_abs_diff = None
Expand Down Expand Up @@ -721,10 +737,12 @@ def _has_rollout_log_probs(batch) -> bool:
if self.args.use_opsm:
reported["opsm_clipfrac"] = opsm_clipfrac

if self.args.use_tis and tis is not None:
reported["tis"] = sum_of_sample_mean(tis, response_lengths, loss_masks).detach()
reported["ois"] = sum_of_sample_mean(ois, response_lengths, loss_masks).detach()
reported["tis_clipfrac"] = sum_of_sample_mean(tis_clipfrac.float(), response_lengths, loss_masks).detach()
if self.args.use_tis and tis_metrics:
for k, v in tis_metrics.items():
if getattr(self.args, "calculate_per_token_loss", False):
Comment thread
zijiexia marked this conversation as resolved.
Outdated
reported[k] = sum_of_token(v, response_lengths, loss_masks).detach()
else:
reported[k] = sum_of_sample_mean(v, response_lengths, loss_masks).detach()

# Scale loss for gradient accumulation
loss = loss * self.dp_size / self.args.global_batch_size
Expand Down Expand Up @@ -803,6 +821,51 @@ def update_weights(self) -> None: # type: ignore[override]

clear_memory()

def _compute_tis_weights(
self,
old_log_probs: torch.Tensor,
rollout_log_probs: torch.Tensor,
loss_masks: list[torch.Tensor],
response_lengths: list[int],
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""Compute Importance Sampling weights for TIS/MIS.

Supports both token-level and sequence-level aggregation, and truncate/mask modes.
"""
tis_mode = getattr(self.args, "tis_mode", "truncate")
tis_level = getattr(self.args, "tis_level", "token")
tis_clip_low = getattr(self.args, "tis_clip_low", 0.1)
tis_clip_high = getattr(self.args, "tis_clip", 2.0)

log_ratio = old_log_probs - rollout_log_probs

# Calculate raw TIS weights based on level
if tis_level == "token":
tis = torch.exp(log_ratio)
elif tis_level == "sequence":
tis_list = []
for seq_log_ratio, mask in zip(log_ratio.split(response_lengths, dim=0), loss_masks, strict=False):
seq_mask = mask.to(seq_log_ratio.device)
sum_log_ratio = masked_sum(seq_log_ratio, seq_mask, expand=True)
seq_tis = torch.exp(sum_log_ratio)
tis_list.append(seq_tis)
tis = torch.cat(tis_list, dim=0)
else:
raise ValueError(f"Unsupported tis_level: {tis_level}")

# Apply mode (truncate or mask)
if tis_mode == "truncate":
tis_clip = torch.clamp(tis, min=tis_clip_low, max=tis_clip_high)
elif tis_mode == "mask":
mask = (tis >= tis_clip_low) & (tis <= tis_clip_high)
tis_clip = tis * mask.float()
else:
raise ValueError(f"Unsupported tis_mode: {tis_mode}")

tis_clipfrac = tis_clip != tis

return tis_clip, tis, tis_clipfrac

def _create_ref_model(self, ref_load_path: str | None):
"""Create and initialize a separate reference model with FSDP2 CPUOffloadPolicy.

Expand Down Expand Up @@ -1035,6 +1098,11 @@ def sum_of_sample_mean(x: torch.Tensor, response_lengths: list[int], loss_masks:
)


def masked_sum(x: torch.Tensor, loss_mask: torch.Tensor, expand: bool = False) -> torch.Tensor:
result = torch.where(loss_mask.bool(), x, 0.0).sum()
return result.expand_as(x) if expand else result


@torch.no_grad()
def move_torch_optimizer(optimizer, device):
"""ref: https://github.com/volcengine/verl/blob/main/verl/utils/fsdp_utils.py"""
Expand Down Expand Up @@ -1103,3 +1171,51 @@ def apply_fsdp2(model, mesh=None, cpu_offload=False, args=None):
fully_shard(model, **fsdp_kwargs)

return model


def sum_of_token(x: torch.Tensor, response_lengths: list[int], loss_masks: list[torch.Tensor]) -> torch.Tensor:
return sum(
[
(x_i * loss_mask_i).sum()
for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks, strict=False)
]
)


def vanilla_tis_function_fsdp(
args,
*,
pg_loss: torch.Tensor,
train_log_probs: list[torch.Tensor],
rollout_log_probs: list[torch.Tensor],
loss_masks: list[torch.Tensor],
**kwargs,
) -> tuple[torch.Tensor, list[torch.Tensor], dict[str, torch.Tensor]]:
"""Apply TIS off-policy correction using importance sampling.

Parameters:
args: Arguments containing TIS settings.
pg_loss: Policy gradient loss tensor of shape [total_seq_len - 1].
train_log_probs: List of tensors containing training log-probabilities
for each sequence.
rollout_log_probs: List of tensors containing rollout log-probabilities
for each sequence.
loss_masks: List of tensors containing loss masks for each sequence.
"""
rollout_log_probs_flat = torch.cat(rollout_log_probs, dim=0)
train_log_probs_flat = torch.cat(train_log_probs, dim=0)

tis = torch.exp(train_log_probs_flat - rollout_log_probs_flat)
tis_abs = (tis - 1).abs()

tis_clip = torch.clamp(tis, min=getattr(args, "tis_clip_low", 0.1), max=getattr(args, "tis_clip", 2.0))
tis_clipfrac = (tis_clip != tis).float()

metrics = {
"tis": tis.clone().detach(),
"tis_clipfrac": tis_clipfrac.clone().detach(),
"tis_abs": tis_abs.clone().detach(),
}
pg_loss = pg_loss * tis_clip

return pg_loss, loss_masks, metrics
Loading
Loading