Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions python/sglang/srt/managers/cache_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,10 @@ def __init__(
self.pp_rank = pp_rank
self.pp_size = pp_size

self.has_draft_kv_pool = False
self.mem_pool_device_draft = None
self.mem_pool_host_draft = None

# Default storage page IO functions (may be overridden by attach).
self.page_get_func = self._generic_page_get
self.page_set_func = self._generic_page_set
Expand Down Expand Up @@ -661,6 +665,13 @@ def start_writing(self) -> None:
self.mem_pool_host.backup_from_device_all_layer(
self.mem_pool_device, host_indices, device_indices, self.io_backend
)
if self.has_draft_kv_pool:
self.mem_pool_host_draft.backup_from_device_all_layer(
self.mem_pool_device_draft,
host_indices,
device_indices,
self.io_backend,
)
finish_event.record()
# NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are
Expand Down Expand Up @@ -729,6 +740,15 @@ def start_loading(self) -> int:
i,
self.io_backend,
)
# TODO: only loading when they are used in current drafting
if self.has_draft_kv_pool and i < self.mem_pool_host_draft.layer_num:
self.mem_pool_host_draft.load_to_device_per_layer(
self.mem_pool_device_draft,
host_indices,
device_indices,
i,
self.io_backend,
)
producer_event.complete(i)
# NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are
Expand Down Expand Up @@ -758,6 +778,19 @@ def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> in
self.mem_pool_host.free(host_indices)
return len(host_indices)

def set_draft_kv_pool(self, draft_kv_pool, draft_host_kv_pool):
"""
Set draft model KV pool for EAGLE speculative decoding.
This should be called by the scheduler after HiCacheController initialization.

Args:
draft_kv_pool: The draft model's device KV cache pool
draft_host_kv_pool: The draft model's host KV cache pool
"""
self.has_draft_kv_pool = True
self.mem_pool_device_draft = draft_kv_pool
self.mem_pool_host_draft = draft_host_kv_pool

def prefetch(
self,
request_id: str,
Expand Down
57 changes: 57 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,56 @@ def __init__(

self.is_initializing = False

def _register_draft_kv_pool_for_hicache(self, server_args):
"""Register draft model KV pool with HiCache for EAGLE speculative decoding"""
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
from sglang.srt.mem_cache.memory_pool_host import (
MHATokenToKVPoolHost,
MLATokenToKVPoolHost,
)

# Get draft model's KV cache pool
if self.enable_overlap:
if self.server_args.enable_multi_layer_eagle:
draft_runner = self.draft_worker.draft_worker.draft_runner_list[0]
else:
draft_runner = self.draft_worker.draft_worker.draft_runner
else:
draft_runner = self.draft_worker.draft_model_runner

draft_kv_pool = draft_runner.token_to_kv_pool
# Create host KV cache pool for draft model
if isinstance(draft_kv_pool, MHATokenToKVPool):
draft_host_kv_pool = MHATokenToKVPoolHost(
draft_kv_pool,
server_args.hicache_ratio,
server_args.hicache_size,
self.page_size,
server_args.hicache_mem_layout,
)
elif isinstance(draft_kv_pool, MLATokenToKVPool):
draft_host_kv_pool = MLATokenToKVPoolHost(
draft_kv_pool,
server_args.hicache_ratio,
server_args.hicache_size,
self.page_size,
server_args.hicache_mem_layout,
)
else:
logger.warning(
f"Draft KV pool type {type(draft_kv_pool).__name__} not supported for HiCache, "
"draft model KV cache will not be backed up/restored"
)
return

# Register with HiCacheController
self.tree_cache.cache_controller.set_draft_kv_pool(
draft_kv_pool, draft_host_kv_pool
)
logger.info(
f"Registered draft model KV pool ({type(draft_kv_pool).__name__}) with HiCache"
)

def init_model_config(self):
self.model_config = ModelConfig.from_server_args(self.server_args)

Expand Down Expand Up @@ -678,6 +728,13 @@ def init_cache_with_memory_pool(self):
self.tp_worker.register_hicache_layer_transfer_counter(
self.tree_cache.cache_controller.layer_done_counter
)
if self.spec_algorithm.is_eagle():
# NOTE: I believe there are bugs when enabling HiCache L3 with EAGLE.
# But for compatibility commented out the assertion for now.
# assert not self.server_args.enable_hicache_storage, (
# "L3 cache with HiCache storage backend is not supported for EAGLE speculative decoding"
# )
self._register_draft_kv_pool_for_hicache(server_args)
elif self.is_hybrid_swa:
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache

Expand Down
Loading