Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
a478470
feat: dataset and rollout vlm single turn ready
nanjiangwill Nov 12, 2025
316cf44
misc: remove code
nanjiangwill Nov 12, 2025
9299459
add geo3k reward utils
coding-famer Nov 13, 2025
366f694
multimodal inputs
coding-famer Nov 14, 2025
86511b6
do apply_chat_template in inference side
coding-famer Nov 15, 2025
a31fb56
apply precommit
coding-famer Nov 15, 2025
16e83da
fix
coding-famer Nov 15, 2025
42271a3
add script
coding-famer Nov 15, 2025
8acccdd
Merge remote-tracking branch 'origin/geo3k_utils' into feat/vlm
nanjiangwill Nov 17, 2025
122ef7c
feat: vlm single turn all components ready, start testing
nanjiangwill Nov 17, 2025
030af3f
merge main
nanjiangwill Nov 17, 2025
906ae47
update
nanjiangwill Nov 17, 2025
52ce7fa
lint
nanjiangwill Nov 17, 2025
f5844a3
lint
nanjiangwill Nov 17, 2025
7b0876d
rename
nanjiangwill Nov 17, 2025
7f0f773
remove useless code
coding-famer Nov 18, 2025
a13b061
make math grader robust
coding-famer Nov 22, 2025
a0640a3
add script
nanjiangwill Nov 24, 2025
2ae60e8
fix name
nanjiangwill Nov 24, 2025
7cdd062
fix name
nanjiangwill Nov 24, 2025
5046add
merge main
nanjiangwill Nov 24, 2025
9d097ca
add
nanjiangwill Nov 24, 2025
de0b7d0
add
nanjiangwill Nov 24, 2025
6f562e8
use spawn for multiprocessing
coding-famer Nov 24, 2025
f3ed679
update script
nanjiangwill Nov 28, 2025
c5cdbb0
update requirements
nanjiangwill Nov 28, 2025
77b60f2
update requirements
nanjiangwill Nov 28, 2025
fdf9da1
update requirements
nanjiangwill Nov 28, 2025
16b83e1
update scripts
nanjiangwill Nov 29, 2025
2a324b7
update scripts
nanjiangwill Nov 29, 2025
5bf5228
Change eval-top-k value from 0.7 to 1
nanjiangwill Nov 29, 2025
ff2cc8a
fix: hack
nanjiangwill Dec 1, 2025
6634abe
fix: hack
nanjiangwill Dec 1, 2025
72eed6d
merge main
nanjiangwill Dec 2, 2025
b3c3192
update script
nanjiangwill Dec 2, 2025
9ccd5e2
update script
nanjiangwill Dec 2, 2025
2c5f45b
update name
nanjiangwill Dec 2, 2025
fd528c3
update script
nanjiangwill Dec 2, 2025
e065a84
update readme
nanjiangwill Dec 2, 2025
8b3ac97
update readme
nanjiangwill Dec 2, 2025
118de3a
update script
nanjiangwill Dec 2, 2025
c30f2eb
better naming
nanjiangwill Dec 2, 2025
dfc0990
better naming
nanjiangwill Dec 2, 2025
d9bd14d
better naming
nanjiangwill Dec 2, 2025
d6152b6
better import
nanjiangwill Dec 2, 2025
33536db
update vlm readme
jhinpan Dec 2, 2025
20cabd5
Remove geo3k reward model & tol and use default math RM
jhinpan Dec 3, 2025
0f10239
Resolve conflicts
jhinpan Dec 3, 2025
d042915
Merge branch 'main' into feat/vlm
jhinpan Dec 3, 2025
80cebdf
Add new exp figs
jhinpan Dec 3, 2025
0f9d7a0
pre-commit and tiny fix
jhinpan Dec 3, 2025
115ce76
remove unused script
jhinpan Dec 4, 2025
dac6703
update cleaner notes about numerical precision issue
jhinpan Dec 4, 2025
ef10af1
revert tol in math utils
jhinpan Dec 4, 2025
7965a26
resolve 1st comment
jhinpan Dec 4, 2025
d8f5320
solve 2nd comments
jhinpan Dec 4, 2025
9f07e85
merge two multimodal blocks
jhinpan Dec 4, 2025
2a4f58d
fix ci
jhinpan Dec 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions examples/geo3k_vlm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# FSDP + VLM Single-Turn RL

