Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 54 additions & 51 deletions vllm_omni/core/sched/omni_ar_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ def __init__(self, *args, **kwargs):
# Track ACTIVE transfers (submitted to runner but not yet acked via kv_extracted_req_ids)
self.active_kv_transfers: set[str] = set()

# Requests marked for deferred stop: keep running until KV extraction
# completes so that kv_ready can be emitted while the request is still
# alive. Stopped on the first scheduler step after extraction ack.
self.pending_stop_after_extraction: set[str] = set()

# [Omni] Pre-parse KV transfer criteria
self.kv_transfer_criteria = self._get_kv_transfer_criteria()

Expand Down Expand Up @@ -126,11 +131,16 @@ def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int
stop_decode_on_trigger = self.kv_transfer_criteria.get("stop_after_transfer", True)

if request.request_id in self.transfer_triggered_requests:
# Already triggered. When stop_decode_on_trigger is True AND
# transfer was actually queued, the request was already stopped
# at trigger time (see below). Any request that reaches this
# point either has stop_decode_on_trigger=False (continue
# decoding) or was not actually queued (should not be stopped).
# Deferred stop: once KV extraction is complete (no longer in
# active_kv_transfers), stop the request. This guarantees the
# kv_ready signal was emitted while the request was still alive.
if (
request.request_id in self.pending_stop_after_extraction
and request.request_id not in self.active_kv_transfers
):
self.pending_stop_after_extraction.discard(request.request_id)
request.status = RequestStatus.FINISHED_STOPPED
return True
return False

if criteria_type == "prefill_finished":
Expand All @@ -140,14 +150,11 @@ def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int
actually_queued = request.request_id in self.requests_needing_kv_transfer

if stop_decode_on_trigger and actually_queued:
# Stop immediately so the request is NOT scheduled in
# the next step, freeing scheduling budget for companion
# requests whose chunked-prefill boundaries must be
# deterministic. waiting_for_transfer_free keeps blocks
# alive until the model runner finishes KV extraction.
self.waiting_for_transfer_free.add(request.request_id)
request.status = RequestStatus.FINISHED_STOPPED
return True
# Defer the stop until KV extraction completes so that
# the kv_ready signal can be emitted while the request
# is still alive. The request will be stopped on the
# next scheduler step after extraction ack arrives.
self.pending_stop_after_extraction.add(request.request_id)

return False

Expand All @@ -167,9 +174,7 @@ def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int
actually_queued = request.request_id in self.requests_needing_kv_transfer

if stop_decode_on_trigger and actually_queued:
self.waiting_for_transfer_free.add(request.request_id)
request.status = RequestStatus.FINISHED_STOPPED
return True
self.pending_stop_after_extraction.add(request.request_id)

return False

Expand Down Expand Up @@ -268,6 +273,26 @@ def update_from_output(
num_scheduled_tokens,
)

# Pre-process KV extraction acks so that the per-request loop below
# can see up-to-date active_kv_transfers state and emit kv_ready
# signals while requests are still alive (before any deferred stop).
kv_extracted_ids = getattr(model_runner_output, "kv_extracted_req_ids", None)
if kv_extracted_ids:
for req_id in kv_extracted_ids:
try:
self.active_kv_transfers.discard(req_id)
req = self.requests.get(req_id)
if req is not None and not req.is_finished():
outputs[req.client_index].append(
EngineCoreOutput(
request_id=req_id,
new_token_ids=[],
kv_transfer_params={"kv_ready": True},
)
)
except Exception:
init_logger(__name__).exception("Failed to pre-process KV extraction for %s", req_id)

# NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more,
# the below loop can be a performance bottleneck. We should do our best
# to avoid expensive operations inside the loop.
Expand Down Expand Up @@ -436,6 +461,7 @@ def update_from_output(
self.transfer_triggered_requests.remove(req.request_id)
if req.request_id in self.active_kv_transfers:
self.active_kv_transfers.remove(req.request_id)
self.pending_stop_after_extraction.discard(req.request_id)

# Same for preempted
for req in stopped_preempted_reqs:
Expand All @@ -444,6 +470,8 @@ def update_from_output(
self.transfer_triggered_requests.remove(req.request_id)
if req.request_id in self.active_kv_transfers:
self.active_kv_transfers.remove(req.request_id)
self.pending_stop_after_extraction.discard(req.request_id)

# KV Connector: update state for finished KV Transfers.
if kv_connector_output:
self._update_from_kv_xfer_finished(kv_connector_output)
Expand Down Expand Up @@ -489,49 +517,25 @@ def update_from_output(
engine_core_outputs[0] = eco = EngineCoreOutputs()
eco.scheduler_stats = stats

# This is where we free blocks that were held for transfer
try:
kv_extracted_ids = getattr(model_runner_output, "kv_extracted_req_ids", None)
if kv_extracted_ids:
for req_id in kv_extracted_ids:
# Emit a kv_ready signal so the orchestrator can forward
# the request to the DiT stage immediately after KV
# extraction, without waiting for AR decode to finish.
req = self.requests.get(req_id)
if req is not None and not req.is_finished():
eco = engine_core_outputs.get(req.client_index)
if eco is None:
eco = EngineCoreOutputs()
engine_core_outputs[req.client_index] = eco
eco.outputs.append(
EngineCoreOutput(
request_id=req_id,
new_token_ids=[],
kv_transfer_params={"kv_ready": True},
)
)

# Mark transfer as finished
if req_id in self.active_kv_transfers:
self.active_kv_transfers.remove(req_id)
logger.debug(f"[Omni] KV Transfer finished for {req_id}")

# Free blocks that were held for transfer (kv_ready and
# active_kv_transfers updates already done before the per-request loop).
if kv_extracted_ids:
for req_id in kv_extracted_ids:
try:
if req_id in self.waiting_for_transfer_free:
# Now it's safe to free blocks
req = self.requests.get(req_id)
if req:
self.kv_cache_manager.free(req)
if req_id in self.requests:
del self.requests[req_id]
if req_id in self.transfer_triggered_requests:
self.transfer_triggered_requests.remove(req_id)
if req_id in self.active_kv_transfers:
self.active_kv_transfers.remove(req_id)

self.active_kv_transfers.discard(req_id)
self.pending_stop_after_extraction.discard(req_id)
logger.debug(f"Freed blocks for {req_id} after transfer extraction")
self.waiting_for_transfer_free.remove(req_id)
Comment thread
natureofnature marked this conversation as resolved.
except Exception:
init_logger(__name__).exception("Failed to process finished transfer requests")
except Exception:
init_logger(__name__).exception("Failed to free blocks for %s after transfer", req_id)

return engine_core_outputs

Expand Down Expand Up @@ -564,8 +568,7 @@ def _free_request(self, request: Request, delay_free_blocks: bool = False) -> di
kv_xfer_params = None
return kv_xfer_params
elif request_id in self.waiting_for_transfer_free:
# Stopped immediately by stop_decode_on_trigger; blocks are
# held until KV extraction completes in a future step.
# Blocks held until KV extraction completes in a future step.
return None
else:
logger.debug(
Expand Down
Loading
Loading