diff --git a/.gitignore b/.gitignore index e49d1d6ba619..24d3cb4fad3b 100644 --- a/.gitignore +++ b/.gitignore @@ -136,6 +136,7 @@ venv/ ENV/ env.bak/ venv.bak/ +.envrc # Spyder project settings .spyderproject diff --git a/tests/v1/worker/test_cpu_model_runner.py b/tests/v1/worker/test_cpu_model_runner.py new file mode 100644 index 000000000000..61369fc68678 --- /dev/null +++ b/tests/v1/worker/test_cpu_model_runner.py @@ -0,0 +1,36 @@ +import pytest +import torch + +from vllm.v1.worker.cpu_model_runner import CPUModelRunner +from vllm.config import (ModelConfig, SchedulerConfig, CacheConfig, + ParallelConfig, VllmConfig, DeviceConfig) + +DEVICE = torch.device("cpu") + +def get_vllm_config(): + model_config = ModelConfig( + model="facebook/opt-125m", + tokenizer="facebook/opt-125m", + ) + device_config = DeviceConfig( + device="cpu", + ) + cache_config = CacheConfig() + scheduler_config = SchedulerConfig() + parallel_config = ParallelConfig() + vllm_config = VllmConfig( + model_config=model_config, + cache_config=cache_config, + scheduler_config=scheduler_config, + parallel_config=parallel_config, + device_config=device_config, + ) + return vllm_config + +@pytest.fixture +def model_runner(): + return CPUModelRunner(get_vllm_config(), DEVICE) + +def test_execute_model(model_runner: CPUModelRunner): + with pytest.raises(NotImplementedError): + model_runner.execute_model(None) \ No newline at end of file diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index 607cfc0ef69c..8d9e43e1ba35 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 from contextlib import contextmanager -from typing import Any +from typing import Any, Optional, Union import torch @@ -8,6 +8,8 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.sequence import IntermediateTensors +from vllm.v1.outputs import ModelRunnerOutput logger = init_logger(__name__) @@ -69,6 +71,13 @@ def _init_device_properties(self) -> None: def _sync_device(self) -> None: pass + @torch.inference_mode() + def execute_model( + self, + scheduler_output, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[ModelRunnerOutput, IntermediateTensors]: + raise NotImplementedError @contextmanager def _set_global_compilation_settings():