diff --git a/vllm_ascend/distributed/kvpool/backend/memcache_backend.py b/vllm_ascend/distributed/kvpool/backend/memcache_backend.py index d6d98d912fd..fcf9120f946 100644 --- a/vllm_ascend/distributed/kvpool/backend/memcache_backend.py +++ b/vllm_ascend/distributed/kvpool/backend/memcache_backend.py @@ -6,6 +6,7 @@ from vllm.logger import logger from vllm_ascend.distributed.kvpool.backend.backend import Backend +from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type class MmcDirect(Enum): @@ -26,10 +27,28 @@ def __init__(self, parallel_config: ParallelConfig): "https://gitee.com/ascend/memfabric_hybrid " # noqa: E501 "to run vLLM with MemcacheConnector.") from e try: - self.rank = parallel_config.rank - self.store = DistributedObjectStore() - res = self.store.init(self.rank) - assert res == 0 + soc_version = get_ascend_device_type() + if soc_version in {AscendDeviceType.A2}: + import torch + from vllm.distributed import get_world_group + tmp_tensor = torch.zeros(1, device="npu") + output_tensor_list = [ + torch.empty_like(tmp_tensor) + for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather( + output_tensor_list, + tmp_tensor, + group=get_world_group().device_group) + self.rank = parallel_config.rank + self.store = DistributedObjectStore() + res = self.store.init(self.rank) + assert res == 0 + else: + self.rank = parallel_config.rank + self.store = DistributedObjectStore() + res = self.store.init(self.rank) + assert res == 0 except ValueError as e: logger.error("Configuration loading failed: %s", e) raise @@ -43,7 +62,12 @@ def set_device(self): torch.npu.set_device(device) def register_buffer(self, ptrs: list[int], sizes: list[int]): - pass + soc_version = get_ascend_device_type() + if soc_version in {AscendDeviceType.A2}: + for ptr, size in zip(ptrs, sizes): + self.store.register_buffer(ptr, size) + else: + pass def exists(self, keys: list[str]) -> list[int]: return self.store.batch_is_exist(keys) diff --git a/vllm_ascend/distributed/kvpool/pool_scheduler.py b/vllm_ascend/distributed/kvpool/pool_scheduler.py index 30199fd9b61..39cdbda0a08 100644 --- a/vllm_ascend/distributed/kvpool/pool_scheduler.py +++ b/vllm_ascend/distributed/kvpool/pool_scheduler.py @@ -82,7 +82,10 @@ def get_num_new_matched_tokens( if num_external_hit_tokens == request.num_tokens: num_external_hit_tokens -= 1 - need_to_allocate = num_external_hit_tokens - num_computed_tokens + if num_external_hit_tokens < num_computed_tokens: + need_to_allocate = 0 + else: + need_to_allocate = num_external_hit_tokens - num_computed_tokens logger.info( "Reqid: %s, Total tokens %d, kvpool hit tokens: %d, need to load: %d", diff --git a/vllm_ascend/distributed/kvpool/pool_worker.py b/vllm_ascend/distributed/kvpool/pool_worker.py index c59078948c1..92543aec50a 100644 --- a/vllm_ascend/distributed/kvpool/pool_worker.py +++ b/vllm_ascend/distributed/kvpool/pool_worker.py @@ -87,7 +87,7 @@ def __init__( self.put_step = 1 self.metadata = KeyMetadata( - model_config.model.split('/')[-1], + model_config.model.rstrip('/').split('/')[-1], self.head_or_tp_rank, self.pcp_rank, self.dcp_rank,