Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 183 additions & 8 deletions python/sglang/srt/disaggregation/encode_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pickle
import time
import traceback
from typing import Dict, List, Optional, Set, Tuple
from typing import Dict, List, Optional, Set, Tuple, Union

import aiohttp
import numpy as np
Expand Down Expand Up @@ -194,6 +194,17 @@ def __init__(
gpu_id=None,
ib_device=server_args.disaggregation_ib_device,
)
if getattr(self.server_args, "enable_mm_global_cache", False):
from sglang.srt.managers.embedding_cache_controller import (
EmbeddingCacheController,
)

self.mm_global_cache = EmbeddingCacheController(
rank, server_args.tp_size, hidden_dim=self.model_config.hidden_size
)
self.background_tasks = set()
Comment thread
liusy58 marked this conversation as resolved.
Outdated
else:
self.mm_global_cache = None

self.embedding_to_send = dict()

Expand Down Expand Up @@ -290,6 +301,159 @@ async def _flatten_and_load_images(self, mm_items):
async_futures = [asyncio.wrap_future(f) for f in futures]
return await asyncio.gather(*async_futures)

def get_num_patches(self, grid: Union[torch.Tensor, List[int]]) -> int:
"""Calculate number of raw patches (before 2x2 merge). Used for pixel_values slicing."""
return int(grid[0] * grid[1] * grid[2])
Comment thread
liusy58 marked this conversation as resolved.

def get_num_tokens(self, grid: Union[torch.Tensor, List[int]]) -> int:
"""Calculate number of tokens (after 2x2 merge). Used for mm_embedding slicing."""
merge_size = getattr(self.image_processor, "merge_size", 2)
return self.get_num_patches(grid) // (merge_size**2)

def slice_embedding(
self, mm_embedding: torch.Tensor, grid_thw: List
) -> List[torch.Tensor]:
"""Slice a concatenated embedding tensor into individual image embeddings."""
slices, offset = [], 0
for grid in grid_thw:
count = self.get_num_tokens(grid)
slices.append(mm_embedding[offset : offset + count])
offset += count
return slices

def _calculate_hashes_from_features(
self, pixel_values: torch.Tensor, grid_thw: List
) -> List[str]:
"""CPU Task: Compute hashes based on processed feature patches (pixel_values)."""
hashes, offset = [], 0
for grid in grid_thw:
num_patches = self.get_num_patches(grid)
feature_slice = pixel_values[offset : offset + num_patches]
tmp_item = MultimodalDataItem(
modality=Modality.IMAGE, feature=feature_slice
)
tmp_item.set_pad_value()
hashes.append(tmp_item.hash)
offset += num_patches
return hashes

async def _encode_missing(
self, pixel_values: torch.Tensor, images_input: dict, indices: List[int]
) -> List[torch.Tensor]:
"""
GPU Task: Run ViT inference ONLY on the subset of images missing from the cache.
"""
grid_thw = images_input["image_grid_thw"]

# 1. Slice pixel_values to get only the patches for missing images
sub_pixel_list = []
offsets = [0]
curr = 0
for g in grid_thw:
curr += self.get_num_patches(g)
offsets.append(curr)

for idx in indices:
sub_pixel_list.append(pixel_values[offsets[idx] : offsets[idx + 1]])

sub_feature = torch.cat(sub_pixel_list, dim=0)

mm_item = MultimodalDataItem.from_dict(
{
"modality": Modality.IMAGE,
"feature": _convert(sub_feature),
}
)

for k, v in images_input.items():
if k == "pixel_values":
continue
val = _convert(v)
if k in _image_grid_attrs:
mm_item.set(k, val[indices])
else:
mm_item.set(k, val)

with torch.inference_mode():
Comment thread
liusy58 marked this conversation as resolved.
new_embeddings = self.model.get_image_feature([mm_item]).cpu()
if new_embeddings.ndim != 2:
new_embeddings = new_embeddings.reshape(-1, new_embeddings.shape[-1])

sub_grids = [grid_thw[i] for i in indices]
return self.slice_embedding(new_embeddings, sub_grids)

async def encode_with_global_cache(
self,
mm_items,
req_id: str,
num_parts: int,
part_idx: int,
hashes: Optional[List[str]] = None,
) -> torch.Tensor:
Comment thread
liusy58 marked this conversation as resolved.
images = await self._flatten_and_load_images(mm_items)
kwargs = {"device": self.device} if self.use_image_processor_gpu else {}
images_input = self.image_processor(images=images, **kwargs)
pixel_values = images_input["pixel_values"]
grid_thw = images_input["image_grid_thw"]
if hashes is None:
image_hashes = self._calculate_hashes_from_features(pixel_values, grid_thw)
else:
image_hashes = hashes

exist_mask = await self.mm_global_cache.batch_is_exist(image_hashes)

missing_indices = [i for i, exist in enumerate(exist_mask) if not exist]
hit_indices = [i for i, exist in enumerate(exist_mask) if exist]

gpu_task = None
if missing_indices:
gpu_task = asyncio.create_task(
self._encode_missing(pixel_values, images_input, missing_indices)
)

if hit_indices:
hit_hashes = [image_hashes[i] for i in hit_indices]
hit_tokens = [self.get_num_tokens(grid_thw[i]) for i in hit_indices]
self.mm_global_cache.prefetch(req_id, hit_hashes, hit_tokens)

new_slices = await gpu_task if gpu_task else []

if hit_indices:
while not self.mm_global_cache.check_prefetch_progress(req_id):
Comment thread
liusy58 marked this conversation as resolved.
Outdated
await asyncio.sleep(0.001)
Comment thread
liusy58 marked this conversation as resolved.
Outdated
Comment thread
liusy58 marked this conversation as resolved.
Outdated

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to sleep for 0.01s. Is it for waiting IO?


cached_slices = (
self.mm_global_cache.get_embeddings([image_hashes[i] for i in hit_indices])
if hit_indices
else []
)
Comment thread
liusy58 marked this conversation as resolved.
Outdated

final_slices = [None] * len(image_hashes)
for i, idx in enumerate(missing_indices):
final_slices[idx] = new_slices[i]
for i, idx in enumerate(hit_indices):
final_slices[idx] = cached_slices[i]

mm_embedding = torch.cat(final_slices, dim=0)
if self.mm_global_cache and missing_indices:
new_hashes = [image_hashes[i] for i in missing_indices]

async def _background_insert():
await asyncio.to_thread(
self.mm_global_cache.insert_batch, new_hashes, new_slices
)

task = asyncio.create_task(_background_insert())
self.background_tasks.add(task)
task.add_done_callback(self.background_tasks.discard)

if self.rank == 0:
self.embedding_to_send[req_id] = EmbeddingData(
req_id, num_parts, part_idx, grid_thw, mm_embedding
)

return mm_embedding.nbytes, mm_embedding.shape[0], mm_embedding.shape[1]

async def _encode(self, mm_items) -> torch.Tensor:
images = await self._flatten_and_load_images(mm_items)

Expand Down Expand Up @@ -382,6 +546,9 @@ async def _send(
[pickle.dumps(new_mm_data), embedding_tensor.__buffer__()]
)

async def encode_with_hash(self, mm_items, req_id, num_parts, part_idx, hashes):
images = await self._flatten_and_load_images(mm_items)
Comment thread
liusy58 marked this conversation as resolved.
Comment thread
liusy58 marked this conversation as resolved.

async def encode(self, mm_items, req_id, num_parts, part_idx):
start_time = time.time()
image_grid_dim, mm_embedding = await self._encode(mm_items)
Expand Down Expand Up @@ -628,13 +795,21 @@ async def handle_encode_request(request: dict):
request.update({"enter_time": time.time()})
for socket in send_sockets:
socket.send_pyobj(request)

nbytes, embedding_len, embedding_dim = await encoder.encode(
mm_items=request["mm_items"],
req_id=request["req_id"],
num_parts=request["num_parts"],
part_idx=request["part_idx"],
)
if encoder.mm_global_cache is not None:
nbytes, embedding_len, embedding_dim = await encoder.encode_with_global_cache(
mm_items=request["mm_items"],
req_id=request["req_id"],
num_parts=request["num_parts"],
part_idx=request["part_idx"],
hashes=request.get("hashes", None),
)
else:
nbytes, embedding_len, embedding_dim = await encoder.encode(
mm_items=request["mm_items"],
req_id=request["req_id"],
num_parts=request["num_parts"],
part_idx=request["part_idx"],
)
if encoder.server_args.encoder_transfer_backend == "mooncake":
del request["mm_items"]
request.update(
Expand Down
Loading
Loading