From febf60cd9a60bacf4b696cc78ce52bb99ef8224e Mon Sep 17 00:00:00 2001 From: piood <2477084691@qq.com> Date: Sun, 9 Nov 2025 16:58:39 +0000 Subject: [PATCH 1/3] [bugfix] fix siglip batch text output error Signed-off-by: piood <2477084691@qq.com> --- vllm/model_executor/models/siglip.py | 97 +++++++++++++++++++++------- 1 file changed, 74 insertions(+), 23 deletions(-) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index e363be523dcc..346d80ac142f 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -18,7 +18,9 @@ SiglipVisionConfig, ) +from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import MultiHeadAttention +from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import divide, get_tensor_model_parallel_world_size @@ -379,6 +381,8 @@ def __init__( quant_config: QuantizationConfig | None = None, *, prefix: str = "", + attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention], + attn_type: AttentionType | None = None, ) -> None: super().__init__() @@ -413,9 +417,23 @@ def __init__( self.tp_size = get_tensor_model_parallel_world_size() self.num_heads_per_partition = divide(self.num_heads, self.tp_size) - self.attn = MultiHeadAttention( - self.num_heads_per_partition, self.head_dim, self.scale - ) + if attn_cls == EncoderOnlyAttention: + self.attn = attn_cls( + self.num_heads_per_partition, + self.head_dim, + self.scale, + prefix=f"{prefix}.attn", + attn_type=attn_type, + ) + elif attn_cls == MultiHeadAttention: + self.attn = MultiHeadAttention( + self.num_heads_per_partition, + self.head_dim, + self.scale, + prefix=f"{prefix}.attn", + ) + else: + raise ValueError(f"Invalid attention class: {attn_cls}") def forward( self, @@ -424,25 +442,7 @@ def forward( """Input shape: Batch x Time x Channel""" qkv_states, _ = self.qkv_proj(hidden_states) query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) - - needs_unsqueeze = query_states.ndim == 2 - if needs_unsqueeze: - query_states, key_states, value_states = ( - query_states.unsqueeze(0), - key_states.unsqueeze(0), - value_states.unsqueeze(0), - ) - out = self.attn(query_states, key_states, value_states) - - if needs_unsqueeze: - out, query_states, key_states, value_states = ( - out.squeeze(0), - query_states.squeeze(0), - key_states.squeeze(0), - value_states.squeeze(0), - ) - attn_output, _ = self.out_proj(out) return attn_output, None @@ -495,6 +495,8 @@ def __init__( quant_config: QuantizationConfig | None = None, *, prefix: str = "", + attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention], + attn_type: AttentionType | None = None, ) -> None: super().__init__() @@ -504,6 +506,8 @@ def __init__( config, quant_config=quant_config, prefix=f"{prefix}.self_attn", + attn_cls=attn_cls, + attn_type=attn_type, ) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP( @@ -539,6 +543,8 @@ def __init__( num_hidden_layers_override: int | None = None, *, prefix: str = "", + attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention], + attn_type: AttentionType | None = None, ) -> None: super().__init__() @@ -555,6 +561,8 @@ def __init__( config, quant_config=quant_config, prefix=f"{prefix}.layers.{layer_idx}", + attn_cls=attn_cls, + attn_type=attn_type, ) for layer_idx in range(num_hidden_layers) ] @@ -598,6 +606,8 @@ def __init__( config=config, quant_config=quant_config, prefix=f"{prefix}.encoder", + attn_cls=EncoderOnlyAttention, + attn_type=AttentionType.ENCODER_ONLY, ) self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) @@ -709,6 +719,7 @@ def __init__( quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override, prefix=f"{prefix}.encoder", + attn_cls=MultiHeadAttention, ) num_hidden_layers = config.num_hidden_layers @@ -1034,10 +1045,50 @@ def get_text_features( inputs_embeds=inputs_embeds, ) text_features = self.text_model.head(last_hidden_state) - # Flip to extract CLS token (first token after reversal) for pooling - text_features = text_features.flip(0) + + # SigLIP uses reversed position_ids; + # flip sequences to move EOS token to first position + text_features = self._flip_sequences_by_position_ids( + text_features, position_ids + ) + return text_features + def _flip_sequences_by_position_ids( + self, + features: torch.Tensor, + position_ids: torch.Tensor, + ) -> torch.Tensor: + """Flip sequences so EOS token moves to first position for CLS pooling. + + SigLIP position_ids are reversed within each sequence. This method detects + sequence boundaries and flips each sequence individually. + """ + if len(features) == 1: + return features + + # Detect sequence boundaries where position_ids decrease + position_diffs = position_ids[1:] - position_ids[:-1] + boundary_mask = position_diffs <= 0 + + boundary_indices = torch.cat( + [ + torch.tensor([0], device=features.device), + torch.where(boundary_mask)[0] + 1, + torch.tensor([len(features)], device=features.device), + ] + ) + + # Flip each sequence individually + flipped_features = [] + for i in range(len(boundary_indices) - 1): + start_idx = boundary_indices[i] + end_idx = boundary_indices[i + 1] + sequence = features[start_idx:end_idx] + flipped_features.append(torch.flip(sequence, [0])) + + return torch.cat(flipped_features, dim=0) + def get_image_features( self, pixel_values: torch.Tensor, From 98b57478ea8d0d4ae0f71da33e881801f7918c45 Mon Sep 17 00:00:00 2001 From: piood <2477084691@qq.com> Date: Mon, 10 Nov 2025 06:49:36 +0000 Subject: [PATCH 2/3] fix Signed-off-by: piood <2477084691@qq.com> --- vllm/model_executor/models/siglip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 346d80ac142f..13f0cde14d9d 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -418,7 +418,7 @@ def __init__( self.num_heads_per_partition = divide(self.num_heads, self.tp_size) if attn_cls == EncoderOnlyAttention: - self.attn = attn_cls( + self.attn = EncoderOnlyAttention( self.num_heads_per_partition, self.head_dim, self.scale, From 29e6b3c69d4ce1ea488b41df69fa487e39fa3421 Mon Sep 17 00:00:00 2001 From: piood <2477084691@qq.com> Date: Mon, 10 Nov 2025 08:38:36 +0000 Subject: [PATCH 3/3] optimize batch flip Signed-off-by: piood <2477084691@qq.com> --- vllm/model_executor/models/siglip.py | 52 +++++++++++----------------- 1 file changed, 20 insertions(+), 32 deletions(-) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 13f0cde14d9d..3cbdd64acc4a 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -18,7 +18,6 @@ SiglipVisionConfig, ) -from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import MultiHeadAttention from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.config import VllmConfig @@ -382,7 +381,6 @@ def __init__( *, prefix: str = "", attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention], - attn_type: AttentionType | None = None, ) -> None: super().__init__() @@ -417,23 +415,12 @@ def __init__( self.tp_size = get_tensor_model_parallel_world_size() self.num_heads_per_partition = divide(self.num_heads, self.tp_size) - if attn_cls == EncoderOnlyAttention: - self.attn = EncoderOnlyAttention( - self.num_heads_per_partition, - self.head_dim, - self.scale, - prefix=f"{prefix}.attn", - attn_type=attn_type, - ) - elif attn_cls == MultiHeadAttention: - self.attn = MultiHeadAttention( - self.num_heads_per_partition, - self.head_dim, - self.scale, - prefix=f"{prefix}.attn", - ) - else: - raise ValueError(f"Invalid attention class: {attn_cls}") + self.attn = attn_cls( + self.num_heads_per_partition, + self.head_dim, + self.scale, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -496,7 +483,6 @@ def __init__( *, prefix: str = "", attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention], - attn_type: AttentionType | None = None, ) -> None: super().__init__() @@ -507,7 +493,6 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.self_attn", attn_cls=attn_cls, - attn_type=attn_type, ) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP( @@ -544,7 +529,6 @@ def __init__( *, prefix: str = "", attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention], - attn_type: AttentionType | None = None, ) -> None: super().__init__() @@ -562,7 +546,6 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.layers.{layer_idx}", attn_cls=attn_cls, - attn_type=attn_type, ) for layer_idx in range(num_hidden_layers) ] @@ -607,7 +590,6 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.encoder", attn_cls=EncoderOnlyAttention, - attn_type=AttentionType.ENCODER_ONLY, ) self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) @@ -1079,15 +1061,21 @@ def _flip_sequences_by_position_ids( ] ) - # Flip each sequence individually - flipped_features = [] - for i in range(len(boundary_indices) - 1): - start_idx = boundary_indices[i] - end_idx = boundary_indices[i + 1] - sequence = features[start_idx:end_idx] - flipped_features.append(torch.flip(sequence, [0])) + # For each sequence [start, end), position i flips to: start + end - 1 - i + lengths = boundary_indices[1:] - boundary_indices[:-1] + starts = boundary_indices[:-1] + ends = boundary_indices[1:] + + # Assign sequence ID to each element + sequence_ids = torch.arange( + len(lengths), device=features.device + ).repeat_interleave(lengths) + + # Calculate flipped indices for all positions at once + current_positions = torch.arange(len(features), device=features.device) + flip_indices = starts[sequence_ids] + ends[sequence_ids] - 1 - current_positions - return torch.cat(flipped_features, dim=0) + return features[flip_indices] def get_image_features( self,