Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,6 @@ def test_is_default_v2_model_runner_model(model_config, expected):
assert VllmConfig._is_default_v2_model_runner_model(config) is expected


def test_use_v2_model_runner_defaults_to_v1_when_kv_connector_present():
config = SimpleNamespace(kv_transfer_config=object())
with patch.object(envs, "VLLM_USE_V2_MODEL_RUNNER", None):
result = VllmConfig.use_v2_model_runner.fget(config)
assert result is False


@pytest.mark.skip_global_cleanup
def test_with_hf_config_populates_missing_architectures_from_causal_lm_mapping(
monkeypatch,
Expand Down
4 changes: 0 additions & 4 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,10 +494,6 @@ def use_v2_model_runner(self) -> bool:
if use_v2_model_runner is not None:
return use_v2_model_runner

# KVCache layout changes are breaking, let's stick with v1 for now (see #42846)
if self.kv_transfer_config is not None:
return False

if not self._is_default_v2_model_runner_model():
return False

Expand Down
61 changes: 49 additions & 12 deletions vllm/v1/worker/gpu/attn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
UniformTypeKVCacheSpecs,
)
from vllm.v1.worker.gpu.model_states.interface import ModelSpecificAttnMetadata
from vllm.v1.worker.utils import AttentionGroup, bind_kv_cache
from vllm.v1.worker.utils import (
AttentionGroup,
bind_kv_cache,
prepare_kernel_block_sizes,
)


@dataclass(frozen=True)
Expand Down Expand Up @@ -51,13 +55,12 @@ def init_attn_backend(
dict[str, type[AttentionBackend]],
list[list[AttentionGroup]],
AttentionCGSupportInfo,
list[int],
]:
attn_backends: dict[str, type[AttentionBackend]] = {}
attn_groups: list[list[AttentionGroup]] = []
attn_backend_workspace: torch.Tensor | None = None
# Find minimum cudagraph support across all attention backends
min_cg_support = AttentionCGSupport.ALWAYS
min_cg_attn_backend = None

# Phase 1: discover attention groups for each kv cache group.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups
):
Expand Down Expand Up @@ -91,12 +94,26 @@ def init_attn_backend(
else:
group_map[key].layer_names.append(layer_name)

groups = [group_map[key] for key in group_order]
attn_groups.append([group_map[key] for key in group_order])

# Phase 2: pick a kernel block size per kv cache group that is supported
# by all backends within that group.
kernel_block_sizes = prepare_kernel_block_sizes(kv_cache_config, attn_groups)

# Phase 3: create metadata builders and determine cudagraph support.
attn_backend_workspace: torch.Tensor | None = None
min_cg_support = AttentionCGSupport.ALWAYS
min_cg_attn_backend = None
for kv_cache_group_id, groups in enumerate(attn_groups):
kv_cache_group_spec = kv_cache_config.kv_cache_groups[kv_cache_group_id]
kernel_block_size = None
if kv_cache_group_id < len(kernel_block_sizes):
kernel_block_size = kernel_block_sizes[kv_cache_group_id]
for group in groups:
group.create_metadata_builders(
vllm_config=vllm_config,
device=device,
kernel_block_size=None,
kernel_block_size=kernel_block_size,
num_metadata_builders=1,
)
builder = group.get_metadata_builder(0)
Expand All @@ -113,8 +130,7 @@ def init_attn_backend(
)
if cg_support.value < min_cg_support.value:
min_cg_support = cg_support
min_cg_attn_backend = attn_backend.__name__
attn_groups.append(groups)
min_cg_attn_backend = group.backend.__name__

return (
attn_backends,
Expand All @@ -123,6 +139,7 @@ def init_attn_backend(
min_cg_support=min_cg_support,
min_cg_attn_backend=min_cg_attn_backend,
),
kernel_block_sizes,
)


Expand All @@ -148,10 +165,13 @@ def _reshape_kv_cache(
kv_cache_raw_tensors: dict[str, torch.Tensor],
attn_backends: dict[str, type[AttentionBackend]],
cache_dtype: str,
kernel_block_sizes: list[int],
) -> dict[str, Any]:
kv_caches: dict[str, Any] = {}
has_attn, has_mamba = False, False
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
for kv_cache_group_id, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups
):
for layer_name in kv_cache_group_spec.layer_names:
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
Expand All @@ -164,9 +184,21 @@ def _reshape_kv_cache(
if isinstance(kv_cache_spec, AttentionSpec):
has_attn = True
attn_backend = attn_backends[layer_name]

if kv_cache_group_id < len(kernel_block_sizes):
kernel_block_size = kernel_block_sizes[kv_cache_group_id]
num_blocks *= kv_cache_spec.block_size // kernel_block_size
else:
kernel_block_size = kv_cache_spec.block_size

if kv_cache_spec.storage_block_size != kv_cache_spec.block_size:
shape_block_size = kv_cache_spec.storage_block_size
else:
shape_block_size = kernel_block_size

kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks,
kv_cache_spec.storage_block_size,
shape_block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype_str=cache_dtype,
Expand Down Expand Up @@ -275,10 +307,15 @@ def init_kv_cache(
attn_backends: dict[str, type[AttentionBackend]],
device: torch.device,
cache_dtype: str,
kernel_block_sizes: list[int],
) -> dict[str, Any]:
kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device)
kv_caches = _reshape_kv_cache(
kv_cache_config, kv_cache_raw_tensors, attn_backends, cache_dtype
kv_cache_config,
kv_cache_raw_tensors,
attn_backends,
cache_dtype,
kernel_block_sizes,
)
bind_kv_cache(kv_caches, forward_context, runner_kv_caches)
return kv_caches
Expand Down
14 changes: 12 additions & 2 deletions vllm/v1/worker/gpu/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ def __init__(
max_num_batched_tokens: int,
max_num_blocks_per_group: list[int],
device: torch.device,
kernel_block_sizes: list[int],
cp_size: int = 1,
cp_rank: int = 0,
cp_interleave: int = 1,
):
self.block_sizes = block_sizes
self.kernel_block_sizes = kernel_block_sizes
self.max_num_reqs = max_num_reqs
self.max_num_batched_tokens = max_num_batched_tokens
self.device = device
Expand All @@ -32,10 +34,15 @@ def __init__(

self.num_kv_cache_groups = len(self.block_sizes)
assert len(max_num_blocks_per_group) == self.num_kv_cache_groups

self.blocks_per_kv_block = [
bs // kbs for bs, kbs in zip(block_sizes, kernel_block_sizes)
]

# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self.block_tables: list[StagedWriteTensor] = []
for i in range(self.num_kv_cache_groups):
max_num_blocks = max_num_blocks_per_group[i]
max_num_blocks = max_num_blocks_per_group[i] * self.blocks_per_kv_block[i]
block_table = StagedWriteTensor(
(self.max_num_reqs, max_num_blocks),
dtype=torch.int32,
Expand Down Expand Up @@ -84,7 +91,7 @@ def init_block_table_layout_tensors(self) -> None:
device=self.device,
)
self.block_sizes_tensor = torch.tensor(
self.block_sizes, dtype=torch.int32, device=self.device
self.kernel_block_sizes, dtype=torch.int32, device=self.device
)
self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables)

