Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
SupportsHMA,
)
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
Expand Down Expand Up @@ -76,7 +77,7 @@ def __repr__(self) -> str:
return f"<MooncakeStoreKVEvents events={self.get_all_events()}>"


class MooncakeStoreConnector(KVConnectorBase_V1):
class MooncakeStoreConnector(KVConnectorBase_V1, SupportsHMA):
"""KV connector using MooncakeDistributedStore as shared KV pool."""

@property
Expand Down Expand Up @@ -106,9 +107,13 @@ def __init__(
self.connector_worker: MooncakeStoreWorker | None = None

if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler = MooncakeStoreScheduler(vllm_config)
self.connector_scheduler = MooncakeStoreScheduler(
vllm_config, self._kv_cache_config
)
else:
self.connector_worker = MooncakeStoreWorker(vllm_config)
self.connector_worker = MooncakeStoreWorker(
vllm_config, self._kv_cache_config
)

# ============================================================
# Scheduler-side methods
Expand Down Expand Up @@ -150,6 +155,16 @@ def request_finished(
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)

def request_finished_all_groups(
self,
request: Request,
block_ids: tuple[list[int], ...],
) -> tuple[bool, dict[str, Any] | None]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished_all_groups(
request, block_ids
)

def update_connector_output(self, connector_output: KVConnectorOutput):
kv_cache_events = connector_output.kv_cache_events
if not kv_cache_events or not isinstance(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class KeyMetadata:
pcp_rank: int
dcp_rank: int
pp_rank: int
group_id: int = 0


@dataclass(order=True)
Expand All @@ -46,6 +47,7 @@ def __hash__(self):
self.key_metadata.pcp_rank,
self.key_metadata.dcp_rank,
self.key_metadata.pp_rank,
self.key_metadata.group_id,
self.chunk_hash,
)
)
Expand All @@ -54,6 +56,7 @@ def to_string(self) -> str:
return (
f"{self.key_metadata.model_name}"
f"@tp_rank:{self.key_metadata.tp_rank}"
f"@group_id:{self.key_metadata.group_id}"
f"@pcp{self.key_metadata.pcp_rank}"
f"@dcp{self.key_metadata.dcp_rank}"
f"@pp_rank:{self.key_metadata.pp_rank}"
Expand Down Expand Up @@ -83,15 +86,26 @@ def set_block_len(self, block_len: list[int]):
self.block_len = block_len

def prepare_value(
self, start: int, end: int, block_ids: list[int]
self,
start: int,
end: int,
block_ids: list[int] | tuple[list[int], ...],
group_idx: int = 0,
) -> tuple[list[int], list[int], int]:
"""Compute memory addresses and sizes for a token range.

Args:
block_ids: Either a single list of block IDs (non-HMA) or a tuple
of per-group block IDs (HMA).
group_idx: Which group's block IDs to use when block_ids is a tuple.

Returns:
(addr_list, size_list, block_id)
"""
addr_list = []
size_list = []
if isinstance(block_ids, tuple):
block_ids = block_ids[group_idx]
block_id = block_ids[start // self.block_size]
length = len(self.block_len)
for index, base_addr in enumerate(self.kv_caches_base_addr):
Expand Down Expand Up @@ -154,7 +168,7 @@ class RequestTracker:

req_id: str
token_len: int
allocated_block_ids: list[int]
allocated_block_ids: tuple[list[int], ...]
num_saved_tokens: int = 0
token_ids: list[int] | None = None
# Snapshot of the prefill range length at tracker creation time.
Expand All @@ -167,14 +181,20 @@ def update(
new_block_ids: tuple[list[int], ...] | list[int],
) -> None:
if len(new_block_ids) == 0:
new_block_ids = []
elif isinstance(new_block_ids, tuple):
new_block_ids = new_block_ids[0]
elif isinstance(new_block_ids, list):
pass
else:
raise ValueError(f"Unsupported new_block_ids type {type(new_block_ids)}")
self.allocated_block_ids.extend(new_block_ids)
return
if isinstance(new_block_ids, list):
new_block_ids = (new_block_ids,)
if isinstance(self.allocated_block_ids, list):
self.allocated_block_ids = (self.allocated_block_ids,)
if len(self.allocated_block_ids) != len(new_block_ids):
raise ValueError(
f"Block ID length mismatch: "
f"{len(self.allocated_block_ids)} vs {len(new_block_ids)}"
)
self.allocated_block_ids = tuple(
list(old) + list(new)
for old, new in zip(self.allocated_block_ids, new_block_ids)
)


@dataclass
Expand All @@ -183,7 +203,7 @@ class ReqMeta:

req_id: str
token_len_chunk: int
block_ids: list[int]
block_ids: tuple[list[int], ...]
block_hashes: list[BlockHash]

can_save: bool | None = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,13 @@
LookupKeyClient,
)
from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
SlidingWindowSpec,
)
from vllm.v1.request import Request

logger = init_logger(__name__)
Expand All @@ -45,7 +50,11 @@ def _new_req_prefill_tokens(request: NewRequestData) -> list[int]:
class MooncakeStoreScheduler:
"""Scheduler-side component for MooncakeStoreConnector."""

