Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 34 additions & 11 deletions python/sglang/srt/managers/cache_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,15 @@ def __init__(
self,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
mem_pool_host: HostKVCache,
page_size: int,
load_cache_event: threading.Event = None,
write_policy: str = "write_through_selective",
):
self.mem_pool_device_allocator = token_to_kv_pool_allocator
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
self.mem_pool_host = mem_pool_host
self.write_policy = write_policy
self.page_size = page_size

self.load_cache_event = load_cache_event
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
Expand Down Expand Up @@ -184,7 +186,12 @@ def __init__(
self.load_stream = torch.cuda.Stream()

self.write_thread = threading.Thread(
target=self.write_thread_func_buffer, daemon=True
target=(
self.write_thread_func_buffer
if self.page_size == 1
else self.write_thread_func_direct
),
daemon=True,
)
self.load_thread = threading.Thread(
target=self.load_thread_func_layer_by_layer, daemon=True
Expand All @@ -205,7 +212,12 @@ def reset(self):
self.ack_load_queue.queue.clear()

self.write_thread = threading.Thread(
target=self.write_thread_func_buffer, daemon=True
target=(
self.write_thread_func_buffer
if self.page_size == 1
else self.write_thread_func_direct
),
daemon=True,
)
self.load_thread = threading.Thread(
target=self.load_thread_func_layer_by_layer, daemon=True
Expand Down Expand Up @@ -260,10 +272,12 @@ def write_thread_func_direct(self):
while not self.stop_event.is_set():
try:
operation = self.write_queue.get(block=True, timeout=1)
operation.data = self.mem_pool_device.get_flat_data(
operation.device_indices
self.mem_pool_host.write_page_all_layers(
operation.host_indices,
operation.device_indices,
self.mem_pool_device,
)
self.mem_pool_host.transfer(operation.host_indices, operation.data)
self.write_stream.synchronize()
self.mem_pool_host.complete_io(operation.host_indices)
for node_id in operation.node_ids:
if node_id != 0:
Expand Down Expand Up @@ -320,12 +334,21 @@ def load_thread_func_layer_by_layer(self):

self.layer_done_counter.reset()
for i in range(self.mem_pool_host.layer_num):
flat_data = self.mem_pool_host.get_flat_data_by_layer(
batch_operation.host_indices, i
)
self.mem_pool_device.transfer_per_layer(
batch_operation.device_indices, flat_data, i
)
if self.page_size == 1:
flat_data = self.mem_pool_host.get_flat_data_by_layer(
batch_operation.host_indices, i
)
self.mem_pool_device.transfer_per_layer(
batch_operation.device_indices, flat_data, i
)
else:
self.mem_pool_host.load_page_per_layer(
batch_operation.host_indices,
batch_operation.device_indices,
self.mem_pool_device,
i,
)
self.load_stream.synchronize()
self.layer_done_counter.increment()

self.mem_pool_host.complete_io(batch_operation.host_indices)
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,7 +1282,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
]

if self.enable_hierarchical_cache:
self.tree_cache.read_to_load_cache()
self.tree_cache.ready_to_load_cache()

if adder.new_chunked_req is not None:
assert self.chunked_req is None
Expand Down
110 changes: 60 additions & 50 deletions python/sglang/srt/mem_cache/hiradix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
TokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
from sglang.srt.mem_cache.radix_cache import _key_match_page_size1 as _key_match

logger = logging.getLogger(__name__)

Expand All @@ -31,29 +30,25 @@ def __init__(
page_size: int,
hicache_ratio: float,
):
if page_size != 1:
raise ValueError(
"Page size larger than 1 is not yet supported in HiRadixCache."
)
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
if isinstance(self.kv_cache, MHATokenToKVPool):
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
self.kv_cache, hicache_ratio
self.kv_cache, hicache_ratio, page_size
)
elif isinstance(self.kv_cache, MLATokenToKVPool):
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
self.kv_cache, hicache_ratio
self.kv_cache, hicache_ratio, page_size
)
else:
raise ValueError(f"Only MHA and MLA supports swap kv_cache to host.")
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")

self.tp_group = tp_cache_group
self.page_size = page_size

self.load_cache_event = threading.Event()
self.cache_controller = HiCacheController(
token_to_kv_pool_allocator,
self.token_to_kv_pool_host,
page_size,
load_cache_event=self.load_cache_event,
)

Expand All @@ -65,7 +60,7 @@ def __init__(
self.write_through_threshold = 1
self.load_back_threshold = 10
super().__init__(
req_to_token_pool, token_to_kv_pool_allocator, self.page_size, disable=False
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
)

def reset(self):
Expand Down Expand Up @@ -299,18 +294,26 @@ def init_load_back(

return last_node, prefix_indices

def read_to_load_cache(self):
def ready_to_load_cache(self):
self.load_cache_event.set()

def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
if self.disable:
return [], self.root_node
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
if self.disable or len(key) == 0:
if include_evicted:
return empty_value, self.root_node, self.root_node
else:
return empty_value, self.root_node

if self.page_size != 1:
page_aligned_len = len(key) // self.page_size * self.page_size
key = key[:page_aligned_len]

value, last_node = self._match_prefix_helper(self.root_node, key)
if value:
value = torch.cat(value)
else:
value = torch.tensor([], dtype=torch.int64)
value = empty_value

last_node_global = last_node
while last_node.evicted:
Expand All @@ -323,11 +326,13 @@ def match_prefix(self, key: List[int], include_evicted=False, **kwargs):

def _match_prefix_helper(self, node: TreeNode, key: List):
node.last_access_time = time.time()
child_key = self.get_child_key_fn(key)
value = []
while len(key) > 0 and key[0] in node.children.keys():
child = node.children[key[0]]

while len(key) > 0 and child_key in node.children.keys():
child = node.children[child_key]
child.last_access_time = time.time()
prefix_len = _key_match(child.key, key)
prefix_len = self.key_match_fn(child.key, key)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
if not new_node.evicted:
Expand All @@ -339,12 +344,16 @@ def _match_prefix_helper(self, node: TreeNode, key: List):
value.append(child.value)
node = child
key = key[prefix_len:]

if len(key):
child_key = self.get_child_key_fn(key)

return value, node

def _split_node(self, key, child: TreeNode, split_len: int):
# child node split into new_node -> child
new_node = TreeNode()
new_node.children = {key[split_len]: child}
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
new_node.parent = child.parent
new_node.lock_ref = child.lock_ref
new_node.key = child.key[:split_len]
Expand All @@ -361,60 +370,61 @@ def _split_node(self, key, child: TreeNode, split_len: int):
child.host_value = child.host_value[split_len:]
child.parent = new_node
child.key = child.key[split_len:]
new_node.parent.children[key[0]] = new_node
new_node.parent.children[self.get_child_key_fn(key)] = new_node
return new_node

def _insert_helper(self, node: TreeNode, key: List, value):
node.last_access_time = time.time()
if len(key) == 0:
return 0

if key[0] in node.children.keys():
child = node.children[key[0]]
prefix_len = _key_match(child.key, key)
child_key = self.get_child_key_fn(key)
total_prefix_length = 0

while len(key) > 0 and child_key in node.children.keys():
node = node.children[child_key]
node.last_access_time = time.time()
prefix_len = self.key_match_fn(node.key, key)

if prefix_len == len(child.key):
if child.evicted:
if prefix_len == len(node.key):
if node.evicted:
# change the reference if the node is evicted
# this often happens in the case of KV cache recomputation
child.value = value[:prefix_len]
self.token_to_kv_pool_host.update_synced(child.host_value)
self.evictable_size_ += len(value[:prefix_len])
return self._insert_helper(
child, key[prefix_len:], value[prefix_len:]
)
node.value = value[:prefix_len]
self.token_to_kv_pool_host.update_synced(node.host_value)
self.evictable_size_ += len(node.value)
else:
self.inc_hit_count(child)
return prefix_len + self._insert_helper(
child, key[prefix_len:], value[prefix_len:]
)

# partial match, split the node
new_node = self._split_node(child.key, child, prefix_len)
if new_node.evicted:
new_node.value = value[:prefix_len]
self.token_to_kv_pool_host.update_synced(new_node.host_value)
self.evictable_size_ += len(new_node.value)
return self._insert_helper(
new_node, key[prefix_len:], value[prefix_len:]
)
self.inc_hit_count(node)
total_prefix_length += prefix_len
else:
self.inc_hit_count(new_node)
return prefix_len + self._insert_helper(
new_node, key[prefix_len:], value[prefix_len:]
)
# partial match, split the node
new_node = self._split_node(node.key, node, prefix_len)
if new_node.evicted:
new_node.value = value[:prefix_len]
self.token_to_kv_pool_host.update_synced(new_node.host_value)
self.evictable_size_ += len(new_node.value)
else:
self.inc_hit_count(new_node)
total_prefix_length += prefix_len
node = new_node

key = key[prefix_len:]
value = value[prefix_len:]

if len(key):
child_key = self.get_child_key_fn(key)

if len(key):
new_node = TreeNode()
new_node.parent = node
new_node.key = key
new_node.value = value
node.children[key[0]] = new_node
node.children[child_key] = new_node
self.evictable_size_ += len(value)

if self.cache_controller.write_policy == "write_through":
self.write_backup(new_node)
return 0
return total_prefix_length

def _collect_leaves_device(self):
def is_leaf(node):
Expand Down
Loading
Loading