Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions vllm_ascend/core/recompute_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def __init__(self, *args, **kwargs):
and self.vllm_config.kv_transfer_config
and self.vllm_config.kv_transfer_config.is_kv_consumer
)
self.is_kv_producer = self.vllm_config.kv_transfer_config and self.vllm_config.kv_transfer_config.is_kv_producer
self.is_hybrid_model = (
"qwen3_next" in self.vllm_config.model_config.model_type
or "qwen3_5" in self.vllm_config.model_config.model_type
)

def add_request(self, request: Request) -> None:
existing = self.requests.get(request.request_id)
Expand All @@ -111,13 +116,66 @@ def add_request(self, request: Request) -> None:
request.streaming_queue = deque()
# Fill in placeholder tokens to enable full graph compatibility. Without
# placeholders, graph matching may fail, forcing eager mode execution.
if self.is_kv_producer and self.is_hybrid_model and request.num_tokens > 1:
request.prompt_token_ids.pop()
request._all_token_ids.pop()
request.num_prompt_tokens -= 1
if self.is_mtp_kv_consumer:
request.spec_token_ids = [PLACEHOLDER_TOKEN_ID] * self.num_spec_tokens
self.waiting.add_request(request)
self.requests[request.request_id] = request
if self.log_stats:
request.record_event(EngineCoreEventType.QUEUED)

def _update_waiting_for_remote_kv(self, request: Request) -> bool:
"""
KV Connector: check if the request_id is finished_recving.

The finished_recving_kv_req_ids list is populated
on the previous steps()'s update_from_output based
on the worker side connector.

When the kv transfer is ready, we cache the blocks
and the request state will be moved back to WAITING from
WAITING_FOR_REMOTE_KV.
"""
assert self.connector is not None
if request.request_id not in self.finished_recving_kv_req_ids:
return False

if request.request_id in self.failed_recving_kv_req_ids:
# Request had KV load failures; num_computed_tokens was already
# updated in _update_requests_with_invalid_blocks
if request.num_computed_tokens:
# Cache any valid computed tokens.
self.kv_cache_manager.cache_blocks(request, request.num_computed_tokens)
else:
# No valid computed tokens, release allocated blocks.
# There may be a local cache hit on retry.
self.kv_cache_manager.free(request)

self.failed_recving_kv_req_ids.remove(request.request_id)
else:
# Now that the blocks are ready, actually cache them.
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
if len(block_ids) == 1:
num_computed_tokens = len(block_ids[0]) * self.block_size
# Handle the case where num request tokens less than one block.
num_computed_tokens = min(num_computed_tokens, request.num_tokens)
else:
num_computed_tokens = request.num_tokens
if num_computed_tokens == request.num_tokens:
num_computed_tokens -= 1
# This will cache the blocks iff caching is enabled.
self.kv_cache_manager.cache_blocks(request, num_computed_tokens)

# Update the request state for scheduling.
request.num_computed_tokens = num_computed_tokens

# Return that we are ready.
self.finished_recving_kv_req_ids.remove(request.request_id)
return True

def schedule(self) -> RecomputeSchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
Expand Down
Loading
Loading