Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/offline_inference/bagel/end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ def main():

params_list = omni.default_sampling_params_list
if args.modality in ("text2img", "img2img"):
params_list[0].max_tokens = 1 # type: ignore
if len(params_list) > 1:
diffusion_params = params_list[1]
diffusion_params.num_inference_steps = args.steps # type: ignore
Expand Down
4 changes: 1 addition & 3 deletions tests/e2e/offline_inference/test_bagel_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,19 +79,17 @@ def _find_free_port() -> int:
return port


def _configure_sampling_params(omni: Omni, max_tokens: int = 1, num_inference_steps: int = 15) -> list:
def _configure_sampling_params(omni: Omni, num_inference_steps: int = 15) -> list:
"""Configure sampling parameters for Bagel img2img generation.

Args:
omni: The Omni instance to get default params from.
max_tokens: Maximum tokens for the first stage.
num_inference_steps: Number of inference steps for the diffusion stage.

Returns:
Configured sampling params list.
"""
params_list = omni.default_sampling_params_list
params_list[0].max_tokens = max_tokens # type: ignore
if len(params_list) > 1:
params_list[1].num_inference_steps = num_inference_steps # type: ignore
params_list[1].extra_args = { # type: ignore
Expand Down
4 changes: 1 addition & 3 deletions tests/e2e/offline_inference/test_bagel_text2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,17 @@ def _find_free_port() -> int:
return port


def _configure_sampling_params(omni: Omni, max_tokens: int = 1, num_inference_steps: int = 15) -> list:
def _configure_sampling_params(omni: Omni, num_inference_steps: int = 15) -> list:
"""Configure sampling parameters for Bagel text2img generation.

Args:
omni: The Omni instance to get default params from.
max_tokens: Maximum tokens for the first stage.
num_inference_steps: Number of inference steps for the diffusion stage.

Returns:
Configured sampling params list.
"""
params_list = omni.default_sampling_params_list
params_list[0].max_tokens = max_tokens # type: ignore
if len(params_list) > 1:
params_list[1].num_inference_steps = num_inference_steps # type: ignore
params_list[1].extra_args = { # type: ignore
Expand Down
1 change: 0 additions & 1 deletion tests/e2e/offline_inference/test_quantization_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def _generate_bagel_image(
torch.cuda.reset_peak_memory_stats()

params_list = omni.default_sampling_params_list
params_list[0].max_tokens = 1 # type: ignore
if len(params_list) > 1:
params_list[1].num_inference_steps = num_inference_steps # type: ignore
params_list[1].extra_args = { # type: ignore
Expand Down
30 changes: 29 additions & 1 deletion vllm_omni/core/sched/omni_ar_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,19 @@ def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int
return False

criteria_type = self.kv_transfer_criteria.get("type")
if (
self.kv_transfer_criteria.get("stop_after_transfer", True)
and request.request_id in self.transfer_triggered_requests
):
# For split pipelines that only need the transferred KV
# snapshot, stop AR decode once KV extraction has completed.
# This frees stage-0 resources without requiring an
# orchestrator-side abort.
if request.request_id not in self.active_kv_transfers:
request.status = RequestStatus.FINISHED_STOPPED
return True
return False

# Universal duplicate check for once semantics
if request.request_id in self.transfer_triggered_requests:
return False

Expand Down Expand Up @@ -456,6 +467,23 @@ def update_from_output(
kv_extracted_ids = getattr(model_runner_output, "kv_extracted_req_ids", None)
if kv_extracted_ids:
for req_id in kv_extracted_ids:
# Emit a kv_ready signal so the orchestrator can forward
# the request to the DiT stage immediately after KV
# extraction, without waiting for AR decode to finish.
req = self.requests.get(req_id)
if req is not None and not req.is_finished():
eco = engine_core_outputs.get(req.client_index)
if eco is None:
eco = EngineCoreOutputs()
engine_core_outputs[req.client_index] = eco
eco.outputs.append(
EngineCoreOutput(
request_id=req_id,
new_token_ids=[],
kv_transfer_params={"kv_ready": True},
)
)

# Mark transfer as finished
if req_id in self.active_kv_transfers:
self.active_kv_transfers.remove(req_id)
Expand Down
85 changes: 60 additions & 25 deletions vllm_omni/engine/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,9 @@ async def _orchestration_loop(self) -> None:
continue
idle = False

# Handle prefill-finished KV-ready signals before finished outputs.
await self._handle_kv_ready_raw_outputs(stage_id, raw_outputs)

# 2) Process raw outputs through the output processor
request_outputs = await self._process_stage_outputs(stage_id, raw_outputs)

Expand Down Expand Up @@ -313,25 +316,7 @@ async def _route_output(
# CFG companion handling: companions don't produce user-visible output
# and don't forward to the next stage directly.
if finished and req_id in self._companion_ids:
parent_id = self._companion_to_parent.get(req_id)
if parent_id is not None:
self._companion_done.setdefault(parent_id, set()).add(req_id)
logger.debug(
"[Orchestrator] CFG companion %s done (parent=%s)",
req_id,
parent_id,
)
# Check if parent is waiting and all companions are done
if parent_id in self._deferred_parents and self._all_companions_done(parent_id):
deferred = self._deferred_parents.pop(parent_id)
parent_state = self.request_states.get(parent_id)
if parent_state is not None:
await self._forward_to_next_stage(
parent_id,
deferred["stage_id"],
deferred["output"],
parent_state,
)
await self._handle_cfg_companion_ready(req_id)
self.request_states.pop(req_id, None)
return

Expand All @@ -358,17 +343,17 @@ async def _route_output(
}
)

if finished and stage_id < req_state.final_stage_id and not self.async_chunk:
# If this parent has CFG companions, defer forwarding until all done
if (
finished
and stage_id < req_state.final_stage_id
and not self.async_chunk
and not self._next_stage_already_submitted(stage_id, req_state)
):
if req_id in self._companion_map and not self._all_companions_done(req_id):
self._deferred_parents[req_id] = {
"stage_id": stage_id,
"output": output,
}
logger.debug(
"[Orchestrator] Parent %s deferred, waiting for CFG companions",
req_id,
)
else:
await self._forward_to_next_stage(req_id, stage_id, output, req_state)

Expand All @@ -393,6 +378,56 @@ def _all_companions_done(self, parent_id: str) -> bool:
done_set = self._companion_done.get(parent_id, set())
return all(cid in done_set for cid in role_map.values())

def _next_stage_already_submitted(self, stage_id: int, req_state: OrchestratorRequestState) -> bool:
return (stage_id + 1) in req_state.stage_submit_ts

async def _handle_cfg_companion_ready(self, req_id: str) -> None:
"""Mark a CFG companion as done; if all companions are done, flush deferred parent."""
parent_id = self._companion_to_parent.get(req_id)
if parent_id is None:
return
done_set = self._companion_done.setdefault(parent_id, set())
if req_id in done_set:
return
done_set.add(req_id)
if parent_id in self._deferred_parents and self._all_companions_done(parent_id):
deferred = self._deferred_parents.pop(parent_id)
parent_state = self.request_states.get(parent_id)
if parent_state is not None and not self._next_stage_already_submitted(deferred["stage_id"], parent_state):
await self._forward_to_next_stage(
parent_id,
deferred["stage_id"],
deferred["output"],
parent_state,
)

async def _handle_kv_ready_raw_outputs(self, stage_id: int, raw_outputs: EngineCoreOutputs) -> None:
"""Forward split requests once stage-0 KV is ready, not only when decode fully finishes."""
if self.async_chunk:
return
for raw_output in raw_outputs.outputs:
kv_params = getattr(raw_output, "kv_transfer_params", None)
if not (isinstance(kv_params, dict) and kv_params.get("kv_ready")):
continue
req_id = raw_output.request_id
req_state = self.request_states.get(req_id)
if req_state is None:
continue
if req_id in self._companion_ids:
await self._handle_cfg_companion_ready(req_id)
continue
if stage_id >= req_state.final_stage_id:
continue
if self._next_stage_already_submitted(stage_id, req_state):
continue
if req_id in self._companion_map and not self._all_companions_done(req_id):
self._deferred_parents[req_id] = {
"stage_id": stage_id,
"output": raw_output,
}
else:
await self._forward_to_next_stage(req_id, stage_id, raw_output, req_state)
Comment thread
princepride marked this conversation as resolved.

def _build_stage_metrics(
self,
stage_id: int,
Expand Down
14 changes: 11 additions & 3 deletions vllm_omni/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ async def create_chat_completion(
output_modalities if output_modalities is not None else self.engine_client.output_modalities
)

num_inference_steps = None
# Omni multistage image generation: Stage-0 (AR) should receive a clean
# text prompt (and optional conditioning image/size) so the model's own
# processor can construct the correct inputs.
Expand Down Expand Up @@ -309,6 +310,12 @@ async def create_chat_completion(
extra_body = request.model_extra or {}
height = extra_body.get("height")
width = extra_body.get("width")
num_inference_steps = extra_body.get("num_inference_steps")
if num_inference_steps is not None:
try:
num_inference_steps = int(num_inference_steps)
except Exception:
num_inference_steps = None
if "size" in extra_body:
try:
size_str = extra_body["size"]
Expand Down Expand Up @@ -372,14 +379,15 @@ async def create_chat_completion(
# Use standard OpenAI API parameters for comprehension stage
sampling_params_list = self._build_sampling_params_list_from_request(request)

# Apply user-specified height/width to diffusion stage(s) for image generation
if _image_gen_height is not None or _image_gen_width is not None:
# Apply user-specified overrides to diffusion stage(s) for image generation
if _image_gen_height is not None or _image_gen_width is not None or num_inference_steps is not None:
for idx, sp in enumerate(sampling_params_list):
# Diffusion stages typically have height/width attributes
if hasattr(sp, "height") and _image_gen_height is not None:
sp.height = _image_gen_height
if hasattr(sp, "width") and _image_gen_width is not None:
sp.width = _image_gen_width
if hasattr(sp, "num_inference_steps") and num_inference_steps is not None:
sp.num_inference_steps = num_inference_steps

self._log_inputs(
request_id,
Expand Down
Loading