|
| 1 | +from multiprocessing import context |
| 2 | +import torch |
| 3 | +from transformers import AutoModel, AutoTokenizer |
| 4 | +from PIL import Image |
| 5 | +import numpy as np |
| 6 | +import torchvision.transforms as transforms |
| 7 | +from datetime import timedelta |
| 8 | +import logging |
| 9 | + |
| 10 | +from lmms_eval import utils |
| 11 | +from lmms_eval.api.instance import Instance |
| 12 | +from lmms_eval.api.model import lmms |
| 13 | +from lmms_eval.api.registry import register_model |
| 14 | +from lmms_eval.utils import stop_sequences_criteria |
| 15 | + |
| 16 | +from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs |
| 17 | +from accelerate.state import AcceleratorState |
| 18 | + |
| 19 | +from typing import Optional, Sequence, List, Tuple, Union |
| 20 | +import re |
| 21 | +from tqdm import tqdm |
| 22 | + |
| 23 | +pattern = re.compile(r"[A-Z]") |
| 24 | + |
| 25 | +eval_logger = logging.getLogger("lmms-eval") |
| 26 | + |
| 27 | +meta_instruction = """You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔). |
| 28 | +- InternLM-XComposer (浦语·灵笔) is a multi-modality conversational language model that is developed\ |
| 29 | + by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless. |
| 30 | +- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in the language chosen by\ |
| 31 | + the user such as English and 中文. |
| 32 | +- InternLM-XComposer (浦语·灵笔) is capable of comprehending and articulating responses\ |
| 33 | + effectively based on the provided image.""" |
| 34 | + |
| 35 | + |
| 36 | +@register_model("xcomposer2_4khd") |
| 37 | +class XComposer2_4KHD(lmms): |
| 38 | + def __init__( |
| 39 | + self, |
| 40 | + pretrained: str = "internlm/internlm-xcomposer2-4khd-7b", |
| 41 | + device: Optional[str] = "cuda:0", |
| 42 | + batch_size: Optional[Union[int, str]] = 1, |
| 43 | + device_map="cuda:0", |
| 44 | + need_bos: bool = True, |
| 45 | + padding: bool = False, |
| 46 | + half: bool = False, |
| 47 | + **kwargs, |
| 48 | + ) -> None: |
| 49 | + super().__init__() |
| 50 | + |
| 51 | + accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) |
| 52 | + accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) |
| 53 | + if accelerator.num_processes > 1: |
| 54 | + self._device = torch.device(f"cuda:{accelerator.local_process_index}") |
| 55 | + self.device_map = f"cuda:{accelerator.local_process_index}" |
| 56 | + elif accelerator.num_processes == 1 and device_map == "auto": |
| 57 | + self._device = torch.device(device) |
| 58 | + self.device_map = device_map |
| 59 | + else: |
| 60 | + self._device = torch.device(f"cuda:{accelerator.local_process_index}") |
| 61 | + self.device_map = f"cuda:{accelerator.local_process_index}" |
| 62 | + |
| 63 | + self.pretrained = pretrained |
| 64 | + self.need_bos = need_bos |
| 65 | + self.padding = padding |
| 66 | + self._model = AutoModel.from_pretrained(self.pretrained, device_map=self.device_map, trust_remote_code=True) |
| 67 | + self._tokenizer = AutoTokenizer.from_pretrained(self.pretrained, trust_remote_code=True) |
| 68 | + self.model.tokenizer = self.tokenizer |
| 69 | + self.batch_size_per_gpu = batch_size |
| 70 | + |
| 71 | + if accelerator.num_processes > 1: |
| 72 | + assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." |
| 73 | + # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model |
| 74 | + # Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works |
| 75 | + # I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work. |
| 76 | + if accelerator.distributed_type == DistributedType.DEEPSPEED: |
| 77 | + kwargs = { |
| 78 | + "train_micro_batch_size_per_gpu": self.batch_size_per_gpu, |
| 79 | + "train_batch_size": self.batch_size_per_gpu * accelerator.num_processes, |
| 80 | + } |
| 81 | + AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs) |
| 82 | + eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0") |
| 83 | + if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED: |
| 84 | + self._model = accelerator.prepare(self.model) |
| 85 | + else: |
| 86 | + self._model = accelerator.prepare_model(self.model, evaluation_mode=True) |
| 87 | + self.accelerator = accelerator |
| 88 | + if self.accelerator.is_local_main_process: |
| 89 | + eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") |
| 90 | + self._rank = self.accelerator.local_process_index |
| 91 | + self._world_size = self.accelerator.num_processes |
| 92 | + elif accelerator.num_processes == 1 and device_map == "auto": |
| 93 | + eval_logger.info(f"Using {accelerator.num_processes} devices with tensor parallelism") |
| 94 | + self._rank = 0 |
| 95 | + self._word_size = 1 |
| 96 | + else: |
| 97 | + eval_logger.info(f"Using single device: {self._device}") |
| 98 | + self.model.to(self._device) |
| 99 | + self._rank = 0 |
| 100 | + self._world_size = 1 |
| 101 | + |
| 102 | + @property |
| 103 | + def config(self): |
| 104 | + # return the associated transformers.AutoConfig for the given pretrained model. |
| 105 | + return self._config |
| 106 | + |
| 107 | + @property |
| 108 | + def tokenizer(self): |
| 109 | + return self._tokenizer |
| 110 | + |
| 111 | + @property |
| 112 | + def model(self): |
| 113 | + # returns the model, unwrapping it if using Accelerate |
| 114 | + if hasattr(self, "accelerator"): |
| 115 | + return self.accelerator.unwrap_model(self._model) |
| 116 | + else: |
| 117 | + return self._model |
| 118 | + |
| 119 | + @property |
| 120 | + def batch_size(self): |
| 121 | + return self.batch_size_per_gpu |
| 122 | + |
| 123 | + @property |
| 124 | + def device(self): |
| 125 | + return self._device |
| 126 | + |
| 127 | + @property |
| 128 | + def rank(self): |
| 129 | + return self._rank |
| 130 | + |
| 131 | + @property |
| 132 | + def world_size(self): |
| 133 | + return self._world_size |
| 134 | + |
| 135 | + def flatten(self, input): |
| 136 | + new_list = [] |
| 137 | + for i in input: |
| 138 | + for j in i: |
| 139 | + new_list.append(j) |
| 140 | + return new_list |
| 141 | + |
| 142 | + def generate_until(self, requests) -> List[str]: |
| 143 | + res = [] |
| 144 | + pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") |
| 145 | + |
| 146 | + for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]: |
| 147 | + # encode, pad, and truncate contexts for this batch |
| 148 | + if "[UNUSED_TOKEN_146]" not in contexts: |
| 149 | + contexts = f"[UNUSED_TOKEN_146]user\n{contexts}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n" |
| 150 | + visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] |
| 151 | + visuals = self.flatten(visuals) |
| 152 | + |
| 153 | + if "hd_num" not in gen_kwargs: |
| 154 | + if listinstr(["docvqa_test", "infovqa_test"], task.lower()): |
| 155 | + self.model.hd_num = 65 |
| 156 | + elif listinstr(["docvqa_val", "infovqa_val", "OCRBench"], task.lower()): |
| 157 | + self.model.hd_num = 55 |
| 158 | + elif listinstr(["mmmu", "mmbench", "mmvet"], task.lower()): |
| 159 | + self.model.hd_num = 16 |
| 160 | + else: |
| 161 | + self.model.hd_num = 25 |
| 162 | + else: |
| 163 | + self.model.hd_num = gen_kwargs.pop("hd_num") |
| 164 | + |
| 165 | + pt1 = 0 |
| 166 | + embeds = [] |
| 167 | + im_mask = [] |
| 168 | + images_loc = [0] |
| 169 | + need_bos = self.need_bos |
| 170 | + padding = self.padding |
| 171 | + for i, pts in enumerate(images_loc + [len(contexts)]): |
| 172 | + subtext = contexts[pt1:pts] |
| 173 | + if need_bos or len(subtext) > 0: |
| 174 | + text_embeds = self.model.encode_text(subtext, add_special_tokens=need_bos).to(self.device) |
| 175 | + embeds.append(text_embeds) |
| 176 | + im_mask.append(torch.zeros(text_embeds.shape[:2]).to(self.device)) |
| 177 | + need_bos = False |
| 178 | + if i < len(visuals): |
| 179 | + image = visuals[i] |
| 180 | + |
| 181 | + image = HD_transform(image, im_num=self.model.hd_num) |
| 182 | + image = self.model.vis_processor(image).unsqueeze(0).to(self.device) |
| 183 | + image_embeds = self.model.encode_img(image) |
| 184 | + embeds.append(image_embeds) |
| 185 | + im_mask.append(torch.ones(image_embeds.shape[:2]).to(self.device)) |
| 186 | + pt1 = pts |
| 187 | + embeds = torch.cat(embeds, dim=1) |
| 188 | + im_mask = torch.cat(im_mask, dim=1) |
| 189 | + im_mask = im_mask.bool() |
| 190 | + |
| 191 | + if "max_new_tokens" not in gen_kwargs: |
| 192 | + gen_kwargs["max_new_tokens"] = 1024 |
| 193 | + if "temperature" not in gen_kwargs: |
| 194 | + gen_kwargs["temperature"] = 0 |
| 195 | + if "top_p" not in gen_kwargs: |
| 196 | + gen_kwargs["top_p"] = None |
| 197 | + if "num_beams" not in gen_kwargs: |
| 198 | + gen_kwargs["num_beams"] = 1 |
| 199 | + if "do_sample" not in gen_kwargs: |
| 200 | + gen_kwargs["do_sample"] = False |
| 201 | + if "repetition_penalty" not in gen_kwargs: |
| 202 | + gen_kwargs["repetition_penalty"] = 1.0 |
| 203 | + |
| 204 | + outputs = self.model.generate( |
| 205 | + inputs_embeds=embeds, |
| 206 | + im_mask=im_mask, |
| 207 | + temperature=gen_kwargs["temperature"], |
| 208 | + max_new_tokens=gen_kwargs["max_new_tokens"], |
| 209 | + num_beams=gen_kwargs["num_beams"], |
| 210 | + do_sample=gen_kwargs["do_sample"], |
| 211 | + repetition_penalty=gen_kwargs["repetition_penalty"], |
| 212 | + ) |
| 213 | + output_token = outputs[0] |
| 214 | + if output_token[0] == 0 or output_token[0] == 1: |
| 215 | + output_token = output_token[1:] |
| 216 | + output_text = self.model.tokenizer.decode(output_token, add_special_tokens=False) |
| 217 | + output_text = output_text.split("[UNUSED_TOKEN_145]")[0].strip() |
| 218 | + output_text = output_text.split("<|im_end|>")[0].strip() |
| 219 | + # if DATASET_TYPE(task) == "multi-choice": |
| 220 | + # output_text = pattern.findall(output_text) |
| 221 | + # if len(output_text) == 0: |
| 222 | + # print("Error:", output_text) |
| 223 | + # output_text = "Z" |
| 224 | + # if type(output_text) == list: |
| 225 | + # output_text = output_text[0] |
| 226 | + res.append(output_text) |
| 227 | + pbar.update(1) |
| 228 | + pbar.close() |
| 229 | + return res |
| 230 | + |
| 231 | + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: |
| 232 | + return super().loglikelihood(requests) |
| 233 | + |
| 234 | + |
| 235 | +def padding_336(b): |
| 236 | + width, height = b.size |
| 237 | + tar = int(np.ceil(height / 336) * 336) |
| 238 | + top_padding = int((tar - height) / 2) |
| 239 | + bottom_padding = tar - height - top_padding |
| 240 | + left_padding = 0 |
| 241 | + right_padding = 0 |
| 242 | + b = transforms.functional.pad(b, [left_padding, top_padding, right_padding, bottom_padding], fill=[255, 255, 255]) |
| 243 | + |
| 244 | + return b |
| 245 | + |
| 246 | + |
| 247 | +def HD_transform(img, im_num=16): |
| 248 | + width, height = img.size |
| 249 | + trans = False |
| 250 | + if width < height: |
| 251 | + img = img.transpose(Image.TRANSPOSE) |
| 252 | + trans = True |
| 253 | + width, height = img.size |
| 254 | + ratio = width / height |
| 255 | + scale = 1 |
| 256 | + while scale * np.ceil(scale / ratio) <= im_num: |
| 257 | + scale += 1 |
| 258 | + scale -= 1 |
| 259 | + new_w = int(scale * 336) |
| 260 | + new_h = int(new_w / ratio) |
| 261 | + |
| 262 | + img = transforms.functional.resize( |
| 263 | + img, |
| 264 | + [new_h, new_w], |
| 265 | + ) |
| 266 | + img = padding_336(img) |
| 267 | + width, height = img.size |
| 268 | + assert width * height <= im_num * 336 * 336 |
| 269 | + if trans: |
| 270 | + img = img.transpose(Image.TRANSPOSE) |
| 271 | + |
| 272 | + return img |
| 273 | + |
| 274 | + |
| 275 | +def listinstr(lst, s): |
| 276 | + assert isinstance(lst, list) |
| 277 | + for item in lst: |
| 278 | + if item in s: |
| 279 | + return True |
| 280 | + return False |
| 281 | + |
| 282 | + |
| 283 | +def DATASET_TYPE(dataset): |
| 284 | + # Dealing with Custom Dataset |
| 285 | + dataset = dataset.lower() |
| 286 | + if listinstr(["mmbench", "seedbench", "ccbench", "mmmu", "scienceqa", "ai2d", "mmstar"], dataset): |
| 287 | + return "multi-choice" |
| 288 | + elif listinstr(["mme", "hallusion"], dataset): |
| 289 | + return "Y/N" |
| 290 | + elif "coco" in dataset: |
| 291 | + return "Caption" |
| 292 | + elif listinstr(["ocrvqa", "textvqa", "chartqa", "mathvista", "docvqa", "infovqa", "llavabench", "mmvet", "ocrbench"], dataset): |
| 293 | + return "VQA" |
| 294 | + else: |
| 295 | + return "QA" |
0 commit comments