diff --git a/examples/offline_inference/bagel/end2end.py b/examples/offline_inference/bagel/end2end.py index 922a1af2368..efcdea2355d 100644 --- a/examples/offline_inference/bagel/end2end.py +++ b/examples/offline_inference/bagel/end2end.py @@ -168,7 +168,6 @@ def main(): params_list = omni.default_sampling_params_list if args.modality in ("text2img", "img2img"): - params_list[0].max_tokens = 1 # type: ignore if len(params_list) > 1: diffusion_params = params_list[1] diffusion_params.num_inference_steps = args.steps # type: ignore diff --git a/tests/e2e/offline_inference/test_bagel_img2img.py b/tests/e2e/offline_inference/test_bagel_img2img.py index c7df4f91bed..a0c3f6cc9fc 100644 --- a/tests/e2e/offline_inference/test_bagel_img2img.py +++ b/tests/e2e/offline_inference/test_bagel_img2img.py @@ -79,19 +79,17 @@ def _find_free_port() -> int: return port -def _configure_sampling_params(omni: Omni, max_tokens: int = 1, num_inference_steps: int = 15) -> list: +def _configure_sampling_params(omni: Omni, num_inference_steps: int = 15) -> list: """Configure sampling parameters for Bagel img2img generation. Args: omni: The Omni instance to get default params from. - max_tokens: Maximum tokens for the first stage. num_inference_steps: Number of inference steps for the diffusion stage. Returns: Configured sampling params list. """ params_list = omni.default_sampling_params_list - params_list[0].max_tokens = max_tokens # type: ignore if len(params_list) > 1: params_list[1].num_inference_steps = num_inference_steps # type: ignore params_list[1].extra_args = { # type: ignore diff --git a/tests/e2e/offline_inference/test_bagel_text2img.py b/tests/e2e/offline_inference/test_bagel_text2img.py index 505e12438d0..c74763a35a4 100644 --- a/tests/e2e/offline_inference/test_bagel_text2img.py +++ b/tests/e2e/offline_inference/test_bagel_text2img.py @@ -80,19 +80,17 @@ def _find_free_port() -> int: return port -def _configure_sampling_params(omni: Omni, max_tokens: int = 1, num_inference_steps: int = 15) -> list: +def _configure_sampling_params(omni: Omni, num_inference_steps: int = 15) -> list: """Configure sampling parameters for Bagel text2img generation. Args: omni: The Omni instance to get default params from. - max_tokens: Maximum tokens for the first stage. num_inference_steps: Number of inference steps for the diffusion stage. Returns: Configured sampling params list. """ params_list = omni.default_sampling_params_list - params_list[0].max_tokens = max_tokens # type: ignore if len(params_list) > 1: params_list[1].num_inference_steps = num_inference_steps # type: ignore params_list[1].extra_args = { # type: ignore diff --git a/tests/e2e/offline_inference/test_quantization_fp8.py b/tests/e2e/offline_inference/test_quantization_fp8.py index 5943afa028f..f71c53de74c 100644 --- a/tests/e2e/offline_inference/test_quantization_fp8.py +++ b/tests/e2e/offline_inference/test_quantization_fp8.py @@ -120,7 +120,6 @@ def _generate_bagel_image( torch.cuda.reset_peak_memory_stats() params_list = omni.default_sampling_params_list - params_list[0].max_tokens = 1 # type: ignore if len(params_list) > 1: params_list[1].num_inference_steps = num_inference_steps # type: ignore params_list[1].extra_args = { # type: ignore diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py index 71da4d5925b..c4d84522255 100644 --- a/vllm_omni/core/sched/omni_ar_scheduler.py +++ b/vllm_omni/core/sched/omni_ar_scheduler.py @@ -95,8 +95,19 @@ def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int return False criteria_type = self.kv_transfer_criteria.get("type") + if ( + self.kv_transfer_criteria.get("stop_after_transfer", True) + and request.request_id in self.transfer_triggered_requests + ): + # For split pipelines that only need the transferred KV + # snapshot, stop AR decode once KV extraction has completed. + # This frees stage-0 resources without requiring an + # orchestrator-side abort. + if request.request_id not in self.active_kv_transfers: + request.status = RequestStatus.FINISHED_STOPPED + return True + return False - # Universal duplicate check for once semantics if request.request_id in self.transfer_triggered_requests: return False @@ -456,6 +467,23 @@ def update_from_output( kv_extracted_ids = getattr(model_runner_output, "kv_extracted_req_ids", None) if kv_extracted_ids: for req_id in kv_extracted_ids: + # Emit a kv_ready signal so the orchestrator can forward + # the request to the DiT stage immediately after KV + # extraction, without waiting for AR decode to finish. + req = self.requests.get(req_id) + if req is not None and not req.is_finished(): + eco = engine_core_outputs.get(req.client_index) + if eco is None: + eco = EngineCoreOutputs() + engine_core_outputs[req.client_index] = eco + eco.outputs.append( + EngineCoreOutput( + request_id=req_id, + new_token_ids=[], + kv_transfer_params={"kv_ready": True}, + ) + ) + # Mark transfer as finished if req_id in self.active_kv_transfers: self.active_kv_transfers.remove(req_id) diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py index e6373ec96ea..8128c25c645 100644 --- a/vllm_omni/engine/orchestrator.py +++ b/vllm_omni/engine/orchestrator.py @@ -268,6 +268,9 @@ async def _orchestration_loop(self) -> None: continue idle = False + # Handle prefill-finished KV-ready signals before finished outputs. + await self._handle_kv_ready_raw_outputs(stage_id, raw_outputs) + # 2) Process raw outputs through the output processor request_outputs = await self._process_stage_outputs(stage_id, raw_outputs) @@ -313,25 +316,7 @@ async def _route_output( # CFG companion handling: companions don't produce user-visible output # and don't forward to the next stage directly. if finished and req_id in self._companion_ids: - parent_id = self._companion_to_parent.get(req_id) - if parent_id is not None: - self._companion_done.setdefault(parent_id, set()).add(req_id) - logger.debug( - "[Orchestrator] CFG companion %s done (parent=%s)", - req_id, - parent_id, - ) - # Check if parent is waiting and all companions are done - if parent_id in self._deferred_parents and self._all_companions_done(parent_id): - deferred = self._deferred_parents.pop(parent_id) - parent_state = self.request_states.get(parent_id) - if parent_state is not None: - await self._forward_to_next_stage( - parent_id, - deferred["stage_id"], - deferred["output"], - parent_state, - ) + await self._handle_cfg_companion_ready(req_id) self.request_states.pop(req_id, None) return @@ -358,17 +343,17 @@ async def _route_output( } ) - if finished and stage_id < req_state.final_stage_id and not self.async_chunk: - # If this parent has CFG companions, defer forwarding until all done + if ( + finished + and stage_id < req_state.final_stage_id + and not self.async_chunk + and not self._next_stage_already_submitted(stage_id, req_state) + ): if req_id in self._companion_map and not self._all_companions_done(req_id): self._deferred_parents[req_id] = { "stage_id": stage_id, "output": output, } - logger.debug( - "[Orchestrator] Parent %s deferred, waiting for CFG companions", - req_id, - ) else: await self._forward_to_next_stage(req_id, stage_id, output, req_state) @@ -393,6 +378,56 @@ def _all_companions_done(self, parent_id: str) -> bool: done_set = self._companion_done.get(parent_id, set()) return all(cid in done_set for cid in role_map.values()) + def _next_stage_already_submitted(self, stage_id: int, req_state: OrchestratorRequestState) -> bool: + return (stage_id + 1) in req_state.stage_submit_ts + + async def _handle_cfg_companion_ready(self, req_id: str) -> None: + """Mark a CFG companion as done; if all companions are done, flush deferred parent.""" + parent_id = self._companion_to_parent.get(req_id) + if parent_id is None: + return + done_set = self._companion_done.setdefault(parent_id, set()) + if req_id in done_set: + return + done_set.add(req_id) + if parent_id in self._deferred_parents and self._all_companions_done(parent_id): + deferred = self._deferred_parents.pop(parent_id) + parent_state = self.request_states.get(parent_id) + if parent_state is not None and not self._next_stage_already_submitted(deferred["stage_id"], parent_state): + await self._forward_to_next_stage( + parent_id, + deferred["stage_id"], + deferred["output"], + parent_state, + ) + + async def _handle_kv_ready_raw_outputs(self, stage_id: int, raw_outputs: EngineCoreOutputs) -> None: + """Forward split requests once stage-0 KV is ready, not only when decode fully finishes.""" + if self.async_chunk: + return + for raw_output in raw_outputs.outputs: + kv_params = getattr(raw_output, "kv_transfer_params", None) + if not (isinstance(kv_params, dict) and kv_params.get("kv_ready")): + continue + req_id = raw_output.request_id + req_state = self.request_states.get(req_id) + if req_state is None: + continue + if req_id in self._companion_ids: + await self._handle_cfg_companion_ready(req_id) + continue + if stage_id >= req_state.final_stage_id: + continue + if self._next_stage_already_submitted(stage_id, req_state): + continue + if req_id in self._companion_map and not self._all_companions_done(req_id): + self._deferred_parents[req_id] = { + "stage_id": stage_id, + "output": raw_output, + } + else: + await self._forward_to_next_stage(req_id, stage_id, raw_output, req_state) + def _build_stage_metrics( self, stage_id: int, diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 7354b573f61..527947be922 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -276,6 +276,7 @@ async def create_chat_completion( output_modalities if output_modalities is not None else self.engine_client.output_modalities ) + num_inference_steps = None # Omni multistage image generation: Stage-0 (AR) should receive a clean # text prompt (and optional conditioning image/size) so the model's own # processor can construct the correct inputs. @@ -309,6 +310,12 @@ async def create_chat_completion( extra_body = request.model_extra or {} height = extra_body.get("height") width = extra_body.get("width") + num_inference_steps = extra_body.get("num_inference_steps") + if num_inference_steps is not None: + try: + num_inference_steps = int(num_inference_steps) + except Exception: + num_inference_steps = None if "size" in extra_body: try: size_str = extra_body["size"] @@ -372,14 +379,15 @@ async def create_chat_completion( # Use standard OpenAI API parameters for comprehension stage sampling_params_list = self._build_sampling_params_list_from_request(request) - # Apply user-specified height/width to diffusion stage(s) for image generation - if _image_gen_height is not None or _image_gen_width is not None: + # Apply user-specified overrides to diffusion stage(s) for image generation + if _image_gen_height is not None or _image_gen_width is not None or num_inference_steps is not None: for idx, sp in enumerate(sampling_params_list): - # Diffusion stages typically have height/width attributes if hasattr(sp, "height") and _image_gen_height is not None: sp.height = _image_gen_height if hasattr(sp, "width") and _image_gen_width is not None: sp.width = _image_gen_width + if hasattr(sp, "num_inference_steps") and num_inference_steps is not None: + sp.num_inference_steps = num_inference_steps self._log_inputs( request_id,