Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 32 additions & 41 deletions verl/models/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""

import sys
from types import SimpleNamespace
from typing import Optional

import torch
Expand Down Expand Up @@ -239,70 +240,60 @@ def state_dict(self, *args, **kwargs):
print("Monkey patch state_dict in AutoModelForCausalLMWithValueHead. ")

# TODO: VLM models only, unify monkey patch to LLM models.
if model.config.model_type == "qwen2_5_vl":
if model.config.model_type in ["qwen2_5_vl", "qwen2_vl"]:
# Step 1: patch model to support image-text mixed data
if is_transformers_version_in_range(min_version="4.52.0"):
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLAttention,
Qwen2_5_VLForConditionalGeneration,
Qwen2_5_VLModel,
Qwen2_5_VLTextModel,
)

from verl.models.transformers.qwen2_vl import forward_with_normal_backend, qwen2_vl_base_forward

Qwen2_5_VLModel.forward = qwen2_vl_base_forward
Qwen2_5_VLForConditionalGeneration.forward = forward_with_normal_backend
else:
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLFlashAttention2 as Qwen2_5_VLAttention,
)
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration,
)
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel as Qwen2_5_VLTextModel

from verl.models.transformers.qwen2_vl import forward_with_normal_backend

Qwen2_5_VLForConditionalGeneration.forward = forward_with_normal_backend

if use_remove_padding or ulysses_sp_size > 1:
from verl.models.transformers.qwen2_vl import qwen2_vl_attn_forward

Qwen2_5_VLAttention.forward = qwen2_vl_attn_forward
print("Monkey patch Qwen2.5VL attention layer")

if ulysses_sp_size > 1:
patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLTextModel)

elif model.config.model_type == "qwen2_vl":
if is_transformers_version_in_range(min_version="4.52.0"):
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
Qwen2VLAttention,
Qwen2VLForConditionalGeneration,
Qwen2VLModel,
Qwen2VLTextModel,
)

from verl.models.transformers.qwen2_vl import forward_with_normal_backend, qwen2_vl_base_forward

Qwen2VLModel.forward = qwen2_vl_base_forward
Qwen2VLForConditionalGeneration.forward = forward_with_normal_backend
else:
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 as Qwen2VLAttention
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel as Qwen2_5_VLTextModel
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel as Qwen2VLTextModel

from verl.models.transformers.qwen2_vl import forward_with_normal_backend
Qwen2_5_VLModel = SimpleNamespace(forward=None)
Qwen2VLModel = SimpleNamespace(forward=None)

from verl.models.transformers.qwen2_vl import forward_with_normal_backend, qwen2_vl_base_forward

Qwen2VLForConditionalGeneration.forward = forward_with_normal_backend
Qwen2_5_VLModel.forward = qwen2_vl_base_forward
Qwen2VLModel.forward = qwen2_vl_base_forward
Qwen2_5_VLForConditionalGeneration.forward = forward_with_normal_backend
Qwen2VLForConditionalGeneration.forward = forward_with_normal_backend
print(f"Monkey patch {model.__class__.__name__} model forward")

# Step 2: patch attention to support ulysses parallelism
if is_transformers_version_in_range(min_version="4.54.0"):
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLAttention
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLAttention
elif is_transformers_version_in_range(min_version="4.53.0"):
raise RuntimeError("Transformers 4.53.* is bugged. Use transformers 4.54.0 or later.")
else:
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLFlashAttention2 as Qwen2_5_VLAttention,
)
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
Qwen2VLFlashAttention2 as Qwen2VLAttention,
)

if use_remove_padding or ulysses_sp_size > 1:
from verl.models.transformers.qwen2_vl import qwen2_vl_attn_forward

Qwen2_5_VLAttention.forward = qwen2_vl_attn_forward
Qwen2VLAttention.forward = qwen2_vl_attn_forward
print("Monkey patch Qwen2VL attention layer")
print(f"Monkey patch {model.__class__.__name__} attention layer")

# Step 3: patch input for multimodal sequence parallelism
if ulysses_sp_size > 1:
patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLTextModel)
patch_vlm_for_ulysses_input_slicing(Qwen2VLTextModel)

elif model.config.model_type == "kimi_vl":
Expand Down
14 changes: 6 additions & 8 deletions verl/models/transformers/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,19 +216,17 @@ def _custom_flash_attention_forward(

if position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
batch_size = query_states.size(0)
query_states, key_states, value_states, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
q, k, v, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=kwargs.pop("dropout", 0.0),
softmax_scale=kwargs.pop("softmax_scale", None),
causal=is_causal,
Expand Down
Loading