Training VLMs with FSDP on single-turn reasoning task using GRPO on the [GEO3K dataset](https://huggingface.co/datasets/hiyouga/geometry3k). We used processed version [here](https://huggingface.co/datasets/chenhegu/geo3k_imgurl).

<p align="center">
<img src="rewards.png" alt="Reward Plot" width="800">
</p>

## Reproduce

```bash
export WANDB_API_KEY=your_wandb_api_key

SLIME_SCRIPT_MODEL_NAME=Qwen3-VL-2B-Instruct SLIME_SCRIPT_NUM_GPUS=8 python examples/geo3k_vlm/run_geo3k_vlm.py 2>&1 | tee run_simple.log
```

## Notes

### Reward Model Configuration

We experimented with three reward model configurations:
1. A geo3k-specific RM with tolerance=0.05 (to handle rounding in ground truth labels)
2. A geo3k-specific RM with tolerance=0.0 (strict matching)
3. The default math RM

All three performed similarly, so we use the default math RM for simplicity.

### Numerical Precision with Non-Binary Rewards

Our initial geo3k-specific verifier produced "format scores" (**0 and 0.9**) instead of clean binary rewards. Under **fp32**, fractional values like 0.9 can't be exactly represented, so when all samples in a group have the same reward, `reward - mean` doesn't equal zero—creating spurious gradient signal.

We fixed this by switching to the default math RM with clean **binary 0/1 rewards**. If you encounter similar precision issues with non-binary rewards, you can change the reward tensor dtype from `torch.float` to `torch.float16` in `slime/ray/rollout.py` (`_post_process_rewards` method) to truncate precision artifacts.
Binary file added examples/geo3k_vlm/rewards.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
193 changes: 193 additions & 0 deletions examples/geo3k_vlm/run_geo3k_vlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import json
import os
import subprocess

import slime.utils.misc as U

MODEL_NAME = os.environ.get("SLIME_SCRIPT_MODEL_NAME", "Qwen3-VL-2B-Instruct")
assert MODEL_NAME in {"Qwen2.5-VL-3B-Instruct", "Qwen3-VL-2B-Instruct", "Qwen3-VL-4B-Instruct", "Qwen3-VL-8B-Instruct"}

NUM_GPUS = int(os.environ.get("SLIME_SCRIPT_NUM_GPUS", "1"))
EXTERNAL_RAY = int(os.environ.get("SLIME_SCRIPT_EXTERNAL_RAY", "0"))
MASTER_ADDR = os.environ.get("MASTER_ADDR", "127.0.0.1")


def detect_nvlink():
"""Detect if NVLink is available on the system."""
try:
result = subprocess.run(["nvidia-smi"], capture_output=True, text=True, check=True)
nvlink_count = result.stdout.count("NVLink")
has_nvlink = 1 if nvlink_count > 0 else 0
print(f"HAS_NVLINK: {has_nvlink} (detected {nvlink_count} NVLink references)")
return has_nvlink
except Exception as e:
print(f"Failed to detect NVLink: {e}")
return 0


def prepare():
U.exec_command("mkdir -p /root/models /root/datasets")
U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}")
dataset_name = "chenhegu/geo3k_imgurl"
_, partial_name = dataset_name.split("/")
U.exec_command(f"hf download --repo-type dataset {dataset_name} --local-dir /root/datasets/{partial_name}")


def execute():
# Detect NVLink for optimized NCCL settings
has_nvlink = detect_nvlink()

ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} "

rollout_args = (
"--prompt-data /root/datasets/geo3k_imgurl/train.parquet "
"--input-key problem "
"--label-key answer "
'--multimodal-keys \'{"image": "images"}\' '
"--apply-chat-template "
"--rollout-shuffle "
"--rm-type math "
"--num-rollout 3000 "
"--rollout-batch-size 64 "
"--n-samples-per-prompt 8 "
"--rollout-max-response-len 4096 "
"--rollout-temperature 0.8 "
"--global-batch-size 512 "
)

eval_args = (
# "--eval-interval 20 "
"--eval-prompt-data geo3k-test /root/datasets/geo3k_imgurl/test.parquet "
"--n-samples-per-eval-prompt 1 "
"--eval-max-response-len 4096 "
"--eval-top-k 1 "
)

