Skip to content
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
257808d
fix layer done counter sync and hiradix pre-compute
xiezhq-hermann Jun 17, 2025
9766794
Merge branch 'main' into xiezhq-hicache-upstream
xiezhq-hermann Jun 17, 2025
c5a23b1
Merge branch 'main' into xiezhq-hicache-upstream
xiezhq-hermann Jun 18, 2025
897a66f
kv cache io kernels
xiezhq-hermann Jun 18, 2025
95bc313
layout change on memory pool and interfaces for the new kernels
xiezhq-hermann Jun 18, 2025
f8ea5e9
Merge branch 'main' into xiezhq-hicache-upstream
xiezhq-hermann Jun 18, 2025
28a2e8e
Merge branch 'main' into xiezhq-hicache-upstream
xiezhq-hermann Jun 24, 2025
3a92509
Merge branch 'main' into xiezhq-hicache-upstream
xiezhq-hermann Jun 24, 2025
ddddf42
Merge branch 'main' into xiezhq-hicache-upstream
xiezhq-hermann Jun 25, 2025
4146051
bump to sgl-kernel 2.0
xiezhq-hermann Jun 25, 2025
554a91b
Merge branch 'main' into xiezhq-hicache-upstream
xiezhq-hermann Jun 27, 2025
3546e02
Merge branch 'main' into xiezhq-hicache-upstream
xiezhq-hermann Jun 30, 2025
4b816da
io backend parameter
xiezhq-hermann Jul 1, 2025
5e730e5
hicache storage prototype
xiezhq-hermann Jul 2, 2025
f75ff68
update the prefetching logics
xiezhq-hermann Jul 2, 2025
516abc6
fall back to the non-continuous memory pool
xiezhq-hermann Jul 4, 2025
fd74ed0
Merge branch 'main' into xiezhq-hicache-upstream
xiezhq-hermann Jul 4, 2025
7cfc9ad
refactoring
xiezhq-hermann Jul 5, 2025
c05e19d
Merge branch 'main' into xiezhq-hicache-upstream
xiezhq-hermann Jul 5, 2025
3f38e28
fix
xiezhq-hermann Jul 5, 2025
f82cc02
Merge branch 'xiezhq-hicache-upstream' into xiezhq-hicache-storage
xiezhq-hermann Jul 5, 2025
2bbd0c5
refactoring
xiezhq-hermann Jul 7, 2025
528395a
Merge branch 'main' into xiezhq-hicache-storage
xiezhq-hermann Jul 7, 2025
b07e8bb
better synchronization and styles
xiezhq-hermann Jul 7, 2025
b55b990
server arg fix
xiezhq-hermann Jul 7, 2025
70046c3
Merge branch 'main' into xiezhq-hicache-storage
xiezhq-hermann Jul 7, 2025
2fab3ca
bug fix and style improvement
xiezhq-hermann Jul 8, 2025
b90604b
Update python/sglang/srt/mem_cache/hicache_storage.py
xiezhq-hermann Jul 11, 2025
f7b6280
Update python/sglang/srt/mem_cache/radix_cache.py
xiezhq-hermann Jul 11, 2025
ae77cc7
fix data race and minor refinement
xiezhq-hermann Jul 11, 2025
29620f7
Merge branch 'main' into xiezhq-hicache-storage
xiezhq-hermann Jul 11, 2025
849d966
sanity check for hicache_storage
xiezhq-hermann Jul 15, 2025
44068a0
Merge remote-tracking branch 'origin/xiezhq-hicache-storage' into xie…
xiezhq-hermann Jul 15, 2025
95657a3
Merge branch 'main' into xiezhq-hicache-storage
xiezhq-hermann Jul 15, 2025
7c18e06
style improvement
xiezhq-hermann Jul 15, 2025
384f3f5
Merge branch 'main' into xiezhq-hicache-storage
xiezhq-hermann Jul 15, 2025
f62ebe2
Merge branch 'main' into xiezhq-hicache-storage
zhyncs Jul 16, 2025
3462237
Merge branch 'main' into xiezhq-hicache-storage
xiezhq-hermann Jul 16, 2025
6837d40
Merge branch 'main' into xiezhq-hicache-storage
xiezhq-hermann Jul 16, 2025
22b913d
Merge branch 'main' into xiezhq-hicache-storage
xiezhq-hermann Jul 16, 2025
4bc693e
Merge branch 'main' into xiezhq-hicache-storage
xiezhq-hermann Jul 17, 2025
8c1c46c
remove the delete and clear interfaces
xiezhq-hermann Jul 17, 2025
1c9a64e
Merge branch 'main' into xiezhq-hicache-storage
xiezhq-hermann Jul 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 248 additions & 0 deletions python/sglang/srt/managers/cache_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool_host import HostKVCache

from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -159,6 +161,57 @@ def clear(self):
self.buffers.queue.clear()


