diff --git a/verl/models/transformers/monkey_patch.py b/verl/models/transformers/monkey_patch.py index 720b25c1668..bd7a0ed5091 100644 --- a/verl/models/transformers/monkey_patch.py +++ b/verl/models/transformers/monkey_patch.py @@ -16,6 +16,7 @@ """ import sys +from types import SimpleNamespace from typing import Optional import torch @@ -239,70 +240,60 @@ def state_dict(self, *args, **kwargs): print("Monkey patch state_dict in AutoModelForCausalLMWithValueHead. ") # TODO: VLM models only, unify monkey patch to LLM models. - if model.config.model_type == "qwen2_5_vl": + if model.config.model_type in ["qwen2_5_vl", "qwen2_vl"]: + # Step 1: patch model to support image-text mixed data if is_transformers_version_in_range(min_version="4.52.0"): from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( - Qwen2_5_VLAttention, Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel, Qwen2_5_VLTextModel, ) - - from verl.models.transformers.qwen2_vl import forward_with_normal_backend, qwen2_vl_base_forward - - Qwen2_5_VLModel.forward = qwen2_vl_base_forward - Qwen2_5_VLForConditionalGeneration.forward = forward_with_normal_backend - else: - from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( - Qwen2_5_VLFlashAttention2 as Qwen2_5_VLAttention, - ) - from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( - Qwen2_5_VLForConditionalGeneration, - ) - from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel as Qwen2_5_VLTextModel - - from verl.models.transformers.qwen2_vl import forward_with_normal_backend - - Qwen2_5_VLForConditionalGeneration.forward = forward_with_normal_backend - - if use_remove_padding or ulysses_sp_size > 1: - from verl.models.transformers.qwen2_vl import qwen2_vl_attn_forward - - Qwen2_5_VLAttention.forward = qwen2_vl_attn_forward - print("Monkey patch Qwen2.5VL attention layer") - - if ulysses_sp_size > 1: - patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLTextModel) - - elif model.config.model_type == "qwen2_vl": - if is_transformers_version_in_range(min_version="4.52.0"): from transformers.models.qwen2_vl.modeling_qwen2_vl import ( - Qwen2VLAttention, Qwen2VLForConditionalGeneration, Qwen2VLModel, Qwen2VLTextModel, ) - - from verl.models.transformers.qwen2_vl import forward_with_normal_backend, qwen2_vl_base_forward - - Qwen2VLModel.forward = qwen2_vl_base_forward - Qwen2VLForConditionalGeneration.forward = forward_with_normal_backend else: - from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 as Qwen2VLAttention + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel as Qwen2_5_VLTextModel from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel as Qwen2VLTextModel - from verl.models.transformers.qwen2_vl import forward_with_normal_backend + Qwen2_5_VLModel = SimpleNamespace(forward=None) + Qwen2VLModel = SimpleNamespace(forward=None) + + from verl.models.transformers.qwen2_vl import forward_with_normal_backend, qwen2_vl_base_forward - Qwen2VLForConditionalGeneration.forward = forward_with_normal_backend + Qwen2_5_VLModel.forward = qwen2_vl_base_forward + Qwen2VLModel.forward = qwen2_vl_base_forward + Qwen2_5_VLForConditionalGeneration.forward = forward_with_normal_backend + Qwen2VLForConditionalGeneration.forward = forward_with_normal_backend + print(f"Monkey patch {model.__class__.__name__} model forward") + + # Step 2: patch attention to support ulysses parallelism + if is_transformers_version_in_range(min_version="4.54.0"): + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLAttention + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLAttention + elif is_transformers_version_in_range(min_version="4.53.0"): + raise RuntimeError("Transformers 4.53.* is bugged. Use transformers 4.54.0 or later.") + else: + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VLFlashAttention2 as Qwen2_5_VLAttention, + ) + from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + Qwen2VLFlashAttention2 as Qwen2VLAttention, + ) if use_remove_padding or ulysses_sp_size > 1: from verl.models.transformers.qwen2_vl import qwen2_vl_attn_forward + Qwen2_5_VLAttention.forward = qwen2_vl_attn_forward Qwen2VLAttention.forward = qwen2_vl_attn_forward - print("Monkey patch Qwen2VL attention layer") + print(f"Monkey patch {model.__class__.__name__} attention layer") + # Step 3: patch input for multimodal sequence parallelism if ulysses_sp_size > 1: + patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLTextModel) patch_vlm_for_ulysses_input_slicing(Qwen2VLTextModel) elif model.config.model_type == "kimi_vl": diff --git a/verl/models/transformers/qwen2_vl.py b/verl/models/transformers/qwen2_vl.py index 0a07ca029ef..3afd4a44db7 100644 --- a/verl/models/transformers/qwen2_vl.py +++ b/verl/models/transformers/qwen2_vl.py @@ -216,19 +216,17 @@ def _custom_flash_attention_forward( if position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all(): batch_size = query_states.size(0) - query_states, key_states, value_states, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( + q, k, v, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = prepare_fa2_from_position_ids( query_states, key_states, value_states, position_ids ) - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens attn_output = flash_attn_varlen_func( - query_states, - key_states, - value_states, + q=q, + k=k, + v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, dropout_p=kwargs.pop("dropout", 0.0), softmax_scale=kwargs.pop("softmax_scale", None), causal=is_causal,