Skip to content

Commit 662f05c

Browse files
committed
Comment out parse result in xcomposer
1 parent 0932932 commit 662f05c

File tree

1 file changed

+295
-0
lines changed

1 file changed

+295
-0
lines changed

Diff for: lmms_eval/models/xcomposer2_4KHD.py

+295
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
from multiprocessing import context
2+
import torch
3+
from transformers import AutoModel, AutoTokenizer
4+
from PIL import Image
5+
import numpy as np
6+
import torchvision.transforms as transforms
7+
from datetime import timedelta
8+
import logging
9+
10+
from lmms_eval import utils
11+
from lmms_eval.api.instance import Instance
12+
from lmms_eval.api.model import lmms
13+
from lmms_eval.api.registry import register_model
14+
from lmms_eval.utils import stop_sequences_criteria
15+
16+
from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs
17+
from accelerate.state import AcceleratorState
18+
19+
from typing import Optional, Sequence, List, Tuple, Union
20+
import re
21+
from tqdm import tqdm
22+
23+
pattern = re.compile(r"[A-Z]")
24+
25+
eval_logger = logging.getLogger("lmms-eval")
26+
27+
meta_instruction = """You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).
28+
- InternLM-XComposer (浦语·灵笔) is a multi-modality conversational language model that is developed\
29+
by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.
30+
- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in the language chosen by\
31+
the user such as English and 中文.
32+
- InternLM-XComposer (浦语·灵笔) is capable of comprehending and articulating responses\
33+
effectively based on the provided image."""
34+
35+
36+
@register_model("xcomposer2_4khd")
37+
class XComposer2_4KHD(lmms):
38+
def __init__(
39+
self,
40+
pretrained: str = "internlm/internlm-xcomposer2-4khd-7b",
41+
device: Optional[str] = "cuda:0",
42+
batch_size: Optional[Union[int, str]] = 1,
43+
device_map="cuda:0",
44+
need_bos: bool = True,
45+
padding: bool = False,
46+
half: bool = False,
47+
**kwargs,
48+
) -> None:
49+
super().__init__()
50+
51+
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
52+
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
53+
if accelerator.num_processes > 1:
54+
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
55+
self.device_map = f"cuda:{accelerator.local_process_index}"
56+
elif accelerator.num_processes == 1 and device_map == "auto":
57+
self._device = torch.device(device)
58+
self.device_map = device_map
59+
else:
60+
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
61+
self.device_map = f"cuda:{accelerator.local_process_index}"
62+
63+
self.pretrained = pretrained
64+
self.need_bos = need_bos
65+
self.padding = padding
66+
self._model = AutoModel.from_pretrained(self.pretrained, device_map=self.device_map, trust_remote_code=True)
67+
self._tokenizer = AutoTokenizer.from_pretrained(self.pretrained, trust_remote_code=True)
68+
self.model.tokenizer = self.tokenizer
69+
self.batch_size_per_gpu = batch_size
70+
71+
if accelerator.num_processes > 1:
72+
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
73+
# If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model
74+
# Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works
75+
# I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work.
76+
if accelerator.distributed_type == DistributedType.DEEPSPEED:
77+
kwargs = {
78+
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
79+
"train_batch_size": self.batch_size_per_gpu * accelerator.num_processes,
80+
}
81+
AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs)
82+
eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0")
83+
if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED:
84+
self._model = accelerator.prepare(self.model)
85+
else:
86+
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
87+
self.accelerator = accelerator
88+
if self.accelerator.is_local_main_process:
89+
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
90+
self._rank = self.accelerator.local_process_index
91+
self._world_size = self.accelerator.num_processes
92+
elif accelerator.num_processes == 1 and device_map == "auto":
93+
eval_logger.info(f"Using {accelerator.num_processes} devices with tensor parallelism")
94+
self._rank = 0
95+
self._word_size = 1
96+
else:
97+
eval_logger.info(f"Using single device: {self._device}")
98+
self.model.to(self._device)
99+
self._rank = 0
100+
self._world_size = 1
101+
102+
@property
103+
def config(self):
104+
# return the associated transformers.AutoConfig for the given pretrained model.
105+
return self._config
106+
107+
@property
108+
def tokenizer(self):
109+
return self._tokenizer
110+
111+
@property
112+
def model(self):
113+
# returns the model, unwrapping it if using Accelerate
114+
if hasattr(self, "accelerator"):
115+
return self.accelerator.unwrap_model(self._model)
116+
else:
117+
return self._model
118+
119+
@property
120+
def batch_size(self):
121+
return self.batch_size_per_gpu
122+
123+
@property
124+
def device(self):
125+
return self._device
126+
127+
@property
128+
def rank(self):
129+
return self._rank
130+
131+
@property
132+
def world_size(self):
133+
return self._world_size
134+
135+
def flatten(self, input):
136+
new_list = []
137+
for i in input:
138+
for j in i:
139+
new_list.append(j)
140+
return new_list
141+
142+
def generate_until(self, requests) -> List[str]:
143+
res = []
144+
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
145+
146+
for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]:
147+
# encode, pad, and truncate contexts for this batch
148+
if "[UNUSED_TOKEN_146]" not in contexts:
149+
contexts = f"[UNUSED_TOKEN_146]user\n{contexts}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n"
150+
visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
151+
visuals = self.flatten(visuals)
152+
153+
if "hd_num" not in gen_kwargs:
154+
if listinstr(["docvqa_test", "infovqa_test"], task.lower()):
155+
self.model.hd_num = 65
156+
elif listinstr(["docvqa_val", "infovqa_val", "OCRBench"], task.lower()):
157+
self.model.hd_num = 55
158+
elif listinstr(["mmmu", "mmbench", "mmvet"], task.lower()):
159+
self.model.hd_num = 16
160+
else:
161+
self.model.hd_num = 25
162+
else:
163+
self.model.hd_num = gen_kwargs.pop("hd_num")
164+
165+
pt1 = 0
166+
embeds = []
167+
im_mask = []
168+
images_loc = [0]
169+
need_bos = self.need_bos
170+
padding = self.padding
171+
for i, pts in enumerate(images_loc + [len(contexts)]):
172+
subtext = contexts[pt1:pts]
173+
if need_bos or len(subtext) > 0:
174+
text_embeds = self.model.encode_text(subtext, add_special_tokens=need_bos).to(self.device)
175+
embeds.append(text_embeds)
176+
im_mask.append(torch.zeros(text_embeds.shape[:2]).to(self.device))
177+
need_bos = False
178+
if i < len(visuals):
179+
image = visuals[i]
180+
181+
image = HD_transform(image, im_num=self.model.hd_num)
182+
image = self.model.vis_processor(image).unsqueeze(0).to(self.device)
183+
image_embeds = self.model.encode_img(image)
184+
embeds.append(image_embeds)
185+
im_mask.append(torch.ones(image_embeds.shape[:2]).to(self.device))
186+
pt1 = pts
187+
embeds = torch.cat(embeds, dim=1)
188+
im_mask = torch.cat(im_mask, dim=1)
189+
im_mask = im_mask.bool()
190+
191+
if "max_new_tokens" not in gen_kwargs:
192+
gen_kwargs["max_new_tokens"] = 1024
193+
if "temperature" not in gen_kwargs:
194+
gen_kwargs["temperature"] = 0
195+
if "top_p" not in gen_kwargs:
196+
gen_kwargs["top_p"] = None
197+
if "num_beams" not in gen_kwargs:
198+
gen_kwargs["num_beams"] = 1
199+
if "do_sample" not in gen_kwargs:
200+
gen_kwargs["do_sample"] = False
201+
if "repetition_penalty" not in gen_kwargs:
202+
gen_kwargs["repetition_penalty"] = 1.0
203+
204+
outputs = self.model.generate(
205+
inputs_embeds=embeds,
206+
im_mask=im_mask,
207+
temperature=gen_kwargs["temperature"],
208+
max_new_tokens=gen_kwargs["max_new_tokens"],
209+
num_beams=gen_kwargs["num_beams"],
210+
do_sample=gen_kwargs["do_sample"],
211+
repetition_penalty=gen_kwargs["repetition_penalty"],
212+
)
213+
output_token = outputs[0]
214+
if output_token[0] == 0 or output_token[0] == 1:
215+
output_token = output_token[1:]
216+
output_text = self.model.tokenizer.decode(output_token, add_special_tokens=False)
217+
output_text = output_text.split("[UNUSED_TOKEN_145]")[0].strip()
218+
output_text = output_text.split("<|im_end|>")[0].strip()
219+
# if DATASET_TYPE(task) == "multi-choice":
220+
# output_text = pattern.findall(output_text)
221+
# if len(output_text) == 0:
222+
# print("Error:", output_text)
223+
# output_text = "Z"
224+
# if type(output_text) == list:
225+
# output_text = output_text[0]
226+
res.append(output_text)
227+
pbar.update(1)
228+
pbar.close()
229+
return res
230+
231+
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
232+
return super().loglikelihood(requests)
233+
234+
235+
def padding_336(b):
236+
width, height = b.size
237+
tar = int(np.ceil(height / 336) * 336)
238+
top_padding = int((tar - height) / 2)
239+
bottom_padding = tar - height - top_padding
240+
left_padding = 0
241+
right_padding = 0
242+
b = transforms.functional.pad(b, [left_padding, top_padding, right_padding, bottom_padding], fill=[255, 255, 255])
243+
244+
return b
245+
246+
247+
def HD_transform(img, im_num=16):
248+
width, height = img.size
249+
trans = False
250+
if width < height:
251+
img = img.transpose(Image.TRANSPOSE)
252+
trans = True
253+
width, height = img.size
254+
ratio = width / height
255+
scale = 1
256+
while scale * np.ceil(scale / ratio) <= im_num:
257+
scale += 1
258+
scale -= 1
259+
new_w = int(scale * 336)
260+
new_h = int(new_w / ratio)
261+
262+
img = transforms.functional.resize(
263+
img,
264+
[new_h, new_w],
265+
)
266+
img = padding_336(img)
267+
width, height = img.size
268+
assert width * height <= im_num * 336 * 336
269+
if trans:
270+
img = img.transpose(Image.TRANSPOSE)
271+
272+
return img
273+
274+
275+
def listinstr(lst, s):
276+
assert isinstance(lst, list)
277+
for item in lst:
278+
if item in s:
279+
return True
280+
return False
281+
282+
283+
def DATASET_TYPE(dataset):
284+
# Dealing with Custom Dataset
285+
dataset = dataset.lower()
286+
if listinstr(["mmbench", "seedbench", "ccbench", "mmmu", "scienceqa", "ai2d", "mmstar"], dataset):
287+
return "multi-choice"
288+
elif listinstr(["mme", "hallusion"], dataset):
289+
return "Y/N"
290+
elif "coco" in dataset:
291+
return "Caption"
292+
elif listinstr(["ocrvqa", "textvqa", "chartqa", "mathvista", "docvqa", "infovqa", "llavabench", "mmvet", "ocrbench"], dataset):
293+
return "VQA"
294+
else:
295+
return "QA"

0 commit comments

Comments
 (0)