|
7 | 7 | from pathlib import PosixPath |
8 | 8 | from typing import Optional, Union |
9 | 9 |
|
| 10 | +import numpy as np |
| 11 | +import numpy.typing as npt |
10 | 12 | import regex as re |
11 | 13 | import torch |
12 | 14 | from PIL.Image import Image |
@@ -495,30 +497,74 @@ def __init__(self, hf_runner: HfRunner): |
495 | 497 | self.max_num = self.config.max_dynamic_patch |
496 | 498 | self.image_size = self.vision_config.image_size |
497 | 499 |
|
498 | | - def __call__(self, text: str, images: Union[Image, list[Image]], |
499 | | - **kwargs): |
| 500 | + def __call__( |
| 501 | + self, |
| 502 | + text: str, |
| 503 | + images: Union[Image, list[Image]] = None, |
| 504 | + videos: Union[npt.NDArray, list[npt.NDArray]] = None, |
| 505 | + **kwargs, |
| 506 | + ): |
500 | 507 | from vllm.model_executor.models.internvl import ( |
501 | 508 | IMG_CONTEXT, IMG_END, IMG_START, |
502 | | - image_to_pixel_values_internvl) |
| 509 | + image_to_pixel_values_internvl, video_to_pixel_values_internvl) |
503 | 510 | images = [images] if isinstance(images, Image) else images |
504 | | - pixel_values = [ |
505 | | - image_to_pixel_values_internvl( |
506 | | - image, |
507 | | - input_size=self.image_size, |
508 | | - min_num=self.min_num, |
509 | | - max_num=self.max_num, |
510 | | - use_thumbnail=self.use_thumbnail, |
511 | | - ) for image in images |
512 | | - ] |
513 | | - num_patches_list = [ |
514 | | - pixel_value.shape[0] for pixel_value in pixel_values |
515 | | - ] |
| 511 | + videos = [videos] if isinstance(videos, np.ndarray) else videos |
| 512 | + if images is not None: |
| 513 | + pixel_values_images = [ |
| 514 | + image_to_pixel_values_internvl( |
| 515 | + image, |
| 516 | + input_size=self.image_size, |
| 517 | + min_num=self.min_num, |
| 518 | + max_num=self.max_num, |
| 519 | + use_thumbnail=self.use_thumbnail, |
| 520 | + ) for image in images |
| 521 | + ] |
| 522 | + num_patches_images = [ |
| 523 | + pixel_value.shape[0] for pixel_value in pixel_values_images |
| 524 | + ] |
| 525 | + else: |
| 526 | + pixel_values_images, num_patches_images = [], [] |
| 527 | + |
| 528 | + if videos is not None: |
| 529 | + pixel_values_videos = [ |
| 530 | + video_to_pixel_values_internvl( |
| 531 | + video, |
| 532 | + input_size=self.image_size, |
| 533 | + min_num=1, |
| 534 | + max_num=1, |
| 535 | + use_thumbnail=False, |
| 536 | + ) for video in videos |
| 537 | + ] |
| 538 | + num_patches_videos = [ |
| 539 | + pixel_value.shape[0] for pixel_value in pixel_values_videos |
| 540 | + ] |
| 541 | + else: |
| 542 | + pixel_values_videos, num_patches_videos = [], [] |
| 543 | + |
| 544 | + pixel_values = [] |
| 545 | + while ("<image>" in text) or ("<video>" in text): |
| 546 | + image_index = text.find("<image>") |
| 547 | + video_index = text.find("<video>") |
| 548 | + if image_index == -1 or (video_index > -1 |
| 549 | + and video_index < image_index): |
| 550 | + num_patches = num_patches_videos.pop(0) |
| 551 | + pixel_values.append(pixel_values_videos.pop(0)) |
| 552 | + context_tokens = IMG_START + \ |
| 553 | + IMG_CONTEXT * self.num_image_token + IMG_END |
| 554 | + video_tokens = ''.join([ |
| 555 | + f'Frame{i+1}: {context_tokens}' |
| 556 | + for i in range(num_patches) |
| 557 | + ]) |
| 558 | + text = text.replace('<video>', video_tokens, 1) |
| 559 | + else: |
| 560 | + num_patches = num_patches_images.pop(0) |
| 561 | + pixel_values.append(pixel_values_images.pop(0)) |
| 562 | + context_tokens = IMG_CONTEXT * self.num_image_token \ |
| 563 | + * num_patches |
| 564 | + image_tokens = IMG_START + context_tokens + IMG_END |
| 565 | + text = text.replace('<image>', image_tokens, 1) |
516 | 566 | pixel_values = torch.cat(pixel_values, dim=0) |
517 | | - for num_patches in num_patches_list: |
518 | | - context_tokens = IMG_CONTEXT * self.num_image_token \ |
519 | | - * num_patches |
520 | | - image_tokens = IMG_START + context_tokens + IMG_END |
521 | | - text = text.replace('<image>', image_tokens, 1) |
| 567 | + |
522 | 568 | prompt = self.tokenizer(text, return_tensors="pt") |
523 | 569 | prompt.update({"pixel_values": pixel_values}) |
524 | 570 | return prompt |
|
0 commit comments