Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions vllm_ascend/core/recompute_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.utils import ConstantList, record_function_or_nullcontext

from vllm_ascend.utils import vllm_version_is


# `spec_manager_map` in single_type_kv_cache_manager is a module-level dict
# whose keys are class objects bound at import time. When the async
Expand Down Expand Up @@ -207,9 +209,10 @@ def _update_waiting_for_remote_kv(self, request: Request) -> None:
# Update the request state for scheduling.
request.num_computed_tokens = num_computed_tokens

# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = request.num_computed_tokens
if vllm_version_is("0.19.0"):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The vllm_version_is("0.19.0") check uses exact equality (as defined in vllm_ascend/utils.py). This logic will return False for any subsequent versions (e.g., 0.19.1, 0.20.0), causing the scheduler to revert to the legacy prefill_stats path. If the API changes introduced in 0.19.0 persist in later versions, this will lead to runtime errors. Consider using a version comparison (e.g., >= 0.19.0) to ensure future compatibility.

# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = request.num_computed_tokens

self.finished_recving_kv_req_ids.remove(request.request_id)

Expand Down Expand Up @@ -498,7 +501,8 @@ def schedule(self) -> RecomputeSchedulerOutput:
step_skipped_waiting.prepend_request(request)
continue

request.num_external_computed_tokens = ext_tokens
if vllm_version_is("0.19.0"):
request.num_external_computed_tokens = ext_tokens
num_external_computed_tokens = ext_tokens

connector_prefix_cache_queries = request.num_tokens - num_new_local_computed_tokens
Expand All @@ -507,6 +511,13 @@ def schedule(self) -> RecomputeSchedulerOutput:
# Total computed tokens (local + external).
num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens
assert num_computed_tokens <= request.num_tokens