Expand All @@ -97,6 +104,9 @@ def append_block_ids(
for i in range(self.num_kv_cache_groups):
start = self.num_blocks.np[i, req_index] if not overwrite else 0
block_ids = new_block_ids[i]
bpk = self.blocks_per_kv_block[i]
if bpk > 1:
block_ids = [b * bpk + k for b in block_ids for k in range(bpk)]
self.block_tables[i].stage_write(req_index, start, block_ids)
self.num_blocks.np[i, req_index] = start + len(block_ids)

Expand Down
10 changes: 6 additions & 4 deletions vllm/v1/worker/gpu/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,20 +392,21 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
) + spec.num_speculative_blocks
max_num_blocks_per_group.append(max_num_blocks)

(self.attn_backends, self.attn_groups, attn_cg_support, kernel_block_sizes) = (
init_attn_backend(self.kv_cache_config, self.vllm_config, self.device)
)

self.block_tables = BlockTables(
block_sizes=block_sizes,
max_num_reqs=self.max_num_reqs,
max_num_batched_tokens=self.max_num_tokens,
max_num_blocks_per_group=max_num_blocks_per_group,
device=self.device,
kernel_block_sizes=kernel_block_sizes,
cp_size=self.dcp_size,
cp_rank=self.dcp_rank,
cp_interleave=self.cp_interleave,
)

self.attn_backends, self.attn_groups, attn_cg_support = init_attn_backend(
self.kv_cache_config, self.vllm_config, self.device
)
initialize_mamba_ssu_backend(
self.vllm_config.mamba_config, self.kv_cache_config
)
Expand Down Expand Up @@ -443,6 +444,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
self.attn_backends,
self.device,
self.cache_config.cache_dtype,
kernel_block_sizes,
)
self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict)

Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def set_attn(
) -> None:
self.model_state = model_state
self.kv_cache_config = kv_cache_config
_, self.attn_groups, _ = init_attn_backend(
_, self.attn_groups, _, _ = init_attn_backend(
kv_cache_config,
self.vllm_config,
self.device,
Expand Down
Loading