Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
import torch

import vllm.v1.worker.gpu.warmup as gpu_warmup
import vllm.v1.worker.gpu_model_runner as gpu_model_runner_module
from vllm.config import (
AttentionConfig,
Expand Down Expand Up @@ -49,6 +50,105 @@
DEVICE_TYPE = current_platform.device_type


class _FakeV1WarmupBlockTable:
def __init__(self):
self.block_tables = [object(), object()]
self.calls = []

def add_row(self, block_ids, row_idx):
self.calls.append(("add_row", block_ids, row_idx))

def commit_block_table(self, num_reqs):
self.calls.append(("commit_block_table", num_reqs))

def compute_slot_mapping(self, num_reqs, query_start_loc, positions):
self.calls.append(
(
"compute_slot_mapping",
num_reqs,
query_start_loc.detach().cpu().tolist(),
positions.detach().cpu().tolist(),
query_start_loc.dtype,
positions.dtype,
)
)

def clear_row(self, row_idx):
self.calls.append(("clear_row", row_idx))


def _make_v1_slot_mapping_warmup_runner_stub(block_table=None, num_blocks=12):
if block_table is None:
block_table = _FakeV1WarmupBlockTable()
return SimpleNamespace(
device=torch.device("cpu"),
kv_cache_config=SimpleNamespace(num_blocks=num_blocks),
input_batch=SimpleNamespace(block_table=block_table),
)


def test_v1_warmup_runs_slot_mapping_and_clears_temporary_row(monkeypatch):
monkeypatch.setattr(gpu_warmup.torch.accelerator, "synchronize", lambda: None)

block_table = _FakeV1WarmupBlockTable()
runner = _make_v1_slot_mapping_warmup_runner_stub(block_table)

gpu_warmup.warmup_v1_slot_mapping_kernel(runner)

assert block_table.calls == [
("add_row", ([1], [1]), 0),
("commit_block_table", 1),
(
"compute_slot_mapping",
1,
[0, 1],
[0],
torch.int32,
torch.int64,
),
("clear_row", 0),
("commit_block_table", 1),
]


def test_v1_warmup_clears_row_on_slot_mapping_error(monkeypatch):
monkeypatch.setattr(gpu_warmup.torch.accelerator, "synchronize", lambda: None)

block_table = _FakeV1WarmupBlockTable()

def raise_on_compute(*args):
block_table.calls.append(("compute_slot_mapping_error",))
raise RuntimeError("test error")

monkeypatch.setattr(block_table, "compute_slot_mapping", raise_on_compute)
runner = _make_v1_slot_mapping_warmup_runner_stub(block_table)

with pytest.raises(RuntimeError, match="test error"):
gpu_warmup.warmup_v1_slot_mapping_kernel(runner)

assert block_table.calls == [
("add_row", ([1], [1]), 0),
("commit_block_table", 1),
("compute_slot_mapping_error",),
("clear_row", 0),
("commit_block_table", 1),
]


def test_v1_warmup_skips_without_usable_kv_block(monkeypatch):
monkeypatch.setattr(gpu_warmup.torch.accelerator, "synchronize", lambda: None)

block_table = _FakeV1WarmupBlockTable()
runner = _make_v1_slot_mapping_warmup_runner_stub(
block_table,
num_blocks=1,
)

gpu_warmup.warmup_v1_slot_mapping_kernel(runner)

assert block_table.calls == []


def initialize_kv_cache(runner: GPUModelRunner):
"""
Only perform necessary steps in GPUModelRunner.initialize_kv_cache()
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/worker/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def __getitem__(self, idx: int) -> "BlockTable":
return self.block_tables[idx]


@triton.jit
@triton.jit(do_not_specialize=["num_tokens"])
def _compute_slot_mapping_kernel(
num_tokens,
max_num_tokens,
Expand Down
32 changes: 32 additions & 0 deletions vllm/v1/worker/gpu/warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,35 @@ def _alloc_blocks(num_blocks: int) -> list[int]:
worker_execute_model(cleanup_output)
model_runner.kv_connector.set_disabled(False)
torch.accelerator.synchronize()


@torch.inference_mode()
def warmup_v1_slot_mapping_kernel(model_runner: Any) -> None:
"""Warm up V1 slot mapping without running model-specific forward paths.

V1 request input preparation calls `BlockTable.compute_slot_mapping()`.
The legacy `_dummy_run()` path does not exercise this kernel, and synthetic
`execute_model()` warmups are not model-agnostic. Compile the slot mapping
kernel directly before the JIT monitor is enabled.
"""
block_table = model_runner.input_batch.block_table

if not block_table.block_tables:
return

# Block 0 is the null block. Use block 1 only when it is available.
if model_runner.kv_cache_config.num_blocks <= 1:
return

device = model_runner.device
block_table.add_row(tuple([1] for _ in block_table.block_tables), 0)
block_table.commit_block_table(1)
query_start_loc = torch.tensor([0, 1], dtype=torch.int32, device=device)
positions = torch.zeros(1, dtype=torch.int64, device=device)

try:
block_table.compute_slot_mapping(1, query_start_loc, positions)
torch.accelerator.synchronize()
finally:
block_table.clear_row(0)
block_table.commit_block_table(1)
9 changes: 7 additions & 2 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
from vllm.v1.worker.workspace import init_workspace_manager

from ...model_executor.model_loader import TensorizerLoader
from .gpu.warmup import warmup_kernels
from .gpu.warmup import warmup_kernels, warmup_v1_slot_mapping_kernel
from .utils import request_memory

logger = init_logger(__name__)
Expand Down Expand Up @@ -687,7 +687,12 @@ def compile_or_warm_up_model(self) -> CompilationTimes:
if self.use_v2_model_runner:
# V2: Run full execute_model + sample_tokens to JIT compile triton kernels.
warmup_kernels(self.model_runner, self.execute_model, self.sample_tokens)
elif get_pp_group().is_last_rank:
else:
# V1: Compile generic input preparation kernels that legacy
# _dummy_run does not cover before the JIT monitor is enabled.
warmup_v1_slot_mapping_kernel(self.model_runner)

if not self.use_v2_model_runner and get_pp_group().is_last_rank:
# V1: Warm up sampler and preallocate memory buffer for logits and other
# sampling related tensors of max possible shape to avoid memory
# fragmentation issue.
Expand Down
Loading