Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions tests/ut/core/test_recompute_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from types import SimpleNamespace
import unittest
from unittest.mock import patch

from vllm_ascend.core.recompute_scheduler import RecomputeScheduler


def _fake_scheduler_init(self, *args, **kwargs):
self.vllm_config = kwargs["vllm_config"]
# Baseline behavior for hybrid mamba models.
self.need_mamba_block_aligned_split = True
if hasattr(self.vllm_config, "has_mamba_layers"):
self.has_mamba_layers = self.vllm_config.has_mamba_layers


class TestRecomputeSchedulerAllMode(unittest.TestCase):

def _build_vllm_config(self, model_type: str, mamba_cache_mode: str, enable_prefix_caching: bool = True):
return SimpleNamespace(
speculative_config=None,
kv_transfer_config=None,
model_config=SimpleNamespace(
hf_text_config=SimpleNamespace(model_type=model_type),
),
cache_config=SimpleNamespace(
mamba_cache_mode=mamba_cache_mode,
enable_prefix_caching=enable_prefix_caching,
),
)

def _build_vllm_config_with_mamba_flag(
self, model_type: str, mamba_cache_mode: str, has_mamba_layers: bool, enable_prefix_caching: bool = True
):
cfg = self._build_vllm_config(model_type, mamba_cache_mode, enable_prefix_caching)
cfg.has_mamba_layers = has_mamba_layers
return cfg

@patch("vllm_ascend.core.recompute_scheduler.register_ascend_mla_spec_in_manager", lambda: None)
@patch("vllm_ascend.core.recompute_scheduler.Scheduler.__init__", new=_fake_scheduler_init)
def test_all_mode_disables_block_aligned_split_for_qwen3_5(self):
cfg = self._build_vllm_config("qwen3_5", "all")

scheduler = RecomputeScheduler(vllm_config=cfg)

self.assertFalse(scheduler.need_mamba_block_aligned_split)

@patch("vllm_ascend.core.recompute_scheduler.register_ascend_mla_spec_in_manager", lambda: None)
@patch("vllm_ascend.core.recompute_scheduler.Scheduler.__init__", new=_fake_scheduler_init)
def test_align_mode_keeps_block_aligned_split_for_qwen3_5(self):
cfg = self._build_vllm_config("qwen3_5", "align")

scheduler = RecomputeScheduler(vllm_config=cfg)

self.assertTrue(scheduler.need_mamba_block_aligned_split)

@patch("vllm_ascend.core.recompute_scheduler.register_ascend_mla_spec_in_manager", lambda: None)
@patch("vllm_ascend.core.recompute_scheduler.Scheduler.__init__", new=_fake_scheduler_init)
def test_all_mode_non_hybrid_keeps_unchanged(self):
cfg = self._build_vllm_config("llama", "all")

scheduler = RecomputeScheduler(vllm_config=cfg)

self.assertTrue(scheduler.need_mamba_block_aligned_split)

@patch("vllm_ascend.core.recompute_scheduler.register_ascend_mla_spec_in_manager", lambda: None)
@patch("vllm_ascend.core.recompute_scheduler.Scheduler.__init__", new=_fake_scheduler_init)
def test_all_mode_hybrid_without_prefix_caching_keeps_unchanged(self):
cfg = self._build_vllm_config("qwen3_next", "all", enable_prefix_caching=False)

scheduler = RecomputeScheduler(vllm_config=cfg)

self.assertTrue(scheduler.need_mamba_block_aligned_split)

@patch("vllm_ascend.core.recompute_scheduler.register_ascend_mla_spec_in_manager", lambda: None)
@patch("vllm_ascend.core.recompute_scheduler.Scheduler.__init__", new=_fake_scheduler_init)
def test_all_mode_uses_has_mamba_layers_flag_when_available(self):
cfg = self._build_vllm_config_with_mamba_flag("llama", "all", has_mamba_layers=True)

scheduler = RecomputeScheduler(vllm_config=cfg)

self.assertFalse(scheduler.need_mamba_block_aligned_split)
83 changes: 83 additions & 0 deletions tests/ut/patch/platform/test_patch_mamba_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from types import SimpleNamespace
import unittest
from unittest.mock import patch

import torch

from vllm_ascend.patch.platform.patch_mamba_config import verify_and_update_config


class _FakeModelCls:

@staticmethod
def get_mamba_state_shape_from_config(_vllm_config):
return [(128,), (64,)]

@staticmethod
def get_mamba_state_dtype_from_config(_vllm_config):
return [torch.float16, torch.float16]


