-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Add the support for more VLMs(Gemma3 and InternVL) #2327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| # Copyright 2024 Bytedance Ltd. and/or its affiliates | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from transformers import AutoProcessor | ||
| from .gemma import Gemma3Preprocessor | ||
| from .internvl import InternVLPreprocessor | ||
| from .qwen_vl import QwenVLPreProcessor | ||
| from .minicpmo import MiniCPMOPreProcessor | ||
| from .kimi_vl import KimiVLPreprocessor | ||
| from .registry import PREPROCESSOR_REGISTER | ||
| import re | ||
|
|
||
|
|
||
| def map_processor_to_preprocessor(processor:AutoProcessor): | ||
| """ | ||
| Map the processor to the Preprocessor | ||
| Args: | ||
| processor(AutoProcessor): The processor. | ||
| Return: | ||
| class: The preprocess class | ||
| """ | ||
| processor_name = processor.__class__.__name__ | ||
| if not processor_name.lower().endswith("processor"): | ||
| raise ValueError(f"Source object '{processor_name}' is not a 'Processor'.") | ||
| if re.match("Qwen2.*?VLProcessor", processor_name): | ||
| print("QwenVL2 Series will use the QwenVLPreprocessor") | ||
| dest_name = "QwenVLPreprocessor".lower() | ||
| else: | ||
| dest_name = processor_name.lower().replace("processor", "preprocessor") | ||
|
|
||
| dest_class = PREPROCESSOR_REGISTER.get(dest_name) | ||
| return dest_class |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| # Copyright 2024 Bytedance Ltd. and/or its affiliates | ||
|
|
||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
|
|
||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """ | ||
| The basic preprocessor used for the multi-modal models. | ||
| """ | ||
|
|
||
| class BasicPreprocessor: | ||
| def __init__(self, processor, image_key="image", video_key="video", audio_key="audio"): | ||
| self.processor = processor | ||
| self.image_key = image_key | ||
| self.video_key = video_key | ||
| self.audio_key = audio_key | ||
|
|
||
| def process_image(self, image, **kwargs): | ||
| raise NotImplementedError("The process_image method must be implemented") | ||
|
|
||
| def process_video(self, video, **kwargs): | ||
| raise NotImplementedError("The process_video method must be implemented") | ||
|
|
||
| def process_audio(self, audio, **kwargs): | ||
| raise NotImplementedError("The process_video method must be implemented") | ||
|
|
||
| def __call__(self, messages, row_dict): | ||
| raw_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) | ||
| multi_modal_data = {} | ||
|
|
||
| images = None | ||
| if self.image_key in row_dict: | ||
| images = [self.process_image(image) for image in row_dict.pop(self.image_key)] | ||
| multi_modal_data["image"] = images | ||
|
|
||
| videos = None | ||
| if self.video_key in row_dict: | ||
| videos = [self.process_video(video) for video in row_dict.pop(self.video_key)] | ||
| multi_modal_data["video"] = [video.numpy() for video in videos] | ||
| model_inputs = self.processor(text=[raw_prompt], images=images, videos=videos, return_tensors="pt") | ||
| input_ids = model_inputs.pop("input_ids") | ||
| attention_mask = model_inputs.pop("attention_mask") | ||
|
|
||
| if "second_per_grid_ts" in model_inputs: | ||
| model_inputs.pop("second_per_grid_ts") | ||
| # There's a trap here, multi_modal_inputs has to be a dict, not BatchFeature | ||
| row_dict["multi_modal_data"] = multi_modal_data | ||
| row_dict["multi_modal_inputs"] = dict(model_inputs) | ||
|
|
||
| # second_per_grid_ts isn't used for training, just for mrope | ||
| row_dict["multi_modal_inputs"].pop("second_per_grid_ts", None) | ||
| return row_dict, model_inputs, input_ids, attention_mask, raw_prompt |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| # Copyright 2024 Bytedance Ltd. and/or its affiliates | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import base64 | ||
| import copy | ||
| from PIL import Image | ||
| import requests | ||
| from io import BytesIO | ||
|
|
||
| from .base_processor import BasicPreprocessor | ||
| from .registry import PREPROCESSOR_REGISTER | ||
|
|
||
| __all__ = ["Gemma3Preprocessor"] | ||
|
|
||
| @PREPROCESSOR_REGISTER.register() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW I am thinking moving all model related code to the same folder, one per model. #2338 (review)
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I think it is a good strategy for the Multi-modality framework. |
||
| class Gemma3Preprocessor(BasicPreprocessor): | ||
| def __init__(self, processor, image_key="image", video_key="video"): | ||
| super().__init__(processor, image_key, video_key) | ||
|
|
||
| def process_image(self, image, **kwargs): | ||
| if isinstance(image, Image.Image): | ||
| image_obj = image | ||
| elif image.startswith("http://") or image.startswith("https://"): | ||
| # fix memory leak issue while using BytesIO | ||
| with requests.get(image, stream=True) as response: | ||
| response.raise_for_status() | ||
| with BytesIO(response.content) as bio: | ||
| image_obj = copy.deepcopy(Image.open(bio)) | ||
| elif image.startswith("file://"): | ||
| image_obj = Image.open(image[7:]) | ||
| elif image.startswith("data:image"): | ||
| if "base64," in image: | ||
| _, base64_data = image.split("base64,", 1) | ||
| data = base64.b64decode(base64_data) | ||
| # fix memory leak issue while using BytesIO | ||
| with BytesIO(data) as bio: | ||
| image_obj = copy.deepcopy(Image.open(bio)) | ||
| else: | ||
| image_obj = Image.open(image) | ||
| return image_obj.convert("RGB") | ||
|
Comment on lines
+30
to
+50
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would that be possible to create some kind of mixin class to handle the duplicate code? such as: class MediaProcessingMixin:
"""Mixin providing common media processing functionality"""
def _process_image_from_source(self, image, **kwargs):
"""Shared image processing logic"""
if isinstance(image, Image.Image):
image_obj = image
elif image.startswith("http://") or image.startswith("https://"):
with requests.get(image, stream=True) as response:
response.raise_for_status()
with BytesIO(response.content) as bio:
image_obj = copy.deepcopy(Image.open(bio))
elif image.startswith("file://"):
image_obj = Image.open(image[7:])
elif image.startswith("data:image"):
if "base64," in image:
_, base64_data = image.split("base64,", 1)
data = base64.b64decode(base64_data)
with BytesIO(data) as bio:
image_obj = copy.deepcopy(Image.open(bio))
else:
image_obj = Image.open(image)
return image_obj.convert("RGB")
# Now each preprocessor can inherit from both the base class AND the mixin
class Gemma3Preprocessor(BasicPreprocessor, MediaProcessingMixin):
def process_image(self, image, **kwargs):
return self._process_image_from_source(image, **kwargs)
class InternVLPreprocessor(BasicPreprocessor, MediaProcessingMixin):
def process_image(self, image, **kwargs):
return self._process_image_from_source(image, **kwargs)
class KimiVLPreprocessor(BasicPreprocessor, MediaProcessingMixin):
def process_image(self, image, **kwargs):
return self._process_image_from_source(image, **kwargs)
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for your advice, I will solve this. |
||
|
|
||
| def process_video(self, video, **kwargs): | ||
| raise ValueError("Gemma3 dose not support the video") | ||
|
|
||
| def process_audio(self, audio, **kwargs): | ||
| raise ValueError("Gemma3 dose not support the audio") | ||
|
|
||
| def __call__(self, messages, row_dict): | ||
| raw_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) | ||
| multi_modal_data = {} | ||
|
|
||
| images = None | ||
| if self.image_key in row_dict: | ||
| images = [self.process_image(image) for image in row_dict.pop(self.image_key)] | ||
| multi_modal_data["image"] = images | ||
| model_inputs = self.processor(text=[raw_prompt], images=images, return_tensors="pt") | ||
| input_ids = model_inputs.pop("input_ids") | ||
| attention_mask = model_inputs.pop("attention_mask") | ||
| if 'token_type_ids' in model_inputs: | ||
| model_inputs.pop("token_type_ids") | ||
| row_dict["multi_modal_data"] = multi_modal_data | ||
| row_dict["multi_modal_inputs"] = dict(model_inputs) | ||
| return row_dict, model_inputs, input_ids, attention_mask, raw_prompt | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| # Copyright 2024 Bytedance Ltd. and/or its affiliates | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
|
|
||
| """ | ||
| The InternVL preprocessor used for the multi-modal models. | ||
| """ | ||
| import base64 | ||
| import copy | ||
| from PIL import Image | ||
| import requests | ||
| from io import BytesIO | ||
| from qwen_vl_utils import fetch_video | ||
|
|
||
| from .base_processor import BasicPreprocessor | ||
| from .registry import PREPROCESSOR_REGISTER | ||
|
|
||
| __all__ = ["InternVLPreprocessor"] | ||
|
|
||
| VIDEO_FORMAT_HELP = """Currently, we only support the video formats introduced in qwen2-vl. | ||
| Refer to https://github.com/QwenLM/Qwen2.5-VL?tab=readme-ov-file#using---transformers-to-chat. | ||
|
|
||
| eg. | ||
| { | ||
| "type": "video", | ||
| "video": [ | ||
| "file:///path/to/frame1.jpg", | ||
| "file:///path/to/frame2.jpg" | ||
| ] | ||
| } | ||
|
|
||
| { | ||
| "type": "video", | ||
| "video": "file:///path/to/video.mp4" | ||
| } | ||
| # Defaults to fps=2, min_frames=4, max_frames=768 | ||
|
|
||
| { | ||
| "type": "video", | ||
| "video": "file:///path/to/video.mp4", | ||
| "fps": 2, | ||
| "min_frames": 1, | ||
| "max_frames": 32 | ||
| } | ||
| """ | ||
|
|
||
| @PREPROCESSOR_REGISTER.register() | ||
| class InternVLPreprocessor(BasicPreprocessor): | ||
| def __init__(self, processor, image_key="image", video_key="video", **kwargs): | ||
| super().__init__(processor, image_key=image_key, video_key=video_key) | ||
|
|
||
| def process_image(self, image, **kwargs): | ||
| if isinstance(image, Image.Image): | ||
| image_obj = image | ||
| elif image.startswith("http://") or image.startswith("https://"): | ||
| # fix memory leak issue while using BytesIO | ||
| with requests.get(image, stream=True) as response: | ||
| response.raise_for_status() | ||
| with BytesIO(response.content) as bio: | ||
| image_obj = copy.deepcopy(Image.open(bio)) | ||
| elif image.startswith("file://"): | ||
| image_obj = Image.open(image[7:]) | ||
| elif image.startswith("data:image"): | ||
| if "base64," in image: | ||
| _, base64_data = image.split("base64,", 1) | ||
| data = base64.b64decode(base64_data) | ||
| # fix memory leak issue while using BytesIO | ||
| with BytesIO(data) as bio: | ||
| image_obj = copy.deepcopy(Image.open(bio)) | ||
| else: | ||
| image_obj = Image.open(image) | ||
| return image_obj.convert("RGB") | ||
|
|
||
| def process_video(self, video, **kwargs): | ||
| """Converts a video dict into a [n_frames, 3, H, W] tensor | ||
|
|
||
| Add video sample FPS in a future MR | ||
| """ | ||
| nframes = kwargs.get("nframes", None) | ||
| fps = kwargs.get("fps", None) | ||
| fps_min_frames = kwargs.get("fps_min_frames", None), | ||
| fps_max_frames = kwargs.get("fps_max_frames", None), | ||
| if not isinstance(video, dict) or "video" not in video: | ||
| raise NotImplementedError(VIDEO_FORMAT_HELP) | ||
| assert nframes is None or fps is None, "Can't use both `nframes` or `fps`" | ||
|
|
||
| # Shallow copy... since we might want to add some keys | ||
| video = dict(video) | ||
|
|
||
| contains_sampling_rules = "nframes" in video or "fps" in video | ||
| if not contains_sampling_rules: | ||
| if nframes is not None: | ||
| video["nframes"] = nframes | ||
| elif fps is not None: | ||
| video["fps"] = fps | ||
| if fps_min_frames is not None: | ||
| video["min_frames"] = fps_min_frames | ||
| if fps_max_frames is not None: | ||
| video["max_frames"] = fps_max_frames | ||
| return fetch_video(video) | ||
|
|
||
| def process_audio(self, audio, **kwargs): | ||
| raise ValueError("InternVL dose not support audio") | ||
|
|
||
| def __call__(self, messages, row_dict): | ||
| raw_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) | ||
| multi_modal_data = {} | ||
|
|
||
| images = None | ||
| if self.image_key in row_dict: | ||
| images = [self.process_image(image) for image in row_dict.pop(self.image_key)] | ||
| multi_modal_data["image"] = images | ||
|
|
||
| videos = None | ||
| if self.video_key in row_dict: | ||
| videos = [self.process_video(video) for video in row_dict.pop(self.video_key)] | ||
| multi_modal_data["video"] = [video.numpy() for video in videos] | ||
| raw_prompt_convert = raw_prompt | ||
| if "<image>" in raw_prompt_convert: | ||
| #In older version the fake_image_token will be used | ||
| raw_prompt_convert=raw_prompt_convert.replace("<image>", "<IMG_CONTEXT>") | ||
| model_inputs = self.processor(text=[raw_prompt_convert], images=images, videos=videos, return_tensors="pt") | ||
| input_ids = model_inputs.pop("input_ids") | ||
| attention_mask = model_inputs.pop("attention_mask") | ||
|
|
||
| if "second_per_grid_ts" in model_inputs: | ||
| model_inputs.pop("second_per_grid_ts") | ||
|
|
||
| # There's a trap here, multi_modal_inputs has to be a dict, not BatchFeature | ||
| row_dict["multi_modal_data"] = multi_modal_data | ||
| row_dict["multi_modal_inputs"] = dict(model_inputs) | ||
|
|
||
| # second_per_grid_ts isn't used for training, just for mrope | ||
| row_dict["multi_modal_inputs"].pop("second_per_grid_ts", None) | ||
| return row_dict, model_inputs, input_ids, attention_mask, raw_prompt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hello, I have one question here. I didn't see any code for internVL model for monkey path here. Does that mean InternVL do not require custom code or sequence parallel is not applicaple for InternVL now?
Thanks a lot!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
InternVL does not have a special design that requires monkey patching. However, the vision model of InternVL does generate a high memory cost. For example, InternVL-Chat-V1.5, a 26B model, requires about 50G of memory for model parameters in BF16 format, and considering the additional overhead during training, it requires around 100-150G. The special requirement for vision encoder may need some discussion.