diff --git a/vllm_ascend/core/recompute_scheduler.py b/vllm_ascend/core/recompute_scheduler.py index b6fe93f839a..02493d1338e 100644 --- a/vllm_ascend/core/recompute_scheduler.py +++ b/vllm_ascend/core/recompute_scheduler.py @@ -46,6 +46,8 @@ from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.utils import ConstantList, record_function_or_nullcontext +from vllm_ascend.utils import vllm_version_is + # `spec_manager_map` in single_type_kv_cache_manager is a module-level dict # whose keys are class objects bound at import time. When the async @@ -207,9 +209,10 @@ def _update_waiting_for_remote_kv(self, request: Request) -> None: # Update the request state for scheduling. request.num_computed_tokens = num_computed_tokens - # Count the number of prefix cached tokens. - if request.num_cached_tokens < 0: - request.num_cached_tokens = request.num_computed_tokens + if vllm_version_is("0.19.0"): + # Count the number of prefix cached tokens. + if request.num_cached_tokens < 0: + request.num_cached_tokens = request.num_computed_tokens self.finished_recving_kv_req_ids.remove(request.request_id) @@ -498,7 +501,8 @@ def schedule(self) -> RecomputeSchedulerOutput: step_skipped_waiting.prepend_request(request) continue - request.num_external_computed_tokens = ext_tokens + if vllm_version_is("0.19.0"): + request.num_external_computed_tokens = ext_tokens num_external_computed_tokens = ext_tokens connector_prefix_cache_queries = request.num_tokens - num_new_local_computed_tokens @@ -507,6 +511,13 @@ def schedule(self) -> RecomputeSchedulerOutput: # Total computed tokens (local + external). num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens assert num_computed_tokens <= request.num_tokens + + if not vllm_version_is("0.19.0") and request.prefill_stats is not None: + request.prefill_stats.set( + num_prompt_tokens=request.num_prompt_tokens, + num_local_cached_tokens=num_new_local_computed_tokens, + num_external_cached_tokens=num_external_computed_tokens, + ) else: # KVTransfer: WAITING reqs have num_computed_tokens > 0 # after async KV recvs are completed. @@ -680,9 +691,10 @@ def schedule(self) -> RecomputeSchedulerOutput: token_budget -= num_new_tokens request.status = RequestStatus.RUNNING request.num_computed_tokens = num_computed_tokens - # Count the number of prefix cached tokens. - if request.num_cached_tokens < 0: - request.num_cached_tokens = num_computed_tokens + if vllm_version_is("0.19.0"): + # Count the number of prefix cached tokens. + if request.num_cached_tokens < 0: + request.num_cached_tokens = num_computed_tokens # Encoder-related. if encoder_inputs_to_schedule: scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule @@ -943,6 +955,12 @@ def update_from_output( prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) if new_token_ids or pooler_output is not None or kv_transfer_params or stopped: # Add EngineCoreOutput for this Request. + prefill_kwargs: dict = {} + if not vllm_version_is("0.19.0"): + prefill_kwargs["prefill_stats"] = request.take_prefill_stats() + else: + prefill_kwargs["num_cached_tokens"] = request.num_cached_tokens + prefill_kwargs["num_external_computed_tokens"] = request.num_external_computed_tokens outputs[request.client_index].append( EngineCoreOutput( request_id=req_id, @@ -955,10 +973,9 @@ def update_from_output( events=request.take_events(), kv_transfer_params=kv_transfer_params, trace_headers=request.trace_headers, - num_cached_tokens=request.num_cached_tokens, - num_external_computed_tokens=request.num_external_computed_tokens, routed_experts=routed_experts, num_nans_in_logits=request.num_nans_in_logits, + **prefill_kwargs, ) ) else: @@ -976,6 +993,9 @@ def update_from_output( requests = [self.requests[req_id] for req_id in failed_kv_load_req_ids] self.finish_requests(failed_kv_load_req_ids, RequestStatus.FINISHED_ERROR) for request in requests: + prefill_kwargs = {} + if vllm_version_is("0.19.0"): + prefill_kwargs["num_cached_tokens"] = request.num_cached_tokens outputs[request.client_index].append( EngineCoreOutput( request_id=request.request_id, @@ -983,7 +1003,7 @@ def update_from_output( finish_reason=request.get_finished_reason(), events=request.take_events(), trace_headers=request.trace_headers, - num_cached_tokens=request.num_cached_tokens, + **prefill_kwargs, ) ) diff --git a/vllm_ascend/core/scheduler_dynamic_batch.py b/vllm_ascend/core/scheduler_dynamic_batch.py index 6d1aab58239..8978d64cffc 100644 --- a/vllm_ascend/core/scheduler_dynamic_batch.py +++ b/vllm_ascend/core/scheduler_dynamic_batch.py @@ -31,6 +31,8 @@ from vllm.v1.request import Request, RequestStatus from vllm.v1.structured_output import StructuredOutputManager +from vllm_ascend.utils import vllm_version_is + class BudgetRefiner: """This budget refiner can make dynamic adjustment to the token budget @@ -488,9 +490,10 @@ def schedule(self) -> SchedulerOutput: token_budget -= num_new_tokens request.status = RequestStatus.RUNNING request.num_computed_tokens = num_computed_tokens - # Count the number of prefix cached tokens. - if request.num_cached_tokens < 0: - request.num_cached_tokens = num_computed_tokens + if vllm_version_is("0.19.0"): + # Count the number of prefix cached tokens. + if request.num_cached_tokens < 0: + request.num_cached_tokens = num_computed_tokens # Encoder-related. if encoder_inputs_to_schedule: scheduled_encoder_inputs[request.request_id] = encoder_inputs_to_schedule diff --git a/vllm_ascend/core/scheduler_profiling_chunk.py b/vllm_ascend/core/scheduler_profiling_chunk.py index 01d46ca4ed2..bb85d567ddb 100644 --- a/vllm_ascend/core/scheduler_profiling_chunk.py +++ b/vllm_ascend/core/scheduler_profiling_chunk.py @@ -41,6 +41,7 @@ from vllm.v1.utils import record_function_or_nullcontext from vllm_ascend.core.profiling_chunk_predictor import ProfilingChunkManager +from vllm_ascend.utils import vllm_version_is class ProfilingChunkScheduler(Scheduler): @@ -482,13 +483,21 @@ def schedule(self) -> SchedulerOutput: # noqa: C901 step_skipped_waiting.prepend_request(request) continue - request.num_external_computed_tokens = ext_tokens + if vllm_version_is("0.19.0"): + request.num_external_computed_tokens = ext_tokens num_external_computed_tokens = ext_tokens connector_prefix_cache_queries = request.num_tokens - num_new_local_computed_tokens connector_prefix_cache_hits = num_external_computed_tokens num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens + + if not vllm_version_is("0.19.0") and request.prefill_stats is not None: + request.prefill_stats.set( + num_prompt_tokens=request.num_prompt_tokens, + num_local_cached_tokens=num_new_local_computed_tokens, + num_external_cached_tokens=num_external_computed_tokens, + ) assert num_computed_tokens <= request.num_tokens else: new_computed_blocks = self.kv_cache_manager.empty_kv_cache_blocks @@ -628,8 +637,9 @@ def schedule(self) -> SchedulerOutput: # noqa: C901 time_budget -= self.profiling_chunk_manager.predict_time(num_new_tokens, request.num_computed_tokens) request.status = RequestStatus.RUNNING request.num_computed_tokens = num_computed_tokens - if request.num_cached_tokens < 0: - request.num_cached_tokens = num_computed_tokens + if vllm_version_is("0.19.0"): + if request.num_cached_tokens < 0: + request.num_cached_tokens = num_computed_tokens if encoder_inputs_to_schedule: scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule for i in encoder_inputs_to_schedule: diff --git a/vllm_ascend/patch/platform/patch_balance_schedule.py b/vllm_ascend/patch/platform/patch_balance_schedule.py index 9a2cc722832..001004c945e 100644 --- a/vllm_ascend/patch/platform/patch_balance_schedule.py +++ b/vllm_ascend/patch/platform/patch_balance_schedule.py @@ -24,6 +24,8 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.utils import record_function_or_nullcontext +from vllm_ascend.utils import vllm_version_is + class BalanceScheduler(Scheduler): def __init__( @@ -349,7 +351,8 @@ def schedule(self) -> SchedulerOutput: skipped_waiting_requests.prepend_request(request) continue - request.num_external_computed_tokens = ext_tokens + if vllm_version_is("0.19.0"): + request.num_external_computed_tokens = ext_tokens num_external_computed_tokens = ext_tokens connector_prefix_cache_queries = request.num_tokens - num_new_local_computed_tokens @@ -357,6 +360,13 @@ def schedule(self) -> SchedulerOutput: # Total computed tokens (local + external). num_computed_tokens = num_new_local_computed_tokens + num_external_computed_tokens + + if not vllm_version_is("0.19.0") and request.prefill_stats is not None: + request.prefill_stats.set( + num_prompt_tokens=request.num_prompt_tokens, + num_local_cached_tokens=num_new_local_computed_tokens, + num_external_cached_tokens=num_external_computed_tokens, + ) else: # KVTransfer: WAITING reqs have num_computed_tokens > 0 # after async KV recvs are completed. @@ -496,9 +506,10 @@ def schedule(self) -> SchedulerOutput: token_budget -= num_new_tokens request.status = RequestStatus.RUNNING request.num_computed_tokens = num_computed_tokens - # Count the number of prefix cached tokens. - if request.num_cached_tokens < 0: - request.num_cached_tokens = num_computed_tokens + if vllm_version_is("0.19.0"): + # Count the number of prefix cached tokens. + if request.num_cached_tokens < 0: + request.num_cached_tokens = num_computed_tokens # Encoder-related. if encoder_inputs_to_schedule: scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule