diff --git a/tests/test_platform_update_block_size.py b/tests/test_platform_update_block_size.py new file mode 100644 index 00000000..7de7447d --- /dev/null +++ b/tests/test_platform_update_block_size.py @@ -0,0 +1,576 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for MetalPlatform.update_block_size_for_backend(). + +Tests cover: +1. Success cases: hybrid models, non-hybrid models +2. Failure cases: model resolution failure, invalid config, etc. +""" + +from unittest.mock import MagicMock, patch + +import pytest +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig + +from vllm_metal.platform import MetalPlatform + + +class TestUpdateBlockSizeForBackend: + """Test suite for update_block_size_for_backend() method.""" + + # ======================================================================== + # Fixtures + # ======================================================================== + + @pytest.fixture + def base_cache_config(self): + """Create a base CacheConfig mock for testing.""" + cache_config = MagicMock(spec=CacheConfig) + cache_config.block_size = 16 + cache_config.user_specified_block_size = ( + False # Allow block_size to be adjusted + ) + cache_config.gpu_memory_utilization = 0.9 + cache_config.cache_dtype = "auto" + cache_config.mamba_cache_mode = "none" + cache_config.mamba_block_size = None + cache_config.mamba_page_size_padded = None + return cache_config + + @pytest.fixture + def base_model_config(self): + """Create a base ModelConfig mock for testing.""" + import torch + + model_config = MagicMock(spec=ModelConfig) + model_config.is_hybrid = True + model_config.architecture = "Qwen3_5ForCausalLM" + model_config.dtype = torch.float16 # Use torch.dtype instead of string + model_config.max_model_len = 512 + model_config.get_num_kv_heads.return_value = 8 + model_config.get_head_size.return_value = 128 + return model_config + + @pytest.fixture + def vllm_config(self, base_cache_config, base_model_config): + """Create a complete VllmConfig for hybrid model testing.""" + parallel_config = MagicMock(spec=ParallelConfig) + parallel_config.tensor_parallel_size = 1 + parallel_config.pipeline_parallel_size = 1 + + config = MagicMock(spec=VllmConfig) + config.model_config = base_model_config + config.cache_config = base_cache_config + config.parallel_config = parallel_config + return config + + @pytest.fixture + def mock_mamba_state(self): + """Create mock mamba state shape and dtype. + + Use small shapes that result in reasonable block_size calculations. + For Qwen3.5-0.8B: + - conv_shape: (max_seqs, conv_kernel-1, conv_dim) + - recurrent_shape: (max_seqs, num_v_heads, value_head_dim, key_head_dim) + + Using max_seqs=1 keeps the page size small for testing. + """ + import torch + + return { + "shape": ( + (1, 3, 2048), # conv_shape (max_seqs=1) + (1, 8, 128, 128), # recurrent_shape (max_seqs=1) + ), + "dtype": ( + torch.float32, # conv_dtype + torch.float32, # recurrent_dtype + ), + } + + # ======================================================================== + # Success Cases + # ======================================================================== + + def test_hybrid_model_success(self, vllm_config, mock_mamba_state): + """Test: Hybrid model successfully sets mamba_page_size_padded. + + This is the main success path - hybrid model with valid config. + """ + with patch("vllm.model_executor.models.ModelRegistry") as mock_registry: + # Setup mock + mock_model_cls = MagicMock() + mock_model_cls.get_mamba_state_shape_from_config.return_value = ( + mock_mamba_state["shape"] + ) + mock_model_cls.get_mamba_state_dtype_from_config.return_value = ( + mock_mamba_state["dtype"] + ) + mock_registry.resolve_model_cls.return_value = (mock_model_cls, None) + + # Execute + MetalPlatform.update_block_size_for_backend(vllm_config) + + # Verify + cache_config = vllm_config.cache_config + assert cache_config.mamba_page_size_padded is not None, ( + "mamba_page_size_padded should be set for hybrid model" + ) + assert cache_config.mamba_page_size_padded > 0, ( + "mamba_page_size_padded should be positive" + ) + assert cache_config.block_size >= 16, ( + "block_size should be >= original value" + ) + + def test_hybrid_model_block_size_already_sufficient( + self, vllm_config, mock_mamba_state + ): + """Test: Hybrid model with already-sufficient block_size. + + When block_size is already large enough, block_size should not be reduced. + Note: mamba_page_size_padded may still be set if attn_page_size > mamba_page_size. + """ + # Set a very large block_size upfront + vllm_config.cache_config.block_size = 256 + # Set cache_dtype to ensure consistent page size calculation + vllm_config.cache_config.cache_dtype = "auto" + + with patch("vllm.model_executor.models.ModelRegistry") as mock_registry: + mock_model_cls = MagicMock() + mock_model_cls.get_mamba_state_shape_from_config.return_value = ( + mock_mamba_state["shape"] + ) + mock_model_cls.get_mamba_state_dtype_from_config.return_value = ( + mock_mamba_state["dtype"] + ) + mock_registry.resolve_model_cls.return_value = (mock_model_cls, None) + + # Execute + MetalPlatform.update_block_size_for_backend(vllm_config) + + # Verify: block_size should remain unchanged at 256 + # (it may be adjusted slightly due to alignment requirements) + assert vllm_config.cache_config.block_size >= 256, ( + "block_size should not decrease" + ) + + def test_non_hybrid_model_skipped(self, vllm_config): + """Test: Non-hybrid model skips the update entirely. + + Non-hybrid models don't need mamba_page_size_padded. + """ + # Set model as non-hybrid + vllm_config.model_config.is_hybrid = False + + original_block_size = vllm_config.cache_config.block_size + + # Execute (should return early) + MetalPlatform.update_block_size_for_backend(vllm_config) + + # Verify: no changes + assert vllm_config.cache_config.block_size == original_block_size + assert vllm_config.cache_config.mamba_page_size_padded is None + + # ======================================================================== + # Failure Cases - Early Return (No Exception) + # ======================================================================== + + def test_model_config_none(self): + """Test: None model_config returns early without error. + + This can happen during initialization edge cases. + """ + config = MagicMock(spec=VllmConfig) + config.model_config = None + config.cache_config = MagicMock() + + # Execute (should not raise) + MetalPlatform.update_block_size_for_backend(config) + + # Verify: no changes to cache_config + # (method should return early) + + # ======================================================================== + # Failure Cases - Raise Exceptions + # ======================================================================== + + def test_model_resolution_failure(self, vllm_config): + """Test: Model class resolution failure raises exception. + + This happens when the model architecture is not registered. + """ + with patch("vllm.model_executor.models.ModelRegistry") as mock_registry: + # Setup mock to raise exception + mock_registry.resolve_model_cls.side_effect = ValueError( + "Model architecture 'Qwen3_5ForCausalLM' not found" + ) + + # Execute and verify exception + with pytest.raises(ValueError) as exc_info: + MetalPlatform.update_block_size_for_backend(vllm_config) + + # Verify exception message + assert "not found" in str(exc_info.value).lower() + + def test_get_mamba_state_shape_failure(self, vllm_config): + """Test: get_mamba_state_shape_from_config failure raises exception. + + This happens when model class doesn't have the required method. + """ + with patch("vllm.model_executor.models.ModelRegistry") as mock_registry: + mock_model_cls = MagicMock() + mock_model_cls.get_mamba_state_shape_from_config.side_effect = ( + AttributeError("Model has no get_mamba_state_shape_from_config") + ) + mock_registry.resolve_model_cls.return_value = (mock_model_cls, None) + + # Execute and verify exception + with pytest.raises(AttributeError) as exc_info: + MetalPlatform.update_block_size_for_backend(vllm_config) + + # Verify exception message + assert "get_mamba_state_shape_from_config" in str(exc_info.value) + + def test_get_mamba_state_dtype_failure(self, vllm_config, mock_mamba_state): + """Test: get_mamba_state_dtype_from_config failure raises exception. + + This happens when model class doesn't have the dtype method. + """ + with patch("vllm.model_executor.models.ModelRegistry") as mock_registry: + mock_model_cls = MagicMock() + mock_model_cls.get_mamba_state_shape_from_config.return_value = ( + mock_mamba_state["shape"] + ) + mock_model_cls.get_mamba_state_dtype_from_config.side_effect = ( + AttributeError("Model has no get_mamba_state_dtype_from_config") + ) + mock_registry.resolve_model_cls.return_value = (mock_model_cls, None) + + # Execute and verify exception + with pytest.raises(AttributeError) as exc_info: + MetalPlatform.update_block_size_for_backend(vllm_config) + + # Verify exception message + assert "get_mamba_state_dtype_from_config" in str(exc_info.value) + + def test_mamba_page_size_zero(self, vllm_config): + """Test: Zero mamba_page_size raises exception. + + This happens when state shape calculation results in zero. + """ + import torch + + with patch("vllm.model_executor.models.ModelRegistry") as mock_registry: + mock_model_cls = MagicMock() + # Return zero-sized shape + mock_model_cls.get_mamba_state_shape_from_config.return_value = ( + (0, 0, 0), + (0, 0, 0, 0), + ) + mock_model_cls.get_mamba_state_dtype_from_config.return_value = ( + torch.float32, + torch.float32, + ) + mock_registry.resolve_model_cls.return_value = (mock_model_cls, None) + + # Execute and verify exception + with pytest.raises(ValueError) as exc_info: + MetalPlatform.update_block_size_for_backend(vllm_config) + + # Verify exception message + assert "zero" in str(exc_info.value).lower() + + def test_invalid_architecture(self, vllm_config): + """Test: Invalid architecture raises exception. + + This happens when the architecture string is malformed. + """ + vllm_config.model_config.architecture = "InvalidArchitecture_123" + + with patch("vllm.model_executor.models.ModelRegistry") as mock_registry: + mock_registry.resolve_model_cls.side_effect = KeyError( + "Unknown architecture: InvalidArchitecture_123" + ) + + # Execute and verify exception + with pytest.raises(KeyError) as exc_info: + MetalPlatform.update_block_size_for_backend(vllm_config) + + # Verify exception message + assert "InvalidArchitecture" in str(exc_info.value) + + # ======================================================================== + # Edge Cases + # ======================================================================== + + def test_block_size_increased_to_minimum(self, vllm_config, mock_mamba_state): + """Test: block_size is increased to minimum required value. + + When original block_size is too small, it should be increased. + """ + # Set very small block_size + vllm_config.cache_config.block_size = 1 + + with patch("vllm.model_executor.models.ModelRegistry") as mock_registry: + mock_model_cls = MagicMock() + mock_model_cls.get_mamba_state_shape_from_config.return_value = ( + mock_mamba_state["shape"] + ) + mock_model_cls.get_mamba_state_dtype_from_config.return_value = ( + mock_mamba_state["dtype"] + ) + mock_registry.resolve_model_cls.return_value = (mock_model_cls, None) + + # Execute + MetalPlatform.update_block_size_for_backend(vllm_config) + + # Verify: block_size should be increased + assert vllm_config.cache_config.block_size > 1 + # Should be at least 32 (kernel_block_alignment_size) + assert vllm_config.cache_config.block_size >= 32 + + def test_mamba_cache_mode_align(self, vllm_config, mock_mamba_state): + """Test: mamba_block_size is synced when mamba_cache_mode='align'. + + This tests the align mode specific logic. + """ + vllm_config.cache_config.mamba_cache_mode = "align" + + with patch("vllm.model_executor.models.ModelRegistry") as mock_registry: + mock_model_cls = MagicMock() + mock_model_cls.get_mamba_state_shape_from_config.return_value = ( + mock_mamba_state["shape"] + ) + mock_model_cls.get_mamba_state_dtype_from_config.return_value = ( + mock_mamba_state["dtype"] + ) + mock_registry.resolve_model_cls.return_value = (mock_model_cls, None) + + # Execute + MetalPlatform.update_block_size_for_backend(vllm_config) + + # Verify: mamba_block_size should equal block_size + assert vllm_config.cache_config.mamba_block_size == ( + vllm_config.cache_config.block_size + ) + + def test_hybrid_with_paged_attention_logs_warning( + self, vllm_config, mock_mamba_state, caplog + ): + """Test: Hybrid model + paged attention logs a warning (PR #235). + + PR #235 added block-size translation to support hybrid + paged attention. + When paged attention is enabled for hybrid models, a warning should be + logged explaining the translation mechanism. + """ + import logging + + with ( + patch("vllm.model_executor.models.ModelRegistry") as mock_registry, + patch("vllm_metal.config.get_config") as mock_get_config, + caplog.at_level(logging.WARNING), + ): + mock_model_cls = MagicMock() + mock_model_cls.get_mamba_state_shape_from_config.return_value = ( + mock_mamba_state["shape"] + ) + mock_model_cls.get_mamba_state_dtype_from_config.return_value = ( + mock_mamba_state["dtype"] + ) + mock_registry.resolve_model_cls.return_value = (mock_model_cls, None) + + # Mock metal config with paged attention enabled + mock_metal_config = MagicMock() + mock_metal_config.use_paged_attention = True + mock_get_config.return_value = mock_metal_config + + # Execute - should NOT raise, just log warning + MetalPlatform.update_block_size_for_backend(vllm_config) + + # Verify warning was logged with explanation + assert "block-size translation" in caplog.text + assert "PR #235" in caplog.text + assert "kernel blocks" in caplog.text + + +# ============================================================================ +# MLA Model Tests +# ============================================================================ + + +class TestMLAModels: + """Test suite for MLA (Multi-Token Latent Attention) model support.""" + + @pytest.fixture + def mla_cache_config(self): + """Create a CacheConfig mock for MLA models.""" + cache_config = MagicMock(spec=CacheConfig) + cache_config.block_size = 16 + cache_config.user_specified_block_size = False + cache_config.gpu_memory_utilization = 0.9 + cache_config.cache_dtype = "auto" + cache_config.mamba_cache_mode = "none" + cache_config.mamba_block_size = None + cache_config.mamba_page_size_padded = None + return cache_config + + @pytest.fixture + def mla_model_config(self): + """Create a ModelConfig mock for MLA models (e.g., DeepSeek).""" + import torch + + model_config = MagicMock(spec=ModelConfig) + model_config.is_hybrid = True + model_config.use_mla = True # MLA flag + model_config.is_deepseek_mla = True + model_config.architecture = "DeepSeekV2ForCausalLM" + model_config.dtype = torch.float16 + model_config.max_model_len = 512 + model_config.get_num_kv_heads.return_value = 8 + model_config.get_head_size.return_value = 128 + return model_config + + @pytest.fixture + def mla_vllm_config(self, mla_cache_config, mla_model_config): + """Create a complete VllmConfig for MLA hybrid model testing.""" + parallel_config = MagicMock(spec=ParallelConfig) + parallel_config.tensor_parallel_size = 1 + parallel_config.pipeline_parallel_size = 1 + + config = MagicMock(spec=VllmConfig) + config.model_config = mla_model_config + config.cache_config = mla_cache_config + config.parallel_config = parallel_config + return config + + @pytest.fixture + def mock_mla_mamba_state(self): + """Create mock mamba state shape and dtype for MLA models. + + Using shapes that result in different page sizes for MLA vs FullAttention. + MLA has different KV head dimensions which affects page_size_bytes. + """ + import torch + + return { + "shape": ( + (1, 3, 2048), # conv_shape (max_seqs=1) + (1, 4, 256, 128), # recurrent_shape - MLA uses different head dims + ), + "dtype": ( + torch.float32, # conv_dtype + torch.float32, # recurrent_dtype + ), + } + + def test_mla_hybrid_model_uses_mla_spec( + self, mla_vllm_config, mock_mla_mamba_state + ): + """Test: MLA + Hybrid model uses MLAAttentionSpec (not FullAttentionSpec). + + This test verifies that MLA models use MLAAttentionSpec for page size + calculation by checking that the implementation checks model_config.use_mla. + + Expected behavior: + - Check model_config.use_mla == True + - Use MLAAttentionSpec (which has different page_size calculation) + """ + with patch("vllm.model_executor.models.ModelRegistry") as mock_registry: + mock_model_cls = MagicMock() + mock_model_cls.get_mamba_state_shape_from_config.return_value = ( + mock_mla_mamba_state["shape"] + ) + mock_model_cls.get_mamba_state_dtype_from_config.return_value = ( + mock_mla_mamba_state["dtype"] + ) + mock_registry.resolve_model_cls.return_value = (mock_model_cls, None) + + # Mock to track which Spec class is used + # Patch at the vllm.v1.kv_cache_interface level where they're imported from + with ( + patch("vllm.v1.kv_cache_interface.MLAAttentionSpec") as mock_mla_spec, + patch("vllm.v1.kv_cache_interface.FullAttentionSpec") as mock_full_spec, + ): + # Setup mock return values + mock_mla_spec_instance = MagicMock() + mock_mla_spec_instance.page_size_bytes = 4096 # MLA page size + mock_mla_spec.return_value = mock_mla_spec_instance + + mock_full_spec_instance = MagicMock() + mock_full_spec_instance.page_size_bytes = ( + 2048 # Different FullAttention page size + ) + mock_full_spec.return_value = mock_full_spec_instance + + # Execute + MetalPlatform.update_block_size_for_backend(mla_vllm_config) + + # Verify: MLAAttentionSpec should be used for MLA models + assert mock_mla_spec.called, ( + "MLAAttentionSpec should be used for MLA models (use_mla=True)" + ) + assert not mock_full_spec.called, ( + "FullAttentionSpec should NOT be used for MLA models" + ) + + def test_mla_non_hybrid_skipped(self, mla_vllm_config): + """Test: Pure MLA model (non-hybrid) skips the update. + + When use_mla=True but is_hybrid=False, the method should return early + without modifying cache_config. + + Expected behavior: + - is_hybrid = False triggers early return + - cache_config remains unchanged + """ + mla_vllm_config.model_config.is_hybrid = False + + original_block_size = mla_vllm_config.cache_config.block_size + original_mamba_page_size_padded = ( + mla_vllm_config.cache_config.mamba_page_size_padded + ) + + # Execute + MetalPlatform.update_block_size_for_backend(mla_vllm_config) + + # Verify: no changes + assert mla_vllm_config.cache_config.block_size == original_block_size + assert ( + mla_vllm_config.cache_config.mamba_page_size_padded + == original_mamba_page_size_padded + ) + + @pytest.mark.parametrize("cache_dtype", ["bfloat16", "float16"]) + def test_mla_with_cache_dtype( + self, mla_vllm_config, mock_mla_mamba_state, cache_dtype + ): + """Test: MLA model with different cache_dtype values. + + This test verifies that cache_config.cache_dtype is properly handled + when computing page sizes for MLA models. + + Expected behavior: + - cache_dtype is converted to torch.dtype correctly + - MLAAttentionSpec uses the correct dtype + - mamba_page_size_padded is set correctly + """ + mla_vllm_config.cache_config.cache_dtype = cache_dtype + + with patch("vllm.model_executor.models.ModelRegistry") as mock_registry: + mock_model_cls = MagicMock() + mock_model_cls.get_mamba_state_shape_from_config.return_value = ( + mock_mla_mamba_state["shape"] + ) + mock_model_cls.get_mamba_state_dtype_from_config.return_value = ( + mock_mla_mamba_state["dtype"] + ) + mock_registry.resolve_model_cls.return_value = (mock_model_cls, None) + + # Execute (should not raise) + MetalPlatform.update_block_size_for_backend(mla_vllm_config) + + # Verify + cache_config = mla_vllm_config.cache_config + assert cache_config.mamba_page_size_padded is not None, ( + f"mamba_page_size_padded should be set for cache_dtype={cache_dtype}" + ) diff --git a/tests/test_prefix_cache.py b/tests/test_prefix_cache.py index 7dff5520..1dde5ada 100644 --- a/tests/test_prefix_cache.py +++ b/tests/test_prefix_cache.py @@ -482,50 +482,50 @@ def _mock_device_info(): class TestPrefixCacheFractionParsing: def test_valid_fraction(self, monkeypatch) -> None: monkeypatch.setenv("VLLM_METAL_PREFIX_CACHE_FRACTION", "0.1") - monkeypatch.setattr(mr.mx.metal, "device_info", _mock_device_info) + monkeypatch.setattr(mr.mx, "device_info", _mock_device_info) result = mr._get_prefix_cache_max_bytes() assert result == int(_TEN_GB * 0.1) def test_default_fraction(self, monkeypatch) -> None: monkeypatch.delenv("VLLM_METAL_PREFIX_CACHE_FRACTION", raising=False) - monkeypatch.setattr(mr.mx.metal, "device_info", _mock_device_info) + monkeypatch.setattr(mr.mx, "device_info", _mock_device_info) result = mr._get_prefix_cache_max_bytes() assert result == int(_TEN_GB * mr._PREFIX_CACHE_DEFAULT_FRACTION) def test_invalid_string_uses_default(self, monkeypatch) -> None: monkeypatch.setenv("VLLM_METAL_PREFIX_CACHE_FRACTION", "abc") - monkeypatch.setattr(mr.mx.metal, "device_info", _mock_device_info) + monkeypatch.setattr(mr.mx, "device_info", _mock_device_info) result = mr._get_prefix_cache_max_bytes() assert result == int(_TEN_GB * mr._PREFIX_CACHE_DEFAULT_FRACTION) def test_out_of_range_zero_uses_default(self, monkeypatch) -> None: monkeypatch.setenv("VLLM_METAL_PREFIX_CACHE_FRACTION", "0") - monkeypatch.setattr(mr.mx.metal, "device_info", _mock_device_info) + monkeypatch.setattr(mr.mx, "device_info", _mock_device_info) result = mr._get_prefix_cache_max_bytes() assert result == int(_TEN_GB * mr._PREFIX_CACHE_DEFAULT_FRACTION) def test_out_of_range_above_one_uses_default(self, monkeypatch) -> None: monkeypatch.setenv("VLLM_METAL_PREFIX_CACHE_FRACTION", "2") - monkeypatch.setattr(mr.mx.metal, "device_info", _mock_device_info) + monkeypatch.setattr(mr.mx, "device_info", _mock_device_info) result = mr._get_prefix_cache_max_bytes() assert result == int(_TEN_GB * mr._PREFIX_CACHE_DEFAULT_FRACTION) def test_nan_uses_default(self, monkeypatch) -> None: monkeypatch.setenv("VLLM_METAL_PREFIX_CACHE_FRACTION", "nan") - monkeypatch.setattr(mr.mx.metal, "device_info", _mock_device_info) + monkeypatch.setattr(mr.mx, "device_info", _mock_device_info) result = mr._get_prefix_cache_max_bytes() assert result == int(_TEN_GB * mr._PREFIX_CACHE_DEFAULT_FRACTION) def test_inf_uses_default(self, monkeypatch) -> None: monkeypatch.setenv("VLLM_METAL_PREFIX_CACHE_FRACTION", "inf") - monkeypatch.setattr(mr.mx.metal, "device_info", _mock_device_info) + monkeypatch.setattr(mr.mx, "device_info", _mock_device_info) result = mr._get_prefix_cache_max_bytes() assert result == int(_TEN_GB * mr._PREFIX_CACHE_DEFAULT_FRACTION) def test_device_info_fallback(self, monkeypatch) -> None: monkeypatch.delenv("VLLM_METAL_PREFIX_CACHE_FRACTION", raising=False) monkeypatch.setattr( - mr.mx.metal, + mr.mx, "device_info", lambda: {}, ) diff --git a/tests/test_v1_worker.py b/tests/test_v1_worker.py index e949038a..7b0185f0 100644 --- a/tests/test_v1_worker.py +++ b/tests/test_v1_worker.py @@ -19,6 +19,8 @@ def _make_worker(model_runner: object, *, use_paged_attention: bool) -> MetalWor worker = MetalWorker.__new__(MetalWorker) worker.model_runner = model_runner # type: ignore[assignment] worker.metal_config = SimpleNamespace(use_paged_attention=use_paged_attention) + worker.cache_config = SimpleNamespace(block_size=16) + worker.vllm_config = SimpleNamespace(cache_config=worker.cache_config) return worker @@ -88,19 +90,32 @@ def test_determine_available_memory_paged_capacity_mode(self) -> None: worker.get_cache_block_size_bytes.assert_called_once_with() def test_determine_available_memory_single_sequence_mode(self) -> None: + """Test MLX path returns one max-length sequence estimate (PR #229).""" + import mlx.core as mx + model_runner = SimpleNamespace( scheduler_memory_reporting_mode=MagicMock( return_value="single_sequence_estimate" ), + kv_cache_dtype=mx.float16, + is_hybrid=False, + is_mla=False, + num_layers=16, + num_kv_heads=8, + head_dim=128, ) worker = _make_worker(model_runner, use_paged_attention=False) - worker._one_sequence_kv_bytes = MagicMock(return_value=4096) worker.model_config = SimpleNamespace(max_model_len=2048) - available = MetalWorker.determine_available_memory(worker) + try: + available = MetalWorker.determine_available_memory(worker) - assert available == 4096 - worker._one_sequence_kv_bytes.assert_called_once_with() + # Should return one max-length sequence KV cache bytes + # 2 (K+V) * 16 layers * 2048 tokens * 8 heads * 128 head_dim * 2 bytes + expected = 2 * 16 * 2048 * 8 * 128 * 2 + assert available == expected + finally: + pass def test_get_supported_tasks_delegates_to_runner_capability(self) -> None: model_runner = SimpleNamespace( diff --git a/vllm_metal/platform.py b/vllm_metal/platform.py index 5ca750b0..71f41201 100644 --- a/vllm_metal/platform.py +++ b/vllm_metal/platform.py @@ -301,6 +301,192 @@ def support_hybrid_kv_cache(cls) -> bool: """Metal supports hybrid KV cache for models like Qwen3.5 (SDPA + GDN).""" return True + @classmethod + def update_block_size_for_backend( + cls, + vllm_config, + ) -> None: + """Update block_size to unify page sizes for hybrid models. + + Hybrid models (e.g., Qwen3.5) have two types of layers: + - SDPA layers: page_size scales with block_size + - Mamba/linear layers: page_size is fixed + + vLLM requires all layer page sizes to be divisible. This method adjusts + block_size and sets mamba_page_size_padded to satisfy vLLM's validation. + + Note: + This is a "logical" fix for vLLM's scheduler validation only. + The Metal plugin manages KV cache internally via MLX's make_prompt_cache(), + independent of vLLM's block_size and page_size calculations. + These parameters are used only to pass vLLM's initialization checks. + + Steps: + 1. Compute attention page size per token (MLAAttentionSpec or FullAttentionSpec) + 2. Get Mamba page size from model class + 3. Calculate block_size so SDPA page_size >= Mamba page_size + 4. Sync mamba_block_size if using align mode + 5. Pad mamba_page_size to match SDPA page_size exactly + + Args: + vllm_config: vLLM configuration (modified in-place for vLLM validation) + + Raises: + ValueError: If hybrid model is used with paged attention on Metal, + or if computed mamba_page_size is zero + Exception: Model class resolution or mamba state query failures + """ + from vllm.model_executor.models import ModelRegistry + from vllm.utils.math_utils import cdiv + from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE + from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + MambaSpec, + MLAAttentionSpec, + ) + + cache_config = vllm_config.cache_config + model_config = vllm_config.model_config + + if not model_config: + return + + # Skip non-hybrid models + is_hybrid = getattr(model_config, "is_hybrid", False) + if not is_hybrid: + return + + # For hybrid models with paged attention, log a warning explaining the + # block-size translation mechanism. + # + # Background: + # - vLLM requires block_size=160 (or larger) for hybrid models to satisfy + # page size divisibility validation between SDPA and Mamba layers. + # - Metal paged attention kernels only support block_size in {8, 16, 32}. + # + # Solution (PR #235): + # - vLLM sees a large block_size (e.g., 160) for its scheduler validation. + # - The Metal kernel uses a translated block_size (e.g., 32) that it supports. + # - Each vLLM block is split into ratio = cache_block_size / kernel_block_size + # kernel blocks. For example, one vLLM block of 160 tokens becomes 5 kernel + # blocks of 32 tokens each. + # - The KV cache is reshaped (zero-copy) to match: [num_blocks, 160, ...] → + # [num_blocks*5, 32, ...]. The physical memory layout is unchanged. + # - Block tables are expanded so the kernel reads the correct blocks. + # + # This is a logical transformation only — the computation is identical, just + # the kernel sees more, smaller blocks. + from vllm_metal.config import get_config + + metal_config = get_config() + if metal_config.use_paged_attention: + logger.warning( + "Hybrid model (e.g., Qwen3.5) with paged attention enabled. " + "Using block-size translation (PR #235) to convert vLLM's large " + "block_size to a Metal kernel-compatible size.\n" + " Mechanism: Each vLLM block is split into multiple kernel blocks.\n" + " Example: vLLM block_size=160 → kernel block_size=32 (ratio=5).\n" + " The KV cache is reshaped (zero-copy) and block tables are expanded.\n" + " This is a logical transformation — physical memory is unchanged.\n" + " Note: The default MLX path (without paged attention) is recommended " + "for hybrid models as it has no translation overhead." + ) + + # Step 1: Compute attention page size per token + # Handle cache_dtype conversion + if cache_config.cache_dtype == "auto": + kv_cache_dtype = model_config.dtype + else: + kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + + # Use MLAAttentionSpec for MLA models, FullAttentionSpec otherwise + spec_class = ( + MLAAttentionSpec + if getattr(model_config, "use_mla", False) + else FullAttentionSpec + ) + attn_page_size_1_token = spec_class( + block_size=1, + num_kv_heads=model_config.get_num_kv_heads(vllm_config.parallel_config), + head_size=model_config.get_head_size(), + dtype=kv_cache_dtype, + ).page_size_bytes + + # Step 2: Get Mamba page size (fixed, independent of block_size) + try: + model_cls, _ = ModelRegistry.resolve_model_cls( + model_config.architecture, + model_config=model_config, + ) + mamba_state_shape = model_cls.get_mamba_state_shape_from_config(vllm_config) + mamba_state_dtype = model_cls.get_mamba_state_dtype_from_config(vllm_config) + + mamba_page_size = MambaSpec( + shapes=mamba_state_shape, + dtypes=mamba_state_dtype, + block_size=-1, + ).page_size_bytes + except Exception as e: + # For hybrid models, re-raise exception instead of silently returning + logger.error( + "Failed to get mamba state for hybrid model %s: %s", + model_config.architecture, + e, + ) + raise + + if mamba_page_size == 0: + raise ValueError( + f"Computed mamba_page_size is zero for hybrid model " + f"{model_config.architecture}" + ) + + # Step 3: Calculate block_size so SDPA page_size >= Mamba page_size + # Use the same formula as vLLM's CPU platform for consistency + # + # Note: kernel_block_alignment_size=32 is chosen for Metal GPU performance. + # Common Metal threadgroup sizes are multiples of 32 (e.g., 32, 64, 128, 256). + # However, this value has no actual impact on MLX execution because: + # - MLX manages its own KV cache via make_prompt_cache() + # - This block_size is only used to satisfy vLLM's validation logic + # - The actual Metal kernel uses MLX's native memory layout + # + # Using 32 provides a reasonable balance between: + # - GPU performance (aligned to Metal threadgroup preferences) + # - Memory efficiency (not excessively large) + # - Compatibility with vLLM's page size unification requirements + kernel_block_alignment_size = 32 # Metal GPU kernel alignment + attn_block_size = kernel_block_alignment_size * cdiv( + mamba_page_size, + kernel_block_alignment_size * attn_page_size_1_token, + ) + + if cache_config.block_size < attn_block_size: + cache_config.block_size = attn_block_size + logger.info( + "Setting attention block size to %d tokens " + "to ensure that attention page size is >= mamba page size.", + attn_block_size, + ) + + # Step 4: Sync mamba_block_size if using align mode + if cache_config.mamba_cache_mode == "align": + cache_config.mamba_block_size = cache_config.block_size + + # Step 5: Pad Mamba page size to exactly match SDPA page size + attn_page_size = cache_config.block_size * attn_page_size_1_token + if attn_page_size > mamba_page_size: + cache_config.mamba_page_size_padded = attn_page_size + mamba_padding_pct = ( + 100 * (attn_page_size - mamba_page_size) / mamba_page_size + ) + logger.info( + "Padding mamba page size by %.2f%% to ensure " + "that mamba page size and attention page size are " + "exactly equal.", + mamba_padding_pct, + ) + @classmethod def get_attn_backend_cls( cls, diff --git a/vllm_metal/v1/model_runner.py b/vllm_metal/v1/model_runner.py index a5701a65..066320c3 100644 --- a/vllm_metal/v1/model_runner.py +++ b/vllm_metal/v1/model_runner.py @@ -125,7 +125,7 @@ def _get_prefix_cache_max_bytes() -> int: fallback_bytes = 8 * 1024 * 1024 * 1024 # 8 GB try: - device_info = mx.metal.device_info() + device_info = mx.device_info() total = int(device_info.get("max_recommended_working_set_size", 0)) except (AttributeError, RuntimeError): total = 0 diff --git a/vllm_metal/v1/worker.py b/vllm_metal/v1/worker.py index 5952a19c..57211298 100644 --- a/vllm_metal/v1/worker.py +++ b/vllm_metal/v1/worker.py @@ -143,17 +143,10 @@ def load_model(self) -> None: # Boundary ownership: # - Worker owns resource setup. # - Runner owns STT/runtime capability decisions. - # Hybrid models (Qwen3.5 SDPA+GDN) require paged attention for - # SDPA KV cache + GDN recurrent state management. - if not self.metal_config.use_paged_attention and self.model_runner.is_hybrid: - self.metal_config.use_paged_attention = True - # Prefix caching guard: check_and_update_config() skipped this - # because use_paged_attention was False at config time. - cache_config = self.vllm_config.cache_config - if getattr(cache_config, "enable_prefix_caching", False): - cache_config.enable_prefix_caching = False - logger.info("Metal: disabled prefix caching for hybrid model") - logger.info("Auto-enabled paged attention for hybrid model") + # Note: For hybrid models (Qwen3.5), we don't auto-enable paged attention + # because MLX's make_prompt_cache() handles hybrid KV cache natively. + # Paged attention for hybrid models requires splitting KV cache across + # multiple buffers to avoid Metal's max buffer size limit (~9.5GB). if ( self.metal_config.use_paged_attention and self.model_runner.should_setup_paged_attention() @@ -430,10 +423,14 @@ def determine_available_memory(self) -> int: ) return available - # Default MLX path: one max-length sequence for admission control. + # Default MLX path: report one max-length sequence for admission control. + # This matches the design from PR #229, which ensures the scheduler + # can admit at least one sequence without over-committing memory. + # MLX's make_prompt_cache() dynamically allocates KV cache per request, + # so we only need to report enough for one sequence. available = self._one_sequence_kv_bytes() logger.info( - "MLX path: reporting %.2fGB for scheduler admission control " + "MLX path: reporting %.2f GB for scheduler admission control " "(one max-length sequence, max_model_len=%d)", available / 1e9, self.model_config.max_model_len,