Skip to content
34 changes: 29 additions & 5 deletions vllm_ascend/distributed/kvpool/backend/memcache_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from vllm.logger import logger

from vllm_ascend.distributed.kvpool.backend.backend import Backend
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type


class MmcDirect(Enum):
Expand All @@ -26,10 +27,28 @@ def __init__(self, parallel_config: ParallelConfig):
"https://gitee.com/ascend/memfabric_hybrid " # noqa: E501
"to run vLLM with MemcacheConnector.") from e
try:
self.rank = parallel_config.rank
self.store = DistributedObjectStore()
res = self.store.init(self.rank)
assert res == 0
soc_version = get_ascend_device_type()
if soc_version in {AscendDeviceType.A2}:
import torch
from vllm.distributed import get_world_group
tmp_tensor = torch.zeros(1, device="npu")
output_tensor_list = [
torch.empty_like(tmp_tensor)
for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(
output_tensor_list,
tmp_tensor,
group=get_world_group().device_group)
self.rank = parallel_config.rank
self.store = DistributedObjectStore()
res = self.store.init(self.rank)
assert res == 0
else:
self.rank = parallel_config.rank
self.store = DistributedObjectStore()
res = self.store.init(self.rank)
assert res == 0
Comment on lines +30 to +51
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is significant code duplication between the if and else blocks. The lines for initializing self.rank and self.store are identical in both branches. This can be refactored to reduce redundancy and improve maintainability, which is crucial for preventing bugs where a fix is applied to one copy but not the other.

Additionally, torch is already imported at the top of the file, so the local import torch inside the if block is redundant and should be removed.

            soc_version = get_ascend_device_type()
            if soc_version in {AscendDeviceType.A2}:
                from vllm.distributed import get_world_group
                tmp_tensor = torch.zeros(1, device="npu")
                output_tensor_list = [
                    torch.empty_like(tmp_tensor)
                    for _ in range(torch.distributed.get_world_size())
                ]
                torch.distributed.all_gather(
                    output_tensor_list,
                    tmp_tensor,
                    group=get_world_group().device_group)
            self.rank = parallel_config.rank
            self.store = DistributedObjectStore()
            res = self.store.init(self.rank)
            assert res == 0

except ValueError as e:
logger.error("Configuration loading failed: %s", e)
raise
Expand All @@ -43,7 +62,12 @@ def set_device(self):
torch.npu.set_device(device)

def register_buffer(self, ptrs: list[int], sizes: list[int]):
pass
soc_version = get_ascend_device_type()
if soc_version in {AscendDeviceType.A2}:
for ptr, size in zip(ptrs, sizes):
self.store.register_buffer(ptr, size)
else:
pass

def exists(self, keys: list[str]) -> list[int]:
return self.store.batch_is_exist(keys)
Expand Down
5 changes: 4 additions & 1 deletion vllm_ascend/distributed/kvpool/pool_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ def get_num_new_matched_tokens(
if num_external_hit_tokens == request.num_tokens:
num_external_hit_tokens -= 1

need_to_allocate = num_external_hit_tokens - num_computed_tokens
if num_external_hit_tokens < num_computed_tokens:
need_to_allocate = 0
else:
need_to_allocate = num_external_hit_tokens - num_computed_tokens

logger.info(
"Reqid: %s, Total tokens %d, kvpool hit tokens: %d, need to load: %d",
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/distributed/kvpool/pool_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
self.put_step = 1

self.metadata = KeyMetadata(
model_config.model.split('/')[-1],
model_config.model.rstrip('/').split('/')[-1],
self.head_or_tp_rank,
self.pcp_rank,
self.dcp_rank,
Expand Down