Skip to content

[bugfix] fix siglip batch text output error#28365

Merged
DarkLight1337 merged 4 commits intovllm-project:mainfrom
piood:fix-siglip-text-batch-output
Nov 10, 2025
Merged

[bugfix] fix siglip batch text output error#28365
DarkLight1337 merged 4 commits intovllm-project:mainfrom
piood:fix-siglip-text-batch-output

Conversation

@piood
Copy link
Copy Markdown
Contributor

@piood piood commented Nov 9, 2025

Purpose

Fix SigLIP text embedding batch processing error when batch size > 1, as reported in #27566.

Root Cause: Text embeddings flatten batch sequences into a single sequence and use MultiHeadAttention, which incorrectly handles batch attention. Additionally, the previous implementation used flip(0) on the entire batch tensor, mixing features across different sequences.

Solution:

  • Switch text encoder to use EncoderOnlyAttention for proper batch handling
  • Implement _flip_sequences_by_position_ids() to flip each sequence individually based on position_ids boundaries

Test Plan

vllm serve google/siglip2-base-patch16-224 --runner pooling --enforce-eager
import os
from openai import OpenAI

openai_api_key = os.environ.get("OPENAI_API_KEY", "EMPTY")
openai_api_base = os.environ.get("OPENAI_API_BASE", "http://localhost:8000/v1")

def batch_text_request():
    """Send batch text embedding request to SigLIP model."""
    client = OpenAI(
        api_key=openai_api_key,
        base_url=openai_api_base,
    )
    
    # Get the served model ID
    model_id = client.models.list().data[0].id
    
    # Create test prompts
    words = ['hi', "how old are you?"]
    prompts = []
    for i in range(10):
        content = words[i % 2]
        prompts.append("a photo of a cat " + content)

    response = client.embeddings.create(
        model=model_id,
        input=prompts, 
        encoding_format="float"
    )

    results = []
    for idx, embedding_obj in enumerate(response.data):
        prompt = prompts[idx]
        embedding_preview = embedding_obj.embedding[:5]
        print(f"Request {idx+1} ('{prompt}'): {embedding_preview}")
        results.append((idx, prompt, embedding_preview))
    
    return results

if __name__ == "__main__":
    batch_text_request()

Test Result

Before fix

Request 1 ('a photo of a cat hi'): [-0.013481298461556435, 0.032701071351766586, -0.003023682162165642, -0.022816617041826248, 0.27149301767349243]
Request 2 ('a photo of a cat how old are you?'): [-0.021391602233052254, 0.030925685539841652, -0.005404650699347258, -0.012358256615698338, 0.2850077152252197]
Request 3 ('a photo of a cat hi'): [-0.015912294387817383, 0.03907473012804985, -0.009924817830324173, -0.009802624583244324, 0.2923952043056488]
Request 4 ('a photo of a cat how old are you?'): [-0.007564036641269922, 0.014619219116866589, -0.015609420835971832, -0.02005157247185707, 0.2704349458217621]
Request 5 ('a photo of a cat hi'): [-0.015912294387817383, 0.03907473012804985, -0.009924817830324173, -0.009802624583244324, 0.2923952043056488]
Request 6 ('a photo of a cat how old are you?'): [-0.007564036641269922, 0.014619219116866589, -0.015609420835971832, -0.02005157247185707, 0.2704349458217621]
Request 7 ('a photo of a cat hi'): [-0.015912294387817383, 0.03907473012804985, -0.009924817830324173, -0.009802624583244324, 0.2923952043056488]
Request 8 ('a photo of a cat how old are you?'): [-0.04892544820904732, 0.03389119729399681, -0.006240752525627613, 0.019908834248781204, 0.3164888918399811]
Request 9 ('a photo of a cat hi'): [-0.007564036641269922, 0.014619219116866589, -0.015609420835971832, -0.02005157247185707, 0.2704349458217621]
Request 10 ('a photo of a cat how old are you?'): [-0.04892544820904732, 0.03389119729399681, -0.006240752525627613, 0.019908834248781204, 0.3164888918399811]

