diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 652b71b9a745..80e39feedb82 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -287,6 +287,11 @@ def __init__( self.pp_size = pp_size self.enable_storage_metrics = enable_storage_metrics + # Draft KV pool support (best-effort piggyback on target L2/L3 ops). + self.has_draft = False + self.mem_pool_device_draft = None + self.mem_pool_host_draft = None + # Default storage page IO functions (may be overridden by attach). self.page_get_func = self._generic_page_get self.page_set_func = self._generic_page_set @@ -718,6 +723,13 @@ def start_writing(self) -> None: self.mem_pool_host.backup_from_device_all_layer( self.mem_pool_device, host_indices, device_indices, self.io_backend ) + if self.has_draft: + self.mem_pool_host_draft.backup_from_device_all_layer( + self.mem_pool_device_draft, + host_indices, + device_indices, + self.io_backend, + ) finish_event.record() # NOTE: We must save the host indices and device indices here, # this is because we need to guarantee that these tensors are @@ -791,6 +803,14 @@ def start_loading(self) -> int: i, self.io_backend, ) + if self.has_draft and i < self.mem_pool_host_draft.layer_num: + self.mem_pool_host_draft.load_to_device_per_layer( + self.mem_pool_device_draft, + host_indices, + device_indices, + i, + self.io_backend, + ) producer_event.complete(i) # NOTE: We must save the host indices and device indices here, # this is because we need to guarantee that these tensors are @@ -820,6 +840,17 @@ def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> in self.mem_pool_host.free(host_indices) return len(host_indices) + def set_draft_kv_pool(self, draft_device_pool, draft_host_pool) -> None: + """Register draft KV pools so L2/L3 ops piggyback draft transfers.""" + self.has_draft = True + self.mem_pool_device_draft = draft_device_pool + self.mem_pool_host_draft = draft_host_pool + logger.info( + "HiCache draft KV registered: %s (host %d slots)", + type(draft_device_pool).__name__, + draft_host_pool.size, + ) + def prefetch( self, request_id: str, @@ -895,6 +926,13 @@ def _page_transfer(self, operation): batch_host_indices = operation.host_indices[ i * self.page_size : (i + len(batch_hashes)) * self.page_size ] + + # Best-effort draft L3 read before publishing target completion. + # Otherwise wait_complete can race and load back target KV before + # draft KV reaches host memory. + if self.has_draft: + self._draft_page_get(batch_hashes, batch_host_indices) + prev_completed_tokens = operation.completed_tokens # Get one batch token, and update the completed_tokens if succeed extra_info = HiCacheStorageExtraInfo(prefix_keys=prefix_keys) @@ -1045,6 +1083,45 @@ def _page_set_zero_copy(self, hash_values, host_indices, extra_info=None) -> boo self.storage_backend.batch_set_v1(hash_values, host_indices, extra_info) ) + def _draft_page_set(self, hash_values, host_indices) -> None: + """Best-effort write draft KV pages to L3 with 'd:' prefixed keys. + + TODO: support batch_set_v1 (zero-copy) for high-performance backends. + """ + try: + draft_keys = [f"d:{h}" for h in hash_values] + draft_data = [ + self.mem_pool_host_draft.get_data_page(host_indices[i * self.page_size]) + for i in range(len(draft_keys)) + ] + self.storage_backend.batch_set(draft_keys, draft_data) + except Exception: + logger.debug( + "Draft L3 write failed (best-effort), skipping.", exc_info=True + ) + + def _draft_page_get(self, hash_values, host_indices) -> None: + """Best-effort read draft KV pages from L3 with 'd:' prefixed keys. + + TODO: support batch_get_v1 (zero-copy) for high-performance backends. + """ + try: + draft_keys = [f"d:{h}" for h in hash_values] + draft_dummy = [ + self.mem_pool_host_draft.get_dummy_flat_data_page() for _ in draft_keys + ] + draft_pages = self.storage_backend.batch_get(draft_keys, draft_dummy) + if draft_pages is None: + return + + for i, p in enumerate(draft_pages): + if p is not None: + self.mem_pool_host_draft.set_from_flat_data_page( + host_indices[i * self.page_size], p + ) + except Exception: + logger.debug("Draft L3 read failed (best-effort), skipping.", exc_info=True) + # Backup batch by batch def _page_backup(self, operation): # Backup batch by batch @@ -1064,6 +1141,10 @@ def _page_backup(self, operation): ) break + # Best-effort draft L3 write alongside target. + if self.has_draft: + self._draft_page_set(batch_hashes, batch_host_indices) + if prefix_keys and len(prefix_keys) > 0: prefix_keys += batch_hashes operation.completed_tokens += self.page_size * len(batch_hashes) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 225b0ee6671d..1090468d46eb 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -430,6 +430,9 @@ def __init__( # Init cache and memory pool self.init_cache_with_memory_pool() + # Register draft KV pool (when spec + HiCache co-enabled). + self._maybe_register_hicache_draft() + # Init running status self.init_running_status() @@ -917,6 +920,69 @@ def init_cache_with_memory_pool(self): embedding_cache_size = envs.SGLANG_VLM_CACHE_SIZE_MB.get() init_mm_embedding_cache(embedding_cache_size * 1024 * 1024) + def _get_draft_kv_pool(self): + """Return (draft_token_to_kv_pool, draft_model_config) for the current + draft worker, or (None, None) when no draft KV pool is available.""" + if self.draft_worker is None or self.spec_algorithm.is_ngram(): + return None, None + + if self.spec_algorithm.supports_spec_v2() and self.enable_overlap: + if self.server_args.enable_multi_layer_eagle: + draft_runner = self.draft_worker.draft_worker.draft_runner_list[0] + else: + draft_runner = self.draft_worker.draft_worker.draft_runner + return draft_runner.token_to_kv_pool, draft_runner.model_config + + return ( + self.draft_worker.model_runner.token_to_kv_pool, + self.draft_worker.model_config, + ) + + def _maybe_register_hicache_draft(self) -> None: + """Register draft KV pool with HiCacheController for piggyback L2/L3 ops.""" + if not self.enable_hierarchical_cache: + return + + draft_kv_pool, _ = self._get_draft_kv_pool() + if draft_kv_pool is None: + return + + from sglang.srt.mem_cache.memory_pool import ( + HybridLinearKVPool, + MHATokenToKVPool, + MLATokenToKVPool, + ) + from sglang.srt.mem_cache.memory_pool_host import ( + MHATokenToKVPoolHost, + MLATokenToKVPoolHost, + ) + + pool = draft_kv_pool + if isinstance(pool, HybridLinearKVPool): + pool = pool.full_kv_pool + + # Create host pool for draft with the same slot count as the target host pool, + # so that host indices stay 1-to-1 between target and draft KV caches. + primary = self.tree_cache.cache_controller.mem_pool_host + kw = dict( + host_to_device_ratio=primary.size / pool.size, + host_size=0, + page_size=self.page_size, + layout=self.server_args.hicache_mem_layout, + ) + if isinstance(pool, MHATokenToKVPool): + draft_host_pool = MHATokenToKVPoolHost(pool, **kw) + elif isinstance(pool, MLATokenToKVPool): + draft_host_pool = MLATokenToKVPoolHost(pool, **kw) + else: + logger.warning( + "Draft pool type %s not supported for HiCache, skipping.", + type(pool).__name__, + ) + return + + self.tree_cache.cache_controller.set_draft_kv_pool(pool, draft_host_pool) + def init_running_status(self): self.waiting_queue: List[Req] = [] # The running decoding batch for continuous batching @@ -1065,19 +1131,8 @@ def init_disaggregation(self): self.server_args.disaggregation_transfer_backend ) - if self.draft_worker is None or self.spec_algorithm.is_ngram(): - draft_token_to_kv_pool = None - elif self.spec_algorithm.supports_spec_v2() and self.enable_overlap: - if self.server_args.enable_multi_layer_eagle: - draft_runner = self.draft_worker.draft_worker.draft_runner_list[0] - else: - draft_runner = self.draft_worker.draft_worker.draft_runner - draft_token_to_kv_pool = draft_runner.token_to_kv_pool - model_config = draft_runner.model_config - else: - # todo: should we fix this when enabling mtp or it doesn't matter since we only enable mtp in decode node thus we don't transfer draft kvs between P and D? - draft_token_to_kv_pool = self.draft_worker.model_runner.token_to_kv_pool - model_config = self.draft_worker.model_config + # todo: should we fix this when enabling mtp or it doesn't matter since we only enable mtp in decode node thus we don't transfer draft kvs between P and D? + draft_token_to_kv_pool, model_config = self._get_draft_kv_pool() if ( self.disaggregation_mode == DisaggregationMode.DECODE diff --git a/test/registered/hicache/test_hicache_spec_file_storage.py b/test/registered/hicache/test_hicache_spec_file_storage.py new file mode 100644 index 000000000000..e406583dfde0 --- /dev/null +++ b/test/registered/hicache/test_hicache_spec_file_storage.py @@ -0,0 +1,303 @@ +""" +E2E test for HiCache file storage with EAGLE3 speculative decoding. + +Usage: + python3 -m pytest test/registered/hicache/test_hicache_spec_file_storage.py -v +""" + +import json +import os +import shutil +import tempfile +import time +import unittest +from typing import Dict, List + +import psutil +import requests + +from sglang.benchmark.utils import get_tokenizer +from sglang.srt.utils import is_hip, kill_process_tree +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.test_utils import ( + DEFAULT_DRAFT_MODEL_EAGLE3, + DEFAULT_TARGET_MODEL_EAGLE3, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + find_available_port, + popen_launch_server, +) +from sglang.utils import wait_for_http_ready + +register_cuda_ci(est_time=600, suite="stage-b-test-1-gpu-large") + + +@unittest.skipIf(is_hip(), "HiCache + EAGLE3 file-storage loadback e2e is CUDA-only.") +class TestHiCacheSpecFileStorage(CustomTestCase): + model = DEFAULT_TARGET_MODEL_EAGLE3 + draft_model = DEFAULT_DRAFT_MODEL_EAGLE3 + + input_token_len = 1024 + max_new_tokens = 200 + page_size = 64 + min_expected_accept_length = 7.0 + min_second_to_first_accept_ratio = 0.9 + storage_wait_timeout = 30 + first_measure_new_tokens = 128 + + @classmethod + def setUpClass(cls): + cls.temp_dir = tempfile.mkdtemp() + default_port = int(DEFAULT_URL_FOR_TEST.rsplit(":", 1)[1]) + cls.base_url = f"http://127.0.0.1:{find_available_port(default_port)}" + + cls.tokenizer = get_tokenizer(cls.model) + cls.prompt_input_ids = cls._build_long_repetitive_prompt_ids( + cls.tokenizer, cls.input_token_len + ) + + extra_config = { + "hicache_storage_pass_prefix_keys": True, + } + cls.other_args = [ + "--enable-hierarchical-cache", + "--enable-cache-report", + "--mem-fraction-static", + "0.3", + "--hicache-ratio", + "1.5", + "--disable-cuda-graph", + "--page-size", + str(cls.page_size), + "--hicache-storage-backend", + "file", + "--hicache-storage-prefetch-policy", + "wait_complete", + "--hicache-storage-backend-extra-config", + json.dumps(extra_config), + "--speculative-algorithm", + "EAGLE3", + "--speculative-draft-model-path", + cls.draft_model, + "--speculative-num-steps", + "7", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "8", + "--dtype", + "float16", + ] + cls.env = { + **os.environ, + "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN": "1", + "SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR": cls.temp_dir, + } + cls.process = None + cls._launch_server() + + @classmethod + def _launch_server(cls): + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=cls.other_args, + env=cls.env, + ) + wait_for_http_ready( + url=f"{cls.base_url}/health", + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + process=cls.process, + ) + + @classmethod + def _stop_server(cls): + if getattr(cls, "process", None) is None: + return + + process = cls.process + try: + root = psutil.Process(process.pid) + watched_procs = [root] + root.children(recursive=True) + except psutil.NoSuchProcess: + watched_procs = [] + + try: + kill_process_tree(process.pid, wait_timeout=60) + except RuntimeError: + non_zombie_procs = [] + for proc in watched_procs: + try: + if proc.is_running() and proc.status() != psutil.STATUS_ZOMBIE: + non_zombie_procs.append(proc) + except psutil.NoSuchProcess: + pass + if non_zombie_procs: + raise + finally: + cls.process = None + + @classmethod + def _restart_server(cls): + cls._stop_server() + cls._launch_server() + + @classmethod + def _count_file_storage_pages(cls): + try: + filenames = os.listdir(cls.temp_dir) + except FileNotFoundError: + return 0, 0 + + target_pages = 0 + draft_pages = 0 + for filename in filenames: + if not filename.endswith(".bin"): + continue + if filename.startswith("d:"): + draft_pages += 1 + else: + target_pages += 1 + return target_pages, draft_pages + + @classmethod + def _wait_for_file_storage_pages(cls): + min_pages = (cls.input_token_len - 2 * cls.page_size) // cls.page_size + deadline = time.monotonic() + cls.storage_wait_timeout + target_pages = draft_pages = 0 + + while time.monotonic() < deadline: + target_pages, draft_pages = cls._count_file_storage_pages() + if target_pages >= min_pages and draft_pages >= min_pages: + return target_pages, draft_pages + time.sleep(0.2) + + raise AssertionError( + "Timed out waiting for HiCache file storage pages before restart: " + f"{target_pages=}, {draft_pages=}, {min_pages=}" + ) + + @classmethod + def tearDownClass(cls): + cls._stop_server() + if hasattr(cls, "temp_dir"): + shutil.rmtree(cls.temp_dir, ignore_errors=True) + + @classmethod + def _encode_without_special_tokens(cls, tokenizer, text: str) -> List[int]: + return tokenizer.encode(text, add_special_tokens=False) + + @classmethod + def _build_long_repetitive_prompt_ids(cls, tokenizer, target_len: int) -> List[int]: + bos_ids = ( + [tokenizer.bos_token_id] + if getattr(tokenizer, "bos_token_id", None) is not None + else [] + ) + suffix_ids = cls._encode_without_special_tokens( + tokenizer, + "\n\nContinue the sequence with only the word apple separated by spaces.\n" + "Answer: apple apple apple apple", + ) + repeat_ids = cls._encode_without_special_tokens(tokenizer, " apple") + if not repeat_ids: + raise ValueError( + "Tokenizer produced no ids for the repetitive prompt seed." + ) + if len(bos_ids) + len(suffix_ids) >= target_len: + raise ValueError( + "Prompt suffix is too long: " + f"{len(bos_ids)=}, {len(suffix_ids)=}, {target_len=}." + ) + + prefix_len = target_len - len(bos_ids) - len(suffix_ids) + repeats = (prefix_len + len(repeat_ids) - 1) // len(repeat_ids) + prefix_ids = (repeat_ids * repeats)[:prefix_len] + prompt_ids = bos_ids + prefix_ids + suffix_ids + assert len(prompt_ids) == target_len + return prompt_ids + + def _send_long_prompt(self, max_new_tokens: int = None) -> Dict: + if max_new_tokens is None: + max_new_tokens = self.max_new_tokens + response = requests.post( + f"{self.base_url}/generate", + json={ + "input_ids": self.prompt_input_ids, + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + "ignore_eos": True, + }, + }, + timeout=900, + ) + self.assertEqual( + response.status_code, + 200, + f"Request failed: {response.status_code} - {response.text}", + ) + return response.json() + + def _get_spec_accept_length(self, response_json: Dict) -> float: + meta_info = response_json.get("meta_info", {}) + self.assertIn( + "spec_accept_length", + meta_info, + f"Missing spec_accept_length in meta_info: {meta_info}", + ) + return float(meta_info["spec_accept_length"]) + + def test_file_storage_loadback_keeps_spec_accept_length(self): + first = self._send_long_prompt(max_new_tokens=self.first_measure_new_tokens) + first_accept_length = self._get_spec_accept_length(first) + self.assertGreaterEqual( + first_accept_length, + self.min_expected_accept_length, + f"First prompt accept length is too low: {first_accept_length=}", + ) + + target_pages, draft_pages = self._wait_for_file_storage_pages() + print(f"file_storage_before_restart: {target_pages=}, {draft_pages=}") + + self._restart_server() + + second = self._send_long_prompt() + second_accept_length = self._get_spec_accept_length(second) + second_meta = second.get("meta_info", {}) + cached_details = second_meta.get("cached_tokens_details") or {} + storage_cached_tokens = int(cached_details.get("storage", 0)) + + print( + f"{first_accept_length=:.3f}, {second_accept_length=:.3f}, " + f"{storage_cached_tokens=}, {cached_details=}" + ) + + self.assertGreaterEqual( + storage_cached_tokens, + self.input_token_len - 2 * self.page_size, + "Expected the second request to load the long prompt KV cache from " + f"file storage, got {cached_details=}", + ) + self.assertEqual( + cached_details.get("storage_backend"), + "HiCacheFile", + f"Expected file storage backend in cache report, got {cached_details=}", + ) + self.assertGreaterEqual( + second_accept_length, + self.min_expected_accept_length, + f"Second prompt accept length is too low: {second_accept_length=}", + ) + self.assertGreaterEqual( + second_accept_length, + first_accept_length * self.min_second_to_first_accept_ratio, + "Spec accept length dropped after file-storage loadback: " + f"{first_accept_length=:.3f}, {second_accept_length=:.3f}", + ) + + +if __name__ == "__main__": + unittest.main(verbosity=2)