diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index fedbae445f37..777c4274eea5 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -1086,9 +1086,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: image_size = config.force_image_size or config.vision_config.image_size patch_size = config.vision_config.patch_size self.patch_size = patch_size - self.num_image_token = int( - (image_size // patch_size) ** 2 * (config.downsample_ratio**2) - ) + self.patch_tokens = (image_size // patch_size) ** 2 + self.num_image_token = int(self.patch_tokens * (config.downsample_ratio**2)) self.downsample_ratio = config.downsample_ratio self.ps_version = config.ps_version @@ -1430,3 +1429,17 @@ def get_mm_mapping(self) -> MultiModelKeys: connector="mlp1", tower_model="vision_model", ) + + def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int: + if num_image_tokens <= 0 or self.num_image_token <= 0: + return 0 + + num_patches = num_image_tokens // self.num_image_token + return num_patches * (self.patch_tokens + 1) + + def get_num_mm_connector_tokens(self, num_vision_tokens: int) -> int: + if num_vision_tokens <= 0 or self.num_image_token <= 0: + return 0 + + num_patches = num_vision_tokens // (self.patch_tokens + 1) + return num_patches * self.num_image_token