From 685962958754f486c126e91f2aebd11a552701f2 Mon Sep 17 00:00:00 2001 From: effortprogrammer Date: Sat, 10 Jan 2026 13:12:21 +0000 Subject: [PATCH 01/18] fix(hyperclovax): support multimodal prompt handling Signed-off-by: effortprogrammer --- tests/entrypoints/test_chat_utils.py | 29 + vllm/entrypoints/chat_utils.py | 12 +- .../models/hyperclovax_vision.py | 1219 ++++++++--------- vllm/model_executor/models/registry.py | 1 + 4 files changed, 571 insertions(+), 690 deletions(-) diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 6df2d26f2f0d..a94da2b89bd2 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -1434,6 +1434,35 @@ def test_parse_chat_messages_context_text_format( assert mm_uuids is None +def test_parse_chat_messages_openai_format_image_url( + phi3v_model_config, + image_url, +): + content = [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "What's in the image?"}, + ] + conversation, mm_data, mm_uuids = parse_chat_messages( + [ + { + "role": "user", + "content": content, + } + ], + phi3v_model_config, + content_format="openai", + ) + + assert conversation == [ + { + "role": "user", + "content": content, + } + ] + _assert_mm_data_is_image_input(mm_data, 1) + _assert_mm_uuids(mm_uuids, 1, expected_uuids=[None]) + + def test_parse_chat_messages_rejects_too_many_images_in_one_message( phi3v_model_config, image_url, diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 5e31f60ad0ca..1caf2dbf2539 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1519,6 +1519,8 @@ def _parse_chat_message_content_part( with multimodal placeholders. """ if isinstance(part, str): # Handle plain text parts + if wrap_dicts: + return {"type": "text", "text": part} return part # Handle structured dictionary parts part_type, content = _parse_chat_message_content_mm_part(part) @@ -1578,11 +1580,11 @@ def _parse_chat_message_content_part( else: raise NotImplementedError(f"Unknown part type: {part_type}") - return ( - {"type": modality} - if wrap_dicts - else (MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None) - ) + if wrap_dicts: + if isinstance(part, dict): + return dict(part) + return {"type": "text", "text": str(part)} + return MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None # No need to validate using Pydantic again diff --git a/vllm/model_executor/models/hyperclovax_vision.py b/vllm/model_executor/models/hyperclovax_vision.py index f5226baba5da..b1d99ba88548 100644 --- a/vllm/model_executor/models/hyperclovax_vision.py +++ b/vllm/model_executor/models/hyperclovax_vision.py @@ -1,25 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # copied from : https://github.com/huggingface/transformers -import ast -from collections import defaultdict from collections.abc import Iterable, Mapping, Sequence from functools import partial from itertools import accumulate -from typing import Annotated, Literal +from typing import Annotated, Any, Literal import numpy as np import torch import torch.nn as nn -from einops import rearrange -from timm.layers import LayerNorm, LayerNorm2d -from timm.models.regnet import RegStage -from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig -from transformers.modeling_utils import no_init_weights +from transformers import BatchFeature from vllm.config import VllmConfig -from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig -from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.forward_context import set_forward_context from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import ( @@ -39,12 +33,11 @@ from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .siglip import SiglipVisionModel +from .qwen2_5_vl import Qwen2_5_VisionTransformer from .utils import ( AutoWeightsLoader, - flatten_bn, + WeightsMapper, init_vllm_registered_model, maybe_prefix, ) @@ -54,6 +47,13 @@ IMAGE_TOKEN: str = "<|dummy3|>" VIDEO_TOKEN: str = "<|_unuse_missing_100270|>" +# V2 (32B Think model) uses different tokens - retrieved from config at runtime +# These placeholder strings must match the chat template format exactly. +# The chat template produces: <|image_start|><|IMAGE_PAD|><|image_end|> +# Similar to Qwen2-VL's <|vision_start|><|image_pad|><|vision_end|> format. +V2_IMAGE_TOKEN: str = "<|image_start|><|IMAGE_PAD|><|image_end|>" +V2_VIDEO_TOKEN: str = "<|video_start|><|VIDEO_PAD|><|video_end|>" + # Based on combine_frames_into_images in # https://huggingface.co/naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B/blob/main/processing_hyperclovax.py @@ -327,7 +327,7 @@ def _get_mm_fields_config( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return dict( + fields = dict( pixel_values_images=MultiModalFieldConfig.batched("image"), image_sizes_images=MultiModalFieldConfig.batched("image"), vision_query_lengths_images=MultiModalFieldConfig.batched("image"), @@ -335,308 +335,489 @@ def _get_mm_fields_config( vision_query_lengths_videos=MultiModalFieldConfig.batched("video"), ) + return fields -def _build_hcxvision_hf_info( - ctx: InputProcessingContext, -) -> HCXVisionProcessingInfo: - return HCXVisionProcessingInfo(ctx) +# ============================================================================= +# HyperCLOVAX V2 (32B Think Model) Support +# Uses Qwen2.5 Vision Transformer instead of CLIP/SigLIP +# ============================================================================= -def _build_hcxvision_hf_processor( - info: HCXVisionProcessingInfo, - dummy_inputs: BaseDummyInputsBuilder[HCXVisionProcessingInfo], - *, - cache: BaseMultiModalProcessorCache | None = None, -) -> BaseMultiModalProcessor: - if isinstance(info, HCXVisionProcessingInfo): - return HCXVisionMultiModalProcessor( - info, - dummy_inputs, # type: ignore - cache=cache, - ) - raise NotImplementedError(type(info)) +class HCXVisionV2ImagePixelInputs(TensorSchema): + """ + V2 Image inputs using Qwen2.5-VL style grid_thw format. + Dimensions: + - np: Number of patches + - ni: Number of images + - cps: Number of channels * patch_size * patch_size + """ -def init_vision_tower_for_hcxvision( - vision_config, - quant_config: QuantizationConfig | None, - multimodal_config: MultiModalConfig | None, - *, - use_nth_layer: int | None = None, - require_post_norm: bool | None = None, - prefix: str = "", -) -> CLIPVisionModel | SiglipVisionModel: - num_hidden_layers = vision_config.num_hidden_layers - if not isinstance(use_nth_layer, int): - pass - elif use_nth_layer >= 0: - num_hidden_layers = use_nth_layer + 1 - else: - num_hidden_layers = num_hidden_layers + use_nth_layer + 1 - - if isinstance(vision_config, CLIPVisionConfig): - return CLIPVisionModel( - vision_config, - quant_config=quant_config, - multimodal_config=multimodal_config, - num_hidden_layers_override=num_hidden_layers, - require_post_norm=require_post_norm, - prefix=prefix, - ) - elif isinstance(vision_config, SiglipVisionConfig): - return SiglipVisionModel( - vision_config, - quant_config=quant_config, - multimodal_config=multimodal_config, - num_hidden_layers_override=num_hidden_layers, - require_post_norm=require_post_norm, - prefix=prefix, - ) + type: Literal["pixel_values"] = "pixel_values" + pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")] + image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) +class HCXVisionV2ImageEmbeddingInputs(TensorSchema): + """ + V2 Image embedding inputs. -class HCXVisionMlp(nn.Module): - def __init__( - self, - mm_projector_type, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - ): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.mm_projector_type = mm_projector_type - if self.mm_projector_type == "mlp": - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - elif self.mm_projector_type == "inverted_mlp": - self.fc1 = nn.Linear(in_features, 2 * hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(2 * hidden_features, out_features) - else: - raise NotImplementedError( - "{} is not implemented".format(self.mm_projector_type) - ) + Dimensions: + - nf: Number of image features + - hs: Hidden size + - ni: Number of images + """ + + type: Literal["image_embeds"] = "image_embeds" + image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")] + image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] + + +HCXVisionV2ImageInputs = HCXVisionV2ImagePixelInputs | HCXVisionV2ImageEmbeddingInputs - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.fc2(x) - return x + +class HCXVisionV2VideoPixelInputs(TensorSchema): + """ + V2 Video inputs using Qwen2.5-VL style grid_thw format. + + Dimensions: + - np: Number of patches + - nv: Number of videos + - ctps: Number of channels * temporal_patch_size * patch_size * patch_size + """ + + type: Literal["pixel_values_videos"] = "pixel_values_videos" + pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "ctps")] + video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] -class HCXVisionCAbstractor(nn.Module): +class HCXVisionV2VideoEmbeddingInputs(TensorSchema): """ - This module is based on C-Abstractor, whose license is under apache-2.0. - You can check the original code at - https://github.com/khanrc/honeybee/blob/main/honeybee/projectors/projectors.py - and we made necessary modifications. + V2 Video embedding inputs. + + Dimensions: + - nf: Number of video features + - hs: Hidden size + - nv: Number of videos """ - def __init__( + type: Literal["video_embeds"] = "video_embeds" + video_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")] + video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] + + +HCXVisionV2VideoInputs = HCXVisionV2VideoPixelInputs | HCXVisionV2VideoEmbeddingInputs + + +class HCXVisionV2ProcessingInfo(BaseProcessingInfo): + """Processing info for HyperCLOVAX V2 (32B Think model).""" + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"image": None, "video": None} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + patch_size = vision_config.patch_size + spatial_merge_size = vision_config.spatial_merge_size + + grid_h = image_height // patch_size + grid_w = image_width // patch_size + + return (grid_h * grid_w) // (spatial_merge_size**2) + + def get_num_video_tokens( self, - num_queries: int, - num_input_tokens: int, - encoder_hidden_size: int, - hidden_size: int, - output_hidden_size: int, - pos_emb: bool = True, - prenorm: bool = False, + *, + video_width: int, + video_height: int, + num_frames: int, + ) -> int: + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + patch_size = vision_config.patch_size + temporal_patch_size = vision_config.temporal_patch_size + spatial_merge_size = vision_config.spatial_merge_size + + grid_t = num_frames // temporal_patch_size + grid_h = video_height // patch_size + grid_w = video_width // patch_size + + return (grid_t * grid_h * grid_w) // (spatial_merge_size**2) + + def get_image_size_with_most_features(self) -> ImageSize: + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + # Use a reasonable default size + size = getattr(vision_config, "image_size", 448) + return ImageSize(width=size, height=size) + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + ) + + +class HCXVisionV2DummyInputsBuilder(BaseDummyInputsBuilder[HCXVisionV2ProcessingInfo]): + """Dummy inputs builder for HyperCLOVAX V2 memory profiling.""" + + def get_dummy_text( + self, + mm_counts: Mapping[str, int], + ) -> str: + # This method is not used when get_dummy_processor_inputs is overridden, + # but we keep it for compatibility. + return "" + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, object] | None = None, ): - super().__init__() - self.num_input_tokens = num_input_tokens - self.output_hidden_size = output_hidden_size + """ + Override to use token IDs directly instead of text strings. - # Positional embedding - if pos_emb: - self.pos_emb = torch.nn.Parameter( - torch.zeros(1, num_input_tokens, encoder_hidden_size) - ) - self.pos_emb.data.normal_(mean=0.0, std=0.02) - else: - self.pos_emb = None + This avoids the tokenizer issue where <|IMAGE_PAD|> might not be + recognized as a special token and gets split into multiple tokens. + By passing token IDs directly, we ensure the correct token (128060) + is used for prompt replacement matching. + """ + from vllm.multimodal.profiling import ProcessorInputs - # (Optional) Pre-normalization layer - if prenorm: - self.prenorm = LayerNorm(encoder_hidden_size) - else: - self.prenorm = None + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + hf_config = self.info.get_hf_config() + + # Use token IDs directly to avoid tokenizer issues with special tokens + image_token_id = hf_config.image_token_id # 128060 + video_token_id = hf_config.video_token_id # 128061 - self.build_net( - num_queries, encoder_hidden_size, hidden_size, output_hidden_size + # Create prompt as token ID list instead of text string + prompt_ids: list[int] = [image_token_id] * num_images + [ + video_token_id + ] * num_videos + + dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options) + + return ProcessorInputs( + prompt=prompt_ids, + mm_data=dummy_mm_data, + tokenization_kwargs={"truncation": False}, ) - self.dtype = next(self.parameters()).dtype - def forward( + def get_dummy_mm_data( self, - x: torch.Tensor, - num_queries_vis_abstractors: list[list[int]] | None = None, - num_grids: list[int] | None = None, - ) -> torch.Tensor: - if self.prenorm is not None: - x = self.prenorm(x) + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) - if self.pos_emb is not None: - x = x + self.pos_emb + target_width, target_height = self.info.get_image_size_with_most_features() + target_num_frames = 16 # Default for video + + image_overrides = mm_options.get("image") if mm_options else None + video_overrides = mm_options.get("video") if mm_options else None + + result: MultiModalDataDict = { + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, # type: ignore + ), + "video": self._get_dummy_videos( + width=target_width, + height=target_height, + num_frames=target_num_frames, + num_videos=num_videos, + overrides=video_overrides, # type: ignore + ), + } + + return result - x = self._forward( - x, - num_queries_vis_abstractors=num_queries_vis_abstractors, - num_grids=num_grids, - ) # (B, L, output_hidden_size) - return x +class HCXVisionV2MultiModalProcessor( + BaseMultiModalProcessor[HCXVisionV2ProcessingInfo] +): + """Multimodal processor for HyperCLOVAX V2 (32B Think model).""" - def _forward( + def _call_hf_processor( self, - x: torch.Tensor, - num_queries_vis_abstractors: list[list[int]] | None = None, - num_grids: list[int] | None = None, - ) -> torch.Tensor: - # x: [B, L, dim] - B, L, dim = x.shape - hw = int(L**0.5) - x = rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw) - - if num_queries_vis_abstractors is not None: - assert num_grids is not None - return self._forward_adaptive_num_query( - x, num_queries_vis_abstractors, num_grids - ) + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + images = mm_data.get("images") + videos = mm_data.get("videos") + + # Get the HF processor + hf_processor = self.info.get_hf_processor(**mm_kwargs) + + # Build data dict for HF processor (images/videos only) + # The HF processor (HCXVisionV2Processor) doesn't support audio + # NOTE: We pass the prompt as-is without token normalization. + # Token expansion is handled by vLLM via _get_prompt_updates since + # _hf_processor_applies_updates returns False. + data: dict[str, object] = dict( + text=prompt, + images=images, + videos=videos, + ) - x = self.net(x) - x = rearrange(x, "b d h w -> b (h w) d") - x = self.readout(x) - return x + processed_outputs = self.info.ctx.call_hf_processor( + hf_processor=hf_processor, + data=data, + ) + + return processed_outputs - def _forward_adaptive_num_query( + def _hf_processor_applies_updates( self, - x: torch.Tensor, - num_queries_vis_abstractors: list[list[int]] | None = None, - num_grids: list[int] | None = None, - ) -> list[torch.Tensor]: - # self.net is consisted by 3 layers (s1, sampler, s2) - assert len(self.net) == 3 - - x = self.net[0](x) # s1 - new_x = [] - for i, num_queries in enumerate(num_queries_vis_abstractors): - hw = int(num_queries**0.5) - sampler = nn.AdaptiveAvgPool2d((hw, hw)) - out = sampler(x[num_grids[i] : num_grids[i + 1], :]) - out = self.net[2](out) # s2 - - out = rearrange(out, "b d h w -> b (h w) d") - out = self.readout(out) - - new_x.append(out) - return new_x - - def build_net( + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + # HyperCLOVAX V2 has a token case mismatch bug: + # - Chat template uses <|IMAGE_PAD|> (uppercase) + # - HF processor (Qwen2_5_VLProcessor) expects <|image_pad|> (lowercase) + # - Tokenizer vocab has <|IMAGE_PAD|> (uppercase) = token ID 128060 + # + # The HF processor's token expansion fails because it looks for lowercase + # but the tokenized prompt has uppercase tokens. We bypass HF processor's + # expansion and let vLLM handle it via _get_prompt_updates using the + # correct token IDs from hf_config. + return False + + def _get_prompt_updates( self, - n_queries: int, - encoder_hidden_size: int, - hidden_size: int, - output_hidden_size: int, - depth: int = 3, - mlp_depth: int = 2, - ): - assert (n_queries**0.5).is_integer(), ( - f"n_queries must be square number. n_queries: {n_queries}" - ) - hw = int(n_queries**0.5) - - # RegBlock = ResBlock + SE - RegBlock = partial( - RegStage, - stride=1, - dilation=1, - act_layer=nn.SiLU, - norm_layer=LayerNorm2d, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_config = self.info.get_hf_config() + + # Use token IDs directly from config. + # This matches what get_dummy_processor_inputs uses, ensuring consistency. + placeholder: dict[str, int] = { + "image": hf_config.image_token_id, # 128060 for <|IMAGE_PAD|> + "video": hf_config.video_token_id, # 128061 for <|VIDEO_PAD|> + } + + merge_size = hf_config.vision_config.spatial_merge_size + + def get_replacement_v2( + item_idx: int, + modality: str, + out_mm_kwargs: MultiModalKwargsItems, + ): + out_item = out_mm_kwargs[modality][item_idx] + + if modality == "image": + grid_thw_elem = out_item.get("image_grid_thw") + if grid_thw_elem is not None: + # Access .data to get the actual tensor from MultiModalFieldElem + grid_thw = grid_thw_elem.data + # Qwen2.5-VL style calculation + h, w = grid_thw[1].item(), grid_thw[2].item() + num_tokens = (h * w) // (merge_size**2) + else: + # Fallback or error + raise ValueError("Missing image_grid_thw for V2 model") + elif modality == "video": + grid_thw_elem = out_item.get("video_grid_thw") + if grid_thw_elem is not None: + # Access .data to get the actual tensor from MultiModalFieldElem + grid_thw = grid_thw_elem.data + t, h, w = grid_thw[0].item(), grid_thw[1].item(), grid_thw[2].item() + num_tokens = (t * h * w) // (merge_size**2) + else: + raise ValueError("Missing video_grid_thw for V2 model") + else: + raise NotImplementedError(modality) + + return [placeholder[modality]] * num_tokens + + return [ + PromptReplacement( + modality=modality, + target=[ + placeholder[modality], + ], + replacement=partial( + get_replacement_v2, + modality=modality, + out_mm_kwargs=out_mm_kwargs, + ), + ) + for modality in ("image", "video") + ] + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + # HyperCLOVAX V2 uses Qwen2.5-VL style flattened pixel values where + # pixel_values has shape (num_patches, channels*patch_size*patch_size) + # while image_grid_thw has shape (num_images, 3). + # We need to use flat_from_sizes to correctly handle this mismatch. + hf_config = self.info.get_hf_config() + spatial_merge_size = hf_config.vision_config.spatial_merge_size + + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_pixel_grid_sizes = image_grid_thw.prod(-1) + image_embed_grid_sizes = ( + image_pixel_grid_sizes // spatial_merge_size // spatial_merge_size ) - s1 = RegBlock( - depth, - encoder_hidden_size, - hidden_size, + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_pixel_grid_sizes = video_grid_thw.prod(-1) + video_embed_grid_sizes = ( + video_pixel_grid_sizes // spatial_merge_size // spatial_merge_size ) - sampler = nn.AdaptiveAvgPool2d((hw, hw)) - s2 = RegBlock( - depth, - hidden_size, - hidden_size, + + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_pixel_grid_sizes + ), + image_embeds=MultiModalFieldConfig.flat_from_sizes( + "image", image_embed_grid_sizes + ), + image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_pixel_grid_sizes + ), + video_embeds=MultiModalFieldConfig.flat_from_sizes( + "video", video_embed_grid_sizes + ), + video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True), ) - self.net = nn.Sequential(s1, sampler, s2) - self.readout = self.build_mlp(mlp_depth, hidden_size, output_hidden_size) - def build_mlp( - self, - depth: int, - hidden_size: int, - output_hidden_size: int, - ): - layers = [nn.Linear(hidden_size, output_hidden_size)] - for _ in range(1, depth): - layers.append(nn.SiLU()) - layers.append(nn.Linear(output_hidden_size, output_hidden_size)) - return nn.Sequential(*layers) +def _build_hcxvision_v2_hf_info( + ctx: InputProcessingContext, +) -> HCXVisionV2ProcessingInfo: + return HCXVisionV2ProcessingInfo(ctx) + + +def _build_hcxvision_v2_hf_processor( + info: HCXVisionV2ProcessingInfo, + dummy_inputs: BaseDummyInputsBuilder[HCXVisionV2ProcessingInfo], + *, + cache: BaseMultiModalProcessorCache | None = None, +) -> BaseMultiModalProcessor: + return HCXVisionV2MultiModalProcessor( + info, + dummy_inputs, # type: ignore + cache=cache, + ) @MULTIMODAL_REGISTRY.register_processor( - _build_hcxvision_hf_processor, - info=_build_hcxvision_hf_info, - dummy_inputs=HCXVisionDummyInputsBuilder, + _build_hcxvision_v2_hf_processor, + info=_build_hcxvision_v2_hf_info, + dummy_inputs=HCXVisionV2DummyInputsBuilder, ) -class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): +class HCXVisionV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): + """ + HyperCLOVAX-SEED Vision-Language Model (V2 architecture). + + Supports: + - HyperCLOVAX-SEED-Think-32B: Vision + Text + - HyperCLOVAX-SEED-Omni-8B: Vision + Audio + Text + + Uses Qwen2.5 Vision Transformer as the vision encoder. + """ + packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], + "qkv": ["qkv"], # For vision tower } - def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + # Weight mapping for loading HuggingFace checkpoints + # NOTE: Order matters! Ignores (None) should come before renames to prevent + # partial matches + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.": "", # Remove model. prefix if present + "vision_model.": "visual.", # HF uses vision_model, we use visual + }, + orig_to_new_substr={ + # Ignore modules not implemented in vLLM + "discrete_vision_model": None, # TextAlignedTokenizer + }, + ) + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + **kwargs: Any | None, + ) -> None: super().__init__() - # init configs config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config - # text_config + + # Text config text_config = config.text_config if text_config.model_type in ["gpt2", "hyperclovax", "llama"]: text_config._attn_implementation = "sdpa" if text_config.model_type != "hyperclovax": text_config.logits_scaling = 1.0 - # vision_config + + # Vision config vision_config = config.vision_config - vision_config.auto_map = {} - vision_config.anyres = config.anyres - vision_config.max_num_grids = config.max_num_grids + + self.config = config + self.vision_config = vision_config + self.text_config = text_config + self.vllm_config = vllm_config self.dtype = vllm_config.model_config.dtype - ## possible_resolution should be matched with preprocessor_config.json - config.possible_resolutions = self._init_possible_resolutions( - config, vision_config + # Initialize Qwen2.5 Vision Transformer + self.visual = Qwen2_5_VisionTransformer( + vision_config=vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + multimodal_config=multimodal_config, + prefix=maybe_prefix(prefix, "visual"), ) - # init models & parameters - with no_init_weights(): # weight will be loaded in from_pretrained - self.vision_model = init_vision_tower_for_hcxvision( - vision_config, - quant_config=quant_config, - multimodal_config=multimodal_config, - use_nth_layer=getattr(config, "use_nth_layer", -1), - require_post_norm=False, - prefix=maybe_prefix(prefix, "vision_model"), - ) - self.mm_projector = self._init_mm_projector(config, text_config, vision_config) + # Linear projector (vision_hidden_size -> text_hidden_size) + # For V2 model: mm_projector_type is "linear" + vision_hidden_size = vision_config.hidden_size + text_hidden_size = text_config.hidden_size + # Check if out_hidden_size is defined (Qwen2.5-VL style) + # The merger in Qwen2.5 VisionTransformer handles projection to out_hidden_size + if hasattr(vision_config, "out_hidden_size"): + out_hidden = vision_config.out_hidden_size + else: + out_hidden = vision_hidden_size + + # Always create Linear projector since HF checkpoint has mm_projector weights + self.mm_projector = nn.Linear(out_hidden, text_hidden_size) + + # Language model self.lm_head_vocab_size = getattr( text_config, "padded_vocab_size", text_config.vocab_size ) @@ -646,83 +827,131 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: prefix=maybe_prefix(prefix, "language_model"), ) - if config.anyres: - self.image_newline = nn.Parameter( - torch.empty(text_config.hidden_size, dtype=self.dtype) - ) - - self.config = config - self.vision_config = vision_config - self.text_config = text_config - - # use_sum_loss = bool(kwargs.pop("use_sum_loss", False)) - # self.reduction = self._init_reduction_type(use_sum_loss) + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): - return IMAGE_TOKEN + return V2_IMAGE_TOKEN if modality.startswith("video"): - return VIDEO_TOKEN + return V2_VIDEO_TOKEN raise ValueError("Only image or video modality is supported") def _parse_and_validate_image_input( self, **kwargs: object, - ) -> HCXVisionImageInputs | None: - pixel_values_images = kwargs.pop("pixel_values_images", None) + ) -> HCXVisionV2ImageInputs | None: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) - if pixel_values_images is None: + if pixel_values is None and image_embeds is None: return None - image_sizes_images = kwargs.pop("image_sizes_images") + if pixel_values is not None: + return HCXVisionV2ImagePixelInputs( + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) - return HCXVisionImagePixelInputs( - pixel_values_images=pixel_values_images, - image_sizes_images=image_sizes_images, - ) + if image_embeds is not None: + return HCXVisionV2ImageEmbeddingInputs( + image_embeds=image_embeds, + image_grid_thw=image_grid_thw, + ) + + return None def _parse_and_validate_video_input( self, **kwargs: object, - ) -> HCXVisionVideoInputs | None: + ) -> HCXVisionV2VideoInputs | None: pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_embeds = kwargs.pop("video_embeds", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) - if pixel_values_videos is None: + if pixel_values_videos is None and video_embeds is None: return None - return HCXVisionVideoPixelInputs( - pixel_values_videos=pixel_values_videos, - ) + if pixel_values_videos is not None: + return HCXVisionV2VideoPixelInputs( + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + ) + + if video_embeds is not None: + return HCXVisionV2VideoEmbeddingInputs( + video_embeds=video_embeds, + video_grid_thw=video_grid_thw, + ) + + return None def _process_image_input( self, - image_input: HCXVisionImageInputs, + image_input: HCXVisionV2ImageInputs, ) -> tuple[torch.Tensor, ...]: - return self.forward_images( - pixel_values_images=image_input["pixel_values_images"], - image_sizes_images=image_input["image_sizes_images"], - ) + """Process images through Qwen2.5 ViT and projector.""" + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() + + if image_input["type"] == "image_embeds": + image_embeds = image_input["image_embeds"].type(self.visual.dtype) + else: + pixel_values = image_input["pixel_values"] + with set_forward_context(None, self.vllm_config): + image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) + + # Apply projector + image_embeds = self.mm_projector(image_embeds) + + # Split concatenated embeddings for each image + merge_size = self.visual.spatial_merge_size + sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() + return image_embeds.split(sizes) def _process_video_input( self, - video_input: HCXVisionVideoInputs, + video_input: HCXVisionV2VideoInputs, ) -> tuple[torch.Tensor, ...]: - return self.forward_videos( - pixel_values_videos=video_input["pixel_values_videos"], - ) + """Process videos through Qwen2.5 ViT and projector.""" + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() + + if video_input["type"] == "video_embeds": + video_embeds = video_input["video_embeds"].type(self.visual.dtype) + else: + pixel_values_videos = video_input["pixel_values_videos"] + with set_forward_context(None, self.vllm_config): + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list) + + # Apply projector + video_embeds = self.mm_projector(video_embeds) + + # Split concatenated embeddings for each video + merge_size = self.visual.spatial_merge_size + sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() + return video_embeds.split(sizes) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: modalities = {} - # Preserve the order of modalities if there are multiple of them - # from the order of kwargs. for input_key in kwargs: - if input_key == "pixel_values_images" and "images" not in modalities: - modalities["images"] = self._parse_and_validate_image_input(**kwargs) - if input_key == "pixel_values_videos" and "videos" not in modalities: - modalities["videos"] = self._parse_and_validate_video_input(**kwargs) + if ( + input_key in ("pixel_values", "image_embeds") + and "image" not in modalities + ): + modalities["image"] = self._parse_and_validate_image_input(**kwargs) + if ( + input_key in ("pixel_values_videos", "video_embeds") + and "video" not in modalities + ): + modalities["video"] = self._parse_and_validate_video_input(**kwargs) return modalities @@ -737,21 +966,19 @@ def embed_multimodal( if not modalities: return [] - # The result multimodal_embeddings is tuple of tensors, with each - # tensor correspoending to a multimodal data item (image or video). multimodal_embeddings: tuple[torch.Tensor, ...] = () - # NOTE: It is important to iterate over the keys in this dictionary - # to preserve the order of the modalities. for modality in modalities: - if modality == "images": - image_input = modalities["images"] - image_embeddings = self._process_image_input(image_input) - multimodal_embeddings += tuple(image_embeddings) - if modality == "videos": - video_input = modalities["videos"] - video_embeddings = self._process_video_input(video_input) - multimodal_embeddings += tuple(video_embeddings) + if modality == "image": + image_input = modalities["image"] + if image_input is not None: + image_embeddings = self._process_image_input(image_input) + multimodal_embeddings += tuple(image_embeddings) + if modality == "video": + video_input = modalities["video"] + if video_input is not None: + video_embeddings = self._process_video_input(video_input) + multimodal_embeddings += tuple(video_embeddings) return multimodal_embeddings @@ -771,175 +998,6 @@ def forward( ) return hidden_states - def forward_images( - self, - pixel_values_images: list[torch.Tensor], - image_sizes_images: torch.Tensor, - ) -> tuple[torch.Tensor, ...]: - pixel_values_image_flat = flatten_bn(pixel_values_images, concat=True) - - visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1 - image_forward_outs = self.vision_model(pixel_values_image_flat)[ - :, visual_token_idx: - ] - - image_forward_outs = image_forward_outs.to(dtype=self.mm_projector.dtype) - image_forward_outs = self.mm_projector(image_forward_outs) # b (h w) d - - split_sizes = [len(item) for item in pixel_values_images] - image_forward_outs = torch.split(image_forward_outs, split_sizes, dim=0) - - # newline for anyres postprocessing - image_features = anyres_postprocessing( - image_forward_outs=image_forward_outs, - image_sizes=image_sizes_images.tolist(), - num_queries_vis_abstractor=self.config.num_queries_vis_abstractor_image, - unpad=self.config.unpad, - patch_size=self.vision_config.patch_size, - grid_size=self.vision_config.image_size, - image_newline=self.image_newline, - possible_resolutions=self.config.possible_resolutions, - ) - - return tuple(image_features) - - def forward_videos( - self, - pixel_values_videos: list[list[torch.Tensor]], - ) -> tuple[torch.Tensor, ...]: - pixel_values_videos_flat = flatten_bn( - [frame for frames in pixel_values_videos for frame in frames], - concat=True, - ) - - visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1 - video_forward_outs = self.vision_model(pixel_values_videos_flat)[ - :, visual_token_idx: - ] - - video_forward_outs = video_forward_outs.to(dtype=self.mm_projector.dtype) - - # Run MM-Projector - # len(num_grids) == len(num_queries_vis_abstractors) + 1 - grid_idx = 0 - # e.g. [0, 9, 18, 19, 27, 28, 36, 37, 45, 46, 54, 55, 56] - num_grids = [grid_idx] - # e.g. [81, 81, 81, 9, 81, 9, 81, 9, 81, 9, 81, 9] - num_queries_vis_abstractors = [] - len_total_frames = video_forward_outs.shape[0] - - if self.config.first_last_frames_slow: - # slowfast (first_last_frames_slow) - assert len_total_frames != 0 - if len_total_frames <= 2: - num_queries_vis_abstractors.append( - self.config.num_queries_vis_abstractor_video_slow - ) - grid_idx += len_total_frames - num_grids.append(grid_idx) - else: - num_queries_vis_abstractors.append( - self.config.num_queries_vis_abstractor_video_slow - ) - grid_idx += 1 - num_grids.append(grid_idx) - - num_queries_vis_abstractors.append( - self.config.num_queries_vis_abstractor_video_fast - ) - grid_idx += len_total_frames - 2 - num_grids.append(grid_idx) - - num_queries_vis_abstractors.append( - self.config.num_queries_vis_abstractor_video_slow - ) - grid_idx += 1 - num_grids.append(grid_idx) - else: - # slowfast - for pixel_values_frames in pixel_values_videos: - for pixel_values_frame in pixel_values_frames: - if len(pixel_values_frame) > 0: - num_queries_vis_abstractors.append( - self.config.num_queries_vis_abstractor_video_slow - ) - grid_idx += 1 - num_grids.append(grid_idx) - num_queries_vis_abstractors.append( - self.config.num_queries_vis_abstractor_video_fast - ) - grid_idx = grid_idx + len(pixel_values_frame) - 1 - num_grids.append(grid_idx) - - video_forward_outs = self.mm_projector( - video_forward_outs, num_queries_vis_abstractors, num_grids - ) - - video_features = [] # what we want to return - target_features = [] - target_group_size = 0 - group_counter = 0 - video_groups = [ - len(frame) for frames in pixel_values_videos for frame in frames - ] # for concat video features after projector - - for forward_out in video_forward_outs: - target_group_size += len(forward_out) - target_features.append(forward_out.flatten(0, 1)) - - video_group_size = video_groups[group_counter] - if video_group_size == target_group_size: - video_features.append(torch.cat(target_features, dim=0)) - target_features = [] - group_counter += 1 - target_group_size = 0 - - elif video_group_size < target_group_size: - raise RuntimeError(f"{video_group_size=} < {target_group_size=}") - - assert len(target_features) == 0, ( - f"target_features is not empty!! {target_features}" - ) - assert len(video_groups) == len(video_features) - - feats_per_video = [len(video) for video in pixel_values_videos] - idxs_per_video = [0, *accumulate(feats_per_video)] - return tuple( - torch.cat(video_features[idxs_per_video[i] : idxs_per_video[i + 1]]) - for i in range(len(feats_per_video)) - ) - - def _prepare_multimodal_kwargs(self, **kwargs: object): - output = defaultdict(list) - for k, v in kwargs.items(): - if len(v) < 1 or len(v[0]) < 1: - continue # if empty batch of empty sample - - new_k, is_video = k, False - if not k.endswith("_images") and not k.endswith("_videos"): - pass - else: - new_k, is_video = k.split("_")[:-1], k.split("_")[-1] - new_k = "_".join(new_k) - is_video = is_video == "videos" - - for _sample_idx, _v in enumerate(v): # batch -> sample - if new_k not in ["pixel_values"]: - if len(output[new_k]) < _sample_idx + 1: - output[new_k].append(list()) - _v = _v.detach().cpu().numpy().tolist() - output[new_k][_sample_idx] += _v - elif isinstance(_v, torch.Tensor): - if len(output[new_k]) < _sample_idx + 1: - output[new_k].append(list()) - output["is_videos"].append(list()) - _v = list(torch.unbind(_v, dim=0)) - output[new_k][_sample_idx] += _v - output["is_videos"][_sample_idx] += [ - is_video, - ] * len(_v) - return dict(output) - def compute_logits( self, hidden_states: torch.Tensor, @@ -951,213 +1009,4 @@ def load_weights( weights: Iterable[tuple[str, torch.Tensor]], ) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) - - def _init_possible_resolutions( - self, - config, - vision_config, - ): - if not getattr(config, "possible_resolutions", []): - possible_resolutions = [] - if config.anyres: - assert config.max_num_grids > 0 - for i in range(1, config.max_num_grids + 1): - for j in range(1, config.max_num_grids + 1): - if i == 1 and j == 1 and not config.use_1x1_grid: - continue - if i * j <= config.max_num_grids: - possible_resolutions.append([i, j]) - - possible_resolutions = [ - [ys * vision_config.image_size, xs * vision_config.image_size] - for ys, xs in possible_resolutions - ] - return possible_resolutions - else: - return config.possible_resolutions - - def _init_mm_projector( - self, - config, - text_config, - vision_config, - ): - input_hidden_size = vision_config.hidden_size - if config.mm_projector_type == "linear": - mm_projector = nn.Linear(input_hidden_size, text_config.hidden_size) - mm_projector.dtype = next(mm_projector.parameters()).dtype - elif config.mm_projector_type == "cabstractor": - mm_projector = HCXVisionCAbstractor( - num_queries=config.num_queries_vis_abstractor_image, - num_input_tokens=(vision_config.image_size // vision_config.patch_size) - ** 2, - encoder_hidden_size=input_hidden_size, - hidden_size=input_hidden_size, - output_hidden_size=text_config.hidden_size, - pos_emb=config.proj_pos_emb, - prenorm=config.proj_prenorm, - ) - else: - mm_projector = HCXVisionMlp( - config.mm_projector_type, - input_hidden_size, - hidden_features=input_hidden_size, - out_features=self.text_config.hidden_size, - ) - return mm_projector - - -def unpad_image(tensor: torch.Tensor, original_size: tuple[int, int]) -> torch.Tensor: - original_width, original_height = original_size - current_height, current_width = tensor.shape[1:] - - original_aspect_ratio = original_width / original_height - current_aspect_ratio = current_width / current_height - - if original_aspect_ratio > current_aspect_ratio: - scale_factor = current_width / original_width - new_height = int(original_height * scale_factor) - padding = (current_height - new_height) // 2 - unpadded_tensor = tensor[:, padding : current_height - padding, :] - else: - scale_factor = current_height / original_height - new_width = int(original_width * scale_factor) - padding = (current_width - new_width) // 2 - unpadded_tensor = tensor[:, :, padding : current_width - padding] - - return unpadded_tensor - - -def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple: - original_height, original_width = original_size - best_fit = None - max_effective_resolution = 0 - min_wasted_resolution = float("inf") - - for height, width in possible_resolutions: - scale = min(width / original_width, height / original_height) - downscaled_width, downscaled_height = ( - int(original_width * scale), - int(original_height * scale), - ) - effective_resolution = min( - downscaled_width * downscaled_height, original_width * original_height - ) - wasted_resolution = (width * height) - effective_resolution - - if effective_resolution > max_effective_resolution or ( - effective_resolution == max_effective_resolution - and wasted_resolution < min_wasted_resolution - ): - max_effective_resolution = effective_resolution - min_wasted_resolution = wasted_resolution - best_fit = (height, width) - - return best_fit - - -def get_anyres_image_grid_shape( - image_size: tuple[int, int], - grid_pinpoints: str | list[tuple[int, int]], - patch_size: int, -) -> tuple[int, int]: - possible_resolutions = ( - grid_pinpoints - if isinstance(grid_pinpoints, list) - else ast.literal_eval(grid_pinpoints) - ) - - original_width, original_height = image_size - height, width = select_best_resolution( - (original_height, original_width), possible_resolutions - ) - return width // patch_size, height // patch_size - - -def reshape_and_unpad_image_features( - image_feature: torch.Tensor, - height: int, - width: int, - image_size: tuple[int, int], - possible_resolutions: list[tuple[int, int]], - grid_size: int, - unpad: bool, - image_newline: torch.Tensor, -) -> torch.Tensor: - base_image_feature = image_feature[0] - image_feature = image_feature[1:] - - assert height * width == base_image_feature.shape[0], ( - f"{height=} * {width=} != {base_image_feature.shape[0]=}" - ) - - num_patch_width, num_patch_height = get_anyres_image_grid_shape( - image_size, possible_resolutions, grid_size - ) - image_feature = image_feature.view( - num_patch_height, num_patch_width, height, width, -1 - ) - - if unpad: - image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() - image_feature = image_feature.flatten(1, 2).flatten(2, 3) - image_feature = unpad_image(image_feature, image_size) - image_feature = torch.cat( - ( - image_feature, - image_newline[:, None, None] - .expand(*image_feature.shape[:-1], 1) - .to(image_feature.device), - ), - dim=-1, - ) - image_feature = image_feature.flatten(1, 2).transpose(0, 1) - else: - image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() - image_feature = image_feature.flatten(0, 3) - image_feature = torch.cat((base_image_feature, image_feature), dim=0) - - return image_feature - - -def anyres_postprocessing( - image_forward_outs: list[torch.Tensor], - image_sizes: list[list[int]], - possible_resolutions: list[tuple[int, int]], - patch_size: int, - grid_size: int, - image_newline: torch.Tensor, - num_queries_vis_abstractor: int = -1, - unpad: bool = False, -) -> list[torch.Tensor]: - height = width = grid_size // patch_size - - if num_queries_vis_abstractor > 0: - assert (num_queries_vis_abstractor**0.5).is_integer(), ( - "n_queries must be square number" - ) - height = width = int(num_queries_vis_abstractor**0.5) - - # post-processing (unpad, add newline) - new_image_features = [] - for image_idx, image_feature in enumerate(image_forward_outs): - if image_feature.shape[0] > 1: - image_feature = reshape_and_unpad_image_features( - image_feature=image_feature, - height=height, - width=width, - image_size=image_sizes[image_idx], - possible_resolutions=possible_resolutions, - grid_size=grid_size, # Pass grid info if needed by helper - unpad=unpad, - image_newline=image_newline, - ) - else: - image_feature = image_feature[0] - image_feature = torch.cat( - (image_feature, image_newline[None].to(image_feature.device)), dim=0 - ) - new_image_features.append(image_feature) - - return new_image_features + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 9124f79badf1..6d99c7f0bac7 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -124,6 +124,7 @@ "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"), "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"), "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"), + "HCXVisionV2ForCausalLM": ("hyperclovax_vision", "HCXVisionV2ForCausalLM"), "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"), From c603df81479689860b702702c32fd4e335572b58 Mon Sep 17 00:00:00 2001 From: effortprogrammer Date: Sat, 10 Jan 2026 13:41:45 +0000 Subject: [PATCH 02/18] fix(hyperclovax): register HyperCLOVAXForCausalLM Signed-off-by: effortprogrammer --- vllm/model_executor/models/registry.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 6d99c7f0bac7..ff7b3ca2ac32 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -125,6 +125,7 @@ "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"), "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"), "HCXVisionV2ForCausalLM": ("hyperclovax_vision", "HCXVisionV2ForCausalLM"), + "HyperCLOVAXForCausalLM": ("llama", "LlamaForCausalLM"), "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"), From ac10e8aa677aa77538b321822c7bb4d9588786b7 Mon Sep 17 00:00:00 2001 From: effortprogrammer Date: Sun, 18 Jan 2026 01:27:11 +0900 Subject: [PATCH 03/18] fix(hyperclovax): restore V1 HCXVisionForCausalLM class for backward compatibility Signed-off-by: effortprogrammer --- .../models/hyperclovax_vision.py | 850 +++++++++++++++++- 1 file changed, 849 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/hyperclovax_vision.py b/vllm/model_executor/models/hyperclovax_vision.py index b1d99ba88548..f61d187259fd 100644 --- a/vllm/model_executor/models/hyperclovax_vision.py +++ b/vllm/model_executor/models/hyperclovax_vision.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # copied from : https://github.com/huggingface/transformers +import ast +from collections import defaultdict from collections.abc import Iterable, Mapping, Sequence from functools import partial from itertools import accumulate @@ -9,11 +11,16 @@ import numpy as np import torch import torch.nn as nn -from transformers import BatchFeature +from einops import rearrange +from timm.layers import LayerNorm, LayerNorm2d +from timm.models.regnet import RegStage +from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig +from transformers.modeling_utils import no_init_weights from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.forward_context import set_forward_context +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import ( @@ -33,11 +40,14 @@ from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape +from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .qwen2_5_vl import Qwen2_5_VisionTransformer +from .siglip import SiglipVisionModel from .utils import ( AutoWeightsLoader, WeightsMapper, + flatten_bn, init_vllm_registered_model, maybe_prefix, ) @@ -338,6 +348,844 @@ def _get_mm_fields_config( return fields +def _build_hcxvision_hf_info( + ctx: InputProcessingContext, +) -> HCXVisionProcessingInfo: + return HCXVisionProcessingInfo(ctx) + + +def _build_hcxvision_hf_processor( + info: HCXVisionProcessingInfo, + dummy_inputs: BaseDummyInputsBuilder[HCXVisionProcessingInfo], + *, + cache: BaseMultiModalProcessorCache | None = None, +) -> BaseMultiModalProcessor: + if isinstance(info, HCXVisionProcessingInfo): + return HCXVisionMultiModalProcessor( + info, + dummy_inputs, # type: ignore + cache=cache, + ) + + raise NotImplementedError(type(info)) + + +def init_vision_tower_for_hcxvision( + vision_config, + quant_config: QuantizationConfig | None, + *, + use_nth_layer: int | None = None, + require_post_norm: bool | None = None, + prefix: str = "", +) -> CLIPVisionModel | SiglipVisionModel: + num_hidden_layers = vision_config.num_hidden_layers + if not isinstance(use_nth_layer, int): + pass + elif use_nth_layer >= 0: + num_hidden_layers = use_nth_layer + 1 + else: + num_hidden_layers = num_hidden_layers + use_nth_layer + 1 + + if isinstance(vision_config, CLIPVisionConfig): + return CLIPVisionModel( + vision_config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers, + require_post_norm=require_post_norm, + prefix=prefix, + ) + elif isinstance(vision_config, SiglipVisionConfig): + return SiglipVisionModel( + vision_config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers, + require_post_norm=require_post_norm, + prefix=prefix, + ) + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +class HCXVisionMlp(nn.Module): + def __init__( + self, + mm_projector_type, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.mm_projector_type = mm_projector_type + if self.mm_projector_type == "mlp": + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + elif self.mm_projector_type == "inverted_mlp": + self.fc1 = nn.Linear(in_features, 2 * hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(2 * hidden_features, out_features) + else: + raise NotImplementedError( + "{} is not implemented".format(self.mm_projector_type) + ) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + return x + + +class HCXVisionCAbstractor(nn.Module): + """ + This module is based on C-Abstractor, whose license is under apache-2.0. + You can check the original code at + https://github.com/khanrc/honeybee/blob/main/honeybee/projectors/projectors.py + and we made necessary modifications. + """ + + def __init__( + self, + num_queries: int, + num_input_tokens: int, + encoder_hidden_size: int, + hidden_size: int, + output_hidden_size: int, + pos_emb: bool = True, + prenorm: bool = False, + ): + super().__init__() + self.num_input_tokens = num_input_tokens + self.output_hidden_size = output_hidden_size + + # Positional embedding + if pos_emb: + self.pos_emb = torch.nn.Parameter( + torch.zeros(1, num_input_tokens, encoder_hidden_size) + ) + self.pos_emb.data.normal_(mean=0.0, std=0.02) + else: + self.pos_emb = None + + # (Optional) Pre-normalization layer + if prenorm: + self.prenorm = LayerNorm(encoder_hidden_size) + else: + self.prenorm = None + + self.build_net( + num_queries, encoder_hidden_size, hidden_size, output_hidden_size + ) + self.dtype = next(self.parameters()).dtype + + def forward( + self, + x: torch.Tensor, + num_queries_vis_abstractors: list[list[int]] | None = None, + num_grids: list[int] | None = None, + ) -> torch.Tensor: + if self.prenorm is not None: + x = self.prenorm(x) + + if self.pos_emb is not None: + x = x + self.pos_emb + + x = self._forward( + x, + num_queries_vis_abstractors=num_queries_vis_abstractors, + num_grids=num_grids, + ) # (B, L, output_hidden_size) + + return x + + def _forward( + self, + x: torch.Tensor, + num_queries_vis_abstractors: list[list[int]] | None = None, + num_grids: list[int] | None = None, + ) -> torch.Tensor: + # x: [B, L, dim] + B, L, dim = x.shape + hw = int(L**0.5) + x = rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw) + + if num_queries_vis_abstractors is not None: + assert num_grids is not None + return self._forward_adaptive_num_query( + x, num_queries_vis_abstractors, num_grids + ) + + x = self.net(x) + x = rearrange(x, "b d h w -> b (h w) d") + x = self.readout(x) + return x + + def _forward_adaptive_num_query( + self, + x: torch.Tensor, + num_queries_vis_abstractors: list[list[int]] | None = None, + num_grids: list[int] | None = None, + ) -> list[torch.Tensor]: + # self.net is consisted by 3 layers (s1, sampler, s2) + assert len(self.net) == 3 + + x = self.net[0](x) # s1 + new_x = [] + for i, num_queries in enumerate(num_queries_vis_abstractors): + hw = int(num_queries**0.5) + sampler = nn.AdaptiveAvgPool2d((hw, hw)) + out = sampler(x[num_grids[i] : num_grids[i + 1], :]) + out = self.net[2](out) # s2 + + out = rearrange(out, "b d h w -> b (h w) d") + out = self.readout(out) + + new_x.append(out) + return new_x + + def build_net( + self, + n_queries: int, + encoder_hidden_size: int, + hidden_size: int, + output_hidden_size: int, + depth: int = 3, + mlp_depth: int = 2, + ): + assert (n_queries**0.5).is_integer(), ( + f"n_queries must be square number. n_queries: {n_queries}" + ) + hw = int(n_queries**0.5) + + # RegBlock = ResBlock + SE + RegBlock = partial( + RegStage, + stride=1, + dilation=1, + act_layer=nn.SiLU, + norm_layer=LayerNorm2d, + ) + + s1 = RegBlock( + depth, + encoder_hidden_size, + hidden_size, + ) + sampler = nn.AdaptiveAvgPool2d((hw, hw)) + s2 = RegBlock( + depth, + hidden_size, + hidden_size, + ) + + self.net = nn.Sequential(s1, sampler, s2) + self.readout = self.build_mlp(mlp_depth, hidden_size, output_hidden_size) + + def build_mlp( + self, + depth: int, + hidden_size: int, + output_hidden_size: int, + ): + layers = [nn.Linear(hidden_size, output_hidden_size)] + for _ in range(1, depth): + layers.append(nn.SiLU()) + layers.append(nn.Linear(output_hidden_size, output_hidden_size)) + return nn.Sequential(*layers) + + +@MULTIMODAL_REGISTRY.register_processor( + _build_hcxvision_hf_processor, + info=_build_hcxvision_hf_info, + dummy_inputs=HCXVisionDummyInputsBuilder, +) +class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): + """ + HyperCLOVAX-SEED Vision-Language Model (V1 architecture). + + Supports: + - HyperCLOVAX-SEED-Vision-Instruct-3B + + Uses CLIP/SigLIP as the vision encoder with C-Abstractor projector. + """ + + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + **kwargs: Any | None, + ) -> None: + super().__init__() + + # init configs + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + # text_config + text_config = config.text_config + if text_config.model_type in ["gpt2", "hyperclovax", "llama"]: + text_config._attn_implementation = "sdpa" + if text_config.model_type != "hyperclovax": + text_config.logits_scaling = 1.0 + # vision_config + vision_config = config.vision_config + vision_config.auto_map = {} + vision_config.anyres = config.anyres + vision_config.max_num_grids = config.max_num_grids + self.dtype = vllm_config.model_config.dtype + + ## possible_resolution should be matched with preprocessor_config.json + config.possible_resolutions = self._init_possible_resolutions( + config, vision_config + ) + + # init models & parameters + with no_init_weights(): # weight will be loaded in from_pretrained + self.vision_model = init_vision_tower_for_hcxvision( + vision_config, + quant_config, + use_nth_layer=getattr(config, "use_nth_layer", -1), + require_post_norm=False, + prefix=maybe_prefix(prefix, "vision_model"), + ) + self.mm_projector = self._init_mm_projector(config, text_config, vision_config) + + self.lm_head_vocab_size = getattr( + text_config, "padded_vocab_size", text_config.vocab_size + ) + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + + if config.anyres: + self.image_newline = nn.Parameter( + torch.empty(text_config.hidden_size, dtype=self.dtype) + ) + + self.config = config + self.vision_config = vision_config + self.text_config = text_config + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("image"): + return IMAGE_TOKEN + if modality.startswith("video"): + return VIDEO_TOKEN + + raise ValueError("Only image or video modality is supported") + + def _parse_and_validate_image_input( + self, + **kwargs: object, + ) -> HCXVisionImageInputs | None: + pixel_values_images = kwargs.pop("pixel_values_images", None) + + if pixel_values_images is None: + return None + + image_sizes_images = kwargs.pop("image_sizes_images") + + return HCXVisionImagePixelInputs( + pixel_values_images=pixel_values_images, + image_sizes_images=image_sizes_images, + ) + + def _parse_and_validate_video_input( + self, + **kwargs: object, + ) -> HCXVisionVideoInputs | None: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + + if pixel_values_videos is None: + return None + + return HCXVisionVideoPixelInputs( + pixel_values_videos=pixel_values_videos, + ) + + def _process_image_input( + self, + image_input: HCXVisionImageInputs, + ) -> tuple[torch.Tensor, ...]: + return self.forward_images( + pixel_values_images=image_input["pixel_values_images"], + image_sizes_images=image_input["image_sizes_images"], + ) + + def _process_video_input( + self, + video_input: HCXVisionVideoInputs, + ) -> tuple[torch.Tensor, ...]: + return self.forward_videos( + pixel_values_videos=video_input["pixel_values_videos"], + ) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key == "pixel_values_images" and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input(**kwargs) + if input_key == "pixel_values_videos" and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input(**kwargs) + + return modalities + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def embed_multimodal( + self, + **kwargs: object, + ) -> MultiModalEmbeddings: + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return [] + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + image_embeddings = self._process_image_input(image_input) + multimodal_embeddings += tuple(image_embeddings) + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_video_input(video_input) + multimodal_embeddings += tuple(video_embeddings) + + return multimodal_embeddings + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor | IntermediateTensors: + if intermediate_tensors is not None: + inputs_embeds = None + + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) + return hidden_states + + def forward_images( + self, + pixel_values_images: list[torch.Tensor], + image_sizes_images: torch.Tensor, + ) -> tuple[torch.Tensor, ...]: + pixel_values_image_flat = flatten_bn(pixel_values_images, concat=True) + + visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1 + image_forward_outs = self.vision_model(pixel_values_image_flat)[ + :, visual_token_idx: + ] + + image_forward_outs = image_forward_outs.to(dtype=self.mm_projector.dtype) + image_forward_outs = self.mm_projector(image_forward_outs) # b (h w) d + + split_sizes = [len(item) for item in pixel_values_images] + image_forward_outs = torch.split(image_forward_outs, split_sizes, dim=0) + + # newline for anyres postprocessing + image_features = anyres_postprocessing( + image_forward_outs=image_forward_outs, + image_sizes=image_sizes_images.tolist(), + num_queries_vis_abstractor=self.config.num_queries_vis_abstractor_image, + unpad=self.config.unpad, + patch_size=self.vision_config.patch_size, + grid_size=self.vision_config.image_size, + image_newline=self.image_newline, + possible_resolutions=self.config.possible_resolutions, + ) + + return tuple(image_features) + + def forward_videos( + self, + pixel_values_videos: list[list[torch.Tensor]], + ) -> tuple[torch.Tensor, ...]: + pixel_values_videos_flat = flatten_bn( + [frame for frames in pixel_values_videos for frame in frames], + concat=True, + ) + + visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1 + video_forward_outs = self.vision_model(pixel_values_videos_flat)[ + :, visual_token_idx: + ] + + video_forward_outs = video_forward_outs.to(dtype=self.mm_projector.dtype) + + # Run MM-Projector + # len(num_grids) == len(num_queries_vis_abstractors) + 1 + grid_idx = 0 + # e.g. [0, 9, 18, 19, 27, 28, 36, 37, 45, 46, 54, 55, 56] + num_grids = [grid_idx] + # e.g. [81, 81, 81, 9, 81, 9, 81, 9, 81, 9, 81, 9] + num_queries_vis_abstractors = [] + len_total_frames = video_forward_outs.shape[0] + + if self.config.first_last_frames_slow: + # slowfast (first_last_frames_slow) + assert len_total_frames != 0 + if len_total_frames <= 2: + num_queries_vis_abstractors.append( + self.config.num_queries_vis_abstractor_video_slow + ) + grid_idx += len_total_frames + num_grids.append(grid_idx) + else: + num_queries_vis_abstractors.append( + self.config.num_queries_vis_abstractor_video_slow + ) + grid_idx += 1 + num_grids.append(grid_idx) + + num_queries_vis_abstractors.append( + self.config.num_queries_vis_abstractor_video_fast + ) + grid_idx += len_total_frames - 2 + num_grids.append(grid_idx) + + num_queries_vis_abstractors.append( + self.config.num_queries_vis_abstractor_video_slow + ) + grid_idx += 1 + num_grids.append(grid_idx) + else: + # slowfast + for pixel_values_frames in pixel_values_videos: + for pixel_values_frame in pixel_values_frames: + if len(pixel_values_frame) > 0: + num_queries_vis_abstractors.append( + self.config.num_queries_vis_abstractor_video_slow + ) + grid_idx += 1 + num_grids.append(grid_idx) + num_queries_vis_abstractors.append( + self.config.num_queries_vis_abstractor_video_fast + ) + grid_idx = grid_idx + len(pixel_values_frame) - 1 + num_grids.append(grid_idx) + + video_forward_outs = self.mm_projector( + video_forward_outs, num_queries_vis_abstractors, num_grids + ) + + video_features = [] # what we want to return + target_features = [] + target_group_size = 0 + group_counter = 0 + video_groups = [ + len(frame) for frames in pixel_values_videos for frame in frames + ] # for concat video features after projector + + for forward_out in video_forward_outs: + target_group_size += len(forward_out) + target_features.append(forward_out.flatten(0, 1)) + + video_group_size = video_groups[group_counter] + if video_group_size == target_group_size: + video_features.append(torch.cat(target_features, dim=0)) + target_features = [] + group_counter += 1 + target_group_size = 0 + + elif video_group_size < target_group_size: + raise RuntimeError(f"{video_group_size=} < {target_group_size=}") + + assert len(target_features) == 0, ( + f"target_features is not empty!! {target_features}" + ) + assert len(video_groups) == len(video_features) + + feats_per_video = [len(video) for video in pixel_values_videos] + idxs_per_video = [0, *accumulate(feats_per_video)] + return tuple( + torch.cat(video_features[idxs_per_video[i] : idxs_per_video[i + 1]]) + for i in range(len(feats_per_video)) + ) + + def _prepare_multimodal_kwargs(self, **kwargs: object): + output = defaultdict(list) + for k, v in kwargs.items(): + if len(v) < 1 or len(v[0]) < 1: + continue # if empty batch of empty sample + + new_k, is_video = k, False + if not k.endswith("_images") and not k.endswith("_videos"): + pass + else: + new_k, is_video = k.split("_")[:-1], k.split("_")[-1] + new_k = "_".join(new_k) + is_video = is_video == "videos" + + for _sample_idx, _v in enumerate(v): # batch -> sample + if new_k not in ["pixel_values"]: + if len(output[new_k]) < _sample_idx + 1: + output[new_k].append(list()) + _v = _v.detach().cpu().numpy().tolist() + output[new_k][_sample_idx] += _v + elif isinstance(_v, torch.Tensor): + if len(output[new_k]) < _sample_idx + 1: + output[new_k].append(list()) + output["is_videos"].append(list()) + _v = list(torch.unbind(_v, dim=0)) + output[new_k][_sample_idx] += _v + output["is_videos"][_sample_idx] += [ + is_video, + ] * len(_v) + return dict(output) + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + return self.language_model.compute_logits(hidden_states) + + def load_weights( + self, + weights: Iterable[tuple[str, torch.Tensor]], + ) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + def _init_possible_resolutions( + self, + config, + vision_config, + ): + if not getattr(config, "possible_resolutions", []): + possible_resolutions = [] + if config.anyres: + assert config.max_num_grids > 0 + for i in range(1, config.max_num_grids + 1): + for j in range(1, config.max_num_grids + 1): + if i == 1 and j == 1 and not config.use_1x1_grid: + continue + if i * j <= config.max_num_grids: + possible_resolutions.append([i, j]) + + possible_resolutions = [ + [ys * vision_config.image_size, xs * vision_config.image_size] + for ys, xs in possible_resolutions + ] + return possible_resolutions + else: + return config.possible_resolutions + + def _init_mm_projector( + self, + config, + text_config, + vision_config, + ): + input_hidden_size = vision_config.hidden_size + if config.mm_projector_type == "linear": + mm_projector = nn.Linear(input_hidden_size, text_config.hidden_size) + mm_projector.dtype = next(mm_projector.parameters()).dtype + elif config.mm_projector_type == "cabstractor": + mm_projector = HCXVisionCAbstractor( + num_queries=config.num_queries_vis_abstractor_image, + num_input_tokens=(vision_config.image_size // vision_config.patch_size) + ** 2, + encoder_hidden_size=input_hidden_size, + hidden_size=input_hidden_size, + output_hidden_size=text_config.hidden_size, + pos_emb=config.proj_pos_emb, + prenorm=config.proj_prenorm, + ) + else: + mm_projector = HCXVisionMlp( + config.mm_projector_type, + input_hidden_size, + hidden_features=input_hidden_size, + out_features=self.text_config.hidden_size, + ) + return mm_projector + + +def unpad_image(tensor: torch.Tensor, original_size: tuple[int, int]) -> torch.Tensor: + original_width, original_height = original_size + current_height, current_width = tensor.shape[1:] + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding : current_height - padding, :] + else: + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding : current_width - padding] + + return unpadded_tensor + + +def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple: + original_height, original_width = original_size + best_fit = None + max_effective_resolution = 0 + min_wasted_resolution = float("inf") + + for height, width in possible_resolutions: + scale = min(width / original_width, height / original_height) + downscaled_width, downscaled_height = ( + int(original_width * scale), + int(original_height * scale), + ) + effective_resolution = min( + downscaled_width * downscaled_height, original_width * original_height + ) + wasted_resolution = (width * height) - effective_resolution + + if effective_resolution > max_effective_resolution or ( + effective_resolution == max_effective_resolution + and wasted_resolution < min_wasted_resolution + ): + max_effective_resolution = effective_resolution + min_wasted_resolution = wasted_resolution + best_fit = (height, width) + + return best_fit + + +def get_anyres_image_grid_shape( + image_size: tuple[int, int], + grid_pinpoints: str | list[tuple[int, int]], + patch_size: int, +) -> tuple[int, int]: + possible_resolutions = ( + grid_pinpoints + if isinstance(grid_pinpoints, list) + else ast.literal_eval(grid_pinpoints) + ) + + original_width, original_height = image_size + height, width = select_best_resolution( + (original_height, original_width), possible_resolutions + ) + return width // patch_size, height // patch_size + + +def reshape_and_unpad_image_features( + image_feature: torch.Tensor, + height: int, + width: int, + image_size: tuple[int, int], + possible_resolutions: list[tuple[int, int]], + grid_size: int, + unpad: bool, + image_newline: torch.Tensor, +) -> torch.Tensor: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + + assert height * width == base_image_feature.shape[0], ( + f"{height=} * {width=} != {base_image_feature.shape[0]=}" + ) + + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + image_size, possible_resolutions, grid_size + ) + image_feature = image_feature.view( + num_patch_height, num_patch_width, height, width, -1 + ) + + if unpad: + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_size) + image_feature = torch.cat( + ( + image_feature, + image_newline[:, None, None] + .expand(*image_feature.shape[:-1], 1) + .to(image_feature.device), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + else: + image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() + image_feature = image_feature.flatten(0, 3) + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + + return image_feature + + +def anyres_postprocessing( + image_forward_outs: list[torch.Tensor], + image_sizes: list[list[int]], + possible_resolutions: list[tuple[int, int]], + patch_size: int, + grid_size: int, + image_newline: torch.Tensor, + num_queries_vis_abstractor: int = -1, + unpad: bool = False, +) -> list[torch.Tensor]: + height = width = grid_size // patch_size + + if num_queries_vis_abstractor > 0: + assert (num_queries_vis_abstractor**0.5).is_integer(), ( + "n_queries must be square number" + ) + height = width = int(num_queries_vis_abstractor**0.5) + + # post-processing (unpad, add newline) + new_image_features = [] + for image_idx, image_feature in enumerate(image_forward_outs): + if image_feature.shape[0] > 1: + image_feature = reshape_and_unpad_image_features( + image_feature=image_feature, + height=height, + width=width, + image_size=image_sizes[image_idx], + possible_resolutions=possible_resolutions, + grid_size=grid_size, # Pass grid info if needed by helper + unpad=unpad, + image_newline=image_newline, + ) + else: + image_feature = image_feature[0] + image_feature = torch.cat( + (image_feature, image_newline[None].to(image_feature.device)), dim=0 + ) + new_image_features.append(image_feature) + + return new_image_features + + # ============================================================================= # HyperCLOVAX V2 (32B Think Model) Support # Uses Qwen2.5 Vision Transformer instead of CLIP/SigLIP From 6fedc9e2c9ede30fb754e452064d77810ef09d6d Mon Sep 17 00:00:00 2001 From: effortprogrammer Date: Thu, 19 Feb 2026 11:36:03 +0900 Subject: [PATCH 04/18] feat(hyperclovax): add dedicated V2 vision module Signed-off-by: effortprogrammer --- .../models/hyperclovax_vision_v2.py | 727 ++++++++++++++++++ 1 file changed, 727 insertions(+) create mode 100644 vllm/model_executor/models/hyperclovax_vision_v2.py diff --git a/vllm/model_executor/models/hyperclovax_vision_v2.py b/vllm/model_executor/models/hyperclovax_vision_v2.py new file mode 100644 index 000000000000..ab546cd426ad --- /dev/null +++ b/vllm/model_executor/models/hyperclovax_vision_v2.py @@ -0,0 +1,727 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +HyperCLOVAX V2 (32B Think Model) Implementation. + +This module contains the V2 architecture that uses Qwen2.5 Vision Transformer +instead of CLIP/SigLIP used in V1. + +Supports: +- HyperCLOVAX-SEED-Think-32B: Vision + Text +- HyperCLOVAX-SEED-Omni-8B: Vision + Audio + Text +""" + +from collections.abc import Iterable, Mapping, Sequence +from functools import partial +from typing import Annotated, Any, Literal + +import torch +import torch.nn as nn +from transformers import BatchFeature + +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.forward_context import set_forward_context +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import BaseMultiModalProcessorCache +from vllm.multimodal.inputs import ( + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ImageSize, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + InputProcessingContext, + PromptReplacement, + PromptUpdate, +) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .qwen2_5_vl import Qwen2_5_VisionTransformer +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) + +# V2 (32B Think model) uses different tokens - retrieved from config at runtime +# These placeholder strings must match the chat template format exactly. +# The chat template produces: <|image_start|><|IMAGE_PAD|><|image_end|> +# Similar to Qwen2-VL's <|vision_start|><|image_pad|><|vision_end|> format. +V2_IMAGE_TOKEN: str = "<|image_start|><|IMAGE_PAD|><|image_end|>" +V2_VIDEO_TOKEN: str = "<|video_start|><|VIDEO_PAD|><|video_end|>" + + +class HCXVisionV2ImagePixelInputs(TensorSchema): + """ + V2 Image inputs using Qwen2.5-VL style grid_thw format. + + Dimensions: + - np: Number of patches + - ni: Number of images + - cps: Number of channels * patch_size * patch_size + """ + + type: Literal["pixel_values"] = "pixel_values" + pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")] + image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] + + +class HCXVisionV2ImageEmbeddingInputs(TensorSchema): + """ + V2 Image embedding inputs. + + Dimensions: + - nf: Number of image features + - hs: Hidden size + - ni: Number of images + """ + + type: Literal["image_embeds"] = "image_embeds" + image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")] + image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] + + +HCXVisionV2ImageInputs = HCXVisionV2ImagePixelInputs | HCXVisionV2ImageEmbeddingInputs + + +class HCXVisionV2VideoPixelInputs(TensorSchema): + """ + V2 Video inputs using Qwen2.5-VL style grid_thw format. + + Dimensions: + - np: Number of patches + - nv: Number of videos + - ctps: Number of channels * temporal_patch_size * patch_size * patch_size + """ + + type: Literal["pixel_values_videos"] = "pixel_values_videos" + pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "ctps")] + video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] + + +class HCXVisionV2VideoEmbeddingInputs(TensorSchema): + """ + V2 Video embedding inputs. + + Dimensions: + - nf: Number of video features + - hs: Hidden size + - nv: Number of videos + """ + + type: Literal["video_embeds"] = "video_embeds" + video_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")] + video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] + + +HCXVisionV2VideoInputs = HCXVisionV2VideoPixelInputs | HCXVisionV2VideoEmbeddingInputs + + +class HCXVisionV2ProcessingInfo(BaseProcessingInfo): + """Processing info for HyperCLOVAX V2 (32B Think model).""" + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"image": None, "video": None} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + patch_size = vision_config.patch_size + spatial_merge_size = vision_config.spatial_merge_size + + grid_h = image_height // patch_size + grid_w = image_width // patch_size + + return (grid_h * grid_w) // (spatial_merge_size**2) + + def get_num_video_tokens( + self, + *, + video_width: int, + video_height: int, + num_frames: int, + ) -> int: + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + patch_size = vision_config.patch_size + temporal_patch_size = vision_config.temporal_patch_size + spatial_merge_size = vision_config.spatial_merge_size + + grid_t = num_frames // temporal_patch_size + grid_h = video_height // patch_size + grid_w = video_width // patch_size + + return (grid_t * grid_h * grid_w) // (spatial_merge_size**2) + + def get_image_size_with_most_features(self) -> ImageSize: + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + # Use a reasonable default size + size = getattr(vision_config, "image_size", 448) + return ImageSize(width=size, height=size) + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + ) + + +class HCXVisionV2DummyInputsBuilder(BaseDummyInputsBuilder[HCXVisionV2ProcessingInfo]): + """Dummy inputs builder for HyperCLOVAX V2 memory profiling.""" + + def get_dummy_text( + self, + mm_counts: Mapping[str, int], + ) -> str: + # This method is not used when get_dummy_processor_inputs is overridden, + # but we keep it for compatibility. + return "" + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, object] | None = None, + ): + """ + Override to use token IDs directly instead of text strings. + + This avoids the tokenizer issue where <|IMAGE_PAD|> might not be + recognized as a special token and gets split into multiple tokens. + By passing token IDs directly, we ensure the correct token (128060) + is used for prompt replacement matching. + """ + from vllm.multimodal.profiling import ProcessorInputs + + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + hf_config = self.info.get_hf_config() + + # Use token IDs directly to avoid tokenizer issues with special tokens + image_token_id = hf_config.image_token_id # 128060 + video_token_id = hf_config.video_token_id # 128061 + + # Create prompt as token ID list instead of text string + prompt_ids: list[int] = [image_token_id] * num_images + [ + video_token_id + ] * num_videos + + dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options) + + return ProcessorInputs( + prompt=prompt_ids, + mm_data=dummy_mm_data, + tokenization_kwargs={"truncation": False}, + ) + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + ) -> dict: + from vllm.multimodal.inputs import MultiModalDataDict + + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + target_width, target_height = self.info.get_image_size_with_most_features() + target_num_frames = 16 # Default for video + + image_overrides = mm_options.get("image") if mm_options else None + video_overrides = mm_options.get("video") if mm_options else None + + result: MultiModalDataDict = { + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, # type: ignore + ), + "video": self._get_dummy_videos( + width=target_width, + height=target_height, + num_frames=target_num_frames, + num_videos=num_videos, + overrides=video_overrides, # type: ignore + ), + } + + return result + + +class HCXVisionV2MultiModalProcessor( + BaseMultiModalProcessor[HCXVisionV2ProcessingInfo] +): + """Multimodal processor for HyperCLOVAX V2 (32B Think model).""" + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + images = mm_data.get("images") + videos = mm_data.get("videos") + + # Get the HF processor + hf_processor = self.info.get_hf_processor(**mm_kwargs) + + # Build data dict for HF processor (images/videos only) + # The HF processor (HCXVisionV2Processor) doesn't support audio + # NOTE: We pass the prompt as-is without token normalization. + # Token expansion is handled by vLLM via _get_prompt_updates since + # _hf_processor_applies_updates returns False. + data: dict[str, object] = dict( + text=prompt, + images=images, + videos=videos, + ) + + processed_outputs = self.info.ctx.call_hf_processor( + hf_processor=hf_processor, + data=data, + ) + + return processed_outputs + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + # HyperCLOVAX V2 has a token case mismatch bug: + # - Chat template uses <|IMAGE_PAD|> (uppercase) + # - HF processor (Qwen2_5_VLProcessor) expects <|image_pad|> (lowercase) + # - Tokenizer vocab has <|IMAGE_PAD|> (uppercase) = token ID 128060 + # + # The HF processor's token expansion fails because it looks for lowercase + # but the tokenized prompt has uppercase tokens. We bypass HF processor's + # expansion and let vLLM handle it via _get_prompt_updates using the + # correct token IDs from hf_config. + return False + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_config = self.info.get_hf_config() + + # Use token IDs directly from config. + # This matches what get_dummy_processor_inputs uses, ensuring consistency. + placeholder: dict[str, int] = { + "image": hf_config.image_token_id, # 128060 for <|IMAGE_PAD|> + "video": hf_config.video_token_id, # 128061 for <|VIDEO_PAD|> + } + + merge_size = hf_config.vision_config.spatial_merge_size + + def get_replacement_v2( + item_idx: int, + modality: str, + out_mm_kwargs: MultiModalKwargsItems, + ): + out_item = out_mm_kwargs[modality][item_idx] + + if modality == "image": + grid_thw_elem = out_item.get("image_grid_thw") + if grid_thw_elem is not None: + # Access .data to get the actual tensor from MultiModalFieldElem + grid_thw = grid_thw_elem.data + # Qwen2.5-VL style calculation + h, w = grid_thw[1].item(), grid_thw[2].item() + num_tokens = (h * w) // (merge_size**2) + else: + # Fallback or error + raise ValueError("Missing image_grid_thw for V2 model") + elif modality == "video": + grid_thw_elem = out_item.get("video_grid_thw") + if grid_thw_elem is not None: + # Access .data to get the actual tensor from MultiModalFieldElem + grid_thw = grid_thw_elem.data + t, h, w = grid_thw[0].item(), grid_thw[1].item(), grid_thw[2].item() + num_tokens = (t * h * w) // (merge_size**2) + else: + raise ValueError("Missing video_grid_thw for V2 model") + else: + raise NotImplementedError(modality) + + return [placeholder[modality]] * num_tokens + + return [ + PromptReplacement( + modality=modality, + target=[ + placeholder[modality], + ], + replacement=partial( + get_replacement_v2, + modality=modality, + out_mm_kwargs=out_mm_kwargs, + ), + ) + for modality in ("image", "video") + ] + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + # HyperCLOVAX V2 uses Qwen2.5-VL style flattened pixel values where + # pixel_values has shape (num_patches, channels*patch_size*patch_size) + # while image_grid_thw has shape (num_images, 3). + # We need to use flat_from_sizes to correctly handle this mismatch. + hf_config = self.info.get_hf_config() + spatial_merge_size = hf_config.vision_config.spatial_merge_size + + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_pixel_grid_sizes = image_grid_thw.prod(-1) + image_embed_grid_sizes = ( + image_pixel_grid_sizes // spatial_merge_size // spatial_merge_size + ) + + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_pixel_grid_sizes = video_grid_thw.prod(-1) + video_embed_grid_sizes = ( + video_pixel_grid_sizes // spatial_merge_size // spatial_merge_size + ) + + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_pixel_grid_sizes + ), + image_embeds=MultiModalFieldConfig.flat_from_sizes( + "image", image_embed_grid_sizes + ), + image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_pixel_grid_sizes + ), + video_embeds=MultiModalFieldConfig.flat_from_sizes( + "video", video_embed_grid_sizes + ), + video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True), + ) + + +def _build_hcxvision_v2_hf_info( + ctx: InputProcessingContext, +) -> HCXVisionV2ProcessingInfo: + return HCXVisionV2ProcessingInfo(ctx) + + +def _build_hcxvision_v2_hf_processor( + info: HCXVisionV2ProcessingInfo, + dummy_inputs: BaseDummyInputsBuilder[HCXVisionV2ProcessingInfo], + *, + cache: BaseMultiModalProcessorCache | None = None, +) -> BaseMultiModalProcessor: + return HCXVisionV2MultiModalProcessor( + info, + dummy_inputs, # type: ignore + cache=cache, + ) + + +@MULTIMODAL_REGISTRY.register_processor( + _build_hcxvision_v2_hf_processor, + info=_build_hcxvision_v2_hf_info, + dummy_inputs=HCXVisionV2DummyInputsBuilder, +) +class HCXVisionV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): + """ + HyperCLOVAX-SEED Vision-Language Model (V2 architecture). + + Supports: + - HyperCLOVAX-SEED-Think-32B: Vision + Text + - HyperCLOVAX-SEED-Omni-8B: Vision + Audio + Text + + Uses Qwen2.5 Vision Transformer as the vision encoder. + """ + + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + "qkv": ["qkv"], # For vision tower + } + + # Weight mapping for loading HuggingFace checkpoints + # NOTE: Order matters! Ignores (None) should come before renames to prevent + # partial matches + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.": "", # Remove model. prefix if present + "vision_model.": "visual.", # HF uses vision_model, we use visual + }, + orig_to_new_substr={ + # Ignore modules not implemented in vLLM + "discrete_vision_model": None, # TextAlignedTokenizer + }, + ) + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + **kwargs: Any | None, + ) -> None: + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + # Text config + text_config = config.text_config + if text_config.model_type in ["gpt2", "hyperclovax", "llama"]: + text_config._attn_implementation = "sdpa" + if text_config.model_type != "hyperclovax": + text_config.logits_scaling = 1.0 + + # Vision config + vision_config = config.vision_config + + self.config = config + self.vision_config = vision_config + self.text_config = text_config + self.vllm_config = vllm_config + self.dtype = vllm_config.model_config.dtype + + # Initialize Qwen2.5 Vision Transformer + self.visual = Qwen2_5_VisionTransformer( + vision_config=vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + multimodal_config=multimodal_config, + prefix=maybe_prefix(prefix, "visual"), + ) + + # Linear projector (vision_hidden_size -> text_hidden_size) + # For V2 model: mm_projector_type is "linear" + vision_hidden_size = vision_config.hidden_size + text_hidden_size = text_config.hidden_size + + # Check if out_hidden_size is defined (Qwen2.5-VL style) + # The merger in Qwen2.5 VisionTransformer handles projection to out_hidden_size + if hasattr(vision_config, "out_hidden_size"): + out_hidden = vision_config.out_hidden_size + else: + out_hidden = vision_hidden_size + + # Always create Linear projector since HF checkpoint has mm_projector weights + self.mm_projector = nn.Linear(out_hidden, text_hidden_size) + + # Language model + self.lm_head_vocab_size = getattr( + text_config, "padded_vocab_size", text_config.vocab_size + ) + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("image"): + return V2_IMAGE_TOKEN + if modality.startswith("video"): + return V2_VIDEO_TOKEN + + raise ValueError("Only image or video modality is supported") + + def _parse_and_validate_image_input( + self, + **kwargs: object, + ) -> HCXVisionV2ImageInputs | None: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + return HCXVisionV2ImagePixelInputs( + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + + if image_embeds is not None: + return HCXVisionV2ImageEmbeddingInputs( + image_embeds=image_embeds, + image_grid_thw=image_grid_thw, + ) + + return None + + def _parse_and_validate_video_input( + self, + **kwargs: object, + ) -> HCXVisionV2VideoInputs | None: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_embeds = kwargs.pop("video_embeds", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + + if pixel_values_videos is None and video_embeds is None: + return None + + if pixel_values_videos is not None: + return HCXVisionV2VideoPixelInputs( + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + ) + + if video_embeds is not None: + return HCXVisionV2VideoEmbeddingInputs( + video_embeds=video_embeds, + video_grid_thw=video_grid_thw, + ) + + return None + + def _process_image_input( + self, + image_input: HCXVisionV2ImageInputs, + ) -> tuple[torch.Tensor, ...]: + """Process images through Qwen2.5 ViT and projector.""" + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() + + if image_input["type"] == "image_embeds": + image_embeds = image_input["image_embeds"].type(self.visual.dtype) + else: + pixel_values = image_input["pixel_values"] + with set_forward_context(None, self.vllm_config): + image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) + + # Apply projector + image_embeds = self.mm_projector(image_embeds) + + # Split concatenated embeddings for each image + merge_size = self.visual.spatial_merge_size + sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() + return image_embeds.split(sizes) + + def _process_video_input( + self, + video_input: HCXVisionV2VideoInputs, + ) -> tuple[torch.Tensor, ...]: + """Process videos through Qwen2.5 ViT and projector.""" + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() + + if video_input["type"] == "video_embeds": + video_embeds = video_input["video_embeds"].type(self.visual.dtype) + else: + pixel_values_videos = video_input["pixel_values_videos"] + with set_forward_context(None, self.vllm_config): + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list) + + # Apply projector + video_embeds = self.mm_projector(video_embeds) + + # Split concatenated embeddings for each video + merge_size = self.visual.spatial_merge_size + sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() + return video_embeds.split(sizes) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + for input_key in kwargs: + if ( + input_key in ("pixel_values", "image_embeds") + and "image" not in modalities + ): + modalities["image"] = self._parse_and_validate_image_input(**kwargs) + if ( + input_key in ("pixel_values_videos", "video_embeds") + and "video" not in modalities + ): + modalities["video"] = self._parse_and_validate_video_input(**kwargs) + + return modalities + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def embed_multimodal( + self, + **kwargs: object, + ) -> MultiModalEmbeddings: + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return [] + + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + for modality in modalities: + if modality == "image": + image_input = modalities["image"] + if image_input is not None: + image_embeddings = self._process_image_input(image_input) + multimodal_embeddings += tuple(image_embeddings) + if modality == "video": + video_input = modalities["video"] + if video_input is not None: + video_embeddings = self._process_video_input(video_input) + multimodal_embeddings += tuple(video_embeddings) + + return multimodal_embeddings + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor | IntermediateTensors: + if intermediate_tensors is not None: + inputs_embeds = None + + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + return self.language_model.compute_logits(hidden_states) + + def load_weights( + self, + weights: Iterable[tuple[str, torch.Tensor]], + ) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) From 88d6efd6b22aeaf6bf829b58e03d895dd8cae4b0 Mon Sep 17 00:00:00 2001 From: effortprogrammer Date: Thu, 19 Feb 2026 11:36:08 +0900 Subject: [PATCH 05/18] refactor(hyperclovax): isolate V1 implementation and reroute V2 registry Signed-off-by: effortprogrammer --- .../models/hyperclovax_vision.py | 684 ------------------ vllm/model_executor/models/registry.py | 2 +- 2 files changed, 1 insertion(+), 685 deletions(-) diff --git a/vllm/model_executor/models/hyperclovax_vision.py b/vllm/model_executor/models/hyperclovax_vision.py index f61d187259fd..81be706737d1 100644 --- a/vllm/model_executor/models/hyperclovax_vision.py +++ b/vllm/model_executor/models/hyperclovax_vision.py @@ -19,7 +19,6 @@ from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions -from vllm.forward_context import set_forward_context from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import BaseMultiModalProcessorCache @@ -42,11 +41,9 @@ from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .qwen2_5_vl import Qwen2_5_VisionTransformer from .siglip import SiglipVisionModel from .utils import ( AutoWeightsLoader, - WeightsMapper, flatten_bn, init_vllm_registered_model, maybe_prefix, @@ -57,13 +54,6 @@ IMAGE_TOKEN: str = "<|dummy3|>" VIDEO_TOKEN: str = "<|_unuse_missing_100270|>" -# V2 (32B Think model) uses different tokens - retrieved from config at runtime -# These placeholder strings must match the chat template format exactly. -# The chat template produces: <|image_start|><|IMAGE_PAD|><|image_end|> -# Similar to Qwen2-VL's <|vision_start|><|image_pad|><|vision_end|> format. -V2_IMAGE_TOKEN: str = "<|image_start|><|IMAGE_PAD|><|image_end|>" -V2_VIDEO_TOKEN: str = "<|video_start|><|VIDEO_PAD|><|video_end|>" - # Based on combine_frames_into_images in # https://huggingface.co/naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B/blob/main/processing_hyperclovax.py @@ -1184,677 +1174,3 @@ def anyres_postprocessing( new_image_features.append(image_feature) return new_image_features - - -# ============================================================================= -# HyperCLOVAX V2 (32B Think Model) Support -# Uses Qwen2.5 Vision Transformer instead of CLIP/SigLIP -# ============================================================================= - - -class HCXVisionV2ImagePixelInputs(TensorSchema): - """ - V2 Image inputs using Qwen2.5-VL style grid_thw format. - - Dimensions: - - np: Number of patches - - ni: Number of images - - cps: Number of channels * patch_size * patch_size - """ - - type: Literal["pixel_values"] = "pixel_values" - pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")] - image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] - - -class HCXVisionV2ImageEmbeddingInputs(TensorSchema): - """ - V2 Image embedding inputs. - - Dimensions: - - nf: Number of image features - - hs: Hidden size - - ni: Number of images - """ - - type: Literal["image_embeds"] = "image_embeds" - image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")] - image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] - - -HCXVisionV2ImageInputs = HCXVisionV2ImagePixelInputs | HCXVisionV2ImageEmbeddingInputs - - -class HCXVisionV2VideoPixelInputs(TensorSchema): - """ - V2 Video inputs using Qwen2.5-VL style grid_thw format. - - Dimensions: - - np: Number of patches - - nv: Number of videos - - ctps: Number of channels * temporal_patch_size * patch_size * patch_size - """ - - type: Literal["pixel_values_videos"] = "pixel_values_videos" - pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "ctps")] - video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] - - -class HCXVisionV2VideoEmbeddingInputs(TensorSchema): - """ - V2 Video embedding inputs. - - Dimensions: - - nf: Number of video features - - hs: Hidden size - - nv: Number of videos - """ - - type: Literal["video_embeds"] = "video_embeds" - video_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")] - video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] - - -HCXVisionV2VideoInputs = HCXVisionV2VideoPixelInputs | HCXVisionV2VideoEmbeddingInputs - - -class HCXVisionV2ProcessingInfo(BaseProcessingInfo): - """Processing info for HyperCLOVAX V2 (32B Think model).""" - - def get_supported_mm_limits(self) -> Mapping[str, int | None]: - return {"image": None, "video": None} - - def get_num_image_tokens( - self, - *, - image_width: int, - image_height: int, - ) -> int: - hf_config = self.get_hf_config() - vision_config = hf_config.vision_config - patch_size = vision_config.patch_size - spatial_merge_size = vision_config.spatial_merge_size - - grid_h = image_height // patch_size - grid_w = image_width // patch_size - - return (grid_h * grid_w) // (spatial_merge_size**2) - - def get_num_video_tokens( - self, - *, - video_width: int, - video_height: int, - num_frames: int, - ) -> int: - hf_config = self.get_hf_config() - vision_config = hf_config.vision_config - patch_size = vision_config.patch_size - temporal_patch_size = vision_config.temporal_patch_size - spatial_merge_size = vision_config.spatial_merge_size - - grid_t = num_frames // temporal_patch_size - grid_h = video_height // patch_size - grid_w = video_width // patch_size - - return (grid_t * grid_h * grid_w) // (spatial_merge_size**2) - - def get_image_size_with_most_features(self) -> ImageSize: - hf_config = self.get_hf_config() - vision_config = hf_config.vision_config - # Use a reasonable default size - size = getattr(vision_config, "image_size", 448) - return ImageSize(width=size, height=size) - - def get_max_image_tokens(self) -> int: - target_width, target_height = self.get_image_size_with_most_features() - return self.get_num_image_tokens( - image_width=target_width, - image_height=target_height, - ) - - -class HCXVisionV2DummyInputsBuilder(BaseDummyInputsBuilder[HCXVisionV2ProcessingInfo]): - """Dummy inputs builder for HyperCLOVAX V2 memory profiling.""" - - def get_dummy_text( - self, - mm_counts: Mapping[str, int], - ) -> str: - # This method is not used when get_dummy_processor_inputs is overridden, - # but we keep it for compatibility. - return "" - - def get_dummy_processor_inputs( - self, - seq_len: int, - mm_counts: Mapping[str, int], - mm_options: Mapping[str, object] | None = None, - ): - """ - Override to use token IDs directly instead of text strings. - - This avoids the tokenizer issue where <|IMAGE_PAD|> might not be - recognized as a special token and gets split into multiple tokens. - By passing token IDs directly, we ensure the correct token (128060) - is used for prompt replacement matching. - """ - from vllm.multimodal.profiling import ProcessorInputs - - num_images = mm_counts.get("image", 0) - num_videos = mm_counts.get("video", 0) - - hf_config = self.info.get_hf_config() - - # Use token IDs directly to avoid tokenizer issues with special tokens - image_token_id = hf_config.image_token_id # 128060 - video_token_id = hf_config.video_token_id # 128061 - - # Create prompt as token ID list instead of text string - prompt_ids: list[int] = [image_token_id] * num_images + [ - video_token_id - ] * num_videos - - dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options) - - return ProcessorInputs( - prompt=prompt_ids, - mm_data=dummy_mm_data, - tokenization_kwargs={"truncation": False}, - ) - - def get_dummy_mm_data( - self, - seq_len: int, - mm_counts: Mapping[str, int], - mm_options: Mapping[str, BaseDummyOptions] | None = None, - ) -> MultiModalDataDict: - num_images = mm_counts.get("image", 0) - num_videos = mm_counts.get("video", 0) - - target_width, target_height = self.info.get_image_size_with_most_features() - target_num_frames = 16 # Default for video - - image_overrides = mm_options.get("image") if mm_options else None - video_overrides = mm_options.get("video") if mm_options else None - - result: MultiModalDataDict = { - "image": self._get_dummy_images( - width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides, # type: ignore - ), - "video": self._get_dummy_videos( - width=target_width, - height=target_height, - num_frames=target_num_frames, - num_videos=num_videos, - overrides=video_overrides, # type: ignore - ), - } - - return result - - -class HCXVisionV2MultiModalProcessor( - BaseMultiModalProcessor[HCXVisionV2ProcessingInfo] -): - """Multimodal processor for HyperCLOVAX V2 (32B Think model).""" - - def _call_hf_processor( - self, - prompt: str, - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object], - tok_kwargs: Mapping[str, object], - ) -> BatchFeature: - images = mm_data.get("images") - videos = mm_data.get("videos") - - # Get the HF processor - hf_processor = self.info.get_hf_processor(**mm_kwargs) - - # Build data dict for HF processor (images/videos only) - # The HF processor (HCXVisionV2Processor) doesn't support audio - # NOTE: We pass the prompt as-is without token normalization. - # Token expansion is handled by vLLM via _get_prompt_updates since - # _hf_processor_applies_updates returns False. - data: dict[str, object] = dict( - text=prompt, - images=images, - videos=videos, - ) - - processed_outputs = self.info.ctx.call_hf_processor( - hf_processor=hf_processor, - data=data, - ) - - return processed_outputs - - def _hf_processor_applies_updates( - self, - prompt_text: str, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - tokenization_kwargs: Mapping[str, object], - ) -> bool: - # HyperCLOVAX V2 has a token case mismatch bug: - # - Chat template uses <|IMAGE_PAD|> (uppercase) - # - HF processor (Qwen2_5_VLProcessor) expects <|image_pad|> (lowercase) - # - Tokenizer vocab has <|IMAGE_PAD|> (uppercase) = token ID 128060 - # - # The HF processor's token expansion fails because it looks for lowercase - # but the tokenized prompt has uppercase tokens. We bypass HF processor's - # expansion and let vLLM handle it via _get_prompt_updates using the - # correct token IDs from hf_config. - return False - - def _get_prompt_updates( - self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargsItems, - ) -> Sequence[PromptUpdate]: - hf_config = self.info.get_hf_config() - - # Use token IDs directly from config. - # This matches what get_dummy_processor_inputs uses, ensuring consistency. - placeholder: dict[str, int] = { - "image": hf_config.image_token_id, # 128060 for <|IMAGE_PAD|> - "video": hf_config.video_token_id, # 128061 for <|VIDEO_PAD|> - } - - merge_size = hf_config.vision_config.spatial_merge_size - - def get_replacement_v2( - item_idx: int, - modality: str, - out_mm_kwargs: MultiModalKwargsItems, - ): - out_item = out_mm_kwargs[modality][item_idx] - - if modality == "image": - grid_thw_elem = out_item.get("image_grid_thw") - if grid_thw_elem is not None: - # Access .data to get the actual tensor from MultiModalFieldElem - grid_thw = grid_thw_elem.data - # Qwen2.5-VL style calculation - h, w = grid_thw[1].item(), grid_thw[2].item() - num_tokens = (h * w) // (merge_size**2) - else: - # Fallback or error - raise ValueError("Missing image_grid_thw for V2 model") - elif modality == "video": - grid_thw_elem = out_item.get("video_grid_thw") - if grid_thw_elem is not None: - # Access .data to get the actual tensor from MultiModalFieldElem - grid_thw = grid_thw_elem.data - t, h, w = grid_thw[0].item(), grid_thw[1].item(), grid_thw[2].item() - num_tokens = (t * h * w) // (merge_size**2) - else: - raise ValueError("Missing video_grid_thw for V2 model") - else: - raise NotImplementedError(modality) - - return [placeholder[modality]] * num_tokens - - return [ - PromptReplacement( - modality=modality, - target=[ - placeholder[modality], - ], - replacement=partial( - get_replacement_v2, - modality=modality, - out_mm_kwargs=out_mm_kwargs, - ), - ) - for modality in ("image", "video") - ] - - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - # HyperCLOVAX V2 uses Qwen2.5-VL style flattened pixel values where - # pixel_values has shape (num_patches, channels*patch_size*patch_size) - # while image_grid_thw has shape (num_images, 3). - # We need to use flat_from_sizes to correctly handle this mismatch. - hf_config = self.info.get_hf_config() - spatial_merge_size = hf_config.vision_config.spatial_merge_size - - image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) - image_pixel_grid_sizes = image_grid_thw.prod(-1) - image_embed_grid_sizes = ( - image_pixel_grid_sizes // spatial_merge_size // spatial_merge_size - ) - - video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) - video_pixel_grid_sizes = video_grid_thw.prod(-1) - video_embed_grid_sizes = ( - video_pixel_grid_sizes // spatial_merge_size // spatial_merge_size - ) - - return dict( - pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_pixel_grid_sizes - ), - image_embeds=MultiModalFieldConfig.flat_from_sizes( - "image", image_embed_grid_sizes - ), - image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True), - pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( - "video", video_pixel_grid_sizes - ), - video_embeds=MultiModalFieldConfig.flat_from_sizes( - "video", video_embed_grid_sizes - ), - video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True), - ) - - -def _build_hcxvision_v2_hf_info( - ctx: InputProcessingContext, -) -> HCXVisionV2ProcessingInfo: - return HCXVisionV2ProcessingInfo(ctx) - - -def _build_hcxvision_v2_hf_processor( - info: HCXVisionV2ProcessingInfo, - dummy_inputs: BaseDummyInputsBuilder[HCXVisionV2ProcessingInfo], - *, - cache: BaseMultiModalProcessorCache | None = None, -) -> BaseMultiModalProcessor: - return HCXVisionV2MultiModalProcessor( - info, - dummy_inputs, # type: ignore - cache=cache, - ) - - -@MULTIMODAL_REGISTRY.register_processor( - _build_hcxvision_v2_hf_processor, - info=_build_hcxvision_v2_hf_info, - dummy_inputs=HCXVisionV2DummyInputsBuilder, -) -class HCXVisionV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): - """ - HyperCLOVAX-SEED Vision-Language Model (V2 architecture). - - Supports: - - HyperCLOVAX-SEED-Think-32B: Vision + Text - - HyperCLOVAX-SEED-Omni-8B: Vision + Audio + Text - - Uses Qwen2.5 Vision Transformer as the vision encoder. - """ - - packed_modules_mapping = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"], - "qkv": ["qkv"], # For vision tower - } - - # Weight mapping for loading HuggingFace checkpoints - # NOTE: Order matters! Ignores (None) should come before renames to prevent - # partial matches - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - "model.": "", # Remove model. prefix if present - "vision_model.": "visual.", # HF uses vision_model, we use visual - }, - orig_to_new_substr={ - # Ignore modules not implemented in vLLM - "discrete_vision_model": None, # TextAlignedTokenizer - }, - ) - - def __init__( - self, - *, - vllm_config: VllmConfig, - prefix: str = "", - **kwargs: Any | None, - ) -> None: - super().__init__() - - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - multimodal_config = vllm_config.model_config.multimodal_config - - # Text config - text_config = config.text_config - if text_config.model_type in ["gpt2", "hyperclovax", "llama"]: - text_config._attn_implementation = "sdpa" - if text_config.model_type != "hyperclovax": - text_config.logits_scaling = 1.0 - - # Vision config - vision_config = config.vision_config - - self.config = config - self.vision_config = vision_config - self.text_config = text_config - self.vllm_config = vllm_config - self.dtype = vllm_config.model_config.dtype - - # Initialize Qwen2.5 Vision Transformer - self.visual = Qwen2_5_VisionTransformer( - vision_config=vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=quant_config, - multimodal_config=multimodal_config, - prefix=maybe_prefix(prefix, "visual"), - ) - - # Linear projector (vision_hidden_size -> text_hidden_size) - # For V2 model: mm_projector_type is "linear" - vision_hidden_size = vision_config.hidden_size - text_hidden_size = text_config.hidden_size - - # Check if out_hidden_size is defined (Qwen2.5-VL style) - # The merger in Qwen2.5 VisionTransformer handles projection to out_hidden_size - if hasattr(vision_config, "out_hidden_size"): - out_hidden = vision_config.out_hidden_size - else: - out_hidden = vision_hidden_size - - # Always create Linear projector since HF checkpoint has mm_projector weights - self.mm_projector = nn.Linear(out_hidden, text_hidden_size) - - # Language model - self.lm_head_vocab_size = getattr( - text_config, "padded_vocab_size", text_config.vocab_size - ) - self.language_model = init_vllm_registered_model( - vllm_config=vllm_config, - hf_config=text_config, - prefix=maybe_prefix(prefix, "language_model"), - ) - - self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors - ) - - @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> str | None: - if modality.startswith("image"): - return V2_IMAGE_TOKEN - if modality.startswith("video"): - return V2_VIDEO_TOKEN - - raise ValueError("Only image or video modality is supported") - - def _parse_and_validate_image_input( - self, - **kwargs: object, - ) -> HCXVisionV2ImageInputs | None: - pixel_values = kwargs.pop("pixel_values", None) - image_embeds = kwargs.pop("image_embeds", None) - image_grid_thw = kwargs.pop("image_grid_thw", None) - - if pixel_values is None and image_embeds is None: - return None - - if pixel_values is not None: - return HCXVisionV2ImagePixelInputs( - pixel_values=pixel_values, - image_grid_thw=image_grid_thw, - ) - - if image_embeds is not None: - return HCXVisionV2ImageEmbeddingInputs( - image_embeds=image_embeds, - image_grid_thw=image_grid_thw, - ) - - return None - - def _parse_and_validate_video_input( - self, - **kwargs: object, - ) -> HCXVisionV2VideoInputs | None: - pixel_values_videos = kwargs.pop("pixel_values_videos", None) - video_embeds = kwargs.pop("video_embeds", None) - video_grid_thw = kwargs.pop("video_grid_thw", None) - - if pixel_values_videos is None and video_embeds is None: - return None - - if pixel_values_videos is not None: - return HCXVisionV2VideoPixelInputs( - pixel_values_videos=pixel_values_videos, - video_grid_thw=video_grid_thw, - ) - - if video_embeds is not None: - return HCXVisionV2VideoEmbeddingInputs( - video_embeds=video_embeds, - video_grid_thw=video_grid_thw, - ) - - return None - - def _process_image_input( - self, - image_input: HCXVisionV2ImageInputs, - ) -> tuple[torch.Tensor, ...]: - """Process images through Qwen2.5 ViT and projector.""" - grid_thw = image_input["image_grid_thw"] - assert grid_thw.ndim == 2 - grid_thw_list = grid_thw.tolist() - - if image_input["type"] == "image_embeds": - image_embeds = image_input["image_embeds"].type(self.visual.dtype) - else: - pixel_values = image_input["pixel_values"] - with set_forward_context(None, self.vllm_config): - image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) - - # Apply projector - image_embeds = self.mm_projector(image_embeds) - - # Split concatenated embeddings for each image - merge_size = self.visual.spatial_merge_size - sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() - return image_embeds.split(sizes) - - def _process_video_input( - self, - video_input: HCXVisionV2VideoInputs, - ) -> tuple[torch.Tensor, ...]: - """Process videos through Qwen2.5 ViT and projector.""" - grid_thw = video_input["video_grid_thw"] - assert grid_thw.ndim == 2 - grid_thw_list = grid_thw.tolist() - - if video_input["type"] == "video_embeds": - video_embeds = video_input["video_embeds"].type(self.visual.dtype) - else: - pixel_values_videos = video_input["pixel_values_videos"] - with set_forward_context(None, self.vllm_config): - video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list) - - # Apply projector - video_embeds = self.mm_projector(video_embeds) - - # Split concatenated embeddings for each video - merge_size = self.visual.spatial_merge_size - sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() - return video_embeds.split(sizes) - - def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: - modalities = {} - - for input_key in kwargs: - if ( - input_key in ("pixel_values", "image_embeds") - and "image" not in modalities - ): - modalities["image"] = self._parse_and_validate_image_input(**kwargs) - if ( - input_key in ("pixel_values_videos", "video_embeds") - and "video" not in modalities - ): - modalities["video"] = self._parse_and_validate_video_input(**kwargs) - - return modalities - - def get_language_model(self) -> torch.nn.Module: - return self.language_model - - def embed_multimodal( - self, - **kwargs: object, - ) -> MultiModalEmbeddings: - modalities = self._parse_and_validate_multimodal_inputs(**kwargs) - if not modalities: - return [] - - multimodal_embeddings: tuple[torch.Tensor, ...] = () - - for modality in modalities: - if modality == "image": - image_input = modalities["image"] - if image_input is not None: - image_embeddings = self._process_image_input(image_input) - multimodal_embeddings += tuple(image_embeddings) - if modality == "video": - video_input = modalities["video"] - if video_input is not None: - video_embeddings = self._process_video_input(video_input) - multimodal_embeddings += tuple(video_embeddings) - - return multimodal_embeddings - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - **kwargs: object, - ) -> torch.Tensor | IntermediateTensors: - if intermediate_tensors is not None: - inputs_embeds = None - - hidden_states = self.language_model.model( - input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds - ) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor | None: - return self.language_model.compute_logits(hidden_states) - - def load_weights( - self, - weights: Iterable[tuple[str, torch.Tensor]], - ) -> set[str]: - loader = AutoWeightsLoader(self) - return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index ff7b3ca2ac32..2a2f11120f07 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -124,7 +124,7 @@ "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"), "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"), "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"), - "HCXVisionV2ForCausalLM": ("hyperclovax_vision", "HCXVisionV2ForCausalLM"), + "HCXVisionV2ForCausalLM": ("hyperclovax_vision_v2", "HCXVisionV2ForCausalLM"), "HyperCLOVAXForCausalLM": ("llama", "LlamaForCausalLM"), "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), From 7fce9659e1340ad919d1e17f52c7da6c7a806685 Mon Sep 17 00:00:00 2001 From: effortprogrammer Date: Thu, 19 Feb 2026 14:51:29 +0900 Subject: [PATCH 06/18] fix(hyperclovax): use multimodal.processing imports for V2 dummy inputs Signed-off-by: effortprogrammer --- vllm/model_executor/models/hyperclovax_vision_v2.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/hyperclovax_vision_v2.py b/vllm/model_executor/models/hyperclovax_vision_v2.py index ab546cd426ad..71f472e3441d 100644 --- a/vllm/model_executor/models/hyperclovax_vision_v2.py +++ b/vllm/model_executor/models/hyperclovax_vision_v2.py @@ -30,13 +30,14 @@ ) from vllm.multimodal.parse import ImageSize, MultiModalDataItems from vllm.multimodal.processing import ( + BaseDummyInputsBuilder, BaseMultiModalProcessor, BaseProcessingInfo, InputProcessingContext, + ProcessorInputs, PromptReplacement, PromptUpdate, ) -from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape @@ -204,8 +205,6 @@ def get_dummy_processor_inputs( By passing token IDs directly, we ensure the correct token (128060) is used for prompt replacement matching. """ - from vllm.multimodal.profiling import ProcessorInputs - num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) From d4ab8ebba60bc2a74b9fb88020f0723e70e29017 Mon Sep 17 00:00:00 2001 From: effortprogrammer Date: Thu, 19 Feb 2026 15:33:52 +0900 Subject: [PATCH 07/18] fix(hyperclovax): align V2 dummy input builder with multimodal API Signed-off-by: effortprogrammer --- .../models/hyperclovax_vision_v2.py | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/hyperclovax_vision_v2.py b/vllm/model_executor/models/hyperclovax_vision_v2.py index 71f472e3441d..6e7084ed50ca 100644 --- a/vllm/model_executor/models/hyperclovax_vision_v2.py +++ b/vllm/model_executor/models/hyperclovax_vision_v2.py @@ -25,6 +25,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import ( + MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) @@ -195,8 +196,9 @@ def get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Mapping[str, object] | None = None, - ): + mm_options: Mapping[str, BaseDummyOptions] | None = None, + mm_processor_kwargs: Mapping[str, object] | None = None, + ) -> ProcessorInputs: """ Override to use token IDs directly instead of text strings. @@ -219,11 +221,18 @@ def get_dummy_processor_inputs( video_token_id ] * num_videos - dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options) + dummy_mm_data = self.get_dummy_mm_data( + seq_len, + mm_counts, + mm_options, + mm_processor_kwargs=mm_processor_kwargs, + ) + dummy_mm_items = self.info.parse_mm_data(dummy_mm_data, validate=False) return ProcessorInputs( prompt=prompt_ids, - mm_data=dummy_mm_data, + mm_items=dummy_mm_items, + hf_processor_mm_kwargs=mm_processor_kwargs or {}, tokenization_kwargs={"truncation": False}, ) @@ -232,9 +241,8 @@ def get_dummy_mm_data( seq_len: int, mm_counts: Mapping[str, int], mm_options: Mapping[str, BaseDummyOptions] | None = None, - ) -> dict: - from vllm.multimodal.inputs import MultiModalDataDict - + mm_processor_kwargs: Mapping[str, object] | None = None, + ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) From 340d0b522f487a9dd709e2423bcecea742727063 Mon Sep 17 00:00:00 2001 From: effortprogrammer Date: Thu, 19 Feb 2026 15:44:42 +0900 Subject: [PATCH 08/18] fix(hyperclovax): drop unsupported multimodal_config arg in V2 vision tower Signed-off-by: effortprogrammer --- vllm/model_executor/models/hyperclovax_vision_v2.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/models/hyperclovax_vision_v2.py b/vllm/model_executor/models/hyperclovax_vision_v2.py index 6e7084ed50ca..be3ac1263ba6 100644 --- a/vllm/model_executor/models/hyperclovax_vision_v2.py +++ b/vllm/model_executor/models/hyperclovax_vision_v2.py @@ -497,7 +497,6 @@ def __init__( config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - multimodal_config = vllm_config.model_config.multimodal_config # Text config text_config = config.text_config @@ -520,7 +519,6 @@ def __init__( vision_config=vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=quant_config, - multimodal_config=multimodal_config, prefix=maybe_prefix(prefix, "visual"), ) From 729187aeb30766b5a53e29148ea169e775302ce0 Mon Sep 17 00:00:00 2001 From: effortprogrammer Date: Thu, 19 Feb 2026 18:25:50 +0900 Subject: [PATCH 09/18] refactor(hyperclovax): address reviewer feedback in model and registry Signed-off-by: effortprogrammer --- tests/models/registry.py | 7 +++++ .../models/hyperclovax_vision.py | 3 +- .../models/hyperclovax_vision_v2.py | 28 ++----------------- 3 files changed, 11 insertions(+), 27 deletions(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 16e64ea9e6d8..1665292ebf12 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -767,6 +767,13 @@ def check_available_online( "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", trust_remote_code=True, ), + "HCXVisionV2ForCausalLM": _HfExamplesInfo( + "naver-hyperclovax/HyperCLOVAX-SEED-Think-32B", + extras={ + "llama-text-backend": "naver-hyperclovax/HyperCLOVAX-SEED-Omni-8B", + }, + trust_remote_code=True, + ), "HunYuanVLForConditionalGeneration": _HfExamplesInfo( "tencent/HunyuanOCR", hf_overrides={"num_experts": 0}, diff --git a/vllm/model_executor/models/hyperclovax_vision.py b/vllm/model_executor/models/hyperclovax_vision.py index d8027f90e72a..15e387cf4447 100644 --- a/vllm/model_executor/models/hyperclovax_vision.py +++ b/vllm/model_executor/models/hyperclovax_vision.py @@ -6,7 +6,7 @@ from collections.abc import Iterable, Mapping, Sequence from functools import partial from itertools import accumulate -from typing import Annotated, Any, Literal +from typing import Annotated, Literal import numpy as np import torch @@ -613,7 +613,6 @@ def __init__( *, vllm_config: VllmConfig, prefix: str = "", - **kwargs: Any | None, ) -> None: super().__init__() diff --git a/vllm/model_executor/models/hyperclovax_vision_v2.py b/vllm/model_executor/models/hyperclovax_vision_v2.py index be3ac1263ba6..7d8a86c69a78 100644 --- a/vllm/model_executor/models/hyperclovax_vision_v2.py +++ b/vllm/model_executor/models/hyperclovax_vision_v2.py @@ -13,7 +13,7 @@ from collections.abc import Iterable, Mapping, Sequence from functools import partial -from typing import Annotated, Any, Literal +from typing import Annotated, Literal import torch import torch.nn as nn @@ -23,7 +23,6 @@ from vllm.config.multimodal import BaseDummyOptions from vllm.forward_context import set_forward_context from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import ( MultiModalDataDict, MultiModalFieldConfig, @@ -34,7 +33,6 @@ BaseDummyInputsBuilder, BaseMultiModalProcessor, BaseProcessingInfo, - InputProcessingContext, ProcessorInputs, PromptReplacement, PromptUpdate, @@ -431,28 +429,9 @@ def _get_mm_fields_config( ) -def _build_hcxvision_v2_hf_info( - ctx: InputProcessingContext, -) -> HCXVisionV2ProcessingInfo: - return HCXVisionV2ProcessingInfo(ctx) - - -def _build_hcxvision_v2_hf_processor( - info: HCXVisionV2ProcessingInfo, - dummy_inputs: BaseDummyInputsBuilder[HCXVisionV2ProcessingInfo], - *, - cache: BaseMultiModalProcessorCache | None = None, -) -> BaseMultiModalProcessor: - return HCXVisionV2MultiModalProcessor( - info, - dummy_inputs, # type: ignore - cache=cache, - ) - - @MULTIMODAL_REGISTRY.register_processor( - _build_hcxvision_v2_hf_processor, - info=_build_hcxvision_v2_hf_info, + HCXVisionV2MultiModalProcessor, + info=HCXVisionV2ProcessingInfo, dummy_inputs=HCXVisionV2DummyInputsBuilder, ) class HCXVisionV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): @@ -491,7 +470,6 @@ def __init__( *, vllm_config: VllmConfig, prefix: str = "", - **kwargs: Any | None, ) -> None: super().__init__() From 335bc415e380102f1bb918f917a1aba541a7aa53 Mon Sep 17 00:00:00 2001 From: effortprogrammer Date: Thu, 19 Feb 2026 18:26:00 +0900 Subject: [PATCH 10/18] docs: add HyperCLOVAX V2 to supported models Signed-off-by: effortprogrammer --- docs/models/supported_models.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 1cad8c4a171a..fba83d931ce1 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -696,6 +696,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `GlmOcrForConditionalGeneration` | GLM-OCR | T + IE+ | `zai-org/GLM-OCR`, etc. | ✅︎ | ✅︎ | | `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | | `HCXVisionForCausalLM` | HyperCLOVAX-SEED-Vision-Instruct-3B | T + I+ + V+ | `naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B` | | | +| `HCXVisionV2ForCausalLM` | HyperCLOVAX-SEED-Think-32B | T + I+ + V+ | `naver-hyperclovax/HyperCLOVAX-SEED-Think-32B` | | | | `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | | `HunYuanVLForConditionalGeneration` | HunyuanOCR | T + IE+ | `tencent/HunyuanOCR`, etc. | ✅︎ | ✅︎ | | `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | From 8b0d8b33873a6f0dd2883f3304e10e114784c461 Mon Sep 17 00:00:00 2001 From: effortprogrammer Date: Thu, 19 Feb 2026 18:35:36 +0900 Subject: [PATCH 11/18] fix(hyperclovax): remove Omni references from V2 scope Signed-off-by: effortprogrammer --- tests/models/registry.py | 3 --- vllm/model_executor/models/hyperclovax_vision_v2.py | 3 --- 2 files changed, 6 deletions(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 1665292ebf12..683684bf3842 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -769,9 +769,6 @@ def check_available_online( ), "HCXVisionV2ForCausalLM": _HfExamplesInfo( "naver-hyperclovax/HyperCLOVAX-SEED-Think-32B", - extras={ - "llama-text-backend": "naver-hyperclovax/HyperCLOVAX-SEED-Omni-8B", - }, trust_remote_code=True, ), "HunYuanVLForConditionalGeneration": _HfExamplesInfo( diff --git a/vllm/model_executor/models/hyperclovax_vision_v2.py b/vllm/model_executor/models/hyperclovax_vision_v2.py index 7d8a86c69a78..7167d80fd831 100644 --- a/vllm/model_executor/models/hyperclovax_vision_v2.py +++ b/vllm/model_executor/models/hyperclovax_vision_v2.py @@ -8,7 +8,6 @@ Supports: - HyperCLOVAX-SEED-Think-32B: Vision + Text -- HyperCLOVAX-SEED-Omni-8B: Vision + Audio + Text """ from collections.abc import Iterable, Mapping, Sequence @@ -288,7 +287,6 @@ def _call_hf_processor( hf_processor = self.info.get_hf_processor(**mm_kwargs) # Build data dict for HF processor (images/videos only) - # The HF processor (HCXVisionV2Processor) doesn't support audio # NOTE: We pass the prompt as-is without token normalization. # Token expansion is handled by vLLM via _get_prompt_updates since # _hf_processor_applies_updates returns False. @@ -440,7 +438,6 @@ class HCXVisionV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): Supports: - HyperCLOVAX-SEED-Think-32B: Vision + Text - - HyperCLOVAX-SEED-Omni-8B: Vision + Audio + Text Uses Qwen2.5 Vision Transformer as the vision encoder. """ From 677e249192cc66352369517ad3e924aea3b4ff2f Mon Sep 17 00:00:00 2001 From: effortprogrammer Date: Thu, 19 Feb 2026 19:08:54 +0900 Subject: [PATCH 12/18] fix(hyperclovax): limit V2 text backend handling to hyperclovax Signed-off-by: effortprogrammer --- vllm/model_executor/models/hyperclovax_vision_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/hyperclovax_vision_v2.py b/vllm/model_executor/models/hyperclovax_vision_v2.py index 7167d80fd831..b8b19d9bc0eb 100644 --- a/vllm/model_executor/models/hyperclovax_vision_v2.py +++ b/vllm/model_executor/models/hyperclovax_vision_v2.py @@ -475,7 +475,7 @@ def __init__( # Text config text_config = config.text_config - if text_config.model_type in ["gpt2", "hyperclovax", "llama"]: + if text_config.model_type == "hyperclovax": text_config._attn_implementation = "sdpa" if text_config.model_type != "hyperclovax": text_config.logits_scaling = 1.0 From 8aeb7a63f6285b5713007e73694bf089e33d11dd Mon Sep 17 00:00:00 2001 From: effortprogrammer Date: Thu, 19 Feb 2026 19:47:10 +0900 Subject: [PATCH 13/18] fix(hyperclovax): restore original text backend condition Signed-off-by: effortprogrammer --- vllm/model_executor/models/hyperclovax_vision_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/hyperclovax_vision_v2.py b/vllm/model_executor/models/hyperclovax_vision_v2.py index b8b19d9bc0eb..7167d80fd831 100644 --- a/vllm/model_executor/models/hyperclovax_vision_v2.py +++ b/vllm/model_executor/models/hyperclovax_vision_v2.py @@ -475,7 +475,7 @@ def __init__( # Text config text_config = config.text_config - if text_config.model_type == "hyperclovax": + if text_config.model_type in ["gpt2", "hyperclovax", "llama"]: text_config._attn_implementation = "sdpa" if text_config.model_type != "hyperclovax": text_config.logits_scaling = 1.0 From f419cdac3b13f255c2954bbb100bbb15a32de741 Mon Sep 17 00:00:00 2001 From: effortprogrammer Date: Mon, 9 Mar 2026 08:15:13 +0900 Subject: [PATCH 14/18] fix(hyperclovax): align V2 multimodal processor update flow Signed-off-by: effortprogrammer --- .../models/hyperclovax_vision_v2.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/hyperclovax_vision_v2.py b/vllm/model_executor/models/hyperclovax_vision_v2.py index 7167d80fd831..a38b44575566 100644 --- a/vllm/model_executor/models/hyperclovax_vision_v2.py +++ b/vllm/model_executor/models/hyperclovax_vision_v2.py @@ -228,7 +228,7 @@ def get_dummy_processor_inputs( return ProcessorInputs( prompt=prompt_ids, - mm_items=dummy_mm_items, + mm_data_items=dummy_mm_items, hf_processor_mm_kwargs=mm_processor_kwargs or {}, tokenization_kwargs={"truncation": False}, ) @@ -299,6 +299,7 @@ def _call_hf_processor( processed_outputs = self.info.ctx.call_hf_processor( hf_processor=hf_processor, data=data, + kwargs=dict(**mm_kwargs, **tok_kwargs), ) return processed_outputs @@ -310,16 +311,15 @@ def _hf_processor_applies_updates( hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], ) -> bool: - # HyperCLOVAX V2 has a token case mismatch bug: - # - Chat template uses <|IMAGE_PAD|> (uppercase) - # - HF processor (Qwen2_5_VLProcessor) expects <|image_pad|> (lowercase) - # - Tokenizer vocab has <|IMAGE_PAD|> (uppercase) = token ID 128060 - # - # The HF processor's token expansion fails because it looks for lowercase - # but the tokenized prompt has uppercase tokens. We bypass HF processor's - # expansion and let vLLM handle it via _get_prompt_updates using the - # correct token IDs from hf_config. - return False + # Match BaseMultiModalProcessor behavior: + # - raw multimodal inputs: HF processor applies updates + # - embedding inputs: vLLM applies updates + return super()._hf_processor_applies_updates( + prompt_text, + mm_items, + hf_processor_mm_kwargs, + tokenization_kwargs, + ) def _get_prompt_updates( self, From 0c25a00650cc5aae20553e65b655554f3d8978c2 Mon Sep 17 00:00:00 2001 From: effortprogrammer Date: Mon, 9 Mar 2026 09:05:16 +0900 Subject: [PATCH 15/18] fix(chat): normalize openai media parts to canonical modality types Signed-off-by: effortprogrammer --- tests/entrypoints/test_chat_utils.py | 5 ++++- vllm/entrypoints/chat_utils.py | 4 +--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 1a118eb4fd8e..01577099143d 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -1480,7 +1480,10 @@ def test_parse_chat_messages_openai_format_image_url( assert conversation == [ { "role": "user", - "content": content, + "content": [ + {"type": "image"}, + {"type": "text", "text": "What's in the image?"}, + ], } ] _assert_mm_data_is_image_input(mm_data, 1) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index eecf30fa2a7a..6de8a3cd5a3d 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1450,9 +1450,7 @@ def _parse_chat_message_content_part( raise NotImplementedError(f"Unknown part type: {part_type}") if wrap_dicts: - if isinstance(part, dict): - return dict(part) - return {"type": "text", "text": str(part)} + return {"type": modality} return MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None From b1036ccceb82229f1156b3f71f0f22e009727e81 Mon Sep 17 00:00:00 2001 From: effortprogrammer Date: Mon, 9 Mar 2026 09:56:46 +0900 Subject: [PATCH 16/18] fix(hyperclovax): satisfy registry coverage and dummy prompt type Signed-off-by: effortprogrammer --- tests/models/registry.py | 3 ++ .../models/hyperclovax_vision_v2.py | 29 ++++--------------- 2 files changed, 9 insertions(+), 23 deletions(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 9c903f88522e..6cf3d9c297c2 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -793,6 +793,9 @@ def check_available_online( "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", trust_remote_code=True, ), + "HyperCLOVAXForCausalLM": _HfExamplesInfo( + "meta-llama/Llama-3.2-1B-Instruct", + ), "HCXVisionV2ForCausalLM": _HfExamplesInfo( "naver-hyperclovax/HyperCLOVAX-SEED-Think-32B", trust_remote_code=True, diff --git a/vllm/model_executor/models/hyperclovax_vision_v2.py b/vllm/model_executor/models/hyperclovax_vision_v2.py index a38b44575566..b32872962ebc 100644 --- a/vllm/model_executor/models/hyperclovax_vision_v2.py +++ b/vllm/model_executor/models/hyperclovax_vision_v2.py @@ -185,9 +185,9 @@ def get_dummy_text( self, mm_counts: Mapping[str, int], ) -> str: - # This method is not used when get_dummy_processor_inputs is overridden, - # but we keep it for compatibility. - return "" + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + return V2_IMAGE_TOKEN * num_images + V2_VIDEO_TOKEN * num_videos def get_dummy_processor_inputs( self, @@ -196,27 +196,10 @@ def get_dummy_processor_inputs( mm_options: Mapping[str, BaseDummyOptions] | None = None, mm_processor_kwargs: Mapping[str, object] | None = None, ) -> ProcessorInputs: - """ - Override to use token IDs directly instead of text strings. - - This avoids the tokenizer issue where <|IMAGE_PAD|> might not be - recognized as a special token and gets split into multiple tokens. - By passing token IDs directly, we ensure the correct token (128060) - is used for prompt replacement matching. - """ + """Build dummy processor inputs for memory profiling.""" num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) - - hf_config = self.info.get_hf_config() - - # Use token IDs directly to avoid tokenizer issues with special tokens - image_token_id = hf_config.image_token_id # 128060 - video_token_id = hf_config.video_token_id # 128061 - - # Create prompt as token ID list instead of text string - prompt_ids: list[int] = [image_token_id] * num_images + [ - video_token_id - ] * num_videos + prompt_text = V2_IMAGE_TOKEN * num_images + V2_VIDEO_TOKEN * num_videos dummy_mm_data = self.get_dummy_mm_data( seq_len, @@ -227,7 +210,7 @@ def get_dummy_processor_inputs( dummy_mm_items = self.info.parse_mm_data(dummy_mm_data, validate=False) return ProcessorInputs( - prompt=prompt_ids, + prompt=prompt_text, mm_data_items=dummy_mm_items, hf_processor_mm_kwargs=mm_processor_kwargs or {}, tokenization_kwargs={"truncation": False}, From 56ac889493a172ee1e83283359412f6350e365df Mon Sep 17 00:00:00 2001 From: effortprogrammer Date: Mon, 9 Mar 2026 16:04:03 +0900 Subject: [PATCH 17/18] fix(registry): map HyperCLOVAX architecture as text model Signed-off-by: effortprogrammer --- tests/models/registry.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 6cf3d9c297c2..5dd0a9f11a87 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -313,6 +313,10 @@ def check_available_online( "HunYuanMoEV1ForCausalLM": _HfExamplesInfo( "tencent/Hunyuan-A13B-Instruct", trust_remote_code=True ), + "HyperCLOVAXForCausalLM": _HfExamplesInfo( + "naver-hyperclovax/HyperCLOVAX-SEED-Think-32B", + trust_remote_code=True, + ), "InternLMForCausalLM": _HfExamplesInfo( "internlm/internlm-chat-7b", trust_remote_code=True ), @@ -793,9 +797,6 @@ def check_available_online( "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", trust_remote_code=True, ), - "HyperCLOVAXForCausalLM": _HfExamplesInfo( - "meta-llama/Llama-3.2-1B-Instruct", - ), "HCXVisionV2ForCausalLM": _HfExamplesInfo( "naver-hyperclovax/HyperCLOVAX-SEED-Think-32B", trust_remote_code=True, From 5abbb8c8b76e3986f6fda8efd6e6ee190173f4bd Mon Sep 17 00:00:00 2001 From: effortprogrammer Date: Mon, 9 Mar 2026 21:12:46 +0900 Subject: [PATCH 18/18] fix(ci): harden cpu runner tensor replacement and realtime warmup wait Signed-off-by: effortprogrammer --- tests/entrypoints/openai/test_realtime_validation.py | 2 +- vllm/v1/worker/cpu_model_runner.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/entrypoints/openai/test_realtime_validation.py b/tests/entrypoints/openai/test_realtime_validation.py index 9a45ac293ef3..9092aac5b693 100644 --- a/tests/entrypoints/openai/test_realtime_validation.py +++ b/tests/entrypoints/openai/test_realtime_validation.py @@ -118,7 +118,7 @@ async def test_multi_chunk_streaming( # JIT compilation warmup_done = False while not warmup_done: - event = await receive_event(ws, timeout=360.0) + event = await receive_event(ws, timeout=600.0) if event["type"] in ("transcription.done", "error"): warmup_done = True diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index 489480004821..d46a1259335d 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -36,7 +36,8 @@ def replace_tensor(obj: Any, cpu_attr_name: str, device_attr_name) -> None: device_tensor = getattr(obj, device_attr_name, None) if cpu_tensor is not None and device_tensor is not None: assert isinstance(cpu_tensor, torch.Tensor) - assert isinstance(device_tensor, torch.Tensor) + if not isinstance(device_tensor, torch.Tensor): + return setattr(obj, device_attr_name, cpu_tensor) for v in vars(self).values():