Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def __init__(self, kvargs):
self.quant_type = kvargs.get("quant_type", "none")
self.quant_cfg_path = kvargs.get("quant_cfg", None)
self.mem_fraction = kvargs.get("mem_fraction", 0.9)
self.enable_hiradix_cache = kvargs.get("use_hiradix_cache", False)
self.hiradix_cache_gpu = kvargs.get("hiradix_cache_gpu", False)
self.hiradix_cache_token_num = kvargs.get("hiradix_cache_token_num", None)
self.radix_lock = kvargs.get("radix_lock", None)
self.tp_world_size_ = get_dp_world_size()
self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode

Expand Down Expand Up @@ -162,14 +166,31 @@ def _init_weights(self):

def _init_mem_manager(self):
assert self.config["num_attention_heads"] % self.tp_world_size_ == 0

max_total_token_num = self.max_total_token_num - self.hiradix_cache_token_num if self.hiradix_cache_gpu else self.max_total_token_num

self.mem_manager = MemoryManager(
self.max_total_token_num,
max_total_token_num,
dtype=self.data_type,
head_num=self.config["num_attention_heads"] // self.tp_world_size_,
head_dim=self.config["n_embed"] // self.config["num_attention_heads"],
layer_num=self.config["n_layer"],
mem_fraction=self.mem_fraction,
)

if self.enable_hiradix_cache:
from lightllm.common.radixmem_buffer import get_shared_data, MemPropties
from lightllm.common.radixmem_manager import build_radix_manager
mem_propties = MemPropties(
self.hiradix_cache_token_num,
dtype=self.data_type,
head_num=self.config["num_attention_heads"] // self.tp_world_size_,
head_dim=self.config["n_embed"] // self.config["num_attention_heads"],
layer_num=self.config["n_layer"]
)
self.radix_manager = build_radix_manager(mem_propties, self.hiradix_cache_gpu, self.radix_lock)
self.mem_propties = mem_propties
self.shared_mem_data = get_shared_data()
return

def _init_kv_move_buffer(self):
Expand Down
184 changes: 184 additions & 0 deletions lightllm/common/radixmem_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@

import torch
from dataclasses import dataclass
import torch.multiprocessing as mp
from lightllm.utils.log_utils import init_logger
from typing import List, Union
from lightllm.utils.envs_utils import get_unique_server_name
from lightllm.utils.dist_utils import get_current_rank_in_node
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
from multiprocessing.managers import DictProxy, ListProxy
from multiprocessing import Manager


logger = init_logger(__name__)

@dataclass
class SharedRadixMemoryData:
kv_buffer: torch.Tensor
mem_state: torch.Tensor
req_mem_index: DictProxy
lru_queue: ListProxy

@dataclass
class MemPropties:
size: int
dtype: torch.dtype
head_num: int
head_dim: int
layer_num: int

shared_mem_data: SharedRadixMemoryData = None


def init_shared_data(mem_propties: MemPropties, device="cuda"):
size, dtype, head_num, head_dim, layer_num = mem_propties.size, mem_propties.dtype, \
mem_propties.head_num, mem_propties.head_dim, mem_propties.layer_num
global shared_mem_data

if device == "cuda":
kv_buffer = torch.empty(
(layer_num, size, head_num, head_dim),
dtype=dtype,
device="cuda"
)
else:
kv_buffer = torch.empty(
(layer_num, size, head_num, head_dim),
dtype=dtype,
device="cpu"
).share_memory_()

mem_state = torch.arange(size, dtype=torch.int32).share_memory_()
manager = Manager()
req_mem_index = manager.dict()
lru_queue = manager.list()

shared_mem_data = SharedRadixMemoryData(
kv_buffer=kv_buffer,
mem_state=mem_state,
req_mem_index=req_mem_index,
lru_queue=lru_queue
)

def get_shared_data() -> SharedRadixMemoryData:
"""Get the shared memory data."""
global shared_mem_data
if shared_mem_data is None:
raise RuntimeError("Shared memory data has not been initialized. Call init_shared_data first.")
return shared_mem_data

class RadixMemoryBuffer:
def __init__(self, mem_propties: MemPropties, shared_data: SharedRadixMemoryData = None, lock: mp.Lock = None, device="cuda",
rank_in_node=None):
size, dtype, head_num, head_dim, layer_num = mem_propties.size, mem_propties.dtype, \
mem_propties.head_num, mem_propties.head_dim, mem_propties.layer_num

