Skip to content
Merged
Show file tree
Hide file tree
Changes from 63 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
30e01df
fix quick start docs in zh/en
zhaochenyang20 Oct 5, 2025
82695df
Update run-qwen3-30B-A3B.sh
zhaochenyang20 Oct 5, 2025
fc0ec27
[Importance sampling] seperate importance sampling as a function
zhaochenyang20 Oct 6, 2025
2ee67b6
Merge branch 'THUDM:main' into importance_sampling
zhaochenyang20 Oct 6, 2025
35f414c
fix lint
zhaochenyang20 Oct 6, 2025
901da8d
Merge branch 'THUDM:main' into main
zhaochenyang20 Oct 6, 2025
19649f9
fix lint of main
zhaochenyang20 Oct 6, 2025
81ffa47
adding pre-commit as a CI flow
zhaochenyang20 Oct 6, 2025
243cc44
only pre-commit with yml
zhaochenyang20 Oct 6, 2025
0b164f3
fix up pre commit
zhaochenyang20 Oct 6, 2025
0717171
unigy local pre-commit with third party
zhaochenyang20 Oct 6, 2025
0d27716
rebase with main for lint
zhaochenyang20 Oct 6, 2025
984c724
fix lint with main
zhaochenyang20 Oct 6, 2025
0695c73
adding kl metircs
zhaochenyang20 Oct 6, 2025
cac9fa9
fix type custing for metrics
zhaochenyang20 Oct 6, 2025
e7cf0c2
comments to compute_tis_weights
zhaochenyang20 Oct 7, 2025
2924975
[lint] tis comment
zhaochenyang20 Oct 7, 2025
7963809
refactor clip mode in sequence level
zhaochenyang20 Oct 7, 2025
6f36eef
[test] geometric level
zhaochenyang20 Oct 7, 2025
3cc4982
adding metrics to new tis
zhaochenyang20 Oct 7, 2025
d60a595
[log probs in 1D]
zhaochenyang20 Oct 9, 2025
5cac6e0
stash with main
zhaochenyang20 Oct 10, 2025
71194c3
slice tis with slice_log_prob_with_cp
zhaochenyang20 Oct 10, 2025
92e6e97
[todo] filter out catastrophic tokens
zhaochenyang20 Oct 10, 2025
dbd27d7
before metrics
zhaochenyang20 Oct 11, 2025
083c300
fix with rollout log probs
zhaochenyang20 Oct 11, 2025
ff68b32
[wait for the metircs]
zhaochenyang20 Oct 11, 2025
afc1cd7
logging a whole sequence
zhaochenyang20 Oct 9, 2025
4ab1aee
rebase with main, ready to metircs global
zhaochenyang20 Oct 12, 2025
f31c9cd
tmp commit
guapisolo Oct 13, 2025
46adaef
make good abstraction to metrics
guapisolo Oct 13, 2025
2d9ffd8
fix seq concat problem
guapisolo Oct 13, 2025
f44a7d1
fix key not exist
guapisolo Oct 13, 2025
8d77db6
update metrics calculation
guapisolo Oct 13, 2025
88f24dd
fix cp>1 with scatter_with_cp func and update argument
guapisolo Oct 13, 2025
3b68708
move cp scatter to loss.py
guapisolo Oct 14, 2025
74ea791
revert fsdp tis
guapisolo Oct 14, 2025
7e4ee12
fix argument names from tis to is
guapisolo Oct 14, 2025
3fd863b
Update clip bound name and veto impl to fully follow the paper
guapisolo Oct 14, 2025
c047b2c
upd clip_mask to mask
guapisolo Oct 14, 2025
ca6e90e
upd arg name, fix small bug, delete scatter_with_cp
guapisolo Oct 14, 2025
c0aeb32
move cp logic to tis.py
guapisolo Oct 14, 2025
4444634
add a qwen3-4b tis sample
guapisolo Oct 14, 2025
c4441ad
fix code comments, file name
guapisolo Oct 14, 2025
293cbce
fix small bug and arg setting
guapisolo Oct 14, 2025
d750125
fix small bug in veto, add mask, and add a script
guapisolo Oct 14, 2025
ce01008
small change fsdp
guapisolo Oct 15, 2025
43d04bd
fix parameter type in some functions and update comments
guapisolo Oct 15, 2025
1074d7a
Update train_infer_is.py
zhaochenyang20 Oct 15, 2025
003c002
Merge pull request #1 from guapisolo/tis
zhaochenyang20 Oct 15, 2025
99224fa
remove sh
zhaochenyang20 Oct 15, 2025
8a8c44c
logging a whole sequence
zhaochenyang20 Oct 9, 2025
4326b28
rebase with main
zhaochenyang20 Oct 15, 2025
52a401e
Update run-qwen3-30B-A3B.sh
zhaochenyang20 Oct 15, 2025
ac4e63a
create test scripts
zhaochenyang20 Oct 15, 2025
4df0724
revert change in qwen3 30B sh
zhaochenyang20 Oct 15, 2025
7b01369
remove two tests sh
zhaochenyang20 Oct 15, 2025
4455479
add kl metrics
guapisolo Oct 15, 2025
3319f38
fix comment
guapisolo Oct 15, 2025
ce541be
Merge branch 'main' into importance_sampling
zhaochenyang20 Oct 15, 2025
1d35f45
Merge pull request #3 from guapisolo/tis
zhaochenyang20 Oct 15, 2025
bf8df31
adding kl metrics
zhaochenyang20 Oct 15, 2025
e7a88e3
Merge branch 'main' into importance_sampling
zhaochenyang20 Oct 16, 2025
8249244
Merge branch 'main' into cytis
guapisolo Oct 17, 2025
fa4606d
revert changes to use_tis
guapisolo Oct 17, 2025
a66753f
move to examples and use yaml for custom args parsing
guapisolo Oct 17, 2025
5809cd1
Merge pull request #5 from guapisolo/tis
zhaochenyang20 Oct 17, 2025
fffeab9
fix small bug
guapisolo Oct 17, 2025
9e4bf7c
remove tis file
guapisolo Oct 17, 2025
eb7711c
give vanilla tis func
guapisolo Oct 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/en/get_started/quick_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ GRPO_ARGS=(

- `--advantage-estimator`: In addition to [GRPO](https://arxiv.org/abs/2402.03300), slime also supports several other training algorithms, such as [GSPO](https://arxiv.org/abs/2507.18071), [Reinforce++](https://arxiv.org/abs/2501.03262) and [Reinforce++ Baseline](https://arxiv.org/abs/2501.03262), and [PPO](https://arxiv.org/abs/1707.06347).
- `--calculate-per-token-loss`: By default, slime calculates the loss on a per-sample basis, i.e., `mean(sum(sample_i) / len(sample_i))`. To calculate the loss on a per-token basis, i.e., `sum(sum(sample_i)) / sum(len(sample_i))`, you can enable this flag.
- `--use-tis`: Enable this setting to use TIS (Truncated Importance Sampling), which is introduced by this [blog](https://fengyao.notion.site/off-policy-rl).
- `--use-train-infer-is`: Enable this setting to use TIS (Truncated Importance Sampling), which is introduced by this [blog](https://fengyao.notion.site/off-policy-rl).

### OPTIMIZER_ARGS: Optimizer Parameters

Expand Down
2 changes: 1 addition & 1 deletion docs/en/get_started/usage.md
Comment thread
zhaochenyang20 marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ Additionally, we provide a `metadata_key`, which defaults to `"metadata"`. When
- `reinforce_plus_plus` and `reinforce_plus_plus_baseline` ([https://arxiv.org/abs/2501.03262](https://arxiv.org/abs/2501.03262))
- `ppo` ([https://arxiv.org/abs/1707.06347](https://arxiv.org/abs/1707.06347))
- `--calculate-per-token-loss`: By default, Slime calculates loss on a per-sample basis, i.e., `mean(sum(sample_i) / len(sample_i))`. Enable this flag to calculate loss on a per-token basis, i.e., `sum(sum(sample_i)) / sum(len(sample_i))`.
- `--use-tis`: Enable this setting to use TIS (Truncated Importance Sampling) (https://fengyao.notion.site/off-policy-rl).
- `--use-train-infer-is`: Enable this setting to use TIS (Truncated Importance Sampling) (https://fengyao.notion.site/off-policy-rl).

## Custom Rollout Function

Expand Down
2 changes: 1 addition & 1 deletion docs/zh/get_started/quick_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ GRPO_ARGS=(

- `--advantage-estimator`: 除去 [GRPO](https://arxiv.org/abs/2402.03300),slime 还支持丰富的其他训练算法,例如 [GSPO](https://arxiv.org/abs/2507.18071)、[Reinforce++](https://arxiv.org/abs/2501.03262) 与 [Reinforce++ Baseline](https://arxiv.org/abs/2501.03262)、以及 [PPO](https://arxiv.org/abs/1707.06347);
- `--calculate-per-token-loss`:slime 中默认的方案是 per sample loss,即 `mean(sum(sample_i) / len(sample_i))`,如果需要计算 per token loss,即 `sum(sum(sample_i)) / sum(len(sample_i))`,可以开启 `--calculate-per-token-loss`;
- `--use-tis`:如果需要开启 TIS (Truncated Importance Sampling),可以开启这一设置。TIS 由此[博客](https://fengyao.notion.site/off-policy-rl)介绍。
- `--use-train-infer-is`:如果需要开启 TIS (Truncated Importance Sampling),可以开启这一设置。TIS 由此[博客](https://fengyao.notion.site/off-policy-rl)介绍。

### OPTIMIZER_ARGS: 优化器参数

Expand Down
2 changes: 1 addition & 1 deletion docs/zh/get_started/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ sglang 的加载非常简单,只需要:
- `reinforce_plus_plus` 与 `reinforce_plus_plus_baseline`(https://arxiv.org/abs/2501.03262);
- `ppo`(https://arxiv.org/abs/1707.06347)。
- `--calculate-per-token-loss`:slime 中默认的方案是 per sample loss,即 `mean(sum(sample_i) / len(sample_i))`,如果需要计算 per token loss,即 `sum(sum(sample_i)) / sum(len(sample_i))`,可以开启 `--calculate-per-token-loss`;
- `--use-tis`:如果需要开启 tis(https://fengyao.notion.site/off-policy-rl),可以开启这一设置。
- `--use-train-infer-is`:如果需要开启 tis(https://fengyao.notion.site/off-policy-rl),可以开启这一设置。

## 自定义 rollout 函数

Expand Down
2 changes: 1 addition & 1 deletion examples/fully_async/run-qwen3-4b-fully_async.sh
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ GRPO_ARGS=(
--eps-clip 0.2
--eps-clip-high 0.28

--use-tis
--use-train-infer-is
)

OPTIMIZER_ARGS=(
Expand Down
2 changes: 1 addition & 1 deletion scripts/run-glm4.5-355B-A32B.sh
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ GRPO_ARGS=(
--eps-clip 1e-4
--eps-clip-high 2e-4

--use-tis
--use-train-infer-is
)

OPTIMIZER_ARGS=(
Expand Down
166 changes: 166 additions & 0 deletions scripts/run-qwen3-30B-A3B-test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
#!/bin/bash

# for rerun the task
pkill -9 sglang
sleep 3
ray stop --force
pkill -9 ray
pkill -9 python
sleep 3
pkill -9 ray
pkill -9 python
pkill -9 redis

set -ex

# will prevent ray from buffering stdout/stderr
export PYTHONBUFFERED=16

NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l)
if [ "$NVLINK_COUNT" -gt 0 ]; then
HAS_NVLINK=1
else
HAS_NVLINK=0
fi
echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)"

SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
source "${SCRIPT_DIR}/models/qwen3-30B-A3B.sh"

CKPT_ARGS=(
--hf-checkpoint /root/Qwen3-30B-A3B
#--hf-checkpoint /root/Qwen3-30B-A3B-FP8
--ref-load /root/Qwen3-30B-A3B_torch_dist
# --load /root/Qwen3-30B-A3B_slime/
# --save /root/Qwen3-30B-A3B_slime/
# --save-interval 20
)

ROLLOUT_ARGS=(
--prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl
--input-key prompt
--label-key label
--apply-chat-template
--rollout-shuffle
--rm-type deepscaler
--num-rollout 3000
--rollout-batch-size 8
--n-samples-per-prompt 4
--rollout-max-response-len 8192
--rollout-temperature 0.8

--global-batch-size 32
--balance-data
)

EVAL_ARGS=(
# --eval-interval 20
--eval-prompt-data aime /root/aime-2024/aime-2024.jsonl
--n-samples-per-eval-prompt 16
--eval-max-response-len 16384
--eval-top-p 0.7
)

PERF_ARGS=(
--tensor-model-parallel-size 2
--sequence-parallel
--pipeline-model-parallel-size 1
--context-parallel-size 2
--expert-model-parallel-size 4
--expert-tensor-parallel-size 1

--recompute-granularity full
--recompute-method uniform
--recompute-num-layers 1

# --micro-batch-size 1
--use-dynamic-batch-size
--max-tokens-per-gpu 20480
)

GRPO_ARGS=(
--advantage-estimator grpo
--use-kl-loss
--kl-loss-coef 0.00
--kl-loss-type low_var_kl
--entropy-coef 0.00
--eps-clip 0.2
--eps-clip-high 0.28
)

IS_ARGS=(
--use-train-infer-is
--train-infer-is-level geometric
--train-infer-is-mode mask
--train-infer-is-lower-bound 0.5
--train-infer-is-upper-bound 2.0
--train-infer-is-veto-threshold 1e-3
)

OPTIMIZER_ARGS=(
--optimizer adam
--lr 1e-6
--lr-decay-style constant
--weight-decay 0.1
--adam-beta1 0.9
--adam-beta2 0.98

--optimizer-cpu-offload
--overlap-cpu-optimizer-d2h-h2d
--use-precision-aware-optimizer
)

WANDB_ARGS=(
--use-wandb
--wandb-project slime-dev
--wandb-group qwen3-30B-A3B-TIS
--wandb-run-id qwen3-30B-A3B-TIS-sequence
--wandb-key ${WANDB_KEY}
)

SGLANG_ARGS=(
--rollout-num-gpus-per-engine 4
--sglang-mem-fraction-static 0.7
--sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256)
)

MISC_ARGS=(
# default dropout in megatron is 0.1
--attention-dropout 0.0
--hidden-dropout 0.0
# should be good for model performance
--accumulate-allreduce-grads-in-fp32
--attention-softmax-in-fp32
# need to comment this when using model with MLA
--attention-backend flash
)

# launch the master node of ray in container
export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 4 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265

# Build the runtime environment JSON with proper variable substitution
RUNTIME_ENV_JSON="{
\"env_vars\": {
\"PYTHONPATH\": \"/root/Megatron-LM/\",
\"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\",
\"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\"
}
}"

ray job submit --address="http://127.0.0.1:8265" \
--runtime-env-json="${RUNTIME_ENV_JSON}" \
-- python3 train.py \
--actor-num-nodes 1 \
--actor-num-gpus-per-node 4 \
--colocate \
${MODEL_ARGS[@]} \
${CKPT_ARGS[@]} \
${ROLLOUT_ARGS[@]} \
${OPTIMIZER_ARGS[@]} \
${GRPO_ARGS[@]} \
${WANDB_ARGS[@]} \
${PERF_ARGS[@]} \
${EVAL_ARGS[@]} \
${SGLANG_ARGS[@]} \
${MISC_ARGS[@]}
6 changes: 4 additions & 2 deletions slime/backends/fsdp_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None:
pg_loss, pg_clipfrac = compute_policy_loss(ppo_kl, advantages, self.args.eps_clip, self.args.eps_clip_high)

# Apply TIS before sample mean calculation
if self.args.use_tis:
if self.args.use_train_infer_is:
# Initialize TIS variables
tis = None
tis_clipfrac = None
Expand All @@ -420,7 +420,9 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None:
tis = torch.exp(old_log_probs - rollout_log_probs)
ois = (-ppo_kl).exp()
tis_clip = torch.clamp(
tis, min=getattr(self.args, "tis_clip_low", 0.1), max=getattr(self.args, "tis_clip", 2.0)
tis,
min=getattr(self.args, "train_infer_is_lower_bound", 0.1),
max=getattr(self.args, "train_infer_is_upper_bound", 2.0),
Comment thread
zhaochenyang20 marked this conversation as resolved.
Outdated
)
tis_clipfrac = tis_clip != tis

Expand Down
30 changes: 18 additions & 12 deletions slime/backends/megatron_utils/loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from argparse import Namespace
from collections.abc import Callable, Iterator
from typing import Optional, Union
from typing import Union

import torch
from megatron.core import mpu
Expand All @@ -16,6 +16,7 @@
get_reinforce_plus_plus_baseline_advantages,
get_reinforce_plus_plus_returns,
)
from slime.utils.train_infer_is import compute_train_infer_is_weights_with_cp
from slime.utils.types import RolloutBatch

from .cp_utils import all_gather_with_cp, get_logits_and_tokens_offset_with_cp, get_sum_of_sample_mean
Expand Down Expand Up @@ -419,17 +420,20 @@ def policy_loss_function(
pg_loss, pg_clipfrac = compute_policy_loss(ppo_kl, advantages, args.eps_clip, args.eps_clip_high)

# Apply TIS off-policy correction using importance sampling if enabled
if args.use_tis:
if args.use_train_infer_is:
assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS"
rollout_log_probs = torch.cat(batch["rollout_log_probs"], dim=0)
old_log_probs = torch.cat(batch["log_probs"], dim=0)

tis = torch.exp(old_log_probs - rollout_log_probs)
ois = (-ppo_kl).exp()
tis_clip = torch.clamp(tis, min=args.tis_clip_low, max=args.tis_clip)
tis_clipfrac = (tis_clip != tis).float()
is_weights, is_metrics = compute_train_infer_is_weights_with_cp(
args=args,
train_log_probs=batch["log_probs"],
rollout_log_probs=batch["rollout_log_probs"],
loss_masks=batch["loss_masks"],
total_lengths=total_lengths,
response_lengths=response_lengths,
)

pg_loss = pg_loss * tis_clip
ois = (-ppo_kl).exp()
pg_loss = pg_loss * is_weights

pg_loss = sum_of_sample_mean(pg_loss)
pg_clipfrac = sum_of_sample_mean(pg_clipfrac)
Expand Down Expand Up @@ -469,10 +473,12 @@ def policy_loss_function(
if args.use_kl_loss:
reported_loss["kl_loss"] = kl_loss.clone().detach()

if args.use_tis:
reported_loss["tis"] = sum_of_sample_mean(tis).clone().detach()
if args.use_train_infer_is:
# Backward compatible basic logs
reported_loss["ois"] = sum_of_sample_mean(ois).clone().detach()
reported_loss["tis_clipfrac"] = sum_of_sample_mean(tis_clipfrac).clone().detach()
for metric_key, metric_value in is_metrics.items():
key_name = f"train_infer_{metric_key}"
reported_loss[key_name] = sum_of_sample_mean(metric_value)

return loss, reported_loss

Expand Down
Loading
Loading