[0.13.0][cherry-pick][bugfix]Synchronize memcache adaptation on A2#5842
Conversation
Signed-off-by: 房建伟 <fangjianwei@fangjianweideMacBook-Air.local>
There was a problem hiding this comment.
Code Review
This pull request synchronizes memcache adaptations for A2 hardware. The changes introduce device-specific logic for A2 in memcache_backend.py, improve the robustness of token allocation calculation in pool_scheduler.py, and fix a path parsing issue in pool_worker.py. My review includes two main points: first, I've identified significant code duplication in memcache_backend.py and suggested a refactoring to improve maintainability. Second, I've proposed a more concise and idiomatic way to calculate need_to_allocate in pool_scheduler.py. Both are high-severity suggestions aimed at improving code quality and correctness.
| soc_version = get_ascend_device_type() | ||
| if soc_version in {AscendDeviceType.A2}: | ||
| import torch | ||
| from vllm.distributed import get_world_group | ||
| tmp_tensor = torch.zeros(1, device="npu") | ||
| output_tensor_list = [ | ||
| torch.empty_like(tmp_tensor) | ||
| for _ in range(torch.distributed.get_world_size()) | ||
| ] | ||
| torch.distributed.all_gather( | ||
| output_tensor_list, | ||
| tmp_tensor, | ||
| group=get_world_group().device_group) | ||
| self.rank = parallel_config.rank | ||
| self.store = DistributedObjectStore() | ||
| res = self.store.init(self.rank) | ||
| assert res == 0 | ||
| else: | ||
| self.rank = parallel_config.rank | ||
| self.store = DistributedObjectStore() | ||
| res = self.store.init(self.rank) | ||
| assert res == 0 |
There was a problem hiding this comment.
There is significant code duplication between the if and else blocks. The initialization of self.rank and self.store is identical in both branches. This can be refactored by moving the common code outside the conditional block to improve maintainability and reduce redundancy.
soc_version = get_ascend_device_type()
if soc_version in {AscendDeviceType.A2}:
import torch
from vllm.distributed import get_world_group
tmp_tensor = torch.zeros(1, device="npu")
output_tensor_list = [
torch.empty_like(tmp_tensor)
for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(
output_tensor_list,
tmp_tensor,
group=get_world_group().device_group)
self.rank = parallel_config.rank
self.store = DistributedObjectStore()
res = self.store.init(self.rank)
assert res == 0| if num_external_hit_tokens < num_computed_tokens: | ||
| need_to_allocate = 0 | ||
| else: | ||
| need_to_allocate = num_external_hit_tokens - num_computed_tokens |
What this PR does / why we need it?
When running memcache in the A2 environment, the logic for registering memory needs to be added. Additionally, there is a link establishment conflict between memcache and HCCS during initialization in A2, so the link should be established in advance.
pick-from: #5601