diff --git a/docs/source/user_guide/feature_guide/index.md b/docs/source/user_guide/feature_guide/index.md index a8732b177c9..977ac53c053 100644 --- a/docs/source/user_guide/feature_guide/index.md +++ b/docs/source/user_guide/feature_guide/index.md @@ -13,6 +13,6 @@ lora eplb_swift_balancer netloader dynamic_batch -kv_pool_mooncake +kv_pool external_dp ::: diff --git a/docs/source/user_guide/feature_guide/kv_pool_mooncake.md b/docs/source/user_guide/feature_guide/kv_pool.md similarity index 84% rename from docs/source/user_guide/feature_guide/kv_pool_mooncake.md rename to docs/source/user_guide/feature_guide/kv_pool.md index 9188d7d1354..4b5ec13f8ac 100644 --- a/docs/source/user_guide/feature_guide/kv_pool_mooncake.md +++ b/docs/source/user_guide/feature_guide/kv_pool.md @@ -1,4 +1,4 @@ -# Mooncacke Store Deployment Guide +# Ascend Store Deployment Guide ## Environmental Dependencies @@ -8,27 +8,30 @@ * PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250724 * vLLM:main branch * vLLM-Ascend:main branch - * Mooncake:main branch - - Installation and Compilation Guide:https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#build-and-use-binaries - - Make sure to build with `-DUSE_ASCEND_DIRECT` to enable ADXL engine. - - An example command for compiling ADXL: - - `rm -rf build && mkdir -p build && cd build \ && cmake .. -DCMAKE_INSTALL_PREFIX=/opt/transfer-engine/ -DCMAKE_POLICY_VERSION_MINIMUM=3.5 -DUSE_ASCEND_DIRECT=ON -DBUILD_SHARED_LIBS=ON -DBUILD_UNIT_TESTS=OFF \ && make -j \ && make install` - - Also, you need to set environment variables to point to them `export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib64/python3.11/site-packages/mooncake`, or copy the .so files to the `/usr/local/lib64` directory after compilation ### KV Pooling Parameter Description **kv_connector_extra_config**: Additional Configurable Parameters for Pooling. -**mooncake_rpc_port**: Port for RPC Communication Between Pooling Scheduler Process and Worker Process: Each Instance Requires a Unique Port Configuration. +**lookup_rpc_port**: Port for RPC Communication Between Pooling Scheduler Process and Worker Process: Each Instance Requires a Unique Port Configuration. **load_async**: Whether to Enable Asynchronous Loading. The default value is false. -**register_buffer**: Whether to Register Video Memory with the Backend. Registration is Not Required When Used with MooncakeConnectorV1; It is Required in All Other Cases. The Default Value is false. +**backend**: Set the storage backend for kvpool, with the default being mooncake. + +## Example of using Mooncake as a KVCache pooling backend +* Software: + * Mooncake:main branch + + Installation and Compilation Guide:https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#build-and-use-binaries -## Run Mooncake Master + Make sure to build with `-DUSE_ASCEND_DIRECT` to enable ADXL engine. -### 1.Configure mooncake.json + An example command for compiling ADXL: + + `rm -rf build && mkdir -p build && cd build \ && cmake .. -DCMAKE_INSTALL_PREFIX=/opt/transfer-engine/ -DCMAKE_POLICY_VERSION_MINIMUM=3.5 -DUSE_ASCEND_DIRECT=ON -DBUILD_SHARED_LIBS=ON -DBUILD_UNIT_TESTS=OFF \ && make -j \ && make install` + + Also, you need to set environment variables to point to them `export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib64/python3.11/site-packages/mooncake`, or copy the .so files to the `/usr/local/lib64` directory after compilation + +### run mooncake master + +#### 1.Configure mooncake.json The environment variable **MOONCAKE_CONFIG_PATH** is configured to the full path where mooncake.json is located. @@ -54,7 +57,7 @@ The environment variable **MOONCAKE_CONFIG_PATH** is configured to the full path **master_server_address**: Configured with the IP and port of the master service. **global_segment_size**: Expands the kvcache size registered by the PD node to the master. -### 2. Start mooncake_master +#### 2. Start mooncake_master Under the mooncake folder: @@ -64,9 +67,9 @@ mooncake_master --port 50088 --eviction_high_watermark_ratio 0.95 --eviction_rat `eviction_high_watermark_ratio` determines the watermark where Mooncake Store will perform eviction,and `eviction_ratio` determines the portion of stored objects that would be evicted. -## Pooling and Prefill Decode Disaggregate Scenario +### Pooling and Prefill Decode Disaggregate Scenario -### 1.Run `prefill` Node and `decode` Node +#### 1.Run `prefill` Node and `decode` Node Using MultiConnector to simultaneously utilize both p2p connectors and pooled connectors. P2P performs kv_transfer, while pooling creates a larger prefix-cache. @@ -123,9 +126,10 @@ python3 -m vllm.entrypoints.openai.api_server \ } }, { - "kv_connector": "MooncakeConnectorStoreV1", + "kv_connector": "AscendStoreConnector", "kv_role": "kv_producer", - "mooncake_rpc_port":"0" + "lookup_rpc_port":"0", + "backend": "mooncake" } ] } @@ -185,16 +189,17 @@ python3 -m vllm.entrypoints.openai.api_server \ } }, { - "kv_connector": "MooncakeConnectorStoreV1", + "kv_connector": "AscendStoreConnector", "kv_role": "kv_consumer", - "mooncake_rpc_port":"1" + "lookup_rpc_port":"1", + "backend": "mooncake" } ] } }' > d.log 2>&1 ``` -### 2、Start proxy_server. +#### 2、Start proxy_server. ``` bash proxy.sh @@ -212,7 +217,7 @@ python vllm-ascend/examples/disaggregated_prefill_v1/load_balance_proxy_server_e --decoder-ports 8200 \ ``` -### 3. Run Inference +#### 3. Run Inference Configure the localhost, port, and model weight path in the command to your own settings. @@ -228,9 +233,9 @@ Long question: curl -s http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{ "model": "/xxxxx/Qwen2.5-7B-Instruct", "prompt": "Given the accelerating impacts of climate change—including rising sea levels, increasing frequency of extreme weather events, loss of biodiversity, and adverse effects on agriculture and human health—there is an urgent need for a robust, globally coordinated response. However, international efforts are complicated by a range of factors: economic disparities between high-income and low-income countries, differing levels of industrialization, varying access to clean energy technologies, and divergent political systems that influence climate policy implementation. In this context, how can global agreements like the Paris Accord be redesigned or strengthened to not only encourage but effectively enforce emission reduction targets? Furthermore, what mechanisms can be introduced to promote fair and transparent technology transfer, provide adequate financial support for climate adaptation in vulnerable regions, and hold nations accountable without exacerbating existing geopolitical tensions or disproportionately burdening those with historically lower emissions?", "max_tokens": 256, "temperature":0.0 }' ``` -## Pooling and Mixed Deployment Scenario +### Pooling and Mixed Deployment Scenario -### 1、Run Mixed Department Script +#### 1、Run Mixed Department Script The mixed script is essentially a pure pooling scenario for the P node. @@ -263,19 +268,17 @@ python3 -m vllm.entrypoints.openai.api_server \ --max-num-batched-tokens 4096 \ --kv-transfer-config \ '{ - "kv_connector": "MooncakeConnectorStoreV1", + "kv_connector": "AscendStoreConnector", "kv_role": "kv_both", "kv_connector_extra_config": { - "register_buffer": true, "use_layerwise": false, - "mooncake_rpc_port":"0" + "lookup_rpc_port":"1", + "backend": "mooncake" } }' > mix.log 2>&1 ``` -`register_buffer` is set to `false` by default and need to be set to `true` only in PD-mixed scenario. - -### 2. Run Inference +#### 2. Run Inference Configure the localhost, port, and model weight path in the command to your own settings. The requests sent will only go to the port where the mixed deployment script is located, and there is no need to start a separate proxy. diff --git a/tests/ut/distributed/mooncake/test_config_data.py b/tests/ut/distributed/mooncake/test_config_data.py index 4408b41a825..bd8d07930f4 100644 --- a/tests/ut/distributed/mooncake/test_config_data.py +++ b/tests/ut/distributed/mooncake/test_config_data.py @@ -1,6 +1,13 @@ +import sys +import types import unittest +from unittest.mock import MagicMock -from vllm_ascend.distributed.mooncake.config_data import ( +fake_engine = types.ModuleType("mooncake.engine") +fake_engine.TransferEngine = MagicMock() # type: ignore[attr-defined] +sys.modules["mooncake.engine"] = fake_engine + +from vllm_ascend.distributed.kvpool.backend.mooncake_backend import ( # noqa: E402 _convert_to_bytes, _parse_global_segment_size) diff --git a/tests/ut/kv_connector/test_mooncake_connector.py b/tests/ut/kv_connector/test_mooncake_connector.py index 8d21a02b9e3..a0edff8e3f3 100644 --- a/tests/ut/kv_connector/test_mooncake_connector.py +++ b/tests/ut/kv_connector/test_mooncake_connector.py @@ -1051,7 +1051,7 @@ def setUp(self): 'vllm_ascend.distributed.mooncake_connector.string_to_int64_hash', mock_string_to_int64_hash), patch( - 'vllm_ascend.distributed.mooncake.transfer_engine.TransferEngine', + 'vllm_ascend.distributed.mooncake_transfer_engine.TransferEngine', return_value=self.mock_transfer_engine), patch( 'vllm_ascend.distributed.mooncake_connector.KVCacheSendingThread', diff --git a/vllm_ascend/distributed/__init__.py b/vllm_ascend/distributed/__init__.py index 0915b38a519..04195d1cc5b 100644 --- a/vllm_ascend/distributed/__init__.py +++ b/vllm_ascend/distributed/__init__.py @@ -31,8 +31,13 @@ def register_connector(): KVConnectorFactory.register_connector( "MooncakeConnectorStoreV1", - "vllm_ascend.distributed.mooncake.mooncake_store_connector_v1", - "MooncakeConnectorV1") + "vllm_ascend.distributed.kvpool.ascend_store_connector", + "AscendStoreConnector") + + KVConnectorFactory.register_connector( + "AscendStoreConnector", + "vllm_ascend.distributed.kvpool.ascend_store_connector", + "AscendStoreConnector") KVConnectorFactory.register_connector( "MooncakeLayerwiseConnector", diff --git a/vllm_ascend/distributed/cpu_offload_connector.py b/vllm_ascend/distributed/cpu_offload_connector.py index 2e91f715232..c6983b69e23 100644 --- a/vllm_ascend/distributed/cpu_offload_connector.py +++ b/vllm_ascend/distributed/cpu_offload_connector.py @@ -29,6 +29,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request @@ -58,7 +59,10 @@ class CPUOffloadingConnectorMetadata(KVConnectorMetadata): class CPUOffloadingConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + def __init__(self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: Optional[KVCacheConfig] = None): if not vllm_config.cache_config.enable_prefix_caching: self.connector_scheduler: Optional[ CPUOffloadingConnectorScheduler] = None diff --git a/vllm_ascend/distributed/kvpool/__init__.py b/vllm_ascend/distributed/kvpool/__init__.py new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/vllm_ascend/distributed/kvpool/__init__.py @@ -0,0 +1 @@ + diff --git a/vllm_ascend/distributed/kvpool/ascend_store_connector.py b/vllm_ascend/distributed/kvpool/ascend_store_connector.py new file mode 100644 index 00000000000..9f4833555db --- /dev/null +++ b/vllm_ascend/distributed/kvpool/ascend_store_connector.py @@ -0,0 +1,194 @@ +import threading +from typing import Any, Optional + +import torch +import zmq +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.forward_context import ForwardContext +from vllm.utils import logger +from vllm.utils.network_utils import make_zmq_socket +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.request import Request +from vllm.v1.serial_utils import MsgpackDecoder + +from vllm_ascend.distributed.kvpool.pool_scheduler import ( + KVPoolScheduler, get_zmq_rpc_path_lookup) +from vllm_ascend.distributed.kvpool.pool_worker import KVPoolWorker + + +class AscendStoreConnector(KVConnectorBase_V1): + + def __init__(self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: Optional[KVCacheConfig] = None): + super().__init__(vllm_config=vllm_config, + role=role, + kv_cache_config=kv_cache_config) + self.kv_role = vllm_config.kv_transfer_config.kv_role + + self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "use_layerwise", False) + + connector_name = vllm_config.kv_transfer_config.kv_connector + if connector_name == "MooncakeConnectorStoreV1": + logger.warning( + "It is recommended to use the AscendStoreConnector, as the MoonCakeStoreConnector will be removed in the future." + ) + + self.kv_caches: dict[str, torch.Tensor] = {} + + self._block_size = vllm_config.cache_config.block_size + + self.sended_but_unfinished_reqs: set[str] = set() + + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler = KVPoolScheduler(vllm_config, + self.use_layerwise) + else: + self.connector_worker = KVPoolWorker( + vllm_config, + self.use_layerwise, + ) + + assert self.connector_worker is not None + if vllm_config.parallel_config.rank == 0: + self.lookup_server = LookupKeyServer(self.connector_worker, + vllm_config, + self.use_layerwise) + + ############################################################ + # Scheduler Side Methods + ############################################################ + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + + ############################################################ + # Worker Side Methods + ############################################################ + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + assert self.connector_worker is not None + self.connector_worker.start_load_kv(self._get_connector_metadata()) + + def wait_for_layer_load(self, layer_name: str) -> None: + if not self.use_layerwise: + return + self.connector_worker.wait_for_layer_load() + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + if not self.use_layerwise: + return + + if self.kv_role == "kv_consumer": + # Don't do save if the role is kv_consumer + return + self.connector_worker.save_kv_layer(self._get_connector_metadata()) + + def wait_for_save(self): + if self.kv_role == "kv_consumer": + # Don't do save if the role is kv_consumer + return + + if self.use_layerwise: + return + + self.connector_worker.wait_for_save(self._get_connector_metadata()) + + def get_finished(self, + finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + """Get the finished recving and sending requests.""" + assert self.connector_worker is not None + meta = self._get_connector_metadata() + done_sending, done_recving = self.connector_worker.get_finished() + sended_and_finished: set[str] = set() + for item in list(self.sended_but_unfinished_reqs): + if item not in meta.unfinished_request_ids: + sended_and_finished.add(item) + self.sended_but_unfinished_reqs.remove(item) + for item in done_sending: + if item in meta.unfinished_request_ids: + self.sended_but_unfinished_reqs.add(item) + else: + sended_and_finished.add(item) + + return sended_and_finished, done_recving + + +class LookupKeyServer: + + def __init__( + self, + pool_worker: KVPoolWorker, + vllm_config: "VllmConfig", + use_layerwise: bool, + ): + self.decoder = MsgpackDecoder() + self.decoder_tensor = MsgpackDecoder(torch.Tensor) + self.ctx = zmq.Context() # type: ignore[attr-defined] + socket_path = get_zmq_rpc_path_lookup(vllm_config) + self.socket = make_zmq_socket( + self.ctx, + socket_path, + zmq.REP, # type: ignore[attr-defined] + bind=True, + ) + + self.pool_worker = pool_worker + self.running = True + self.use_layerwise = use_layerwise + + def process_request(): + while self.running: + all_frames = self.socket.recv_multipart(copy=False) + token_len = int.from_bytes(all_frames[0], byteorder="big") + hash_frames = all_frames[1:] + hashes_str = self.decoder.decode(hash_frames) + result = self.pool_worker.lookup_scheduler( + token_len, hashes_str, self.use_layerwise) + response = result.to_bytes(4, "big") + self.socket.send(response) + + self.thread = threading.Thread(target=process_request, daemon=True) + self.thread.start() + + def close(self): + self.socket.close(linger=0) + # TODO: close the thread! diff --git a/vllm_ascend/distributed/kvpool/backend/__init__.py b/vllm_ascend/distributed/kvpool/backend/__init__.py new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/vllm_ascend/distributed/kvpool/backend/__init__.py @@ -0,0 +1 @@ + diff --git a/vllm_ascend/distributed/kvpool/backend/backend.py b/vllm_ascend/distributed/kvpool/backend/backend.py new file mode 100644 index 00000000000..3aeccbf352c --- /dev/null +++ b/vllm_ascend/distributed/kvpool/backend/backend.py @@ -0,0 +1,29 @@ +from abc import ABC, abstractmethod + +from vllm.config import ParallelConfig + + +class Backend(ABC): + + def __init__(self, parallel_config: ParallelConfig): + pass + + def set_device(self): + pass + + def register_buffer(self, ptrs: list[int], lengths: list[int]): + pass + + @abstractmethod + def exists(self, keys: list[str]) -> list[int]: + pass + + @abstractmethod + def put(self, keys: list[str], addrs: list[list[int]], + sizes: list[list[int]]): + pass + + @abstractmethod + def get(self, keys: list[str], addrs: list[list[int]], + sizes: list[list[int]]): + pass diff --git a/vllm_ascend/distributed/kvpool/backend/memcache_backend.py b/vllm_ascend/distributed/kvpool/backend/memcache_backend.py new file mode 100644 index 00000000000..0da6d092c4f --- /dev/null +++ b/vllm_ascend/distributed/kvpool/backend/memcache_backend.py @@ -0,0 +1,74 @@ +# Standard +from enum import Enum + +import torch +from vllm.config import ParallelConfig +from vllm.utils import logger + +from vllm_ascend.distributed.kvpool.backend.backend import Backend + + +class MmcDirect(Enum): + COPY_L2G = 0 + COPY_G2L = 1 + COPY_G2H = 2 + COPY_H2G = 3 + + +class MemcacheBackend(Backend): + + def __init__(self, parallel_config: ParallelConfig): + try: + from memcache import DistributedObjectStore # type: ignore + except ImportError as e: + raise ImportError( + "Please install memcache by following the instructions at " + "https://gitee.com/ascend/memfabric_hybrid " # noqa: E501 + "to run vLLM with MemcacheConnector.") from e + try: + self.rank = parallel_config.rank + self.store = DistributedObjectStore() + res = self.store.init(self.rank) + assert res == 0 + except ValueError as e: + logger.error("Configuration loading failed: %s", e) + raise + except Exception as exc: + logger.error( + "An error occurred while loading the configuration: %s", exc) + raise + + def set_device(self): + device = torch.device(f"npu:{self.rank}") + torch.npu.set_device(device) + + def register_buffer(self, ptrs: list[int], sizes: list[int]): + for ptr, size in zip(ptrs, sizes): + ret_value = self.store.register_buffer(ptr, size) + if ret_value != 0: + raise RuntimeError("Memcache memory registration failed.") + + def exists(self, keys: list[str]) -> list[int]: + return self.store.batch_is_exist(keys) + + def get(self, key: list[str], addr: list[list[int]], + size: list[list[int]]): + try: + res = self.store.batch_get_into_layers(key, addr, size, + MmcDirect.COPY_G2L.value) + for value in res: + if value != 0: + logger.error(f"Failed to get key {key},res:{res}") + except Exception as e: + logger.error(f"Failed to get key {key}. {e}") + + def put(self, key: list[str], addr: list[list[int]], + size: list[list[int]]): + try: + res = self.store.batch_put_from_layers(key, addr, size, + MmcDirect.COPY_L2G.value) + for value in res: + if value != 0: + logger.error(f"Failed to get key {key},res:{res}") + except Exception as e: + logger.error(f"Failed to put key {key},error:{e}") diff --git a/vllm_ascend/distributed/kvpool/backend/mooncake_backend.py b/vllm_ascend/distributed/kvpool/backend/mooncake_backend.py new file mode 100644 index 00000000000..314c4dcc9b4 --- /dev/null +++ b/vllm_ascend/distributed/kvpool/backend/mooncake_backend.py @@ -0,0 +1,188 @@ +# Standard +import json +import os +import re +from dataclasses import dataclass +from typing import Union + +# Third Party +from vllm.config import ParallelConfig +from vllm.utils import logger +from vllm.utils.network_utils import get_ip + +from vllm_ascend.distributed.kvpool.backend.backend import Backend +from vllm_ascend.distributed.mooncake_transfer_engine import global_te + +DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB +DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB + + +class MooncakeBackend(Backend): + + def __init__(self, parallel_config: ParallelConfig): + try: + from mooncake.store import MooncakeDistributedStore # type: ignore + except ImportError as e: + raise ImportError( + "Please install mooncake by following the instructions at " + "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 + "to run vLLM with MooncakeConnector.") from e + self.config = MooncakeStoreConfig.load_from_env() + self.store = MooncakeDistributedStore() + if self.config.protocol == "ascend": + local_hostname = get_ip() + transfer_engine = global_te.get_transfer_engine(local_hostname, + device_name=None) + self.local_seg = local_hostname + ":" + str( + transfer_engine.get_rpc_port()) + ret = self.store.setup(self.local_seg, self.config.metadata_server, + self.config.global_segment_size, + self.config.local_buffer_size, + self.config.protocol, + self.config.device_name, + self.config.master_server_address, + transfer_engine.get_engine()) + if ret != 0: + msg = "Initialize mooncake failed." + logger.error(msg) + raise RuntimeError(msg) + + def register_buffer(self, ptrs: list[int], lengths: list[int]): + global_te.register_buffer(ptrs, lengths) + + def exists(self, keys: list[str]) -> list[int]: + return self.store.batch_is_exist(keys) + + def put(self, keys: list[str], addrs: list[list[int]], + sizes: list[list[int]]): + try: + res = self.store.batch_put_from_multi_buffers(keys, addrs, sizes) + for value in res: + if value < 0: + logger.error(f"Failed to put key {keys},res:{res}") + except Exception as e: + logger.error(f"Failed to put key {keys},error:{e}") + + def get(self, keys: list[str], addrs: list[list[int]], + sizes: list[list[int]]): + try: + res = self.store.batch_get_into_multi_buffers(keys, addrs, sizes) + for value in res: + if value < 0: + logger.error(f"Failed to get key {keys}, res:{res}") + except Exception as e: + logger.error(f"Failed to get key {keys}, error:{e}") + + +@dataclass +class MooncakeStoreConfig: + local_hostname: str + metadata_server: str + global_segment_size: Union[int, str] + local_buffer_size: int + protocol: str + device_name: str + master_server_address: str + use_ascend_direct: bool + + @staticmethod + def from_file(file_path: str) -> "MooncakeStoreConfig": + with open(file_path) as file: + config = json.load(file) + return MooncakeStoreConfig( + local_hostname=config.get("local_hostname"), + metadata_server=config.get("metadata_server"), + global_segment_size=_parse_global_segment_size( + config.get("global_segment_size", + DEFAULT_GLOBAL_SEGMENT_SIZE)), + local_buffer_size=(config.get("local_buffer_size", + DEFAULT_LOCAL_BUFFER_SIZE)), + protocol=config.get("protocol", "tcp"), + device_name=config.get("device_name", ""), + master_server_address=config.get("master_server_address"), + use_ascend_direct=config.get("use_ascend_direct", False)) + + @staticmethod + def load_from_env() -> "MooncakeStoreConfig": + config_path = os.getenv("MOONCAKE_CONFIG_PATH") + if not config_path: + raise ValueError( + "The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") + return MooncakeStoreConfig.from_file(config_path) + + +def _parse_global_segment_size(value) -> int: + """ + Parse storage size strings with support for units: GB, MB, KB, B + + Args: + value: Input value (int, str, or other convertible types) + + Returns: + int: Size in bytes + + Raises: + ValueError: For invalid format, missing number, or negative values + TypeError: For unsupported input types + """ + + if isinstance(value, int): + return value + elif not isinstance(value, str): + try: + return int(value) + except (TypeError, ValueError) as e: + raise TypeError( + f"Unsupported type for global_segment_size: {type(value)}" + ) from e + + cleaned_input = value.strip().lower() + if not cleaned_input: + raise ValueError("global segment size cannot be empty.") + + UNIT_MULTIPLIERS = { + 'gb': 1024**3, # 1 GB = 1024^3 bytes + 'mb': 1024**2, # 1 MB = 1024^2 bytes + 'kb': 1024, # 1 KB = 1024 bytes + 'b': 1 # 1 B = 1 byte + } + pattern = r'^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$' + match = re.match(pattern, cleaned_input) + + if not match: + raise ValueError(f"Invalid format: '{value}'") + + number_str = match.group(1) + unit = match.group(2) or 'b' + + multiplier = UNIT_MULTIPLIERS[unit] + return _convert_to_bytes(number_str, multiplier, value) + + +def _convert_to_bytes(number_str: str, multiplier: int, + original_input: str) -> int: + """ + Convert numeric string to byte count + + Args: + number_str: Numeric portion of input + multiplier: Unit conversion factor + original_input: Original input string (for error messages) + + Returns: + int: Byte count + + Raises: + ValueError: For invalid numbers or negative results + """ + try: + numeric_value = float(number_str) + except ValueError: + raise ValueError( + f"Invalid numeric value '{number_str}' in: '{original_input}'") + # Calculate byte count + try: + byte_count = int(numeric_value * multiplier) + except OverflowError: + raise ValueError(f"Storage size too large: '{original_input}'") + return byte_count diff --git a/vllm_ascend/distributed/kvpool/config_data.py b/vllm_ascend/distributed/kvpool/config_data.py new file mode 100644 index 00000000000..e3b0873d686 --- /dev/null +++ b/vllm_ascend/distributed/kvpool/config_data.py @@ -0,0 +1,364 @@ +from dataclasses import dataclass +from typing import Iterable, List, Optional, Tuple, Union + +from vllm.distributed.kv_transfer.kv_connector.v1.base import \ + KVConnectorMetadata +from vllm.utils import logger +from vllm.utils.math_utils import cdiv +from vllm.v1.core.kv_cache_utils import BlockHash +from vllm.v1.core.sched.output import NewRequestData + + +#Parameters related to the key +@dataclass +class KeyMetadata: + """name of the LLM model""" + + model_name: str + """ worker id when running under a distributed setting """ + head_or_tp_rank: int + + +@dataclass(order=True) +class PoolKey: + key_metadata: KeyMetadata + chunk_hash: str + + def __hash__(self): + return hash(( + self.key_metadata.model_name, + self.key_metadata.head_or_tp_rank, + self.chunk_hash, + )) + + def to_string(self): + return ( + f"{self.key_metadata.model_name}" + f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}" + ) + + def split_layers(self, num_layers: int) -> List["LayerPoolKey"]: + """Split the key into multiple keys for each layer""" + keys = [] + for layer_id in range(num_layers): + keys.append( + LayerPoolKey( + self.key_metadata, + self.chunk_hash, + layer_id, + )) + return keys + + +@dataclass(order=True) +class LayerPoolKey(PoolKey): + """A key for the layer cache engine""" + + layer_id: int + + def __hash__(self): + return hash(( + self.key_metadata.model_name, + self.key_metadata.head_or_tp_rank, + self.chunk_hash, + self.layer_id, + )) + + def to_string(self): + return ( + f"{self.key_metadata.model_name}" + f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}@{self.layer_id}" + ) + + +class ChunkedTokenDatabase(): + + def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool): + self.metadata = metadata + self.block_size = block_size + self.use_mla = use_mla + self.kv_caches_base_addr: list[int] = [] + self.block_len: list[int] = [] + + def _make_key_by_hash(self, + chunk_hash: str, + layer_id: Optional[int] = None): + assert self.metadata is not None + return PoolKey( + self.metadata, + chunk_hash, + ) + + def set_kv_caches_base_addr(self, kv_caches_base_addr: list[int]): + self.kv_caches_base_addr = kv_caches_base_addr + + def set_block_len(self, block_len: list[int]): + self.block_len = block_len + + def prepare_value(self, start: int, end: int, block_ids: list[int]): + addr_list = [] + size_list = [] + block_id = block_ids[start // self.block_size] + for index, base_addr in enumerate(self.kv_caches_base_addr): + block_len = (self.block_len[index % 2] + if self.use_mla else self.block_len[0]) + + addr = base_addr + block_id * block_len + length = int(block_len / self.block_size * (end - start)) + addr_list.append(addr) + size_list.append(length) + return addr_list, size_list, block_id + + def prepare_value_layer(self, start: int, end: int, block_ids: list[int], + layer_id: int): + block_id = block_ids[start // self.block_size] + if self.use_mla: + addr_k = self.kv_caches_base_addr[layer_id * + 2] + block_id * self.block_len[0] + addr_v = self.kv_caches_base_addr[layer_id * 2 + + 1] + block_id * self.block_len[1] + length_k = int(self.block_len[0] / self.block_size * (end - start)) + length_v = int(self.block_len[1] / self.block_size * (end - start)) + size_list = [length_k, length_v] + else: + addr_k = self.kv_caches_base_addr[layer_id * + 2] + block_id * self.block_len[0] + addr_v = self.kv_caches_base_addr[layer_id * 2 + + 1] + block_id * self.block_len[0] + length = int(self.block_len[0] / self.block_size * (end - start)) + size_list = [length, length] + addr_list = [addr_k, addr_v] + return addr_list, size_list + + def process_tokens( + self, + token_len: int, + block_hashes: Union[list[BlockHash], list[str]], + mask_num: int = 0, + ) -> Iterable[Tuple[int, int, PoolKey]]: + """Process the tokens and return the corresponding cache engine keys. + + :param Union[torch.Tensor, List[int]] tokens: The tokens to process. + + :param Optional[torch.Tensor] mask: The mask for the tokens. Should + have the same length as tokens. And the mask should ALWAYS be like + FFFFFTTTTTTT, where True means the tokens needs to be matched, + and the Falses will ALWAYS be at the PREFIX of the tensor. + + :param bool make_key: Whether to make the cache engine key or not. + If False, the hash value will be returned instead. + + :returns: A iterable of tuples with three elements. The first element + is the start index of the tokens for the key. The second element + is the end index of the tokens for the key. The third element is + the cache engine key (or hash) for the tokens. + + :raises: ValueError if the number of Falses in the mask is not a + multiple of the chunk size. + """ + if not block_hashes: + return + if not isinstance(block_hashes[0], str): + block_hashes = [ + h.hex() # type: ignore[union-attr] + for h in block_hashes + ] + start_idx = 0 + for chunk_id, hash_val in enumerate(block_hashes): + start_idx = chunk_id * self.block_size + if start_idx >= token_len: + break + end_idx = min(start_idx + self.block_size, token_len) + if start_idx < mask_num: + continue + else: + yield start_idx, end_idx, self._make_key_by_hash(hash_val) + + +#Parameters related to the connector metadata +@dataclass +class LoadSpec: + # Number of tokens cached in vLLM + vllm_cached_tokens: int + # Number of tokens that are cached in kvpool + kvpool_cached_tokens: int + # Whether the scheduler allow us to load the tokens + can_load: bool + + +@dataclass +class RequestTracker: + # Request id + req_id: str + + # The token ids that has been scheduled so far + token_len: int + + # The block ids that has been allocated so far + # NOTE: allocated blocks could be more than the number of tokens + # FIXME: need to check whether the block ids will be changed after + # preemption + allocated_block_ids: list[int] + + # The number of tokens that has been savd + num_saved_tokens: int = 0 + + @staticmethod + def from_new_request( + new_request: "NewRequestData", + num_tokens_to_compute: int, + ) -> "RequestTracker": + """Create the request tracker from a new request. + + Args: + new_request (NewRequestData): the new request data. + num_tokens_to_compute (int): the number of tokens that will + be 'computed', including the `num_computed_tokens` (vLLM's + local cache hit) and new tokens that will be scheduled. + + """ + unfolded_block_ids = [] + + if not isinstance(new_request.block_ids[0], list): + unfolded_block_ids = new_request.block_ids.copy() + else: + unfolded_block_ids = new_request.block_ids[0].copy() + + return RequestTracker( + req_id=new_request.req_id, + token_len=num_tokens_to_compute, + allocated_block_ids=unfolded_block_ids, + num_saved_tokens=0, + ) + + def update( + self, + new_token_ids: list[int], + new_block_ids: Union[tuple[list[int], ...], list[int]], + ) -> None: + """Update the request tracker when a running request is + scheduled again + """ + + self.token_len = self.token_len + len(new_token_ids) + + if len(new_block_ids) == 0: + new_block_ids = [] + elif isinstance(new_block_ids, tuple): + new_block_ids = new_block_ids[0] + elif isinstance(new_block_ids, list): + pass + else: + raise ValueError( + f"Unsupported new_block_ids type {type(new_block_ids)}") + self.allocated_block_ids.extend(new_block_ids) + + +@dataclass +class ReqMeta: + # Request id + req_id: str + # Request tokens + token_len_chunk: int + + block_ids: list[int] + + block_hashes: list[BlockHash] + + can_save: Optional[bool] = None + # load_spec + load_spec: Optional[LoadSpec] = None + + is_last_chunk: Optional[bool] = None + + @staticmethod + def from_request_tracker( + tracker: RequestTracker, + block_size: int, + load_spec: Optional[LoadSpec] = None, + skip_save: Optional[bool] = False, + block_hashes: list[BlockHash] = [], + is_last_chunk: Optional[bool] = None, + discard_partial_chunks: bool = True, + ) -> Optional["ReqMeta"]: + """Create the request metadata from a request tracker. + + Args: + tracker (RequestTracker): the request tracker. + block_size (int): the block size in vLLM. + load_spec (Optional[LoadSpec]): the load spec for KV cache loading. + skip_save (bool): whether to skip the save operation. + discard_partial_chunks (bool): whether to discard partial chunks. + + Returns: + the request metadata if we need to perform load/save + operations, None otherwise. + """ + input_token_len = tracker.token_len + + # For save operation: do not save if the following condition is met + # 1. has already been saved before (num_saved_tokens > 0) + # 2. number of unsaved tokens is not reached the chunk boundary + chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) * + block_size if discard_partial_chunks else 0) + # Calculate number of tokens to save based on discard_partial_chunks + # setting + num_tokens_to_save = ((input_token_len // block_size * block_size) + if discard_partial_chunks else input_token_len) + + skip_save = skip_save or num_tokens_to_save < chunk_boundary + if skip_save and load_spec is None: + return None + + # If we need to save, update the number of saved tokens + if not skip_save: + tracker.num_saved_tokens = num_tokens_to_save + + # # For load operation: check whether the request is scheduled to load + if load_spec is not None and load_spec.can_load: + logger.debug( + "Scheduled to load %d tokens for request %s", + load_spec.kvpool_cached_tokens, + tracker.req_id, + ) + else: + # Do not load if not in `can_load` state + load_spec = None + logger.debug( + f"request:{tracker.req_id}, meta save spec:{not skip_save}, meta load spec:{load_spec}" + ) + return ReqMeta( + req_id=tracker.req_id, + token_len_chunk=num_tokens_to_save, + block_ids=tracker.allocated_block_ids, + can_save=not skip_save, + load_spec=load_spec, + block_hashes=block_hashes, + is_last_chunk=is_last_chunk, + ) + + +class AscendConnectorMetadata(KVConnectorMetadata): + + def __init__(self, unfinished_request_ids): + self.requests = [] + self.unfinished_request_ids = unfinished_request_ids + + def add_request(self, req_meta: ReqMeta) -> None: + """Add a request to the metadata. + + Args: + req_meta (ReqMeta): the request metadata. + """ + self.requests.append(req_meta) + + +@dataclass +class LasyerMultiBlockReqMeta: + req_id: str + keys: List[LayerPoolKey] + starts: List[int] + ends: list[int] + block_ids: list[int] + layer_id: int + is_last_chunk: bool = True diff --git a/vllm_ascend/distributed/kvpool/kv_transfer.py b/vllm_ascend/distributed/kvpool/kv_transfer.py new file mode 100644 index 00000000000..b30158ae8c2 --- /dev/null +++ b/vllm_ascend/distributed/kvpool/kv_transfer.py @@ -0,0 +1,246 @@ +import queue +import threading +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Optional + +import torch +from vllm.utils import logger +from vllm.v1.core.kv_cache_utils import BlockHash + +from vllm_ascend.distributed.kvpool.backend.backend import Backend + +# isort: off +from vllm_ascend.distributed.kvpool.config_data import (ChunkedTokenDatabase, + LasyerMultiBlockReqMeta + ) +# isort: on + + +class KVTransferThread(threading.Thread): + + def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, + tp_rank: int, ready_event: threading.Event, name: str): + super().__init__(daemon=True, name=name) + self.m_store = m_store + self.ready_event = ready_event + self.tp_rank = tp_rank + self.token_database = token_database + self.done_task_lock = threading.Lock() + self.request_queue: queue.Queue[Any] = queue.Queue() + # TODO(jianzs): make this configurable + self.executor = ThreadPoolExecutor(max_workers=32) + self.finished_requests: set[str] = set() + + def add_request( + self, + req_id: str, + token_len: int, + block_ids: list[int], + block_hashes: list[BlockHash], + mask_num: int = 0, + is_last_chunk: Optional[bool] = None, + ) -> torch.Tensor: + req = ({ + "req_id": req_id, + "token_len": token_len, + "block_ids": block_ids, + "block_hashes": block_hashes, + "mask_num": mask_num, + "is_last_chunk": is_last_chunk, + }) + self.request_queue.put(req) + + def get_and_clear_finished_requests(self) -> set[str]: + """ + Get and clear the requests that have been completed. + Returns: + A set of request IDs that have been completed. + """ + with self.done_task_lock: + finished_requests = self.finished_requests.copy() + self.finished_requests.clear() + return finished_requests + + def set_finished_request(self, req_id): + with self.done_task_lock: + self.finished_requests.add(req_id) + + def run(self): + """Run the thread to handle KV cache transfer requests.""" + self.m_store.set_device() + self.ready_event.set() + while True: + try: + request_data = self.request_queue.get() + if request_data is None: + logger.warning("Received a None request!") + self.request_queue.task_done() + continue + self._handle_request(request_data) + except Exception as e: + logger.error(f"Error in KVCacheTransferThread: {e}") + + def _handle_request(self, req_meta: dict[str, Any]): + pass + + +class KVCacheStoreSendingThread(KVTransferThread): + + def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, + tp_rank: int, put_step: int, ready_event: threading.Event): + super().__init__(m_store, + token_database, + tp_rank, + ready_event, + name="KVCacheSendingThread") + self.put_step = put_step + + def _handle_request(self, req_meta: dict[str, Any]): + token_len = req_meta["token_len"] + mask_num = req_meta["mask_num"] + block_ids = req_meta["block_ids"] + block_hashes = req_meta["block_hashes"] + req_id = req_meta["req_id"] + is_last_chunk = req_meta["is_last_chunk"] + addr_list = [] + size_list = [] + key_list = [] + for start, end, key in self.token_database.process_tokens( + token_len, block_hashes, mask_num): + addr, size, _ = self.token_database.prepare_value( + start, end, block_ids) + key_list.append(key.to_string()) + addr_list.append(addr) + size_list.append(size) + key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step] + addr_list_tp = addr_list[self.tp_rank % self.put_step::self.put_step] + size_list_tp = size_list[self.tp_rank % self.put_step::self.put_step] + if key_list_tp: + torch.npu.current_stream().synchronize() + self.m_store.put(key_list_tp, addr_list_tp, size_list_tp) + if is_last_chunk: + self.set_finished_request(req_id) + self.request_queue.task_done() + + +class KVCacheStoreRecvingThread(KVTransferThread): + + def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, + tp_rank: int, ready_event: threading.Event): + super().__init__(m_store, + token_database, + tp_rank, + ready_event, + name="KVCacheStoreRecvingThread") + + def _handle_request(self, req_meta: dict[str, Any]): + token_len = req_meta["token_len"] + mask_num = req_meta["mask_num"] + block_ids = req_meta["block_ids"] + req_id = req_meta["req_id"] + block_hashes = req_meta["block_hashes"] + addr_list = [] + size_list = [] + key_list = [] + for start, end, key in self.token_database.process_tokens( + token_len, block_hashes, mask_num): + addr, size, _ = self.token_database.prepare_value( + start, end, block_ids) + key_list.append(key.to_string()) + addr_list.append(addr) + size_list.append(size) + key_list_c = key_list[self.tp_rank % + len(key_list):] + key_list[:self.tp_rank % + len(key_list)] + addr_list_c = addr_list[self.tp_rank % + len(addr_list):] + addr_list[:self.tp_rank % + len(addr_list)] + size_list_c = size_list[self.tp_rank % + len(size_list):] + size_list[:self.tp_rank % + len(size_list)] + self.m_store.get(key_list_c, addr_list_c, size_list_c) + self.set_finished_request(req_id) + self.request_queue.task_done() + + +class KVCacheStoreLayerSendingThread(KVTransferThread): + + def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, + tp_rank: int, put_step: int, ready_event: threading.Event, + num_layers: int): + super().__init__(m_store, + token_database, + tp_rank, + ready_event, + name="KVCacheStoreLayerSendingThread") + self.final_layer_id = num_layers - 1 + self.put_step = put_step + + def add_request( # type: ignore[override] + self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: + self.request_queue.put(req_meta) + + def _handle_request( # type: ignore[override] + self, req_meta: LasyerMultiBlockReqMeta): + addr_list = [] + size_list = [] + key_list = [] + for index, key in enumerate(req_meta.keys): + addr, size = self.token_database.prepare_value_layer( + req_meta.starts[index], req_meta.ends[index], + req_meta.block_ids, req_meta.layer_id) + key_list.append(key.to_string()) + addr_list.append(addr) + size_list.append(size) + key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step] + addr_list_tp = addr_list[self.tp_rank % self.put_step::self.put_step] + size_list_tp = size_list[self.tp_rank % self.put_step::self.put_step] + if key_list_tp: + torch.npu.current_stream().synchronize() + self.m_store.put(key_list_tp, addr_list_tp, size_list_tp) + if req_meta.layer_id == self.final_layer_id and req_meta.is_last_chunk: + self.set_finished_request(req_meta.req_id) + self.request_queue.task_done() + + +class KVCacheStoreLayerRecvingThread(KVTransferThread): + + def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, + tp_rank: int, ready_event: threading.Event, + get_event: threading.Event): + super().__init__(m_store, + token_database, + tp_rank, + ready_event, + name="KVCacheStoreLayerRecvingThread") + self.get_event = get_event + + def add_request( # type: ignore[override] + self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: + self.request_queue.put(req_meta) + + def _handle_request( # type: ignore[override] + self, req_meta: LasyerMultiBlockReqMeta): + addr_list = [] + size_list = [] + key_list = [] + for index, key in enumerate(req_meta.keys): + addr, size = self.token_database.prepare_value_layer( + req_meta.starts[index], req_meta.ends[index], + req_meta.block_ids, req_meta.layer_id) + key_list.append(key.to_string()) + addr_list.append(addr) + size_list.append(size) + key_list_c = key_list[self.tp_rank % + len(key_list):] + key_list[:self.tp_rank % + len(key_list)] + addr_list_c = addr_list[self.tp_rank % + len(addr_list):] + addr_list[:self.tp_rank % + len(addr_list)] + size_list_c = size_list[self.tp_rank % + len(size_list):] + size_list[:self.tp_rank % + len(size_list)] + self.m_store.get(key_list_c, addr_list_c, size_list_c) + + self.request_queue.task_done() + self.get_event.set() diff --git a/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py b/vllm_ascend/distributed/kvpool/pool_scheduler.py similarity index 52% rename from vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py rename to vllm_ascend/distributed/kvpool/pool_scheduler.py index aad4dc6e9c3..06041b5a6e5 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py +++ b/vllm_ascend/distributed/kvpool/pool_scheduler.py @@ -1,174 +1,33 @@ -import threading from typing import Any, Optional -import torch import vllm.envs as envs import zmq -from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) -from vllm.forward_context import ForwardContext +from vllm.distributed.kv_transfer.kv_connector.v1.base import \ + KVConnectorMetadata from vllm.utils import logger from vllm.utils.network_utils import make_zmq_socket from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.kv_cache_utils import BlockHash from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import Request -from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder +from vllm.v1.serial_utils import MsgpackEncoder -from vllm_ascend.distributed.mooncake.config_data import ( - LoadSpec, MooncakeConnectorMetadata, ReqMeta, RequestTracker) -from vllm_ascend.distributed.mooncake.mooncake_engine import MooncakeEngine +from vllm_ascend.distributed.kvpool.config_data import ( + AscendConnectorMetadata, LoadSpec, ReqMeta, RequestTracker) -class MooncakeConnectorV1(KVConnectorBase_V1): - - def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): - super().__init__(vllm_config=vllm_config, role=role) - self.kv_role = vllm_config.kv_transfer_config.kv_role - - self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "use_layerwise", False) - - self.kv_caches: dict[str, torch.Tensor] = {} - - self._block_size = vllm_config.cache_config.block_size - - self.sended_but_unfinished_reqs: set[str] = set() - - if role == KVConnectorRole.SCHEDULER: - self.connector_scheduler = MooncakeStoreConnectorV1Scheduler( - vllm_config, self.use_layerwise) - else: - self.connector_worker = MooncakeEngine( - vllm_config, - self.use_layerwise, - ) - - assert self.connector_worker is not None - if vllm_config.parallel_config.rank == 0: - self.lookup_server = MooncakeLookupServer( - self.connector_worker, vllm_config, self.use_layerwise) - - ############################################################ - # Scheduler Side Methods - ############################################################ - - def get_num_new_matched_tokens( - self, request: "Request", - num_computed_tokens: int) -> tuple[int, bool]: - assert self.connector_scheduler is not None - return self.connector_scheduler.get_num_new_matched_tokens( - request, num_computed_tokens) - - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): - assert self.connector_scheduler is not None - return self.connector_scheduler.update_state_after_alloc( - request, blocks, num_external_tokens) - - def build_connector_meta( - self, - scheduler_output: SchedulerOutput, - ) -> KVConnectorMetadata: - assert self.connector_scheduler is not None - return self.connector_scheduler.build_connector_meta(scheduler_output) - - def request_finished( - self, - request: "Request", - block_ids: list[int], - ) -> tuple[bool, Optional[dict[str, Any]]]: - assert self.connector_scheduler is not None - return self.connector_scheduler.request_finished(request, block_ids) - - ############################################################ - # Worker Side Methods - ############################################################ - def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): - assert self.connector_worker is not None - self.connector_worker.register_kv_caches(kv_caches) - - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: - assert self.connector_worker is not None - assert isinstance(self._get_connector_metadata(), - MooncakeConnectorMetadata) - self.connector_worker.start_load_kv(self._get_connector_metadata()) - - def wait_for_layer_load(self, layer_name: str) -> None: - """MooncakeStoreConnector does not do layerwise saving.""" - if not self.use_layerwise: - return - self.connector_worker.wait_for_layer_load() - - def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: - """MooncakeStoreConnector does not save explicitly.""" - if not self.use_layerwise: - return - - if self.kv_role == "kv_consumer": - # Don't do save if the role is kv_consumer - return - self.connector_worker.save_kv_layer(self._get_connector_metadata()) - - def wait_for_save(self): - """MooncakeStoreConnector does not save explicitly.""" - if self.kv_role == "kv_consumer": - # Don't do save if the role is kv_consumer - return - - if self.use_layerwise: - self.connector_worker.wait_layer_transfer_finish() - return - - self.connector_worker.wait_for_save(self._get_connector_metadata()) - - def get_finished(self, - finished_req_ids: set[str]) -> tuple[set[str], set[str]]: - """Get the finished recving and sending requests.""" - assert self.connector_worker is not None - meta = self._get_connector_metadata() - done_sending, done_recving = self.connector_worker.get_finished() - sended_and_finished: set[str] = set() - for item in list(self.sended_but_unfinished_reqs): - if item not in meta.unfinished_request_ids: - sended_and_finished.add(item) - self.sended_but_unfinished_reqs.remove(item) - for item in done_sending: - if item in meta.unfinished_request_ids: - self.sended_but_unfinished_reqs.add(item) - else: - sended_and_finished.add(item) - - return sended_and_finished, done_recving - - -def get_zmq_rpc_path_mooncake( - vllm_config: Optional["VllmConfig"] = None, ) -> str: - base_url = envs.VLLM_RPC_BASE_PATH - # Default to 0 if not configured - rpc_port = 0 - if vllm_config is not None: - rpc_port = vllm_config.kv_transfer_config.get_from_extra_config( - "mooncake_rpc_port", 0) - logger.debug("Base URL: %s, RPC Port: %s", base_url, rpc_port) - return f"ipc://{base_url}/mooncake_rpc_port_{rpc_port}" - - -class MooncakeStoreConnectorV1Scheduler: +class KVPoolScheduler: def __init__(self, vllm_config: "VllmConfig", use_layerwise): - self.client = MooncakeLookupClient(vllm_config) + self.client = LookupKeyClient(vllm_config) self.use_layerwise = use_layerwise self.kv_role = vllm_config.kv_transfer_config.kv_role self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "consumer_is_to_load", False) self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "load_async", False) - # request_id -> (vllm cached tokes, mooncake cached tokens) + # request_id -> (vllm cached tokes, kvpool cached tokens) self.load_specs: dict[str, LoadSpec] = {} self._block_size = vllm_config.cache_config.block_size # request_id -> full_token_ids @@ -201,14 +60,13 @@ def get_num_new_matched_tokens( return 0, False if self._discard_partial_chunks: - token_block_end = len(request.prompt_token_ids - ) // self._block_size * self._block_size - token_ids = torch.tensor( - request.prompt_token_ids[:token_block_end]) + token_len = len(request.prompt_token_ids + ) // self._block_size * self._block_size else: - token_ids = torch.tensor(request.prompt_token_ids) + token_len = len(request.prompt_token_ids) - num_external_hit_tokens = self.client.lookup(token_ids) + num_external_hit_tokens = self.client.lookup(token_len, + request.block_hashes) if num_external_hit_tokens == request.num_tokens: num_external_hit_tokens -= 1 @@ -216,7 +74,7 @@ def get_num_new_matched_tokens( need_to_allocate = num_external_hit_tokens - num_computed_tokens logger.info( - "Reqid: %s, Total tokens %d, mooncake hit tokens: %d, need to load: %d", + "Reqid: %s, Total tokens %d, kvpool hit tokens: %d, need to load: %d", request.request_id, request.num_tokens, num_external_hit_tokens, @@ -228,11 +86,11 @@ def get_num_new_matched_tokens( self.load_specs[request.request_id] = LoadSpec( vllm_cached_tokens=num_computed_tokens, - mooncake_cached_tokens=num_external_hit_tokens, + kvpool_cached_tokens=num_external_hit_tokens, can_load=False, ) - return need_to_allocate, self.load_async + return need_to_allocate, self.load_async and not self.use_layerwise def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", @@ -261,10 +119,10 @@ def update_state_after_alloc(self, request: "Request", assert ( num_external_tokens > 0 and num_external_tokens - == self.load_specs[request.request_id].mooncake_cached_tokens - + == self.load_specs[request.request_id].kvpool_cached_tokens - self.load_specs[request.request_id].vllm_cached_tokens ), (f"Mismatch in number of tokens: {num_external_tokens} vs " - f"{self.load_specs[request.request_id].mooncake_cached_tokens} - " + f"{self.load_specs[request.request_id].kvpool_cached_tokens} - " f"{self.load_specs[request.request_id].vllm_cached_tokens}" f" for request {request.request_id}") @@ -289,7 +147,7 @@ def build_connector_meta( self._unfinished_requests.pop(finished_req_id, None) self._unfinished_request_ids.discard(finished_req_id) - meta = MooncakeConnectorMetadata(self._unfinished_request_ids) + meta = AscendConnectorMetadata(self._unfinished_request_ids) for request in scheduler_output.scheduled_new_reqs: # Right now, we only load KV for new requests @@ -304,12 +162,15 @@ def build_connector_meta( self._block_size * self._block_size) if self._discard_partial_chunks else len( request.prompt_token_ids)) + request_tuple = self._unfinished_requests.get(request.req_id) + request_real = request_tuple[0] # type: ignore[index] req_meta = ReqMeta.from_request_tracker( request_tracker, self._block_size, load_spec=load_spec, skip_save=force_skip_save, - is_last_chunk=len(request_tracker.token_ids) + block_hashes=request_real.block_hashes, + is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num, discard_partial_chunks=self._discard_partial_chunks, ) @@ -317,33 +178,14 @@ def build_connector_meta( meta.add_request(req_meta) cached_reqs = scheduler_output.scheduled_cached_reqs - if isinstance(cached_reqs, list) and not force_skip_save: - for i, req in enumerate(cached_reqs): - request_tracker = self._request_trackers[req.req_id] - request_tracker.update(req.new_token_ids, req.new_block_ids) - last_chunk_tokens_num = ((len(req.prompt_token_ids) // - self._block_size * self._block_size) - if self._discard_partial_chunks else - len(req.prompt_token_ids)) - req_meta = ReqMeta.from_request_tracker( - request_tracker, - self._block_size, - load_spec=None, - skip_save=force_skip_save, - is_last_chunk=len(request_tracker.token_ids) - >= last_chunk_tokens_num, - discard_partial_chunks=self._discard_partial_chunks, - ) - if req_meta is not None: - meta.add_request(req_meta) - elif not force_skip_save: + if not force_skip_save: for i, req_id in enumerate(cached_reqs.req_ids): request_tracker = self._request_trackers[req_id] num_new_tokens = scheduler_output.num_scheduled_tokens[req_id] req_tuple = self._unfinished_requests.get(req_id) if req_tuple: request = req_tuple[0] - num_current_tokens = len(request_tracker.token_ids) + num_current_tokens = request_tracker.token_len new_token_ids = request.all_token_ids[ num_current_tokens:num_current_tokens + num_new_tokens] else: @@ -355,8 +197,7 @@ def build_connector_meta( continue request_tracker.update(new_token_ids, new_block_ids) # decode not save - if len(request_tracker.token_ids) > len( - request.prompt_token_ids): + if request_tracker.token_len > len(request.prompt_token_ids): continue last_chunk_tokens_num = ((len(request.prompt_token_ids) // @@ -368,7 +209,8 @@ def build_connector_meta( self._block_size, load_spec=None, skip_save=force_skip_save, - is_last_chunk=len(request_tracker.token_ids) + block_hashes=request.block_hashes, + is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num, discard_partial_chunks=self._discard_partial_chunks, ) @@ -384,15 +226,14 @@ def build_connector_meta( load_spec = self.load_specs.pop(request_id, None) if not load_spec: continue - num_tokens_to_compute = load_spec.mooncake_cached_tokens + num_tokens_to_compute = load_spec.kvpool_cached_tokens if (num_tokens_to_compute % self._block_size != 0) and (num_tokens_to_compute == len(request.prompt_token_ids) - 1): num_tokens_to_compute = num_tokens_to_compute + 1 request_tracker = RequestTracker( req_id=request_id, - token_ids=request.prompt_token_ids[:num_tokens_to_compute]. - copy(), + token_len=num_tokens_to_compute, allocated_block_ids=block_ids, num_saved_tokens=0, ) @@ -404,6 +245,7 @@ def build_connector_meta( self._block_size, load_spec=load_spec, skip_save=None, + block_hashes=request.block_hashes, discard_partial_chunks=self._discard_partial_chunks, ) if req_meta is not None: @@ -431,12 +273,12 @@ def request_finished( return delay_free_blocks, None -class MooncakeLookupClient: +class LookupKeyClient: def __init__(self, vllm_config: "VllmConfig"): self.encoder = MsgpackEncoder() self.ctx = zmq.Context() # type: ignore[attr-defined] - socket_path = get_zmq_rpc_path_mooncake(vllm_config) + socket_path = get_zmq_rpc_path_lookup(vllm_config) self.socket = make_zmq_socket( self.ctx, socket_path, @@ -444,9 +286,12 @@ def __init__(self, vllm_config: "VllmConfig"): bind=False, ) - def lookup(self, token_ids: torch.Tensor) -> int: - request = self.encoder.encode(token_ids) - self.socket.send_multipart(request, copy=False) + def lookup(self, token_len: int, block_hashes: list[BlockHash]) -> int: + hash_strs = [h.hex() for h in block_hashes] + hash_frames = self.encoder.encode(hash_strs) + token_len_bytes = token_len.to_bytes(4, byteorder="big") + all_frames = [token_len_bytes] + list(hash_frames) + self.socket.send_multipart(all_frames, copy=False) resp = self.socket.recv() result = int.from_bytes(resp, "big") return result @@ -455,39 +300,19 @@ def close(self): self.socket.close(linger=0) -class MooncakeLookupServer: - - def __init__( - self, - mooncake_engine: MooncakeEngine, - vllm_config: "VllmConfig", - use_layerwise: bool, - ): - self.decoder = MsgpackDecoder(torch.Tensor) - self.ctx = zmq.Context() # type: ignore[attr-defined] - socket_path = get_zmq_rpc_path_mooncake(vllm_config) - self.socket = make_zmq_socket( - self.ctx, - socket_path, - zmq.REP, # type: ignore[attr-defined] - bind=True, - ) - - self.mooncake_engine = mooncake_engine - self.running = True - - def process_request(): - while self.running: - frames = self.socket.recv_multipart(copy=False) - token_ids = self.decoder.decode(frames) - result = self.mooncake_engine.lookup_scheduler( - token_ids, use_layerwise) - response = result.to_bytes(4, "big") - self.socket.send(response) - - self.thread = threading.Thread(target=process_request, daemon=True) - self.thread.start() - - def close(self): - self.socket.close(linger=0) - # TODO: close the thread! +def get_zmq_rpc_path_lookup( + vllm_config: Optional["VllmConfig"] = None, ) -> str: + base_url = envs.VLLM_RPC_BASE_PATH + # Default to 0 if not configured + rpc_port = 0 + if vllm_config is not None: + extra_config = vllm_config.kv_transfer_config.kv_connector_extra_config + if "lookup_rpc_port" in extra_config: + rpc_port = extra_config["lookup_rpc_port"] + elif "mooncake_rpc_port" in extra_config: + rpc_port = extra_config["mooncake_rpc_port"] + logger.warning( + "It is recommended to use the lookup_rpc_port, as the mooncake_rpc_port will be removed in the future." + ) + logger.debug("Base URL: %s, RPC Port: %s", base_url, rpc_port) + return f"ipc://{base_url}/lookup_rpc_port_{rpc_port}" diff --git a/vllm_ascend/distributed/mooncake/mooncake_engine.py b/vllm_ascend/distributed/kvpool/pool_worker.py similarity index 57% rename from vllm_ascend/distributed/mooncake/mooncake_engine.py rename to vllm_ascend/distributed/kvpool/pool_worker.py index 143d2c91cad..b03d2808928 100644 --- a/vllm_ascend/distributed/mooncake/mooncake_engine.py +++ b/vllm_ascend/distributed/kvpool/pool_worker.py @@ -1,25 +1,33 @@ # Standard import math import threading -import time -from typing import Generator, List, Optional, Union +from typing import Dict, Generator, Optional, Type # Third Party import torch from vllm.config import VllmConfig from vllm.utils import logger -from vllm.utils.torch_utils import get_kv_cache_torch_dtype - -from vllm_ascend.distributed.mooncake.config_data import ( - ChunkedTokenDatabase, LasyerMultiBlockReqMeta, MooncakeConnectorMetadata, - MooncakeEngineMetadata) -from vllm_ascend.distributed.mooncake.kv_transfer import ( +from vllm.v1.core.kv_cache_utils import BlockHash + +from vllm_ascend.distributed.kvpool.backend.backend import Backend +from vllm_ascend.distributed.kvpool.backend.memcache_backend import \ + MemcacheBackend +from vllm_ascend.distributed.kvpool.backend.mooncake_backend import \ + MooncakeBackend +from vllm_ascend.distributed.kvpool.config_data import ( + AscendConnectorMetadata, ChunkedTokenDatabase, KeyMetadata, + LasyerMultiBlockReqMeta) +from vllm_ascend.distributed.kvpool.kv_transfer import ( KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread, KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread) -from vllm_ascend.distributed.mooncake.mooncake_store import Mooncakestore + +backend_map: Dict[str, Type[Backend]] = { + "mooncake": MooncakeBackend, + "memcache": MemcacheBackend, +} -class MooncakeEngine: +class KVPoolWorker: #The main class for the cache engine. def __init__( @@ -29,6 +37,7 @@ def __init__( ): model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config + self.dp_rank = parallel_config.data_parallel_rank self.use_mla = False if (hasattr(model_config, "use_mla") and isinstance(model_config.use_mla, bool) @@ -40,37 +49,37 @@ def __init__( self.kv_role = vllm_config.kv_transfer_config.kv_role self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "load_async", False) - self.register_buffer = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "register_buffer", False) + self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "backend", "mooncake") self.block_size = vllm_config.cache_config.block_size self.current_layer = 0 - # self.use_mla = first_kv_cache_tuple[0].size( - # -1) != first_kv_cache_tuple[1].size(-1) self.num_layers = model_config.get_num_layers(parallel_config) self.block_size = vllm_config.cache_config.block_size - num_kv_head = model_config.get_num_kv_heads(parallel_config) - head_size = model_config.get_head_size() - kv_dtype = get_kv_cache_torch_dtype( - vllm_config.cache_config.cache_dtype, model_config.dtype) - self.hidden_dim_size = num_kv_head * head_size + if self.use_mla: - kv_shape = (self.num_layers, 1, self.block_size, 1, head_size) + self.num_kv_head = 1 + else: + self.num_kv_head = model_config.get_total_num_kv_heads() + + if self.num_kv_head < self.tp_size: + self.put_step = self.tp_size // self.num_kv_head + self.head_or_tp_rank = self.tp_rank // self.put_step else: - kv_shape = (self.num_layers, 2, self.block_size, num_kv_head, - head_size) - self.metadata = MooncakeEngineMetadata( + self.head_or_tp_rank = self.tp_rank + self.put_step = 1 + + self.metadata = KeyMetadata( model_config.model, - parallel_config.world_size, - parallel_config.rank, - kv_dtype, - kv_shape, - self.block_size, - self.use_mla, + self.head_or_tp_rank, ) - self.token_database = ChunkedTokenDatabase(self.metadata) + self.token_database = ChunkedTokenDatabase(self.metadata, + self.block_size, + self.use_mla) - self.m_store = Mooncakestore(parallel_config) + real_backend = backend_map.get(self.backend.lower()) + self.m_store = real_backend( # type: ignore[misc] + parallel_config) self.kv_send_thread: Optional[KVTransferThread] = None self.kv_recv_thread: Optional[KVTransferThread] = None @@ -108,94 +117,83 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.kv_caches = kv_caches self.kv_caches_base_addr = [] + ptrs = [] + lengths = [] for cache_or_caches in kv_caches.values(): # Normalize to always be a list of caches if self.use_mla: for i, cache in enumerate(cache_or_caches, 0): base_addr = cache.data_ptr() self.kv_caches_base_addr.append(base_addr) - if self.register_buffer: - region_len = self.num_blocks * self.block_len[i % 2] - self._register(base_addr, region_len) + region_len = self.num_blocks * self.block_len[i % 2] + ptrs.append(base_addr) + lengths.append(region_len) else: cache_list = [cache_or_caches ] if self.use_mla else cache_or_caches for cache in cache_list: base_addr = cache.data_ptr() self.kv_caches_base_addr.append(base_addr) - if self.register_buffer: - region_len = self.num_blocks * self.block_len[0] - self._register(base_addr, region_len) + region_len = self.num_blocks * self.block_len[0] + ptrs.append(base_addr) + lengths.append(region_len) + self.m_store.register_buffer(ptrs, lengths) + self.token_database.set_kv_caches_base_addr(self.kv_caches_base_addr) + self.token_database.set_block_len(self.block_len) if self.use_layerwise: self.get_event = threading.Event() if self.kv_role in ['kv_producer', 'kv_both']: ready_event_sending = threading.Event() self.kv_send_thread = KVCacheStoreLayerSendingThread( - self.tp_rank, self.tp_size, self.m_store, - self.kv_caches_base_addr, self.token_database, - self.block_len, self.block_size, ready_event_sending, - self.num_layers) + self.m_store, self.token_database, self.tp_rank, + self.put_step, ready_event_sending, self.num_layers) self.kv_send_thread.start() ready_event = threading.Event() self.kv_recv_thread = KVCacheStoreLayerRecvingThread( - self.tp_rank, self.tp_size, self.m_store, - self.kv_caches_base_addr, self.token_database, self.block_len, - self.block_size, ready_event, self.get_event) + self.m_store, self.token_database, self.tp_rank, ready_event, + self.get_event) self.kv_recv_thread.start() ready_event.wait() else: if self.kv_role in ['kv_producer', 'kv_both']: ready_event_sending = threading.Event() self.kv_send_thread = KVCacheStoreSendingThread( - self.tp_rank, self.tp_size, self.m_store, - self.kv_caches_base_addr, self.token_database, - self.block_len, self.block_size, ready_event_sending) + self.m_store, self.token_database, self.tp_rank, + self.put_step, ready_event_sending) self.kv_send_thread.start() if self.load_async: ready_event = threading.Event() self.kv_recv_thread = KVCacheStoreRecvingThread( - self.tp_rank, self.tp_size, self.m_store, - self.kv_caches_base_addr, self.token_database, - self.block_len, self.block_size, ready_event) + self.m_store, self.token_database, self.tp_rank, + ready_event) self.kv_recv_thread.start() ready_event.wait() - def _register(self, ptr, length): - logger.debug( - "Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, " - "block_lens=%s", ptr, length, self.num_blocks, self.block_len) - try: - self.m_store.register_buffer(ptr, length) - except Exception as e: - raise RuntimeError( - f"Mooncake memory registration failed. Error is: {e}") - - def start_load_kv(self, metadata: MooncakeConnectorMetadata): + def start_load_kv(self, metadata: AscendConnectorMetadata): self.current_layer = 0 self.layerwise_retrievers = [] for request in metadata.requests: load_spec = request.load_spec if load_spec is None or not load_spec.can_load: #load =0 continue - tokens = request.token_ids + token_len = request.token_len_chunk req_id = request.req_id - if (load_spec.mooncake_cached_tokens % self.block_size - != 0) and (load_spec.mooncake_cached_tokens - == tokens.shape[0] - 1): - tokens = tokens[:request.load_spec.mooncake_cached_tokens + 1] + if (load_spec.kvpool_cached_tokens % self.block_size + != 0) and (load_spec.kvpool_cached_tokens + == token_len - 1): + token_len = request.load_spec.kvpool_cached_tokens + 1 else: - tokens = tokens[:request.load_spec.mooncake_cached_tokens] - masked_token_count = (request.load_spec.vllm_cached_tokens // - self.block_size * self.block_size) - token_mask = torch.ones_like(tokens, dtype=torch.bool) - token_mask[:masked_token_count] = False + token_len = request.load_spec.kvpool_cached_tokens + mask_num = (request.load_spec.vllm_cached_tokens // + self.block_size * self.block_size) if self.use_layerwise: layerwise_retriever = self.retrieve_layer( req_id, - tokens, + token_len, request.block_ids, - token_mask, + request.block_hashes, + mask_num, ) next(layerwise_retriever) # first layer load self.layerwise_retrievers.append(layerwise_retriever) @@ -203,102 +201,84 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata): if self.load_async: self.kv_recv_thread.add_request( # type: ignore[union-attr] req_id, - tokens, + token_len, request.block_ids, - token_mask, + request.block_hashes, + mask_num, ) else: - if self.m_store.config.use_ascend_direct: - addr_list = [] - size_list = [] - key_list = [] - blockIds = [] - for start, end, key in self.token_database.process_tokens( - tokens, token_mask): - addr, size, block_id = self.prepare_value( - start, end, request.block_ids) - key_list.append(key.to_string()) - addr_list.append(addr) - size_list.append(size) - blockIds.append(block_id) - self.m_store.get_batch(key_list, addr_list, size_list, - blockIds) - else: - for start, end, key in self.token_database.process_tokens( - tokens, token_mask): - addr, size, _ = self.prepare_value( - start, end, request.block_ids) - self.m_store.get(key, addr, size) - - def prepare_value(self, start: int, end: int, block_ids: list[int]): - addr_list = [] - size_list = [] - block_id = block_ids[start // self.block_size] - for index, base_addr in enumerate(self.kv_caches_base_addr): - block_len = (self.block_len[index % 2] - if self.use_mla else self.block_len[0]) - - addr = base_addr + block_id * block_len - length = int(block_len / self.block_size * (end - start)) - addr_list.append(addr) - size_list.append(length) - return addr_list, size_list, block_id + addr_list = [] + size_list = [] + key_list = [] + for start, end, key in self.token_database.process_tokens( + token_len, request.block_hashes, mask_num): + addr, size, _ = self.token_database.prepare_value( + start, end, request.block_ids) + key_list.append(key.to_string()) + addr_list.append(addr) + size_list.append(size) + key_list_c = key_list[self.tp_rank % len( + key_list):] + key_list[:self.tp_rank % len(key_list)] + addr_list_c = addr_list[self.tp_rank % + len(addr_list + ):] + addr_list[:self.tp_rank % + len(addr_list)] + size_list_c = size_list[self.tp_rank % + len(size_list + ):] + size_list[:self.tp_rank % + len(size_list)] + self.m_store.get(key_list_c, addr_list_c, size_list_c) def wait_for_layer_load(self) -> None: - """MooncakeConnector does not do layerwise saving.""" for layerwise_retriever in self.layerwise_retrievers: ret_token_mask = next(layerwise_retriever) if self.current_layer == self.num_layers - 1: assert ret_token_mask is not None num_retrieved_tokens = ret_token_mask.sum().item() - logger.info(f"Retrieved {num_retrieved_tokens} tokens") + logger.debug(f"Retrieved {num_retrieved_tokens} tokens") def save_kv_layer(self, - connector_metadata: MooncakeConnectorMetadata) -> None: - """MooncakeConnector does not save explicitly.""" + connector_metadata: AscendConnectorMetadata) -> None: if self.current_layer == 0: self.layerwise_storers = [] for request in connector_metadata.requests: - save_spec = request.save_spec - if save_spec is None or not save_spec.can_save: + can_save = request.can_save + if can_save is None or not can_save: continue - token_ids = request.token_ids + token_len = request.token_len_chunk req_id = request.req_id - assert isinstance(token_ids, torch.Tensor) - assert token_ids.is_cpu # TODO: whether need to remov saveThread # no lookup, skipmask - skip_leading_tokens = max( - self.lookup(token_ids, self.use_layerwise), - save_spec.skip_leading_tokens, - ) - if skip_leading_tokens == len(token_ids): + skip_leading_tokens = self.lookup(token_len, + request.block_hashes, + self.use_layerwise) + if skip_leading_tokens == token_len: if request.is_last_chunk: self.kv_send_thread.set_finished_request( # type: ignore[union-attr] req_id) continue # skip this request - skip_leading_tokens = (skip_leading_tokens // self.block_size * - self.block_size) + mask_num = (skip_leading_tokens // self.block_size * + self.block_size) - store_mask = torch.ones_like(token_ids, dtype=torch.bool) - store_mask[:skip_leading_tokens] = False logger.info( "Storing KV cache for %d out of %d tokens " "(skip_leading_tokens=%d) for request %s", - len(token_ids) - skip_leading_tokens, - len(token_ids), + token_len - skip_leading_tokens, + token_len, skip_leading_tokens, request.req_id, ) layerwise_storer = self.store_layer( req_id, - token_ids, - mask=store_mask, + token_len, + block_hashes=request.block_hashes, + mask_num=mask_num, block_ids=request.block_ids, + is_last_chunk=request.is_last_chunk, ) self.layerwise_storers.append(layerwise_storer) for layerwise_storer in self.layerwise_storers: @@ -306,59 +286,53 @@ def save_kv_layer(self, next(layerwise_storer) except Exception: raise - self.current_layer = self.current_layer + 1 + self.current_layer = self.current_layer + 1 - def wait_for_save(self, connector_metadata: MooncakeConnectorMetadata): - """MooncakeConnector does not save explicitly.""" + def wait_for_save(self, connector_metadata: AscendConnectorMetadata): for request in connector_metadata.requests: - save_spec = request.save_spec - if save_spec is None or not save_spec.can_save: + can_save = request.can_save + if can_save is None or not can_save: continue - token_ids = request.token_ids + token_len = request.token_len_chunk req_id = request.req_id - assert isinstance(token_ids, torch.Tensor) - assert token_ids.is_cpu - skip_leading_tokens = max( - self.lookup(token_ids, self.use_layerwise), - save_spec.skip_leading_tokens, - ) - if skip_leading_tokens == len(token_ids): + skip_leading_tokens = self.lookup(token_len, request.block_hashes, + self.use_layerwise) + if skip_leading_tokens == token_len: if request.is_last_chunk: self.kv_send_thread.set_finished_request( # type: ignore[union-attr] req_id) continue # skip this request - skip_leading_tokens = (skip_leading_tokens // self.block_size * - self.block_size) - - store_mask = torch.ones_like(token_ids, dtype=torch.bool) - store_mask[:skip_leading_tokens] = False + mask_num = (skip_leading_tokens // self.block_size * + self.block_size) logger.info( "Storing KV cache for %d out of %d tokens " "(skip_leading_tokens=%d) for request %s", - len(token_ids) - skip_leading_tokens, - len(token_ids), + token_len - skip_leading_tokens, + token_len, skip_leading_tokens, request.req_id, ) self.kv_send_thread.add_request( # type: ignore[union-attr] req_id, - token_ids, + token_len, request.block_ids, - store_mask, + request.block_hashes, + mask_num, request.is_last_chunk, ) def retrieve_layer( self, req_id: str, - tokens: torch.Tensor, + token_len: int, block_ids: list[int], - mask: Optional[torch.Tensor] = None, + block_hashes: list[BlockHash], + mask_num: int = 0, ) -> Generator[Optional[torch.Tensor], None, None]: """ Retrieve the KV cache in a layerwise manner. @@ -376,20 +350,16 @@ def retrieve_layer( be the boolean mask indicating which tokens are retrieved and will only be returned in the last iteration. """ + num_required_tokens = token_len - mask_num - if mask is not None: - num_required_tokens = torch.sum(mask).item() - else: - num_required_tokens = len(tokens) - - ret_mask = torch.zeros_like(tokens, dtype=torch.bool, device="cpu") + ret_mask = torch.zeros(token_len, dtype=torch.bool, device="cpu") starts = [] ends = [] keys = [] first_flag = True for start, end, key in self.token_database.process_tokens( - tokens, mask): + token_len, block_hashes, mask_num): keys_multi_layer = key.split_layers(self.num_layers) starts.append(start) ends.append(end) @@ -421,16 +391,18 @@ def retrieve_layer( retrieved_tokens = torch.sum(ret_mask) logger.debug(f"Retrieved {retrieved_tokens} " f"out of {num_required_tokens} " - f"out of total {len(tokens)} tokens") + f"out of total {token_len} tokens") yield ret_mask def store_layer( self, req_id: str, - tokens: torch.Tensor, + token_len: int, block_ids: list[int], - mask: Optional[torch.Tensor] = None, + block_hashes: list[BlockHash], + is_last_chunk: bool, + mask_num: int = 0, ) -> Generator[None, None, None]: """ Store the KV cache in a layerwise manner. @@ -452,17 +424,13 @@ def store_layer( storage backends. In the last iteration, it puts the memory objects of the last layer to the storage backends. """ - - if mask is not None: - num_stored_tokens = torch.sum(mask).item() - else: - num_stored_tokens = len(tokens) + num_stored_tokens = token_len - mask_num starts = [] ends = [] keys = [] for start, end, key in self.token_database.process_tokens( - tokens, mask): + token_len, block_hashes, mask_num): keys_multi_layer = key.split_layers(self.num_layers) starts.append(start) ends.append(end) @@ -473,7 +441,7 @@ def store_layer( for layer_id, keys_multi_chunk in enumerate(keys): req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk, starts, ends, block_ids, - layer_id) + layer_id, is_last_chunk) self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg] req_meta) # type: ignore[union-attr, call-arg, arg-type] yield @@ -481,7 +449,7 @@ def store_layer( for layer_id in range(self.num_layers): yield logger.debug( - f"Stored {num_stored_tokens} out of total {len(tokens)} tokens") + f"Stored {num_stored_tokens} out of total {token_len} tokens") def get_finished(self) -> tuple[set[str], set[str]]: done_sending = ( @@ -500,13 +468,10 @@ def get_finished(self) -> tuple[set[str], set[str]]: self.tp_rank) return done_sending, done_recving - def wait_layer_transfer_finish(self): - time.sleep(10) - pass - def lookup( self, - tokens: Union[torch.Tensor, List[int]], + token_len: int, + block_hashes: list[BlockHash], use_layerwise: bool, ) -> int: """ @@ -517,34 +482,24 @@ def lookup( end = 0 keys = [] try: - if use_layerwise: - for start, end, key in self.token_database.process_tokens( - tokens): + starts = [] + for start, end, key in self.token_database.process_tokens( + token_len, block_hashes): + if use_layerwise: keys_multi_layer = key.split_layers(self.num_layers) for item in keys_multi_layer: keys.append(item.to_string()) - # batch is_exists - ress = self.m_store.batch_exists(keys) - res = 1 - for value in ress: - if value != 1: - res = 0 - break - if res == 1: - continue - else: - return start - else: - starts = [] - for start, end, key in self.token_database.process_tokens( - tokens): + else: keys.append(key.to_string()) - starts.append(start) - res = self.m_store.batch_exists( - keys) # type: ignore[assignment] - for index, value in enumerate(res): # type: ignore[arg-type] - if value != 1: - return starts[index] + starts.append(start) + + res = self.m_store.exists(keys) # type: ignore[assignment] + + if use_layerwise: + res = self.check_all_layers_exists(res, self.num_layers) + for index, value in enumerate(res): # type: ignore[arg-type] + if value != 1: + return starts[index] # all tokens where found, return the maximal end except Exception as e: logger.error(f"Remote connection failed in contains: {e}") @@ -553,7 +508,8 @@ def lookup( def lookup_scheduler( self, - tokens: Union[torch.Tensor, List[int]], + token_len: int, + block_hashes: list[BlockHash], use_layerwise: bool, ) -> int: """ @@ -564,59 +520,59 @@ def lookup_scheduler( end = 0 keys = [] try: - if use_layerwise: - for start, end, key in self.token_database.process_tokens( - tokens): + starts = [] + for start, end, key in self.token_database.process_tokens( + token_len, block_hashes): + if use_layerwise: keys_multi_layer = key.split_layers(self.num_layers) for item in keys_multi_layer: keys.append(item.to_string()) - # batch is_exists - ress = self.m_store.batch_exists(keys) - res = 1 - for value in ress: - if value != 1: - res = 0 - break - if res == 1: - continue - else: - return start - else: - starts = [] - for start, end, key in self.token_database.process_tokens( - tokens): + else: keys.append(key.to_string()) - starts.append(start) - multi_tp_keys = keys[:] - for i in range(1, self.tp_size): - for item in keys: - new_str = item.replace( # type: ignore[attr-defined] - "@0", f"@{i}", 1) - multi_tp_keys.append(new_str) - res = self.m_store.batch_exists( - multi_tp_keys) # type: ignore[assignment] - num_block = len(keys) - multi_tp_values = [ - res[i * num_block:(i + 1) * - num_block] # type: ignore[index] - for i in range(self.tp_size) - ] - index = self.find_min_first_non_one_index(multi_tp_values) - if index != -1: - return starts[index] + starts.append(start) + + multi_tp_keys = keys[:] + for i in range(1, min(self.tp_size, self.num_kv_head)): + for item in keys: + new_str = item.replace( # type: ignore[attr-defined] + "@head_or_tp_rank:0", f"@head_or_tp_rank:{i}", 1) + multi_tp_keys.append(new_str) + + res = self.m_store.exists( + multi_tp_keys) # type: ignore[assignment] + num_block = len(keys) + if use_layerwise: + res = self.check_all_layers_exists(res, self.num_layers) + num_block = len(keys) // self.num_layers + multi_tp_values = [ + res[i * num_block:(i + 1) * num_block] # type: ignore[index] + for i in range(min(self.tp_size, self.num_kv_head)) + ] + index = self.find_min_first_non_one_index(multi_tp_values) + if index != -1: + return starts[index] # all tokens where found, return the maximal end except Exception as e: logger.error(f"Remote connection failed in contains: {e}") return start return end + def check_all_layers_exists(self, res: list[int], + num_layers: int) -> list[int]: + total_chunks = len(res) // num_layers + result = [] + + for chunk_idx in range(total_chunks): + start = chunk_idx * num_layers + end = start + num_layers + chunk = res[start:end] + result.append(1 if all(x == 1 for x in chunk) else 0) + + return result + def find_min_first_non_one_index(self, arr): try: return min(idx for row in arr for idx, val in enumerate(row) if val != 1) except ValueError: return -1 - - def close(self) -> None: - """Close the cache engine and free all the resources""" - self.m_store.close() diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py index c0bd06d4b89..5c5a0a5bef3 100644 --- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -28,6 +28,7 @@ from vllm.utils import logger from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request, RequestStatus import vllm_ascend.envs as envs_ascend @@ -100,7 +101,10 @@ def add_new_req(self, request_id: str, local_block_ids: list[int], class LLMDataDistCMgrConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + def __init__(self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: Optional[KVCacheConfig] = None): assert vllm_config.kv_transfer_config is not None self.engine_id = vllm_config.kv_transfer_config.engine_id if role == KVConnectorRole.SCHEDULER: diff --git a/vllm_ascend/distributed/mooncake/__init__.py b/vllm_ascend/distributed/mooncake/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/vllm_ascend/distributed/mooncake/config_data.py b/vllm_ascend/distributed/mooncake/config_data.py deleted file mode 100644 index 2434b4dbc05..00000000000 --- a/vllm_ascend/distributed/mooncake/config_data.py +++ /dev/null @@ -1,534 +0,0 @@ -import array -import hashlib -import json -import os -import re -from dataclasses import dataclass -from typing import Iterable, List, Optional, Tuple, Union - -import torch -from vllm.distributed.kv_transfer.kv_connector.v1.base import \ - KVConnectorMetadata -from vllm.utils import logger -from vllm.utils.math_utils import cdiv -from vllm.v1.core.sched.output import NewRequestData - -DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB -DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB - - -@dataclass -class MooncakeEngineMetadata: - """name of the LLM model""" - - model_name: str - """ world size when running under a distributed setting """ - world_size: int - """ worker id when running under a distributed setting """ - worker_id: int - """ the format of kv tensors """ - kv_dtype: torch.dtype - """ the shape of kv tensors """ - """ (num_layer, 2, metadata.block_size, num_kv_head, head_size) """ - kv_shape: tuple[int, int, int, int, int] - block_size: int = 128 - """ whether use MLA""" - use_mla: bool = False - - -@dataclass(order=True) -class MooncakeEngineKey: - model_name: str - world_size: int - worker_id: int - chunk_hash: str - - def __hash__(self): - return hash(( - self.model_name, - self.world_size, - self.worker_id, - self.chunk_hash, - )) - - def to_string(self): - return (f"{self.model_name}@{self.world_size}" - f"@{self.worker_id}@{self.chunk_hash}") - - def split_layers(self, num_layers: int) -> List["LayerMooncakeEngineKey"]: - """Split the key into multiple keys for each layer""" - keys = [] - for layer_id in range(num_layers): - keys.append( - LayerMooncakeEngineKey( - self.model_name, - self.world_size, - self.worker_id, - self.chunk_hash, - layer_id, - )) - return keys - - def to_dict(self): - # Note(Kuntai): this is used for serializing CacheEngineKey via msgpack. - return { - "__type__": "CacheEngineKey", - "model_name": self.model_name, - "world_size": self.world_size, - "worker_id": self.worker_id, - "chunk_hash": self.chunk_hash, - } - - @staticmethod - def from_dict(d): - return MooncakeEngineKey( - model_name=d["model_name"], - world_size=d["world_size"], - worker_id=d["worker_id"], - chunk_hash=d["chunk_hash"], - ) - - -@dataclass(order=True) -class LayerMooncakeEngineKey(MooncakeEngineKey): - """A key for the layer cache engine""" - - layer_id: int - - def __hash__(self): - return hash(( - self.model_name, - self.world_size, - self.worker_id, - self.chunk_hash, - self.layer_id, - )) - - def to_string(self): - return (f"{self.model_name}@{self.world_size}" - f"@{self.worker_id}@{self.chunk_hash}@{self.layer_id}") - - -class ChunkedTokenDatabase(): - - def __init__( - self, - metadata: MooncakeEngineMetadata, - ): - self.metadata = metadata - - def _make_key_by_hash(self, - chunk_hash: str, - layer_id: Optional[int] = None): - assert self.metadata is not None - return MooncakeEngineKey( - self.metadata.model_name, - self.metadata.world_size, - self.metadata.worker_id, - chunk_hash, - ) - - def _hash( - self, - tokens: Union[torch.Tensor, List[int]], - prefix_hash: str, - ) -> str: - # TODO: change it to a more efficient hash function - if isinstance(tokens, torch.Tensor): - tokens_bytes = tokens.cpu().to(torch.uint32).numpy().tobytes() - elif isinstance(tokens, list): - tokens_bytes = array.array("I", tokens).tobytes() - return hashlib.sha256(prefix_hash.encode("ascii") + - tokens_bytes).hexdigest() - - def _chunk_tokens( - self, - tokens: Union[torch.Tensor, List[int]], - ) -> Iterable[Union[torch.Tensor, List[int]]]: - """ - Chunk the tokens into chunks of size self.metadata.block_size. - - :param tokens: the input tokens, with shape [seq_len] - device: the target device after chunking - - :return: a generator of chunks of tokens, each with - shape [metadata.block_size] - """ - for i in range(0, len(tokens), self.metadata.block_size): - yield tokens[i:i + self.metadata.block_size] - - def _prefix_hash( - self, - token_chunks: Iterable[Union[torch.Tensor, List[int]]], - ) -> Iterable[str]: - prefix_hash = '' - for token_chunk in token_chunks: - prefix_hash = self._hash(token_chunk, prefix_hash) - yield prefix_hash - - def process_tokens( - self, - tokens: Union[torch.Tensor, List[int]], - mask: Optional[torch.Tensor] = None, - ) -> Iterable[Tuple[int, int, MooncakeEngineKey]]: - """Process the tokens and return the corresponding cache engine keys. - - :param Union[torch.Tensor, List[int]] tokens: The tokens to process. - - :param Optional[torch.Tensor] mask: The mask for the tokens. Should - have the same length as tokens. And the mask should ALWAYS be like - FFFFFTTTTTTT, where True means the tokens needs to be matched, - and the Falses will ALWAYS be at the PREFIX of the tensor. - - :param bool make_key: Whether to make the cache engine key or not. - If False, the hash value will be returned instead. - - :returns: A iterable of tuples with three elements. The first element - is the start index of the tokens for the key. The second element - is the end index of the tokens for the key. The third element is - the cache engine key (or hash) for the tokens. - - :raises: ValueError if the number of Falses in the mask is not a - multiple of the chunk size. - """ - if mask is not None: - num_falses = mask.numel() - mask.long().sum().item() - else: - num_falses = 0 - - if num_falses % self.metadata.block_size != 0: - raise ValueError( - "The number of Falses in the mask is not a multiple of the chunk size." - ) - total_len = len(tokens) - - token_chunks = self._chunk_tokens(tokens) - prefix_hashes = self._prefix_hash(token_chunks) - - start_idx = 0 - for chunk_id, hash_val in enumerate(prefix_hashes): - start_idx = chunk_id * self.metadata.block_size - end_idx = min(start_idx + self.metadata.block_size, total_len) - if start_idx < num_falses: - continue - else: - yield start_idx, end_idx, self._make_key_by_hash(hash_val) - - -@dataclass -class LoadSpec: - # Number of tokens cached in vLLM - vllm_cached_tokens: int - # Number of tokens that are cached in mooncake - mooncake_cached_tokens: int - # Whether the scheduler allow us to load the tokens - can_load: bool - - -@dataclass -class SaveSpec: - # Skip already saved tokens - skip_leading_tokens: int - # Whether the scheduler allow us to save the tokens - can_save: bool - - -@dataclass -class RequestTracker: - # Request id - req_id: str - - # The token ids that has been scheduled so far - token_ids: list[int] - - # The block ids that has been allocated so far - # NOTE: allocated blocks could be more than the number of tokens - # FIXME: need to check whether the block ids will be changed after - # preemption - allocated_block_ids: list[int] - - # The number of tokens that has been savd - num_saved_tokens: int = 0 - - @staticmethod - def from_new_request( - new_request: "NewRequestData", - num_tokens_to_compute: int, - ) -> "RequestTracker": - """Create the request tracker from a new request. - - Args: - new_request (NewRequestData): the new request data. - num_tokens_to_compute (int): the number of tokens that will - be 'computed', including the `num_computed_tokens` (vLLM's - local cache hit) and new tokens that will be scheduled. - - """ - # vLLM 0.9.0 update: request.block_ids changed from list[int] to - # list[list[int]] - # Need to check the type of request.block_ids - - unfolded_block_ids = [] - - if not isinstance(new_request.block_ids[0], list): - unfolded_block_ids = new_request.block_ids.copy() - else: - unfolded_block_ids = new_request.block_ids[0].copy() - - return RequestTracker( - req_id=new_request.req_id, - token_ids=new_request.prompt_token_ids[:num_tokens_to_compute]. - copy(), - allocated_block_ids=unfolded_block_ids, - num_saved_tokens=0, - ) - - def update( - self, - new_token_ids: list[int], - new_block_ids: Union[tuple[list[int], ...], list[int]], - ) -> None: - """Update the request tracker when a running request is - scheduled again - """ - - self.token_ids.extend(new_token_ids) - - if len(new_block_ids) == 0: - new_block_ids = [] - elif isinstance(new_block_ids, tuple): - new_block_ids = new_block_ids[0] - elif isinstance(new_block_ids, list): - pass - else: - raise ValueError( - f"Unsupported new_block_ids type {type(new_block_ids)}") - self.allocated_block_ids.extend(new_block_ids) - - -@dataclass -class ReqMeta: - # Request id - req_id: str - # Request tokens - token_ids: torch.Tensor - - block_ids: list[int] - # # Slot mapping if exchange for block_id - # slot_mapping: torch.Tensor - # Skip save or not - save_spec: Optional[SaveSpec] = None - # load_spec - load_spec: Optional[LoadSpec] = None - - is_last_chunk: Optional[bool] = None - - @staticmethod - def from_request_tracker( - tracker: RequestTracker, - block_size: int, - load_spec: Optional[LoadSpec] = None, - skip_save: Optional[bool] = False, - is_last_chunk: Optional[bool] = None, - discard_partial_chunks: bool = True, - ) -> Optional["ReqMeta"]: - """Create the request metadata from a request tracker. - - Args: - tracker (RequestTracker): the request tracker. - block_size (int): the block size in vLLM. - load_spec (Optional[LoadSpec]): the load spec for KV cache loading. - skip_save (bool): whether to skip the save operation. - discard_partial_chunks (bool): whether to discard partial chunks. - - Returns: - the request metadata if we need to perform load/save - operations, None otherwise. - """ - input_token_ids = tracker.token_ids - input_token_len = len(input_token_ids) - - # For save operation: do not save if the following condition is met - # 1. has already been saved before (num_saved_tokens > 0) - # 2. number of unsaved tokens is not reached the chunk boundary - skip_leading_tokens = tracker.num_saved_tokens - chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) * - block_size if discard_partial_chunks else 0) - # Calculate number of tokens to save based on discard_partial_chunks - # setting - num_tokens_to_save = ((input_token_len // block_size * block_size) - if discard_partial_chunks else input_token_len) - - skip_save = skip_save or num_tokens_to_save < chunk_boundary - if skip_save and load_spec is None: - return None - - # If we need to save, update the number of saved tokens - if not skip_save: - tracker.num_saved_tokens = num_tokens_to_save - save_spec = SaveSpec(skip_leading_tokens, not skip_save) - - # Calculate the token ids and slot mappings for load and save - # OPTIMIZATION: pre-allocate the buffer for token ids and block ids - token_ids = torch.tensor(input_token_ids)[:num_tokens_to_save] - - # # For load operation: check whether the request is scheduled to load - if load_spec is not None and load_spec.can_load: - logger.debug( - "Scheduled to load %d tokens for request %s", - load_spec.mooncake_cached_tokens, - tracker.req_id, - ) - else: - # Do not load if not in `can_load` state - load_spec = None - logger.debug( - f"request:{tracker.req_id}, meta save spec:{save_spec}, meta load spec:{load_spec}" - ) - return ReqMeta( - req_id=tracker.req_id, - token_ids=token_ids, - block_ids=tracker.allocated_block_ids, - save_spec=save_spec, - load_spec=load_spec, - is_last_chunk=is_last_chunk, - ) - - -class MooncakeConnectorMetadata(KVConnectorMetadata): - - def __init__(self, unfinished_request_ids): - self.requests = [] - self.unfinished_request_ids = unfinished_request_ids - - def add_request(self, req_meta: ReqMeta) -> None: - """Add a request to the metadata. - - Args: - req_meta (ReqMeta): the request metadata. - """ - self.requests.append(req_meta) - - -@dataclass -class LasyerMultiBlockReqMeta: - req_id: str - keys: List[LayerMooncakeEngineKey] - starts: List[int] - ends: list[int] - block_ids: list[int] - layer_id: int - - -@dataclass -class MooncakeStoreConfig: - local_hostname: str - metadata_server: str - global_segment_size: Union[int, str] - local_buffer_size: int - protocol: str - device_name: str - master_server_address: str - use_ascend_direct: bool - - @staticmethod - def from_file(file_path: str) -> "MooncakeStoreConfig": - with open(file_path) as file: - config = json.load(file) - return MooncakeStoreConfig( - local_hostname=config.get("local_hostname"), - metadata_server=config.get("metadata_server"), - global_segment_size=_parse_global_segment_size( - config.get("global_segment_size", - DEFAULT_GLOBAL_SEGMENT_SIZE)), - local_buffer_size=(config.get("local_buffer_size", - DEFAULT_LOCAL_BUFFER_SIZE)), - protocol=config.get("protocol", "tcp"), - device_name=config.get("device_name", ""), - master_server_address=config.get("master_server_address"), - use_ascend_direct=config.get("use_ascend_direct", False)) - - @staticmethod - def load_from_env() -> "MooncakeStoreConfig": - config_path = os.getenv("MOONCAKE_CONFIG_PATH") - if not config_path: - raise ValueError( - "The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") - return MooncakeStoreConfig.from_file(config_path) - - -def _parse_global_segment_size(value) -> int: - """ - Parse storage size strings with support for units: GB, MB, KB, B - - Args: - value: Input value (int, str, or other convertible types) - - Returns: - int: Size in bytes - - Raises: - ValueError: For invalid format, missing number, or negative values - TypeError: For unsupported input types - """ - - if isinstance(value, int): - return value - elif not isinstance(value, str): - try: - return int(value) - except (TypeError, ValueError) as e: - raise TypeError( - f"Unsupported type for global_segment_size: {type(value)}" - ) from e - - cleaned_input = value.strip().lower() - if not cleaned_input: - raise ValueError("global segment size cannot be empty.") - - UNIT_MULTIPLIERS = { - 'gb': 1024**3, # 1 GB = 1024^3 bytes - 'mb': 1024**2, # 1 MB = 1024^2 bytes - 'kb': 1024, # 1 KB = 1024 bytes - 'b': 1 # 1 B = 1 byte - } - pattern = r'^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$' - match = re.match(pattern, cleaned_input) - - if not match: - raise ValueError(f"Invalid format: '{value}'") - - number_str = match.group(1) - unit = match.group(2) or 'b' - - multiplier = UNIT_MULTIPLIERS[unit] - return _convert_to_bytes(number_str, multiplier, value) - - -def _convert_to_bytes(number_str: str, multiplier: int, - original_input: str) -> int: - """ - Convert numeric string to byte count - - Args: - number_str: Numeric portion of input - multiplier: Unit conversion factor - original_input: Original input string (for error messages) - - Returns: - int: Byte count - - Raises: - ValueError: For invalid numbers or negative results - """ - try: - numeric_value = float(number_str) - except ValueError: - raise ValueError( - f"Invalid numeric value '{number_str}' in: '{original_input}'") - # Calculate byte count - try: - byte_count = int(numeric_value * multiplier) - except OverflowError: - raise ValueError(f"Storage size too large: '{original_input}'") - return byte_count diff --git a/vllm_ascend/distributed/mooncake/kv_transfer.py b/vllm_ascend/distributed/mooncake/kv_transfer.py deleted file mode 100644 index 4472f678ddd..00000000000 --- a/vllm_ascend/distributed/mooncake/kv_transfer.py +++ /dev/null @@ -1,282 +0,0 @@ -import queue -import threading -from concurrent.futures import ThreadPoolExecutor -from typing import Any, Optional - -import torch -from vllm.utils import logger - -from vllm_ascend.distributed.mooncake.config_data import ( - ChunkedTokenDatabase, LasyerMultiBlockReqMeta) -from vllm_ascend.distributed.mooncake.mooncake_store import Mooncakestore - - -class KVTransferThread(threading.Thread): - - def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, - local_kv_caches_base_addr: list[int], - token_database: ChunkedTokenDatabase, block_len: list[int], - block_size: int, ready_event: threading.Event, name: str): - super().__init__(daemon=True, name=name) - self.tp_rank = tp_rank - self.tp_size = tp_size - self.m_store = m_store - self.ready_event = ready_event - self.kv_caches_base_addr = local_kv_caches_base_addr - self.block_len = block_len - self.token_database = token_database - self.block_size = block_size - self.done_task_lock = threading.Lock() - # TODO(jianzs): find a better way to detect MLA. - self.use_mla = len(block_len) == 2 - - self.request_queue: queue.Queue[Any] = queue.Queue() - # TODO(jianzs): make this configurable - self.executor = ThreadPoolExecutor(max_workers=32) - self.finished_requests: set[str] = set() - - def prepare_value(self, start: int, end: int, block_ids: list[int]): - addr_list = [] - size_list = [] - block_id = block_ids[start // self.block_size] - for index, base_addr in enumerate(self.kv_caches_base_addr): - block_len = (self.block_len[index % 2] - if self.use_mla else self.block_len[0]) - - addr = base_addr + block_id * block_len - length = int(block_len / self.block_size * (end - start)) - addr_list.append(addr) - size_list.append(length) - return addr_list, size_list, block_id - - def prepare_value_layer(self, start: int, end: int, block_ids: list[int], - layer_id: int): - block_id = block_ids[start // self.block_size] - if self.use_mla: - addr_k = self.kv_caches_base_addr[layer_id * - 2] + block_id * self.block_len[0] - addr_v = self.kv_caches_base_addr[layer_id * 2 + - 1] + block_id * self.block_len[1] - length_k = int(self.block_len[0] / self.block_size * (end - start)) - length_v = int(self.block_len[1] / self.block_size * (end - start)) - size_list = [length_k, length_v] - else: - addr_k = self.kv_caches_base_addr[layer_id * - 2] + block_id * self.block_len[0] - addr_v = self.kv_caches_base_addr[layer_id * 2 + - 1] + block_id * self.block_len[0] - length = int(self.block_len[0] / self.block_size * (end - start)) - size_list = [length, length] - addr_list = [addr_k, addr_v] - return addr_list, size_list - - def add_request( - self, - req_id: str, - tokens: torch.Tensor, - block_ids: list[int], - mask: Optional[torch.Tensor] = None, - is_last_chunk: Optional[bool] = None, - ) -> torch.Tensor: - req = ({ - "req_id": req_id, - "tokens": tokens, - "block_ids": block_ids, - "mask": mask, - "is_last_chunk": is_last_chunk, - }) - self.request_queue.put(req) - - def get_and_clear_finished_requests(self) -> set[str]: - """ - Get and clear the requests that have been completed. - Returns: - A set of request IDs that have been completed. - """ - with self.done_task_lock: - finished_requests = self.finished_requests.copy() - self.finished_requests.clear() - return finished_requests - - def set_finished_request(self, req_id): - with self.done_task_lock: - self.finished_requests.add(req_id) - - def run(self): - """Run the thread to handle KV cache transfer requests.""" - self.ready_event.set() - while True: - try: - request_data = self.request_queue.get() - if request_data is None: - logger.warning("Received a None request!") - self.request_queue.task_done() - continue - self._handle_request(request_data) - except Exception as e: - logger.error(f"Error in KVCacheTransferThread: {e}") - - def _handle_request(self, req_meta: dict[str, Any]): - pass - - -class KVCacheStoreSendingThread(KVTransferThread): - - def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, - local_kv_caches_base_addr: list[int], - token_database: ChunkedTokenDatabase, block_len: list[int], - block_size: int, ready_event: threading.Event): - super().__init__(tp_rank, - tp_size, - m_store, - local_kv_caches_base_addr, - token_database, - block_len, - block_size, - ready_event, - name="KVCacheSendingThread") - - def _handle_request(self, req_meta: dict[str, Any]): - tokens = req_meta["tokens"] - mask = req_meta["mask"] - block_ids = req_meta["block_ids"] - req_id = req_meta["req_id"] - is_last_chunk = req_meta["is_last_chunk"] - if self.m_store.config.use_ascend_direct: - addr_list = [] - size_list = [] - key_list = [] - blockIds = [] - for start, end, key in self.token_database.process_tokens( - tokens, mask): - addr, size, block_id = self.prepare_value( - start, end, block_ids) - key_list.append(key.to_string()) - addr_list.append(addr) - size_list.append(size) - blockIds.append(block_id) - torch.npu.current_stream().synchronize() - self.m_store.put_batch(key_list, addr_list, size_list, blockIds) - else: - torch.npu.current_stream().synchronize() - for start, end, key in self.token_database.process_tokens( - tokens, mask): - addr, size, _ = self.prepare_value(start, end, block_ids) - self.m_store.put(key, addr, size) - if is_last_chunk: - self.set_finished_request(req_id) - self.request_queue.task_done() - - -class KVCacheStoreRecvingThread(KVTransferThread): - - def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, - local_kv_caches_base_addr: list[int], - token_database: ChunkedTokenDatabase, block_len: list[int], - block_size: int, ready_event: threading.Event): - super().__init__(tp_rank, - tp_size, - m_store, - local_kv_caches_base_addr, - token_database, - block_len, - block_size, - ready_event, - name="KVCacheStoreRecvingThread") - - def _handle_request(self, req_meta: dict[str, Any]): - tokens = req_meta["tokens"] - mask = req_meta["mask"] - block_ids = req_meta["block_ids"] - req_id = req_meta["req_id"] - if self.m_store.config.use_ascend_direct: - addr_list = [] - size_list = [] - key_list = [] - blockIds = [] - for start, end, key in self.token_database.process_tokens( - tokens, mask): - addr, size, block_id = self.prepare_value( - start, end, block_ids) - key_list.append(key.to_string()) - addr_list.append(addr) - size_list.append(size) - blockIds.append(block_id) - self.m_store.get_batch(key_list, addr_list, size_list, blockIds) - else: - for start, end, key in self.token_database.process_tokens( - tokens, mask): - addr, size, _ = self.prepare_value(start, end, block_ids) - self.m_store.get(key, addr, size) - self.set_finished_request(req_id) - self.request_queue.task_done() - - -class KVCacheStoreLayerSendingThread(KVTransferThread): - - def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, - local_kv_caches_base_addr: list[int], - token_database: ChunkedTokenDatabase, block_len: list[int], - block_size: int, ready_event: threading.Event, - num_layers: int): - super().__init__(tp_rank, - tp_size, - m_store, - local_kv_caches_base_addr, - token_database, - block_len, - block_size, - ready_event, - name="KVCacheStoreLayerSendingThread") - self.final_layer_id = num_layers - 1 - - def add_request( # type: ignore[override] - self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: - self.request_queue.put(req_meta) - - def _handle_request( # type: ignore[override] - self, req_meta: LasyerMultiBlockReqMeta): - torch.npu.current_stream().synchronize() - for index, key in enumerate(req_meta.keys): - addr, size = self.prepare_value_layer(req_meta.starts[index], - req_meta.ends[index], - req_meta.block_ids, - req_meta.layer_id) - self.m_store.put(key, addr, size) - if req_meta.layer_id == self.final_layer_id: - self.set_finished_request(req_meta.req_id) - self.request_queue.task_done() - - -class KVCacheStoreLayerRecvingThread(KVTransferThread): - - def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, - local_kv_caches_base_addr: list[int], - token_database: ChunkedTokenDatabase, block_len: list[int], - block_size: int, ready_event: threading.Event, - get_event: threading.Event): - super().__init__(tp_rank, - tp_size, - m_store, - local_kv_caches_base_addr, - token_database, - block_len, - block_size, - ready_event, - name="KVCacheStoreLayerRecvingThread") - self.get_event = get_event - - def add_request( # type: ignore[override] - self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: - self.request_queue.put(req_meta) - - def _handle_request( # type: ignore[override] - self, req_meta: LasyerMultiBlockReqMeta): - for index, key in enumerate(req_meta.keys): - addr, size = self.prepare_value_layer(req_meta.starts[index], - req_meta.ends[index], - req_meta.block_ids, - req_meta.layer_id) - self.m_store.get(key, addr, size) - self.request_queue.task_done() - self.get_event.set() diff --git a/vllm_ascend/distributed/mooncake/mooncake_store.py b/vllm_ascend/distributed/mooncake/mooncake_store.py deleted file mode 100644 index 01020d72d87..00000000000 --- a/vllm_ascend/distributed/mooncake/mooncake_store.py +++ /dev/null @@ -1,127 +0,0 @@ -# Standard -import os - -# Third Party -from mooncake.store import ReplicateConfig # type: ignore -from vllm.config import ParallelConfig -from vllm.distributed.parallel_state import get_tensor_model_parallel_rank -from vllm.utils import logger -from vllm.utils.network_utils import get_ip - -from vllm_ascend.distributed.mooncake.config_data import MooncakeEngineKey -from vllm_ascend.distributed.mooncake.transfer_engine import get_global_te - -from .config_data import MooncakeStoreConfig - -METADATA_BYTES_LEN = 24 -BASE_PORT = int(os.getenv("VLLM_BASE_PORT", "8790")) - - -class Mooncakestore(): - - def __init__(self, parallel_config: ParallelConfig): - try: - from mooncake.store import MooncakeDistributedStore # type: ignore - except ImportError as e: - raise ImportError( - "Please install mooncake by following the instructions at " - "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 - "to run vLLM with MooncakeConnector.") from e - tp_rank = get_tensor_model_parallel_rank() - tp_size = parallel_config.tensor_parallel_size - dp_rank = parallel_config.data_parallel_rank_local - all_device_ids = os.getenv("ASCEND_RT_VISIBLE_DEVICES", None) - if not all_device_ids: - device_ids_list = list( - range(dp_rank * tp_size, (dp_rank + 1) * tp_size)) - else: - device_ids_list = list(map(int, all_device_ids.split(','))) - assert len(device_ids_list) > tp_rank - device_id = device_ids_list[tp_rank] - self.config = MooncakeStoreConfig.load_from_env() - self.store = MooncakeDistributedStore() - if self.config.protocol == "ascend" and not self.config.use_ascend_direct: - local_hostname = get_ip() + ":" + str(BASE_PORT + int(device_id)) + \ - ":npu_" + str(device_id) - ret = self.store.setup(local_hostname, self.config.metadata_server, - self.config.global_segment_size, - self.config.local_buffer_size, - self.config.protocol, - self.config.device_name, - self.config.master_server_address) - else: - local_hostname = get_ip() - transfer_engine = get_global_te(local_hostname, device_name=None) - self.local_seg = local_hostname + ":" + str( - transfer_engine.get_rpc_port()) - ret = self.store.setup(self.local_seg, self.config.metadata_server, - self.config.global_segment_size, - self.config.local_buffer_size, - self.config.protocol, - self.config.device_name, - self.config.master_server_address, - transfer_engine.get_engine()) - if ret != 0: - msg = "Initialize mooncake failed." - logger.error(msg) - raise RuntimeError(msg) - - def exists(self, key: MooncakeEngineKey) -> bool: - return self.store.is_exist(key.to_string()) == 1 - - def batch_exists(self, keys: list[str]) -> list[int]: - return self.store.batch_is_exist(keys) - - def register_buffer(self, ptr, length): - return self.store.register_buffer(ptr, length) - - def get_batch(self, keys: list[str], addrs: list[list[int]], - sizes: list[list[int]], block_ids: list[int]): - try: - res = self.store.batch_get_into_multi_buffers( - keys, addrs, sizes, True) - for value in res: - if value < 0: - logger.error(f"Failed to get key {keys},res:{res}") - except Exception as e: - logger.error(f"Failed to get key {keys}. {e}") - - def put_batch(self, keys: list[str], addrs: list[list[int]], - sizes: list[list[int]], block_ids: list[int]): - try: - config = ReplicateConfig() - config.preferred_segment = self.local_seg - config.prefer_alloc_in_same_node = True - res = self.store.batch_put_from_multi_buffers( - keys, addrs, sizes, config) - for value in res: - if value < 0: - logger.error(f"Failed to put key {keys},res:{res}") - except Exception as e: - logger.error(f"Failed to put key {keys},error:{e}") - - def get(self, key: MooncakeEngineKey, addr: list[int], size: list[int]): - expect_res = sum(size) - key_str = key.to_string() - try: - res = self.store.batch_get_into_ascend(key_str, addr, size) - if res[0] != expect_res: - logger.error(f"Failed to get key: [{key_str}] .") - except Exception: - logger.error(f"Failed to get key: [{key_str}] .") - return res - - def put(self, key: MooncakeEngineKey, addr: list[int], size: list[int]): - key_str = key.to_string() - try: - ret = self.store.batch_put_from_ascend(key_str, addr, size) - if ret[0] != 0: - logger.error(f"Failed to put key {key_str}.") - except Exception: - logger.error(f"Failed to put key {key_str}.") - - return ret - - def close(self): - self.store.close() - logger.info("Closed the mooncake store connection") diff --git a/vllm_ascend/distributed/mooncake/transfer_engine.py b/vllm_ascend/distributed/mooncake/transfer_engine.py deleted file mode 100644 index d4e172b7857..00000000000 --- a/vllm_ascend/distributed/mooncake/transfer_engine.py +++ /dev/null @@ -1,38 +0,0 @@ -import ipaddress -import threading -from typing import Optional - -from mooncake.engine import TransferEngine # type: ignore - -_global_te = None -_global_te_lock = threading.Lock() - - -def get_global_te(hostname: str, device_name: Optional[str]): - try: - ip = ipaddress.ip_address(hostname) - if isinstance(ip, ipaddress.IPv6Address): - raise RuntimeError( - "The backend of mooncake's Ascend Direct Xfer Library currently does not support IPv6." - ) - except ValueError: - pass - - global _global_te - if _global_te is None: - with _global_te_lock: - # Double-Checked Locking - if _global_te is None: - if TransferEngine is None: - raise RuntimeError("mooncake is not available") - transfer_engine = TransferEngine() - device_name = device_name if device_name is not None else "" - ret_value = transfer_engine.initialize(hostname, - "P2PHANDSHAKE", - "ascend", device_name) - if ret_value != 0: - raise RuntimeError( - f"TransferEngine initialization failed with ret_value: {ret_value}" - ) - _global_te = transfer_engine - return _global_te diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 403b17e4f9a..754bba7b68b 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -31,11 +31,12 @@ get_tensor_model_parallel_rank, get_tp_group) from vllm.utils import logger from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import RequestStatus import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config -from vllm_ascend.distributed.mooncake.transfer_engine import get_global_te +from vllm_ascend.distributed.mooncake_transfer_engine import global_te from vllm_ascend.distributed.utils import get_transfer_timeout_value from vllm_ascend.utils import prefill_context_parallel_enable @@ -634,7 +635,10 @@ def add_new_req( class MooncakeConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + def __init__(self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: Optional[KVCacheConfig] = None): assert vllm_config.kv_transfer_config is not None self.engine_id = vllm_config.kv_transfer_config.engine_id @@ -944,7 +948,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): else: hostname = f"{self.side_channel_host}:0:npu_{self.device_id}" logger.info("Initializing Mooncake work %s", engine_id) - self.engine = get_global_te(hostname, device_name=None) + self.engine = global_te.get_transfer_engine(hostname, device_name=None) self.te_rpc_port = self.engine.get_rpc_port() # Background thread for sending or receiving KV caches. @@ -1054,6 +1058,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.kv_caches = kv_caches kv_caches_base_addr = [] + ptrs = [] + lengths = [] for cache_or_caches in kv_caches.values(): # Normalize to always be a list of caches if self.use_mla: @@ -1061,13 +1067,15 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len[i % 2] kv_caches_base_addr.append(base_addr) - self._register(base_addr, region_len) + ptrs.append(base_addr) + lengths.append(region_len) elif self.use_sparse: for i, cache in enumerate(cache_or_caches, 0): base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len[i % 3] kv_caches_base_addr.append(base_addr) - self._register(base_addr, region_len) + ptrs.append(base_addr) + lengths.append(region_len) else: cache_list = [ cache_or_caches @@ -1076,8 +1084,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len[0] kv_caches_base_addr.append(base_addr) - self._register(base_addr, region_len) - + ptrs.append(base_addr) + lengths.append(region_len) + global_te.register_buffer(ptrs, lengths) # After KV Caches registered, start the sending or receiving thread. metadata = MooncakeAgentMetadata( engine_id=self.engine_id, @@ -1101,14 +1110,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.kv_recv_thread.start() ready_event.wait() - def _register(self, ptr, length): - logger.debug( - "Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, " - "block_lens=%s", ptr, length, self.num_blocks, self.block_len) - ret_value = self.engine.register_memory(ptr, length) - if ret_value != 0: - raise RuntimeError("Mooncake memory registration failed.") - def get_finished(self) -> tuple[set[str], set[str]]: done_sending = ( self.kv_send_thread. diff --git a/vllm_ascend/distributed/mooncake_layerwise_connector.py b/vllm_ascend/distributed/mooncake_layerwise_connector.py index ccb6d344970..215becc5477 100644 --- a/vllm_ascend/distributed/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/mooncake_layerwise_connector.py @@ -30,6 +30,7 @@ from vllm.utils import logger from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import KVCacheConfig import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config @@ -359,7 +360,10 @@ def add_new_req(self, class MooncakeLayerwiseConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + def __init__(self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: Optional[KVCacheConfig] = None): assert vllm_config.kv_transfer_config is not None self.engine_id = vllm_config.kv_transfer_config.engine_id self._connector_metadata = MooncakeLayerwiseConnectorMetadata() diff --git a/vllm_ascend/distributed/mooncake_transfer_engine.py b/vllm_ascend/distributed/mooncake_transfer_engine.py new file mode 100644 index 00000000000..fceecd4c4aa --- /dev/null +++ b/vllm_ascend/distributed/mooncake_transfer_engine.py @@ -0,0 +1,53 @@ +import ipaddress +import threading +from typing import Optional + +from mooncake.engine import TransferEngine # type: ignore + + +class GlobalTE(): + + def __init__(self): + self.transfer_engine = None + self.is_register_buffer: bool = False + self.transfer_engine_lock = threading.Lock() + self.register_buffer_lock = threading.Lock() + + def get_transfer_engine(self, hostname: str, device_name: Optional[str]): + try: + ip = ipaddress.ip_address(hostname) + if isinstance(ip, ipaddress.IPv6Address): + raise RuntimeError( + "The backend of mooncake's Ascend Direct Xfer Library currently does not support IPv6." + ) + except ValueError: + pass + if self.transfer_engine is None: + with self.transfer_engine_lock: + # Double-Checked Locking + if self.transfer_engine is None: + if TransferEngine is None: + raise RuntimeError("mooncake is not available") + self.transfer_engine = TransferEngine() + device_name = device_name if device_name is not None else "" + ret_value = self.transfer_engine.initialize( + hostname, "P2PHANDSHAKE", "ascend", device_name) + if ret_value != 0: + raise RuntimeError( + f"TransferEngine initialization failed with ret_value: {ret_value}" + ) + return self.transfer_engine + + def register_buffer(self, ptrs: list[int], sizes: list[int]): + with self.register_buffer_lock: + assert self.transfer_engine is not None, "Transfer engine must be initialized" + if self.is_register_buffer: + return + for ptr, size in zip(ptrs, sizes): + ret_value = self.transfer_engine.register_memory(ptr, size) + if ret_value != 0: + raise RuntimeError("Mooncake memory registration failed.") + self.is_register_buffer = True + + +global_te = GlobalTE()