diff --git a/tests/ut/worker/test_pooling_model_runner.py b/tests/ut/worker/test_pooling_model_runner.py new file mode 100644 index 00000000000..28a0a7d3c6b --- /dev/null +++ b/tests/ut/worker/test_pooling_model_runner.py @@ -0,0 +1,355 @@ +import unittest +from unittest.mock import MagicMock, patch + +import torch +from vllm.distributed.parallel_state import GroupCoordinator +from vllm.engine.arg_utils import EngineArgs +from vllm.pooling_params import PoolingParams +from vllm.sequence import SequenceData, SequenceGroupMetadata + +from vllm_ascend.worker.pooling_model_runner import ( + ModelInputForNPUWithPoolingMetadata, NPUPoolingModelRunner) + + +class TestPoolingModelRunner(unittest.TestCase): + """Unit tests for the NPUPoolingModelRunner class.""" + + def _create_model_runner(self, model: str, *args, + **kwargs) -> NPUPoolingModelRunner: + engine_args = EngineArgs(model, *args, **kwargs) + engine_config = engine_args.create_engine_config() + model_runner = NPUPoolingModelRunner(vllm_config=engine_config, ) + return model_runner + + def setUp(self): + """Initialize test fixtures and common mocks""" + self.attn_backend = "npu" + + model_runner = self._create_model_runner( + "tests/ut/fake_weight", + trust_remote_code=True, + enable_chunked_prefill=False, + ) + + self.runner = model_runner + self.runner.attn_backend = self.attn_backend + model_runner.model = MagicMock() + self.runner = model_runner + # Sample test data + self.sample_tensor_dict = {"tensor1": torch.randn(3, 4)} + self.sample_seq_group = [MagicMock(spec=SequenceGroupMetadata)] + self.sample_finished_ids = ["req1", "req2"] + + @patch( + 'vllm_ascend.worker.pooling_model_runner.ModelInputForNPUWithPoolingMetadata.from_broadcasted_tensor_dict' + ) + def test_make_model_input_from_broadcasted_tensor_dict( + self, mock_from_dict): + """Test tensor dictionary conversion to model input""" + # Setup mock return + expected_output = MagicMock() + mock_from_dict.return_value = expected_output + + # Execute + result = self.runner.make_model_input_from_broadcasted_tensor_dict( + self.sample_tensor_dict) + + # Verify + mock_from_dict.assert_called_once_with(self.sample_tensor_dict, + attn_backend=self.attn_backend) + self.assertEqual(result, expected_output) + + @patch.object(NPUPoolingModelRunner, '_prepare_pooling') + @patch.object(NPUPoolingModelRunner, '_prepare_model_input_tensors') + def test_prepare_model_input_normal_case(self, mock_prepare_tensors, + mock_prepare_pooling): + """Test normal flow of model input preparation""" + # Setup mocks + mock_model_input = ModelInputForNPUWithPoolingMetadata( + seq_lens=[1, 2, 3]) + mock_prepare_tensors.return_value = mock_model_input + + mock_pooling_metadata = MagicMock() + mock_prepare_pooling.return_value = mock_pooling_metadata + + # Execute + result = self.runner.prepare_model_input( + seq_group_metadata_list=self.sample_seq_group, + finished_requests_ids=self.sample_finished_ids) + + # Verify + mock_prepare_tensors.assert_called_once_with(self.sample_seq_group, + self.sample_finished_ids) + mock_prepare_pooling.assert_called_once_with(self.sample_seq_group, + mock_model_input.seq_lens) + self.assertEqual(result.pooling_metadata, mock_pooling_metadata) + + def test_prepare_model_input_null_sequence_group(self): + """Test assertion when seq_group_metadata_list is None""" + with self.assertRaises(AssertionError): + self.runner.prepare_model_input( + seq_group_metadata_list=None, + finished_requests_ids=self.sample_finished_ids) + + @patch.object(NPUPoolingModelRunner, '_prepare_model_input_tensors') + def test_prepare_model_input_null_seq_lens(self, mock_prepare_tensors): + """Test assertion when seq_lens is None in model input""" + # Setup mock with None seq_lens + mock_model_input = MagicMock() + mock_model_input.seq_lens = None + mock_prepare_tensors.return_value = mock_model_input + + with self.assertRaises(AssertionError): + self.runner.prepare_model_input( + seq_group_metadata_list=self.sample_seq_group, + finished_requests_ids=self.sample_finished_ids) + + @patch.object(NPUPoolingModelRunner, '_prepare_pooling') + @patch.object(NPUPoolingModelRunner, '_prepare_model_input_tensors') + def test_prepare_model_input_with_virtual_engine(self, + mock_prepare_tensors, + mock_prepare_pooling): + """Test virtual engine parameter is properly handled""" + # Setup mocks + mock_model_input = ModelInputForNPUWithPoolingMetadata( + seq_lens=[1, 2, 3]) + mock_prepare_tensors.return_value = mock_model_input + + # Execute with virtual_engine parameter + result = self.runner.prepare_model_input( + seq_group_metadata_list=self.sample_seq_group, + virtual_engine=1, + finished_requests_ids=self.sample_finished_ids) + + # Verify virtual_engine doesn't affect the flow + self.assertIsNotNone(result) + + @patch.object(NPUPoolingModelRunner, '_prepare_pooling') + @patch.object(NPUPoolingModelRunner, '_prepare_model_input_tensors') + def test_prepare_model_input_with_null_finished_ids( + self, mock_prepare_tensors, mock_prepare_pooling): + """Test case when finished_requests_ids is None""" + # Setup mocks + mock_model_input = ModelInputForNPUWithPoolingMetadata( + seq_lens=[1, 2, 3]) + mock_prepare_tensors.return_value = mock_model_input + + # Execute with None finished_ids + result = self.runner.prepare_model_input( + seq_group_metadata_list=self.sample_seq_group, + finished_requests_ids=None) + + # Verify + mock_prepare_tensors.assert_called_once_with(self.sample_seq_group, + None) + self.assertIsNotNone(result) + + @patch('vllm.model_executor.pooling_metadata.PoolingMetadata.__init__') + def test_prepare_pooling_normal_case(self, mock_pooling_metadata): + """Test normal case with multiple sequences in group""" + # Setup test data + mock_pooling_metadata.return_value = None + seq_data = { + 1: MagicMock(spec=SequenceData), + 2: MagicMock(spec=SequenceData) + } + pooling_params = MagicMock(spec=PoolingParams) + seq_group = MagicMock(spec=SequenceGroupMetadata) + seq_group.seq_data = seq_data + seq_group.pooling_params = pooling_params + + # Call the function + self.runner._prepare_pooling([seq_group], [10, 20]) + + # Verify results + mock_pooling_metadata.assert_called_once_with(seq_groups=[ + ([1, 2], pooling_params) + ], + seq_data=seq_data, + prompt_lens=[10, 20]) + + @patch('vllm.model_executor.pooling_metadata.PoolingMetadata.__init__') + def test_prepare_pooling_empty_group(self, mock_pooling_metadata): + """Test case with empty sequence group""" + # Setup empty group + mock_pooling_metadata.return_value = None + empty_seq_data: dict[int, SequenceData] = {} + pooling_params = MagicMock(spec=PoolingParams) + empty_group = MagicMock(spec=SequenceGroupMetadata) + empty_group.seq_data = empty_seq_data + empty_group.pooling_params = pooling_params + + # Call the function + self.runner._prepare_pooling([empty_group], []) + + # Verify results + mock_pooling_metadata.assert_called_once_with(seq_groups=[ + ([], pooling_params) + ], + seq_data={}, + prompt_lens=[]) + + @patch('vllm.model_executor.pooling_metadata.PoolingMetadata.__init__') + def test_prepare_pooling_single_sequence(self, mock_pooling_metadata): + """Test case with single sequence in group""" + # Setup single sequence + mock_pooling_metadata.return_value = None + single_seq_data = {3: MagicMock(spec=SequenceData)} + pooling_params = MagicMock(spec=PoolingParams) + single_group = MagicMock(spec=SequenceGroupMetadata) + single_group.seq_data = single_seq_data + single_group.pooling_params = pooling_params + + # Call the function + self.runner._prepare_pooling([single_group], [5]) + + # Verify results + mock_pooling_metadata.assert_called_once_with(seq_groups=[ + ([3], pooling_params) + ], + seq_data=single_seq_data, + prompt_lens=[5]) + + @patch('vllm.model_executor.pooling_metadata.PoolingMetadata.__init__') + def test_prepare_pooling_multiple_groups(self, mock_pooling_metadata): + """Test case with multiple sequence groups""" + # Setup multiple groups + mock_pooling_metadata.return_value = None + seq_data1 = {1: MagicMock(spec=SequenceData)} + seq_data2 = {2: MagicMock(spec=SequenceData)} + params1 = MagicMock(spec=PoolingParams) + params2 = MagicMock(spec=PoolingParams) + + group1 = MagicMock(spec=SequenceGroupMetadata) + group1.seq_data = seq_data1 + group1.pooling_params = params1 + + group2 = MagicMock(spec=SequenceGroupMetadata) + group2.seq_data = seq_data2 + group2.pooling_params = params2 + + # Call the function + self.runner._prepare_pooling([group1, group2], [10, 20]) + + # Verify results + mock_pooling_metadata.assert_called_once_with(seq_groups=[ + ([1], params1), ([2], params2) + ], + seq_data={ + **seq_data1, + **seq_data2 + }, + prompt_lens=[10, 20]) + + @patch('vllm.model_executor.pooling_metadata.PoolingMetadata.__init__') + def test_prepare_pooling_empty_input(self, mock_pooling_metadata): + """Test case with empty input lists""" + # Call the function with empty inputs + mock_pooling_metadata.return_value = None + self.runner._prepare_pooling([], []) + + # Verify results + mock_pooling_metadata.assert_called_once_with(seq_groups=[], + seq_data={}, + prompt_lens=[]) + + @patch('vllm.forward_context.set_forward_context') + @patch('vllm.distributed.parallel_state._PP', + new_callable=lambda: MagicMock(spec=GroupCoordinator, + is_last_rank=True)) + @patch('torch.npu.Event') + @patch.object(NPUPoolingModelRunner, 'set_active_loras') + @patch.object(NPUPoolingModelRunner, 'set_active_prompt_adapters') + def test_execute_model_normal_flow(self, mock_set_adapters, mock_set_loras, + mock_event, mock_pp, mock_set_forward): + """Test normal execution path with all dependencies mocked""" + + # Setup model input mock + mock_input = MagicMock() + mock_input.input_tokens = torch.tensor([1]) + mock_input.input_positions = torch.tensor([0]) + mock_input.multi_modal_kwargs = {} + self.runner.is_driver_worker = True + # Execute + self.runner.execute_model(model_input=mock_input, + kv_caches=[], + num_steps=1) + + # Verify core calls + self.runner.model.pooler.assert_called_once() + + @patch('vllm.forward_context.set_forward_context') + def test_execute_model_invalid_steps(self, mock_set_forward): + """Test ValueError when num_steps != 1""" + with self.assertRaises(ValueError): + self.runner.execute_model(model_input=MagicMock(), + kv_caches=[], + num_steps=2) + mock_set_forward.assert_not_called() + + @patch('vllm.forward_context.set_forward_context') + @patch('vllm.distributed.parallel_state._PP', + new_callable=lambda: MagicMock(spec=GroupCoordinator, + is_last_rank=False)) + @patch('torch.npu.Event') + def test_execute_model_perf_monitoring(self, mock_event, mock_pp, + mock_set_forward): + """Test performance monitoring with timing mocks""" + # Setup mocks + + mock_event.return_value.elapsed_time.return_value = 15.0 + self.runner.observability_config = MagicMock( + collect_model_forward_time=True) + + # Execute + self.runner.execute_model(model_input=MagicMock( + input_tokens=torch.tensor([1]), + input_positions=torch.tensor([0]), + multi_modal_kwargs={}), + kv_caches=[], + num_steps=1) + + # Verify timing calls + self.assertEqual(mock_event.call_count, 2) + + @patch('vllm.forward_context.set_forward_context') + @patch.object(NPUPoolingModelRunner, 'set_active_loras') + @patch('vllm.distributed.parallel_state._PP', + new_callable=lambda: MagicMock(spec=GroupCoordinator, + is_last_rank=False)) + def test_execute_model_lora_config(self, mock_pp, set_active_loras, + mock_set_forward): + """Test LoRA configuration handling""" + # Setup + + self.runner.lora_config = True + mock_input = MagicMock() + mock_input.lora_requests = ["req1"] + mock_input.lora_mapping = {"map": 1} + + # Execute + self.runner.execute_model(model_input=mock_input, + kv_caches=[], + num_steps=1) + + # Verify LoRA call + set_active_loras.assert_called_once_with(["req1"], {"map": 1}) + + @patch('vllm.forward_context.set_forward_context') + @patch('vllm.distributed.parallel_state._PP', + new_callable=lambda: MagicMock(spec=GroupCoordinator, + is_last_rank=False)) + def test_execute_model_not_last_rank(self, mock_pp, mock_set_forward): + """Test behavior when not the last pipeline rank""" + # Setup + + # Execute + self.runner.execute_model(model_input=MagicMock( + input_tokens=torch.tensor([1]), + input_positions=torch.tensor([0]), + multi_modal_kwargs={}), + kv_caches=[], + num_steps=1) + + # Verify pooler not called + self.runner.model.pooler.assert_not_called()