Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 61 additions & 22 deletions vllm/model_executor/models/siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)

from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size
Expand Down Expand Up @@ -379,6 +380,7 @@ def __init__(
quant_config: QuantizationConfig | None = None,
*,
prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
) -> None:
super().__init__()

Expand Down Expand Up @@ -413,8 +415,11 @@ def __init__(
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

self.attn = MultiHeadAttention(
self.num_heads_per_partition, self.head_dim, self.scale
self.attn = attn_cls(
self.num_heads_per_partition,
self.head_dim,
self.scale,
prefix=f"{prefix}.attn",
)

def forward(
Expand All @@ -424,25 +429,7 @@ def forward(
"""Input shape: Batch x Time x Channel"""
qkv_states, _ = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)

needs_unsqueeze = query_states.ndim == 2
if needs_unsqueeze:
query_states, key_states, value_states = (
query_states.unsqueeze(0),
key_states.unsqueeze(0),
value_states.unsqueeze(0),
)

out = self.attn(query_states, key_states, value_states)

if needs_unsqueeze:
out, query_states, key_states, value_states = (
out.squeeze(0),
query_states.squeeze(0),
key_states.squeeze(0),
value_states.squeeze(0),
)

attn_output, _ = self.out_proj(out)

return attn_output, None
Expand Down Expand Up @@ -495,6 +482,7 @@ def __init__(
quant_config: QuantizationConfig | None = None,
*,
prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
) -> None:
super().__init__()

Expand All @@ -504,6 +492,7 @@ def __init__(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
attn_cls=attn_cls,
)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(
Expand Down Expand Up @@ -539,6 +528,7 @@ def __init__(
num_hidden_layers_override: int | None = None,
*,
prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
) -> None:
super().__init__()

Expand All @@ -555,6 +545,7 @@ def __init__(
config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
attn_cls=attn_cls,
)
for layer_idx in range(num_hidden_layers)
]
Expand Down Expand Up @@ -598,6 +589,7 @@ def __init__(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.encoder",
attn_cls=EncoderOnlyAttention,
)

self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
Expand Down Expand Up @@ -709,6 +701,7 @@ def __init__(
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
attn_cls=MultiHeadAttention,
)

num_hidden_layers = config.num_hidden_layers
Expand Down Expand Up @@ -1034,10 +1027,56 @@ def get_text_features(
inputs_embeds=inputs_embeds,
)
text_features = self.text_model.head(last_hidden_state)
# Flip to extract CLS token (first token after reversal) for pooling
text_features = text_features.flip(0)

# SigLIP uses reversed position_ids;
# flip sequences to move EOS token to first position
text_features = self._flip_sequences_by_position_ids(
text_features, position_ids
)

return text_features

def _flip_sequences_by_position_ids(
self,
features: torch.Tensor,
position_ids: torch.Tensor,
) -> torch.Tensor:
"""Flip sequences so EOS token moves to first position for CLS pooling.

SigLIP position_ids are reversed within each sequence. This method detects
sequence boundaries and flips each sequence individually.
"""
if len(features) == 1:
return features

# Detect sequence boundaries where position_ids decrease
position_diffs = position_ids[1:] - position_ids[:-1]
boundary_mask = position_diffs <= 0

boundary_indices = torch.cat(
[
torch.tensor([0], device=features.device),
torch.where(boundary_mask)[0] + 1,
torch.tensor([len(features)], device=features.device),
]
)

# For each sequence [start, end), position i flips to: start + end - 1 - i
lengths = boundary_indices[1:] - boundary_indices[:-1]
starts = boundary_indices[:-1]
ends = boundary_indices[1:]

# Assign sequence ID to each element
sequence_ids = torch.arange(
len(lengths), device=features.device
).repeat_interleave(lengths)

# Calculate flipped indices for all positions at once
current_positions = torch.arange(len(features), device=features.device)
flip_indices = starts[sequence_ids] + ends[sequence_ids] - 1 - current_positions

return features[flip_indices]

def get_image_features(
self,
pixel_values: torch.Tensor,
Expand Down