Skip to content

[trainer] fix: fallback vision tower to flash_attention_2 for Qwen2.5-VL when u…#4670

Merged
wuxibin89 merged 1 commit intoverl-project:mainfrom
aoshen524:fix/qwen2.5vl-flash-attention-3-vit-fallback
Dec 29, 2025
Merged

[trainer] fix: fallback vision tower to flash_attention_2 for Qwen2.5-VL when u…#4670
wuxibin89 merged 1 commit intoverl-project:mainfrom
aoshen524:fix/qwen2.5vl-flash-attention-3-vit-fallback

Conversation

@aoshen524
Copy link
Contributor

@aoshen524 aoshen524 commented Dec 25, 2025

Fix: Fallback Vision Tower to Flash Attention 2 for Qwen2.5-VL when using Flash Attention 3

Description

This PR adds a patch for Qwen2.5-VL models to fallback the vision tower's attention implementation to flash_attention_2 when the main model uses flash_attention_3.

Motivation

Qwen2.5-VL's vision tower does not support flash_attention_3 properly. When attn_implementation is set to flash_attention_3, using FA3 for the vision tower causes significant performance degradation compared to flash_attention_2.

Experimental Validation

We have tested this patch across the entire Qwen2.5-VL series (3B, 7B, 32B, and 72B models) using the Transformers library on an 8×H100 GPU machine with auto device placement.

Below is the performance comparison for Qwen2.5-VL-7B with input of one 1260×700 image + 150 tokens of text:

======================================================================
COMPARISON SUMMARY
======================================================================

Implementation            Avg Latency (ms)   Throughput (tok/s)
-------------------------------------------------------------
flash_attention_2         102.85             12503.46      
flash_attention_3         309.49             4155.19              

FA3 vs FA2 Speedup: 0.33x
Memory Difference: +0.00 GB

Test Environment:

  • Hardware: 8×H100 GPUs
  • Library: Transformers with auto device placement
  • Models tested: Qwen2.5-VL-3B, 7B, 32B, 72B

Key Findings:

  • Flash Attention 3 is 3x slower than Flash Attention 2 for the vision tower
  • No memory benefit from using FA3 for vision components
  • Consistent behavior observed across all model sizes (3B, 7B, 32B, 72B)

Changes

  • Added a check for qwen2_5_vl model type
  • When attn_implementation == "flash_attention_3", automatically set actor_model_config.vision_config._attn_implementation = "flash_attention_2" for the vision tower
  • This allows the language model to use FA3 while the vision tower uses FA2, achieving optimal performance

Impact

This change ensures that Qwen2.5-VL models can benefit from flash_attention_3 for text processing while maintaining optimal performance for vision encoding.

Technical Details

The patch is applied in verl/workers/fsdp_workers.py in the _build_model_optimizer method:

# patch for qwen2.5-vl: when using flash_attention_3, set vision tower to use flash_attention_2
# because the vision tower does not support flash_attention_3
if (
    getattr(actor_model_config, "model_type", None) == "qwen2_5_vl"
    and attn_implementation == "flash_attention_3"
    and hasattr(actor_model_config, "vision_config")
):
    actor_model_config.vision_config._attn_implementation = "flash_attention_2"

Testing

Tested on:

  • Qwen2.5-VL-3B
  • Qwen2.5-VL-7B
  • Qwen2.5-VL-32B
  • Qwen2.5-VL-72B

All models show consistent performance improvements with this patch when using flash_attention_3 for the language model.

…sing flash_attention_3

Qwen2.5-VL vision tower does not support flash_attention_3, so when
attn_implementation is set to flash_attention_3, we need to set the
vision tower's _attn_implementation to flash_attention_2 instead.
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces a code change in verl/workers/fsdp_workers.py within the _build_model_optimizer function. This change adds a specific patch for the qwen2_5_vl model. If the model type is qwen2_5_vl and the attention implementation is set to flash_attention_3, the patch overrides the vision tower's attention implementation to flash_attention_2, as the vision tower does not support flash_attention_3.