def __init__(self, vllm_config: VllmConfig):
def __init__(
self,
vllm_config: VllmConfig,
kv_cache_config: "KVCacheConfig | None" = None,
):
assert vllm_config.kv_transfer_config is not None
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
Expand All @@ -68,11 +77,34 @@ def __init__(self, vllm_config: VllmConfig):
)
)

# HMA detection and sliding-window block counts per group.
if kv_cache_config is not None:
self._is_hma_required = (
not vllm_config.scheduler_config.disable_hybrid_kv_cache_manager
and any(
not isinstance(g.kv_cache_spec, FullAttentionSpec)
for g in kv_cache_config.kv_cache_groups
)
)
sw_sizes_tokens: list[tuple[int, int]] = [
(g.kv_cache_spec.sliding_window, g.kv_cache_spec.block_size)
if isinstance(g.kv_cache_spec, SlidingWindowSpec)
else (0, self._block_size)
for g in kv_cache_config.kv_cache_groups
]
self.blocks_per_sw = [
cdiv(n_tokens, block_size) + 1 if n_tokens else 0
for n_tokens, block_size in sw_sizes_tokens
]
else:
self._is_hma_required = False
self.blocks_per_sw = []

# Per-request state
self.load_specs: dict[str, LoadSpec] = {} # to be loaded
self._request_trackers: dict[str, RequestTracker] = {} # scheduled new requests
self._preempted_req_ids: set[str] = set() # preempted requests
self._unfinished_requests: dict[str, tuple[Request, list[int]]] = {}
self._unfinished_requests: dict[str, tuple[Request, tuple[list[int], ...]]] = {}
self._unfinished_request_ids: set[str] = set()

def get_num_new_matched_tokens(
Expand Down Expand Up @@ -126,9 +158,9 @@ def update_state_after_alloc(
num_external_tokens: int,
):
"""Update state after block allocation."""
local_block_ids: list[int] = []
local_block_ids: tuple[list[int], ...] = ()
if num_external_tokens > 0:
local_block_ids = blocks.get_block_ids()[0]
local_block_ids = blocks.get_block_ids()

self._unfinished_requests[request.request_id] = (request, local_block_ids)
self._unfinished_request_ids.add(request.request_id)
Expand Down Expand Up @@ -190,10 +222,9 @@ def build_connector_meta(
request_real = request_tuple[0] # type: ignore[index]

if not isinstance(request.block_ids[0], list):
unfolded_block_ids = request.block_ids.copy()
unfolded_block_ids = (request.block_ids.copy(),)
else:
# TODO: support HMA
unfolded_block_ids = request.block_ids[0].copy()
unfolded_block_ids = tuple(list(g) for g in request.block_ids)

prefill_tokens = _new_req_prefill_tokens(request)
request_tracker = RequestTracker(
Expand Down Expand Up @@ -237,9 +268,9 @@ def build_connector_meta(
if req_id in self._preempted_req_ids:
# Resumed after preemption
if isinstance(new_block_ids, tuple):
block_ids_list = new_block_ids[0].copy()
block_ids_list = tuple(list(g) for g in new_block_ids)
else:
block_ids_list = new_block_ids.copy()
block_ids_list = (new_block_ids.copy(),)
self._preempted_req_ids.discard(req_id)
load_spec = self.load_specs.pop(req_id, None)
request_tuple = self._unfinished_requests.get(req_id)
Expand Down Expand Up @@ -358,10 +389,35 @@ def build_connector_meta(

return meta

def get_sw_clipped_blocks(
self,
block_ids: tuple[list[int], ...] | list[int],
) -> tuple[list[int], ...]:
"""Clip per-group block IDs to sliding window size.

For groups with SlidingWindowAttention, only the most recent
``blocks_per_sw`` blocks are retained. Non-SWA groups keep all
blocks unchanged.
"""
if isinstance(block_ids, list):
block_ids = (block_ids,)
if len(block_ids) == 0 or not self._is_hma_required:
return block_ids
assert len(block_ids) == len(self.blocks_per_sw), (
f"Block ID group count mismatch: "
f"{len(block_ids)} vs {len(self.blocks_per_sw)}"
)
return tuple([
blocks[-self.blocks_per_sw[i]:]
if self.blocks_per_sw[i] > 0
else blocks
for i, blocks in enumerate(block_ids)
])

def request_finished(
self,
request: Request,
block_ids: list[int],
block_ids: list[int] | tuple[list[int], ...],
) -> tuple[bool, dict[str, Any] | None]:
"""Determine whether to delay freeing blocks for async save."""
if self.kv_role == "kv_consumer":
Expand All @@ -370,11 +426,23 @@ def request_finished(
assert tracker is not None
if tracker.num_saved_tokens <= 0:
return False, None
delay_free_blocks = len(block_ids) > 0
if isinstance(block_ids, list):
block_ids = (block_ids,)
block_ids = self.get_sw_clipped_blocks(block_ids)
delay_free_blocks = any(len(g) > 0 for g in block_ids)
if delay_free_blocks:
total_blocks = sum(len(g) for g in block_ids)
logger.debug(
"Delaying free of %d blocks for request %s",
len(block_ids),
total_blocks,
request.request_id,
)
return delay_free_blocks, None

def request_finished_all_groups(
self,
request: Request,
block_ids: tuple[list[int], ...],
) -> tuple[bool, dict[str, Any] | None]:
"""Determine whether to delay freeing blocks for async save (HMA)."""
return self.request_finished(request, block_ids)
Loading
Loading