Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions python/sglang/srt/managers/cache_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@

import torch

from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MHATokenToKVPoolHost
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPoolHost,
TokenToKVPoolAllocator,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -127,12 +130,12 @@ class HiCacheController:

def __init__(
self,
mem_pool_device: MHATokenToKVPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
mem_pool_host: MHATokenToKVPoolHost,
write_policy: str = "write_through_selective",
):

self.mem_pool_device = mem_pool_device
self.mem_pool_device_allocator = token_to_kv_pool_allocator
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
self.mem_pool_host = mem_pool_host
self.write_policy = write_policy

Expand Down Expand Up @@ -216,7 +219,7 @@ def load(
"""
Load KV caches from host memory to device memory.
"""
device_indices = self.mem_pool_device.alloc(len(host_indices))
device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
if device_indices is None:
return None
self.mem_pool_host.protect_load(host_indices)
Expand Down Expand Up @@ -417,7 +420,7 @@ def evict_device(
self, device_indices: torch.Tensor, host_indices: torch.Tensor
) -> int:
if self.mem_pool_host.is_synced(host_indices):
self.mem_pool_device.free(device_indices)
self.mem_pool_device_allocator.free(device_indices)
self.mem_pool_host.update_backup(host_indices)
return len(device_indices)
else:
Expand Down
33 changes: 17 additions & 16 deletions python/sglang/srt/mem_cache/hiradix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
MHATokenToKVPoolHost,
ReqToTokenPool,
TokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match

Expand All @@ -21,11 +21,13 @@ class HiRadixCache(RadixCache):
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool: MHATokenToKVPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
):
self.token_to_kv_pool_host = MHATokenToKVPoolHost(token_to_kv_pool)
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
token_to_kv_pool_allocator.get_kvcache()
)
self.cache_controller = HiCacheController(
token_to_kv_pool, self.token_to_kv_pool_host
token_to_kv_pool_allocator, self.token_to_kv_pool_host
)

# record the nodes with ongoing write through
Expand All @@ -35,7 +37,7 @@ def __init__(
# todo: dynamically adjust the threshold
self.write_through_threshold = 1
self.load_back_threshold = 10
super().__init__(req_to_token_pool, token_to_kv_pool, disable=False)
super().__init__(req_to_token_pool, token_to_kv_pool_allocator, disable=False)

def reset(self):
TreeNode.counter = 0
Expand Down Expand Up @@ -160,7 +162,7 @@ def _evict_write_through(self, node: TreeNode):

def _evict_write_through_selective(self, node: TreeNode):
# evict a node not initiated write to host
self.cache_controller.mem_pool_device.free(node.value)
self.cache_controller.mem_pool_device_allocator.free(node.value)
num_evicted = len(node.value)
self._delete_leaf(node)
return num_evicted
Expand Down Expand Up @@ -270,28 +272,27 @@ def init_load_back(

return last_node, prefix_indices

def _match_prefix_helper(
self, node: TreeNode, key: List, value, last_node: TreeNode
):
def _match_prefix_helper(self, node: TreeNode, key: List):
node.last_access_time = time.time()
if len(key) == 0:
return

if key[0] in node.children.keys():
value = []
while len(key) > 0 and key[0] in node.children.keys():
child = node.children[key[0]]
child.last_access_time = time.time()
prefix_len = _key_match(child.key, key)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
self.inc_hit_count(new_node)
if not new_node.evicted:
value.append(new_node.value)
last_node[0] = new_node
node = new_node
break
else:
self.inc_hit_count(child)
if not child.evicted:
value.append(child.value)
last_node[0] = child
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
node = child
key = key[prefix_len:]
return value, node

def _split_node(self, key, child: TreeNode, split_len: int):
# child node split into new_node -> child
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/mem_cache/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ class MHATokenToKVPoolHost:
def __init__(
self,
device_pool: MHATokenToKVPool,
host_to_device_ratio: float = 2.0,
host_to_device_ratio: float = 3.0,
pin_memory: bool = False, # no need to use pin memory with the double buffering
device: str = "cpu",
):
Expand Down
4 changes: 0 additions & 4 deletions python/sglang/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,10 @@
from IPython.display import HTML, display
from tqdm import tqdm

from sglang.srt.openai_api.protocol import ChatCompletionMessageContentPart
from sglang.srt.utils import kill_process_tree

logger = logging.getLogger(__name__)

# type of content fields, can be only prompts or with images/videos
MsgContent = Union[str, List[ChatCompletionMessageContentPart]]


def get_exception_traceback():
etype, value, tb = sys.exc_info()
Expand Down
Loading