diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 019a617acd..45d7e8171a 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -15,7 +15,7 @@ from lmdeploy.archs import get_task from lmdeploy.messages import (GenerationConfig, LogitsProcessor, - PytorchEngineConfig, TurbomindEngineConfig) + PytorchEngineConfig, TurbomindEngineConfig, VisionConfig) from lmdeploy.model import ChatTemplateConfig from lmdeploy.serve.async_engine import AsyncEngine from lmdeploy.serve.openai.protocol import ( # noqa: E501 @@ -1054,6 +1054,7 @@ def serve(model_path: str, _, pipeline_class = get_task(model_path) + vision_config = VisionConfig(kwargs.get("vision_max_batch_size", 1)) VariableInterface.async_engine = pipeline_class( model_path=model_path, model_name=model_name, @@ -1061,6 +1062,7 @@ def serve(model_path: str, backend_config=backend_config, chat_template_config=chat_template_config, max_log_len=max_log_len, + vision_config=vision_config, **kwargs) if proxy_url is not None: diff --git a/lmdeploy/vl/model/llava_hf.py b/lmdeploy/vl/model/llava_hf.py index 66faf4f467..4be3c10387 100644 --- a/lmdeploy/vl/model/llava_hf.py +++ b/lmdeploy/vl/model/llava_hf.py @@ -24,6 +24,33 @@ def build_model(self): warnings.simplefilter('ignore') from transformers import LlavaForConditionalGeneration model = LlavaForConditionalGeneration._from_config(self.hf_config) + + if not getattr(model.config, "tie_word_embeddings", False): + if not self.with_llm: + del model.language_model + for key in ['language_model']: + setattr(model, key, None) + else: + self.vl_model = model + with disable_logging(): + load_checkpoint_and_dispatch( + model=model, + max_memory=self.max_memory, + checkpoint=self.model_path, + device_map='auto' if not self.with_llm else {'': 'cpu'}, + no_split_module_classes=['CLIPEncoderLayer', 'SiglipEncoderLayer'], + dtype=torch.half) + else: + # fix for llava-hf/llava-interleave-qwen-7b-hf + # we have to remove llm after init model for using call llm.get_output_embedding() + with disable_logging(): + load_checkpoint_and_dispatch( + model=model, + max_memory=self.max_memory, + checkpoint=self.model_path, + device_map='auto' if not self.with_llm else {'': 'cpu'}, + no_split_module_classes=['CLIPEncoderLayer', 'SiglipEncoderLayer'], + dtype=torch.half) if not self.with_llm: del model.language_model for key in ['language_model']: @@ -31,14 +58,6 @@ def build_model(self): else: self.vl_model = model - with disable_logging(): - load_checkpoint_and_dispatch( - model=model, - max_memory=self.max_memory, - checkpoint=self.model_path, - device_map='auto' if not self.with_llm else {'': 'cpu'}, - no_split_module_classes=['CLIPEncoderLayer'], - dtype=torch.half) model.eval() self.model = model # processor