-
Notifications
You must be signed in to change notification settings - Fork 6.9k
Encoder Global Cache Manager #16137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Encoder Global Cache Manager #16137
Changes from 5 commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
4654ca6
global cache
liusy58 586339e
Merge branch 'main' into global_cache
liusy58 42a3f5a
fix
liusy58 4bbce2f
add
liusy58 666e5ba
Merge branch 'main' into global_cache
ZhengWG 2cb3052
Merge branch 'main' of https://github.com/sgl-project/sglang into glo…
liusy58 1572ec8
fix conflict
liusy58 4e5c9b7
fix lint
liusy58 6da23cf
fix
liusy58 6d89821
Merge branch 'main' into global_cache
liusy58 6b14dff
fix lint
liusy58 6b57c9e
fix
liusy58 4c60049
Merge branch 'main' into global_cache
stmatengss aca2cb7
Merge branch 'main' into global_cache
liusy58 4feadac
fix
liusy58 7000ff3
Merge branch 'global_cache' of github.com:liusy58/sglang into global_…
liusy58 bda58a8
fix
liusy58 17f133b
fix
liusy58 b78d6c2
Merge branch 'main' into global_cache
liusy58 571e518
clean code
liusy58 09ef470
clean code
liusy58 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,7 +7,7 @@ | |
| import pickle | ||
| import time | ||
| import traceback | ||
| from typing import Dict, List, Optional, Set, Tuple | ||
| from typing import Dict, List, Optional, Set, Tuple, Union | ||
|
|
||
| import aiohttp | ||
| import numpy as np | ||
|
|
@@ -194,6 +194,17 @@ def __init__( | |
| gpu_id=None, | ||
| ib_device=server_args.disaggregation_ib_device, | ||
| ) | ||
| if getattr(self.server_args, "enable_mm_global_cache", False): | ||
| from sglang.srt.managers.embedding_cache_controller import ( | ||
| EmbeddingCacheController, | ||
| ) | ||
|
|
||
| self.mm_global_cache = EmbeddingCacheController( | ||
| rank, server_args.tp_size, hidden_dim=self.model_config.hidden_size | ||
| ) | ||
| self.background_tasks = set() | ||
| else: | ||
| self.mm_global_cache = None | ||
|
|
||
| self.embedding_to_send = dict() | ||
|
|
||
|
|
@@ -290,6 +301,159 @@ async def _flatten_and_load_images(self, mm_items): | |
| async_futures = [asyncio.wrap_future(f) for f in futures] | ||
| return await asyncio.gather(*async_futures) | ||
|
|
||
| def get_num_patches(self, grid: Union[torch.Tensor, List[int]]) -> int: | ||
| """Calculate number of raw patches (before 2x2 merge). Used for pixel_values slicing.""" | ||
| return int(grid[0] * grid[1] * grid[2]) | ||
|
liusy58 marked this conversation as resolved.
|
||
|
|
||
| def get_num_tokens(self, grid: Union[torch.Tensor, List[int]]) -> int: | ||
| """Calculate number of tokens (after 2x2 merge). Used for mm_embedding slicing.""" | ||
| merge_size = getattr(self.image_processor, "merge_size", 2) | ||
| return self.get_num_patches(grid) // (merge_size**2) | ||
|
|
||
| def slice_embedding( | ||
| self, mm_embedding: torch.Tensor, grid_thw: List | ||
| ) -> List[torch.Tensor]: | ||
| """Slice a concatenated embedding tensor into individual image embeddings.""" | ||
| slices, offset = [], 0 | ||
| for grid in grid_thw: | ||
| count = self.get_num_tokens(grid) | ||
| slices.append(mm_embedding[offset : offset + count]) | ||
| offset += count | ||
| return slices | ||
|
|
||
| def _calculate_hashes_from_features( | ||
| self, pixel_values: torch.Tensor, grid_thw: List | ||
| ) -> List[str]: | ||
| """CPU Task: Compute hashes based on processed feature patches (pixel_values).""" | ||
| hashes, offset = [], 0 | ||
| for grid in grid_thw: | ||
| num_patches = self.get_num_patches(grid) | ||
| feature_slice = pixel_values[offset : offset + num_patches] | ||
| tmp_item = MultimodalDataItem( | ||
| modality=Modality.IMAGE, feature=feature_slice | ||
| ) | ||
| tmp_item.set_pad_value() | ||
| hashes.append(tmp_item.hash) | ||
| offset += num_patches | ||
| return hashes | ||
|
|
||
| async def _encode_missing( | ||
| self, pixel_values: torch.Tensor, images_input: dict, indices: List[int] | ||
| ) -> List[torch.Tensor]: | ||
| """ | ||
| GPU Task: Run ViT inference ONLY on the subset of images missing from the cache. | ||
| """ | ||
| grid_thw = images_input["image_grid_thw"] | ||
|
|
||
| # 1. Slice pixel_values to get only the patches for missing images | ||
| sub_pixel_list = [] | ||
| offsets = [0] | ||
| curr = 0 | ||
| for g in grid_thw: | ||
| curr += self.get_num_patches(g) | ||
| offsets.append(curr) | ||
|
|
||
| for idx in indices: | ||
| sub_pixel_list.append(pixel_values[offsets[idx] : offsets[idx + 1]]) | ||
|
|
||
| sub_feature = torch.cat(sub_pixel_list, dim=0) | ||
|
|
||
| mm_item = MultimodalDataItem.from_dict( | ||
| { | ||
| "modality": Modality.IMAGE, | ||
| "feature": _convert(sub_feature), | ||
| } | ||
| ) | ||
|
|
||
| for k, v in images_input.items(): | ||
| if k == "pixel_values": | ||
| continue | ||
| val = _convert(v) | ||
| if k in _image_grid_attrs: | ||
| mm_item.set(k, val[indices]) | ||
| else: | ||
| mm_item.set(k, val) | ||
|
|
||
| with torch.inference_mode(): | ||
|
liusy58 marked this conversation as resolved.
|
||
| new_embeddings = self.model.get_image_feature([mm_item]).cpu() | ||
| if new_embeddings.ndim != 2: | ||
| new_embeddings = new_embeddings.reshape(-1, new_embeddings.shape[-1]) | ||
|
|
||
| sub_grids = [grid_thw[i] for i in indices] | ||
| return self.slice_embedding(new_embeddings, sub_grids) | ||
|
|
||
| async def encode_with_global_cache( | ||
| self, | ||
| mm_items, | ||
| req_id: str, | ||
| num_parts: int, | ||
| part_idx: int, | ||
| hashes: Optional[List[str]] = None, | ||
| ) -> torch.Tensor: | ||
|
liusy58 marked this conversation as resolved.
|
||
| images = await self._flatten_and_load_images(mm_items) | ||
| kwargs = {"device": self.device} if self.use_image_processor_gpu else {} | ||
| images_input = self.image_processor(images=images, **kwargs) | ||
| pixel_values = images_input["pixel_values"] | ||
| grid_thw = images_input["image_grid_thw"] | ||
| if hashes is None: | ||
| image_hashes = self._calculate_hashes_from_features(pixel_values, grid_thw) | ||
| else: | ||
| image_hashes = hashes | ||
|
|
||
| exist_mask = await self.mm_global_cache.batch_is_exist(image_hashes) | ||
|
|
||
| missing_indices = [i for i, exist in enumerate(exist_mask) if not exist] | ||
| hit_indices = [i for i, exist in enumerate(exist_mask) if exist] | ||
|
|
||
| gpu_task = None | ||
| if missing_indices: | ||
| gpu_task = asyncio.create_task( | ||
| self._encode_missing(pixel_values, images_input, missing_indices) | ||
| ) | ||
|
|
||
| if hit_indices: | ||
| hit_hashes = [image_hashes[i] for i in hit_indices] | ||
| hit_tokens = [self.get_num_tokens(grid_thw[i]) for i in hit_indices] | ||
| self.mm_global_cache.prefetch(req_id, hit_hashes, hit_tokens) | ||
|
|
||
| new_slices = await gpu_task if gpu_task else [] | ||
|
|
||
| if hit_indices: | ||
| while not self.mm_global_cache.check_prefetch_progress(req_id): | ||
|
liusy58 marked this conversation as resolved.
Outdated
|
||
| await asyncio.sleep(0.001) | ||
|
liusy58 marked this conversation as resolved.
Outdated
liusy58 marked this conversation as resolved.
Outdated
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to sleep for 0.01s. Is it for waiting IO? |
||
|
|
||
| cached_slices = ( | ||
| self.mm_global_cache.get_embeddings([image_hashes[i] for i in hit_indices]) | ||
| if hit_indices | ||
| else [] | ||
| ) | ||
|
liusy58 marked this conversation as resolved.
Outdated
|
||
|
|
||
| final_slices = [None] * len(image_hashes) | ||
| for i, idx in enumerate(missing_indices): | ||
| final_slices[idx] = new_slices[i] | ||
| for i, idx in enumerate(hit_indices): | ||
| final_slices[idx] = cached_slices[i] | ||
|
|
||
| mm_embedding = torch.cat(final_slices, dim=0) | ||
| if self.mm_global_cache and missing_indices: | ||
| new_hashes = [image_hashes[i] for i in missing_indices] | ||
|
|
||
| async def _background_insert(): | ||
| await asyncio.to_thread( | ||
| self.mm_global_cache.insert_batch, new_hashes, new_slices | ||
| ) | ||
|
|
||
| task = asyncio.create_task(_background_insert()) | ||
| self.background_tasks.add(task) | ||
| task.add_done_callback(self.background_tasks.discard) | ||
|
|
||
| if self.rank == 0: | ||
| self.embedding_to_send[req_id] = EmbeddingData( | ||
| req_id, num_parts, part_idx, grid_thw, mm_embedding | ||
| ) | ||
|
|
||
| return mm_embedding.nbytes, mm_embedding.shape[0], mm_embedding.shape[1] | ||
|
|
||
| async def _encode(self, mm_items) -> torch.Tensor: | ||
| images = await self._flatten_and_load_images(mm_items) | ||
|
|
||
|
|
@@ -382,6 +546,9 @@ async def _send( | |
| [pickle.dumps(new_mm_data), embedding_tensor.__buffer__()] | ||
| ) | ||
|
|
||
| async def encode_with_hash(self, mm_items, req_id, num_parts, part_idx, hashes): | ||
| images = await self._flatten_and_load_images(mm_items) | ||
|
liusy58 marked this conversation as resolved.
liusy58 marked this conversation as resolved.
|
||
|
|
||
| async def encode(self, mm_items, req_id, num_parts, part_idx): | ||
| start_time = time.time() | ||
| image_grid_dim, mm_embedding = await self._encode(mm_items) | ||
|
|
@@ -628,13 +795,21 @@ async def handle_encode_request(request: dict): | |
| request.update({"enter_time": time.time()}) | ||
| for socket in send_sockets: | ||
| socket.send_pyobj(request) | ||
|
|
||
| nbytes, embedding_len, embedding_dim = await encoder.encode( | ||
| mm_items=request["mm_items"], | ||
| req_id=request["req_id"], | ||
| num_parts=request["num_parts"], | ||
| part_idx=request["part_idx"], | ||
| ) | ||
| if encoder.mm_global_cache is not None: | ||
| nbytes, embedding_len, embedding_dim = await encoder.encode_with_global_cache( | ||
| mm_items=request["mm_items"], | ||
| req_id=request["req_id"], | ||
| num_parts=request["num_parts"], | ||
| part_idx=request["part_idx"], | ||
| hashes=request.get("hashes", None), | ||
| ) | ||
| else: | ||
| nbytes, embedding_len, embedding_dim = await encoder.encode( | ||
| mm_items=request["mm_items"], | ||
| req_id=request["req_id"], | ||
| num_parts=request["num_parts"], | ||
| part_idx=request["part_idx"], | ||
| ) | ||
| if encoder.server_args.encoder_transfer_backend == "mooncake": | ||
| del request["mm_items"] | ||
| request.update( | ||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.