diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/connector.py index 184501a96a5c..f758cce06bba 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/connector.py @@ -26,6 +26,7 @@ KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, + SupportsHMA, ) from vllm.forward_context import ForwardContext from vllm.logger import init_logger @@ -76,7 +77,7 @@ def __repr__(self) -> str: return f"" -class MooncakeStoreConnector(KVConnectorBase_V1): +class MooncakeStoreConnector(KVConnectorBase_V1, SupportsHMA): """KV connector using MooncakeDistributedStore as shared KV pool.""" @property @@ -106,9 +107,13 @@ def __init__( self.connector_worker: MooncakeStoreWorker | None = None if role == KVConnectorRole.SCHEDULER: - self.connector_scheduler = MooncakeStoreScheduler(vllm_config) + self.connector_scheduler = MooncakeStoreScheduler( + vllm_config, self._kv_cache_config + ) else: - self.connector_worker = MooncakeStoreWorker(vllm_config) + self.connector_worker = MooncakeStoreWorker( + vllm_config, self._kv_cache_config + ) # ============================================================ # Scheduler-side methods @@ -150,6 +155,16 @@ def request_finished( assert self.connector_scheduler is not None return self.connector_scheduler.request_finished(request, block_ids) + def request_finished_all_groups( + self, + request: Request, + block_ids: tuple[list[int], ...], + ) -> tuple[bool, dict[str, Any] | None]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished_all_groups( + request, block_ids + ) + def update_connector_output(self, connector_output: KVConnectorOutput): kv_cache_events = connector_output.kv_cache_events if not kv_cache_events or not isinstance( diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/data.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/data.py index acdeedceeaac..a1a275e7b26f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/data.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/data.py @@ -29,6 +29,7 @@ class KeyMetadata: pcp_rank: int dcp_rank: int pp_rank: int + group_id: int = 0 @dataclass(order=True) @@ -46,6 +47,7 @@ def __hash__(self): self.key_metadata.pcp_rank, self.key_metadata.dcp_rank, self.key_metadata.pp_rank, + self.key_metadata.group_id, self.chunk_hash, ) ) @@ -54,6 +56,7 @@ def to_string(self) -> str: return ( f"{self.key_metadata.model_name}" f"@tp_rank:{self.key_metadata.tp_rank}" + f"@group_id:{self.key_metadata.group_id}" f"@pcp{self.key_metadata.pcp_rank}" f"@dcp{self.key_metadata.dcp_rank}" f"@pp_rank:{self.key_metadata.pp_rank}" @@ -83,15 +86,26 @@ def set_block_len(self, block_len: list[int]): self.block_len = block_len def prepare_value( - self, start: int, end: int, block_ids: list[int] + self, + start: int, + end: int, + block_ids: list[int] | tuple[list[int], ...], + group_idx: int = 0, ) -> tuple[list[int], list[int], int]: """Compute memory addresses and sizes for a token range. + Args: + block_ids: Either a single list of block IDs (non-HMA) or a tuple + of per-group block IDs (HMA). + group_idx: Which group's block IDs to use when block_ids is a tuple. + Returns: (addr_list, size_list, block_id) """ addr_list = [] size_list = [] + if isinstance(block_ids, tuple): + block_ids = block_ids[group_idx] block_id = block_ids[start // self.block_size] length = len(self.block_len) for index, base_addr in enumerate(self.kv_caches_base_addr): @@ -154,7 +168,7 @@ class RequestTracker: req_id: str token_len: int - allocated_block_ids: list[int] + allocated_block_ids: tuple[list[int], ...] num_saved_tokens: int = 0 token_ids: list[int] | None = None # Snapshot of the prefill range length at tracker creation time. @@ -167,14 +181,20 @@ def update( new_block_ids: tuple[list[int], ...] | list[int], ) -> None: if len(new_block_ids) == 0: - new_block_ids = [] - elif isinstance(new_block_ids, tuple): - new_block_ids = new_block_ids[0] - elif isinstance(new_block_ids, list): - pass - else: - raise ValueError(f"Unsupported new_block_ids type {type(new_block_ids)}") - self.allocated_block_ids.extend(new_block_ids) + return + if isinstance(new_block_ids, list): + new_block_ids = (new_block_ids,) + if isinstance(self.allocated_block_ids, list): + self.allocated_block_ids = (self.allocated_block_ids,) + if len(self.allocated_block_ids) != len(new_block_ids): + raise ValueError( + f"Block ID length mismatch: " + f"{len(self.allocated_block_ids)} vs {len(new_block_ids)}" + ) + self.allocated_block_ids = tuple( + list(old) + list(new) + for old, new in zip(self.allocated_block_ids, new_block_ids) + ) @dataclass @@ -183,7 +203,7 @@ class ReqMeta: req_id: str token_len_chunk: int - block_ids: list[int] + block_ids: tuple[list[int], ...] block_hashes: list[BlockHash] can_save: bool | None = None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/scheduler.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/scheduler.py index 5ce3278fee8e..fe2df8754d9f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/scheduler.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/scheduler.py @@ -21,8 +21,13 @@ LookupKeyClient, ) from vllm.logger import init_logger +from vllm.utils.math_utils import cdiv from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + SlidingWindowSpec, +) from vllm.v1.request import Request logger = init_logger(__name__) @@ -45,7 +50,11 @@ def _new_req_prefill_tokens(request: NewRequestData) -> list[int]: class MooncakeStoreScheduler: """Scheduler-side component for MooncakeStoreConnector.""" - def __init__(self, vllm_config: VllmConfig): + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_config: "KVCacheConfig | None" = None, + ): assert vllm_config.kv_transfer_config is not None self.kv_role = vllm_config.kv_transfer_config.kv_role self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get( @@ -68,11 +77,34 @@ def __init__(self, vllm_config: VllmConfig): ) ) + # HMA detection and sliding-window block counts per group. + if kv_cache_config is not None: + self._is_hma_required = ( + not vllm_config.scheduler_config.disable_hybrid_kv_cache_manager + and any( + not isinstance(g.kv_cache_spec, FullAttentionSpec) + for g in kv_cache_config.kv_cache_groups + ) + ) + sw_sizes_tokens: list[tuple[int, int]] = [ + (g.kv_cache_spec.sliding_window, g.kv_cache_spec.block_size) + if isinstance(g.kv_cache_spec, SlidingWindowSpec) + else (0, self._block_size) + for g in kv_cache_config.kv_cache_groups + ] + self.blocks_per_sw = [ + cdiv(n_tokens, block_size) + 1 if n_tokens else 0 + for n_tokens, block_size in sw_sizes_tokens + ] + else: + self._is_hma_required = False + self.blocks_per_sw = [] + # Per-request state self.load_specs: dict[str, LoadSpec] = {} # to be loaded self._request_trackers: dict[str, RequestTracker] = {} # scheduled new requests self._preempted_req_ids: set[str] = set() # preempted requests - self._unfinished_requests: dict[str, tuple[Request, list[int]]] = {} + self._unfinished_requests: dict[str, tuple[Request, tuple[list[int], ...]]] = {} self._unfinished_request_ids: set[str] = set() def get_num_new_matched_tokens( @@ -126,9 +158,9 @@ def update_state_after_alloc( num_external_tokens: int, ): """Update state after block allocation.""" - local_block_ids: list[int] = [] + local_block_ids: tuple[list[int], ...] = () if num_external_tokens > 0: - local_block_ids = blocks.get_block_ids()[0] + local_block_ids = blocks.get_block_ids() self._unfinished_requests[request.request_id] = (request, local_block_ids) self._unfinished_request_ids.add(request.request_id) @@ -190,10 +222,9 @@ def build_connector_meta( request_real = request_tuple[0] # type: ignore[index] if not isinstance(request.block_ids[0], list): - unfolded_block_ids = request.block_ids.copy() + unfolded_block_ids = (request.block_ids.copy(),) else: - # TODO: support HMA - unfolded_block_ids = request.block_ids[0].copy() + unfolded_block_ids = tuple(list(g) for g in request.block_ids) prefill_tokens = _new_req_prefill_tokens(request) request_tracker = RequestTracker( @@ -237,9 +268,9 @@ def build_connector_meta( if req_id in self._preempted_req_ids: # Resumed after preemption if isinstance(new_block_ids, tuple): - block_ids_list = new_block_ids[0].copy() + block_ids_list = tuple(list(g) for g in new_block_ids) else: - block_ids_list = new_block_ids.copy() + block_ids_list = (new_block_ids.copy(),) self._preempted_req_ids.discard(req_id) load_spec = self.load_specs.pop(req_id, None) request_tuple = self._unfinished_requests.get(req_id) @@ -358,10 +389,35 @@ def build_connector_meta( return meta + def get_sw_clipped_blocks( + self, + block_ids: tuple[list[int], ...] | list[int], + ) -> tuple[list[int], ...]: + """Clip per-group block IDs to sliding window size. + + For groups with SlidingWindowAttention, only the most recent + ``blocks_per_sw`` blocks are retained. Non-SWA groups keep all + blocks unchanged. + """ + if isinstance(block_ids, list): + block_ids = (block_ids,) + if len(block_ids) == 0 or not self._is_hma_required: + return block_ids + assert len(block_ids) == len(self.blocks_per_sw), ( + f"Block ID group count mismatch: " + f"{len(block_ids)} vs {len(self.blocks_per_sw)}" + ) + return tuple([ + blocks[-self.blocks_per_sw[i]:] + if self.blocks_per_sw[i] > 0 + else blocks + for i, blocks in enumerate(block_ids) + ]) + def request_finished( self, request: Request, - block_ids: list[int], + block_ids: list[int] | tuple[list[int], ...], ) -> tuple[bool, dict[str, Any] | None]: """Determine whether to delay freeing blocks for async save.""" if self.kv_role == "kv_consumer": @@ -370,11 +426,23 @@ def request_finished( assert tracker is not None if tracker.num_saved_tokens <= 0: return False, None - delay_free_blocks = len(block_ids) > 0 + if isinstance(block_ids, list): + block_ids = (block_ids,) + block_ids = self.get_sw_clipped_blocks(block_ids) + delay_free_blocks = any(len(g) > 0 for g in block_ids) if delay_free_blocks: + total_blocks = sum(len(g) for g in block_ids) logger.debug( "Delaying free of %d blocks for request %s", - len(block_ids), + total_blocks, request.request_id, ) return delay_free_blocks, None + + def request_finished_all_groups( + self, + request: Request, + block_ids: tuple[list[int], ...], + ) -> tuple[bool, dict[str, Any] | None]: + """Determine whether to delay freeing blocks for async save (HMA).""" + return self.request_finished(request, block_ids) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/worker.py index 487542c59175..49dc8e35d83c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/store/worker.py @@ -47,6 +47,17 @@ logger = init_logger(__name__) + +def _unwrap_block_ids( + block_ids: list[int] | tuple[list[int], ...], + group_idx: int = 0, +) -> list[int]: + """Extract a single group's block IDs from HMA or non-HMA format.""" + if isinstance(block_ids, tuple): + return block_ids[group_idx] + return block_ids + + DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB DEFAULT_LOCAL_BUFFER_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB MOONCAKE_NO_AVAILABLE_HANDLE = -200 @@ -200,23 +211,33 @@ class KVCacheStoreSendingThread(KVTransferThread): def __init__( self, store: Any, - token_database: ChunkedTokenDatabase, + group_token_databases: list[ChunkedTokenDatabase] + | ChunkedTokenDatabase, block_size: int, tp_rank: int, - put_step: int, + group_put_steps: list[int], + group_sw_blocks: list[int], kv_role: str, ready_event: threading.Event, enable_kv_event: bool = False, ): + if isinstance(group_token_databases, ChunkedTokenDatabase): + group_token_databases = [group_token_databases] super().__init__( store, - token_database, + group_token_databases[0] + if group_token_databases + else ChunkedTokenDatabase( + KeyMetadata("", 0, 0, 0, 0), block_size + ), block_size, tp_rank, ready_event, name="KVCacheStoreSendingThread", ) - self.put_step = put_step + self.group_token_databases = group_token_databases + self.group_put_steps = group_put_steps + self.group_sw_blocks = group_sw_blocks self.kv_role = kv_role self.stored_requests: defaultdict[str, int] = defaultdict(int) self.enable_kv_event = enable_kv_event @@ -225,15 +246,20 @@ def __init__( self._store_pressure_active = False self._skip_store_requests: set[str] = set() - def add_stored_request(self, req_id: str): + def add_stored_request(self, req_id: str, count: int = 1): with self.done_task_lock: - self.stored_requests[req_id] += 1 + self.stored_requests[req_id] += count def dec_stored_request(self, req_id: str): with self.done_task_lock: if req_id in self.stored_requests: self.stored_requests[req_id] -= 1 + def dec_stored_request_by(self, req_id: str, count: int): + with self.done_task_lock: + if req_id in self.stored_requests: + self.stored_requests[req_id] -= count + def delete_finished_stored_request(self, req_id: str): with self.done_task_lock: if req_id in self.stored_requests: @@ -259,15 +285,19 @@ def _clear_store_pressure(self) -> bool: self._skip_store_requests.clear() return True + def _active_group_indices( + self, block_ids: list[int] | tuple[list[int], ...] + ) -> list[int]: + if isinstance(block_ids, list): + return [0] + return list(range(len(block_ids))) + def _handle_request(self, req_meta: ReqMeta): - token_len = req_meta.token_len_chunk - block_ids = req_meta.block_ids req_id = req_meta.req_id - current_event = req_meta.current_event - if req_id not in self.stored_requests: self.request_queue.task_done() return + if self._should_skip_request(req_id): logger.debug( "Skipping Mooncake store for request %s while CPU offloading " @@ -278,12 +308,44 @@ def _handle_request(self, req_meta: ReqMeta): self.request_queue.task_done() return + for group_idx in self._active_group_indices(req_meta.block_ids): + token_database = self.group_token_databases[group_idx] + if not token_database.kv_caches_base_addr: + continue + self._handle_request_for_group(req_meta, group_idx, token_database) + + self.dec_stored_request(req_id) + self.request_queue.task_done() + + def _handle_request_for_group( + self, + req_meta: ReqMeta, + group_idx: int, + token_database: ChunkedTokenDatabase, + ): + token_len = req_meta.token_len_chunk + block_ids = _unwrap_block_ids(req_meta.block_ids, group_idx) + req_id = req_meta.req_id + current_event = req_meta.current_event + + # Compute SWA mask: skip blocks outside the sliding window. + sw_blocks = self.group_sw_blocks[group_idx] + swa_mask = 0 + if sw_blocks > 0: + total_blocks = ( + token_len + token_database.block_size - 1 + ) // token_database.block_size + if total_blocks > sw_blocks: + swa_mask = (total_blocks - sw_blocks) * token_database.block_size + starts = [] ends = [] keys = [] block_hashes: list[BlockHash] = [] for index, (start, end, key) in enumerate( - self.token_database.process_tokens(token_len, req_meta.block_hashes) + token_database.process_tokens( + token_len, req_meta.block_hashes, swa_mask + ) ): starts.append(start) ends.append(end) @@ -291,13 +353,13 @@ def _handle_request(self, req_meta: ReqMeta): block_hashes.append(req_meta.block_hashes[index]) # Apply put_step striding for TP - starts = starts[self.tp_rank % self.put_step :: self.put_step] - ends = ends[self.tp_rank % self.put_step :: self.put_step] - keys = keys[self.tp_rank % self.put_step :: self.put_step] - block_hashes = block_hashes[self.tp_rank % self.put_step :: self.put_step] + put_step = self.group_put_steps[group_idx] + starts = starts[self.tp_rank % put_step :: put_step] + ends = ends[self.tp_rank % put_step :: put_step] + keys = keys[self.tp_rank % put_step :: put_step] + block_hashes = block_hashes[self.tp_rank % put_step :: put_step] if not keys: - self.dec_stored_request(req_id) return # Check which blocks already exist (dedup) @@ -305,7 +367,6 @@ def _handle_request(self, req_meta: ReqMeta): missing_indices = [i for i, exists in enumerate(exists_states) if exists != 1] if not missing_indices: - self.dec_stored_request(req_id) return starts = [starts[i] for i in missing_indices] @@ -315,11 +376,12 @@ def _handle_request(self, req_meta: ReqMeta): logger.debug( "Storing KV cache for %d out of %d blocks " - "(missing_count=%d) for request %s", + "(missing_count=%d) for request %s group %d", len(keys), - token_len // self.block_size, + token_len // token_database.block_size, len(missing_indices), req_id, + group_idx, ) addrs = [] @@ -329,7 +391,7 @@ def _handle_request(self, req_meta: ReqMeta): new_block_hashes = [maybe_convert_block_hash(bh) for bh in block_hashes] for index, start in enumerate(starts): - addr, size, _ = self.token_database.prepare_value( + addr, size, _ = token_database.prepare_value( start, ends[index], block_ids ) addrs.append(addr) @@ -361,7 +423,9 @@ def _handle_request(self, req_meta: ReqMeta): failed = [i for i, v in enumerate(res) if v < 0] if failed: # Compute total bytes attempted for this batch - total_bytes = sum(sum(s) if isinstance(s, list) else s for s in sizes) + total_bytes = sum( + sum(s) if isinstance(s, list) else s for s in sizes + ) failed_codes = set(res[i] for i in failed) logger.warning( "batch_put failed: %d/%d keys failed " @@ -396,9 +460,6 @@ def _handle_request(self, req_meta: ReqMeta): if self.enable_kv_event and stored_events: self.update_kv_event(stored_events) - self.dec_stored_request(req_id) - self.request_queue.task_done() - class KVCacheStoreRecvingThread(KVTransferThread): """Background thread for loading KV cache blocks from the store.""" @@ -406,42 +467,90 @@ class KVCacheStoreRecvingThread(KVTransferThread): def __init__( self, store: Any, - token_database: ChunkedTokenDatabase, + group_token_databases: list[ChunkedTokenDatabase] + | ChunkedTokenDatabase, block_size: int, tp_rank: int, + group_sw_blocks: list[int], ready_event: threading.Event, ): + if isinstance(group_token_databases, ChunkedTokenDatabase): + group_token_databases = [group_token_databases] super().__init__( store, - token_database, + group_token_databases[0] + if group_token_databases + else ChunkedTokenDatabase( + KeyMetadata("", 0, 0, 0, 0), block_size + ), block_size, tp_rank, ready_event, name="KVCacheStoreRecvingThread", ) + self.group_token_databases = group_token_databases + self.group_sw_blocks = group_sw_blocks + + def _active_group_indices( + self, block_ids: list[int] | tuple[list[int], ...] + ) -> list[int]: + if isinstance(block_ids, list): + return [0] + return list(range(len(block_ids))) def _handle_request(self, req_meta: ReqMeta): - token_len = req_meta.load_spec.token_len # type: ignore[union-attr] req_id = req_meta.req_id + + for group_idx in self._active_group_indices(req_meta.block_ids): + token_database = self.group_token_databases[group_idx] + if not token_database.kv_caches_base_addr: + continue + self._handle_request_for_group(req_meta, group_idx, token_database) + + self.set_finished_request(req_id) + self.request_queue.task_done() + + def _handle_request_for_group( + self, + req_meta: ReqMeta, + group_idx: int, + token_database: ChunkedTokenDatabase, + ): + token_len = req_meta.load_spec.token_len # type: ignore[union-attr] mask_num = ( req_meta.load_spec.vllm_cached_tokens # type: ignore[union-attr] - // self.block_size - * self.block_size + // token_database.block_size + * token_database.block_size ) + # Merge SWA mask so we only load blocks within the sliding window. + sw_blocks = self.group_sw_blocks[group_idx] + if sw_blocks > 0: + total_blocks = ( + token_len + token_database.block_size - 1 + ) // token_database.block_size + if total_blocks > sw_blocks: + swa_mask = (total_blocks - sw_blocks) * token_database.block_size + mask_num = max(mask_num, swa_mask) + + block_ids = _unwrap_block_ids(req_meta.block_ids, group_idx) + addr_list = [] size_list = [] key_list = [] - for start, end, key in self.token_database.process_tokens( + for start, end, key in token_database.process_tokens( token_len, req_meta.block_hashes, mask_num ): - addr, size, _ = self.token_database.prepare_value( - start, end, req_meta.block_ids + addr, size, _ = token_database.prepare_value( + start, end, block_ids ) key_list.append(key.to_string()) addr_list.append(addr) size_list.append(size) + if not key_list: + return + # Rotate lists by tp_rank for load balancing key_list_c = ( key_list[self.tp_rank % len(key_list) :] @@ -479,9 +588,6 @@ def _handle_request(self, req_meta: ReqMeta): e, ) - self.set_finished_request(req_id) - self.request_queue.task_done() - # ============================================================ # Store Worker @@ -491,7 +597,11 @@ def _handle_request(self, req_meta: ReqMeta): class MooncakeStoreWorker: """Worker-side component for MooncakeStoreConnector.""" - def __init__(self, vllm_config: VllmConfig): + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_config: "KVCacheConfig | None" = None, + ): try: from mooncake.store import MooncakeDistributedStore # type: ignore except ImportError as e: @@ -537,26 +647,82 @@ def __init__(self, vllm_config: VllmConfig): ): self.use_mla = True - if self.use_mla: - self.num_kv_head = 1 - else: - self.num_kv_head = model_config.get_total_num_kv_heads() - - if self.num_kv_head < self.tp_size: - self.put_step = self.tp_size // self.num_kv_head - self.head_or_tp_rank = self.tp_rank // self.put_step + # Per-group parameters for HMA support. + self._group_block_sizes: dict[int, int] = {} + self._group_put_steps: dict[int, int] = {} + self._group_head_or_tp_ranks: dict[int, int] = {} + self._group_num_kv_heads: dict[int, int] = {} + self._group_sw_blocks: dict[int, int] = {} + if kv_cache_config is not None: + self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) + self._layer_to_group: dict[str, int] = {} + for gid, group in enumerate(kv_cache_config.kv_cache_groups): + for layer_name in group.layer_names: + self._layer_to_group[layer_name] = gid + group_block_size = group.kv_cache_spec.block_size + if self.pcp_size > 1: + group_block_size *= self.pcp_size + if self.dcp_size > 1: + group_block_size *= self.dcp_size + self._group_block_sizes[gid] = group_block_size + + # Compute per-group put_step, head_or_tp_rank and num_kv_head. + spec = group.kv_cache_spec + if hasattr(spec, "num_kv_heads"): + num_kv_head = spec.num_kv_heads + else: + num_kv_head = 1 + self._group_num_kv_heads[gid] = num_kv_head + if num_kv_head < self.tp_size: + put_step = self.tp_size // num_kv_head + head_or_tp_rank = self.tp_rank // put_step + else: + head_or_tp_rank = self.tp_rank + put_step = 1 + self._group_put_steps[gid] = put_step + self._group_head_or_tp_ranks[gid] = head_or_tp_rank + + # Compute per-group sliding window block count. + if hasattr(spec, "sliding_window") and spec.sliding_window: + self._group_sw_blocks[gid] = ( + cdiv(spec.sliding_window, group_block_size) + 1 + ) + else: + self._group_sw_blocks[gid] = 0 else: - self.head_or_tp_rank = self.tp_rank - self.put_step = 1 - + # Fallback for non-HMA: treat all layers as a single group. + self.num_kv_cache_groups = 1 + self._layer_to_group = {} + self._group_block_sizes[0] = self.block_size + if self.use_mla: + num_kv_head = 1 + else: + num_kv_head = model_config.get_total_num_kv_heads() + self._group_num_kv_heads[0] = num_kv_head + if num_kv_head < self.tp_size: + put_step = self.tp_size // num_kv_head + head_or_tp_rank = self.tp_rank // put_step + else: + head_or_tp_rank = self.tp_rank + put_step = 1 + self._group_put_steps[0] = put_step + self._group_head_or_tp_ranks[0] = head_or_tp_rank + self._group_sw_blocks[0] = 0 + + # Default metadata (tp_rank=0); group-specific metadata is created + # per-group in _register_group_caches. self.metadata = KeyMetadata( model_name=model_config.model.rstrip("/").split("/")[-1], - tp_rank=self.head_or_tp_rank, + tp_rank=0, pcp_rank=self.pcp_rank, dcp_rank=self.dcp_rank, pp_rank=self.pp_rank, + group_id=0, ) + # Per-group token databases (single entry for non-HMA) + self.group_token_databases: list[ChunkedTokenDatabase] = [] + # Shared token_database for lookup (key generation is group-agnostic) self.token_database = ChunkedTokenDatabase(self.metadata, self.block_size) # Initialize MooncakeDistributedStore with its own TransferEngine @@ -593,19 +759,105 @@ def __init__(self, vllm_config: VllmConfig): if vllm_config.parallel_config.rank == 0: self.lookup_server = LookupKeyServer(self, vllm_config) + def _update_token_database_refs(self) -> None: + """Update shared and backward-compatible references after registration. + + NOTE: We do NOT overwrite self.token_database here. The lookup + server uses process_tokens() which only needs the correct global + block_size (self.block_size). Overwriting with group 0's database + would break lookup consistency if that group has a different block_size. + """ + self.kv_caches_base_addr = [] + self.block_len = [] + for td in self.group_token_databases: + self.kv_caches_base_addr.extend(td.kv_caches_base_addr) + self.block_len.extend(td.block_len) + + def _start_transfer_threads(self) -> None: + """Start KV cache store/load transfer threads.""" + if self.kv_send_thread is not None or self.kv_recv_thread is not None: + return + + group_sw_blocks = [ + self._group_sw_blocks.get(gid, 0) + for gid in range(self.num_kv_cache_groups) + ] + + if self.kv_role in ["kv_producer", "kv_both"]: + ready_event_sending = threading.Event() + group_put_steps = [ + self._group_put_steps.get(gid, 1) + for gid in range(self.num_kv_cache_groups) + ] + self.kv_send_thread = KVCacheStoreSendingThread( + self.store, + self.group_token_databases, + self.block_size, + self.tp_rank, + group_put_steps, + group_sw_blocks, + self.kv_role, + ready_event_sending, + self.enable_kv_events, + ) + self.kv_send_thread.start() + + ready_event_recving = threading.Event() + self.kv_recv_thread = KVCacheStoreRecvingThread( + self.store, + self.group_token_databases, + self.block_size, + self.tp_rank, + group_sw_blocks, + ready_event_recving, + ) + self.kv_recv_thread.start() + ready_event_recving.wait() + def register_cross_layers_kv_caches(self, kv_cache: torch.Tensor) -> None: """Register a cross-layers KV cache tensor. Wraps the unified tensor in a single-entry dict so that the existing stride-based logic in register_kv_caches() produces the correct single-segment result (block_len = page_size * num_layers). + + NOTE: HMA (multiple KV cache groups) is mutually exclusive with + cross-layer blocks because use_uniform_kv_cache() requires exactly + one attention group. This method should only be called in non-HMA + mode, where register_kv_caches() correctly handles __cross_layer__. """ + assert not self._layer_to_group, ( + "register_cross_layers_kv_caches should not be called in HMA mode. " + "use_uniform_kv_cache() prevents cross-layer when multiple groups exist." + ) self.register_kv_caches({"__cross_layer__": kv_cache}) - def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): - """Register KV cache tensors and start transfer threads.""" - # TODO(yifan): we haven't supported HMA yet. - first_kv_cache = next(iter(kv_caches.values())) + def _register_group_caches( + self, + group_caches: dict[str, torch.Tensor], + group_id: int = 0, + block_size: int | None = None, + put_step: int | None = None, + head_or_tp_rank: int | None = None, + ) -> ChunkedTokenDatabase: + """Register buffers for a single KV cache group and return its database. + + Args: + group_caches: layer_name -> cache tensor for this group. + block_size: The block size for this group. Defaults to self.block_size. + put_step: The put_step for this group. Defaults to the group's + put_step from _group_put_steps or 1. + head_or_tp_rank: The effective TP rank for this group. Defaults to + the group's value from _group_head_or_tp_ranks or self.tp_rank. + """ + if block_size is None: + block_size = self.block_size + if put_step is None: + put_step = 1 + if head_or_tp_rank is None: + head_or_tp_rank = self.tp_rank + + first_kv_cache = next(iter(group_caches.values())) # num_blocks from cache_config is authoritative (set after # profiling, before KV cache allocation). @@ -614,17 +866,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # Detect the KV cache memory layout using the stride-based # approach from simple_kv_offload/worker.py. - # - # The physical layout varies across attention backends: - # FlashAttn/ROCm : (2, num_blocks, ...) → K/V outermost - # FlashInfer/MLA : (num_blocks, ...) → blocks outermost - # - # We derive page_size_bytes = storage.nbytes() // num_blocks, - # then classify dims: any dim whose byte-stride exceeds - # page_size_bytes must be an outer segment dim (e.g. the K/V - # dim of size 2). For those backends we register each segment - # (K, V) as a separate base-address so that the per-block - # offset arithmetic in prepare_value() stays correct. storage = first_kv_cache.untyped_storage() el = first_kv_cache.element_size() page_size_bytes = storage.nbytes() // self.num_blocks @@ -637,10 +878,10 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # Register buffers with the store (deduplicate shared storages) # and record per-segment base addresses for every layer. seen_ptrs: set[int] = set() - self.kv_caches_base_addr: list[int] = [] - self.block_len: list[int] = [] + kv_caches_base_addr: list[int] = [] + block_len: list[int] = [] - for cache in kv_caches.values(): + for cache in group_caches.values(): cache_storage = cache.untyped_storage() base_addr = cache_storage.data_ptr() region_len = cache_storage.nbytes() @@ -658,14 +899,14 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): if not outer_dims: # Blocks-first layout (FlashInfer / MLA): one segment. - self.kv_caches_base_addr.append(base_addr) - self.block_len.append(page_size_bytes) + kv_caches_base_addr.append(base_addr) + block_len.append(page_size_bytes) else: # K/V-first layout (FlashAttn / ROCm): split segments. seg_stride = cache.stride(outer_dims[0]) * el for idx in range(cache.shape[outer_dims[0]]): - self.kv_caches_base_addr.append(base_addr + idx * seg_stride) - self.block_len.append(seg_stride // self.num_blocks) + kv_caches_base_addr.append(base_addr + idx * seg_stride) + block_len.append(seg_stride // self.num_blocks) logger.info( "Registering KV_Caches. use_mla: %s, shape %s, " @@ -675,39 +916,68 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.use_mla, first_kv_cache.shape, self.num_blocks, - list(set(self.block_len)), - sum(self.block_len), - len(self.kv_caches_base_addr), + list(set(block_len)), + sum(block_len), + len(kv_caches_base_addr), ) - self.token_database.set_kv_caches_base_addr(self.kv_caches_base_addr) - self.token_database.set_block_len(self.block_len) + # Create group-specific metadata with the effective TP rank for keygen. + group_metadata = KeyMetadata( + model_name=self.metadata.model_name, + tp_rank=head_or_tp_rank, + pcp_rank=self.metadata.pcp_rank, + dcp_rank=self.metadata.dcp_rank, + pp_rank=self.metadata.pp_rank, + group_id=group_id, + ) + token_database = ChunkedTokenDatabase(group_metadata, block_size) + token_database.set_kv_caches_base_addr(kv_caches_base_addr) + token_database.set_block_len(block_len) + return token_database - # Start transfer threads - if self.kv_role in ["kv_producer", "kv_both"]: - ready_event_sending = threading.Event() - self.kv_send_thread = KVCacheStoreSendingThread( - self.store, - self.token_database, - self.block_size, - self.tp_rank, - self.put_step, - self.kv_role, - ready_event_sending, - self.enable_kv_events, + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """Register KV cache tensors and start transfer threads.""" + # Group caches by their KV cache group ID for HMA support. + group_caches: dict[int, dict[str, torch.Tensor]] = defaultdict(dict) + if self._layer_to_group: + for layer_name, cache in kv_caches.items(): + gid = self._layer_to_group.get(layer_name) + if gid is None: + continue + group_caches[gid][layer_name] = cache + else: + # Non-HMA fallback: all layers go to group 0. + group_caches[0] = dict(kv_caches) + + self.group_token_databases.clear() + for gid in range(self.num_kv_cache_groups): + caches = group_caches.get(gid, {}) + block_size = self._group_block_sizes.get(gid, self.block_size) + put_step = self._group_put_steps.get(gid, 1) + head_or_tp_rank = self._group_head_or_tp_ranks.get( + gid, self.tp_rank ) - self.kv_send_thread.start() + if not caches: + # Create an empty database for groups without layers. + group_metadata = KeyMetadata( + model_name=self.metadata.model_name, + tp_rank=head_or_tp_rank, + pcp_rank=self.metadata.pcp_rank, + dcp_rank=self.metadata.dcp_rank, + pp_rank=self.metadata.pp_rank, + group_id=gid, + ) + self.group_token_databases.append( + ChunkedTokenDatabase(group_metadata, block_size) + ) + continue + token_database = self._register_group_caches( + caches, gid, block_size, put_step, head_or_tp_rank + ) + self.group_token_databases.append(token_database) - ready_event_recving = threading.Event() - self.kv_recv_thread = KVCacheStoreRecvingThread( - self.store, - self.token_database, - self.block_size, - self.tp_rank, - ready_event_recving, - ) - self.kv_recv_thread.start() - ready_event_recving.wait() + self._update_token_database_refs() + self._start_transfer_threads() def start_load_kv( self, @@ -821,20 +1091,19 @@ def _get_and_clear_finished_sending( return finished_sending - def lookup( + def _lookup_single_group( self, token_len: int, block_hashes: list[BlockHash], + num_kv_head: int, + token_database: ChunkedTokenDatabase, ) -> int: - """Check how many prefix tokens exist in the store. - - Checks across all TP ranks and PP ranks. - """ + """Check prefix hit for a single group. Returns token count.""" end = 0 keys: list[str] = [] try: starts: list[int] = [] - for start, end, key in self.token_database.process_tokens( + for start, end, key in token_database.process_tokens( token_len, block_hashes ): keys.append(key.to_string()) @@ -842,7 +1111,7 @@ def lookup( # Expand keys for all TP ranks multi_tp_keys = keys[:] - for i in range(1, min(self.tp_size, self.num_kv_head)): + for i in range(1, min(self.tp_size, num_kv_head)): for item in keys: new_str = item.replace("@tp_rank:0", f"@tp_rank:{i}", 1) multi_tp_keys.append(new_str) @@ -859,7 +1128,7 @@ def lookup( num_block = len(keys) multi_tp_values = [ res[i * num_block : (i + 1) * num_block] - for i in range(min(self.tp_size, self.num_kv_head) * self.pp_size) + for i in range(min(self.tp_size, num_kv_head) * self.pp_size) ] index = self._find_min_first_non_one_index(multi_tp_values) if index != -1: @@ -869,6 +1138,49 @@ def lookup( return 0 return end + def lookup( + self, + token_len: int, + block_hashes: list[BlockHash], + ) -> int: + """Check how many prefix tokens exist in the store. + + For non-HMA (single group) checks across all TP/PP ranks directly. + For HMA, checks each group independently and returns the minimum + prefix length found across all groups, ensuring every group has + the required prefix blocks. + """ + if self.num_kv_cache_groups <= 1: + return self._lookup_single_group( + token_len, + block_hashes, + self._group_num_kv_heads.get(0, 1), + self.token_database, + ) + + # HMA: check each group independently and take the minimum. + min_result = token_len + for gid in range(self.num_kv_cache_groups): + num_kv_head = self._group_num_kv_heads.get(gid, 1) + group_block_size = self._group_block_sizes.get(gid, self.block_size) + group_metadata = KeyMetadata( + model_name=self.metadata.model_name, + tp_rank=0, + pcp_rank=self.pcp_rank, + dcp_rank=self.dcp_rank, + pp_rank=self.pp_rank, + group_id=gid, + ) + token_db = ChunkedTokenDatabase(group_metadata, group_block_size) + result = self._lookup_single_group( + token_len, block_hashes, num_kv_head, token_db + ) + if result < min_result: + min_result = result + if min_result == 0: + break + return min_result + @staticmethod def _find_min_first_non_one_index( arr: list[list[int]],