After fix

Request 1 ('a photo of a cat hi'): [-0.013481298461556435, 0.032701071351766586, -0.003023682162165642, -0.022816617041826248, 0.27149301767349243]
Request 2 ('a photo of a cat how old are you?'): [-0.04890233278274536, 0.03391861915588379, -0.006237870082259178, 0.01992269791662693, 0.3165053427219391]
Request 3 ('a photo of a cat hi'): [-0.013506860472261906, 0.032723937183618546, -0.003055858425796032, -0.022827142849564552, 0.27145496010780334]
Request 4 ('a photo of a cat how old are you?'): [-0.04890233278274536, 0.03391861915588379, -0.006237870082259178, 0.01992269791662693, 0.3165053427219391]
Request 5 ('a photo of a cat hi'): [-0.013506860472261906, 0.032723937183618546, -0.003055858425796032, -0.022827142849564552, 0.27145496010780334]
Request 6 ('a photo of a cat how old are you?'): [-0.04890233278274536, 0.03391861915588379, -0.006237870082259178, 0.01992269791662693, 0.3165053427219391]
Request 7 ('a photo of a cat hi'): [-0.013506860472261906, 0.032723937183618546, -0.003055858425796032, -0.022827142849564552, 0.27145496010780334]
Request 8 ('a photo of a cat how old are you?'): [-0.04890233278274536, 0.03391861915588379, -0.006237870082259178, 0.01992269791662693, 0.3165053427219391]
Request 9 ('a photo of a cat hi'): [-0.013506860472261906, 0.032723937183618546, -0.003055858425796032, -0.022827142849564552, 0.27145496010780334]
Request 10 ('a photo of a cat how old are you?'): [-0.04890233278274536, 0.03391861915588379, -0.006237870082259178, 0.01992269791662693, 0.3165053427219391]

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: piood <2477084691@qq.com>
@piood piood mentioned this pull request Nov 9, 2025
5 tasks
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug in SigLIP's batch text embedding. The fix involves two main changes: switching the text encoder to use EncoderOnlyAttention for correct batch handling, and implementing a new method _flip_sequences_by_position_ids to correctly flip individual sequences within a batch. The changes are logical and directly address the root causes described. I have one suggestion to improve the performance of the new flipping method by vectorizing its implementation.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@piood
Copy link
Copy Markdown
Contributor Author

piood commented Nov 10, 2025

@DarkLight1337 Can you review it?

@DarkLight1337
Copy link
Copy Markdown
Member

Does the existing test pass locally?

@piood
Copy link
Copy Markdown
Contributor Author

piood commented Nov 10, 2025

Does the existing test pass locally?

Yes, all passed.

piood and others added 3 commits November 10, 2025 06:49
Signed-off-by: piood <2477084691@qq.com>
Signed-off-by: piood <2477084691@qq.com>
@DarkLight1337 DarkLight1337 added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 10, 2025
@piood
Copy link
Copy Markdown
Contributor Author

piood commented Nov 10, 2025

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix a bug in SigLIP's batch text embedding process. The solution involves switching to EncoderOnlyAttention for proper batch handling and implementing a new method, _flip_sequences_by_position_ids, to correctly flip individual sequences. While the overall approach is sound, I've identified a potential critical issue in the boundary detection logic of the new flipping method which appears inconsistent with how SigLIP's reversed position IDs work.

@piood
Copy link
Copy Markdown
Contributor Author

piood commented Nov 10, 2025

@DarkLight1337 All checks have passed, please review again, thanks.

Copy link
Copy Markdown
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

@DarkLight1337 DarkLight1337 merged commit 15be507 into vllm-project:main Nov 10, 2025
55 checks passed
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
@piood piood deleted the fix-siglip-text-batch-output branch December 3, 2025 03:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants