Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[side effect] fix vlm quant failed #2914

Merged
merged 1 commit into from
Dec 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions lmdeploy/lite/apis/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from lmdeploy.lite.quantization import CalibrationContext, CalibrationContextV2
from lmdeploy.lite.utils import (collect_target_modules, get_calib_loaders,
load_hf_from_pretrained)
from lmdeploy.vl.model.builder import load_vl_model

LAYER_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMDecoderLayer',
Expand Down Expand Up @@ -243,18 +244,20 @@ def calibrate(model: str,
# Load tokenizer and configuration
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)

model = load_hf_from_pretrained(model,
torch_dtype=torch.float16,
trust_remote_code=True)
vl_model = None
if model_type == 'vlm':
vl_model = model
if hasattr(model, 'language_model'):
model = model.language_model
if hasattr(model, 'llm'):
model = model.llm
if model_type == 'llm':
model = load_hf_from_pretrained(model,
torch_dtype=torch.float16,
trust_remote_code=True)
vl_model = None
elif model_type == 'vlm':
vl_model = load_vl_model(model, backend=None, with_llm=True).vl_model
model = vl_model
if hasattr(vl_model, 'language_model'): # deepseek vl
model = vl_model.language_model
if hasattr(vl_model, 'llm'): # MiniCPMV
model = vl_model.llm
model.config.use_cache = False
model = model.half().eval()
model.half().eval()

model_type = type(model).__name__
if model_type not in LAYER_TYPE_MAP or model_type not in NORM_TYPE_MAP:
Expand Down
8 changes: 5 additions & 3 deletions lmdeploy/vl/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ class VisonModel(ABC):

def __init__(self,
model_path: str,
with_llm: bool = False,
max_memory: Dict[int, int] = None,
hf_config: AutoConfig = None,
backend: str = ''):
"""init."""
self.model_path = model_path
self.with_llm = with_llm
self.max_memory = max_memory
self.backend = backend
if hf_config is None:
Expand All @@ -38,11 +40,11 @@ def build_preprocessor(self, ):
raise NotImplementedError()

def build_model(self, ):
"""build model.
"""build the vision part of a VLM model when backend is turbomind.

ONLY implement it when the backend is turbomind engine
But when `with_llm=True`, load the whole VLM model
"""
if self.backend == 'turbomind':
if self.backend == 'turbomind' or self.with_llm:
raise NotImplementedError()

@abstractmethod
Expand Down
13 changes: 10 additions & 3 deletions lmdeploy/vl/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,16 @@

def load_vl_model(model_path: str,
backend: str,
with_llm: bool = False,
backend_config: Optional[Union[TurbomindEngineConfig,
PytorchEngineConfig]] = None):
"""load visual model.