grpo_args = (
"--advantage-estimator grpo "
# "--use-kl-loss "
"--kl-loss-coef 0.00 "
"--kl-loss-type low_var_kl "
"--kl-coef 0.00 "
"--entropy-coef 0.00 "
"--eps-clip 0.2 "
"--eps-clip-high 0.28 "
)

optimizer_args = (
"--optimizer adam "
"--lr 1e-6 "
"--lr-decay-style constant "
"--weight-decay 0.1 "
"--adam-beta1 0.9 "
"--adam-beta2 0.98 "
)

sglang_args = (
"--rollout-num-gpus-per-engine 1 "
"--sglang-mem-fraction-static 0.6 "
f"--sglang-cuda-graph-bs {' '.join(map(str, [1, 2, 4, 8] + list(range(16, 257, 8))))} "
)

fsdp_args = (
# Set to true for FULL_STATE_DICT mode, false for SHARDED_STATE_DICT mode (default)
# "--fsdp-full-params " # Uncomment this line to enable full params mode
# Set the bucket size for weight update
"--update-weight-buffer-size 536870912 " # 512MB
"--train-backend fsdp "
"--gradient-checkpointing "
"--sglang-attention-backend fa3 "
"--attn-implementation flash_attention_3 "
)

wandb_args = (
"--use-wandb "
"--wandb-project geo3k-vlm "
"--wandb-group geo3k-vlm "
"--wandb-key ${WANDB_API_KEY} "
"--disable-wandb-random-suffix "
)

misc_args = "--actor-num-nodes 1 " f"--actor-num-gpus-per-node {NUM_GPUS} " "--colocate "

# misc_args += (
# "--use-dynamic-batch-size "
# # TODO pick a good value
# "--max-tokens-per-gpu 2048 "
# )

# true_on_policy_args = (
# "--sglang-enable-deterministic-inference "
# "--sglang-rl-on-policy-target fsdp "
# "--deterministic-mode "
# "--true-on-policy-mode "
# )
# true_on_policy_envs = {
# # TODO note: "Ring" in original RL PR, "allreduce:tree" in SGLang
# # "NCCL_ALGO": "Ring",
# "NCCL_ALGO": "allreduce:tree",
# "NVTE_ALLOW_NONDETERMINISTIC_ALGO": "0",
# "CUBLAS_WORKSPACE_CONFIG": ":4096:8",
# }

train_args = (
f"{ckpt_args} "
f"{rollout_args} "
f"{optimizer_args} "
f"{grpo_args} "
f"{sglang_args} "
f"{fsdp_args} "
f"{eval_args} "
f"{misc_args} "
f"{wandb_args} "
# f"{true_on_policy_args} "
)

# Kill existing processes
U.exec_command(
"pkill -9 sglang; "
"sleep 3; "
f"{'' if EXTERNAL_RAY else 'ray stop --force; '}"
f"{'' if EXTERNAL_RAY else 'pkill -9 ray; '}"
"pkill -9 slime; "
"sleep 3; "
f"{'' if EXTERNAL_RAY else 'pkill -9 ray; '}"
"pkill -9 slime; "
"pkill -9 redis; "
"true; "
)

if not EXTERNAL_RAY:
# Start Ray
U.exec_command(
f"export PYTHONBUFFERED=16 && "
f"ray start --head --node-ip-address {MASTER_ADDR} --num-gpus {NUM_GPUS} "
f"--disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265"
)

# Prepare runtime environment
runtime_env_json = json.dumps(
{
"env_vars": {
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
"NCCL_NVLS_ENABLE": str(has_nvlink),
# **true_on_policy_envs,
# "SGLANG_DUMPER_ENABLE": "0",
# "SGLANG_TEMP_UTILS_ENABLE_DEBUG_PRINT": "0",
}
}
)

# Submit Ray job
U.exec_command(
f"export no_proxy=127.0.0.1 && export PYTHONBUFFERED=16 && "
f'ray job submit --address="http://127.0.0.1:8265" '
f"--runtime-env-json='{runtime_env_json}' "
f"-- python3 /root/slime/train.py "
f"{train_args}"
)


if __name__ == "__main__":
prepare()
execute()
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ omegaconf
pillow
pylatexenc
pyyaml
qwen_vl_utils # for VLM
ray[default]
ring_flash_attn
sglang-router>=0.2.3
Expand Down
35 changes: 23 additions & 12 deletions slime/backends/fsdp_utils/actor.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import logging
import os
from argparse import Namespace
from itertools import accumulate