self.kv_buffer = shared_data.kv_buffer
self.mem_state = shared_data.mem_state
self.req_mem_index = shared_data.req_mem_index
self.lock = lock if lock is not None else mp.Lock()

#TODO profile size
self.size = size # token slot 个数
self.head_num = head_num
self.head_dim = head_dim
self.layer_num = layer_num
self.dtype = dtype

can_use_mem_size = self.size
mark_start = 0
mark_end = self.size
rank_in_node = rank_in_node if rank_in_node is not None else get_current_rank_in_node()
self.rank_in_node = rank_in_node
self.can_use_mem_size = SharedInt(
f"{get_unique_server_name()}_radix_mem_manger_can_use_token_num_{rank_in_node}"
)
self.can_use_mem_size.set_value(can_use_mem_size)
self.mark_start = SharedInt(
f"{get_unique_server_name()}_radix_mem_manger_mark_start_{rank_in_node}"
)
self.mark_start.set_value(mark_start)

self.mark_end = SharedInt(
f"{get_unique_server_name()}_radix_mem_manger_mark_end_{rank_in_node}"
)
self.mark_end.set_value(mark_end)
logger.info(f"create {get_unique_server_name()}_radix_mem_manger_can_use_token_num_{rank_in_node}")

def _free(self, free_index: Union[torch.Tensor, List[int]]):
"""_summary_

Args:
free_index (torch.Tensor): _description_
"""
end = self.mark_start.get_value()
start = end - len(free_index)
assert start >= 0, f"error free state start: {end} free len {len(free_index)}"

if isinstance(free_index, list):
self.mem_state.numpy()[start:end] = free_index
else:
# 从 gpu 到 cpu 的拷贝操作是流内阻塞操作
self.mem_state[start:end] = free_index
Comment on lines +119 to +123

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There are a few areas for improvement in this new file for better maintainability and clarity:

  1. Placeholder Docstrings: The docstring for the _free method contains placeholder text like _summary_ and _description_. These should be filled out to properly document the function's purpose, arguments, and behavior.
  2. Chinese Comments: There are comments in Chinese (e.g., line 121). For consistency and to make the code accessible to a wider audience, it's best to use English for all comments and documentation.
  3. TODOs: A TODO comment exists on line 82. It's good practice to either address these during development or create a ticket to track them for future work.


self.mark_start.set_value(end - len(free_index))

self.can_use_mem_size.set_value(self.can_use_mem_size.get_value() + len(free_index))

if self.can_use_mem_size.get_value() == len(self.mem_state):
logger.debug(f"freed all gpu mem size {self.can_use_mem_size.get_value()}")

return

def free_req_index(self, req_id: int):
"""Free the memory index for a specific request ID."""
with self.lock:
if req_id not in self.req_mem_index:
logger.warning(f"Request ID {req_id} not found in memory index.")
return
index = self.req_mem_index[req_id]
self._free(index)
logger.info(f"Freed memory index for request {req_id} size {len(index)}, "
f"left size {self.can_use_mem_size.get_value()}")
del self.req_mem_index[req_id]

def alloc(self, need_size) -> torch.Tensor:
if need_size > self.mark_end.get_value() - self.mark_start.get_value():
logger.error(
f"warn no enough cache need_size {need_size} "
f"left_size {self.can_use_mem_size.get_value()}"
)
raise RuntimeError(f"Not enough memory to allocate {need_size} tokens.")

start = self.mark_start.get_value()
end = start + need_size
ans = self.mem_state[start:end]
self.mark_start.set_value(start + need_size)

self.can_use_mem_size.set_value(self.can_use_mem_size.get_value() - need_size)
return ans

def set_req_mem_index(self, req_id: int, index: List[int]):
"""Set the memory index for a specific request ID."""
with self.lock:
if req_id in self.req_mem_index:
logger.info(f"Request ID {req_id} already exists. "
f"Overwriting index {self.req_mem_index[req_id]} with {index}.")
self.req_mem_index[req_id] = index
logger.info(f"radix mem buffer insert req {req_id}, current disk work num {self._get_current_work_num()}")

def get_req_mem_index(self, req_id: int) -> List[int]:
"""Get the memory index for a specific request ID."""
with self.lock:
if req_id not in self.req_mem_index:
logger.warning(f"Request ID {req_id} not found. Returning empty list.")
return []
return self.req_mem_index[req_id]

def get_kv_buffer(self, index) -> torch.Tensor:
with self.lock:
return self.kv_buffer[:, index, :, :]

def _get_current_work_num(self) -> int:
return len(self.req_mem_index)
150 changes: 150 additions & 0 deletions lightllm/common/radixmem_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import torch
import time
import xxhash
import numpy as np
from typing import List, Dict, Tuple, Optional
import torch.multiprocessing as mp
from collections import OrderedDict

from .radixmem_buffer import MemPropties, init_shared_data, get_shared_data
from .radixmem_buffer import SharedRadixMemoryData, RadixMemoryBuffer

from lightllm.utils.log_utils import init_logger
logger = init_logger(__name__)

class RadixBufferManager:

def __init__(self,
radix_buffer: RadixMemoryBuffer = None,
radix_mem_data: SharedRadixMemoryData = None,
lock: Optional[mp.Lock] = None,
max_entries: int = 10000,
chunk_size: int = 64
):
self.chunk_size = chunk_size
self.max_entries = max_entries
self.radix_buffer = radix_buffer
self.lru_queue = radix_mem_data.lru_queue

self.lock = lock if lock is not None else mp.Lock()

def _compute_hash(self, tokens: List[int]) -> List[Tuple[int, List[int]]]:
chunks = []
hsum = xxhash.xxh3_64()
cumulative_tokens = []

for i in range(0, len(tokens), self.chunk_size):
chunk = tokens[i:i + self.chunk_size]
cumulative_tokens.extend(chunk)

chunk_np = np.array(chunk, dtype=np.uint32)
hsum.update(chunk_np.tobytes())

current_hash = hsum.intdigest()
chunks.append((current_hash, cumulative_tokens.copy()))

return chunks

def write(self, tokens: List[int], values: torch.Tensor, start_pos: int=0) -> None:
with self.lock:
index = start_pos // self.chunk_size
chunks = self._compute_hash(tokens)

values = values[index * self.chunk_size:]
chunks = chunks[index:]
for i, (hash_val, _) in enumerate(chunks):
if hash_val not in self.radix_buffer.req_mem_index:
self.radix_buffer.req_mem_index[hash_val] = values[i * self.chunk_size : (i + 1) * self.chunk_size]
self._update_lru_state(hash_val)

def _update_lru_state(self, hash_val: int):
if hash_val in self.lru_queue:
self.lru_queue.remove(hash_val)
self.lru_queue.append(hash_val)

while len(self.lru_queue) > self.max_entries:
self.lru_queue.pop(0)

def _free_space(self, required_size: int) -> bool:
current_free = self.radix_buffer.can_use_mem_size.get_value()

if current_free >= required_size:
return True

need_to_free = required_size - current_free
freed_size = 0

while freed_size < need_to_free and len(self.lru_queue) > 0:
evict_size = self._evict_lru()
freed_size += evict_size

final_free = self.radix_buffer.can_use_mem_size.get_value()
return final_free >= required_size

def alloc(self, required_size: int) -> bool:
with self.lock:
self._free_space(required_size)
ans = self.radix_buffer.alloc(required_size)
return ans

def _evict_lru(self):
if not self.lru_queue:
return
oldest_hash = self.lru_queue[0]

evict_size = 0
if oldest_hash in self.radix_buffer.req_mem_index:
indices = self.radix_buffer.req_mem_index[oldest_hash]
evict_size += len(indices)
self.radix_buffer._free(indices)
del self.radix_buffer.req_mem_index[oldest_hash]

self.lru_queue.pop(0)
return evict_size

def query_cache(self, tokens: List[int]) -> int:
with self.lock:
chunks = self._compute_hash(tokens)
if not chunks:
return 0, []

max_hit = 0
mem_index = []
for hash_val, _ in chunks:
if hash_val in self.radix_buffer.req_mem_index:
index_val = self.radix_buffer.req_mem_index[hash_val]
mem_index.extend(index_val)
max_hit += len(index_val)
else:
break
return max_hit, mem_index

def clear(self):
with self.lock:
self.radix_buffer.req_mem_index.clear()
self.lru_queue[:] = []

def build_radix_manager(mem_propties: MemPropties,
use_gpu: bool,
radix_lock) -> RadixBufferManager:
device = "cuda" if use_gpu else "cpu"

init_shared_data(
mem_propties=mem_propties,
device=device,
)

radix_mem_buffer = RadixMemoryBuffer(
mem_propties=mem_propties,
shared_data=get_shared_data(),
lock=radix_lock,
device=device,
)

radix_manager = RadixBufferManager(
radix_buffer=radix_mem_buffer,
radix_mem_data=get_shared_data(),
lock=radix_lock,
)

return radix_manager
Loading