Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
302 changes: 286 additions & 16 deletions tests/v1/e2e/general/test_mamba_prefix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pytest
import torch

import vllm.envs as envs
from tests.utils import create_new_process_for_each_test
from vllm import LLM, SamplingParams, TokensPrompt
from vllm.config import CacheConfig
Expand Down Expand Up @@ -490,12 +491,7 @@ def apply_patch(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(mamba_utils, "do_mamba_copy_block", fake_copy_fn)


@create_new_process_for_each_test()
def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
run_ref_mamba_state_in_subprocess()
apply_patch(monkeypatch)
prompt_dataset = datasets.load_dataset("heheda/a_long_article")
full_prompt = prompt_dataset["train"][0]["text"]
def get_mamba_prefix_cache_step_configs() -> dict[str, TestConfig]:
tests = {
"accept_1": TestConfig(
num_prompt_tokens=554,
Expand Down Expand Up @@ -727,6 +723,27 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
),
}

return tests


def fill_following_kv_cache_block_ids(test_config: TestConfig) -> None:
for step_action_prev, step_action_next in zip(
test_config.step_actions[:-1], test_config.step_actions[1:]
):
if len(step_action_next.kv_cache_block_ids) == 0:
step_action_next.kv_cache_block_ids = (
step_action_prev.kv_cache_block_ids.copy()
)


@create_new_process_for_each_test()
def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
run_ref_mamba_state_in_subprocess()
apply_patch(monkeypatch)
prompt_dataset = datasets.load_dataset("heheda/a_long_article")
full_prompt = prompt_dataset["train"][0]["text"]
tests = get_mamba_prefix_cache_step_configs()

engine = LLM(
model=MODEL,
enable_prefix_caching=True,
Expand Down Expand Up @@ -754,16 +771,7 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
)
global cur_step_action_idx
cur_step_action_idx = 0
for step_action_prev, step_action_next in zip(
test_config.step_actions[:-1], test_config.step_actions[1:]
):
if (
step_action_next.kv_cache_block_ids is not None
and len(step_action_next.kv_cache_block_ids) == 0
):
prev_block_ids = step_action_prev.kv_cache_block_ids
if prev_block_ids is not None:
step_action_next.kv_cache_block_ids = prev_block_ids.copy()
fill_following_kv_cache_block_ids(test_config)
global step_actions
step_actions = test_config.step_actions
_ = engine.generate(
Expand All @@ -783,3 +791,265 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
del engine
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()


@create_new_process_for_each_test()
def test_mamba_prefix_cache_mrv2(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
monkeypatch.setenv("VLLM_USE_V2_MODEL_RUNNER", "1")
envs.disable_envs_cache()

from vllm.v1.worker.gpu.model_runner import GPUModelRunner as MRV2GPUModelRunner
from vllm.v1.worker.gpu.model_states.mamba_hybrid import (
MambaHybridModelState,
)
from vllm.v1.worker.gpu.sample.output import SamplerOutput as MRV2SamplerOutput

events: list[int] = []
original_execute_model = MRV2GPUModelRunner.execute_model
original_sample = MRV2GPUModelRunner.sample
original_copy_mamba_state = MambaHybridModelState._copy_mamba_state
original_step_action_fn = InprocClient.get_output
original_allocate_slots = KVCacheManager.allocate_slots

def get_temporal_copy_idx(
model_state: MambaHybridModelState,
input_batch: Any,
num_computed_tokens: torch.Tensor,
*,
postprocess: bool,
) -> tuple[int, int]:
info = model_state.mamba_cache_align_info
req_idx = int(input_batch.idx_mapping[0].item())
src_block_idx = int(info.state_idx_gpu[req_idx].item())
accepted = int(model_state.num_accepted_tokens_gpu[req_idx].item())

if postprocess:
new_computed = int(num_computed_tokens[req_idx].item())
running_tokens = new_computed - accepted + 1
aligned = (new_computed // BLOCK_SIZE) * BLOCK_SIZE
should_copy = src_block_idx >= 0 and aligned >= running_tokens
accept_bias = aligned - running_tokens
dst_block_idx = aligned // BLOCK_SIZE - 1
else:
num_computed = int(num_computed_tokens[req_idx].item())
query_len = int(
input_batch.query_start_loc[1].item()
- input_batch.query_start_loc[0].item()
)
computed_after = num_computed + query_len
should_copy = src_block_idx >= 0
accept_bias = accepted - 1
dst_block_idx = (computed_after + BLOCK_SIZE - 1) // BLOCK_SIZE - 1

if postprocess:
no_copy = src_block_idx == dst_block_idx and accept_bias == 0
else:
no_copy = src_block_idx == dst_block_idx
if no_copy:
should_copy = False
if not should_copy:
return (-1, -1)
return (src_block_idx + accept_bias, dst_block_idx)

def wrapped_copy_mamba_state(
self: MambaHybridModelState,
input_batch: Any,
num_computed_tokens: torch.Tensor,
mamba_spec: Any,
*,
postprocess: bool,
) -> None:
actual_copy_idx = get_temporal_copy_idx(
self, input_batch, num_computed_tokens, postprocess=postprocess
)
expected_temporal_states: list[tuple[torch.Tensor, torch.Tensor]] = []
if actual_copy_idx != (-1, -1):
info = self.mamba_cache_align_info
block_tables = info.current_step_block_tables
kv_cache_config = info.current_step_kv_cache_config
assert block_tables is not None
assert kv_cache_config is not None
forward_context = self.vllm_config.compilation_config.static_forward_context
for mamba_group_id in info.group_ids:
block_table = block_tables[mamba_group_id]
src_block_id = int(block_table[0, actual_copy_idx[0]].item())
dst_block_id = int(block_table[0, actual_copy_idx[1]].item())
layer_names = kv_cache_config.kv_cache_groups[
mamba_group_id
].layer_names
for layer_name in layer_names:
# Qwen3-Next stores temporal state as the last Mamba cache.
temporal_state = forward_context[layer_name].kv_cache[-1]
expected_temporal_states.append(
(
temporal_state[dst_block_id],
temporal_state[src_block_id].detach().clone(),
)
)
ret = original_copy_mamba_state(
self,
input_batch,
num_computed_tokens,
mamba_spec,
postprocess=postprocess,
)
for target, expected in expected_temporal_states:
torch.testing.assert_close(target, expected)
if cur_step_action is not None:
expected_copy_idx = (
cur_step_action.postprocess_copy_idx
if postprocess
else cur_step_action.preprocess_copy_idx
)
assert actual_copy_idx == expected_copy_idx, (
f"Unexpected MRV2 Mamba align copy: {postprocess=}, "
f"expected={expected_copy_idx}, actual={actual_copy_idx}, "
f"{cur_step_action=}"
)
return ret

def wrapped_execute_model(
self: MRV2GPUModelRunner,
scheduler_output: SchedulerOutput,
*args: Any,
**kwargs: Any,
):
events.extend(
req.num_computed_tokens for req in scheduler_output.scheduled_new_reqs
)
events.extend(scheduler_output.scheduled_cached_reqs.num_computed_tokens)
if cur_step_action is not None:
num_scheduled_tokens = next(
iter(scheduler_output.num_scheduled_tokens.values())
)
assert num_scheduled_tokens == cur_step_action.num_scheduled_tokens
ret = original_execute_model(self, scheduler_output, *args, **kwargs)
if cur_step_action is not None and self.execute_model_state is not None:
input_batch = self.execute_model_state.input_batch
assert (
cur_step_action.num_computed_tokens_start
== input_batch.positions[input_batch.query_start_loc[0]].item()
)
return ret

def fake_sample(
self: MRV2GPUModelRunner,
hidden_states: torch.Tensor,
input_batch: Any,
grammar_output: Any,
):
if cur_step_action is None:
return original_sample(self, hidden_states, input_batch, grammar_output)

num_reqs = input_batch.num_reqs
sampled_token_ids = torch.ones(
(num_reqs, self.num_speculative_steps + 1),
device=hidden_states.device,
dtype=torch.int64,
)
num_logits = torch.tensor(
input_batch.cu_num_logits_np[1 : num_reqs + 1]
- input_batch.cu_num_logits_np[:num_reqs],
device=hidden_states.device,
dtype=torch.int32,
)
accepted = torch.full_like(num_logits, num_accepted_tokens)
num_sampled = torch.minimum(accepted, num_logits)
prefill_lens = self.req_states.prefill_len.gpu[input_batch.idx_mapping]
is_chunked_prefill = input_batch.seq_lens[:num_reqs] < prefill_lens
num_sampled = torch.where(is_chunked_prefill, 0, num_sampled)
num_rejected = torch.where(is_chunked_prefill, 0, num_logits - num_sampled)
sampler_output = MRV2SamplerOutput(
sampled_token_ids=sampled_token_ids,
logprobs_tensors=None,
num_nans=None,
num_sampled=num_sampled,
)
return sampler_output, num_sampled, num_rejected

monkeypatch.setattr(
InprocClient,
"get_output",
get_fake_step_action_fn(original_step_action_fn),
)
monkeypatch.setattr(
KVCacheManager,
"allocate_slots",
get_fake_allocate_slots_fn(original_allocate_slots),
)
monkeypatch.setattr(MRV2GPUModelRunner, "execute_model", wrapped_execute_model)
monkeypatch.setattr(MRV2GPUModelRunner, "sample", fake_sample)
monkeypatch.setattr(
MambaHybridModelState, "_copy_mamba_state", wrapped_copy_mamba_state
)

engine = LLM(
model=MODEL,
load_format="dummy",
skip_tokenizer_init=True,
enable_prefix_caching=True,
block_size=BLOCK_SIZE,
mamba_cache_mode="align",
speculative_config={
"method": "qwen3_next_mtp",
"num_speculative_tokens": num_speculative_tokens,
},
max_num_batched_tokens=3072,
max_model_len=BLOCK_SIZE * 12,
hf_overrides={"num_hidden_layers": NUM_HIDDEN_LAYERS},
seed=42,
)

try:
tests = get_mamba_prefix_cache_step_configs()

global step_actions
global cur_step_action_idx
global num_accepted_tokens
for test_name, test_config in tests.items():
num_accepted_tokens = test_config.num_accepted_tokens
cur_step_action_idx = 0
fill_following_kv_cache_block_ids(test_config)
step_actions = test_config.step_actions
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=test_config.num_generated_tokens,
ignore_eos=True,
)
_ = engine.generate(
[TokensPrompt(prompt_token_ids=[1] * test_config.num_prompt_tokens)],
sampling_params=sampling_params,
)
assert cur_step_action_idx == len(test_config.step_actions), test_name
assert (
engine.llm_engine.engine_core.engine_core.scheduler.reset_prefix_cache()
)

step_actions = []
cur_step_action_idx = 0
num_accepted_tokens = 1
prompt = TokensPrompt(prompt_token_ids=[1] * (BLOCK_SIZE * 2))
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=1,
ignore_eos=True,
)
_ = engine.generate([prompt], sampling_params=sampling_params)
first_event_count = len(events)
_ = engine.generate([prompt], sampling_params=sampling_params)
second_events = events[first_event_count:]
prefix_hits = [
num_computed_tokens
for num_computed_tokens in second_events
if num_computed_tokens >= BLOCK_SIZE
]
assert prefix_hits, (
"Expected the second identical prompt to hit prefix cache, "
f"got events={second_events!r}"
)
assert engine.llm_engine.engine_core.engine_core.scheduler.reset_prefix_cache()
finally:
del engine
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
11 changes: 0 additions & 11 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1989,13 +1989,6 @@ def _get_v2_model_runner_unsupported_features(self) -> list[str]:
model_config = self.model_config
speculative_config = self.speculative_config

if (
model_config is not None
and model_config.has_inner_state
and self.cache_config.mamba_cache_mode == "align"
):
unsupported.append("hybrid/mamba models with align cache mode")

if self.parallel_config.prefill_context_parallel_size > 1:
unsupported.append("prefill context parallelism")

Expand Down Expand Up @@ -2126,10 +2119,6 @@ def validate_block_size(self) -> None:
"to schedule a multiple of block_size tokens even if they are "
"in the middle of a mm input"
)
# TODO: support align mamba cache mode for model runner v2
assert not envs.VLLM_USE_V2_MODEL_RUNNER, (
"Model Runner V2 has not yet supported mamba_cache_mode='align'. "
)

@model_validator(mode="after")
def validate_nvfp4_kv_cache_with_mla(self) -> "VllmConfig":
Expand Down
Loading
Loading