diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py index a2d08e0e39f9..12e36fb1677b 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -8,7 +8,7 @@ import time import traceback from http import HTTPStatus -from typing import Dict, List, Optional, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple, Union import aiohttp import numpy as np @@ -26,6 +26,7 @@ from sglang.srt.disaggregation.encode_receiver import EmbeddingData from sglang.srt.distributed.parallel_state import ( get_mooncake_transfer_engine, + get_tp_group, init_distributed_environment, initialize_model_parallel, ) @@ -206,6 +207,22 @@ def __init__( self.schedule_socket = get_zmq_socket( self.context, zmq.PULL, schedule_path, True ) + self.background_tasks: Set[asyncio.Task] = set() + + if self.server_args.enable_mm_global_cache: + from sglang.srt.mem_cache.storage.mooncake_store.embedding_cache_controller import ( + EmbeddingCacheController, + ) + + self.mm_global_cache = EmbeddingCacheController( + rank, + server_args.tp_size, + hidden_dim=self.model_config.hidden_size, + tp_group=get_tp_group().cpu_group, + all_rank_get=False, + ) + else: + self.mm_global_cache = None if self.rank == 0: logger.info( @@ -214,10 +231,10 @@ def __init__( if self.server_args.encoder_transfer_backend == "mooncake": self.local_ip = get_local_ip_auto() + self.engine = get_mooncake_transfer_engine() self.embedding_to_send = dict() - self.background_tasks: Set[asyncio.Task] = set() logger.info(f"rank {rank} init finish ") @@ -312,6 +329,233 @@ async def _flatten_and_load_images(self, mm_items): async_futures = [asyncio.wrap_future(f) for f in futures] return await asyncio.gather(*async_futures) + def get_num_patches(self, grid: Union[torch.Tensor, List[int]]) -> int: + """Calculate number of raw patches (before 2x2 merge). Used for pixel_values slicing.""" + return int(grid[0] * grid[1] * grid[2]) + + def get_num_tokens(self, grid: Union[torch.Tensor, List[int]]) -> int: + """Calculate number of tokens (after 2x2 merge). Used for mm_embedding slicing.""" + merge_size = getattr(self.image_processor, "merge_size", 2) + return self.get_num_patches(grid) // (merge_size**2) + + def slice_embedding( + self, mm_embedding: torch.Tensor, grid_thw: List + ) -> List[torch.Tensor]: + """Slice a concatenated embedding tensor into individual image embeddings.""" + slices, offset = [], 0 + for grid in grid_thw: + count = self.get_num_tokens(grid) + slices.append(mm_embedding[offset : offset + count]) + offset += count + return slices + + def _calculate_hashes_from_features( + self, pixel_values: torch.Tensor, grid_thw: List + ) -> List[str]: + """CPU Task: Compute hashes based on processed feature patches (pixel_values).""" + hashes, offset = [], 0 + for grid in grid_thw: + num_patches = self.get_num_patches(grid) + feature_slice = pixel_values[offset : offset + num_patches] + tmp_item = MultimodalDataItem( + modality=Modality.IMAGE, feature=feature_slice + ) + tmp_item.set_pad_value() + hashes.append(tmp_item.hash) + offset += num_patches + return hashes + + async def _encode_missing( + self, pixel_values: torch.Tensor, images_input: dict, indices: List[int] + ) -> List[torch.Tensor]: + """ + GPU Task: Run ViT inference ONLY on the subset of images missing from the cache. + """ + grid_thw = images_input["image_grid_thw"] + + # 1. Slice pixel_values to get only the patches for missing images + sub_pixel_list = [] + offsets = [0] + curr = 0 + for g in grid_thw: + curr += self.get_num_patches(g) + offsets.append(curr) + + for idx in indices: + sub_pixel_list.append(pixel_values[offsets[idx] : offsets[idx + 1]]) + + sub_feature = torch.cat(sub_pixel_list, dim=0) + + mm_item = MultimodalDataItem.from_dict( + { + "modality": Modality.IMAGE, + "feature": _convert(sub_feature), + } + ) + + for k, v in images_input.items(): + if k == "pixel_values": + continue + val = _convert(v) + if k in _image_grid_attrs: + mm_item.set(k, val[indices]) + else: + mm_item.set(k, val) + + with torch.inference_mode(): + new_embeddings = self.model.get_image_feature([mm_item]).cpu() + if new_embeddings.ndim != 2: + new_embeddings = new_embeddings.reshape(-1, new_embeddings.shape[-1]) + + sub_grids = [grid_thw[i] for i in indices] + return self.slice_embedding(new_embeddings, sub_grids) + + async def encode_with_global_cache( + self, + mm_items, + req_id: str, + num_parts: int, + part_idx: int, + hashes: Optional[List[str]] = None, + ) -> torch.Tensor: + images = await self._flatten_and_load_images(mm_items) + kwargs = {"device": self.device} if self.use_image_processor_gpu else {} + images_input = self.image_processor(images=images, **kwargs) + pixel_values = images_input["pixel_values"] + grid_thw = images_input["image_grid_thw"] + num_images = len(grid_thw) + + # Step 1: Rank 0 checks global cache and broadcasts hit/miss mask to all ranks. + if self.rank == 0: + if hashes is None: + image_hashes = self._calculate_hashes_from_features( + pixel_values, grid_thw + ) + else: + image_hashes = hashes + exist_mask = await self.mm_global_cache.batch_is_exist(image_hashes) + mask_tensor = torch.tensor( + [1 if e else 0 for e in exist_mask], dtype=torch.int32 + ) + else: + image_hashes = None + mask_tensor = torch.zeros(num_images, dtype=torch.int32) + + if self.server_args.tp_size > 1: + torch.distributed.broadcast( + mask_tensor, + src=0, + group=self.mm_global_cache.prefetch_tp_group, + ) + + exist_mask = [m.item() == 1 for m in mask_tensor] + missing_indices = [i for i, e in enumerate(exist_mask) if not e] + hit_indices = [i for i, e in enumerate(exist_mask) if e] + + # Step 2: All ranks run ViT together on cache-miss images. + new_slices = [] + if missing_indices: + new_slices = await self._encode_missing( + pixel_values, images_input, missing_indices + ) + + # Step 3: Rank 0 prefetches cache-hit embeddings from global cache. + prefetch_status = torch.tensor([1], dtype=torch.int32) + + if self.rank == 0: + if hit_indices: + hit_hashes = [image_hashes[i] for i in hit_indices] + hit_tokens = [self.get_num_tokens(grid_thw[i]) for i in hit_indices] + self.mm_global_cache.prefetch(req_id, hit_hashes, hit_tokens) + + try: + + async def _wait_prefetch(): + while not self.mm_global_cache.check_prefetch_progress(req_id): + await asyncio.sleep(0.005) + + await asyncio.wait_for(_wait_prefetch(), timeout=60.0) + except (asyncio.TimeoutError, Exception) as e: + logger.error( + f"Prefetch failed for req {req_id}: {e}. " + f"Falling back to ViT for {len(hit_indices)} hit images." + ) + prefetch_status[0] = 0 + + # Step 4: Broadcast prefetch result to all ranks so they stay in sync. + if self.server_args.tp_size > 1: + torch.distributed.broadcast( + prefetch_status, + src=0, + group=self.mm_global_cache.prefetch_tp_group, + ) + + # Step 5: If prefetch failed, all ranks fallback to ViT for the hit images. + if prefetch_status.item() == 0 and hit_indices: + logger.info( + f"Req {req_id}: Prefetch failed, all ranks running ViT fallback " + f"for {len(hit_indices)} images." + ) + fallback_slices = await self._encode_missing( + pixel_values, images_input, hit_indices + ) + else: + fallback_slices = None + + # Step 6: Rank 0 assembles final embedding and prepares for sending. + if self.rank == 0: + final_slices = [None] * num_images + + for i, idx in enumerate(missing_indices): + final_slices[idx] = new_slices[i] + + # Fill in cache-hit embeddings (from prefetch or fallback) + if prefetch_status.item() == 1 and hit_indices: + cached_slices = self.mm_global_cache.get_embeddings( + [image_hashes[i] for i in hit_indices] + ) + for i, idx in enumerate(hit_indices): + final_slices[idx] = cached_slices[i] + elif fallback_slices is not None: + for i, idx in enumerate(hit_indices): + final_slices[idx] = fallback_slices[i] + + mm_embedding = torch.cat(final_slices, dim=0) + + # Background insert: store newly computed embeddings into global cache. + # Includes both original misses and fallback-recomputed hits. + all_new_hashes = [image_hashes[i] for i in missing_indices] + all_new_slices = list(new_slices) + if fallback_slices is not None: + all_new_hashes += [image_hashes[i] for i in hit_indices] + all_new_slices += list(fallback_slices) + + if all_new_hashes: + + async def _background_insert(): + await asyncio.to_thread( + self.mm_global_cache.insert_batch, + all_new_hashes, + all_new_slices, + ) + + task = asyncio.create_task(_background_insert()) + self.background_tasks.add(task) + task.add_done_callback(self.background_tasks.discard) + + self.embedding_to_send[req_id] = EmbeddingData( + req_id, num_parts, part_idx, grid_thw, mm_embedding + ) + return ( + mm_embedding.nbytes, + mm_embedding.shape[0], + mm_embedding.shape[1], + None, + None, + ) + else: + return (0, 0, 0, None, None) + async def _encode(self, mm_items) -> torch.Tensor: try: images = await self._flatten_and_load_images(mm_items) @@ -421,6 +665,9 @@ def send_with_socket(): await asyncio.get_event_loop().run_in_executor(self.executor, send_with_socket) + async def encode_with_hash(self, mm_items, req_id, num_parts, part_idx, hashes): + images = await self._flatten_and_load_images(mm_items) + async def encode(self, mm_items, req_id, num_parts, part_idx): try: image_grid_dim, mm_embedding = await self._encode(mm_items) @@ -644,12 +891,21 @@ async def run_encoder( else: encoder.profiler.stop() else: - await encoder.encode( - mm_items=request["mm_items"], - req_id=request["req_id"], - num_parts=request["num_parts"], - part_idx=request["part_idx"], - ) + if encoder.mm_global_cache is not None: + await encoder.encode_with_global_cache( + mm_items=request["mm_items"], + req_id=request["req_id"], + num_parts=request["num_parts"], + part_idx=request["part_idx"], + hashes=request.get("hashes", None), + ) + else: + await encoder.encode( + mm_items=request["mm_items"], + req_id=request["req_id"], + num_parts=request["num_parts"], + part_idx=request["part_idx"], + ) def launch_encoder(server_args, schedule_path, dist_init_method, rank): @@ -706,15 +962,25 @@ def start_background_send(req_id): request.update({"enter_time": time.time()}) for socket in send_sockets: socket.send_pyobj(request) - - nbytes, embedding_len, embedding_dim, error_msg, error_code = ( - await encoder.encode( - mm_items=request["mm_items"], - req_id=request["req_id"], - num_parts=request["num_parts"], - part_idx=request["part_idx"], + if encoder.mm_global_cache is not None: + nbytes, embedding_len, embedding_dim, error_msg, error_code = ( + await encoder.encode_with_global_cache( + mm_items=request["mm_items"], + req_id=request["req_id"], + num_parts=request["num_parts"], + part_idx=request["part_idx"], + hashes=request.get("hashes", None), + ) + ) + else: + nbytes, embedding_len, embedding_dim, error_msg, error_code = ( + await encoder.encode( + mm_items=request["mm_items"], + req_id=request["req_id"], + num_parts=request["num_parts"], + part_idx=request["part_idx"], + ) ) - ) if error_msg: if encoder.server_args.encoder_transfer_backend == "zmq_to_scheduler": diff --git a/python/sglang/srt/mem_cache/storage/mooncake_store/embedding_cache_controller.py b/python/sglang/srt/mem_cache/storage/mooncake_store/embedding_cache_controller.py new file mode 100644 index 000000000000..87090c9d7f72 --- /dev/null +++ b/python/sglang/srt/mem_cache/storage/mooncake_store/embedding_cache_controller.py @@ -0,0 +1,315 @@ +import asyncio +import logging +import threading +import time +from queue import Empty, Queue +from typing import List, Optional + +import torch + +from sglang.srt.mem_cache.storage.mooncake_store.mooncake_embedding_store import ( + MooncakeEmbeddingStore, +) + +logger = logging.getLogger(__name__) + + +class ContiguousMemoryAllocator: + """ + A simple allocator to manage variable-sized contiguous blocks + within a large pre-allocated flat buffer. + """ + + def __init__(self, total_size_bytes: int): + self.total_size = total_size_bytes + # List of (offset, size) for free blocks + self.free_blocks = [(0, total_size_bytes)] + self.allocated_map = {} # {handle: (offset, size)} + self.lock = threading.Lock() + + def allocate(self, size_bytes: int) -> Optional[int]: + with self.lock: + # Simple First-Fit allocation + for i, (offset, block_size) in enumerate(self.free_blocks): + if block_size >= size_bytes: + # Allocate from this block + remaining_size = block_size - size_bytes + if remaining_size > 0: + self.free_blocks[i] = (offset + size_bytes, remaining_size) + else: + self.free_blocks.pop(i) + return offset + return None + + def free(self, offset: int, size_bytes: int): + with self.lock: + # Return block and merge adjacent free blocks + self.free_blocks.append((offset, size_bytes)) + self.free_blocks.sort() + + merged = [] + if not self.free_blocks: + return + + curr_offset, curr_size = self.free_blocks[0] + for next_offset, next_size in self.free_blocks[1:]: + if curr_offset + curr_size == next_offset: + curr_size += next_size + else: + merged.append((curr_offset, curr_size)) + curr_offset, curr_size = next_offset, next_size + merged.append((curr_offset, curr_size)) + self.free_blocks = merged + + +class EmbeddingPrefetchOperation: + """Groups all missing images of a request for a single batch GET.""" + + def __init__(self, req_id: str, keys: List[str], ptrs: List[int], sizes: List[int]): + self.req_id = req_id + self.keys = keys + self.ptrs = ptrs + self.sizes = sizes + self.is_finished = False + self.success = False + self._lock = threading.Lock() + + def mark_done(self, success: bool): + with self._lock: + self.success = success + self.is_finished = True + + +class EmbeddingInsertOperation: + """Groups all newly computed images of a request for a single batch PUT.""" + + def __init__(self, keys: List[str], ptrs: List[int], sizes: List[int]): + self.keys = keys + self.ptrs = ptrs + self.sizes = sizes + + +class EmbeddingCacheController: + def __init__( + self, + tp_rank, + tp_size, + max_pool_size_gb=4.0, + hidden_dim=1024, + tp_group=None, + all_rank_get=False, + ): + self.tp_world_size = tp_size + self.tp_group = tp_group + self.all_rank_get = all_rank_get + self.hidden_dim = hidden_dim + self.element_size = torch.float32.itemsize + + # 1. Mooncake Backend & Pinned Buffer + self.mooncake_store = MooncakeEmbeddingStore() + self.total_pool_size_bytes = int(max_pool_size_gb * 1024**3) + self.cpu_pool = torch.empty( + self.total_pool_size_bytes, dtype=torch.uint8, pin_memory=True + ) + self.mooncake_store.register_buffer(self.cpu_pool) + + # 2. Variable Size Memory Management + self.allocator = ContiguousMemoryAllocator(self.total_pool_size_bytes) + self.hash_to_metadata = {} # {image_hash: (offset, num_tokens, size_bytes)} + + # 3. Task Tracking + self.ongoing_prefetch = {} # {req_id: EmbeddingPrefetchOperation} + self.prefetch_queue = Queue() + self.insert_queue = Queue() + + self.lock = threading.Lock() + self.stop_event = threading.Event() + self.io_thread = threading.Thread(target=self._io_loop, daemon=True) + self.io_thread.start() + + if self.tp_world_size > 1: + if self.tp_group is None: + raise ValueError("tp_group must be provided when tp_size > 1") + from sglang.srt.distributed.parallel_state import ( + create_custom_parallel_group, + ) + + group_ranks = torch.distributed.get_process_group_ranks(self.tp_group) + self.prefetch_tp_group = create_custom_parallel_group( + group_ranks=group_ranks, backend="gloo" + ) + else: + self.prefetch_tp_group = None + + def prefetch( + self, req_id: str, image_hashes: List[str], expected_tokens: List[int] + ): + """Issues ONE batch GET for all missing images in the request.""" + keys, ptrs, sizes = [], [], [] + + with self.lock: + for h, num_tokens in zip(image_hashes, expected_tokens): + if h in self.hash_to_metadata: + logger.debug( + f"Req {req_id}: Hash already in local metadata, skipping prefetch." + ) + continue + + size_bytes = num_tokens * self.hidden_dim * self.element_size + offset = self.allocator.allocate(size_bytes) + if offset is None: + continue + + self.hash_to_metadata[h] = (offset, num_tokens, size_bytes) + keys.append(h) + ptrs.append(self.cpu_pool.data_ptr() + offset) + sizes.append(size_bytes) + + if not keys: + return + + logger.info( + f"Req {req_id}: Starting global fetch for {len(keys)} images from Mooncake." + ) + + op = EmbeddingPrefetchOperation(req_id, keys, ptrs, sizes) + self.ongoing_prefetch[req_id] = op + self.prefetch_queue.put(op) + + def insert_batch( + self, image_hashes: List[str], embedding_tensors: List[torch.Tensor] + ): + """Issues ONE batch PUT for all embeddings computed by this request.""" + keys, ptrs, sizes = [], [], [] + + with self.lock: + for h, tensor in zip(image_hashes, embedding_tensors): + if h in self.hash_to_metadata: + continue + + num_tokens = tensor.shape[0] + size_bytes = num_tokens * self.hidden_dim * self.element_size + offset = self.allocator.allocate(size_bytes) + if offset is None: + continue + + # Copy to pinned pool for RDMA + self.hash_to_metadata[h] = (offset, num_tokens, size_bytes) + target_view = ( + self.cpu_pool[offset : offset + size_bytes] + .view(torch.float32) + .view(num_tokens, self.hidden_dim) + ) + target_view.copy_(tensor.cpu()) + + keys.append(h) + ptrs.append(self.cpu_pool.data_ptr() + offset) + sizes.append(size_bytes) + + if keys: + logger.info( + f"Global Cache: Inserting {len(keys)} new embeddings into Mooncake cluster." + ) + self.insert_queue.put(EmbeddingInsertOperation(keys, ptrs, sizes)) + + def _io_loop(self): + """Asynchronous worker handling both Batch GET and Batch PUT.""" + while not self.stop_event.is_set(): + processed_any = False + + try: + op = self.prefetch_queue.get_nowait() + results = self.mooncake_store.batch_get(op.keys, op.ptrs, op.sizes) + success_count = sum(results) + logger.info( + f"Mooncake GET Finished: Req {op.req_id}, Successfully fetched {success_count}/{len(op.keys)} images." + ) + op.mark_done(all(results)) + self.prefetch_queue.task_done() + processed_any = True + except Empty: + pass + + try: + op = self.insert_queue.get_nowait() + self.mooncake_store.batch_put(op.keys, op.ptrs, op.sizes) + logger.info( + f"Mooncake PUT Finished: Successfully stored {len(op.keys)} keys in cluster." + ) + self.insert_queue.task_done() + processed_any = True + except Empty: + pass + + if not processed_any: + time.sleep(0.001) + + def check_prefetch_progress(self, req_id: str) -> bool: + """TP-Group barrier: ensures all cards have the request batch ready.""" + local_ready = False + with self.lock: + if req_id not in self.ongoing_prefetch: + local_ready = True + else: + op = self.ongoing_prefetch[req_id] + if op.is_finished: + local_ready = op.success + + if self.all_rank_get and self.tp_world_size > 1: + ready_tensor = torch.tensor( + [1 if local_ready else 0], dtype=torch.int, device="cpu" + ) + torch.distributed.all_reduce( + ready_tensor, + op=torch.distributed.ReduceOp.MIN, + group=self.prefetch_tp_group, + ) + local_ready = ready_tensor.item() == 1 + + if local_ready: + with self.lock: + self.ongoing_prefetch.pop(req_id, None) + return True + return False + + def get_embeddings(self, image_hashes: List[str]) -> List[torch.Tensor]: + """Final reconstruction for model input.""" + with self.lock: + tensors = [] + for h in image_hashes: + offset, num_tokens, size_bytes = self.hash_to_metadata[h] + tensors.append( + self.cpu_pool[offset : offset + size_bytes] + .view(torch.float32) + .view(num_tokens, self.hidden_dim) + ) + return tensors + + async def batch_is_exist(self, image_hashes: List[str]) -> List[bool]: + with self.lock: + local_results = [h in self.hash_to_metadata for h in image_hashes] + local_hit_count = sum(local_results) + + global_hit_count = 0 + if not all(local_results): + missing_indices = [i for i, res in enumerate(local_results) if not res] + missing_hashes = [image_hashes[i] for i in missing_indices] + + global_exists = await asyncio.to_thread( + self.mooncake_store.batch_is_exist, missing_hashes + ) + global_hit_count = sum(global_exists) + + for i, exists in zip(missing_indices, global_exists): + local_results[i] = exists + + total = len(image_hashes) + miss_count = total - local_hit_count - global_hit_count + logger.info( + f"=== Multi-Level Cache Check === " + f"Total: {total} | " + f"Local Hits: {local_hit_count} | " + f"Global Hits: {global_hit_count} | " + f"Misses (GPU Work): {miss_count}" + ) + return local_results diff --git a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_embedding_store.py b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_embedding_store.py new file mode 100644 index 000000000000..f358d97dcc79 --- /dev/null +++ b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_embedding_store.py @@ -0,0 +1,68 @@ +import logging +from typing import Any, List + +from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import MooncakeBaseStore + +logger = logging.getLogger(__name__) + + +class MooncakeEmbeddingStore(MooncakeBaseStore): + def __init__( + self, + storage_config: Any = None, + ): + super().__init__() + + MooncakeDistributedStore = self._import_mooncake_store() + self.store = MooncakeDistributedStore() + self.config = self._load_config(storage_config) + ret_code = self.store.setup( + self.config.local_hostname, + self.config.metadata_server, + self.config.global_segment_size, + 16 * 1024 * 1024, # Internal local buffer size + self.config.protocol, + self.config.device_name, + self.config.master_server_address, + ) + if ret_code != 0: + raise RuntimeError(f"Failed to setup Mooncake Embedding Store: {ret_code}") + + logger.info("Mooncake Embedding Store initialized successfully.") + + def get_key(self, image_hash: str) -> str: + return f"emb_{image_hash}" + + def batch_get( + self, hashes: List[str], ptrs: List[int], sizes: List[int] + ) -> List[bool]: + keys = [self.get_key(h) for h in hashes] + results = self.store.batch_get_into(keys, ptrs, sizes) + return [res > 0 for res in results] + + def batch_put( + self, hashes: List[str], ptrs: List[int], sizes: List[int] + ) -> List[bool]: + keys = [self.get_key(h) for h in hashes] + exists = self.store.batch_is_exist(keys) + + put_keys, put_ptrs, put_sizes, indices = [], [], [], [] + success_map = [True] * len(hashes) + + for i, status in enumerate(exists): + if status != 1: + put_keys.append(keys[i]) + put_ptrs.append(ptrs[i]) + put_sizes.append(sizes[i]) + indices.append(i) + + if put_keys: + results = self.store.batch_put_from(put_keys, put_ptrs, put_sizes) + for i, res in enumerate(results): + success_map[indices[i]] = res == 0 + return success_map + + def batch_is_exist(self, hashes: List[str]) -> List[bool]: + keys = [self.get_key(h) for h in hashes] + results = self.store.batch_is_exist(keys) + return [res == 1 for res in results] diff --git a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py index 208e87c8ccd3..731e2f1bdf37 100644 --- a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +++ b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py @@ -222,47 +222,74 @@ def load_from_extra_config(extra_config: dict) -> "MooncakeStoreConfig": ) -class MooncakeStore(HiCacheStorage): +class MooncakeBaseStore: + def __init__(self): + self.store = None + self.config = None - def __init__( - self, storage_config: HiCacheStorageConfig = None, mem_pool: HostKVCache = None - ): + def _import_mooncake_store(self): try: from mooncake.store import MooncakeDistributedStore + + return MooncakeDistributedStore except ImportError as e: raise ImportError( "Please install mooncake by following the instructions at " - "https://kvcache-ai.github.io/Mooncake/getting_started/build.html" + "https://kvcache-ai.github.io/Mooncake/getting_started/build.html " "to run SGLang with MooncakeConnector." ) from e + def _load_config(self, storage_config: Any = None): + extra_config = ( + getattr(storage_config, "extra_config", None) if storage_config else None + ) + + if extra_config and ( + extra_config.get("master_server_address") is not None + or extra_config.get("client_server_address") is not None + ): + config = MooncakeStoreConfig.load_from_extra_config(extra_config) + logger.info("Mooncake Configuration loaded from extra_config successfully.") + + elif envs.SGLANG_HICACHE_MOONCAKE_CONFIG_PATH.is_set(): + config = MooncakeStoreConfig.from_file() + logger.info("Mooncake Configuration loaded from file successfully.") + + else: + config = MooncakeStoreConfig.load_from_env() + logger.info("Mooncake Configuration loaded from env successfully.") + + return config + + def register_buffer(self, tensor: torch.Tensor): + if self.store is None: + raise RuntimeError("Mooncake store is not initialized.") + ptr = tensor.data_ptr() + size = tensor.numel() * tensor.element_size() + ret_code = self.store.register_buffer(ptr, size) + if ret_code != 0: + logger.error(f"Failed to register buffer, error code: {ret_code}") + raise RuntimeError( + f"Failed to register buffer to Mooncake Store, error code: {ret_code}" + ) + + +class MooncakeStore(HiCacheStorage, MooncakeBaseStore): + + def __init__( + self, storage_config: HiCacheStorageConfig = None, mem_pool: HostKVCache = None + ): + MooncakeBaseStore.__init__(self) + MooncakeDistributedStore = self._import_mooncake_store() try: self.store = MooncakeDistributedStore() + self.config = self._load_config(storage_config) extra_config = ( getattr(storage_config, "extra_config", None) if storage_config else None ) - # Load configuration with master_server_address prioritized from extra_config if available - if extra_config is not None and ( - extra_config.get("master_server_address") is not None - or extra_config.get("client_server_address") is not None - ): - # Load from extra_config - self.config = MooncakeStoreConfig.load_from_extra_config(extra_config) - logger.info( - "Mooncake Configuration loaded from extra_config successfully." - ) - elif envs.SGLANG_HICACHE_MOONCAKE_CONFIG_PATH.is_set(): - # Load from config file - self.config = MooncakeStoreConfig.from_file() - logger.info("Mooncake Configuration loaded from file successfully.") - else: - # Load from environment variables - self.config = MooncakeStoreConfig.load_from_env() - logger.info("Mooncake Configuration loaded from env successfully.") - tp_scale_factor = 1 if storage_config is None else storage_config.tp_size per_tp_global_segment_size = ( @@ -442,14 +469,7 @@ def register_mem_pool_host(self, mem_pool_host: HostKVCache): ], "mooncake store storage backend only support page first or page first direct layout" buffer = self.mem_pool_host.kv_buffer try: - buffer_ptr = buffer.data_ptr() - buffer_size = buffer.numel() * buffer.element_size() - ret_code = self.store.register_buffer(buffer_ptr, buffer_size) - if ret_code: - logger.error(f"Failed to register buffer, error code: {ret_code}") - raise RuntimeError( - f"Failed to register buffer to Mooncake Store, error code: {ret_code}" - ) + super().register_buffer(buffer) except TypeError as err: logger.error("Failed to register buffer to Mooncake Store: %s", err) raise TypeError("Mooncake Store Register Buffer Error.") from err diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index d15e5238f538..e5d308c8168c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -688,6 +688,7 @@ class ServerArgs: mm_enable_dp_encoder: bool = False mm_process_config: Optional[Dict[str, Any]] = None limit_mm_data_per_request: Optional[Union[str, Dict[str, int]]] = None + enable_mm_global_cache: bool = False # For checkpoint decryption decrypted_config_file: Optional[str] = None @@ -4976,6 +4977,13 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Enable prefix multimodal cache. Currently only supports mm-only.", ) + parser.add_argument( + "--enable-mm-global-cache", + action="store_true", + default=ServerArgs.enable_mm_global_cache, + help="Enable global multimodal embedding cache to skip redundant ViT inference.", + ) + # For registering hooks parser.add_argument( "--forward-hooks",