Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions vllm/v1/worker/gpu/attn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
UniformTypeKVCacheSpecs,
)
from vllm.v1.worker.gpu.model_states.interface import ModelSpecificAttnMetadata
from vllm.v1.worker.utils import AttentionGroup, bind_kv_cache
from vllm.v1.worker.utils import (
AttentionGroup,
bind_kv_cache,
select_common_block_size,
)


@dataclass(frozen=True)
Expand Down Expand Up @@ -93,10 +97,16 @@ def init_attn_backend(

groups = [group_map[key] for key in group_order]
for group in groups:
if isinstance(group.kv_cache_spec, AttentionSpec):
kernel_block_size = select_common_block_size(
group.kv_cache_spec.block_size, [group.backend]
)
else:
kernel_block_size = None
Comment thread
njhill marked this conversation as resolved.
Outdated
group.create_metadata_builders(
vllm_config=vllm_config,
device=device,
kernel_block_size=None,
kernel_block_size=kernel_block_size,
num_metadata_builders=1,
)
builder = group.get_metadata_builder(0)
Expand Down Expand Up @@ -164,9 +174,21 @@ def _reshape_kv_cache(
if isinstance(kv_cache_spec, AttentionSpec):
has_attn = True
attn_backend = attn_backends[layer_name]

kernel_block_size = select_common_block_size(
kv_cache_spec.block_size, [attn_backend]
)
num_blocks_per_kv_block = kv_cache_spec.block_size // kernel_block_size
kernel_num_blocks = num_blocks * num_blocks_per_kv_block

if kv_cache_spec.storage_block_size != kv_cache_spec.block_size:
shape_block_size = kv_cache_spec.storage_block_size
else:
shape_block_size = kernel_block_size
Comment thread
njhill marked this conversation as resolved.
Outdated

kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks,
kv_cache_spec.storage_block_size,
kernel_num_blocks,
shape_block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype_str=cache_dtype,
Expand Down
33 changes: 31 additions & 2 deletions vllm/v1/worker/gpu/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable

import numpy as np
import torch

from vllm.triton_utils import tl, triton
Expand All @@ -17,11 +18,15 @@ def __init__(
max_num_batched_tokens: int,
max_num_blocks_per_group: list[int],
device: torch.device,
kernel_block_sizes: list[int] | None = None,
cp_size: int = 1,
cp_rank: int = 0,
cp_interleave: int = 1,
):
self.block_sizes = block_sizes
if kernel_block_sizes is None:
kernel_block_sizes = block_sizes
self.kernel_block_sizes = kernel_block_sizes
self.max_num_reqs = max_num_reqs
self.max_num_batched_tokens = max_num_batched_tokens
self.device = device
Expand All @@ -32,10 +37,15 @@ def __init__(

self.num_kv_cache_groups = len(self.block_sizes)
assert len(max_num_blocks_per_group) == self.num_kv_cache_groups

self.blocks_per_kv_block = [
bs // kbs for bs, kbs in zip(block_sizes, kernel_block_sizes)
]

# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self.block_tables: list[StagedWriteTensor] = []
for i in range(self.num_kv_cache_groups):
max_num_blocks = max_num_blocks_per_group[i]
max_num_blocks = max_num_blocks_per_group[i] * self.blocks_per_kv_block[i]
block_table = StagedWriteTensor(
(self.max_num_reqs, max_num_blocks),
dtype=torch.int32,
Expand Down Expand Up @@ -84,7 +94,7 @@ def init_block_table_layout_tensors(self) -> None:
device=self.device,
)
self.block_sizes_tensor = torch.tensor(
self.block_sizes, dtype=torch.int32, device=self.device
self.kernel_block_sizes, dtype=torch.int32, device=self.device
)
self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables)

Expand All @@ -97,6 +107,9 @@ def append_block_ids(
for i in range(self.num_kv_cache_groups):
start = self.num_blocks.np[i, req_index] if not overwrite else 0
block_ids = new_block_ids[i]
bpk = self.blocks_per_kv_block[i]
if bpk > 1:
block_ids = _expand_to_kernel_blocks(block_ids, bpk)
self.block_tables[i].stage_write(req_index, start, block_ids)
self.num_blocks.np[i, req_index] = start + len(block_ids)

Expand Down Expand Up @@ -173,6 +186,22 @@ def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor:
return self.slot_mappings[:, :num_tokens]


def _expand_to_kernel_blocks(
block_ids: list[int],
blocks_per_kv_block: int,
) -> list[int]:
"""Expand scheduler block IDs to kernel block IDs.

Each scheduler block of size B maps to `blocks_per_kv_block` kernel blocks
of size B/blocks_per_kv_block. E.g. scheduler block 3 with ratio 2
becomes kernel blocks [6, 7].
"""
arr = np.array(block_ids, dtype=np.int32)
arange = np.arange(blocks_per_kv_block, dtype=np.int32)
expanded = (arr.reshape(-1, 1) * blocks_per_kv_block + arange).reshape(-1)
return expanded.tolist()


@triton.jit(do_not_specialize=["num_reqs"])
def _gather_block_tables_kernel(
batch_idx_to_req_idx, # [batch_size]
Expand Down
14 changes: 10 additions & 4 deletions vllm/v1/worker/gpu/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
from vllm.v1.worker.gpu.states import RequestState
from vllm.v1.worker.gpu.structured_outputs import StructuredOutputsWorker
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.utils import prepare_kernel_block_sizes

logger = init_logger(__name__)

Expand Down Expand Up @@ -372,20 +373,25 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
) + spec.num_speculative_blocks
max_num_blocks_per_group.append(max_num_blocks)

self.attn_backends, self.attn_groups, attn_cg_support = init_attn_backend(
self.kv_cache_config, self.vllm_config, self.device
)

kernel_block_sizes = prepare_kernel_block_sizes(
kv_cache_config, self.attn_groups
)
Comment thread
njhill marked this conversation as resolved.
Outdated

self.block_tables = BlockTables(
block_sizes=block_sizes,
max_num_reqs=self.max_num_reqs,
max_num_batched_tokens=self.max_num_tokens,
max_num_blocks_per_group=max_num_blocks_per_group,
device=self.device,
kernel_block_sizes=kernel_block_sizes,
cp_size=self.dcp_size,
cp_rank=self.dcp_rank,
cp_interleave=self.cp_interleave,
)

self.attn_backends, self.attn_groups, attn_cg_support = init_attn_backend(
self.kv_cache_config, self.vllm_config, self.device
)
initialize_mamba_ssu_backend(
self.vllm_config.mamba_config, self.kv_cache_config
)
Expand Down
Loading