Skip to content

[Bagel]: Support multistage img2img#1669

Merged
hsliuustc0106 merged 22 commits into
vllm-project:mainfrom
princepride:newest-support-multistage-i2i-bagel
Mar 9, 2026
Merged

[Bagel]: Support multistage img2img#1669
hsliuustc0106 merged 22 commits into
vllm-project:mainfrom
princepride:newest-support-multistage-i2i-bagel

Conversation

@princepride
Copy link
Copy Markdown
Collaborator

@princepride princepride commented Mar 5, 2026

Purpose

Fully support Bagel multi-stage deploy.

  • Patch ModelConfig.is_mm_prefix_lm to include Bagel, enabling bidirectional attention for multimodal prefix positions. Delegates to the original vLLM implementation so upstream list updates are automatically inherited.
  • Introduce OmniBagelForConditionalGeneration, wrapping vLLM's Bagel with a built-in VAE encoder for img2img. Supports dual-mode multimodal processing: image modality uses ViT, img2img modality uses VAE. Tracks per-request RoPE and image shape via get_kv_transfer_metadata() for KV transfer to the diffusion stage.
  • Add img2img CFG expansion producing 3 companions (user + cfg_text + cfg_img). Update max_batch_size to 3 in bagel.yaml.
  • Support custom_metadata in KV cache transfer, allowing model-specific data (RoPE, image shape) to flow from AR to diffusion stage. Propagate received metadata to req.sampling_params.kv_metadata.
  • Query model's get_kv_transfer_metadata() to attach custom metadata before KV transfer. Clear model warmup state on first real inference call.

Test Plan

img2img(stage0, stage1)

FLASHINFER_DISABLE_VERSION_CHECK=1 python3 examples/offline_inference/bagel/end2end.py   --modality img2img   --image-path women.jpg   --prompts "Let the woman wear a blue dress"
bagel_i2i_output

text2img(stag0, stage1)

FLASHINFER_DISABLE_VERSION_CHECK=1 python3 examples/offline_inference/bagel/end2end.py   --modality text2img   --prompts "A cute cat"
image

text2text(stage0)

FLASHINFER_DISABLE_VERSION_CHECK=1 python3 examples/offline_inference/bagel/end2end.py   --modality text2text   --prompts "Where is the capital of France?"
'The capital of France is Paris.'

img2text(stage0)

FLASHINFER_DISABLE_VERSION_CHECK=1 python3 examples/offline_inference/bagel/end2end.py   --modality img2text   --image-path text2img.png   --prompts "Please describe this image."
'This is a digital illustration of a kitten sitting on a wooden floor. The kitten has orange and white fur with darker stripes, large expressive eyes, and pink ears. Its tail is curled around its body, and it appears to be looking directly at the viewer with a curious expression. The background is blurred, suggesting an indoor setting with soft lighting.'

Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
@princepride princepride marked this pull request as ready for review March 5, 2026 10:00
@princepride
Copy link
Copy Markdown
Collaborator Author

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 29692a7c51

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread vllm_omni/model_executor/models/registry.py
Comment thread vllm_omni/model_executor/models/bagel/bagel.py Outdated
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have any test for this multistage example?

Comment thread vllm_omni/patch.py
@@ -0,0 +1,1009 @@
from collections import deque
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this file the wapper for the bagel model for AR part?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't wrapper vLLM's implementation anymore, whether through direct calls or by inheriting modules. First, the vllm's img2text task only requires ViT, whereas our img2img task needs both ViT and VAE. This means we have to modify both the mm_processor and the vision encoder. During implementation, I also discovered another issue: the vision latents for the VAE part actually use DiT model weights rather than AR weights. We must load the full weights into the AR stage.

Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review

Rating: 7.5/10 | Verdict: ⚠️ Changes Requested

Summary

Comprehensive implementation of Bagel multistage img2img support with AR wrapper, VAE encoder, MoT routing, and KV transfer metadata. The feature is well-designed but has a critical issue that needs to be addressed before merge.


🔴 Critical Issue: FIFO Queue Metadata Mismatch

Location: vllm_omni/model_executor/models/bagel/bagel.py:513

def get_kv_transfer_metadata(self, req_id: str) -> dict[str, Any] | None:
    if self._ropes_queue:
        return self._ropes_queue.popleft()  # Ignores req_id!
    return None

Problem: The req_id parameter is completely ignored. Metadata is returned in FIFO order regardless of which request is asking. When multiple requests are in-flight and finish out of order, a request can receive another request's ropes/image_shape, corrupting downstream diffusion conditioning.

Suggested Fix: Use a dict keyed by req_id instead of a queue:

self._ropes_metadata: dict[str, dict[str, Any]] = {}

def get_kv_transfer_metadata(self, req_id: str) -> dict[str, Any] | None:
    return self._ropes_metadata.pop(req_id, None)

🟡 Questions

1. Independent Batching Support

Does this PR support independent (decoupled) batching for multistage inference? Specifically:

  • Can AR and diffusion stages have different max_batch_size?
  • Are requests queued and processed independently at each stage?

This is important for production scenarios where AR is lightweight but diffusion is memory-intensive. True independent batching would allow e.g., AR with batch_size=8 and diffusion with batch_size=2.

2. Tests for Multistage Example

Do we have any automated tests for this multistage img2img flow? I noticed the PR description has manual test commands, but I don't see corresponding test files. Adding tests for:

  • Multi-request scenarios with the metadata flow
  • E2E img2img + text2img in multistage config
  • Regression tests for single-stage mode

would help ensure stability.

3. Patch Alternatives (re: my inline comment)

The patch.py modification to ModelConfig.is_mm_prefix_lm works, but are there plans to:

  • Contribute this upstream to vLLM so Bagel is in the built-in list?
  • Use a HF config flag instead of patching?

Not blocking, but worth considering for maintainability.


🟢 Strengths

  1. Complete Feature - Supports img2img, text2img, text2text, img2text
  2. Clean Architecture - AR wrapper + VAE encoder + MoT routing
  3. Good Documentation - PR has test screenshots and commands

Action Items:

  1. Fix FIFO queue → dict keyed by req_id
  2. Add tests for multistage flow
  3. Clarify independent batching support

Thanks for the implementation! Once the metadata issue is fixed, this will be a great addition. 🦐

Signed-off-by: princepride <wangzhipeng628@gmail.com>
@princepride
Copy link
Copy Markdown
Collaborator Author

@natureofnature PTAL

@princepride
Copy link
Copy Markdown
Collaborator Author

princepride commented Mar 5, 2026

do we have any test for this multistage example?

I can add an e2e test for img2img, but it requires an image as input. Previously, in vLLM, Simon helped me upload test images to S3 bucket for CI/CD purposes, but I’m not sure where we are storing our test images in vLLM-Omni now. @congw729

Signed-off-by: princepride <wangzhipeng628@gmail.com>
Comment thread vllm_omni/model_executor/models/bagel/bagel.py
@congw729
Copy link
Copy Markdown
Collaborator

congw729 commented Mar 6, 2026

do we have any test for this multistage example?

I can add an e2e test for img2img, but it requires an image as input. Previously, in vLLM, Simon helped me upload test images to S3 bucket for CI/CD purposes, but I’m not sure where we are storing our test images in vLLM-Omni now. @congw729

We haven't tried uploading test input images to the S3 bucket. Can you show me the PR/link you just mentioned? I will try to satisfied your need.

@princepride
Copy link
Copy Markdown
Collaborator Author

do we have any test for this multistage example?

I can add an e2e test for img2img, but it requires an image as input. Previously, in vLLM, Simon helped me upload test images to S3 bucket for CI/CD purposes, but I’m not sure where we are storing our test images in vLLM-Omni now. @congw729

We haven't tried uploading test input images to the S3 bucket. Can you show me the PR/link you just mentioned? I will try to satisfied your need.

vllm-project/vllm#23229 (comment) Last time, Simon Mo helped me upload images to the S3 bucket.

Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
@hsliuustc0106 hsliuustc0106 added the ready label to trigger buildkite CI label Mar 9, 2026
logger.warning(f"Request {req_id} has no block IDs, skipping")
continue

custom_metadata = data.get("custom_metadata")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need introduce the custom_metadata? is there any doc to explain this arg? o.w., this many bring poos user/dev experiences.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we need transfer RoPE and image shape from ar stage to dit stage

"""
if isinstance(layer_kv, torch.Tensor):
if layer_kv.ndim < 3 or layer_kv.shape[0] != 2:
if layer_kv.ndim >= 3 and layer_kv.shape[0] == 2:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks like a kv layer issue? can we make it as a util function?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will do it. In vLLM, different backend will return different shape of paged attention block, for example in: vllm/v1/attention/backends/flash_attn.py

@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
    if block_size % 16 != 0:
        raise ValueError("Block size must be a multiple of 16.")
    return (2, num_blocks, block_size, num_kv_heads, head_size)

However, in: vllm/v1/attention/backends/flashinfer.py

@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
    return (num_blocks, 2, block_size, num_kv_heads, head_size)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@hsliuustc0106 hsliuustc0106 merged commit 7271116 into vllm-project:main Mar 9, 2026
7 checks passed
lishunyang12 pushed a commit to lishunyang12/vllm-omni that referenced this pull request Mar 11, 2026
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
@princepride princepride mentioned this pull request Mar 13, 2026
14 tasks
clodaghwalsh17 pushed a commit to clodaghwalsh17/nm-vllm-omni-ent that referenced this pull request May 12, 2026
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants