diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 5758ba51e18d..f0b3b4472161 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -820,6 +820,7 @@ def get_dataset(args, tokenizer, model_id=None): image_format=args.image_format, image_resolution=args.image_resolution, backend=args.backend, + random_image_count=args.random_image_count, ) elif args.dataset_name == "generated-shared-prefix": assert not tokenize_prompt @@ -1474,10 +1475,12 @@ def sample_image_requests( image_format: str, image_resolution: str, backend: str, + random_image_count: bool = False, ) -> List[DatasetRow]: """Generate requests with images. - - Each request includes ``image_count`` images. + - If ``random_image_count`` is True, each request includes a random number of images between 1 and ``image_count``. + - If ``random_image_count`` is False, each request includes exactly ``image_count`` images. - Supported resolutions: 4k (3840x2160), 1080p (1920x1080), 720p (1280x720), 360p (640x360), or custom 'heightxwidth' (e.g., 1080x1920). - Text lengths follow the 'random' dataset sampling rule. ``prompt_len`` @@ -1487,10 +1490,20 @@ def sample_image_requests( # Parse resolution (supports presets and 'heightxwidth') width, height = parse_image_resolution(image_resolution) + # Determine image counts for each request + if random_image_count: + # Random number of images per request + image_counts = np.random.randint(1, image_count + 1, size=num_requests) + total_images = np.sum(image_counts) + else: + # Fixed number of images per request + image_counts = np.full(num_requests, image_count) + total_images = image_count * num_requests + # Check for potentially problematic combinations and warn user - if width * height >= 1920 * 1080 and image_count * num_requests >= 100: + if width * height >= 1920 * 1080 and total_images >= 100: warnings.warn( - f"High resolution ({width}x{height}) with {image_count * num_requests} total images " + f"High resolution ({width}x{height}) with {total_images} total images " f"may take a long time. Consider reducing resolution or image count.", UserWarning, stacklevel=2, @@ -1528,6 +1541,9 @@ def _gen_random_image_data_uri( dataset: List[DatasetRow] = [] total_image_bytes = 0 for i in range(num_requests): + # Get the number of images for this request + request_image_count = int(image_counts[i]) + # Generate text prompt text_prompt = gen_mm_prompt( processor.tokenizer, @@ -1537,7 +1553,7 @@ def _gen_random_image_data_uri( # Generate image list images, images_base64, images_bytes = zip( - *[_gen_random_image_data_uri() for _ in range(image_count)] + *[_gen_random_image_data_uri() for _ in range(request_image_count)] ) total_image_bytes += sum(list(images_bytes)) @@ -1549,11 +1565,20 @@ def _gen_random_image_data_uri( processor, backend, ) - dataset.append(data_row) + # Print statistics print(f"#Input tokens: {np.sum([x.prompt_len for x in dataset])}") print(f"#Output tokens: {np.sum([x.output_len for x in dataset])}") + print(f"#Total images: {total_images}") + + if random_image_count: + print( + f"#Images per request: min={np.min(image_counts)}, max={np.max(image_counts)}, mean={np.mean(image_counts):.2f}" + ) + else: + print(f"#Images per request: {image_count} (fixed)") + print( f"\nCreated {len(dataset)} {image_content} {image_format} images with average {total_image_bytes // num_requests} bytes per request" ) @@ -2700,6 +2725,11 @@ def __call__(self, parser, namespace, values, option_string=None): "Supports presets 4k/1080p/720p/360p or custom 'heightxwidth' (e.g., 1080x1920)." ), ) + parser.add_argument( + "--random-image-count", + action="store_true", + help="Enable Random Image Count", + ) parser.add_argument( "--image-format", type=str, diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 05425d26800f..fae9e0f6b8c6 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -9,11 +9,15 @@ def run_server(server_args): - """Run the server based on server_args.grpc_mode.""" + """Run the server based on server_args.grpc_mode and server_args.encoder_only.""" if server_args.grpc_mode: from sglang.srt.entrypoints.grpc_server import serve_grpc asyncio.run(serve_grpc(server_args)) + elif server_args.encoder_only: + from sglang.srt.disaggregation.encode_server import launch_server + + launch_server(server_args) else: # Default mode: HTTP mode. from sglang.srt.entrypoints.http_server import launch_server diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 471d0b3eea88..aaee61c0ce91 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -100,6 +100,8 @@ def __init__( model_impl: Union[str, ModelImpl] = ModelImpl.AUTO, sampling_defaults: str = "openai", quantize_and_serve: bool = False, + encoder_only: bool = False, + language_only: bool = False, ) -> None: # Parse args self.model_path = model_path @@ -216,6 +218,9 @@ def __init__( self.hf_config, "image_token_id", None ) or getattr(self.hf_config, "image_token_index", None) + self.hf_config.encoder_only = encoder_only + self.hf_config.language_only = language_only + # matryoshka embeddings self.matryoshka_dimensions = getattr( self.hf_config, "matryoshka_dimensions", None @@ -246,6 +251,8 @@ def from_server_args( sampling_defaults=server_args.sampling_defaults, quantize_and_serve=server_args.quantize_and_serve, override_config_file=server_args.decrypted_config_file, + language_only=server_args.language_only, + encoder_only=server_args.encoder_only, **kwargs, ) diff --git a/python/sglang/srt/disaggregation/encode_receiver.py b/python/sglang/srt/disaggregation/encode_receiver.py new file mode 100644 index 000000000000..c4938213909b --- /dev/null +++ b/python/sglang/srt/disaggregation/encode_receiver.py @@ -0,0 +1,544 @@ +import asyncio +import logging +import pickle +import random +import threading +import uuid +from typing import List + +import aiohttp +import torch +import zmq +import zmq.asyncio + +from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine +from sglang.srt.managers.io_struct import TokenizedGenerateReqInput +from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import get_local_ip_auto, get_zmq_socket_on_host +from sglang.srt.utils.hf_transformers_utils import get_processor + +logger = logging.getLogger(__name__) + + +class EmbeddingData: + def __init__(self, req_id, num_parts, part_idx, image_grid_dim, embedding=None): + self.req_id = req_id + self.num_parts = num_parts + self.part_idx = part_idx + self.image_grid_dim = image_grid_dim + self.embedding = embedding + self.send_time = None + self.dtype = embedding.dtype if embedding is not None else None + self.shape = list(embedding.shape) if embedding is not None else None + # aggregated data + self.ready_list = [i == self.part_idx for i in range(self.num_parts)] + self.embedding_list = [ + embedding if i == self.part_idx else None for i in range(self.num_parts) + ] + self.image_grid_dim_list = [ + self.image_grid_dim if i == self.part_idx else None + for i in range(self.num_parts) + ] + + def add(self, embedding_data): + assert self.req_id == embedding_data.req_id + assert not self.ready_list[embedding_data.part_idx] + self.ready_list[embedding_data.part_idx] = True + self.image_grid_dim_list[embedding_data.part_idx] = ( + embedding_data.image_grid_dim + ) + self.embedding_list[embedding_data.part_idx] = embedding_data.embedding + + def get_embedding(self, is_concat=False): + if is_concat: + return torch.concat([embedding.cuda() for embedding in self.embedding_list]) + else: + return self.embedding_list + + def get_img_grid(self): + return torch.concatenate(self.image_grid_dim_list) + + @property + def ready(self): + return sum(self.ready_list) == self.num_parts + + def __repr__(self): + return f"EmbeddingData(req_id={self.req_id}, num_parts={self.num_parts}, part_idx={self.part_idx})" + + def copy_without_embedding(self): + new_data = EmbeddingData( + req_id=self.req_id, + num_parts=self.num_parts, + part_idx=self.part_idx, + image_grid_dim=self.image_grid_dim, + ) + new_data.send_time = self.send_time + new_data.dtype = self.dtype + new_data.shape = self.shape + return new_data + + +# For zmq_to_scheduler +class WaitingImageRequest: + def __init__( + self, + rid: str, + recv_req: TokenizedGenerateReqInput, + mm_processor, + encoder_urls, + host_name, + receive_count, + embedding_port=None, + ): + self.rid = rid + self.recv_req = recv_req + self.mm_inputs = None + self.error = None + self.thread = None + self.mm_processor = mm_processor + self.encoder_urls = encoder_urls + self.host_name = host_name + self.receive_count = receive_count + self.num_items_assigned = recv_req.num_items_assigned + self.embedding_port, self.recv_socket = get_zmq_socket_on_host( + zmq.Context(), zmq.PULL + ) + logger.info(f"Waiting for input {self.embedding_port = }") + self.recv_embedding_data = None + self.ready = False + + def send_encode_request(self): + async def _send_single_request(session, url, payload): + try: + async with session.post(url, json=payload) as response: + response.raise_for_status() + return await response.text() + except Exception as e: + logger.error(f"Failed to send request to {url}: {e}") + raise + + async def send_embedding_port(req_id, receive_count, host_name, embedding_port): + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=1800) + ) as session: + tasks = [] + logger.info(f"{self.num_items_assigned = } ") + for idx, assigned_num in enumerate(self.num_items_assigned): + if assigned_num == 0: + continue + encoder_url = self.encoder_urls[idx] + target_url = f"{encoder_url}/scheduler_receive_url" + payload = { + "req_id": req_id, + "receive_count": receive_count, + "receive_url": f"{host_name}:{embedding_port}", + } + + logger.info(f"Preparing to send to {target_url}") + + task = _send_single_request(session, target_url, payload) + tasks.append(task) + + if not tasks: + logger.info("No tasks to send.") + return + logger.info(f"Concurrently sending {len(tasks)} requests...") + results = await asyncio.gather(*tasks, return_exceptions=True) + + for i, result in enumerate(results): + if isinstance(result, Exception): + logger.error(f"Request {i} failed: {result}") + else: + logger.debug(f"Request {i} succeeded.") + + asyncio.run( + send_embedding_port( + self.recv_req.rid, + self.receive_count, + self.host_name, + self.embedding_port, + ) + ) + + def _try_recv_mm_data(self): + if self.ready: + return + while self.recv_embedding_data is None or not self.recv_embedding_data.ready: + try: + parts = self.recv_socket.recv_multipart(flags=zmq.NOBLOCK, copy=False) + except zmq.Again: + # No data available yet, wait a bit and retry + return + + recv_obj: EmbeddingData = pickle.loads(parts[0]) + buffer = parts[1].buffer if hasattr(parts[1], "buffer") else parts[1] + recv_obj.embedding = torch.frombuffer(buffer, dtype=recv_obj.dtype).reshape( + recv_obj.shape + ) + recv_obj.embedding_list[recv_obj.part_idx] = recv_obj.embedding + if self.recv_embedding_data is None: + self.recv_embedding_data = recv_obj + else: + self.recv_embedding_data.add(recv_obj) + + recv_embedding = self.recv_embedding_data.get_embedding(is_concat=True) + img_grid_thw = self.recv_embedding_data.get_img_grid() + + mm_inputs = self.mm_processor.get_mm_data( + self.recv_req.input_text, recv_embedding, img_grid_thw + ) + self.recv_req.mm_inputs = mm_inputs + self.recv_req.input_ids = mm_inputs["input_ids"] + self.ready = True + self.recv_socket.close() + + +def _determine_tensor_transport_mode(server_args): + is_cross_node = server_args.dist_init_addr + + if is_cross_node: + # Fallback to default CPU transport for multi-node + return "default" + else: + return "cuda_ipc" + + +class MMReceiver: + + def __init__( + self, + server_args: ServerArgs, + dtype=None, + hf_config=None, + pp_rank=None, + tp_rank=None, + ): + self.context = zmq.asyncio.Context(20) + self.encoder_transfer_backend = server_args.encoder_transfer_backend + self.encode_urls = server_args.encoder_urls + self.encode_idx = list(range(len(self.encode_urls))) + self.host = server_args.host + if self.encoder_transfer_backend == "mooncake": + self.dtype = dtype + self.embeddings_engine = MooncakeTransferEngine( + hostname=get_local_ip_auto(), + gpu_id=None, + ib_device=server_args.disaggregation_ib_device, + ) + self.embeddings_buffer = dict() + elif self.encoder_transfer_backend == "zmq_to_scheduler": + self.pp_rank = pp_rank + self.tp_rank = tp_rank + self.tp_size = server_args.tp_size + self.nnodes = server_args.nnodes + self.hostname = get_local_ip_auto() + self.world_size = server_args.pp_size * server_args.tp_size + self.waiting_list: List[WaitingImageRequest] = [] + if hf_config is not None: + transport_mode = _determine_tensor_transport_mode(server_args) + import_processors("sglang.srt.multimodal.processors") + _processor = None + try: + _processor = get_processor( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, + use_fast=not server_args.disable_fast_image_processor, + ) + except ValueError as e: + error_message = str(e) + if "does not have a slow version" in error_message: + logger.info( + f"Processor {server_args.tokenizer_path} does not have a slow version. Automatically use fast version" + ) + _processor = get_processor( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, + use_fast=True, + ) + else: + raise e + self.mm_processor = get_mm_processor( + hf_config, server_args, _processor, transport_mode + ) + + # For zmq_to_scheduler + def process_waiting_requests(self, recv_reqs): + new_recv_reqs = [] + for recv_req in recv_reqs: + # E Disaggregation + if ( + isinstance(recv_req, TokenizedGenerateReqInput) + and recv_req.need_wait_for_image is True + ): + embedding_port = None + if recv_req.embedding_ports is not None: + embedding_port = recv_req.embedding_ports[ + self.tp_size * self.pp_rank + self.tp_rank + ] + waiting_req = WaitingImageRequest( + rid=recv_req.rid, + recv_req=recv_req, + mm_processor=self.mm_processor, + encoder_urls=self.encode_urls, + host_name=self.hostname, + receive_count=self.world_size, + embedding_port=embedding_port, + ) + if recv_req.embedding_ports is None: + waiting_req.send_encode_request() + self.waiting_list.append(waiting_req) + else: + new_recv_reqs.append(recv_req) + + if len(self.waiting_list) == 0: + return new_recv_reqs + + local_status = [] + for waiting_req in self.waiting_list: + waiting_req._try_recv_mm_data() + local_status.append(waiting_req.ready) + + local_status = torch.tensor(local_status, device="cuda", dtype=torch.int32) + + torch.distributed.all_reduce(local_status, op=torch.distributed.ReduceOp.MIN) + + new_waiting = [] + for i, waiting_req in enumerate(self.waiting_list): + if local_status[i].item(): + new_recv_reqs.append(waiting_req.recv_req) + else: + new_waiting.append(waiting_req) + + self.waiting_list = new_waiting + return new_recv_reqs + + # For zmq_to_scheduler + def _run_encode_in_thread( + self, req_id, img_data, endpoint_encode, num_items_assigned, embedding_port + ): + try: + asyncio.run( + self.encode( + req_id=req_id, + img_data=img_data, + embedding_port=embedding_port, + endpoint_encode=endpoint_encode, + endpoint_send=None, + num_items_assigned=num_items_assigned, + ) + ) + except Exception as e: + logger.error(f"Encode failed for request {req_id}: {e}", exc_info=True) + + async def encode( + self, + req_id, + img_data, + embedding_port, + endpoint_encode, + endpoint_send, + num_items_assigned=None, + ): + if len(img_data) == 0: + return + + # Split mm_items + encode_requests = [] + if num_items_assigned is None: + random.shuffle(self.encode_idx) + num_items_assigned = [ + (idx + len(img_data)) // len(self.encode_urls) + for idx in self.encode_idx + ] + num_parts = sum(1 for x in num_items_assigned if x != 0) + cum_num_items = 0 + cum_idx = 0 + for idx, assigned_num in enumerate(num_items_assigned): + if assigned_num == 0: + continue + encode_requests.append( + { + "encoder_idx": idx, + "mm_items": img_data[cum_num_items : cum_num_items + assigned_num], + "num_parts": num_parts, + "part_idx": cum_idx, + "req_id": req_id, + "prefill_host": self.host, + "embedding_port": embedding_port, + } + ) + cum_idx += 1 + cum_num_items += assigned_num + + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout( + total=1800 + ) # Add timeout for request reliability + ) as session: + # Send encode requests + + tasks = [ + session.post( + f"{self.encode_urls[encode_request['encoder_idx']]}/{endpoint_encode}", + json=encode_request, + ) + for encode_request in encode_requests + ] + + responses = await asyncio.gather(*tasks) + response_json_list_unsort = [ + await response.json() for response in responses + ] + + # zmq backend: return is None + if None in response_json_list_unsort: + return + + # mooncake backend: send bootstrap info + + embedding_size_list_sort = [None for _ in range(num_parts)] + embedding_length_tot = 0 + response_json_list_sort = [None for _ in range(num_parts)] + for response_json in response_json_list_unsort: + idx = response_json["part_idx"] + embedding_size_list_sort[idx] = response_json["embedding_size"] + embedding_length_tot += response_json["embedding_len"] + response_json_list_sort[idx] = response_json + + offset = 0 + metadata_tasks = [] + buffer_address = await self.allocate_embedding_buffer( + req_id, + embedding_length_tot, + response_json_list_sort[0]["embedding_dim"], + ) + for idx in range(len(tasks)): + response_json = response_json_list_sort[idx] + buffer_address_adjust = offset + buffer_address + response_json.update( + { + "session_id": self.embeddings_engine.session_id, + "buffer_address": buffer_address_adjust, + } + ) + metadata_tasks.append( + session.post( + f"{self.encode_urls[response_json['encoder_idx']]}/{endpoint_send}", + json=response_json, + ) + ) + offset += embedding_size_list_sort[idx] + await asyncio.gather(*metadata_tasks) + + # For mooncake + async def allocate_embedding_buffer(self, req_id, embedding_length, embedding_dim): + embeddings = torch.zeros( + (embedding_length, embedding_dim), + dtype=self.dtype, + ) + self.embeddings_engine.register( + embeddings.data_ptr(), + embeddings.nbytes, + ) + self.embeddings_buffer[req_id] = embeddings + return embeddings.data_ptr() + + # For zmq_to_scheduler + def send_encode_request(self, obj): + if type(obj.image_data) != list: + image_urls = [obj.image_data.url] + else: + image_urls = [img.url for img in obj.image_data] + if obj.rid is None: + obj.rid = uuid.uuid4().hex + if image_urls and len(image_urls) > 0: + logger.info(f"Processing {len(image_urls)} images for request {obj.rid}") + obj.need_wait_for_image = True + + encode_idx = list(range(len(self.encode_urls))) + random.shuffle(encode_idx) + obj.num_items_assigned = [ + (idx + len(image_urls)) // len(self.encode_urls) for idx in encode_idx + ] + obj.embedding_ports = None + encode_thread = threading.Thread( + target=self._run_encode_in_thread, + args=( + obj.rid, + image_urls, + "encode", + obj.num_items_assigned, + obj.embedding_ports, + ), + daemon=True, + ) + encode_thread.start() + + # For zmq_to_tokenizer and mooncake + async def recv_mm_data(self, img_data, mm_processor, prompt): + try: + if len(self.encode_urls) == 0: + return None + req_id = uuid.uuid4().hex + embedding_port, recv_socket = get_zmq_socket_on_host(self.context, zmq.PULL) + if type(img_data) != list: + img_data = [img_data.url] + else: + img_data = [img.url for img in img_data] + asyncio.create_task( + self.encode(req_id, img_data, embedding_port, "encode", "send") + ) + return await asyncio.wait_for( + self._recv_mm_data(req_id, recv_socket, mm_processor, prompt), + timeout=20, + ) + except asyncio.TimeoutError: + logger.warning(f"Embedding recv timeout for request {req_id}") + if hasattr(self, "embeddings_buffer") and req_id in self.embeddings_buffer: + del self.embeddings_buffer[req_id] + return None + + # For zmq_to_tokenizer and mooncake + async def _recv_mm_data(self, req_id, recv_socket, mm_processor, prompt): + # Bypass MMReceiver + if req_id is None: + return None + + recv_embedding = None + + recv_embedding_data: EmbeddingData = None + + while recv_embedding_data is None or not recv_embedding_data.ready: + parts = await recv_socket.recv_multipart(copy=False) + + recv_obj: EmbeddingData = pickle.loads(parts[0]) + logger.info(f"{recv_obj = }") + if self.encoder_transfer_backend == "zmq_to_tokenizer": + buffer = parts[1].buffer if hasattr(parts[1], "buffer") else parts[1] + recv_obj.embedding = torch.frombuffer( + buffer, dtype=recv_obj.dtype + ).reshape(recv_obj.shape) + if recv_embedding_data is None: + recv_obj.embedding_list[recv_obj.part_idx] = recv_obj.embedding + recv_embedding_data = recv_obj + else: + recv_embedding_data.add(recv_obj) + + if self.encoder_transfer_backend == "mooncake": + recv_embedding = self.embeddings_buffer[req_id] + del self.embeddings_buffer[req_id] + self.embeddings_engine.deregister(recv_embedding.data_ptr()) + elif self.encoder_transfer_backend == "zmq_to_tokenizer": + recv_embedding = recv_embedding_data.get_embedding(is_concat=True) + + recv_socket.close() + + img_grid_thw = recv_embedding_data.get_img_grid() + + mm_inputs = mm_processor.get_mm_data(prompt, recv_embedding, img_grid_thw) + return mm_inputs diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py new file mode 100644 index 000000000000..a9ad60f8777f --- /dev/null +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -0,0 +1,527 @@ +import asyncio +import ctypes +import logging +import multiprocessing as mp +import os +import pickle +import time +import traceback +from typing import Dict, List, Optional, Set, Tuple + +import aiohttp +import numpy as np +import torch +import uvicorn +import zmq +import zmq.asyncio +from fastapi import FastAPI +from fastapi.responses import ORJSONResponse, Response +from transformers import AutoImageProcessor +from transformers.image_utils import load_images + +from sglang.srt.configs.device_config import DeviceConfig +from sglang.srt.configs.load_config import LoadConfig +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.disaggregation.encode_receiver import EmbeddingData +from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine +from sglang.srt.distributed.parallel_state import ( + init_distributed_environment, + initialize_model_parallel, +) +from sglang.srt.layers.dp_attention import initialize_dp_attention +from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem +from sglang.srt.mem_cache.multimodal_cache import MultiModalStaticCache +from sglang.srt.model_loader import get_model +from sglang.srt.server_args import ( + PortArgs, + ServerArgs, + set_global_server_args_for_scheduler, +) +from sglang.srt.utils import get_local_ip_auto, get_zmq_socket, random_uuid + +logger = logging.getLogger(__name__) + +rid_lock = asyncio.Lock() +rid_to_receive_endpoint: Dict[str, List[str]] = dict() +rid_to_receive_count: Dict[str, int] = dict() + + +class TensorWrapper: + """Wrapper to keep tensor alive while exposing buffer for zero-copy.""" + + def __init__(self, tensor): + # Ensure tensor is on CPU and contiguous + if tensor.is_cuda: + tensor = tensor.cpu() + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + + # Keep tensor reference + self.tensor = tensor + self.shape = list(tensor.shape) + self.dtype = tensor.dtype + + def __buffer__(self): + data_ptr = self.tensor.data_ptr() + total_bytes = self.tensor.numel() * self.tensor.element_size() + c_obj = (ctypes.c_char * total_bytes).from_address(data_ptr) + c_obj._keep_alive_ref = self + return memoryview(c_obj) + + +def _convert(data): + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, np.ndarray): + return torch.tensor(data) + elif isinstance(data, list) and isinstance(data[0], np.ndarray): + return torch.tensor(np.array(data)) + elif isinstance(data, list) and isinstance(data[0], (int, float)): + return torch.tensor(data) + else: + return data + + +_image_grid_attrs = ["image_grid_thw", "image_grid_hws"] + + +def _get_image_grid_dim(images_input): + for attr in _image_grid_attrs: + if attr in images_input: + return images_input[attr] + raise ValueError( + f"Image grid dim ({_image_grid_attrs}) not found in {images_input}" + ) + + +class MMEncoder: + def __init__( + self, + server_args: ServerArgs, + schedule_path=None, + dist_init_method=None, + rank: int = 0, + ): + logger.info(f"init MMEncoder {rank}/{server_args.tp_size}") + self.server_args = server_args + set_global_server_args_for_scheduler(server_args) + self.rank = rank + + self.image_processor = AutoImageProcessor.from_pretrained( + server_args.model_path, + trust_remote_code=server_args.trust_remote_code, + use_fast=True, + ) + + self.model_config = ModelConfig.from_server_args( + server_args, + ) + + self.load_config = LoadConfig( + load_format=server_args.load_format, + download_dir=server_args.download_dir, + model_loader_extra_config=server_args.model_loader_extra_config, + remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip, + remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port, + remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports, + ) + + self.device = server_args.device + self.gpu_id = server_args.base_gpu_id + rank + + self.device_config = DeviceConfig( + device=self.device, + gpu_id=self.gpu_id, + ) + + torch.get_device_module(self.device).set_device(self.gpu_id) + + init_distributed_environment( + world_size=server_args.tp_size, + rank=rank, + distributed_init_method=dist_init_method, + local_rank=rank, + ) + initialize_model_parallel(tensor_model_parallel_size=server_args.tp_size) + initialize_dp_attention(server_args, self.model_config) + + self.model = get_model( + model_config=self.model_config, + load_config=self.load_config, + device_config=self.device_config, + ) + + self.context = zmq.asyncio.Context(2) + + embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "4096")) + self.mm_cache = MultiModalStaticCache(embedding_cache_size * 1024 * 1024) + self.mm_cache_lock = asyncio.Lock() + + if schedule_path is not None: + self.schedule_socket = get_zmq_socket( + self.context, zmq.PULL, schedule_path, True + ) + + if self.rank == 0: + logger.info( + f"Using transfer backend: {self.server_args.encoder_transfer_backend}" + ) + + if self.server_args.encoder_transfer_backend == "mooncake": + self.local_ip = get_local_ip_auto() + + self.engine = MooncakeTransferEngine( + hostname=self.local_ip, + gpu_id=None, + ib_device=server_args.disaggregation_ib_device, + ) + + self.embedding_to_send = dict() + + logger.info(f"rank {rank} init finish ") + + async def _encode(self, mm_items) -> torch.Tensor: + images = load_images(mm_items) + + images_input = self.image_processor(images=images) + feature = images_input["pixel_values"] + mm_item = MultimodalDataItem.from_dict( + { + "modality": Modality.IMAGE, + "feature": _convert(feature), + } + ) + for k, v in images_input.items(): + if k == "pixel_values": + continue + mm_item.set(k, _convert(v)) + + # support mm_cache + mm_embedding = None + mm_hash = None + + start_time = time.perf_counter() + if self.server_args.enable_prefix_mm_cache: + mm_item.set_pad_value() + mm_hash = MultiModalStaticCache.combine_hashes([mm_item.hash]) + async with self.mm_cache_lock: + mm_cache = self.mm_cache.get([mm_item.hash]) + if mm_cache is not None: + mm_embedding = mm_cache + + if mm_embedding is None: + with torch.inference_mode(): + mm_embedding: torch.Tensor = self.model.get_image_feature([mm_item]) + mm_embedding = mm_embedding.cpu() + if len(mm_embedding.shape) != 2: + mm_embedding = mm_embedding.reshape(-1, mm_embedding.shape[-1]) + + if self.server_args.enable_prefix_mm_cache: + async with self.mm_cache_lock: + self.mm_cache.set(mm_hash, mm_embedding) + end_time = time.perf_counter() + logger.info( + f"Vit time : {(end_time - start_time)*1000:.2f} ms {mm_embedding.shape = }" + ) + + return _get_image_grid_dim(images_input), mm_embedding + + async def _send( + self, + embedding: torch.Tensor, + mm_data: EmbeddingData, + session_id=None, + buffer_address=None, + prefill_host=None, + embedding_port=None, + url=None, + ): + if self.server_args.encoder_transfer_backend == "mooncake": + self.engine.register(embedding.data_ptr(), embedding.nbytes) + self.engine.transfer_sync( + session_id, embedding.data_ptr(), buffer_address, embedding.nbytes + ) + self.engine.deregister(embedding.data_ptr()) + + mm_data.embedding = None + mm_data.embedding_list[mm_data.part_idx] = None + + # Send ack/data + endpoint = ( + f"tcp://{url}" + if url is not None + else f"tcp://{prefill_host}:{embedding_port}" + ) + logger.info(f"{endpoint = }") + socket = get_zmq_socket( + self.context, + zmq.PUSH, + endpoint, + False, + ) + + if self.server_args.encoder_transfer_backend == "mooncake": + socket.send_multipart([pickle.dumps(mm_data)]) + else: + new_mm_data = mm_data.copy_without_embedding() + embedding_tensor = TensorWrapper(mm_data.embedding) + socket.send_multipart( + [pickle.dumps(new_mm_data), embedding_tensor.__buffer__()] + ) + + async def encode(self, mm_items, req_id, num_parts, part_idx): + start_time = time.time() + image_grid_dim, mm_embedding = await self._encode(mm_items) + end_time = time.time() + logger.info(f"🕛 encode cost = {(end_time - start_time) * 1000:.2f}ms") + if self.rank == 0: + mm_data = EmbeddingData( + req_id, + num_parts, + part_idx, + image_grid_dim, + mm_embedding, + ) + self.embedding_to_send[mm_data.req_id] = mm_data + return mm_embedding.nbytes, mm_embedding.shape[0], mm_embedding.shape[1] + + # For zmq_to_tokenizer zmq_to_scheduler and mooncake + async def send( + self, req_id, prefill_host, embedding_port, session_id=None, buffer_address=None + ): + mm_data: EmbeddingData = self.embedding_to_send[req_id] + await self._send( + mm_data.embedding, + mm_data, + session_id=session_id, + buffer_address=buffer_address, + prefill_host=prefill_host, + embedding_port=embedding_port, + ) + + # For zmq_to_scheduler + async def send_with_url( + self, + req_id, + ): + mm_data = self.embedding_to_send.get(req_id) + if not mm_data: + return + sent_urls: Set[str] = set() + all_tasks: List[Tuple[asyncio.Task, str]] = [] + start_time = asyncio.get_running_loop().time() + timeout = 60.0 + + try: + while True: + async with rid_lock: + current_targets = rid_to_receive_endpoint.get(req_id, set()).copy() + expected_count = rid_to_receive_count.get(req_id) + + new_targets = current_targets - sent_urls + + if new_targets: + logger.info( + f"Found {len(new_targets)} new endpoints for {req_id}. Starting tasks..." + ) + for url in new_targets: + task = asyncio.create_task( + self._send( + mm_data.embedding, + mm_data, + url=url, + ) + ) + all_tasks.append((task, url)) + sent_urls.add(url) # Mark as handled immediately + if expected_count is not None and len(sent_urls) >= expected_count: + logger.info( + f"All {expected_count} endpoints initiated for {req_id}. Breaking loop." + ) + break + + if asyncio.get_running_loop().time() - start_time > timeout: + logger.error( + f"Timeout waiting for all endpoints for {req_id}. Initiated {len(sent_urls)}/{expected_count}" + ) + break + + await asyncio.sleep(0.001) + + if all_tasks: + logger.info( + f"Loop finished. Awaiting completion of {len(all_tasks)} sending tasks..." + ) + tasks_only = [t[0] for t in all_tasks] + results = await asyncio.gather(*tasks_only, return_exceptions=True) + + # Process results and log errors + for i, result in enumerate(results): + url = all_tasks[i][1] # Retrieve URL associated with the task + if isinstance(result, Exception): + logger.error(f"Failed to send to {url}: {result}") + else: + logger.debug(f"Successfully sent to {url}") + + logger.info(f"All tasks completed for req_id: {req_id}") + + finally: + logger.info(f"Cleaning up resources for req_id {req_id}") + async with rid_lock: + rid_to_receive_endpoint.pop(req_id, None) + rid_to_receive_count.pop(req_id, None) + self.embedding_to_send.pop(req_id, None) + + async def get_embedding_port(self, prefill_url): + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=1800) + ) as session: + response = await session.post( + f"{prefill_url}/embedding_bootstrap", + json={"embedding_port": None}, + ) + response_json = await response.json() + return response_json["embedding_port"] + + +app = FastAPI() +encoder: Optional[MMEncoder] = None +send_sockets: List[zmq.Socket] = [] + + +async def run_encoder( + server_args: ServerArgs, schedule_path, dist_init_method, rank: int +): + encoder = MMEncoder(server_args, schedule_path, dist_init_method, rank) + while True: + request = await encoder.schedule_socket.recv_pyobj() + await encoder.encode( + mm_items=request["mm_items"], + req_id=request["req_id"], + num_parts=request["num_parts"], + part_idx=request["part_idx"], + ) + + +def launch_encoder(server_args, schedule_path, dist_init_method, rank): + try: + asyncio.run(run_encoder(server_args, schedule_path, dist_init_method, rank)) + except KeyboardInterrupt: + logger.info(f"Exit rank {rank}") + except Exception: + traceback.print_exc() + + +def launch_server(server_args: ServerArgs): + global encoder + ctx = mp.get_context("spawn") + zmq_ctx = zmq.Context(10) + ipc_path_prefix = random_uuid() + port_args = PortArgs.init_new(server_args) + if server_args.dist_init_addr: + dist_init_method = f"tcp://{server_args.dist_init_addr}" + else: + dist_init_method = f"tcp://127.0.0.1:{port_args.nccl_port}" + for rank in range(1, server_args.tp_size): + schedule_path = f"ipc:///tmp/{ipc_path_prefix}_schedule_{rank}" + send_sockets.append( + get_zmq_socket(zmq_ctx, zmq.PUSH, schedule_path, bind=False) + ) + ctx.Process( + target=launch_encoder, + args=(server_args, schedule_path, dist_init_method, rank), + daemon=True, + ).start() + encoder = MMEncoder(server_args, dist_init_method=dist_init_method) + uvicorn.run(app, host=server_args.host, port=server_args.port) + + +@app.post("/encode") +async def handle_encode_request(request: dict): + # broadcast request + request.update({"enter_time": time.time()}) + for socket in send_sockets: + socket.send_pyobj(request) + + nbytes, embedding_len, embedding_dim = await encoder.encode( + mm_items=request["mm_items"], + req_id=request["req_id"], + num_parts=request["num_parts"], + part_idx=request["part_idx"], + ) + if encoder.server_args.encoder_transfer_backend == "mooncake": + del request["mm_items"] + request.update( + { + "embedding_size": nbytes, + "embedding_len": embedding_len, + "embedding_dim": embedding_dim, + } + ) + return ORJSONResponse(content=request) + elif encoder.server_args.encoder_transfer_backend == "zmq_to_scheduler": + logger.info(f"{request['embedding_port'] = }") + if request["embedding_port"] is None: + await encoder.send_with_url( + req_id=request["req_id"], + ) + else: + assert type(request["embedding_port"]) == list + tasks = [] + for embedding_port in request["embedding_port"]: + tasks.append( + encoder.send( + req_id=request["req_id"], + prefill_host=request["prefill_host"], + embedding_port=embedding_port, + ) + ) + await asyncio.gather(*tasks) + encoder.embedding_to_send.pop(request["req_id"], None) + return ORJSONResponse(content=None) + elif encoder.server_args.encoder_transfer_backend == "zmq_to_tokenizer": + await encoder.send( + req_id=request["req_id"], + prefill_host=request["prefill_host"], + embedding_port=request["embedding_port"], + ) + encoder.embedding_to_send.pop(request["req_id"], None) + return ORJSONResponse(content=None) + + +@app.post("/send") +async def handle_send_request(request: dict): + # mooncake backend + await encoder.send( + req_id=request["req_id"], + prefill_host=request["prefill_host"], + embedding_port=request["embedding_port"], + session_id=request["session_id"], + buffer_address=request["buffer_address"], + ) + encoder.embedding_to_send.pop(request["req_id"], None) + return ORJSONResponse(content=None) + + +@app.post("/scheduler_receive_url") +async def handle_scheduler_receive_url_request(request: dict): + rid = request["req_id"] + async with rid_lock: + global rid_to_receive_endpoint + if rid not in rid_to_receive_endpoint: + rid_to_receive_endpoint[rid] = set() + rid_to_receive_count[rid] = request["receive_count"] + assert rid_to_receive_count[rid] == request["receive_count"] + rid_to_receive_endpoint[rid].add(request["receive_url"]) + + +@app.get("/health") +@app.get("/health_generate") +async def health_generate(): + """ + Health check endpoint for the encoder server. + Returns 200 if the encoder is initialized and ready. + """ + if encoder is None: + return Response(status_code=503) + return Response(status_code=200) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index e71bfce48782..3ff0f3b7f1ef 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -245,6 +245,10 @@ class GenerateReqInput(BaseReq, APIServingTimingMixin): # Whether to return entropy return_entropy: bool = False + need_wait_for_image: Optional[bool] = None + num_items_assigned: Optional[List] = None + embedding_ports: Optional[List] = None + def contains_mm_input(self) -> bool: return ( has_valid_data(self.image_data) @@ -724,6 +728,10 @@ class TokenizedGenerateReqInput(BaseReq): # Whether to return entropy return_entropy: bool = False + need_wait_for_image: bool = False + num_items_assigned: Optional[List] = None + embedding_ports: Optional[List] = None + @dataclass class BatchTokenizedGenerateReqInput(BaseBatchReq): diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index 75c646a63ea3..9c3864331628 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -395,13 +395,49 @@ def get_embedding_chunk( def _get_precomputed_embedding( items: List[MultimodalDataItem], + prefix_length: List[int], + extend_length: List[int], + items_offset_list: List[List[Tuple[int, int]]], ) -> Optional[torch.Tensor]: """ If all items have precomputed_embeddings, return their concatenation. If some but not all have precomputed_embeddings, raise NotImplementedError. If none have precomputed_embeddings, return None. """ - precomputed_embeddings = [item.precomputed_embeddings for item in items] + precomputed_embeddings = [] + for idx, item in enumerate(items): + if item.precomputed_embeddings is None: + precomputed_embeddings.append(None) + continue + seq_start_idx = prefix_length[idx] + seq_end_idx = seq_start_idx + extend_length[idx] - 1 + prefix_embedding_length = [] + extend_embedding_length = [] + for mm_start_idx, mm_end_idx in items_offset_list[idx]: + if mm_start_idx > seq_end_idx: + break + if seq_start_idx > mm_start_idx: + prefix_embedding_length.append( + min(seq_start_idx - mm_start_idx, mm_end_idx - mm_start_idx + 1) + ) + if mm_end_idx >= seq_start_idx: + extend_embedding_length.append( + min( + mm_end_idx - seq_start_idx + 1, + seq_end_idx - mm_start_idx + 1, + mm_end_idx - mm_start_idx + 1, + seq_end_idx - seq_start_idx + 1, + ) + ) + prefix_embedding_length = int(np.sum(prefix_embedding_length)) + extend_embedding_length = int(np.sum(extend_embedding_length)) + precomputed_embeddings.append( + item.precomputed_embeddings[ + prefix_embedding_length : prefix_embedding_length + + extend_embedding_length + ] + ) + if any(feature is not None for feature in precomputed_embeddings): if not all(feature is not None for feature in precomputed_embeddings): raise NotImplementedError( @@ -527,7 +563,9 @@ def get_embedding_and_mask( - A boolean mask tensor indicating where these embeddings should be placed """ # 1. Get embedding - embedding = _get_precomputed_embedding(embedding_items) + embedding = _get_precomputed_embedding( + embedding_items, prefix_length, extend_length, items_offset_list + ) if embedding is None: embedding = _get_chunked_prefill_embedding( data_embedding_func, diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 075f26df3226..18edaa3c3507 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -47,6 +47,7 @@ from sglang.srt.disaggregation.decode_kvcache_offload_manager import ( DecodeKVCacheOffloadManager, ) +from sglang.srt.disaggregation.encode_receiver import MMReceiver from sglang.srt.disaggregation.prefill import ( PrefillBootstrapQueue, SchedulerDisaggregationPrefillMixin, @@ -561,6 +562,17 @@ def __init__( # Init mlp sync flag self.require_mlp_sync = require_mlp_sync(server_args) + if ( + self.server_args.language_only + and self.server_args.encoder_transfer_backend == "zmq_to_scheduler" + ): + self.mm_receiver = MMReceiver( + server_args, + hf_config=self.model_config.hf_config, + tp_rank=self.tp_rank, + pp_rank=self.pp_rank, + ) + # Init request dispatcher self._request_dispatcher = TypeBasedDispatcher( [ @@ -1203,6 +1215,14 @@ def recv_requests( return recv_reqs def process_input_requests(self, recv_reqs: List): + + # Process MM requests under E disaggregation + if ( + self.server_args.language_only + and self.server_args.encoder_transfer_backend == "zmq_to_scheduler" + ): + recv_reqs = self.mm_receiver.process_waiting_requests(recv_reqs) + for recv_req in recv_reqs: # If it is a health check generation request and there are running requests, ignore it. if is_health_check_generate_req(recv_req) and ( diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 4c5134ac3da5..32fb8a1f54ce 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -38,6 +38,7 @@ from fastapi import BackgroundTasks from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.disaggregation.encode_receiver import MMReceiver from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.environ import envs from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry @@ -287,6 +288,13 @@ def __init__( # Make sure that each request carries the tokenizer_ipc_name for response routing self.send_to_scheduler = SenderWrapper(port_args, send_to_scheduler) + # E Disaggregation + if self.server_args.language_only: + self.mm_receiver = MMReceiver( + server_args, + dtype=self.model_config.dtype, + ) + # Request states self._chosen_loop = None self.rid_to_state: Dict[str, ReqState] = {} @@ -411,6 +419,13 @@ async def generate_request( created_time = obj.received_time if obj.received_time else time.time() self.auto_create_handle_loop() obj.normalize_batch_and_arguments() + if ( + self.server_args.language_only + and isinstance(obj, GenerateReqInput) + and self.server_args.encoder_transfer_backend == "zmq_to_scheduler" + and obj.contains_mm_input() + ): + self.mm_receiver.send_encode_request(obj) if self.enable_trace: self._trace_request_start(obj, created_time, request) @@ -615,13 +630,29 @@ async def _tokenize_one_request( obj.image_data = [obj.image_data] if obj.audio_data is not None and not isinstance(obj.audio_data, list): obj.audio_data = [obj.audio_data] - mm_inputs: Dict = await self.mm_data_processor.process( - image_data=obj.image_data, - audio_data=obj.audio_data, - input_text_or_ids=(input_text or input_ids), - request_obj=obj, - max_req_input_len=self.max_req_input_len, - ) + + mm_inputs = None + + if ( + not self.server_args.language_only + or self.server_args.encoder_transfer_backend + in ["zmq_to_tokenizer", "mooncake"] + ): + if self.server_args.language_only: + mm_inputs = await self.mm_receiver.recv_mm_data( + img_data=obj.image_data, + mm_processor=self.mm_processor, + prompt=(input_text or input_ids), + ) + if mm_inputs is None: + mm_inputs: Dict = await self.mm_data_processor.process( + image_data=obj.image_data, + audio_data=obj.audio_data, + input_text_or_ids=(input_text or input_ids), + request_obj=obj, + max_req_input_len=self.max_req_input_len, + ) + if mm_inputs and "input_ids" in mm_inputs: input_ids = mm_inputs["input_ids"] else: @@ -801,6 +832,9 @@ def _create_tokenized_object( data_parallel_rank=obj.data_parallel_rank, priority=obj.priority, extra_key=obj.extra_key, + need_wait_for_image=obj.need_wait_for_image, + num_items_assigned=obj.num_items_assigned, + embedding_ports=obj.embedding_ports, ) elif isinstance(obj, EmbeddingReqInput): tokenized_obj = TokenizedEmbeddingReqInput( diff --git a/python/sglang/srt/models/dots_vlm.py b/python/sglang/srt/models/dots_vlm.py index 1de27f664645..ea113b60c54d 100644 --- a/python/sglang/srt/models/dots_vlm.py +++ b/python/sglang/srt/models/dots_vlm.py @@ -50,9 +50,10 @@ def __init__( self.video_token_id = config.video_span_id self.pp_group = get_pp_group() - self.language_model = DeepseekV2ForCausalLM( - config.language_config, quant_config - ) + if not config.encoder_only: + self.language_model = DeepseekV2ForCausalLM( + config.language_config, quant_config + ) # Initialize vision tower (matching transformers naming for weight compatibility) self.vision_tower = DotsVisionTransformer(config.vision_config) @@ -104,18 +105,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): language_weights.append((name, loaded_weight)) # Load vision tower weights - vision_state_dict = dict(vision_weights) - params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in vision_state_dict.items(): - if name not in params_dict: - raise ValueError(f"Weight {name} not found in params_dict") - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = self._pad_vit_attn_dummy_heads(name, loaded_weight) - weight_loader(param, loaded_weight) + if not self.config.language_only: + vision_state_dict = dict(vision_weights) + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in vision_state_dict.items(): + if name not in params_dict: + raise ValueError(f"Weight {name} not found in params_dict") + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = self._pad_vit_attn_dummy_heads(name, loaded_weight) + weight_loader(param, loaded_weight) # Load language model weights - if language_weights: + if not self.config.encoder_only and language_weights: self.language_model.load_weights(language_weights) @classmethod diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index 7c1d09852276..9336dbaf8756 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -558,6 +558,31 @@ def __init__( self.pp_group = get_pp_group() self.config = config self.use_data_parallel = get_global_server_args().mm_enable_dp_encoder + + if not self.config.encoder_only: + self.model = Qwen2Model( + config, + quant_config, + prefix=add_prefix("model", prefix), + ) + + if self.pp_group.is_last_rank: + if self.pp_group.world_size == 1 and self.config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + ) + else: + # ranks other than the last rank will have a placeholder layer + self.lm_head = PPMissingLayer() + else: + # encoder_only mode: no language model, so no lm_head needed + self.lm_head = None + self.visual = Qwen2_5_VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), @@ -569,26 +594,6 @@ def __init__( max_context_len=self.config.max_position_embeddings, ) - self.model = Qwen2Model( - config, - quant_config, - prefix=add_prefix("model", prefix), - ) - - if self.pp_group.is_last_rank: - if self.pp_group.world_size == 1 and self.config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens - else: - self.lm_head = ParallelLMHead( - self.config.vocab_size, - self.config.hidden_size, - quant_config=quant_config, - prefix=add_prefix("lm_head", prefix), - ) - else: - # ranks other than the last rank will have a placeholder layer - self.lm_head = PPMissingLayer() - self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling self.logits_processor = LogitsProcessor(config) @@ -751,6 +756,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): layer_id = get_layer_id(name) if ( layer_id is not None + and hasattr(self, "model") and hasattr(self.model, "start_layer") and ( layer_id < self.model.start_layer @@ -762,6 +768,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip loading visual/language model weights + if ( + self.config.encoder_only or self.config.language_only + ) and name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -778,7 +789,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name in params_dict.keys(): param = params_dict[name] else: - raise ValueError(f"Weight {name} not found in params_dict") + if get_global_server_args().encoder_only: + continue + else: + raise ValueError(f"Weight {name} not found in params_dict") except KeyError: print(params_dict.keys()) raise diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py index ed52f7ff40f2..b49e03ccd6b5 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -601,6 +601,7 @@ def __init__( super().__init__() self.use_data_parallel = get_global_server_args().mm_enable_dp_encoder + self.visual = Qwen3VLMoeVisionModel( config.vision_config, # NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization. @@ -616,22 +617,28 @@ def __init__( self.config: Qwen3VLConfig = config # for qwen3-vl else: self.config = config.text_config # for qwen3-omni + self.config.encoder_only = getattr(config, "encoder_only", False) + self.config.language_only = getattr(config, "language_only", False) - self.model = language_model_cls( - config=self.config, - quant_config=quant_config, - prefix=add_prefix("model", prefix), - ) - - if self.config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens - else: - self.lm_head = ParallelLMHead( - self.config.vocab_size, - self.config.hidden_size, + if not hasattr(config, "encoder_only") or not config.encoder_only: + self.model = language_model_cls( + config=self.config, quant_config=quant_config, - prefix=add_prefix("lm_head", prefix), + prefix=add_prefix("model", prefix), ) + + if self.config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + ) + else: + # encoder_only mode: no language model, so no lm_head needed + self.lm_head = None self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling self.logits_processor = LogitsProcessor(self.config) @@ -640,7 +647,7 @@ def __init__( # 8, 16, 24 layer will be merged to 0, 1, 2 layer of decoder output hidden_states # deepstack - self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes + self.deepstack_visual_indexes = config.vision_config.deepstack_visual_indexes self.num_deepstack_embeddings = len(self.deepstack_visual_indexes) self.use_deepstack = {Modality.IMAGE: True, Modality.VIDEO: True} @@ -774,6 +781,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip loading visual/language model weights + if ( + self.config.encoder_only or self.config.language_only + ) and name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -788,6 +800,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip loading visual/language model weights + if ( + self.config.encoder_only or self.config.language_only + ) and name not in params_dict: + continue param = params_dict[name] except KeyError: print(params_dict.keys()) diff --git a/python/sglang/srt/models/qwen3_vl_moe.py b/python/sglang/srt/models/qwen3_vl_moe.py index e3e9e07d1962..b13a221f83fa 100644 --- a/python/sglang/srt/models/qwen3_vl_moe.py +++ b/python/sglang/srt/models/qwen3_vl_moe.py @@ -243,7 +243,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - if "visual" in name: + if "visual" in name or self.config.encoder_only: continue # Anyway, this is an expert weight and should not be # attempted to load as other weights later @@ -309,6 +309,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(ignore_suffixes) and name not in params_dict: continue + # Skip loading mm/language parameters + if ( + self.config.encoder_only or self.config.language_only + ) and name not in params_dict: + continue + if name in params_dict.keys(): param = params_dict[name] weight_loader = getattr( diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index 370aec2b65ab..6553178e6ce8 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -231,6 +231,64 @@ def __init__( MM_ITEM_MEMORY_POOL_RECYCLE_INTERVAL, ) + @property + def spatial_merge_size(self): + return self.hf_config.vision_config.spatial_merge_size + + def build_input_ids(self, prompt, img_grid_thw): + """ + Use prompt and img_grid_thw to build input_ids + """ + if not isinstance(prompt, list): + prompt = self._processor.tokenizer.encode(prompt) + + img_token_id = self.IM_TOKEN_ID + spatial_merge_size = self.spatial_merge_size + + input_ids = [] + offsets = [] + + cur_idx = 0 + + # Use img_token_id instead of im_start_id, because a dummy im_start_id + # may be generated by the tokenizer. + img_start_indices = list( + filter(lambda i: prompt[i + 1] == img_token_id, range(len(prompt) - 1)) + ) + + for cur_img_idx, img_start_idx in enumerate(img_start_indices): + assert cur_idx <= img_start_idx + # include img_start_id + input_ids.extend(prompt[cur_idx : img_start_idx + 1]) + img_offset_start = len(input_ids) + img_token_num = img_grid_thw[cur_img_idx].prod() // (spatial_merge_size**2) + input_ids.extend([img_token_id] * img_token_num) + # jump to img_end_id + cur_idx = img_start_idx + 2 + offsets.append((img_offset_start, len(input_ids) - 1)) + else: + input_ids.extend(prompt[cur_idx:]) + + return input_ids, offsets + + def get_mm_data(self, prompt, embeddings, img_grid_thw): + input_ids, offsets = self.build_input_ids(prompt, img_grid_thw) + mm_items = [ + MultimodalDataItem( + modality=Modality.IMAGE, + offsets=offsets, + precomputed_embeddings=embeddings, + ) + ] + + return { + "input_ids": input_ids, + "mm_items": mm_items, + "im_start_id": self.IM_START_TOKEN_ID, + "im_end_id": self.IM_END_TOKEN_ID, + "im_token_id": self.IM_TOKEN_ID, + } + def process_mm_data( self, input_text, images=None, videos=None, audios=None, **kwargs ) -> dict: diff --git a/python/sglang/srt/multimodal/processors/dots_vlm.py b/python/sglang/srt/multimodal/processors/dots_vlm.py index 65752244da39..8d6faf5e8748 100644 --- a/python/sglang/srt/multimodal/processors/dots_vlm.py +++ b/python/sglang/srt/multimodal/processors/dots_vlm.py @@ -24,8 +24,8 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): self.im_end_id = _processor.tokenizer.encode("<|endofimg|>")[0] self.image_token_id = _processor.tokenizer.encode("<|imgpad|>")[0] self.IM_TOKEN_ID = self.image_token_id - self.IM_START_ID = self.im_start_id - self.IM_END_ID = self.im_end_id + self.IM_START_TOKEN_ID = self.im_start_id + self.IM_END_TOKEN_ID = self.im_end_id vision_config = hf_config.vision_config patch_size = vision_config.patch_size diff --git a/python/sglang/srt/multimodal/processors/qwen_vl.py b/python/sglang/srt/multimodal/processors/qwen_vl.py index 88fe63978208..324a9b8fb939 100644 --- a/python/sglang/srt/multimodal/processors/qwen_vl.py +++ b/python/sglang/srt/multimodal/processors/qwen_vl.py @@ -12,6 +12,7 @@ from sglang.srt.environ import envs from sglang.srt.layers.rotary_embedding import MRotaryEmbedding +from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration from sglang.srt.models.qwen3_omni_moe import Qwen3OmniMoeForConditionalGeneration @@ -235,6 +236,10 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): super().__init__(hf_config, server_args, _processor, *args, **kwargs) + self.IM_START_TOKEN_ID = hf_config.vision_start_token_id + self.IM_END_TOKEN_ID = hf_config.vision_end_token_id + self.IM_TOKEN_ID = hf_config.image_token_id + self.vision_start_token_id = hf_config.vision_start_token_id self.vision_end_token_id = getattr(hf_config, "vision_end_token_id", None) @@ -255,6 +260,42 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): audio_token_id=self.audio_token_id, ).build(_processor) + def get_mm_data(self, prompt, embeddings, img_grid_thw): + input_ids, offsets = self.build_input_ids(prompt, img_grid_thw) + mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index( + spatial_merge_size=self.hf_config.vision_config.spatial_merge_size, + image_token_id=self.mm_tokens.image_token_id, + video_token_id=self.mm_tokens.video_token_id, + vision_start_token_id=self.vision_start_token_id, + model_type=self.model_type, + input_ids=torch.tensor(input_ids, dtype=torch.long).unsqueeze(0), + image_grid_thw=img_grid_thw, + tokens_per_second=getattr( + self.hf_config.vision_config, "tokens_per_second", None + ), + ) + mrope_positions = mrope_positions.squeeze(1) + + mm_items = [ + MultimodalDataItem( + modality=Modality.IMAGE, + offsets=offsets, + precomputed_embeddings=embeddings, + ) + ] + + return { + "input_ids": input_ids, + "mm_items": mm_items, + "im_start_id": self.IM_START_TOKEN_ID, + "im_end_id": self.IM_END_TOKEN_ID, + "im_token_id": self.mm_tokens.image_token_id, + "video_token_id": self.mm_tokens.video_token_id, + "audio_token_id": self.mm_tokens.audio_token_id, + "mrope_positions": mrope_positions, + "mrope_position_delta": mrope_position_delta, + } + async def process_mm_data_async( self, image_data: List[Union[str, bytes]], diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 42294937b38e..ed2b4f97de81 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -137,6 +137,8 @@ DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"] +ENCODER_TRANSFER_BACKEND_CHOICES = ["zmq_to_scheduler", "zmq_to_tokenizer", "mooncake"] + GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"] DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3", "triton"] @@ -269,6 +271,12 @@ class ServerArgs: nccl_port: Optional[int] = None checkpoint_engine_wait_weights_before_ready: bool = False + # Encode prefill disaggregation + encoder_only: bool = False + language_only: bool = False + encoder_transfer_backend: str = ENCODER_TRANSFER_BACKEND_CHOICES[0] + encoder_urls: List[str] = dataclasses.field(default_factory=list) + # Quantization and data type dtype: str = "auto" quantization: Optional[str] = None @@ -606,6 +614,7 @@ class ServerArgs: mm_max_concurrent_calls: int = 32 mm_per_request_timeout: float = 10.0 enable_broadcast_mm_inputs_process: bool = False + enable_prefix_mm_cache: bool = False mm_enable_dp_encoder: bool = False mm_process_config: Optional[Dict[str, Any]] = None @@ -675,7 +684,10 @@ def __post_init__(self): self._handle_load_format() # Handle PD disaggregation. - self._handle_disaggregation() + self._handle_pd_disaggregation() + + # Handle Encoder disaggregation. + self._handle_encoder_disaggregation() # Validate tokenizer settings. self._handle_tokenizer_batching() @@ -1990,7 +2002,30 @@ def _handle_load_format(self): ): self.load_format = "auto" - def _handle_disaggregation(self): + def _handle_encoder_disaggregation(self): + if self.enable_prefix_mm_cache and not self.encoder_only: + raise ValueError( + "--enable-prefix-mm-cache requires --encoder-only to be enabled" + ) + if self.encoder_only and self.language_only: + raise ValueError("Cannot set --encoder-only and --language-only together") + if self.encoder_only and not self.disaggregation_mode == "null": + raise ValueError( + "Cannot set --encoder-only and --disaggregation-mode prefill/decode together" + ) + if ( + self.language_only + and self.encoder_transfer_backend == "zmq_to_scheduler" + and self.pp_size > 1 + ): + raise ValueError("zmq_to_scheduler not support pp_size > 1") + + if self.language_only and len(self.encoder_urls) == 0: + raise ValueError( + "requires at least one encoder urls to be set via --encoder-urls" + ) + + def _handle_pd_disaggregation(self): if self.disaggregation_mode == "decode": assert ( self.disaggregation_decode_tp is None @@ -2395,6 +2430,32 @@ def add_cli_args(parser: argparse.ArgumentParser): "before serving inference requests.", ) + # Encode prefill disaggregation + parser.add_argument( + "--encoder-only", + action="store_true", + help="For MLLM with an encoder, launch an encoder-only server", + ) + parser.add_argument( + "--language-only", + action="store_true", + help="For VLM, load weights for the language model only.", + ) + parser.add_argument( + "--encoder-transfer-backend", + type=str, + default=ServerArgs.encoder_transfer_backend, + choices=ENCODER_TRANSFER_BACKEND_CHOICES, + help="The backend for encoder disaggregation transfer. Default is zmq_to_scheduler.", + ) + parser.add_argument( + "--encoder-urls", + nargs="+", + type=str, + default=[], + help="List of encoder server urls.", + ) + # Quantization and data type parser.add_argument( "--dtype", @@ -4153,6 +4214,12 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.decrypted_draft_config_file, help="The path of the decrypted draft config file.", ) + parser.add_argument( + "--enable-prefix-mm-cache", + action="store_true", + default=ServerArgs.enable_prefix_mm_cache, + help="Enable prefix multimodal cache. Currently only supports mm-only.", + ) # For registering hooks parser.add_argument( diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 52389e8c21f8..de7c05109596 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -152,6 +152,7 @@ TestFile("test_multi_instance_release_memory_occupation.py", 64), TestFile("test_pp_single_node.py", 800), TestFile("test_piecewise_cuda_graph.py", 1200), + TestFile("test_epd_disaggregation.py", 600), ], "per-commit-8-gpu-h200": [ TestFile("test_deepseek_v3_basic.py", 275), diff --git a/test/srt/test_epd_disaggregation.py b/test/srt/test_epd_disaggregation.py new file mode 100644 index 000000000000..07f79a97a9d4 --- /dev/null +++ b/test/srt/test_epd_disaggregation.py @@ -0,0 +1,424 @@ +import os +import subprocess +import threading +import unittest + +from sglang.srt.utils import kill_process_tree +from sglang.test.server_fixtures.disaggregation_fixture import ( + PDDisaggregationServerBase, +) +from sglang.test.test_utils import ( + DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + popen_launch_server, +) + + +class TestEPDDisaggregationOneEncoder(PDDisaggregationServerBase): + """Test EPD disaggregation with single encode server""" + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.model = DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST + cls.encode_port = f"{int(cls.lb_port) + 300}" + cls.encode_url = f"http://{cls.base_host}:{cls.encode_port}" + + print( + f"Setting up EPD (one encoder): encode={cls.encode_port}, " + f"prefill={cls.prefill_port}, decode={cls.decode_port}" + ) + + # Start servers in order: encode -> prefill/decode + cls.start_encode() + prefill_thread = threading.Thread(target=cls.start_prefill) + decode_thread = threading.Thread(target=cls.start_decode) + prefill_thread.start() + decode_thread.start() + prefill_thread.join() + decode_thread.join() + + # Wait for all servers to be ready + cls.wait_server_ready(cls.encode_url + "/health") + cls.wait_server_ready(cls.prefill_url + "/health") + cls.wait_server_ready(cls.decode_url + "/health") + + cls.launch_lb() + + # Set OpenAI API key and base URL environment variables. Needed for lmms-eval to work. + cls.api_key = "sk-123456" + os.environ["OPENAI_API_KEY"] = cls.api_key + os.environ["OPENAI_API_BASE"] = f"{cls.lb_url}/v1" + + @classmethod + def start_encode(cls): + """Start encode server for multimodal processing""" + encode_args = [ + "--trust-remote-code", + "--encoder-only", + "--encoder-transfer-backend", + "zmq_to_scheduler", + "--tp", + "1", + "--port", + cls.encode_port, + "--enable-prefix-mm-cache", + ] + cls.process_encode = popen_launch_server( + cls.model, + base_url=cls.encode_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=encode_args, + ) + + @classmethod + def start_prefill(cls): + """Start prefill server with language model only""" + prefill_args = [ + "--trust-remote-code", + "--language-only", + "--encoder-urls", + cls.encode_url, + "--encoder-transfer-backend", + "zmq_to_scheduler", + "--disaggregation-mode", + "prefill", + "--tp", + "1", + "--base-gpu-id", + "1", + "--port", + cls.prefill_port, + ] + prefill_args += cls.transfer_backend + cls.rdma_devices + cls.process_prefill = popen_launch_server( + cls.model, + base_url=cls.prefill_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=prefill_args, + ) + + @classmethod + def start_decode(cls): + """Start decode server""" + decode_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "decode", + "--tp", + "1", + "--base-gpu-id", + "2", + "--port", + cls.decode_port, + ] + decode_args += cls.transfer_backend + cls.rdma_devices + cls.process_decode = popen_launch_server( + cls.model, + base_url=cls.decode_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=decode_args, + ) + + @classmethod + def tearDownClass(cls): + """Clean up all processes""" + for process in [ + cls.process_lb, + cls.process_decode, + cls.process_prefill, + cls.process_encode, + ]: + if process: + try: + kill_process_tree(process.pid) + except Exception as e: + print(f"Error killing process: {e}") + + def run_mmmu_eval(self, model_version: str, output_path: str, limit: str = "50"): + """ + Evaluate a VLM on the MMMU validation set with lmms-eval. + Reference: test_vlm_models.py + + Args: + model_version: Model version/checkpoint to evaluate + output_path: Path to save evaluation results + limit: Number of samples to evaluate (default: "50" for CI time constraints) + """ + model = "openai_compatible" + tp = 1 + tasks = "mmmu_val" + batch_size = 32 + log_suffix = "openai_compatible" + os.makedirs(output_path, exist_ok=True) + + model_args = f'model_version="{model_version}",' f"tp={tp}" + + cmd = [ + "python3", + "-m", + "lmms_eval", + "--model", + model, + "--model_args", + model_args, + "--tasks", + tasks, + "--batch_size", + str(batch_size), + "--log_samples", + "--log_samples_suffix", + log_suffix, + "--output_path", + str(output_path), + "--limit", + limit, + ] + + subprocess.run(cmd, check=True, timeout=3600) + + def test_mmmu(self): + """Test MMMU evaluation with EPD disaggregation""" + import glob + import json + + output_path = "./logs/epd_one_encoder_mmmu" + self.run_mmmu_eval(self.model, output_path) + + # Get the result file + result_files = glob.glob(f"{output_path}/**/*.json", recursive=True) + if not result_files: + result_files = glob.glob(f"{output_path}/*.json") + + if not result_files: + self.fail(f"No JSON result files found in {output_path}") + + result_file_path = result_files[0] + with open(result_file_path, "r") as f: + result = json.load(f) + print(f"MMMU result: {result}") + + mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"] + print(f"MMMU accuracy: {mmmu_accuracy:.4f}") + + # for qwen2.5-vl-3b-instruct, the accuracy is 0.40 + self.assertGreater(mmmu_accuracy, 0.40) + + +class TestEPDDisaggregationMultiEncoders(PDDisaggregationServerBase): + """ + Test EPD disaggregation with multiple encode servers for load balancing. + Both encode servers run on GPU 0 (different ports) for testing load distribution. + """ + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.model = DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST + cls.encode_port1 = f"{int(cls.lb_port) + 300}" + cls.encode_port2 = f"{int(cls.lb_port) + 301}" + cls.encode_url1 = f"http://{cls.base_host}:{cls.encode_port1}" + cls.encode_url2 = f"http://{cls.base_host}:{cls.encode_port2}" + + print( + f"Setting up EPD (multiple encoders): encode1={cls.encode_port1}, " + f"encode2={cls.encode_port2}, prefill={cls.prefill_port}, decode={cls.decode_port}" + ) + + # Start two encode servers on GPU 0/1 + encode1_thread = threading.Thread( + target=cls.start_encode_server, args=(cls.encode_port1, 0) + ) + encode2_thread = threading.Thread( + target=cls.start_encode_server, args=(cls.encode_port2, 1) + ) + encode1_thread.start() + encode2_thread.start() + encode1_thread.join() + encode2_thread.join() + + prefill_thread = threading.Thread(target=cls.start_prefill) + decode_thread = threading.Thread(target=cls.start_decode) + prefill_thread.start() + decode_thread.start() + prefill_thread.join() + decode_thread.join() + + cls.wait_server_ready(cls.encode_url1 + "/health") + cls.wait_server_ready(cls.encode_url2 + "/health") + cls.wait_server_ready(cls.prefill_url + "/health") + cls.wait_server_ready(cls.decode_url + "/health") + + cls.launch_lb() + + # Set OpenAI API key and base URL environment variables. Needed for lmms-eval to work. + cls.api_key = "sk-123456" + os.environ["OPENAI_API_KEY"] = cls.api_key + os.environ["OPENAI_API_BASE"] = f"{cls.lb_url}/v1" + + @classmethod + def start_encode_server(cls, port, gpu_id): + """Start an encode server on specific port and GPU""" + encode_args = [ + "--trust-remote-code", + "--encoder-only", + "--encoder-transfer-backend", + "zmq_to_scheduler", + "--tp", + "1", + "--port", + port, + "--enable-prefix-mm-cache", + ] + # Only set base-gpu-id if not using GPU 0 + if gpu_id != 0: + encode_args.extend(["--base-gpu-id", str(gpu_id)]) + + process = popen_launch_server( + cls.model, + base_url=f"http://{cls.base_host}:{port}", + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=encode_args, + ) + if port == cls.encode_port1: + cls.process_encode1 = process + else: + cls.process_encode2 = process + + @classmethod + def start_prefill(cls): + """Start prefill server with multiple encode URLs""" + prefill_args = [ + "--trust-remote-code", + "--language-only", + "--encoder-urls", + cls.encode_url1, + cls.encode_url2, + "--encoder-transfer-backend", + "zmq_to_scheduler", + "--disaggregation-mode", + "prefill", + "--tp", + "1", + "--base-gpu-id", + "2", + "--port", + cls.prefill_port, + ] + prefill_args += cls.transfer_backend + cls.rdma_devices + cls.process_prefill = popen_launch_server( + cls.model, + base_url=cls.prefill_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=prefill_args, + ) + + @classmethod + def start_decode(cls): + """Start decode server""" + decode_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "decode", + "--tp", + "1", + "--base-gpu-id", + "3", + "--port", + cls.decode_port, + ] + decode_args += cls.transfer_backend + cls.rdma_devices + cls.process_decode = popen_launch_server( + cls.model, + base_url=cls.decode_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=decode_args, + ) + + @classmethod + def tearDownClass(cls): + """Clean up all processes""" + for process in [ + cls.process_lb, + cls.process_decode, + cls.process_prefill, + cls.process_encode1, + cls.process_encode2, + ]: + if process: + try: + kill_process_tree(process.pid) + except Exception as e: + print(f"Error killing process: {e}") + + def run_mmmu_eval(self, model_version: str, output_path: str, limit: str = "50"): + """ + Evaluate a VLM on the MMMU validation set with lmms-eval. + Reference: test_vlm_models.py + + Args: + model_version: Model version/checkpoint to evaluate + output_path: Path to save evaluation results + limit: Number of samples to evaluate (default: "50" for CI time constraints) + """ + model = "openai_compatible" + tp = 1 + tasks = "mmmu_val" + batch_size = 32 + log_suffix = "openai_compatible" + os.makedirs(output_path, exist_ok=True) + + model_args = f'model_version="{model_version}",' f"tp={tp}" + + cmd = [ + "python3", + "-m", + "lmms_eval", + "--model", + model, + "--model_args", + model_args, + "--tasks", + tasks, + "--batch_size", + str(batch_size), + "--log_samples", + "--log_samples_suffix", + log_suffix, + "--output_path", + str(output_path), + "--limit", + limit, + ] + + subprocess.run(cmd, check=True, timeout=3600) + + def test_mmmu(self): + """Test MMMU evaluation with EPD disaggregation (multiple encoders)""" + import glob + import json + + output_path = "./logs/epd_multi_encoder_mmmu" + self.run_mmmu_eval(self.model, output_path) + + # Get the result file + result_files = glob.glob(f"{output_path}/**/*.json", recursive=True) + if not result_files: + result_files = glob.glob(f"{output_path}/*.json") + + if not result_files: + self.fail(f"No JSON result files found in {output_path}") + + result_file_path = result_files[0] + with open(result_file_path, "r") as f: + result = json.load(f) + print(f"MMMU result (multi encoder): {result}") + + mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"] + print(f"MMMU accuracy (multi encoder): {mmmu_accuracy:.4f}") + # for qwen2.5-vl-3b-instruct, the accuracy is 0.40 + self.assertGreater(mmmu_accuracy, 0.40) + + +if __name__ == "__main__": + unittest.main()