@wuxibin89 wuxibin89 changed the title fix: fallback vision tower to flash_attention_2 for Qwen2.5-VL when u… [trainer] fix: fallback vision tower to flash_attention_2 for Qwen2.5-VL when u… Dec 29, 2025
@wuxibin89 wuxibin89 merged commit cd4072d into verl-project:main Dec 29, 2025
47 of 49 checks passed
boren-ms pushed a commit to boren-ms/verl that referenced this pull request Dec 30, 2025
…-VL when u… (verl-project#4670)

# Fix: Fallback Vision Tower to Flash Attention 2 for Qwen2.5-VL when
using Flash Attention 3

## Description

This PR adds a patch for Qwen2.5-VL models to fallback the vision
tower's attention implementation to flash_attention_2 when the main
model uses flash_attention_3.

## Motivation

Qwen2.5-VL's vision tower does not support flash_attention_3 properly.
When `attn_implementation` is set to `flash_attention_3`, using FA3 for
the vision tower causes significant performance degradation compared to
flash_attention_2.

## Experimental Validation

We have tested this patch across the entire Qwen2.5-VL series (3B, 7B,
32B, and 72B models) using the Transformers library on an 8×H100 GPU
machine with auto device placement.

Below is the performance comparison for Qwen2.5-VL-7B with input of one
1260×700 image + 150 tokens of text:

```
======================================================================
COMPARISON SUMMARY
======================================================================

Implementation            Avg Latency (ms)   Throughput (tok/s)
-------------------------------------------------------------
flash_attention_2         102.85             12503.46      
flash_attention_3         309.49             4155.19              

FA3 vs FA2 Speedup: 0.33x
Memory Difference: +0.00 GB
```

**Test Environment:**
- Hardware: 8×H100 GPUs
- Library: Transformers with auto device placement
- Models tested: Qwen2.5-VL-3B, 7B, 32B, 72B

**Key Findings:**
- Flash Attention 3 is **3x slower** than Flash Attention 2 for the
vision tower
- No memory benefit from using FA3 for vision components
- Consistent behavior observed across all model sizes (3B, 7B, 32B, 72B)

## Changes

- Added a check for `qwen2_5_vl` model type
- When `attn_implementation == "flash_attention_3"`, automatically set
`actor_model_config.vision_config._attn_implementation =
"flash_attention_2"` for the vision tower
- This allows the language model to use FA3 while the vision tower uses
FA2, achieving optimal performance

## Impact

This change ensures that Qwen2.5-VL models can benefit from
flash_attention_3 for text processing while maintaining optimal
performance for vision encoding.

## Technical Details

The patch is applied in `verl/workers/fsdp_workers.py` in the
`_build_model_optimizer` method:

```python
# patch for qwen2.5-vl: when using flash_attention_3, set vision tower to use flash_attention_2
# because the vision tower does not support flash_attention_3
if (
    getattr(actor_model_config, "model_type", None) == "qwen2_5_vl"
    and attn_implementation == "flash_attention_3"
    and hasattr(actor_model_config, "vision_config")
):
    actor_model_config.vision_config._attn_implementation = "flash_attention_2"
```

## Testing

Tested on:
- Qwen2.5-VL-3B
- Qwen2.5-VL-7B
- Qwen2.5-VL-32B
- Qwen2.5-VL-72B

All models show consistent performance improvements with this patch when
using flash_attention_3 for the language model.
jsfanfanfan pushed a commit to meituan-search/verl that referenced this pull request Jan 9, 2026
…-VL when u… (verl-project#4670)

# Fix: Fallback Vision Tower to Flash Attention 2 for Qwen2.5-VL when
using Flash Attention 3

## Description

This PR adds a patch for Qwen2.5-VL models to fallback the vision
tower's attention implementation to flash_attention_2 when the main
model uses flash_attention_3.

## Motivation

Qwen2.5-VL's vision tower does not support flash_attention_3 properly.
When `attn_implementation` is set to `flash_attention_3`, using FA3 for
the vision tower causes significant performance degradation compared to
flash_attention_2.

## Experimental Validation

We have tested this patch across the entire Qwen2.5-VL series (3B, 7B,
32B, and 72B models) using the Transformers library on an 8×H100 GPU
machine with auto device placement.

Below is the performance comparison for Qwen2.5-VL-7B with input of one
1260×700 image + 150 tokens of text:

```
======================================================================
COMPARISON SUMMARY
======================================================================

Implementation            Avg Latency (ms)   Throughput (tok/s)
-------------------------------------------------------------
flash_attention_2         102.85             12503.46      
flash_attention_3         309.49             4155.19              

FA3 vs FA2 Speedup: 0.33x
Memory Difference: +0.00 GB
```

**Test Environment:**
- Hardware: 8×H100 GPUs
- Library: Transformers with auto device placement
- Models tested: Qwen2.5-VL-3B, 7B, 32B, 72B

**Key Findings:**
- Flash Attention 3 is **3x slower** than Flash Attention 2 for the
vision tower
- No memory benefit from using FA3 for vision components
- Consistent behavior observed across all model sizes (3B, 7B, 32B, 72B)

## Changes

- Added a check for `qwen2_5_vl` model type
- When `attn_implementation == "flash_attention_3"`, automatically set
`actor_model_config.vision_config._attn_implementation =
"flash_attention_2"` for the vision tower
- This allows the language model to use FA3 while the vision tower uses
FA2, achieving optimal performance

## Impact

This change ensures that Qwen2.5-VL models can benefit from
flash_attention_3 for text processing while maintaining optimal
performance for vision encoding.

## Technical Details

The patch is applied in `verl/workers/fsdp_workers.py` in the
`_build_model_optimizer` method:

```python
# patch for qwen2.5-vl: when using flash_attention_3, set vision tower to use flash_attention_2
# because the vision tower does not support flash_attention_3
if (
    getattr(actor_model_config, "model_type", None) == "qwen2_5_vl"
    and attn_implementation == "flash_attention_3"
    and hasattr(actor_model_config, "vision_config")
):
    actor_model_config.vision_config._attn_implementation = "flash_attention_2"
```

## Testing

Tested on:
- Qwen2.5-VL-3B
- Qwen2.5-VL-7B
- Qwen2.5-VL-32B
- Qwen2.5-VL-72B

All models show consistent performance improvements with this patch when
using flash_attention_3 for the language model.
vyomakesh0728 added a commit to vyomakesh0728/verl that referenced this pull request Jan 22, 2026
…-VL when u… (verl-project#4670)

# Fix: Fallback Vision Tower to Flash Attention 2 for Qwen2.5-VL when
using Flash Attention 3

## Description

This PR adds a patch for Qwen2.5-VL models to fallback the vision
tower's attention implementation to flash_attention_2 when the main
model uses flash_attention_3.

## Motivation

Qwen2.5-VL's vision tower does not support flash_attention_3 properly.
When `attn_implementation` is set to `flash_attention_3`, using FA3 for
the vision tower causes significant performance degradation compared to
flash_attention_2.

## Experimental Validation

We have tested this patch across the entire Qwen2.5-VL series (3B, 7B,
32B, and 72B models) using the Transformers library on an 8×H100 GPU
machine with auto device placement.

Below is the performance comparison for Qwen2.5-VL-7B with input of one
1260×700 image + 150 tokens of text:

```
======================================================================
COMPARISON SUMMARY
======================================================================

Implementation            Avg Latency (ms)   Throughput (tok/s)
-------------------------------------------------------------
flash_attention_2         102.85             12503.46      
flash_attention_3         309.49             4155.19              

FA3 vs FA2 Speedup: 0.33x
Memory Difference: +0.00 GB
```

**Test Environment:**
- Hardware: 8×H100 GPUs
- Library: Transformers with auto device placement
- Models tested: Qwen2.5-VL-3B, 7B, 32B, 72B

**Key Findings:**
- Flash Attention 3 is **3x slower** than Flash Attention 2 for the
vision tower
- No memory benefit from using FA3 for vision components
- Consistent behavior observed across all model sizes (3B, 7B, 32B, 72B)

## Changes

- Added a check for `qwen2_5_vl` model type
- When `attn_implementation == "flash_attention_3"`, automatically set
`actor_model_config.vision_config._attn_implementation =
"flash_attention_2"` for the vision tower
- This allows the language model to use FA3 while the vision tower uses
FA2, achieving optimal performance

## Impact

This change ensures that Qwen2.5-VL models can benefit from
flash_attention_3 for text processing while maintaining optimal
performance for vision encoding.

## Technical Details

The patch is applied in `verl/workers/fsdp_workers.py` in the
`_build_model_optimizer` method:

```python
# patch for qwen2.5-vl: when using flash_attention_3, set vision tower to use flash_attention_2
# because the vision tower does not support flash_attention_3
if (
    getattr(actor_model_config, "model_type", None) == "qwen2_5_vl"
    and attn_implementation == "flash_attention_3"
    and hasattr(actor_model_config, "vision_config")
):
    actor_model_config.vision_config._attn_implementation = "flash_attention_2"
```

## Testing

Tested on:
- Qwen2.5-VL-3B
- Qwen2.5-VL-7B
- Qwen2.5-VL-32B
- Qwen2.5-VL-72B

All models show consistent performance improvements with this patch when
using flash_attention_3 for the language model.
sophiayyya pushed a commit to sophiayyya/verl that referenced this pull request Jan 25, 2026
…-VL when u… (verl-project#4670)

# Fix: Fallback Vision Tower to Flash Attention 2 for Qwen2.5-VL when
using Flash Attention 3

## Description

This PR adds a patch for Qwen2.5-VL models to fallback the vision
tower's attention implementation to flash_attention_2 when the main
model uses flash_attention_3.

## Motivation

Qwen2.5-VL's vision tower does not support flash_attention_3 properly.
When `attn_implementation` is set to `flash_attention_3`, using FA3 for
the vision tower causes significant performance degradation compared to
flash_attention_2.

## Experimental Validation

We have tested this patch across the entire Qwen2.5-VL series (3B, 7B,
32B, and 72B models) using the Transformers library on an 8×H100 GPU
machine with auto device placement.

Below is the performance comparison for Qwen2.5-VL-7B with input of one
1260×700 image + 150 tokens of text:

```
======================================================================
COMPARISON SUMMARY
======================================================================

Implementation            Avg Latency (ms)   Throughput (tok/s)
-------------------------------------------------------------
flash_attention_2         102.85             12503.46      
flash_attention_3         309.49             4155.19              

FA3 vs FA2 Speedup: 0.33x
Memory Difference: +0.00 GB
```

**Test Environment:**
- Hardware: 8×H100 GPUs
- Library: Transformers with auto device placement
- Models tested: Qwen2.5-VL-3B, 7B, 32B, 72B

**Key Findings:**
- Flash Attention 3 is **3x slower** than Flash Attention 2 for the
vision tower
- No memory benefit from using FA3 for vision components
- Consistent behavior observed across all model sizes (3B, 7B, 32B, 72B)

## Changes

- Added a check for `qwen2_5_vl` model type
- When `attn_implementation == "flash_attention_3"`, automatically set
`actor_model_config.vision_config._attn_implementation =
"flash_attention_2"` for the vision tower
- This allows the language model to use FA3 while the vision tower uses
FA2, achieving optimal performance

## Impact

This change ensures that Qwen2.5-VL models can benefit from
flash_attention_3 for text processing while maintaining optimal
performance for vision encoding.

## Technical Details

The patch is applied in `verl/workers/fsdp_workers.py` in the
`_build_model_optimizer` method:

```python
# patch for qwen2.5-vl: when using flash_attention_3, set vision tower to use flash_attention_2
# because the vision tower does not support flash_attention_3
if (
    getattr(actor_model_config, "model_type", None) == "qwen2_5_vl"
    and attn_implementation == "flash_attention_3"
    and hasattr(actor_model_config, "vision_config")
):
    actor_model_config.vision_config._attn_implementation = "flash_attention_2"
```

## Testing

Tested on:
- Qwen2.5-VL-3B
- Qwen2.5-VL-7B
- Qwen2.5-VL-32B
- Qwen2.5-VL-72B

All models show consistent performance improvements with this patch when
using flash_attention_3 for the language model.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants