Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
298 changes: 282 additions & 16 deletions python/sglang/srt/disaggregation/encode_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import time
import traceback
from http import HTTPStatus
from typing import Dict, List, Optional, Set, Tuple
from typing import Dict, List, Optional, Set, Tuple, Union

import aiohttp
import numpy as np
Expand All @@ -26,6 +26,7 @@
from sglang.srt.disaggregation.encode_receiver import EmbeddingData
from sglang.srt.distributed.parallel_state import (
get_mooncake_transfer_engine,
get_tp_group,
init_distributed_environment,
initialize_model_parallel,
)
Expand Down Expand Up @@ -206,6 +207,22 @@ def __init__(
self.schedule_socket = get_zmq_socket(
self.context, zmq.PULL, schedule_path, True
)
self.background_tasks: Set[asyncio.Task] = set()

if self.server_args.enable_mm_global_cache:
from sglang.srt.mem_cache.storage.mooncake_store.embedding_cache_controller import (
EmbeddingCacheController,
)

self.mm_global_cache = EmbeddingCacheController(
rank,
server_args.tp_size,
hidden_dim=self.model_config.hidden_size,
tp_group=get_tp_group().cpu_group,
all_rank_get=False,
)
else:
self.mm_global_cache = None

if self.rank == 0:
logger.info(
Expand All @@ -214,10 +231,10 @@ def __init__(

if self.server_args.encoder_transfer_backend == "mooncake":
self.local_ip = get_local_ip_auto()

self.engine = get_mooncake_transfer_engine()

self.embedding_to_send = dict()
self.background_tasks: Set[asyncio.Task] = set()

logger.info(f"rank {rank} init finish ")

Expand Down Expand Up @@ -312,6 +329,233 @@ async def _flatten_and_load_images(self, mm_items):
async_futures = [asyncio.wrap_future(f) for f in futures]
return await asyncio.gather(*async_futures)

def get_num_patches(self, grid: Union[torch.Tensor, List[int]]) -> int:
"""Calculate number of raw patches (before 2x2 merge). Used for pixel_values slicing."""
return int(grid[0] * grid[1] * grid[2])
Comment thread
liusy58 marked this conversation as resolved.

def get_num_tokens(self, grid: Union[torch.Tensor, List[int]]) -> int:
"""Calculate number of tokens (after 2x2 merge). Used for mm_embedding slicing."""
merge_size = getattr(self.image_processor, "merge_size", 2)
return self.get_num_patches(grid) // (merge_size**2)

def slice_embedding(
self, mm_embedding: torch.Tensor, grid_thw: List
) -> List[torch.Tensor]:
"""Slice a concatenated embedding tensor into individual image embeddings."""
slices, offset = [], 0
for grid in grid_thw:
count = self.get_num_tokens(grid)
slices.append(mm_embedding[offset : offset + count])
offset += count
return slices

def _calculate_hashes_from_features(
self, pixel_values: torch.Tensor, grid_thw: List
) -> List[str]:
"""CPU Task: Compute hashes based on processed feature patches (pixel_values)."""
hashes, offset = [], 0
for grid in grid_thw:
num_patches = self.get_num_patches(grid)
feature_slice = pixel_values[offset : offset + num_patches]
tmp_item = MultimodalDataItem(
modality=Modality.IMAGE, feature=feature_slice
)
tmp_item.set_pad_value()
hashes.append(tmp_item.hash)
offset += num_patches
return hashes

async def _encode_missing(
self, pixel_values: torch.Tensor, images_input: dict, indices: List[int]
) -> List[torch.Tensor]:
"""
GPU Task: Run ViT inference ONLY on the subset of images missing from the cache.
"""
grid_thw = images_input["image_grid_thw"]

# 1. Slice pixel_values to get only the patches for missing images
sub_pixel_list = []
offsets = [0]
curr = 0
for g in grid_thw:
curr += self.get_num_patches(g)
offsets.append(curr)

for idx in indices:
sub_pixel_list.append(pixel_values[offsets[idx] : offsets[idx + 1]])

sub_feature = torch.cat(sub_pixel_list, dim=0)

mm_item = MultimodalDataItem.from_dict(
{
"modality": Modality.IMAGE,
"feature": _convert(sub_feature),
}
)

for k, v in images_input.items():
if k == "pixel_values":
continue
val = _convert(v)
if k in _image_grid_attrs:
mm_item.set(k, val[indices])
else:
mm_item.set(k, val)

with torch.inference_mode():
Comment thread
liusy58 marked this conversation as resolved.
new_embeddings = self.model.get_image_feature([mm_item]).cpu()
if new_embeddings.ndim != 2:
new_embeddings = new_embeddings.reshape(-1, new_embeddings.shape[-1])

sub_grids = [grid_thw[i] for i in indices]
return self.slice_embedding(new_embeddings, sub_grids)

async def encode_with_global_cache(
self,
mm_items,
req_id: str,
num_parts: int,
part_idx: int,
hashes: Optional[List[str]] = None,
) -> torch.Tensor:
Comment thread
liusy58 marked this conversation as resolved.
images = await self._flatten_and_load_images(mm_items)
kwargs = {"device": self.device} if self.use_image_processor_gpu else {}
images_input = self.image_processor(images=images, **kwargs)
pixel_values = images_input["pixel_values"]
grid_thw = images_input["image_grid_thw"]
num_images = len(grid_thw)

# Step 1: Rank 0 checks global cache and broadcasts hit/miss mask to all ranks.
if self.rank == 0:
if hashes is None:
image_hashes = self._calculate_hashes_from_features(
pixel_values, grid_thw
)
else:
image_hashes = hashes
exist_mask = await self.mm_global_cache.batch_is_exist(image_hashes)
mask_tensor = torch.tensor(
[1 if e else 0 for e in exist_mask], dtype=torch.int32
)
else:
image_hashes = None
mask_tensor = torch.zeros(num_images, dtype=torch.int32)

if self.server_args.tp_size > 1:
torch.distributed.broadcast(
mask_tensor,
src=0,
group=self.mm_global_cache.prefetch_tp_group,
)

exist_mask = [m.item() == 1 for m in mask_tensor]
missing_indices = [i for i, e in enumerate(exist_mask) if not e]
hit_indices = [i for i, e in enumerate(exist_mask) if e]

# Step 2: All ranks run ViT together on cache-miss images.
new_slices = []
if missing_indices:
new_slices = await self._encode_missing(
pixel_values, images_input, missing_indices
)

# Step 3: Rank 0 prefetches cache-hit embeddings from global cache.
prefetch_status = torch.tensor([1], dtype=torch.int32)

if self.rank == 0:
if hit_indices:
hit_hashes = [image_hashes[i] for i in hit_indices]
hit_tokens = [self.get_num_tokens(grid_thw[i]) for i in hit_indices]
self.mm_global_cache.prefetch(req_id, hit_hashes, hit_tokens)

try:

async def _wait_prefetch():
while not self.mm_global_cache.check_prefetch_progress(req_id):
await asyncio.sleep(0.005)

await asyncio.wait_for(_wait_prefetch(), timeout=60.0)
except (asyncio.TimeoutError, Exception) as e:
logger.error(
f"Prefetch failed for req {req_id}: {e}. "
f"Falling back to ViT for {len(hit_indices)} hit images."
)
prefetch_status[0] = 0

# Step 4: Broadcast prefetch result to all ranks so they stay in sync.
if self.server_args.tp_size > 1:
torch.distributed.broadcast(
prefetch_status,
src=0,
group=self.mm_global_cache.prefetch_tp_group,
)

# Step 5: If prefetch failed, all ranks fallback to ViT for the hit images.
if prefetch_status.item() == 0 and hit_indices:
logger.info(
f"Req {req_id}: Prefetch failed, all ranks running ViT fallback "
f"for {len(hit_indices)} images."
)
fallback_slices = await self._encode_missing(
pixel_values, images_input, hit_indices
)
else:
fallback_slices = None

# Step 6: Rank 0 assembles final embedding and prepares for sending.
if self.rank == 0:
final_slices = [None] * num_images

for i, idx in enumerate(missing_indices):
final_slices[idx] = new_slices[i]

# Fill in cache-hit embeddings (from prefetch or fallback)
if prefetch_status.item() == 1 and hit_indices:
cached_slices = self.mm_global_cache.get_embeddings(
[image_hashes[i] for i in hit_indices]
)
for i, idx in enumerate(hit_indices):
final_slices[idx] = cached_slices[i]
elif fallback_slices is not None:
for i, idx in enumerate(hit_indices):
final_slices[idx] = fallback_slices[i]

mm_embedding = torch.cat(final_slices, dim=0)

# Background insert: store newly computed embeddings into global cache.
# Includes both original misses and fallback-recomputed hits.
all_new_hashes = [image_hashes[i] for i in missing_indices]
all_new_slices = list(new_slices)
if fallback_slices is not None:
all_new_hashes += [image_hashes[i] for i in hit_indices]
all_new_slices += list(fallback_slices)

if all_new_hashes:

async def _background_insert():
await asyncio.to_thread(
self.mm_global_cache.insert_batch,
all_new_hashes,
all_new_slices,
)

task = asyncio.create_task(_background_insert())
self.background_tasks.add(task)
task.add_done_callback(self.background_tasks.discard)

self.embedding_to_send[req_id] = EmbeddingData(
req_id, num_parts, part_idx, grid_thw, mm_embedding
)
return (
mm_embedding.nbytes,
mm_embedding.shape[0],
mm_embedding.shape[1],
None,
None,
)
else:
return (0, 0, 0, None, None)

async def _encode(self, mm_items) -> torch.Tensor:
try:
images = await self._flatten_and_load_images(mm_items)
Expand Down Expand Up @@ -421,6 +665,9 @@ def send_with_socket():

await asyncio.get_event_loop().run_in_executor(self.executor, send_with_socket)

async def encode_with_hash(self, mm_items, req_id, num_parts, part_idx, hashes):
images = await self._flatten_and_load_images(mm_items)
Comment thread
liusy58 marked this conversation as resolved.
Comment thread
liusy58 marked this conversation as resolved.

async def encode(self, mm_items, req_id, num_parts, part_idx):
try:
image_grid_dim, mm_embedding = await self._encode(mm_items)
Expand Down Expand Up @@ -644,12 +891,21 @@ async def run_encoder(
else:
encoder.profiler.stop()
else:
await encoder.encode(
mm_items=request["mm_items"],
req_id=request["req_id"],
num_parts=request["num_parts"],
part_idx=request["part_idx"],
)
if encoder.mm_global_cache is not None:
await encoder.encode_with_global_cache(
mm_items=request["mm_items"],
req_id=request["req_id"],
num_parts=request["num_parts"],
part_idx=request["part_idx"],
hashes=request.get("hashes", None),
)
else:
await encoder.encode(
mm_items=request["mm_items"],
req_id=request["req_id"],
num_parts=request["num_parts"],
part_idx=request["part_idx"],
)


def launch_encoder(server_args, schedule_path, dist_init_method, rank):
Expand Down Expand Up @@ -706,15 +962,25 @@ def start_background_send(req_id):
request.update({"enter_time": time.time()})
for socket in send_sockets:
socket.send_pyobj(request)

nbytes, embedding_len, embedding_dim, error_msg, error_code = (
await encoder.encode(
mm_items=request["mm_items"],
req_id=request["req_id"],
num_parts=request["num_parts"],
part_idx=request["part_idx"],
if encoder.mm_global_cache is not None:
nbytes, embedding_len, embedding_dim, error_msg, error_code = (
await encoder.encode_with_global_cache(
mm_items=request["mm_items"],
req_id=request["req_id"],
num_parts=request["num_parts"],
part_idx=request["part_idx"],
hashes=request.get("hashes", None),
)
)
else:
nbytes, embedding_len, embedding_dim, error_msg, error_code = (
await encoder.encode(
mm_items=request["mm_items"],
req_id=request["req_id"],
num_parts=request["num_parts"],
part_idx=request["part_idx"],
)
)
)

if error_msg:
if encoder.server_args.encoder_transfer_backend == "zmq_to_scheduler":
Expand Down
Loading
Loading