class TestPatchMambaConfig(unittest.TestCase):

def _build_vllm_config(self, mamba_cache_mode: str):
cache_config = SimpleNamespace(
cache_dtype="auto",
block_size=16,
mamba_page_size_padded=None,
enable_prefix_caching=True,
mamba_cache_mode=mamba_cache_mode,
mamba_block_size=None,
)
model_config = SimpleNamespace(
dtype=torch.float16,
architecture="FakeArch",
max_model_len=4096,
get_num_kv_heads=lambda _parallel: 1,
get_head_size=lambda: 1,
)
parallel_config = SimpleNamespace()
return SimpleNamespace(
cache_config=cache_config,
model_config=model_config,
parallel_config=parallel_config,
)

@patch("vllm_ascend.patch.platform.patch_mamba_config.MambaModelConfig.verify_and_update_config")
@patch("vllm_ascend.patch.platform.patch_mamba_config.ModelRegistry.resolve_model_cls")
def test_all_mode_uses_block_size(self, mock_resolve, _mock_verify):
mock_resolve.return_value = (_FakeModelCls, None)
vllm_config = self._build_vllm_config("all")

verify_and_update_config.__func__(None, vllm_config)

self.assertEqual(
vllm_config.cache_config.mamba_block_size,
vllm_config.cache_config.block_size,
)

@patch("vllm_ascend.patch.platform.patch_mamba_config.MambaModelConfig.verify_and_update_config")
@patch("vllm_ascend.patch.platform.patch_mamba_config.ModelRegistry.resolve_model_cls")
def test_align_mode_uses_block_size(self, mock_resolve, _mock_verify):
mock_resolve.return_value = (_FakeModelCls, None)
vllm_config = self._build_vllm_config("align")

verify_and_update_config.__func__(None, vllm_config)

self.assertEqual(
vllm_config.cache_config.mamba_block_size,
vllm_config.cache_config.block_size,
)

@patch("vllm_ascend.patch.platform.patch_mamba_config.MambaModelConfig.verify_and_update_config")
@patch("vllm_ascend.patch.platform.patch_mamba_config.ModelRegistry.resolve_model_cls")
def test_none_mode_uses_max_model_len(self, mock_resolve, _mock_verify):
mock_resolve.return_value = (_FakeModelCls, None)
vllm_config = self._build_vllm_config("none")

verify_and_update_config.__func__(None, vllm_config)

self.assertEqual(
vllm_config.cache_config.mamba_block_size,
vllm_config.model_config.max_model_len,
)
17 changes: 15 additions & 2 deletions vllm_ascend/core/recompute_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,23 @@ def __init__(self, *args, **kwargs):
and self.vllm_config.kv_transfer_config.is_kv_consumer
)
self.is_kv_producer = self.vllm_config.kv_transfer_config and self.vllm_config.kv_transfer_config.is_kv_producer
model_type = getattr(getattr(self.vllm_config.model_config, "hf_text_config", None), "model_type", "")
model_type = (model_type or "").lower()
self.is_hybrid_model = (
"qwen3_next" in self.vllm_config.model_config.hf_text_config.model_type
or "qwen3_5" in self.vllm_config.model_config.hf_text_config.model_type
"qwen3_next" in model_type
or "qwen3_5" in model_type
)
has_mamba_layers = getattr(self, "has_mamba_layers", None)
if has_mamba_layers is None:
has_mamba_layers = self.is_hybrid_model
cache_config = getattr(self.vllm_config, "cache_config", None)
if (
has_mamba_layers
and cache_config is not None
and cache_config.enable_prefix_caching
and cache_config.mamba_cache_mode == "all"
):
self.need_mamba_block_aligned_split = False

def add_request(self, request: Request) -> None:
existing = self.requests.get(request.request_id)
Expand Down
8 changes: 6 additions & 2 deletions vllm_ascend/patch/platform/patch_mamba_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,12 @@ def verify_and_update_config(cls, vllm_config) -> None:
"exactly equal.",
mamba_padding_pct,
)
if cache_config.enable_prefix_caching and cache_config.mamba_cache_mode == "align":
cache_config.mamba_block_size = cache_config.block_size
if cache_config.enable_prefix_caching:
# Prefix caching needs block-level mamba states in both align/all modes.
if cache_config.mamba_cache_mode in ("align", "all"):
cache_config.mamba_block_size = cache_config.block_size
else:
cache_config.mamba_block_size = model_config.max_model_len
else:
cache_config.mamba_block_size = model_config.max_model_len

Expand Down
Loading