diff --git a/examples/offline_inference/bagel/end2end.py b/examples/offline_inference/bagel/end2end.py
index efcdea2355d..2153a31ba70 100644
--- a/examples/offline_inference/bagel/end2end.py
+++ b/examples/offline_inference/bagel/end2end.py
@@ -2,6 +2,7 @@
import os
from vllm_omni.inputs.data import OmniPromptType
+from vllm_omni.model_executor.stage_input_processors.bagel import GEN_THINK_SYSTEM_PROMPT
def parse_args():
@@ -65,6 +66,17 @@ def parse_args():
help="CFG parallel size: 1=batched (single GPU), 2=parallel with 2 branches (text CFG only), 3=parallel (3 GPUs).",
)
parser.add_argument("--seed", type=int, default=None, help="Random seed for generation.")
+ parser.add_argument(
+ "--cfg-interval",
+ type=float,
+ nargs=2,
+ default=None,
+ help="CFG interval [start, end] (default: pipeline default)",
+ )
+ parser.add_argument(
+ "--cfg-renorm-type", type=str, default=None, help="CFG renorm type: global, text_channel, channel"
+ )
+ parser.add_argument("--cfg-renorm-min", type=float, default=None, help="CFG renorm min")
parser.add_argument(
"--enable-diffusion-pipeline-profiler",
action="store_true",
@@ -76,6 +88,12 @@ def parse_args():
default=None,
help="Quantization method (e.g. 'fp8').",
)
+ parser.add_argument(
+ "--think",
+ action="store_true",
+ default=False,
+ help="Enable thinking mode: AR stage decodes ... planning tokens before image generation.",
+ )
args = parser.parse_args()
return args
@@ -110,8 +128,12 @@ def main():
from vllm_omni.entrypoints.omni import Omni
omni_kwargs = {}
- if args.stage_configs_path:
- omni_kwargs["stage_configs_path"] = args.stage_configs_path
+ stage_configs_path = args.stage_configs_path
+ if args.think and stage_configs_path is None:
+ stage_configs_path = "vllm_omni/model_executor/stage_configs/bagel_think.yaml"
+ print(f"[Info] Think mode enabled, using stage config: {stage_configs_path}")
+ if stage_configs_path:
+ omni_kwargs["stage_configs_path"] = stage_configs_path
omni_kwargs.update(
{
@@ -136,7 +158,8 @@ def main():
if not args.image_path or not os.path.exists(args.image_path):
raise ValueError(f"img2img requires --image-path pointing to an existing file, got: {args.image_path}")
loaded_image = Image.open(args.image_path).convert("RGB")
- final_prompt_text = f"<|fim_middle|><|im_start|>{p}<|im_end|>"
+ think_prefix = f"<|im_start|>{GEN_THINK_SYSTEM_PROMPT}<|im_end|>" if args.think else ""
+ final_prompt_text = f"{think_prefix}<|fim_middle|><|im_start|>{p}<|im_end|>"
prompt_dict = {
"prompt": final_prompt_text,
"multi_modal_data": {"img2img": loaded_image},
@@ -160,7 +183,8 @@ def main():
prompt_dict = {"prompt": final_prompt_text, "modalities": ["text"]}
formatted_prompts.append(prompt_dict)
else:
- final_prompt_text = f"<|im_start|>{p}<|im_end|>"
+ think_prefix = f"<|im_start|>{GEN_THINK_SYSTEM_PROMPT}<|im_end|>" if args.think else ""
+ final_prompt_text = f"{think_prefix}<|im_start|>{p}<|im_end|>"
prompt_dict = {"prompt": final_prompt_text, "modalities": ["image"]}
if args.negative_prompt is not None:
prompt_dict["negative_prompt"] = args.negative_prompt
@@ -178,6 +202,12 @@ def main():
"cfg_text_scale": args.cfg_text_scale,
"cfg_img_scale": args.cfg_img_scale,
}
+ if args.cfg_interval is not None:
+ extra["cfg_interval"] = tuple(args.cfg_interval)
+ if args.cfg_renorm_type is not None:
+ extra["cfg_renorm_type"] = args.cfg_renorm_type
+ if args.cfg_renorm_min is not None:
+ extra["cfg_renorm_min"] = args.cfg_renorm_min
if args.negative_prompt is not None:
extra["negative_prompt"] = args.negative_prompt
diffusion_params.extra_args = extra # type: ignore
@@ -186,6 +216,17 @@ def main():
img_idx = 0
for req_output in omni_outputs:
+ if args.think:
+ text_output = getattr(req_output, "text", None) or getattr(req_output, "outputs", None)
+ if text_output:
+ if isinstance(text_output, list) and text_output:
+ for out in text_output:
+ txt = getattr(out, "text", str(out))
+ if txt:
+ print(f"[Think] {txt}")
+ elif isinstance(text_output, str):
+ print(f"[Think] {text_output}")
+
images = getattr(req_output, "images", None)
if not images:
@@ -194,6 +235,7 @@ def main():
for j, img in enumerate(images):
save_path = os.path.join(args.output, f"output_{img_idx}_{j}.png")
img.save(save_path)
+ print(f"[Output] Saved image to {save_path}")
img_idx += 1
print(omni_outputs)
diff --git a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py
index aa4f0a74f02..3e053cbda50 100644
--- a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py
+++ b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py
@@ -326,11 +326,18 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
cfg_text_scale = extra_args.get("cfg_text_scale", 4.0)
cfg_img_scale = extra_args.get("cfg_img_scale", 1.5)
+ cfg_interval = extra_args.get("cfg_interval", (0.4, 1.0))
+ cfg_renorm_type = extra_args.get("cfg_renorm_type", "global")
+ cfg_renorm_min = extra_args.get("cfg_renorm_min", 0.0)
+
gen_params = BagelGenParams(
num_timesteps=int(req.sampling_params.num_inference_steps or 50),
timestep_shift=3.0,
cfg_text_scale=cfg_text_scale,
cfg_img_scale=cfg_img_scale,
+ cfg_interval=cfg_interval,
+ cfg_renorm_type=cfg_renorm_type,
+ cfg_renorm_min=cfg_renorm_min,
)
gen_context = {
diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py
index 9de3dc867ff..b94f83bab39 100644
--- a/vllm_omni/engine/async_omni_engine.py
+++ b/vllm_omni/engine/async_omni_engine.py
@@ -722,14 +722,15 @@ def _enqueue_cfg_companions(
cid = f"{parent_id}{ep.request_id_suffix}"
companion_prompt = ep.prompt
- # Run through same input processing as the main prompt
+ companion_params, companion_spl = ep.apply_overrides(stage0_params, sampling_params_list)
+
if isinstance(companion_prompt, dict):
_inject_global_id(companion_prompt, cid)
request = self.input_processor.process_inputs(
request_id=cid,
prompt=companion_prompt,
- params=stage0_params,
+ params=companion_params,
supported_tasks=self.supported_tasks,
)
request = _upgrade_to_omni_request(request, companion_prompt)
@@ -750,7 +751,7 @@ def _enqueue_cfg_companions(
"parent_id": parent_id,
"role": ep.role,
"prompt": request,
- "sampling_params_list": sampling_params_list,
+ "sampling_params_list": companion_spl,
}
)
diff --git a/vllm_omni/model_executor/models/bagel/bagel.py b/vllm_omni/model_executor/models/bagel/bagel.py
index e58b3501c44..e79f0212e2e 100644
--- a/vllm_omni/model_executor/models/bagel/bagel.py
+++ b/vllm_omni/model_executor/models/bagel/bagel.py
@@ -429,6 +429,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self._ropes_metadata: dict[str, dict[str, Any]] = {}
self._cfg_companion_queue: deque[tuple[tuple[int, int, int, int], int]] = deque()
+ # Per-request position offset for decode after img2img prefill.
+ # Prefill rewrites positions (VAE→0, ViT→1, text→2..N) but the model
+ # runner assigns decode positions starting from prefill_len, not N+1.
+ # offset = rope - prefill_len (a negative number).
+ self._pending_decode_offsets: list[int] = []
+ self._decode_position_offsets: dict[str, int] = {}
+
from transformers import AutoTokenizer
tok_name = getattr(vllm_config.model_config, "tokenizer", None) or vllm_config.model_config.model
@@ -438,6 +445,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
_tok.add_tokens([t])
self._start_of_image_id = int(_tok.convert_tokens_to_ids("<|vision_start|>"))
self._end_of_image_id = int(_tok.convert_tokens_to_ids("<|vision_end|>"))
+ self._img2img_token_id = int(_tok.convert_tokens_to_ids("<|fim_middle|>"))
self._vae_token_mask: torch.Tensor | None = None
self.device = get_local_device()
@@ -518,10 +526,64 @@ def _clear_warmup_state(self):
self._ropes_metadata.clear()
self._pending_img2img_info.clear()
self._cfg_companion_queue.clear()
+ self._pending_decode_offsets.clear()
+ self._decode_position_offsets.clear()
self._vae_token_mask = None
- def get_kv_transfer_metadata(self, req_id: str) -> dict[str, Any] | None:
- return self._ropes_metadata.pop(req_id, None)
+ def get_kv_transfer_metadata(
+ self,
+ req_id: str,
+ *,
+ num_computed_tokens: int | None = None,
+ ) -> dict[str, Any] | None:
+ meta = self._ropes_metadata.pop(req_id, None)
+ if meta is None:
+ return None
+ # In think-mode img2img the prefill rope doesn't account for decoded
+ # thinking tokens; correct it to num_computed_tokens + offset.
+ # Skip correction when num_computed_tokens is unavailable (None).
+ offset = self._decode_position_offsets.pop(req_id, 0)
+ if offset != 0 and "ropes" in meta and num_computed_tokens is not None:
+ meta["ropes"] = [num_computed_tokens + offset]
+ return meta
+
+ def prepare_runner_inputs(
+ self,
+ input_ids: torch.Tensor | None,
+ positions: torch.Tensor | None,
+ inputs_embeds: torch.Tensor | None,
+ req_ids: list[str],
+ num_computed_tokens: list[int],
+ num_scheduled_tokens: list[int],
+ input_ids_buffer: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
+ """Model-runner hook: adjust inputs before ``forward()``.
+
+ Returns ``(input_ids, positions)`` — possibly modified.
+
+ Two adjustments for BAGEL img2img:
+
+ 1. **Restore input_ids** when ``inputs_embeds`` is present so that
+ ``_adjust_positions_for_img2img`` can locate the
+ ``<|fim_middle|>`` placeholder.
+ 2. **Decode position offset**: prefill rewrites positions to a
+ compact scheme (rope ≪ prefill_len). The runner assigns decode
+ positions from ``num_computed_tokens``, which is far too large;
+ apply the stored per-request offset.
+ """
+ if inputs_embeds is not None and input_ids is None and input_ids_buffer is not None:
+ input_ids = input_ids_buffer
+
+ if self._decode_position_offsets and positions is not None:
+ token_start = 0
+ for i, rid in enumerate(req_ids):
+ sched = num_scheduled_tokens[i]
+ offset = self._decode_position_offsets.get(rid, 0)
+ if offset != 0 and num_computed_tokens[i] > 0:
+ positions[token_start : token_start + sched] += offset
+ token_start += sched
+
+ return input_ids, positions
def flush_pending_metadata(self, req_ids: list[str]) -> None:
"""Map pending metadata (batch order) to req_ids after forward()."""
@@ -529,7 +591,14 @@ def flush_pending_metadata(self, req_ids: list[str]) -> None:
self._ropes_pending = []
for i, meta in enumerate(pending):
if i < len(req_ids):
- self._ropes_metadata[req_ids[i]] = meta
+ if req_ids[i] not in self._ropes_metadata:
+ self._ropes_metadata[req_ids[i]] = meta
+
+ pending_offsets = self._pending_decode_offsets
+ self._pending_decode_offsets = []
+ for i, offset in enumerate(pending_offsets):
+ if i < len(req_ids) and offset != 0:
+ self._decode_position_offsets[req_ids[i]] = offset
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
mm_input_by_modality = {}
@@ -643,7 +712,16 @@ def _process_img2img_input(self, multimodal_input):
num_vit = vit_emb.shape[0] + 2
info = (num_vae, num_vit, int(H), int(W))
self._pending_img2img_info.append(info)
- self._cfg_companion_queue.append((info, 2)) # cfg_text + cfg_img
+ # Only the gen (main) request should add a companion queue entry.
+ # Companion requests (cfg_text, cfg_img) also call this method with
+ # the same image, so guard by checking whether this exact info
+ # tuple is already enqueued. For batched img2img with multiple
+ # concurrent gen requests this correctly adds one entry per unique
+ # image; images with identical (num_vae, num_vit, H, W) that arrive
+ # in the same batch are indistinguishable here and will share one
+ # entry, but that is an uncommon edge case.
+ if not any(entry[0] == info for entry in self._cfg_companion_queue):
+ self._cfg_companion_queue.append((info, 2)) # cfg_text + cfg_img
return tuple(results)
@@ -659,42 +737,65 @@ def forward(
seq_len = inputs_embeds.shape[0] if inputs_embeds is not None else positions.shape[0]
if self._pending_img2img_info:
- positions = self._adjust_positions_for_img2img(positions)
+ positions = self._adjust_positions_for_img2img(positions, input_ids)
use_mot = True
elif self._cfg_companion_queue:
- cached, remaining = self._cfg_companion_queue[0]
- remaining -= 1
- num_vae, num_vit, img_H, img_W = cached
- num_img2img = num_vae + 1 + num_vit # +1 separator
- seq_len = inputs_embeds.shape[0] if inputs_embeds is not None else positions.shape[0]
-
- if inputs_embeds is not None and seq_len >= num_img2img:
- self._pending_img2img_info = [cached]
- positions = self._adjust_positions_for_img2img(positions)
- use_mot = True
+ # Guard: if this looks like a pure decode step (small token count,
+ # no multimodal embeddings), the queue has stale entries from a
+ # previous prefill cycle — clear them instead of consuming.
+ if inputs_embeds is None and seq_len <= 2:
+ self._cfg_companion_queue.clear()
else:
- rope = int(positions[seq_len - 1].item()) + 1
- self._ropes_pending.append({"ropes": [rope]})
+ cached, remaining = self._cfg_companion_queue[0]
+ remaining -= 1
+ num_vae, num_vit, img_H, img_W = cached
+ num_img2img = num_vae + 1 + num_vit # +1 separator
+ seq_len = inputs_embeds.shape[0] if inputs_embeds is not None else positions.shape[0]
- if remaining == 0:
- self._cfg_companion_queue.popleft()
- else:
- self._cfg_companion_queue[0] = (cached, remaining)
+ if inputs_embeds is not None and seq_len >= num_img2img:
+ self._pending_img2img_info = [cached]
+ positions = self._adjust_positions_for_img2img(positions, input_ids)
+ use_mot = True
+ else:
+ rope = int(positions[seq_len - 1].item()) + 1
+ self._ropes_pending.append({"ropes": [rope]})
+
+ if remaining == 0:
+ self._cfg_companion_queue.popleft()
+ else:
+ self._cfg_companion_queue[0] = (cached, remaining)
if use_mot:
return self._mot_forward(input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs)
return super().forward(input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs)
- def _adjust_positions_for_img2img(self, positions: torch.Tensor) -> torch.Tensor:
- """Rewrite position IDs to match the single-stage DiT scheme:
- VAE tokens -> position 0, separator -> position 0,
- ViT tokens -> position 1, text -> 2, 3, ...
+ def _adjust_positions_for_img2img(
+ self,
+ positions: torch.Tensor,
+ input_ids: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ """Rewrite position IDs to match the original BAGEL position scheme:
+
+ If there are ``pre_text_len`` text tokens before the img2img block::
+
+ pre_text → 0, 1, ..., M-1
+ VAE → M (all share)
+ separator→ M
+ ViT → M+1 (all share)
+ post_text→ M+2, M+3, ...
+
+ When no text precedes the img2img block (M=0), this reduces to the
+ simpler scheme: VAE→0, ViT→1, text→2, 3, ...
Also computes ``self._vae_token_mask`` (bool tensor, True for actual
VAE latent patches that should use gen-mode weights) and pushes
per-request ropes + image_shape to the FIFO consumed by
``get_kv_transfer_metadata``.
+
+ For img2img requests, also stores a decode position offset so that
+ subsequent autoregressive decode steps use positions that continue
+ from the rewritten scheme rather than from the original prefill length.
"""
info_list = self._pending_img2img_info
self._pending_img2img_info = []
@@ -724,35 +825,66 @@ def _adjust_positions_for_img2img(self, positions: torch.Tensor) -> torch.Tensor
num_img2img = num_vae + 1 + num_vit # +1 separator
if req_len >= num_img2img:
- new_positions[start : start + num_vae] = 0
- new_positions[start + num_vae] = 0 # separator
- vit_start = start + num_vae + 1
- new_positions[vit_start : vit_start + num_vit] = 1
- num_text = req_len - num_img2img
- if num_text > 0:
- text_start = start + num_img2img
- new_positions[text_start:end] = torch.arange(
- 2, 2 + num_text, device=positions.device, dtype=positions.dtype
+ # Detect offset of img2img tokens within this request
+ # by searching for the img2img placeholder token ID.
+ pre_text_len = 0
+ if input_ids is not None:
+ req_ids = input_ids[start:end]
+ mask = req_ids == self._img2img_token_id
+ indices = mask.nonzero(as_tuple=True)[0]
+ if indices.numel() > 0:
+ pre_text_len = int(indices[0].item())
+
+ img_start = start + pre_text_len
+ post_text_start = img_start + num_img2img
+ # pre_text_pos: position base for image tokens
+ pre_text_pos = pre_text_len
+
+ # Pre-image text: sequential positions 0..pre_text_pos-1
+ if pre_text_len > 0:
+ new_positions[start:img_start] = torch.arange(
+ 0, pre_text_pos, device=positions.device, dtype=positions.dtype
+ )
+
+ # VAE tokens: all share position pre_text_pos
+ new_positions[img_start : img_start + num_vae] = pre_text_pos
+ # Separator: position pre_text_pos
+ new_positions[img_start + num_vae] = pre_text_pos
+ # ViT tokens: all share position pre_text_pos+1
+ vit_start = img_start + num_vae + 1
+ new_positions[vit_start : vit_start + num_vit] = pre_text_pos + 1
+
+ # Post-image text: sequential positions pre_text_pos+2, pre_text_pos+3, ...
+ num_post_text = end - post_text_start
+ if num_post_text > 0:
+ new_positions[post_text_start:end] = torch.arange(
+ pre_text_pos + 2,
+ pre_text_pos + 2 + num_post_text,
+ device=positions.device,
+ dtype=positions.dtype,
)
- # VAE gen-mode mask: only actual VAE patches (not markers)
- vae_patches_start = start + 1 # skip start_marker
- vae_patches_end = start + num_vae - 1 # before end_marker
+ # VAE gen-mode mask: only actual VAE latent patches (not markers)
+ vae_patches_start = img_start + 1 # skip start_marker
+ vae_patches_end = img_start + num_vae - 1 # before end_marker
if vae_patches_end > vae_patches_start:
vae_mask[vae_patches_start:vae_patches_end] = True
- rope = 2 + num_text
+ rope = pre_text_pos + 2 + num_post_text
self._ropes_pending.append(
{
"ropes": [rope],
"image_shape": [img_H, img_W],
}
)
+ decode_offset = rope - req_len
+ self._pending_decode_offsets.append(decode_offset)
img2img_idx += 1
continue
rope = int(new_positions[end - 1].item()) + 1
self._ropes_pending.append({"ropes": [rope]})
+ self._pending_decode_offsets.append(0)
self._vae_token_mask = vae_mask if vae_mask.any() else None
return new_positions
diff --git a/vllm_omni/model_executor/stage_configs/bagel_think.yaml b/vllm_omni/model_executor/stage_configs/bagel_think.yaml
new file mode 100644
index 00000000000..c4cf32c707e
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/bagel_think.yaml
@@ -0,0 +1,86 @@
+# BAGEL Think Model: AR stage decodes thinking tokens before KV transfer to DiT.
+#
+# Differences from bagel.yaml:
+# - No kv_transfer_criteria: AR stage decodes until EOS, then transfers full
+# KV cache (including thinking tokens) via _free_request path.
+# - prompt_expand_func: uses expand_cfg_prompts_think which sets max_tokens=1
+# on companion requests so they stop immediately after prefill.
+# - max_tokens: 2048 for thinking text generation.
+
+stage_args:
+ - stage_id: 0
+ stage_type: llm
+ prompt_expand_func: vllm_omni.model_executor.stage_input_processors.bagel.expand_cfg_prompts_think
+ runtime:
+ devices: "0"
+ engine_args:
+ model_stage: thinker
+ max_num_seqs: 3
+ model_arch: OmniBagelForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.45
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: text
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ tensor_parallel_size: 1
+ omni_kv_config:
+ need_send_cache: true
+ final_output: true
+ final_output_type: text
+ is_comprehension: true
+ default_sampling_params:
+ temperature: 0.3
+ top_p: 0.9
+ top_k: 1
+ max_tokens: 2048
+ seed: 52
+ detokenize: True
+ repetition_penalty: 1.05
+
+ - stage_id: 1
+ stage_type: diffusion
+ cfg_kv_collect_func: vllm_omni.model_executor.stage_input_processors.bagel.collect_cfg_kv_caches
+ runtime:
+ devices: "0"
+ engine_args:
+ model_stage: dit
+ max_num_seqs: 1
+ gpu_memory_utilization: 0.45
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: image
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ tensor_parallel_size: 1
+ omni_kv_config:
+ need_recv_cache: true
+ engine_input_source: [0]
+
+ final_output: true
+ final_output_type: image
+ is_comprehension: false
+ default_sampling_params:
+ seed: 52
+
+# Runtime edges
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1
+ max_inflight: 1
+
+ connectors:
+ shared_memory_connector:
+ name: SharedMemoryConnector
+ extra:
+ shm_threshold_bytes: 65536
+
+ edges:
+ - from: 0
+ to: 1
+ window_size: -1
diff --git a/vllm_omni/model_executor/stage_input_processors/bagel.py b/vllm_omni/model_executor/stage_input_processors/bagel.py
index d7055ff5180..6b88fcd4a18 100644
--- a/vllm_omni/model_executor/stage_input_processors/bagel.py
+++ b/vllm_omni/model_executor/stage_input_processors/bagel.py
@@ -30,6 +30,26 @@ class ExpandedPrompt:
prompt: dict[str, Any] | str
role: str
request_id_suffix: str
+ sampling_params_override: dict[str, Any] | None = None
+
+ def apply_overrides(
+ self,
+ base_params: Any,
+ base_spl: list[Any],
+ ) -> tuple[Any, list[Any]]:
+ """Return ``(params, sampling_params_list)`` with overrides applied.
+
+ If this prompt has no overrides the originals are returned as-is.
+ """
+ if not self.sampling_params_override:
+ return base_params, base_spl
+ patched = base_params.clone()
+ for k, v in self.sampling_params_override.items():
+ setattr(patched, k, v)
+ spl = list(base_spl)
+ if spl:
+ spl[0] = patched
+ return patched, spl
def expand_cfg_prompts(
@@ -108,6 +128,95 @@ def expand_cfg_prompts(
return []
+GEN_THINK_SYSTEM_PROMPT = (
+ "You should first think about the planning process in the mind "
+ "and then generate the image. \n"
+ "The planning process is enclosed within tags, "
+ "i.e. planning process here image here"
+)
+
+
+def expand_cfg_prompts_think(
+ prompt: dict[str, Any] | str,
+ sampling_params: Any,
+) -> list[ExpandedPrompt]:
+ """Expand prompts for Bagel CFG in thinking mode.
+
+ Same as expand_cfg_prompts but companion requests get max_tokens=1
+ so they stop immediately after prefill (no thinking decode).
+
+ In thinking mode the gen (main) request decodes thinking tokens until
+ EOS; companions should only contribute their prefill KV cache.
+ """
+ if not isinstance(prompt, dict):
+ return []
+
+ modalities = prompt.get("modalities", [])
+ if "image" not in modalities and "img2img" not in modalities:
+ return []
+
+ neg_prompt = _get_negative_prompt(prompt, sampling_params)
+ companion_params = {"max_tokens": 1}
+
+ if "image" in modalities:
+ neg_prompt_dict = {
+ "prompt": neg_prompt,
+ "modalities": prompt.get("modalities", []),
+ }
+ return [
+ ExpandedPrompt(
+ prompt=neg_prompt_dict,
+ role="cfg_text",
+ request_id_suffix=CFG_TEXT_SUFFIX,
+ sampling_params_override=companion_params,
+ ),
+ ]
+
+ if "img2img" in modalities:
+ IMG2IMG_PLACEHOLDER = "<|fim_middle|>"
+
+ original_text = prompt.get("prompt", "")
+ # Extract system prompt prefix (everything before <|fim_middle|>)
+ # so cfg_text gets system_prompt + image (no user text), matching
+ # the original BAGEL code where cfg_text = deepcopy(gen after image).
+ parts = original_text.split(IMG2IMG_PLACEHOLDER, 1)
+ system_prefix = parts[0] if len(parts) > 1 else ""
+
+ cfg_text_prompt = f"{system_prefix}{IMG2IMG_PLACEHOLDER}{neg_prompt}"
+ cfg_text_dict: dict[str, Any] = {
+ "prompt": cfg_text_prompt,
+ "modalities": ["img2img"],
+ }
+ mm_data = prompt.get("multi_modal_data")
+ if mm_data:
+ cfg_text_dict["multi_modal_data"] = mm_data
+
+ cfg_img_text = original_text.replace(IMG2IMG_PLACEHOLDER, "")
+ cfg_img_dict: dict[str, Any] = {
+ "prompt": cfg_img_text,
+ "modalities": ["img2img"],
+ }
+ if mm_data:
+ cfg_img_dict["multi_modal_data"] = mm_data
+
+ return [
+ ExpandedPrompt(
+ prompt=cfg_text_dict,
+ role="cfg_text",
+ request_id_suffix=CFG_TEXT_SUFFIX,
+ sampling_params_override=companion_params,
+ ),
+ ExpandedPrompt(
+ prompt=cfg_img_dict,
+ role="cfg_img",
+ request_id_suffix=CFG_IMG_SUFFIX,
+ sampling_params_override=companion_params,
+ ),
+ ]
+
+ return []
+
+
def collect_cfg_kv_caches(
request_id: str,
cfg_request_ids: dict[str, str],
diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py
index 697c39d242e..155b75675ff 100644
--- a/vllm_omni/worker/gpu_ar_model_runner.py
+++ b/vllm_omni/worker/gpu_ar_model_runner.py
@@ -108,7 +108,14 @@ def execute_model(
if finished_reqs and hasattr(self.model, "get_kv_transfer_metadata"):
for req_id, data in finished_reqs.items():
try:
- model_meta = self.model.get_kv_transfer_metadata(req_id)
+ req_idx = self.input_batch.req_id_to_index.get(req_id)
+ num_computed = (
+ int(self.input_batch.num_computed_tokens_cpu[req_idx]) if req_idx is not None else None
+ )
+ model_meta = self.model.get_kv_transfer_metadata(
+ req_id,
+ num_computed_tokens=num_computed,
+ )
if model_meta:
existing = data.get("custom_metadata") or {}
existing.update(model_meta)
@@ -266,6 +273,19 @@ def execute_model(
ec_connector_output,
) = self._preprocess(scheduler_output, num_tokens_padded, intermediate_tensors)
+ # Let the model adjust inputs before forward (e.g. restore input_ids
+ # for multimodal position detection, fix decode position offsets).
+ if hasattr(self.model, "prepare_runner_inputs"):
+ input_ids, positions = self.model.prepare_runner_inputs(
+ input_ids=input_ids,
+ positions=positions,
+ inputs_embeds=inputs_embeds,
+ req_ids=req_ids[:num_reqs],
+ num_computed_tokens=[int(self.input_batch.num_computed_tokens_cpu[i]) for i in range(num_reqs)],
+ num_scheduled_tokens=[int(num_scheduled_tokens_np[i]) for i in range(num_reqs)],
+ input_ids_buffer=self.input_ids.gpu[:num_tokens_padded],
+ )
+
# Set cudagraph mode to none if calc_kv_scales is true.
# KV scales calculation involves dynamic operations that are incompatible
# with CUDA graph capture.