Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
[mypy]
; warn_return_any = True
warn_unused_configs = True
; disable errors about unchecked annotations for now.
disable_error_code = annotation-unchecked

; Suppress all missing import errors from torch_npu for mypy.
[mypy-torch_npu.*]
Expand Down Expand Up @@ -31,4 +33,4 @@ ignore_missing_imports = True
ignore_missing_imports = True

[mypy-ucm.*]
ignore_missing_imports = True
ignore_missing_imports = True
5 changes: 0 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,6 @@ line-length = 120
# Folder to be modified
exclude = [
"tests/**",
# (5)
"vllm_ascend/distributed/kv_transfer/kv_pool/**",
"vllm_ascend/distributed/kv_transfer/utils/**",
"vllm_ascend/kv_offload/**",
"vllm_ascend/lora/**",
# (7)
"vllm_ascend/quantization/**",
"vllm_ascend/sample/*.py",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import threading
from typing import Any, Optional
from typing import Any

import torch
import zmq
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole
from vllm.forward_context import ForwardContext
from vllm.logger import logger
from vllm.utils.network_utils import make_zmq_socket
Expand All @@ -17,40 +16,35 @@
from vllm.v1.serial_utils import MsgpackDecoder

from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_scheduler import (
KVPoolScheduler, get_zmq_rpc_path_lookup)
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_worker import \
KVPoolWorker
KVPoolScheduler,
get_zmq_rpc_path_lookup,
)
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_worker import KVPoolWorker


class AscendStoreConnector(KVConnectorBase_V1):

def __init__(self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: Optional[KVCacheConfig] = None):
super().__init__(vllm_config=vllm_config,
role=role,
kv_cache_config=kv_cache_config)
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole, kv_cache_config: KVCacheConfig | None = None):
super().__init__(vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config)
self.kv_role = vllm_config.kv_transfer_config.kv_role

self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"use_layerwise", False)
self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get("use_layerwise", False)
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_put", False)
"consumer_is_to_put", False
)

connector_name = vllm_config.kv_transfer_config.kv_connector
if connector_name == "MooncakeConnectorStoreV1":
logger.warning(
"It is recommended to use the AscendStoreConnector, as the MoonCakeStoreConnector will be removed in the future."
"It is recommended to use the AscendStoreConnector, "
"as the MoonCakeStoreConnector will be removed in the future."
)

self.kv_caches: dict[str, torch.Tensor] = {}

self.sended_but_unfinished_reqs: set[str] = set()

if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler = KVPoolScheduler(vllm_config,
self.use_layerwise)
self.connector_scheduler = KVPoolScheduler(vllm_config, self.use_layerwise)
else:
self.connector_worker = KVPoolWorker(
vllm_config,
Expand All @@ -59,27 +53,19 @@ def __init__(self,

assert self.connector_worker is not None
if vllm_config.parallel_config.rank == 0:
self.lookup_server = LookupKeyServer(self.connector_worker,
vllm_config,
self.use_layerwise)
self.lookup_server = LookupKeyServer(self.connector_worker, vllm_config, self.use_layerwise)

############################################################
# Scheduler Side Methods
############################################################

def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
def get_num_new_matched_tokens(self, request: "Request", num_computed_tokens: int) -> tuple[int, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens)
return self.connector_scheduler.get_num_new_matched_tokens(request, num_computed_tokens)

def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int):
assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc(
request, blocks, num_external_tokens)
return self.connector_scheduler.update_state_after_alloc(request, blocks, num_external_tokens)

def build_connector_meta(
self,
Expand All @@ -92,7 +78,7 @@ def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
) -> tuple[bool, dict[str, Any] | None]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)

Expand All @@ -103,8 +89,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)

def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
assert self.connector_worker is not None
self.connector_worker.start_load_kv(self._get_connector_metadata())

Expand All @@ -113,8 +98,9 @@ def wait_for_layer_load(self, layer_name: str) -> None:
return
self.connector_worker.wait_for_layer_load()

def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
def save_kv_layer(
self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs
) -> None:
if not self.use_layerwise:
return

Expand All @@ -133,17 +119,16 @@ def wait_for_save(self):

self.connector_worker.wait_for_save(self._get_connector_metadata())

