diff --git a/docs/advanced_features/hicache_storage_runtime_attach_detach.md b/docs/advanced_features/hicache_storage_runtime_attach_detach.md index f40e36cd083f..0449c5212028 100644 --- a/docs/advanced_features/hicache_storage_runtime_attach_detach.md +++ b/docs/advanced_features/hicache_storage_runtime_attach_detach.md @@ -2,13 +2,19 @@ This document explains how to **dynamically attach/detach the HiCache L3 storage backend at runtime** (e.g., `mooncake` / `hf3fs` / `nixl` / `file` / `aibrix` / `eic`) while **SGLang is already running and serving traffic**, without restarting the process. -For safety and consistency, the current implementation **strictly requires** these operations to happen only when the service is **idle**: +For safety and consistency, the default implementation **strictly requires** these operations to happen only when the service is **idle**: - **No running requests** - **No waiting/queued requests** If the idle condition is not met, the API will fail fast (HTTP 400) and **will not modify** the current service state. +You can optionally enable a **force mode** to switch even under load. In force mode: + +- Requests **do not use** the storage backend during the switch (read treated as miss; write skipped). +- The switch waits for **existing storage operations to drain** before actual attach/detach. +- Any failure will return an error without crashing the server, and the IO block is rolled back. + --- ## 1. Background and implementation overview @@ -99,7 +105,8 @@ curl -s -X PUT http://127.0.0.1:30000/hicache/storage-backend \ -d '{ "hicache_storage_backend": "mooncake", "hicache_storage_backend_extra_config_json": "{\"master_server_address\":\"127.0.0.1:50051\",\"protocol\":\"tcp\",\"global_segment_size\":\"4gb\",\"prefetch_threshold\":256}", - "hicache_storage_prefetch_policy": "timeout" + "hicache_storage_prefetch_policy": "timeout", + "force": true }' ``` @@ -115,6 +122,14 @@ Notes: curl -s -X DELETE http://127.0.0.1:30000/hicache/storage-backend ``` +```bash +curl -s -X DELETE http://127.0.0.1:30000/hicache/storage-backend \ + -H 'Content-Type: application/json' \ + -d '{ + "force": true + }' +``` + Notes: - Detach only makes SGLang **stop using** the L3 storage backend and stops prefetch/backup threads diff --git a/python/sglang/srt/disaggregation/decode_kvcache_offload_manager.py b/python/sglang/srt/disaggregation/decode_kvcache_offload_manager.py index 1cfe615a7fb7..2cd2d5aaa689 100644 --- a/python/sglang/srt/disaggregation/decode_kvcache_offload_manager.py +++ b/python/sglang/srt/disaggregation/decode_kvcache_offload_manager.py @@ -223,6 +223,8 @@ def _trigger_backup( incremental_tokens, hash_value=page_hashes, ) + if ack_id is None: + return self.ongoing_backup[ack_id] = (req.rid, host_indices, start_time) def _compute_prefix_hash(self, tokens, prior_hash=""): diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 0bcbd4a37c92..9559f7806bce 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -99,6 +99,7 @@ ConfigureLoggingReq, ContinueGenerationReqInput, DestroyWeightsUpdateGroupReqInput, + DetachHiCacheStorageReqInput, EmbeddingReqInput, GenerateReqInput, GetWeightsByNameReqInput, @@ -726,7 +727,8 @@ async def clear_hicache_storage_backend(): # "hicache_storage_backend": "file", # "hicache_storage_backend_extra_config_json": "{}", # "hicache_storage_prefetch_policy": "timeout", -# "hicache_write_policy": "write_through" +# "hicache_write_policy": "write_through", +# "force": "false" # }' @app.api_route("/hicache/storage-backend", methods=["PUT"]) @auth_level(AuthLevel.ADMIN_OPTIONAL) @@ -743,6 +745,7 @@ async def attach_hicache_storage_backend(obj: AttachHiCacheStorageReqInput): hicache_storage_backend_extra_config_json=obj.hicache_storage_backend_extra_config_json, hicache_storage_prefetch_policy=obj.hicache_storage_prefetch_policy, hicache_write_policy=obj.hicache_write_policy, + force=obj.force, ) msg = getattr(ret, "message", "") return Response( @@ -759,10 +762,16 @@ async def attach_hicache_storage_backend(obj: AttachHiCacheStorageReqInput): # example usage: -# curl -s -X DELETE http://127.0.0.1:30000/hicache/storage-backend +# curl -s -X DELETE http://127.0.0.1:30000/hicache/storage-backend \ +# -H 'Content-Type: application/json' \ +# -d '{ +# "force": "false" +# }' @app.api_route("/hicache/storage-backend", methods=["DELETE"]) @auth_level(AuthLevel.ADMIN_OPTIONAL) -async def detach_hicache_storage_backend(): +async def detach_hicache_storage_backend( + obj: Optional[DetachHiCacheStorageReqInput] = None, +): """Detach (disable) HiCache storage backend at runtime. Only allowed when there are NO running / queued requests. @@ -770,7 +779,10 @@ async def detach_hicache_storage_backend(): if not _global_state.tokenizer_manager.server_args.admin_api_key: return _admin_api_key_missing_response() - ret = await _global_state.tokenizer_manager.detach_hicache_storage() + if obj is None: + obj = DetachHiCacheStorageReqInput() + + ret = await _global_state.tokenizer_manager.detach_hicache_storage(force=obj.force) msg = getattr(ret, "message", "") return Response( content=( diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 845dffe96496..e482c797a070 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -274,6 +274,7 @@ def __init__( self.storage_backend_type = None self.pp_rank = pp_rank self.pp_size = pp_size + self.fault_reporter = None # Default storage page IO functions (may be overridden by attach). self.page_get_func = self._generic_page_get @@ -284,6 +285,7 @@ def __init__( # transfer buffers (CPU<->GPU). We want to allow runtime attach/detach of # storage without stopping the whole controller. self.storage_stop_event = threading.Event() + self.storage_io_blocked = threading.Event() self.device = self.mem_pool_device.device self.layer_num = self.mem_pool_device.layer_num @@ -396,6 +398,21 @@ def _stop_storage_threads(self): ) raise RuntimeError("Failed to stop HiCache storage threads cleanly.") + def set_storage_io_blocked(self, blocked: bool): + if blocked: + self.storage_io_blocked.set() + else: + self.storage_io_blocked.clear() + + def is_storage_io_blocked(self) -> bool: + return self.storage_io_blocked.is_set() + + def set_fault_reporter(self, reporter): + self.fault_reporter = reporter + if hasattr(self, "storage_backend") and self.storage_backend is not None: + if hasattr(self.storage_backend, "set_fault_reporter"): + self.storage_backend.set_fault_reporter(reporter) + def attach_storage_backend( self, storage_backend: str, @@ -444,6 +461,10 @@ def attach_storage_backend( storage_backend, self.storage_config, self.mem_pool_host ) self.storage_backend.register_mem_pool_host(self.mem_pool_host) + if self.fault_reporter is not None and hasattr( + self.storage_backend, "set_fault_reporter" + ): + self.storage_backend.set_fault_reporter(self.fault_reporter) self.enable_storage = True # todo: threshold policy for prefetching @@ -769,6 +790,8 @@ def prefetch( """ Prefetch KV caches from storage backend to host memory. """ + if self.storage_io_blocked.is_set(): + return None operation = PrefetchOperation( request_id, host_indices, new_input_tokens, last_hash, prefix_keys ) @@ -967,6 +990,8 @@ def write_storage( """ Write KV caches from host memory to storage backend. """ + if self.storage_io_blocked.is_set(): + return None operation = StorageOperation( host_indices, token_ids, hash_value=hash_value, prefix_keys=prefix_keys ) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index c8074fa86baf..c0ac9110b8d4 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1179,6 +1179,7 @@ class AttachHiCacheStorageReqInput(BaseReq): hicache_storage_backend_extra_config_json: Optional[str] = None hicache_storage_prefetch_policy: Optional[str] = None hicache_write_policy: Optional[str] = None + force: bool = False def __post_init__(self): if self.hicache_storage_prefetch_policy is None: @@ -1211,7 +1212,7 @@ class AttachHiCacheStorageReqOutput(BaseReq): class DetachHiCacheStorageReqInput(BaseReq): """Dynamically detach (disable) HiCache storage backend at runtime.""" - pass + force: bool = False @dataclass diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index c7c704c1b410..adbf2078bd0e 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -18,6 +18,7 @@ import os import signal import sys +import threading import time from collections import deque from dataclasses import dataclass @@ -170,6 +171,10 @@ from sglang.srt.managers.utils import GenerationBatchResult, validate_input_length from sglang.srt.mem_cache.cache_init_params import CacheInitParams from sglang.srt.mem_cache.common import release_kv_cache +from sglang.srt.mem_cache.hicache_storage import ( + HiCacheStorageFaultConfig, + HiCacheStorageFaultManager, +) from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors from sglang.srt.multiplex.multiplexing_mixin import SchedulerMultiplexMixin @@ -708,9 +713,108 @@ def init_cache_with_memory_pool(self): else: self.decode_offload_manager = None + self._hicache_storage_op_lock = threading.RLock() + self._hicache_fault_detach_inflight = False + self._hicache_storage_last_config = None + self._init_hicache_fault_manager() + embedding_cache_size = envs.SGLANG_VLM_CACHE_SIZE_MB.get() init_mm_embedding_cache(embedding_cache_size * 1024 * 1024) + def _init_hicache_fault_manager(self): + if not self.enable_hierarchical_cache: + self.hicache_fault_manager = None + return + override = {} + try: + override = self.tree_cache.get_fault_config_override() + except Exception: + override = {} + cfg = self._build_hicache_fault_config(override) + labels = {} + try: + labels = { + "storage_backend": self.server_args.hicache_storage_backend or "none", + "tp_rank": str(self.tp_rank), + "dp_rank": str(self.dp_rank), + "pp_rank": str(self.pp_rank), + "pp_size": str(self.pp_size), + } + except Exception: + labels = {} + + def _block_io(blocked: bool, reason: str): + return self.tree_cache.set_storage_io_blocked(blocked, reason=reason) + + def _detach_cb(reason: str): + self._hicache_fault_detach_inflight = True + try: + out = self.detach_hicache_storage_wrapped( + DetachHiCacheStorageReqInput(force=True) + ) + finally: + self._hicache_fault_detach_inflight = False + return out.success, out.message + + def _attach_cb(): + cfg_dict = self._hicache_storage_last_config + if not cfg_dict: + return False, "No previous hicache storage config to reconnect." + out = self.attach_hicache_storage_wrapped( + AttachHiCacheStorageReqInput( + hicache_storage_backend=cfg_dict["backend"], + hicache_storage_backend_extra_config_json=cfg_dict["extra_config"], + hicache_storage_prefetch_policy=cfg_dict["prefetch_policy"], + hicache_write_policy=cfg_dict["write_policy"], + force=True, + ) + ) + return out.success, out.message + + self.hicache_fault_manager = HiCacheStorageFaultManager( + cfg, + labels=labels, + block_io_cb=_block_io, + detach_cb=_detach_cb, + attach_cb=_attach_cb, + ) + self.tree_cache.set_fault_manager(self.hicache_fault_manager) + if self.server_args.hicache_storage_backend is not None: + self._hicache_storage_last_config = { + "backend": self.server_args.hicache_storage_backend, + "extra_config": self.server_args.hicache_storage_backend_extra_config, + "prefetch_policy": self.server_args.hicache_storage_prefetch_policy, + "write_policy": self.server_args.hicache_write_policy, + } + + def _build_hicache_fault_config(self, override: dict) -> HiCacheStorageFaultConfig: + level = override.get("level", "auto_detach") + if level not in ("off", "auto_detach", "auto_reconnect"): + logger.warning("Unknown hicache fault level: %s", level) + level = "auto_detach" + return HiCacheStorageFaultConfig( + enabled=level != "off", + auto_detach=level in ("auto_detach", "auto_reconnect"), + auto_reconnect=level == "auto_reconnect", + consecutive_fatal_threshold=int(override.get("consecutive", 3)), + ratio_window_s=float(override.get("window_s", 60.0)), + ratio_threshold=float(override.get("ratio", 0.5)), + ratio_min_events=int(override.get("min_events", 10)), + reconnect_backoff_initial_s=float(override.get("backoff_initial_s", 10.0)), + reconnect_backoff_max_s=float(override.get("backoff_max_s", 300.0)), + ) + + def _refresh_hicache_fault_manager(self): + if not self.enable_hierarchical_cache or not self.hicache_fault_manager: + return + override = {} + try: + override = self.tree_cache.get_fault_config_override() + except Exception: + override = {} + cfg = self._build_hicache_fault_config(override) + self.hicache_fault_manager.update_config(cfg) + def init_running_status(self): self.waiting_queue: List[Req] = [] # The running decoding batch for continuous batching @@ -2472,65 +2576,148 @@ def _is_idle_for_hicache_storage_op(self) -> bool: def attach_hicache_storage_wrapped( self, recv_req: AttachHiCacheStorageReqInput ) -> AttachHiCacheStorageReqOutput: - if not self.enable_hierarchical_cache: - return AttachHiCacheStorageReqOutput( - success=False, message="Hierarchical cache is not enabled." - ) + lock = getattr(self, "_hicache_storage_op_lock", None) + if lock is not None: + lock.acquire() + try: + if not self.enable_hierarchical_cache: + return AttachHiCacheStorageReqOutput( + success=False, message="Hierarchical cache is not enabled." + ) - if not self._is_idle_for_hicache_storage_op(): - return AttachHiCacheStorageReqOutput( - success=False, - message=( - "Reject attach: scheduler is not idle. " - f"#queue-req={len(self.waiting_queue)} " - f"#running-req={len(self.running_batch.reqs)}" - ), - ) + if (not recv_req.force) and (not self._is_idle_for_hicache_storage_op()): + return AttachHiCacheStorageReqOutput( + success=False, + message=( + "Reject attach: scheduler is not idle. " + f"#queue-req={len(self.waiting_queue)} " + f"#running-req={len(self.running_batch.reqs)}" + ), + ) - if not hasattr(self.tree_cache, "attach_storage_backend"): - return AttachHiCacheStorageReqOutput( - success=False, - message="Current tree_cache implementation does not support dynamic attach.", - ) + if not hasattr(self.tree_cache, "attach_storage_backend"): + return AttachHiCacheStorageReqOutput( + success=False, + message="Current tree_cache implementation does not support dynamic attach.", + ) - try: - ok, msg = self.tree_cache.attach_storage_backend( - storage_backend=recv_req.hicache_storage_backend, - storage_backend_extra_config_json=recv_req.hicache_storage_backend_extra_config_json, - served_model_name=self.server_args.served_model_name, - hicache_storage_prefetch_policy=recv_req.hicache_storage_prefetch_policy, - hicache_write_policy=recv_req.hicache_write_policy, - ) - except Exception as e: - logger.exception("Attach HiCache storage backend failed with exception.") - return AttachHiCacheStorageReqOutput(success=False, message=str(e)) - if ok: - self.enable_hicache_storage = True - self.server_args.hicache_storage_backend = recv_req.hicache_storage_backend - if recv_req.hicache_storage_backend_extra_config_json is not None: - self.server_args.hicache_storage_backend_extra_config = ( - recv_req.hicache_storage_backend_extra_config_json + prev_blocked = False + if recv_req.force: + try: + prev_blocked = self.tree_cache.is_storage_io_blocked() + ok, msg = self.tree_cache.set_storage_io_blocked( + True, reason="force_attach" + ) + if not ok: + return AttachHiCacheStorageReqOutput( + success=False, message=f"Failed to block storage IO: {msg}" + ) + ok, msg = self.tree_cache.wait_storage_ops_idle() + if not ok: + _ok, _msg = self.tree_cache.set_storage_io_blocked( + prev_blocked, reason="" + ) + if not _ok and _msg: + msg = f"{msg}; rollback failed: {_msg}" + return AttachHiCacheStorageReqOutput(success=False, message=msg) + except Exception as e: + logger.exception("Force attach setup failed with exception.") + _ok, _msg = self.tree_cache.set_storage_io_blocked( + prev_blocked, reason="" + ) + if not _ok: + logger.error("Force attach rollback failed: %s", _msg) + return AttachHiCacheStorageReqOutput(success=False, message=str(e)) + + try: + ok, msg = self.tree_cache.attach_storage_backend( + storage_backend=recv_req.hicache_storage_backend, + storage_backend_extra_config_json=recv_req.hicache_storage_backend_extra_config_json, + served_model_name=self.server_args.served_model_name, + hicache_storage_prefetch_policy=recv_req.hicache_storage_prefetch_policy, + hicache_write_policy=recv_req.hicache_write_policy, ) - if recv_req.hicache_storage_prefetch_policy is not None: - self.server_args.hicache_storage_prefetch_policy = ( - recv_req.hicache_storage_prefetch_policy + except Exception as e: + logger.exception( + "Attach HiCache storage backend failed with exception." ) - if recv_req.hicache_write_policy is not None: - self.server_args.hicache_write_policy = recv_req.hicache_write_policy - logger.info( - f"Attached HiCache storage backend: {recv_req.hicache_storage_backend}" - ) - return AttachHiCacheStorageReqOutput(success=ok, message=msg) + if recv_req.force: + _ok, _msg = self.tree_cache.set_storage_io_blocked( + prev_blocked, reason="" + ) + if not _ok: + logger.error("Force attach rollback failed: %s", _msg) + return AttachHiCacheStorageReqOutput(success=False, message=str(e)) + + if ok: + self.enable_hicache_storage = True + self.server_args.hicache_storage_backend = ( + recv_req.hicache_storage_backend + ) + if recv_req.hicache_storage_backend_extra_config_json is not None: + self.server_args.hicache_storage_backend_extra_config = ( + recv_req.hicache_storage_backend_extra_config_json + ) + if recv_req.hicache_storage_prefetch_policy is not None: + self.server_args.hicache_storage_prefetch_policy = ( + recv_req.hicache_storage_prefetch_policy + ) + if recv_req.hicache_write_policy is not None: + self.server_args.hicache_write_policy = ( + recv_req.hicache_write_policy + ) + logger.info( + f"Attached HiCache storage backend: {recv_req.hicache_storage_backend}" + ) + if recv_req.force: + _ok, _msg = self.tree_cache.set_storage_io_blocked( + prev_blocked, reason="" + ) + if not _ok: + logger.error("Force attach unblock failed: %s", _msg) + if hasattr(self, "_hicache_storage_last_config"): + self._hicache_storage_last_config = { + "backend": recv_req.hicache_storage_backend, + "extra_config": recv_req.hicache_storage_backend_extra_config_json, + "prefetch_policy": recv_req.hicache_storage_prefetch_policy, + "write_policy": recv_req.hicache_write_policy, + } + if ( + hasattr(self, "hicache_fault_manager") + and self.hicache_fault_manager + ): + self.hicache_fault_manager.notify_attach_success() + self._refresh_hicache_fault_manager() + else: + if recv_req.force: + _ok, _msg = self.tree_cache.set_storage_io_blocked( + prev_blocked, reason="" + ) + if not _ok: + logger.error("Force attach rollback failed: %s", _msg) + return AttachHiCacheStorageReqOutput(success=ok, message=msg) + finally: + if lock is not None: + lock.release() def detach_hicache_storage_wrapped( self, recv_req: DetachHiCacheStorageReqInput ) -> DetachHiCacheStorageReqOutput: + if hasattr(self, "_hicache_storage_op_lock"): + self._hicache_storage_op_lock.acquire() + _release_lock = True + else: + _release_lock = False if not self.enable_hierarchical_cache: + if _release_lock: + self._hicache_storage_op_lock.release() return DetachHiCacheStorageReqOutput( success=False, message="Hierarchical cache is not enabled." ) - if not self._is_idle_for_hicache_storage_op(): + if (not recv_req.force) and (not self._is_idle_for_hicache_storage_op()): + if _release_lock: + self._hicache_storage_op_lock.release() return DetachHiCacheStorageReqOutput( success=False, message=( @@ -2541,17 +2728,61 @@ def detach_hicache_storage_wrapped( ) if not hasattr(self.tree_cache, "detach_storage_backend"): + if _release_lock: + self._hicache_storage_op_lock.release() return DetachHiCacheStorageReqOutput( success=False, message="Current tree_cache implementation does not support dynamic detach.", ) + prev_blocked = False + if recv_req.force: + try: + prev_blocked = self.tree_cache.is_storage_io_blocked() + ok, msg = self.tree_cache.set_storage_io_blocked( + True, reason="force_detach" + ) + if not ok: + if _release_lock: + self._hicache_storage_op_lock.release() + return DetachHiCacheStorageReqOutput( + success=False, message=f"Failed to block storage IO: {msg}" + ) + ok, msg = self.tree_cache.wait_storage_ops_idle() + if not ok: + _ok, _msg = self.tree_cache.set_storage_io_blocked( + prev_blocked, reason="" + ) + if not _ok and _msg: + msg = f"{msg}; rollback failed: {_msg}" + if _release_lock: + self._hicache_storage_op_lock.release() + return DetachHiCacheStorageReqOutput(success=False, message=msg) + except Exception as e: + logger.exception("Force detach setup failed with exception.") + _ok, _msg = self.tree_cache.set_storage_io_blocked( + prev_blocked, reason="" + ) + if not _ok: + logger.error("Force detach rollback failed: %s", _msg) + if _release_lock: + self._hicache_storage_op_lock.release() + return DetachHiCacheStorageReqOutput(success=False, message=str(e)) + # Idempotent detach: even if scheduler thinks storage is disabled, we still # attempt best-effort cleanup in tree_cache (it may have leftover state). try: ok, msg = self.tree_cache.detach_storage_backend() except Exception as e: logger.exception("Detach HiCache storage backend failed with exception.") + if recv_req.force: + _ok, _msg = self.tree_cache.set_storage_io_blocked( + prev_blocked, reason="" + ) + if not _ok: + logger.error("Force detach rollback failed: %s", _msg) + if _release_lock: + self._hicache_storage_op_lock.release() return DetachHiCacheStorageReqOutput(success=False, message=str(e)) if ok or (not self.enable_hicache_storage): @@ -2560,10 +2791,30 @@ def detach_hicache_storage_wrapped( self.server_args.hicache_storage_backend = None self.server_args.hicache_storage_backend_extra_config = None logger.info("Detached HiCache storage backend.") + if recv_req.force: + _ok, _msg = self.tree_cache.set_storage_io_blocked( + prev_blocked, reason="" + ) + if not _ok: + logger.error("Force detach unblock failed: %s", _msg) + if ( + hasattr(self, "hicache_fault_manager") + and self.hicache_fault_manager + and not self._hicache_fault_detach_inflight + ): + self.hicache_fault_manager.notify_manual_detach() + if _release_lock: + self._hicache_storage_op_lock.release() return DetachHiCacheStorageReqOutput( success=True, message=msg or "HiCache storage backend is detached." ) + if recv_req.force: + _ok, _msg = self.tree_cache.set_storage_io_blocked(prev_blocked, reason="") + if not _ok: + logger.error("Force detach rollback failed: %s", _msg) + if _release_lock: + self._hicache_storage_op_lock.release() return DetachHiCacheStorageReqOutput(success=False, message=msg) def _is_no_request(self): diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py index f2f9791e61a0..9476306d4179 100644 --- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py +++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -350,6 +350,7 @@ async def attach_hicache_storage( hicache_storage_backend_extra_config_json: Optional[str] = None, hicache_storage_prefetch_policy: Optional[str] = None, hicache_write_policy: Optional[str] = None, + force: bool = False, ) -> AttachHiCacheStorageReqOutput: """Attach (enable) HiCache storage backend at runtime.""" results = await self.attach_hicache_storage_communicator( @@ -358,6 +359,7 @@ async def attach_hicache_storage( hicache_storage_backend_extra_config_json=hicache_storage_backend_extra_config_json, hicache_storage_prefetch_policy=hicache_storage_prefetch_policy, hicache_write_policy=hicache_write_policy, + force=force, ) ) @@ -381,10 +383,11 @@ async def attach_hicache_storage( async def detach_hicache_storage( self: TokenizerManager, + force: bool = False, ) -> DetachHiCacheStorageReqOutput: """Detach (disable) HiCache storage backend at runtime.""" results = await self.detach_hicache_storage_communicator( - DetachHiCacheStorageReqInput() + DetachHiCacheStorageReqInput(force=force) ) all_success, all_message = _Communicator.merge_results(results) diff --git a/python/sglang/srt/mem_cache/hicache_storage.py b/python/sglang/srt/mem_cache/hicache_storage.py index 38df15262cc9..550da26a67c4 100644 --- a/python/sglang/srt/mem_cache/hicache_storage.py +++ b/python/sglang/srt/mem_cache/hicache_storage.py @@ -1,9 +1,14 @@ import hashlib +import json import logging import os +import queue +import threading +import time from abc import ABC, abstractmethod +from collections import deque from dataclasses import dataclass -from typing import Any, List, Optional +from typing import Any, Callable, Deque, List, Optional, Tuple import torch @@ -71,6 +76,18 @@ class HiCacheStorage(ABC): def register_mem_pool_host(self, mem_pool_host: HostKVCache): self.mem_pool_host = mem_pool_host + self._fault_reporter: Optional["HiCacheStorageFaultManager"] = None + + def set_fault_reporter(self, reporter: Optional["HiCacheStorageFaultManager"]): + self._fault_reporter = reporter + + def _report_storage_op(self, op: str, fatal: bool, detail: str = ""): + if self._fault_reporter is None: + return + try: + self._fault_reporter.report_op(op=op, fatal=fatal, detail=detail) + except Exception: + logger.exception("HiCacheStorage fault reporter failed.") def batch_get_v1( self, @@ -181,6 +198,228 @@ def get_stats(self): return None +@dataclass +class HiCacheStorageFaultConfig: + enabled: bool = True + auto_detach: bool = True + auto_reconnect: bool = False + consecutive_fatal_threshold: int = 3 + ratio_window_s: float = 60.0 + ratio_threshold: float = 0.5 + ratio_min_events: int = 10 + reconnect_backoff_initial_s: float = 10.0 + reconnect_backoff_max_s: float = 300.0 + + +class HiCacheStorageFaultManager: + def __init__( + self, + config: HiCacheStorageFaultConfig, + labels: Optional[dict] = None, + block_io_cb: Optional[Callable[[bool, str], Tuple[bool, str]]] = None, + detach_cb: Optional[Callable[[str], Tuple[bool, str]]] = None, + attach_cb: Optional[Callable[[], Tuple[bool, str]]] = None, + ): + self.config = config + self.labels = labels or {} + self.block_io_cb = block_io_cb + self.detach_cb = detach_cb + self.attach_cb = attach_cb + + self._lock = threading.Lock() + self._events: Deque[Tuple[float, bool]] = deque() + self._consecutive_fatal = 0 + self._state = "HEALTHY" + self._last_fault_reason = "" + self._detach_pending = False + self._backoff_s = config.reconnect_backoff_initial_s + self._next_reconnect_time = 0.0 + + self._queue: "queue.Queue[str]" = queue.Queue() + self._worker = threading.Thread(target=self._worker_loop, daemon=True) + self._worker.start() + + self._init_metrics() + + def _init_metrics(self): + try: + from prometheus_client import Counter, Gauge + + self._fault_events_total = Counter( + name="sglang:hicache_storage_fault_events_total", + documentation="Total HiCache storage fault events.", + labelnames=self.labels.keys(), + ) + self._fault_state = Gauge( + name="sglang:hicache_storage_fault_state", + documentation="HiCache storage fault state (0 healthy, 1 blocked, 2 detached, 3 reconnecting).", + labelnames=self.labels.keys(), + ) + self._detach_total = Counter( + name="sglang:hicache_storage_auto_detach_total", + documentation="Total auto detaches triggered by fault manager.", + labelnames=self.labels.keys(), + ) + self._reconnect_total = Counter( + name="sglang:hicache_storage_auto_reconnect_total", + documentation="Total auto reconnect attempts triggered by fault manager.", + labelnames=self.labels.keys(), + ) + except Exception: + self._fault_events_total = None + self._fault_state = None + self._detach_total = None + self._reconnect_total = None + + def report_op(self, op: str, fatal: bool, detail: str = ""): + if not self.config.enabled: + return + now = time.monotonic() + with self._lock: + self._events.append((now, fatal)) + if fatal: + self._consecutive_fatal += 1 + self._last_fault_reason = detail or op + else: + self._consecutive_fatal = 0 + self._prune_events_locked(now) + + if fatal and self._fault_events_total is not None: + self._fault_events_total.labels(**self.labels).inc(1) + + if self._should_trigger_fault_locked(): + self._trigger_fault_locked(reason=detail or op) + + def _prune_events_locked(self, now: float): + window_s = self.config.ratio_window_s + while self._events and now - self._events[0][0] > window_s: + self._events.popleft() + + def _should_trigger_fault_locked(self) -> bool: + if self._state in ("BLOCKED", "DETACHED_BY_FAULT", "RECONNECTING"): + return False + if self._consecutive_fatal >= self.config.consecutive_fatal_threshold: + return True + total = len(self._events) + if total < self.config.ratio_min_events: + return False + fatal_count = sum(1 for _, is_fatal in self._events if is_fatal) + ratio = fatal_count / max(total, 1) + return ratio >= self.config.ratio_threshold + + def _trigger_fault_locked(self, reason: str): + self._state = "BLOCKED" + self._last_fault_reason = reason + if self._fault_state is not None: + self._fault_state.labels(**self.labels).set(1) + if self.block_io_cb is not None: + ok, msg = self.block_io_cb(True, reason) + if not ok: + logger.error("Failed to block HiCache storage IO: %s", msg) + if self.config.auto_detach and not self._detach_pending: + self._detach_pending = True + self._queue.put("detach") + + def notify_manual_detach(self): + with self._lock: + self._state = "DISABLED" + self._detach_pending = False + self._consecutive_fatal = 0 + self._events.clear() + if self._fault_state is not None: + self._fault_state.labels(**self.labels).set(0) + + def notify_attach_success(self): + with self._lock: + self._state = "HEALTHY" + self._detach_pending = False + self._consecutive_fatal = 0 + self._events.clear() + self._backoff_s = self.config.reconnect_backoff_initial_s + if self._fault_state is not None: + self._fault_state.labels(**self.labels).set(0) + + def update_config(self, config: HiCacheStorageFaultConfig): + with self._lock: + self.config = config + self._backoff_s = config.reconnect_backoff_initial_s + + def _worker_loop(self): + while True: + try: + task = self._queue.get(timeout=1) + except Exception: + task = None + + if task == "detach": + self._handle_detach() + elif task == "attach": + self._handle_attach() + + self._maybe_schedule_reconnect() + + def _handle_detach(self): + if self.detach_cb is None: + return + ok, msg = self.detach_cb(self._last_fault_reason) + if ok: + if self._detach_total is not None: + self._detach_total.labels(**self.labels).inc(1) + with self._lock: + self._state = "DETACHED_BY_FAULT" + self._detach_pending = False + self._next_reconnect_time = time.monotonic() + self._backoff_s + if self._fault_state is not None: + self._fault_state.labels(**self.labels).set(2) + else: + logger.error("Auto detach failed: %s", msg) + with self._lock: + self._detach_pending = False + + def _handle_attach(self): + if self.attach_cb is None: + return + ok, msg = self.attach_cb() + if ok: + if self._reconnect_total is not None: + self._reconnect_total.labels(**self.labels).inc(1) + if self.block_io_cb is not None: + _ok, _msg = self.block_io_cb(False, "") + if not _ok: + logger.error("Failed to unblock HiCache storage IO: %s", _msg) + with self._lock: + self._state = "HEALTHY" + self._consecutive_fatal = 0 + self._events.clear() + self._backoff_s = self.config.reconnect_backoff_initial_s + if self._fault_state is not None: + self._fault_state.labels(**self.labels).set(0) + else: + logger.error("Auto reconnect failed: %s", msg) + with self._lock: + self._state = "DETACHED_BY_FAULT" + self._backoff_s = min( + self._backoff_s * 2, self.config.reconnect_backoff_max_s + ) + self._next_reconnect_time = time.monotonic() + self._backoff_s + if self._fault_state is not None: + self._fault_state.labels(**self.labels).set(2) + + def _maybe_schedule_reconnect(self): + if not self.config.auto_reconnect: + return + with self._lock: + if self._state != "DETACHED_BY_FAULT": + return + now = time.monotonic() + if now < self._next_reconnect_time: + return + self._state = "RECONNECTING" + if self._fault_state is not None: + self._fault_state.labels(**self.labels).set(3) + self._queue.put("attach") + + class HiCacheFile(HiCacheStorage): def __init__( @@ -204,6 +443,35 @@ def __init__( os.makedirs(self.file_path) logger.info(f"Created HiCacheFile storage directory at {self.file_path}") + self._fault_inject_path = os.getenv("SGLANG_HICACHE_FAULT_INJECT_PATH") + self._fault_stats_path = os.getenv("SGLANG_HICACHE_FAULT_STATS_PATH") + + def _read_fault_inject_config(self) -> dict: + if not self._fault_inject_path: + return {} + try: + with open(self._fault_inject_path, "r") as fin: + return json.load(fin) + except Exception: + return {} + + def _update_fault_stats(self, op: str, result: str): + if not self._fault_stats_path: + return + try: + data = {} + if os.path.exists(self._fault_stats_path): + with open(self._fault_stats_path, "r") as fin: + data = json.load(fin) + data.setdefault(op, {}) + data[op][result] = data[op].get(result, 0) + 1 + tmp_path = f"{self._fault_stats_path}.tmp" + with open(tmp_path, "w") as fout: + json.dump(data, fout) + os.replace(tmp_path, self._fault_stats_path) + except Exception: + logger.exception("Failed to update fault stats.") + def _get_suffixed_key(self, key: str) -> str: return key + self.config_suffix @@ -216,15 +484,26 @@ def get( key = self._get_suffixed_key(key) tensor_path = os.path.join(self.file_path, f"{key}.bin") try: + cfg = self._read_fault_inject_config() + if cfg.get("fail_get"): + raise RuntimeError("HiCacheFile injected get failure") expected = target_location.numel() * target_location.element_size() with open(tensor_path, "rb", buffering=0) as f: buf = memoryview(target_location.view(torch.uint8).contiguous().numpy()) if f.readinto(buf) != expected: raise IOError(f"Short read for {key}") + self._report_storage_op("get", fatal=False) + self._update_fault_stats("get", "success") return target_location except FileNotFoundError: logger.warning(f"Failed to fetch {key} from HiCacheFile storage.") + self._report_storage_op("get", fatal=False) + self._update_fault_stats("get", "miss") return None + except Exception as e: + self._report_storage_op("get", fatal=True, detail=str(e)) + self._update_fault_stats("get", "fatal") + raise def batch_get( self, @@ -253,10 +532,17 @@ def set( key = self._get_suffixed_key(key) tensor_path = os.path.join(self.file_path, f"{key}.bin") try: + cfg = self._read_fault_inject_config() + if cfg.get("fail_set"): + raise RuntimeError("HiCacheFile injected set failure") value.contiguous().view(dtype=torch.uint8).numpy().tofile(tensor_path) + self._report_storage_op("set", fatal=False) + self._update_fault_stats("set", "success") return True except Exception as e: logger.error(f"Failed to save tensor {key}: {e}") + self._report_storage_op("set", fatal=True, detail=str(e)) + self._update_fault_stats("set", "fatal") return False def batch_set( diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 6e0fac3b57f3..669951f6f651 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -101,6 +101,8 @@ def __init__(self, params: CacheInitParams, server_args: ServerArgs): self.pp_rank = params.pp_rank self.pp_size = params.pp_size self.enable_storage = server_args.hicache_storage_backend is not None + self._storage_io_blocked = threading.Event() + self._storage_io_blocked_reason = "" self.enable_storage_metrics = self.enable_storage and params.enable_metrics self.extra_metric_labels = server_args.extra_metric_labels @@ -110,6 +112,7 @@ def __init__(self, params: CacheInitParams, server_args: ServerArgs): prefetch_timeout_base, prefetch_timeout_per_ki_token, hicache_storage_pass_prefix_keys, + fault_config, ) = self._parse_storage_backend_extra_config( server_args.hicache_storage_backend_extra_config ) @@ -163,10 +166,16 @@ def __init__(self, params: CacheInitParams, server_args: ServerArgs): # Detach storage backend automatically on process shutdown atexit.register(self.shutdown) - self.evictable_host_leaves = set() - super().__init__(params=params) + def set_fault_manager(self, manager): + self.fault_manager = manager + if hasattr(self, "cache_controller") and self.cache_controller is not None: + self.cache_controller.set_fault_reporter(manager) + + def get_fault_config_override(self) -> dict: + return getattr(self, "hicache_fault_config_override", {}) + def shutdown(self): """Best-effort auto-detach of storage backend on process shutdown. @@ -179,46 +188,79 @@ def shutdown(self): except Exception: logger.exception("Failed to detach storage backend on process shutdown.") - def _apply_storage_runtime_config( - self, - *, - storage_backend: Optional[str], - prefetch_threshold: int, - prefetch_timeout_base: float, - prefetch_timeout_per_ki_token: float, - hicache_storage_pass_prefix_keys: bool, - enable_storage: bool, - enable_storage_metrics: bool, - extra_metric_labels: Optional[Dict[str, str]], - ) -> None: - prefetch_timeout_per_page = ( - self.page_size / 1024 * prefetch_timeout_per_ki_token - ) + def is_storage_io_blocked(self) -> bool: + return self._storage_io_blocked.is_set() - storage_metrics_collector = None - if enable_storage_metrics: - labels = { - "storage_backend": storage_backend, - "tp_rank": self.cache_controller.tp_rank, - "dp_rank": self.cache_controller.dp_rank, - "pp_rank": self.cache_controller.pp_rank, - "pp_size": self.cache_controller.pp_size, - } - if extra_metric_labels: - labels.update(extra_metric_labels) - self.storage_metrics_collector = StorageMetricsCollector(labels=labels) - storage_metrics_collector = StorageMetricsCollector(labels=labels) - - self.enable_storage = enable_storage - self.prefetch_threshold = prefetch_threshold - self.prefetch_timeout_base = prefetch_timeout_base - self.prefetch_timeout_per_page = prefetch_timeout_per_page - self.hicache_storage_pass_prefix_keys = hicache_storage_pass_prefix_keys - self.enable_storage_metrics = enable_storage_metrics - if self.enable_storage_metrics: - self.storage_metrics_collector = storage_metrics_collector - else: - self.storage_metrics_collector = None + def set_storage_io_blocked( + self, blocked: bool, reason: str = "" + ) -> tuple[bool, str]: + prev_blocked = self._storage_io_blocked.is_set() + prev_reason = self._storage_io_blocked_reason + try: + if blocked: + self._storage_io_blocked.set() + self._storage_io_blocked_reason = reason + else: + self._storage_io_blocked.clear() + self._storage_io_blocked_reason = "" + if hasattr(self, "cache_controller"): + self.cache_controller.set_storage_io_blocked(blocked) + return True, "" + except Exception as e: + if prev_blocked: + self._storage_io_blocked.set() + self._storage_io_blocked_reason = prev_reason + else: + self._storage_io_blocked.clear() + self._storage_io_blocked_reason = "" + logger.exception("Failed to set storage IO blocked state.") + return False, str(e) + + def _storage_io_allowed(self) -> bool: + return self.enable_storage and (not self._storage_io_blocked.is_set()) + + def get_storage_pending_counts(self) -> dict: + cc = self.cache_controller + counts = { + "ongoing_backup": len(self.ongoing_backup), + "ongoing_prefetch": len(self.ongoing_prefetch), + } + try: + if hasattr(cc, "backup_queue"): + counts["backup_queue"] = cc.backup_queue.qsize() + if hasattr(cc, "prefetch_queue"): + counts["prefetch_queue"] = cc.prefetch_queue.qsize() + if hasattr(cc, "prefetch_buffer"): + counts["prefetch_buffer"] = cc.prefetch_buffer.qsize() + if hasattr(cc, "ack_backup_queue"): + counts["ack_backup_queue"] = cc.ack_backup_queue.qsize() + if hasattr(cc, "prefetch_revoke_queue"): + counts["prefetch_revoke_queue"] = cc.prefetch_revoke_queue.qsize() + if hasattr(cc, "host_mem_release_queue"): + counts["host_mem_release_queue"] = cc.host_mem_release_queue.qsize() + except Exception: + logger.exception("Failed to collect storage pending counts.") + counts["total"] = sum(counts.values()) + return counts + + def wait_storage_ops_idle( + self, timeout_s: float = 30.0, poll_interval_s: float = 0.05 + ) -> tuple[bool, str]: + start = time.monotonic() + while True: + try: + self._drain_storage_control_queues_local() + except Exception: + logger.exception("Drain storage control queues failed during wait.") + counts = self.get_storage_pending_counts() + if counts.get("total", 0) == 0: + return True, "" + if time.monotonic() - start >= timeout_s: + return ( + False, + f"Timeout waiting for storage ops to drain: {counts}", + ) + time.sleep(poll_interval_s) def attach_storage_backend( self, @@ -304,6 +346,7 @@ def attach_storage_backend( prefetch_timeout_base, prefetch_timeout_per_ki_token, hicache_storage_pass_prefix_keys, + fault_config, ) = self._parse_storage_backend_extra_config( storage_backend_extra_config_json ) @@ -327,17 +370,59 @@ def attach_storage_backend( ) return False, f"Failed to attach storage backend '{storage_backend}': {e}" - self._apply_storage_runtime_config( - storage_backend=storage_backend, - prefetch_threshold=prefetch_threshold, - prefetch_timeout_base=prefetch_timeout_base, - prefetch_timeout_per_ki_token=prefetch_timeout_per_ki_token, - hicache_storage_pass_prefix_keys=hicache_storage_pass_prefix_keys, - enable_storage=True, - enable_storage_metrics=self._enable_metrics_flag, - extra_metric_labels=self.extra_metric_labels, - ) - return True, "Attached HiCache storage backend successfully." + self.hicache_fault_config_override = fault_config or {} + + # Commit/rollback boundary: + # - After controller attach succeeds, any exception below MUST rollback by + # detaching controller, otherwise threads may keep running while scheduler + # believes storage is disabled. + try: + # Compute runtime knobs first. + prefetch_timeout_per_page = ( + self.page_size / 1024 * prefetch_timeout_per_ki_token + ) + + # Metrics is optional, but if enabled and creation fails, treat attach as failed + # to keep the system state consistent. + storage_metrics_collector = None + enable_storage_metrics = self._enable_metrics_flag + if enable_storage_metrics: + labels = { + "storage_backend": storage_backend, + "tp_rank": self.cache_controller.tp_rank, + "dp_rank": self.cache_controller.dp_rank, + } + storage_metrics_collector = StorageMetricsCollector(labels=labels) + + # All steps succeeded: now atomically flip flags/state. + self.enable_storage = True + self.prefetch_threshold = prefetch_threshold + self.prefetch_timeout_base = prefetch_timeout_base + self.prefetch_timeout_per_page = prefetch_timeout_per_page + self.hicache_storage_pass_prefix_keys = hicache_storage_pass_prefix_keys + + self.enable_storage_metrics = enable_storage_metrics + if self.enable_storage_metrics: + self.storage_metrics_collector = storage_metrics_collector + return True, "Attached HiCache storage backend successfully." + except Exception as e: + logger.exception( + "Attach storage backend post-init failed; rolling back detach." + ) + # Best-effort rollback to avoid state inconsistency. + try: + self.cache_controller.detach_storage_backend() + except Exception: + logger.exception("Rollback detach_storage_backend failed.") + + self.enable_storage = False + self.enable_storage_metrics = False + if hasattr(self, "storage_metrics_collector"): + self.storage_metrics_collector = None + return ( + False, + f"Failed to finalize attach storage backend '{storage_backend}': {e}", + ) def detach_storage_backend(self) -> tuple[bool, str]: """Detach (disable) storage backend at runtime. @@ -504,7 +589,7 @@ def _parse_storage_backend_extra_config( storage_backend_extra_config: JSON string containing extra configuration Returns: - tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token, hicache_storage_pass_prefix_keys) + tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token, hicache_storage_pass_prefix_keys, fault_config) """ # Parse extra config if provided. Extra config can be a JSON string or a json/toml/yaml file path prefixed with "@". extra_config = {} @@ -536,6 +621,10 @@ def _parse_storage_backend_extra_config( logger.error(f"Invalid backend extra config JSON: {e}") raise e + fault_config = extra_config.pop("fault_tolerance", None) + if fault_config is not None and not isinstance(fault_config, dict): + raise ValueError("fault_tolerance must be a dict if provided.") + prefetch_threshold = extra_config.pop("prefetch_threshold", 256) # tokens prefetch_timeout_base = extra_config.pop("prefetch_timeout_base", 1) # seconds prefetch_timeout_per_ki_token = extra_config.pop( @@ -569,6 +658,7 @@ def _parse_storage_backend_extra_config( float(prefetch_timeout_base), float(prefetch_timeout_per_ki_token), hicache_storage_pass_prefix_keys, + fault_config, ) def reset(self): @@ -633,6 +723,8 @@ def write_backup(self, node: TreeNode, write_back=False): return len(host_indices) def write_backup_storage(self, node: TreeNode): + if not self._storage_io_allowed(): + return prefix_keys = ( node.get_prefix_hash_values(node.parent) if self.hicache_storage_pass_prefix_keys @@ -642,6 +734,8 @@ def write_backup_storage(self, node: TreeNode): operation_id = self.cache_controller.write_storage( node.host_value, node.key, node.hash_value, prefix_keys ) + if operation_id is None: + return self.ongoing_backup[operation_id] = node node.protect_host() @@ -695,7 +789,7 @@ def writing_check(self, write_back=False): for ack_id in ack_list: backuped_node = self.ongoing_write_through.pop(ack_id) self.dec_lock_ref(backuped_node) - if self.enable_storage: + if self._storage_io_allowed(): self.write_backup_storage(backuped_node) finish_count -= 1 @@ -1171,7 +1265,7 @@ def prefetch_from_storage( ) new_input_tokens = new_input_tokens[:prefetch_length] if ( - not self.enable_storage + not self._storage_io_allowed() or prefetch_length < self.prefetch_threshold or self.cache_controller.prefetch_rate_limited() ): @@ -1189,6 +1283,13 @@ def prefetch_from_storage( operation = self.cache_controller.prefetch( req_id, host_indices, new_input_tokens, last_hash, prefix_keys ) + if operation is None: + last_host_node.release_host() + try: + self.cache_controller.mem_pool_host.free(host_indices) + except Exception: + pass + return self.ongoing_prefetch[req_id] = ( last_host_node, new_input_tokens, diff --git a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py index f661e73fafbb..b405e25ea039 100644 --- a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +++ b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py @@ -476,16 +476,28 @@ def batch_get_v1( host_indices: torch.Tensor, extra_info: Optional[HiCacheStorageExtraInfo] = None, ) -> List[bool]: - # Apply extra_backend_tag prefix if available - if self.extra_backend_tag is not None: - prefix = self.extra_backend_tag - keys = [f"{prefix}_{key}" for key in keys] + try: + # Apply extra_backend_tag prefix if available + if self.extra_backend_tag is not None: + prefix = self.extra_backend_tag + keys = [f"{prefix}_{key}" for key in keys] - key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess(keys, host_indices) - get_results = self._get_batch_zero_copy_impl( - key_strs, buffer_ptrs, buffer_sizes - ) - return self._batch_postprocess(get_results, is_set_operate=False) + key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess( + keys, host_indices + ) + get_results = self._get_batch_zero_copy_impl( + key_strs, buffer_ptrs, buffer_sizes + ) + if any(res < 0 for res in get_results): + self._report_storage_op( + "batch_get_v1", fatal=True, detail=str(get_results) + ) + else: + self._report_storage_op("batch_get_v1", fatal=False) + return self._batch_postprocess(get_results, is_set_operate=False) + except Exception as e: + self._report_storage_op("batch_get_v1", fatal=True, detail=str(e)) + raise def batch_set_v1( self, @@ -493,37 +505,49 @@ def batch_set_v1( host_indices: torch.Tensor, extra_info: Optional[HiCacheStorageExtraInfo] = None, ) -> List[bool]: - # Apply extra_backend_tag prefix if available - if self.extra_backend_tag is not None: - prefix = self.extra_backend_tag - keys = [f"{prefix}_{key}" for key in keys] - - key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess(keys, host_indices) - exist_result = self._batch_exist(key_strs) - - set_keys = [] - set_buffer_ptrs = [] - set_buffer_sizes = [] - set_indices = [] - set_results = [-1] * len(key_strs) - for i in range(len(key_strs)): - if exist_result[i] != 1: - set_keys.append(key_strs[i]) - set_buffer_ptrs.append(buffer_ptrs[i]) - set_buffer_sizes.append(buffer_sizes[i]) - set_indices.append(i) - else: - set_results[i] = 0 + try: + # Apply extra_backend_tag prefix if available + if self.extra_backend_tag is not None: + prefix = self.extra_backend_tag + keys = [f"{prefix}_{key}" for key in keys] - # Only set non-existing keys to storage - if len(set_keys) > 0: - put_results = self._put_batch_zero_copy_impl( - set_keys, set_buffer_ptrs, set_buffer_sizes + key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess( + keys, host_indices ) - for i in range(len(set_indices)): - set_results[set_indices[i]] = put_results[i] + exist_result = self._batch_exist(key_strs) + + set_keys = [] + set_buffer_ptrs = [] + set_buffer_sizes = [] + set_indices = [] + set_results = [-1] * len(key_strs) + for i in range(len(key_strs)): + if exist_result[i] != 1: + set_keys.append(key_strs[i]) + set_buffer_ptrs.append(buffer_ptrs[i]) + set_buffer_sizes.append(buffer_sizes[i]) + set_indices.append(i) + else: + set_results[i] = 0 + + # Only set non-existing keys to storage + if len(set_keys) > 0: + put_results = self._put_batch_zero_copy_impl( + set_keys, set_buffer_ptrs, set_buffer_sizes + ) + for i in range(len(set_indices)): + set_results[set_indices[i]] = put_results[i] - return self._batch_postprocess(set_results, is_set_operate=True) + if any(res < 0 for res in set_results): + self._report_storage_op( + "batch_set_v1", fatal=True, detail=str(set_results) + ) + else: + self._report_storage_op("batch_set_v1", fatal=False) + return self._batch_postprocess(set_results, is_set_operate=True) + except Exception as e: + self._report_storage_op("batch_set_v1", fatal=True, detail=str(e)) + raise def set( self, diff --git a/test/registered/hicache/test_force_setup_rollback.py b/test/registered/hicache/test_force_setup_rollback.py new file mode 100644 index 000000000000..caa27b977193 --- /dev/null +++ b/test/registered/hicache/test_force_setup_rollback.py @@ -0,0 +1,84 @@ +""" +"Test force attach/detach setup rollback. + +Usage: + python3 -m pytest test/srt/hicache/test_force_setup_rollback.py -v +""" + +import unittest + +from sglang.srt.managers.io_struct import ( + AttachHiCacheStorageReqInput, + DetachHiCacheStorageReqInput, +) +from sglang.srt.managers.scheduler import Scheduler +from sglang.test.ci.ci_register import register_cuda_ci + +register_cuda_ci(est_time=200, suite="stage-b-test-small-1-gpu") + + +class _FakeTreeCache: + def __init__(self, blocked: bool): + self._blocked = blocked + self.block_calls = [] + self.attach_called = False + self.detach_called = False + + def is_storage_io_blocked(self) -> bool: + return self._blocked + + def set_storage_io_blocked(self, blocked: bool, reason: str = ""): + self._blocked = blocked + self.block_calls.append((blocked, reason)) + return True, "" + + def wait_storage_ops_idle(self): + raise RuntimeError("boom in wait_storage_ops_idle") + + def attach_storage_backend(self, *args, **kwargs): + self.attach_called = True + return True, "" + + def detach_storage_backend(self): + self.detach_called = True + return True, "" + + +class TestHiCacheForceSetupRollback(unittest.TestCase): + def _build_scheduler(self, tree_cache: _FakeTreeCache) -> Scheduler: + scheduler = Scheduler.__new__(Scheduler) + scheduler.enable_hierarchical_cache = True + scheduler.tree_cache = tree_cache + scheduler._is_idle_for_hicache_storage_op = lambda: True + return scheduler + + def test_force_attach_setup_exception_rolls_back_block(self): + tree_cache = _FakeTreeCache(blocked=False) + scheduler = self._build_scheduler(tree_cache) + + req = AttachHiCacheStorageReqInput( + hicache_storage_backend="file", + force=True, + ) + out = scheduler.attach_hicache_storage_wrapped(req) + + self.assertFalse(out.success) + self.assertFalse(tree_cache.is_storage_io_blocked()) + self.assertEqual(tree_cache.block_calls, [(True, "force_attach"), (False, "")]) + self.assertFalse(tree_cache.attach_called) + + def test_force_detach_setup_exception_rolls_back_block(self): + tree_cache = _FakeTreeCache(blocked=True) + scheduler = self._build_scheduler(tree_cache) + + req = DetachHiCacheStorageReqInput(force=True) + out = scheduler.detach_hicache_storage_wrapped(req) + + self.assertFalse(out.success) + self.assertTrue(tree_cache.is_storage_io_blocked()) + self.assertEqual(tree_cache.block_calls, [(True, "force_detach"), (True, "")]) + self.assertFalse(tree_cache.detach_called) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/hicache/test_hicache_storage_fault_tolerance.py b/test/registered/hicache/test_hicache_storage_fault_tolerance.py new file mode 100644 index 000000000000..b0693877135c --- /dev/null +++ b/test/registered/hicache/test_hicache_storage_fault_tolerance.py @@ -0,0 +1,318 @@ +""" +E2E tests for HiCache storage fault tolerance with fake backend. +""" + +import json +import os +import tempfile +import time +from urllib import error, request + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + find_available_port, + popen_launch_server, +) + +register_cuda_ci(est_time=200, suite="stage-b-test-large-2-gpu") + + +class _BaseFaultToleranceTest(CustomTestCase): + @classmethod + def _wait_for_server_ready(cls, base_url: str, timeout: int = 60) -> bool: + start_time = time.time() + while time.time() - start_time < timeout: + try: + code, _body = cls._http_get(f"{base_url}/health", timeout=5) + if code == 200: + return True + except Exception: + pass + time.sleep(2) + raise TimeoutError("Server failed to start within timeout") + + @staticmethod + def _http_get(url: str, timeout: int = 10, headers: dict | None = None): + try: + req = request.Request(url, headers=headers or {}, method="GET") + with request.urlopen(req, timeout=timeout) as resp: + return resp.getcode(), resp.read().decode("utf-8", errors="replace") + except error.HTTPError as e: + body = e.read().decode("utf-8", errors="replace") + return e.code, body + + @staticmethod + def _http_put_json_with_headers( + url: str, + payload: dict | None = None, + timeout: int = 30, + headers: dict | None = None, + ): + data = None + all_headers = dict(headers or {}) + if payload is not None: + data = json.dumps(payload).encode("utf-8") + all_headers["Content-Type"] = "application/json" + req = request.Request(url, data=data, headers=all_headers, method="PUT") + try: + with request.urlopen(req, timeout=timeout) as resp: + return resp.getcode(), resp.read().decode("utf-8", errors="replace") + except error.HTTPError as e: + body = e.read().decode("utf-8", errors="replace") + return e.code, body + + @staticmethod + def _attach_backend( + base_url: str, + backend: str, + extra_cfg: dict, + prefetch_policy: str = "timeout", + write_policy: str = "write_through", + headers: dict | None = None, + ): + payload = { + "hicache_storage_backend": backend, + "hicache_storage_backend_extra_config_json": json.dumps(extra_cfg), + "hicache_storage_prefetch_policy": prefetch_policy, + "hicache_write_policy": write_policy, + } + return _BaseFaultToleranceTest._http_put_json_with_headers( + f"{base_url}/hicache/storage-backend", + payload, + timeout=30, + headers=headers, + ) + + @staticmethod + def _read_json(path: str) -> dict: + if not os.path.exists(path): + return {} + with open(path, "r") as fin: + return json.load(fin) + + @staticmethod + def _write_json(path: str, data: dict): + with open(path, "w") as fout: + json.dump(data, fout) + + def _call_generate(self, base_url: str): + requests.post( + base_url + "/generate", + json={ + "text": "Hello world. Please repeat: hello world.", + "sampling_params": {"max_new_tokens": 32, "temperature": 0}, + }, + timeout=30, + ) + + +class TestHiCacheFaultAutoDetach(_BaseFaultToleranceTest): + @classmethod + def setUpClass(cls): + cls.temp_dir = tempfile.mkdtemp() + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + default_port = int(DEFAULT_URL_FOR_TEST.rsplit(":", 1)[1]) + cls.base_url = f"http://127.0.0.1:{find_available_port(default_port)}" + cls.admin_key = "sglang-test-admin-key" + + cls.fault_inject_path = os.path.join(cls.temp_dir, "fault_inject.json") + cls.stats_path = os.path.join(cls.temp_dir, "stats.json") + cls.storage_dir = os.path.join(cls.temp_dir, "hicache_file") + cls._write_json(cls.fault_inject_path, {"fail_mode": ""}) + + cls.env = os.environ.copy() + cls.env["SGLANG_HICACHE_FAULT_INJECT_PATH"] = cls.fault_inject_path + cls.env["SGLANG_HICACHE_FAULT_STATS_PATH"] = cls.stats_path + cls.env["SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR"] = cls.storage_dir + + cls.other_args = [ + "--enable-hierarchical-cache", + "--mem-fraction-static", + "0.6", + "--hicache-ratio", + "1.2", + "--hicache-size", + "100", + "--page-size", + "64", + "--admin-api-key", + cls.admin_key, + ] + + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=cls.other_args, + env=cls.env, + ) + cls._wait_for_server_ready(cls.base_url) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + time.sleep(2) + + def test_auto_detach_on_fault(self): + headers = {"Authorization": f"Bearer {self.admin_key}"} + extra_cfg = { + "prefetch_threshold": 1, + "prefetch_timeout_base": 1, + "prefetch_timeout_per_ki_token": 0.01, + "fault_tolerance": {"level": "auto_detach"}, + } + code, body = self._attach_backend( + self.base_url, + "file", + extra_cfg, + prefetch_policy="timeout", + write_policy="write_through", + headers=headers, + ) + self.assertEqual(code, 200, body) + + self._call_generate(self.base_url) + stats_before = self._read_json(self.stats_path) + self.assertTrue(stats_before, f"stats missing: {stats_before}") + + self._write_json(self.fault_inject_path, {"fail_mode": "get"}) + self._call_generate(self.base_url) + + # wait for auto detach + detached = False + for _ in range(30): + code_info, body_info = self._http_get( + f"{self.base_url}/hicache/storage-backend", + timeout=10, + headers=headers, + ) + if ( + code_info == 200 + and json.loads(body_info).get("hicache_storage_backend") is None + ): + detached = True + break + time.sleep(1) + self.assertTrue(detached, "auto detach did not happen") + + stats_detached = self._read_json(self.stats_path) + self._call_generate(self.base_url) + stats_after = self._read_json(self.stats_path) + self.assertEqual(stats_detached, stats_after) + + +class TestHiCacheFaultAutoReconnect(_BaseFaultToleranceTest): + @classmethod + def setUpClass(cls): + cls.temp_dir = tempfile.mkdtemp() + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + default_port = int(DEFAULT_URL_FOR_TEST.rsplit(":", 1)[1]) + 1 + cls.base_url = f"http://127.0.0.1:{find_available_port(default_port)}" + cls.admin_key = "sglang-test-admin-key" + + cls.fault_inject_path = os.path.join(cls.temp_dir, "fault_inject.json") + cls.stats_path = os.path.join(cls.temp_dir, "stats.json") + cls.storage_dir = os.path.join(cls.temp_dir, "hicache_file") + cls._write_json(cls.fault_inject_path, {"fail_mode": ""}) + + cls.env = os.environ.copy() + cls.env["SGLANG_HICACHE_FAULT_INJECT_PATH"] = cls.fault_inject_path + cls.env["SGLANG_HICACHE_FAULT_STATS_PATH"] = cls.stats_path + cls.env["SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR"] = cls.storage_dir + + cls.other_args = [ + "--enable-hierarchical-cache", + "--mem-fraction-static", + "0.6", + "--hicache-ratio", + "1.2", + "--hicache-size", + "100", + "--page-size", + "64", + "--admin-api-key", + cls.admin_key, + ] + + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=cls.other_args, + env=cls.env, + ) + cls._wait_for_server_ready(cls.base_url) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + time.sleep(2) + + def test_auto_reconnect_after_recovery(self): + headers = {"Authorization": f"Bearer {self.admin_key}"} + extra_cfg = { + "prefetch_threshold": 1, + "prefetch_timeout_base": 1, + "prefetch_timeout_per_ki_token": 0.01, + "fault_tolerance": { + "level": "auto_reconnect", + "backoff_initial_s": 1, + "backoff_max_s": 2, + }, + } + code, body = self._attach_backend( + self.base_url, + "file", + extra_cfg, + prefetch_policy="timeout", + write_policy="write_through", + headers=headers, + ) + self.assertEqual(code, 200, body) + + self._write_json(self.fault_inject_path, {"fail_mode": "get"}) + self._call_generate(self.base_url) + + # wait for auto detach + for _ in range(30): + code_info, body_info = self._http_get( + f"{self.base_url}/hicache/storage-backend", + timeout=10, + headers=headers, + ) + if ( + code_info == 200 + and json.loads(body_info).get("hicache_storage_backend") is None + ): + break + time.sleep(1) + + # recover and wait for auto reconnect + self._write_json(self.fault_inject_path, {"fail_mode": ""}) + reattached = False + for _ in range(30): + code_info, body_info = self._http_get( + f"{self.base_url}/hicache/storage-backend", + timeout=10, + headers=headers, + ) + if ( + code_info == 200 + and json.loads(body_info).get("hicache_storage_backend") == "file" + ): + reattached = True + break + time.sleep(1) + self.assertTrue(reattached, "auto reconnect did not happen") + + stats_before = self._read_json(self.stats_path) + self._call_generate(self.base_url) + stats_after = self._read_json(self.stats_path) + self.assertNotEqual(stats_before, stats_after)