Skip to content

Commit

Permalink
read the img size from the config (#332)
Browse files Browse the repository at this point in the history
Co-authored-by: baishihao <[email protected]>
  • Loading branch information
shihaobai and baishihao authored Feb 20, 2024
1 parent 2e6f961 commit 163e845
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 7 deletions.
2 changes: 2 additions & 0 deletions lightllm/models/llava/llava_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def load_model(self, weight_dir):
# load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
vision_path = config.get('mm_vision_tower', 'openai/clip-vit-large-patch14-336')
if isinstance(vision_path, list):
vision_path = vision_path[0]
if vision_path.startswith("./"):
vision_path = os.path.join(weight_dir, vision_path)

Expand Down
16 changes: 13 additions & 3 deletions lightllm/models/llava/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import json
import numpy as np
from lightllm.models.llama.model import LlamaTpPartModel
Expand All @@ -8,11 +9,20 @@
# Warp of the origal tokenizer
class LlavaTokenizer:

def __init__(self, tokenizer):
def __init__(self, tokenizer, model_cfg):
self.tokenizer = tokenizer
self.image_token = "<image>"
# (image_size // patch_size) ** 2: (336 // 14) ** 2
self.image_length = 576
mm_vision_tower = model_cfg.get('mm_vision_tower', 'openai/clip-vit-large-patch14-336')
if isinstance(mm_vision_tower, list):
mm_vision_tower = mm_vision_tower[0]
mm_vision_tower = mm_vision_tower.split('/')[-1]
vision_tower_match = re.match(r'^clip-vit-large-patch(\d+)-(\d+)$', mm_vision_tower)
patch_size = int(vision_tower_match.group(1))
default_img_size = int(vision_tower_match.group(2))
image_size = model_cfg.get("img_size", default_img_size)
image_size = model_cfg.get("mm_image_size", image_size)
# (image_size // patch_size) ** 2: (336 // 14) ** 2 = 576
self.image_length = (image_size // patch_size) ** 2

# only change the impl of the encode func:
def encode(self, prompt, multimodal_params: MultimodalParams = None):
Expand Down
4 changes: 2 additions & 2 deletions lightllm/models/qwen_vl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# Warp of the origal tokenizer
class QWenVLTokenizer:

def __init__(self, tokenizer):
def __init__(self, tokenizer, model_cfg):
self.tokenizer = tokenizer
# <img>: 151857
self.image_start_tag = tokenizer.image_start_tag
Expand All @@ -18,7 +18,7 @@ def __init__(self, tokenizer):
self.image_end_tag = tokenizer.image_end_tag
self.image_end_id = tokenizer.img_end_id
# <imgpad>: 151859
self.image_length = 256
self.image_length = model_cfg['visual'].get("n_queries", 256)

def _list_find(self, input_list, target, start_idx):
cur_list = input_list[start_idx:]
Expand Down
4 changes: 2 additions & 2 deletions lightllm/server/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def get_tokenizer(

model_cfg, _ = PretrainedConfig.get_config_dict(tokenizer_name)
if model_cfg["model_type"] == "llava":
tokenizer = LlavaTokenizer(tokenizer)
tokenizer = LlavaTokenizer(tokenizer, model_cfg)
elif model_cfg["model_type"] == "qwen" and "visual" in model_cfg:
tokenizer = QWenVLTokenizer(tokenizer)
tokenizer = QWenVLTokenizer(tokenizer, model_cfg)

if not isinstance(tokenizer, PreTrainedTokenizerFast):
logger.info(
Expand Down

0 comments on commit 163e845

Please sign in to comment.