Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 76 additions & 90 deletions vllm_ascend/core/recompute_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from collections import defaultdict
from dataclasses import dataclass, fields

import numpy as np
from vllm._bc_linter import bc_linter_include
from vllm.config import SchedulerConfig, VllmConfig
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
Expand All @@ -40,7 +39,6 @@
from vllm.v1.metrics.perf import PerfStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.utils import ConstantList, record_function_or_nullcontext

Expand Down Expand Up @@ -84,27 +82,6 @@ class RecomputeSchedulerOutput(SchedulerOutput):
class RecomputeScheduler(Scheduler):
running: list[Request]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# When is_mtp_kv_consumer is true, we will fill request.spec_token_ids
# with placeholder tokens to enable full graph when decode nodes pull
# the KV cache of one request from prefill nodes.
self.is_mtp_kv_consumer = (
self.vllm_config.speculative_config
and self.vllm_config.kv_transfer_config
and self.vllm_config.kv_transfer_config.is_kv_consumer
)

def add_request(self, request: Request) -> None:
# Fill in placeholder tokens to enable full graph compatibility. Without
# placeholders, graph matching may fail, forcing eager mode execution.
if self.is_mtp_kv_consumer:
request.spec_token_ids = [PLACEHOLDER_TOKEN_ID] * self.num_spec_tokens
self.waiting.add_request(request)
self.requests[request.request_id] = request
if self.log_stats:
request.record_event(EngineCoreEventType.QUEUED)

def schedule(self) -> RecomputeSchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
Expand Down Expand Up @@ -185,6 +162,9 @@ def schedule(self) -> RecomputeSchedulerOutput:
shift_computed_tokens=1 if self.use_eagle else 0,
)

if self.need_mamba_block_aligned_split:
num_new_tokens = self._mamba_block_aligned_split(request, num_new_tokens)

if num_new_tokens == 0:
# The request cannot be scheduled because one of the following
# reasons:
Expand All @@ -195,6 +175,8 @@ def schedule(self) -> RecomputeSchedulerOutput:
# its max_total_tokens or max_model_len.
# 2. The encoder budget is exhausted.
# 3. The encoder cache is exhausted.
# 4. Insufficient budget for a block-aligned chunk in hybrid
# models with mamba cache mode \"align\".
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
# we do not strictly follow the FCFS scheduling policy and
# allow the lower-priority requests to be scheduled.
Expand Down Expand Up @@ -237,12 +219,12 @@ def schedule(self) -> RecomputeSchedulerOutput:
)
self.running.remove(preempted_req)
if preempted_req in scheduled_running_reqs:
preempted_req_id = preempted_req.request_id
scheduled_running_reqs.remove(preempted_req)
token_budget += num_scheduled_tokens[preempted_req.request_id]
req_to_new_blocks.pop(preempted_req.request_id)
num_scheduled_tokens.pop(preempted_req.request_id)
scheduled_spec_decode_tokens.pop(preempted_req.request_id, None)
preempted_encoder_inputs = scheduled_encoder_inputs.pop(preempted_req.request_id, None)
token_budget += num_scheduled_tokens.pop(preempted_req_id)
req_to_new_blocks.pop(preempted_req_id)
scheduled_spec_decode_tokens.pop(preempted_req_id, None)
preempted_encoder_inputs = scheduled_encoder_inputs.pop(preempted_req_id, None)
if preempted_encoder_inputs:
# Restore encoder compute budget if the preempted
# request had encoder inputs scheduled in this step.
Expand All @@ -266,8 +248,9 @@ def schedule(self) -> RecomputeSchedulerOutput:

# Schedule the request.
scheduled_running_reqs.append(request)
req_to_new_blocks[request.request_id] = new_blocks
num_scheduled_tokens[request.request_id] = num_new_tokens
request_id = request.request_id
req_to_new_blocks[request_id] = new_blocks
num_scheduled_tokens[request_id] = num_new_tokens
token_budget -= num_new_tokens
req_index += 1

Expand All @@ -277,16 +260,18 @@ def schedule(self) -> RecomputeSchedulerOutput:
num_new_tokens + request.num_computed_tokens - request.num_tokens - request.num_output_placeholders
)
if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens.
del request.spec_token_ids[num_scheduled_spec_tokens:]
scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids
spec_token_ids = request.spec_token_ids
if len(spec_token_ids) > num_scheduled_spec_tokens:
spec_token_ids = spec_token_ids[:num_scheduled_spec_tokens]
scheduled_spec_decode_tokens[request.request_id] = spec_token_ids