import ray
import torch
import torch.distributed as dist
import torch.nn.functional as F
from ring_flash_attn import substitute_hf_flash_attn, update_ring_flash_attn_params
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoTokenizer
from transformers import AutoConfig

from slime.ray.train_actor import TrainRayActor
from slime.utils import train_dump_utils, train_metric_utils
Expand All @@ -19,6 +19,7 @@
from slime.utils.memory_utils import clear_memory, print_memory
from slime.utils.metric_utils import compute_rollout_step
from slime.utils.ppo_utils import compute_approx_kl, compute_gspo_kl, compute_opsm_mask, compute_policy_loss
from slime.utils.processing_utils import load_processor, load_tokenizer
from slime.utils.ray_utils import Box
from slime.utils.timer import Timer, inverse_timer, timer
from slime.utils.tracking_utils import init_tracking
Expand Down Expand Up @@ -73,16 +74,15 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty
for i in range(dist.get_world_size()):
if i == dist.get_rank():
self.hf_config = AutoConfig.from_pretrained(self.args.hf_checkpoint, trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(self.args.hf_checkpoint, trust_remote_code=True)
self.tokenizer = load_tokenizer(self.args.hf_checkpoint, trust_remote_code=True)
if self.args.multimodal_keys:
self.processor = load_processor(self.args.hf_checkpoint, trust_remote_code=True)
dist.barrier(group=get_gloo_group())

if self.args.multimodal_keys:
self.vlm_processor = AutoProcessor.from_pretrained(self.args.hf_checkpoint, trust_remote_code=True)

init_context = self._get_init_weight_context_manager()

with init_context():
model = AutoModelForCausalLM.from_pretrained(
model = self.get_model_cls().from_pretrained(
self.args.hf_checkpoint,
trust_remote_code=True,
attn_implementation=self.args.attn_implementation,
Expand Down Expand Up @@ -142,6 +142,16 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty

return int(getattr(self.args, "start_rollout_id", 0))

def get_model_cls(self):
if self.args.multimodal_keys:
from transformers import AutoModelForVision2Seq

return AutoModelForVision2Seq
else:
from transformers import AutoModelForCausalLM

return AutoModelForCausalLM

def _enable_true_on_policy_optimizations(self, args):
if args.true_on_policy_mode:
from sglang.srt.batch_invariant_ops import enable_batch_invariant_mode
Expand Down Expand Up @@ -347,8 +357,6 @@ def _compute_log_prob(
tqdm(packed_batches, desc=f"{store_prefix}log_probs", disable=dist.get_rank() != 0)
):
model_args = self._get_model_inputs_args(batch)
if "pixel_values" in batch:
model_args["pixel_values"] = batch["pixel_values"]
logits = active_model(**model_args).logits.squeeze(0).float()
log_probs_result, entropy_result = get_logprob_and_entropy_with_cp(
logits=logits,
Expand Down Expand Up @@ -436,6 +444,9 @@ def _packed_data(
rollout_log_probs=(
rollout_data["rollout_log_probs"][start:end] if "rollout_log_probs" in rollout_data else None
),
multimodal_inputs=(
rollout_data["multimodal_inputs"][start:end] if "multimodal_inputs" in rollout_data else None
),
num_packs=mbs_size,
)
)
Expand Down Expand Up @@ -783,15 +794,13 @@ def _create_ref_model(self, ref_load_path: str | None):
if ref_load_path is None:
raise ValueError("ref_load_path must be provided when loading reference model")

import os

if os.path.isdir(ref_load_path):
logger.info(f"[Rank {dist.get_rank()}] Creating separate ref model from {ref_load_path}")

init_context = self._get_init_weight_context_manager()

with init_context():
ref_model = AutoModelForCausalLM.from_pretrained(
ref_model = self.get_model_cls().from_pretrained(
ref_load_path,
trust_remote_code=True,
attn_implementation=self.args.attn_implementation,
Expand Down Expand Up @@ -828,6 +837,8 @@ def _get_model_inputs_args(self, packed_sequence: dict) -> dict:
"position_ids": position_ids,
"attention_mask": None,
}
if packed_sequence.get("multimodal_inputs"):
model_args.update(packed_sequence["multimodal_inputs"])
return model_args


Expand Down
Loading