From 4265598babb3817da16fd1822a05c624d2bc638e Mon Sep 17 00:00:00 2001 From: fanxinyao Date: Mon, 17 Mar 2025 12:11:38 +0800 Subject: [PATCH 1/5] support internvl and MMMU dataset evaluation --- benchmark/mmmu/bench_hf.py | 66 +-- benchmark/mmmu/bench_sglang.py | 57 +- benchmark/mmmu/data_utils.py | 2 +- benchmark/mmmu/eval_utils.py | 25 +- benchmark/mmmu/internvl_chat.py | 205 +++++++ benchmark/mmmu/prompt_format.yaml | 10 +- benchmark/mmmu/qwen2vl_chat.py | 131 +++++ python/sglang/srt/configs/model_config.py | 1 + python/sglang/srt/conversation.py | 25 +- python/sglang/srt/hf_transformers_utils.py | 14 +- .../srt/managers/image_processors/internvl.py | 262 +++++++++ python/sglang/srt/managers/scheduler.py | 4 +- .../sglang/srt/managers/tokenizer_manager.py | 4 +- python/sglang/srt/managers/tp_worker.py | 4 +- python/sglang/srt/models/internvl.py | 268 +++++++++ python/sglang/srt/models/internvl_vit.py | 554 ++++++++++++++++++ test/srt/test_vision_openai_server.py | 16 +- 17 files changed, 1511 insertions(+), 137 deletions(-) create mode 100644 benchmark/mmmu/internvl_chat.py create mode 100644 benchmark/mmmu/qwen2vl_chat.py create mode 100644 python/sglang/srt/managers/image_processors/internvl.py create mode 100644 python/sglang/srt/models/internvl.py create mode 100644 python/sglang/srt/models/internvl_vit.py diff --git a/benchmark/mmmu/bench_hf.py b/benchmark/mmmu/bench_hf.py index 0a237b07b58..872f97a516a 100644 --- a/benchmark/mmmu/bench_hf.py +++ b/benchmark/mmmu/bench_hf.py @@ -9,6 +9,7 @@ import argparse import random +import re import torch from data_utils import save_json @@ -16,85 +17,32 @@ EvalArgs, eval_result, get_sampling_params, + load_model, prepare_samples, process_result, ) +from Qwen2VLchat import Qwen2VLchat from tqdm import tqdm -from transformers import AutoModelForImageTextToText, AutoProcessor, GenerationConfig @torch.no_grad() def eval_mmmu(args): eval_args = EvalArgs.from_cli_args(args) - - model = AutoModelForImageTextToText.from_pretrained( - args.model_path, - torch_dtype="auto", - trust_remote_code=True, - ) - model = model.eval().cuda() - - processor = AutoProcessor.from_pretrained( - args.model_path, torch_dtype="auto", device_map="auto" - ) - + model = load_model(args.model_path) + model.build_model() samples = prepare_samples(eval_args) out_samples = dict() - - sampling_params = get_sampling_params(eval_args) - generation_config = GenerationConfig( - max_new_tokens=sampling_params["max_new_tokens"], - do_sample=False, - ) - answer_dict = {} for sample in tqdm(samples): - prompt = sample["final_input_prompt"] - image = sample["image"] - prefix = prompt.split("<")[0] - suffix = prompt.split(">")[1] + image = sample["image_1"] if image is not None: - messages = [ - { - "role": "user", - "content": [ - {"type": "text", "text": prefix}, - { - "type": "image", - "image": image, - }, - {"type": "text", "text": suffix}, - ], - } - ] - text = processor.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - inputs = processor( - text=[text], - images=[image], - padding=True, - return_tensors="pt", - ).to(model.device) - - generated_ids = model.generate( - **inputs, generation_config=generation_config - ) - - response = processor.decode( - generated_ids[0], - skip_special_tokens=True, - clean_up_tokenization_spaces=False, - )[len(text) :] - print(f"response: {response}") + response = model.chat(sample) else: # multiple images actually if sample["question_type"] == "multiple-choice": all_choices = sample["all_choices"] response = random.choice(all_choices) - else: response = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS" - process_result(response, sample, answer_dict, out_samples) args.output_path = f"{args.model_path}_val_hf.json" diff --git a/benchmark/mmmu/bench_sglang.py b/benchmark/mmmu/bench_sglang.py index ba03dced367..18c6d90fb08 100644 --- a/benchmark/mmmu/bench_sglang.py +++ b/benchmark/mmmu/bench_sglang.py @@ -18,6 +18,7 @@ EvalArgs, eval_result, get_sampling_params, + load_model, prepare_samples, process_result, ) @@ -35,64 +36,30 @@ def eval_mmmu(args): if server_args.chat_template is None: raise ValueError("Chat template must be provided for this benchmark") - + model = load_model(args.model_path) backend = Engine(**dataclasses.asdict(server_args)) - out_samples = dict() - - sampling_params = get_sampling_params(eval_args) - samples = prepare_samples(eval_args) answer_dict = {} for sample in tqdm(samples): - prompt = sample["final_input_prompt"] - image = sample["image"] - buff = BytesIO() - image.save(buff, format="PNG") - base64_str = base64.b64encode(buff.getvalue()).decode("utf-8") - prefix = prompt.split("<")[0] - suffix = prompt.split(">")[1] - request_dict = { - "model": "", - "messages": [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": prefix, - }, - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{base64_str}" - }, - }, - { - "type": "text", - "text": suffix, - }, - ], - } - ], - } - - conv = generate_chat_conv( - ChatCompletionRequest(**request_dict), - template_name=server_args.chat_template, - ) - prompt = conv.get_prompt() + image = sample["image_1"] if image is not None: + request_dict = model.build_prompt_sglang(sample) + conv = generate_chat_conv( + ChatCompletionRequest(**request_dict), + template_name=server_args.chat_template, + ) + prompt = conv.get_prompt() + print(f"\033[31m{prompt}\033[0m") gen_out = backend.generate( prompt=prompt, image_data=conv.image_data, - sampling_params=sampling_params, + sampling_params=model.sampling_params, )["text"] - response = gen_out - + print(f"\033[32m{response}\033[0m") else: # multiple images actually if sample["question_type"] == "multiple-choice": all_choices = sample["all_choices"] diff --git a/benchmark/mmmu/data_utils.py b/benchmark/mmmu/data_utils.py index 197e906383e..33446522a4a 100644 --- a/benchmark/mmmu/data_utils.py +++ b/benchmark/mmmu/data_utils.py @@ -187,7 +187,7 @@ def construct_prompt(sample, config): index2ans = {} for option in options: prediction_range.append(start_chr) - example += f"({start_chr}) {option}\n" + example += f"{start_chr}. {option}\n" index2ans[start_chr] = option start_chr = chr(ord(start_chr) + 1) empty_prompt_sample_structure = config["multi_choice_example_format"] diff --git a/benchmark/mmmu/eval_utils.py b/benchmark/mmmu/eval_utils.py index 6daf5db6f84..25709e82aed 100644 --- a/benchmark/mmmu/eval_utils.py +++ b/benchmark/mmmu/eval_utils.py @@ -20,13 +20,15 @@ ) from datasets import concatenate_datasets, load_dataset +from internvl_chat import InternVLChat +from qwen2vl_chat import Qwen2VLChat + @dataclasses.dataclass class EvalArgs: backend: str = "engine" seed: int = 42 split: str = "validation" - # Default setting to make the benchmark available on A100 for most 7B models image_pixels_limit: int = 4300000 result_filename: str = "" prompt_format_file: str = "prompt_format.yaml" @@ -39,7 +41,6 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--result-filename", type=str, default=EvalArgs.result_filename ) - parser.add_argument( "--image-pixels-limit", type=int, default=EvalArgs.image_pixels_limit ) @@ -122,18 +123,8 @@ def prepare_samples(eval_args: EvalArgs): samples = [] skip_count = 0 for i, sample in enumerate(dataset): - sample = process_single_sample(sample) sample = construct_prompt(sample, eval_args.config) - image = sample["image"] - width, height = image.size - if width * height >= eval_args.image_pixels_limit: - skip_count += 1 - continue samples.append(sample) - - print( - f"skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset" - ) return samples @@ -548,3 +539,13 @@ def eval_result(model_answer_path, answer_dict): print(f"eval out saved to {out}") print(f"Overall accuracy: {overall_acc}") + + +def load_model(path): + if "Qwen2-VL" in path: + model = Qwen2VLChat(path) + elif "InternVL" in path: + model = InternVLChat(path) + else: + raise Exception("This model is not supported yet.") + return model diff --git a/benchmark/mmmu/internvl_chat.py b/benchmark/mmmu/internvl_chat.py new file mode 100644 index 00000000000..d849f18fbbf --- /dev/null +++ b/benchmark/mmmu/internvl_chat.py @@ -0,0 +1,205 @@ +import base64 +import re +from io import BytesIO + +import torch +import torchvision.transforms as T +from PIL import Image +from torchvision.transforms.functional import InterpolationMode +from transformers import AutoModel, AutoProcessor, AutoTokenizer + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + + +def build_transform(input_size): + MEAN, STD = IMAGENET_MEAN, IMAGENET_STD + transform = T.Compose( + [ + T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD), + ] + ) + return transform + + +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float("inf") + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + +def dynamic_preprocess( + image, min_num=1, max_num=12, image_size=448, use_thumbnail=False +): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num + ) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size + ) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + + +def load_image(image_file, input_size=448, max_num=12, upscale=False): + image = image_file.convert("RGB") + if upscale: + image = image.resize((image.width * 2, image.height * 2), Image.BILINEAR) + transform = build_transform(input_size=input_size) + images = dynamic_preprocess( + image, image_size=input_size, use_thumbnail=True, max_num=max_num + ) + pixel_values = [transform(image) for image in images] + pixel_values = torch.stack(pixel_values) + return pixel_values + + +class InternVLChat: + def __init__( + self, + model_path: str, + max_new_tokens=2048, + temperature=0.01, + repetition_penalty=1.0, + ): + self.model_path = model_path + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_path, trust_remote_code=True, use_fast=False + ) + self.generate_kwargs = dict( + max_new_tokens=max_new_tokens, + top_p=None, + do_sample=False, + temperature=temperature, + ) + self.sampling_params = dict( + max_new_tokens=max_new_tokens, + stop_token_ids=[self.tokenizer.added_tokens_encoder["<|im_end|>"]], + temperature=temperature, + ) + + def build_model(self): + self.model = AutoModel.from_pretrained( + self.model_path, + torch_dtype=torch.bfloat16, + use_flash_attn=True, + trust_remote_code=True, + ) + self.model.eval().cuda() + torch.cuda.empty_cache() + + def build_prompt_hf(self, sample): + prompt = sample["final_input_prompt"] + count = len(sorted(set(int(m) for m in re.findall(r"", prompt)))) + num_patches_list = [] + if count == 1: + prompt = "\n" + prompt + pixel_values = ( + load_image(sample["image_1"], upscale=True).to(torch.bfloat16).cuda() + ) + num_patches_list = [pixel_values.size(0)] + else: + pixel_values_list = [] + for idx in range(1, count + 1): + prompt = f"Image-{count+1-idx}: \n" + prompt + pixel_values = load_image(sample[f"image_{idx}"], upscale=True).to( + torch.bfloat16 + ) + num_patches_list.append(pixel_values.size(0)) + pixel_values_list.append(pixel_values) + pixel_values = torch.cat(pixel_values_list, dim=0).cuda() + print(f"\033[31m{prompt}\033[0m") + return { + "prompt": prompt, + "pixel_values": pixel_values, + "num_patches_list": num_patches_list, + } + + def build_prompt_sglang(self, sample): + prompt = sample["final_input_prompt"] + count = len(sorted(set(int(m) for m in re.findall(r"", prompt)))) + request_dict = { + "model": "", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt, + }, + ], + } + ], + } + for idx in range(1, count + 1): + image = sample[f"image_{idx}"] + bytes_io = BytesIO() + image.save(bytes_io, format="PNG") + base64_str = base64.b64encode(bytes_io.getvalue()).decode("utf-8") + request_dict["messages"][0]["content"].append( + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{base64_str}"}, + } + ) + return request_dict + + def chat(self, sample): + inputs = self.build_prompt_hf(sample) + response = self.model.chat( + tokenizer=self.tokenizer, + pixel_values=inputs["pixel_values"], + num_patches_list=inputs["num_patches_list"], + question=inputs["prompt"], + generation_config=self.generate_kwargs, + verbose=True, + ) + # print(f'\033[32m{response}\033[0m') + return response diff --git a/benchmark/mmmu/prompt_format.yaml b/benchmark/mmmu/prompt_format.yaml index 1a0f7211fea..5126c129927 100644 --- a/benchmark/mmmu/prompt_format.yaml +++ b/benchmark/mmmu/prompt_format.yaml @@ -1,15 +1,9 @@ task_instructions: - "" multi_choice_example_format: -- "{} - -{} - -Answer with the option's letter from the given choices directly." +- "{}\n{}Answer with the option's letter from the given choices directly.\n" short_ans_example_format: -- "{} - -Answer the question using a single word or phrase." +- "{}Answer the question using a single word or phrase.\n" temperature: - 0 diff --git a/benchmark/mmmu/qwen2vl_chat.py b/benchmark/mmmu/qwen2vl_chat.py new file mode 100644 index 00000000000..0ac5c675839 --- /dev/null +++ b/benchmark/mmmu/qwen2vl_chat.py @@ -0,0 +1,131 @@ +import base64 +import re +from io import BytesIO + +import torch +from qwen_vl_utils import process_vision_info +from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor + + +class Qwen2VLChat: + def __init__( + self, + model_path: str, + max_new_tokens=2048, + top_p=0.001, + top_k=1, + temperature=0.01, + repetition_penalty=1.0, + ): + self.model_path = model_path + self.processor = Qwen2VLProcessor.from_pretrained(model_path) + self.generate_kwargs = dict( + max_new_tokens=max_new_tokens, + top_p=top_p, + top_k=top_k, + temperature=temperature, + repetition_penalty=repetition_penalty, + ) + self.sampling_params = dict( + max_new_tokens=max_new_tokens, + top_p=top_p, + stop_token_ids=[ + self.processor.tokenizer.added_tokens_encoder["<|im_end|>"] + ], + top_k=top_k, + temperature=temperature, + ) + + def build_model(self): + self.model = Qwen2VLForConditionalGeneration.from_pretrained( + self.model_path, + torch_dtype="auto", + device_map="cpu", + attn_implementation="flash_attention_2", + ) + self.model.cuda().eval() + torch.cuda.empty_cache() + + def build_prompt_hf(self, sample): + prompt = sample["final_input_prompt"] + image_count = len( + sorted(set(int(m) for m in re.findall(r"", prompt))) + ) + image = sample["image_1"] + if image is not None: + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Question: " + prompt}, + ], + } + ] + for i in range(1, image_count + 1): + messages[0]["content"].append( + { + "type": "image", + "image": sample["image_1"], + "min_pixels": 1003520, + "max_pixels": 12845056, + } + ) + print(f"\033[31m{messages}\033[0m") + text = self.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + images, videos = process_vision_info([messages]) + inputs = self.processor( + text=[text], + images=[images], + padding=True, + return_tensors="pt", + ).to(self.model.device) + return inputs + + def build_prompt_sglang(self, sample): + prompt = sample["final_input_prompt"] + count = len(sorted(set(int(m) for m in re.findall(r"", prompt)))) + request_dict = { + "model": "", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Question: " + prompt, + }, + ], + } + ], + } + for idx in range(1, count + 1): + image = sample[f"image_{idx}"] + bytes_io = BytesIO() + image.save(bytes_io, format="PNG") + base64_str = base64.b64encode(bytes_io.getvalue()).decode("utf-8") + request_dict["messages"][0]["content"].append( + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{base64_str}"}, + } + ) + return request_dict + + def chat(self, sample): + inputs = self.build_prompt_hf(sample) + generated_ids = self.model.generate( + **inputs, + **self.generate_kwargs, + ) + generated_ids = [ + output_ids[len(input_ids) :] + for input_ids, output_ids in zip(inputs.input_ids, generated_ids) + ] + out = self.processor.tokenizer.batch_decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + response = out[0] + print(f"\033[32m{response}\033[0m") + return response diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 22174c922d0..07a6eaa115f 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -465,6 +465,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal "Qwen2_5_VLForConditionalGeneration", "MiniCPMV", "MultiModalityCausalLM", + "InternVLChatModel" ] diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index 6255126be7b..36d54d4cd54 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -46,7 +46,7 @@ class SeparatorStyle(IntEnum): METAMATH = auto() QWEN2_VL_EMBED = auto() GEMMA3 = auto() - + MPT = auto() @dataclasses.dataclass class Conversation: @@ -297,7 +297,16 @@ def get_prompt(self) -> str: else: ret += role return ret - + elif self.sep_style == SeparatorStyle.MPT: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + self.sep + else: + ret += role + return ret else: raise ValueError(f"Invalid style: {self.sep_style}") @@ -673,3 +682,15 @@ def generate_chat_conv( image_token="", ) ) +# Reference: https://huggingface.co/OpenGVLab/InternVL2_5-38B#inference-with-transformers +register_conv_template( + Conversation( + name="internvl2_5", + system_template="<|im_start|>system\n{system_message}", + system_message="你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>\n", + stop_str=["<|im_end|>", "<|action_end|>"], + ) +) \ No newline at end of file diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 987cc98dcaf..c2af43ba595 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -26,6 +26,7 @@ AutoTokenizer, PretrainedConfig, PreTrainedTokenizer, + PreTrainedTokenizerBase, PreTrainedTokenizerFast, ) from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES @@ -86,7 +87,10 @@ def get_config( for key, val in config.language_config.__dict__.items(): setattr(config, key, val) setattr(config, "architectures", ["MultiModalityCausalLM"]) - + if isinstance(model, str) and "InternVL" in model: + for key, val in config.llm_config.__dict__.items(): + if not hasattr(config, key): + setattr(config, key, val) if config.model_type in _CONFIG_REGISTRY: config_class = _CONFIG_REGISTRY[config.model_type] config = config_class.from_pretrained(model, revision=revision) @@ -212,6 +216,10 @@ def get_tokenizer( attach_additional_stop_token_ids(tokenizer) return tokenizer +def get_tokenizer_from_processor(processor): + if isinstance(processor, PreTrainedTokenizerBase): + return processor + return processor.tokenizer def get_processor( tokenizer_name: str, @@ -228,8 +236,8 @@ def get_processor( tokenizer_revision=tokenizer_revision, **kwargs, ) - - attach_additional_stop_token_ids(processor.tokenizer) + tokenizer = get_tokenizer_from_processor(processor) + attach_additional_stop_token_ids(tokenizer) return processor diff --git a/python/sglang/srt/managers/image_processors/internvl.py b/python/sglang/srt/managers/image_processors/internvl.py new file mode 100644 index 00000000000..3df8db649f2 --- /dev/null +++ b/python/sglang/srt/managers/image_processors/internvl.py @@ -0,0 +1,262 @@ +import asyncio +import math +from typing import List, Union + +import numpy as np +import torch +from PIL import Image + +from sglang.srt.managers.image_processor import BaseImageProcessor +from sglang.srt.managers.image_processors.base_image_processor import ( + get_global_processor, +) +from sglang.srt.models.internvl import InternVLChatModel +from sglang.srt.utils import load_image + + +# Compatible with InternVL +class InternVLImageProcessor(BaseImageProcessor): + def __init__(self, hf_config, server_args, _image_processor): + super().__init__(hf_config, server_args, _image_processor) + self._image_processor = _image_processor + image_size = hf_config.force_image_size or hf_config.vision_config.image_size + patch_size = hf_config.vision_config.patch_size + + self.IMG_CONTEXT_TOKEN = "" + self.IMG_START_TOKEN = "" + self.IMG_END_TOKEN = "" + self.num_image_token = int( + (image_size // patch_size) ** 2 * (hf_config.downsample_ratio**2) + ) + + tokenizer = self._processor + self.img_start_token_id = tokenizer.convert_tokens_to_ids(self.IMG_START_TOKEN) + self.img_end_token_id = tokenizer.convert_tokens_to_ids(self.IMG_END_TOKEN) + self.img_context_token_id = tokenizer.convert_tokens_to_ids( + self.IMG_CONTEXT_TOKEN + ) + + @staticmethod + def _process_single_image_task( + image_data: Union[str, bytes], + image_processor=None, + ): + pass + + @staticmethod + def build_transform(input_size): + IMAGENET_MEAN = (0.485, 0.456, 0.406) + IMAGENET_STD = (0.229, 0.224, 0.225) + + def resize_image(img, size): + return img.resize((size, size), Image.Resampling.BICUBIC) + + def to_tensor(img): + # Convert PIL Image to numpy array + img_array = np.array(img).astype(np.float32) / 255.0 + # Convert HWC to CHW format + img_array = img_array.transpose(2, 0, 1) + return torch.from_numpy(img_array) + + def normalize(tensor, mean, std): + mean = torch.tensor(mean).view(-1, 1, 1) + std = torch.tensor(std).view(-1, 1, 1) + return (tensor - mean) / std + + def transform(img): + img = img.convert("RGB") if img.mode != "RGB" else img + img = resize_image(img, input_size) + tensor = to_tensor(img) + tensor = normalize(tensor, IMAGENET_MEAN, IMAGENET_STD) + return tensor + + return transform + + @staticmethod + def dynamic_preprocess( + image, min_num=1, max_num=12, image_size=448, use_thumbnail=False + ): + + def find_closest_aspect_ratio( + aspect_ratio, target_ratios, width, height, image_size + ): + best_ratio_diff = float("inf") + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num + ) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size + ) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + + @staticmethod + def get_index(bound, fps, max_frame, first_idx=0, num_segments=32): + if bound: + start, end = bound[0], bound[1] + else: + start, end = -100000, 100000 + start_idx = max(first_idx, round(start * fps)) + end_idx = min(round(end * fps), max_frame) + seg_size = float(end_idx - start_idx) / num_segments + frame_indices = np.array( + [ + int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) + for idx in range(num_segments) + ] + ) + return frame_indices + + @staticmethod + def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32): + vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) + max_frame = len(vr) - 1 + fps = float(vr.get_avg_fps()) + + pixel_values_list, num_patches_list = [], [] + transform = InternVLImageProcessor.build_transform(input_size=input_size) + frame_indices = InternVLImageProcessor.get_index( + bound, fps, max_frame, first_idx=0, num_segments=num_segments + ) + for frame_index in frame_indices: + img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB") + img = InternVLImageProcessor.dynamic_preprocess( + img, image_size=input_size, use_thumbnail=True, max_num=max_num + ) + pixel_values = [transform(tile) for tile in img] + pixel_values = torch.stack(pixel_values) + num_patches_list.append(pixel_values.shape[0]) + pixel_values_list.append(pixel_values) + pixel_values = torch.cat(pixel_values_list) + return pixel_values, num_patches_list + + async def process_images_async( + self, + image_data: List[Union[str, bytes]], + input_ids, + request_obj, + *args, + **kwargs, + ): + if not image_data: + return None + + tokenizer = self._processor + if isinstance(input_ids, list): + assert len(input_ids) and isinstance(input_ids[0], int) + input_text = tokenizer.decode(input_ids) + else: + input_text = input_ids + + image_hashes, image_sizes = [], [] + + all_frames = [] + + def load_image_internvl(image_file, input_size=448, max_num=12): + image, _size = load_image(image_file) + image = image.convert("RGB") + image = image.resize((image.width * 2, image.height * 2), Image.BILINEAR) + transform = InternVLImageProcessor.build_transform(input_size=input_size) + images = InternVLImageProcessor.dynamic_preprocess( + image, image_size=input_size, use_thumbnail=True, max_num=max_num + ) + pixel_values = [transform(image) for image in images] + pixel_values = torch.stack(pixel_values) + return pixel_values + + num_patches_list = [] + + # Process each input with allocated frames + for image_index, (image) in enumerate(image_data): + try: + if isinstance(image, str) and image.startswith("video:"): + path = image[len("video:") :] + pixel_values, num_patches_list_video = ( + InternVLImageProcessor.load_video(path) + ) + + frames = [pixel_values.to(torch.bfloat16)] + num_patches_list += num_patches_list_video + else: + raw_image = load_image_internvl(image) + frames = [raw_image.to(torch.bfloat16)] + num_patches = raw_image.shape[0] + num_patches_list += [num_patches] + + except FileNotFoundError as e: + print(e) + return None + image_hashes += [hash(image)] * len(frames) + all_frames += frames + + pixel_values = torch.cat(all_frames, dim=0) + for idx, num_patches in enumerate(num_patches_list): + image_tokens = ( + self.IMG_START_TOKEN + + self.IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + + self.IMG_END_TOKEN + ) + input_text = input_text.replace("", image_tokens, 1) + return { + "input_ids": tokenizer(input_text, return_tensors="pt")["input_ids"] + .flatten() + .tolist(), + "pixel_values": pixel_values, + "im_start_id": self.img_start_token_id, + "im_end_id": self.img_end_token_id, + "im_token_id": self.img_context_token_id, + "image_hashes": image_hashes, + "image_sizes": image_sizes, + "modalities": request_obj.modalities or ["image"], + } + + +ImageProcessorMapping = { + InternVLChatModel: InternVLImageProcessor, +} diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index b74dcc39df1..26801171119 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -37,7 +37,7 @@ from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import create_grammar_backend -from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer +from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer, get_tokenizer_from_processor from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( @@ -413,7 +413,7 @@ def init_tokenizer(self): trust_remote_code=server_args.trust_remote_code, revision=server_args.revision, ) - self.tokenizer = self.processor.tokenizer + self.tokenizer = get_tokenizer_from_processor(self.processor) else: self.tokenizer = get_tokenizer( server_args.tokenizer_path, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index c211d76ff57..f40cbaf1c80 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -49,7 +49,7 @@ from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer +from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer, get_tokenizer_from_processor from sglang.srt.managers.image_processor import ( get_dummy_image_processor, get_image_processor, @@ -187,7 +187,7 @@ def __init__( self.tokenizer = self.processor = None else: self.processor = _processor - self.tokenizer = self.processor.tokenizer + self.tokenizer = get_tokenizer_from_processor(self.processor) os.environ["TOKENIZERS_PARALLELISM"] = "false" else: self.image_processor = get_dummy_image_processor() diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 1423f253f1b..d78714d351a 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -20,7 +20,7 @@ import torch from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer +from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer, get_tokenizer_from_processor from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( GetWeightsByNameReqInput, @@ -93,7 +93,7 @@ def __init__( trust_remote_code=server_args.trust_remote_code, revision=server_args.revision, ) - self.tokenizer = self.processor.tokenizer + self.tokenizer = get_tokenizer_from_processor(self.processor) else: self.tokenizer = get_tokenizer( server_args.tokenizer_path, diff --git a/python/sglang/srt/models/internvl.py b/python/sglang/srt/models/internvl.py new file mode 100644 index 00000000000..78c972a9f3d --- /dev/null +++ b/python/sglang/srt/models/internvl.py @@ -0,0 +1,268 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.managers.schedule_batch import ImageInputs +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.internlm2 import InternLM2ForCausalLM +from sglang.srt.models.internvl_vit import InternVisionModel +from sglang.srt.models.qwen2 import Qwen2ForCausalLM +from sglang.utils import logger + + +class InternVLChatModel(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + use_flash_attn=True, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + + image_size = config.force_image_size or config.vision_config.image_size + patch_size = config.vision_config.patch_size + self.patch_size = patch_size + self.select_layer = config.select_layer + self.template = config.template + self.num_image_token = int( + (image_size // patch_size) ** 2 * (config.downsample_ratio**2) + ) + self.downsample_ratio = config.downsample_ratio + self.ps_version = config.ps_version + config.llm_config.attn_implementation = ( + "flash_attention_2" if use_flash_attn else "eager" + ) + + self.vision_model = InternVisionModel( + config=config.vision_config, quant_config=quant_config + ) + if config.llm_config.architectures[0] == "InternLM2ForCausalLM": + self.language_model = InternLM2ForCausalLM(config.llm_config) + self.llm_architectures = "InternLM2ForCausalLM" + elif config.llm_config.architectures[0] == "Qwen2ForCausalLM": + self.language_model = Qwen2ForCausalLM(config.llm_config) + self.llm_architectures = "Qwen2ForCausalLM" + else: + raise NotImplementedError( + f"{config.llm_config.architectures[0]} is not implemented." + ) + + vit_hidden_size = config.vision_config.hidden_size + llm_hidden_size = config.llm_config.hidden_size + + self.mlp1 = nn.Sequential( + nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2), + nn.Linear( + vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size + ), + nn.GELU(), + nn.Linear(llm_hidden_size, llm_hidden_size), + ) + + def pixel_shuffle(self, x, scale_factor=0.5): + n, w, h, c = x.size() + # N, W, H, C --> N, W, H * scale, C // scale + x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) + # N, W, H * scale, C // scale --> N, H * scale, W, C // scale + x = x.permute(0, 2, 1, 3).contiguous() + # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) + x = x.view( + n, + int(h * scale_factor), + int(w * scale_factor), + int(c / (scale_factor * scale_factor)), + ) + if self.ps_version == "v1": + logger.warn( + "In ps_version 'v1', the height and width have not been swapped back, " + "which results in a transposed image." + ) + else: + x = x.permute(0, 2, 1, 3).contiguous() + return x + + def extract_feature(self, pixel_values): + if self.select_layer == -1: + vit_embeds = self.vision_model( + pixel_values=pixel_values, output_hidden_states=False, return_dict=True + ).last_hidden_state + else: + vit_embeds = self.vision_model( + pixel_values=pixel_values, output_hidden_states=True, return_dict=True + ).hidden_states[self.select_layer] + vit_embeds = vit_embeds[:, 1:, :] + + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) + vit_embeds = self.mlp1(vit_embeds) + return vit_embeds + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + if ( + forward_batch.image_inputs is not None + and forward_batch.image_inputs[0] is not None + ): + image_input = forward_batch.image_inputs[0] + + image_token_indices = torch.isin( + input_ids, + torch.tensor(image_input.pad_values).to(device=input_ids.device), + ) + if image_token_indices.sum() == 0: + pass + else: + # [B * S] -> [B, S] + input_ids = input_ids.unsqueeze(0) + input_ids.clamp_(min=0, max=self.config.vocab_size - 1) + if self.llm_architectures == "Qwen2ForCausalLM": + input_embeds = self.language_model.model.embed_tokens(input_ids) + else: + input_embeds = self.language_model.model.tok_embeddings(input_ids) + B, N, C = input_embeds.shape + input_embeds = input_embeds.reshape(B * N, C) + pixel_values = image_input.pixel_values + vit_embeds = self.extract_feature(pixel_values) + + num_image_tokens = image_token_indices.sum() + input_embeds[image_token_indices] = vit_embeds.reshape(-1, C)[ + -num_image_tokens: + ].to(input_embeds.device) + input_embeds = input_embeds.reshape(N, C) + input_ids = None + + if input_ids is not None: + input_ids.clamp_(min=0, max=self.config.vocab_size - 1) + return self.language_model( + input_ids=input_ids, + positions=positions, + forward_batch=forward_batch, + input_embeds=input_embeds, + ) + + def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): + im_start_id: int = image_inputs.im_start_id + im_end_id: int = image_inputs.im_end_id + media_token_pairs = [(im_start_id, im_end_id)] + pad_values = image_inputs.pad_values + media_token_pairs = media_token_pairs + start_tokens = [s for s, _e in media_token_pairs] + end_tokens = [e for _s, e in media_token_pairs] + # First start token marks new media + media_start_token = start_tokens[0] + + padded_ids = [] + last_idx = 0 + media_idx = -1 + + start_indices = [i for i, x in enumerate(input_ids) if x in start_tokens] + end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens] + + if len(start_indices) != len(end_indices): + return input_ids + + for start_idx, end_idx in zip(start_indices, end_indices): + padded_ids.extend(input_ids[last_idx : start_idx + 1]) + + if input_ids[start_idx] == media_start_token: + media_idx += 1 + + num_tokens = end_idx - start_idx - 1 + pad_value = pad_values[media_idx] + padded_ids.extend([pad_value] * num_tokens) + + last_idx = end_idx + + padded_ids.extend(input_ids[last_idx:]) + + assert len(input_ids) == len(padded_ids) + return padded_ids + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + if self.llm_architectures == "Qwen2ForCausalLM": + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "up_proj", 1), + ("gate_up_proj", "gate_proj", 0), + ] + else: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "w1", 0), + ("gate_up_proj", "w3", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + if "wqkv" in name: + config = self.config + kv_groups = config.num_attention_heads // config.num_key_value_heads + head_dim = config.hidden_size // config.num_attention_heads + loaded_weight = loaded_weight.view( + -1, 2 + kv_groups, head_dim, loaded_weight.shape[-1] + ) + wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1], dim=1) + wq = wq.reshape(-1, wq.shape[-1]) + wk = wk.reshape(-1, wk.shape[-1]) + wv = wv.reshape(-1, wv.shape[-1]) + weight_loader = param.weight_loader + weight_loader(param, wq, "q") + weight_loader(param, wk, "k") + weight_loader(param, wv, "v") + else: + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + + +EntryClass = InternVLChatModel diff --git a/python/sglang/srt/models/internvl_vit.py b/python/sglang/srt/models/internvl_vit.py new file mode 100644 index 00000000000..eb2af464a02 --- /dev/null +++ b/python/sglang/srt/models/internvl_vit.py @@ -0,0 +1,554 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from typing import Iterable, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from einops import rearrange +from modelscope.models.cv.action_recognition.tada_convnext import DropPath +from torch import nn +from transformers import PretrainedConfig, PreTrainedModel +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling + +from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.utils import logger + +try: + from flash_attn.bert_padding import pad_input, unpad_input + from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func + + has_flash_attn = True +except: + print("FlashAttention2 is not installed.") + has_flash_attn = False + + +class FlashAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + + def __init__( + self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None + ): + super().__init__() + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + + def forward( + self, + qkv, + key_padding_mask=None, + causal=False, + cu_seqlens=None, + max_s=None, + need_weights=False, + ): + """Implements the multihead softmax attention. + Arguments + --------- + qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None + if unpadded: (nnz, 3, h, d) + key_padding_mask: a bool tensor of shape (B, S) + """ + assert not need_weights + assert qkv.dtype in [torch.float16, torch.bfloat16] + assert qkv.is_cuda + + if cu_seqlens is None: + batch_size = qkv.shape[0] + seqlen = qkv.shape[1] + if key_padding_mask is None: + qkv = rearrange(qkv, "b s ... -> (b s) ...") + max_s = seqlen + cu_seqlens = torch.arange( + 0, + (batch_size + 1) * seqlen, + step=seqlen, + dtype=torch.int32, + device=qkv.device, + ) + output = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens, + max_s, + self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, + causal=causal, + ) + output = rearrange(output, "(b s) ... -> b s ...", b=batch_size) + else: + nheads = qkv.shape[-2] + x = rearrange(qkv, "b s three h d -> b s (three h d)") + x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask) + x_unpad = rearrange( + x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads + ) + output_unpad = flash_attn_varlen_qkvpacked_func( + x_unpad, + cu_seqlens, + max_s, + self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, + causal=causal, + ) + output = rearrange( + pad_input( + rearrange(output_unpad, "nnz h d -> nnz (h d)"), + indices, + batch_size, + seqlen, + ), + "b s (h d) -> b s h d", + h=nheads, + ) + else: + assert max_s is not None + output = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens, + max_s, + self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, + causal=causal, + ) + + return output, None + + +class InternVisionEmbeddings(nn.Module): + def __init__(self, config: PretrainedConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter( + torch.randn(1, 1, self.embed_dim), + ) + + self.patch_embedding = nn.Conv2d( + in_channels=3, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter( + torch.randn(1, self.num_positions, self.embed_dim) + ) + + def _get_pos_embed(self, pos_embed, H, W): + target_dtype = pos_embed.dtype + pos_embed = ( + pos_embed.float() + .reshape( + 1, + self.image_size // self.patch_size, + self.image_size // self.patch_size, + -1, + ) + .permute(0, 3, 1, 2) + ) + pos_embed = ( + F.interpolate(pos_embed, size=(H, W), mode="bicubic", align_corners=False) + .reshape(1, -1, H * W) + .permute(0, 2, 1) + .to(target_dtype) + ) + return pos_embed + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding( + pixel_values + ) # shape = [*, channel, width, height] + batch_size, _, height, width = patch_embeds.shape + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + position_embedding = torch.cat( + [ + self.position_embedding[:, :1, :], + self._get_pos_embed(self.position_embedding[:, 1:, :], height, width), + ], + dim=1, + ) + embeddings = embeddings + position_embedding.to(target_dtype) + return embeddings + + +class InternAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.use_flash_attn = config.use_flash_attn and has_flash_attn + if config.use_flash_attn and not has_flash_attn: + print( + "Warning: Flash Attention is not available, use_flash_attn is set to False." + ) + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + + self.scale = self.head_dim**-0.5 + self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias) + self.attn_drop = nn.Dropout(config.attention_dropout) + self.proj_drop = nn.Dropout(config.dropout) + + self.qk_normalization = config.qk_normalization + + if self.qk_normalization: + self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps) + self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps) + if self.use_flash_attn: + self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout) + self.proj = nn.Linear(self.embed_dim, self.embed_dim) + + def _naive_attn(self, x): + B, N, C = x.shape + qkv = self.qkv(x) + # [b, s, embed_dim] --> [3, b, h, s,head_size] + qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute( + 2, 0, 3, 1, 4 + ) + # [3, b, h, s,head_size] --> [b, h, s,head_size] + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + if self.qk_normalization: + B_, H_, N_, D_ = q.shape + q = ( + self.q_norm(q.transpose(1, 2).flatten(-2, -1)) + .view(B_, N_, H_, D_) + .transpose(1, 2) + ) + k = ( + self.k_norm(k.transpose(1, 2).flatten(-2, -1)) + .view(B_, N_, H_, D_) + .transpose(1, 2) + ) + attn = (q * self.scale) @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def _flash_attn(self, x, key_padding_mask=None, need_weights=False): + qkv = self.qkv(x) + qkv = rearrange( + qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads + ) + if self.qk_normalization: + q, k, v = qkv.unbind(2) + q = self.q_norm(q.flatten(-2, -1)).view(q.shape) + k = self.k_norm(k.flatten(-2, -1)).view(k.shape) + qkv = torch.stack([q, k, v], dim=2) + + context, _ = self.inner_attn( + qkv, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + causal=False, + ) + outs = self.proj(rearrange(context, "b s h d -> b s (h d)")) + outs = self.proj_drop(outs) + return outs + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + x = ( + self._naive_attn(hidden_states) + if not self.use_flash_attn + else self._flash_attn(hidden_states) + ) + return x + + +class InternRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class InternMLP(nn.Module): + def __init__(self, config: PretrainedConfig): + super().__init__() + self.config = config + self.act = ACT2FN[config.hidden_act] + self.fc1 = ColumnParallelLinear(config.hidden_size, config.intermediate_size) + self.fc2 = RowParallelLinear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +NORM2FN = { + "rms_norm": InternRMSNorm, + "layer_norm": nn.LayerNorm, +} + + +class InternVisionEncoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + drop_path_rate: float, + quant_config: QuantizationConfig = None, + ): + super().__init__() + self.embed_dim = config.hidden_size + self.intermediate_size = config.intermediate_size + self.norm_type = config.norm_type + + self.attn = InternAttention(config) + self.mlp = InternMLP(config) + self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) + self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) + + self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim)) + self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim)) + self.drop_path1 = ( + DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + ) + self.drop_path2 = ( + DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + ) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> Tuple[ + torch.FloatTensor, + Optional[torch.FloatTensor], + Optional[Tuple[torch.FloatTensor]], + ]: + """ + Args: + hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)` + """ + hidden_states = hidden_states + self.drop_path1( + self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1 + ) + + hidden_states = hidden_states + self.drop_path2( + self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2 + ) + + return hidden_states + + +class InternVisionEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`InternEncoderLayer`]. + + Args: + config (`InternConfig`): + The corresponding vision configuration for the `InternEncoder`. + """ + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + # stochastic depth decay rule + dpr = [ + x.item() + for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers) + ] + self.layers = nn.ModuleList( + [ + InternVisionEncoderLayer(config, dpr[idx], quant_config) + for idx in range(config.num_hidden_layers) + ] + ) + self.gradient_checkpointing = True + + def forward( + self, + inputs_embeds, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Embedded representation of the inputs. Should be float, not int tokens. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + encoder_states = () if output_hidden_states else None + hidden_states = inputs_embeds + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + layer_outputs = encoder_layer( + hidden_states, + ) + hidden_states = layer_outputs + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states + ) + + +class InternVisionModel(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + + self.embeddings = InternVisionEmbeddings( + config, + ) + self.encoder = InternVisionEncoder(config, quant_config) + + def resize_pos_embeddings(self, old_size, new_size, patch_size): + pos_emb = self.embeddings.position_embedding + _, num_positions, embed_dim = pos_emb.shape + cls_emb = pos_emb[:, :1, :] + pos_emb = ( + pos_emb[:, 1:, :] + .reshape(1, old_size // patch_size, old_size // patch_size, -1) + .permute(0, 3, 1, 2) + ) + pos_emb = F.interpolate( + pos_emb.float(), + size=new_size // patch_size, + mode="bicubic", + align_corners=False, + ) + pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1) + pos_emb = torch.cat([cls_emb, pos_emb], dim=1) + self.embeddings.position_embedding = nn.Parameter(pos_emb) + self.embeddings.image_size = new_size + logger.info( + "Resized position embeddings from {} to {}".format(old_size, new_size) + ) + + def get_input_embeddings(self): + return self.embeddings + + @property + def dtype(self) -> torch.dtype: + return self.encoder.layers[0].mlp.fc1.weight.dtype + + @property + def device(self) -> torch.device: + return self.encoder.layers[0].mlp.fc1.weight.device + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_embeds: Optional[torch.FloatTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + pixel_values = pixel_values.to(device=self.device, dtype=self.dtype) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if pixel_values is None and pixel_embeds is None: + raise ValueError("You have to specify pixel_values or pixel_embeds") + + if pixel_embeds is not None: + hidden_states = pixel_embeds + else: + if len(pixel_values.shape) == 4: + hidden_states = self.embeddings(pixel_values) + else: + raise ValueError(f"wrong pixel_values size: {pixel_values.shape}") + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = encoder_outputs.last_hidden_state + pooled_output = last_hidden_state[:, 0, :] + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +EntryClass = InternVisionModel diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index 935c2057be2..dbd39109883 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -562,6 +562,20 @@ def setUpClass(cls): def test_video_chat_completion(self): pass - +class TestInternVL2_5Server(TestOpenAIVisionServer): + @classmethod + def setUpClass(cls): + cls.model = "OpenGVLab/InternVL2_5-2B" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--trust-remote-code", "--chat-template", "internvl2_5"], + ) + cls.base_url += "/v1" + def test_regex(self): + pass if __name__ == "__main__": unittest.main() From 973f750831eaace486ebf32ebe5ba281eb495ad8 Mon Sep 17 00:00:00 2001 From: fanxinyao Date: Mon, 17 Mar 2025 12:11:38 +0800 Subject: [PATCH 2/5] support internvl and MMMU dataset evaluation --- benchmark/mmmu/prompt_format.yaml | 4 ++-- python/sglang/srt/conversation.py | 11 ++++++++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/benchmark/mmmu/prompt_format.yaml b/benchmark/mmmu/prompt_format.yaml index 5126c129927..4910799351a 100644 --- a/benchmark/mmmu/prompt_format.yaml +++ b/benchmark/mmmu/prompt_format.yaml @@ -1,9 +1,9 @@ task_instructions: - "" multi_choice_example_format: -- "{}\n{}Answer with the option's letter from the given choices directly.\n" +- "{}\n{}Answer with the option's letter from the given choices directly." short_ans_example_format: -- "{}Answer the question using a single word or phrase.\n" +- "{}Answer the question using a single word or phrase." temperature: - 0 diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index 36d54d4cd54..bca3f1e3413 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -16,6 +16,7 @@ # Adapted from # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py import dataclasses +import re from enum import IntEnum, auto from typing import Dict, List, Optional, Tuple, Union @@ -303,6 +304,14 @@ def get_prompt(self) -> str: if message: if type(message) is tuple: message, _, _ = message + if 'user' in role: + if len(self.modalities) > 1: + image_section = "\n".join([f"Image-{i+1}: " for i in range(len(self.modalities))]) + message = re.sub(r"", "", message, count=len(self.modalities)).strip() + message = f"{image_section}\n{message}" + else: + message = message.replace("", "", 1).strip() + message = f"\n{message}" ret += role + message + self.sep else: ret += role @@ -693,4 +702,4 @@ def generate_chat_conv( sep="<|im_end|>\n", stop_str=["<|im_end|>", "<|action_end|>"], ) -) \ No newline at end of file +) From b6d8f3aca6e55e205b60a064cec35f1bf657058a Mon Sep 17 00:00:00 2001 From: fanxinyao Date: Mon, 17 Mar 2025 12:11:38 +0800 Subject: [PATCH 3/5] support internvl and MMMU dataset evaluation --- benchmark/mmmu/README.md | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/benchmark/mmmu/README.md b/benchmark/mmmu/README.md index be3f1e0434c..c870a747adb 100644 --- a/benchmark/mmmu/README.md +++ b/benchmark/mmmu/README.md @@ -15,8 +15,7 @@ python benchmark/mmmu/bench_hf.py --model-path Qwen/Qwen2-VL-7B-Instruct ``` Some popular model results: - -1. Qwen/Qwen2-VL-2B-Instruct: 0.241 -2. Qwen/Qwen2-VL-7B-Instruct: 0.255 -3. Qwen/Qwen2.5-VL-3B-Instruct: 0.245 -4. Qwen/Qwen2.5-VL-7B-Instruct: 0.242 +1. Qwen/Qwen2-VL-7B-Instruct(sglang): 0.48 +2. Qwen/Qwen2-VL-7B-Instruct(hf): 0.482 +3. OpenGVLab/InternVL2_5-38B(sglang): 0.612 +4. OpenGVLab/InternVL2_5-38B(hf): 0.61 \ No newline at end of file From be092d40a1286ae22641cec040ac46dfd3738b1f Mon Sep 17 00:00:00 2001 From: fanxinyao Date: Mon, 17 Mar 2025 12:11:38 +0800 Subject: [PATCH 4/5] support internvl and MMMU dataset evaluation --- benchmark/mmmu/README.md | 2 +- benchmark/mmmu/eval_utils.py | 1 - python/sglang/srt/conversation.py | 14 +++++-- python/sglang/srt/hf_transformers_utils.py | 2 + python/sglang/srt/managers/scheduler.py | 6 ++- .../sglang/srt/managers/tokenizer_manager.py | 6 ++- python/sglang/srt/managers/tp_worker.py | 6 ++- python/sglang/srt/models/internvl.py | 38 +++---------------- test/srt/test_vision_openai_server.py | 4 ++ 9 files changed, 38 insertions(+), 41 deletions(-) diff --git a/benchmark/mmmu/README.md b/benchmark/mmmu/README.md index 1940df82709..e96b1456600 100644 --- a/benchmark/mmmu/README.md +++ b/benchmark/mmmu/README.md @@ -22,4 +22,4 @@ Some popular model results: 1. Qwen/Qwen2-VL-7B-Instruct(sglang): 0.48 2. Qwen/Qwen2-VL-7B-Instruct(hf): 0.482 3. OpenGVLab/InternVL2_5-38B(sglang): 0.612 -4. OpenGVLab/InternVL2_5-38B(hf): 0.61 \ No newline at end of file +4. OpenGVLab/InternVL2_5-38B(hf): 0.61 diff --git a/benchmark/mmmu/eval_utils.py b/benchmark/mmmu/eval_utils.py index 25709e82aed..cea2bcad6a3 100644 --- a/benchmark/mmmu/eval_utils.py +++ b/benchmark/mmmu/eval_utils.py @@ -19,7 +19,6 @@ process_single_sample, ) from datasets import concatenate_datasets, load_dataset - from internvl_chat import InternVLChat from qwen2vl_chat import Qwen2VLChat diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index ab5851441fe..602ad210386 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -50,6 +50,7 @@ class SeparatorStyle(IntEnum): GEMMA3 = auto() MPT = auto() + @dataclasses.dataclass class Conversation: """A class that manages prompt templates and keeps all conversation history.""" @@ -318,10 +319,17 @@ def get_prompt(self) -> str: if message: if type(message) is tuple: message, _, _ = message - if 'user' in role: + if "user" in role: if len(self.modalities) > 1: - image_section = "\n".join([f"Image-{i+1}: " for i in range(len(self.modalities))]) - message = re.sub(r"", "", message, count=len(self.modalities)).strip() + image_section = "\n".join( + [ + f"Image-{i+1}: " + for i in range(len(self.modalities)) + ] + ) + message = re.sub( + r"", "", message, count=len(self.modalities) + ).strip() message = f"{image_section}\n{message}" else: message = message.replace("", "", 1).strip() diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index a513e09c0ff..a5ab80bca8c 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -218,11 +218,13 @@ def get_tokenizer( attach_additional_stop_token_ids(tokenizer) return tokenizer + def get_tokenizer_from_processor(processor): if isinstance(processor, PreTrainedTokenizerBase): return processor return processor.tokenizer + def get_processor( tokenizer_name: str, *args, diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 95235a321bf..bda3d712a78 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -37,7 +37,11 @@ from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import create_grammar_backend -from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer, get_tokenizer_from_processor +from sglang.srt.hf_transformers_utils import ( + get_processor, + get_tokenizer, + get_tokenizer_from_processor, +) from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index f40cbaf1c80..96dfbfed2cd 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -49,7 +49,11 @@ from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer, get_tokenizer_from_processor +from sglang.srt.hf_transformers_utils import ( + get_processor, + get_tokenizer, + get_tokenizer_from_processor, +) from sglang.srt.managers.image_processor import ( get_dummy_image_processor, get_image_processor, diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 0592b8cbe52..1fd46529dff 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -20,7 +20,11 @@ import torch from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer, get_tokenizer_from_processor +from sglang.srt.hf_transformers_utils import ( + get_processor, + get_tokenizer, + get_tokenizer_from_processor, +) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( GetWeightsByNameReqInput, diff --git a/python/sglang/srt/models/internvl.py b/python/sglang/srt/models/internvl.py index 78c972a9f3d..b3d82a4bc37 100644 --- a/python/sglang/srt/models/internvl.py +++ b/python/sglang/srt/models/internvl.py @@ -19,6 +19,9 @@ from transformers import PretrainedConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.managers.multi_modality_padding import ( + MultiModalityDataPaddingPatternTokenPairs, +) from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -172,39 +175,8 @@ def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): im_start_id: int = image_inputs.im_start_id im_end_id: int = image_inputs.im_end_id media_token_pairs = [(im_start_id, im_end_id)] - pad_values = image_inputs.pad_values - media_token_pairs = media_token_pairs - start_tokens = [s for s, _e in media_token_pairs] - end_tokens = [e for _s, e in media_token_pairs] - # First start token marks new media - media_start_token = start_tokens[0] - - padded_ids = [] - last_idx = 0 - media_idx = -1 - - start_indices = [i for i, x in enumerate(input_ids) if x in start_tokens] - end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens] - - if len(start_indices) != len(end_indices): - return input_ids - - for start_idx, end_idx in zip(start_indices, end_indices): - padded_ids.extend(input_ids[last_idx : start_idx + 1]) - - if input_ids[start_idx] == media_start_token: - media_idx += 1 - - num_tokens = end_idx - start_idx - 1 - pad_value = pad_values[media_idx] - padded_ids.extend([pad_value] * num_tokens) - - last_idx = end_idx - - padded_ids.extend(input_ids[last_idx:]) - - assert len(input_ids) == len(padded_ids) - return padded_ids + pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs) + return pattern.pad_input_tokens(input_ids, image_inputs) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if self.llm_architectures == "Qwen2ForCausalLM": diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index f10151d9acc..05e591253dd 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -603,6 +603,7 @@ def setUpClass(cls): def test_video_chat_completion(self): pass + class TestInternVL2_5Server(TestOpenAIVisionServer): @classmethod def setUpClass(cls): @@ -616,7 +617,10 @@ def setUpClass(cls): other_args=["--trust-remote-code", "--chat-template", "internvl2_5"], ) cls.base_url += "/v1" + def test_regex(self): pass + + if __name__ == "__main__": unittest.main() From 3743ec16864e661e763c8955f7f5760a8c04317e Mon Sep 17 00:00:00 2001 From: fanxinyao Date: Mon, 17 Mar 2025 12:11:38 +0800 Subject: [PATCH 5/5] support internvl and MMMU dataset evaluation --- benchmark/mmmu/bench_hf.py | 2 +- python/sglang/srt/configs/model_config.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmark/mmmu/bench_hf.py b/benchmark/mmmu/bench_hf.py index e7bd0221a13..73328c64f01 100644 --- a/benchmark/mmmu/bench_hf.py +++ b/benchmark/mmmu/bench_hf.py @@ -2,7 +2,7 @@ Bench the huggingface vLM with benchmark MMMU Usage: - python benchmark/mmmu/bench_hf.py --model-path Qwen/Qwen2-VL-7B-Instruct --dataset-path + python benchmark/mmmu/bench_hf.py --model-path Qwen/Qwen2-VL-7B-Instruct --dataset-path The eval output will be logged """ diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 16f708f88a2..8272d518ac5 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -472,7 +472,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal "Qwen2_5_VLForConditionalGeneration", "MiniCPMV", "MultiModalityCausalLM", - "InternVLChatModel" + "InternVLChatModel", "DeepseekVL2ForCausalLM", ]