Args:
model_path(str): the path or repo_id from model hub of the model
backend(str): the name of inference backend
with_llm(bool): load LLM model or not. Set it to False for VLM
inference scenarios and True for VLM quantization
backend_config: the config of the inference engine
"""
if not os.path.exists(model_path):
Expand All @@ -49,11 +52,13 @@ def load_vl_model(model_path: str,
download_dir=download_dir)

max_memory = None
tp = getattr(backend_config, 'tp', 1)
max_memory = {i: torch.cuda.mem_get_info(i)[0] for i in range(tp)}
if not with_llm:
tp = getattr(backend_config, 'tp', 1)
max_memory = {i: torch.cuda.mem_get_info(i)[0] for i in range(tp)}

_, hf_config = get_model_arch(model_path)
kwargs = dict(model_path=model_path,
with_llm=with_llm,
max_memory=max_memory,
hf_config=hf_config,
backend=backend)
Expand All @@ -63,7 +68,9 @@ def load_vl_model(model_path: str,
logger.info(f'matching vision model: {name}')
model = module(**kwargs)
model.build_preprocessor()
if backend == 'turbomind':
# build the vision part of a VLM model when backend is
# turbomind, or load the whole VLM model when `with_llm==True`
if backend == 'turbomind' or with_llm:
model.build_model()
return model
except Exception:
Expand Down
16 changes: 11 additions & 5 deletions lmdeploy/vl/model/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,15 @@ def build_preprocessor(self):
self.model_path).image_processor

def build_model(self):
"""build the vision part of a VLM model when backend is turbomind, or
load the whole VLM model when `self.with_llm==True`"""
from accelerate import init_empty_weights
with init_empty_weights():
warnings.simplefilter('ignore')
model = AutoModelForCausalLM.from_pretrained(self.model_path)
del model.language_model
self.vl_model = model
if not self.with_llm:
del model.language_model

from accelerate.utils import get_balanced_memory, infer_auto_device_map
max_memory = get_balanced_memory(model,
Expand Down Expand Up @@ -74,11 +78,13 @@ def build_model(self):

from accelerate import load_checkpoint_and_dispatch
with disable_logging():
load_checkpoint_and_dispatch(model=model,
checkpoint=self.model_path,
device_map=device_map,
dtype=torch.half)
load_checkpoint_and_dispatch(
model=model,
checkpoint=self.model_path,
device_map=device_map if not self.with_llm else {'': 'cpu'},
dtype=torch.half)

self.model = model.eval()
self.vision_model = model.vision_model.eval()
self.aligner = model.aligner.eval()

Expand Down
12 changes: 8 additions & 4 deletions lmdeploy/vl/model/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,11 @@ class InternVLVisionModel(VisonModel):

def __init__(self,
model_path: str,
with_llm: bool = False,
max_memory: Dict[int, int] = None,
hf_config: AutoConfig = None,
backend: str = ''):
super().__init__(model_path, max_memory, hf_config, backend)
super().__init__(model_path, with_llm, max_memory, hf_config, backend)

def build_preprocessor(self):
self.config = self.hf_config
Expand Down Expand Up @@ -124,21 +125,24 @@ def build_preprocessor(self):
(force_image_size // patch_size)**2 * (downsample_ratio**2))

def build_model(self):
"""Load model."""
"""build the vision part of a VLM model when backend is turbomind, or
load the whole VLM model when `self.with_llm==True`"""
from accelerate import init_empty_weights
with init_empty_weights():
# transformers below 4.37.0 may raise error about flash_attn
self.config.llm_config.attn_implementation = 'eager'
model = AutoModel.from_config(self.config, trust_remote_code=True)
del model.language_model
self.vl_model = model
if not self.with_llm:
del model.language_model

model.half()
from accelerate import load_checkpoint_and_dispatch
with disable_logging():
load_checkpoint_and_dispatch(
model=model,
checkpoint=self.model_path,
device_map='auto',
device_map='auto' if not self.with_llm else {'': 'cpu'},
max_memory=self.max_memory,
no_split_module_classes=['InternVisionEncoderLayer'],
dtype=torch.half)
Expand Down
15 changes: 9 additions & 6 deletions lmdeploy/vl/model/internvl_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def build_preprocessor(self):
return super().build_preprocessor()

def build_model(self):
"""build model & load weights."""
"""build the vision part of a VLM model when backend is turbomind, or
load the whole VLM model when `self.with_llm==True`"""
check_llava_install()
# currently, only support llava llama
from llava.model.language_model.llava_llama import ( # noqa
Expand All @@ -98,10 +99,12 @@ def build_model(self):
} # disable vision part quantization
model = AutoModelForCausalLM.from_config(self.config,
trust_remote_code=True)
del model.lm_head
del model.model.embed_tokens
del model.model.layers
del model.model.norm
self.vl_model = model
if not self.with_llm:
del model.lm_head
del model.model.embed_tokens
del model.model.layers
del model.model.norm

with init_empty_vit():
vision_tower = model.get_vision_tower()
Expand All @@ -126,7 +129,7 @@ def build_model(self):
model=model,
max_memory=self.max_memory,
checkpoint=self.model_path,
device_map='auto',
device_map='auto' if not self.with_llm else {'': 'cpu'},
no_split_module_classes=['InternVisionEncoderLayer'],
dtype=torch.half)

Expand Down
17 changes: 10 additions & 7 deletions lmdeploy/vl/model/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ def build_preprocessor(self):
self.n_token_per_image += 1

def build_model(self):
"""build model & load weights."""
"""build the vision part of a VLM model when backend is turbomind, or
load the whole VLM model when `self.with_llm==True`"""
check_llava_install()

self.arch = self.hf_config.architectures[0]
Expand Down Expand Up @@ -271,11 +272,13 @@ def build_model(self):
model = AutoModelForCausalLM.from_config(self.config,
trust_remote_code=True)

# remove the LLM part from llava model.
del model.lm_head
del model.model.embed_tokens
del model.model.layers
del model.model.norm
self.vl_model = model
if not self.with_llm:
# remove the LLM part from llava model.
del model.lm_head
del model.model.embed_tokens
del model.model.layers
del model.model.norm

# init empty vision_tower, the embedding layer in CLIPVisionModel
# can't init right under init_empty_weights
Expand All @@ -292,7 +295,7 @@ def build_model(self):
model=model,
max_memory=self.max_memory,
checkpoint=self.model_path,
device_map='auto',
device_map='auto' if not self.with_llm else {'': 'cpu'},
no_split_module_classes=['CLIPEncoderLayer'],
dtype=torch.half)

Expand Down
24 changes: 14 additions & 10 deletions lmdeploy/vl/model/llava_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,26 +32,30 @@ def build_preprocessor(self):
self.n_token_per_image += 1

def build_model(self):
"""build the vision part of a VLM model when backend is turbomind, or
load the whole VLM model when `self.with_llm==True`"""
from accelerate import init_empty_weights, load_checkpoint_and_dispatch

with init_empty_weights(), warnings.catch_warnings():
warnings.simplefilter('ignore')
from transformers import LlavaForConditionalGeneration
model = LlavaForConditionalGeneration._from_config(self.hf_config)
del model.language_model
self.vl_model = model
if not self.with_llm:
del model.language_model

# fix for llava-hf/llava-interleave-qwen-7b-hf
setattr(model.config, 'tie_word_embeddings', False)
with disable_logging():
load_checkpoint_and_dispatch(model=model,
max_memory=self.max_memory,
checkpoint=self.model_path,
device_map='auto',
no_split_module_classes=[
'CLIPEncoderLayer',
'SiglipEncoderLayer'
],
dtype=torch.half)
load_checkpoint_and_dispatch(
model=model,
max_memory=self.max_memory,
checkpoint=self.model_path,
device_map='auto' if not self.with_llm else {'': 'cpu'},
no_split_module_classes=[
'CLIPEncoderLayer', 'SiglipEncoderLayer'
],
dtype=torch.half)
model.eval()
self.model = model

Expand Down
8 changes: 6 additions & 2 deletions lmdeploy/vl/model/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@ def build_preprocessor(self):
from transformers import LlavaNextForConditionalGeneration
self.model = LlavaNextForConditionalGeneration._from_config(
self.hf_config)
del self.model.language_model
self.vl_model = self.model
if not self.with_llm:
del self.model.language_model

def build_model(self):
"""build the vision part of a VLM model when backend is turbomind, or
load the whole VLM model when `self.with_llm==True`"""
from accelerate import load_checkpoint_and_dispatch
from accelerate.utils import get_balanced_memory, infer_auto_device_map

Expand Down Expand Up @@ -58,7 +62,7 @@ def build_model(self):
load_checkpoint_and_dispatch(
model=self.model,
checkpoint=self.model_path,
device_map=device_map,
device_map=device_map if not self.with_llm else {'': 'cpu'},
no_split_module_classes=no_split_module_classes,
dtype=torch.half)
self.model.eval()
Expand Down
14 changes: 9 additions & 5 deletions lmdeploy/vl/model/mini_gemeni.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def build_preprocessor(self):
pass

def build_model(self):
"""build the vision part of a VLM model when backend is turbomind, or
load the whole VLM model when `self.with_llm==True`"""
check_mini_gemini_install()
# empty init
from accelerate import init_empty_weights
Expand All @@ -201,10 +203,12 @@ def build_model(self):
vision_tower.load_model()
vision_tower_aux = model.get_vision_tower_aux()
vision_tower_aux.load_model()
del model.lm_head
del model.model.embed_tokens
del model.model.layers
del model.model.norm
self.vl_model = model
if not self.with_llm:
del model.lm_head
del model.model.embed_tokens
del model.model.layers
del model.model.norm

from accelerate.utils import get_balanced_memory, infer_auto_device_map
max_memory = get_balanced_memory(
Expand All @@ -230,7 +234,7 @@ def build_model(self):
load_checkpoint_and_dispatch(
model=model,
checkpoint=self.model_path,
device_map=device_map,
device_map=device_map if not self.with_llm else {'': 'cpu'},
no_split_module_classes=['CLIPEncoderLayer', 'ConvNeXtStage'],
dtype=torch.half)

Expand Down
Loading