diff --git a/tests/ut/core/test_recompute_scheduler.py b/tests/ut/core/test_recompute_scheduler.py new file mode 100644 index 00000000000..4c81d29153a --- /dev/null +++ b/tests/ut/core/test_recompute_scheduler.py @@ -0,0 +1,81 @@ +from types import SimpleNamespace +import unittest +from unittest.mock import patch + +from vllm_ascend.core.recompute_scheduler import RecomputeScheduler + + +def _fake_scheduler_init(self, *args, **kwargs): + self.vllm_config = kwargs["vllm_config"] + # Baseline behavior for hybrid mamba models. + self.need_mamba_block_aligned_split = True + if hasattr(self.vllm_config, "has_mamba_layers"): + self.has_mamba_layers = self.vllm_config.has_mamba_layers + + +class TestRecomputeSchedulerAllMode(unittest.TestCase): + + def _build_vllm_config(self, model_type: str, mamba_cache_mode: str, enable_prefix_caching: bool = True): + return SimpleNamespace( + speculative_config=None, + kv_transfer_config=None, + model_config=SimpleNamespace( + hf_text_config=SimpleNamespace(model_type=model_type), + ), + cache_config=SimpleNamespace( + mamba_cache_mode=mamba_cache_mode, + enable_prefix_caching=enable_prefix_caching, + ), + ) + + def _build_vllm_config_with_mamba_flag( + self, model_type: str, mamba_cache_mode: str, has_mamba_layers: bool, enable_prefix_caching: bool = True + ): + cfg = self._build_vllm_config(model_type, mamba_cache_mode, enable_prefix_caching) + cfg.has_mamba_layers = has_mamba_layers + return cfg + + @patch("vllm_ascend.core.recompute_scheduler.register_ascend_mla_spec_in_manager", lambda: None) + @patch("vllm_ascend.core.recompute_scheduler.Scheduler.__init__", new=_fake_scheduler_init) + def test_all_mode_disables_block_aligned_split_for_qwen3_5(self): + cfg = self._build_vllm_config("qwen3_5", "all") + + scheduler = RecomputeScheduler(vllm_config=cfg) + + self.assertFalse(scheduler.need_mamba_block_aligned_split) + + @patch("vllm_ascend.core.recompute_scheduler.register_ascend_mla_spec_in_manager", lambda: None) + @patch("vllm_ascend.core.recompute_scheduler.Scheduler.__init__", new=_fake_scheduler_init) + def test_align_mode_keeps_block_aligned_split_for_qwen3_5(self): + cfg = self._build_vllm_config("qwen3_5", "align") + + scheduler = RecomputeScheduler(vllm_config=cfg) + + self.assertTrue(scheduler.need_mamba_block_aligned_split) + + @patch("vllm_ascend.core.recompute_scheduler.register_ascend_mla_spec_in_manager", lambda: None) + @patch("vllm_ascend.core.recompute_scheduler.Scheduler.__init__", new=_fake_scheduler_init) + def test_all_mode_non_hybrid_keeps_unchanged(self): + cfg = self._build_vllm_config("llama", "all") + + scheduler = RecomputeScheduler(vllm_config=cfg) + + self.assertTrue(scheduler.need_mamba_block_aligned_split) + + @patch("vllm_ascend.core.recompute_scheduler.register_ascend_mla_spec_in_manager", lambda: None) + @patch("vllm_ascend.core.recompute_scheduler.Scheduler.__init__", new=_fake_scheduler_init) + def test_all_mode_hybrid_without_prefix_caching_keeps_unchanged(self): + cfg = self._build_vllm_config("qwen3_next", "all", enable_prefix_caching=False) + + scheduler = RecomputeScheduler(vllm_config=cfg) + + self.assertTrue(scheduler.need_mamba_block_aligned_split) + + @patch("vllm_ascend.core.recompute_scheduler.register_ascend_mla_spec_in_manager", lambda: None) + @patch("vllm_ascend.core.recompute_scheduler.Scheduler.__init__", new=_fake_scheduler_init) + def test_all_mode_uses_has_mamba_layers_flag_when_available(self): + cfg = self._build_vllm_config_with_mamba_flag("llama", "all", has_mamba_layers=True) + + scheduler = RecomputeScheduler(vllm_config=cfg) + + self.assertFalse(scheduler.need_mamba_block_aligned_split) diff --git a/tests/ut/patch/platform/test_patch_mamba_config.py b/tests/ut/patch/platform/test_patch_mamba_config.py new file mode 100644 index 00000000000..e4bdd122c3c --- /dev/null +++ b/tests/ut/patch/platform/test_patch_mamba_config.py @@ -0,0 +1,83 @@ +from types import SimpleNamespace +import unittest +from unittest.mock import patch + +import torch + +from vllm_ascend.patch.platform.patch_mamba_config import verify_and_update_config + + +class _FakeModelCls: + + @staticmethod + def get_mamba_state_shape_from_config(_vllm_config): + return [(128,), (64,)] + + @staticmethod + def get_mamba_state_dtype_from_config(_vllm_config): + return [torch.float16, torch.float16] + + +class TestPatchMambaConfig(unittest.TestCase): + + def _build_vllm_config(self, mamba_cache_mode: str): + cache_config = SimpleNamespace( + cache_dtype="auto", + block_size=16, + mamba_page_size_padded=None, + enable_prefix_caching=True, + mamba_cache_mode=mamba_cache_mode, + mamba_block_size=None, + ) + model_config = SimpleNamespace( + dtype=torch.float16, + architecture="FakeArch", + max_model_len=4096, + get_num_kv_heads=lambda _parallel: 1, + get_head_size=lambda: 1, + ) + parallel_config = SimpleNamespace() + return SimpleNamespace( + cache_config=cache_config, + model_config=model_config, + parallel_config=parallel_config, + ) + + @patch("vllm_ascend.patch.platform.patch_mamba_config.MambaModelConfig.verify_and_update_config") + @patch("vllm_ascend.patch.platform.patch_mamba_config.ModelRegistry.resolve_model_cls") + def test_all_mode_uses_block_size(self, mock_resolve, _mock_verify): + mock_resolve.return_value = (_FakeModelCls, None) + vllm_config = self._build_vllm_config("all") + + verify_and_update_config.__func__(None, vllm_config) + + self.assertEqual( + vllm_config.cache_config.mamba_block_size, + vllm_config.cache_config.block_size, + ) + + @patch("vllm_ascend.patch.platform.patch_mamba_config.MambaModelConfig.verify_and_update_config") + @patch("vllm_ascend.patch.platform.patch_mamba_config.ModelRegistry.resolve_model_cls") + def test_align_mode_uses_block_size(self, mock_resolve, _mock_verify): + mock_resolve.return_value = (_FakeModelCls, None) + vllm_config = self._build_vllm_config("align") + + verify_and_update_config.__func__(None, vllm_config) + + self.assertEqual( + vllm_config.cache_config.mamba_block_size, + vllm_config.cache_config.block_size, + ) + + @patch("vllm_ascend.patch.platform.patch_mamba_config.MambaModelConfig.verify_and_update_config") + @patch("vllm_ascend.patch.platform.patch_mamba_config.ModelRegistry.resolve_model_cls") + def test_none_mode_uses_max_model_len(self, mock_resolve, _mock_verify): + mock_resolve.return_value = (_FakeModelCls, None) + vllm_config = self._build_vllm_config("none") + + verify_and_update_config.__func__(None, vllm_config) + + self.assertEqual( + vllm_config.cache_config.mamba_block_size, + vllm_config.model_config.max_model_len, + ) diff --git a/vllm_ascend/core/recompute_scheduler.py b/vllm_ascend/core/recompute_scheduler.py index b6fe93f839a..9f2c24a1216 100644 --- a/vllm_ascend/core/recompute_scheduler.py +++ b/vllm_ascend/core/recompute_scheduler.py @@ -120,10 +120,23 @@ def __init__(self, *args, **kwargs): and self.vllm_config.kv_transfer_config.is_kv_consumer ) self.is_kv_producer = self.vllm_config.kv_transfer_config and self.vllm_config.kv_transfer_config.is_kv_producer + model_type = getattr(getattr(self.vllm_config.model_config, "hf_text_config", None), "model_type", "") + model_type = (model_type or "").lower() self.is_hybrid_model = ( - "qwen3_next" in self.vllm_config.model_config.hf_text_config.model_type - or "qwen3_5" in self.vllm_config.model_config.hf_text_config.model_type + "qwen3_next" in model_type + or "qwen3_5" in model_type ) + has_mamba_layers = getattr(self, "has_mamba_layers", None) + if has_mamba_layers is None: + has_mamba_layers = self.is_hybrid_model + cache_config = getattr(self.vllm_config, "cache_config", None) + if ( + has_mamba_layers + and cache_config is not None + and cache_config.enable_prefix_caching + and cache_config.mamba_cache_mode == "all" + ): + self.need_mamba_block_aligned_split = False def add_request(self, request: Request) -> None: existing = self.requests.get(request.request_id) diff --git a/vllm_ascend/patch/platform/patch_mamba_config.py b/vllm_ascend/patch/platform/patch_mamba_config.py index 775e5474376..9ffac7d3667 100644 --- a/vllm_ascend/patch/platform/patch_mamba_config.py +++ b/vllm_ascend/patch/platform/patch_mamba_config.py @@ -86,8 +86,12 @@ def verify_and_update_config(cls, vllm_config) -> None: "exactly equal.", mamba_padding_pct, ) - if cache_config.enable_prefix_caching and cache_config.mamba_cache_mode == "align": - cache_config.mamba_block_size = cache_config.block_size + if cache_config.enable_prefix_caching: + # Prefix caching needs block-level mamba states in both align/all modes. + if cache_config.mamba_cache_mode in ("align", "all"): + cache_config.mamba_block_size = cache_config.block_size + else: + cache_config.mamba_block_size = model_config.max_model_len else: cache_config.mamba_block_size = model_config.max_model_len