diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index ff3290233..2bd5853a2 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -74,6 +74,10 @@ def __init__(self, kvargs): self.quant_type = kvargs.get("quant_type", "none") self.quant_cfg_path = kvargs.get("quant_cfg", None) self.mem_fraction = kvargs.get("mem_fraction", 0.9) + self.enable_hiradix_cache = kvargs.get("use_hiradix_cache", False) + self.hiradix_cache_gpu = kvargs.get("hiradix_cache_gpu", False) + self.hiradix_cache_token_num = kvargs.get("hiradix_cache_token_num", None) + self.radix_lock = kvargs.get("radix_lock", None) self.tp_world_size_ = get_dp_world_size() self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode @@ -162,14 +166,31 @@ def _init_weights(self): def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 + + max_total_token_num = self.max_total_token_num - self.hiradix_cache_token_num if self.hiradix_cache_gpu else self.max_total_token_num + self.mem_manager = MemoryManager( - self.max_total_token_num, + max_total_token_num, dtype=self.data_type, head_num=self.config["num_attention_heads"] // self.tp_world_size_, head_dim=self.config["n_embed"] // self.config["num_attention_heads"], layer_num=self.config["n_layer"], mem_fraction=self.mem_fraction, ) + + if self.enable_hiradix_cache: + from lightllm.common.radixmem_buffer import get_shared_data, MemPropties + from lightllm.common.radixmem_manager import build_radix_manager + mem_propties = MemPropties( + self.hiradix_cache_token_num, + dtype=self.data_type, + head_num=self.config["num_attention_heads"] // self.tp_world_size_, + head_dim=self.config["n_embed"] // self.config["num_attention_heads"], + layer_num=self.config["n_layer"] + ) + self.radix_manager = build_radix_manager(mem_propties, self.hiradix_cache_gpu, self.radix_lock) + self.mem_propties = mem_propties + self.shared_mem_data = get_shared_data() return def _init_kv_move_buffer(self): diff --git a/lightllm/common/radixmem_buffer.py b/lightllm/common/radixmem_buffer.py new file mode 100644 index 000000000..740e02120 --- /dev/null +++ b/lightllm/common/radixmem_buffer.py @@ -0,0 +1,184 @@ + +import torch +from dataclasses import dataclass +import torch.multiprocessing as mp +from lightllm.utils.log_utils import init_logger +from typing import List, Union +from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.utils.dist_utils import get_current_rank_in_node +from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt +from multiprocessing.managers import DictProxy, ListProxy +from multiprocessing import Manager + + +logger = init_logger(__name__) + +@dataclass +class SharedRadixMemoryData: + kv_buffer: torch.Tensor + mem_state: torch.Tensor + req_mem_index: DictProxy + lru_queue: ListProxy + +@dataclass +class MemPropties: + size: int + dtype: torch.dtype + head_num: int + head_dim: int + layer_num: int + +shared_mem_data: SharedRadixMemoryData = None + + +def init_shared_data(mem_propties: MemPropties, device="cuda"): + size, dtype, head_num, head_dim, layer_num = mem_propties.size, mem_propties.dtype, \ + mem_propties.head_num, mem_propties.head_dim, mem_propties.layer_num + global shared_mem_data + + if device == "cuda": + kv_buffer = torch.empty( + (layer_num, size, head_num, head_dim), + dtype=dtype, + device="cuda" + ) + else: + kv_buffer = torch.empty( + (layer_num, size, head_num, head_dim), + dtype=dtype, + device="cpu" + ).share_memory_() + + mem_state = torch.arange(size, dtype=torch.int32).share_memory_() + manager = Manager() + req_mem_index = manager.dict() + lru_queue = manager.list() + + shared_mem_data = SharedRadixMemoryData( + kv_buffer=kv_buffer, + mem_state=mem_state, + req_mem_index=req_mem_index, + lru_queue=lru_queue + ) + +def get_shared_data() -> SharedRadixMemoryData: + """Get the shared memory data.""" + global shared_mem_data + if shared_mem_data is None: + raise RuntimeError("Shared memory data has not been initialized. Call init_shared_data first.") + return shared_mem_data + +class RadixMemoryBuffer: + def __init__(self, mem_propties: MemPropties, shared_data: SharedRadixMemoryData = None, lock: mp.Lock = None, device="cuda", + rank_in_node=None): + size, dtype, head_num, head_dim, layer_num = mem_propties.size, mem_propties.dtype, \ + mem_propties.head_num, mem_propties.head_dim, mem_propties.layer_num + + self.kv_buffer = shared_data.kv_buffer + self.mem_state = shared_data.mem_state + self.req_mem_index = shared_data.req_mem_index + self.lock = lock if lock is not None else mp.Lock() + + #TODO profile size + self.size = size # token slot 个数 + self.head_num = head_num + self.head_dim = head_dim + self.layer_num = layer_num + self.dtype = dtype + + can_use_mem_size = self.size + mark_start = 0 + mark_end = self.size + rank_in_node = rank_in_node if rank_in_node is not None else get_current_rank_in_node() + self.rank_in_node = rank_in_node + self.can_use_mem_size = SharedInt( + f"{get_unique_server_name()}_radix_mem_manger_can_use_token_num_{rank_in_node}" + ) + self.can_use_mem_size.set_value(can_use_mem_size) + self.mark_start = SharedInt( + f"{get_unique_server_name()}_radix_mem_manger_mark_start_{rank_in_node}" + ) + self.mark_start.set_value(mark_start) + + self.mark_end = SharedInt( + f"{get_unique_server_name()}_radix_mem_manger_mark_end_{rank_in_node}" + ) + self.mark_end.set_value(mark_end) + logger.info(f"create {get_unique_server_name()}_radix_mem_manger_can_use_token_num_{rank_in_node}") + + def _free(self, free_index: Union[torch.Tensor, List[int]]): + """_summary_ + + Args: + free_index (torch.Tensor): _description_ + """ + end = self.mark_start.get_value() + start = end - len(free_index) + assert start >= 0, f"error free state start: {end} free len {len(free_index)}" + + if isinstance(free_index, list): + self.mem_state.numpy()[start:end] = free_index + else: + # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 + self.mem_state[start:end] = free_index + + self.mark_start.set_value(end - len(free_index)) + + self.can_use_mem_size.set_value(self.can_use_mem_size.get_value() + len(free_index)) + + if self.can_use_mem_size.get_value() == len(self.mem_state): + logger.debug(f"freed all gpu mem size {self.can_use_mem_size.get_value()}") + + return + + def free_req_index(self, req_id: int): + """Free the memory index for a specific request ID.""" + with self.lock: + if req_id not in self.req_mem_index: + logger.warning(f"Request ID {req_id} not found in memory index.") + return + index = self.req_mem_index[req_id] + self._free(index) + logger.info(f"Freed memory index for request {req_id} size {len(index)}, " + f"left size {self.can_use_mem_size.get_value()}") + del self.req_mem_index[req_id] + + def alloc(self, need_size) -> torch.Tensor: + if need_size > self.mark_end.get_value() - self.mark_start.get_value(): + logger.error( + f"warn no enough cache need_size {need_size} " + f"left_size {self.can_use_mem_size.get_value()}" + ) + raise RuntimeError(f"Not enough memory to allocate {need_size} tokens.") + + start = self.mark_start.get_value() + end = start + need_size + ans = self.mem_state[start:end] + self.mark_start.set_value(start + need_size) + + self.can_use_mem_size.set_value(self.can_use_mem_size.get_value() - need_size) + return ans + + def set_req_mem_index(self, req_id: int, index: List[int]): + """Set the memory index for a specific request ID.""" + with self.lock: + if req_id in self.req_mem_index: + logger.info(f"Request ID {req_id} already exists. " + f"Overwriting index {self.req_mem_index[req_id]} with {index}.") + self.req_mem_index[req_id] = index + logger.info(f"radix mem buffer insert req {req_id}, current disk work num {self._get_current_work_num()}") + + def get_req_mem_index(self, req_id: int) -> List[int]: + """Get the memory index for a specific request ID.""" + with self.lock: + if req_id not in self.req_mem_index: + logger.warning(f"Request ID {req_id} not found. Returning empty list.") + return [] + return self.req_mem_index[req_id] + + def get_kv_buffer(self, index) -> torch.Tensor: + with self.lock: + return self.kv_buffer[:, index, :, :] + + def _get_current_work_num(self) -> int: + return len(self.req_mem_index) diff --git a/lightllm/common/radixmem_manager.py b/lightllm/common/radixmem_manager.py new file mode 100644 index 000000000..1c2011c62 --- /dev/null +++ b/lightllm/common/radixmem_manager.py @@ -0,0 +1,150 @@ +import torch +import time +import xxhash +import numpy as np +from typing import List, Dict, Tuple, Optional +import torch.multiprocessing as mp +from collections import OrderedDict + +from .radixmem_buffer import MemPropties, init_shared_data, get_shared_data +from .radixmem_buffer import SharedRadixMemoryData, RadixMemoryBuffer + +from lightllm.utils.log_utils import init_logger +logger = init_logger(__name__) + +class RadixBufferManager: + + def __init__(self, + radix_buffer: RadixMemoryBuffer = None, + radix_mem_data: SharedRadixMemoryData = None, + lock: Optional[mp.Lock] = None, + max_entries: int = 10000, + chunk_size: int = 64 + ): + self.chunk_size = chunk_size + self.max_entries = max_entries + self.radix_buffer = radix_buffer + self.lru_queue = radix_mem_data.lru_queue + + self.lock = lock if lock is not None else mp.Lock() + + def _compute_hash(self, tokens: List[int]) -> List[Tuple[int, List[int]]]: + chunks = [] + hsum = xxhash.xxh3_64() + cumulative_tokens = [] + + for i in range(0, len(tokens), self.chunk_size): + chunk = tokens[i:i + self.chunk_size] + cumulative_tokens.extend(chunk) + + chunk_np = np.array(chunk, dtype=np.uint32) + hsum.update(chunk_np.tobytes()) + + current_hash = hsum.intdigest() + chunks.append((current_hash, cumulative_tokens.copy())) + + return chunks + + def write(self, tokens: List[int], values: torch.Tensor, start_pos: int=0) -> None: + with self.lock: + index = start_pos // self.chunk_size + chunks = self._compute_hash(tokens) + + values = values[index * self.chunk_size:] + chunks = chunks[index:] + for i, (hash_val, _) in enumerate(chunks): + if hash_val not in self.radix_buffer.req_mem_index: + self.radix_buffer.req_mem_index[hash_val] = values[i * self.chunk_size : (i + 1) * self.chunk_size] + self._update_lru_state(hash_val) + + def _update_lru_state(self, hash_val: int): + if hash_val in self.lru_queue: + self.lru_queue.remove(hash_val) + self.lru_queue.append(hash_val) + + while len(self.lru_queue) > self.max_entries: + self.lru_queue.pop(0) + + def _free_space(self, required_size: int) -> bool: + current_free = self.radix_buffer.can_use_mem_size.get_value() + + if current_free >= required_size: + return True + + need_to_free = required_size - current_free + freed_size = 0 + + while freed_size < need_to_free and len(self.lru_queue) > 0: + evict_size = self._evict_lru() + freed_size += evict_size + + final_free = self.radix_buffer.can_use_mem_size.get_value() + return final_free >= required_size + + def alloc(self, required_size: int) -> bool: + with self.lock: + self._free_space(required_size) + ans = self.radix_buffer.alloc(required_size) + return ans + + def _evict_lru(self): + if not self.lru_queue: + return + oldest_hash = self.lru_queue[0] + + evict_size = 0 + if oldest_hash in self.radix_buffer.req_mem_index: + indices = self.radix_buffer.req_mem_index[oldest_hash] + evict_size += len(indices) + self.radix_buffer._free(indices) + del self.radix_buffer.req_mem_index[oldest_hash] + + self.lru_queue.pop(0) + return evict_size + + def query_cache(self, tokens: List[int]) -> int: + with self.lock: + chunks = self._compute_hash(tokens) + if not chunks: + return 0, [] + + max_hit = 0 + mem_index = [] + for hash_val, _ in chunks: + if hash_val in self.radix_buffer.req_mem_index: + index_val = self.radix_buffer.req_mem_index[hash_val] + mem_index.extend(index_val) + max_hit += len(index_val) + else: + break + return max_hit, mem_index + + def clear(self): + with self.lock: + self.radix_buffer.req_mem_index.clear() + self.lru_queue[:] = [] + +def build_radix_manager(mem_propties: MemPropties, + use_gpu: bool, + radix_lock) -> RadixBufferManager: + device = "cuda" if use_gpu else "cpu" + + init_shared_data( + mem_propties=mem_propties, + device=device, + ) + + radix_mem_buffer = RadixMemoryBuffer( + mem_propties=mem_propties, + shared_data=get_shared_data(), + lock=radix_lock, + device=device, + ) + + radix_manager = RadixBufferManager( + radix_buffer=radix_mem_buffer, + radix_mem_data=get_shared_data(), + lock=radix_lock, + ) + + return radix_manager \ No newline at end of file diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index 9101cb963..69d673192 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -102,15 +102,29 @@ def _init_mem_manager(self): added_mtp_layer_num = 0 if get_env_start_args().mtp_mode == "deepseekv3": added_mtp_layer_num += get_env_start_args().mtp_step - + + max_total_token_num = self.max_total_token_num - self.hiradix_cache_token_num if self.hiradix_cache_gpu else self.max_total_token_num self.mem_manager = manager_class( - self.max_total_token_num, + max_total_token_num, dtype=self.data_type, head_num=1, head_dim=self.config["kv_lora_rank"] + self.config["qk_rope_head_dim"], layer_num=self.config["num_hidden_layers"] + added_mtp_layer_num, mem_fraction=self.mem_fraction, ) + if self.enable_hiradix_cache: + from lightllm.common.radixmem_buffer import get_shared_data, MemPropties + from lightllm.common.radixmem_manager import build_radix_manager + mem_propties = MemPropties( + self.hiradix_cache_token_num, + dtype=self.data_type, + head_num=1, + head_dim=self.config["kv_lora_rank"] + self.config["qk_rope_head_dim"], + layer_num=self.config["num_hidden_layers"] + added_mtp_layer_num, + ) + self.radix_manager = build_radix_manager(mem_propties, self.hiradix_cache_gpu, self.radix_lock) + self.mem_propties = mem_propties + self.shared_mem_data = get_shared_data() return def _init_weights(self): diff --git a/lightllm/models/qwen2/model.py b/lightllm/models/qwen2/model.py index e3d8de461..47eb87686 100644 --- a/lightllm/models/qwen2/model.py +++ b/lightllm/models/qwen2/model.py @@ -41,12 +41,28 @@ def _init_mem_manager(self): head_dim_ = self.config["hidden_size"] // self.config["num_attention_heads"] head_dim_ = self.config.get("head_dim", head_dim_) tp_k_head_num_ = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) + + max_total_token_num = self.max_total_token_num - self.hiradix_cache_token_num if self.hiradix_cache_gpu else self.max_total_token_num self.mem_manager = select_mem_manager_class(self.mode)( - self.max_total_token_num, + max_total_token_num, dtype=self.data_type, head_num=tp_k_head_num_, head_dim=head_dim_, layer_num=self.config["num_hidden_layers"], mem_fraction=self.mem_fraction, ) + + if self.enable_hiradix_cache: + from lightllm.common.radixmem_buffer import MemPropties, get_shared_data + from lightllm.common.radixmem_manager import build_radix_manager + mem_propties = MemPropties( + self.hiradix_cache_token_num, + dtype=self.data_type, + head_num=2 * tp_k_head_num_, + head_dim=head_dim_, + layer_num=self.config["num_hidden_layers"], + ) + self.radix_manager = build_radix_manager(mem_propties, self.hiradix_cache_gpu, self.radix_lock) + self.mem_propties = mem_propties + self.shared_mem_data = get_shared_data() return diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index d904c727f..9122f76d8 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -220,6 +220,9 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument("--disable_dynamic_prompt_cache", action="store_true", help="disable dynamic prompt cache") parser.add_argument("--chunked_prefill_size", type=int, default=4096, help="chunked prefill size") + parser.add_argument("--use_hiradix_cache", action="store_true", help="enable hierachy prompt cache") + parser.add_argument("--hiradix_cache_gpu", action="store_true", help="enable hierachy prompt cache gpu") + parser.add_argument("--hiradix_cache_token_num", type=int, default=None , help="set the number of tokens to use hierachy prompt cache") parser.add_argument("--disable_chunked_prefill", action="store_true", help="whether to disable chunked prefill") parser.add_argument("--diverse_mode", action="store_true", help="diversity generation mode") parser.add_argument("--token_healing_mode", action="store_true", help="code model infer mode") @@ -326,7 +329,7 @@ def make_argument_parser() -> argparse.ArgumentParser: "--visual_infer_batch_size", type=int, default=1, help="number of images to process in each inference batch" ) parser.add_argument( - "--visual_gpu_ids", nargs="+", type=int, default=None, help="List of GPU IDs to use, e.g., 0 1 2" + "--visual_gpu_ids", nargs="+", type=int, default=[0, 1, 2, 3, 4, 5, 6, 7], help="List of GPU IDs to use, e.g., 0 1 2" ) parser.add_argument("--visual_tp", type=int, default=1, help="number of tensort parallel instances for ViT") parser.add_argument("--visual_dp", type=int, default=1, help="number of data parallel instances for ViT") diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 6e6c27b5e..15d1c6a6f 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -173,6 +173,11 @@ def normal_or_p_d_start(args): args.batch_max_tokens >= args.chunked_prefill_size ), "chunked prefill mode, batch_max_tokens must >= chunked_prefill_size" + # if use_hiradix_cache, then use_dynamic_prompt_cache must be True + hiradix_cache_port_num = 0 + if args.use_hiradix_cache: + assert not args.disable_dynamic_prompt_cache, "use_hiradix_cache must be used with use_dynamic_prompt_cache" + # help to manage data stored on Ceph if "s3://" in args.model_dir: from lightllm.utils.petrel_helper import s3_model_prepare @@ -201,8 +206,11 @@ def normal_or_p_d_start(args): ports_locker.lock_port() node_world_size = args.tp // args.nnodes + + if args.use_hiradix_cache: + hiradix_cache_port_num = node_world_size + 2 can_use_ports = alloc_can_use_network_port( - num=7 + node_world_size + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports + num=7 + node_world_size + args.visual_dp * args.visual_tp + hiradix_cache_port_num, used_nccl_ports=already_uesd_ports ) logger.info(f"alloced ports: {can_use_ports}") ( @@ -230,6 +238,10 @@ def normal_or_p_d_start(args): args.audio_port = audio_port args.cache_port = cache_port args.metric_port = metric_port + if args.use_hiradix_cache: + args.hiradix_cache_ports = can_use_ports[0:node_world_size] + args.hiradix_server_ports = can_use_ports[node_world_size: node_world_size + 2] + can_use_ports = can_use_ports[node_world_size + 2:] # 申请在 p d 分离模式下,会用的端口 args.pd_node_infer_rpyc_ports = can_use_ports[0:node_world_size] diff --git a/lightllm/server/core/objs/__init__.py b/lightllm/server/core/objs/__init__.py index 5594df6a0..d66e3bd86 100644 --- a/lightllm/server/core/objs/__init__.py +++ b/lightllm/server/core/objs/__init__.py @@ -1,5 +1,5 @@ from .sampling_params import SamplingParams -from .req import Req, FinishStatus +from .req import Req, FinishStatus, RadixStatus from .shm_req_manager import ShmReqManager from .rpc_shm import RpcShmParams, RpcShmResults, ShmSyncStatusArray from .start_args_type import StartArgs diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index f2ebadad1..2ae2cbcca 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -54,6 +54,73 @@ def set_token_ids(self, ids: List[int]): def get_token_ids(self): return list(self.data[: self.size]) +class ReqRankStatus(ctypes.Structure): + _pack_ = 4 + _fields_ = [("dp_rank_in_node", ctypes.c_int), ("dp_world_size", ctypes.c_int)] + + def __init__(self): + self.dp_rank_in_node = 0 + self.dp_world_size = 8 + + def set_status(self, dp_rank_in_node: int, dp_world_size: int): + self.dp_rank_in_node = dp_rank_in_node + self.dp_world_size = dp_world_size + +class RadixStatus(ctypes.Structure): + _pack_ = 4 + _fields_ = [("status", ctypes.c_int * 32), ("rank_status", ReqRankStatus), ("finished", ctypes.c_int)] + + NOCACHE = -2 + NOT_READY = -1 + READ_READY = 1 + WRITE_READY = 2 + WRITE_DONE = 3 + + def __init__(self, init_state=NOT_READY): + for i in range(32): + self.status[i] = init_state + self.rank_status = ReqRankStatus() + self.finished = 0 + + def set_status(self, idx: int, new_status: int): + assert 0 <= idx < 32, f"Index out of range: {idx}" + assert new_status in (self.NOCACHE, self.NOT_READY, self.READ_READY, self.WRITE_READY, self.WRITE_DONE) + self.status[idx] = new_status + + def set_finished(self): + self.finished = 1 + + def is_finished(self): + return self.finished == 1 + + def get_status(self, idx: int) -> int: + assert 0 <= idx < 32, f"Index out of range: {idx}" + return self.status[idx] + + def is_write_done(self): + dp_index = self.rank_status.dp_rank_in_node + dp_size = self.rank_status.dp_world_size + rank_list = range(dp_index * dp_size, (dp_index + 1) * dp_size) + return np.all(np.array(self.status)[rank_list] == self.WRITE_DONE) + + def is_no_need_cache(self, idx: int) -> bool: + return self.get_status(idx) == self.NOCACHE + + def is_read_ready(self, idx: int) -> bool: + return self.get_status(idx) == self.READ_READY + + def is_write_ready(self, idx: int) -> bool: + return self.get_status(idx) == self.WRITE_READY + + def is_not_ready(self, idx: int) -> bool: + return self.get_status(idx) == self.NOT_READY + + def all_dp_read_ready_or_nocache(self, indexs: List[int]) -> bool: + return np.all(np.array(self.status)[indexs] == self.READ_READY) or np.all(np.array(self.status)[indexs] == self.NOCACHE) + + def all_read_ready_or_nocache(self) -> bool: + return np.all(np.array(self.status) == self.READ_READY) or np.all(np.array(self.status) == self.NOCACHE) + class Req(ctypes.Structure): _pack_ = 4 @@ -98,6 +165,8 @@ class Req(ctypes.Structure): ("mtp_accepted_token_num", ctypes.c_int), # mtp_step 保存一个mtp使用的常量参数,用于快速访问,不会被外部输入初始化 ("_mtp_step", ctypes.c_int), + # 用于标记当前请求的radix状态 + ("radix_status", RadixStatus), ] def get_str(self): @@ -151,6 +220,7 @@ def init( self.shm_prompt_ids.arr[0 : len(prompt_ids)] = prompt_ids self.mtp_accepted_token_num = 0 self._mtp_step = get_env_start_args().mtp_step + self.radix_status = RadixStatus(RadixStatus.NOT_READY) self.post_init() @@ -204,7 +274,6 @@ def can_release(self): # 只有管理节点有一个引用 ref_count_ok = self.ref_count == 1 can_released_mark = self.can_released_mark - if self.is_aborted and can_released_mark and ref_count_ok: return True diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index f76fbc8c8..d68587480 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -46,6 +46,7 @@ class StartArgs: dp_prefill_wait_step: int = field(default=0) disable_aggressive_schedule: bool = field(default=False) disable_dynamic_prompt_cache: bool = field(default=False) + use_hiradix_cache: bool = field(default=False) chunked_prefill_size: int = field(default=8192) disable_chunked_prefill: bool = field(default=False) diverse_mode: bool = field(default=False) diff --git a/lightllm/server/detokenization/decode_req.py b/lightllm/server/detokenization/decode_req.py index 9a85ea089..4c17f3f98 100644 --- a/lightllm/server/detokenization/decode_req.py +++ b/lightllm/server/detokenization/decode_req.py @@ -8,9 +8,11 @@ class DecodeReq: def __init__( self, + args, req: Req, is_pd_decode_mode: bool, ) -> None: + self.args = args self.request_id = req.request_id self.group_req_id = req.group_req_id self.prompt_ids = req.shm_prompt_ids.arr[0 : req.input_len].tolist() @@ -59,6 +61,7 @@ def can_set_release_mark(self): self.req.finish_status.is_finished() and self.req.candetoken_out_len == len(self.output_ids) and self.req.finish_token_index == self.input_len + len(self.output_ids) - 1 + and (self.req.radix_status.is_finished() if self.args.use_hiradix_cache else True) ): return True return False diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index 8f32d0992..57e0422d5 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -71,7 +71,7 @@ async def handle_loop(self): ) # p d 分离模式,decode节点的解码需要做一些特殊的修复。 - decode_req = DecodeReq(req, self.is_pd_decode_mode) + decode_req = DecodeReq(self.args, req, self.is_pd_decode_mode) if self.is_pd_decode_mode: decode_req = decode_mode_fix(decode_req, self.tokenizer, self.eos_id) # token_healing mode 的特殊初始化 diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index fa455c225..a2d7084ca 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -51,7 +51,11 @@ def __init__( context = zmq.asyncio.Context(2) self.send_to_router = context.socket(zmq.PUSH) self.send_to_router.connect(f"{args.zmq_mode}127.0.0.1:{router_port}") - + self.use_hiradix_cache = args.use_hiradix_cache + if self.use_hiradix_cache: + context_hiradix = zmq.asyncio.Context() + self.send_to_hiradix = context_hiradix.socket(zmq.PUSH) + self.send_to_hiradix.connect(f"{args.zmq_mode}127.0.0.1:{self.args.hiradix_server_ports[0]}") self.multinode_req_manager = None self.nnodes = args.nnodes self._shm_lock_pool = AtomicShmArrayLock(f"{get_unique_server_name()}_lightllm_resource_lock", 1) @@ -302,7 +306,7 @@ async def generate( ) req_objs.append(req_obj) - req_status = ReqStatus(group_request_id, multimodal_params, req_objs, start_time) + req_status = ReqStatus(group_request_id, multimodal_params, req_objs, start_time, self.use_hiradix_cache) self.req_id_to_out_inf[group_request_id] = req_status await self.transfer_to_next_module_or_node( @@ -476,10 +480,16 @@ async def transfer_to_next_module( protocol=pickle.HIGHEST_PROTOCOL, ) else: - self.send_to_router.send_pyobj( - group_req_objs.to_group_req_index(), - protocol=pickle.HIGHEST_PROTOCOL, - ) + if self.use_hiradix_cache: + self.send_to_hiradix.send_pyobj( + group_req_objs.to_group_req_index(), + protocol=pickle.HIGHEST_PROTOCOL + ) + else: + self.send_to_router.send_pyobj( + group_req_objs.to_group_req_index(), + protocol=pickle.HIGHEST_PROTOCOL, + ) return assert False, "dead code path" @@ -695,7 +705,8 @@ async def handle_loop(self): class ReqStatus: - def __init__(self, group_request_id, multimodal_params, req_objs: List[Req], start_time) -> None: + def __init__(self, group_request_id, multimodal_params, req_objs: List[Req], start_time, use_hiradix_cache) -> None: + self.use_hiradix_cache = use_hiradix_cache self.lock = asyncio.Lock() self.event = asyncio.Event() self.group_req_objs = GroupReqObjs( @@ -708,6 +719,11 @@ def __init__(self, group_request_id, multimodal_params, req_objs: List[Req], sta def can_release(self): for req in self.group_req_objs.shm_req_objs: - if not req.can_release(): - return False - return True + if self.use_hiradix_cache: + if req.can_release() and req.radix_status.is_finished(): + + return True + else: + if req.can_release(): + return True + return False diff --git a/lightllm/server/router/batch.py b/lightllm/server/router/batch.py index 14a987f49..f356f4ef6 100644 --- a/lightllm/server/router/batch.py +++ b/lightllm/server/router/batch.py @@ -40,8 +40,14 @@ def get_req_list_for_dp(self, dp_index: int): req_list.append(req) return req_list - def filter_out_finished_req(self, shm_req_manager: ShmReqManager): + def release_reqs(self, reqs: List[Req], shm_req_manager: ShmReqManager): + for req in reqs: + shm_req_manager.put_back_req_obj(req) + + def filter_out_finished_req(self): unfinished_req_ids = [] + finished_reqs = [] + for req in self.reqs: # 更新aborted 标记,可以触发推理进程主动退出aborted的请求。 if req.is_aborted: @@ -49,14 +55,13 @@ def filter_out_finished_req(self, shm_req_manager: ShmReqManager): if req.shm_infer_released: logger.info(f"router release req id {req.request_id}") - shm_req_manager.put_back_req_obj(req) - req = None + finished_reqs.append(req) else: unfinished_req_ids.append(req.request_id) self.reqs = [self.id_to_reqs[req_id] for req_id in unfinished_req_ids] self.id_to_reqs = {req.request_id: req for req in self.reqs} - return + return finished_reqs def pop_req(self, req_id): self.reqs = [req for req in self.reqs if req.request_id != req_id] diff --git a/lightllm/server/router/dynamic_prompt/hiradix/__init__.py b/lightllm/server/router/dynamic_prompt/hiradix/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/server/router/dynamic_prompt/hiradix/disk_cache_server.py b/lightllm/server/router/dynamic_prompt/hiradix/disk_cache_server.py new file mode 100644 index 000000000..293e011ab --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/hiradix/disk_cache_server.py @@ -0,0 +1,310 @@ +import torch +import time +import tempfile +import rpyc +import inspect +import asyncio +import threading +import torch.multiprocessing as mp +from typing import List +from rpyc.utils.server import ThreadedServer +from os.path import join +from lightllm.utils.log_utils import init_logger +from .io_objs import ShmReqInfo, GroupReqInfo, HitSate, PullState, PushState, CacheTask +from lightllm.server.core.objs import ShmReqManager +from lightllm.utils.graceful_utils import graceful_registry +from lightllm.common.radixmem_buffer import RadixMemoryBuffer +from lightllm.common.radixmem_manager import RadixBufferManager +from lightllm.server.core.objs import Req, RadixStatus + +logger = init_logger(__name__) + + +def wait_until_ready(task, timeout=10.0, check_interval=0.01): + start_time = time.time() + while not task.ready(): + time.sleep(check_interval) + if time.time() - start_time > timeout: + logger.error("Current kv cache task not ready in time") + return False + return True + + +class RemoteCacheManager: + def __init__(self, unique_name: str, rank_in_node: int, mem_manager): + tmp_dir = tempfile.mkdtemp(prefix=f"cache_{unique_name}_{rank_in_node}") + self.cache_file = join(tmp_dir, "cache_file") + all_buffers = mem_manager.kv_buffer + all_buffers = all_buffers.view(all_buffers.shape[0], all_buffers.shape[1], -1) + + from kvcache.python.jit import PyLocalCacheService + + self.py_cache_service = PyLocalCacheService( + file=self.cache_file, + storage_size=128 * (1024**3), # 128GB + num_shard=32, + kvcache_tensor=all_buffers, + num_worker=8, + ) + + def insert(self, cache_task: CacheTask): + assert cache_task.mode == "w", "Cache task mode must be 'w' for insert" + + t = self.py_cache_service.create( + tokens=cache_task.tokens, + kv_page_indexer=cache_task.kv_page_indexer, + mode=cache_task.mode, + start_pos=cache_task.start_pos, + ) + res = wait_until_ready(t) + + if not res: + self.py_cache_service.az5(t) + return False + + return True + + def read(self, cache_task: CacheTask): + assert cache_task.mode == "r", "Cache task mode must be 'r' for read" + + t = self.py_cache_service.create( + tokens=cache_task.tokens, + kv_page_indexer=cache_task.kv_page_indexer, + mode=cache_task.mode, + start_pos=cache_task.start_pos, + ) + + res = wait_until_ready(t) + return res + + def query(self, cache_task: CacheTask): + query_result = self.py_cache_service.query(cache_task.tokens) + + max_len = 0 + for result in query_result: + if result: + max_len += 1 + else: + break + + return max_len * self.block_size + + @property + def block_size( + self, + ): + return self.py_cache_service.tokens_per_block + + +class DiskCacheService(rpyc.Service): + def __init__(self, radix_manager=None, remote_cache_manager=None, shm_req_manager=None, rank_in_node=None): + super().__init__() + self.radix_manager = radix_manager + self.remote_cache_manager = remote_cache_manager + self.shm_req_manager = shm_req_manager + self.rank_in_node = rank_in_node + + def exposed_push(self, req_info): + req_info: ShmReqInfo = ShmReqInfo.from_dict(req_info) + req: Req = self.shm_req_manager.get_req_obj_by_index(req_info.shm_req_index) + req.link_prompt_ids_shm_array() + + if not req.radix_status.is_write_ready(self.rank_in_node): + raise RuntimeError("Radix cache is not ready for write.") + + token_ids = req.shm_prompt_ids.arr[0 : req.shm_cur_kv_len] + keys = torch.tensor(token_ids, dtype=torch.int64, device="cpu") + + _, index_list = self.radix_manager.query_cache(tokens=token_ids.tolist()) + index_tensor = torch.tensor(index_list, device="cpu", dtype=torch.int32) + assert len(keys) == len(index_tensor), f"keys length {len(keys)} != index length {len(index_list)}" + + if len(keys) != len(index_tensor): + raise ValueError(f"Mismatch in keys and index size: {len(keys)} != {len(index_tensor)}") + + insert_task = CacheTask(tokens=keys, kv_page_indexer=index_tensor, mode="w") + result = self.remote_cache_manager.insert(insert_task) + + reqs = [req] + self.set_reqs_radix_status(reqs, RadixStatus.WRITE_DONE) + self.put_back_req_objs(reqs) + + return PushState(state=result).to_dict() + + def set_reqs_radix_status(self, reqs: List[Req], status: int): + for req in reqs: + req.radix_status.set_status(self.rank_in_node, status) + + def put_back_req_objs(self, reqs: List[Req]): + for req in reqs: + self.shm_req_manager.put_back_req_obj(req) + + def exposed_pull(self, group_req): + group_req: GroupReqInfo = GroupReqInfo.from_dict(group_req) + reqs: List[Req] = [] + for shm_req_index in group_req.shm_req_indexes: + req: Req = self.shm_req_manager.get_req_obj_by_index(shm_req_index) + reqs.append(req) + + req = reqs[0] + req.link_prompt_ids_shm_array() + keys = req.get_prompt_ids() + + query_len, _ = self.radix_manager.query_cache(tokens=keys) + if query_len > 0: + radix_state = RadixStatus.READ_READY + cache_state = PullState(query_len, HitSate.MEM) + else: + query_task = CacheTask(tokens=keys) + query_len = self.remote_cache_manager.query(query_task) + + if query_len > 0: + + index = self.radix_manager.alloc(query_len) + read_task = CacheTask(tokens=keys[:query_len], kv_page_indexer=index, mode="r") + self.remote_cache_manager.read(read_task) + + self.radix_manager.write(tokens=keys[:query_len], values=index.tolist()) + + radix_state = RadixStatus.READ_READY + cache_state = PullState(query_len, HitSate.DISK) + else: + radix_state = RadixStatus.NOCACHE + cache_state = PullState(0, HitSate.NONE) + + self.set_reqs_radix_status(reqs, radix_state) + self.put_back_req_objs(reqs) + + return cache_state.to_dict() + + +class DiskCacheClient: + def __init__(self, rank_in_node: int, service=None, use_rpc=True, proc=None): + self.rank_in_node = rank_in_node + self.use_rpc = use_rpc + self.service = service + self.proc = proc + if self.use_rpc: + self._push = self._async_wraper(self.service.push) + self._pull = self._async_wraper(self.service.pull) + else: + self._push = self.service.exposed_push + self._pull = self.service.exposed_pull + + def _async_wraper(self, func): + async_func = rpyc.async_(func) + + async def _wrapped(*args, **kwargs): + result = async_func(*args, **kwargs) + await asyncio.to_thread(result.wait) + return result.value + + return _wrapped + + async def push(self, req_info: ShmReqInfo): + if self.use_rpc: + return await self._push(req_info) + else: + return self._push(req_info) + + async def pull(self, group_req: GroupReqInfo): + if self.use_rpc: + return await self._pull(group_req) + else: + return self._pull(group_req) + + +def start_cache_server(radix_manager, remote_cache_manager, shm_req_manager, rank_in_node, port, init_event): + class CustomService(DiskCacheService): + def __init__(self): + super().__init__(radix_manager, remote_cache_manager, shm_req_manager, rank_in_node) + + def start(): + try: + server = ThreadedServer( + CustomService(), port=port, protocol_config={"allow_public_attrs": True, "allow_pickle": True} + ) + init_event.set() + server.start() + except Exception as e: + logger.error(f"Failed to start ThreadedServer: {e}") + + t = threading.Thread(target=start, daemon=True) + t.start() + + logger.info(f"DiskCacheService started on port {port}") + return t + + +def _init_server(device_id, mem_queue, radix_lock: List[mp.Lock], init_event: mp.Event, port: int = 18861): + from lightllm.utils.envs_utils import get_unique_server_name + + graceful_registry(inspect.currentframe().f_code.co_name) + torch.cuda.set_device(device_id) + mem_proties, shared_mem_data = mem_queue.get() + mem_manager = RadixMemoryBuffer( + mem_propties=mem_proties, shared_data=shared_mem_data, lock=radix_lock, rank_in_node=device_id + ) + remote_cache_manager = RemoteCacheManager( + unique_name=get_unique_server_name(), + rank_in_node=device_id, + mem_manager=mem_manager, + ) + radix_manager = RadixBufferManager(radix_buffer=mem_manager, radix_mem_data=shared_mem_data, lock=radix_lock) + + shm_req_manager = ShmReqManager() + + t = start_cache_server( + radix_manager=radix_manager, + remote_cache_manager=remote_cache_manager, + shm_req_manager=shm_req_manager, + rank_in_node=device_id, + port=port, + init_event=init_event, + ) + t.join() + return + + +async def start_disk_cache_server_process(args, device_id, node_word_size, mem_queue, radix_lock, port): + """ + Start the DiskCacheManager in process. + """ + from lightllm.utils.envs_utils import get_unique_server_name + + if node_word_size == 1: + mem_proties, shared_mem_data = mem_queue.get() + mem_buffer = RadixMemoryBuffer( + mem_propties=mem_proties, shared_data=shared_mem_data, lock=radix_lock, rank_in_node=device_id + ) + remote_cache_manager = RemoteCacheManager( + unique_name=get_unique_server_name(), + rank_in_node=device_id, + mem_manager=mem_buffer, + ) + shm_req_manager = ShmReqManager() + + radix_manager = RadixBufferManager(radix_buffer=mem_buffer, radix_mem_data=shared_mem_data, lock=radix_lock) + service = DiskCacheService(radix_manager, remote_cache_manager, shm_req_manager) + client = DiskCacheClient(service=service, rank_in_node=0, use_rpc=False) + return client + + init_event = mp.Event() + proc = mp.Process(target=_init_server, args=(device_id, mem_queue, radix_lock, init_event, port)) + proc.start() + + init_event.wait(timeout=60) + + max_wait_times = 20 + for i in range(max_wait_times): + try: + conn = rpyc.connect("localhost", port, config={"allow_pickle": True}) + break + except Exception: + asyncio.sleep(2) + + service = conn.root + client = DiskCacheClient(rank_in_node=device_id, service=service, use_rpc=True, proc=proc) + assert proc.is_alive() + logger.info(f"disk cache process for device {device_id} start!") + return client diff --git a/lightllm/server/router/dynamic_prompt/hiradix/hiradix_cache.py b/lightllm/server/router/dynamic_prompt/hiradix/hiradix_cache.py new file mode 100644 index 000000000..33682095a --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/hiradix/hiradix_cache.py @@ -0,0 +1,159 @@ +import torch +import numpy as np +from ..radix_cache import RadixCache +from lightllm.common.mem_manager import MemoryManager +from lightllm.common.radixmem_buffer import RadixMemoryBuffer +from lightllm.common.radixmem_manager import RadixBufferManager +from lightllm.utils.log_utils import init_logger +from ..shared_arr import SharedArray +from lightllm.server.core.objs import Req, RadixStatus + +logger = init_logger(__name__) + + +class LocalCacheManager: + + def __init__(self, radix_manager: RadixBufferManager, mem_manager: MemoryManager, rank_in_node): + self.radix_manager = radix_manager + self.radix_buffer: RadixMemoryBuffer = self.radix_manager.radix_buffer + self.mem_manager = mem_manager + self.rank_in_node = rank_in_node + + def insert(self, req: Req, key: torch.Tensor, value=None): + query_len, query_index = self._query_cache(req, key) + + alloc_len = len(key) - query_len + if alloc_len == 0: + self._set_radix_staus(req, RadixStatus.WRITE_READY) + return + + new_index = self._alloc_and_copy_kv(alloc_len, value) + + start_pos = max(0, (query_len - 1) // self.chunk_size * self.chunk_size) + self.radix_manager.write(tokens=key.tolist(), values=query_index + new_index, start_pos=start_pos) + + self._set_radix_staus(req, RadixStatus.WRITE_READY) + + def _query_cache(self, req, key): + if req.radix_status.is_no_need_cache(self.rank_in_node): + return 0, [] + + if req.radix_status.is_read_ready(self.rank_in_node): + query_len, mem_index = self.radix_manager.query_cache(key.tolist()) + return query_len, mem_index + return 0, [] + + def _alloc_and_copy_kv(self, alloc_len, value): + assert alloc_len > 0, "No allocation needed" + + new_index = self.radix_manager.alloc(alloc_len) + dst_kv_buffer = self.radix_buffer.get_kv_buffer(new_index) + src_kv_buffer = self.mem_manager.get_index_kv_buffer(value[-alloc_len:])["kv_buffer"] + + assert len(src_kv_buffer) == len( + dst_kv_buffer + ), f"Mis match buffer size src {len(src_kv_buffer)} != dst {len(dst_kv_buffer)}" + + self.copy_kv_from_gpu_to_cpu(src_kv_buffer, dst_kv_buffer) + return new_index.tolist() + + def _set_radix_staus(self, req, status): + req.radix_status.set_status(self.rank_in_node, status) + + def read(self, key, value, query_index, alloc_len): + try: + src_kv_buffer = self.radix_buffer.get_kv_buffer(index=query_index[-alloc_len:]) + dst_kv_buffer = self.mem_manager.get_index_kv_buffer(index=value[-alloc_len:])["kv_buffer"] + + assert len(src_kv_buffer) == len( + dst_kv_buffer + ), f"Mis match buffer size src {len(src_kv_buffer)} != dst {len(dst_kv_buffer)}" + + self.copy_kv_from_cpu_to_gpu(src_kv_buffer, dst_kv_buffer) + + except Exception as e: + logger.error(f"LocalCache read from radix mem error {e}") + return False + + return True + + def query(self, req: Req, key): + return self._query_cache(req, key) + + @property + def chunk_size(self): + return self.radix_manager.chunk_size + + def copy_kv_from_cpu_to_gpu(self, src_kv_tensor, dst_kv_tensor): + dst_kv_tensor.copy_(src_kv_tensor, non_blocking=True) + + def copy_kv_from_gpu_to_cpu(self, src_kv_tensor, dst_kv_tensor): + dst_kv_tensor.copy_(src_kv_tensor, non_blocking=True) + + +class HiRadixCache(RadixCache): + def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, radix_manager, radix_info_queue): + super().__init__(unique_name, total_token_num, rank_in_node, mem_manager) + self.rank_in_node = rank_in_node + self.radix_manager: RadixBufferManager = radix_manager + self.local_cache_manager = LocalCacheManager( + radix_manager=self.radix_manager, mem_manager=mem_manager, rank_in_node=rank_in_node + ) + self.radix_info_queue = radix_info_queue + self.is_hi_radix_cache = True + self.disk_cache_match_count = SharedArray( + f"{unique_name}_disk_cache_match_count_{rank_in_node}", (1,), dtype=np.int64 + ) + self.disk_cache_match_count.arr[0] = 0 + self.total_match_count = SharedArray(f"{unique_name}_total_match_count_{rank_in_node}", (1,), dtype=np.int64) + self.total_match_count.arr[0] = 0 + self.disk_cache_match_ratio = SharedArray( + f"{unique_name}_disk_cache_match_ratio_{rank_in_node}", (1,), dtype=np.float32 + ) + self.disk_cache_match_ratio.arr[0] = 0.0 + logger.info(f"Initializing HiRadixCache {rank_in_node}") + + def insert(self, key, value=None, req=None): + if len(key) == 0: + return 0 + share_len = super().insert(key, value) + if req is None: + return + self.local_cache_manager.insert(req, key, value) + return share_len + + def match_prefix(self, req, key, update_refs=False): + assert len(key) != 0 + self.total_match_count.arr[0] += 1 + ans_value_list = [] + ans_value = None + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=False) + if tree_node.node_prefix_total_len != 0: + ans_value = torch.concat(ans_value_list) + max_len = 0 + if tree_node.node_prefix_total_len < len(key): + max_len, query_index = self.local_cache_manager.query(req, key) + + logger.debug( + f"HiCache rank_in_node={self.rank_in_node} current key len {len(key)} match radix len " + f"{tree_node.node_prefix_total_len}, max len {max_len}" + ) + if max_len > tree_node.node_prefix_total_len: + pull_len = max_len - tree_node.node_prefix_total_len + self.disk_cache_match_count.arr[0] += 1 + self.disk_cache_match_ratio.arr[0] = self.disk_cache_match_count.arr[0] / self.total_match_count.arr[0] + self.free_radix_cache_to_get_enough_token(pull_len) + buffers = self.mem_manager.alloc(pull_len) + if ans_value is not None: + buffers = torch.concat([ans_value, buffers]) + logger.debug( + f"HiCache current match ratio {self.disk_cache_match_ratio.arr[0]}," + f"pulled cache len {pull_len} from disk" + ) + res = self.local_cache_manager.read(key[:max_len], buffers, query_index, alloc_len=pull_len) + if res: + super().insert(key[:max_len], buffers) + else: + self.mem_manager.free(buffers[tree_node.node_prefix_total_len :]) + + return super().match_prefix(key, update_refs=update_refs) diff --git a/lightllm/server/router/dynamic_prompt/hiradix/io_objs.py b/lightllm/server/router/dynamic_prompt/hiradix/io_objs.py new file mode 100644 index 000000000..35c6e7da6 --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/hiradix/io_objs.py @@ -0,0 +1,71 @@ +import torch +from dataclasses import dataclass +from enum import Enum +from typing import List + + +@dataclass +class ShmReqInfo: + request_id: int + shm_req_index: int + + def to_dict(self): + return {"request_id": self.request_id, "shm_req_index": self.shm_req_index} + + @staticmethod + def from_dict(d): + return ShmReqInfo(request_id=d["request_id"], shm_req_index=d["shm_req_index"]) + + +@dataclass +class GroupReqInfo: + group_req_id: int + shm_req_indexes: List[int] + + def to_dict(self): + return {"group_req_id": self.group_req_id, "shm_req_indexes": self.shm_req_indexes} + + @staticmethod + def from_dict(d): + return GroupReqInfo(group_req_id=d["group_req_id"], shm_req_indexes=d["shm_req_indexes"]) + + +@dataclass +class CacheTask: + tokens: torch.Tensor + mode: str = None + kv_page_indexer: torch.Tensor = None + start_pos: torch.Tensor = 0 + + +@dataclass +class PushState: + state: bool + + def to_dict(self): + return {"state": self.state} + + @staticmethod + def from_dict(d): + return PushState( + state=d["state"], + ) + + +class HitSate(Enum): + NONE = -1 + MEM = 0 + DISK = 1 + + +@dataclass +class PullState: + match_length: int + cache_source: HitSate + + def to_dict(self): + return {"match_length": self.match_length, "cache_source": self.cache_source.name} + + @staticmethod + def from_dict(d): + return PullState(match_length=d["match_length"], cache_source=HitSate[d["cache_source"]]) diff --git a/lightllm/server/router/dynamic_prompt/hiradix/manager.py b/lightllm/server/router/dynamic_prompt/hiradix/manager.py new file mode 100644 index 000000000..4b695071c --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/hiradix/manager.py @@ -0,0 +1,144 @@ +import zmq +import zmq.asyncio +import inspect +import pickle +import torch.multiprocessing as mp +import asyncio +from typing import List +from lightllm.server.core.objs.io_objs import GroupReqIndexes +from lightllm.utils.log_utils import init_logger +from lightllm.utils.graceful_utils import graceful_registry +from .disk_cache_server import DiskCacheClient +from lightllm.server.core.objs import ShmReqManager +from .io_objs import ShmReqInfo, GroupReqInfo +from lightllm.server.core.objs import Req + +logger = init_logger(__name__) + + +class HiRadixCacheManagerServer: + def __init__(self, args, mem_queues: List[mp.Queue], radix_locks: List[mp.Lock], router_port: int): + self.args = args + self.mem_queues = mem_queues + self.radix_locks = radix_locks + self.node_world_size = args.tp // args.nnodes + self.disk_cache_processes = [] + self.ports = args.hiradix_cache_ports + self.cache_server_client = [] + context = zmq.asyncio.Context(3) + self.recv_from_httpserver = context.socket(zmq.PULL) + recv_from_http_port, recv_from_router_port = self.args.hiradix_server_ports + self.recv_from_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{recv_from_http_port}") + self.clients: List[DiskCacheClient] = [] + self.send_to_router = context.socket(zmq.PUSH) + self.send_to_router.connect(f"{args.zmq_mode}127.0.0.1:{router_port}") + self.recv_from_router = context.socket(zmq.PULL) + self.recv_from_router.bind(f"{args.zmq_mode}127.0.0.1:{recv_from_router_port}") + self.shm_req_manager = ShmReqManager() + + async def asyn_init(self): + self.pull_queue = asyncio.Queue() + self.push_queue = asyncio.Queue() + + async def start_all(self): + from lightllm.server.router.dynamic_prompt.hiradix.disk_cache_server import start_disk_cache_server_process + + for rank_in_node in range(self.node_world_size): + client = await start_disk_cache_server_process( + self.args, + device_id=rank_in_node, + node_word_size=self.node_world_size, + mem_queue=self.mem_queues[rank_in_node], + radix_lock=self.radix_locks[rank_in_node], + port=self.ports[rank_in_node], + ) + self.clients.append(client) + + async def pull_cache(self, group_req): + tasks = [] + group_req_info = GroupReqInfo( + group_req_id=group_req.group_req_id, shm_req_indexes=group_req.shm_req_indexes + ).to_dict() + for client in self.clients: + task = client.pull(group_req_info) + tasks.append(task) + all_results = await asyncio.gather(*tasks) + logger.info(f"pull cache results {all_results}") + await self.send_to_router.send_pyobj(group_req, protocol=pickle.HIGHEST_PROTOCOL) + + async def push_cache(self, req_info): + tasks = [] + for client in self.clients: + task = client.push(req_info) + tasks.append(task) + all_results = await asyncio.gather(*tasks) + req: Req = self.shm_req_manager.get_req_obj_by_index(req_info["shm_req_index"]) + assert req.radix_status.is_write_done() + req.radix_status.set_finished() + self.shm_req_manager.put_back_req_obj(req) + logger.info(f"push cache results {all_results}") + + async def pull_woker(self): + while True: + req: GroupReqInfo = await self.pull_queue.get() + await self.pull_cache(req) + await asyncio.sleep(0.01) + + async def push_woker(self): + while True: + req: ShmReqInfo = await self.push_queue.get() + await self.push_cache(req.to_dict()) + await asyncio.sleep(0.01) + + async def run(self): + await self.asyn_init() + await asyncio.gather( + self.loop_for_netio_req_to_pull(), self.pull_woker(), self.loop_for_netio_req_to_push(), self.push_woker() + ) + + async def loop_for_netio_req_to_push(self): + while True: + recv_req: ShmReqInfo = await self.recv_from_router.recv_pyobj() + if isinstance(recv_req, ShmReqInfo): + await self.push_queue.put(recv_req) + else: + raise ValueError(f"Invalid request: {recv_req}") + + async def loop_for_netio_req_to_pull(self): + while True: + recv_req: GroupReqIndexes = await self.recv_from_httpserver.recv_pyobj() + if isinstance(recv_req, GroupReqIndexes): + await self.pull_queue.put(recv_req) + else: + raise ValueError(f"Invalid request: {recv_req}") + + +def _init_env_server(args, mem_queues, radix_locks: List[mp.Lock], init_event: mp.Event, router_port: int): + graceful_registry(inspect.currentframe().f_code.co_name) + hiradix_cache_manager = HiRadixCacheManagerServer( + args, mem_queues=mem_queues, radix_locks=radix_locks, router_port=router_port + ) + asyncio.run(hiradix_cache_manager.start_all()) + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + init_event.set() + loop.run_until_complete(hiradix_cache_manager.run()) + except Exception as e: + logger.error(f"hiradix server error happend {e}") + return + + +def start_hiradix_cache_manager_process_server( + args, radix_mem_queues: List[mp.Queue], radix_locks: List[mp.Lock], router_port: int +): + """ + Start the HiRadix cache manager process. + """ + init_event = mp.Event() + proc = mp.Process(target=_init_env_server, args=(args, radix_mem_queues, radix_locks, init_event, router_port)) + proc.start() + init_event.wait() + logger.info("HiRadix cache manager process started") + assert proc.is_alive() + return diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index 65ec4354b..e4c34bc85 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -123,6 +123,8 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager: Memo ) self.tree_total_tokens_num.arr[0] = 0 + self.is_hi_radix_cache = False + def insert(self, key, value=None): if value is None: value = key diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index c10847e3f..d8a0acad3 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -11,10 +11,11 @@ import torch.multiprocessing as mp import torch.distributed as dist import multiprocessing -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from .batch import Batch, Req from .model_infer.model_rpc import start_model_process, ModelRpcClient from .req_queue import build_req_queue +from lightllm.server.router.dynamic_prompt.hiradix.io_objs import ShmReqInfo from lightllm.server.core.objs.io_objs import GroupReqIndexes from lightllm.server.core.objs import ShmReqManager, StartArgs from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient @@ -56,7 +57,7 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por self.read_only_statics_mem_manager = ReadOnlyStaticsMemoryManager() # 初始化 radix_cache_client 用于读取 prompt cache 的管理信息 self.radix_cache_client = None - + self.use_hiradix_cache = args.use_hiradix_cache and not args.disable_dynamic_prompt_cache self.mtp_step = args.mtp_step # 共享变量,用于存储router端调度分析得到的机器负载信息 @@ -80,6 +81,13 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por self.send_to_detokenization = context.socket(zmq.PUSH) self.send_to_detokenization.connect(f"{args.zmq_mode}127.0.0.1:{detokenization_port}") + self.router_port = router_port + if self.use_hiradix_cache: + hiradix_port = self.args.hiradix_server_ports[1] + context_radix = zmq.asyncio.Context() + self.send_to_hiradix_server = context_radix.socket(zmq.PUSH) + self.send_to_hiradix_server.connect(f"{args.zmq_mode}127.0.0.1:{hiradix_port}") + if self.is_multinode_tp: self.mulitnode_group = dist.init_process_group( backend="gloo", @@ -114,6 +122,15 @@ async def wait_to_model_ready(self): self.mem_queues: List[torch.multiprocessing.Queue] = [ torch.multiprocessing.Queue() for _ in range(self.node_world_size) ] + self.radix_mem_queues: List[Union[torch.multiprocessing.Queue, None]] = [ + torch.multiprocessing.Queue() if self.use_hiradix_cache else None for _ in range(self.node_world_size) + ] + self.radix_info_queues: List[Union[torch.multiprocessing.Queue, None]] = [ + torch.multiprocessing.Queue() if self.use_hiradix_cache else None for _ in range(self.node_world_size) + ] + self.radix_locks: List[Union[torch.multiprocessing.Lock, None]] = [ + torch.multiprocessing.Lock() if self.use_hiradix_cache else None for _ in range(self.node_world_size) + ] self.rpc_event = multiprocessing.Event() self.rpc_finished_event = multiprocessing.Event() @@ -130,6 +147,9 @@ async def wait_to_model_ready(self): info_queue=self.info_queue, mem_queue=self.mem_queues[(rank_id % node_world_size)], router_lock=self.router_lock, + radix_mem_queue=self.radix_mem_queues[(rank_id % node_world_size)], + radix_info_queue=self.radix_info_queues[(rank_id % node_world_size)], + radix_lock=self.radix_locks[(rank_id % node_world_size)] ) self.model_rpc_servers.append(rpc_model) @@ -158,6 +178,7 @@ async def wait_to_model_ready(self): "return_all_prompt_logprobs": self.args.return_all_prompt_logprobs, "use_reward_model": self.args.use_reward_model, "disable_dynamic_prompt_cache": self.args.disable_dynamic_prompt_cache, + "use_hiradix_cache": self.args.use_hiradix_cache, "data_type": self.args.data_type, "eos_id": self.eos_id, "diverse_mode": self.args.diverse_mode, @@ -201,6 +222,11 @@ async def wait_to_model_ready(self): ) start_decode_kv_move_manager_process(self.args, self.info_queue, self.mem_queues) + + if self.use_hiradix_cache: + # 启动 hi radix cache 管理进程 + from lightllm.server.router.dynamic_prompt.hiradix.manager import start_hiradix_cache_manager_process_server + start_hiradix_cache_manager_process_server(self.args, self.radix_mem_queues, self.radix_locks, self.router_port) return @@ -307,7 +333,7 @@ async def _step(self): self._add_new_batch_to_running_batch(new_batch=new_batch) await self._prefill_batch(new_batch) self.stats_tool.count_prompt_tokens(new_batch) - self._filter_reqs_from_running_batch() + await self._filter_reqs_from_running_batch() self.has_wait_tokens = 0 # Check if need pause some requests for decode. @@ -324,7 +350,7 @@ async def _step(self): # Decode self.stats_tool.count_output_tokens(self.running_batch) await self._decode_batch() - self._filter_reqs_from_running_batch() + await self._filter_reqs_from_running_batch() self.has_wait_tokens += 1 return @@ -354,9 +380,11 @@ def _add_new_batch_to_running_batch(self, new_batch: Batch): self.running_batch.merge(new_batch) return - def _filter_reqs_from_running_batch(self): + async def _filter_reqs_from_running_batch(self): if self.running_batch is not None: - self.running_batch.filter_out_finished_req(self.shm_req_manager) + finishs_reqs = self.running_batch.filter_out_finished_req() + await self._send_hiradix_manager(finishs_reqs) + self.running_batch.release_reqs(finishs_reqs, self.shm_req_manager) if self.running_batch.is_clear(): self.running_batch = None return @@ -368,6 +396,17 @@ def _can_decode(self, batch: Batch, dp_index: int): batch.get_batch_decode_need_tokens()[dp_index] + self.get_used_tokens(dp_index) <= self.max_total_token_num ) + async def _send_hiradix_manager(self, reqs): + if not self.use_hiradix_cache: + return + for req in reqs: + req_info = ShmReqInfo( + req.request_id, + req.index_in_shm_mem + ) + await self.send_to_hiradix_server.send_pyobj(req_info, protocol=pickle.HIGHEST_PROTOCOL) + return + def _send_detokenization_pack(self): # 发 mtp_step + 1 个 None 包触发一下 detokenization, 因为在开启 mtp feature 以后,每一步 # 生成的 token 数量最多为 mtp_step + 1 个,如果不及时触发 detokenization, 会带来一些性能 diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 10b68245c..11189b85c 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -33,9 +33,10 @@ class InferenceContext: vocab_size = None overlap_stream: torch.cuda.Stream = None # 一些情况下推理进程进行异步折叠操作的异步流对象。 + backend = None def register( - self, req_manager: ReqManager, radix_cache: RadixCache, shm_req_manager: ShmReqManager, vocab_size: int + self, req_manager: ReqManager, radix_cache: RadixCache, shm_req_manager: ShmReqManager, vocab_size: int, backend ): self.req_manager = req_manager self.radix_cache = radix_cache @@ -46,6 +47,7 @@ def register( self.infer_req_ids = [] self.vocab_size = vocab_size + self.backend = backend return def get_overlap_stream(self) -> torch.cuda.Stream: @@ -95,7 +97,13 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq", is_group_finis value = self.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() if is_group_finished: - prefix_len = self.radix_cache.insert(key, value) + if hasattr(self.radix_cache, "is_hi_radix_cache") and getattr(self.radix_cache, "is_hi_radix_cache"): + prefix_len = self.radix_cache.insert( + key, value, + req=req.shm_req + ) + else: + prefix_len = self.radix_cache.insert(key, value) old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len]) if req.shared_kv_node is not None: @@ -110,6 +118,7 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq", is_group_finis self.radix_cache.dec_node_ref_counter(req.shared_kv_node) req.shared_kv_node = None + def _save_promptcache_kvbuffer(self): """ save prompt cache kv buffer @@ -266,6 +275,7 @@ def __init__( # 当开启后,mtp_gen_token_ids 保存多生成的多余的token_id,但是在后面的 # 步骤中需要重新进行校验。 self.mtp_gen_token_ids: List[int] = [] + self.shm_req = None def init_all(self): if self.initialized is False: @@ -294,7 +304,10 @@ def init_all(self): input_token_ids = self.shm_req.shm_prompt_ids.arr[0 : self.get_cur_total_len()] key = torch.tensor(input_token_ids, dtype=torch.int64, device="cpu") key = key[0 : len(key) - 1] # 最后一个不需要,因为需要一个额外的token,让其在prefill的时候输出下一个token的值 - share_node, kv_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=True) + if hasattr(g_infer_context.radix_cache, "is_hi_radix_cache") and getattr(g_infer_context.radix_cache, "is_hi_radix_cache"): + share_node, kv_len, value_tensor = g_infer_context.radix_cache.match_prefix(self.shm_req, key, update_refs=True) + else: + share_node, kv_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=True) if share_node is not None: self.shared_kv_node = share_node ready_cache_len = share_node.node_prefix_total_len @@ -305,6 +318,7 @@ def init_all(self): self.shm_req.shm_cur_kv_len = self.cur_kv_len + self.init_radix_status() self.initialized = True self.paused = False return @@ -312,6 +326,9 @@ def init_all(self): def is_uninitialized(self): return not self.initialized or self.paused + def init_radix_status(self): + return g_infer_context.backend.set_radix_status(self.shm_req) + def get_output_len(self): return self.cur_output_len diff --git a/lightllm/server/router/model_infer/mode_backend/__init__.py b/lightllm/server/router/model_infer/mode_backend/__init__.py index 7ad15f00f..9e946a19d 100644 --- a/lightllm/server/router/model_infer/mode_backend/__init__.py +++ b/lightllm/server/router/model_infer/mode_backend/__init__.py @@ -2,6 +2,7 @@ from .continues_batch.impl_for_return_all_prompt_logprobs import ReturnPromptLogProbBackend from .continues_batch.impl_for_reward_model import RewardModelBackend from .chunked_prefill.impl import ChunkedPrefillBackend +from .chunked_prefill.impl_for_hiradix_cache import ChunkedPrefillBackendHiCache from .diverse_backend.impl import DiversehBackend from .chunked_prefill.impl_for_token_healing import TokenHealingBackend from .chunked_prefill.impl_for_outlines_constraint_mode import OutlinesConstraintBackend diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index dd1ea45fe..25fba774e 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -7,7 +7,8 @@ from lightllm.utils.log_utils import init_logger from lightllm.models import get_model from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache -from lightllm.server.router.model_infer.infer_batch import InferReq +from lightllm.server.router.dynamic_prompt.hiradix.hiradix_cache import HiRadixCache +from lightllm.server.router.model_infer.infer_batch import InferReq, InferSamplingParams from lightllm.server.router.token_load import TokenLoad from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock from lightllm.common.basemodel.basemodel import TpPartBaseModel @@ -56,6 +57,8 @@ def init_model(self, kvargs): self.use_dynamic_prompt_cache = not self.args.disable_dynamic_prompt_cache self.eos_id: List[int] = kvargs.get("eos_id", [2]) self.disable_cudagraph = self.args.disable_cudagraph + self.use_hiradix_cache = kvargs.get("use_hiradix_cache", False) + self.radix_lock = kvargs.get("radix_lock", None) self.logger = init_logger(__name__) @@ -113,16 +116,29 @@ def init_model(self, kvargs): "quant_type": kvargs.get("quant_type", None), "quant_cfg": kvargs.get("quant_cfg", None), "run_mode": self.run_mode, + "use_hiradix_cache": self.use_hiradix_cache, + "hiradix_cache_gpu": kvargs.get("hiradix_cache_gpu", False), + "hiradix_cache_token_num": kvargs.get("hiradix_cache_token_num", False), + "radix_lock": self.radix_lock } self.model, self.is_multimodal = get_model(model_cfg, model_kvargs) self.model: TpPartBaseModel = self.model # for easy typing set_random_seed(2147483647) self.radix_cache = ( - RadixCache( + HiRadixCache( get_unique_server_name(), self.model.mem_manager.size, self.rank_in_node, mem_manager=self.model.mem_manager, + radix_manager=self.model.radix_manager, + radix_info_queue=kvargs.get("radix_info_queue", None) + ) + if self.use_hiradix_cache + else RadixCache( + get_unique_server_name(), + self.model.mem_manager.size, + self.rank_in_node, + mem_manager=self.model.mem_manager ) if self.use_dynamic_prompt_cache else None @@ -138,6 +154,7 @@ def init_model(self, kvargs): radix_cache=self.radix_cache, shm_req_manager=self.shm_req_manager, vocab_size=self.model.vocab_size, + backend=self ) # 初始化 dp 模式使用的通信 tensor, 对于非dp模式,不会使用到 @@ -150,6 +167,13 @@ def init_model(self, kvargs): self.init_custom() return + + def set_radix_status(self, req): + if not self.use_hiradix_cache: + return + if self.is_master_in_dp: + req.radix_status.rank_status.set_status(self.dp_rank_in_node, self.dp_world_size) + return False def init_custom(self): pass diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_hiradix_cache.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_hiradix_cache.py new file mode 100644 index 000000000..5a36132df --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_hiradix_cache.py @@ -0,0 +1,26 @@ +import torch +from typing import List, Tuple +from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend +from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end +from lightllm.utils.log_utils import init_logger +from lightllm.server.router.model_infer.infer_batch import g_infer_context +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample +from lightllm.server.router.model_infer.mode_backend.pre import ( + prepare_prefill_inputs, + prepare_decode_inputs, +) +import torch.multiprocessing as mp +from .impl import ChunkedPrefillBackend + + +logger = init_logger(__name__) + +class ChunkedPrefillBackendHiCache(ChunkedPrefillBackend): + + def __init__(self, radix_mem_queue: mp.Queue) -> None: + super().__init__() + self.radix_mem_queue = radix_mem_queue + + def init_custom(self): + self.radix_mem_queue.put((self.model.mem_propties, self.model.shared_mem_data)) \ No newline at end of file diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py index a54b54980..2ea2366cf 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py @@ -221,7 +221,6 @@ def __remove_dead_trans_obj(self): gc.collect() return - def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp.Event): import lightllm.utils.rpyc_fix_utils as _ diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 311c2725f..1e0ea2091 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -10,6 +10,7 @@ ContinuesBatchBackend, ReturnPromptLogProbBackend, ChunkedPrefillBackend, + ChunkedPrefillBackendHiCache, DiversehBackend, RewardModelBackend, TokenHealingBackend, @@ -49,12 +50,18 @@ def __init__( rpc_finished_event: multiprocessing.Event, info_queue: mp.Queue, mem_queue: mp.Queue, + radix_mem_queue: mp.Queue = None, + radix_info_queue: mp.Queue = None, + radix_lock: mp.Lock = None ): super().__init__() self.args: StartArgs = args self.node_world_size = node_world_size self.info_queue = info_queue self.mem_queue = mem_queue + self.radix_mem_queue = radix_mem_queue + self.radix_info_queue = radix_info_queue + self.radix_lock = radix_lock self.rpc_event = rpc_event self.rpc_finished_event = rpc_finished_event @@ -124,6 +131,14 @@ def init_model(self, kvargs): assert not (is_outlines_constraint_mode and is_xgrammar_constraint_mode), "only one constraint mode can be true" is_prefill_node = self.args.run_mode == "prefill" is_decode_node = self.args.run_mode == "decode" + use_hiradix_cache = self.args.use_hiradix_cache and not self.args.disable_dynamic_prompt_cache + kvargs.update({ + "use_hiradix_cache": use_hiradix_cache, + "hiradix_cache_gpu": self.args.hiradix_cache_gpu, + "radix_info_queue": self.radix_info_queue, + "radix_lock": self.radix_lock, + "hiradix_cache_token_num": self.args.hiradix_cache_token_num + }) enable_mtp = self.args.mtp_mode is not None @@ -177,7 +192,10 @@ def init_model(self, kvargs): if enable_mtp: self.backend = ContinuesBatchWithMTPBackend() else: - self.backend = ChunkedPrefillBackend() + if use_hiradix_cache: + self.backend = ChunkedPrefillBackendHiCache(self.radix_mem_queue) + else: + self.backend = ChunkedPrefillBackend() logger.info(f"use {self.backend.__class__.__name__}") self.backend.init_model(kvargs) @@ -287,6 +305,9 @@ def _init_env( rpc_event: mp.Event, rpc_finished_event: mp.Event, success_event: mp.Event, + radix_mem_queue: mp.Queue = None, + radix_info_queue: mp.Queue = None, + radix_lock: mp.Lock = None ): import lightllm.utils.rpyc_fix_utils as _ @@ -300,7 +321,8 @@ def _init_env( g_router_lock.obj = router_lock model_rpc_server = ModelRpcServer( - args, rank, rank_in_node, node_world_size, rpc_event, rpc_finished_event, info_queue, mem_queue + args, rank, rank_in_node, node_world_size, rpc_event, rpc_finished_event, info_queue, mem_queue, + radix_mem_queue, radix_info_queue, radix_lock ) success_event.set() @@ -318,6 +340,9 @@ async def start_model_process( info_queue: mp.Queue, mem_queue: mp.Queue, router_lock: mp.Queue, + radix_mem_queue: mp.Queue = None, + radix_info_queue: mp.Queue = None, + radix_lock: mp.Lock = None ): import lightllm.utils.rpyc_fix_utils as _ @@ -335,6 +360,9 @@ async def start_model_process( rpc_event, rpc_finished_event, success_event, + radix_mem_queue, + radix_info_queue, + radix_lock ), ) proc.start() diff --git a/test/server/test_hicache.py b/test/server/test_hicache.py new file mode 100644 index 000000000..bb82457c4 --- /dev/null +++ b/test/server/test_hicache.py @@ -0,0 +1,155 @@ +# test_hicache.py +import torch +import time +import random +from threading import Thread, Event +from queue import Queue +from lightllm.server.router.dynamic_prompt.cache_controller import ( + HiCacheController, + CacheNode, + BLOCK_SIZE, + HiHostService, + HiHostTask, +) + + +class MockMemoryManager: + """模拟内存管理器,仅返回连续的索引值""" + + def __init__(self): + self.current_idx = 0 + self.kvcache_store = {} + + def alloc(self, size): + indices = list(range(self.current_idx, self.current_idx + size)) + self.current_idx += size + self.store(indices, torch.tensor([[random.randint(0, 0xFFFF) for __ in range(512)] for _ in range(size)])) + return indices + + def load_index_kv_buffer(self, index, load_tensor_dict): + self.kvcache_store[index] = load_tensor_dict["kv_buffer"] + + def get_index_kv_buffer(self, index): + return {"kv_buffer": self.kvcache_store[index]} + + def to_kvcache(self, indices): + assert all( + [idx in self.kvcache_store for idx in indices] + ), f"Not all of {indices} are not found in kvcache_store" + return torch.tensor([self.kvcache_store[idx].tolist() for idx in indices]) + + def store(self, indices, value): + print(f"[TEST:MemManager] Storing {value.shape} at {indices}") + for idx, value_dim in zip(indices, range(value.shape[0])): + self.kvcache_store[idx] = value[value_dim] + print(f"[TEST:MemManager] Stored {value[value_dim].shape} at {idx}") + return indices + + def free(self, indices): + print(f"[TEST:MemManager] Freeing {indices}") + for idx in indices: + del self.kvcache_store[idx] + + +def setup(): + mem_manager = MockMemoryManager() + service = HiHostService() + hicache = HiCacheController(mem_manager) + hicache.service = service # 注入模拟服务 + + indices = mem_manager.alloc(5) + print(mem_manager.to_kvcache(indices)) + + # 预先计算单token大小 + dummy_indices = mem_manager.alloc(1) + kvcache = mem_manager.to_kvcache(dummy_indices[:1]) + token_size = kvcache.nelement() * kvcache.element_size() + print(f"[TEST] Single token KV cache size: {token_size} bytes, Block size: {BLOCK_SIZE}") + + return mem_manager, service, hicache, token_size + + +def test_basic_write_read(mem_manager, hicache, token_size): + # 计算每个块可容纳的token数量 + tokens_per_block = BLOCK_SIZE // token_size + print(f"[TEST] Each block can hold {tokens_per_block} tokens") + + # 生成测试数据:刚好占满一个块 + token_ids = list(range(tokens_per_block)) + indices = mem_manager.alloc(len(token_ids)) + kvcache = mem_manager.to_kvcache(indices) + print(f"[TEST] Generated KV cache with shape: {kvcache.shape}, type: {kvcache.dtype}") + + # 写入缓存 + hicache.write(torch.tensor(token_ids), torch.tensor(indices)) + time.sleep(2) + + # 等待任务完成 + hicache.service.wait_till_all_finished() + + mem_manager.free(indices) + + # 读取验证 + result = hicache.read(torch.tensor(token_ids)) + result = mem_manager.to_kvcache(result.tolist()) + assert result.eq(kvcache).all(), f"Retrieved kvcache: {result}, Expected kvcache: {kvcache}" + print("[TEST] Basic test passed. Retrieved kvcache\n\n") + + +def test_node_splitting(mem_manager, hicache, token_size): + tokens_per_block = BLOCK_SIZE // token_size + # 生成超过一个块的数据 + token_ids = list(range(12, 12 + tokens_per_block * 3 + 1)) + indices = mem_manager.alloc(len(token_ids)) + kvcache = mem_manager.to_kvcache(indices) + + hicache.write(torch.tensor(token_ids), torch.tensor(indices)) + time.sleep(2) + hicache.service.wait_till_all_finished() + + # 验证根节点应该有子节点 + root = hicache.root + assert len(root.children) > 0 + print(f"\nRoot node has {len(root.children)} children") + + # 读取完整序列 + result = hicache.read(torch.tensor(token_ids)) + result = mem_manager.to_kvcache(result.tolist()) + assert result.eq(kvcache).all(), f"Retrieved kvcache: {result}, Expected kvcache: {kvcache}" + print(f"[TEST] Node splitting test passed. Retrieved kvcache: {result.shape}\n\n") + + +def test_partial_read(mem_manager, hicache): + token_ids = [97, 98, 99, 100, 101, 102] + indices = mem_manager.alloc(len(token_ids)) + kvcache = mem_manager.to_kvcache(indices) + hicache.write(torch.tensor(token_ids), torch.tensor(indices)) + time.sleep(2) + hicache.service.wait_till_all_finished() + + # 查询存在的部分前缀 + result = hicache.read(torch.tensor([97, 98, 99])) + result = mem_manager.to_kvcache(result.tolist()) + assert result.eq(kvcache[:3]).all() + print("[TEST] Partial read passed") + + # 查询不存在的前缀 + result = hicache.read(torch.tensor([97, 98, 100])) + assert len(result) == 2 + result = mem_manager.to_kvcache(result.tolist()) + assert result.eq(kvcache[:2]).all() + print(f"[TEST] Non-existent prefix returned: {result.tolist()}") + + +def main(): + mem_manager, service, hicache, token_size = setup() + try: + test_basic_write_read(mem_manager, hicache, token_size) + test_node_splitting(mem_manager, hicache, token_size) + test_partial_read(mem_manager, hicache) + finally: + service.shutdown() + + +if __name__ == "__main__": + main()