# New spec tokens will be set in `update_draft_token_ids` before the
# next step when applicable.
request.spec_token_ids = []

# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule
scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
Expand Down Expand Up @@ -318,6 +303,7 @@ def schedule(self) -> RecomputeSchedulerOutput:
break

request = self.waiting.peek_request()
request_id = request.request_id

# KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
Expand All @@ -332,7 +318,7 @@ def schedule(self) -> RecomputeSchedulerOutput:
else:
logger.debug(
"%s is still in WAITING_FOR_REMOTE_KVS state.",
request.request_id,
request_id,
)
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
Expand All @@ -349,6 +335,13 @@ def schedule(self) -> RecomputeSchedulerOutput:
skipped_waiting_requests.prepend_request(request)
continue

# Streaming: skip request if still waiting for next streaming req.
if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
assert not request.streaming_queue
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue

# Check that adding the request still respects the max_loras
# constraint.
if (
Expand All @@ -366,6 +359,7 @@ def schedule(self) -> RecomputeSchedulerOutput:

num_external_computed_tokens = 0
load_kv_async = False
connector_prefix_cache_queries, connector_prefix_cache_hits = 0, 0

# Get already-cached tokens.
if request.num_computed_tokens == 0:
Expand All @@ -391,6 +385,9 @@ def schedule(self) -> RecomputeSchedulerOutput:
request.num_external_computed_tokens = ext_tokens
num_external_computed_tokens = ext_tokens

connector_prefix_cache_queries = request.num_tokens - num_new_local_computed_tokens
connector_prefix_cache_hits = num_external_computed_tokens

# Total computed tokens (local + external).
num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens
else:
Expand All @@ -413,10 +410,7 @@ def schedule(self) -> RecomputeSchedulerOutput:
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed
# requests, which have output tokens.
if self.is_mtp_kv_consumer:
num_new_tokens = request.num_tokens_with_spec - num_computed_tokens
else:
num_new_tokens = request.num_tokens - num_computed_tokens
num_new_tokens = request.num_tokens - num_computed_tokens
threshold = self.scheduler_config.long_prefill_token_threshold
if 0 < threshold < num_new_tokens:
num_new_tokens = threshold
Expand Down Expand Up @@ -449,6 +443,16 @@ def schedule(self) -> RecomputeSchedulerOutput:
# The request cannot be scheduled.
break

if self.need_mamba_block_aligned_split:
num_new_tokens = self._mamba_block_aligned_split(
request,
num_new_tokens,
num_new_local_computed_tokens,
num_external_computed_tokens,
)
if num_new_tokens == 0:
break

# Handles an edge case when P/D Disaggregation
# is used with Spec Decoding where an
# extra block gets allocated which
Expand Down Expand Up @@ -487,9 +491,15 @@ def schedule(self) -> RecomputeSchedulerOutput:
if self.connector is not None:
self.connector.update_state_after_alloc(
request,
self.kv_cache_manager.get_blocks(request.request_id),
self.kv_cache_manager.get_blocks(request_id),
num_external_computed_tokens,
)
if self.connector_prefix_cache_stats is not None and connector_prefix_cache_queries != 0:
self.connector_prefix_cache_stats.record(
num_tokens=connector_prefix_cache_queries,
num_hits=connector_prefix_cache_hits,
preempted=request.num_preemptions > 0,
)

# Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None.
Expand All @@ -501,25 +511,6 @@ def schedule(self) -> RecomputeSchedulerOutput:
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
continue

self._update_connector_prefix_cache_stats(request)

# For spec_token_ids, the waiting queue has the same processing
# as the running queue.
if self.is_mtp_kv_consumer and request.spec_token_ids:
num_scheduled_spec_tokens = (
num_new_tokens
+ request.num_computed_tokens
- request.num_tokens
- request.num_output_placeholders
)
if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens.
del request.spec_token_ids[num_scheduled_spec_tokens:]
scheduled_spec_decode_tokens[request.request_id] = request.spec_token_ids
# New spec tokens will be set in `update_draft_token_ids` before the
# next step when applicable.
request.spec_token_ids = []

self.running.append(request)
if self.log_stats:
request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp)
Expand All @@ -532,8 +523,8 @@ def schedule(self) -> RecomputeSchedulerOutput:

