Skip to content

Commit

Permalink
Refactor VLM modules for internvl-llava (#2797)
Browse files Browse the repository at this point in the history
  • Loading branch information
lvhan028 authored Nov 22, 2024
1 parent 45cf22d commit 38eec0d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 43 deletions.
59 changes: 20 additions & 39 deletions lmdeploy/vl/model/internvl_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@

import warnings
from contextlib import contextmanager
from typing import List, Union
from typing import Dict, List

import torch
from PIL.Image import Image
from transformers import AutoConfig, AutoModelForCausalLM

from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
from lmdeploy.vl.model.llava import VISION_MODELS, LlavaVisionModel
from lmdeploy.vl.model.utils import rewrite_ctx

from .utils import disable_logging, disable_transformers_logging
Expand All @@ -18,14 +17,13 @@


def check_llava_install():
"""check llava install."""
try:
from llava.model.multimodal_encoder.clip_encoder import \
InternVisionModel # noqa: F401
except ImportError:
raise ImportError(
'To use LlavaVLModel, please install llava by '
'pip install "git+https://github.com/OpenGVLab/InternVL#subdirectory=internvl_chat_llava" --no-deps' # noqa: E501
'`pip install git+https://github.com/OpenGVLab/InternVL#subdirectory=internvl_chat_llava --no-deps`' # noqa: E501
)


Expand Down Expand Up @@ -65,7 +63,7 @@ def init_empty_vit():


@VISION_MODELS.register_module()
class InternVLLlavaVisionModel(VisonModel):
class InternVLLlavaVisionModel(LlavaVisionModel):
"""Llava visual model."""

@classmethod
Expand All @@ -78,9 +76,11 @@ def match(cls, config: AutoConfig):
return True
return False

def build_preprocessor(self):
return super().build_preprocessor()

def build_model(self):
"""build model & load weights."""
# check llava install
check_llava_install()
# currently, only support llava llama
from llava.model.language_model.llava_llama import ( # noqa
Expand Down Expand Up @@ -137,42 +137,23 @@ def build_model(self):
self.vision_tower = model.model.vision_tower.eval()
self.mm_projector = model.model.mm_projector.eval()

def encode_images(self, images: torch.Tensor) -> torch.Tensor:
"""encode images."""
image_features = self.vision_tower(images)
image_features = self.mm_projector(image_features)
return image_features

def preprocess(
self,
images: List[Image]) -> Union[torch.Tensor, List[torch.Tensor]]:
"""preprocess."""
# TODO: gpu processor
from llava.mm_utils import process_images
images = [x.convert('RGB') for x in images]
image_processor = self.vision_tower.image_processor
outputs = process_images(images, image_processor, self.config)
return outputs
def preprocess(self, messages: List[Dict]) -> List[Dict]:
"""refer to `super().preprocess() for spec."""
return super().preprocess(messages)

@torch.no_grad()
def forward(self, images: List[Image]) -> List[torch.Tensor]:
"""forward."""
images = self.preprocess(images)
if isinstance(images, list):
images = [
x.to(self.vision_tower.device, dtype=torch.float16)
for x in images
]
else:
images = images.to(self.vision_tower.device, dtype=torch.float16)

if type(images) is list or images.ndim == 5:
concat_images = torch.cat([image for image in images], dim=0)
image_features = self.encode_images(concat_images)
split_sizes = [image.shape[0] for image in images]
def forward(self, inputs: List[Dict]) -> List[torch.Tensor]:
pixel_values = [x['pixel_values'] for x in inputs]
split_sizes = [x.shape[0] for x in pixel_values]
pixel_values = torch.cat(pixel_values, dim=0)
pixel_values = pixel_values.to(device=self.vision_tower.device,
dtype=torch.float16)

if pixel_values.ndim == 5:
image_features = self.encode_images(pixel_values)
image_features = torch.split(image_features, split_sizes, dim=0)
image_features = [x.flatten(0, 1) for x in image_features]
else:
image_features = self.encode_images(images)
image_features = self.encode_images(pixel_values)
image_features = [x for x in image_features]
return image_features
6 changes: 2 additions & 4 deletions lmdeploy/vl/model/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,8 @@ def check_llava_install():

def _clip_vision_tower_load_model(self, **kwargs):
logger.info(f'CLIPVisionTower.load_model: {self.vision_tower_name}')
from transformers import (CLIPImageProcessor, CLIPVisionConfig,
CLIPVisionModel)
self.image_processor = CLIPImageProcessor.from_pretrained(
self.vision_tower_name)
from transformers import CLIPVisionConfig, CLIPVisionModel

config = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
self.vision_tower = CLIPVisionModel._from_config(config=config)
self.vision_tower.requires_grad_(False)
Expand Down

0 comments on commit 38eec0d

Please sign in to comment.