Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ venv/
ENV/
env.bak/
venv.bak/
.envrc

# Spyder project settings
.spyderproject
Expand Down
36 changes: 36 additions & 0 deletions tests/v1/worker/test_cpu_model_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pytest
import torch

from vllm.v1.worker.cpu_model_runner import CPUModelRunner
from vllm.config import (ModelConfig, SchedulerConfig, CacheConfig,
ParallelConfig, VllmConfig, DeviceConfig)

DEVICE = torch.device("cpu")

def get_vllm_config():
model_config = ModelConfig(
model="facebook/opt-125m",
tokenizer="facebook/opt-125m",
)
device_config = DeviceConfig(
device="cpu",
)
cache_config = CacheConfig()
scheduler_config = SchedulerConfig()
parallel_config = ParallelConfig()
vllm_config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
scheduler_config=scheduler_config,
parallel_config=parallel_config,
device_config=device_config,
)
return vllm_config

@pytest.fixture
def model_runner():
return CPUModelRunner(get_vllm_config(), DEVICE)

def test_execute_model(model_runner: CPUModelRunner):
with pytest.raises(NotImplementedError):
model_runner.execute_model(None)
11 changes: 10 additions & 1 deletion vllm/v1/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
from contextlib import contextmanager
from typing import Any
from typing import Any, Optional, Union

import torch

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.sequence import IntermediateTensors
from vllm.v1.outputs import ModelRunnerOutput

logger = init_logger(__name__)

Expand Down Expand Up @@ -69,6 +71,13 @@ def _init_device_properties(self) -> None:
def _sync_device(self) -> None:
pass

@torch.inference_mode()
def execute_model(
self,
scheduler_output,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, IntermediateTensors]:
raise NotImplementedError

@contextmanager
def _set_global_compilation_settings():
Expand Down
Loading