Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
368 changes: 265 additions & 103 deletions python/sglang/srt/disaggregation/encode_receiver.py

Large diffs are not rendered by default.

117 changes: 94 additions & 23 deletions python/sglang/srt/disaggregation/encode_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import zmq.asyncio
from fastapi import FastAPI
from fastapi.responses import ORJSONResponse, Response
from transformers import AutoImageProcessor
from transformers import AutoProcessor

from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
Expand All @@ -33,6 +33,7 @@
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.mem_cache.multimodal_cache import EmbeddingResult, MultiModalStaticCache
from sglang.srt.model_loader import get_model
from sglang.srt.multimodal.processors.qwen_vl import preprocess_video
from sglang.srt.server_args import (
PortArgs,
ServerArgs,
Expand Down Expand Up @@ -94,15 +95,18 @@ def _convert(data):
return data


_image_grid_attrs = ["image_grid_thw", "image_grid_hws"]
_vision_grid_attrs = {
Modality.IMAGE: ["image_grid_thw", "image_grid_hws"],
Modality.VIDEO: ["video_grid_thw"],
}


def _get_image_grid_dim(images_input):
for attr in _image_grid_attrs:
if attr in images_input:
return images_input[attr]
def _get_vision_grid_dim(mm_inputs, modality):
for attr in _vision_grid_attrs[modality]:
if attr in mm_inputs:
return mm_inputs[attr]
raise ValueError(
f"Image grid dim ({_image_grid_attrs}) not found in {images_input}"
f"Vision grid dim ({_vision_grid_attrs[modality]}) not found in {mm_inputs}"
)


Expand All @@ -120,15 +124,33 @@ def __init__(
self.rank = rank
self.profiler = EncoderProfiler(rank)

self.image_processor = AutoImageProcessor.from_pretrained(
server_args.model_path,
processor_path = server_args.tokenizer_path or server_args.model_path
self.mm_processor = AutoProcessor.from_pretrained(
processor_path,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
use_fast=True,
)

self.model_config = ModelConfig.from_server_args(
server_args,
)
vision_config = (
server_args.mm_process_config.get("vision_config", {})
if server_args.mm_process_config is not None
else {}
)
if not vision_config:
# keep default values as qwen_vl.py
self.video_config = {
"fps": 2.0,
"max_frames": 768,
"min_frames": 4,
}
else:
self.video_config = vision_config
Comment on lines +143 to +151
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be extended for more models in the future?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, it just needs some time to align the preprocessing details of different models. I'll do it ASAP.

# default support cuda
self.video_config["device"] = "cuda"

self.load_config = LoadConfig(
load_format=server_args.load_format,
Expand All @@ -138,6 +160,9 @@ def __init__(
remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port,
remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports,
)
self.model_type = getattr(
self.model_config.hf_config, "model_type", "unknown"
).lower()

self.device = server_args.device
self.gpu_id = server_args.base_gpu_id + rank
Expand Down Expand Up @@ -290,22 +315,66 @@ async def _flatten_and_load_images(self, mm_items):
async_futures = [asyncio.wrap_future(f) for f in futures]
return await asyncio.gather(*async_futures)

async def _encode(self, mm_items) -> torch.Tensor:
images = await self._flatten_and_load_images(mm_items)
async def _process_mm_items(self, mm_items, modality):
if modality == Modality.IMAGE:
images = await self._flatten_and_load_images(mm_items)
kwargs = {"device": self.device} if self.use_image_processor_gpu else {}
processor_input = self.mm_processor.image_processor(images=images, **kwargs)
feature = processor_input["pixel_values"]
get_feature_method = self.model.get_image_feature
elif modality == Modality.VIDEO:
# mainly follows qwen_vl.py: only support qwen series models for video processing
if "qwen" not in self.model_type:
raise ValueError(
f"Video modality processing is currently only supported for Qwen series models with EPD enabled, "
f"but got model_type: {self.model_type}"
)
# Load videos concurrently
futures, _ = self.submit_data_loading_tasks(
mm_items, [Modality.VIDEO] * len(mm_items)
)
async_futures = [asyncio.wrap_future(f) for f in futures]
video_items = await asyncio.gather(*async_futures)

videos_processed = [
await preprocess_video(video, video_config=self.video_config)
for video in video_items
]
videos, video_metadata = map(list, zip(*videos_processed))

# pass preprocessed videos to processor with do_sample_frames=False
video_processor_kwargs = {}
video_processor_kwargs["do_sample_frames"] = False
for key in self.video_config:
video_processor_kwargs[key] = self.video_config[key]
if video_metadata:
video_processor_kwargs["video_metadata"] = video_metadata
processor_input = self.mm_processor.video_processor(
videos=videos, **video_processor_kwargs
)
feature = processor_input["pixel_values_videos"]
get_feature_method = self.model.get_video_feature
else:
raise ValueError(
f"Currently only support image and video modalities, but got {modality}"
)

kwargs = {"device": self.device} if self.use_image_processor_gpu else {}
images_input = self.image_processor(images=images, **kwargs)
feature = images_input["pixel_values"]
mm_item = MultimodalDataItem.from_dict(
{
"modality": Modality.IMAGE,
"modality": modality,
"feature": _convert(feature),
}
)
for k, v in images_input.items():
if k == "pixel_values":
for k, v in processor_input.items():
if k in ["pixel_values", "pixel_values_videos"]:
continue
mm_item.set(k, _convert(v))
return processor_input, mm_item, get_feature_method

async def _encode(self, mm_items, modality: Modality) -> torch.Tensor:
mm_inputs, mm_item, get_feature_fn = await self._process_mm_items(
mm_items, modality
)

# support mm_cache
mm_embedding = None
Expand All @@ -322,7 +391,7 @@ async def _encode(self, mm_items) -> torch.Tensor:

if mm_embedding is None:
with torch.inference_mode():
mm_embedding: torch.Tensor = self.model.get_image_feature([mm_item])
mm_embedding: torch.Tensor = get_feature_fn([mm_item])
mm_embedding = mm_embedding.cpu()
if len(mm_embedding.shape) != 2:
mm_embedding = mm_embedding.reshape(-1, mm_embedding.shape[-1])
Expand All @@ -337,7 +406,7 @@ async def _encode(self, mm_items) -> torch.Tensor:
if self.profiler is not None:
self.profiler.step()

return _get_image_grid_dim(images_input), mm_embedding
return _get_vision_grid_dim(mm_inputs, modality), mm_embedding

async def _send(
self,
Expand All @@ -357,7 +426,6 @@ async def _send(
self.engine.deregister(embedding.data_ptr())

mm_data.embedding = None
mm_data.embedding_list[mm_data.part_idx] = None

# Send ack/data
endpoint = (
Expand All @@ -382,17 +450,18 @@ async def _send(
[pickle.dumps(new_mm_data), embedding_tensor.__buffer__()]
)

async def encode(self, mm_items, req_id, num_parts, part_idx):
async def encode(self, mm_items, modality: Modality, req_id, num_parts, part_idx):
start_time = time.time()
image_grid_dim, mm_embedding = await self._encode(mm_items)
grid_dim, mm_embedding = await self._encode(mm_items, modality)
end_time = time.time()
logger.info(f"🕛 encode cost = {(end_time - start_time) * 1000:.2f}ms")
if self.rank == 0:
mm_data = EmbeddingData(
req_id,
num_parts,
part_idx,
image_grid_dim,
grid_dim,
modality,
mm_embedding,
)
self.embedding_to_send[mm_data.req_id] = mm_data
Expand Down Expand Up @@ -583,6 +652,7 @@ async def run_encoder(
else:
await encoder.encode(
mm_items=request["mm_items"],
modality=Modality.from_str(request["modality"]),
req_id=request["req_id"],
num_parts=request["num_parts"],
part_idx=request["part_idx"],
Expand Down Expand Up @@ -631,6 +701,7 @@ async def handle_encode_request(request: dict):

nbytes, embedding_len, embedding_dim = await encoder.encode(
mm_items=request["mm_items"],
modality=Modality.from_str(request["modality"]),
req_id=request["req_id"],
num_parts=request["num_parts"],
part_idx=request["part_idx"],
Expand Down
8 changes: 4 additions & 4 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ class GenerateReqInput(BaseReq, APIServingTimingMixin):
external_trace_header: Optional[Dict] = None

# For EPD-disaggregated inference
need_wait_for_image: Optional[bool] = None
num_items_assigned: Optional[List] = None
need_wait_for_mm_inputs: Optional[bool] = None
num_items_assigned: Optional[Dict[Union[str, int], List[int]]] = None

# Multimodal tiling controls (extensions)
max_dynamic_patch: Optional[int] = None
Expand Down Expand Up @@ -758,8 +758,8 @@ class TokenizedGenerateReqInput(BaseReq):
# Whether to return entropy
return_entropy: bool = False

need_wait_for_image: bool = False
num_items_assigned: Optional[List] = None
need_wait_for_mm_inputs: bool = False
num_items_assigned: Optional[Dict[Union[str, int], List[int]]] = None


@dataclass
Expand Down
111 changes: 80 additions & 31 deletions python/sglang/srt/managers/mm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,34 @@ def get_embedding_chunk(
return embedding_chunk, start_index, end_index


def _get_precomputed_embedding_multi_items(
items_per_req: List[MultimodalDataItem],
prefix_len: int,
extend_len: int,
items_offset: List[Tuple[int, int]],
) -> Optional[torch.Tensor]:
"""
Extract precomputed embedding chunk for a request with multiple items.
"""
if any(item.precomputed_embeddings is None for item in items_per_req):
return None
req_embeddings = torch.concat(
[item.precomputed_embeddings for item in items_per_req]
)

# Extract the chunk using get_embedding_chunk logic (just slicing, no recomputation)
embedding_chunk, _, _ = get_embedding_chunk(
embedding=req_embeddings,
extend_prefix_len=prefix_len,
extend_seq_len=extend_len,
items_offset=items_offset,
)
return embedding_chunk


def _get_precomputed_embedding(
items: List[MultimodalDataItem],
items_size: List[int],
prefix_length: List[int],
extend_length: List[int],
items_offset_list: List[List[Tuple[int, int]]],
Expand All @@ -434,38 +460,61 @@ def _get_precomputed_embedding(
If none have precomputed_embeddings, return None.
"""
precomputed_embeddings = []
for idx, item in enumerate(items):
if item.precomputed_embeddings is None:
precomputed_embeddings.append(None)
max_iterations = min(len(items_size) - 1, len(prefix_length))

for i in range(max_iterations):
if items_size[i] == items_size[i + 1]:
continue
seq_start_idx = prefix_length[idx]
seq_end_idx = seq_start_idx + extend_length[idx] - 1
prefix_embedding_length = []
extend_embedding_length = []
for mm_start_idx, mm_end_idx in items_offset_list[idx]:
if mm_start_idx > seq_end_idx:
break
if seq_start_idx > mm_start_idx:
prefix_embedding_length.append(
min(seq_start_idx - mm_start_idx, mm_end_idx - mm_start_idx + 1)
)
if mm_end_idx >= seq_start_idx:
extend_embedding_length.append(
min(
mm_end_idx - seq_start_idx + 1,
seq_end_idx - mm_start_idx + 1,
mm_end_idx - mm_start_idx + 1,
seq_end_idx - seq_start_idx + 1,

items_per_req = items[items_size[i] : items_size[i + 1]]
items_offset = items_offset_list[i]

if len(items_per_req) > 1:
embedding_chunk = _get_precomputed_embedding_multi_items(
items_per_req=items_per_req,
prefix_len=prefix_length[i],
extend_len=extend_length[i] if i < len(extend_length) else 0,
items_offset=items_offset,
)
if embedding_chunk is None:
return None
precomputed_embeddings.append(embedding_chunk)
else:
# Single item per request: use original logic
item = items_per_req[0]
if item.precomputed_embeddings is None:
precomputed_embeddings.append(None)
continue
seq_start_idx = prefix_length[i]
seq_end_idx = (
seq_start_idx + (extend_length[i] if i < len(extend_length) else 0) - 1
)
prefix_embedding_length = []
extend_embedding_length = []
for mm_start_idx, mm_end_idx in items_offset:
if mm_start_idx > seq_end_idx:
break
if seq_start_idx > mm_start_idx:
prefix_embedding_length.append(
min(seq_start_idx - mm_start_idx, mm_end_idx - mm_start_idx + 1)
)
)
prefix_embedding_length = int(np.sum(prefix_embedding_length))
extend_embedding_length = int(np.sum(extend_embedding_length))
precomputed_embeddings.append(
item.precomputed_embeddings[
prefix_embedding_length : prefix_embedding_length
+ extend_embedding_length
]
)
if mm_end_idx >= seq_start_idx:
extend_embedding_length.append(
min(
mm_end_idx - seq_start_idx + 1,
seq_end_idx - mm_start_idx + 1,
mm_end_idx - mm_start_idx + 1,
seq_end_idx - seq_start_idx + 1,
)
)
prefix_embedding_length = int(np.sum(prefix_embedding_length))
extend_embedding_length = int(np.sum(extend_embedding_length))
precomputed_embeddings.append(
item.precomputed_embeddings[
prefix_embedding_length : prefix_embedding_length
+ extend_embedding_length
]
)

if any(feature is not None for feature in precomputed_embeddings):
if not all(feature is not None for feature in precomputed_embeddings):
Expand Down Expand Up @@ -883,7 +932,7 @@ def get_embedding_and_mask(
"""
# 1. Get embedding
embedding = _get_precomputed_embedding(
embedding_items, prefix_length, extend_length, items_offset_list
embedding_items, items_size, prefix_length, extend_length, items_offset_list
)
if embedding is None:
embedding, input_ids = _get_chunked_prefill_embedding(
Expand Down
Loading
Loading