def get_finished(self,
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
"""Get the finished recving and sending requests."""
assert self.connector_worker is not None
done_sending, done_recving = self.connector_worker.get_finished(
finished_req_ids, self._get_connector_metadata())
finished_req_ids, self._get_connector_metadata()
)
return done_sending, done_recving


class LookupKeyServer:

def __init__(
self,
pool_worker: KVPoolWorker,
Expand Down Expand Up @@ -171,8 +156,7 @@ def process_request():
token_len = int.from_bytes(all_frames[0], byteorder="big")
hash_frames = all_frames[1:]
hashes_str = self.decoder.decode(hash_frames)
result = self.pool_worker.lookup_scheduler(
token_len, hashes_str, self.use_layerwise)
result = self.pool_worker.lookup_scheduler(token_len, hashes_str, self.use_layerwise)
response = result.to_bytes(4, "big")
self.socket.send(response)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@


class Backend(ABC):

@abstractmethod
def __init__(self, parallel_config: ParallelConfig):
pass

@abstractmethod
def set_device(self):
pass

@abstractmethod
def register_buffer(self, ptrs: list[int], lengths: list[int]):
pass

Expand All @@ -19,11 +21,9 @@ def exists(self, keys: list[str]) -> list[int]:
pass

@abstractmethod
def put(self, keys: list[str], addrs: list[list[int]],
sizes: list[list[int]]):
def put(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]):
pass

@abstractmethod
def get(self, keys: list[str], addrs: list[list[int]],
sizes: list[list[int]]):
def get(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from vllm.config import ParallelConfig
from vllm.logger import logger

from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \
Backend
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type


Expand All @@ -18,29 +17,24 @@ class MmcDirect(Enum):


class MemcacheBackend(Backend):

def __init__(self, parallel_config: ParallelConfig):
try:
from memcache_hybrid import DistributedObjectStore # type: ignore
except ImportError as e:
raise ImportError(
"Please install memcache by following the instructions at "
"https://gitee.com/ascend/memfabric_hybrid " # noqa: E501
"to run vLLM with MemcacheConnector.") from e
"to run vLLM with MemcacheConnector."
) from e
try:
soc_version = get_ascend_device_type()
if soc_version in {AscendDeviceType.A2}:
import torch
from vllm.distributed import get_world_group

tmp_tensor = torch.zeros(1, device="npu")
output_tensor_list = [
torch.empty_like(tmp_tensor)
for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(
output_tensor_list,
tmp_tensor,
group=get_world_group().device_group)
output_tensor_list = [torch.empty_like(tmp_tensor) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensor_list, tmp_tensor, group=get_world_group().device_group)
self.rank = parallel_config.rank
self.store = DistributedObjectStore()
res = self.store.init(self.rank)
Expand All @@ -54,8 +48,7 @@ def __init__(self, parallel_config: ParallelConfig):
logger.error("Configuration loading failed: %s", e)
raise
except Exception as exc:
logger.error(
"An error occurred while loading the configuration: %s", exc)
logger.error("An error occurred while loading the configuration: %s", exc)
raise

def set_device(self):
Expand All @@ -73,22 +66,18 @@ def register_buffer(self, ptrs: list[int], sizes: list[int]):
def exists(self, keys: list[str]) -> list[int]:
return self.store.batch_is_exist(keys)

def get(self, key: list[str], addr: list[list[int]],
size: list[list[int]]):
def get(self, key: list[str], addr: list[list[int]], size: list[list[int]]):
try:
res = self.store.batch_get_into_layers(key, addr, size,
MmcDirect.COPY_G2L.value)
res = self.store.batch_get_into_layers(key, addr, size, MmcDirect.COPY_G2L.value)
for value in res:
if value != 0:
logger.error(f"Failed to get key {key},res:{res}")
except Exception as e:
logger.error(f"Failed to get key {key}. {e}")

def put(self, key: list[str], addr: list[list[int]],
size: list[list[int]]):
def put(self, key: list[str], addr: list[list[int]], size: list[list[int]]):
try:
res = self.store.batch_put_from_layers(key, addr, size,
MmcDirect.COPY_L2G.value)
res = self.store.batch_put_from_layers(key, addr, size, MmcDirect.COPY_L2G.value)
for value in res:
if value != 0:
logger.error(f"Failed to get key {key},res:{res}")
Expand Down
Loading