Skip to content

Conversation

@popomen
Copy link
Contributor

@popomen popomen commented May 28, 2025

Checklist Before Starting

  • Search for similar PR(s).

What does this PR do?

Only apply Ulysses sequence parallel to the LLM part of the VLM model, which is the main component, to avoid the Image features and image tokens do not match issue from occurring before masked_scatter.

High-Level Design

Demonstrate the high-level design if this PR is complex.

Specific Changes

  1. For the VLM model, we only pad the inputs before forward pass without slicing them; instead, we perform slicing after the embedding stage.
  2. In cases where ViT and LLM share/reuse FlashAttention, distinguish the ViT scenario and skip the Ulysses logic.

API

Demonstrate how the API changes if any.

Usage Example

Provide usage example(s) for easier usage.

# Add code snippet or script demonstrating how to use this 

Test

python -m verl.trainer.main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=/mnt/hdfs/zhudelin123/data/geo3k/train.parquet \
    data.val_files=/mnt/hdfs/zhudelin123/data/geo3k/test.parquet \
    data.train_batch_size=64 \
    data.max_prompt_length=2048 \
    data.max_response_length=2048 \
    data.filter_overlong_prompts=True \
    data.truncation=error \
    data.image_key=images \
    actor_rollout_ref.model.path=/mnt/hdfs/Qwen2.5-VL-7B-Instruct \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=64 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.01 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.model.use_fused_kernels=True \
    actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
    actor_rollout_ref.rollout.enable_chunked_prefill=False \
    actor_rollout_ref.rollout.enforce_eager=False \
    actor_rollout_ref.rollout.free_cache_engine=False \
    actor_rollout_ref.rollout.n=4 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger=[console,wandb] \
    trainer.project_name=nanzhe_verl_grpo_example_geo3k \
    trainer.experiment_name=qwen2_5_vl_7b_sp2_test \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=2 \
    trainer.save_freq=-1 \
    trainer.test_freq=-1 \
    trainer.default_hdfs_dir=null \
    trainer.total_epochs=1 \
    trainer.resume_mode=disable
image

Additional Info.

  • Issue Number: Fixes issue # or discussion # if any.
  • Training: [Note which backend this PR will affect: FSDP, Megatron, both, or none]
  • Inference: [Note which backend this PR will affect: vLLM, SGLang, both, or none]

Checklist Before Submitting

  • Read the Contribute Guide.
  • Apply pre-commit checks.
  • Add [BREAKING] to the PR title if it breaks any API.
  • Update the documentation about your changes in the docs.
  • Add CI test(s) if necessary.

@CLAassistant
Copy link

CLAassistant commented May 28, 2025

CLA assistant check
All committers have signed the CLA.

@vermouth1992 vermouth1992 requested a review from hiyouga May 28, 2025 13:09
hiyouga
hiyouga previously approved these changes May 28, 2025
Copy link
Collaborator

@hiyouga hiyouga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall LGTM!

@vermouth1992
Copy link
Collaborator

Could you modify the vlm test to use sp=2? Thanks

@vermouth1992
Copy link
Collaborator

overall LGTM!

Do we need per-model modification for ulysses?

@hiyouga
Copy link
Collaborator

hiyouga commented May 29, 2025

@vermouth1992 TRUE. To implement SP for VLMs, we need to insert some operations into the top-level module. Since the definition of this module (e.g.,Qwen2VLModel) varies across models, we currently need to apply model-specific patches for Ulysses. The transformers library is being refactored to unify VLMs, so we’ll be able to adopt a more unified approach once their work is done.

Copy link
Collaborator

@hiyouga hiyouga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also enable SP for the unittest:

TRAIN_FILES=$HOME/data/geo3k/train.parquet VAL_FILES=$HOME/data/geo3k/test.parquet \
MAX_PROMPT_LEN=1536 MAX_RESPONSE_LEN=1536 \
MODEL_ID=Qwen/Qwen2-VL-2B-Instruct \
ADV_ESTIMATOR=grpo RM_PAD=True USE_KL=True ENABLE_CHUNKED_PREFILL=False \
bash tests/e2e/ppo_trainer/run_function_reward.sh