if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_blocks[request.request_id] = self.kv_cache_manager.get_blocks(request.request_id)
num_scheduled_tokens[request.request_id] = num_new_tokens
req_to_new_blocks[request_id] = self.kv_cache_manager.get_blocks(request_id)
num_scheduled_tokens[request_id] = num_new_tokens
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
Expand All @@ -542,7 +533,7 @@ def schedule(self) -> RecomputeSchedulerOutput:
request.num_cached_tokens = num_computed_tokens
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule
scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
Expand Down Expand Up @@ -573,8 +564,8 @@ def schedule(self) -> RecomputeSchedulerOutput:
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"):
if self.running:
any_request = self.running[0]
num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request.request_id)
any_request_id = self.running[0].request_id
num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request_id)

# Construct the scheduler output.
if self.use_v2_model_runner:
Expand Down Expand Up @@ -644,7 +635,7 @@ def schedule(self) -> RecomputeSchedulerOutput:

def update_from_output(
self,
scheduler_output: RecomputeSchedulerOutput,
scheduler_output: SchedulerOutput,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The type hint for scheduler_output has been changed to SchedulerOutput. However, the method body at line 673 accesses scheduler_output.recomputed_reqs, which is an attribute specific to the RecomputeSchedulerOutput subclass. This will cause an AttributeError at runtime because SchedulerOutput does not have this attribute. To fix this bug, the type hint should be reverted to RecomputeSchedulerOutput.

Suggested change
scheduler_output: SchedulerOutput,
scheduler_output: RecomputeSchedulerOutput,

model_runner_output: ModelRunnerOutput,
) -> dict[int, EngineCoreOutputs]:
sampled_token_ids = model_runner_output.sampled_token_ids
Expand Down Expand Up @@ -700,17 +691,21 @@ def update_from_output(
# skip failed or rescheduled requests from KV load failure
continue
request = self.requests.get(req_id)
if request is None:
if request is None or request.is_finished():
# The request is already finished. This can happen if the
# request is aborted while the model is executing it (e.g.,
# in pipeline parallelism).
# in pipeline parallelism or in async scheduling).
# NOTE(Kuntai): When delay_free_blocks=True (for async KV
# cache transfer in KV connector), the aborted request will not
# be set to None (in order to finish async KV transfer).
# In this case, we use is_finished() to check.
continue

req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[req_index] if sampled_token_ids else []

scheduled_spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(req_id)
if scheduled_spec_token_ids:
if scheduled_spec_token_ids and generated_token_ids:
num_draft_tokens = len(scheduled_spec_token_ids)
num_accepted = len(generated_token_ids) - 1
num_rejected = num_draft_tokens - num_accepted
Expand Down Expand Up @@ -749,27 +744,17 @@ def update_from_output(
stopped = True

routed_experts = None
finish_reason = None
if stopped:
if self.vllm_config.model_config.enable_return_routed_experts:
kv_blocks = self.kv_cache_manager.get_blocks(request.request_id)
block_ids = kv_blocks.get_block_ids()[0]
num_tokens = request.num_tokens - 1

# compute slot mapping
block_ids_array = np.array(block_ids, dtype=np.int32)
num_blocks = len(block_ids)
block_size = self.block_size

# generate block offsets
block_offsets = np.arange(0, block_size)

# compute slot mapping: slot = block_id * block_size + offset
slot_mapping = (
block_offsets.reshape((1, block_size)) + block_ids_array.reshape((num_blocks, 1)) * block_size
).flatten()[:num_tokens]

routed_experts = self.routed_experts_reader.get_routed_experts(indices=slot_mapping)
kv_transfer_params = self._free_request(request)
routed_experts = self._get_routed_experts(request)

# Capture finish_reason BEFORE _handle_stopped_request, which may
# reset the status to WAITING for streaming requests that continue.
finish_reason = request.get_finished_reason()
finished = self._handle_stopped_request(request)
if finished:
kv_transfer_params = self._free_request(request)

if status_before_stop == RequestStatus.RUNNING:
stopped_running_reqs.add(request)
else:
Expand All @@ -796,13 +781,13 @@ def update_from_output(

# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids or pooler_output is not None or kv_transfer_params:
if new_token_ids or pooler_output is not None or kv_transfer_params or stopped:
# Add EngineCoreOutput for this Request.
outputs[request.client_index].append(
EngineCoreOutput(
request_id=req_id,
new_token_ids=new_token_ids,
finish_reason=request.get_finished_reason(),
finish_reason=finish_reason,
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
pooling_output=pooler_output,
Expand All @@ -811,6 +796,7 @@ def update_from_output(
kv_transfer_params=kv_transfer_params,
trace_headers=request.trace_headers,
num_cached_tokens=request.num_cached_tokens,
num_external_computed_tokens=request.num_external_computed_tokens,
routed_experts=routed_experts,
num_nans_in_logits=request.num_nans_in_logits,
)
Expand Down
Loading