Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions python/sglang/srt/managers/cache_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,11 @@ def __init__(
self.pp_size = pp_size
self.enable_storage_metrics = enable_storage_metrics

# Draft KV pool support (best-effort piggyback on target L2/L3 ops).
self.has_draft = False
self.mem_pool_device_draft = None
self.mem_pool_host_draft = None

# Default storage page IO functions (may be overridden by attach).
self.page_get_func = self._generic_page_get
self.page_set_func = self._generic_page_set
Expand Down Expand Up @@ -718,6 +723,13 @@ def start_writing(self) -> None:
self.mem_pool_host.backup_from_device_all_layer(
self.mem_pool_device, host_indices, device_indices, self.io_backend
)
if self.has_draft:
self.mem_pool_host_draft.backup_from_device_all_layer(
self.mem_pool_device_draft,
host_indices,
device_indices,
self.io_backend,
)
finish_event.record()
# NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are
Expand Down Expand Up @@ -791,6 +803,14 @@ def start_loading(self) -> int:
i,
self.io_backend,
)
if self.has_draft and i < self.mem_pool_host_draft.layer_num:
self.mem_pool_host_draft.load_to_device_per_layer(
self.mem_pool_device_draft,
host_indices,
device_indices,
i,
self.io_backend,
)
producer_event.complete(i)
# NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are
Expand Down Expand Up @@ -820,6 +840,17 @@ def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> in
self.mem_pool_host.free(host_indices)
return len(host_indices)

def set_draft_kv_pool(self, draft_device_pool, draft_host_pool) -> None:
"""Register draft KV pools so L2/L3 ops piggyback draft transfers."""
self.has_draft = True
self.mem_pool_device_draft = draft_device_pool
self.mem_pool_host_draft = draft_host_pool
logger.info(
"HiCache draft KV registered: %s (host %d slots)",
type(draft_device_pool).__name__,
draft_host_pool.size,
)

def prefetch(
self,
request_id: str,
Expand Down Expand Up @@ -895,6 +926,13 @@ def _page_transfer(self, operation):
batch_host_indices = operation.host_indices[
i * self.page_size : (i + len(batch_hashes)) * self.page_size
]

# Best-effort draft L3 read before publishing target completion.
# Otherwise wait_complete can race and load back target KV before
# draft KV reaches host memory.
if self.has_draft:
self._draft_page_get(batch_hashes, batch_host_indices)

prev_completed_tokens = operation.completed_tokens
# Get one batch token, and update the completed_tokens if succeed
extra_info = HiCacheStorageExtraInfo(prefix_keys=prefix_keys)
Expand Down Expand Up @@ -1045,6 +1083,45 @@ def _page_set_zero_copy(self, hash_values, host_indices, extra_info=None) -> boo
self.storage_backend.batch_set_v1(hash_values, host_indices, extra_info)
)

def _draft_page_set(self, hash_values, host_indices) -> None:
"""Best-effort write draft KV pages to L3 with 'd:' prefixed keys.

TODO: support batch_set_v1 (zero-copy) for high-performance backends.
"""
try:
draft_keys = [f"d:{h}" for h in hash_values]
draft_data = [
self.mem_pool_host_draft.get_data_page(host_indices[i * self.page_size])
for i in range(len(draft_keys))
]
self.storage_backend.batch_set(draft_keys, draft_data)
except Exception:
logger.debug(
"Draft L3 write failed (best-effort), skipping.", exc_info=True
)

def _draft_page_get(self, hash_values, host_indices) -> None:
"""Best-effort read draft KV pages from L3 with 'd:' prefixed keys.

TODO: support batch_get_v1 (zero-copy) for high-performance backends.
"""
try:
draft_keys = [f"d:{h}" for h in hash_values]
draft_dummy = [
self.mem_pool_host_draft.get_dummy_flat_data_page() for _ in draft_keys
]
draft_pages = self.storage_backend.batch_get(draft_keys, draft_dummy)
if draft_pages is None:
return

for i, p in enumerate(draft_pages):
if p is not None:
self.mem_pool_host_draft.set_from_flat_data_page(
host_indices[i * self.page_size], p
)
except Exception:
logger.debug("Draft L3 read failed (best-effort), skipping.", exc_info=True)

# Backup batch by batch
def _page_backup(self, operation):
# Backup batch by batch
Expand All @@ -1064,6 +1141,10 @@ def _page_backup(self, operation):
)
break

# Best-effort draft L3 write alongside target.
if self.has_draft:
self._draft_page_set(batch_hashes, batch_host_indices)

if prefix_keys and len(prefix_keys) > 0:
prefix_keys += batch_hashes
operation.completed_tokens += self.page_size * len(batch_hashes)
Expand Down
81 changes: 68 additions & 13 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,9 @@ def __init__(
# Init cache and memory pool
self.init_cache_with_memory_pool()

# Register draft KV pool (when spec + HiCache co-enabled).
self._maybe_register_hicache_draft()

# Init running status
self.init_running_status()

Expand Down Expand Up @@ -917,6 +920,69 @@ def init_cache_with_memory_pool(self):
embedding_cache_size = envs.SGLANG_VLM_CACHE_SIZE_MB.get()
init_mm_embedding_cache(embedding_cache_size * 1024 * 1024)

def _get_draft_kv_pool(self):
"""Return (draft_token_to_kv_pool, draft_model_config) for the current
draft worker, or (None, None) when no draft KV pool is available."""
if self.draft_worker is None or self.spec_algorithm.is_ngram():
return None, None

if self.spec_algorithm.supports_spec_v2() and self.enable_overlap:
if self.server_args.enable_multi_layer_eagle:
draft_runner = self.draft_worker.draft_worker.draft_runner_list[0]
else:
draft_runner = self.draft_worker.draft_worker.draft_runner
return draft_runner.token_to_kv_pool, draft_runner.model_config

return (
self.draft_worker.model_runner.token_to_kv_pool,
self.draft_worker.model_config,
)

def _maybe_register_hicache_draft(self) -> None:
"""Register draft KV pool with HiCacheController for piggyback L2/L3 ops."""
if not self.enable_hierarchical_cache:
return

draft_kv_pool, _ = self._get_draft_kv_pool()
if draft_kv_pool is None:
return

from sglang.srt.mem_cache.memory_pool import (
HybridLinearKVPool,
MHATokenToKVPool,
MLATokenToKVPool,
)
from sglang.srt.mem_cache.memory_pool_host import (
MHATokenToKVPoolHost,
MLATokenToKVPoolHost,
)

pool = draft_kv_pool
if isinstance(pool, HybridLinearKVPool):
pool = pool.full_kv_pool

# Create host pool for draft with the same slot count as the target host pool,
# so that host indices stay 1-to-1 between target and draft KV caches.
primary = self.tree_cache.cache_controller.mem_pool_host
kw = dict(
host_to_device_ratio=primary.size / pool.size,
host_size=0,
page_size=self.page_size,
layout=self.server_args.hicache_mem_layout,
)
if isinstance(pool, MHATokenToKVPool):
draft_host_pool = MHATokenToKVPoolHost(pool, **kw)
elif isinstance(pool, MLATokenToKVPool):
draft_host_pool = MLATokenToKVPoolHost(pool, **kw)
else:
logger.warning(
"Draft pool type %s not supported for HiCache, skipping.",
type(pool).__name__,
)
return

self.tree_cache.cache_controller.set_draft_kv_pool(pool, draft_host_pool)

def init_running_status(self):
self.waiting_queue: List[Req] = []
# The running decoding batch for continuous batching
Expand Down Expand Up @@ -1065,19 +1131,8 @@ def init_disaggregation(self):
self.server_args.disaggregation_transfer_backend
)

if self.draft_worker is None or self.spec_algorithm.is_ngram():
draft_token_to_kv_pool = None
elif self.spec_algorithm.supports_spec_v2() and self.enable_overlap:
if self.server_args.enable_multi_layer_eagle:
draft_runner = self.draft_worker.draft_worker.draft_runner_list[0]
else:
draft_runner = self.draft_worker.draft_worker.draft_runner
draft_token_to_kv_pool = draft_runner.token_to_kv_pool
model_config = draft_runner.model_config
else:
# todo: should we fix this when enabling mtp or it doesn't matter since we only enable mtp in decode node thus we don't transfer draft kvs between P and D?
draft_token_to_kv_pool = self.draft_worker.model_runner.token_to_kv_pool
model_config = self.draft_worker.model_config
# todo: should we fix this when enabling mtp or it doesn't matter since we only enable mtp in decode node thus we don't transfer draft kvs between P and D?
draft_token_to_kv_pool, model_config = self._get_draft_kv_pool()

if (
self.disaggregation_mode == DisaggregationMode.DECODE
Expand Down
Loading
Loading