diff --git a/docs/advanced_features/hicache.rst b/docs/advanced_features/hicache.rst index b2bd08b79e76..e7d83211dc9a 100644 --- a/docs/advanced_features/hicache.rst +++ b/docs/advanced_features/hicache.rst @@ -6,3 +6,4 @@ Hierarchical KV Caching (HiCache) hicache_best_practices.md hicache_design.md + hicache_storage_runtime_attach_detach.md diff --git a/docs/advanced_features/hicache_best_practices.md b/docs/advanced_features/hicache_best_practices.md index cb1baa01e1c8..02749530ae6e 100644 --- a/docs/advanced_features/hicache_best_practices.md +++ b/docs/advanced_features/hicache_best_practices.md @@ -19,6 +19,10 @@ SGLang HiCache extends the traditional RadixAttention with a three-tier hierarch --hicache-storage-backend # Optional storage backend (e.g., hf3fs, mooncake, etc.) ``` +Notes: + +- Besides configuring `--hicache-storage-backend` at startup, SGLang also supports **runtime attach/detach** of the HiCache storage backend (no restart required) via HTTP admin endpoints. See [Runtime Attach/Detach HiCache Storage Backend](hicache_storage_runtime_attach_detach.md). + ## Key Configurations with Storage Backends Enabled ### Memory Layout Optimization diff --git a/docs/advanced_features/hicache_storage_runtime_attach_detach.md b/docs/advanced_features/hicache_storage_runtime_attach_detach.md new file mode 100644 index 000000000000..f40e36cd083f --- /dev/null +++ b/docs/advanced_features/hicache_storage_runtime_attach_detach.md @@ -0,0 +1,132 @@ +# Runtime Attach/Detach HiCache Storage Backend (No Restart) + +This document explains how to **dynamically attach/detach the HiCache L3 storage backend at runtime** (e.g., `mooncake` / `hf3fs` / `nixl` / `file` / `aibrix` / `eic`) while **SGLang is already running and serving traffic**, without restarting the process. + +For safety and consistency, the current implementation **strictly requires** these operations to happen only when the service is **idle**: + +- **No running requests** +- **No waiting/queued requests** + +If the idle condition is not met, the API will fail fast (HTTP 400) and **will not modify** the current service state. + +--- + +## 1. Background and implementation overview + +### 1.1 Architecture / control path + +The control path is: + +1. **HTTP Server** (`python/sglang/srt/entrypoints/http_server.py`) + - Exposes `PUT /hicache/storage-backend`, `DELETE /hicache/storage-backend`, `GET /hicache/storage-backend` +2. **TokenizerManager** (`python/sglang/srt/managers/tokenizer_communicator_mixin.py`) + - Sends the request to the Scheduler via `_Communicator` +3. **Scheduler** (`python/sglang/srt/managers/scheduler.py`) + - Performs a **strict idle check** + - Calls `tree_cache.attach_storage_backend(...)` / `detach_storage_backend(...)` +4. **HiRadixCache** (`python/sglang/srt/mem_cache/hiradix_cache.py`) + - Parses `hicache_storage_backend_extra_config_json` (supports both backend config and prefetch knobs) + - Calls `cache_controller.attach_storage_backend(...)` / `detach_storage_backend(...)` +5. **HiCacheController** (`python/sglang/srt/managers/cache_controller.py`) + - Creates/destroys the storage backend instance (via `StorageBackendFactory`) + - Starts/stops backend background threads at runtime (prefetch/backup) + +--- + +## 2. Idle-state requirement (strict) + +The Scheduler uses a stricter `_is_idle_for_hicache_storage_op()`: + +- `_is_no_request()` is true (covers running/overlap/pp/disagg and other active states) +- `waiting_queue` is empty +- `grammar_queue` is empty (if the grammar backend is enabled) + +If the condition is not met, attach/detach returns an error like: + +- `Reject attach: scheduler is not idle. #queue-req=... #running-req=...` + +> Tip: before switching, drain upstream traffic and wait for the server to become idle, then call attach/detach. + +### 2.1 DP (data parallel) semantics + +When `dp_size > 1`, the tokenizer dispatches the request to **all DP scheduler instances** and aggregates their responses: + +- The final `success` is **true only if all DP ranks return success** +- The final `message` concatenates messages from all DP ranks + +This is intended to prevent “silent partial success”, but it also means you may see: + +- Overall **failure** even though **some ranks already succeeded** + +Currently there is **no automatic partial rollback** across DP ranks (see TODO in code). Operationally: + +- Prefer to keep backend config identical across ranks +- If attach fails, immediately call detach (best-effort/idempotent), fix config, then retry attach + +--- + +## 3. How to use (HTTP Admin API) + +The examples below assume your SGLang HTTP server is at `http://127.0.0.1:30000`. + +### 3.1 Query current storage backend status + +```bash +curl -s http://127.0.0.1:30000/hicache/storage-backend +``` + +Example response: + +```json +{ + "hicache_storage_backend": "mooncake", + "hicache_storage_backend_extra_config": "{\"master_server_address\":\"127.0.0.1:50051\", ...}" +} +``` + +### 3.2 Attach (enable) a storage backend +```bash +curl -s -X PUT http://127.0.0.1:30000/hicache/storage-backend \ + -H 'Content-Type: application/json' \ + -d '{ + "hicache_storage_backend": "mooncake" + }' +``` + +```bash +curl -s -X PUT http://127.0.0.1:30000/hicache/storage-backend \ + -H 'Content-Type: application/json' \ + -d '{ + "hicache_storage_backend": "mooncake", + "hicache_storage_backend_extra_config_json": "{\"master_server_address\":\"127.0.0.1:50051\",\"protocol\":\"tcp\",\"global_segment_size\":\"4gb\",\"prefetch_threshold\":256}", + "hicache_storage_prefetch_policy": "timeout" + }' +``` + +Notes: + +- `hicache_storage_backend_extra_config_json` can include both: + - **Backend configuration** (e.g., Mooncake master/metadata/protocol, etc.) + - **Prefetch configuration** (`prefetch_threshold`, `prefetch_timeout_base`, `prefetch_timeout_per_ki_token`, `hicache_storage_pass_prefix_keys`) + +### 3.3 Detach (disable) the storage backend + +```bash +curl -s -X DELETE http://127.0.0.1:30000/hicache/storage-backend +``` + +Notes: + +- Detach only makes SGLang **stop using** the L3 storage backend and stops prefetch/backup threads +- It **does not automatically delete** data stored in Mooncake/HF3FS (or other remote backends) + +--- + +## 4. Behavior and caveats + +- **No restart required**: attach/detach switches in-process at runtime +- **Must be idle**: otherwise the request is rejected to avoid consistency issues +- **Host KV layout constraints still apply**: for example, Mooncake still requires layouts like `page_first/page_first_direct/page_head`; if the server's HiCache host-memory layout does not satisfy the backend requirements, attach will fail with an error +- **Observability**: + - After attach, `server_args.hicache_storage_backend*` is updated on both the tokenizer and scheduler sides + - If metrics are enabled, attach will create a storage metrics collector in `HiRadixCache` on demand diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index afac1d03dace..d24c7739cd66 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -93,6 +93,7 @@ from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.managers.io_struct import ( AbortReq, + AttachHiCacheStorageReqInput, CheckWeightsReqInput, CloseSessionReqInput, ConfigureLoggingReq, @@ -693,6 +694,22 @@ async def flush_cache(): @app.api_route("/clear_hicache_storage_backend", methods=["GET", "POST"]) @auth_level(AuthLevel.ADMIN_OPTIONAL) +async def clear_hicache_storage_backend_deprecated(): + """Deprecated: use POST /hicache/storage-backend/clear.""" + ret = await _global_state.tokenizer_manager.clear_hicache_storage() + return Response( + content=( + "Deprecated endpoint. Use POST /hicache/storage-backend/clear.\n" + "Hierarchical cache storage backend cleared.\n" + ), + status_code=200 if ret.success else HTTPStatus.BAD_REQUEST, + ) + + +# example usage: +# curl -s -X POST http://127.0.0.1:30000/clear_hicache_storage_backend +@app.api_route("/hicache/storage-backend/clear", methods=["POST"]) +@auth_level(AuthLevel.ADMIN_OPTIONAL) async def clear_hicache_storage_backend(): """Clear the hierarchical cache storage backend.""" ret = await _global_state.tokenizer_manager.clear_hicache_storage() @@ -702,6 +719,89 @@ async def clear_hicache_storage_backend(): ) +# example usage: +# curl -s -X PUT http://127.0.0.1:30000/hicache/storage-backend \ +# -H 'Content-Type: application/json' \ +# -d '{ +# "hicache_storage_backend": "file", +# "hicache_storage_backend_extra_config_json": "{}", +# "hicache_storage_prefetch_policy": "timeout", +# "hicache_write_policy": "write_through" +# }' +@app.api_route("/hicache/storage-backend", methods=["PUT"]) +@auth_level(AuthLevel.ADMIN_OPTIONAL) +async def attach_hicache_storage_backend(obj: AttachHiCacheStorageReqInput): + """Attach (enable) HiCache storage backend at runtime. + + Only allowed when there are NO running / queued requests. + """ + if not _global_state.tokenizer_manager.server_args.admin_api_key: + return _admin_api_key_missing_response() + + ret = await _global_state.tokenizer_manager.attach_hicache_storage( + hicache_storage_backend=obj.hicache_storage_backend, + hicache_storage_backend_extra_config_json=obj.hicache_storage_backend_extra_config_json, + hicache_storage_prefetch_policy=obj.hicache_storage_prefetch_policy, + hicache_write_policy=obj.hicache_write_policy, + ) + msg = getattr(ret, "message", "") + return Response( + content=( + ( + "HiCache storage backend attached.\n" + if ret.success + else "Failed to attach HiCache storage backend.\n" + ) + + (msg + "\n" if msg else "") + ), + status_code=200 if ret.success else HTTPStatus.BAD_REQUEST, + ) + + +# example usage: +# curl -s -X DELETE http://127.0.0.1:30000/hicache/storage-backend +@app.api_route("/hicache/storage-backend", methods=["DELETE"]) +@auth_level(AuthLevel.ADMIN_OPTIONAL) +async def detach_hicache_storage_backend(): + """Detach (disable) HiCache storage backend at runtime. + + Only allowed when there are NO running / queued requests. + """ + if not _global_state.tokenizer_manager.server_args.admin_api_key: + return _admin_api_key_missing_response() + + ret = await _global_state.tokenizer_manager.detach_hicache_storage() + msg = getattr(ret, "message", "") + return Response( + content=( + ( + "HiCache storage backend detached.\n" + if ret.success + else "Failed to detach HiCache storage backend.\n" + ) + + (msg + "\n" if msg else "") + ), + status_code=200 if ret.success else HTTPStatus.BAD_REQUEST, + ) + + +# example usage: +# curl -s http://127.0.0.1:30000/hicache/storage-backend +@app.get("/hicache/storage-backend") +@auth_level(AuthLevel.ADMIN_OPTIONAL) +async def hicache_storage_backend_status(): + """Get current HiCache storage backend status (tokenizer-side view).""" + if not _global_state.tokenizer_manager.server_args.admin_api_key: + return _admin_api_key_missing_response() + + return { + "hicache_storage_backend": _global_state.tokenizer_manager.server_args.hicache_storage_backend, + "hicache_storage_backend_extra_config": _global_state.tokenizer_manager.server_args.hicache_storage_backend_extra_config, + "hicache_storage_prefetch_policy": _global_state.tokenizer_manager.server_args.hicache_storage_prefetch_policy, + "hicache_write_policy": _global_state.tokenizer_manager.server_args.hicache_write_policy, + } + + @app.api_route("/start_profile", methods=["GET", "POST"]) @auth_level(AuthLevel.ADMIN_OPTIONAL) async def start_profile_async(obj: Optional[ProfileReqInput] = None): @@ -1489,6 +1589,27 @@ def _create_error_response(e): ) +# FIXME: In theory we should configure ADMIN_FORCE for some entrypoints, but doing so +# would currently cause all endpoints to go through add_api_key_middleware +# (even when neither api-key nor admin-api-key is configured). +# +# For now, we simulate ADMIN_FORCE by explicitly checking the admin API key parameter. +# Once the auth wiring is refactored so ADMIN_FORCE only affects the intended +# admin endpoints, we should switch this logic to use ADMIN_FORCE directly. +def _admin_api_key_missing_response( + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, +) -> ORJSONResponse: + return ORJSONResponse( + content={ + "error": ( + "This endpoint requires admin API key, but this server was started " + "without one (admin-api-key). Restart with --admin-api-key to enable." + ) + }, + status_code=status_code, + ) + + # Minimal 32x32 black PNG (base64, GLM4v requires at least 32x32 sized image) MINIMUM_PNG_PICTURE_BASE64 = "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAYAAABzenr0AAAACXBIWXMAAA7EAAAOxAGVKw4bAAAAbUlEQVRYhe3VsQ2AMAxE0Y/lIgNQULD/OqyCMgCihCKSG4yRuKuiNH6JLsoEbMACOGBcua9HOR7Y6w6swBwMy0qLTpkeI77qdEBpBFAHBBDAGH8WrwJKI4AAegUCfAKgEgpQDvh3CR3oQCuav58qlAw73kKCSgAAAABJRU5ErkJggg==" diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 80174584e51d..845dffe96496 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -262,6 +262,7 @@ def __init__( pp_rank: int = 0, pp_size: int = 1, ): + self.tp_group = tp_group self.mem_pool_device_allocator = token_to_kv_pool_allocator self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache() self.mem_pool_host = mem_pool_host @@ -269,41 +270,186 @@ def __init__( self.page_size = page_size self.io_backend = io_backend self.enable_storage = False + self.storage_backend = None + self.storage_backend_type = None self.pp_rank = pp_rank self.pp_size = pp_size - if storage_backend is not None: - self.storage_backend_type = storage_backend - from sglang.srt.mem_cache.hicache_storage import get_hash_str + # Default storage page IO functions (may be overridden by attach). + self.page_get_func = self._generic_page_get + self.page_set_func = self._generic_page_set - self.get_hash_str = get_hash_str - self.storage_config = self._generate_storage_config( - model_name, storage_backend_extra_config - ) - # for MLA models, only one rank needs to backup the KV cache - self.backup_skip = ( - self.storage_config.is_mla_model - # todo: load balancing - and self.storage_config.tp_rank != 0 - ) + # Dedicated stop event for storage background threads (prefetch/backup). + # NOTE: Do NOT reuse `self.stop_event` here since it also guards core HiCache + # transfer buffers (CPU<->GPU). We want to allow runtime attach/detach of + # storage without stopping the whole controller. + self.storage_stop_event = threading.Event() + + self.device = self.mem_pool_device.device + self.layer_num = self.mem_pool_device.layer_num + self.layer_done_counter = LayerDoneCounter(self.layer_num) + self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter) + + if write_policy not in [ + "write_through", + "write_through_selective", + "write_back", + ]: + raise ValueError(f"Invalid write policy: {write_policy}") - # Use storage backend factory for dynamic backend creation - from sglang.srt.mem_cache.storage import StorageBackendFactory + # self.write_queue = PriorityQueue[CacheOperation]() + self.load_queue: List[CacheOperation] = [] + self.write_queue: List[CacheOperation] = [] + self.ack_load_queue: List[HiCacheAck] = [] + self.ack_write_queue: List[HiCacheAck] = [] + + self.stop_event = threading.Event() + self.write_buffer = TransferBuffer(self.stop_event) + self.load_buffer = TransferBuffer( + self.stop_event, buffer_count=10, max_buffer_size=100 + ) + + self.write_stream = device_module.Stream() + self.load_stream = device_module.Stream() + # If a storage backend is provided at startup, treat it as an implicit attach, + # so init/runtime share the same lifecycle semantics and code paths. + if storage_backend is not None: try: - self.storage_backend = StorageBackendFactory.create_backend( - storage_backend, self.storage_config, self.mem_pool_host + self.attach_storage_backend( + storage_backend=storage_backend, + prefetch_threshold=prefetch_threshold, + model_name=model_name, + storage_backend_extra_config=storage_backend_extra_config, ) except ValueError as e: + # Preserve the historical error shape on init for unknown backends. raise ValueError(f"Failed to create storage backend: {e}") from e + def _start_storage_threads(self): + """Start storage prefetch/backup threads and their queues. + + This is used by runtime attach, and also by reset when storage is enabled. + """ + assert self.enable_storage + assert not self.storage_stop_event.is_set() + + self.prefetch_thread = threading.Thread( + target=self.prefetch_thread_func, daemon=True + ) + self.backup_thread = threading.Thread( + target=self.backup_thread_func, daemon=True + ) + self.prefetch_queue = Queue() + self.backup_queue = Queue() + + self.prefetch_revoke_queue = Queue() + self.ack_backup_queue = Queue() + self.host_mem_release_queue = Queue() + + self.prefetch_thread.start() + self.backup_thread.start() + + def _stop_storage_threads(self): + """Stop storage prefetch/backup threads and drain internal queues. + + Caller should ensure no in-flight requests. + """ + # Always request stop. This is safe even when storage is already disabled, + # and makes detach truly idempotent (previous partial detach may have left + # threads alive). + # NOTE: do NOT clear stop_event unless threads have fully stopped; otherwise + # a still-alive thread may resume and touch released state. + self.storage_stop_event.set() + + # Best-effort wakeups so threads exit promptly even if blocked on queues. + try: + if hasattr(self, "prefetch_queue"): + self.prefetch_queue.put_nowait(None) + if hasattr(self, "backup_queue"): + self.backup_queue.put_nowait(None) + if hasattr(self, "prefetch_buffer"): + self.prefetch_buffer.put_nowait(None) + except Exception: + pass + + # Best-effort joins (threads are daemon, but join keeps state clean). + threads = [] + if hasattr(self, "prefetch_thread"): + threads.append(self.prefetch_thread) + if hasattr(self, "backup_thread"): + threads.append(self.backup_thread) + if hasattr(self, "prefetch_io_aux_thread"): + threads.append(self.prefetch_io_aux_thread) + + for t in threads: + try: + t.join(timeout=10) + except Exception: + pass + + alive = [t for t in threads if getattr(t, "is_alive", lambda: False)()] + if alive: + logger.error( + "Failed to stop HiCache storage threads cleanly: %s", + [getattr(t, "name", repr(t)) for t in alive], + ) + raise RuntimeError("Failed to stop HiCache storage threads cleanly.") + + def attach_storage_backend( + self, + storage_backend: str, + prefetch_threshold: int = 256, + model_name: Optional[str] = None, + storage_backend_extra_config: Optional[dict] = None, + ): + """Attach (enable) storage backend at runtime. + + Requirement: no in-flight requests. This call is expected to run on the scheduler + thread (control path), not concurrently with prefetch/backup. + """ + if self.enable_storage: + raise RuntimeError("Storage backend already attached.") + + # Defensive: a previous partial detach may have flipped `enable_storage` but + # left background threads alive. Attaching on top of them is unsafe. + try: + self._stop_storage_threads() + except Exception as e: + raise RuntimeError( + "Cannot attach storage backend: previous detach did not stop storage threads cleanly." + ) from e + + # Rollback-safe init: if creation fails, keep controller state consistent + # for future attach attempts. + self.storage_backend_type = storage_backend + from sglang.srt.mem_cache.hicache_storage import get_hash_str + + self.get_hash_str = get_hash_str + self.storage_config = self._generate_storage_config( + model_name, storage_backend_extra_config + ) + # for MLA models, only one rank needs to backup the KV cache + self.backup_skip = ( + self.storage_config.is_mla_model + # todo: load balancing + and self.storage_config.tp_rank != 0 + ) + + # Use storage backend factory for dynamic backend creation + from sglang.srt.mem_cache.storage import StorageBackendFactory + + try: + self.storage_backend = StorageBackendFactory.create_backend( + storage_backend, self.storage_config, self.mem_pool_host + ) self.storage_backend.register_mem_pool_host(self.mem_pool_host) self.enable_storage = True # todo: threshold policy for prefetching self.prefetch_threshold = max(prefetch_threshold, self.page_size) - self.prefetch_capacity_limit = int( - 0.8 * (self.mem_pool_host.size - self.mem_pool_device.size) + self.prefetch_capacity_limit = max( + 0, int(0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)) ) # granularity of batch storage IO operations, in number of pages self.storage_batch_size = 128 @@ -311,13 +457,13 @@ def __init__( self.prefetch_tokens_occupied = 0 # create a new communication group for synchronizing storage operations across TP workers - self.tp_world_size = torch.distributed.get_world_size(group=tp_group) + self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group) if self.tp_world_size > 1: from sglang.srt.distributed.parallel_state import ( create_custom_parallel_group, ) - group_ranks = torch.distributed.get_process_group_ranks(tp_group) + group_ranks = torch.distributed.get_process_group_ranks(self.tp_group) self.prefetch_tp_group = create_custom_parallel_group( group_ranks=group_ranks, backend="gloo" ) @@ -333,49 +479,92 @@ def __init__( self.page_get_func = self._page_get_zero_copy self.page_set_func = self._page_set_zero_copy - self.device = self.mem_pool_device.device - self.layer_num = self.mem_pool_device.layer_num - self.layer_done_counter = LayerDoneCounter(self.layer_num) - self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter) - - if write_policy not in [ - "write_through", - "write_through_selective", - "write_back", - ]: - raise ValueError(f"Invalid write policy: {write_policy}") - - # self.write_queue = PriorityQueue[CacheOperation]() - self.load_queue: List[CacheOperation] = [] - self.write_queue: List[CacheOperation] = [] - self.ack_load_queue: List[HiCacheAck] = [] - self.ack_write_queue: List[HiCacheAck] = [] - - self.stop_event = threading.Event() - self.write_buffer = TransferBuffer(self.stop_event) - self.load_buffer = TransferBuffer( - self.stop_event, buffer_count=10, max_buffer_size=100 - ) - - self.write_stream = device_module.Stream() - self.load_stream = device_module.Stream() + # Ensure stop_event is clear before starting threads. + self.storage_stop_event.clear() + self._start_storage_threads() + except Exception: + # Best-effort cleanup for partial init. + try: + self._stop_storage_threads() + except Exception: + pass + try: + if hasattr(self, "prefetch_tp_group"): + try: + torch.distributed.destroy_process_group(self.prefetch_tp_group) + except Exception: + pass + self.prefetch_tp_group = None + except Exception: + pass + try: + if ( + hasattr(self, "storage_backend") + and self.storage_backend is not None + ): + if hasattr(self.storage_backend, "close"): + self.storage_backend.close() + except Exception: + pass + self.storage_backend = None + self.storage_backend_type = None + self.enable_storage = False + self.page_get_func = self._generic_page_get + self.page_set_func = self._generic_page_set + raise - if self.enable_storage: - self.prefetch_thread = threading.Thread( - target=self.prefetch_thread_func, daemon=True - ) - self.backup_thread = threading.Thread( - target=self.backup_thread_func, daemon=True - ) - self.prefetch_queue = Queue() - self.backup_queue = Queue() + def detach_storage_backend(self): + """Detach (disable) storage backend at runtime. - self.prefetch_revoke_queue = Queue() - self.ack_backup_queue = Queue() - self.host_mem_release_queue = Queue() + Requirement: no in-flight requests. This will stop storage threads and release + the backend instance (best-effort close). + """ + # Idempotent cleanup: even if `enable_storage` is already False, + # we may still have leftover resources (threads/backend/process group) from a + # previous partial detach. We attempt cleanup whenever possible. + try: + self._stop_storage_threads() + except Exception as e: + # Do not proceed tearing down backend/process group if threads are not + # fully stopped; otherwise still-alive threads may touch released state. + # Caller can retry detach. + logger.exception("Stop storage threads failed: %s", e) + # IMPORTANT: Do not silently succeed. Upper layers rely on exceptions here + # to avoid flipping `enable_storage` flags while threads are still alive. + raise RuntimeError("Stop storage threads failed; detach aborted.") from e + + # Best-effort destroy process group created for storage ops. + try: + if ( + hasattr(self, "prefetch_tp_group") + and self.prefetch_tp_group is not None + ): + try: + torch.distributed.destroy_process_group(self.prefetch_tp_group) + except Exception: + pass + self.prefetch_tp_group = None + except Exception: + pass + + # Best-effort close (some backends rely on GC/destructor). + try: + if ( + hasattr(self, "storage_backend") + and self.storage_backend is not None + and hasattr(self.storage_backend, "close") + ): + self.storage_backend.close() + except Exception: + logger.exception("Failed to close storage backend cleanly.") - self.prefetch_thread.start() - self.backup_thread.start() + self.storage_backend = None + self.storage_backend_type = None + self.enable_storage = False + self.page_get_func = self._generic_page_get + self.page_set_func = self._generic_page_set + # Now it's safe to clear the stop event for future re-attach. + self.storage_stop_event.clear() def _generate_storage_config( self, @@ -408,6 +597,7 @@ def _generate_storage_config( def reset(self): self.stop_event.set() + self.storage_stop_event.set() self.write_queue.clear() self.load_queue.clear() @@ -424,6 +614,7 @@ def reset(self): self.ack_backup_queue.queue.clear() self.stop_event.clear() + self.storage_stop_event.clear() if self.enable_storage: self.prefetch_thread = threading.Thread( @@ -661,9 +852,11 @@ def prefetch_io_aux_func(self): """ Auxiliary function conducting IO operations for prefetching. """ - while not self.stop_event.is_set(): + while not self.storage_stop_event.is_set(): try: operation = self.prefetch_buffer.get(block=True, timeout=1) + if operation is None: + continue self._page_transfer(operation) # operation terminated by controller, release pre-allocated memory self.append_host_mem_release( @@ -719,9 +912,11 @@ def prefetch_thread_func(self): Manage prefetching operations from storage backend to host memory. """ self.prefetch_buffer = Queue() - aux_thread = threading.Thread(target=self.prefetch_io_aux_func, daemon=True) - aux_thread.start() - while (not self.stop_event.is_set()) or not self.prefetch_queue.empty(): + self.prefetch_io_aux_thread = threading.Thread( + target=self.prefetch_io_aux_func, daemon=True + ) + self.prefetch_io_aux_thread.start() + while (not self.storage_stop_event.is_set()) or not self.prefetch_queue.empty(): try: operation = self.prefetch_queue.get(block=True, timeout=1) if operation is None: @@ -818,7 +1013,7 @@ def backup_thread_func(self): """ Manage backup operations from host memory to storage backend. """ - while not self.stop_event.is_set(): + while not self.storage_stop_event.is_set(): try: operation = self.backup_queue.get(block=True, timeout=1) if operation is None: diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index fad02e0a0112..6248121862a0 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1158,6 +1158,60 @@ class FlushCacheReqOutput(BaseReq): success: bool +@dataclass +class AttachHiCacheStorageReqInput(BaseReq): + """Dynamically attach (enable) HiCache storage backend at runtime. + + Note: `hicache_storage_backend_extra_config_json` is a JSON string. It may contain both: + - backend-specific configs (e.g., mooncake master address) + - prefetch-related knobs (prefetch_threshold, prefetch_timeout_*, hicache_storage_pass_prefix_keys) + """ + + hicache_storage_backend: str + hicache_storage_backend_extra_config_json: Optional[str] = None + hicache_storage_prefetch_policy: Optional[str] = None + hicache_write_policy: Optional[str] = None + + def __post_init__(self): + if self.hicache_storage_prefetch_policy is None: + pass + else: + allowed = ["best_effort", "wait_complete", "timeout"] + if self.hicache_storage_prefetch_policy not in allowed: + raise ValueError( + f"Invalid hicache_storage_prefetch_policy: {self.hicache_storage_prefetch_policy!r}. " + f"Expected one of {allowed}." + ) + + if self.hicache_write_policy is None: + return + allowed = ["write_back", "write_through", "write_through_selective"] + if self.hicache_write_policy not in allowed: + raise ValueError( + f"Invalid hicache_write_policy: {self.hicache_write_policy!r}. " + f"Expected one of {allowed}." + ) + + +@dataclass +class AttachHiCacheStorageReqOutput(BaseReq): + success: bool + message: str = "" + + +@dataclass +class DetachHiCacheStorageReqInput(BaseReq): + """Dynamically detach (disable) HiCache storage backend at runtime.""" + + pass + + +@dataclass +class DetachHiCacheStorageReqOutput(BaseReq): + success: bool + message: str = "" + + @dataclass class PauseGenerationReqInput(BaseReq): """ diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 99f14e3efa9b..6307450347e6 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -71,6 +71,8 @@ from sglang.srt.managers.io_struct import ( AbortReq, ActiveRanksOutput, + AttachHiCacheStorageReqInput, + AttachHiCacheStorageReqOutput, BaseBatchReq, BaseReq, BatchTokenizedEmbeddingReqInput, @@ -81,6 +83,8 @@ CloseSessionReqInput, ContinueGenerationReqInput, DestroyWeightsUpdateGroupReqInput, + DetachHiCacheStorageReqInput, + DetachHiCacheStorageReqOutput, ExpertDistributionReq, ExpertDistributionReqOutput, ExpertDistributionReqType, @@ -1020,6 +1024,8 @@ def init_request_dispatcher(self): (BatchTokenizedEmbeddingReqInput, self.handle_batch_embedding_request), (FlushCacheReqInput, self.flush_cache_wrapped), (ClearHiCacheReqInput, self.clear_hicache_storage_wrapped), + (AttachHiCacheStorageReqInput, self.attach_hicache_storage_wrapped), + (DetachHiCacheStorageReqInput, self.detach_hicache_storage_wrapped), (AbortReq, self.abort_request), (OpenSessionReqInput, self.open_session), (CloseSessionReqInput, self.close_session), @@ -2454,6 +2460,118 @@ def clear_hicache_storage_wrapped(self, recv_req: ClearHiCacheReqInput): if_success = False return ClearHiCacheReqOutput(success=if_success) + def _is_idle_for_hicache_storage_op(self) -> bool: + """Stricter idle check for storage attach/detach. + + We require: + - no running batches (including overlap/pp/disagg paths) via `_is_no_request()` + - no queued requests in scheduler queues (waiting/grammar/disagg queues) + """ + if not self._is_no_request(): + return False + if len(self.waiting_queue) != 0: + return False + if len(self.grammar_manager.grammar_queue) != 0: + return False + return True + + def attach_hicache_storage_wrapped( + self, recv_req: AttachHiCacheStorageReqInput + ) -> AttachHiCacheStorageReqOutput: + if not self.enable_hierarchical_cache: + return AttachHiCacheStorageReqOutput( + success=False, message="Hierarchical cache is not enabled." + ) + + if not self._is_idle_for_hicache_storage_op(): + return AttachHiCacheStorageReqOutput( + success=False, + message=( + "Reject attach: scheduler is not idle. " + f"#queue-req={len(self.waiting_queue)} " + f"#running-req={len(self.running_batch.reqs)}" + ), + ) + + if not hasattr(self.tree_cache, "attach_storage_backend"): + return AttachHiCacheStorageReqOutput( + success=False, + message="Current tree_cache implementation does not support dynamic attach.", + ) + + try: + ok, msg = self.tree_cache.attach_storage_backend( + storage_backend=recv_req.hicache_storage_backend, + storage_backend_extra_config_json=recv_req.hicache_storage_backend_extra_config_json, + served_model_name=self.server_args.served_model_name, + hicache_storage_prefetch_policy=recv_req.hicache_storage_prefetch_policy, + hicache_write_policy=recv_req.hicache_write_policy, + ) + except Exception as e: + logger.exception("Attach HiCache storage backend failed with exception.") + return AttachHiCacheStorageReqOutput(success=False, message=str(e)) + if ok: + self.enable_hicache_storage = True + self.server_args.hicache_storage_backend = recv_req.hicache_storage_backend + if recv_req.hicache_storage_backend_extra_config_json is not None: + self.server_args.hicache_storage_backend_extra_config = ( + recv_req.hicache_storage_backend_extra_config_json + ) + if recv_req.hicache_storage_prefetch_policy is not None: + self.server_args.hicache_storage_prefetch_policy = ( + recv_req.hicache_storage_prefetch_policy + ) + if recv_req.hicache_write_policy is not None: + self.server_args.hicache_write_policy = recv_req.hicache_write_policy + logger.info( + f"Attached HiCache storage backend: {recv_req.hicache_storage_backend}" + ) + return AttachHiCacheStorageReqOutput(success=ok, message=msg) + + def detach_hicache_storage_wrapped( + self, recv_req: DetachHiCacheStorageReqInput + ) -> DetachHiCacheStorageReqOutput: + if not self.enable_hierarchical_cache: + return DetachHiCacheStorageReqOutput( + success=False, message="Hierarchical cache is not enabled." + ) + + if not self._is_idle_for_hicache_storage_op(): + return DetachHiCacheStorageReqOutput( + success=False, + message=( + "Reject detach: scheduler is not idle. " + f"#queue-req={len(self.waiting_queue)} " + f"#running-req={len(self.running_batch.reqs)}" + ), + ) + + if not hasattr(self.tree_cache, "detach_storage_backend"): + return DetachHiCacheStorageReqOutput( + success=False, + message="Current tree_cache implementation does not support dynamic detach.", + ) + + # Idempotent detach: even if scheduler thinks storage is disabled, we still + # attempt best-effort cleanup in tree_cache (it may have leftover state). + try: + ok, msg = self.tree_cache.detach_storage_backend() + except Exception as e: + logger.exception("Detach HiCache storage backend failed with exception.") + return DetachHiCacheStorageReqOutput(success=False, message=str(e)) + + if ok or (not self.enable_hicache_storage): + # Treat "already disabled / nothing to do" as success for idempotence. + self.enable_hicache_storage = False + self.server_args.hicache_storage_backend = None + self.server_args.hicache_storage_backend_extra_config = None + logger.info("Detached HiCache storage backend.") + return DetachHiCacheStorageReqOutput( + success=True, message=msg or "HiCache storage backend is detached." + ) + + return DetachHiCacheStorageReqOutput(success=False, message=msg) + def _is_no_request(self): no_request = ( self.running_batch.is_empty() diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py index e25729e71dde..f2f9791e61a0 100644 --- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py +++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -23,6 +23,8 @@ import zmq from sglang.srt.managers.io_struct import ( + AttachHiCacheStorageReqInput, + AttachHiCacheStorageReqOutput, CheckWeightsReqInput, CheckWeightsReqOutput, ClearHiCacheReqInput, @@ -30,6 +32,8 @@ CloseSessionReqInput, DestroyWeightsUpdateGroupReqInput, DestroyWeightsUpdateGroupReqOutput, + DetachHiCacheStorageReqInput, + DetachHiCacheStorageReqOutput, ExpertDistributionReq, ExpertDistributionReqOutput, ExpertDistributionReqType, @@ -202,6 +206,12 @@ def init_communicators(self: TokenizerManager, server_args: ServerArgs): self.clear_hicache_storage_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) + self.attach_hicache_storage_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.detach_hicache_storage_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) self.profile_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) @@ -281,6 +291,14 @@ def _get_communicator_dispatcher(self: TokenizerManager): ClearHiCacheReqOutput, self.clear_hicache_storage_communicator.handle_recv, ), + ( + AttachHiCacheStorageReqOutput, + self.attach_hicache_storage_communicator.handle_recv, + ), + ( + DetachHiCacheStorageReqOutput, + self.detach_hicache_storage_communicator.handle_recv, + ), ( FlushCacheReqOutput, self.flush_cache_communicator.handle_recv, @@ -326,6 +344,57 @@ async def clear_hicache_storage(self: TokenizerManager) -> ClearHiCacheReqOutput 0 ] + async def attach_hicache_storage( + self: TokenizerManager, + hicache_storage_backend: str, + hicache_storage_backend_extra_config_json: Optional[str] = None, + hicache_storage_prefetch_policy: Optional[str] = None, + hicache_write_policy: Optional[str] = None, + ) -> AttachHiCacheStorageReqOutput: + """Attach (enable) HiCache storage backend at runtime.""" + results = await self.attach_hicache_storage_communicator( + AttachHiCacheStorageReqInput( + hicache_storage_backend=hicache_storage_backend, + hicache_storage_backend_extra_config_json=hicache_storage_backend_extra_config_json, + hicache_storage_prefetch_policy=hicache_storage_prefetch_policy, + hicache_write_policy=hicache_write_policy, + ) + ) + + all_success, all_message = _Communicator.merge_results(results) + out = AttachHiCacheStorageReqOutput(success=all_success, message=all_message) + # TODO: partial rollback if failed + if all_success: + # Keep tokenizer side server_info consistent with scheduler side. + self.server_args.hicache_storage_backend = hicache_storage_backend + if hicache_storage_backend_extra_config_json is not None: + self.server_args.hicache_storage_backend_extra_config = ( + hicache_storage_backend_extra_config_json + ) + if hicache_storage_prefetch_policy is not None: + self.server_args.hicache_storage_prefetch_policy = ( + hicache_storage_prefetch_policy + ) + if hicache_write_policy is not None: + self.server_args.hicache_write_policy = hicache_write_policy + return out + + async def detach_hicache_storage( + self: TokenizerManager, + ) -> DetachHiCacheStorageReqOutput: + """Detach (disable) HiCache storage backend at runtime.""" + results = await self.detach_hicache_storage_communicator( + DetachHiCacheStorageReqInput() + ) + + all_success, all_message = _Communicator.merge_results(results) + out = DetachHiCacheStorageReqOutput(success=all_success, message=all_message) + # TODO: partial rollback if failed + if all_success: + self.server_args.hicache_storage_backend = None + self.server_args.hicache_storage_backend_extra_config = None + return out + async def start_profile( self: TokenizerManager, output_dir: Optional[str] = None, diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 443d654b9f36..ac38d0bc8af2 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -1,10 +1,12 @@ from __future__ import annotations +import atexit import heapq import json import logging import threading import time +from queue import Empty from typing import TYPE_CHECKING, List, Optional import torch @@ -43,6 +45,7 @@ class HiRadixCache(RadixCache): def __init__(self, params: CacheInitParams, server_args: ServerArgs): + self._enable_metrics_flag = params.enable_metrics if server_args.hicache_io_backend == "direct": # FIXME: move this logic into server_args parsing if server_args.hicache_mem_layout == "page_first": @@ -94,12 +97,6 @@ def __init__(self, params: CacheInitParams, server_args: ServerArgs): ) = self._parse_storage_backend_extra_config( server_args.hicache_storage_backend_extra_config ) - self.prefetch_threshold = prefetch_threshold - self.prefetch_timeout_base = prefetch_timeout_base - self.prefetch_timeout_per_page = ( - self.page_size / 1024 * prefetch_timeout_per_ki_token - ) - self.hicache_storage_pass_prefix_keys = hicache_storage_pass_prefix_keys # TODO: support more timeout check functions self.is_prefetch_timeout = self._prefetch_timeout_check_linear_func self.prefetch_stop_policy = server_args.hicache_storage_prefetch_policy @@ -114,22 +111,21 @@ def __init__(self, params: CacheInitParams, server_args: ServerArgs): write_policy=server_args.hicache_write_policy, io_backend=server_args.hicache_io_backend, storage_backend=server_args.hicache_storage_backend, - prefetch_threshold=self.prefetch_threshold, + prefetch_threshold=prefetch_threshold, model_name=server_args.served_model_name, storage_backend_extra_config=extra_config, pp_rank=self.pp_rank, pp_size=self.pp_size, ) - if self.enable_storage_metrics: - # TODO: support pp - labels = { - "storage_backend": server_args.hicache_storage_backend, - "tp_rank": self.cache_controller.tp_rank, - "dp_rank": self.cache_controller.dp_rank, - "pp_rank": self.cache_controller.pp_rank, - "pp_size": self.cache_controller.pp_size, - } - self.storage_metrics_collector = StorageMetricsCollector(labels=labels) + self._apply_storage_runtime_config( + storage_backend=server_args.hicache_storage_backend, + prefetch_threshold=prefetch_threshold, + prefetch_timeout_base=prefetch_timeout_base, + prefetch_timeout_per_ki_token=prefetch_timeout_per_ki_token, + hicache_storage_pass_prefix_keys=hicache_storage_pass_prefix_keys, + enable_storage=self.enable_storage, + enable_storage_metrics=self.enable_storage_metrics, + ) # record the nodes with ongoing write through self.ongoing_write_through = {} @@ -144,8 +140,333 @@ def __init__(self, params: CacheInitParams, server_args: ServerArgs): ) self.load_back_threshold = 10 + # Detach storage backend automatically on process shutdown + atexit.register(self.shutdown) + super().__init__(params=params) + def shutdown(self): + """Best-effort auto-detach of storage backend on process shutdown. + + This keeps startup and runtime behavior consistent: if a backend was attached + (either via CLI args or via admin API), we attempt to detach it on exit. + """ + try: + if self.enable_storage: + self.detach_storage_backend() + except Exception: + logger.exception("Failed to detach storage backend on process shutdown.") + + def _apply_storage_runtime_config( + self, + *, + storage_backend: Optional[str], + prefetch_threshold: int, + prefetch_timeout_base: float, + prefetch_timeout_per_ki_token: float, + hicache_storage_pass_prefix_keys: bool, + enable_storage: bool, + enable_storage_metrics: bool, + ) -> None: + prefetch_timeout_per_page = ( + self.page_size / 1024 * prefetch_timeout_per_ki_token + ) + + storage_metrics_collector = None + if enable_storage_metrics: + labels = { + "storage_backend": storage_backend, + "tp_rank": self.cache_controller.tp_rank, + "dp_rank": self.cache_controller.dp_rank, + "pp_rank": self.cache_controller.pp_rank, + "pp_size": self.cache_controller.pp_size, + } + storage_metrics_collector = StorageMetricsCollector(labels=labels) + + self.enable_storage = enable_storage + self.prefetch_threshold = prefetch_threshold + self.prefetch_timeout_base = prefetch_timeout_base + self.prefetch_timeout_per_page = prefetch_timeout_per_page + self.hicache_storage_pass_prefix_keys = hicache_storage_pass_prefix_keys + self.enable_storage_metrics = enable_storage_metrics + if self.enable_storage_metrics: + self.storage_metrics_collector = storage_metrics_collector + else: + self.storage_metrics_collector = None + + def attach_storage_backend( + self, + storage_backend: str, + storage_backend_extra_config_json: Optional[str] = None, + served_model_name: Optional[str] = None, + hicache_storage_prefetch_policy: Optional[str] = None, + hicache_write_policy: Optional[str] = None, + ) -> tuple[bool, str]: + """Attach (enable) storage backend at runtime. + + This will start storage threads inside `HiCacheController` and enable + prefetch/backup paths. Caller must ensure there are no running/queued + requests to avoid races. + """ + # Validate inputs first (no side effects). + if hicache_storage_prefetch_policy is not None: + allowed = ["best_effort", "wait_complete", "timeout"] + if hicache_storage_prefetch_policy not in allowed: + return ( + False, + f"Invalid hicache_storage_prefetch_policy: {hicache_storage_prefetch_policy!r}. " + f"Expected one of {allowed}.", + ) + + if hicache_write_policy is not None: + allowed = ["write_back", "write_through", "write_through_selective"] + if hicache_write_policy not in allowed: + return ( + False, + f"Invalid hicache_write_policy: {hicache_write_policy!r}. " + f"Expected one of {allowed}.", + ) + + # If already enabled: + # - backend unchanged: treat as success, update policies only. + # - backend changed: treat as failure, do NOT update policies. + if self.enable_storage: + current_backend = self.cache_controller.storage_backend_type + + if current_backend == storage_backend: + if hicache_storage_prefetch_policy is not None: + self.prefetch_stop_policy = hicache_storage_prefetch_policy + logger.info( + f"Set hicache_storage_prefetch_policy to {hicache_storage_prefetch_policy}" + ) + if hicache_write_policy is not None: + self.cache_controller.write_policy = hicache_write_policy + self.write_through_threshold = ( + 1 if hicache_write_policy == "write_through" else 2 + ) + logger.info(f"Set hicache_write_policy to {hicache_write_policy}") + return ( + True, + "HiCache storage backend already enabled with same backend; policies updated.", + ) + + return ( + False, + f"HiCache storage backend is already enabled with backend '{current_backend}'. " + f"Cannot attach different backend '{storage_backend}'. Detach first.", + ) + + # Not enabled: update policies before controller attach so storage threads observe new values. + if hicache_storage_prefetch_policy is not None: + self.prefetch_stop_policy = hicache_storage_prefetch_policy + logger.info( + f"Set hicache_storage_prefetch_policy to {hicache_storage_prefetch_policy}" + ) + + if hicache_write_policy is not None: + self.cache_controller.write_policy = hicache_write_policy + self.write_through_threshold = ( + 1 if hicache_write_policy == "write_through" else 2 + ) + logger.info(f"Set hicache_write_policy to {hicache_write_policy}") + + logger.info(f"Attaching HiCache storage backend: {storage_backend}") + try: + ( + extra_config, + prefetch_threshold, + prefetch_timeout_base, + prefetch_timeout_per_ki_token, + hicache_storage_pass_prefix_keys, + ) = self._parse_storage_backend_extra_config( + storage_backend_extra_config_json + ) + except Exception as e: + logger.exception(f"Failed to parse storage_backend_extra_config_json: {e}") + return ( + False, + f"Failed to parse storage_backend_extra_config_json '{storage_backend_extra_config_json}': {e}", + ) + + try: + self.cache_controller.attach_storage_backend( + storage_backend=storage_backend, + prefetch_threshold=prefetch_threshold, + model_name=served_model_name, + storage_backend_extra_config=extra_config, + ) + except Exception as e: + logger.exception( + f"Failed to attach storage backend '{storage_backend}': {e}" + ) + return False, f"Failed to attach storage backend '{storage_backend}': {e}" + + self._apply_storage_runtime_config( + storage_backend=storage_backend, + prefetch_threshold=prefetch_threshold, + prefetch_timeout_base=prefetch_timeout_base, + prefetch_timeout_per_ki_token=prefetch_timeout_per_ki_token, + hicache_storage_pass_prefix_keys=hicache_storage_pass_prefix_keys, + enable_storage=True, + enable_storage_metrics=self._enable_metrics_flag, + ) + return True, "Attached HiCache storage backend successfully." + + def detach_storage_backend(self) -> tuple[bool, str]: + """Detach (disable) storage backend at runtime. + + Caller must ensure there are no running/queued requests to avoid races. + """ + try: + # Drain any pending control queues before tearing down storage threads/backend. + # IMPORTANT: this must happen before we clear `ongoing_*`, otherwise acks/releases + # cannot be matched to nodes and may leak host pages / locks. + self._drain_storage_control_queues_local() + # Idempotent detach: always ask controller to best-effort cleanup, even if + # `self.enable_storage` is already False (may be leftover state from a + # previous partial detach). + self.cache_controller.detach_storage_backend() + except Exception as e: + logger.exception("Failed to detach storage backend.") + # Do NOT crash the server for admin operations. Return failure with detail. + return False, f"Failed to detach HiCache storage backend: {e}" + + # Best-effort cleanup of any leftover bookkeeping. + self._drain_storage_control_queues_local() + # After controller threads are fully stopped, it's safe to force-release any + # leftover pending ops (e.g., async prefetch/backup that didn't get a revoke/ack). + self._force_release_pending_storage_ops() + + self.enable_storage = False + self.enable_storage_metrics = False + if hasattr(self, "storage_metrics_collector"): + self.storage_metrics_collector = None + return True, "Detached HiCache storage backend successfully." + + def _force_release_pending_storage_ops(self): + """Force release any leftover pending prefetch/backup bookkeeping. + + This is a safety net for detach/shutdown paths. It assumes storage threads + have been stopped already (via controller.detach), so no concurrent access + to these structures should happen. + """ + cc = self.cache_controller + + # Force release leftover prefetch ops: free pre-allocated host pages and + # drop the host protection on the matched prefix node. + try: + for req_id, info in list(self.ongoing_prefetch.items()): + try: + last_host_node, token_ids, host_indices, _operation = info + except Exception: + # Unexpected shape; just drop it. + self.ongoing_prefetch.pop(req_id, None) + continue + + try: + if host_indices is not None: + cc.mem_pool_host.free(host_indices) + except Exception: + logger.exception( + "Failed to free host indices for prefetch %s", req_id + ) + + try: + last_host_node.release_host() + except Exception: + logger.exception( + "Failed to release host protection for prefetch %s", req_id + ) + + try: + cc.prefetch_tokens_occupied -= len(token_ids) + if cc.prefetch_tokens_occupied < 0: + cc.prefetch_tokens_occupied = 0 + except Exception: + pass + + self.ongoing_prefetch.pop(req_id, None) + except Exception: + logger.exception("Force release pending prefetch ops failed.") + + # Force release leftover backup ops: drop host protection on nodes. + try: + for ack_id, node in list(self.ongoing_backup.items()): + try: + node.release_host() + except Exception: + logger.exception( + "Failed to release host protection for backup op %s", ack_id + ) + self.ongoing_backup.pop(ack_id, None) + except Exception: + logger.exception("Force release pending backup ops failed.") + + def _drain_storage_control_queues_local(self): + """Drain storage control queues without TP synchronization. + + This is intended for shutdown/detach paths where we want to make best-effort + cleanup even if queue sizes temporarily differ across ranks. + """ + self._drain_storage_control_queues_impl( + n_revoke=None, + n_backup=None, + n_release=None, + log_metrics=False, + ) + + def _drain_storage_control_queues_impl( + self, + n_revoke: Optional[int], + n_backup: Optional[int], + n_release: Optional[int], + log_metrics: bool, + ): + cc = self.cache_controller + + def _drain_queue(q, limit: Optional[int]): + drained = 0 + while limit is None or drained < limit: + try: + item = q.get_nowait() + except Empty: + break + drained += 1 + yield item + + def _drain_revoke(): + for req_id in _drain_queue(cc.prefetch_revoke_queue, n_revoke): + info = self.ongoing_prefetch.pop(req_id, None) + if info is not None: + last_host_node, token_ids, _, _ = info + last_host_node.release_host() + cc.prefetch_tokens_occupied -= len(token_ids) + if cc.prefetch_tokens_occupied < 0: + cc.prefetch_tokens_occupied = 0 + + def _drain_backup(): + for operation in _drain_queue(cc.ack_backup_queue, n_backup): + ack_id = operation.id + entry = self.ongoing_backup.pop(ack_id, None) + if entry is not None: + entry.release_host() + if log_metrics and self.enable_storage_metrics: + self.storage_metrics_collector.log_backuped_tokens( + operation.completed_tokens + ) + + def _drain_release(): + host_indices_list = [] + for host_indices in _drain_queue(cc.host_mem_release_queue, n_release): + host_indices_list.append(host_indices) + if host_indices_list: + host_indices = torch.cat(host_indices_list, dim=0) + cc.mem_pool_host.free(host_indices) + + _drain_revoke() + _drain_backup() + _drain_release() + def _parse_storage_backend_extra_config( self, storage_backend_extra_config: Optional[str] ): @@ -188,6 +509,11 @@ def _parse_storage_backend_extra_config( raise ValueError( f"prefetch_timeout_per_ki_token must be number, got {type(prefetch_timeout_per_ki_token).__name__}" ) + if not isinstance(hicache_storage_pass_prefix_keys, bool): + raise ValueError( + "hicache_storage_pass_prefix_keys must be bool, got " + f"{type(hicache_storage_pass_prefix_keys).__name__}" + ) return ( extra_config, @@ -548,36 +874,12 @@ def drain_storage_control_queues(self): ) n_revoke, n_backup, n_release = map(int, qsizes.tolist()) - - # process prefetch revokes - for _ in range(n_revoke): - req_id = cc.prefetch_revoke_queue.get() - info = self.ongoing_prefetch.pop(req_id, None) - if info is not None: - last_host_node, token_ids, _, _ = info - last_host_node.release_host() - cc.prefetch_tokens_occupied -= len(token_ids) - # else: the revoked operation already got terminated, nothing to do - - # process backup acks - for _ in range(n_backup): - operation = cc.ack_backup_queue.get() - ack_id = operation.id - entry = self.ongoing_backup.pop(ack_id, None) - if entry is not None: - entry.release_host() - if self.enable_storage_metrics: - self.storage_metrics_collector.log_backuped_tokens( - operation.completed_tokens - ) - - # release host memory - host_indices_list = [] - for _ in range(n_release): - host_indices_list.append(cc.host_mem_release_queue.get()) - if host_indices_list: - host_indices = torch.cat(host_indices_list, dim=0) - cc.mem_pool_host.free(host_indices) + self._drain_storage_control_queues_impl( + n_revoke=n_revoke, + n_backup=n_backup, + n_release=n_release, + log_metrics=True, + ) # Timeout is linearly increasing with the number of pages def _prefetch_timeout_check_linear_func(self, operation: PrefetchOperation): diff --git a/test/registered/hicache/test_hicache_storage_runtime_attach_detach.py b/test/registered/hicache/test_hicache_storage_runtime_attach_detach.py new file mode 100644 index 000000000000..65e52041d153 --- /dev/null +++ b/test/registered/hicache/test_hicache_storage_runtime_attach_detach.py @@ -0,0 +1,368 @@ +""" +E2E smoke test for HiCache storage runtime attach/detach. + +This test launches an SGLang server with hierarchical cache enabled but WITHOUT +any storage backend at startup, then attaches/detaches a storage backend via the +HTTP endpoints. + +Usage: + python3 -m pytest test/registered/hicache/test_hicache_storage_runtime_attach_detach.py -v +""" + +import json +import os +import tempfile +import time +import unittest +from urllib import error, request + +from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + find_available_port, + popen_launch_server, +) + +register_cuda_ci(est_time=200, suite="stage-b-test-large-2-gpu") + + +class TestHiCacheStorageRuntimeAttachDetach(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.temp_dir = tempfile.mkdtemp() + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + # Use a per-test-class available port to reduce flakiness / conflicts. + default_port = int(DEFAULT_URL_FOR_TEST.rsplit(":", 1)[1]) + cls.base_url = f"http://127.0.0.1:{find_available_port(default_port)}" + + cls.other_args = [ + "--enable-hierarchical-cache", + "--mem-fraction-static", + "0.6", + "--hicache-ratio", + "1.2", + "--hicache-size", + "100", + "--page-size", + "64", + "--enable-cache-report", + # NOTE: do NOT pass --hicache-storage-backend* here + ] + + cls.env = { + **os.environ, + # File backend uses this env var to decide where to store cache pages. + "SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR": cls.temp_dir, + # Make runs less flaky for CI/dev. + "SGLANG_ENABLE_DETERMINISTIC_INFERENCE": "1", + } + + @classmethod + def tearDownClass(cls): + import shutil + + shutil.rmtree(cls.temp_dir, ignore_errors=True) + + @classmethod + def _wait_for_server_ready(cls, base_url: str, timeout: int = 60) -> bool: + start_time = time.time() + while time.time() - start_time < timeout: + try: + code, _body = cls._http_get(f"{base_url}/health", timeout=5) + if code == 200: + return True + except Exception: + pass + time.sleep(2) + raise TimeoutError("Server failed to start within timeout") + + @staticmethod + def _http_get(url: str, timeout: int = 10, headers: dict | None = None): + try: + req = request.Request(url, headers=headers or {}, method="GET") + with request.urlopen(req, timeout=timeout) as resp: + return resp.getcode(), resp.read().decode("utf-8", errors="replace") + except error.HTTPError as e: + body = e.read().decode("utf-8", errors="replace") + return e.code, body + + @staticmethod + def _http_post_json(url: str, payload: dict | None = None, timeout: int = 30): + data = None + headers = {} + if payload is not None: + data = json.dumps(payload).encode("utf-8") + headers["Content-Type"] = "application/json" + req = request.Request(url, data=data, headers=headers, method="POST") + try: + with request.urlopen(req, timeout=timeout) as resp: + return resp.getcode(), resp.read().decode("utf-8", errors="replace") + except error.HTTPError as e: + body = e.read().decode("utf-8", errors="replace") + return e.code, body + + @staticmethod + def _http_post_json_with_headers( + url: str, + payload: dict | None = None, + timeout: int = 30, + headers: dict | None = None, + ): + data = None + all_headers = dict(headers or {}) + if payload is not None: + data = json.dumps(payload).encode("utf-8") + all_headers["Content-Type"] = "application/json" + req = request.Request(url, data=data, headers=all_headers, method="POST") + try: + with request.urlopen(req, timeout=timeout) as resp: + return resp.getcode(), resp.read().decode("utf-8", errors="replace") + except error.HTTPError as e: + body = e.read().decode("utf-8", errors="replace") + return e.code, body + + @staticmethod + def _http_put_json_with_headers( + url: str, + payload: dict | None = None, + timeout: int = 30, + headers: dict | None = None, + ): + data = None + all_headers = dict(headers or {}) + if payload is not None: + data = json.dumps(payload).encode("utf-8") + all_headers["Content-Type"] = "application/json" + req = request.Request(url, data=data, headers=all_headers, method="PUT") + try: + with request.urlopen(req, timeout=timeout) as resp: + return resp.getcode(), resp.read().decode("utf-8", errors="replace") + except error.HTTPError as e: + body = e.read().decode("utf-8", errors="replace") + return e.code, body + + @staticmethod + def _http_delete_with_headers( + url: str, timeout: int = 30, headers: dict | None = None + ): + all_headers = dict(headers or {}) + req = request.Request(url, headers=all_headers, method="DELETE") + try: + with request.urlopen(req, timeout=timeout) as resp: + return resp.getcode(), resp.read().decode("utf-8", errors="replace") + except error.HTTPError as e: + body = e.read().decode("utf-8", errors="replace") + return e.code, body + + def _get_backend_status(self, base_url: str, headers: dict | None = None): + code, body = self._http_get( + f"{base_url}/hicache/storage-backend", timeout=10, headers=headers + ) + self.assertEqual(code, 200, body) + return json.loads(body) + + def _attach_backend( + self, + base_url: str, + backend: str, + extra_cfg: dict, + prefetch_policy: str = "timeout", + write_policy: str = "write_through", + headers: dict | None = None, + ): + payload = { + "hicache_storage_backend": backend, + "hicache_storage_backend_extra_config_json": json.dumps(extra_cfg), + "hicache_storage_prefetch_policy": prefetch_policy, + "hicache_write_policy": write_policy, + } + return self._http_put_json_with_headers( + f"{base_url}/hicache/storage-backend", + payload, + timeout=30, + headers=headers, + ) + + def _detach_backend(self, base_url: str, headers: dict | None = None): + return self._http_delete_with_headers( + f"{base_url}/hicache/storage-backend", + timeout=30, + headers=headers, + ) + + def test_runtime_attach_detach(self): + # Phase A: WITHOUT --admin-api-key, ADMIN_FORCE endpoints must be forbidden (403). + process1 = popen_launch_server( + self.model, + self.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=self.other_args, + env=self.env, + ) + try: + self._wait_for_server_ready(self.base_url) + + code_info, _body_info = self._http_get( + f"{self.base_url}/hicache/storage-backend", timeout=10 + ) + self.assertEqual(code_info, 400) + code_attach_no_admin, _body_attach_no_admin = self._attach_backend( + base_url=self.base_url, backend="file", extra_cfg={} + ) + self.assertEqual(code_attach_no_admin, 400) + code_detach_no_admin, _body_detach_no_admin = self._detach_backend( + self.base_url + ) + self.assertEqual(code_detach_no_admin, 400) + finally: + kill_process_tree(process1.pid) + time.sleep(2) + + # Phase B: WITH --admin-api-key, must provide Authorization: Bearer . + admin_key = "sglang-test-admin-key" + base_url2 = f"http://127.0.0.1:{find_available_port(int(self.base_url.rsplit(':', 1)[1]) + 1)}" + other_args2 = list(self.other_args) + ["--admin-api-key", admin_key] + process2 = popen_launch_server( + self.model, + base_url2, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args2, + env=self.env, + ) + try: + self._wait_for_server_ready(base_url2) + + # 1) Initially disabled (but unauthorized without admin key) + code_info2_unauth, _ = self._http_get( + f"{base_url2}/hicache/storage-backend", timeout=10 + ) + self.assertEqual(code_info2_unauth, 401) + + admin_headers = {"Authorization": f"Bearer {admin_key}"} + status0 = self._get_backend_status(base_url2, headers=admin_headers) + self.assertIsNone(status0.get("hicache_storage_backend")) + + # 2) Attach should succeed when idle + extra_cfg = { + "hicache_storage_pass_prefix_keys": True, + # keep knobs small and stable + "prefetch_threshold": 256, + "prefetch_timeout_base": 3, + "prefetch_timeout_per_ki_token": 0.01, + } + + # Unauthorized attach must fail. + code_attach_unauth, _ = self._attach_backend( + base_url=base_url2, backend="file", extra_cfg=extra_cfg + ) + self.assertEqual(code_attach_unauth, 401) + + code_attach, body_attach = self._attach_backend( + base_url=base_url2, + backend="file", + extra_cfg=extra_cfg, + prefetch_policy="timeout", + write_policy="write_back", + headers=admin_headers, + ) + self.assertEqual(code_attach, 200, f"{code_attach} - {body_attach}") + + status1 = self._get_backend_status(base_url2, headers=admin_headers) + self.assertEqual(status1.get("hicache_storage_backend"), "file") + self.assertEqual( + status1.get("hicache_storage_backend_extra_config"), + json.dumps(extra_cfg), + ) + self.assertEqual(status1.get("hicache_storage_prefetch_policy"), "timeout") + self.assertEqual(status1.get("hicache_write_policy"), "write_back") + + # 3) Attach again succeeds with policies updated + code_attach_again, body_attach_again = self._attach_backend( + base_url=base_url2, + backend="file", + extra_cfg=extra_cfg, + prefetch_policy="wait_complete", + write_policy="write_through_selective", + headers=admin_headers, + ) + self.assertEqual( + code_attach_again, 200, f"{code_attach_again} - {body_attach_again}" + ) + + status2 = self._get_backend_status(base_url2, headers=admin_headers) + self.assertEqual( + status2.get("hicache_storage_backend_extra_config"), + json.dumps(extra_cfg), + ) + self.assertEqual( + status2.get("hicache_storage_prefetch_policy"), "wait_complete" + ) + self.assertEqual( + status2.get("hicache_write_policy"), "write_through_selective" + ) + + # 4) Attach again with different backend should be rejected + code_attach_again, body_attach_again = self._attach_backend( + base_url=base_url2, + backend="mooncake", + extra_cfg=extra_cfg, + headers=admin_headers, + ) + self.assertNotEqual(code_attach_again, 200, body_attach_again) + + # 5) Detach should succeed and be idempotent + code_detach, body_detach = self._detach_backend( + base_url2, headers=admin_headers + ) + self.assertEqual(code_detach, 200, f"{code_detach} - {body_detach}") + status3 = self._get_backend_status(base_url2, headers=admin_headers) + self.assertIsNone(status3.get("hicache_storage_backend")) + self.assertEqual( + status3.get("hicache_storage_prefetch_policy"), "wait_complete" + ) + self.assertEqual( + status3.get("hicache_write_policy"), "write_through_selective" + ) + + code_detach_again, body_detach_again = self._detach_backend( + base_url2, headers=admin_headers + ) + self.assertEqual( + code_detach_again, + 200, + f"{code_detach_again} - {body_detach_again}", + ) + + # 6) Re-attach after detach should succeed + code_attach2, body_attach2 = self._attach_backend( + base_url=base_url2, + backend="file", + extra_cfg=extra_cfg, + headers=admin_headers, + ) + self.assertEqual(code_attach2, 200, f"{code_attach2} - {body_attach2}") + status4 = self._get_backend_status(base_url2, headers=admin_headers) + self.assertEqual(status4.get("hicache_storage_backend"), "file") + self.assertEqual( + status4.get("hicache_storage_backend_extra_config"), + json.dumps(extra_cfg), + ) + self.assertEqual(status4.get("hicache_storage_prefetch_policy"), "timeout") + self.assertEqual(status4.get("hicache_write_policy"), "write_through") + + # Cleanup: detach for test isolation + code_detach2, body_detach2 = self._detach_backend( + base_url2, headers=admin_headers + ) + self.assertEqual(code_detach2, 200, f"{code_detach2} - {body_detach2}") + finally: + kill_process_tree(process2.pid) + time.sleep(2) + + +if __name__ == "__main__": + unittest.main(verbosity=2)