diff --git a/examples/offline_inference/bagel/end2end.py b/examples/offline_inference/bagel/end2end.py index efcdea2355d..2153a31ba70 100644 --- a/examples/offline_inference/bagel/end2end.py +++ b/examples/offline_inference/bagel/end2end.py @@ -2,6 +2,7 @@ import os from vllm_omni.inputs.data import OmniPromptType +from vllm_omni.model_executor.stage_input_processors.bagel import GEN_THINK_SYSTEM_PROMPT def parse_args(): @@ -65,6 +66,17 @@ def parse_args(): help="CFG parallel size: 1=batched (single GPU), 2=parallel with 2 branches (text CFG only), 3=parallel (3 GPUs).", ) parser.add_argument("--seed", type=int, default=None, help="Random seed for generation.") + parser.add_argument( + "--cfg-interval", + type=float, + nargs=2, + default=None, + help="CFG interval [start, end] (default: pipeline default)", + ) + parser.add_argument( + "--cfg-renorm-type", type=str, default=None, help="CFG renorm type: global, text_channel, channel" + ) + parser.add_argument("--cfg-renorm-min", type=float, default=None, help="CFG renorm min") parser.add_argument( "--enable-diffusion-pipeline-profiler", action="store_true", @@ -76,6 +88,12 @@ def parse_args(): default=None, help="Quantization method (e.g. 'fp8').", ) + parser.add_argument( + "--think", + action="store_true", + default=False, + help="Enable thinking mode: AR stage decodes ... planning tokens before image generation.", + ) args = parser.parse_args() return args @@ -110,8 +128,12 @@ def main(): from vllm_omni.entrypoints.omni import Omni omni_kwargs = {} - if args.stage_configs_path: - omni_kwargs["stage_configs_path"] = args.stage_configs_path + stage_configs_path = args.stage_configs_path + if args.think and stage_configs_path is None: + stage_configs_path = "vllm_omni/model_executor/stage_configs/bagel_think.yaml" + print(f"[Info] Think mode enabled, using stage config: {stage_configs_path}") + if stage_configs_path: + omni_kwargs["stage_configs_path"] = stage_configs_path omni_kwargs.update( { @@ -136,7 +158,8 @@ def main(): if not args.image_path or not os.path.exists(args.image_path): raise ValueError(f"img2img requires --image-path pointing to an existing file, got: {args.image_path}") loaded_image = Image.open(args.image_path).convert("RGB") - final_prompt_text = f"<|fim_middle|><|im_start|>{p}<|im_end|>" + think_prefix = f"<|im_start|>{GEN_THINK_SYSTEM_PROMPT}<|im_end|>" if args.think else "" + final_prompt_text = f"{think_prefix}<|fim_middle|><|im_start|>{p}<|im_end|>" prompt_dict = { "prompt": final_prompt_text, "multi_modal_data": {"img2img": loaded_image}, @@ -160,7 +183,8 @@ def main(): prompt_dict = {"prompt": final_prompt_text, "modalities": ["text"]} formatted_prompts.append(prompt_dict) else: - final_prompt_text = f"<|im_start|>{p}<|im_end|>" + think_prefix = f"<|im_start|>{GEN_THINK_SYSTEM_PROMPT}<|im_end|>" if args.think else "" + final_prompt_text = f"{think_prefix}<|im_start|>{p}<|im_end|>" prompt_dict = {"prompt": final_prompt_text, "modalities": ["image"]} if args.negative_prompt is not None: prompt_dict["negative_prompt"] = args.negative_prompt @@ -178,6 +202,12 @@ def main(): "cfg_text_scale": args.cfg_text_scale, "cfg_img_scale": args.cfg_img_scale, } + if args.cfg_interval is not None: + extra["cfg_interval"] = tuple(args.cfg_interval) + if args.cfg_renorm_type is not None: + extra["cfg_renorm_type"] = args.cfg_renorm_type + if args.cfg_renorm_min is not None: + extra["cfg_renorm_min"] = args.cfg_renorm_min if args.negative_prompt is not None: extra["negative_prompt"] = args.negative_prompt diffusion_params.extra_args = extra # type: ignore @@ -186,6 +216,17 @@ def main(): img_idx = 0 for req_output in omni_outputs: + if args.think: + text_output = getattr(req_output, "text", None) or getattr(req_output, "outputs", None) + if text_output: + if isinstance(text_output, list) and text_output: + for out in text_output: + txt = getattr(out, "text", str(out)) + if txt: + print(f"[Think] {txt}") + elif isinstance(text_output, str): + print(f"[Think] {text_output}") + images = getattr(req_output, "images", None) if not images: @@ -194,6 +235,7 @@ def main(): for j, img in enumerate(images): save_path = os.path.join(args.output, f"output_{img_idx}_{j}.png") img.save(save_path) + print(f"[Output] Saved image to {save_path}") img_idx += 1 print(omni_outputs) diff --git a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py index aa4f0a74f02..3e053cbda50 100644 --- a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py +++ b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py @@ -326,11 +326,18 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: cfg_text_scale = extra_args.get("cfg_text_scale", 4.0) cfg_img_scale = extra_args.get("cfg_img_scale", 1.5) + cfg_interval = extra_args.get("cfg_interval", (0.4, 1.0)) + cfg_renorm_type = extra_args.get("cfg_renorm_type", "global") + cfg_renorm_min = extra_args.get("cfg_renorm_min", 0.0) + gen_params = BagelGenParams( num_timesteps=int(req.sampling_params.num_inference_steps or 50), timestep_shift=3.0, cfg_text_scale=cfg_text_scale, cfg_img_scale=cfg_img_scale, + cfg_interval=cfg_interval, + cfg_renorm_type=cfg_renorm_type, + cfg_renorm_min=cfg_renorm_min, ) gen_context = { diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index 9de3dc867ff..b94f83bab39 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -722,14 +722,15 @@ def _enqueue_cfg_companions( cid = f"{parent_id}{ep.request_id_suffix}" companion_prompt = ep.prompt - # Run through same input processing as the main prompt + companion_params, companion_spl = ep.apply_overrides(stage0_params, sampling_params_list) + if isinstance(companion_prompt, dict): _inject_global_id(companion_prompt, cid) request = self.input_processor.process_inputs( request_id=cid, prompt=companion_prompt, - params=stage0_params, + params=companion_params, supported_tasks=self.supported_tasks, ) request = _upgrade_to_omni_request(request, companion_prompt) @@ -750,7 +751,7 @@ def _enqueue_cfg_companions( "parent_id": parent_id, "role": ep.role, "prompt": request, - "sampling_params_list": sampling_params_list, + "sampling_params_list": companion_spl, } ) diff --git a/vllm_omni/model_executor/models/bagel/bagel.py b/vllm_omni/model_executor/models/bagel/bagel.py index e58b3501c44..e79f0212e2e 100644 --- a/vllm_omni/model_executor/models/bagel/bagel.py +++ b/vllm_omni/model_executor/models/bagel/bagel.py @@ -429,6 +429,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._ropes_metadata: dict[str, dict[str, Any]] = {} self._cfg_companion_queue: deque[tuple[tuple[int, int, int, int], int]] = deque() + # Per-request position offset for decode after img2img prefill. + # Prefill rewrites positions (VAE→0, ViT→1, text→2..N) but the model + # runner assigns decode positions starting from prefill_len, not N+1. + # offset = rope - prefill_len (a negative number). + self._pending_decode_offsets: list[int] = [] + self._decode_position_offsets: dict[str, int] = {} + from transformers import AutoTokenizer tok_name = getattr(vllm_config.model_config, "tokenizer", None) or vllm_config.model_config.model @@ -438,6 +445,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): _tok.add_tokens([t]) self._start_of_image_id = int(_tok.convert_tokens_to_ids("<|vision_start|>")) self._end_of_image_id = int(_tok.convert_tokens_to_ids("<|vision_end|>")) + self._img2img_token_id = int(_tok.convert_tokens_to_ids("<|fim_middle|>")) self._vae_token_mask: torch.Tensor | None = None self.device = get_local_device() @@ -518,10 +526,64 @@ def _clear_warmup_state(self): self._ropes_metadata.clear() self._pending_img2img_info.clear() self._cfg_companion_queue.clear() + self._pending_decode_offsets.clear() + self._decode_position_offsets.clear() self._vae_token_mask = None - def get_kv_transfer_metadata(self, req_id: str) -> dict[str, Any] | None: - return self._ropes_metadata.pop(req_id, None) + def get_kv_transfer_metadata( + self, + req_id: str, + *, + num_computed_tokens: int | None = None, + ) -> dict[str, Any] | None: + meta = self._ropes_metadata.pop(req_id, None) + if meta is None: + return None + # In think-mode img2img the prefill rope doesn't account for decoded + # thinking tokens; correct it to num_computed_tokens + offset. + # Skip correction when num_computed_tokens is unavailable (None). + offset = self._decode_position_offsets.pop(req_id, 0) + if offset != 0 and "ropes" in meta and num_computed_tokens is not None: + meta["ropes"] = [num_computed_tokens + offset] + return meta + + def prepare_runner_inputs( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor | None, + inputs_embeds: torch.Tensor | None, + req_ids: list[str], + num_computed_tokens: list[int], + num_scheduled_tokens: list[int], + input_ids_buffer: torch.Tensor | None = None, + ) -> tuple[torch.Tensor | None, torch.Tensor | None]: + """Model-runner hook: adjust inputs before ``forward()``. + + Returns ``(input_ids, positions)`` — possibly modified. + + Two adjustments for BAGEL img2img: + + 1. **Restore input_ids** when ``inputs_embeds`` is present so that + ``_adjust_positions_for_img2img`` can locate the + ``<|fim_middle|>`` placeholder. + 2. **Decode position offset**: prefill rewrites positions to a + compact scheme (rope ≪ prefill_len). The runner assigns decode + positions from ``num_computed_tokens``, which is far too large; + apply the stored per-request offset. + """ + if inputs_embeds is not None and input_ids is None and input_ids_buffer is not None: + input_ids = input_ids_buffer + + if self._decode_position_offsets and positions is not None: + token_start = 0 + for i, rid in enumerate(req_ids): + sched = num_scheduled_tokens[i] + offset = self._decode_position_offsets.get(rid, 0) + if offset != 0 and num_computed_tokens[i] > 0: + positions[token_start : token_start + sched] += offset + token_start += sched + + return input_ids, positions def flush_pending_metadata(self, req_ids: list[str]) -> None: """Map pending metadata (batch order) to req_ids after forward().""" @@ -529,7 +591,14 @@ def flush_pending_metadata(self, req_ids: list[str]) -> None: self._ropes_pending = [] for i, meta in enumerate(pending): if i < len(req_ids): - self._ropes_metadata[req_ids[i]] = meta + if req_ids[i] not in self._ropes_metadata: + self._ropes_metadata[req_ids[i]] = meta + + pending_offsets = self._pending_decode_offsets + self._pending_decode_offsets = [] + for i, offset in enumerate(pending_offsets): + if i < len(req_ids) and offset != 0: + self._decode_position_offsets[req_ids[i]] = offset def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {} @@ -643,7 +712,16 @@ def _process_img2img_input(self, multimodal_input): num_vit = vit_emb.shape[0] + 2 info = (num_vae, num_vit, int(H), int(W)) self._pending_img2img_info.append(info) - self._cfg_companion_queue.append((info, 2)) # cfg_text + cfg_img + # Only the gen (main) request should add a companion queue entry. + # Companion requests (cfg_text, cfg_img) also call this method with + # the same image, so guard by checking whether this exact info + # tuple is already enqueued. For batched img2img with multiple + # concurrent gen requests this correctly adds one entry per unique + # image; images with identical (num_vae, num_vit, H, W) that arrive + # in the same batch are indistinguishable here and will share one + # entry, but that is an uncommon edge case. + if not any(entry[0] == info for entry in self._cfg_companion_queue): + self._cfg_companion_queue.append((info, 2)) # cfg_text + cfg_img return tuple(results) @@ -659,42 +737,65 @@ def forward( seq_len = inputs_embeds.shape[0] if inputs_embeds is not None else positions.shape[0] if self._pending_img2img_info: - positions = self._adjust_positions_for_img2img(positions) + positions = self._adjust_positions_for_img2img(positions, input_ids) use_mot = True elif self._cfg_companion_queue: - cached, remaining = self._cfg_companion_queue[0] - remaining -= 1 - num_vae, num_vit, img_H, img_W = cached - num_img2img = num_vae + 1 + num_vit # +1 separator - seq_len = inputs_embeds.shape[0] if inputs_embeds is not None else positions.shape[0] - - if inputs_embeds is not None and seq_len >= num_img2img: - self._pending_img2img_info = [cached] - positions = self._adjust_positions_for_img2img(positions) - use_mot = True + # Guard: if this looks like a pure decode step (small token count, + # no multimodal embeddings), the queue has stale entries from a + # previous prefill cycle — clear them instead of consuming. + if inputs_embeds is None and seq_len <= 2: + self._cfg_companion_queue.clear() else: - rope = int(positions[seq_len - 1].item()) + 1 - self._ropes_pending.append({"ropes": [rope]}) + cached, remaining = self._cfg_companion_queue[0] + remaining -= 1 + num_vae, num_vit, img_H, img_W = cached + num_img2img = num_vae + 1 + num_vit # +1 separator + seq_len = inputs_embeds.shape[0] if inputs_embeds is not None else positions.shape[0] - if remaining == 0: - self._cfg_companion_queue.popleft() - else: - self._cfg_companion_queue[0] = (cached, remaining) + if inputs_embeds is not None and seq_len >= num_img2img: + self._pending_img2img_info = [cached] + positions = self._adjust_positions_for_img2img(positions, input_ids) + use_mot = True + else: + rope = int(positions[seq_len - 1].item()) + 1 + self._ropes_pending.append({"ropes": [rope]}) + + if remaining == 0: + self._cfg_companion_queue.popleft() + else: + self._cfg_companion_queue[0] = (cached, remaining) if use_mot: return self._mot_forward(input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs) return super().forward(input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs) - def _adjust_positions_for_img2img(self, positions: torch.Tensor) -> torch.Tensor: - """Rewrite position IDs to match the single-stage DiT scheme: - VAE tokens -> position 0, separator -> position 0, - ViT tokens -> position 1, text -> 2, 3, ... + def _adjust_positions_for_img2img( + self, + positions: torch.Tensor, + input_ids: torch.Tensor | None = None, + ) -> torch.Tensor: + """Rewrite position IDs to match the original BAGEL position scheme: + + If there are ``pre_text_len`` text tokens before the img2img block:: + + pre_text → 0, 1, ..., M-1 + VAE → M (all share) + separator→ M + ViT → M+1 (all share) + post_text→ M+2, M+3, ... + + When no text precedes the img2img block (M=0), this reduces to the + simpler scheme: VAE→0, ViT→1, text→2, 3, ... Also computes ``self._vae_token_mask`` (bool tensor, True for actual VAE latent patches that should use gen-mode weights) and pushes per-request ropes + image_shape to the FIFO consumed by ``get_kv_transfer_metadata``. + + For img2img requests, also stores a decode position offset so that + subsequent autoregressive decode steps use positions that continue + from the rewritten scheme rather than from the original prefill length. """ info_list = self._pending_img2img_info self._pending_img2img_info = [] @@ -724,35 +825,66 @@ def _adjust_positions_for_img2img(self, positions: torch.Tensor) -> torch.Tensor num_img2img = num_vae + 1 + num_vit # +1 separator if req_len >= num_img2img: - new_positions[start : start + num_vae] = 0 - new_positions[start + num_vae] = 0 # separator - vit_start = start + num_vae + 1 - new_positions[vit_start : vit_start + num_vit] = 1 - num_text = req_len - num_img2img - if num_text > 0: - text_start = start + num_img2img - new_positions[text_start:end] = torch.arange( - 2, 2 + num_text, device=positions.device, dtype=positions.dtype + # Detect offset of img2img tokens within this request + # by searching for the img2img placeholder token ID. + pre_text_len = 0 + if input_ids is not None: + req_ids = input_ids[start:end] + mask = req_ids == self._img2img_token_id + indices = mask.nonzero(as_tuple=True)[0] + if indices.numel() > 0: + pre_text_len = int(indices[0].item()) + + img_start = start + pre_text_len + post_text_start = img_start + num_img2img + # pre_text_pos: position base for image tokens + pre_text_pos = pre_text_len + + # Pre-image text: sequential positions 0..pre_text_pos-1 + if pre_text_len > 0: + new_positions[start:img_start] = torch.arange( + 0, pre_text_pos, device=positions.device, dtype=positions.dtype + ) + + # VAE tokens: all share position pre_text_pos + new_positions[img_start : img_start + num_vae] = pre_text_pos + # Separator: position pre_text_pos + new_positions[img_start + num_vae] = pre_text_pos + # ViT tokens: all share position pre_text_pos+1 + vit_start = img_start + num_vae + 1 + new_positions[vit_start : vit_start + num_vit] = pre_text_pos + 1 + + # Post-image text: sequential positions pre_text_pos+2, pre_text_pos+3, ... + num_post_text = end - post_text_start + if num_post_text > 0: + new_positions[post_text_start:end] = torch.arange( + pre_text_pos + 2, + pre_text_pos + 2 + num_post_text, + device=positions.device, + dtype=positions.dtype, ) - # VAE gen-mode mask: only actual VAE patches (not markers) - vae_patches_start = start + 1 # skip start_marker - vae_patches_end = start + num_vae - 1 # before end_marker + # VAE gen-mode mask: only actual VAE latent patches (not markers) + vae_patches_start = img_start + 1 # skip start_marker + vae_patches_end = img_start + num_vae - 1 # before end_marker if vae_patches_end > vae_patches_start: vae_mask[vae_patches_start:vae_patches_end] = True - rope = 2 + num_text + rope = pre_text_pos + 2 + num_post_text self._ropes_pending.append( { "ropes": [rope], "image_shape": [img_H, img_W], } ) + decode_offset = rope - req_len + self._pending_decode_offsets.append(decode_offset) img2img_idx += 1 continue rope = int(new_positions[end - 1].item()) + 1 self._ropes_pending.append({"ropes": [rope]}) + self._pending_decode_offsets.append(0) self._vae_token_mask = vae_mask if vae_mask.any() else None return new_positions diff --git a/vllm_omni/model_executor/stage_configs/bagel_think.yaml b/vllm_omni/model_executor/stage_configs/bagel_think.yaml new file mode 100644 index 00000000000..c4cf32c707e --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/bagel_think.yaml @@ -0,0 +1,86 @@ +# BAGEL Think Model: AR stage decodes thinking tokens before KV transfer to DiT. +# +# Differences from bagel.yaml: +# - No kv_transfer_criteria: AR stage decodes until EOS, then transfers full +# KV cache (including thinking tokens) via _free_request path. +# - prompt_expand_func: uses expand_cfg_prompts_think which sets max_tokens=1 +# on companion requests so they stop immediately after prefill. +# - max_tokens: 2048 for thinking text generation. + +stage_args: + - stage_id: 0 + stage_type: llm + prompt_expand_func: vllm_omni.model_executor.stage_input_processors.bagel.expand_cfg_prompts_think + runtime: + devices: "0" + engine_args: + model_stage: thinker + max_num_seqs: 3 + model_arch: OmniBagelForConditionalGeneration + worker_type: ar + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.45 + enforce_eager: true + trust_remote_code: true + engine_output_type: text + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + tensor_parallel_size: 1 + omni_kv_config: + need_send_cache: true + final_output: true + final_output_type: text + is_comprehension: true + default_sampling_params: + temperature: 0.3 + top_p: 0.9 + top_k: 1 + max_tokens: 2048 + seed: 52 + detokenize: True + repetition_penalty: 1.05 + + - stage_id: 1 + stage_type: diffusion + cfg_kv_collect_func: vllm_omni.model_executor.stage_input_processors.bagel.collect_cfg_kv_caches + runtime: + devices: "0" + engine_args: + model_stage: dit + max_num_seqs: 1 + gpu_memory_utilization: 0.45 + enforce_eager: true + trust_remote_code: true + engine_output_type: image + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + tensor_parallel_size: 1 + omni_kv_config: + need_recv_cache: true + engine_input_source: [0] + + final_output: true + final_output_type: image + is_comprehension: false + default_sampling_params: + seed: 52 + +# Runtime edges +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 + + connectors: + shared_memory_connector: + name: SharedMemoryConnector + extra: + shm_threshold_bytes: 65536 + + edges: + - from: 0 + to: 1 + window_size: -1 diff --git a/vllm_omni/model_executor/stage_input_processors/bagel.py b/vllm_omni/model_executor/stage_input_processors/bagel.py index d7055ff5180..6b88fcd4a18 100644 --- a/vllm_omni/model_executor/stage_input_processors/bagel.py +++ b/vllm_omni/model_executor/stage_input_processors/bagel.py @@ -30,6 +30,26 @@ class ExpandedPrompt: prompt: dict[str, Any] | str role: str request_id_suffix: str + sampling_params_override: dict[str, Any] | None = None + + def apply_overrides( + self, + base_params: Any, + base_spl: list[Any], + ) -> tuple[Any, list[Any]]: + """Return ``(params, sampling_params_list)`` with overrides applied. + + If this prompt has no overrides the originals are returned as-is. + """ + if not self.sampling_params_override: + return base_params, base_spl + patched = base_params.clone() + for k, v in self.sampling_params_override.items(): + setattr(patched, k, v) + spl = list(base_spl) + if spl: + spl[0] = patched + return patched, spl def expand_cfg_prompts( @@ -108,6 +128,95 @@ def expand_cfg_prompts( return [] +GEN_THINK_SYSTEM_PROMPT = ( + "You should first think about the planning process in the mind " + "and then generate the image. \n" + "The planning process is enclosed within tags, " + "i.e. planning process here image here" +) + + +def expand_cfg_prompts_think( + prompt: dict[str, Any] | str, + sampling_params: Any, +) -> list[ExpandedPrompt]: + """Expand prompts for Bagel CFG in thinking mode. + + Same as expand_cfg_prompts but companion requests get max_tokens=1 + so they stop immediately after prefill (no thinking decode). + + In thinking mode the gen (main) request decodes thinking tokens until + EOS; companions should only contribute their prefill KV cache. + """ + if not isinstance(prompt, dict): + return [] + + modalities = prompt.get("modalities", []) + if "image" not in modalities and "img2img" not in modalities: + return [] + + neg_prompt = _get_negative_prompt(prompt, sampling_params) + companion_params = {"max_tokens": 1} + + if "image" in modalities: + neg_prompt_dict = { + "prompt": neg_prompt, + "modalities": prompt.get("modalities", []), + } + return [ + ExpandedPrompt( + prompt=neg_prompt_dict, + role="cfg_text", + request_id_suffix=CFG_TEXT_SUFFIX, + sampling_params_override=companion_params, + ), + ] + + if "img2img" in modalities: + IMG2IMG_PLACEHOLDER = "<|fim_middle|>" + + original_text = prompt.get("prompt", "") + # Extract system prompt prefix (everything before <|fim_middle|>) + # so cfg_text gets system_prompt + image (no user text), matching + # the original BAGEL code where cfg_text = deepcopy(gen after image). + parts = original_text.split(IMG2IMG_PLACEHOLDER, 1) + system_prefix = parts[0] if len(parts) > 1 else "" + + cfg_text_prompt = f"{system_prefix}{IMG2IMG_PLACEHOLDER}{neg_prompt}" + cfg_text_dict: dict[str, Any] = { + "prompt": cfg_text_prompt, + "modalities": ["img2img"], + } + mm_data = prompt.get("multi_modal_data") + if mm_data: + cfg_text_dict["multi_modal_data"] = mm_data + + cfg_img_text = original_text.replace(IMG2IMG_PLACEHOLDER, "") + cfg_img_dict: dict[str, Any] = { + "prompt": cfg_img_text, + "modalities": ["img2img"], + } + if mm_data: + cfg_img_dict["multi_modal_data"] = mm_data + + return [ + ExpandedPrompt( + prompt=cfg_text_dict, + role="cfg_text", + request_id_suffix=CFG_TEXT_SUFFIX, + sampling_params_override=companion_params, + ), + ExpandedPrompt( + prompt=cfg_img_dict, + role="cfg_img", + request_id_suffix=CFG_IMG_SUFFIX, + sampling_params_override=companion_params, + ), + ] + + return [] + + def collect_cfg_kv_caches( request_id: str, cfg_request_ids: dict[str, str], diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py index 697c39d242e..155b75675ff 100644 --- a/vllm_omni/worker/gpu_ar_model_runner.py +++ b/vllm_omni/worker/gpu_ar_model_runner.py @@ -108,7 +108,14 @@ def execute_model( if finished_reqs and hasattr(self.model, "get_kv_transfer_metadata"): for req_id, data in finished_reqs.items(): try: - model_meta = self.model.get_kv_transfer_metadata(req_id) + req_idx = self.input_batch.req_id_to_index.get(req_id) + num_computed = ( + int(self.input_batch.num_computed_tokens_cpu[req_idx]) if req_idx is not None else None + ) + model_meta = self.model.get_kv_transfer_metadata( + req_id, + num_computed_tokens=num_computed, + ) if model_meta: existing = data.get("custom_metadata") or {} existing.update(model_meta) @@ -266,6 +273,19 @@ def execute_model( ec_connector_output, ) = self._preprocess(scheduler_output, num_tokens_padded, intermediate_tensors) + # Let the model adjust inputs before forward (e.g. restore input_ids + # for multimodal position detection, fix decode position offsets). + if hasattr(self.model, "prepare_runner_inputs"): + input_ids, positions = self.model.prepare_runner_inputs( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + req_ids=req_ids[:num_reqs], + num_computed_tokens=[int(self.input_batch.num_computed_tokens_cpu[i]) for i in range(num_reqs)], + num_scheduled_tokens=[int(num_scheduled_tokens_np[i]) for i in range(num_reqs)], + input_ids_buffer=self.input_ids.gpu[:num_tokens_padded], + ) + # Set cudagraph mode to none if calc_kv_scales is true. # KV scales calculation involves dynamic operations that are incompatible # with CUDA graph capture.