Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
from vllm.v1.core.sched.utils import check_stop, remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
from vllm.v1.metrics.perf import ModelMetrics, PerfStats
from vllm.v1.metrics.stats import (
PrefixCacheStats,
Expand Down Expand Up @@ -220,6 +220,14 @@ def __init__(
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER

def has_mamba_layers(kv_cache_config: KVCacheConfig) -> bool:
return any(
isinstance(group_spec.kv_cache_spec, MambaSpec)
for group_spec in kv_cache_config.kv_cache_groups
)

self.has_mamba_layers = has_mamba_layers(kv_cache_config)

self.perf_metrics: ModelMetrics | None = None
if self.log_stats and vllm_config.observability_config.enable_mfu_metrics:
self.perf_metrics = ModelMetrics(vllm_config)
Expand Down Expand Up @@ -274,11 +282,22 @@ def schedule(self) -> SchedulerOutput:
req_index += 1
continue

num_new_tokens = (
request.num_tokens_with_spec
+ request.num_output_placeholders
- request.num_computed_tokens
num_tokens_to_compute = (
request.num_tokens_with_spec + request.num_output_placeholders
)
# Ensure new tokens for a request in the prefill phase do not contain
# speculative tokens, especially in the last prefill chunk. For a hybrid
# model, extra speculative tokens would corrupt the generated mamba state.
# TODO: This logic does not yet handle resumed requests.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess resume requests don't have speculative tokens so are not affected by this bug. WDYT?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m not sure about that. For a resumed request, couldn't extra spec tokens (e.g., gamma=3) still be appended during prefill? For example, if 1024 prompt tokens + 1024 generated tokens are resumed, calculating 2048+3 tokens in prefill phase instead of just 2048 tokens would likely lead to an incorrect Mamba state.

if (
self.has_mamba_layers
and request.num_computed_tokens < request.num_prompt_tokens
):
num_tokens_to_compute = min(
num_tokens_to_compute, request.num_prompt_tokens
)
num_new_tokens = num_tokens_to_compute - request.num_computed_tokens

if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens:
num_new_tokens = self.scheduler_config.long_prefill_token_threshold
num_new_tokens = min(num_new_tokens, token_budget)
Expand Down