class StorageOperation:
counter = 0

def __init__(
self,
host_indices: torch.Tensor,
token_ids: List[int],
last_hash: Optional[str] = None,
):
self.host_indices = host_indices
self.token_ids = token_ids
self.last_hash = last_hash
self.completed_tokens = 0
self.hash_value = []

self.id = StorageOperation.counter
StorageOperation.counter += 1

def __lt__(self, other: "StorageOperation"):
return self.id < other.id


class PrefetchOperation(StorageOperation):
def __init__(
self,
request_id: str,
host_indices: torch.Tensor,
token_ids: List[int],
last_hash: Optional[str] = None,
):
self.request_id = request_id

self._done_flag = False
self._lock = threading.Lock()

super().__init__(host_indices, token_ids, last_hash)

def increment(self, num_tokens: int):
with self._lock:
if self._done_flag:
return
self.completed_tokens += num_tokens

def mark_done(self):
with self._lock:
self._done_flag = True

def is_done(self) -> bool:
return self._done_flag


class HiCacheController:

def __init__(
Expand All @@ -169,6 +222,8 @@ def __init__(
load_cache_event: threading.Event = None,
write_policy: str = "write_through_selective",
io_backend: str = "",
storage_backend: Optional[str] = None,
prefetch_threshold: int = 256,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonder if we can pass tp_rank and pp_rank information to this function. When dealing with file systems that aren't optimized for handling numerous small files, it might be more practical to have each rank correspond to a single large file. Including model parallelism (tp/pp) information in filenames would help us manage sglang instances and their associated KVCache files more effectively.

):
self.mem_pool_device_allocator = token_to_kv_pool_allocator
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
Expand All @@ -186,6 +241,21 @@ def __init__(
else:
self.io_backend = io_backend

self.enable_storage = False
# todo: move backend initialization to storage backend module
if storage_backend is not None:
if storage_backend == "file":
self.storage_backend = HiCacheFile()
self.enable_storage = True
# tracking prefetch operation progress
self.ongoing_prefetch: dict[int, PrefetchOperation] = {}
# todo: threshold policy for prefetching
self.prefetch_threshold = prefetch_threshold
else:
raise NotImplementedError(
f"Unsupported storage backend: {storage_backend}"
)

self.load_cache_event = load_cache_event
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
Expand Down Expand Up @@ -218,9 +288,26 @@ def __init__(
self.load_thread = threading.Thread(
target=self.load_thread_func_layer_by_layer, daemon=True
)

self.write_thread.start()
self.load_thread.start()

if self.enable_storage:
self.prefetch_thread = threading.Thread(
target=self.prefetch_thread_func, daemon=True
)
self.backup_thread = threading.Thread(
target=self.backup_thread_func, daemon=True
)
self.prefetch_queue = Queue()
self.backup_queue = Queue()

self.prefetch_revoke_queue = Queue()
self.ack_backup_queue = Queue()

self.prefetch_thread.start()
self.backup_thread.start()

def reset(self):
self.stop_event.set()
self.write_thread.join()
Expand All @@ -232,6 +319,13 @@ def reset(self):
self.load_buffer.clear()
self.ack_write_queue.queue.clear()
self.ack_load_queue.queue.clear()
if self.enable_storage:
self.prefetch_thread.join()
self.backup_thread.join()
self.prefetch_queue.queue.clear()
self.backup_queue.queue.clear()
self.prefetch_revoke_queue.queue.clear()
self.ack_backup_queue.queue.clear()

self.write_thread = threading.Thread(
target=self.write_thread_func_direct, daemon=True
Expand All @@ -243,6 +337,16 @@ def reset(self):
self.write_thread.start()
self.load_thread.start()

if self.enable_storage:
self.prefetch_thread = threading.Thread(
target=self.prefetch_thread_func, daemon=True
)
self.backup_thread = threading.Thread(
target=self.backup_thread_func, daemon=True
)
self.prefetch_thread.start()
self.backup_thread.start()

def write(
self,
device_indices: torch.Tensor,
Expand Down Expand Up @@ -383,3 +487,147 @@ def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> in
raise ValueError(
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
)

def prefetch(
self,
request_id: str,
host_indices: torch.Tensor,
new_input_tokens: List[int],
last_hash: Optional[str] = None,
) -> int:
"""
Prefetch KV caches from storage backend to host memory.
"""
operation = PrefetchOperation(
request_id, host_indices, new_input_tokens, last_hash
)
self.ongoing_prefetch[request_id] = operation
self.prefetch_queue.put(operation)

def terminate_prefetch(self, request_id: str):
operation = self.ongoing_prefetch.pop(request_id, None)
if operation is None:
raise ValueError(
f"Request ID {request_id} not found in ongoing prefetches."
)
operation.mark_done()
return operation.completed_tokens, operation.hash_value

def prefetch_io_aux_func(self):
"""
Auxiliary function conducting IO operations for prefetching.
"""
while not self.stop_event.is_set():
try:
operation = self.prefetch_buffer.get(block=True, timeout=1)
for h in operation.hash_value:
page_data = self.storage_backend.get(h)
if page_data is None:
logger.warning(
f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
)
break
self.mem_pool_host.set_from_flat_data_page(
operation.host_indices[operation.completed_tokens],
page_data,
)
operation.increment(self.page_size)
if operation.is_done():
# operation terminated by controller, release pre-allocated memory
self.mem_pool_host.free(
operation.host_indices[operation.completed_tokens :]
)
break
except Empty:
continue

def prefetch_thread_func(self):
"""
Manage prefetching operations from storage backend to host memory.
"""
self.prefetch_buffer = Queue()
aux_thread = threading.Thread(target=self.prefetch_io_aux_func, daemon=True)
aux_thread.start()
while (not self.stop_event.is_set()) or not self.prefetch_queue.empty():
try:
operation = self.prefetch_queue.get(block=True, timeout=1)
if operation is None:
continue

last_hash = operation.last_hash
tokens_to_fetch = operation.token_ids

storage_hit_count = 0
remaining_tokens = len(tokens_to_fetch)
hash_value = []
while remaining_tokens >= self.page_size:
last_hash = get_hash_str(
tokens_to_fetch[
storage_hit_count : storage_hit_count + self.page_size
],
last_hash,
)
if self.storage_backend.exists(last_hash):
storage_hit_count += self.page_size
hash_value.append(last_hash)
remaining_tokens -= self.page_size
else:
break

if storage_hit_count < self.prefetch_threshold:
# not to prefetch if not enough benefits
self.prefetch_revoke_queue.put(operation.request_id)
else:
operation.hash_value = hash_value
logger.debug(
f"Prefetching {len(hash_value)} pages for request {operation.request_id}."
)
self.prefetch_buffer.put(operation)

except Empty:
continue

def write_storage(
self,
host_indices: torch.Tensor,
token_ids: List[int],
last_hash: Optional[str] = None,
) -> int:
"""
Write KV caches from host memory to storage backend.
"""
operation = StorageOperation(host_indices, token_ids, last_hash)
self.backup_queue.put(operation)
return operation.id

def backup_thread_func(self):
"""
Manage backup operations from host memory to storage backend.
"""
while not self.stop_event.is_set():
try:
operation = self.backup_queue.get(block=True, timeout=1)
if operation is None:
continue

last_hash = operation.last_hash
tokens_to_backup = operation.token_ids

for i in range(0, len(tokens_to_backup), self.page_size):
last_hash = get_hash_str(
tokens_to_backup[i : i + self.page_size], last_hash
)
# todo, handle failures in storage backend
self.storage_backend.set(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can batch_set be used here to boost I/O? IO coalescing is required in scenarios like 3fs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes these will be abstracted out for different options

last_hash,
self.mem_pool_host.get_flat_data_page(
operation.host_indices[i]
),
)
operation.completed_tokens += self.page_size
operation.hash_value.append(last_hash)

self.ack_backup_queue.put((operation.id, operation.hash_value))

except Empty:
continue
14 changes: 14 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def __init__(
)
self.gpu_id = gpu_id
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
self.enable_hicache_storage = server_args.hicache_storage_backend is not None
self.page_size = server_args.page_size
self.dp_size = server_args.dp_size
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
Expand Down Expand Up @@ -599,6 +600,7 @@ def init_memory_pool_and_cache(self):
== "fa3" # hot fix for incompatibility
else server_args.hicache_io_backend
),
hicache_storage_backend=server_args.hicache_storage_backend,
)
self.tp_worker.register_hicache_layer_transfer_counter(
self.tree_cache.cache_controller.layer_done_counter
Expand Down Expand Up @@ -1220,6 +1222,15 @@ def _add_request_to_queue(self, req: Req):
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.disagg_decode_prealloc_queue.add(req)
else:
if self.enable_hicache_storage:
req.init_next_round_input(self.tree_cache)
last_hash = req.last_host_node.get_last_hash_value()
matched_len = len(req.prefix_indices) + req.host_hit_length
if (matched_len > 0 and last_hash is not None) or matched_len == 0:
new_input_tokens = req.fill_ids[matched_len:]
self.tree_cache.prefetch_from_storage(
req.rid, req.last_host_node, new_input_tokens, last_hash
)
self.waiting_queue.append(req)

def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
Expand Down Expand Up @@ -1600,6 +1611,9 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
self.running_batch.batch_is_full = True
break

if self.enable_hicache_storage:
self.tree_cache.check_prefetch_progress(req.rid)

req.init_next_round_input(self.tree_cache)
res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))

Expand Down
Loading
Loading