-
Notifications
You must be signed in to change notification settings - Fork 280
Asynchicache #977
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jinbiaoyu
wants to merge
13
commits into
main
Choose a base branch
from
asynchicache
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Asynchicache #977
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
e6265d1
initial hicache support not finished
jayfeather9 382ab2f
add debug outputs
jayfeather9 f4bd76e
support hicache
jinbiaoyu a0ae71d
asynchronous hi radix cahce
jinbiaoyu 9141e07
fix some
jinbiaoyu e3f4955
update and add radix manager
jinbiaoyu 1a7e7d3
fix error
jinbiaoyu b55ca74
update package dir
jinbiaoyu 4f29672
fix format
jinbiaoyu ab5d933
fix too long
jinbiaoyu caa2d6c
update & fix format
jinbiaoyu 66fd1ab
update
jinbiaoyu b58dc31
release
jinbiaoyu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
|
||
import torch | ||
from dataclasses import dataclass | ||
import torch.multiprocessing as mp | ||
from lightllm.utils.log_utils import init_logger | ||
from typing import List, Union | ||
from lightllm.utils.envs_utils import get_unique_server_name | ||
from lightllm.utils.dist_utils import get_current_rank_in_node | ||
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt | ||
from multiprocessing.managers import DictProxy, ListProxy | ||
from multiprocessing import Manager | ||
|
||
|
||
logger = init_logger(__name__) | ||
|
||
@dataclass | ||
class SharedRadixMemoryData: | ||
kv_buffer: torch.Tensor | ||
mem_state: torch.Tensor | ||
req_mem_index: DictProxy | ||
lru_queue: ListProxy | ||
|
||
@dataclass | ||
class MemPropties: | ||
size: int | ||
dtype: torch.dtype | ||
head_num: int | ||
head_dim: int | ||
layer_num: int | ||
|
||
shared_mem_data: SharedRadixMemoryData = None | ||
|
||
|
||
def init_shared_data(mem_propties: MemPropties, device="cuda"): | ||
size, dtype, head_num, head_dim, layer_num = mem_propties.size, mem_propties.dtype, \ | ||
mem_propties.head_num, mem_propties.head_dim, mem_propties.layer_num | ||
global shared_mem_data | ||
|
||
if device == "cuda": | ||
kv_buffer = torch.empty( | ||
(layer_num, size, head_num, head_dim), | ||
dtype=dtype, | ||
device="cuda" | ||
) | ||
else: | ||
kv_buffer = torch.empty( | ||
(layer_num, size, head_num, head_dim), | ||
dtype=dtype, | ||
device="cpu" | ||
).share_memory_() | ||
|
||
mem_state = torch.arange(size, dtype=torch.int32).share_memory_() | ||
manager = Manager() | ||
req_mem_index = manager.dict() | ||
lru_queue = manager.list() | ||
|
||
shared_mem_data = SharedRadixMemoryData( | ||
kv_buffer=kv_buffer, | ||
mem_state=mem_state, | ||
req_mem_index=req_mem_index, | ||
lru_queue=lru_queue | ||
) | ||
|
||
def get_shared_data() -> SharedRadixMemoryData: | ||
"""Get the shared memory data.""" | ||
global shared_mem_data | ||
if shared_mem_data is None: | ||
raise RuntimeError("Shared memory data has not been initialized. Call init_shared_data first.") | ||
return shared_mem_data | ||
|
||
class RadixMemoryBuffer: | ||
def __init__(self, mem_propties: MemPropties, shared_data: SharedRadixMemoryData = None, lock: mp.Lock = None, device="cuda", | ||
rank_in_node=None): | ||
size, dtype, head_num, head_dim, layer_num = mem_propties.size, mem_propties.dtype, \ | ||
mem_propties.head_num, mem_propties.head_dim, mem_propties.layer_num | ||
|
||
self.kv_buffer = shared_data.kv_buffer | ||
self.mem_state = shared_data.mem_state | ||
self.req_mem_index = shared_data.req_mem_index | ||
self.lock = lock if lock is not None else mp.Lock() | ||
|
||
#TODO profile size | ||
self.size = size # token slot 个数 | ||
self.head_num = head_num | ||
self.head_dim = head_dim | ||
self.layer_num = layer_num | ||
self.dtype = dtype | ||
|
||
can_use_mem_size = self.size | ||
mark_start = 0 | ||
mark_end = self.size | ||
rank_in_node = rank_in_node if rank_in_node is not None else get_current_rank_in_node() | ||
self.rank_in_node = rank_in_node | ||
self.can_use_mem_size = SharedInt( | ||
f"{get_unique_server_name()}_radix_mem_manger_can_use_token_num_{rank_in_node}" | ||
) | ||
self.can_use_mem_size.set_value(can_use_mem_size) | ||
self.mark_start = SharedInt( | ||
f"{get_unique_server_name()}_radix_mem_manger_mark_start_{rank_in_node}" | ||
) | ||
self.mark_start.set_value(mark_start) | ||
|
||
self.mark_end = SharedInt( | ||
f"{get_unique_server_name()}_radix_mem_manger_mark_end_{rank_in_node}" | ||
) | ||
self.mark_end.set_value(mark_end) | ||
logger.info(f"create {get_unique_server_name()}_radix_mem_manger_can_use_token_num_{rank_in_node}") | ||
|
||
def _free(self, free_index: Union[torch.Tensor, List[int]]): | ||
"""_summary_ | ||
|
||
Args: | ||
free_index (torch.Tensor): _description_ | ||
""" | ||
end = self.mark_start.get_value() | ||
start = end - len(free_index) | ||
assert start >= 0, f"error free state start: {end} free len {len(free_index)}" | ||
|
||
if isinstance(free_index, list): | ||
self.mem_state.numpy()[start:end] = free_index | ||
else: | ||
# 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 | ||
self.mem_state[start:end] = free_index | ||
|
||
self.mark_start.set_value(end - len(free_index)) | ||
|
||
self.can_use_mem_size.set_value(self.can_use_mem_size.get_value() + len(free_index)) | ||
|
||
if self.can_use_mem_size.get_value() == len(self.mem_state): | ||
logger.debug(f"freed all gpu mem size {self.can_use_mem_size.get_value()}") | ||
|
||
return | ||
|
||
def free_req_index(self, req_id: int): | ||
"""Free the memory index for a specific request ID.""" | ||
with self.lock: | ||
if req_id not in self.req_mem_index: | ||
logger.warning(f"Request ID {req_id} not found in memory index.") | ||
return | ||
index = self.req_mem_index[req_id] | ||
self._free(index) | ||
logger.info(f"Freed memory index for request {req_id} size {len(index)}, " | ||
f"left size {self.can_use_mem_size.get_value()}") | ||
del self.req_mem_index[req_id] | ||
|
||
def alloc(self, need_size) -> torch.Tensor: | ||
if need_size > self.mark_end.get_value() - self.mark_start.get_value(): | ||
logger.error( | ||
f"warn no enough cache need_size {need_size} " | ||
f"left_size {self.can_use_mem_size.get_value()}" | ||
) | ||
raise RuntimeError(f"Not enough memory to allocate {need_size} tokens.") | ||
|
||
start = self.mark_start.get_value() | ||
end = start + need_size | ||
ans = self.mem_state[start:end] | ||
self.mark_start.set_value(start + need_size) | ||
|
||
self.can_use_mem_size.set_value(self.can_use_mem_size.get_value() - need_size) | ||
return ans | ||
|
||
def set_req_mem_index(self, req_id: int, index: List[int]): | ||
"""Set the memory index for a specific request ID.""" | ||
with self.lock: | ||
if req_id in self.req_mem_index: | ||
logger.info(f"Request ID {req_id} already exists. " | ||
f"Overwriting index {self.req_mem_index[req_id]} with {index}.") | ||
self.req_mem_index[req_id] = index | ||
logger.info(f"radix mem buffer insert req {req_id}, current disk work num {self._get_current_work_num()}") | ||
|
||
def get_req_mem_index(self, req_id: int) -> List[int]: | ||
"""Get the memory index for a specific request ID.""" | ||
with self.lock: | ||
if req_id not in self.req_mem_index: | ||
logger.warning(f"Request ID {req_id} not found. Returning empty list.") | ||
return [] | ||
return self.req_mem_index[req_id] | ||
|
||
def get_kv_buffer(self, index) -> torch.Tensor: | ||
with self.lock: | ||
return self.kv_buffer[:, index, :, :] | ||
|
||
def _get_current_work_num(self) -> int: | ||
return len(self.req_mem_index) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import torch | ||
import time | ||
import xxhash | ||
import numpy as np | ||
from typing import List, Dict, Tuple, Optional | ||
import torch.multiprocessing as mp | ||
from collections import OrderedDict | ||
|
||
from .radixmem_buffer import MemPropties, init_shared_data, get_shared_data | ||
from .radixmem_buffer import SharedRadixMemoryData, RadixMemoryBuffer | ||
|
||
from lightllm.utils.log_utils import init_logger | ||
logger = init_logger(__name__) | ||
|
||
class RadixBufferManager: | ||
|
||
def __init__(self, | ||
radix_buffer: RadixMemoryBuffer = None, | ||
radix_mem_data: SharedRadixMemoryData = None, | ||
lock: Optional[mp.Lock] = None, | ||
max_entries: int = 10000, | ||
chunk_size: int = 64 | ||
): | ||
self.chunk_size = chunk_size | ||
self.max_entries = max_entries | ||
self.radix_buffer = radix_buffer | ||
self.lru_queue = radix_mem_data.lru_queue | ||
|
||
self.lock = lock if lock is not None else mp.Lock() | ||
|
||
def _compute_hash(self, tokens: List[int]) -> List[Tuple[int, List[int]]]: | ||
chunks = [] | ||
hsum = xxhash.xxh3_64() | ||
cumulative_tokens = [] | ||
|
||
for i in range(0, len(tokens), self.chunk_size): | ||
chunk = tokens[i:i + self.chunk_size] | ||
cumulative_tokens.extend(chunk) | ||
|
||
chunk_np = np.array(chunk, dtype=np.uint32) | ||
hsum.update(chunk_np.tobytes()) | ||
|
||
current_hash = hsum.intdigest() | ||
chunks.append((current_hash, cumulative_tokens.copy())) | ||
|
||
return chunks | ||
|
||
def write(self, tokens: List[int], values: torch.Tensor, start_pos: int=0) -> None: | ||
with self.lock: | ||
index = start_pos // self.chunk_size | ||
chunks = self._compute_hash(tokens) | ||
|
||
values = values[index * self.chunk_size:] | ||
chunks = chunks[index:] | ||
for i, (hash_val, _) in enumerate(chunks): | ||
if hash_val not in self.radix_buffer.req_mem_index: | ||
self.radix_buffer.req_mem_index[hash_val] = values[i * self.chunk_size : (i + 1) * self.chunk_size] | ||
self._update_lru_state(hash_val) | ||
|
||
def _update_lru_state(self, hash_val: int): | ||
if hash_val in self.lru_queue: | ||
self.lru_queue.remove(hash_val) | ||
self.lru_queue.append(hash_val) | ||
|
||
while len(self.lru_queue) > self.max_entries: | ||
self.lru_queue.pop(0) | ||
|
||
def _free_space(self, required_size: int) -> bool: | ||
current_free = self.radix_buffer.can_use_mem_size.get_value() | ||
|
||
if current_free >= required_size: | ||
return True | ||
|
||
need_to_free = required_size - current_free | ||
freed_size = 0 | ||
|
||
while freed_size < need_to_free and len(self.lru_queue) > 0: | ||
evict_size = self._evict_lru() | ||
freed_size += evict_size | ||
|
||
final_free = self.radix_buffer.can_use_mem_size.get_value() | ||
return final_free >= required_size | ||
|
||
def alloc(self, required_size: int) -> bool: | ||
with self.lock: | ||
self._free_space(required_size) | ||
ans = self.radix_buffer.alloc(required_size) | ||
return ans | ||
|
||
def _evict_lru(self): | ||
if not self.lru_queue: | ||
return | ||
oldest_hash = self.lru_queue[0] | ||
|
||
evict_size = 0 | ||
if oldest_hash in self.radix_buffer.req_mem_index: | ||
indices = self.radix_buffer.req_mem_index[oldest_hash] | ||
evict_size += len(indices) | ||
self.radix_buffer._free(indices) | ||
del self.radix_buffer.req_mem_index[oldest_hash] | ||
|
||
self.lru_queue.pop(0) | ||
return evict_size | ||
|
||
def query_cache(self, tokens: List[int]) -> int: | ||
with self.lock: | ||
chunks = self._compute_hash(tokens) | ||
if not chunks: | ||
return 0, [] | ||
|
||
max_hit = 0 | ||
mem_index = [] | ||
for hash_val, _ in chunks: | ||
if hash_val in self.radix_buffer.req_mem_index: | ||
index_val = self.radix_buffer.req_mem_index[hash_val] | ||
mem_index.extend(index_val) | ||
max_hit += len(index_val) | ||
else: | ||
break | ||
return max_hit, mem_index | ||
|
||
def clear(self): | ||
with self.lock: | ||
self.radix_buffer.req_mem_index.clear() | ||
self.lru_queue[:] = [] | ||
|
||
def build_radix_manager(mem_propties: MemPropties, | ||
use_gpu: bool, | ||
radix_lock) -> RadixBufferManager: | ||
device = "cuda" if use_gpu else "cpu" | ||
|
||
init_shared_data( | ||
mem_propties=mem_propties, | ||
device=device, | ||
) | ||
|
||
radix_mem_buffer = RadixMemoryBuffer( | ||
mem_propties=mem_propties, | ||
shared_data=get_shared_data(), | ||
lock=radix_lock, | ||
device=device, | ||
) | ||
|
||
radix_manager = RadixBufferManager( | ||
radix_buffer=radix_mem_buffer, | ||
radix_mem_data=get_shared_data(), | ||
lock=radix_lock, | ||
) | ||
|
||
return radix_manager |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a few areas for improvement in this new file for better maintainability and clarity:
_free
method contains placeholder text like_summary_
and_description_
. These should be filled out to properly document the function's purpose, arguments, and behavior.TODO
comment exists on line 82. It's good practice to either address these during development or create a ticket to track them for future work.