Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
strategy:
fail-fast: false
matrix:
info: [{"num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "test_moonlight_16B_A3B.py"}, {"num_gpus": 2, "test_file": "test_qwen3_4B_fsdp_true_on_policy.py"}, {"num_gpus": 8, "test_file": "test_mimo_7B_mtp_only_grad.py"}]
info: [{"num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_ppo.py"}, {"num_gpus": 8, "test_file": "test_moonlight_16B_A3B.py"}, {"num_gpus": 2, "test_file": "test_qwen3_4B_fsdp_true_on_policy.py"}, {"num_gpus": 8, "test_file": "test_mimo_7B_mtp_only_grad.py"}, {"num_gpus": 8, "test_file": "test_qwen3_vl_4B_fsdp.py"}]
defaults:
run:
working-directory: ${{ github.workspace }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/pr-test.yml.j2
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
{'test_file': 'test_moonlight_16B_A3B.py', 'num_gpus': 8},
{'test_file': 'test_qwen3_4B_fsdp_true_on_policy.py', 'num_gpus': 2},
{'test_file': 'test_mimo_7B_mtp_only_grad.py', 'num_gpus': 8},
{'test_file': 'test_qwen3_vl_4B_fsdp.py', 'num_gpus': 8},
],
},
'e2e-test-long': {
Expand Down
2 changes: 1 addition & 1 deletion docs/en/get_started/customization.md
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ dict: {
"rollout_log_probs": list, # Log probs (for off-policy correction)
"rollout_routed_experts": list, # Routed experts (for MoE)
"metadata": list, # Train metadata
"multimodal_inputs": list, # Multimodal inputs (for VLM)
"multimodal_train_inputs": list, # Multimodal tensors (for VLM)
"teacher_log_probs": list, # Teacher log probs (for distillation)
}
```
Expand Down
2 changes: 1 addition & 1 deletion docs/zh/get_started/customization.md
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ dict: {
"rollout_log_probs": list, # log 概率(用于离策略校正)
"rollout_routed_experts": list, # 路由专家(用于 MoE)
"metadata": list, # 训练元数据
"multimodal_inputs": list, # 多模态输入(用于 VLM)
"multimodal_train_inputs": list, # 多模态张量(用于 VLM)
"teacher_log_probs": list, # 教师 log 概率(用于蒸馏)
}
```
Expand Down
54 changes: 27 additions & 27 deletions examples/geo3k_vlm/run_geo3k_vlm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,6 @@ fi
# Common args
CKPT_ARGS=(
--hf-checkpoint /root/models/${MODEL_NAME}
# vl model has rotary base 5000000
--rotary-base 5000000
)

ROLLOUT_ARGS=(
Expand Down Expand Up @@ -154,41 +152,43 @@ MISC_ARGS=(
# Backend-specific args
if [ "$TRAIN_BACKEND" = "fsdp" ]; then
BACKEND_ARGS=(
--train-backend fsdp
--gradient-checkpointing
--sglang-attention-backend fa3
--attn-implementation flash_attention_3
--update-weight-buffer-size 536870912
--train-backend fsdp
--gradient-checkpointing
--sglang-attention-backend fa3
--attn-implementation flash_attention_3
--update-weight-buffer-size 536870912
)
MODEL_ARGS=()
else
# megatron backend (default)
BACKEND_ARGS=(
--train-backend megatron
--load /root/models/${MODEL_NAME}
--tensor-model-parallel-size 4
--sequence-parallel
--pipeline-model-parallel-size 1
--context-parallel-size 1
--expert-model-parallel-size 1
--expert-tensor-parallel-size 1
--recompute-granularity full
--recompute-method uniform
--recompute-num-layers 1
--use-dynamic-batch-size
--max-tokens-per-gpu 4096
--attention-dropout 0.0
--hidden-dropout 0.0
--accumulate-allreduce-grads-in-fp32
--attention-softmax-in-fp32
--attention-backend flash
--megatron-to-hf-mode bridge
--train-backend megatron
--load /root/models/${MODEL_NAME}
--tensor-model-parallel-size 4
--sequence-parallel
--pipeline-model-parallel-size 1
--context-parallel-size 1
--expert-model-parallel-size 1
--expert-tensor-parallel-size 1
--recompute-granularity full
--recompute-method uniform
--recompute-num-layers 1
--use-dynamic-batch-size
--max-tokens-per-gpu 4096
--attention-dropout 0.0
--hidden-dropout 0.0
--accumulate-allreduce-grads-in-fp32
--attention-softmax-in-fp32
--attention-backend flash
--megatron-to-hf-mode bridge
)

# get MODEL_ARGS from scripts/models for megatron backend
SLIME_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/../.." &>/dev/null && pwd)"
MODEL_ARGS_FILE=$(echo "$MODEL_NAME" | sed 's/-Instruct//g; s/-Thinking//g; s/Qwen3-VL-/qwen3-/g; s/-2B/-1.7B/g')
source "${SLIME_DIR}/scripts/models/${MODEL_ARGS_FILE}.sh"
# VL models require rotary-base 5000000
MODEL_ARGS_ROTARY_BASE=5000000 source "${SLIME_DIR}/scripts/models/${MODEL_ARGS_FILE}.sh"

fi

# Start Ray if not using external Ray
Expand Down
4 changes: 2 additions & 2 deletions examples/geo3k_vlm/run_geo3k_vlm_sft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ fi
CKPT_ARGS=(
--hf-checkpoint /root/models/${MODEL_NAME}
--load /root/models/${MODEL_NAME}
--rotary-base 5000000
)

SFT_ARGS=(
Expand Down Expand Up @@ -152,7 +151,8 @@ else
# get MODEL_ARGS from scripts/models for megatron backend
SLIME_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/../.." &>/dev/null && pwd)"
MODEL_ARGS_FILE=$(echo "$MODEL_NAME" | sed 's/-Instruct//g; s/-Thinking//g; s/Qwen3-VL-/qwen3-/g; s/-2B/-1.7B/g')
source "${SLIME_DIR}/scripts/models/${MODEL_ARGS_FILE}.sh"
# VL models require rotary-base 5000000
MODEL_ARGS_ROTARY_BASE=5000000 source "${SLIME_DIR}/scripts/models/${MODEL_ARGS_FILE}.sh"
fi

# Start Ray if not using external Ray
Expand Down
2 changes: 1 addition & 1 deletion scripts/models/qwen3-1.7B.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ MODEL_ARGS=(
--disable-bias-linear
--normalization "RMSNorm"
--norm-epsilon 1e-6
--rotary-base 1000000
--rotary-base "${MODEL_ARGS_ROTARY_BASE:-1000000}"
--vocab-size 151936
--kv-channels 128
--qk-layernorm
Expand Down
2 changes: 1 addition & 1 deletion scripts/models/qwen3-235B-A22B.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ MODEL_ARGS=(
--untie-embeddings-and-output-weights
--vocab-size 151936

--rotary-base 1000000
--rotary-base "${MODEL_ARGS_ROTARY_BASE:-1000000}"

# moe
--moe-ffn-hidden-size 1536
Expand Down
2 changes: 1 addition & 1 deletion scripts/models/qwen3-30B-A3B.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ MODEL_ARGS=(
--untie-embeddings-and-output-weights
--vocab-size 151936

--rotary-base 1000000
--rotary-base "${MODEL_ARGS_ROTARY_BASE:-1000000}"

# moe
--moe-ffn-hidden-size 768
Expand Down
2 changes: 1 addition & 1 deletion scripts/models/qwen3-8B.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ MODEL_ARGS=(
--disable-bias-linear
--normalization "RMSNorm"
--norm-epsilon 1e-6
--rotary-base 1000000
--rotary-base "${MODEL_ARGS_ROTARY_BASE:-1000000}"
--vocab-size 151936
--kv-channels 128
--qk-layernorm
Expand Down
10 changes: 6 additions & 4 deletions slime/backends/fsdp_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,8 +463,10 @@ def _packed_data(
rollout_log_probs=(
rollout_data["rollout_log_probs"][start:end] if "rollout_log_probs" in rollout_data else None
),
multimodal_inputs=(
rollout_data["multimodal_inputs"][start:end] if "multimodal_inputs" in rollout_data else None
multimodal_train_inputs=(
rollout_data["multimodal_train_inputs"][start:end]
if "multimodal_train_inputs" in rollout_data
else None
),
num_packs=mbs_size,
)
Expand Down Expand Up @@ -890,8 +892,8 @@ def _get_model_inputs_args(self, packed_sequence: dict) -> dict:
"position_ids": position_ids,
"attention_mask": None,
}
if packed_sequence.get("multimodal_inputs"):
model_args.update(packed_sequence["multimodal_inputs"])
if packed_sequence.get("multimodal_train_inputs"):
model_args.update(packed_sequence["multimodal_train_inputs"])
return model_args


Expand Down
16 changes: 8 additions & 8 deletions slime/backends/fsdp_utils/data_packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def pack_sequences(
advantages: list[float],
returns: list[float],
rollout_log_probs: list[list[float]] | None = None,
multimodal_inputs: list[dict] | None = None,
multimodal_train_inputs: list[dict] | None = None,
max_tokens_per_gpu: int | None = None,
num_packs: int | None = None,
) -> list[dict]:
Expand All @@ -33,7 +33,7 @@ def pack_sequences(
advantages: List of advantages per sequence
returns: List of returns per sequence
rollout_log_probs: List of rollout log probabilities per sequence
multimodal_inputs: List of dict of multimodal tokens per sequence
multimodal_train_inputs: List of dict of multimodal tensors for training per sequence
max_tokens_per_gpu: Maximum tokens per GPU pack
num_packs: Explicit number of packs to create

Expand Down Expand Up @@ -100,19 +100,19 @@ def pack_sequences(
),
}

# Collect and add multimodal inputs for this partition
if multimodal_inputs:
# Collect and add multimodal training tensors for this partition
if multimodal_train_inputs:
multimodal_data = {} # key -> concatenated tensor
multimodal_num_items = {} # key -> list of item counts per sequence
for i in indices:
for key, mm_tensor in multimodal_inputs[i].items():
for key, mm_tensor in multimodal_train_inputs[i].items():
if key not in multimodal_data:
multimodal_data[key] = mm_tensor
multimodal_num_items[key] = [mm_tensor.size(0)]
else:
multimodal_data[key] = torch.cat([multimodal_data[key], mm_tensor], dim=0)
multimodal_num_items[key].append(mm_tensor.size(0))
packed_batch["multimodal_inputs"] = multimodal_data
packed_batch["multimodal_train_inputs"] = multimodal_data
packed_batch["multimodal_num_items"] = multimodal_num_items

result.append(packed_batch)
Expand Down Expand Up @@ -157,8 +157,8 @@ def unpack_sequences(packed_batch: dict) -> list[dict]:
# Skip multimodal_num_items - it's metadata
if key == "multimodal_num_items":
continue
# Handle multimodal_inputs dict: split each tensor using multimodal_num_items
elif key == "multimodal_inputs" and isinstance(value, dict):
# Handle multimodal_train_inputs dict: split each tensor using multimodal_num_items
elif key == "multimodal_train_inputs" and isinstance(value, dict):
instance[key] = {}
for mm_key, mm_tensor in value.items():
if mm_key in multimodal_num_items:
Expand Down
8 changes: 4 additions & 4 deletions slime/backends/megatron_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,15 +193,15 @@ def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch:
rollout_data["loss_masks"] = [
torch.tensor(t, dtype=torch.int, device=torch.cuda.current_device()) for t in rollout_data["loss_masks"]
]
if "multimodal_inputs" in rollout_data:
# Move multimodal inputs to GPU in advance
rollout_data["multimodal_inputs"] = [
if "multimodal_train_inputs" in rollout_data:
# Move multimodal training tensors to GPU in advance
rollout_data["multimodal_train_inputs"] = [
(
{key: tensor.to(device=torch.cuda.current_device()) for key, tensor in mm_dict.items()}
if mm_dict is not None
else None
)
for mm_dict in rollout_data["multimodal_inputs"]
for mm_dict in rollout_data["multimodal_train_inputs"]
]
if "rollout_log_probs" in rollout_data:
rollout_data["rollout_log_probs"] = [
Expand Down
12 changes: 6 additions & 6 deletions slime/backends/megatron_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@ def get_batch(
assert loss_masks.shape == tokens.shape, f"loss_masks.shape: {loss_masks.shape}, tokens.shape: {tokens.shape}"
batch["full_loss_masks"] = loss_masks

# Process multimodal inputs if present
multimodal_inputs = batch.get("multimodal_inputs", None)
if multimodal_inputs is not None:
# Process multimodal training tensors if present
multimodal_train_inputs = batch.get("multimodal_train_inputs", None)
if multimodal_train_inputs is not None:
multimodal_data = {} # key -> concatenated tensor
multimodal_num_items = {} # key -> list of item counts per sequence
for mm_input_dict in multimodal_inputs:
for mm_input_dict in multimodal_train_inputs:
if mm_input_dict is not None:
for key, mm_tensor in mm_input_dict.items():
if key not in multimodal_data:
Expand All @@ -121,7 +121,7 @@ def get_batch(
else:
multimodal_data[key] = torch.cat([multimodal_data[key], mm_tensor], dim=0)
multimodal_num_items[key].append(mm_tensor.size(0))
batch["multimodal_inputs"] = multimodal_data
batch["multimodal_train_inputs"] = multimodal_data
batch["multimodal_num_items"] = multimodal_num_items

return batch
Expand Down Expand Up @@ -349,7 +349,7 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc
for key, val in rollout_data.items():
if key in [
"tokens",
"multimodal_inputs",
"multimodal_train_inputs",
"loss_masks",
"sample_indices",
"rollout_routed_experts",
Expand Down
8 changes: 4 additions & 4 deletions slime/backends/megatron_utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def forward_step(
[
"tokens",
"loss_masks",
"multimodal_inputs",
"multimodal_train_inputs",
"total_lengths",
"response_lengths",
],
Expand All @@ -225,7 +225,7 @@ def forward_step(
labels=None,
packed_seq_params=packed_seq_params,
loss_mask=batch["full_loss_masks"],
**(batch["multimodal_inputs"] if batch["multimodal_inputs"] is not None else {}),
**(batch["multimodal_train_inputs"] if batch["multimodal_train_inputs"] is not None else {}),
)

return output_tensor, partial(
Expand Down Expand Up @@ -354,7 +354,7 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p
data_iterator,
[
"tokens",
"multimodal_inputs",
"multimodal_train_inputs",
"packed_seq_params",
"total_lengths",
"response_lengths",
Expand Down Expand Up @@ -392,7 +392,7 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p
packed_seq_params=batch["packed_seq_params"],
loss_mask=batch["full_loss_masks"],
mtp_kwargs={"mtp_labels": batch["tokens"]} if args.enable_mtp_training else {},
**(batch["multimodal_inputs"] if batch["multimodal_inputs"] is not None else {}),
**(batch["multimodal_train_inputs"] if batch["multimodal_train_inputs"] is not None else {}),
)

if os.environ.get("ENABLE_ROUTING_REPLAY", "0") == "1":
Expand Down
18 changes: 3 additions & 15 deletions slime/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from slime.utils.metric_checker import MetricChecker
from slime.utils.metric_utils import compute_pass_rate, compute_rollout_step, compute_statistics, dict_add_prefix
from slime.utils.misc import load_function
from slime.utils.processing_utils import load_processor
from slime.utils.ray_utils import Box
from slime.utils.seqlen_balancing import get_seqlen_balanced_partitions
from slime.utils.tracking_utils import init_tracking
Expand Down Expand Up @@ -78,7 +77,6 @@ def __init__(self, args, pg):
self._metric_checker = MetricChecker.maybe_create(args)
if self.args.use_fault_tolerance:
self._health_monitor = RolloutHealthMonitor(self, args)
self.processor = None

def dispose(self):
if self._metric_checker is not None:
Expand Down Expand Up @@ -276,18 +274,8 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl
if samples[0].train_metadata is not None:
train_data["metadata"] = [sample.train_metadata for sample in samples]

if samples[0].multimodal_inputs is not None:
if self.processor is None:
self.processor = load_processor(self.args.hf_checkpoint, trust_remote_code=True)
train_data["multimodal_inputs"] = []
for sample in samples:
# Get input IDs with full prompt (text + multimodal)
processor_output = self.processor(text=sample.prompt, **sample.multimodal_inputs)

# Extract multimodal tokens (exclude text-related tokens)
train_data["multimodal_inputs"].append(
{k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"]}
)
if samples[0].multimodal_train_inputs is not None:
train_data["multimodal_train_inputs"] = [sample.multimodal_train_inputs for sample in samples]

if "teacher_log_probs" in samples[0].__dict__:
train_data["teacher_log_probs"] = [sample.teacher_log_probs for sample in samples]
Expand Down Expand Up @@ -320,7 +308,7 @@ def _split_train_data_by_dp(self, data, dp_size):
rollout_data["partition"] = partition
for key in [
"tokens",
"multimodal_inputs",
"multimodal_train_inputs",
"response_lengths",
"rewards",
"truncated",
Expand Down
3 changes: 3 additions & 0 deletions slime/rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A
if state.processor:
processor_output = state.processor(text=sample.prompt, **sample.multimodal_inputs)
prompt_ids = processor_output["input_ids"][0]
sample.multimodal_train_inputs = {
k: v for k, v in processor_output.items() if k not in ["input_ids", "attention_mask"]
} or None
else:
prompt_ids = state.tokenizer.encode(sample.prompt, add_special_tokens=False)

Expand Down
Loading
Loading