diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 0111fd6e7198..3d68e2210fad 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -42,7 +42,7 @@ from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue from vllm.v1.core.sched.utils import check_stop, remove_all from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs -from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec from vllm.v1.metrics.perf import ModelMetrics, PerfStats from vllm.v1.metrics.stats import ( PrefixCacheStats, @@ -220,6 +220,14 @@ def __init__( self.use_pp = self.parallel_config.pipeline_parallel_size > 1 self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER + def has_mamba_layers(kv_cache_config: KVCacheConfig) -> bool: + return any( + isinstance(group_spec.kv_cache_spec, MambaSpec) + for group_spec in kv_cache_config.kv_cache_groups + ) + + self.has_mamba_layers = has_mamba_layers(kv_cache_config) + self.perf_metrics: ModelMetrics | None = None if self.log_stats and vllm_config.observability_config.enable_mfu_metrics: self.perf_metrics = ModelMetrics(vllm_config) @@ -274,11 +282,22 @@ def schedule(self) -> SchedulerOutput: req_index += 1 continue - num_new_tokens = ( - request.num_tokens_with_spec - + request.num_output_placeholders - - request.num_computed_tokens + num_tokens_to_compute = ( + request.num_tokens_with_spec + request.num_output_placeholders ) + # Ensure new tokens for a request in the prefill phase do not contain + # speculative tokens, especially in the last prefill chunk. For a hybrid + # model, extra speculative tokens would corrupt the generated mamba state. + # TODO: This logic does not yet handle resumed requests. + if ( + self.has_mamba_layers + and request.num_computed_tokens < request.num_prompt_tokens + ): + num_tokens_to_compute = min( + num_tokens_to_compute, request.num_prompt_tokens + ) + num_new_tokens = num_tokens_to_compute - request.num_computed_tokens + if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = min(num_new_tokens, token_budget)