@popomen popomen changed the title [Bugfix] Fix ulysses sequence parallelism for vlm [vlm] Support ulysses sequence parallelism for vlm May 29, 2025
@popomen
Copy link
Contributor Author

popomen commented May 29, 2025

overall LGTM!

Do we need per-model modification for ulysses?

I've attempted a less invasive approach. For the slice inputs section, you only need to inject Ulysses logic with a single line of code. However, for the alltoall operation within the attention block, you still need to patch the entire method.

Copy link
Collaborator

@hiyouga hiyouga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@hiyouga hiyouga merged commit 9c50ffd into volcengine:main May 30, 2025
18 of 36 checks passed
wwwjn pushed a commit to wwwjn/verl that referenced this pull request Jun 10, 2025
### Checklist Before Starting

- [x] Search for similar PR(s).

### What does this PR do?

Only apply Ulysses sequence parallel to the LLM part of the VLM model,
which is the main component, to avoid `the Image features and image
tokens do not match` issue from occurring before `masked_scatter`.

### High-Level Design

> Demonstrate the high-level design if this PR is complex.

### Specific Changes

1. For the VLM model, we only pad the inputs before forward pass without
slicing them; instead, we perform slicing after the embedding stage.
2. In cases where ViT and LLM share/reuse FlashAttention, distinguish
the ViT scenario and skip the Ulysses logic.

### API

> Demonstrate how the API changes if any.

### Usage Example

> Provide usage example(s) for easier usage.

```python
# Add code snippet or script demonstrating how to use this 
```

### Test

```
python -m verl.trainer.main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=/mnt/hdfs/zhudelin123/data/geo3k/train.parquet \
    data.val_files=/mnt/hdfs/zhudelin123/data/geo3k/test.parquet \
    data.train_batch_size=64 \
    data.max_prompt_length=2048 \
    data.max_response_length=2048 \
    data.filter_overlong_prompts=True \
    data.truncation=error \
    data.image_key=images \
    actor_rollout_ref.model.path=/mnt/hdfs/Qwen2.5-VL-7B-Instruct \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=64 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.01 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.model.use_fused_kernels=True \
    actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
    actor_rollout_ref.rollout.enable_chunked_prefill=False \
    actor_rollout_ref.rollout.enforce_eager=False \
    actor_rollout_ref.rollout.free_cache_engine=False \
    actor_rollout_ref.rollout.n=4 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger=[console,wandb] \
    trainer.project_name=nanzhe_verl_grpo_example_geo3k \
    trainer.experiment_name=qwen2_5_vl_7b_sp2_test \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=2 \
    trainer.save_freq=-1 \
    trainer.test_freq=-1 \
    trainer.default_hdfs_dir=null \
    trainer.total_epochs=1 \
    trainer.resume_mode=disable
```

<img width="481" alt="image"
src="https://github.com/user-attachments/assets/066db41d-46cf-4bc8-9d50-b9a8189c7654"
/>


### Additional Info.

- **Issue Number**: Fixes issue # or discussion # if any.
- **Training**: [Note which backend this PR will affect: FSDP, Megatron,
both, or none]
- **Inference**: [Note which backend this PR will affect: vLLM, SGLang,
both, or none]

### Checklist Before Submitting

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting).
- [x] Add `[BREAKING]` to the PR title if it breaks any API.
- [ ] Update the documentation about your changes in the
[docs](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add CI test(s) if necessary.
vermouth1992 pushed a commit that referenced this pull request Jun 10, 2025
### Checklist Before Starting

- [x] Search for similar PR(s).

### What does this PR do?
Fix sequence parallelism conflict in kimiVL patch.

Background:
A recent VLM-related PR(#1739 ) has modified the sequence parallelism
logic of VLM: Split inputs_embeds after the model's embedding layer
instand of spliting input_ids and position_ids before forward.
However, the SP logic I implemented in KimiVL's PR(#1639 ) was still
implemented in accordance with the old logic. And split the image token
at the combination of image_token and text_token to avoid the problem of
'the Image features and image tokens do not match'.
Since these two PR were developed in parallel which led to logical
conflicts after the PR were merged.

### High-Level Design

> Demonstrate the high-level design if this PR is complex.

### Specific Changes

- Delete the patch for _merge_with_image_features which to assign the
image token to the corresponding SP rank.
- Adjust the processing related to position_ids in
_ulysses_flash_attn_forward.

### API

> Demonstrate how the API changes if any.

### Usage Example

> Provide usage example(s) for easier usage.

```python
# Add code snippet or script demonstrating how to use this 
```

### Test


![image](https://github.com/user-attachments/assets/82ef7a74-66f8-4bb0-a0fc-3702b215c8c0)


### Additional Info.

- **Issue Number**: Fixes issue # or discussion # if any.
- **Training**: [Note which backend this PR will affect: FSDP, Megatron,
both, or none]
- **Inference**: [Note which backend this PR will affect: vLLM, SGLang,
both, or none]

### Checklist Before Submitting

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting).
- [ ] Add `[BREAKING]` to the PR title if it breaks any API.
- [ ] Update the documentation about your changes in the
[docs](https://github.com/volcengine/verl/tree/main/docs).
- [ ] New CI unit test(s) are added to cover the code path.
- [ ] Rely on existing unit tests on CI that covers the code path.

---------

Signed-off-by: ShareLer <[email protected]>
yellowbee686 pushed a commit to yellowbee686/verl that referenced this pull request Jun 10, 2025
### Checklist Before Starting

- [x] Search for similar PR(s).

### What does this PR do?
Fix sequence parallelism conflict in kimiVL patch.

Background:
A recent VLM-related PR(volcengine#1739 ) has modified the sequence parallelism
logic of VLM: Split inputs_embeds after the model's embedding layer
instand of spliting input_ids and position_ids before forward.
However, the SP logic I implemented in KimiVL's PR(volcengine#1639 ) was still
implemented in accordance with the old logic. And split the image token
at the combination of image_token and text_token to avoid the problem of
'the Image features and image tokens do not match'.
Since these two PR were developed in parallel which led to logical
conflicts after the PR were merged.

### High-Level Design

> Demonstrate the high-level design if this PR is complex.

### Specific Changes

- Delete the patch for _merge_with_image_features which to assign the
image token to the corresponding SP rank.
- Adjust the processing related to position_ids in
_ulysses_flash_attn_forward.

### API

> Demonstrate how the API changes if any.

### Usage Example

> Provide usage example(s) for easier usage.

```python
# Add code snippet or script demonstrating how to use this 
```

### Test


![image](https://github.com/user-attachments/assets/82ef7a74-66f8-4bb0-a0fc-3702b215c8c0)


### Additional Info.

- **Issue Number**: Fixes issue # or discussion # if any.
- **Training**: [Note which backend this PR will affect: FSDP, Megatron,
both, or none]
- **Inference**: [Note which backend this PR will affect: vLLM, SGLang,
both, or none]

### Checklist Before Submitting

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting).
- [ ] Add `[BREAKING]` to the PR title if it breaks any API.
- [ ] Update the documentation about your changes in the
[docs](https://github.com/volcengine/verl/tree/main/docs).
- [ ] New CI unit test(s) are added to cover the code path.
- [ ] Rely on existing unit tests on CI that covers the code path.

---------

Signed-off-by: ShareLer <[email protected]>
Comment on lines +113 to +120
if is_vlm_model:
# vlm model's inputs will be sliced after embedding
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad(
input_ids_rmpad,
position_ids_rmpad=position_ids_rmpad,
sp_size=self.ulysses_sequence_parallel_size,
)
else:

This comment was marked as off-topic.

whatadayG pushed a commit to whatadayG/verl that referenced this pull request Sep 5, 2025
### Checklist Before Starting

- [x] Search for similar PR(s).

### What does this PR do?
Fix sequence parallelism conflict in kimiVL patch.

Background:
A recent VLM-related PR(volcengine#1739 ) has modified the sequence parallelism
logic of VLM: Split inputs_embeds after the model's embedding layer
instand of spliting input_ids and position_ids before forward.
However, the SP logic I implemented in KimiVL's PR(volcengine#1639 ) was still
implemented in accordance with the old logic. And split the image token
at the combination of image_token and text_token to avoid the problem of
'the Image features and image tokens do not match'.
Since these two PR were developed in parallel which led to logical
conflicts after the PR were merged.

### High-Level Design

> Demonstrate the high-level design if this PR is complex.

### Specific Changes

- Delete the patch for _merge_with_image_features which to assign the
image token to the corresponding SP rank.
- Adjust the processing related to position_ids in
_ulysses_flash_attn_forward.

### API

> Demonstrate how the API changes if any.

### Usage Example

> Provide usage example(s) for easier usage.

```python
# Add code snippet or script demonstrating how to use this 
```

### Test


![image](https://github.com/user-attachments/assets/82ef7a74-66f8-4bb0-a0fc-3702b215c8c0)


### Additional Info.

- **Issue Number**: Fixes issue # or discussion # if any.
- **Training**: [Note which backend this PR will affect: FSDP, Megatron,
both, or none]
- **Inference**: [Note which backend this PR will affect: vLLM, SGLang,
both, or none]

### Checklist Before Submitting

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting).
- [ ] Add `[BREAKING]` to the PR title if it breaks any API.
- [ ] Update the documentation about your changes in the
[docs](https://github.com/volcengine/verl/tree/main/docs).
- [ ] New CI unit test(s) are added to cover the code path.
- [ ] Rely on existing unit tests on CI that covers the code path.

---------

Signed-off-by: ShareLer <[email protected]>
chenjiaoAngel added a commit to chenjiaoAngel/verl that referenced this pull request Nov 14, 2025
### Checklist Before Starting

- [x] Search for similar PR(s).

### What does this PR do?

Only apply Ulysses sequence parallel to the LLM part of the VLM model,
which is the main component, to avoid `the Image features and image
tokens do not match` issue from occurring before `masked_scatter`.

### High-Level Design

> Demonstrate the high-level design if this PR is complex.

### Specific Changes

1. For the VLM model, we only pad the inputs before forward pass without
slicing them; instead, we perform slicing after the embedding stage.
2. In cases where ViT and LLM share/reuse FlashAttention, distinguish
the ViT scenario and skip the Ulysses logic.

### API

> Demonstrate how the API changes if any.

### Usage Example

> Provide usage example(s) for easier usage.

```python
# Add code snippet or script demonstrating how to use this 
```

### Test

```
python -m verl.trainer.main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=/mnt/hdfs/zhudelin123/data/geo3k/train.parquet \
    data.val_files=/mnt/hdfs/zhudelin123/data/geo3k/test.parquet \
    data.train_batch_size=64 \
    data.max_prompt_length=2048 \
    data.max_response_length=2048 \
    data.filter_overlong_prompts=True \
    data.truncation=error \
    data.image_key=images \
    actor_rollout_ref.model.path=/mnt/hdfs/Qwen2.5-VL-7B-Instruct \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=64 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.01 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.model.use_fused_kernels=True \
    actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
    actor_rollout_ref.rollout.enable_chunked_prefill=False \
    actor_rollout_ref.rollout.enforce_eager=False \
    actor_rollout_ref.rollout.free_cache_engine=False \
    actor_rollout_ref.rollout.n=4 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger=[console,wandb] \
    trainer.project_name=nanzhe_verl_grpo_example_geo3k \
    trainer.experiment_name=qwen2_5_vl_7b_sp2_test \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=2 \
    trainer.save_freq=-1 \
    trainer.test_freq=-1 \
    trainer.default_hdfs_dir=null \
    trainer.total_epochs=1 \
    trainer.resume_mode=disable
```

<img width="481" alt="image"
src="https://github.com/user-attachments/assets/066db41d-46cf-4bc8-9d50-b9a8189c7654"
/>


### Additional Info.

- **Issue Number**: Fixes issue # or discussion # if any.
- **Training**: [Note which backend this PR will affect: FSDP, Megatron,
both, or none]
- **Inference**: [Note which backend this PR will affect: vLLM, SGLang,
both, or none]

### Checklist Before Submitting

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting).
- [x] Add `[BREAKING]` to the PR title if it breaks any API.
- [ ] Update the documentation about your changes in the
[docs](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add CI test(s) if necessary.
chenjiaoAngel added a commit to chenjiaoAngel/verl that referenced this pull request Nov 14, 2025
### Checklist Before Starting

- [x] Search for similar PR(s).

### What does this PR do?
Fix sequence parallelism conflict in kimiVL patch.

Background:
A recent VLM-related PR(volcengine#1739 ) has modified the sequence parallelism
logic of VLM: Split inputs_embeds after the model's embedding layer
instand of spliting input_ids and position_ids before forward.
However, the SP logic I implemented in KimiVL's PR(volcengine#1639 ) was still
implemented in accordance with the old logic. And split the image token
at the combination of image_token and text_token to avoid the problem of
'the Image features and image tokens do not match'.
Since these two PR were developed in parallel which led to logical
conflicts after the PR were merged.

### High-Level Design

> Demonstrate the high-level design if this PR is complex.

### Specific Changes

- Delete the patch for _merge_with_image_features which to assign the
image token to the corresponding SP rank.
- Adjust the processing related to position_ids in
_ulysses_flash_attn_forward.

### API

> Demonstrate how the API changes if any.

### Usage Example

> Provide usage example(s) for easier usage.

```python
# Add code snippet or script demonstrating how to use this 
```

### Test


![image](https://github.com/user-attachments/assets/82ef7a74-66f8-4bb0-a0fc-3702b215c8c0)


### Additional Info.

- **Issue Number**: Fixes issue # or discussion # if any.
- **Training**: [Note which backend this PR will affect: FSDP, Megatron,
both, or none]
- **Inference**: [Note which backend this PR will affect: vLLM, SGLang,
both, or none]

### Checklist Before Submitting

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting).
- [ ] Add `[BREAKING]` to the PR title if it breaks any API.
- [ ] Update the documentation about your changes in the
[docs](https://github.com/volcengine/verl/tree/main/docs).
- [ ] New CI unit test(s) are added to cover the code path.
- [ ] Rely on existing unit tests on CI that covers the code path.

---------

Signed-off-by: ShareLer <[email protected]>
TimurTaepov pushed a commit to giorgossideris/verl that referenced this pull request Dec 20, 2025
### Checklist Before Starting

- [x] Search for similar PR(s).

### What does this PR do?

Only apply Ulysses sequence parallel to the LLM part of the VLM model,
which is the main component, to avoid `the Image features and image
tokens do not match` issue from occurring before `masked_scatter`.

### High-Level Design

> Demonstrate the high-level design if this PR is complex.

### Specific Changes

1. For the VLM model, we only pad the inputs before forward pass without
slicing them; instead, we perform slicing after the embedding stage.
2. In cases where ViT and LLM share/reuse FlashAttention, distinguish
the ViT scenario and skip the Ulysses logic.

### API

> Demonstrate how the API changes if any.

### Usage Example

> Provide usage example(s) for easier usage.

```python
# Add code snippet or script demonstrating how to use this 
```

### Test

```
python -m verl.trainer.main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=/mnt/hdfs/zhudelin123/data/geo3k/train.parquet \
    data.val_files=/mnt/hdfs/zhudelin123/data/geo3k/test.parquet \
    data.train_batch_size=64 \
    data.max_prompt_length=2048 \
    data.max_response_length=2048 \
    data.filter_overlong_prompts=True \
    data.truncation=error \
    data.image_key=images \
    actor_rollout_ref.model.path=/mnt/hdfs/Qwen2.5-VL-7B-Instruct \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=64 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.01 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.model.use_fused_kernels=True \
    actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
    actor_rollout_ref.rollout.enable_chunked_prefill=False \
    actor_rollout_ref.rollout.enforce_eager=False \
    actor_rollout_ref.rollout.free_cache_engine=False \
    actor_rollout_ref.rollout.n=4 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger=[console,wandb] \
    trainer.project_name=nanzhe_verl_grpo_example_geo3k \
    trainer.experiment_name=qwen2_5_vl_7b_sp2_test \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=2 \
    trainer.save_freq=-1 \
    trainer.test_freq=-1 \
    trainer.default_hdfs_dir=null \
    trainer.total_epochs=1 \
    trainer.resume_mode=disable
```

<img width="481" alt="image"
src="https://github.com/user-attachments/assets/066db41d-46cf-4bc8-9d50-b9a8189c7654"
/>


### Additional Info.

- **Issue Number**: Fixes issue # or discussion # if any.
- **Training**: [Note which backend this PR will affect: FSDP, Megatron,
both, or none]
- **Inference**: [Note which backend this PR will affect: vLLM, SGLang,
both, or none]

### Checklist Before Submitting

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting).
- [x] Add `[BREAKING]` to the PR title if it breaks any API.
- [ ] Update the documentation about your changes in the
[docs](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add CI test(s) if necessary.
TimurTaepov pushed a commit to giorgossideris/verl that referenced this pull request Dec 20, 2025
### Checklist Before Starting

- [x] Search for similar PR(s).

### What does this PR do?
Fix sequence parallelism conflict in kimiVL patch.

Background:
A recent VLM-related PR(volcengine#1739 ) has modified the sequence parallelism
logic of VLM: Split inputs_embeds after the model's embedding layer
instand of spliting input_ids and position_ids before forward.
However, the SP logic I implemented in KimiVL's PR(volcengine#1639 ) was still
implemented in accordance with the old logic. And split the image token
at the combination of image_token and text_token to avoid the problem of
'the Image features and image tokens do not match'.
Since these two PR were developed in parallel which led to logical
conflicts after the PR were merged.

### High-Level Design

> Demonstrate the high-level design if this PR is complex.

### Specific Changes

- Delete the patch for _merge_with_image_features which to assign the
image token to the corresponding SP rank.
- Adjust the processing related to position_ids in
_ulysses_flash_attn_forward.

### API

> Demonstrate how the API changes if any.

### Usage Example

> Provide usage example(s) for easier usage.

```python
# Add code snippet or script demonstrating how to use this 
```

### Test


![image](https://github.com/user-attachments/assets/82ef7a74-66f8-4bb0-a0fc-3702b215c8c0)


### Additional Info.

- **Issue Number**: Fixes issue # or discussion # if any.
- **Training**: [Note which backend this PR will affect: FSDP, Megatron,
both, or none]
- **Inference**: [Note which backend this PR will affect: vLLM, SGLang,
both, or none]

### Checklist Before Submitting

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting).
- [ ] Add `[BREAKING]` to the PR title if it breaks any API.
- [ ] Update the documentation about your changes in the
[docs](https://github.com/volcengine/verl/tree/main/docs).
- [ ] New CI unit test(s) are added to cover the code path.
- [ ] Rely on existing unit tests on CI that covers the code path.

---------

Signed-off-by: ShareLer <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants