From 409938553556f851ea47155833d23ef12dd80083 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Fri, 13 Mar 2026 15:00:21 +0000 Subject: [PATCH 1/2] add fix for bagel, update example Signed-off-by: Chendi Xue --- examples/offline_inference/bagel/end2end.py | 16 +++- .../model_executor/models/bagel/bagel.py | 5 +- .../platforms/xpu/stage_configs/bagel.yaml | 86 +++++++++++++++++++ 3 files changed, 103 insertions(+), 4 deletions(-) create mode 100644 vllm_omni/platforms/xpu/stage_configs/bagel.yaml diff --git a/examples/offline_inference/bagel/end2end.py b/examples/offline_inference/bagel/end2end.py index cffe44fc36a..5fe8309d77e 100644 --- a/examples/offline_inference/bagel/end2end.py +++ b/examples/offline_inference/bagel/end2end.py @@ -34,6 +34,13 @@ def parse_args(): help="Path to input image for img2img.", ) + parser.add_argument( + "--output", + type=str, + default=".", + help="Output directory to save images.", + ) + # OmniLLM init args parser.add_argument("--log-stats", action="store_true", default=False) parser.add_argument("--init-sleep-seconds", type=int, default=20) @@ -65,6 +72,7 @@ def parse_args(): def main(): args = parse_args() + os.makedirs(args.output, exist_ok=True) model_name = args.model prompts: list[OmniPromptType] = [] try: @@ -173,13 +181,17 @@ def main(): if images: for j, img in enumerate(images): - img.save(f"output_{i}_{j}.png") + save_path = os.path.join(args.output, f"output_{i}_{j}.png") + img.save(save_path) if hasattr(req_output, "request_output") and req_output.request_output: for stage_out in req_output.request_output: if hasattr(stage_out, "images") and stage_out.images: for k, img in enumerate(stage_out.images): - save_path = f"output_{i}_stage_{getattr(stage_out, 'stage_id', '?')}_{k}.png" + save_path = os.path.join( + args.output, + f"output_{i}_stage_{getattr(stage_out, 'stage_id', '?')}_{k}.png", + ) img.save(save_path) print(f"[Info] Saved stage output image to {save_path}") diff --git a/vllm_omni/model_executor/models/bagel/bagel.py b/vllm_omni/model_executor/models/bagel/bagel.py index a053e98dd06..e58b3501c44 100644 --- a/vllm_omni/model_executor/models/bagel/bagel.py +++ b/vllm_omni/model_executor/models/bagel/bagel.py @@ -41,6 +41,7 @@ ) from vllm.transformers_utils.processors.bagel import BagelProcessor +from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.models.bagel.autoencoder import ( AutoEncoderParams, DiagonalGaussian, @@ -439,7 +440,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._end_of_image_id = int(_tok.convert_tokens_to_ids("<|vision_end|>")) self._vae_token_mask: torch.Tensor | None = None - + self.device = get_local_device() self._install_mot_modules(config) def _install_mot_modules(self, config): @@ -627,7 +628,7 @@ def _process_img2img_input(self, multimodal_input): ) pos_embed = self.latent_pos_embed([vae_position_ids]) packed_timesteps = torch.tensor([timestep], device=padded_latent.device) - with torch.amp.autocast("cuda", dtype=torch.bfloat16): + with torch.amp.autocast(self.device.type, dtype=torch.bfloat16): timestep_embeds = self.time_embedder(packed_timesteps.to(padded_latent)) vae_embeds = self.vae2llm(latent) + timestep_embeds + pos_embed diff --git a/vllm_omni/platforms/xpu/stage_configs/bagel.yaml b/vllm_omni/platforms/xpu/stage_configs/bagel.yaml new file mode 100644 index 00000000000..b1443d884a6 --- /dev/null +++ b/vllm_omni/platforms/xpu/stage_configs/bagel.yaml @@ -0,0 +1,86 @@ +# stage config for running bagel-7b-mot with architecture of OmniLLM. + +stage_args: + - stage_id: 0 + stage_type: llm + prompt_expand_func: vllm_omni.model_executor.stage_input_processors.bagel.expand_cfg_prompts + runtime: + devices: "0" + # 3 = 1 user prompt + 2 CFG companions (text-unconditional + image-unconditional). + max_batch_size: 3 + engine_args: + model_stage: thinker + model_arch: OmniBagelForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.9 + enforce_eager: true + trust_remote_code: true + engine_output_type: text + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 16384 + tensor_parallel_size: 1 + quantization: fp8 + omni_kv_config: + need_send_cache: true + kv_transfer_criteria: + type: prefill_finished #or special token generated + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.4 + top_p: 0.9 + top_k: 1 + max_tokens: 2048 + seed: 52 + detokenize: True + repetition_penalty: 1.05 + + - stage_id: 1 + stage_type: diffusion + cfg_kv_collect_func: vllm_omni.model_executor.stage_input_processors.bagel.collect_cfg_kv_caches + runtime: + devices: "1" + max_batch_size: 1 + engine_args: + model_stage: dit + gpu_memory_utilization: 0.9 + enforce_eager: true + trust_remote_code: true + engine_output_type: image + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + tensor_parallel_size: 1 + omni_kv_config: + need_recv_cache: true + engine_input_source: [0] + + final_output: true + final_output_type: image + is_comprehension: false + default_sampling_params: + seed: 52 + +# Runtime edges +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 + + # Distributed connectors configuration (optional) + # More connectors will be supported in the future. + connectors: + shared_memory_connector: + name: SharedMemoryConnector + extra: + shm_threshold_bytes: 65536 # 64KB threshold + + + edges: + - from: 0 + to: 1 + window_size: -1 From a4884a868413c92f0c2b68387333734d97741b9f Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Wed, 18 Mar 2026 21:54:21 +0000 Subject: [PATCH 2/2] fix rebase Signed-off-by: Chendi Xue --- examples/offline_inference/bagel/end2end.py | 29 ++++++--------------- 1 file changed, 8 insertions(+), 21 deletions(-) diff --git a/examples/offline_inference/bagel/end2end.py b/examples/offline_inference/bagel/end2end.py index 89e7e2c14d3..d47004f2132 100644 --- a/examples/offline_inference/bagel/end2end.py +++ b/examples/offline_inference/bagel/end2end.py @@ -174,27 +174,14 @@ def main(): img_idx = 0 for req_output in omni_outputs: images = getattr(req_output, "images", None) - if not images and hasattr(req_output, "output"): - if isinstance(req_output.output, list): - images = req_output.output - else: - images = [req_output.output] - - if images: - for j, img in enumerate(images): - save_path = os.path.join(args.output, f"output_{i}_{j}.png") - img.save(save_path) - - if hasattr(req_output, "request_output") and req_output.request_output: - for stage_out in req_output.request_output: - if hasattr(stage_out, "images") and stage_out.images: - for k, img in enumerate(stage_out.images): - save_path = os.path.join( - args.output, - f"output_{i}_stage_{getattr(stage_out, 'stage_id', '?')}_{k}.png", - ) - img.save(save_path) - print(f"[Info] Saved stage output image to {save_path}") + + if not images: + continue + + for j, img in enumerate(images): + save_path = os.path.join(args.output, f"output_{img_idx}_{j}.png") + img.save(save_path) + img_idx += 1 print(omni_outputs)