diff --git a/tests/ut/batch_invariant/test_batch_invariant.py b/tests/ut/batch_invariant/test_batch_invariant.py index 03dc9d59844..bb9073ac384 100644 --- a/tests/ut/batch_invariant/test_batch_invariant.py +++ b/tests/ut/batch_invariant/test_batch_invariant.py @@ -127,6 +127,7 @@ def test_init_batch_invariance(self, batch_invariant_enabled, has_backend, expec """Test init_batch_invariance under different conditions""" # Mock dependencies import vllm.envs as envs + envs.VLLM_BATCH_INVARIANT = batch_invariant_enabled batch_invariant.HAS_TRITON = has_backend batch_invariant.HAS_ASCENDC_BATCH_INVARIANT = has_backend @@ -151,11 +152,8 @@ def test_init_batch_invariance(self, batch_invariant_enabled, has_backend, expec def test_add_rms_norm(self, mock_torch_npu): """Test add_rms_norm function""" # Mock dependencies - mock_torch = batch_invariant.torch # Create mock tensors - batch_size = 2 - hidden_size = 4 x = MagicMock(spec=torch.Tensor) residual = MagicMock(spec=torch.Tensor) weight = MagicMock(spec=torch.Tensor) @@ -187,8 +185,6 @@ def test_add_rms_norm(self, mock_torch_npu): def test_add_rms_norm_consistency(self, mock_torch_npu): """Test that add_rms_norm produces the same output as torch_npu.npu_add_rms_norm""" # Create mock tensors - batch_size = 2 - hidden_size = 4 x = MagicMock(spec=torch.Tensor) residual = MagicMock(spec=torch.Tensor) weight = MagicMock(spec=torch.Tensor) diff --git a/tests/ut/compilation/test_acl_graph.py b/tests/ut/compilation/test_acl_graph.py index 828c6e39de6..8be495627d6 100644 --- a/tests/ut/compilation/test_acl_graph.py +++ b/tests/ut/compilation/test_acl_graph.py @@ -22,21 +22,22 @@ from vllm.forward_context import BatchDescriptor, ForwardContext from tests.ut.base import TestBase -from vllm_ascend.attention.attention_v1 import (AscendMetadata, - AscendMetadataForDecode) -from vllm_ascend.attention.context_parallel.attention_cp import \ - AscendAttentionCPImpl +from vllm_ascend.attention.attention_v1 import AscendMetadata, AscendMetadataForDecode +from vllm_ascend.attention.context_parallel.attention_cp import AscendAttentionCPImpl from vllm_ascend.attention.context_parallel.mla_cp import AscendMlaCPImpl -from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata, - AscendMLAMetadata) +from vllm_ascend.attention.mla_v1 import AscendMLADecodeMetadata, AscendMLAMetadata from vllm_ascend.compilation.acl_graph import ( - ACLGraphEntry, ACLGraphWrapper, get_draft_graph_params, get_graph_params, - set_draft_graph_params, set_graph_params, - update_draft_graph_params_workspaces) + ACLGraphEntry, + ACLGraphWrapper, + get_draft_graph_params, + get_graph_params, + set_draft_graph_params, + set_graph_params, + update_draft_graph_params_workspaces, +) class TestACLGraphEntry(TestBase): - def test_aclgraph_entry_initialization(self): """Test ACLGraphEntry initialization with default values""" batch_descriptor = BatchDescriptor( @@ -62,10 +63,9 @@ def test_aclgraph_entry_with_values(self): mock_output = MagicMock() input_addresses = [12345, 67890] - entry = ACLGraphEntry(batch_descriptor=batch_descriptor, - aclgraph=mock_graph, - output=mock_output, - input_addresses=input_addresses) + entry = ACLGraphEntry( + batch_descriptor=batch_descriptor, aclgraph=mock_graph, output=mock_output, input_addresses=input_addresses + ) self.assertEqual(entry.batch_descriptor, batch_descriptor) self.assertEqual(entry.aclgraph, mock_graph) @@ -74,7 +74,6 @@ def test_aclgraph_entry_with_values(self): class TestACLGraphWrapper(TestBase): - def setUp(self): """Set up test fixtures""" super().setUp() @@ -106,17 +105,16 @@ def setUp(self): self.mock_forward_context.batch_descriptor = self.mock_batch_descriptor self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL - @patch('vllm_ascend.compilation.acl_graph.current_platform') - @patch('vllm_ascend.compilation.acl_graph.envs') - def test_initialization_with_default_options(self, mock_envs, - mock_current_platform): + @patch("vllm_ascend.compilation.acl_graph.current_platform") + @patch("vllm_ascend.compilation.acl_graph.envs") + def test_initialization_with_default_options(self, mock_envs, mock_current_platform): """Test ACLGraphWrapper initialization with default CUDAGraphOptions""" mock_envs.VLLM_LOGGING_LEVEL = "INFO" mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool - wrapper = ACLGraphWrapper(runnable=self.mock_runnable, - vllm_config=self.mock_vllm_config, - runtime_mode=CUDAGraphMode.FULL) + wrapper = ACLGraphWrapper( + runnable=self.mock_runnable, vllm_config=self.mock_vllm_config, runtime_mode=CUDAGraphMode.FULL + ) self.assertEqual(wrapper.runnable, self.mock_runnable) self.assertEqual(wrapper.vllm_config, self.mock_vllm_config) @@ -126,10 +124,9 @@ def test_initialization_with_default_options(self, mock_envs, self.assertIsInstance(wrapper.aclgraph_options, CUDAGraphOptions) self.assertEqual(wrapper.concrete_aclgraph_entries, {}) - @patch('vllm_ascend.compilation.acl_graph.current_platform') - @patch('vllm_ascend.compilation.acl_graph.envs') - def test_initialization_with_custom_options(self, mock_envs, - mock_current_platform): + @patch("vllm_ascend.compilation.acl_graph.current_platform") + @patch("vllm_ascend.compilation.acl_graph.envs") + def test_initialization_with_custom_options(self, mock_envs, mock_current_platform): """Test ACLGraphWrapper initialization with custom CUDAGraphOptions""" mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool @@ -138,7 +135,8 @@ def test_initialization_with_custom_options(self, mock_envs, runnable=self.mock_runnable, vllm_config=self.mock_vllm_config, runtime_mode=CUDAGraphMode.FULL, - cudagraph_options=self.mock_cudagraph_options) + cudagraph_options=self.mock_cudagraph_options, + ) self.assertEqual(wrapper.runnable, self.mock_runnable) self.assertEqual(wrapper.vllm_config, self.mock_vllm_config) @@ -148,26 +146,25 @@ def test_initialization_with_custom_options(self, mock_envs, self.assertEqual(wrapper.aclgraph_options, self.mock_cudagraph_options) self.assertEqual(wrapper.concrete_aclgraph_entries, {}) - @patch('vllm_ascend.compilation.acl_graph.current_platform') - @patch('vllm_ascend.compilation.acl_graph.envs') - def test_initialization_assertion_error(self, mock_envs, - mock_current_platform): + @patch("vllm_ascend.compilation.acl_graph.current_platform") + @patch("vllm_ascend.compilation.acl_graph.envs") + def test_initialization_assertion_error(self, mock_envs, mock_current_platform): """Test ACLGraphWrapper initialization raises AssertionError for NONE mode""" mock_envs.VLLM_LOGGING_LEVEL = "INFO" mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool with self.assertRaises(AssertionError): - ACLGraphWrapper(runnable=self.mock_runnable, - vllm_config=self.mock_vllm_config, - runtime_mode=CUDAGraphMode.NONE) - - @patch('vllm_ascend.ascend_forward_context.get_forward_context') - @patch('vllm_ascend.compilation.acl_graph.get_forward_context') - @patch('vllm_ascend.compilation.acl_graph.current_platform') - @patch('vllm_ascend.compilation.acl_graph.envs') - def test_call_with_none_runtime_mode(self, mock_envs, - mock_current_platform, - mock_get_forward_context, mock_get_forward_context_2): + ACLGraphWrapper( + runnable=self.mock_runnable, vllm_config=self.mock_vllm_config, runtime_mode=CUDAGraphMode.NONE + ) + + @patch("vllm_ascend.ascend_forward_context.get_forward_context") + @patch("vllm_ascend.compilation.acl_graph.get_forward_context") + @patch("vllm_ascend.compilation.acl_graph.current_platform") + @patch("vllm_ascend.compilation.acl_graph.envs") + def test_call_with_none_runtime_mode( + self, mock_envs, mock_current_platform, mock_get_forward_context, mock_get_forward_context_2 + ): """Test __call__ method when runtime mode is NONE""" mock_envs.VLLM_LOGGING_LEVEL = "INFO" mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool @@ -178,7 +175,8 @@ def test_call_with_none_runtime_mode(self, mock_envs, runnable=self.mock_runnable, vllm_config=self.mock_vllm_config, runtime_mode=CUDAGraphMode.FULL, - cudagraph_options=self.mock_cudagraph_options) + cudagraph_options=self.mock_cudagraph_options, + ) result = wrapper("arg1", "arg2") @@ -186,14 +184,13 @@ def test_call_with_none_runtime_mode(self, mock_envs, self.mock_runnable.assert_called_once_with("arg1", "arg2") self.assertEqual(result, "test_output") - @patch('vllm_ascend.ascend_forward_context.get_forward_context') - @patch('vllm_ascend.compilation.acl_graph.get_forward_context') - @patch('vllm_ascend.compilation.acl_graph.current_platform') - @patch('vllm_ascend.compilation.acl_graph.envs') - def test_call_with_mismatched_runtime_mode(self, mock_envs, - mock_current_platform, - mock_get_forward_context, - mock_get_forward_context_2): + @patch("vllm_ascend.ascend_forward_context.get_forward_context") + @patch("vllm_ascend.compilation.acl_graph.get_forward_context") + @patch("vllm_ascend.compilation.acl_graph.current_platform") + @patch("vllm_ascend.compilation.acl_graph.envs") + def test_call_with_mismatched_runtime_mode( + self, mock_envs, mock_current_platform, mock_get_forward_context, mock_get_forward_context_2 + ): """Test __call__ method when runtime mode doesn't match wrapper mode""" mock_envs.VLLM_LOGGING_LEVEL = "INFO" mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool @@ -205,7 +202,8 @@ def test_call_with_mismatched_runtime_mode(self, mock_envs, runnable=self.mock_runnable, vllm_config=self.mock_vllm_config, runtime_mode=CUDAGraphMode.FULL, - cudagraph_options=self.mock_cudagraph_options) + cudagraph_options=self.mock_cudagraph_options, + ) result = wrapper("arg1", "arg2") @@ -213,20 +211,25 @@ def test_call_with_mismatched_runtime_mode(self, mock_envs, self.mock_runnable.assert_called_once_with("arg1", "arg2") self.assertEqual(result, "test_output") - @patch('vllm_ascend.compilation.acl_graph.torch') - @patch( - 'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled' - ) - @patch('vllm_ascend.compilation.acl_graph.get_forward_context') - @patch('vllm_ascend.ascend_forward_context.get_forward_context') - @patch('vllm_ascend.compilation.acl_graph.current_platform') - @patch('vllm_ascend.compilation.acl_graph.envs') - @patch('vllm_ascend.compilation.acl_graph.compilation_counter') - @patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors') + @patch("vllm_ascend.compilation.acl_graph.torch") + @patch("vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled") + @patch("vllm_ascend.compilation.acl_graph.get_forward_context") + @patch("vllm_ascend.ascend_forward_context.get_forward_context") + @patch("vllm_ascend.compilation.acl_graph.current_platform") + @patch("vllm_ascend.compilation.acl_graph.envs") + @patch("vllm_ascend.compilation.acl_graph.compilation_counter") + @patch("vllm_ascend.compilation.acl_graph.weak_ref_tensors") def test_call_capture_graph_first_time( - self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs, - mock_current_platform, mock_get_forward_context,mock_get_forward_context_2, - mock_validate_cudagraph_capturing_enabled, mock_torch): + self, + mock_weak_ref_tensors, + mock_compilation_counter, + mock_envs, + mock_current_platform, + mock_get_forward_context, + mock_get_forward_context_2, + mock_validate_cudagraph_capturing_enabled, + mock_torch, + ): """Test __call__ method captures graph for the first time""" mock_envs.VLLM_LOGGING_LEVEL = "INFO" mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool @@ -257,7 +260,8 @@ def test_call_capture_graph_first_time( runnable=self.mock_runnable, vllm_config=self.mock_vllm_config, runtime_mode=CUDAGraphMode.FULL, - cudagraph_options=self.mock_cudagraph_options) + cudagraph_options=self.mock_cudagraph_options, + ) # Create a real torch tensor for the test, not a mock test_tensor = torch.tensor([1, 2, 3]) @@ -268,13 +272,11 @@ def test_call_capture_graph_first_time( # Verify graph capture happened mock_validate_cudagraph_capturing_enabled.assert_called_once() mock_torch.npu.NPUGraph.assert_called_once() - mock_torch.npu.graph.assert_called_once_with(mock_npu_graph, - pool=self.mock_graph_pool) + mock_torch.npu.graph.assert_called_once_with(mock_npu_graph, pool=self.mock_graph_pool) self.mock_runnable.assert_called_once_with(test_tensor, "arg2") # Verify the entry was created and updated - self.assertIn(self.mock_batch_descriptor, - wrapper.concrete_aclgraph_entries) + self.assertIn(self.mock_batch_descriptor, wrapper.concrete_aclgraph_entries) entry = wrapper.concrete_aclgraph_entries[self.mock_batch_descriptor] self.assertEqual(entry.aclgraph, mock_npu_graph) self.assertEqual(entry.output, "weak_ref_output") @@ -285,22 +287,25 @@ def test_call_capture_graph_first_time( # Should return the original output (not weak ref) self.assertEqual(result, "test_output") - @patch('vllm_ascend.compilation.acl_graph.torch') - @patch( - 'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled' - ) - @patch('vllm_ascend.compilation.acl_graph.get_forward_context') - @patch('vllm_ascend.ascend_forward_context.get_forward_context') - @patch('vllm_ascend.compilation.acl_graph.current_platform') - @patch('vllm_ascend.compilation.acl_graph.envs') - @patch('vllm_ascend.compilation.acl_graph.compilation_counter') - @patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors') - def test_call_replay_graph(self, mock_weak_ref_tensors, - mock_compilation_counter, mock_envs, - mock_current_platform, mock_get_forward_context, - mock_get_forward_context_2, - mock_validate_cudagraph_capturing_enabled, - mock_torch): + @patch("vllm_ascend.compilation.acl_graph.torch") + @patch("vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled") + @patch("vllm_ascend.compilation.acl_graph.get_forward_context") + @patch("vllm_ascend.ascend_forward_context.get_forward_context") + @patch("vllm_ascend.compilation.acl_graph.current_platform") + @patch("vllm_ascend.compilation.acl_graph.envs") + @patch("vllm_ascend.compilation.acl_graph.compilation_counter") + @patch("vllm_ascend.compilation.acl_graph.weak_ref_tensors") + def test_call_replay_graph( + self, + mock_weak_ref_tensors, + mock_compilation_counter, + mock_envs, + mock_current_platform, + mock_get_forward_context, + mock_get_forward_context_2, + mock_validate_cudagraph_capturing_enabled, + mock_torch, + ): """Test __call__ method replays graph when already captured""" mock_envs.VLLM_LOGGING_LEVEL = "INFO" mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool @@ -333,7 +338,8 @@ def test_call_replay_graph(self, mock_weak_ref_tensors, runnable=self.mock_runnable, vllm_config=self.mock_vllm_config, runtime_mode=CUDAGraphMode.FULL, - cudagraph_options=self.mock_cudagraph_options) + cudagraph_options=self.mock_cudagraph_options, + ) # Create a real torch tensor for the test, not a mock test_tensor = torch.tensor([1, 2, 3]) @@ -363,19 +369,23 @@ def test_call_replay_graph(self, mock_weak_ref_tensors, self.assertEqual(first_result, "test_output") # Original output self.assertEqual(second_result, "weak_ref_output") # Weak ref output - @patch('vllm_ascend.compilation.acl_graph.torch') - @patch( - 'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled' - ) - @patch('vllm_ascend.compilation.acl_graph.get_forward_context') - @patch('vllm_ascend.ascend_forward_context.get_forward_context') - @patch('vllm_ascend.compilation.acl_graph.current_platform') - @patch('vllm_ascend.compilation.acl_graph.envs') - @patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors') + @patch("vllm_ascend.compilation.acl_graph.torch") + @patch("vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled") + @patch("vllm_ascend.compilation.acl_graph.get_forward_context") + @patch("vllm_ascend.ascend_forward_context.get_forward_context") + @patch("vllm_ascend.compilation.acl_graph.current_platform") + @patch("vllm_ascend.compilation.acl_graph.envs") + @patch("vllm_ascend.compilation.acl_graph.weak_ref_tensors") def test_call_with_debug_mode_input_address_check( - self, mock_weak_ref_tensors, mock_envs, mock_current_platform, - mock_get_forward_context,mock_get_forward_context_2, - mock_validate_cudagraph_capturing_enabled, mock_torch): + self, + mock_weak_ref_tensors, + mock_envs, + mock_current_platform, + mock_get_forward_context, + mock_get_forward_context_2, + mock_validate_cudagraph_capturing_enabled, + mock_torch, + ): """Test __call__ method with debug mode input address checking""" mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool @@ -408,7 +418,8 @@ def test_call_with_debug_mode_input_address_check( runnable=self.mock_runnable, vllm_config=self.mock_vllm_config, runtime_mode=CUDAGraphMode.FULL, - cudagraph_options=self.mock_cudagraph_options) + cudagraph_options=self.mock_cudagraph_options, + ) # First call to capture the graph tensor = torch.tensor([1, 2, 3]) # Create tensor once @@ -420,19 +431,23 @@ def test_call_with_debug_mode_input_address_check( # Should not raise AssertionError self.assertTrue(True) - @patch('vllm_ascend.compilation.acl_graph.torch') - @patch( - 'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled' - ) - @patch('vllm_ascend.compilation.acl_graph.get_forward_context') - @patch('vllm_ascend.ascend_forward_context.get_forward_context') - @patch('vllm_ascend.compilation.acl_graph.current_platform') - @patch('vllm_ascend.compilation.acl_graph.envs') - @patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors') + @patch("vllm_ascend.compilation.acl_graph.torch") + @patch("vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled") + @patch("vllm_ascend.compilation.acl_graph.get_forward_context") + @patch("vllm_ascend.ascend_forward_context.get_forward_context") + @patch("vllm_ascend.compilation.acl_graph.current_platform") + @patch("vllm_ascend.compilation.acl_graph.envs") + @patch("vllm_ascend.compilation.acl_graph.weak_ref_tensors") def test_call_with_debug_mode_input_address_mismatch( - self, mock_weak_ref_tensors, mock_envs, mock_current_platform, - mock_get_forward_context,mock_get_forward_context_2, - mock_validate_cudagraph_capturing_enabled, mock_torch): + self, + mock_weak_ref_tensors, + mock_envs, + mock_current_platform, + mock_get_forward_context, + mock_get_forward_context_2, + mock_validate_cudagraph_capturing_enabled, + mock_torch, + ): """Test __call__ method with debug mode input address mismatch raises AssertionError""" mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool @@ -464,37 +479,42 @@ def test_call_with_debug_mode_input_address_mismatch( runnable=self.mock_runnable, vllm_config=self.mock_vllm_config, runtime_mode=CUDAGraphMode.FULL, - cudagraph_options=self.mock_cudagraph_options) + cudagraph_options=self.mock_cudagraph_options, + ) # First call to capture the graph tensor1 = torch.tensor([1, 2, 3]) _ = wrapper(tensor1, "arg2") # Second call with different tensor addresses should raise AssertionError - tensor2 = torch.tensor([4, 5, - 6]) # Different values, different address + tensor2 = torch.tensor([4, 5, 6]) # Different values, different address with self.assertRaises(AssertionError) as context: wrapper(tensor2, "arg2") - self.assertIn("Input addresses for aclgraphs are different", - str(context.exception)) - - @patch('vllm_ascend.compilation.acl_graph.torch') - @patch( - 'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled' - ) - @patch('vllm_ascend.compilation.acl_graph.get_forward_context') - @patch('vllm_ascend.ascend_forward_context.get_forward_context') - @patch('vllm_ascend.compilation.acl_graph.current_platform') - @patch('vllm_ascend.compilation.acl_graph.envs') - @patch('vllm_ascend.compilation.acl_graph.compilation_counter') - @patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors') - @patch('vllm_ascend.compilation.acl_graph.patch') + self.assertIn("Input addresses for aclgraphs are different", str(context.exception)) + + @patch("vllm_ascend.compilation.acl_graph.torch") + @patch("vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled") + @patch("vllm_ascend.compilation.acl_graph.get_forward_context") + @patch("vllm_ascend.ascend_forward_context.get_forward_context") + @patch("vllm_ascend.compilation.acl_graph.current_platform") + @patch("vllm_ascend.compilation.acl_graph.envs") + @patch("vllm_ascend.compilation.acl_graph.compilation_counter") + @patch("vllm_ascend.compilation.acl_graph.weak_ref_tensors") + @patch("vllm_ascend.compilation.acl_graph.patch") def test_call_capture_graph_with_gc_disable( - self, mock_patch, mock_weak_ref_tensors, mock_compilation_counter, - mock_envs, mock_current_platform, mock_get_forward_context,mock_get_forward_context_2, - mock_validate_cudagraph_capturing_enabled, mock_torch): + self, + mock_patch, + mock_weak_ref_tensors, + mock_compilation_counter, + mock_envs, + mock_current_platform, + mock_get_forward_context, + mock_get_forward_context_2, + mock_validate_cudagraph_capturing_enabled, + mock_torch, + ): """Test __call__ method captures graph with gc_disable option enabled""" mock_envs.VLLM_LOGGING_LEVEL = "INFO" mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool @@ -536,7 +556,8 @@ def test_call_capture_graph_with_gc_disable( runnable=self.mock_runnable, vllm_config=self.mock_vllm_config, runtime_mode=CUDAGraphMode.FULL, - cudagraph_options=self.mock_cudagraph_options) + cudagraph_options=self.mock_cudagraph_options, + ) # Create a real torch tensor for the test, not a mock test_tensor = torch.tensor([1, 2, 3]) @@ -550,26 +571,30 @@ def test_call_capture_graph_with_gc_disable( # Verify graph capture happened mock_validate_cudagraph_capturing_enabled.assert_called_once() mock_torch.npu.NPUGraph.assert_called_once() - mock_torch.npu.graph.assert_called_once_with(mock_npu_graph, - pool=self.mock_graph_pool) + mock_torch.npu.graph.assert_called_once_with(mock_npu_graph, pool=self.mock_graph_pool) # Should return the original output (not weak ref) since weak_ref_output is not enabled self.assertEqual(result, "test_output") - @patch('vllm_ascend.compilation.acl_graph.torch') - @patch( - 'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled' - ) - @patch('vllm_ascend.compilation.acl_graph.get_forward_context') - @patch('vllm_ascend.ascend_forward_context.get_forward_context') - @patch('vllm_ascend.compilation.acl_graph.current_platform') - @patch('vllm_ascend.compilation.acl_graph.envs') - @patch('vllm_ascend.compilation.acl_graph.compilation_counter') - @patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors') + @patch("vllm_ascend.compilation.acl_graph.torch") + @patch("vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled") + @patch("vllm_ascend.compilation.acl_graph.get_forward_context") + @patch("vllm_ascend.ascend_forward_context.get_forward_context") + @patch("vllm_ascend.compilation.acl_graph.current_platform") + @patch("vllm_ascend.compilation.acl_graph.envs") + @patch("vllm_ascend.compilation.acl_graph.compilation_counter") + @patch("vllm_ascend.compilation.acl_graph.weak_ref_tensors") def test_call_capture_graph_with_weak_ref_output( - self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs, - mock_current_platform, mock_get_forward_context,mock_get_forward_context_2, - mock_validate_cudagraph_capturing_enabled, mock_torch): + self, + mock_weak_ref_tensors, + mock_compilation_counter, + mock_envs, + mock_current_platform, + mock_get_forward_context, + mock_get_forward_context_2, + mock_validate_cudagraph_capturing_enabled, + mock_torch, + ): """Test __call__ method captures graph with weak_ref_output option enabled""" mock_envs.VLLM_LOGGING_LEVEL = "INFO" mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool @@ -593,9 +618,7 @@ def test_call_capture_graph_with_weak_ref_output( # Mock weak_ref_tensors to simulate the actual behavior: # 1. First call (inside the graph context with weak_ref_output=True) should return "weak_ref_output" # 2. Second call (for entry.output) should return "weak_ref_output" - mock_weak_ref_tensors.side_effect = [ - "weak_ref_output", "weak_ref_output" - ] + mock_weak_ref_tensors.side_effect = ["weak_ref_output", "weak_ref_output"] # Ensure torch.Tensor can be correctly identified by isinstance mock_torch.Tensor = torch.Tensor @@ -607,7 +630,8 @@ def test_call_capture_graph_with_weak_ref_output( runnable=self.mock_runnable, vllm_config=self.mock_vllm_config, runtime_mode=CUDAGraphMode.FULL, - cudagraph_options=self.mock_cudagraph_options) + cudagraph_options=self.mock_cudagraph_options, + ) # Create a real torch tensor for the test, not a mock test_tensor = torch.tensor([1, 2, 3]) @@ -621,20 +645,19 @@ def test_call_capture_graph_with_weak_ref_output( # Verify graph capture happened mock_validate_cudagraph_capturing_enabled.assert_called_once() mock_torch.npu.NPUGraph.assert_called_once() - mock_torch.npu.graph.assert_called_once_with(mock_npu_graph, - pool=self.mock_graph_pool) + mock_torch.npu.graph.assert_called_once_with(mock_npu_graph, pool=self.mock_graph_pool) # Should return the weak ref output when weak_ref_output option is enabled self.assertEqual(result, "weak_ref_output") - - @patch('vllm_ascend.compilation.acl_graph.get_forward_context') - @patch('vllm_ascend.ascend_forward_context.get_forward_context') - @patch('vllm_ascend.compilation.acl_graph.current_platform') - @patch('vllm_ascend.compilation.acl_graph.envs') - @patch('vllm_ascend.compilation.acl_graph.logger') - def test_call_capture_graph_with_debug_log(self, mock_logger, mock_envs, - mock_current_platform, - mock_get_forward_context,mock_get_forward_context_2): + + @patch("vllm_ascend.compilation.acl_graph.get_forward_context") + @patch("vllm_ascend.ascend_forward_context.get_forward_context") + @patch("vllm_ascend.compilation.acl_graph.current_platform") + @patch("vllm_ascend.compilation.acl_graph.envs") + @patch("vllm_ascend.compilation.acl_graph.logger") + def test_call_capture_graph_with_debug_log( + self, mock_logger, mock_envs, mock_current_platform, mock_get_forward_context, mock_get_forward_context_2 + ): """Test __call__ method captures graph with debug logging enabled""" mock_envs.VLLM_LOGGING_LEVEL = "INFO" mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool @@ -647,7 +670,7 @@ def test_call_capture_graph_with_debug_log(self, mock_logger, mock_envs, # weak_ref_output is not enabled by default # Mock torch - with patch('vllm_ascend.compilation.acl_graph.torch') as mock_torch: + with patch("vllm_ascend.compilation.acl_graph.torch") as mock_torch: # Mock torch.npu.NPUGraph mock_npu_graph = MagicMock() mock_torch.npu.NPUGraph.return_value = mock_npu_graph @@ -662,24 +685,20 @@ def test_call_capture_graph_with_debug_log(self, mock_logger, mock_envs, mock_torch.Tensor = torch.Tensor # Mock weak_ref_tensors - with patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors' - ) as mock_weak_ref_tensors: + with patch("vllm_ascend.compilation.acl_graph.weak_ref_tensors") as mock_weak_ref_tensors: # Mock weak_ref_tensors to simulate the actual behavior: # 1. First call (inside the graph context) should return "inner_output" # 2. Second call (for entry.output) should return "weak_ref_output" - mock_weak_ref_tensors.side_effect = [ - "inner_output", "weak_ref_output" - ] + mock_weak_ref_tensors.side_effect = ["inner_output", "weak_ref_output"] # Mock validate_cudagraph_capturing_enabled - with patch( - 'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled' - ): + with patch("vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled"): wrapper = ACLGraphWrapper( runnable=self.mock_runnable, vllm_config=self.mock_vllm_config, runtime_mode=CUDAGraphMode.FULL, - cudagraph_options=self.mock_cudagraph_options) + cudagraph_options=self.mock_cudagraph_options, + ) # Create a real torch tensor for the test, not a mock test_tensor = torch.tensor([1, 2, 3]) @@ -699,7 +718,8 @@ def test_getattr_access_runnable_attributes(self): runnable=mock_runnable, vllm_config=self.mock_vllm_config, runtime_mode=CUDAGraphMode.FULL, - cudagraph_options=self.mock_cudagraph_options) + cudagraph_options=self.mock_cudagraph_options, + ) # Should be able to access attributes of the runnable self.assertEqual(wrapper.test_attr, "test_value") @@ -717,14 +737,14 @@ class EmptyRunnable: runnable=mock_runnable, vllm_config=self.mock_vllm_config, runtime_mode=CUDAGraphMode.FULL, - cudagraph_options=self.mock_cudagraph_options) + cudagraph_options=self.mock_cudagraph_options, + ) # Should raise AttributeError for non-existent attributes with self.assertRaises(AttributeError) as context: _ = wrapper.non_existent_attr - self.assertIn("Attribute non_existent_attr not found", - str(context.exception)) + self.assertIn("Attribute non_existent_attr not found", str(context.exception)) def test_unwrap_method(self): """Test unwrap method returns the original runnable""" @@ -732,36 +752,34 @@ def test_unwrap_method(self): runnable=self.mock_runnable, vllm_config=self.mock_vllm_config, runtime_mode=CUDAGraphMode.FULL, - cudagraph_options=self.mock_cudagraph_options) + cudagraph_options=self.mock_cudagraph_options, + ) unwrapped = wrapper.unwrap() self.assertEqual(unwrapped, self.mock_runnable) class TestDraftGraphParams(TestBase): - def test_set_draft_graph_params(self): - with patch('vllm_ascend.compilation.acl_graph._draft_graph_params', - new=None): + with patch("vllm_ascend.compilation.acl_graph._draft_graph_params", new=None): set_draft_graph_params([4]) from vllm_ascend.compilation.acl_graph import _draft_graph_params + self.assertIsNotNone(_draft_graph_params) - @patch('vllm_ascend.compilation.acl_graph._draft_graph_params') - def test_update_draft_graph_params_workspaces(self, - draft_graph_params_mock): + @patch("vllm_ascend.compilation.acl_graph._draft_graph_params") + def test_update_draft_graph_params_workspaces(self, draft_graph_params_mock): draft_graph_params_mock.workspaces = {4: 5} update_draft_graph_params_workspaces(4, 6) self.assertEqual(draft_graph_params_mock.workspaces[4], 6) - @patch('vllm_ascend.compilation.acl_graph._draft_graph_params') + @patch("vllm_ascend.compilation.acl_graph._draft_graph_params") def test_get_draft_graph_params(self, draft_graph_params_mock): graph_params = get_draft_graph_params() self.assertIs(draft_graph_params_mock, graph_params) class TestPCPDCPGraphParams(TestBase): - def setUp(self): self.update_stream = MagicMock(name="FakeStream") graph_params = get_graph_params() @@ -777,10 +795,12 @@ def setUp(self): self.graph_params.events[4].append(mock_event) self.graph_params.handles[4].append(MagicMock()) - @patch('vllm_ascend.ascend_forward_context.get_forward_context') - @patch('torch.npu.graph_task_update_end', ) - @patch('torch.npu.graph_task_update_begin', MagicMock()) - @patch('torch_npu.npu_fused_infer_attention_score.out', MagicMock()) + @patch("vllm_ascend.ascend_forward_context.get_forward_context") + @patch( + "torch.npu.graph_task_update_end", + ) + @patch("torch.npu.graph_task_update_begin", MagicMock()) + @patch("torch_npu.npu_fused_infer_attention_score.out", MagicMock()) def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end, mock_context): input_positions = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) block_table = torch.zeros(2, 5, dtype=torch.long) @@ -792,23 +812,12 @@ def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end, mock_context): query_start_loc = torch.tensor([0, 4]) block_tables = torch.zeros(2, 5, dtype=torch.long) - decode = AscendMLADecodeMetadata(input_positions, - block_table, - seq_lens, - max_seq_lens, - seq_lens_list, - cp_seq_len=cp_seq_len) - metadata = AscendMLAMetadata(8, - 8, - slot_mapping, - query_start_loc, - seq_lens, - seq_lens, - block_tables, - 4, - 4, - 0, - decode=decode) + decode = AscendMLADecodeMetadata( + input_positions, block_table, seq_lens, max_seq_lens, seq_lens_list, cp_seq_len=cp_seq_len + ) + metadata = AscendMLAMetadata( + 8, 8, slot_mapping, query_start_loc, seq_lens, seq_lens, block_tables, 4, 4, 0, decode=decode + ) forward_context = MagicMock() forward_context.attn_metadata = {"attn_layer_0": metadata} forward_context.is_draft_model = False @@ -832,20 +841,36 @@ def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end, mock_context): lse = torch.randn(2, 16, 8) self.graph_params.attn_params[4] = [] self.graph_params.attn_params[4].append( - (q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout, - None, 0, scale, block_table, 128, None, actual_seq_lengths_kv, - out, lse)) + ( + q_nope, + k_nope, + q_pe, + k_pe, + num_heads, + num_kv_heads, + input_layout, + None, + 0, + scale, + block_table, + 128, + None, + actual_seq_lengths_kv, + out, + lse, + ) + ) with patch("torch_npu._C._npu_setStream", return_value=None): - AscendMlaCPImpl.update_graph_params( - self.update_stream, forward_context, 4 - ) + AscendMlaCPImpl.update_graph_params(self.update_stream, forward_context, 4) _mock_graph_task_end.assert_called_once() - @patch('torch.npu.graph_task_update_end', ) - @patch('torch.npu.graph_task_update_begin', MagicMock()) - @patch('torch_npu.npu_fused_infer_attention_score.out', MagicMock()) + @patch( + "torch.npu.graph_task_update_end", + ) + @patch("torch.npu.graph_task_update_begin", MagicMock()) + @patch("torch_npu.npu_fused_infer_attention_score.out", MagicMock()) def test_update_attn_dcp_pcp_params(self, _mock_graph_task_end): block_table = torch.zeros(2, 5, dtype=torch.long) num_heads = 256 @@ -861,26 +886,40 @@ def test_update_attn_dcp_pcp_params(self, _mock_graph_task_end): out = torch.randn(2, 16, 128) lse = torch.randn(2, 16, 8) - num_computed_tokens_of_pcp_dcp = np.array([[[1, 1], [1, 1]], - [[1, 1], [1, 1]]]) + num_computed_tokens_of_pcp_dcp = np.array([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]) decode = AscendMetadataForDecode(num_computed_tokens_of_pcp_dcp) - metadata = AscendMetadata(num_actual_tokens_pcp_padded=[1, 1], - actual_seq_lengths_q=actual_seq_lengths_q, - num_decode_tokens=1, - decode_meta=decode) + metadata = AscendMetadata( + num_actual_tokens_pcp_padded=[1, 1], + actual_seq_lengths_q=actual_seq_lengths_q, + num_decode_tokens=1, + decode_meta=decode, + ) forward_context = MagicMock() forward_context.attn_metadata = {"attn_layer_0": metadata} forward_context.is_draft_model = False self.graph_params.attn_params[4] = [] self.graph_params.attn_params[4].append( - (q_nope, k_nope, k_nope, num_heads, num_kv_heads, scale, - block_table, 128, actual_seq_lengths_kv, actual_seq_lengths_q, - out, lse, 2, 0, 0)) + ( + q_nope, + k_nope, + k_nope, + num_heads, + num_kv_heads, + scale, + block_table, + 128, + actual_seq_lengths_kv, + actual_seq_lengths_q, + out, + lse, + 2, + 0, + 0, + ) + ) with patch("torch_npu._C._npu_setStream", return_value=None): - AscendAttentionCPImpl.update_graph_params( - self.update_stream, forward_context, 4, None - ) + AscendAttentionCPImpl.update_graph_params(self.update_stream, forward_context, 4, None) _mock_graph_task_end.assert_called_once() diff --git a/tests/ut/compilation/test_npugraph_ex_utils_check.py b/tests/ut/compilation/test_npugraph_ex_utils_check.py index 68e076addb2..236ee17b21d 100644 --- a/tests/ut/compilation/test_npugraph_ex_utils_check.py +++ b/tests/ut/compilation/test_npugraph_ex_utils_check.py @@ -13,8 +13,7 @@ # This file is a part of the vllm-ascend project. # -from vllm_ascend.compilation.passes.utils.npugraph_ex_utils_check import \ - extra_stream_scope_check +from vllm_ascend.compilation.passes.utils.npugraph_ex_utils_check import extra_stream_scope_check def test_extra_stream_scope_check_logic(): @@ -24,31 +23,25 @@ def test_extra_stream_scope_check_logic(): """ class MockNode: - def __init__(self, stream_label=None): self.op = "call_function" self.meta = {"stream_label": stream_label} class MockMatch: - def __init__(self, nodes): self.nodes = nodes # Test 1: all default → OK - assert extra_stream_scope_check( - MockMatch([MockNode(None), MockNode(None)])) is True + assert extra_stream_scope_check(MockMatch([MockNode(None), MockNode(None)])) is True # Test 2: same non-default → OK - assert extra_stream_scope_check( - MockMatch([MockNode("s1"), MockNode("s1")])) is True + assert extra_stream_scope_check(MockMatch([MockNode("s1"), MockNode("s1")])) is True # Test 3: mixed non-default → FAIL - assert extra_stream_scope_check( - MockMatch([MockNode("s1"), MockNode("s2")])) is False + assert extra_stream_scope_check(MockMatch([MockNode("s1"), MockNode("s2")])) is False # Test 4: default + non-default → FAIL - assert extra_stream_scope_check( - MockMatch([MockNode(None), MockNode("s1")])) is False + assert extra_stream_scope_check(MockMatch([MockNode(None), MockNode("s1")])) is False # Test 5: empty → OK assert extra_stream_scope_check(MockMatch([])) is True diff --git a/tests/ut/conftest.py b/tests/ut/conftest.py index 77f0ec273d8..494769c9bf7 100644 --- a/tests/ut/conftest.py +++ b/tests/ut/conftest.py @@ -20,17 +20,17 @@ triton_runtime = MagicMock() triton_runtime.driver.active.utils.get_device_properties.return_value = { - 'num_aic': 8, - 'num_vectorcore': 8, + "num_aic": 8, + "num_vectorcore": 8, } -sys.modules['triton.runtime'] = triton_runtime +sys.modules["triton.runtime"] = triton_runtime from vllm_ascend.utils import adapt_patch # noqa E402 from vllm_ascend.utils import register_ascend_customop # noqa E402 # triton and torch_npu is not available in the environment, so we need to mock them -sys.modules['torch_npu'].npu.current_device = MagicMock(return_value=0) -sys.modules['torch_npu._inductor'] = MagicMock() +sys.modules["torch_npu"].npu.current_device = MagicMock(return_value=0) +sys.modules["torch_npu._inductor"] = MagicMock() adapt_patch() adapt_patch(True) diff --git a/tests/ut/core/test_profiling_chunk.py b/tests/ut/core/test_profiling_chunk.py index c36a7bc0224..4d97710ec12 100644 --- a/tests/ut/core/test_profiling_chunk.py +++ b/tests/ut/core/test_profiling_chunk.py @@ -14,30 +14,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Any, Dict, Optional from unittest.mock import MagicMock, patch import torch -from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig, VllmConfig) +from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig from vllm.sampling_params import SamplingParams from vllm.utils.hashing import sha256 -from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, - init_none_hash) -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec) +from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request from vllm.v1.structured_output import StructuredOutputManager from tests.ut.base import TestBase -from vllm_ascend.ascend_config import (ProfilingChunkConfig, - clear_ascend_config, init_ascend_config) -from vllm_ascend.core.profiling_chunk_predictor import (ChunkSizePredictor, - ProfilingChunkManager) -from vllm_ascend.core.scheduler_profiling_chunk import \ - ProfilingChunkScheduler - +from vllm_ascend.ascend_config import ProfilingChunkConfig, clear_ascend_config, init_ascend_config +from vllm_ascend.core.profiling_chunk_predictor import ChunkSizePredictor, ProfilingChunkManager +from vllm_ascend.core.scheduler_profiling_chunk import ProfilingChunkScheduler MODEL = "Qwen/Qwen3-0.6B" BLOCK_SIZE = 16 @@ -63,10 +55,7 @@ def create_requests(num_requests, num_tokens=10, max_tokens=16): def make_output(scheduler): req_ids = [req.request_id for req in scheduler.running] - req_id_to_index = { - req.request_id: i - for i, req in enumerate(scheduler.running) - } + req_id_to_index = {req.request_id: i for i, req in enumerate(scheduler.running)} sampled_token_ids = [[1000]] * len(scheduler.running) return ModelRunnerOutput( req_ids=req_ids, @@ -84,7 +73,6 @@ def make_output(scheduler): class TestProfilingChunkConfig(TestBase): - def test_default_values(self): cfg = ProfilingChunkConfig() self.assertFalse(cfg.enabled) @@ -144,10 +132,9 @@ def test_disabled_without_pp_ok(self, _mock): class TestChunkSizePredictor(TestBase): - @staticmethod def _make_data(a, b, c, seq_lens): - return [a * l * l + b * l + c for l in seq_lens] + return [a * seq_len * seq_len + b * seq_len + c for seq_len in seq_lens] def test_fit_and_predict(self): predictor = ChunkSizePredictor() @@ -158,8 +145,7 @@ def test_fit_and_predict(self): predictor.set_target_latency(8192) predictor.is_ready = True - chunk = predictor.predict( - num_computed_tokens=0, base_chunk_size=8192, page_size=128) + chunk = predictor.predict(num_computed_tokens=0, base_chunk_size=8192, page_size=128) self.assertIsNotNone(chunk) self.assertEqual(chunk % 128, 0) @@ -204,7 +190,6 @@ def test_fit_chunk_and_predict_with_history(self): class TestProfilingChunkManager(TestBase): - def test_not_ready_before_profiling(self): mgr = ProfilingChunkManager(base_chunk_size=8192, page_size=128) self.assertFalse(mgr.is_ready) @@ -213,7 +198,7 @@ def test_not_ready_before_profiling(self): def test_run_profiling_success(self): mgr = ProfilingChunkManager(base_chunk_size=8192, page_size=128) seq_lens = list(range(64, 8256, 128)) - latencies = [1e-6 * l * l + 0.01 * l + 1.0 for l in seq_lens] + latencies = [1e-6 * seq_len * seq_len + 0.01 * seq_len + 1.0 for seq_len in seq_lens] self.assertTrue(mgr.predictor.fit(seq_lens, latencies)) mgr.predictor.set_target_latency(8192) mgr.predictor.is_ready = True @@ -233,15 +218,14 @@ def test_run_profiling_all_fail(self): def test_record_batch_refines_model(self): mgr = ProfilingChunkManager(base_chunk_size=8192, page_size=128) seq_lens = list(range(64, 8256, 128)) - latencies = [1e-6 * l * l + 0.01 * l + 1.0 for l in seq_lens] + latencies = [1e-6 * seq_len * seq_len + 0.01 * seq_len + 1.0 for seq_len in seq_lens] mgr.predictor.fit(seq_lens, latencies) mgr.predictor.set_target_latency(8192) mgr.predictor.is_ready = True mgr._profiling_done = True for i in range(10): - mgr.record_batch_execution_time( - [(4096 - i * 100, i * 500)], 0.05 + i * 0.01) + mgr.record_batch_execution_time([(4096 - i * 100, i * 500)], 0.05 + i * 0.01) self.assertGreaterEqual(len(mgr.chunked_fit_data), 10) self.assertTrue(mgr.history_ready) @@ -252,7 +236,6 @@ def test_record_batch_refines_model(self): class TestProfilingChunkScheduler(TestBase): - @patch("vllm_ascend.ascend_config.AscendConfig.__init__", MagicMock(return_value=None)) @patch("vllm_ascend.ascend_config.get_ascend_config") @patch("vllm.config.ModelConfig.__post_init__", MagicMock()) @@ -262,8 +245,7 @@ def create_scheduler(self, mock_get_ascend_config): profiling_cfg.enabled = True profiling_cfg.smooth_factor = 0.8 profiling_cfg.min_chunk = 256 - mock_get_ascend_config.return_value = MagicMock( - profiling_chunk_config=profiling_cfg) + mock_get_ascend_config.return_value = MagicMock(profiling_chunk_config=profiling_cfg) mock_hf_config = MagicMock() mock_hf_config.model_type = "qwen3" @@ -295,7 +277,8 @@ def create_scheduler(self, mock_get_ascend_config): scheduler_config.chunked_prefill_enabled = True cache_config = CacheConfig( - block_size=BLOCK_SIZE, gpu_memory_utilization=0.9, + block_size=BLOCK_SIZE, + gpu_memory_utilization=0.9, cache_dtype="auto", ) @@ -306,6 +289,7 @@ def create_scheduler(self, mock_get_ascend_config): ) vllm_config.parallel_config.pipeline_parallel_size = 2 from unittest.mock import PropertyMock + type(model_config).is_encoder_decoder = PropertyMock(return_value=False) vllm_config.model_config.hf_config.is_encoder_decoder = False @@ -314,13 +298,8 @@ def create_scheduler(self, mock_get_ascend_config): kv_cache_tensors=[], kv_cache_groups=[ KVCacheGroupSpec( - ['layer'], - FullAttentionSpec( - block_size=BLOCK_SIZE, - num_kv_heads=1, - head_size=1, - dtype=torch.float32 - ) + ["layer"], + FullAttentionSpec(block_size=BLOCK_SIZE, num_kv_heads=1, head_size=1, dtype=torch.float32), ) ], ) @@ -408,8 +387,7 @@ def test_schedule_chunked_prefill_running(self): mock_executor.collective_rpc.return_value = [10.0] scheduler.run_profiling_chunk_init(mock_executor) - requests = create_requests(num_requests=1, num_tokens=2000, - max_tokens=16) + requests = create_requests(num_requests=1, num_tokens=2000, max_tokens=16) for req in requests: scheduler.add_request(req) diff --git a/tests/ut/core/test_scheduler_dynamic_batch.py b/tests/ut/core/test_scheduler_dynamic_batch.py index 8d52e35bdcb..f97611fdc6f 100644 --- a/tests/ut/core/test_scheduler_dynamic_batch.py +++ b/tests/ut/core/test_scheduler_dynamic_batch.py @@ -1,20 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Dict, List, Optional, Tuple +from typing import Any from unittest.mock import MagicMock, patch import torch -from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, - SchedulerConfig, SpeculativeConfig, VllmConfig) -from vllm.multimodal.inputs import (MultiModalFeatureSpec, - MultiModalKwargsItem, PlaceholderRange) +from vllm.config import CacheConfig, KVTransferConfig, ModelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig +from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItem, PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.utils.hashing import sha256 -from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, - init_none_hash) +from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec) +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.structured_output import StructuredOutputManager @@ -36,18 +32,17 @@ def create_requests( num_requests: int, num_tokens: int = 10, - mm_positions: Optional[list[PlaceholderRange]] = None, + mm_positions: list[PlaceholderRange] | None = None, max_tokens: int = 16, - stop_token_ids: Optional[list[int]] = None, + stop_token_ids: list[int] | None = None, block_size: int = 3, hash_fn=sha256, ): init_none_hash(hash_fn) prompt_logprobs = PROMPT_LOGPROBS - sampling_params = SamplingParams(ignore_eos=False, - max_tokens=max_tokens, - stop_token_ids=stop_token_ids, - prompt_logprobs=prompt_logprobs) + sampling_params = SamplingParams( + ignore_eos=False, max_tokens=max_tokens, stop_token_ids=stop_token_ids, prompt_logprobs=prompt_logprobs + ) requests = [] for i in range(num_requests): mm_features = [] @@ -59,26 +54,25 @@ def create_requests( data=MultiModalKwargsItem.dummy("dummy_m"), mm_position=position, identifier=identifier, - modality="image") + modality="image", + ) mm_features.append(mm_feature) - request = Request(request_id=f"{i}", - prompt_token_ids=[i] * num_tokens, - sampling_params=sampling_params, - eos_token_id=EOS_TOKEN_ID, - pooling_params=None, - mm_features=mm_features if mm_features else None, - block_hasher=get_request_block_hasher( - block_size, hash_fn)) + request = Request( + request_id=f"{i}", + prompt_token_ids=[i] * num_tokens, + sampling_params=sampling_params, + eos_token_id=EOS_TOKEN_ID, + pooling_params=None, + mm_features=mm_features if mm_features else None, + block_hasher=get_request_block_hasher(block_size, hash_fn), + ) requests.append(request) return requests def make_output(scheduler): req_ids = [req.request_id for req in scheduler.running] - req_id_to_index = { - req.request_id: i - for i, req in enumerate(scheduler.running) - } + req_id_to_index = {req.request_id: i for i, req in enumerate(scheduler.running)} sampled_token_ids = [[1000]] * len(scheduler.running) logprobs = None @@ -95,10 +89,9 @@ def make_output(scheduler): class TestSchedulerDynamicBatch(TestBase): - @patch("vllm.config.ModelConfig.__post_init__", MagicMock()) @patch("vllm.config.VllmConfig.__post_init__", MagicMock()) - @patch('vllm.v1.core.sched.scheduler.compute_encoder_budget') + @patch("vllm.v1.core.sched.scheduler.compute_encoder_budget") def create_scheduler(self, mock_compute_encoder_budget): mock_compute_encoder_budget.return_value = [100, 100] use_kv_connector = False @@ -133,11 +126,9 @@ def create_scheduler(self, mock_compute_encoder_budget): model_config.hf_text_config = MagicMock() model_config.hf_text_config.is_encoder_decoder = False # Cache config, optionally force APC - kwargs_cache: Dict[str, - Any] = ({} if ENABLE_PREFIX_CACHING is None else { - 'enable_prefix_caching': - ENABLE_PREFIX_CACHING - }) + kwargs_cache: dict[str, Any] = ( + {} if ENABLE_PREFIX_CACHING is None else {"enable_prefix_caching": ENABLE_PREFIX_CACHING} + ) cache_config = CacheConfig( block_size=block_size, gpu_memory_utilization=0.9, @@ -146,16 +137,19 @@ def create_scheduler(self, mock_compute_encoder_budget): **kwargs_cache, ) - kv_transfer_config = KVTransferConfig( - kv_connector="SharedStorageConnector", - kv_role="kv_both", - kv_connector_extra_config={"shared_storage_path": "local_storage"}, - ) if use_kv_connector else None + kv_transfer_config = ( + KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={"shared_storage_path": "local_storage"}, + ) + if use_kv_connector + else None + ) - speculative_config: Optional[SpeculativeConfig] = None + speculative_config: SpeculativeConfig | None = None if NUM_SPECULATIVE_TOKENS is not None: - speculative_config = SpeculativeConfig( - model="ngram", num_speculative_tokens=NUM_SPECULATIVE_TOKENS) + speculative_config = SpeculativeConfig(model="ngram", num_speculative_tokens=NUM_SPECULATIVE_TOKENS) vllm_config = VllmConfig( scheduler_config=scheduler_config, @@ -168,11 +162,7 @@ def create_scheduler(self, mock_compute_encoder_budget): kv_cache_config = KVCacheConfig( num_blocks=10000, # A large number of blocks to hold all requests kv_cache_tensors=[], - kv_cache_groups=[ - KVCacheGroupSpec(['layer'], - FullAttentionSpec(block_size, 1, 1, - torch.float32, False)) - ], + kv_cache_groups=[KVCacheGroupSpec(["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False))], ) kv_cache_config.hash_block_size = block_size cache_config.num_gpu_blocks = 10000 @@ -207,8 +197,7 @@ def test_finish_request(self): scheduler.add_request(request) for i, request in enumerate(requests): - scheduler.finish_requests(request.request_id, - RequestStatus.FINISHED_ABORTED) + scheduler.finish_requests(request.request_id, RequestStatus.FINISHED_ABORTED) self.assertNotIn(request.request_id, scheduler.requests) self.assertEqual(len(scheduler.waiting), 9 - i) @@ -219,15 +208,13 @@ def test_get_num_unfinished_requests(self): scheduler.add_request(request) for i, request in enumerate(requests): - scheduler.finish_requests(request.request_id, - RequestStatus.FINISHED_STOPPED) - self.assertEqual(scheduler.get_num_unfinished_requests(), - len(requests) - i - 1) + scheduler.finish_requests(request.request_id, RequestStatus.FINISHED_STOPPED) + self.assertEqual(scheduler.get_num_unfinished_requests(), len(requests) - i - 1) def test_schedule(self): - '''Test scheduling. + """Test scheduling. Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs - ''' + """ scheduler = self.create_scheduler() scheduler.scheduler_config.chunked_prefill_enabled = True requests = create_requests(num_requests=10) @@ -241,8 +228,7 @@ def test_schedule(self): self.assertEqual(len(output.finished_req_ids), 0) # Verify all requests are scheduled. for req_id, num_tokens in output.num_scheduled_tokens.items(): - self.assertEqual(num_tokens, - len(requests[int(req_id)].prompt_token_ids)) + self.assertEqual(num_tokens, len(requests[int(req_id)].prompt_token_ids)) # Verify requests moved from waiting to running self.assertEqual(len(scheduler.waiting), 0) @@ -253,8 +239,7 @@ def test_schedule(self): def test_schedule_multimodal_requests(self): scheduler = self.create_scheduler() scheduler.scheduler_config.chunked_prefill_enabled = True - mm_positions = [[PlaceholderRange(offset=i, length=10)] - for i in range(10)] + mm_positions = [[PlaceholderRange(offset=i, length=10)] for i in range(10)] requests = create_requests( num_requests=10, mm_positions=mm_positions, @@ -271,8 +256,7 @@ def test_schedule_multimodal_requests(self): # Verify all requests are scheduled. for req_id, num_tokens in output.num_scheduled_tokens.items(): - self.assertEqual(num_tokens, - len(requests[int(req_id)].prompt_token_ids)) + self.assertEqual(num_tokens, len(requests[int(req_id)].prompt_token_ids)) self.assertEqual(len(output.scheduled_encoder_inputs), len(requests)) for req_id, encoder_input in output.scheduled_encoder_inputs.items(): assert len(encoder_input) == 1 @@ -284,9 +268,9 @@ def test_schedule_multimodal_requests(self): self.assertEqual(scheduler.running[i], request) def test_schedule_enable_prefix_caching(self): - '''Test scheduling. + """Test scheduling. Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs - ''' + """ global ENABLE_PREFIX_CACHING ENABLE_PREFIX_CACHING = True global PROMPT_LOGPROBS @@ -304,8 +288,7 @@ def test_schedule_enable_prefix_caching(self): self.assertEqual(len(output.finished_req_ids), 0) # Verify all requests are scheduled. for req_id, num_tokens in output.num_scheduled_tokens.items(): - self.assertEqual(num_tokens, - len(requests[int(req_id)].prompt_token_ids)) + self.assertEqual(num_tokens, len(requests[int(req_id)].prompt_token_ids)) # Verify requests moved from waiting to running self.assertEqual(len(scheduler.waiting), 0) @@ -327,39 +310,31 @@ def test_stop_via_update_from_output(self): scheduler.running.append(req) req.status = RequestStatus.RUNNING - scheduler_output = SchedulerOutput(scheduled_new_reqs=[], - scheduled_cached_reqs=[], - num_scheduled_tokens={ - requests[0].request_id: 1, - requests[1].request_id: 2 - }, - total_num_scheduled_tokens=3, - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: [], - requests[1].request_id: [10] - }, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_mm_hashes=[]) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={requests[0].request_id: 1, requests[1].request_id: 2}, + total_num_scheduled_tokens=3, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={requests[0].request_id: [], requests[1].request_id: [10]}, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_mm_hashes=[], + ) model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], - req_id_to_index={ - req.request_id: i - for i, req in enumerate(requests) - }, - sampled_token_ids=[[EOS_TOKEN_ID], [10, 11] - ], # First request hits EOS, second continues + req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, + sampled_token_ids=[[EOS_TOKEN_ID], [10, 11]], # First request hits EOS, second continues logprobs=None, prompt_logprobs_dict={}, - pooler_output=[]) + pooler_output=[], + ) scheduler.update_from_output(scheduler_output, model_output) # Verify first request stopped, second continues self.assertEqual(len(scheduler.running), 1) - self.assertEqual(scheduler.running[0].request_id, - requests[1].request_id) + self.assertEqual(scheduler.running[0].request_id, requests[1].request_id) self.assertEqual(requests[0].status, RequestStatus.FINISHED_STOPPED) self.assertIn(requests[0].request_id, scheduler.finished_req_ids) self.assertEqual(list(requests[0].output_token_ids), [EOS_TOKEN_ID]) @@ -368,49 +343,38 @@ def test_stop_via_update_from_output(self): # Test case 2: Stop on custom stop token NUM_SPECULATIVE_TOKENS = 2 scheduler = self.create_scheduler() - requests = create_requests(num_requests=2, - max_tokens=10, - stop_token_ids=[42, 43]) + requests = create_requests(num_requests=2, max_tokens=10, stop_token_ids=[42, 43]) for req in requests: req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) req.status = RequestStatus.RUNNING - scheduler_output = SchedulerOutput(scheduled_new_reqs=[], - scheduled_cached_reqs=[], - num_scheduled_tokens={ - requests[0].request_id: 3, - requests[1].request_id: 2 - }, - total_num_scheduled_tokens=5, - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: - [10, 42], - requests[1].request_id: [13] - }, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_mm_hashes=[]) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={requests[0].request_id: 3, requests[1].request_id: 2}, + total_num_scheduled_tokens=5, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={requests[0].request_id: [10, 42], requests[1].request_id: [13]}, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_mm_hashes=[], + ) model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], - req_id_to_index={ - req.request_id: i - for i, req in enumerate(requests) - }, - sampled_token_ids=[[10, 42, 12], - [13, 14]], # First request hits stop token + req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, + sampled_token_ids=[[10, 42, 12], [13, 14]], # First request hits stop token logprobs=None, prompt_logprobs_dict={}, - pooler_output=[]) + pooler_output=[], + ) scheduler.update_from_output(scheduler_output, model_output) # Verify first request stopped on custom token self.assertEqual(len(scheduler.running), 1) - self.assertEqual(scheduler.running[0].request_id, - requests[1].request_id) + self.assertEqual(scheduler.running[0].request_id, requests[1].request_id) self.assertEqual(requests[0].status, RequestStatus.FINISHED_STOPPED) self.assertEqual(requests[0].stop_reason, 42) self.assertIn(requests[0].request_id, scheduler.finished_req_ids) @@ -427,41 +391,31 @@ def test_stop_via_update_from_output(self): scheduler.running.append(req) req.status = RequestStatus.RUNNING - scheduler_output = SchedulerOutput(scheduled_new_reqs=[], - scheduled_cached_reqs=[], - num_scheduled_tokens={ - requests[0].request_id: 3, - requests[1].request_id: 1 - }, - total_num_scheduled_tokens=4, - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: - [10, 11], - requests[1].request_id: [] - }, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_mm_hashes=[]) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={requests[0].request_id: 3, requests[1].request_id: 1}, + total_num_scheduled_tokens=4, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={requests[0].request_id: [10, 11], requests[1].request_id: []}, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_mm_hashes=[], + ) model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], - req_id_to_index={ - req.request_id: i - for i, req in enumerate(requests) - }, - sampled_token_ids=[[10, 11, 12], - [13]], # First request exceeds max_tokens + req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, + sampled_token_ids=[[10, 11, 12], [13]], # First request exceeds max_tokens logprobs=None, prompt_logprobs_dict={}, - pooler_output=[]) + pooler_output=[], + ) scheduler.update_from_output(scheduler_output, model_output) # Verify first request stopped due to length self.assertEqual(len(scheduler.running), 1) - self.assertEqual(scheduler.running[0].request_id, - requests[1].request_id) - self.assertEqual(requests[0].status, - RequestStatus.FINISHED_LENGTH_CAPPED) + self.assertEqual(scheduler.running[0].request_id, requests[1].request_id) + self.assertEqual(requests[0].status, RequestStatus.FINISHED_LENGTH_CAPPED) self.assertIn(requests[0].request_id, scheduler.finished_req_ids) self.assertEqual(list(requests[0].output_token_ids), [10, 11]) self.assertEqual(list(requests[1].output_token_ids), [13]) @@ -480,27 +434,26 @@ def test_stop_via_update_from_output(self): num_scheduled_tokens={requests[0].request_id: 3}, total_num_scheduled_tokens=3, scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: [EOS_TOKEN_ID, 10] - }, + scheduled_spec_decode_tokens={requests[0].request_id: [EOS_TOKEN_ID, 10]}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_mm_hashes=[]) + free_encoder_mm_hashes=[], + ) model_output = ModelRunnerOutput( req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], logprobs=None, prompt_logprobs_dict={}, - pooler_output=[]) + pooler_output=[], + ) scheduler.update_from_output(scheduler_output, model_output) # Verify request continues past EOS self.assertEqual(len(scheduler.running), 1) self.assertFalse(requests[0].is_finished()) - self.assertEqual(list(requests[0].output_token_ids), - [EOS_TOKEN_ID, 10, 11]) + self.assertEqual(list(requests[0].output_token_ids), [EOS_TOKEN_ID, 10, 11]) def test_schedule_concurrent_batches(self): global MAX_NUM_BATCHED_TOKENS @@ -530,17 +483,13 @@ def test_schedule_concurrent_batches(self): scheduler.add_request(requests[0]) scheduler_output0 = scheduler.schedule() self.assertEqual(len(scheduler_output0.scheduled_new_reqs), 1) - self.assertEqual( - scheduler_output0.num_scheduled_tokens[requests[0].request_id], - 512) + self.assertEqual(scheduler_output0.num_scheduled_tokens[requests[0].request_id], 512) # The first request is still running, so only schedule the second request. scheduler.add_request(requests[1]) scheduler_output1 = scheduler.schedule() self.assertEqual(len(scheduler_output1.scheduled_new_reqs), 1) - self.assertEqual( - scheduler_output1.num_scheduled_tokens[requests[1].request_id], - 512) + self.assertEqual(scheduler_output1.num_scheduled_tokens[requests[1].request_id], 512) # Model output of the first request. model_runner_output = ModelRunnerOutput( @@ -549,10 +498,10 @@ def test_schedule_concurrent_batches(self): sampled_token_ids=[[0]], logprobs=None, prompt_logprobs_dict={}, - pooler_output=[]) + pooler_output=[], + ) - scheduler.update_from_output(scheduler_output0, - model_runner_output) + scheduler.update_from_output(scheduler_output0, model_runner_output) # Schedule the next step. # The first request can be scheduled again while the second @@ -565,10 +514,10 @@ def test_schedule_concurrent_batches(self): sampled_token_ids=[[0]], logprobs=None, prompt_logprobs_dict={}, - pooler_output=[]) + pooler_output=[], + ) - scheduler.update_from_output(scheduler_output1, - model_runner_output) + scheduler.update_from_output(scheduler_output1, model_runner_output) def test_schedule_spec_decoding_stats(self): """Test scheduling behavior with speculative decoding. @@ -577,20 +526,30 @@ def test_schedule_spec_decoding_stats(self): 1. Speculated tokens get scheduled correctly 2. Spec decoding stats properly count number of draft and accepted tokens """ - spec_tokens_list: List[List[List[int]]] = [[[1, 2, 3]], [[1, 2, 3]], - [[1, 2], [3]], [[1]], [[]], - [[1, 2, 3], [4, 5, 6]]] - output_tokens_list: List[List[List[int]]] = [[[1, 2, 3, 4]], [[1, 5]], - [[1, 2, 5], [3, 4]], - [[1, 2]], [[5]], - [[1, 2, 7], [4, 8]]] - expected_list: List[Tuple[int, int, - int, List[int]]] = [(1, 3, 3, [1, 1, 1]), - (1, 3, 1, [1, 0, 0]), - (2, 3, 3, [2, 1]), - (1, 1, 1, [1]), - (0, 0, 0, [0]), - (2, 6, 3, [2, 1, 0])] + spec_tokens_list: list[list[list[int]]] = [ + [[1, 2, 3]], + [[1, 2, 3]], + [[1, 2], [3]], + [[1]], + [[]], + [[1, 2, 3], [4, 5, 6]], + ] + output_tokens_list: list[list[list[int]]] = [ + [[1, 2, 3, 4]], + [[1, 5]], + [[1, 2, 5], [3, 4]], + [[1, 2]], + [[5]], + [[1, 2, 7], [4, 8]], + ] + expected_list: list[tuple[int, int, int, list[int]]] = [ + (1, 3, 3, [1, 1, 1]), + (1, 3, 1, [1, 0, 0]), + (2, 3, 3, [2, 1]), + (1, 1, 1, [1]), + (0, 0, 0, [0]), + (2, 6, 3, [2, 1, 0]), + ] global NUM_SPECULATIVE_TOKENS for idx in range(len(spec_tokens_list)): @@ -600,8 +559,7 @@ def test_schedule_spec_decoding_stats(self): num_spec_tokens = max(1, max(len(t) for t in spec_tokens)) NUM_SPECULATIVE_TOKENS = num_spec_tokens scheduler = self.create_scheduler() - requests = create_requests(num_requests=len(spec_tokens), - num_tokens=1) + requests = create_requests(num_requests=len(spec_tokens), num_tokens=1) req_ids = [] req_to_index = {} for i, request in enumerate(requests): @@ -624,11 +582,11 @@ def test_schedule_spec_decoding_stats(self): sampled_token_ids=[[0] for _ in range(len(requests))], logprobs=None, prompt_logprobs_dict={}, - pooler_output=[]) + pooler_output=[], + ) draft_token_ids = DraftTokenIds(req_ids, spec_tokens) - engine_core_outputs = scheduler.update_from_output( - output, model_runner_output) + engine_core_outputs = scheduler.update_from_output(output, model_runner_output) scheduler.update_draft_token_ids(draft_token_ids) for i in range(len(requests)): @@ -638,33 +596,25 @@ def test_schedule_spec_decoding_stats(self): # The prompt token and the sampled token self.assertEqual(running_req.num_tokens, 2) # The prompt token, the sampled token, and the speculated tokens - self.assertEqual(running_req.num_tokens_with_spec, - 2 + len(spec_tokens[i])) + self.assertEqual(running_req.num_tokens_with_spec, 2 + len(spec_tokens[i])) # No draft or accepted tokens counted yet self.assertTrue( - not engine_core_outputs - or (engine_core_outputs[0].scheduler_stats.spec_decoding_stats - is None)) + not engine_core_outputs or (engine_core_outputs[0].scheduler_stats.spec_decoding_stats is None) + ) # Schedule the speculated tokens for validation output = scheduler.schedule() self.assertEqual(len(output.scheduled_new_reqs), 0) # The sampled token and speculated tokens - self.assertEqual( - output.total_num_scheduled_tokens, - len(requests) + sum(len(ids) for ids in spec_tokens)) + self.assertEqual(output.total_num_scheduled_tokens, len(requests) + sum(len(ids) for ids in spec_tokens)) for i in range(len(requests)): req_id = requests[i].request_id - self.assertEqual(output.num_scheduled_tokens[req_id], - 1 + len(spec_tokens[i])) + self.assertEqual(output.num_scheduled_tokens[req_id], 1 + len(spec_tokens[i])) if spec_tokens[i]: - self.assertEqual( - len(output.scheduled_spec_decode_tokens[req_id]), - len(spec_tokens[i])) + self.assertEqual(len(output.scheduled_spec_decode_tokens[req_id]), len(spec_tokens[i])) else: - self.assertNotIn(req_id, - output.scheduled_spec_decode_tokens) + self.assertNotIn(req_id, output.scheduled_spec_decode_tokens) model_runner_output = ModelRunnerOutput( req_ids=req_ids, @@ -672,13 +622,12 @@ def test_schedule_spec_decoding_stats(self): sampled_token_ids=output_tokens, logprobs=None, prompt_logprobs_dict={}, - pooler_output=[]) + pooler_output=[], + ) - engine_core_outputs = scheduler.update_from_output( - output, model_runner_output) + engine_core_outputs = scheduler.update_from_output(output, model_runner_output) - scheduler_stats = engine_core_outputs[0].scheduler_stats \ - if engine_core_outputs else None + scheduler_stats = engine_core_outputs[0].scheduler_stats if engine_core_outputs else None if expected[0] == 0: self.assertIsNone(scheduler_stats.spec_decoding_stats) else: @@ -687,8 +636,7 @@ def test_schedule_spec_decoding_stats(self): self.assertEqual(stats.num_drafts, expected[0]) self.assertEqual(stats.num_draft_tokens, expected[1]) self.assertEqual(stats.num_accepted_tokens, expected[2]) - self.assertEqual(stats.num_accepted_tokens_per_pos, - expected[3]) + self.assertEqual(stats.num_accepted_tokens_per_pos, expected[3]) def assert_scheduler_empty(self, scheduler): """Confirm the scheduler is "empty" - i.e. no leaks.""" @@ -704,17 +652,10 @@ def assert_scheduler_empty(self, scheduler): self.assertEqual(len(scheduler.encoder_cache_manager.cached), 0) # KVCache Manager. - self.assertEqual( - len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. - req_to_blocks), 0) - self.assertEqual( - len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. - num_cached_block), 0) - num_free_blocks = (scheduler.kv_cache_manager.block_pool. - free_block_queue.num_free_blocks) - self.assertEqual( - num_free_blocks, - scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) + self.assertEqual(len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks), 0) + self.assertEqual(len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].num_cached_block), 0) + num_free_blocks = scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks + self.assertEqual(num_free_blocks, scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) # NOTE(rob): just the ref count on blocks will be 0. The hash # value, etc will remain since we lazily evict for prefix cache. @@ -727,9 +668,7 @@ def test_memory_leak(self): NUM_REQUESTS = 5 NUM_TOKENS = 10 MAX_TOKENS = 10 - requests = create_requests(num_requests=NUM_REQUESTS, - num_tokens=NUM_TOKENS, - max_tokens=MAX_TOKENS) + requests = create_requests(num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, max_tokens=MAX_TOKENS) # Add each request. for request in requests: diff --git a/tests/ut/device_allocator/test_camem.py b/tests/ut/device_allocator/test_camem.py index 41953934aef..a4bfacc48ea 100644 --- a/tests/ut/device_allocator/test_camem.py +++ b/tests/ut/device_allocator/test_camem.py @@ -19,11 +19,14 @@ import torch from tests.ut.base import PytestBase -from vllm_ascend.device_allocator.camem import (AllocationData, CaMemAllocator, - create_and_map, - find_loaded_library, - get_pluggable_allocator, - unmap_and_release) +from vllm_ascend.device_allocator.camem import ( + AllocationData, + CaMemAllocator, + create_and_map, + find_loaded_library, + get_pluggable_allocator, + unmap_and_release, +) def dummy_malloc(args): @@ -35,7 +38,6 @@ def dummy_free(ptr): class TestCaMem(PytestBase): - def test_find_loaded_library_success_and_not_found(self): path = find_loaded_library("libc") assert path is not None, "Expected to find libc library" @@ -45,34 +47,34 @@ def test_find_loaded_library_success_and_not_found(self): path = find_loaded_library("non_existent_library") assert path is None, "Expected to not find non-existent library" - @pytest.mark.parametrize("handle", [ - (1, 2, 3), - ("device", 99), - (None, ), - ]) + @pytest.mark.parametrize( + "handle", + [ + (1, 2, 3), + ("device", 99), + (None,), + ], + ) def test_create_and_map_calls_python_create_and_map(self, handle): - with patch("vllm_ascend.device_allocator.camem.python_create_and_map" - ) as mock_create: + with patch("vllm_ascend.device_allocator.camem.python_create_and_map") as mock_create: create_and_map(handle) mock_create.assert_called_once_with(*handle) - @pytest.mark.parametrize("handle", [ - (42, "bar"), - ("foo", ), - ]) + @pytest.mark.parametrize( + "handle", + [ + (42, "bar"), + ("foo",), + ], + ) def test_unmap_and_release_calls_python_unmap_and_release(self, handle): - with patch( - "vllm_ascend.device_allocator.camem.python_unmap_and_release" - ) as mock_release: + with patch("vllm_ascend.device_allocator.camem.python_unmap_and_release") as mock_release: unmap_and_release(handle) mock_release.assert_called_once_with(*handle) @patch("vllm_ascend.device_allocator.camem.init_module") - @patch( - "vllm_ascend.device_allocator.camem.torch.npu.memory.NPUPluggableAllocator" - ) - def test_get_pluggable_allocator(self, mock_allocator_class, - mock_init_module): + @patch("vllm_ascend.device_allocator.camem.torch.npu.memory.NPUPluggableAllocator") + def test_get_pluggable_allocator(self, mock_allocator_class, mock_init_module): mock_allocator_instance = MagicMock() mock_allocator_class.return_value = mock_allocator_instance @@ -133,12 +135,11 @@ def test_sleep_offload_and_discard(self, mock_memcpy, mock_unmap): def mock_torch_empty(*args, **kwargs): # If pin_memory was explicitly set to True, change it to False - if 'pin_memory' in kwargs and kwargs['pin_memory'] is True: - kwargs['pin_memory'] = False + if "pin_memory" in kwargs and kwargs["pin_memory"] is True: + kwargs["pin_memory"] = False return original_torch_empty(*args, **kwargs) - with patch("vllm_ascend.device_allocator.camem.torch.empty", - side_effect=mock_torch_empty): + with patch("vllm_ascend.device_allocator.camem.torch.empty", side_effect=mock_torch_empty): allocator.sleep(offload_tags="tag1") # only offload tag1, other tag2 call unmap_and_release @@ -151,8 +152,7 @@ def mock_torch_empty(*args, **kwargs): @patch("vllm_ascend.device_allocator.camem.create_and_map") @patch("vllm_ascend.device_allocator.camem.memcpy") - def test_wake_up_loads_and_clears_cpu_backup(self, mock_memcpy, - mock_create_and_map): + def test_wake_up_loads_and_clears_cpu_backup(self, mock_memcpy, mock_create_and_map): allocator = CaMemAllocator.get_instance() handle = (1, 10, 1000, 0) @@ -175,9 +175,7 @@ def test_use_memory_pool_context_manager(self): mock_ctx.__enter__.return_value = "data" mock_ctx.__exit__.return_value = None - with patch( - "vllm_ascend.device_allocator.camem.use_memory_pool_with_allocator", - return_value=mock_ctx): + with patch("vllm_ascend.device_allocator.camem.use_memory_pool_with_allocator", return_value=mock_ctx): with allocator.use_memory_pool(tag="my_tag"): assert allocator.current_tag == "my_tag" # restore old tag after context manager exits diff --git a/tests/ut/device_allocator/test_cpu_binding.py b/tests/ut/device_allocator/test_cpu_binding.py index 9c4ba1b4551..e0c9e833fe9 100644 --- a/tests/ut/device_allocator/test_cpu_binding.py +++ b/tests/ut/device_allocator/test_cpu_binding.py @@ -43,21 +43,20 @@ def make_cpu_alloc(rank_id=0): class TestDeviceInfo(unittest.TestCase): - - @patch('vllm_ascend.cpu_binding.subprocess.Popen') + @patch("vllm_ascend.cpu_binding.subprocess.Popen") def test_execute_command(self, mock_popen): process = MagicMock() - process.communicate.return_value = (b'command-output', b'') + process.communicate.return_value = (b"command-output", b"") process.returncode = 7 mock_popen.return_value.__enter__.return_value = process - output, return_code = cpu_binding_module.execute_command(['dummy', 'cmd']) + output, return_code = cpu_binding_module.execute_command(["dummy", "cmd"]) - self.assertEqual(output, 'command-output') + self.assertEqual(output, "command-output") self.assertEqual(return_code, 7) mock_popen.assert_called_once() args, kwargs = mock_popen.call_args - self.assertEqual(args[0], ['dummy', 'cmd']) + self.assertEqual(args[0], ["dummy", "cmd"]) self.assertEqual(kwargs["shell"], False) self.assertEqual(kwargs["stdout"], subprocess.PIPE) self.assertEqual(kwargs["stderr"], subprocess.PIPE) @@ -65,24 +64,24 @@ def test_execute_command(self, mock_popen): self.assertEqual(kwargs["env"]["LANG"], "C") self.assertEqual(kwargs["env"]["LC_MESSAGES"], "C") - @patch('vllm_ascend.cpu_binding.subprocess.Popen') + @patch("vllm_ascend.cpu_binding.subprocess.Popen") def test_execute_command_kills_timed_out_process(self, mock_popen): process = MagicMock() process.communicate.side_effect = [ - subprocess.TimeoutExpired(cmd=['dummy', 'cmd'], timeout=1000), - (b'command-output', b''), + subprocess.TimeoutExpired(cmd=["dummy", "cmd"], timeout=1000), + (b"command-output", b""), ] process.returncode = -9 mock_popen.return_value.__enter__.return_value = process - output, return_code = cpu_binding_module.execute_command(['dummy', 'cmd']) + output, return_code = cpu_binding_module.execute_command(["dummy", "cmd"]) - self.assertEqual(output, 'command-output') + self.assertEqual(output, "command-output") self.assertEqual(return_code, -9) process.kill.assert_called_once_with() self.assertEqual(process.communicate.call_count, 2) - @patch('vllm_ascend.cpu_binding.execute_command') + @patch("vllm_ascend.cpu_binding.execute_command") def setUp(self, mock_execute_command): mock_execute_command.side_effect = [ ("NPU ID Chip ID Chip Logic ID Chip Name\n0 0 0 Ascend\n0 1 - Mcu\n1 0 1 Ascend", 0), @@ -91,15 +90,15 @@ def setUp(self, mock_execute_command): ] self.device_info = DeviceInfo() - @patch('vllm_ascend.cpu_binding.execute_command') + @patch("vllm_ascend.cpu_binding.execute_command") def test_get_npu_map_info(self, mock_execute_command): execute_result_list = [ ("NPU ID Chip ID Chip Logic ID Chip Phy-ID Chip Name\n0 0 0 0 Ascend\n0 1 1 1 Ascend\n0 2 - - Mcu", 0), ("NPU ID Chip ID Chip Logic ID Chip Name\n8 0 0 Ascend\n8 1 - Mcu\n9 0 1 Ascend", 0), ] result_list = [ - {'0': {'0': '0', '1': '1'}}, - {'8': {'0': '0'}, '9': {'0': '1'}}, + {"0": {"0": "0", "1": "1"}}, + {"8": {"0": "0"}, "9": {"0": "1"}}, ] for result in execute_result_list: mock_execute_command.return_value = result @@ -107,7 +106,7 @@ def test_get_npu_map_info(self, mock_execute_command): expected = result_list.pop(0) self.assertEqual(npu_map_info, expected) - @patch('vllm_ascend.cpu_binding.execute_command') + @patch("vllm_ascend.cpu_binding.execute_command") def test_get_running_npus(self, mock_execute_command): mock_execute_command.side_effect = [ ("| NPU Chip | Process id |\n| 0 1 | 1236 | vllm | 56000 |", 0), @@ -121,11 +120,11 @@ def test_get_running_npus(self, mock_execute_command): running_npus = self.device_info.get_running_npus() self.assertEqual(len(running_npus), 1) - @patch('vllm_ascend.cpu_binding.ASCEND_RT_VISIBLE_DEVICES', '1,5') - @patch('vllm_ascend.cpu_binding.execute_command') + @patch("vllm_ascend.cpu_binding.ASCEND_RT_VISIBLE_DEVICES", "1,5") + @patch("vllm_ascend.cpu_binding.execute_command") def test_get_running_npus_filters_invalid_rows_and_visible_devices(self, mock_execute_command): device_info = object.__new__(DeviceInfo) - device_info.npu_map_info = {'0': {'0': '0', '1': '1'}} + device_info.npu_map_info = {"0": {"0": "0", "1": "1"}} mock_execute_command.return_value = ( "ignored before header\n" "| NPU Chip | Process id |\n" @@ -138,34 +137,29 @@ def test_get_running_npus_filters_invalid_rows_and_visible_devices(self, mock_ex self.assertEqual(device_info.get_running_npus(), [1]) - @patch('vllm_ascend.cpu_binding.execute_command') + @patch("vllm_ascend.cpu_binding.execute_command") def test_get_running_npus_skips_non_pipe_rows_inside_process_section(self, mock_execute_command): device_info = object.__new__(DeviceInfo) - device_info.npu_map_info = {'0': {'0': '0'}} + device_info.npu_map_info = {"0": {"0": "0"}} mock_execute_command.return_value = ( - "| NPU Chip | Process id |\n" - "separator row\n" - "| 0 0 | 1234 | vllm |", + "| NPU Chip | Process id |\nseparator row\n| 0 0 | 1234 | vllm |", 0, ) self.assertEqual(device_info.get_running_npus(), [0]) - @patch('vllm_ascend.cpu_binding.execute_command') + @patch("vllm_ascend.cpu_binding.execute_command") def test_parse_topo_affinity(self, mock_execute_command): mock_execute_command.return_value = ("NPU0 X HCCS HCCS HCCS HCCS HCCS HCCS HCCS 0-3", 0) affinity = self.device_info.parse_topo_affinity() expected = {0: [0, 1, 2, 3]} self.assertEqual(affinity, expected) - @patch('vllm_ascend.cpu_binding.execute_command') + @patch("vllm_ascend.cpu_binding.execute_command") def test_parse_topo_affinity_skips_affinity_header_and_non_npu_rows(self, mock_execute_command): device_info = object.__new__(DeviceInfo) mock_execute_command.return_value = ( - "HEADER\n" - "NPU Chip Affinity\n" - "not-an-npu row\n" - "NPU0 x x x 2-3", + "HEADER\nNPU Chip Affinity\nnot-an-npu row\nNPU0 x x x 2-3", 0, ) @@ -182,20 +176,20 @@ def test_get_all_logic_npus(self): def test_get_all_logic_npus_filters_invalid_values(self): device_info = object.__new__(DeviceInfo) device_info.npu_map_info = { - '0': {'0': '0', '1': '', '2': 'abc'}, - '1': {'0': '2'}, + "0": {"0": "0", "1": "", "2": "abc"}, + "1": {"0": "2"}, } self.assertEqual(device_info.get_all_logic_npus(), [0, 2]) - @patch('vllm_ascend.cpu_binding.os.path.exists', return_value=False) + @patch("vllm_ascend.cpu_binding.os.path.exists", return_value=False) def test_parse_allowed_cpus_returns_empty_when_status_file_missing(self, _mock_exists): device_info = object.__new__(DeviceInfo) self.assertEqual(device_info.parse_allowed_cpus(), []) - @patch('vllm_ascend.cpu_binding.os.path.exists', return_value=True) - @patch('builtins.open', new_callable=mock_open, read_data='Name:\tpython\nState:\tR\n') + @patch("vllm_ascend.cpu_binding.os.path.exists", return_value=True) + @patch("builtins.open", new_callable=mock_open, read_data="Name:\tpython\nState:\tR\n") def test_parse_allowed_cpus_raises_when_field_missing(self, _mock_open, _mock_exists): device_info = object.__new__(DeviceInfo) @@ -204,8 +198,7 @@ def test_parse_allowed_cpus_raises_when_field_missing(self, _mock_open, _mock_ex class TestCpuAlloc(unittest.TestCase): - - @patch('vllm_ascend.cpu_binding.execute_command') + @patch("vllm_ascend.cpu_binding.execute_command") def setUp(self, mock_execute_command): mock_execute_command.side_effect = [ ("NPU ID Chip ID Chip Logic ID Chip Name\n0 0 0 Ascend\n0 1 - Mcu\n1 0 1 Ascend", 0), @@ -216,7 +209,7 @@ def setUp(self, mock_execute_command): def test_average_distribute(self): self.cpu_alloc.npu_cpu_pool = {0: [10, 11, 12, 13], 1: [10, 11, 12, 13]} - groups = {'[10, 11, 12, 13]': [0, 1]} + groups = {"[10, 11, 12, 13]": [0, 1]} result = self.cpu_alloc.average_distribute(groups) self.assertEqual(result, {0: [10, 11], 1: [12, 13]}) @@ -225,36 +218,43 @@ def test_average_distribute(self): 1: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], 2: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], } - groups = {'[0, 1, 2, 3, 4, 5]': [0, 1, 2]} + groups = {"[0, 1, 2, 3, 4, 5]": [0, 1, 2]} result = self.cpu_alloc.average_distribute(groups) - self.assertEqual(result, { - 0: [0, 1, 2, 3], - 1: [4, 5, 6, 7], - 2: [8, 9, 10, 11, 12, 13], - }) + self.assertEqual( + result, + { + 0: [0, 1, 2, 3], + 1: [4, 5, 6, 7], + 2: [8, 9, 10, 11, 12, 13], + }, + ) - @patch('vllm_ascend.cpu_binding.get_ascend_device_type') + @patch("vllm_ascend.cpu_binding.get_ascend_device_type") def test_binding_mode_table(self, mock_get_device_type): mock_get_device_type.return_value = AscendDeviceType.A2 - self.assertEqual(self.cpu_alloc._binding_mode(), 'topo_affinity') + self.assertEqual(self.cpu_alloc._binding_mode(), "topo_affinity") mock_get_device_type.return_value = AscendDeviceType.A3 - self.assertEqual(self.cpu_alloc._binding_mode(), 'global_slice') + self.assertEqual(self.cpu_alloc._binding_mode(), "global_slice") - @patch('vllm_ascend.cpu_binding.get_ascend_device_type') + @patch("vllm_ascend.cpu_binding.get_ascend_device_type") def test_build_cpu_pools_fallback_to_global_slice(self, mock_get_device_type): mock_get_device_type.return_value = AscendDeviceType.A2 self.cpu_alloc.device_info.npu_affinity = {} - with patch.object(self.cpu_alloc, 'build_cpu_node_map') as mock_build_cpu_node_map, \ - patch.object(self.cpu_alloc, 'build_global_slice_cpu_pool') as mock_build_global_slice_cpu_pool: + with ( + patch.object(self.cpu_alloc, "build_cpu_node_map") as mock_build_cpu_node_map, + patch.object(self.cpu_alloc, "build_global_slice_cpu_pool") as mock_build_global_slice_cpu_pool, + ): self.cpu_alloc.build_cpu_pools() mock_build_cpu_node_map.assert_called_once() mock_build_global_slice_cpu_pool.assert_called_once() - @patch('vllm_ascend.cpu_binding.get_ascend_device_type') + @patch("vllm_ascend.cpu_binding.get_ascend_device_type") def test_build_cpu_pools_global_slice_mode(self, mock_get_device_type): mock_get_device_type.return_value = AscendDeviceType.A3 - with patch.object(self.cpu_alloc, 'build_cpu_node_map') as mock_build_cpu_node_map, \ - patch.object(self.cpu_alloc, 'build_global_slice_cpu_pool') as mock_build_global_slice_cpu_pool: + with ( + patch.object(self.cpu_alloc, "build_cpu_node_map") as mock_build_cpu_node_map, + patch.object(self.cpu_alloc, "build_global_slice_cpu_pool") as mock_build_global_slice_cpu_pool, + ): self.cpu_alloc.build_cpu_pools() mock_build_cpu_node_map.assert_called_once() mock_build_global_slice_cpu_pool.assert_called_once() @@ -271,12 +271,12 @@ def test_extend_numa(self): result = self.cpu_alloc.extend_numa([0, 1]) self.assertEqual(result, [0, 1, 3]) - @patch('vllm_ascend.cpu_binding.execute_command') + @patch("vllm_ascend.cpu_binding.execute_command") def test_build_cpu_node_map(self, mock_execute_command): - mock_execute_command.return_value = ('', 0) + mock_execute_command.return_value = ("", 0) with self.assertRaises(RuntimeError): self.cpu_alloc.build_cpu_node_map() - mock_execute_command.return_value = ('0 0\n1 1\n2 0\n3 1', 0) + mock_execute_command.return_value = ("0 0\n1 1\n2 0\n3 1", 0) self.cpu_alloc.build_cpu_node_map() expected_cpu_node = {0: 0, 1: 1, 2: 0, 3: 1} expected_numa_to_cpu_map = {0: [0, 2], 1: [1, 3]} @@ -339,7 +339,7 @@ def test_build_global_slice_cpu_pool_returns_when_running_or_allowed_empty(self) self.cpu_alloc.build_global_slice_cpu_pool() self.assertEqual(self.cpu_alloc.npu_cpu_pool, {}) - @patch('vllm_ascend.cpu_binding.execute_command') + @patch("vllm_ascend.cpu_binding.execute_command") def test_allocate(self, _mock_execute_command): self.cpu_alloc.device_info.running_npu_list = [0] self.cpu_alloc.npu_cpu_pool = {0: [0, 1, 2, 3, 4]} @@ -351,9 +351,9 @@ def test_allocate(self, _mock_execute_command): with self.assertRaises(RuntimeError): self.cpu_alloc.allocate() - @patch('vllm_ascend.cpu_binding.execute_command') + @patch("vllm_ascend.cpu_binding.execute_command") def test_bind_threads(self, mock_execute_command): - thread_message = '1234 1234 ? 00:00:03 acl_thread\n4567 4567 ? 00:00:03 release_thread' + thread_message = "1234 1234 ? 00:00:03 acl_thread\n4567 4567 ? 00:00:03 release_thread" mock_execute_command.return_value = (thread_message, 0) self.cpu_alloc.device_info.running_npu_list = [0] self.cpu_alloc.assign_main = {0: [0, 1]} @@ -362,70 +362,69 @@ def test_bind_threads(self, mock_execute_command): self.cpu_alloc.bind_threads() mock_execute_command.assert_called() - @patch('vllm_ascend.cpu_binding.get_ascend_device_type') - @patch('vllm_ascend.cpu_binding.os.listdir') - @patch('builtins.open', new_callable=mock_open, read_data='123: 0 0 0 0 sq_send_trigger_irq\n') - @patch('vllm_ascend.cpu_binding.shutil.which') - @patch('vllm_ascend.cpu_binding.os.access') - @patch('vllm_ascend.cpu_binding.execute_command') - def test_bind_npu_irq_a3_uses_card_chip_mapping(self, mock_execute_command, mock_access, - mock_which, _mock_open, mock_listdir, - mock_get_device_type): + @patch("vllm_ascend.cpu_binding.get_ascend_device_type") + @patch("vllm_ascend.cpu_binding.os.listdir") + @patch("builtins.open", new_callable=mock_open, read_data="123: 0 0 0 0 sq_send_trigger_irq\n") + @patch("vllm_ascend.cpu_binding.shutil.which") + @patch("vllm_ascend.cpu_binding.os.access") + @patch("vllm_ascend.cpu_binding.execute_command") + def test_bind_npu_irq_a3_uses_card_chip_mapping( + self, mock_execute_command, mock_access, mock_which, _mock_open, mock_listdir, mock_get_device_type + ): mock_access.return_value = True mock_which.return_value = None mock_listdir.side_effect = FileNotFoundError mock_get_device_type.return_value = AscendDeviceType.A3 - mock_execute_command.return_value = ('PCIe Bus Info 0000:03:00.0', 0) + mock_execute_command.return_value = ("PCIe Bus Info 0000:03:00.0", 0) self.cpu_alloc.rank_id = 0 self.cpu_alloc.device_info.running_npu_list = [3] self.cpu_alloc.npu_cpu_pool = {3: [0, 1, 2, 3, 4]} self.cpu_alloc.bind_npu_irq() - mock_execute_command.assert_any_call(['npu-smi', 'info', '-t', 'board', '-i', '1', '-c', '1']) + mock_execute_command.assert_any_call(["npu-smi", "info", "-t", "board", "-i", "1", "-c", "1"]) class TestCpuBindingSupplemental(unittest.TestCase): - def test_cpu_to_mask_handles_single_and_multi_group_masks(self): - self.assertEqual(CpuAlloc.cpu_to_mask(3), '00000008') - self.assertEqual(CpuAlloc.cpu_to_mask(35), '00000008,00000000') + self.assertEqual(CpuAlloc.cpu_to_mask(3), "00000008") + self.assertEqual(CpuAlloc.cpu_to_mask(35), "00000008,00000000") def test_get_threads_map_skips_irrelevant_lines(self): thread_message = ( - 'bad-line\n' - '123 456 ? 00:00:01 acl_thread\n' - '123 789 ? 00:00:01 release_thread\n' - '123 999 ? 00:00:01 worker_thread\n' - '555 666 ? 00:00:01 acl_thread' + "bad-line\n" + "123 456 ? 00:00:01 acl_thread\n" + "123 789 ? 00:00:01 release_thread\n" + "123 999 ? 00:00:01 worker_thread\n" + "555 666 ? 00:00:01 acl_thread" ) self.assertEqual( CpuAlloc.get_threads_map(thread_message), { - '123': {'acl_thread': ['456'], 'release_thread': ['789']}, - '555': {'acl_thread': ['666'], 'release_thread': []}, + "123": {"acl_thread": ["456"], "release_thread": ["789"]}, + "555": {"acl_thread": ["666"], "release_thread": []}, }, ) - @patch('vllm_ascend.cpu_binding.execute_command') + @patch("vllm_ascend.cpu_binding.execute_command") def test_bind_skips_empty_cpu_list(self, mock_execute_command): - CpuAlloc.bind('123', [], False) + CpuAlloc.bind("123", [], False) mock_execute_command.assert_not_called() - @patch('vllm_ascend.cpu_binding.execute_command', return_value=('ok', 0)) + @patch("vllm_ascend.cpu_binding.execute_command", return_value=("ok", 0)) def test_bind_uses_sub_thread_flag(self, mock_execute_command): - CpuAlloc.bind('123', [1, 2], True) + CpuAlloc.bind("123", [1, 2], True) - mock_execute_command.assert_called_once_with(['taskset', '-acp', '1,2', '123']) + mock_execute_command.assert_called_once_with(["taskset", "-acp", "1,2", "123"]) - @patch('vllm_ascend.cpu_binding.execute_command', return_value=('failed', 1)) + @patch("vllm_ascend.cpu_binding.execute_command", return_value=("failed", 1)) def test_bind_raises_for_failed_taskset(self, mock_execute_command): with self.assertRaises(RuntimeError): - CpuAlloc.bind('123', [1, 2], False) + CpuAlloc.bind("123", [1, 2], False) - mock_execute_command.assert_called_once_with(['taskset', '-cp', '1,2', '123']) + mock_execute_command.assert_called_once_with(["taskset", "-cp", "1,2", "123"]) def test_extend_numa_returns_original_list_when_multiple_nodes_present(self): cpu_alloc = make_cpu_alloc() @@ -433,45 +432,46 @@ def test_extend_numa_returns_original_list_when_multiple_nodes_present(self): self.assertEqual(cpu_alloc.extend_numa([0, 1]), [0, 1]) - @patch('vllm_ascend.cpu_binding.execute_command') + @patch("vllm_ascend.cpu_binding.execute_command") def test_build_cpu_node_map_skips_blank_and_header_rows(self, mock_execute_command): cpu_alloc = make_cpu_alloc() - mock_execute_command.return_value = ('CPU NODE\n\n0 0\n1 1', 0) + mock_execute_command.return_value = ("CPU NODE\n\n0 0\n1 1", 0) cpu_alloc.build_cpu_node_map() self.assertEqual(cpu_alloc.cpu_node, {0: 0, 1: 1}) self.assertEqual(cpu_alloc.numa_to_cpu_map, {0: [0], 1: [1]}) - @patch('vllm_ascend.cpu_binding.get_ascend_device_type', return_value='unknown') + @patch("vllm_ascend.cpu_binding.get_ascend_device_type", return_value="unknown") def test_binding_mode_defaults_to_topo_affinity_for_unknown_device(self, _mock_get_device_type): - self.assertEqual(CpuAlloc._binding_mode(), 'topo_affinity') + self.assertEqual(CpuAlloc._binding_mode(), "topo_affinity") - @patch('vllm_ascend.cpu_binding.get_ascend_device_type', return_value=AscendDeviceType.A2) + @patch("vllm_ascend.cpu_binding.get_ascend_device_type", return_value=AscendDeviceType.A2) def test_build_cpu_pools_raises_on_affinity_conflict(self, _mock_get_device_type): cpu_alloc = make_cpu_alloc() cpu_alloc.device_info.running_npu_list = [0] cpu_alloc.device_info.allowed_cpus = [8, 9] cpu_alloc.device_info.npu_affinity = {0: [0, 1]} - with patch.object(cpu_alloc, 'build_cpu_node_map'): - with self.assertRaises(RuntimeError): - cpu_alloc.build_cpu_pools() + with patch.object(cpu_alloc, "build_cpu_node_map"), self.assertRaises(RuntimeError): + cpu_alloc.build_cpu_pools() - @patch('vllm_ascend.cpu_binding.get_ascend_device_type', return_value=AscendDeviceType.A2) + @patch("vllm_ascend.cpu_binding.get_ascend_device_type", return_value=AscendDeviceType.A2) def test_build_cpu_pools_topo_mode_builds_and_splits_duplicate_groups(self, _mock_get_device_type): cpu_alloc = make_cpu_alloc() cpu_alloc.device_info.running_npu_list = [0, 1, 2] cpu_alloc.device_info.allowed_cpus = [0, 1, 2, 3] cpu_alloc.device_info.npu_affinity = {0: [0, 1], 1: [2, 3], 2: [2, 3]} - with patch.object(cpu_alloc, 'build_cpu_node_map'), \ - patch.object(cpu_alloc, 'extend_numa', side_effect=lambda cpus: cpus): + with ( + patch.object(cpu_alloc, "build_cpu_node_map"), + patch.object(cpu_alloc, "extend_numa", side_effect=lambda cpus: cpus), + ): cpu_alloc.build_cpu_pools() self.assertEqual(cpu_alloc.npu_cpu_pool, {0: [0, 1], 1: [2], 2: [3]}) - @patch('vllm_ascend.cpu_binding.logger.info') + @patch("vllm_ascend.cpu_binding.logger.info") def test_print_plan_handles_empty_release_assignment(self, mock_logger_info): cpu_alloc = make_cpu_alloc() cpu_alloc.device_info.running_npu_list = [1] @@ -484,7 +484,7 @@ def test_print_plan_handles_empty_release_assignment(self, mock_logger_info): self.assertEqual(mock_logger_info.call_count, 2) - @patch('vllm_ascend.cpu_binding.logger.info') + @patch("vllm_ascend.cpu_binding.logger.info") def test_print_plan_handles_non_empty_release_assignment(self, mock_logger_info): cpu_alloc = make_cpu_alloc() cpu_alloc.device_info.running_npu_list = [1] @@ -497,45 +497,48 @@ def test_print_plan_handles_non_empty_release_assignment(self, mock_logger_info) self.assertEqual(mock_logger_info.call_count, 2) - @patch('vllm_ascend.cpu_binding.shutil.which', return_value=None) - @patch('vllm_ascend.cpu_binding.execute_command') + @patch("vllm_ascend.cpu_binding.shutil.which", return_value=None) + @patch("vllm_ascend.cpu_binding.execute_command") def test_bind_memory_skips_when_migratepages_missing(self, mock_execute_command, _mock_which): cpu_alloc = make_cpu_alloc() - cpu_alloc.bind_memory('999', 0) + cpu_alloc.bind_memory("999", 0) mock_execute_command.assert_not_called() - @patch('vllm_ascend.cpu_binding.shutil.which', return_value='/usr/bin/migratepages') - @patch('vllm_ascend.cpu_binding.execute_command') + @patch("vllm_ascend.cpu_binding.shutil.which", return_value="/usr/bin/migratepages") + @patch("vllm_ascend.cpu_binding.execute_command") def test_bind_memory_skips_when_cpu_pool_or_numa_invalid(self, mock_execute_command, _mock_which): cpu_alloc = make_cpu_alloc() cpu_alloc.numa_to_cpu_map = {0: [0], 1: [1]} - cpu_alloc.bind_memory('1000', 0) + cpu_alloc.bind_memory("1000", 0) mock_execute_command.assert_not_called() cpu_alloc.npu_cpu_pool = {0: [8]} cpu_alloc.cpu_node = {8: 3} - cpu_alloc.bind_memory('1000', 0) + cpu_alloc.bind_memory("1000", 0) mock_execute_command.assert_not_called() - @patch('vllm_ascend.cpu_binding.shutil.which', return_value='/usr/bin/migratepages') - @patch('vllm_ascend.cpu_binding.execute_command') + @patch("vllm_ascend.cpu_binding.shutil.which", return_value="/usr/bin/migratepages") + @patch("vllm_ascend.cpu_binding.execute_command") def test_bind_memory_executes_on_valid_numa_target(self, mock_execute_command, _mock_which): cpu_alloc = make_cpu_alloc() cpu_alloc.npu_cpu_pool = {0: [8, 9]} cpu_alloc.cpu_node = {8: 1} cpu_alloc.numa_to_cpu_map = {0: [0], 1: [8, 9]} - cpu_alloc.bind_memory('1000', 0) + cpu_alloc.bind_memory("1000", 0) - mock_execute_command.assert_called_once_with(['migratepages', '1000', '0,1', '1']) + mock_execute_command.assert_called_once_with(["migratepages", "1000", "0,1", "1"]) - @patch('vllm_ascend.cpu_binding.psutil.Process') - @patch('vllm_ascend.cpu_binding.execute_command', return_value=( - '1000 2000 ? 00:00:01 acl_thread\n1000 3000 ? 00:00:01 release_thread', - 0, - )) + @patch("vllm_ascend.cpu_binding.psutil.Process") + @patch( + "vllm_ascend.cpu_binding.execute_command", + return_value=( + "1000 2000 ? 00:00:01 acl_thread\n1000 3000 ? 00:00:01 release_thread", + 0, + ), + ) def test_bind_threads_binds_main_acl_and_release_threads(self, _mock_execute_command, mock_process): cpu_alloc = make_cpu_alloc() cpu_alloc.device_info.running_npu_list = [0] @@ -544,29 +547,29 @@ def test_bind_threads_binds_main_acl_and_release_threads(self, _mock_execute_com cpu_alloc.assign_rel = {0: [4]} mock_process.return_value.pid = 1000 - with patch.object(cpu_alloc, 'bind') as mock_bind, patch.object(cpu_alloc, 'bind_memory') as mock_bind_memory: + with patch.object(cpu_alloc, "bind") as mock_bind, patch.object(cpu_alloc, "bind_memory") as mock_bind_memory: cpu_alloc.bind_threads() self.assertEqual( mock_bind.call_args_list, [ - call('1000', [1, 2], True), - call('2000', [3], False), - call('3000', [4], False), + call("1000", [1, 2], True), + call("2000", [3], False), + call("3000", [4], False), ], ) - mock_bind_memory.assert_called_once_with('1000', 0) + mock_bind_memory.assert_called_once_with("1000", 0) - @patch('vllm_ascend.cpu_binding.os.access', return_value=False) - @patch('vllm_ascend.cpu_binding.execute_command') + @patch("vllm_ascend.cpu_binding.os.access", return_value=False) + @patch("vllm_ascend.cpu_binding.execute_command") def test_bind_npu_irq_returns_when_irq_path_not_writable(self, mock_execute_command, _mock_access): cpu_alloc = make_cpu_alloc() cpu_alloc.bind_npu_irq() mock_execute_command.assert_not_called() - @patch('vllm_ascend.cpu_binding.os.access', return_value=True) - @patch('vllm_ascend.cpu_binding.execute_command') + @patch("vllm_ascend.cpu_binding.os.access", return_value=True) + @patch("vllm_ascend.cpu_binding.execute_command") def test_bind_npu_irq_returns_when_current_npu_has_no_cpu_pool(self, mock_execute_command, _mock_access): cpu_alloc = make_cpu_alloc() cpu_alloc.device_info.running_npu_list = [0] @@ -576,13 +579,14 @@ def test_bind_npu_irq_returns_when_current_npu_has_no_cpu_pool(self, mock_execut mock_execute_command.assert_not_called() - @patch('vllm_ascend.cpu_binding.get_ascend_device_type', return_value=AscendDeviceType.A2) - @patch('builtins.open', new_callable=mock_open, read_data='123: 0 0 0 sq_send_trigger_irq\n') - @patch('vllm_ascend.cpu_binding.shutil.which', return_value=None) - @patch('vllm_ascend.cpu_binding.os.access', return_value=True) - @patch('vllm_ascend.cpu_binding.execute_command') - def test_bind_npu_irq_skips_when_cpu_pool_too_small(self, mock_execute_command, _mock_access, - _mock_which, _mock_open, _mock_get_device_type): + @patch("vllm_ascend.cpu_binding.get_ascend_device_type", return_value=AscendDeviceType.A2) + @patch("builtins.open", new_callable=mock_open, read_data="123: 0 0 0 sq_send_trigger_irq\n") + @patch("vllm_ascend.cpu_binding.shutil.which", return_value=None) + @patch("vllm_ascend.cpu_binding.os.access", return_value=True) + @patch("vllm_ascend.cpu_binding.execute_command") + def test_bind_npu_irq_skips_when_cpu_pool_too_small( + self, mock_execute_command, _mock_access, _mock_which, _mock_open, _mock_get_device_type + ): cpu_alloc = make_cpu_alloc() cpu_alloc.device_info.running_npu_list = [0] cpu_alloc.npu_cpu_pool = {0: [7]} @@ -591,51 +595,51 @@ def test_bind_npu_irq_skips_when_cpu_pool_too_small(self, mock_execute_command, mock_execute_command.assert_not_called() - @patch('vllm_ascend.cpu_binding.get_ascend_device_type', return_value=AscendDeviceType.A2) - @patch('builtins.open', new_callable=mock_open, read_data='123: 0 0 0 sq_send_trigger_irq\n') - @patch('vllm_ascend.cpu_binding.shutil.which', return_value=None) - @patch('vllm_ascend.cpu_binding.os.access', return_value=True) - @patch('vllm_ascend.cpu_binding.execute_command', return_value=('board info without pci', 0)) - def test_bind_npu_irq_skips_when_pci_address_missing(self, mock_execute_command, _mock_access, - _mock_which, _mock_open, _mock_get_device_type): + @patch("vllm_ascend.cpu_binding.get_ascend_device_type", return_value=AscendDeviceType.A2) + @patch("builtins.open", new_callable=mock_open, read_data="123: 0 0 0 sq_send_trigger_irq\n") + @patch("vllm_ascend.cpu_binding.shutil.which", return_value=None) + @patch("vllm_ascend.cpu_binding.os.access", return_value=True) + @patch("vllm_ascend.cpu_binding.execute_command", return_value=("board info without pci", 0)) + def test_bind_npu_irq_skips_when_pci_address_missing( + self, mock_execute_command, _mock_access, _mock_which, _mock_open, _mock_get_device_type + ): cpu_alloc = make_cpu_alloc() cpu_alloc.device_info.running_npu_list = [0] cpu_alloc.npu_cpu_pool = {0: [7, 8]} cpu_alloc.bind_npu_irq() - mock_execute_command.assert_called_once_with(['npu-smi', 'info', '-t', 'board', '-i', '0']) - - @patch('vllm_ascend.cpu_binding.get_ascend_device_type', return_value=AscendDeviceType.A2) - @patch('vllm_ascend.cpu_binding.os.listdir', return_value=['456', '457']) - @patch('builtins.open', new_callable=mock_open, read_data='123: 0 0 0 sq_send_trigger_irq\n') - @patch('vllm_ascend.cpu_binding.shutil.which', return_value=None) - @patch('vllm_ascend.cpu_binding.os.access', return_value=True) - @patch('vllm_ascend.cpu_binding.execute_command', return_value=('prefix\nPCIe Bus Info 0000:03:00.0', 0)) - def test_bind_npu_irq_skips_when_sq_irq_not_found(self, _mock_execute_command, _mock_access, - _mock_which, _mock_open, _mock_listdir, - _mock_get_device_type): + mock_execute_command.assert_called_once_with(["npu-smi", "info", "-t", "board", "-i", "0"]) + + @patch("vllm_ascend.cpu_binding.get_ascend_device_type", return_value=AscendDeviceType.A2) + @patch("vllm_ascend.cpu_binding.os.listdir", return_value=["456", "457"]) + @patch("builtins.open", new_callable=mock_open, read_data="123: 0 0 0 sq_send_trigger_irq\n") + @patch("vllm_ascend.cpu_binding.shutil.which", return_value=None) + @patch("vllm_ascend.cpu_binding.os.access", return_value=True) + @patch("vllm_ascend.cpu_binding.execute_command", return_value=("prefix\nPCIe Bus Info 0000:03:00.0", 0)) + def test_bind_npu_irq_skips_when_sq_irq_not_found( + self, _mock_execute_command, _mock_access, _mock_which, _mock_open, _mock_listdir, _mock_get_device_type + ): cpu_alloc = make_cpu_alloc() cpu_alloc.device_info.running_npu_list = [0] cpu_alloc.npu_cpu_pool = {0: [7, 8]} cpu_alloc.bind_npu_irq() - @patch('vllm_ascend.cpu_binding.get_ascend_device_type', return_value=AscendDeviceType.A2) - @patch('vllm_ascend.cpu_binding.os.listdir', return_value=['123', '124']) - @patch('builtins.open', new_callable=mock_open, read_data='123: 0 0 0 sq_send_trigger_irq\n') - @patch('vllm_ascend.cpu_binding.shutil.which', return_value='/bin/systemctl') - @patch('vllm_ascend.cpu_binding.os.access', return_value=True) - @patch('vllm_ascend.cpu_binding.execute_command') - def test_bind_npu_irq_stops_irqbalance_and_writes_affinity_masks(self, mock_execute_command, - _mock_access, _mock_which, - mock_file, _mock_listdir, - _mock_get_device_type): + @patch("vllm_ascend.cpu_binding.get_ascend_device_type", return_value=AscendDeviceType.A2) + @patch("vllm_ascend.cpu_binding.os.listdir", return_value=["123", "124"]) + @patch("builtins.open", new_callable=mock_open, read_data="123: 0 0 0 sq_send_trigger_irq\n") + @patch("vllm_ascend.cpu_binding.shutil.which", return_value="/bin/systemctl") + @patch("vllm_ascend.cpu_binding.os.access", return_value=True) + @patch("vllm_ascend.cpu_binding.execute_command") + def test_bind_npu_irq_stops_irqbalance_and_writes_affinity_masks( + self, mock_execute_command, _mock_access, _mock_which, mock_file, _mock_listdir, _mock_get_device_type + ): mock_execute_command.side_effect = [ - ('irqbalance.service enabled\n', 0), - ('', 0), - ('stopped', 0), - ('prefix\nPCIe Bus Info 0000:03:00.0', 0), + ("irqbalance.service enabled\n", 0), + ("", 0), + ("stopped", 0), + ("prefix\nPCIe Bus Info 0000:03:00.0", 0), ] cpu_alloc = make_cpu_alloc() cpu_alloc.device_info.running_npu_list = [0] @@ -643,24 +647,24 @@ def test_bind_npu_irq_stops_irqbalance_and_writes_affinity_masks(self, mock_exec cpu_alloc.bind_npu_irq() - self.assertIn(call(['systemctl', 'stop', 'irqbalance']), mock_execute_command.call_args_list) - self.assertIn(call(['npu-smi', 'info', '-t', 'board', '-i', '0']), mock_execute_command.call_args_list) + self.assertIn(call(["systemctl", "stop", "irqbalance"]), mock_execute_command.call_args_list) + self.assertIn(call(["npu-smi", "info", "-t", "board", "-i", "0"]), mock_execute_command.call_args_list) handle = mock_file() - self.assertEqual(handle.write.call_args_list, [call('00000100'), call('00000200')]) - - @patch('vllm_ascend.cpu_binding.get_ascend_device_type', return_value=AscendDeviceType.A2) - @patch('vllm_ascend.cpu_binding.os.listdir', return_value=['123', '124']) - @patch('builtins.open', new_callable=mock_open, read_data='123: 0 0 0 sq_send_trigger_irq\n') - @patch('vllm_ascend.cpu_binding.shutil.which', return_value='/bin/systemctl') - @patch('vllm_ascend.cpu_binding.os.access', return_value=True) - @patch('vllm_ascend.cpu_binding.execute_command') - def test_bind_npu_irq_keeps_irqbalance_when_inactive(self, mock_execute_command, _mock_access, - _mock_which, _mock_open, _mock_listdir, - _mock_get_device_type): + self.assertEqual(handle.write.call_args_list, [call("00000100"), call("00000200")]) + + @patch("vllm_ascend.cpu_binding.get_ascend_device_type", return_value=AscendDeviceType.A2) + @patch("vllm_ascend.cpu_binding.os.listdir", return_value=["123", "124"]) + @patch("builtins.open", new_callable=mock_open, read_data="123: 0 0 0 sq_send_trigger_irq\n") + @patch("vllm_ascend.cpu_binding.shutil.which", return_value="/bin/systemctl") + @patch("vllm_ascend.cpu_binding.os.access", return_value=True) + @patch("vllm_ascend.cpu_binding.execute_command") + def test_bind_npu_irq_keeps_irqbalance_when_inactive( + self, mock_execute_command, _mock_access, _mock_which, _mock_open, _mock_listdir, _mock_get_device_type + ): mock_execute_command.side_effect = [ - ('irqbalance.service enabled\n', 0), - ('', 3), - ('prefix\nPCIe Bus Info 0000:03:00.0', 0), + ("irqbalance.service enabled\n", 0), + ("", 3), + ("prefix\nPCIe Bus Info 0000:03:00.0", 0), ] cpu_alloc = make_cpu_alloc() cpu_alloc.device_info.running_npu_list = [0] @@ -668,20 +672,20 @@ def test_bind_npu_irq_keeps_irqbalance_when_inactive(self, mock_execute_command, cpu_alloc.bind_npu_irq() - self.assertNotIn(call(['systemctl', 'stop', 'irqbalance']), mock_execute_command.call_args_list) - - @patch('vllm_ascend.cpu_binding.get_ascend_device_type', return_value=AscendDeviceType.A2) - @patch('vllm_ascend.cpu_binding.os.listdir', return_value=['123', '124']) - @patch('builtins.open', new_callable=mock_open, read_data='123: 0 0 0 sq_send_trigger_irq\n') - @patch('vllm_ascend.cpu_binding.shutil.which', return_value='/bin/systemctl') - @patch('vllm_ascend.cpu_binding.os.access', return_value=True) - @patch('vllm_ascend.cpu_binding.execute_command') - def test_bind_npu_irq_skips_irqbalance_handling_when_service_absent(self, mock_execute_command, _mock_access, - _mock_which, _mock_open, _mock_listdir, - _mock_get_device_type): + self.assertNotIn(call(["systemctl", "stop", "irqbalance"]), mock_execute_command.call_args_list) + + @patch("vllm_ascend.cpu_binding.get_ascend_device_type", return_value=AscendDeviceType.A2) + @patch("vllm_ascend.cpu_binding.os.listdir", return_value=["123", "124"]) + @patch("builtins.open", new_callable=mock_open, read_data="123: 0 0 0 sq_send_trigger_irq\n") + @patch("vllm_ascend.cpu_binding.shutil.which", return_value="/bin/systemctl") + @patch("vllm_ascend.cpu_binding.os.access", return_value=True) + @patch("vllm_ascend.cpu_binding.execute_command") + def test_bind_npu_irq_skips_irqbalance_handling_when_service_absent( + self, mock_execute_command, _mock_access, _mock_which, _mock_open, _mock_listdir, _mock_get_device_type + ): mock_execute_command.side_effect = [ - ('another.service enabled\n', 0), - ('prefix\nPCIe Bus Info 0000:03:00.0', 0), + ("another.service enabled\n", 0), + ("prefix\nPCIe Bus Info 0000:03:00.0", 0), ] cpu_alloc = make_cpu_alloc() cpu_alloc.device_info.running_npu_list = [0] @@ -689,21 +693,21 @@ def test_bind_npu_irq_skips_irqbalance_handling_when_service_absent(self, mock_e cpu_alloc.bind_npu_irq() - self.assertNotIn(call(['systemctl', 'is-active', '--quiet', 'irqbalance']), mock_execute_command.call_args_list) + self.assertNotIn(call(["systemctl", "is-active", "--quiet", "irqbalance"]), mock_execute_command.call_args_list) - @patch('vllm_ascend.cpu_binding.get_ascend_device_type', return_value=AscendDeviceType.A2) - @patch('vllm_ascend.cpu_binding.os.listdir', return_value=['123', '124']) + @patch("vllm_ascend.cpu_binding.get_ascend_device_type", return_value=AscendDeviceType.A2) + @patch("vllm_ascend.cpu_binding.os.listdir", return_value=["123", "124"]) @patch( - 'builtins.open', + "builtins.open", new_callable=mock_open, - read_data='100: 0 0 0 other_irq\n123: 0 0 0 sq_send_trigger_irq\n', + read_data="100: 0 0 0 other_irq\n123: 0 0 0 sq_send_trigger_irq\n", ) - @patch('vllm_ascend.cpu_binding.shutil.which', return_value=None) - @patch('vllm_ascend.cpu_binding.os.access', return_value=True) - @patch('vllm_ascend.cpu_binding.execute_command', return_value=('prefix\nPCIe Bus Info 0000:03:00.0', 0)) - def test_bind_npu_irq_scans_multiple_interrupt_lines(self, _mock_execute_command, _mock_access, - _mock_which, mock_file, _mock_listdir, - _mock_get_device_type): + @patch("vllm_ascend.cpu_binding.shutil.which", return_value=None) + @patch("vllm_ascend.cpu_binding.os.access", return_value=True) + @patch("vllm_ascend.cpu_binding.execute_command", return_value=("prefix\nPCIe Bus Info 0000:03:00.0", 0)) + def test_bind_npu_irq_scans_multiple_interrupt_lines( + self, _mock_execute_command, _mock_access, _mock_which, mock_file, _mock_listdir, _mock_get_device_type + ): cpu_alloc = make_cpu_alloc() cpu_alloc.device_info.running_npu_list = [0] cpu_alloc.npu_cpu_pool = {0: [8, 9, 10]} @@ -711,44 +715,45 @@ def test_bind_npu_irq_scans_multiple_interrupt_lines(self, _mock_execute_command cpu_alloc.bind_npu_irq() handle = mock_file() - self.assertEqual(handle.write.call_args_list, [call('00000100'), call('00000200')]) + self.assertEqual(handle.write.call_args_list, [call("00000100"), call("00000200")]) def test_run_all_invokes_steps_in_order(self): cpu_alloc = make_cpu_alloc() calls = [] - with patch.object(cpu_alloc, 'build_cpu_pools', side_effect=lambda: calls.append('build_cpu_pools')), \ - patch.object(cpu_alloc, 'allocate', side_effect=lambda: calls.append('allocate')), \ - patch.object(cpu_alloc, 'print_plan', side_effect=lambda: calls.append('print_plan')), \ - patch.object(cpu_alloc, 'bind_threads', side_effect=lambda: calls.append('bind_threads')), \ - patch.object(cpu_alloc, 'bind_npu_irq', side_effect=lambda: calls.append('bind_npu_irq')): + with ( + patch.object(cpu_alloc, "build_cpu_pools", side_effect=lambda: calls.append("build_cpu_pools")), + patch.object(cpu_alloc, "allocate", side_effect=lambda: calls.append("allocate")), + patch.object(cpu_alloc, "print_plan", side_effect=lambda: calls.append("print_plan")), + patch.object(cpu_alloc, "bind_threads", side_effect=lambda: calls.append("bind_threads")), + patch.object(cpu_alloc, "bind_npu_irq", side_effect=lambda: calls.append("bind_npu_irq")), + ): cpu_alloc.run_all() - self.assertEqual(calls, ['build_cpu_pools', 'allocate', 'print_plan', 'bind_threads', 'bind_npu_irq']) + self.assertEqual(calls, ["build_cpu_pools", "allocate", "print_plan", "bind_threads", "bind_npu_irq"]) class TestBindingSwitch(unittest.TestCase): - - @patch('vllm_ascend.cpu_binding.platform.machine') + @patch("vllm_ascend.cpu_binding.platform.machine") def test_is_arm_cpu(self, mock_machine): - mock_machine.return_value = 'x86_64' + mock_machine.return_value = "x86_64" self.assertFalse(is_arm_cpu()) - mock_machine.return_value = 'aarch64' + mock_machine.return_value = "aarch64" self.assertTrue(is_arm_cpu()) - mock_machine.return_value = 'armv8' + mock_machine.return_value = "armv8" self.assertTrue(is_arm_cpu()) - mock_machine.return_value = 'mips64' + mock_machine.return_value = "mips64" self.assertFalse(is_arm_cpu()) - @patch('vllm_ascend.cpu_binding.CpuAlloc') - @patch('vllm_ascend.cpu_binding.is_arm_cpu') + @patch("vllm_ascend.cpu_binding.CpuAlloc") + @patch("vllm_ascend.cpu_binding.is_arm_cpu") def test_bind_cpus_skip_non_arm(self, mock_is_arm_cpu, mock_cpu_alloc): mock_is_arm_cpu.return_value = False bind_cpus(0) mock_cpu_alloc.assert_not_called() - @patch('vllm_ascend.cpu_binding.CpuAlloc') - @patch('vllm_ascend.cpu_binding.is_arm_cpu', return_value=True) + @patch("vllm_ascend.cpu_binding.CpuAlloc") + @patch("vllm_ascend.cpu_binding.is_arm_cpu", return_value=True) def test_bind_cpus_runs_allocator_on_arm(self, _mock_is_arm_cpu, mock_cpu_alloc): bind_cpus(1) @@ -756,5 +761,5 @@ def test_bind_cpus_runs_allocator_on_arm(self, _mock_is_arm_cpu, mock_cpu_alloc) mock_cpu_alloc.return_value.run_all.assert_called_once_with() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/ut/distributed/device_communicators/test_pyhccl.py b/tests/ut/distributed/device_communicators/test_pyhccl.py index 626f1165764..b63e0713406 100644 --- a/tests/ut/distributed/device_communicators/test_pyhccl.py +++ b/tests/ut/distributed/device_communicators/test_pyhccl.py @@ -4,9 +4,7 @@ from vllm.distributed.utils import StatelessProcessGroup from tests.ut.base import TestBase -from vllm_ascend.distributed.device_communicators.pyhccl import \ - PyHcclCommunicator -from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ +from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator class MockHcclLib: @@ -18,7 +16,6 @@ class MockUniqueId: class TestPyHcclCommunicator(TestBase): - @patch.dict(os.environ, {"RANK": "0", "WORLD_SIZE": "1"}) def test_world_size_1_return_early(self): comm = PyHcclCommunicator( @@ -30,25 +27,17 @@ def test_world_size_1_return_early(self): @patch.dict(os.environ, {"RANK": "0", "WORLD_SIZE": "2"}) def test_load_hccl_fail(self): - comm = PyHcclCommunicator(group=StatelessProcessGroup( - 0, 2, None, None), - device="npu:0", - library_path="/not/exist/path/libhccl.so") + comm = PyHcclCommunicator( + group=StatelessProcessGroup(0, 2, None, None), device="npu:0", library_path="/not/exist/path/libhccl.so" + ) self.assertTrue(comm.disabled) - @patch( - "vllm_ascend.distributed.device_communicators.pyhccl_wrapper.HCCLLibrary", - MockHcclLib) - @patch( - "vllm_ascend.distributed.device_communicators.pyhccl_wrapper.hcclUniqueId", - MockUniqueId) + @patch("vllm_ascend.distributed.device_communicators.pyhccl_wrapper.HCCLLibrary", MockHcclLib) + @patch("vllm_ascend.distributed.device_communicators.pyhccl_wrapper.hcclUniqueId", MockUniqueId) @patch("torch.npu.device") - @patch("vllm_ascend.utils.current_stream", - return_value=MagicMock(npu_stream=5678)) + @patch("vllm_ascend.utils.current_stream", return_value=MagicMock(npu_stream=5678)) def test_stateless_group(self, *_): - group = StatelessProcessGroup(rank=3, - world_size=4, - store=None) + group = StatelessProcessGroup(rank=3, world_size=4, store=None) comm = PyHcclCommunicator(group=group, device=3) @@ -56,12 +45,8 @@ def test_stateless_group(self, *_): self.assertEqual(comm.world_size, 4) @patch.dict(os.environ, {"RANK": "1", "WORLD_SIZE": "2"}) - @patch( - "vllm_ascend.distributed.device_communicators.pyhccl_wrapper.HCCLLibrary", - MockHcclLib) - @patch( - "vllm_ascend.distributed.device_communicators.pyhccl_wrapper.hcclUniqueId", - MockUniqueId) + @patch("vllm_ascend.distributed.device_communicators.pyhccl_wrapper.HCCLLibrary", MockHcclLib) + @patch("vllm_ascend.distributed.device_communicators.pyhccl_wrapper.hcclUniqueId", MockUniqueId) @patch("torch.distributed.is_initialized", return_value=True) @patch("torch.distributed.get_backend", return_value="nccl") @patch("torch.distributed.get_rank", return_value=1) @@ -69,8 +54,7 @@ def test_stateless_group(self, *_): @patch("torch.distributed.get_process_group_ranks", return_value=[0, 1]) @patch("torch.distributed.broadcast") @patch("torch.npu.device") - @patch("vllm_ascend.utils.current_stream", - return_value=MagicMock(npu_stream=1234)) + @patch("vllm_ascend.utils.current_stream", return_value=MagicMock(npu_stream=1234)) def test_multi_gpu_pg_torch( self, *_, diff --git a/tests/ut/distributed/device_communicators/test_pyhccl_wrapper.py b/tests/ut/distributed/device_communicators/test_pyhccl_wrapper.py index ff905120a16..db07acc7826 100644 --- a/tests/ut/distributed/device_communicators/test_pyhccl_wrapper.py +++ b/tests/ut/distributed/device_communicators/test_pyhccl_wrapper.py @@ -5,13 +5,21 @@ from tests.ut.base import TestBase from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import ( - Function, HCCLLibrary, aclrtStream_t, buffer_type, hcclComm_t, - hcclDataType_t, hcclDataTypeEnum, hcclRedOp_t, hcclRedOpTypeEnum, - hcclResult_t, hcclUniqueId) + Function, + HCCLLibrary, + aclrtStream_t, + buffer_type, + hcclComm_t, + hcclDataType_t, + hcclDataTypeEnum, + hcclRedOp_t, + hcclRedOpTypeEnum, + hcclResult_t, + hcclUniqueId, +) class TestHcclUniqueId(TestBase): - def test_construct(self): uid = hcclUniqueId() uid.internal[0] = 12 @@ -20,7 +28,6 @@ def test_construct(self): class TestHcclDataTypeEnum(TestBase): - def test_torch_dtype_mapping(self): expected = { torch.int8: hcclDataTypeEnum.hcclInt8, @@ -35,8 +42,7 @@ def test_torch_dtype_mapping(self): for torch_dtype, expected_enum in expected.items(): with self.subTest(torch_dtype=torch_dtype): - self.assertEqual(hcclDataTypeEnum.from_torch(torch_dtype), - expected_enum) + self.assertEqual(hcclDataTypeEnum.from_torch(torch_dtype), expected_enum) def test_unsupported_dtype_raises(self): with self.assertRaises(ValueError): @@ -44,7 +50,6 @@ def test_unsupported_dtype_raises(self): class TestHcclRedOpTypeEnum(TestBase): - def test_torch_reduce_op_mapping(self): expected = { ReduceOp.SUM: hcclRedOpTypeEnum.hcclSum, @@ -55,8 +60,7 @@ def test_torch_reduce_op_mapping(self): for torch_op, expected_enum in expected.items(): with self.subTest(torch_op=torch_op): - self.assertEqual(hcclRedOpTypeEnum.from_torch(torch_op), - expected_enum) + self.assertEqual(hcclRedOpTypeEnum.from_torch(torch_op), expected_enum) def test_unsupported_op_raises(self): unsupported_op = "NOT_EXIST" @@ -65,7 +69,6 @@ def test_unsupported_op_raises(self): class TestFunction(TestBase): - def test_construct_with_valid_args(self): func = Function(name="foo", restype=int, argtypes=[int, str, float]) self.assertEqual(func.name, "foo") @@ -74,7 +77,6 @@ def test_construct_with_valid_args(self): class TestHCLLLibrary(TestBase): - def test_init_with_nonexistent_so(self): fake_path = "/definitely/not/exist/libhccl.so" with self.assertRaises(OSError): @@ -127,7 +129,6 @@ def test_hccl_comm_initRank(self, mock_hccl_check): @patch.object(HCCLLibrary, "HCCL_CHECK") def test_hccl_all_reduce(self, mock_hccl_check): - lib = HCCLLibrary.__new__(HCCLLibrary) lib._funcs = {"HcclAllReduce": MagicMock(return_value=0)} sendbuff = buffer_type() @@ -138,16 +139,13 @@ def test_hccl_all_reduce(self, mock_hccl_check): comm = hcclComm_t() stream = aclrtStream_t() - lib.hcclAllReduce(sendbuff, recvbuff, count, datatype, op, comm, - stream) + lib.hcclAllReduce(sendbuff, recvbuff, count, datatype, op, comm, stream) - lib._funcs["HcclAllReduce"].assert_called_once_with( - sendbuff, recvbuff, count, datatype, op, comm, stream) + lib._funcs["HcclAllReduce"].assert_called_once_with(sendbuff, recvbuff, count, datatype, op, comm, stream) mock_hccl_check.assert_called_once_with(0) @patch.object(HCCLLibrary, "HCCL_CHECK") def test_hccl_broad_cast(self, mock_hccl_check): - lib = HCCLLibrary.__new__(HCCLLibrary) lib._funcs = {"HcclBroadcast": MagicMock(return_value=0)} buff = buffer_type() @@ -159,8 +157,7 @@ def test_hccl_broad_cast(self, mock_hccl_check): lib.hcclBroadcast(buff, count, datatype, root, comm, stream) - lib._funcs["HcclBroadcast"].assert_called_once_with( - buff, count, datatype, root, comm, stream) + lib._funcs["HcclBroadcast"].assert_called_once_with(buff, count, datatype, root, comm, stream) mock_hccl_check.assert_called_once_with(0) @patch.object(HCCLLibrary, "HCCL_CHECK") diff --git a/tests/ut/distributed/mooncake/test_config_data.py b/tests/ut/distributed/mooncake/test_config_data.py index ac13dad16e0..0cea10c83c6 100644 --- a/tests/ut/distributed/mooncake/test_config_data.py +++ b/tests/ut/distributed/mooncake/test_config_data.py @@ -11,19 +11,19 @@ sys.modules["mooncake.store"] = fake_store from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.mooncake_backend import ( # noqa: E402 - _convert_to_bytes, _parse_global_segment_size) + _convert_to_bytes, + _parse_global_segment_size, +) class TestParseGlobalSegmentSize(unittest.TestCase): - def test_int_input(self): self.assertEqual(_parse_global_segment_size(1024), 1024) self.assertEqual(_parse_global_segment_size(0), 0) def test_gb_unit(self): self.assertEqual(_parse_global_segment_size("2GB"), 2 * 1024**3) - self.assertEqual(_parse_global_segment_size("1.5GB"), - int(1.5 * 1024**3)) + self.assertEqual(_parse_global_segment_size("1.5GB"), int(1.5 * 1024**3)) self.assertEqual(_parse_global_segment_size(" 2 GB "), 2 * 1024**3) def test_gb_unit_edge_cases(self): @@ -34,14 +34,12 @@ def test_gb_unit_edge_cases(self): def test_mb_unit(self): self.assertEqual(_parse_global_segment_size("512MB"), 512 * 1024**2) - self.assertEqual(_parse_global_segment_size("0.5MB"), - int(0.5 * 1024**2)) + self.assertEqual(_parse_global_segment_size("0.5MB"), int(0.5 * 1024**2)) self.assertEqual(_parse_global_segment_size("1024MB"), 1024 * 1024**2) def test_kb_unit(self): self.assertEqual(_parse_global_segment_size("256KB"), 256 * 1024) - self.assertEqual(_parse_global_segment_size("1.25KB"), - int(1.25 * 1024)) + self.assertEqual(_parse_global_segment_size("1.25KB"), int(1.25 * 1024)) def test_b_unit(self): self.assertEqual(_parse_global_segment_size("4096B"), 4096) @@ -63,11 +61,9 @@ def test_non_string_non_int_input(self): class TestConvertToBytes(unittest.TestCase): - def test_valid_conversion(self): self.assertEqual(_convert_to_bytes("10", 1, "10"), 10) - self.assertEqual(_convert_to_bytes("1.5", 1024, "1.5KB"), - int(1.5 * 1024)) + self.assertEqual(_convert_to_bytes("1.5", 1024, "1.5KB"), int(1.5 * 1024)) self.assertEqual(_convert_to_bytes("0", 1024**3, "0GB"), 0) def test_invalid_numbers(self): diff --git a/tests/ut/distributed/test_communicator.py b/tests/ut/distributed/test_communicator.py index f968d26ed2e..a31989be6fb 100644 --- a/tests/ut/distributed/test_communicator.py +++ b/tests/ut/distributed/test_communicator.py @@ -4,20 +4,14 @@ import torch import torch.distributed as dist -from vllm_ascend.distributed.device_communicators.npu_communicator import \ - NPUCommunicator +from vllm_ascend.distributed.device_communicators.npu_communicator import NPUCommunicator class TestNPUCommunicator(unittest.TestCase): - @patch("vllm.config.get_current_vllm_config", return_value=None) @patch("torch.npu.current_device", return_value=MagicMock()) @patch("torch.npu.set_device", return_value=MagicMock()) - @patch("torch.distributed.get_process_group_ranks", - return_value={ - 0: 0, - 1: 1 - }) + @patch("torch.distributed.get_process_group_ranks", return_value={0: 0, 1: 1}) @patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1}) @patch("torch.distributed.is_initialized", return_value=True) @patch("torch.distributed.get_rank", return_value=1) @@ -28,15 +22,8 @@ class TestNPUCommunicator(unittest.TestCase): @patch("torch.distributed.get_process_group_ranks", return_value=[0, 1]) @patch("torch.npu.device") def test_all_to_all_with_sizes(self, *_): - - def patched_all_to_all(output_tensor_list, - input_tensor_list, - group=None, - async_op=False): - output_tensor_list[:] = ([ - torch.tensor([10, 20]), - torch.tensor([50, 60]) - ]) + def patched_all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False): + output_tensor_list[:] = [torch.tensor([10, 20]), torch.tensor([50, 60])] torch.distributed.all_to_all = patched_all_to_all @@ -47,20 +34,14 @@ def patched_all_to_all(output_tensor_list, with patch.dict(dist.distributed_c10d._world.pg_map, {dist.group.WORLD: MagicMock()}, clear=False): comm = NPUCommunicator(cpu_group=dist.group.WORLD) - output = comm.all_to_all(input_, - scatter_sizes=scatter_sizes, - gather_sizes=gather_sizes) + output = comm.all_to_all(input_, scatter_sizes=scatter_sizes, gather_sizes=gather_sizes) assert output.tolist() == [10, 20, 50, 60] @patch("vllm.config.get_current_vllm_config", return_value=None) @patch("torch.npu.current_device", return_value=MagicMock()) @patch("torch.npu.set_device", return_value=MagicMock()) - @patch("torch.distributed.get_process_group_ranks", - return_value={ - 0: 0, - 1: 1 - }) + @patch("torch.distributed.get_process_group_ranks", return_value={0: 0, 1: 1}) @patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1}) @patch("torch.distributed.is_initialized", return_value=True) @patch("torch.distributed.get_rank", return_value=1) @@ -71,15 +52,8 @@ def patched_all_to_all(output_tensor_list, @patch("torch.distributed.get_process_group_ranks", return_value=[0, 1]) @patch("torch.npu.device") def test_all_to_all_without_sizes(self, *_): - - def patched_all_to_all(output_tensor_list, - input_tensor_list, - group=None, - async_op=False): - output_tensor_list[:] = ([ - torch.tensor([[10, 20]]), - torch.tensor([[50, 60]]) - ]) + def patched_all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False): + output_tensor_list[:] = [torch.tensor([[10, 20]]), torch.tensor([[50, 60]])] torch.distributed.all_to_all = patched_all_to_all diff --git a/tests/ut/distributed/test_parallel_state.py b/tests/ut/distributed/test_parallel_state.py index 30914efa3d4..d94f66972fa 100644 --- a/tests/ut/distributed/test_parallel_state.py +++ b/tests/ut/distributed/test_parallel_state.py @@ -4,10 +4,21 @@ from vllm.config import ParallelConfig from vllm_ascend.distributed.parallel_state import ( - _FLASHCOMM2_ODP, _FLASHCOMM2_OTP, _LMTP, _MC2, _OTP, _P_TP, - destroy_ascend_model_parallel, get_flashcomm2_odp_group, - get_flashcomm2_otp_group, get_lmhead_tp_group, get_mc2_group, - get_otp_group, get_p_tp_group, init_ascend_model_parallel) + _FLASHCOMM2_ODP, + _FLASHCOMM2_OTP, + _LMTP, + _MC2, + _OTP, + _P_TP, + destroy_ascend_model_parallel, + get_flashcomm2_odp_group, + get_flashcomm2_otp_group, + get_lmhead_tp_group, + get_mc2_group, + get_otp_group, + get_p_tp_group, + init_ascend_model_parallel, +) @pytest.fixture @@ -21,11 +32,13 @@ def parallel_config(): @pytest.fixture def mock_distributed(): - with patch('torch.distributed.is_initialized', return_value=True), \ - patch('torch.distributed.get_world_size', return_value=16), \ - patch('torch.distributed.get_backend', return_value='nccl'), \ - patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group, \ - patch('vllm_ascend.distributed.parallel_state.get_tp_group') as mock_tp_group: + with ( + patch("torch.distributed.is_initialized", return_value=True), + patch("torch.distributed.get_world_size", return_value=16), + patch("torch.distributed.get_backend", return_value="nccl"), + patch("vllm_ascend.distributed.parallel_state.get_world_group") as mock_group, + patch("vllm_ascend.distributed.parallel_state.get_tp_group") as mock_tp_group, + ): mock_group.return_value.local_rank = 0 mock_group.return_value.device_group = MagicMock() mock_tp_group.return_value.world_size = 4 @@ -47,12 +60,14 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config): mock_envs_ascend = MagicMock() mock_envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE = 2 mock_envs_ascend.VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL = 0 - with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \ - patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \ - patch('vllm_ascend.distributed.parallel_state.get_current_vllm_config', return_value=mock_vllm_config), \ - patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config), \ - patch('vllm_ascend.utils.envs_ascend', new=mock_envs_ascend), \ - patch('vllm_ascend.utils.get_ascend_config', return_value=mock_ascend_config): + with ( + patch("vllm_ascend.distributed.parallel_state.model_parallel_initialized", return_value=False), + patch("vllm_ascend.distributed.parallel_state.init_model_parallel_group"), + patch("vllm_ascend.distributed.parallel_state.get_current_vllm_config", return_value=mock_vllm_config), + patch("vllm_ascend.distributed.parallel_state.get_ascend_config", return_value=mock_ascend_config), + patch("vllm_ascend.utils.envs_ascend", new=mock_envs_ascend), + patch("vllm_ascend.utils.get_ascend_config", return_value=mock_ascend_config), + ): init_ascend_model_parallel(parallel_config) mc2_group = get_mc2_group() diff --git a/tests/ut/eplb/adaptor/test_vllm_adaptor.py b/tests/ut/eplb/adaptor/test_vllm_adaptor.py index 7989a368dfc..71b997b4027 100644 --- a/tests/ut/eplb/adaptor/test_vllm_adaptor.py +++ b/tests/ut/eplb/adaptor/test_vllm_adaptor.py @@ -2,10 +2,10 @@ from unittest.mock import MagicMock, patch import torch +from transformers import DeepseekV2Config from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor from vllm_ascend.quantization.methods.base import QuantType -from transformers import DeepseekV2Config class TestVllmAdaptor(unittest.TestCase): @@ -45,6 +45,6 @@ def tearDown(self): self.mock_rank.stop() self.mock_size.stop() + if __name__ == "__main__": unittest.main() - \ No newline at end of file diff --git a/tests/ut/eplb/core/policy/test_policy_factory.py b/tests/ut/eplb/core/policy/test_policy_factory.py index 737f3ab8d55..8813abc5643 100644 --- a/tests/ut/eplb/core/policy/test_policy_factory.py +++ b/tests/ut/eplb/core/policy/test_policy_factory.py @@ -1,8 +1,10 @@ import unittest + import torch + +from vllm_ascend.eplb.core.eplb_worker import EplbWorker from vllm_ascend.eplb.core.policy.policy_factory import PolicyFactory from vllm_ascend.eplb.core.policy.policy_flashlb import generate_layered_experts -from vllm_ascend.eplb.core.eplb_worker import EplbWorker class TestEplbRebalancePolicies(unittest.TestCase): @@ -10,7 +12,7 @@ def setUp(self): torch.manual_seed(42) self.current_expert_table = generate_layered_experts() x = torch.rand(100, 58, 32, 9) - x = x ** 10 + x = x**10 self.expert_workload = (x * 999 + 1).long() self.hotness = EplbWorker._calculate_hotness(self.current_expert_table, self.expert_workload.sum(0)) @@ -30,5 +32,5 @@ def test_flashlb_rebalance_experts(self): self.assertLessEqual(update_mean, 1.1) -if __name__ == '__main__': - unittest.main(verbosity=2) \ No newline at end of file +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/ut/eplb/core/test_eplb_device_transfer_loader.py b/tests/ut/eplb/core/test_eplb_device_transfer_loader.py index 3284e345371..4711f5b21b3 100644 --- a/tests/ut/eplb/core/test_eplb_device_transfer_loader.py +++ b/tests/ut/eplb/core/test_eplb_device_transfer_loader.py @@ -11,22 +11,11 @@ def mock_adaptor(): adaptor = MagicMock() - adaptor.expert_map_per_layer_cpu = { - 0: { - 10: torch.tensor(1), - 20: torch.tensor(0) - } - } - - adaptor.expert_param_per_layer = { - 0: { - 0: [[torch.tensor([1.0])]], - 1: [[torch.tensor([2.0])]] - } - } - - adaptor.buffer_tensor_list = [[[torch.tensor([3.0])], - [torch.tensor([4.0])]]] + adaptor.expert_map_per_layer_cpu = {0: {10: torch.tensor(1), 20: torch.tensor(0)}} + + adaptor.expert_param_per_layer = {0: {0: [[torch.tensor([1.0])]], 1: [[torch.tensor([2.0])]]}} + + adaptor.buffer_tensor_list = [[[torch.tensor([3.0])], [torch.tensor([4.0])]]] return adaptor @@ -35,15 +24,15 @@ def test_generate_task_and_state_flow(mock_adaptor): loader_obj = loader.D2DExpertWeightLoader() loader_obj.set_adator(mock_adaptor) - with patch("torch.distributed.P2POp") as mock_p2p, \ - patch("torch.distributed.isend", return_value="isend_op"), \ - patch("torch.distributed.irecv", return_value="irecv_op"): - + with ( + patch("torch.distributed.P2POp") as mock_p2p, + patch("torch.distributed.isend", return_value="isend_op"), + patch("torch.distributed.irecv", return_value="irecv_op"), + ): mock_p2p.side_effect = lambda op, tensor, rank: (op, tensor, rank) loader_obj.state = loader.ExpertWeightUpdateState.READY - loader_obj.generate_expert_d2d_transfer_task([(1, 10)], [(2, 20)], - {20: torch.tensor(0)}, 0) + loader_obj.generate_expert_d2d_transfer_task([(1, 10)], [(2, 20)], {20: torch.tensor(0)}, 0) assert loader_obj.comm_op_list is None loader_obj.state = loader.ExpertWeightUpdateState.WAITING @@ -62,8 +51,7 @@ def test_asyn_transfer_and_update(mock_adaptor): reqs: list[MagicMock] = [] - with patch("torch.distributed.batch_isend_irecv", - return_value=[MagicMock(), MagicMock()]): + with patch("torch.distributed.batch_isend_irecv", return_value=[MagicMock(), MagicMock()]): loader_obj.asyn_expert_weight_transfer(reqs) assert loader_obj.state == loader.ExpertWeightUpdateState.TRANSFERRING diff --git a/tests/ut/eplb/core/test_eplb_utils.py b/tests/ut/eplb/core/test_eplb_utils.py index f5388680f1f..01b6f9a53bc 100644 --- a/tests/ut/eplb/core/test_eplb_utils.py +++ b/tests/ut/eplb/core/test_eplb_utils.py @@ -21,9 +21,8 @@ def setUp(self, mock_fix_incompatible_config): "eplb_config": {"dynamic_eplb": True, "num_redundant_experts": 2}, } from vllm.model_executor.layers.fused_moe.config import RoutingMethodType - moe_parallel_config = FusedMoEParallelConfig( - 2, 0, 1, 2, 1, 1, 1, 1, 1, True, "hccl", - enable_eplb=True) + + moe_parallel_config = FusedMoEParallelConfig(2, 0, 1, 2, 1, 1, 1, 1, 1, True, "hccl", enable_eplb=True) moe_config = FusedMoEConfig( num_experts=8, experts_per_token=8, diff --git a/tests/ut/eplb/test_eplb_updator.py b/tests/ut/eplb/test_eplb_updator.py index d45ad2b0bfb..d69f6660091 100644 --- a/tests/ut/eplb/test_eplb_updator.py +++ b/tests/ut/eplb/test_eplb_updator.py @@ -1,12 +1,13 @@ import unittest from unittest.mock import MagicMock, patch + import torch + from vllm_ascend.eplb.eplb_updator import EplbUpdator class TestEplbUpdatorComputeAndSetMoeLoad(unittest.TestCase): def setUp(self): - # ====================== 1. Mock environment ====================== self.rank = 0 self.world_size = 4 @@ -29,8 +30,7 @@ def mock_all_gather(tensor, dim): self.mock_comm_group.all_gather = mock_all_gather - p3 = patch("vllm_ascend.eplb.eplb_updator.get_dynamic_eplb_group", - return_value=self.mock_comm_group) + p3 = patch("vllm_ascend.eplb.eplb_updator.get_dynamic_eplb_group", return_value=self.mock_comm_group) self.addCleanup(p3.stop) p3.start() @@ -42,10 +42,7 @@ def mock_all_gather(tensor, dim): self.eplb_process.shared_dict = {} self.updator = EplbUpdator( - eplb_config=self.eplb_config, - loader=self.loader, - eplb_process=self.eplb_process, - process=self.process + eplb_config=self.eplb_config, loader=self.loader, eplb_process=self.eplb_process, process=self.process ) # ====================== 4. Mock adaptor ====================== @@ -77,5 +74,5 @@ def test_compute_and_set_moe_load_multi_stage(self): self.assertEqual(moe_load.device.type, "cpu") -if __name__ == '__main__': - unittest.main() \ No newline at end of file +if __name__ == "__main__": + unittest.main()