if not vllm_version_is("0.19.0") and request.prefill_stats is not None:
request.prefill_stats.set(
num_prompt_tokens=request.num_prompt_tokens,
num_local_cached_tokens=num_new_local_computed_tokens,
num_external_cached_tokens=num_external_computed_tokens,
)
else:
# KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed.
Expand Down Expand Up @@ -680,9 +691,10 @@ def schedule(self) -> RecomputeSchedulerOutput:
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
if vllm_version_is("0.19.0"):
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule
Expand Down Expand Up @@ -943,6 +955,12 @@ def update_from_output(
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids or pooler_output is not None or kv_transfer_params or stopped:
# Add EngineCoreOutput for this Request.
prefill_kwargs: dict = {}
if not vllm_version_is("0.19.0"):
prefill_kwargs["prefill_stats"] = request.take_prefill_stats()
else:
prefill_kwargs["num_cached_tokens"] = request.num_cached_tokens
prefill_kwargs["num_external_computed_tokens"] = request.num_external_computed_tokens
outputs[request.client_index].append(
EngineCoreOutput(
request_id=req_id,
Expand All @@ -955,10 +973,9 @@ def update_from_output(
events=request.take_events(),
kv_transfer_params=kv_transfer_params,
trace_headers=request.trace_headers,
num_cached_tokens=request.num_cached_tokens,
num_external_computed_tokens=request.num_external_computed_tokens,
routed_experts=routed_experts,
num_nans_in_logits=request.num_nans_in_logits,
**prefill_kwargs,
)
)
else:
Expand All @@ -976,14 +993,17 @@ def update_from_output(
requests = [self.requests[req_id] for req_id in failed_kv_load_req_ids]
self.finish_requests(failed_kv_load_req_ids, RequestStatus.FINISHED_ERROR)
for request in requests:
prefill_kwargs = {}
if vllm_version_is("0.19.0"):
prefill_kwargs["num_cached_tokens"] = request.num_cached_tokens
outputs[request.client_index].append(
EngineCoreOutput(
request_id=request.request_id,
new_token_ids=[],
finish_reason=request.get_finished_reason(),
events=request.take_events(),
trace_headers=request.trace_headers,
num_cached_tokens=request.num_cached_tokens,
**prefill_kwargs,
)
)

Expand Down
9 changes: 6 additions & 3 deletions vllm_ascend/core/scheduler_dynamic_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager

from vllm_ascend.utils import vllm_version_is


class BudgetRefiner:
"""This budget refiner can make dynamic adjustment to the token budget
Expand Down Expand Up @@ -488,9 +490,10 @@ def schedule(self) -> SchedulerOutput:
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
if vllm_version_is("0.19.0"):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Exact version matching with vllm_version_is("0.19.0") is fragile and will fail on patch releases or newer minor versions. This should be updated to a "greater than or equal to" comparison to avoid regressions in future vLLM releases.

# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule
Expand Down
16 changes: 13 additions & 3 deletions vllm_ascend/core/scheduler_profiling_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from vllm.v1.utils import record_function_or_nullcontext

from vllm_ascend.core.profiling_chunk_predictor import ProfilingChunkManager
from vllm_ascend.utils import vllm_version_is


class ProfilingChunkScheduler(Scheduler):
Expand Down Expand Up @@ -482,13 +483,21 @@ def schedule(self) -> SchedulerOutput: # noqa: C901
step_skipped_waiting.prepend_request(request)
continue

request.num_external_computed_tokens = ext_tokens
if vllm_version_is("0.19.0"):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using exact version matching for feature availability will incorrectly return False for versions like 0.19.1 or 0.20.0. This will likely break the scheduler on any version newer than exactly 0.19.0.

request.num_external_computed_tokens = ext_tokens
num_external_computed_tokens = ext_tokens

connector_prefix_cache_queries = request.num_tokens - num_new_local_computed_tokens
connector_prefix_cache_hits = num_external_computed_tokens

num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens

if not vllm_version_is("0.19.0") and request.prefill_stats is not None:
request.prefill_stats.set(
num_prompt_tokens=request.num_prompt_tokens,
num_local_cached_tokens=num_new_local_computed_tokens,
num_external_cached_tokens=num_external_computed_tokens,
)
assert num_computed_tokens <= request.num_tokens
else:
new_computed_blocks = self.kv_cache_manager.empty_kv_cache_blocks
Expand Down Expand Up @@ -628,8 +637,9 @@ def schedule(self) -> SchedulerOutput: # noqa: C901
time_budget -= self.profiling_chunk_manager.predict_time(num_new_tokens, request.num_computed_tokens)
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
if vllm_version_is("0.19.0"):
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule
for i in encoder_inputs_to_schedule:
Expand Down
19 changes: 15 additions & 4 deletions vllm_ascend/patch/platform/patch_balance_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.utils import record_function_or_nullcontext

from vllm_ascend.utils import vllm_version_is


class BalanceScheduler(Scheduler):
def __init__(
Expand Down Expand Up @@ -349,14 +351,22 @@ def schedule(self) -> SchedulerOutput:
skipped_waiting_requests.prepend_request(request)
continue

request.num_external_computed_tokens = ext_tokens
if vllm_version_is("0.19.0"):
Comment thread
Potabk marked this conversation as resolved.
request.num_external_computed_tokens = ext_tokens
num_external_computed_tokens = ext_tokens

connector_prefix_cache_queries = request.num_tokens - num_new_local_computed_tokens
connector_prefix_cache_hits = num_external_computed_tokens

# Total computed tokens (local + external).
num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens

if not vllm_version_is("0.19.0") and request.prefill_stats is not None:
request.prefill_stats.set(
num_prompt_tokens=request.num_prompt_tokens,
num_local_cached_tokens=num_new_local_computed_tokens,
num_external_cached_tokens=num_external_computed_tokens,
)
else:
# KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed.
Expand Down Expand Up @@ -496,9 +506,10 @@ def schedule(self) -> SchedulerOutput:
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
if vllm_version_is("0.19.0"):
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule
Expand Down
Loading