diff --git a/python/sglang/srt/disaggregation/encode_receiver.py b/python/sglang/srt/disaggregation/encode_receiver.py index 8c84f122ca0c..82bb80ef11a3 100644 --- a/python/sglang/srt/disaggregation/encode_receiver.py +++ b/python/sglang/srt/disaggregation/encode_receiver.py @@ -4,7 +4,7 @@ import random import threading import uuid -from typing import List, Optional +from typing import Dict, List, Optional import aiohttp import torch @@ -16,54 +16,32 @@ from sglang.srt.distributed.parallel_state import GroupCoordinator from sglang.srt.managers.io_struct import TokenizedGenerateReqInput from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors +from sglang.srt.managers.schedule_batch import Modality from sglang.srt.server_args import ServerArgs from sglang.srt.utils import get_local_ip_auto, get_zmq_socket_on_host +from sglang.srt.utils.common import ImageData from sglang.srt.utils.hf_transformers_utils import get_processor logger = logging.getLogger(__name__) class EmbeddingData: - def __init__(self, req_id, num_parts, part_idx, image_grid_dim, embedding=None): + def __init__(self, req_id, num_parts, part_idx, grid_dim, modality, embedding=None): self.req_id = req_id self.num_parts = num_parts self.part_idx = part_idx - self.image_grid_dim = image_grid_dim + self.grid_dim = grid_dim + self.modality = modality self.embedding = embedding self.send_time = None self.dtype = embedding.dtype if embedding is not None else None self.shape = list(embedding.shape) if embedding is not None else None - # aggregated data - self.ready_list = [i == self.part_idx for i in range(self.num_parts)] - self.embedding_list = [ - embedding if i == self.part_idx else None for i in range(self.num_parts) - ] - self.image_grid_dim_list = [ - self.image_grid_dim if i == self.part_idx else None - for i in range(self.num_parts) - ] - - def add(self, embedding_data): - assert self.req_id == embedding_data.req_id - assert not self.ready_list[embedding_data.part_idx] - self.ready_list[embedding_data.part_idx] = True - self.image_grid_dim_list[embedding_data.part_idx] = ( - embedding_data.image_grid_dim - ) - self.embedding_list[embedding_data.part_idx] = embedding_data.embedding - - def get_embedding(self, is_concat=False): - if is_concat: - return torch.concat([embedding.cuda() for embedding in self.embedding_list]) - else: - return self.embedding_list - def get_img_grid(self): - return torch.concatenate(self.image_grid_dim_list) + def get_grid(self): + return self.grid_dim - @property - def ready(self): - return sum(self.ready_list) == self.num_parts + def get_embedding(self): + return self.embedding def __repr__(self): return f"EmbeddingData(req_id={self.req_id}, num_parts={self.num_parts}, part_idx={self.part_idx})" @@ -73,7 +51,8 @@ def copy_without_embedding(self): req_id=self.req_id, num_parts=self.num_parts, part_idx=self.part_idx, - image_grid_dim=self.image_grid_dim, + grid_dim=self.grid_dim, + modality=self.modality, ) new_data.send_time = self.send_time new_data.dtype = self.dtype @@ -81,6 +60,89 @@ def copy_without_embedding(self): return new_data +class MultiModalEmbeddingData(EmbeddingData): + def __init__(self, part_idx, num_parts, req_id, grid_dim, modality, embedding): + super().__init__(req_id, num_parts, part_idx, grid_dim, modality, embedding) + self.img_grid_thw = [None] * num_parts + self.video_grid_thw = [None] * num_parts + self.modality_list = [ + modality if part_idx == i else None for i in range(num_parts) + ] + self.ready_list = [i == part_idx for i in range(num_parts)] + self.embedding_list = [ + embedding if i == part_idx else None for i in range(num_parts) + ] + if modality == Modality.IMAGE: + self.img_grid_thw[part_idx] = self.get_grid() + elif modality == Modality.VIDEO: + self.video_grid_thw[part_idx] = self.get_grid() + + @classmethod + def from_embedding_data(cls, embedding_data: EmbeddingData): + """Create MultiModalEmbeddingData from an EmbeddingData instance.""" + mm_data = cls( + part_idx=embedding_data.part_idx, + num_parts=embedding_data.num_parts, + req_id=embedding_data.req_id, + grid_dim=embedding_data.grid_dim, + modality=embedding_data.modality, + embedding=embedding_data.embedding, + ) + # Copy over additional attributes + mm_data.send_time = embedding_data.send_time + return mm_data + + def __repr__(self): + return f"MultiModalEmbeddingData(req_id={self.req_id}, num_parts={self.num_parts}, part_idx={self.part_idx}, modality={self.modality})" + + def _get_mm_grid(self, modality): + if modality == Modality.IMAGE: + grid_dims = self.img_grid_thw + elif modality == Modality.VIDEO: + grid_dims = self.video_grid_thw + + valid_grid_dims = [] + for grid_dim in grid_dims: + if grid_dim is None: + continue + if grid_dim.dim() == 1: + valid_grid_dims.append(grid_dim.unsqueeze(0)) + else: + valid_grid_dims.append(grid_dim) + if len(valid_grid_dims) == 0: + return None + return torch.cat(valid_grid_dims, dim=0) + + def get_embedding(self, is_concat=False): + if is_concat: + return torch.concat([embedding.cuda() for embedding in self.embedding_list]) + else: + return self.embedding_list + + @property + def ready(self): + return sum(self.ready_list) == self.num_parts + + def get_img_grid(self): + return self._get_mm_grid(Modality.IMAGE) + + def get_video_grid(self): + return self._get_mm_grid(Modality.VIDEO) + + def add(self, embedding_data: EmbeddingData): + assert self.req_id == embedding_data.req_id + assert not self.ready_list[embedding_data.part_idx] + self.ready_list[embedding_data.part_idx] = True + self.modality_list[embedding_data.part_idx] = embedding_data.modality + self.embedding_list[embedding_data.part_idx] = embedding_data.get_embedding() + if embedding_data.modality == Modality.IMAGE: + self.img_grid_thw[embedding_data.part_idx] = embedding_data.get_grid() + elif embedding_data.modality == Modality.VIDEO: + self.video_grid_thw[embedding_data.part_idx] = embedding_data.get_grid() + else: + raise ValueError(f"Invalid modality: {embedding_data.modality}") + + # For zmq_to_scheduler class WaitingImageRequest: def __init__( @@ -125,21 +187,22 @@ async def send_embedding_port(req_id, receive_count, host_name, embedding_port): ) as session: tasks = [] logger.info(f"{self.num_items_assigned = } ") - for idx, assigned_num in enumerate(self.num_items_assigned): - if assigned_num == 0: - continue - encoder_url = self.encoder_urls[idx] - target_url = f"{encoder_url}/scheduler_receive_url" - payload = { - "req_id": req_id, - "receive_count": receive_count, - "receive_url": f"{host_name}:{embedding_port}", - } - - logger.info(f"Preparing to send to {target_url}") - task = _send_single_request(session, target_url, payload) - tasks.append(task) + for modality, assigned_nums in self.num_items_assigned.items(): + for idx, assigned_num in enumerate(assigned_nums): + if assigned_num == 0: + continue + encoder_url = self.encoder_urls[idx] + target_url = f"{encoder_url}/scheduler_receive_url" + payload = { + "req_id": req_id, + "receive_count": receive_count, + "receive_url": f"{host_name}:{embedding_port}", + "modality": modality.name, + } + logger.info(f"Preparing to send to {target_url}") + task = _send_single_request(session, target_url, payload) + tasks.append(task) if not tasks: logger.info("No tasks to send.") @@ -177,17 +240,20 @@ def _try_recv_mm_data(self): recv_obj.embedding = torch.frombuffer(buffer, dtype=recv_obj.dtype).reshape( recv_obj.shape ) - recv_obj.embedding_list[recv_obj.part_idx] = recv_obj.embedding + if self.recv_embedding_data is None: - self.recv_embedding_data = recv_obj + self.recv_embedding_data = MultiModalEmbeddingData.from_embedding_data( + recv_obj + ) else: self.recv_embedding_data.add(recv_obj) recv_embedding = self.recv_embedding_data.get_embedding(is_concat=True) img_grid_thw = self.recv_embedding_data.get_img_grid() + video_grid_thw = self.recv_embedding_data.get_video_grid() mm_inputs = self.mm_processor.get_mm_data( - self.recv_req.input_text, recv_embedding, img_grid_thw + self.recv_req.input_text, recv_embedding, img_grid_thw, video_grid_thw ) self.recv_req.mm_inputs = mm_inputs self.recv_req.input_ids = mm_inputs["input_ids"] @@ -274,7 +340,7 @@ def process_waiting_requests(self, recv_reqs): for recv_req in recv_reqs: if ( isinstance(recv_req, TokenizedGenerateReqInput) - and recv_req.need_wait_for_image is True + and recv_req.need_wait_for_mm_inputs is True ): waiting_req = WaitingImageRequest( rid=recv_req.rid, @@ -317,13 +383,13 @@ def process_waiting_requests(self, recv_reqs): # For zmq_to_scheduler def _run_encode_in_thread( - self, req_id, img_data, endpoint_encode, num_items_assigned, embedding_port + self, req_id, mm_data, endpoint_encode, num_items_assigned, embedding_port ): try: asyncio.run( self.encode( req_id=req_id, - img_data=img_data, + mm_data=mm_data, embedding_port=embedding_port, endpoint_encode=endpoint_encode, endpoint_send=None, @@ -333,45 +399,117 @@ def _run_encode_in_thread( except Exception as e: logger.error(f"Encode failed for request {req_id}: {e}", exc_info=True) + def _assign_items_by_modality( + self, mm_data, encoder_num, random_shuffle=True + ) -> Dict: + """ + Assign multimodal items across encoders by modality with cross-modality load balancing. + + Args: + mm_data: List of multimodal data items, each with a "modality" key + encoder_num: Number of encoders + random_shuffle: Whether to shuffle the encoder indices + + Returns: + Dictionary mapping modality to list of assignment counts per encoder + Format: {modality: [count_for_encoder_0, count_for_encoder_1, ...]} + """ + encode_idx = list(range(encoder_num)) + if random_shuffle: + random.shuffle(encode_idx) + # Get unique modalities + modalities = list(dict.fromkeys(mm_item.get("modality") for mm_item in mm_data)) + num_items_assigned = {} + current_offset = 0 + + for modality in modalities: + mm_data_modality = [ + mm_item for mm_item in mm_data if mm_item.get("modality") == modality + ] + num_items = len(mm_data_modality) + if num_items == 0: + continue + + base = num_items // len(encode_idx) + remainder = num_items % len(encode_idx) + # Rotate assignments based on current_offset to balance load across modalities + assignments = [0] * len(encode_idx) + for i in range(len(encode_idx)): + # keep shuffle order when assigning items to encoders + pos_in_shuffled = (current_offset + i) % len(encode_idx) + actual_encoder_idx = encode_idx[pos_in_shuffled] + assignments[actual_encoder_idx] = base + (1 if i < remainder else 0) + num_items_assigned[modality] = assignments + current_offset = (current_offset + remainder) % len(encode_idx) + + return num_items_assigned + async def encode( self, req_id, - img_data, + mm_data, embedding_port, endpoint_encode, endpoint_send, num_items_assigned=None, ): - if len(img_data) == 0: + if len(mm_data) == 0: return - # Split mm_items + # get unique modalities with order preserved + modalities = [mm_item.get("modality") for mm_item in mm_data] + modalities = list(dict.fromkeys(modalities)) encode_requests = [] + if num_items_assigned is None: - random.shuffle(self.encode_idx) - num_items_assigned = [ - (idx + len(img_data)) // len(self.encode_urls) - for idx in self.encode_idx - ] - num_parts = sum(1 for x in num_items_assigned if x != 0) - cum_num_items = 0 - cum_idx = 0 - for idx, assigned_num in enumerate(num_items_assigned): - if assigned_num == 0: - continue - encode_requests.append( - { - "encoder_idx": idx, - "mm_items": img_data[cum_num_items : cum_num_items + assigned_num], - "num_parts": num_parts, - "part_idx": cum_idx, - "req_id": req_id, - "prefill_host": self.host, - "embedding_port": embedding_port, - } + num_items_assigned = self._assign_items_by_modality( + mm_data, len(self.encode_urls) ) - cum_idx += 1 - cum_num_items += assigned_num + + # Calculate total num_parts across all modalities + total_num_parts = 0 + modality_num_parts = {} + for modality in modalities: + num_items_assigned_modality = num_items_assigned.get(modality) + num_parts = sum(1 for x in num_items_assigned_modality if x != 0) + modality_num_parts[modality] = num_parts + total_num_parts += num_parts + + part_idx_offset = 0 + for modality in modalities: + num_items_assigned_modality = num_items_assigned.get(modality) + mm_data_modality = [ + mm_item for mm_item in mm_data if mm_item.get("modality") == modality + ] + + num_parts = modality_num_parts[modality] + cum_num_items = 0 + cum_idx = 0 + for idx, assigned_num in enumerate(num_items_assigned_modality): + if assigned_num == 0: + continue + encode_requests.append( + { + "encoder_idx": self.encode_idx[ + idx + ], # use shuffle-idx to load-balance + "mm_items": [ + mm_item.get("url") + for mm_item in mm_data_modality[ + cum_num_items : cum_num_items + assigned_num + ] + ], + "num_parts": total_num_parts, + "part_idx": part_idx_offset + cum_idx, + "req_id": req_id, + "modality": modality.name, # convert enum to string for json serialization + "prefill_host": self.host, + "embedding_port": embedding_port, + } + ) + cum_idx += 1 + cum_num_items += assigned_num + part_idx_offset += num_parts async with aiohttp.ClientSession( timeout=aiohttp.ClientTimeout( @@ -399,9 +537,9 @@ async def encode( # mooncake backend: send bootstrap info - embedding_size_list_sort = [None for _ in range(num_parts)] + embedding_size_list_sort = [None for _ in range(total_num_parts)] embedding_length_tot = 0 - response_json_list_sort = [None for _ in range(num_parts)] + response_json_list_sort = [None for _ in range(total_num_parts)] for response_json in response_json_list_unsort: idx = response_json["part_idx"] embedding_size_list_sort[idx] = response_json["embedding_size"] @@ -448,26 +586,21 @@ async def allocate_embedding_buffer(self, req_id, embedding_length, embedding_di # For zmq_to_scheduler def send_encode_request(self, obj): - if type(obj.image_data) != list: - image_urls = [obj.image_data.url] - else: - image_urls = [img.url for img in obj.image_data] + mm_data = self._extract_url_data(obj) if obj.rid is None: obj.rid = uuid.uuid4().hex - if image_urls and len(image_urls) > 0: - logger.info(f"Processing {len(image_urls)} images for request {obj.rid}") - obj.need_wait_for_image = True - encode_idx = list(range(len(self.encode_urls))) - random.shuffle(encode_idx) - obj.num_items_assigned = [ - (idx + len(image_urls)) // len(self.encode_urls) for idx in encode_idx - ] + if mm_data and len(mm_data) > 0: + logger.info(f"Processing {len(mm_data)} mm_items for request {obj.rid}") + obj.need_wait_for_mm_inputs = True + obj.num_items_assigned = self._assign_items_by_modality( + mm_data, len(self.encode_urls) + ) encode_thread = threading.Thread( target=self._run_encode_in_thread, args=( obj.rid, - image_urls, + mm_data, "encode", obj.num_items_assigned, None, @@ -476,19 +609,42 @@ def send_encode_request(self, obj): ) encode_thread.start() + def _extract_url_data(self, request_obj) -> List[Dict]: + mm_data = [] + + def _load_url(mm_items, modality): + for mm_item in mm_items: + mm_data.append( + { + "url": ( + mm_item.url if isinstance(mm_item, ImageData) else mm_item + ), + "modality": modality, + } + ) + + if request_obj.image_data: + if not isinstance(request_obj.image_data, List): + _load_url([request_obj.image_data], Modality.IMAGE) + else: + _load_url(request_obj.image_data, Modality.IMAGE) + if request_obj.video_data: + if not isinstance(request_obj.video_data, List): + _load_url([request_obj.video_data], Modality.VIDEO) + else: + _load_url(request_obj.video_data, Modality.VIDEO) + return mm_data + # For zmq_to_tokenizer and mooncake - async def recv_mm_data(self, img_data, mm_processor, prompt): + async def recv_mm_data(self, request_obj, mm_processor, prompt): try: if len(self.encode_urls) == 0: return None req_id = uuid.uuid4().hex embedding_port, recv_socket = get_zmq_socket_on_host(self.context, zmq.PULL) - if type(img_data) != list: - img_data = [img_data.url] - else: - img_data = [img.url for img in img_data] + mm_data = self._extract_url_data(request_obj) asyncio.create_task( - self.encode(req_id, img_data, embedding_port, "encode", "send") + self.encode(req_id, mm_data, embedding_port, "encode", "send") ) return await asyncio.wait_for( self._recv_mm_data(req_id, recv_socket, mm_processor, prompt), @@ -508,7 +664,7 @@ async def _recv_mm_data(self, req_id, recv_socket, mm_processor, prompt): recv_embedding = None - recv_embedding_data: EmbeddingData = None + recv_embedding_data: MultiModalEmbeddingData = None while recv_embedding_data is None or not recv_embedding_data.ready: parts = await recv_socket.recv_multipart(copy=False) @@ -520,9 +676,12 @@ async def _recv_mm_data(self, req_id, recv_socket, mm_processor, prompt): recv_obj.embedding = torch.frombuffer( buffer, dtype=recv_obj.dtype ).reshape(recv_obj.shape) + if recv_embedding_data is None: recv_obj.embedding_list[recv_obj.part_idx] = recv_obj.embedding - recv_embedding_data = recv_obj + recv_embedding_data = MultiModalEmbeddingData.from_embedding_data( + recv_obj + ) else: recv_embedding_data.add(recv_obj) @@ -536,6 +695,9 @@ async def _recv_mm_data(self, req_id, recv_socket, mm_processor, prompt): recv_socket.close() img_grid_thw = recv_embedding_data.get_img_grid() + video_grid_thw = recv_embedding_data.get_video_grid() - mm_inputs = mm_processor.get_mm_data(prompt, recv_embedding, img_grid_thw) + mm_inputs = mm_processor.get_mm_data( + prompt, recv_embedding, img_grid_thw, video_grid_thw + ) return mm_inputs diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py index 9d347025dadb..c3384457461b 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -17,7 +17,7 @@ import zmq.asyncio from fastapi import FastAPI from fastapi.responses import ORJSONResponse, Response -from transformers import AutoImageProcessor +from transformers import AutoProcessor from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig @@ -33,6 +33,7 @@ from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.mem_cache.multimodal_cache import EmbeddingResult, MultiModalStaticCache from sglang.srt.model_loader import get_model +from sglang.srt.multimodal.processors.qwen_vl import preprocess_video from sglang.srt.server_args import ( PortArgs, ServerArgs, @@ -94,15 +95,18 @@ def _convert(data): return data -_image_grid_attrs = ["image_grid_thw", "image_grid_hws"] +_vision_grid_attrs = { + Modality.IMAGE: ["image_grid_thw", "image_grid_hws"], + Modality.VIDEO: ["video_grid_thw"], +} -def _get_image_grid_dim(images_input): - for attr in _image_grid_attrs: - if attr in images_input: - return images_input[attr] +def _get_vision_grid_dim(mm_inputs, modality): + for attr in _vision_grid_attrs[modality]: + if attr in mm_inputs: + return mm_inputs[attr] raise ValueError( - f"Image grid dim ({_image_grid_attrs}) not found in {images_input}" + f"Vision grid dim ({_vision_grid_attrs[modality]}) not found in {mm_inputs}" ) @@ -120,15 +124,33 @@ def __init__( self.rank = rank self.profiler = EncoderProfiler(rank) - self.image_processor = AutoImageProcessor.from_pretrained( - server_args.model_path, + processor_path = server_args.tokenizer_path or server_args.model_path + self.mm_processor = AutoProcessor.from_pretrained( + processor_path, trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, use_fast=True, ) self.model_config = ModelConfig.from_server_args( server_args, ) + vision_config = ( + server_args.mm_process_config.get("vision_config", {}) + if server_args.mm_process_config is not None + else {} + ) + if not vision_config: + # keep default values as qwen_vl.py + self.video_config = { + "fps": 2.0, + "max_frames": 768, + "min_frames": 4, + } + else: + self.video_config = vision_config + # default support cuda + self.video_config["device"] = "cuda" self.load_config = LoadConfig( load_format=server_args.load_format, @@ -138,6 +160,9 @@ def __init__( remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port, remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports, ) + self.model_type = getattr( + self.model_config.hf_config, "model_type", "unknown" + ).lower() self.device = server_args.device self.gpu_id = server_args.base_gpu_id + rank @@ -290,22 +315,66 @@ async def _flatten_and_load_images(self, mm_items): async_futures = [asyncio.wrap_future(f) for f in futures] return await asyncio.gather(*async_futures) - async def _encode(self, mm_items) -> torch.Tensor: - images = await self._flatten_and_load_images(mm_items) + async def _process_mm_items(self, mm_items, modality): + if modality == Modality.IMAGE: + images = await self._flatten_and_load_images(mm_items) + kwargs = {"device": self.device} if self.use_image_processor_gpu else {} + processor_input = self.mm_processor.image_processor(images=images, **kwargs) + feature = processor_input["pixel_values"] + get_feature_method = self.model.get_image_feature + elif modality == Modality.VIDEO: + # mainly follows qwen_vl.py: only support qwen series models for video processing + if "qwen" not in self.model_type: + raise ValueError( + f"Video modality processing is currently only supported for Qwen series models with EPD enabled, " + f"but got model_type: {self.model_type}" + ) + # Load videos concurrently + futures, _ = self.submit_data_loading_tasks( + mm_items, [Modality.VIDEO] * len(mm_items) + ) + async_futures = [asyncio.wrap_future(f) for f in futures] + video_items = await asyncio.gather(*async_futures) + + videos_processed = [ + await preprocess_video(video, video_config=self.video_config) + for video in video_items + ] + videos, video_metadata = map(list, zip(*videos_processed)) + + # pass preprocessed videos to processor with do_sample_frames=False + video_processor_kwargs = {} + video_processor_kwargs["do_sample_frames"] = False + for key in self.video_config: + video_processor_kwargs[key] = self.video_config[key] + if video_metadata: + video_processor_kwargs["video_metadata"] = video_metadata + processor_input = self.mm_processor.video_processor( + videos=videos, **video_processor_kwargs + ) + feature = processor_input["pixel_values_videos"] + get_feature_method = self.model.get_video_feature + else: + raise ValueError( + f"Currently only support image and video modalities, but got {modality}" + ) - kwargs = {"device": self.device} if self.use_image_processor_gpu else {} - images_input = self.image_processor(images=images, **kwargs) - feature = images_input["pixel_values"] mm_item = MultimodalDataItem.from_dict( { - "modality": Modality.IMAGE, + "modality": modality, "feature": _convert(feature), } ) - for k, v in images_input.items(): - if k == "pixel_values": + for k, v in processor_input.items(): + if k in ["pixel_values", "pixel_values_videos"]: continue mm_item.set(k, _convert(v)) + return processor_input, mm_item, get_feature_method + + async def _encode(self, mm_items, modality: Modality) -> torch.Tensor: + mm_inputs, mm_item, get_feature_fn = await self._process_mm_items( + mm_items, modality + ) # support mm_cache mm_embedding = None @@ -322,7 +391,7 @@ async def _encode(self, mm_items) -> torch.Tensor: if mm_embedding is None: with torch.inference_mode(): - mm_embedding: torch.Tensor = self.model.get_image_feature([mm_item]) + mm_embedding: torch.Tensor = get_feature_fn([mm_item]) mm_embedding = mm_embedding.cpu() if len(mm_embedding.shape) != 2: mm_embedding = mm_embedding.reshape(-1, mm_embedding.shape[-1]) @@ -337,7 +406,7 @@ async def _encode(self, mm_items) -> torch.Tensor: if self.profiler is not None: self.profiler.step() - return _get_image_grid_dim(images_input), mm_embedding + return _get_vision_grid_dim(mm_inputs, modality), mm_embedding async def _send( self, @@ -357,7 +426,6 @@ async def _send( self.engine.deregister(embedding.data_ptr()) mm_data.embedding = None - mm_data.embedding_list[mm_data.part_idx] = None # Send ack/data endpoint = ( @@ -382,9 +450,9 @@ async def _send( [pickle.dumps(new_mm_data), embedding_tensor.__buffer__()] ) - async def encode(self, mm_items, req_id, num_parts, part_idx): + async def encode(self, mm_items, modality: Modality, req_id, num_parts, part_idx): start_time = time.time() - image_grid_dim, mm_embedding = await self._encode(mm_items) + grid_dim, mm_embedding = await self._encode(mm_items, modality) end_time = time.time() logger.info(f"🕛 encode cost = {(end_time - start_time) * 1000:.2f}ms") if self.rank == 0: @@ -392,7 +460,8 @@ async def encode(self, mm_items, req_id, num_parts, part_idx): req_id, num_parts, part_idx, - image_grid_dim, + grid_dim, + modality, mm_embedding, ) self.embedding_to_send[mm_data.req_id] = mm_data @@ -583,6 +652,7 @@ async def run_encoder( else: await encoder.encode( mm_items=request["mm_items"], + modality=Modality.from_str(request["modality"]), req_id=request["req_id"], num_parts=request["num_parts"], part_idx=request["part_idx"], @@ -631,6 +701,7 @@ async def handle_encode_request(request: dict): nbytes, embedding_len, embedding_dim = await encoder.encode( mm_items=request["mm_items"], + modality=Modality.from_str(request["modality"]), req_id=request["req_id"], num_parts=request["num_parts"], part_idx=request["part_idx"], diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 04e3b6c5b521..4cf5e3187792 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -262,8 +262,8 @@ class GenerateReqInput(BaseReq, APIServingTimingMixin): external_trace_header: Optional[Dict] = None # For EPD-disaggregated inference - need_wait_for_image: Optional[bool] = None - num_items_assigned: Optional[List] = None + need_wait_for_mm_inputs: Optional[bool] = None + num_items_assigned: Optional[Dict[Union[str, int], List[int]]] = None # Multimodal tiling controls (extensions) max_dynamic_patch: Optional[int] = None @@ -758,8 +758,8 @@ class TokenizedGenerateReqInput(BaseReq): # Whether to return entropy return_entropy: bool = False - need_wait_for_image: bool = False - num_items_assigned: Optional[List] = None + need_wait_for_mm_inputs: bool = False + num_items_assigned: Optional[Dict[Union[str, int], List[int]]] = None @dataclass diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index 1e4d09036c9a..b8aa715d9e5f 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -422,8 +422,34 @@ def get_embedding_chunk( return embedding_chunk, start_index, end_index +def _get_precomputed_embedding_multi_items( + items_per_req: List[MultimodalDataItem], + prefix_len: int, + extend_len: int, + items_offset: List[Tuple[int, int]], +) -> Optional[torch.Tensor]: + """ + Extract precomputed embedding chunk for a request with multiple items. + """ + if any(item.precomputed_embeddings is None for item in items_per_req): + return None + req_embeddings = torch.concat( + [item.precomputed_embeddings for item in items_per_req] + ) + + # Extract the chunk using get_embedding_chunk logic (just slicing, no recomputation) + embedding_chunk, _, _ = get_embedding_chunk( + embedding=req_embeddings, + extend_prefix_len=prefix_len, + extend_seq_len=extend_len, + items_offset=items_offset, + ) + return embedding_chunk + + def _get_precomputed_embedding( items: List[MultimodalDataItem], + items_size: List[int], prefix_length: List[int], extend_length: List[int], items_offset_list: List[List[Tuple[int, int]]], @@ -434,38 +460,61 @@ def _get_precomputed_embedding( If none have precomputed_embeddings, return None. """ precomputed_embeddings = [] - for idx, item in enumerate(items): - if item.precomputed_embeddings is None: - precomputed_embeddings.append(None) + max_iterations = min(len(items_size) - 1, len(prefix_length)) + + for i in range(max_iterations): + if items_size[i] == items_size[i + 1]: continue - seq_start_idx = prefix_length[idx] - seq_end_idx = seq_start_idx + extend_length[idx] - 1 - prefix_embedding_length = [] - extend_embedding_length = [] - for mm_start_idx, mm_end_idx in items_offset_list[idx]: - if mm_start_idx > seq_end_idx: - break - if seq_start_idx > mm_start_idx: - prefix_embedding_length.append( - min(seq_start_idx - mm_start_idx, mm_end_idx - mm_start_idx + 1) - ) - if mm_end_idx >= seq_start_idx: - extend_embedding_length.append( - min( - mm_end_idx - seq_start_idx + 1, - seq_end_idx - mm_start_idx + 1, - mm_end_idx - mm_start_idx + 1, - seq_end_idx - seq_start_idx + 1, + + items_per_req = items[items_size[i] : items_size[i + 1]] + items_offset = items_offset_list[i] + + if len(items_per_req) > 1: + embedding_chunk = _get_precomputed_embedding_multi_items( + items_per_req=items_per_req, + prefix_len=prefix_length[i], + extend_len=extend_length[i] if i < len(extend_length) else 0, + items_offset=items_offset, + ) + if embedding_chunk is None: + return None + precomputed_embeddings.append(embedding_chunk) + else: + # Single item per request: use original logic + item = items_per_req[0] + if item.precomputed_embeddings is None: + precomputed_embeddings.append(None) + continue + seq_start_idx = prefix_length[i] + seq_end_idx = ( + seq_start_idx + (extend_length[i] if i < len(extend_length) else 0) - 1 + ) + prefix_embedding_length = [] + extend_embedding_length = [] + for mm_start_idx, mm_end_idx in items_offset: + if mm_start_idx > seq_end_idx: + break + if seq_start_idx > mm_start_idx: + prefix_embedding_length.append( + min(seq_start_idx - mm_start_idx, mm_end_idx - mm_start_idx + 1) ) - ) - prefix_embedding_length = int(np.sum(prefix_embedding_length)) - extend_embedding_length = int(np.sum(extend_embedding_length)) - precomputed_embeddings.append( - item.precomputed_embeddings[ - prefix_embedding_length : prefix_embedding_length - + extend_embedding_length - ] - ) + if mm_end_idx >= seq_start_idx: + extend_embedding_length.append( + min( + mm_end_idx - seq_start_idx + 1, + seq_end_idx - mm_start_idx + 1, + mm_end_idx - mm_start_idx + 1, + seq_end_idx - seq_start_idx + 1, + ) + ) + prefix_embedding_length = int(np.sum(prefix_embedding_length)) + extend_embedding_length = int(np.sum(extend_embedding_length)) + precomputed_embeddings.append( + item.precomputed_embeddings[ + prefix_embedding_length : prefix_embedding_length + + extend_embedding_length + ] + ) if any(feature is not None for feature in precomputed_embeddings): if not all(feature is not None for feature in precomputed_embeddings): @@ -883,7 +932,7 @@ def get_embedding_and_mask( """ # 1. Get embedding embedding = _get_precomputed_embedding( - embedding_items, prefix_length, extend_length, items_offset_list + embedding_items, items_size, prefix_length, extend_length, items_offset_list ) if embedding is None: embedding, input_ids = _get_chunked_prefill_embedding( diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index acc02aebe7e1..724410f03f07 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -713,7 +713,7 @@ async def _tokenize_one_request( ): if self.server_args.language_only: mm_inputs = await self.mm_receiver.recv_mm_data( - img_data=obj.image_data, + request_obj=obj, mm_processor=self.mm_processor, prompt=(input_text or input_ids), ) @@ -941,7 +941,7 @@ def _create_tokenized_object( priority=obj.priority, extra_key=obj.extra_key, routing_key=obj.routing_key, - need_wait_for_image=obj.need_wait_for_image, + need_wait_for_mm_inputs=obj.need_wait_for_mm_inputs, num_items_assigned=obj.num_items_assigned, ) elif isinstance(obj, EmbeddingReqInput): diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index 5b6afb15199b..6bf69168c99f 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -233,14 +233,15 @@ def __init__( def spatial_merge_size(self): return self.hf_config.vision_config.spatial_merge_size - def build_input_ids(self, prompt, img_grid_thw): + def build_input_ids(self, prompt, img_grid_thw, video_grid_thw): """ - Use prompt and img_grid_thw to build input_ids + Use prompt and img_grid_thw and video_grid_thw to build input_ids """ if not isinstance(prompt, list): prompt = self._processor.tokenizer.encode(prompt) img_token_id = self.IM_TOKEN_ID + video_token_id = self.VIDEO_TOKEN_ID spatial_merge_size = self.spatial_merge_size input_ids = [] @@ -250,34 +251,62 @@ def build_input_ids(self, prompt, img_grid_thw): # Use img_token_id instead of im_start_id, because a dummy im_start_id # may be generated by the tokenizer. - img_start_indices = list( - filter(lambda i: prompt[i + 1] == img_token_id, range(len(prompt) - 1)) - ) - - for cur_img_idx, img_start_idx in enumerate(img_start_indices): - assert cur_idx <= img_start_idx - # include img_start_id - input_ids.extend(prompt[cur_idx : img_start_idx + 1]) - img_offset_start = len(input_ids) - img_token_num = img_grid_thw[cur_img_idx].prod() // (spatial_merge_size**2) - input_ids.extend([img_token_id] * img_token_num) - # jump to img_end_id - cur_idx = img_start_idx + 2 - offsets.append((img_offset_start, len(input_ids) - 1)) + vision_start_indices = [] + for i in range(len(prompt) - 1): + if prompt[i + 1] == img_token_id: + vision_start_indices.append((i, Modality.IMAGE)) + elif prompt[i + 1] == video_token_id: + vision_start_indices.append((i, Modality.VIDEO)) + # get modality list with order preserved + modality_list = [modality for _, modality in vision_start_indices] + + img_idx = 0 + video_idx = 0 + for mm_start_idx, modality in vision_start_indices: + if modality == Modality.IMAGE: + mm_token_num = img_grid_thw[img_idx].prod() // (spatial_merge_size**2) + mm_token_id = img_token_id + img_idx += 1 + elif modality == Modality.VIDEO: + mm_token_num = video_grid_thw[video_idx].prod() // ( + spatial_merge_size**2 + ) + mm_token_id = video_token_id + video_idx += 1 + else: + raise ValueError(f"Invalid modality: {modality}") + assert cur_idx <= mm_start_idx + + input_ids.extend(prompt[cur_idx : mm_start_idx + 1]) + mm_offset_start = len(input_ids) + input_ids.extend([mm_token_id] * mm_token_num) + cur_idx = mm_start_idx + 2 # jump to img_end_id or video_end_id + offsets.append((mm_offset_start, len(input_ids) - 1)) else: input_ids.extend(prompt[cur_idx:]) - return input_ids, offsets + return input_ids, offsets, modality_list - def get_mm_data(self, prompt, embeddings, img_grid_thw): - input_ids, offsets = self.build_input_ids(prompt, img_grid_thw) - mm_items = [ - MultimodalDataItem( - modality=Modality.IMAGE, - offsets=offsets, - precomputed_embeddings=embeddings, + def get_mm_data(self, prompt, embeddings, img_grid_thw=None, video_grid_thw=None): + input_ids, offsets, modality_list = self.build_input_ids( + prompt, img_grid_thw, video_grid_thw + ) + assert all(isinstance(modality, Modality) for modality in modality_list) + + mm_items = [] + embedding_index = 0 + for modality, offset in zip(modality_list, offsets): + start_idx, end_idx = offset + num_tokens = end_idx - start_idx + 1 + embedding_slice = embeddings[embedding_index : embedding_index + num_tokens] + embedding_index += num_tokens + mm_items.append( + MultimodalDataItem( + modality=modality, + offsets=offset, + precomputed_embeddings=embedding_slice, + ) ) - ] return { "input_ids": input_ids, @@ -285,6 +314,7 @@ def get_mm_data(self, prompt, embeddings, img_grid_thw): "im_start_id": self.IM_START_TOKEN_ID, "im_end_id": self.IM_END_TOKEN_ID, "im_token_id": self.IM_TOKEN_ID, + "video_token_id": getattr(self, "VIDEO_TOKEN_ID", None), } def process_mm_data( diff --git a/python/sglang/srt/multimodal/processors/qwen_vl.py b/python/sglang/srt/multimodal/processors/qwen_vl.py index eb648542dd10..d51f7b013b95 100644 --- a/python/sglang/srt/multimodal/processors/qwen_vl.py +++ b/python/sglang/srt/multimodal/processors/qwen_vl.py @@ -239,6 +239,7 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): self.IM_START_TOKEN_ID = hf_config.vision_start_token_id self.IM_END_TOKEN_ID = hf_config.vision_end_token_id self.IM_TOKEN_ID = hf_config.image_token_id + self.VIDEO_TOKEN_ID = hf_config.video_token_id self.vision_start_token_id = hf_config.vision_start_token_id self.vision_end_token_id = getattr(hf_config, "vision_end_token_id", None) @@ -256,12 +257,16 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): image_token_regex=re.compile( r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>" ), - video_token_id=hf_config.video_token_id, + video_token_id=self.VIDEO_TOKEN_ID, audio_token_id=self.audio_token_id, ).build(_processor) - def get_mm_data(self, prompt, embeddings, img_grid_thw): - input_ids, offsets = self.build_input_ids(prompt, img_grid_thw) + def get_mm_data(self, prompt, embeddings, img_grid_thw, video_grid_thw): + input_ids, offsets, modality_list = self.build_input_ids( + prompt, img_grid_thw, video_grid_thw + ) + assert all(isinstance(modality, Modality) for modality in modality_list) + mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index( spatial_merge_size=self.hf_config.vision_config.spatial_merge_size, image_token_id=self.mm_tokens.image_token_id, @@ -270,19 +275,27 @@ def get_mm_data(self, prompt, embeddings, img_grid_thw): model_type=self.model_type, input_ids=torch.tensor(input_ids, dtype=torch.long).unsqueeze(0), image_grid_thw=img_grid_thw, + video_grid_thw=video_grid_thw, tokens_per_second=getattr( self.hf_config.vision_config, "tokens_per_second", None ), ) mrope_positions = mrope_positions.squeeze(1) - mm_items = [ - MultimodalDataItem( - modality=Modality.IMAGE, - offsets=offsets, - precomputed_embeddings=embeddings, + mm_items = [] + embedding_index = 0 + for modality, offset in zip(modality_list, offsets): + start_idx, end_idx = offset + num_tokens = end_idx - start_idx + 1 + embedding_slice = embeddings[embedding_index : embedding_index + num_tokens] + embedding_index += num_tokens + mm_items.append( + MultimodalDataItem( + modality=modality, + offsets=offset, + precomputed_embeddings=embedding_slice, + ) ) - ] return { "input_ids": input_ids, diff --git a/test/srt/test_epd_disaggregation.py b/test/srt/test_epd_disaggregation.py index 6733bc5692f3..04bfaa113b01 100644 --- a/test/srt/test_epd_disaggregation.py +++ b/test/srt/test_epd_disaggregation.py @@ -2,6 +2,8 @@ import threading import unittest +import openai + from sglang.srt.utils import kill_process_tree from sglang.test.kits.mmmu_vlm_kit import _run_lmms_eval_with_retry from sglang.test.server_fixtures.disaggregation_fixture import ( @@ -14,6 +16,9 @@ popen_launch_server, ) +# video test URL +VIDEO_JOBS_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/videos/jobs_presenting_ipod.mp4" + @unittest.skipIf(is_in_ci(), "Skipping in CI to reduce multi-GPU runtime") class TestEPDDisaggregationOneEncoder(PDDisaggregationServerBase): @@ -421,6 +426,98 @@ def test_mmmu(self): # for qwen2.5-vl-3b-instruct, the accuracy is 0.40 self.assertGreater(mmmu_accuracy, 0.40) + def test_video(self): + """Test video support with EPD disaggregation (multiple encoders)""" + client = openai.Client(api_key=self.api_key, base_url=f"{self.lb_url}/v1") + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe the video."}, + { + "type": "video_url", + "video_url": {"url": VIDEO_JOBS_URL}, + }, + ], + }, + ] + + response = client.chat.completions.create( + model="default", + messages=messages, + max_tokens=8192, + stream=False, + ) + + video_response = response.choices[0].message.content + print("-" * 30) + print(f"Video response (multi encoder):\n{video_response}") + print("-" * 30) + + # Add assertions to validate the video response + video_response_lower = video_response.lower() + + # Check for device-related keywords + has_device = ( + "ipod" in video_response_lower + or "device" in video_response_lower + or "microphone" in video_response_lower + or "smartphone" in video_response_lower + or "phone" in video_response_lower + ) + + # Check for person-related keywords + has_person = ( + "man" in video_response_lower + or "person" in video_response_lower + or "individual" in video_response_lower + or "speaker" in video_response_lower + or "presenter" in video_response_lower + or "steve" in video_response_lower + or "hand" in video_response_lower + or "hands" in video_response_lower + ) + + # Check for action-related keywords + has_action = ( + "present" in video_response_lower + or "presenting" in video_response_lower + or "examine" in video_response_lower + or "examining" in video_response_lower + or "display" in video_response_lower + or "displaying" in video_response_lower + or "hold" in video_response_lower + or "holding" in video_response_lower + or "gestur" in video_response_lower + or "speak" in video_response_lower + or "speaking" in video_response_lower + ) + + assert has_device, f""" + ====================== video response ===================== + {video_response} + =========================================================== + should contain device-related keywords: 'iPod', 'device', 'microphone', 'smartphone', or 'phone' + """ + + assert has_person, f""" + ====================== video response ===================== + {video_response} + =========================================================== + should contain person-related keywords: 'man', 'person', 'individual', 'speaker', 'presenter', 'Steve', 'hand', or 'hands' + """ + + assert has_action, f""" + ====================== video response ===================== + {video_response} + =========================================================== + should contain action-related keywords: 'present', 'presenting', 'examine', 'examining', 'display', 'displaying', 'hold', 'holding', 'gestur', 'speak', or 'speaking' + """ + + self.assertIsNotNone(video_response) + self.assertGreater(len(video_response), 0) + if __name__ == "__main__": unittest.main()