diff --git a/tests/v1/kv_offload/test_cpu_gpu_expand.py b/tests/v1/kv_offload/test_cpu_gpu_expand.py new file mode 100644 index 000000000000..e04307041759 --- /dev/null +++ b/tests/v1/kv_offload/test_cpu_gpu_expand.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import numpy as np + +from vllm.v1.kv_offload.worker.cpu_gpu import expand_block_ids + + +def test_expand_block_ids_full_blocks(): + output = np.empty(12, dtype=np.int64) + expand_block_ids( + np.array([0, 1, 3], dtype=np.int64), + block_size_factor=4, + output=output, + ) + + np.testing.assert_array_equal( + output, + np.array([0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15], dtype=np.int64), + ) + + +def test_expand_block_ids_partial_ranges(): + output = np.empty(6, dtype=np.int64) + expand_block_ids( + np.array([0, 1], dtype=np.int64), + block_size_factor=8, + output=output, + block_offsets=np.array([2, 0], dtype=np.int64), + block_counts=np.array([3, 3], dtype=np.int64), + ) + + np.testing.assert_array_equal( + output, + np.array([2, 3, 4, 8, 9, 10], dtype=np.int64), + ) + + +def test_expand_block_ids_partial_ranges_can_repeat_same_block(): + output = np.empty(4, dtype=np.int64) + expand_block_ids( + np.array([0, 0], dtype=np.int64), + block_size_factor=8, + output=output, + block_offsets=np.array([0, 4], dtype=np.int64), + block_counts=np.array([2, 2], dtype=np.int64), + ) + + np.testing.assert_array_equal( + output, + np.array([0, 1, 4, 5], dtype=np.int64), + ) diff --git a/tests/v1/kv_offload/test_cpu_gpu_mapping.py b/tests/v1/kv_offload/test_cpu_gpu_mapping.py new file mode 100644 index 000000000000..7bf111d1b6df --- /dev/null +++ b/tests/v1/kv_offload/test_cpu_gpu_mapping.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import numpy as np + +from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec +from vllm.v1.kv_offload.worker.cpu_gpu import build_transfer_indices + + +def test_build_transfer_indices_whole_blocks_preserves_legacy_skip_behavior(): + src_spec = CPULoadStoreSpec([7]) + dst_spec = GPULoadStoreSpec([3, 4, 5], group_sizes=(3,)) + + mapping = build_transfer_indices( + src_spec, + dst_spec, + src_block_size_factor=4, + dst_block_size_factor=1, + ) + + np.testing.assert_array_equal( + mapping, + np.array([[29, 3], [30, 4], [31, 5]], dtype=np.int64), + ) + + +def test_build_transfer_indices_supports_partial_gpu_ranges(): + src_spec = GPULoadStoreSpec( + [0, 1], + group_sizes=(2,), + block_offsets=[2, 0], + block_counts=[3, 3], + ) + dst_spec = GPULoadStoreSpec( + [5, 6], + group_sizes=(2,), + block_offsets=[1, 4], + block_counts=[3, 3], + ) + + mapping = build_transfer_indices( + src_spec, + dst_spec, + src_block_size_factor=8, + dst_block_size_factor=8, + ) + + np.testing.assert_array_equal( + mapping, + np.array( + [[2, 41], [3, 42], [4, 43], [8, 52], [9, 53], [10, 54]], + dtype=np.int64, + ), + ) diff --git a/tests/v1/kv_offload/test_hashing.py b/tests/v1/kv_offload/test_hashing.py new file mode 100644 index 000000000000..12da9367739b --- /dev/null +++ b/tests/v1/kv_offload/test_hashing.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm import SamplingParams +from vllm.utils.hashing import sha256 +from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash +from vllm.v1.kv_offload.hashing import HybridChunkBlockHashList, RequestBlockHashList +from vllm.v1.request import Request + + +def make_request(num_tokens: int, block_size: int = 16) -> Request: + init_none_hash(sha256) + sampling_params = SamplingParams(max_tokens=1) + sampling_params.update_from_generation_config({}, eos_token_id=100) + return Request( + request_id="r0", + prompt_token_ids=list(range(num_tokens)), + sampling_params=sampling_params, + pooling_params=None, + block_hasher=get_request_block_hasher(block_size, sha256), + ) + + +def test_request_block_hash_list_matches_request_hashes_when_sizes_match(): + request = make_request(64, block_size=16) + direct_hashes = list(RequestBlockHashList(request, 16, sha256)) + + assert direct_hashes == request.block_hashes + + +def test_request_block_hash_list_supports_arbitrary_block_sizes(): + request = make_request(65536, block_size=1056) + direct_hashes = RequestBlockHashList(request, 16384, sha256) + + assert len(direct_hashes) == 4 + assert direct_hashes[0] != direct_hashes[1] + + +def test_hybrid_chunk_block_hash_list_uses_per_group_granularity(): + request = make_request(65536, block_size=1056) + hash_list = HybridChunkBlockHashList( + request, + group_block_sizes=(16384, 16384, 16384, 1056), + logical_chunk_size=16384, + hash_function=sha256, + ) + + assert len(hash_list) == 4 + assert hash_list[0] != hash_list[1] + + +def test_hybrid_chunk_block_hash_list_caches_chunk_hashes(): + """Accessing the same index twice should return the cached value.""" + request = make_request(65536, block_size=1056) + hash_list = HybridChunkBlockHashList( + request, + group_block_sizes=(16384, 1056), + logical_chunk_size=16384, + hash_function=sha256, + ) + + # Cache starts empty + assert len(hash_list._chunk_hashes) == 0 + + # Access index 0: should populate the cache + h0 = hash_list[0] + assert len(hash_list._chunk_hashes) == 1 + assert hash_list._chunk_hashes[0] == h0 + + # Access index 1: cache grows + h1 = hash_list[1] + assert len(hash_list._chunk_hashes) == 2 + + # Re-access index 0: served from cache, identical value + assert hash_list[0] == h0 + + # Re-access index 1: served from cache + assert hash_list[1] == h1 + + # Cache does not grow on repeated access + assert len(hash_list._chunk_hashes) == 2 + + +def test_hybrid_chunk_block_hash_list_skips_leading_unhashable_chunks(): + request = make_request(100000, block_size=1056) + hash_list = HybridChunkBlockHashList( + request, + group_block_sizes=(50000, 16384, 1056), + logical_chunk_size=16384, + hash_function=sha256, + ) + + assert hash_list.first_hashable_chunk_idx == 3 + assert len(hash_list) == 3 diff --git a/tests/v1/kv_offload/test_planner.py b/tests/v1/kv_offload/test_planner.py new file mode 100644 index 000000000000..285baebca8f1 --- /dev/null +++ b/tests/v1/kv_offload/test_planner.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from vllm.v1.kv_offload.planner import HybridOffloadPlanner + + +def test_fixed_chunk_marks_large_groups_as_partial(): + planner = HybridOffloadPlanner( + hash_block_size=16, + gpu_block_sizes=(65536, 65536, 65536, 1056), + fixed_chunk_size=16384, + ) + + assert planner.offload_unit_sizes == (16384, 16384, 16384, 1056) + assert planner.requires_partial_group_offload == (True, True, True, False) + assert planner.group_hash_factors == (1024, 1024, 1024, 66) + + +def test_fixed_chunk_rejects_non_positive_size(): + with pytest.raises(ValueError, match="must be positive"): + HybridOffloadPlanner( + hash_block_size=16, + gpu_block_sizes=(65536, 1056), + fixed_chunk_size=0, + ) + + +def test_fixed_chunk_rejects_smaller_than_hash_block_size(): + with pytest.raises(ValueError, match="greater than or equal to hash_block_size"): + HybridOffloadPlanner( + hash_block_size=1056, + gpu_block_sizes=(65536, 1056), + fixed_chunk_size=1024, + ) + + +def test_fixed_chunk_leaves_indivisible_large_groups_unsplit(): + planner = HybridOffloadPlanner( + hash_block_size=16, + gpu_block_sizes=(65536, 50000, 1056), + fixed_chunk_size=16384, + ) + + assert planner.offload_unit_sizes == (16384, 50000, 1056) + assert planner.requires_partial_group_offload == (True, False, False) + assert planner.first_hashable_chunk_idx == 3 + assert planner.chunk_count_for_tokens(16_384) == 0 + assert planner.chunk_count_for_tokens(50_000) == 1 + + +def test_storable_prefix_uses_common_fully_covered_units(): + planner = HybridOffloadPlanner( + hash_block_size=16, + gpu_block_sizes=(65536, 65536, 65536, 1056), + fixed_chunk_size=16384, + ) + + assert planner.storable_prefix_tokens(10_000) == 0 + assert planner.storable_prefix_tokens(16_384) == 15_840 + assert planner.storable_prefix_tokens(20_000) == 16_384 + assert planner.storable_prefix_tokens(33_000) == 32_736 + + +def test_loadable_prefix_reconciles_existing_group_coverage(): + planner = HybridOffloadPlanner( + hash_block_size=16, + gpu_block_sizes=(65536, 65536, 65536, 1056), + fixed_chunk_size=16384, + ) + + assert planner.loadable_prefix_tokens((16384, 16384, 16384, 15840)) == 15840 + assert planner.loadable_prefix_tokens((32768, 32768, 32768, 32736)) == 32736 + assert planner.loadable_prefix_tokens((16384, 0, 16384, 15840)) == 0 + + +def test_planner_reports_partial_group_requirement(): + planner = HybridOffloadPlanner( + hash_block_size=16, + gpu_block_sizes=(65536, 1056), + fixed_chunk_size=16384, + ) + + assert planner.requires_partial_group_offload_any is True + + +def test_planner_allows_engine_hash_size_to_differ_from_hybrid_chunk(): + planner = HybridOffloadPlanner( + hash_block_size=1056, + gpu_block_sizes=(65536, 65536, 65536, 1056), + fixed_chunk_size=16384, + ) + + assert planner.offload_unit_sizes == (16384, 16384, 16384, 1056) + assert planner.group_hash_factors == (None, None, None, 1) + assert planner.chunk_prefix_tokens(1) == 15840 + + +def test_chunk_prefix_tokens_uses_common_covered_prefix(): + planner = HybridOffloadPlanner( + hash_block_size=16, + gpu_block_sizes=(65536, 65536, 65536, 1056), + fixed_chunk_size=16384, + ) + + assert planner.chunk_prefix_tokens(0) == 0 + assert planner.chunk_prefix_tokens(1) == 15840 + assert planner.chunk_prefix_tokens(2) == 32736 + assert planner.chunk_prefix_tokens(4) == 65472 + + +def test_chunk_count_for_tokens_inverts_common_prefix_boundaries(): + planner = HybridOffloadPlanner( + hash_block_size=16, + gpu_block_sizes=(65536, 65536, 65536, 1056), + fixed_chunk_size=16384, + ) + + assert planner.chunk_count_for_tokens(0) == 0 + assert planner.chunk_count_for_tokens(15839) == 0 + assert planner.chunk_count_for_tokens(15840) == 1 + assert planner.chunk_count_for_tokens(32735) == 1 + assert planner.chunk_count_for_tokens(32736) == 2 diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index 64aee2bd9c49..c613a5f71a59 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -16,6 +16,7 @@ KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, + SupportsHMA, ) from vllm.logger import init_logger from vllm.v1.attention.backend import AttentionMetadata @@ -69,7 +70,7 @@ def __repr__(self) -> str: return f"" -class LMCacheConnectorV1(KVConnectorBase_V1): +class LMCacheConnectorV1(KVConnectorBase_V1, SupportsHMA): @classmethod def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool: """ @@ -339,6 +340,18 @@ def request_finished( """ return self._lmcache_engine.request_finished(request, block_ids) + def request_finished_all_groups( + self, + request: "Request", + block_ids: tuple[list[int], ...], + ) -> tuple[bool, dict[str, Any] | None]: + """ + Called when a request has finished for all kv cache groups. + Flatten per-group block IDs and delegate to request_finished. + """ + flat_block_ids = [bid for group_ids in block_ids for bid in group_ids] + return self._lmcache_engine.request_finished(request, flat_block_ids) + def take_events(self) -> Iterable["KVCacheEvent"]: """ Take the KV cache events from the connector. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 3888d2e0f44c..c48845c29aee 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -18,6 +18,7 @@ KVConnectorMetadata, KVConnectorRole, KVConnectorWorkerMetadata, + SupportsHMA, ) from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( KVConnectorPromMetrics, @@ -39,6 +40,70 @@ logger = init_logger(__name__) +# Key used in MultiKVConnectorStats.data for per-connector selection metrics. +_SELECTION_KEY = "__selection__" + + +@dataclass +class _SelectionStats(KVConnectorStats): + """Per-connector selection statistics accumulated by MultiConnector. + + Tracks how often each child connector is queried, wins the weighted + selection, contributes matched tokens, and misses. Flows through the + existing KVConnectorStats pipeline so that counters are registered in + the APIServer process (the one that serves /metrics) rather than in the + EngineCore subprocess. + """ + + def __post_init__(self): + if not self.data: + self.reset() + + def reset(self): + # {connector_name: {"queries": int, "hits": int, + # "hit_tokens": int, "misses": int}} + self.data = {} + + def _ensure(self, name: str): + if name not in self.data: + self.data[name] = { + "queries": 0, + "hits": 0, + "hit_tokens": 0, + "misses": 0, + } + + def record_query(self, name: str): + self._ensure(name) + self.data[name]["queries"] += 1 + + def record_hit(self, name: str, tokens: int): + self._ensure(name) + self.data[name]["hits"] += 1 + self.data[name]["hit_tokens"] += tokens + + def record_miss(self, name: str): + self._ensure(name) + self.data[name]["misses"] += 1 + + def aggregate(self, other: "KVConnectorStats") -> "KVConnectorStats": + if not isinstance(other, _SelectionStats): + return self + for name, counts in other.data.items(): + if name not in self.data: + self.data[name] = dict(counts) + else: + for k, v in counts.items(): + self.data[name][k] = self.data[name].get(k, 0) + v + return self + + def reduce(self) -> dict[str, Any]: + # Return a shallow copy so the caller can't mutate our state. + return dict(self.data) + + def is_empty(self) -> bool: + return not self.data + @dataclass class MultiKVConnectorMetadata(KVConnectorMetadata): @@ -114,16 +179,91 @@ def __init__( super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues) self._prom_metrics = prom_metrics + # Per-connector selection counters. Labels: model_name, engine, + # connector. Registered here (APIServer process) so they appear in + # the /metrics endpoint. + _sel_labels = labelnames + ["connector"] + self._counter_mc_queries = self._counter_cls( + name="vllm:kv_connector_mc_queries_total", + documentation=( + "Total cache-lookup queries issued to each child connector " + "by MultiConnector." + ), + labelnames=_sel_labels, + ) + self._counter_mc_hits = self._counter_cls( + name="vllm:kv_connector_mc_hits_total", + documentation=( + "Number of times each child connector won the weighted " + "selection and will serve the load." + ), + labelnames=_sel_labels, + ) + self._counter_mc_hit_tokens = self._counter_cls( + name="vllm:kv_connector_mc_hit_tokens_total", + documentation="Total tokens matched by the winning child connector.", + labelnames=_sel_labels, + ) + self._counter_mc_misses = self._counter_cls( + name="vllm:kv_connector_mc_misses_total", + documentation=( + "Number of requests where each child connector had no cache hit." + ), + labelnames=_sel_labels, + ) + # Cache of labeled metric instances keyed by (engine_idx, connector_name). + self._mc_queries: dict[tuple[int, str], Any] = {} + self._mc_hits: dict[tuple[int, str], Any] = {} + self._mc_hit_tokens: dict[tuple[int, str], Any] = {} + self._mc_misses: dict[tuple[int, str], Any] = {} + + def _observe_selection( + self, + per_connector: dict[str, dict[str, int]], + engine_idx: int, + ) -> None: + """Update per-connector selection counters from a _SelectionStats dict.""" + for conn_name, counts in per_connector.items(): + key = (engine_idx, conn_name) + if key not in self._mc_queries: + label_vals = self.per_engine_labelvalues[engine_idx] + [conn_name] + self._mc_queries[key] = self._counter_mc_queries.labels(*label_vals) + self._mc_hits[key] = self._counter_mc_hits.labels(*label_vals) + self._mc_hit_tokens[key] = self._counter_mc_hit_tokens.labels( + *label_vals + ) + self._mc_misses[key] = self._counter_mc_misses.labels(*label_vals) + self._mc_queries[key].inc(counts.get("queries", 0)) + self._mc_hits[key].inc(counts.get("hits", 0)) + self._mc_hit_tokens[key].inc(counts.get("hit_tokens", 0)) + self._mc_misses[key].inc(counts.get("misses", 0)) + def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0): + # Handle MultiConnector's own selection metrics first. + selection_data = transfer_stats_data.get(_SELECTION_KEY) + if selection_data is not None: + # Cross-process: msgspec serialises the dataclass as a dict with a + # "data" field. Same-process: the _SelectionStats object itself. + per_conn = ( + selection_data["data"] + if isinstance(selection_data, dict) + else selection_data.data + ) + self._observe_selection(per_conn, engine_idx) + + # Route child-connector stats. for connector_id, stats_data in transfer_stats_data.items(): + if connector_id == _SELECTION_KEY: + continue assert connector_id in self._prom_metrics, ( - f"{connector_id} is not contained in the list of registered connectors " - f"with Prometheus metrics support: {self._prom_metrics.keys()}" + f"{connector_id} is not contained in the list of registered " + f"connectors with Prometheus metrics support: " + f"{self._prom_metrics.keys()}" ) self._prom_metrics[connector_id].observe(stats_data["data"], engine_idx) -class MultiConnector(KVConnectorBase_V1): +class MultiConnector(KVConnectorBase_V1, SupportsHMA): """ A wrapper for using multiple KVConnectors at the same time. @@ -166,6 +306,44 @@ def __init__( self._connectors.append(connector_cls(temp_config, role, kv_cache_config)) self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config) + # Per-connector load weights for weighted selection. + # Higher weight means a connector's hit is preferred even if + # another offers more tokens (e.g. fast CPU cache vs slow disk). + # Configured via "load_weight" in each connector's config. + assert vllm_config.kv_transfer_config is not None + connectors_cfg = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "connectors", [] + ) + self._load_weights: list[float] = [ + float(cfg.get("kv_connector_extra_config", {}).get("load_weight", 1.0)) + for cfg in connectors_cfg + ] + # Pad if config is shorter than connectors (shouldn't happen) + while len(self._load_weights) < len(self._connectors): + self._load_weights.append(1.0) + + # Validate HMA: MultiConnector advertises SupportsHMA, but this + # only works if all children also support it. + if not vllm_config.scheduler_config.disable_hybrid_kv_cache_manager: + non_hma = [ + type(c).__name__ + for c in self._connectors + if not isinstance(c, SupportsHMA) + ] + if non_hma: + raise TypeError( + f"MultiConnector has HMA enabled but these child " + f"connectors do not support it: {non_hma}. Either " + f"use --disable-hybrid-kv-cache-manager or replace " + f"the non-HMA connectors." + ) + + # Human-readable names for per-connector Prometheus labels. + self._connector_names: list[str] = [type(c).__name__ for c in self._connectors] + + # Per-connector selection stats; flushed via get_kv_connector_stats(). + self._selection_stats = _SelectionStats() + # A mapping from request id to the index of the connector chosen to # load the request from (if any). self._requests_to_connector: dict[str, int] = {} @@ -315,11 +493,24 @@ def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): for c in self._connectors: c.set_host_xfer_buffer_ops(copy_operation) - def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata): - """Handle preempted requests for all sub-connectors.""" - assert isinstance(kv_connector_metadata, MultiKVConnectorMetadata) - for c, cm in zip(self._connectors, kv_connector_metadata.metadata): - c.handle_preemptions(cm) + def handle_preemptions( + self, + kv_connector_metadata: KVConnectorMetadata | set[str], + ): + """Handle preempted requests for all sub-connectors. + + Stock vLLM 0.18.0 passes ``set[str]`` (preempted request IDs), + while the MultiConnector metadata path passes + ``MultiKVConnectorMetadata``. Accept both. + """ + if isinstance(kv_connector_metadata, set): + # Stock vLLM path — forward the raw set to every child. + for c in self._connectors: + c.handle_preemptions(kv_connector_metadata) # type: ignore[arg-type] + else: + assert isinstance(kv_connector_metadata, MultiKVConnectorMetadata) + for c, cm in zip(self._connectors, kv_connector_metadata.metadata): + c.handle_preemptions(cm) def get_finished_count(self) -> int | None: # TODO(https://github.com/vllm-project/vllm/issues/33400) @@ -347,26 +538,69 @@ def build_connector_worker_meta(self) -> KVConnectorWorkerMetadata | None: # ============================== # Scheduler-side methods # ============================== + def get_timed_out_loads(self) -> set[str]: + """Aggregate timed-out loads from all child connectors.""" + result: set[str] = set() + for c in self._connectors: + if hasattr(c, "get_timed_out_loads"): + result.update(c.get_timed_out_loads()) + return result + def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int, ) -> tuple[int | None, bool]: - to_return = (0, False) + # Weighted selection: each connector's hit is scored as + # tokens * load_weight. The connector with the highest + # weighted score wins. This lets a fast CPU cache (high + # weight) beat a slow disk cache unless the disk hit is + # substantially larger. + best_idx = -1 + best_score = 0.0 + best_result: tuple[int, bool] = (0, False) + + # Track which connectors gave a definitive answer vs deferred. + per_connector_results: list[tuple[int, bool] | None] = [] + any_resolved = False for i, c in enumerate(self._connectors): toks, load_async = c.get_num_new_matched_tokens( request, num_computed_tokens ) - # If there is a connector still looking up the matches, - # we return None to indicate that we are not done yet. if toks is None: - return (None, False) - # The first connector that has new matched tokens will be assigned - # to this request. - if to_return[0] == 0 and toks > 0: - self._requests_to_connector[request.request_id] = i - to_return = (toks, load_async) - return to_return + # Connector is still resolving (e.g. backpressured). + # Skip it — don't block connectors that already answered. + per_connector_results.append(None) + continue + any_resolved = True + name = self._connector_names[i] + self._selection_stats.record_query(name) + per_connector_results.append((toks, load_async)) + if toks > 0: + score = toks * self._load_weights[i] + if score > best_score: + best_score = score + best_idx = i + best_result = (toks, load_async) + + # Only defer if ALL connectors returned None. + if not any_resolved: + return (None, False) + + if best_idx >= 0: + winner_name = self._connector_names[best_idx] + self._requests_to_connector[request.request_id] = best_idx + self._selection_stats.record_hit(winner_name, best_result[0]) + for i, result in enumerate(per_connector_results): + if i != best_idx and result is not None and result[0] == 0: + self._selection_stats.record_miss(self._connector_names[i]) + else: + # No connector had a hit (resolved ones all returned 0). + for i, result in enumerate(per_connector_results): + if result is not None: + self._selection_stats.record_miss(self._connector_names[i]) + + return best_result def update_state_after_alloc( self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int @@ -440,11 +674,25 @@ def request_finished( self, request: "Request", blocks: list[int], + ) -> tuple[bool, dict[str, Any] | None]: + return self.request_finished_all_groups(request, (blocks,)) + + def request_finished_all_groups( + self, + request: "Request", + block_ids: tuple[list[int], ...], ) -> tuple[bool, dict[str, Any] | None]: async_saves = 0 kv_txfer_params = None for c in self._connectors: - async_save, txfer_params = c.request_finished(request, blocks) + if isinstance(c, SupportsHMA): + async_save, txfer_params = c.request_finished_all_groups( + request, block_ids + ) + else: + # Flatten block_ids for non-HMA connectors + flat = [bid for group in block_ids for bid in group] + async_save, txfer_params = c.request_finished(request, flat) if async_save: async_saves += 1 if txfer_params is not None: @@ -510,8 +758,23 @@ def build_kv_connector_stats( # 1. Already-instantiated KVConnectorStats objects (same process) # 2. Serialized dicts (cross-process after serialization) # We need to reconstruct proper KVConnectorStats objects from dicts - reconstructed_data = {} + reconstructed_data: dict[str, KVConnectorStats] = {} for connector_name, stats_value in data.items(): + # Selection stats are internal to MultiConnector — reconstruct + # directly without going through KVConnectorFactory. + if connector_name == _SELECTION_KEY: + if isinstance(stats_value, _SelectionStats): + reconstructed_data[connector_name] = stats_value + else: + assert isinstance(stats_value, dict) and "data" in stats_value, ( + f"Expected a dict with a 'data' field for " + f"{_SELECTION_KEY!r}, got {stats_value!r}" + ) + reconstructed_data[connector_name] = _SelectionStats( + data=stats_value["data"] + ) + continue + # If already a KVConnectorStats object, use it directly if isinstance(stats_value, KVConnectorStats): reconstructed_data[connector_name] = stats_value @@ -549,6 +812,14 @@ def get_kv_connector_stats(self) -> MultiKVConnectorStats | None: # Lazy init to allow optional return value. stats_by_connector = MultiKVConnectorStats() stats_by_connector[c.__class__.__name__] = stats + + # Attach accumulated selection metrics, then reset for the next window. + if not self._selection_stats.is_empty(): + if stats_by_connector is None: + stats_by_connector = MultiKVConnectorStats() + stats_by_connector[_SELECTION_KEY] = self._selection_stats + self._selection_stats = _SelectionStats() + return stats_by_connector @classmethod diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py index c28fe5e96593..1d2799e07db5 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time from collections import defaultdict from collections.abc import Iterable from itertools import islice @@ -14,10 +15,12 @@ ) from vllm.logger import init_logger from vllm.v1.core.kv_cache_manager import KVCacheBlocks -from vllm.v1.core.kv_cache_utils import BlockHash +from vllm.v1.core.kv_cache_utils import BlockHash, BlockHashListWithBlockSize from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_offload.abstract import OffloadingManager +from vllm.v1.kv_offload.hashing import HybridChunkBlockHashList from vllm.v1.kv_offload.mediums import GPULoadStoreSpec +from vllm.v1.kv_offload.planner import HybridOffloadPlanner from vllm.v1.kv_offload.spec import OffloadingSpec from vllm.v1.kv_offload.worker.worker import TransferSpec from vllm.v1.outputs import KVConnectorOutput @@ -30,15 +33,35 @@ class OffloadingConnectorScheduler: """Implementation of Scheduler side methods""" def __init__(self, spec: OffloadingSpec): - assert len(spec.gpu_block_size) == 1 - self.gpu_block_size = spec.gpu_block_size[0] - self.offloaded_block_size = self.gpu_block_size * spec.block_size_factor - self.block_size_factor = spec.block_size_factor + self.hybrid_offload_enabled = spec.hybrid_offload_enabled + self.hybrid_planner: HybridOffloadPlanner | None = spec.hybrid_planner + self.requires_partial_group_offload = spec.requires_partial_group_offload + self.gpu_block_sizes = tuple(spec.gpu_block_size) + self.group_hash_block_sizes = tuple(spec.group_hash_block_size) + self.hash_function = spec.hash_function + self.num_kv_groups = len(self.gpu_block_sizes) + self.hash_block_size = spec.hash_block_size + self.offloaded_block_size = spec.offloaded_block_size + self.max_concurrent_loads = spec.vllm_config.scheduler_config.max_num_seqs + self.hash_block_size_factor: int | None = None + if not self.hybrid_offload_enabled: + assert self.offloaded_block_size % self.hash_block_size == 0 + self.hash_block_size_factor = ( + self.offloaded_block_size // self.hash_block_size + ) + self.block_size_factors = tuple(spec.block_size_factors) + if self.hybrid_offload_enabled: + for gpu_block_size, unit_size in zip( + self.gpu_block_sizes, self.group_hash_block_sizes + ): + assert gpu_block_size % unit_size == 0, ( + "Hybrid GPU block size must be divisible by group offload unit size" + ) self.manager: OffloadingManager = spec.get_manager() self._requests: dict[ReqId, Request] = {} # list of GPU block IDs per request - self._request_block_ids: dict[ReqId, list[int]] = {} + self._request_block_ids: dict[ReqId, tuple[list[int], ...]] = {} # requests to load for the current scheduler step self._reqs_to_load: dict[ReqId, TransferSpec] = {} # request blocks are stored in order @@ -54,18 +77,174 @@ def __init__(self, spec: OffloadingSpec): self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set) self._reqs_being_loaded = defaultdict[ReqId, set[BlockHash]](set) + # Hybrid mode: one HybridChunkBlockHashList per active request, reused + # across scheduler steps so RequestBlockHashList lazily caches computed + # group-level hashes instead of recomputing them from scratch each step. + self._hybrid_hash_lists: dict[ReqId, HybridChunkBlockHashList] = {} + + # Load timeout: if a load takes longer than this, cancel it and + # fall back to recompute. Prevents requests from stalling + # indefinitely on slow NFS or hung storage. + self._load_timeout_seconds = float( + spec.extra_config.get("load_timeout_seconds", 30.0) + ) + self._load_start_times: dict[ReqId, float] = {} + + def _chunk_prefix_tokens(self, chunk_count: int) -> int: + if not self.hybrid_offload_enabled: + return chunk_count * self.offloaded_block_size + return self._get_hybrid_planner().chunk_prefix_tokens(chunk_count) + + def _chunk_count_for_tokens(self, tokens: int) -> int: + if not self.hybrid_offload_enabled: + return tokens // self.offloaded_block_size + planner = self._get_hybrid_planner() + return planner.chunk_count_for_tokens(tokens) + + def _get_hybrid_planner(self): + planner = self.hybrid_planner + if planner is None: + raise RuntimeError("Hybrid offload planner is not configured") + return planner + + def _empty_block_groups(self) -> tuple[list[int], ...]: + return tuple([] for _ in range(self.num_kv_groups)) + + @staticmethod + def _flatten_block_groups( + block_groups: tuple[list[int], ...] | list[list[int]], + ) -> tuple[list[int], tuple[int, ...]]: + flat_block_ids: list[int] = [] + group_sizes: list[int] = [] + for group_block_ids in block_groups: + group_sizes.append(len(group_block_ids)) + flat_block_ids.extend(group_block_ids) + return flat_block_ids, tuple(group_sizes) + + def _append_block_groups( + self, + req_id: ReqId, + new_block_id_groups: tuple[list[int], ...], + ) -> None: + existing = self._request_block_ids[req_id] + assert len(existing) == len(new_block_id_groups) == self.num_kv_groups + for group_index, group_block_ids in enumerate(new_block_id_groups): + existing[group_index].extend(group_block_ids) + + def _build_gpu_transfer_spec_from_chunk_range( + self, + block_groups: tuple[list[int], ...] | list[list[int]], + start_chunk_idx: int, + end_chunk_idx: int, + include_block_indices: bool = False, + ) -> GPULoadStoreSpec: + if not self.hybrid_offload_enabled: + block_groups_list: list[list[int]] = [] + block_indices: list[int] = [] + for group_index, group_block_ids in enumerate(block_groups): + group_factor = self.block_size_factors[group_index] + start_gpu_block_idx = start_chunk_idx * group_factor + end_gpu_block_idx = end_chunk_idx * group_factor + if include_block_indices: + block_indices.append(start_gpu_block_idx) + block_groups_list.append( + group_block_ids[start_gpu_block_idx:end_gpu_block_idx] + ) + flat_block_ids, group_sizes = self._flatten_block_groups(block_groups_list) + return GPULoadStoreSpec( + flat_block_ids, + group_sizes=group_sizes, + block_indices=tuple(block_indices) if include_block_indices else None, + ) + + flat_block_ids: list[int] = [] # type: ignore[no-redef] + flat_block_offsets: list[int] = [] + flat_block_counts: list[int] = [] + group_sizes: list[int] = [] # type: ignore[no-redef] + block_indices: list[int] = [] # type: ignore[no-redef] + planner = self._get_hybrid_planner() + if start_chunk_idx == 0: + group_start_tokens = tuple(0 for _ in range(self.num_kv_groups)) + else: + group_start_tokens = planner.group_covered_tokens_for_chunk_count( + start_chunk_idx + ) + group_end_tokens = planner.group_covered_tokens_for_chunk_count(end_chunk_idx) + + for group_index, group_block_ids in enumerate(block_groups): + unit_size = self.group_hash_block_sizes[group_index] + gpu_block_size = self.gpu_block_sizes[group_index] + sub_blocks_per_gpu_block = gpu_block_size // unit_size + + start_unit_idx = group_start_tokens[group_index] // unit_size + end_unit_idx = group_end_tokens[group_index] // unit_size + assert end_unit_idx >= start_unit_idx + + group_entry_count = 0 + if include_block_indices: + block_indices.append(start_unit_idx // sub_blocks_per_gpu_block) + + unit_idx = start_unit_idx + while unit_idx < end_unit_idx: + gpu_block_idx = unit_idx // sub_blocks_per_gpu_block + sub_block_offset = unit_idx % sub_blocks_per_gpu_block + sub_block_count = min( + sub_blocks_per_gpu_block - sub_block_offset, + end_unit_idx - unit_idx, + ) + flat_block_ids.append(group_block_ids[gpu_block_idx]) + flat_block_offsets.append(sub_block_offset) + flat_block_counts.append(sub_block_count) + group_entry_count += 1 + unit_idx += sub_block_count + + group_sizes.append(group_entry_count) # type: ignore[attr-defined] + + return GPULoadStoreSpec( + flat_block_ids, + group_sizes=tuple(group_sizes), + block_indices=tuple(block_indices) if include_block_indices else None, + block_offsets=flat_block_offsets, + block_counts=flat_block_counts, + ) + def _get_block_hashes( self, req: Request, start_idx: int = 0, end_idx: int | None = None, ) -> Iterable[BlockHash]: - return islice( + if self.hybrid_offload_enabled: + # Reuse a cached HybridChunkBlockHashList so that + # RequestBlockHashList's lazily-computed per-group hashes survive + # across multiple calls within a step and across scheduler steps. + # Without caching, each call rebuilds the list and recomputes all + # previously-seen group-level hashes from scratch. + req_id = req.request_id + offloaded_hashes = self._hybrid_hash_lists.get(req_id) + if offloaded_hashes is None: + offloaded_hashes = HybridChunkBlockHashList( + req, + self.group_hash_block_sizes, + self.offloaded_block_size, + self.hash_function, + ) + self._hybrid_hash_lists[req_id] = offloaded_hashes + return islice(offloaded_hashes, start_idx, end_idx) + + simple_hashes = BlockHashListWithBlockSize( req.block_hashes, - self.block_size_factor * start_idx + self.block_size_factor - 1, - self.block_size_factor * end_idx if end_idx else None, - self.block_size_factor, + self.hash_block_size, + self.offloaded_block_size, ) + return islice(simple_hashes, start_idx, end_idx) + + def _get_num_offloaded_blocks(self, request: Request) -> int: + if self.hybrid_offload_enabled: + return self._chunk_count_for_tokens(request.num_tokens) + + assert self.hash_block_size_factor is not None + return len(request.block_hashes) // self.hash_block_size_factor def get_num_new_matched_tokens( self, request: Request, num_computed_tokens: int @@ -89,19 +268,28 @@ def get_num_new_matched_tokens( - `True` if tokens will be loaded asynchronously (between scheduler steps). """ - num_blocks = request.num_tokens // self.offloaded_block_size + # Backpressure: if too many requests are already loading from + # external storage, defer this one to the next scheduler step. + # Without this cap, a burst of concurrent loads can queue + # hundreds of I/O tasks and stall the EngineCore. + if len(self._reqs_being_loaded) >= self.max_concurrent_loads: + return None, False - assert len(request.block_hashes) // self.block_size_factor == num_blocks + num_blocks = self._get_num_offloaded_blocks(request) block_hashes = self._get_block_hashes(request) self.manager.touch(block_hashes) - full_block_tokens = self.offloaded_block_size * num_blocks - if full_block_tokens - num_computed_tokens < self.offloaded_block_size: - # we can load less than a block, skip - return 0, False - - start_block_idx = num_computed_tokens // self.offloaded_block_size + full_block_tokens = self._chunk_prefix_tokens(num_blocks) + if self.hybrid_offload_enabled: + if full_block_tokens <= num_computed_tokens: + return 0, False + start_block_idx = self._chunk_count_for_tokens(num_computed_tokens) + else: + if full_block_tokens - num_computed_tokens < self.offloaded_block_size: + # we can load less than a block, skip + return 0, False + start_block_idx = num_computed_tokens // self.offloaded_block_size hits = self.manager.lookup( self._get_block_hashes(request, start_idx=start_block_idx) ) @@ -112,7 +300,7 @@ def get_num_new_matched_tokens( return 0, False num_hit_tokens = ( - self.offloaded_block_size * (start_block_idx + hits) - num_computed_tokens + self._chunk_prefix_tokens(start_block_idx + hits) - num_computed_tokens ) logger.debug( "Request %s hit %s offloaded tokens after %s GPU hit tokens", @@ -120,7 +308,12 @@ def get_num_new_matched_tokens( num_hit_tokens, num_computed_tokens, ) - if num_hit_tokens < self.offloaded_block_size: + min_hit_tokens = ( + self.hash_block_size + if self.hybrid_offload_enabled + else self.offloaded_block_size + ) + if num_hit_tokens < min_hit_tokens: return 0, False if self._blocks_being_loaded: @@ -145,47 +338,68 @@ def update_state_after_alloc( self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int ): self._requests[request.request_id] = request - # the block ids are updated in _get_reqs_to_store - self._request_block_ids[request.request_id] = [] + self._request_block_ids[request.request_id] = self._empty_block_groups() if num_external_tokens == 0: return block_groups = blocks.get_block_ids() - block_ids = block_groups[0] + computed_tokens_per_group: list[int] = [] + for group_index, group_blocks in enumerate(blocks.blocks): + num_computed_gpu_blocks = sum( + block.block_hash is not None + for block in group_blocks + if not block.is_null + ) + computed_tokens_per_group.append( + num_computed_gpu_blocks * self.gpu_block_sizes[group_index] + ) - num_computed_gpu_blocks = sum( - block.block_hash is not None for block in blocks.blocks[0] + num_computed_tokens = min(computed_tokens_per_group, default=0) + groups_agree = all( + group_tokens == num_computed_tokens + for group_tokens in computed_tokens_per_group ) - num_computed_tokens = num_computed_gpu_blocks * self.gpu_block_size - full_block_tokens = num_computed_tokens + num_external_tokens - assert full_block_tokens % self.offloaded_block_size == 0 - - num_pending_gpu_blocks = len(block_ids) - num_computed_gpu_blocks - assert num_external_tokens == num_pending_gpu_blocks * self.gpu_block_size - - start_block_idx = num_computed_tokens // self.offloaded_block_size - num_blocks = full_block_tokens // self.offloaded_block_size + if not groups_agree: + # Some groups loaded more blocks than others (e.g., stale + # cache files rejected for one group but not others, or + # kernel block size mismatch on the attention group). + # Fall back to recompute by reporting 0 external tokens. + logger.warning( + "KV groups disagree on computed prefix length: %s. " + "Falling back to full recompute.", + computed_tokens_per_group, + ) + num_computed_tokens = 0 + num_external_tokens = 0 - assert len(request.block_hashes) // self.block_size_factor >= num_blocks - block_hashes = self._get_block_hashes( - request, start_idx=start_block_idx, end_idx=num_blocks + full_block_tokens = num_computed_tokens + num_external_tokens + start_block_idx = self._chunk_count_for_tokens(num_computed_tokens) + num_blocks = self._chunk_count_for_tokens(full_block_tokens) + assert self._chunk_prefix_tokens(num_blocks) == full_block_tokens + + assert self._get_num_offloaded_blocks(request) >= num_blocks + # Materialise into a list so the same hashes can be passed to + # prepare_load (which consumes the iterable) and also used to + # update _reqs_being_loaded without a second HybridChunkBlockHashList. + block_hashes = list( + self._get_block_hashes( + request, start_idx=start_block_idx, end_idx=num_blocks + ) ) src_spec = self.manager.prepare_load(block_hashes) - dst_spec = GPULoadStoreSpec( - block_ids[num_computed_gpu_blocks:], - group_sizes=(num_pending_gpu_blocks,), - block_indices=(num_computed_gpu_blocks,), - ) - - block_hashes = self._get_block_hashes( - request, start_idx=start_block_idx, end_idx=num_blocks + dst_spec = self._build_gpu_transfer_spec_from_chunk_range( + block_groups, + start_chunk_idx=start_block_idx, + end_chunk_idx=num_blocks, + include_block_indices=True, ) self._reqs_to_load[request.request_id] = (src_spec, dst_spec) req_blocks_being_loaded = self._reqs_being_loaded[request.request_id] req_blocks_being_loaded.update(block_hashes) + self._load_start_times[request.request_id] = time.monotonic() self._next_stored_block_idx[request.request_id] = num_blocks if self._blocks_being_loaded is not None: @@ -196,31 +410,29 @@ def _get_reqs_to_store(self, scheduler_output: SchedulerOutput): # iterate over both new and cached requests for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output): if preempted: - self._request_block_ids[req_id] = [] + self._request_block_ids[req_id] = self._empty_block_groups() if new_block_id_groups: - new_block_ids = new_block_id_groups[0] - self._request_block_ids[req_id] += new_block_ids + self._append_block_groups(req_id, new_block_id_groups) - block_ids = self._request_block_ids[req_id] + block_groups = self._request_block_ids[req_id] req = self._requests[req_id] new_tokens = scheduler_output.num_scheduled_tokens[req_id] expected_tokens = req.num_computed_tokens + new_tokens # with async scheduling, some tokens may be missing total_tokens = min(expected_tokens, req.num_tokens) - num_blocks = total_tokens // self.offloaded_block_size + num_blocks = self._chunk_count_for_tokens(total_tokens) start_block_idx = self._next_stored_block_idx.get(req_id, 0) num_new_blocks = num_blocks - start_block_idx if num_new_blocks <= 0: continue - num_gpu_blocks = num_blocks * self.block_size_factor - assert len(req.block_hashes) >= num_gpu_blocks - - new_block_hashes = self._get_block_hashes( - req, start_idx=start_block_idx, end_idx=num_blocks + new_block_hashes = list( + self._get_block_hashes( + req, start_idx=start_block_idx, end_idx=num_blocks + ) ) store_output = self.manager.prepare_store(new_block_hashes) if store_output is None: @@ -238,21 +450,92 @@ def _get_reqs_to_store(self, scheduler_output: SchedulerOutput): block_hashes = self._get_block_hashes(req, end_idx=num_blocks) self.manager.touch(block_hashes) - new_block_hashes = self._get_block_hashes( - req, start_idx=start_block_idx, end_idx=num_blocks - ) dst_spec = store_output.store_spec - src_block_ids: list[int] = [] - for idx, blk_hash in enumerate(new_block_hashes): - if blk_hash not in block_hashes_to_store: - continue - offloaded_block_idx = start_block_idx + idx - gpu_block_idx = offloaded_block_idx * self.block_size_factor - for i in range(self.block_size_factor): - src_block_ids.append(block_ids[gpu_block_idx + i]) - src_spec = GPULoadStoreSpec( - src_block_ids, group_sizes=(len(src_block_ids),) - ) + if self.hybrid_offload_enabled: + block_hash_to_chunk_idx = { + block_hash: start_block_idx + idx + for idx, block_hash in enumerate(new_block_hashes) + } + src_specs: list[GPULoadStoreSpec] = [] + for block_hash in new_block_hashes: + if block_hash not in block_hashes_to_store: + continue + chunk_idx = block_hash_to_chunk_idx[block_hash] + src_specs.append( + self._build_gpu_transfer_spec_from_chunk_range( + block_groups, + start_chunk_idx=chunk_idx, + end_chunk_idx=chunk_idx + 1, + ) + ) + + # Accumulate per-group, then flatten. Each + # src_spec_part's flat arrays are ordered by group + # (group_0 entries, group_1 entries, ...), but when + # we combine *multiple* spec parts we must keep all + # of group_0's entries contiguous, then group_1's, + # etc. The old code appended chunk-by-chunk which + # interleaved groups and caused mamba sub-block + # offsets to bleed into attention-group entries. + per_group_ids: list[list[int]] = [[] for _ in range(self.num_kv_groups)] + per_group_offsets: list[list[int]] = [ + [] for _ in range(self.num_kv_groups) + ] + per_group_counts: list[list[int]] = [ + [] for _ in range(self.num_kv_groups) + ] + group_sizes = [0] * self.num_kv_groups + for src_spec_part in src_specs: + start = 0 + for group_index, group_size in enumerate(src_spec_part.group_sizes): + end = start + group_size + per_group_ids[group_index].extend( + src_spec_part.block_ids[start:end].tolist() + ) + assert src_spec_part.block_offsets is not None + assert src_spec_part.block_counts is not None + per_group_offsets[group_index].extend( + src_spec_part.block_offsets[start:end].tolist() + ) + per_group_counts[group_index].extend( + src_spec_part.block_counts[start:end].tolist() + ) + group_sizes[group_index] += group_size + start = end + flat_src_block_ids: list[int] = [] + flat_block_offsets: list[int] = [] + flat_block_counts: list[int] = [] + for gi in range(self.num_kv_groups): + flat_src_block_ids.extend(per_group_ids[gi]) + flat_block_offsets.extend(per_group_offsets[gi]) + flat_block_counts.extend(per_group_counts[gi]) + src_spec = GPULoadStoreSpec( + flat_src_block_ids, + group_sizes=tuple(group_sizes), + block_offsets=flat_block_offsets, + block_counts=flat_block_counts, + ) + else: + src_block_groups: list[list[int]] = [] + for group_index, group_block_ids in enumerate(block_groups): + group_factor = self.block_size_factors[group_index] + src_group_block_ids: list[int] = [] + for idx, blk_hash in enumerate(new_block_hashes): + if blk_hash not in block_hashes_to_store: + continue + offloaded_block_idx = start_block_idx + idx + gpu_block_idx = offloaded_block_idx * group_factor + for i in range(group_factor): + src_group_block_ids.append( + group_block_ids[gpu_block_idx + i] + ) + src_block_groups.append(src_group_block_ids) + flat_src_block_ids, src_group_sizes = self._flatten_block_groups( + src_block_groups + ) + src_spec = GPULoadStoreSpec( + flat_src_block_ids, group_sizes=src_group_sizes + ) reqs_to_store[req_id] = (src_spec, dst_spec) self._reqs_being_stored[req_id] |= block_hashes_to_store @@ -266,6 +549,36 @@ def _get_reqs_to_store(self, scheduler_output: SchedulerOutput): return reqs_to_store + def get_timed_out_loads(self) -> set[ReqId]: + """Return request IDs whose loads have exceeded the timeout. + + Timed-out loads are removed from ``_reqs_being_loaded`` and + their block hashes are released, so the scheduler can treat + them as failed and fall back to recompute. + """ + if not self._load_start_times: + return set() + + now = time.monotonic() + timed_out: set[ReqId] = set() + for req_id, start in list(self._load_start_times.items()): + if now - start > self._load_timeout_seconds: + elapsed = now - start + logger.warning( + "Load timeout: req_id=%s exceeded %.0fs " + "(elapsed %.1fs). Falling back to recompute.", + req_id, + self._load_timeout_seconds, + elapsed, + ) + timed_out.add(req_id) + self._load_start_times.pop(req_id) + block_hashes = self._reqs_being_loaded.pop(req_id, None) + if block_hashes and self._blocks_being_loaded: + self._blocks_being_loaded.difference_update(block_hashes) + + return timed_out + def build_connector_meta( self, scheduler_output: SchedulerOutput ) -> KVConnectorMetadata: @@ -300,6 +613,7 @@ def update_connector_output(self, connector_output: KVConnectorOutput): self.manager.complete_store(block_hashes) for req_id in connector_output.finished_recving or []: + self._load_start_times.pop(req_id, None) block_hashes = self._reqs_being_loaded.pop(req_id, None) if block_hashes: if self._blocks_being_loaded: @@ -309,7 +623,7 @@ def update_connector_output(self, connector_output: KVConnectorOutput): def request_finished( self, request: Request, - block_ids: list[int], + block_ids: tuple[list[int], ...], ) -> tuple[bool, dict[str, Any] | None]: """ Called when a request has finished, before its blocks are freed. @@ -324,6 +638,8 @@ def request_finished( req_id = request.request_id self._requests.pop(req_id, None) self._request_block_ids.pop(req_id, None) + self._hybrid_hash_lists.pop(req_id, None) + self._load_start_times.pop(req_id, None) # TODO(orozery): possibly kickoff offload for last block # which may have been deferred due to async scheduling diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py index 77398eee8885..36fade2554aa 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py @@ -58,6 +58,10 @@ def __init__(self, spec: OffloadingSpec): self._unsubmitted_store_jobs: list[tuple[int, TransferSpec]] = [] self._finished_reqs_waiting_for_store: set[ReqId] = set() + # Loads that failed validation (e.g. stale cache) and were never + # submitted to the I/O engine. Reported as "finished receiving" + # so the scheduler falls back to recompute. + self._failed_load_req_ids: set[ReqId] = set() def _generate_job_id(self) -> int: job_id = self._job_counter @@ -318,14 +322,37 @@ def start_kv_transfers(self, metadata: OffloadingConnectorMetadata): self._jobs[job_id] = (req_id, False) assert req_id not in self._load_job self._load_job[req_id] = job_id + logger.debug( + "offloading worker submit load req_id=%s job_id=%s", + req_id, + job_id, + ) success = self.worker.transfer_async(job_id, transfer_spec) - assert success + if not success: + logger.warning( + "offloading worker load submission failed for " + "req_id=%s job_id=%s (stale cache files?), " + "falling back to recompute", + req_id, + job_id, + ) + # Remove from tracking and mark as failed so + # get_finished() reports it as complete, letting the + # scheduler fall back to recompute. + del self._load_job[req_id] + del self._jobs[job_id] + self._failed_load_req_ids.add(req_id) def prepare_store_kv(self, metadata: OffloadingConnectorMetadata): for req_id, transfer_spec in metadata.reqs_to_store.items(): job_id = self._generate_job_id() self._jobs[job_id] = (req_id, True) self._store_jobs[req_id].add(job_id) + logger.debug( + "offloading worker queue store req_id=%s job_id=%s", + req_id, + job_id, + ) # NOTE(orozery): defer the store to the beginning of the next engine step, # so that offloading starts AFTER transfers related to token sampling, # thereby avoiding delays to token generation due to offloading. @@ -346,6 +373,14 @@ def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: for transfer_result in self.worker.get_finished(): # we currently do not support job failures job_id = transfer_result.job_id + logger.debug( + "offloading worker finished job_id=%s success=%s " + "transfer_type=%s transfer_size=%s", + job_id, + transfer_result.success, + transfer_result.transfer_type, + transfer_result.transfer_size, + ) assert transfer_result.success req_id, store = self._jobs.pop(job_id) if ( @@ -372,13 +407,26 @@ def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: req_job = self._load_job[req_id] assert job_id == req_job del self._load_job[req_id] + logger.debug( + "offloading worker finished load req_id=%s job_id=%s", + req_id, + job_id, + ) finished_recving.add(req_id) + # Include loads that failed validation (stale cache) so the + # scheduler removes them from _reqs_being_loaded and falls + # back to recompute. + if self._failed_load_req_ids: + finished_recving.update(self._failed_load_req_ids) + self._failed_load_req_ids.clear() + for req_id in finished_req_ids: pending_req_jobs = self._store_jobs.get(req_id) if pending_req_jobs: self._finished_reqs_waiting_for_store.add(req_id) elif pending_req_jobs is not None: + logger.debug("offloading worker finished sending req_id=%s", req_id) finished_sending.add(req_id) del self._store_jobs[req_id] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py index 547ee2578a12..1a16c5c3beab 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py @@ -10,6 +10,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import ( KVConnectorBase_V1, KVConnectorRole, + SupportsHMA, ) from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( @@ -41,7 +42,27 @@ from vllm.v1.request import Request -class OffloadingConnector(KVConnectorBase_V1): +class OffloadingConnector(KVConnectorBase_V1, SupportsHMA): + @staticmethod + def _coerce_metadata( + metadata: KVConnectorMetadata, + ) -> OffloadingConnectorMetadata: + if isinstance(metadata, OffloadingConnectorMetadata): + return metadata + if all( + hasattr(metadata, field) + for field in ("reqs_to_load", "reqs_to_store", "reqs_to_flush") + ): + return OffloadingConnectorMetadata( + reqs_to_load=metadata.reqs_to_load, # type: ignore[attr-defined] + reqs_to_store=metadata.reqs_to_store, # type: ignore[attr-defined] + reqs_to_flush=metadata.reqs_to_flush, # type: ignore[attr-defined] + ) + raise TypeError( + "OffloadingConnector requires metadata with reqs_to_load, " + "reqs_to_store, and reqs_to_flush fields." + ) + @property def prefer_cross_layer_blocks(self) -> bool: return True @@ -74,15 +95,28 @@ def register_cross_layers_kv_cache( assert self.connector_worker is not None self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend) - def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata): + def handle_preemptions( + self, + preempted_req_ids_or_metadata: set[str] | KVConnectorMetadata, + ): assert self.connector_worker is not None - assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata) - self.connector_worker.handle_preemptions(kv_connector_metadata) + # Stock vLLM passes preempted_req_ids (set[str]) directly. + # Our branch's EngineCore may also pass KVConnectorMetadata. + if isinstance(preempted_req_ids_or_metadata, (set, list)): + metadata = OffloadingConnectorMetadata( + reqs_to_load={}, + reqs_to_store={}, + reqs_to_flush=set(preempted_req_ids_or_metadata), + ) + else: + metadata = self._coerce_metadata(preempted_req_ids_or_metadata) + self.connector_worker.handle_preemptions(metadata) def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: assert self.connector_worker is not None - assert isinstance(self._connector_metadata, OffloadingConnectorMetadata) - self.connector_worker.start_kv_transfers(self._connector_metadata) + self.connector_worker.start_kv_transfers( + self._coerce_metadata(self._get_connector_metadata()) + ) def wait_for_layer_load(self, layer_name: str) -> None: pass @@ -98,13 +132,20 @@ def save_kv_layer( def wait_for_save(self): assert self.connector_worker is not None - assert isinstance(self._connector_metadata, OffloadingConnectorMetadata) - self.connector_worker.prepare_store_kv(self._connector_metadata) + self.connector_worker.prepare_store_kv( + self._coerce_metadata(self._get_connector_metadata()) + ) def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: assert self.connector_worker is not None return self.connector_worker.get_finished(finished_req_ids) + def get_timed_out_loads(self) -> set[str]: + """Return req IDs whose loads exceeded the timeout.""" + if self.connector_scheduler is not None: + return self.connector_scheduler.get_timed_out_loads() + return set() + def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int ) -> tuple[int | None, bool]: @@ -135,6 +176,13 @@ def request_finished( self, request: "Request", block_ids: list[int], + ) -> tuple[bool, dict[str, Any] | None]: + return self.request_finished_all_groups(request, (block_ids,)) + + def request_finished_all_groups( + self, + request: "Request", + block_ids: tuple[list[int], ...], ) -> tuple[bool, dict[str, Any] | None]: assert self.connector_scheduler is not None return self.connector_scheduler.request_finished(request, block_ids) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index c28a5d18ae77..486fa2918599 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -295,9 +295,10 @@ def _mamba_block_aligned_split( num_new_local_computed_tokens: int = 0, num_external_computed_tokens: int = 0, ) -> int: - assert num_external_computed_tokens == 0, ( - "External KV connector is not verified yet" - ) + # External tokens (from KV offload connectors) are loaded into GPU + # cache before the forward pass. They are indistinguishable from + # locally-computed tokens for alignment purposes — the mamba state + # is already populated at block boundaries by the offloading path. num_computed_tokens = ( request.num_computed_tokens + num_new_local_computed_tokens @@ -691,7 +692,10 @@ def schedule(self) -> SchedulerOutput: # The request cannot be scheduled. break - if self.need_mamba_block_aligned_split: + if self.need_mamba_block_aligned_split and not load_kv_async: + # Skip mamba alignment when doing an async KV load + # (num_new_tokens is intentionally 0 — we're loading + # cached tokens, not computing new ones). num_new_tokens = self._mamba_block_aligned_split( request, num_new_tokens, @@ -2077,7 +2081,22 @@ def _try_promote_blocked_waiting_request(self, request: Request) -> bool: # update_from_output(), based on worker-side connector signals # in KVConnectorOutput.finished_recving if request.request_id not in self.finished_recving_kv_req_ids: - return False + # Check if the load timed out — if so, treat as + # failed and fall back to recompute. + if self.connector is not None and hasattr( + self.connector, "get_timed_out_loads" + ): + # OffloadingConnector extension: check load timeout. + timed_out = self.connector.get_timed_out_loads() # type: ignore[attr-defined] + if request.request_id in timed_out: + self.failed_recving_kv_req_ids.add(request.request_id) + self.finished_recving_kv_req_ids.add(request.request_id) + else: + return False + elif self.connector is not None: + return False + else: + return False self._update_waiting_for_remote_kv(request) if request.num_preemptions: request.status = RequestStatus.PREEMPTED diff --git a/vllm/v1/kv_offload/cpu/spec.py b/vllm/v1/kv_offload/cpu/spec.py index 4feae8cf7d5a..41668b1595d3 100644 --- a/vllm/v1/kv_offload/cpu/spec.py +++ b/vllm/v1/kv_offload/cpu/spec.py @@ -38,7 +38,7 @@ def __init__(self, vllm_config: VllmConfig, kv_cache_config: KVCacheConfig): * vllm_config.parallel_config.world_size ) - kv_bytes_per_offloaded_block = kv_bytes_per_block * self.block_size_factor + kv_bytes_per_offloaded_block = kv_bytes_per_block * self.block_size_factors[0] self.num_blocks = ( int(cpu_bytes_to_use) // kv_bytes_per_offloaded_block if kv_bytes_per_offloaded_block > 0 @@ -62,7 +62,7 @@ def get_manager(self) -> OffloadingManager: assert len(self.gpu_block_size) == 1 gpu_block_size = self.gpu_block_size[0] - offloaded_block_size = gpu_block_size * self.block_size_factor + offloaded_block_size = gpu_block_size * self.block_size_factors[0] self._manager = CPUOffloadingManager( block_size=offloaded_block_size, @@ -97,7 +97,7 @@ def get_handlers( self._handlers = CpuGpuOffloadingHandlers( kv_caches=kv_caches, - block_size_factor=self.block_size_factor, + block_size_factor=self.block_size_factors[0], num_cpu_blocks=self.num_blocks, ) diff --git a/vllm/v1/kv_offload/hashing.py b/vllm/v1/kv_offload/hashing.py new file mode 100644 index 000000000000..4b655a52e220 --- /dev/null +++ b/vllm/v1/kv_offload/hashing.py @@ -0,0 +1,168 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable, Iterator +from typing import overload + +from vllm.v1.core.kv_cache_utils import ( + BlockHash, + generate_block_hash_extra_keys, + hash_block_tokens, +) +from vllm.v1.request import Request + + +class RequestBlockHashList: + """Compute request block hashes directly at an arbitrary block size. + + Unlike ``Request.block_hashes``, this does not depend on the scheduler's + single ``hash_block_size``. It is used for offload paths that need a + different granularity than EngineCore prefix-caching. + """ + + def __init__( + self, + request: Request, + block_size: int, + hash_function: Callable[[object], bytes], + ): + self.request = request + self.block_size = block_size + self.hash_function = hash_function + self._hashes: list[BlockHash] = [] + self._next_token_idx = 0 + self._curr_mm_idx = 0 + self._prev_block_hash: BlockHash | None = None + + def __len__(self) -> int: + return self.request.num_tokens // self.block_size + + @overload + def __getitem__(self, idx: int) -> BlockHash: ... + + @overload + def __getitem__(self, idx: slice) -> list[BlockHash]: ... + + def __getitem__(self, idx): + if isinstance(idx, int): + return self._get_value_at(idx) + + if isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + return [self._get_value_at(i) for i in range(start, stop, step)] + + raise TypeError(f"Invalid index type: {type(idx)!r}") + + def __iter__(self) -> Iterator[BlockHash]: + for i in range(len(self)): + yield self._get_value_at(i) + + def _get_value_at(self, idx: int) -> BlockHash: + self._ensure_computed_through(idx) + return self._hashes[idx] + + def _ensure_computed_through(self, idx: int) -> None: + while len(self._hashes) <= idx: + end_token_idx = self._next_token_idx + self.block_size + if end_token_idx > self.request.num_tokens: + raise IndexError(idx) + + extra_keys, self._curr_mm_idx = generate_block_hash_extra_keys( + self.request, + self._next_token_idx, + end_token_idx, + self._curr_mm_idx, + ) + block_tokens = self.request.all_token_ids[ + self._next_token_idx : end_token_idx + ] + block_hash = hash_block_tokens( + self.hash_function, + self._prev_block_hash, + block_tokens, + extra_keys, + ) + self._hashes.append(block_hash) + self._next_token_idx = end_token_idx + self._prev_block_hash = block_hash + + +class HybridChunkBlockHashList: + """Compose a logical offload hash from per-group block hashes. + + Each logical chunk boundary advances by ``logical_chunk_size`` tokens. + For each group, we use the most recent full group-sized block hash that + fits under that chunk boundary, then hash the tuple of group hashes into a + single offload key. + """ + + def __init__( + self, + request: Request, + group_block_sizes: tuple[int, ...], + logical_chunk_size: int, + hash_function: Callable[[object], bytes], + ): + self.request = request + self.group_block_sizes = group_block_sizes + self.logical_chunk_size = logical_chunk_size + self.hash_function = hash_function + self.first_hashable_chunk_idx = ( + max( + (block_size + logical_chunk_size - 1) // logical_chunk_size + for block_size in group_block_sizes + ) + - 1 + ) + self.group_hashes = tuple( + RequestBlockHashList(request, block_size, hash_function) + for block_size in group_block_sizes + ) + # Cache of combined chunk hashes, grown lazily as indices are accessed. + # Mirrors the pattern used by RequestBlockHashList._hashes. When the + # scheduler reuses this instance across steps (via _hybrid_hash_lists), + # previously-computed indices are served from here without invoking + # hash_function again. + self._chunk_hashes: list[BlockHash] = [] + + def __len__(self) -> int: + return max( + 0, + self.request.num_tokens // self.logical_chunk_size + - self.first_hashable_chunk_idx, + ) + + @overload + def __getitem__(self, idx: int) -> BlockHash: ... + + @overload + def __getitem__(self, idx: slice) -> list[BlockHash]: ... + + def __getitem__(self, idx): + if isinstance(idx, int): + return self._get_value_at(idx) + + if isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + return [self._get_value_at(i) for i in range(start, stop, step)] + + raise TypeError(f"Invalid index type: {type(idx)!r}") + + def __iter__(self) -> Iterator[BlockHash]: + for i in range(len(self)): + yield self._get_value_at(i) + + def _get_value_at(self, idx: int) -> BlockHash: + if idx < len(self._chunk_hashes): + return self._chunk_hashes[idx] + chunk_end = (idx + 1 + self.first_hashable_chunk_idx) * self.logical_chunk_size + component_hashes: list[BlockHash] = [] + for block_size, group_hashes in zip(self.group_block_sizes, self.group_hashes): + num_full_blocks = chunk_end // block_size + component_hashes.append(group_hashes[num_full_blocks - 1]) + chunk_hash = BlockHash(self.hash_function(tuple(component_hashes))) + # Cache only next-in-sequence indices to keep the list dense. The + # typical caller iterates sequentially via islice so this hit rate is + # high; out-of-order accesses skip caching to avoid leaving gaps. + if idx == len(self._chunk_hashes): + self._chunk_hashes.append(chunk_hash) + return chunk_hash diff --git a/vllm/v1/kv_offload/mediums.py b/vllm/v1/kv_offload/mediums.py index 85ef2a95a6bd..d09a5e88df84 100644 --- a/vllm/v1/kv_offload/mediums.py +++ b/vllm/v1/kv_offload/mediums.py @@ -48,12 +48,26 @@ def __init__( block_ids: list[int], group_sizes: Sequence[int], block_indices: Sequence[int] | None = None, + block_offsets: Sequence[int] | None = None, + block_counts: Sequence[int] | None = None, ): super().__init__(block_ids) assert sum(group_sizes) == len(block_ids) assert block_indices is None or len(block_indices) == len(group_sizes) + assert (block_offsets is None) == (block_counts is None) + if block_offsets is not None and block_counts is not None: + assert len(block_offsets) == len(block_ids) + assert len(block_counts) == len(block_ids) self.group_sizes: Sequence[int] = group_sizes self.block_indices: Sequence[int] | None = block_indices + self.block_offsets: np.ndarray | None = ( + np.array(block_offsets, dtype=np.int64) + if block_offsets is not None + else None + ) + self.block_counts: np.ndarray | None = ( + np.array(block_counts, dtype=np.int64) if block_counts is not None else None + ) @staticmethod def medium() -> str: diff --git a/vllm/v1/kv_offload/planner.py b/vllm/v1/kv_offload/planner.py new file mode 100644 index 000000000000..e2a81ceccd1d --- /dev/null +++ b/vllm/v1/kv_offload/planner.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +from dataclasses import dataclass + + +@dataclass(frozen=True) +class HybridOffloadPlanner: + """Plan fixed-size external offload units for hybrid KV groups. + + This planner is intentionally pure and scheduler-facing. It does not assume + that every group can be transferred as a native full GPU block. When a + group's offload unit is smaller than its GPU block size, that group needs + partial-state transfer support in the worker/backend. + """ + + hash_block_size: int + gpu_block_sizes: tuple[int, ...] + fixed_chunk_size: int + + def __post_init__(self) -> None: + if self.hash_block_size <= 0: + raise ValueError("hash_block_size must be positive") + if self.fixed_chunk_size <= 0: + raise ValueError("fixed_chunk_size must be positive") + if self.fixed_chunk_size < self.hash_block_size: + raise ValueError( + "fixed_chunk_size must be greater than or equal to hash_block_size" + ) + if not self.gpu_block_sizes: + raise ValueError("gpu_block_sizes must be non-empty") + if any(block_size <= 0 for block_size in self.gpu_block_sizes): + raise ValueError("gpu_block_sizes must be positive") + + # Pre-compute derived values that are called in tight loops + # (e.g. chunk_count_for_tokens binary search calls offload_unit_sizes + # and first_hashable_chunk_idx on every iteration). Using + # object.__setattr__ is the standard pattern for frozen dataclasses. + units: list[int] = [] + for gpu_block_size in self.gpu_block_sizes: + if gpu_block_size <= self.fixed_chunk_size: + units.append(gpu_block_size) + elif gpu_block_size % self.fixed_chunk_size == 0: + units.append(self.fixed_chunk_size) + else: + units.append(gpu_block_size) + object.__setattr__(self, "_offload_unit_sizes", tuple(units)) + object.__setattr__( + self, + "_first_hashable_chunk_idx", + max(math.ceil(u / self.fixed_chunk_size) for u in units) - 1, + ) + object.__setattr__( + self, + "_group_hash_factors", + tuple( + u // self.hash_block_size if u % self.hash_block_size == 0 else None + for u in units + ), + ) + + @property + def offload_unit_sizes(self) -> tuple[int, ...]: + return self._offload_unit_sizes # type: ignore[attr-defined] + + @property + def requires_partial_group_offload(self) -> tuple[bool, ...]: + return tuple( + unit_size < gpu_block_size + for unit_size, gpu_block_size in zip( + self.offload_unit_sizes, self.gpu_block_sizes + ) + ) + + @property + def requires_partial_group_offload_any(self) -> bool: + return any(self.requires_partial_group_offload) + + @property + def first_hashable_chunk_idx(self) -> int: + return self._first_hashable_chunk_idx # type: ignore[attr-defined] + + @property + def group_hash_factors(self) -> tuple[int | None, ...]: + return self._group_hash_factors # type: ignore[attr-defined] + + def group_covered_tokens_for_chunk_count(self, chunk_count: int) -> tuple[int, ...]: + if chunk_count < 0: + raise ValueError("chunk_count must be non-negative") + logical_tokens = ( + chunk_count + self.first_hashable_chunk_idx + ) * self.fixed_chunk_size + return tuple( + (logical_tokens // unit_size) * unit_size + for unit_size in self.offload_unit_sizes + ) + + def chunk_prefix_tokens(self, chunk_count: int) -> int: + if chunk_count <= 0: + return 0 + return self.loadable_prefix_tokens( + self.group_covered_tokens_for_chunk_count(chunk_count) + ) + + def chunk_count_for_tokens(self, tokens: int) -> int: + if tokens < 0: + raise ValueError("tokens must be non-negative") + + low = 0 + high = max( + 0, + tokens // self.fixed_chunk_size + 1 - self.first_hashable_chunk_idx, + ) + while low < high: + mid = (low + high + 1) // 2 + if self.chunk_prefix_tokens(mid) <= tokens: + low = mid + else: + high = mid - 1 + return low + + def storable_prefix_tokens(self, request_tokens: int) -> int: + if request_tokens <= 0: + return 0 + return self.loadable_prefix_tokens( + tuple( + (request_tokens // unit_size) * unit_size + for unit_size in self.offload_unit_sizes + ) + ) + + def loadable_prefix_tokens(self, group_covered_tokens: tuple[int, ...]) -> int: + if len(group_covered_tokens) != len(self.gpu_block_sizes): + raise ValueError("group_covered_tokens must match gpu_block_sizes length") + if any(tokens < 0 for tokens in group_covered_tokens): + raise ValueError("group_covered_tokens must be non-negative") + + common_prefix = min(group_covered_tokens, default=0) + return common_prefix - (common_prefix % self.hash_block_size) diff --git a/vllm/v1/kv_offload/spec.py b/vllm/v1/kv_offload/spec.py index 1eb4fdb3e6ce..f441a6176081 100644 --- a/vllm/v1/kv_offload/spec.py +++ b/vllm/v1/kv_offload/spec.py @@ -3,12 +3,15 @@ from abc import ABC, abstractmethod from collections.abc import Iterator from dataclasses import dataclass +from math import lcm from typing import TYPE_CHECKING import torch from vllm.logger import init_logger +from vllm.utils.hashing import get_hash_fn_by_name from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager +from vllm.v1.kv_offload.planner import HybridOffloadPlanner from vllm.v1.kv_offload.worker.worker import OffloadingHandler if TYPE_CHECKING: @@ -86,31 +89,126 @@ def __init__(self, vllm_config: "VllmConfig", kv_cache_config: "KVCacheConfig"): # block size used by vLLM for hashing request tokens for the sake # of enabling prefix caching self.hash_block_size = vllm_config.cache_config.block_size + self.hash_function = get_hash_fn_by_name( + vllm_config.cache_config.prefix_caching_hash_algo + ) + # TODO: Block hashes are computed from token IDs and a chain seed + # (NONE_HASH), but do NOT incorporate model weights or config. + # If model weights change (e.g. same model name, different revision + # or fine-tune), stored KV cache files will still hash-match but + # contain stale/wrong KV values — silent correctness corruption. + # Fix: incorporate a model fingerprint (weights checksum, revision + # hash, or tokenizer vocab hash) into the hash seed so that + # weight changes invalidate the cache automatically. # gpu block size per group self.gpu_block_size: tuple[int, ...] = tuple( kv_cache_group.kv_cache_spec.block_size for kv_cache_group in kv_cache_config.kv_cache_groups ) + self.hybrid_planner: HybridOffloadPlanner | None = None + self.hybrid_offload_enabled: bool = False + self.group_hash_block_size: tuple[int, ...] = tuple( + self.hash_block_size for _ in self.gpu_block_size + ) - for block_size in self.gpu_block_size: - assert block_size % self.hash_block_size == 0 - - # offloaded_block_size / gpu_block_size - self.block_size_factor: int = 1 + hybrid_chunk_size = self.extra_config.get("hybrid_chunk_size") + if hybrid_chunk_size is not None: + chunk_size_int = int(hybrid_chunk_size) + # Warn about gpu_block_sizes that are not divisible by chunk_size, + # as these groups cannot be split and will raise + # first_hashable_chunk_idx, potentially making offloading + # impossible for practical context lengths. + for i, gbs in enumerate(self.gpu_block_size): + if gbs > chunk_size_int and gbs % chunk_size_int != 0: + logger.warning( + "KV group %d has gpu_block_size=%d which is not " + "divisible by hybrid_chunk_size=%d. This group " + "cannot be split into chunks and will require " + "%d tokens before any offloading can occur. " + "Consider setting max_model_len to a multiple " + "of hybrid_chunk_size.", + i, + gbs, + chunk_size_int, + gbs, + ) + self.hybrid_planner = HybridOffloadPlanner( + hash_block_size=self.hash_block_size, + gpu_block_sizes=self.gpu_block_size, + fixed_chunk_size=chunk_size_int, + ) + self.hybrid_offload_enabled = True + self.group_hash_block_size = self.hybrid_planner.offload_unit_sizes + max_model_len = vllm_config.model_config.max_model_len + if ( + self.hybrid_planner.first_hashable_chunk_idx * chunk_size_int + >= max_model_len + ): + logger.error( + "Hybrid offloading is effectively disabled: " + "first_hashable_chunk_idx=%d requires %d tokens " + "but max_model_len=%d. No chunks can ever be " + "stored. Set max_model_len to a multiple of " + "hybrid_chunk_size=%d (e.g. %d).", + self.hybrid_planner.first_hashable_chunk_idx, + self.hybrid_planner.first_hashable_chunk_idx * chunk_size_int, + max_model_len, + chunk_size_int, + (max_model_len // chunk_size_int) * chunk_size_int, + ) + else: + logger.info( + "Hybrid offloading enabled: chunk_size=%d, " + "offload_unit_sizes=%s, " + "first_hashable_chunk_idx=%d, " + "min_tokens_for_offload=%d", + chunk_size_int, + self.hybrid_planner.offload_unit_sizes, + self.hybrid_planner.first_hashable_chunk_idx, + (self.hybrid_planner.first_hashable_chunk_idx + 1) * chunk_size_int, + ) + else: + for block_size in self.gpu_block_size: + assert block_size % self.hash_block_size == 0 + + self.offloaded_block_size: int = lcm(*self.gpu_block_size) + self.block_size_factors: tuple[int, ...] = tuple( + self.offloaded_block_size // block_size + for block_size in self.gpu_block_size + ) offloaded_block_size = self.extra_config.get("block_size") if offloaded_block_size is not None: offloaded_block_size_int = int(offloaded_block_size) - gpu_block_sizes = set(self.gpu_block_size) - assert len(gpu_block_sizes) == 1, ( + assert all( + offloaded_block_size_int % gpu_block_size == 0 + for gpu_block_size in self.gpu_block_size + ), ( "If 'block_size' is specified in kv_connector_extra_config, " - "there must be at least one KV cache group, " - "and all groups must have the same block size." + "it must be divisible by every KV cache group block size." + ) + self.offloaded_block_size = offloaded_block_size_int + self.block_size_factors = tuple( + self.offloaded_block_size // block_size + for block_size in self.gpu_block_size + ) + + if self.hybrid_offload_enabled: + self.offloaded_block_size = int(hybrid_chunk_size) # type: ignore[arg-type] + self.block_size_factors = tuple( + self.offloaded_block_size // block_size + if self.offloaded_block_size % block_size == 0 + else 0 + for block_size in self.gpu_block_size ) - gpu_block_size = gpu_block_sizes.pop() - assert offloaded_block_size_int % gpu_block_size == 0 - self.block_size_factor = offloaded_block_size_int // gpu_block_size + @property + def requires_partial_group_offload(self) -> bool: + return ( + self.hybrid_planner.requires_partial_group_offload_any + if self.hybrid_planner is not None + else False + ) @abstractmethod def get_manager(self) -> OffloadingManager: diff --git a/vllm/v1/kv_offload/worker/cpu_gpu.py b/vllm/v1/kv_offload/worker/cpu_gpu.py index eeabf0cdadd7..7c2eaa1a45db 100644 --- a/vllm/v1/kv_offload/worker/cpu_gpu.py +++ b/vllm/v1/kv_offload/worker/cpu_gpu.py @@ -9,7 +9,7 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.utils.platform_utils import is_pin_memory_available -from vllm.v1.kv_offload.mediums import BlockIDsLoadStoreSpec +from vllm.v1.kv_offload.mediums import BlockIDsLoadStoreSpec, GPULoadStoreSpec from vllm.v1.kv_offload.spec import CanonicalKVCacheRef, CanonicalKVCaches from vllm.v1.kv_offload.worker.worker import ( OffloadingHandler, @@ -34,6 +34,8 @@ def expand_block_ids( block_size_factor: int, output: np.ndarray, skip_count: int = 0, + block_offsets: np.ndarray | None = None, + block_counts: np.ndarray | None = None, ): """ Convert a list of block IDs to a list of matching block ids, @@ -49,6 +51,26 @@ def expand_block_ids( and 3 maps to [12, 13, 14, 15] """ assert skip_count < block_size_factor + if block_offsets is not None or block_counts is not None: + assert block_offsets is not None and block_counts is not None + assert len(block_offsets) == len(block_ids) + assert len(block_counts) == len(block_ids) + + output_idx = 0 + for block_id, block_offset, block_count in zip( + block_ids, block_offsets, block_counts + ): + assert block_offset >= 0 + assert block_count >= 0 + assert block_offset + block_count <= block_size_factor + base_block_id = block_id * block_size_factor + output_end_idx = output_idx + block_count + output[output_idx:output_end_idx] = base_block_id + np.arange( + block_offset, block_offset + block_count + ) + output_idx = output_end_idx + assert output_idx == len(output) + return first_range = np.arange(skip_count, block_size_factor) full_range = np.arange(0, block_size_factor) @@ -62,6 +84,67 @@ def expand_block_ids( output_idx = output_end_idx +def build_transfer_indices( + src_spec: BlockIDsLoadStoreSpec, + dst_spec: BlockIDsLoadStoreSpec, + src_block_size_factor: int, + dst_block_size_factor: int, +) -> np.ndarray: + src_blocks = src_spec.block_ids + dst_blocks = dst_spec.block_ids + assert src_blocks.ndim == 1 + assert dst_blocks.ndim == 1 + + src_block_offsets = ( + src_spec.block_offsets if isinstance(src_spec, GPULoadStoreSpec) else None + ) + src_block_counts = ( + src_spec.block_counts if isinstance(src_spec, GPULoadStoreSpec) else None + ) + dst_block_offsets = ( + dst_spec.block_offsets if isinstance(dst_spec, GPULoadStoreSpec) else None + ) + dst_block_counts = ( + dst_spec.block_counts if isinstance(dst_spec, GPULoadStoreSpec) else None + ) + + src_sub_block_count = ( + int(np.sum(src_block_counts)) + if src_block_counts is not None + else src_blocks.size * src_block_size_factor + ) + dst_sub_block_count = ( + int(np.sum(dst_block_counts)) + if dst_block_counts is not None + else dst_blocks.size * dst_block_size_factor + ) + + src_sub_blocks_to_skip = 0 + if src_block_counts is None and dst_block_counts is None: + src_sub_blocks_to_skip = -dst_blocks.size % src_block_size_factor + assert dst_sub_block_count == src_sub_block_count - src_sub_blocks_to_skip + else: + assert dst_sub_block_count == src_sub_block_count + + src_to_dst = np.empty((dst_sub_block_count, 2), dtype=np.int64) + expand_block_ids( + src_blocks, + src_block_size_factor, + src_to_dst[:, 0], + skip_count=src_sub_blocks_to_skip, + block_offsets=src_block_offsets, + block_counts=src_block_counts, + ) + expand_block_ids( + dst_blocks, + dst_block_size_factor, + src_to_dst[:, 1], + block_offsets=dst_block_offsets, + block_counts=dst_block_counts, + ) + return src_to_dst + + class SingleDirectionOffloadingHandler(OffloadingHandler): """ SingleDirectionOffloadingHandler handles transfers for a single direction, @@ -154,25 +237,13 @@ def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool: assert isinstance(src_spec, BlockIDsLoadStoreSpec) assert isinstance(dst_spec, BlockIDsLoadStoreSpec) - src_blocks = src_spec.block_ids - dst_blocks = dst_spec.block_ids - assert src_blocks.ndim == 1 - assert dst_blocks.ndim == 1 - - src_sub_block_count = src_blocks.size * self.src_block_size_factor - dst_sub_block_count = dst_blocks.size * self.dst_block_size_factor - src_sub_blocks_to_skip = -dst_blocks.size % self.src_block_size_factor - - assert dst_sub_block_count == src_sub_block_count - src_sub_blocks_to_skip - - src_to_dst = np.empty((dst_sub_block_count, 2), dtype=np.int64) - expand_block_ids( - src_blocks, - self.src_block_size_factor, - src_to_dst[:, 0], - skip_count=src_sub_blocks_to_skip, + src_to_dst = build_transfer_indices( + src_spec, + dst_spec, + src_block_size_factor=self.src_block_size_factor, + dst_block_size_factor=self.dst_block_size_factor, ) - expand_block_ids(dst_blocks, self.dst_block_size_factor, src_to_dst[:, 1]) + dst_sub_block_count = src_to_dst.shape[0] src_to_dst_tensor = torch.from_numpy(src_to_dst) stream = self._stream_pool.pop() if self._stream_pool else torch.cuda.Stream() diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 5d5877d1692e..0559dd326039 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -1118,9 +1118,16 @@ def record( # Labeled prompt token counters by source pts = iteration_stats.prompt_token_stats for source in PromptTokenStats.ALL_SOURCES: - self.counter_prompt_tokens_by_source[source][engine_idx].inc( - pts.get_by_source(source) - ) + value = pts.get_by_source(source) + if value < 0: + logger.warning( + "Negative prompt_tokens_by_source[%s]=%d " + "(external KV transfer accounting skew), clamping to 0", + source, + value, + ) + value = 0 + self.counter_prompt_tokens_by_source[source][engine_idx].inc(value) self.counter_prompt_tokens_cached[engine_idx].inc(pts.cached_tokens) self.counter_prompt_tokens_recomputed[engine_idx].inc(pts.recomputed_tokens) self.counter_generation_tokens[engine_idx].inc(