Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions verl/models/transformers/monkey_patch.py

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hello, I have one question here. I didn't see any code for internVL model for monkey path here. Does that mean InternVL do not require custom code or sequence parallel is not applicaple for InternVL now?
Thanks a lot!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InternVL does not have a special design that requires monkey patching. However, the vision model of InternVL does generate a high memory cost. For example, InternVL-Chat-V1.5, a 26B model, requires about 50G of memory for model parameters in BF16 format, and considering the additional overhead during training, it requires around 100-150G. The special requirement for vision encoder may need some discussion.

Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,18 @@ def apply_monkey_patch(
try:
num_attention_heads, num_key_value_heads = model.config.num_attention_heads, model.config.num_key_value_heads
except AttributeError:
num_attention_heads, num_key_value_heads = (
model.config.text_config.num_attention_heads,
model.config.text_config.num_key_value_heads,
)
if hasattr(model.config, "text_config"):
num_attention_heads, num_key_value_heads = (
model.config.text_config.num_attention_heads,
model.config.text_config.num_key_value_heads,
)
elif hasattr(model.config, "llm_config"):
num_attention_heads, num_key_value_heads = (
model.config.llm_config.num_attention_heads,
model.config.llm_config.num_key_value_heads,
)
else:
raise ValueError("We cannot get num_attention_heads and num_key_value_heads from the model's config")

assert num_attention_heads % ulysses_sp_size == 0, (
f"num_attention_heads {num_attention_heads} must be divisible by ulysses_sp_size {ulysses_sp_size}"
Expand Down
6 changes: 5 additions & 1 deletion verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,10 @@ actor_rollout_ref:
# Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather
# before the current forward computation.
forward_prefetch: False
# Whether use orig params. When frozen some parameters this will be used
use_orig_params: False
# Whether frozen the vision tower
frozen_vision_tower: False

# profiler configs
profiler:
Expand Down Expand Up @@ -1041,4 +1045,4 @@ ray_init:
num_cpus: null

# Path to save Ray timeline JSON for performance profiling
timeline_json_file: null
timeline_json_file: null
43 changes: 43 additions & 0 deletions verl/utils/dataset/preprocessor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from transformers import AutoProcessor
from .gemma import Gemma3Preprocessor
from .internvl import InternVLPreprocessor
from .qwen_vl import QwenVLPreProcessor
from .minicpmo import MiniCPMOPreProcessor
from .kimi_vl import KimiVLPreprocessor
from .registry import PREPROCESSOR_REGISTER
import re


def map_processor_to_preprocessor(processor:AutoProcessor):
"""
Map the processor to the Preprocessor
Args:
processor(AutoProcessor): The processor.
Return:
class: The preprocess class
"""
processor_name = processor.__class__.__name__
if not processor_name.lower().endswith("processor"):
raise ValueError(f"Source object '{processor_name}' is not a 'Processor'.")
if re.match("Qwen2.*?VLProcessor", processor_name):
print("QwenVL2 Series will use the QwenVLPreprocessor")
dest_name = "QwenVLPreprocessor".lower()
else:
dest_name = processor_name.lower().replace("processor", "preprocessor")

dest_class = PREPROCESSOR_REGISTER.get(dest_name)
return dest_class
60 changes: 60 additions & 0 deletions verl/utils/dataset/preprocessor/base_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
The basic preprocessor used for the multi-modal models.
"""

class BasicPreprocessor:
def __init__(self, processor, image_key="image", video_key="video", audio_key="audio"):
self.processor = processor
self.image_key = image_key
self.video_key = video_key
self.audio_key = audio_key

def process_image(self, image, **kwargs):
raise NotImplementedError("The process_image method must be implemented")

def process_video(self, video, **kwargs):
raise NotImplementedError("The process_video method must be implemented")

def process_audio(self, audio, **kwargs):
raise NotImplementedError("The process_video method must be implemented")

def __call__(self, messages, row_dict):
raw_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
multi_modal_data = {}

images = None
if self.image_key in row_dict:
images = [self.process_image(image) for image in row_dict.pop(self.image_key)]
multi_modal_data["image"] = images

videos = None
if self.video_key in row_dict:
videos = [self.process_video(video) for video in row_dict.pop(self.video_key)]
multi_modal_data["video"] = [video.numpy() for video in videos]
model_inputs = self.processor(text=[raw_prompt], images=images, videos=videos, return_tensors="pt")
input_ids = model_inputs.pop("input_ids")
attention_mask = model_inputs.pop("attention_mask")

if "second_per_grid_ts" in model_inputs:
model_inputs.pop("second_per_grid_ts")
# There's a trap here, multi_modal_inputs has to be a dict, not BatchFeature
row_dict["multi_modal_data"] = multi_modal_data
row_dict["multi_modal_inputs"] = dict(model_inputs)

# second_per_grid_ts isn't used for training, just for mrope
row_dict["multi_modal_inputs"].pop("second_per_grid_ts", None)
return row_dict, model_inputs, input_ids, attention_mask, raw_prompt
73 changes: 73 additions & 0 deletions verl/utils/dataset/preprocessor/gemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import copy
from PIL import Image
import requests
from io import BytesIO

from .base_processor import BasicPreprocessor
from .registry import PREPROCESSOR_REGISTER

__all__ = ["Gemma3Preprocessor"]

@PREPROCESSOR_REGISTER.register()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW I am thinking moving all model related code to the same folder, one per model. #2338 (review)
Given the complexity of multimodal structures, i think it's worth a RFC for the overall approach and design

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think it is a good strategy for the Multi-modality framework.

class Gemma3Preprocessor(BasicPreprocessor):
def __init__(self, processor, image_key="image", video_key="video"):
super().__init__(processor, image_key, video_key)

def process_image(self, image, **kwargs):
if isinstance(image, Image.Image):
image_obj = image
elif image.startswith("http://") or image.startswith("https://"):
# fix memory leak issue while using BytesIO
with requests.get(image, stream=True) as response:
response.raise_for_status()
with BytesIO(response.content) as bio:
image_obj = copy.deepcopy(Image.open(bio))
elif image.startswith("file://"):
image_obj = Image.open(image[7:])
elif image.startswith("data:image"):
if "base64," in image:
_, base64_data = image.split("base64,", 1)
data = base64.b64decode(base64_data)
# fix memory leak issue while using BytesIO
with BytesIO(data) as bio:
image_obj = copy.deepcopy(Image.open(bio))
else:
image_obj = Image.open(image)
return image_obj.convert("RGB")
Comment on lines +30 to +50
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would that be possible to create some kind of mixin class to handle the duplicate code? such as:

class MediaProcessingMixin:
    """Mixin providing common media processing functionality"""
    
    def _process_image_from_source(self, image, **kwargs):
        """Shared image processing logic"""
        if isinstance(image, Image.Image):
            image_obj = image
        elif image.startswith("http://") or image.startswith("https://"):
            with requests.get(image, stream=True) as response:
                response.raise_for_status()
                with BytesIO(response.content) as bio:
                    image_obj = copy.deepcopy(Image.open(bio))
        elif image.startswith("file://"):
            image_obj = Image.open(image[7:])
        elif image.startswith("data:image"):
            if "base64," in image:
                _, base64_data = image.split("base64,", 1)
                data = base64.b64decode(base64_data)
                with BytesIO(data) as bio:
                    image_obj = copy.deepcopy(Image.open(bio))
        else:
            image_obj = Image.open(image)
        return image_obj.convert("RGB")

# Now each preprocessor can inherit from both the base class AND the mixin
class Gemma3Preprocessor(BasicPreprocessor, MediaProcessingMixin):
    def process_image(self, image, **kwargs):
        return self._process_image_from_source(image, **kwargs)

class InternVLPreprocessor(BasicPreprocessor, MediaProcessingMixin):
    def process_image(self, image, **kwargs):
        return self._process_image_from_source(image, **kwargs)

class KimiVLPreprocessor(BasicPreprocessor, MediaProcessingMixin):
    def process_image(self, image, **kwargs):
        return self._process_image_from_source(image, **kwargs)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your advice, I will solve this.


def process_video(self, video, **kwargs):
raise ValueError("Gemma3 dose not support the video")

def process_audio(self, audio, **kwargs):
raise ValueError("Gemma3 dose not support the audio")

def __call__(self, messages, row_dict):
raw_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
multi_modal_data = {}

images = None
if self.image_key in row_dict:
images = [self.process_image(image) for image in row_dict.pop(self.image_key)]
multi_modal_data["image"] = images
model_inputs = self.processor(text=[raw_prompt], images=images, return_tensors="pt")
input_ids = model_inputs.pop("input_ids")
attention_mask = model_inputs.pop("attention_mask")
if 'token_type_ids' in model_inputs:
model_inputs.pop("token_type_ids")
row_dict["multi_modal_data"] = multi_modal_data
row_dict["multi_modal_inputs"] = dict(model_inputs)
return row_dict, model_inputs, input_ids, attention_mask, raw_prompt
146 changes: 146 additions & 0 deletions verl/utils/dataset/preprocessor/internvl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
The InternVL preprocessor used for the multi-modal models.
"""
import base64
import copy
from PIL import Image
import requests
from io import BytesIO
from qwen_vl_utils import fetch_video

from .base_processor import BasicPreprocessor
from .registry import PREPROCESSOR_REGISTER

__all__ = ["InternVLPreprocessor"]

VIDEO_FORMAT_HELP = """Currently, we only support the video formats introduced in qwen2-vl.
Refer to https://github.com/QwenLM/Qwen2.5-VL?tab=readme-ov-file#using---transformers-to-chat.

eg.
{
"type": "video",
"video": [
"file:///path/to/frame1.jpg",
"file:///path/to/frame2.jpg"
]
}

{
"type": "video",
"video": "file:///path/to/video.mp4"
}
# Defaults to fps=2, min_frames=4, max_frames=768

{
"type": "video",
"video": "file:///path/to/video.mp4",
"fps": 2,
"min_frames": 1,
"max_frames": 32
}
"""

@PREPROCESSOR_REGISTER.register()
class InternVLPreprocessor(BasicPreprocessor):
def __init__(self, processor, image_key="image", video_key="video", **kwargs):
super().__init__(processor, image_key=image_key, video_key=video_key)

def process_image(self, image, **kwargs):
if isinstance(image, Image.Image):
image_obj = image
elif image.startswith("http://") or image.startswith("https://"):
# fix memory leak issue while using BytesIO
with requests.get(image, stream=True) as response:
response.raise_for_status()
with BytesIO(response.content) as bio:
image_obj = copy.deepcopy(Image.open(bio))
elif image.startswith("file://"):
image_obj = Image.open(image[7:])
elif image.startswith("data:image"):
if "base64," in image:
_, base64_data = image.split("base64,", 1)
data = base64.b64decode(base64_data)
# fix memory leak issue while using BytesIO
with BytesIO(data) as bio:
image_obj = copy.deepcopy(Image.open(bio))
else:
image_obj = Image.open(image)
return image_obj.convert("RGB")

def process_video(self, video, **kwargs):
"""Converts a video dict into a [n_frames, 3, H, W] tensor

Add video sample FPS in a future MR
"""
nframes = kwargs.get("nframes", None)
fps = kwargs.get("fps", None)
fps_min_frames = kwargs.get("fps_min_frames", None),
fps_max_frames = kwargs.get("fps_max_frames", None),
if not isinstance(video, dict) or "video" not in video:
raise NotImplementedError(VIDEO_FORMAT_HELP)
assert nframes is None or fps is None, "Can't use both `nframes` or `fps`"

# Shallow copy... since we might want to add some keys
video = dict(video)

contains_sampling_rules = "nframes" in video or "fps" in video
if not contains_sampling_rules:
if nframes is not None:
video["nframes"] = nframes
elif fps is not None:
video["fps"] = fps
if fps_min_frames is not None:
video["min_frames"] = fps_min_frames
if fps_max_frames is not None:
video["max_frames"] = fps_max_frames
return fetch_video(video)

def process_audio(self, audio, **kwargs):
raise ValueError("InternVL dose not support audio")

def __call__(self, messages, row_dict):
raw_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
multi_modal_data = {}

images = None
if self.image_key in row_dict:
images = [self.process_image(image) for image in row_dict.pop(self.image_key)]
multi_modal_data["image"] = images

videos = None
if self.video_key in row_dict:
videos = [self.process_video(video) for video in row_dict.pop(self.video_key)]
multi_modal_data["video"] = [video.numpy() for video in videos]
raw_prompt_convert = raw_prompt
if "<image>" in raw_prompt_convert:
#In older version the fake_image_token will be used
raw_prompt_convert=raw_prompt_convert.replace("<image>", "<IMG_CONTEXT>")
model_inputs = self.processor(text=[raw_prompt_convert], images=images, videos=videos, return_tensors="pt")
input_ids = model_inputs.pop("input_ids")
attention_mask = model_inputs.pop("attention_mask")

if "second_per_grid_ts" in model_inputs:
model_inputs.pop("second_per_grid_ts")

# There's a trap here, multi_modal_inputs has to be a dict, not BatchFeature
row_dict["multi_modal_data"] = multi_modal_data
row_dict["multi_modal_inputs"] = dict(model_inputs)

# second_per_grid_ts isn't used for training, just for mrope
row_dict["multi_modal_inputs"].pop("second_per_grid_ts", None)
return row_dict, model_inputs, input_ids, attention_mask, raw_prompt
Loading