Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions vllm_omni/core/sched/omni_ar_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.logger import init_logger
from vllm.v1.core.sched.async_scheduler import AsyncScheduler as VLLMScheduler
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as SyncScheduler
from vllm.v1.core.sched.utils import remove_all
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
from vllm.v1.metrics.perf import PerfStats
Expand Down Expand Up @@ -80,6 +81,31 @@ def __init__(self, *args, **kwargs):
# Snapshot prompt length for each streaming input update
self._new_prompt_len_snapshot: dict[str, int] = {}

def _get_confirmed_num_computed_tokens(self, request: Request) -> int:
"""num_computed_tokens minus async placeholders (KV actually on GPU)."""
return request.num_computed_tokens - request.num_output_placeholders

def _update_request_with_output(self, request: Request, new_token_ids: list[int]) -> tuple[list[int], bool]:
"""Append output tokens, then cache blocks up to the confirmed count
so KV transfer never sees blocks whose data has not been computed yet.
"""
if request.discard_latest_async_tokens:
request.discard_latest_async_tokens = False
return [], False

status_before_update = request.status

new_token_ids, stopped = SyncScheduler._update_request_with_output(self, request, new_token_ids)

request.num_output_placeholders -= len(new_token_ids)
assert request.num_output_placeholders >= 0

if status_before_update == RequestStatus.RUNNING:
confirmed = self._get_confirmed_num_computed_tokens(request)
self.kv_cache_manager.cache_blocks(request, confirmed)

return new_token_ids, stopped

def _get_kv_transfer_criteria(self) -> dict | None:
# Note: vllm_config is available in Scheduler after super().__init__
if not hasattr(self, "vllm_config"):
Expand Down Expand Up @@ -146,10 +172,13 @@ def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int
return True
return False

# seq_len for KV transfer must exclude async placeholders.
confirmed_computed = self._get_confirmed_num_computed_tokens(request)

if criteria_type == "prefill_finished":
if request.num_computed_tokens >= request.num_prompt_tokens:
if confirmed_computed >= request.num_prompt_tokens:
self.transfer_triggered_requests.add(request.request_id)
self._mark_request_for_kv_transfer(request.request_id, request.num_computed_tokens)
self._mark_request_for_kv_transfer(request.request_id, confirmed_computed)
actually_queued = request.request_id in self.requests_needing_kv_transfer

if stop_decode_on_trigger and actually_queued:
Expand All @@ -169,9 +198,9 @@ def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int
try:
idx = new_token_ids.index(target_token_id)
tokens_to_exclude = len(new_token_ids) - (idx + 1)
snapshot_len = request.num_computed_tokens - tokens_to_exclude
snapshot_len = confirmed_computed - tokens_to_exclude
except ValueError:
snapshot_len = request.num_computed_tokens
snapshot_len = confirmed_computed

self._mark_request_for_kv_transfer(request.request_id, snapshot_len)
actually_queued = request.request_id in self.requests_needing_kv_transfer
Expand Down Expand Up @@ -622,7 +651,8 @@ def _free_request(self, request: Request, delay_free_blocks: bool = False) -> di
)
else:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related: _replace_session_with_streaming_update in omni_scheduler_mixin.py resets num_computed_tokens = 0 but does not reset num_output_placeholders. If this helper is ever called on a request that went through that path, it would return a negative value. Probably worth resetting num_output_placeholders = 0 there too for consistency.

self.waiting_for_transfer_free.add(request_id)
self._mark_request_for_kv_transfer(request_id, request.num_computed_tokens)
confirmed_computed = self._get_confirmed_num_computed_tokens(request)
self._mark_request_for_kv_transfer(request_id, confirmed_computed)
# Return KV transfer metadata so it propagates to RequestOutput
if request_id in self.requests_needing_kv_transfer:
transfer_data = self.requests_needing_kv_transfer[request_id]
Expand Down
8 changes: 7 additions & 1 deletion vllm_omni/model_executor/models/bagel/bagel.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,12 @@ def get_kv_transfer_metadata(
return None
if num_computed_tokens is not None and "image_shape" in meta:
prefill_rope = meta["ropes"][0] if meta.get("ropes") else 0
if num_computed_tokens > prefill_rope:
prefill_position_count = meta.get("prefill_position_count")
if prefill_position_count is not None:
num_decoded = num_computed_tokens - prefill_position_count
if num_decoded > 0:
meta["ropes"] = [prefill_rope + num_decoded]
elif num_computed_tokens > prefill_rope:
meta["ropes"] = [num_computed_tokens]
return meta

Expand Down Expand Up @@ -849,6 +854,7 @@ def _adjust_positions_for_img2img(
{
"ropes": [rope],
"image_shape": [img_H, img_W],
"prefill_position_count": int(end - start),
}
)
img2img_idx += 1
Expand Down
Loading