Skip to content

[FSDP, VLM] feat: add vlm training for FSDP#501

Merged
zhuzilin merged 58 commits intoTHUDM:mainfrom
nanjiangwill:feat/vlm
Dec 4, 2025
Merged

[FSDP, VLM] feat: add vlm training for FSDP#501
zhuzilin merged 58 commits intoTHUDM:mainfrom
nanjiangwill:feat/vlm

Conversation

@nanjiangwill
Copy link
Collaborator

@nanjiangwill nanjiangwill commented Oct 15, 2025

Goal: Support VLM training on slime with FSDP

TODO

@nanjiangwill nanjiangwill marked this pull request as draft October 15, 2025 06:20
# Process images for training (like tokenization for images)
if images_for_training and state.processor is not None:
processed = state.processor(images=images_for_training, return_tensors="pt")
sample.pixel_values = processed["pixel_values"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shouldn't try to pass the pixel values from sglang to megatron, but instead maybe re-process the image from the training side is better.

@coding-famer
Copy link
Contributor

Hi @nanjiangwill, glad to see this important feature! I'd love to help with this PR if you're open to collaborating. I can help with supporting qwen-vl models and geo3k training example. Please LMK if you are happy with collaborating!

@nanjiangwill
Copy link
Collaborator Author

Hi @nanjiangwill, glad to see this important feature! I'd love to help with this PR if you're open to collaborating. I can help with supporting qwen-vl models and geo3k training example. Please LMK if you are happy with collaborating!

hi @coding-famer, that will be amazing! can i have ur email/contact number so i can reach out to you?

@coding-famer
Copy link
Contributor

Hi @nanjiangwill, glad to see this important feature! I'd love to help with this PR if you're open to collaborating. I can help with supporting qwen-vl models and geo3k training example. Please LMK if you are happy with collaborating!

hi @coding-famer, that will be amazing! can i have ur email/contact number so i can reach out to you?

Have sent you an email!

@dongyuanjushi
Copy link

Hi @nanjiangwill and @coding-famer, I'd love to help with this PR about the multi-turn part if you are open to this.

@nanjiangwill
Copy link
Collaborator Author

Hi @nanjiangwill and @coding-famer, I'd love to help with this PR about the multi-turn part if you are open to this.

heyy thanks for reaching out! can i have your email?

@dongyuanjushi
Copy link

Hi @nanjiangwill and @coding-famer, I'd love to help with this PR about the multi-turn part if you are open to this.

heyy thanks for reaching out! can i have your email?

Have sent the email!

@zhaochenyang20
Copy link
Collaborator

我的神 @nanjiangwill 😭

@zhaochenyang20
Copy link
Collaborator

牛逼!

@aJupyter
Copy link

请问大概什么时候才会合并呀

@zhaochenyang20
Copy link
Collaborator

请问大概什么时候才会合并呀

肯定会早于星际之门集群建设好

@jhinpan
Copy link
Contributor

jhinpan commented Dec 2, 2025

Could the code for multimodal processing be merged as soon as possible? The support for multi-turn is a separate module and can be merged later. Otherwise, this will affect the issues based on the multimodal model and merge code.

This PR will be merge in the next few days. Pls just give us some time to do final check and review.

@zhaochenyang20
Copy link
Collaborator

🈚️敌!

@jhinpan
Copy link
Contributor

jhinpan commented Dec 3, 2025

We ran three experiments comparing different reward model configurations on 8*H100:

  1. geo3k RM with tol=0.05 (original)
  2. geo3k RM with tol=0.0 (strict matching)
  3. default math RM

Results showed all three configurations perform similarly, indicating:

  • The tolerance parameter has minimal impact on training outcomes
  • The geo3k-specific reward model provides no advantage over the default math RM

I will later remove them. Just leave a note here to mark:

image

Copy link
Contributor

@zhuzilin zhuzilin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Left some minor comments.

flat_rollout_log_probs, dtype=torch.float32, device=torch.cuda.current_device()
),
"multimodal_inputs": multimodal_data,
"multimodal_num_items": multimodal_num_items,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here we need sth like:

packed_batch = {
    "tokens": ...,
}
if multimodal_inputs:
    for key, mm_tensor in multimodal_inputs[i].items():
        ...
    packed_batch.extend({
        "multimodal_inputs": multimodal_data,
        "multimodal_num_items": multimodal_num_items,
    })
result.append(packed_batch)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

used .update() instead of .extend() since dictionaries use update to merge key-value pairs.

):
# group norm
rewards = torch.tensor(raw_rewards, dtype=torch.float)
rewards = torch.tensor(raw_rewards, dtype=torch.float16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm.. it seems better to move this type conversion into the custom reward model? otherwise, it may influence the other users who are using dense rm for rlhf.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move float16 specifically to notes in geo3k_vlm example

@zhuzilin zhuzilin marked this pull request as ready for review December 4, 2025 03:47
@zhuzilin zhuzilin merged commit 81d8cd6 into THUDM:main Dec 4, 2025
7 of 16 checks passed
@nanjiangwill nanjiangwill changed the title [WIP] feat: add vlm training for FSDP [FSDP, VLM] feat: add vlm training for FSDP Dec 4, 2025
@simondong1
Copy link

simondong1 commented Dec 7, 2025

Hi @nanjiangwill, this is awesome! I am also working on RL with VLMs. Would love to contribute and collaborate on further works! my email is simondong0919 at gmail.com love to get in touch

Fengzdadi pushed a commit to Fengzdadi/slime that referenced this pull request Dec 19, 2025
Co-authored-by: Chenhe Gu <chenhegu0109@gmail.com>
Co-authored-by: Jin Pan <jpan236@wisc.edu>
Co-authored-by: Jinn <47354855+jhinpan@users.noreply.github.com>
@adol001
Copy link

adol001 commented Dec 26, 2025

@nanjiangwill @zhuzilin

It looks like this commit changed the behavior of the apply_chat_template parameter. As a result, this parameter becomes useless during dataset construction: instead of applying the chat template, the data from prompt_key is fed directly into something like:

messages = [{"role": "user", "content": messages}]

In custom generate scenarios, do you realize what this can lead to?

You’re effectively replacing what used to be a str with a dict (or even a list of dicts). If downstream code doesn’t explicitly validate types, it will just get stuffed into the prompt and you end up with outputs like:

First rollout sample: ['<|im_start|>user\n[{'role': 'user', 'content':

This kind of change alters the default behavior in a non-backward-compatible way. At the very least, compatibility with existing users should be considered — and there should be warnings or clear notices. In my case, it indirectly caused training quality to degrade. I wouldn’t have noticed this logic change if it weren’t for a new training job that made the issue obvious.

@zhuzilin
Copy link
Contributor

@adol001 really sorry about this, we will revert this change.

@nanjiangwill
Copy link
Collaborator Author

@adol001 sorry again about this, we have fixed the issue with #1232 and #1234

jind11 pushed a commit to eigen-ai-labs/slime that referenced this pull request Feb 24, 2026
Co-authored-by: Chenhe Gu <chenhegu0109@gmail.com>
Co-authored-by: Jin Pan <jpan236@wisc.edu>
Co-authored-by: Jinn <47354855+jhinpan@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants