From 92e4a99d90d7f2edb83b9744b639415254114d74 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Tue, 21 Oct 2025 09:38:48 +0000 Subject: [PATCH 01/68] Implement EPD disaggregation --- python/sglang/launch_server.py | 6 +- python/sglang/srt/configs/model_config.py | 7 ++ .../sglang/srt/entrypoints/encode_server.py | 78 ++++++++++++++++ .../srt/entrypoints/openai/serving_base.py | 1 + .../sglang/srt/managers/tokenizer_manager.py | 5 ++ python/sglang/srt/managers/tp_worker.py | 1 + python/sglang/srt/models/qwen2_5_vl.py | 36 ++++---- python/sglang/srt/server_args.py | 14 +++ .../bindings/python/sglang_router/mini_lb.py | 90 +++++++++++++------ .../python/sglang_router/router_args.py | 25 ++++++ 10 files changed, 219 insertions(+), 44 deletions(-) create mode 100644 python/sglang/srt/entrypoints/encode_server.py diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 9e3e82a78f92..3d3d19b14442 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -9,11 +9,15 @@ def run_server(server_args): - """Run the server based on server_args.grpc_mode.""" + """Run the server based on server_args.grpc_mode and server_args.mm_only.""" if server_args.grpc_mode: from sglang.srt.entrypoints.grpc_server import serve_grpc asyncio.run(serve_grpc(server_args)) + elif server_args.mm_only: + from sglang.srt.entrypoints.encode_server import launch_server + + launch_server(server_args) else: from sglang.srt.entrypoints.http_server import launch_server diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 6483487fd863..f2f06ae78c27 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -100,6 +100,8 @@ def __init__( model_impl: Union[str, ModelImpl] = ModelImpl.AUTO, sampling_defaults: str = "openai", quantize_and_serve: bool = False, + mm_only: bool = False, + language_only: bool = False, ) -> None: # Parse args self.model_path = model_path @@ -215,6 +217,9 @@ def __init__( self.image_token_id = getattr( self.hf_config, "image_token_id", None ) or getattr(self.hf_config, "image_token_index", None) + + self.hf_config.mm_only = mm_only + self.hf_config.language_only = language_only # matryoshka embeddings self.matryoshka_dimensions = getattr( @@ -246,6 +251,8 @@ def from_server_args( sampling_defaults=server_args.sampling_defaults, quantize_and_serve=server_args.quantize_and_serve, override_config_file=server_args.decrypted_config_file, + language_only=server_args.language_only, + mm_only=server_args.mm_only, **kwargs, ) diff --git a/python/sglang/srt/entrypoints/encode_server.py b/python/sglang/srt/entrypoints/encode_server.py new file mode 100644 index 000000000000..b06278787946 --- /dev/null +++ b/python/sglang/srt/entrypoints/encode_server.py @@ -0,0 +1,78 @@ +import uvicorn +import pickle + +from fastapi import FastAPI +from fastapi.responses import ORJSONResponse +from transformers import AutoImageProcessor +from transformers.image_utils import load_images +from typing import Optional + +from sglang.srt.server_args import PortArgs, ServerArgs, set_global_server_args_for_scheduler +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.configs.load_config import LoadConfig +from sglang.srt.model_loader import get_model +from sglang.srt.configs.device_config import DeviceConfig +from sglang.srt.distributed.parallel_state import initialize_model_parallel, init_distributed_environment +from sglang.srt.layers.dp_attention import initialize_dp_attention +from sglang.srt.managers.schedule_batch import MultimodalDataItem, Modality + +class ImageEncoder: + def __init__(self, server_args:ServerArgs): + set_global_server_args_for_scheduler(server_args) + + self.image_processor = AutoImageProcessor.from_pretrained( + server_args.model_path, trust_remote_code=server_args.trust_remote_code + ) + + self.model_config = ModelConfig.from_server_args( + server_args, + ) + + self.load_config = LoadConfig( + load_format=server_args.load_format, + download_dir=server_args.download_dir, + model_loader_extra_config=server_args.model_loader_extra_config, + remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip, + remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port, + remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports, + ) + + port_args = PortArgs.init_new(server_args) + if server_args.dist_init_addr: + dist_init_method = f"tcp://{server_args.dist_init_addr}" + else: + dist_init_method = f"tcp://127.0.0.1:{port_args.nccl_port}" + + init_distributed_environment(world_size=1, rank=0, distributed_init_method=dist_init_method) + initialize_model_parallel() + initialize_dp_attention(server_args, self.model_config) + + self.model = get_model( + model_config=self.model_config, + load_config=self.load_config, + device_config=DeviceConfig(), + ) + + def encoding(self,mm_items): + images = load_images(mm_items) + images_input = self.image_processor(images=images) + mm_item = MultimodalDataItem.from_dict({ + 'modality':Modality.IMAGE, + 'feature':images_input['pixel_values'], + }) + mm_item.set('image_grid_thw', images_input['image_grid_thw']) + mm_embeddings = self.model.get_image_feature([mm_item]) + return mm_embeddings + +app = FastAPI() +encoder: Optional[ImageEncoder] = None + +def launch_server(server_args:ServerArgs): + global encoder + encoder = ImageEncoder(server_args) + uvicorn.run(app, host=server_args.host, port=server_args.port) + +@app.post("/encode") +async def handle_encode_request(request_data: dict): + mm_embeddings = encoder.encoding(request_data['mm_items']) + return ORJSONResponse(content={'mm_embeddings':mm_embeddings.tolist()}) \ No newline at end of file diff --git a/python/sglang/srt/entrypoints/openai/serving_base.py b/python/sglang/srt/entrypoints/openai/serving_base.py index 669aed7b0462..8502103d0a96 100644 --- a/python/sglang/srt/entrypoints/openai/serving_base.py +++ b/python/sglang/srt/entrypoints/openai/serving_base.py @@ -4,6 +4,7 @@ import logging import time import uuid +import torch from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index b90cf0616cba..df9a79bb3f9d 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -76,6 +76,7 @@ from sglang.srt.managers.request_metrics_exporter import RequestMetricsExporterManager from sglang.srt.managers.schedule_batch import RequestStage from sglang.srt.managers.scheduler import is_health_check_generate_req +from sglang.srt.managers.schedule_batch import Modality from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin from sglang.srt.metrics.collector import TokenizerMetricsCollector @@ -718,6 +719,10 @@ async def _tokenize_one_request( ) if mm_inputs and "input_ids" in mm_inputs: input_ids = mm_inputs["input_ids"] + if hasattr(obj, 'image_data_embedding'): + for mm_item in mm_inputs["mm_items"]: + if mm_item.modality == Modality.IMAGE: + mm_item.precomputed_embeddings = obj.image_data_embedding else: mm_inputs = None diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index f37138a72749..aa880b20a4fd 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -237,6 +237,7 @@ def __init__( else server_args.speculative_draft_model_revision ), is_draft_model=is_draft_worker, + ) if server_args.dllm_algorithm is not None: diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index 32efa19d88b5..550f7e60cadd 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -476,21 +476,6 @@ def __init__( self.pp_group = get_pp_group() self.config = config self.use_data_parallel = get_global_server_args().mm_enable_dp_encoder - self.visual = Qwen2_5_VisionTransformer( - config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - # NOTE: Qwen2_5-VL vision encoder currently supports BitsAndBytes 4-bit quantization. - # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported. - quant_config=quant_config, - prefix=add_prefix("visual", prefix), - use_data_parallel=self.use_data_parallel, - ) - - self.model = Qwen2Model( - config, - quant_config, - prefix=add_prefix("model", prefix), - ) if self.pp_group.is_last_rank: if self.pp_group.world_size == 1 and self.config.tie_word_embeddings: @@ -506,6 +491,23 @@ def __init__( # ranks other than the last rank will have a placeholder layer self.lm_head = PPMissingLayer() + if not self.config.language_only: + self.visual = Qwen2_5_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + # NOTE: Qwen2_5-VL vision encoder currently supports BitsAndBytes 4-bit quantization. + # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported. + quant_config=quant_config, + prefix=add_prefix("visual", prefix), + ) + + if not self.config.mm_only: + self.model = Qwen2Model( + config, + quant_config, + prefix=add_prefix("model", prefix), + ) + self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling self.logits_processor = LogitsProcessor(config) @@ -647,6 +649,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue + + # Skip loading language model weights + if name not in params_dict: + continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ae8a403d182b..743953893cc8 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -253,6 +253,10 @@ class ServerArgs: warmups: Optional[str] = None nccl_port: Optional[int] = None checkpoint_engine_wait_weights_before_ready: bool = False + + # Encode prefill disaggregation + mm_only: bool = False + language_only: bool = False # Quantization and data type dtype: str = "auto" @@ -2236,6 +2240,16 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="If set, the server will wait for initial weights to be loaded via checkpoint-engine or other update methods " "before serving inference requests.", + + # Encode prefill disaggregation + parser.add_argument( + "--mm-only", + action='store_true' + ) + + parser.add_argument( + "--language-only", + action='store_true' ) # Quantization and data type diff --git a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py index 39e809358253..a391f6a4b68f 100644 --- a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py +++ b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py @@ -7,6 +7,7 @@ import logging import random import urllib +import torch from http import HTTPStatus from itertools import chain from typing import Optional @@ -68,6 +69,7 @@ def __init__( "Tracing is not supported in this environment. Please install sglang." ) self.enable_trace = False + self.encode_urls = router_args.encode_urls def _validate_router_args(self, router_args: RouterArgs): logger.warning( @@ -105,9 +107,47 @@ def select_pair(self): self.prefill_bootstrap_ports[pidx], self.decode_urls[didx], ) + + async def encode( + self, request_data, encode_urls, endpoint + ): + messages = request_data.get('messages') + if messages is None: + return + + # Extract mm_items + img_list = [] + for message in messages: + for item in message.get('content'): + if item.get('type') == 'image_url': + img_url = item.get('image_url').get('url') + img_list.append(img_url) + + # Split mm_items + num_item_per_encoder = (len(img_list)+len(encode_urls)-1) // len(encode_urls) + encode_requests = [{'mm_items':img_list[i*num_item_per_encoder:(i+1)*num_item_per_encoder]} + for i in range(len(encode_urls))] + + # Send encode requests + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout( + total=self.timeout + ) # Add timeout for request reliability + ) as session: + tasks = [ + session.post(f"{encode_urls[i]}/{endpoint}", json=encode_requests[i]) + for i in range(len(encode_urls)) + ] + + encode_responses = await asyncio.gather(*tasks) + + encode_data = [await response.json() for response in encode_responses] + encode_data = torch.concatenate([torch.tensor(data['mm_embeddings'], dtype=torch.bfloat16) for data in encode_data]) + + return encode_data async def generate( - self, modified_request, prefill_server, decode_server, endpoint + self, prefill_request, decode_request, prefill_server, decode_server, endpoint ) -> ORJSONResponse: assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}" @@ -128,16 +168,8 @@ async def generate( headers = {"trace_context": trace_context} tasks = [ - session.post( - f"{prefill_server}/{endpoint}", - json=modified_request, - headers=headers, - ), - session.post( - f"{decode_server}/{endpoint}", - json=modified_request, - headers=headers, - ), + session.post(f"{prefill_server}/{endpoint}", json=prefill_request), + session.post(f"{decode_server}/{endpoint}", json=decode_request), ] for bootstrap_room in bootstrap_room_list: @@ -146,7 +178,7 @@ async def generate( # Wait for both responses to complete. Prefill should end first. prefill_response, decode_response = await asyncio.gather(*tasks) - if "return_logprob" in modified_request: + if "return_logprob" in prefill_request: prefill_json = await prefill_response.json() ret_json = await decode_response.json() @@ -175,7 +207,7 @@ async def generate( ) async def generate_stream( - self, modified_request, prefill_server, decode_server, endpoint="generate" + self, prefill_request, decode_request, prefill_server, decode_server, endpoint="generate" ): assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}" @@ -200,16 +232,8 @@ async def stream_results(): headers = {"trace_context": trace_context} tasks = [ - session.post( - f"{prefill_server}/{endpoint}", - json=modified_request, - headers=headers, - ), - session.post( - f"{decode_server}/{endpoint}", - json=modified_request, - headers=headers, - ), + session.post(f"{prefill_server}/{endpoint}", json=prefill_request), + session.post(f"{decode_server}/{endpoint}", json=decode_request), ] for bootstrap_room in bootstrap_room_list: @@ -219,7 +243,7 @@ async def stream_results(): # Wait for both responses to complete. Since this is streaming, they return immediately. prefill_response, decode_response = await asyncio.gather(*tasks) - if modified_request.get("return_logprob", False): + if prefill_request.get("return_logprob", False): prefill_chunks = [] async for chunk in prefill_response.content: prefill_chunks.append(chunk) @@ -424,30 +448,40 @@ async def handle_generate_request(request_data: dict): async def _forward_to_backend(request_data: dict, endpoint_name: str): + mm_result = await lb.encode(request_data, lb.encode_urls, 'encode') + prefill_server, bootstrap_port, decode_server = lb.select_pair() # Parse and transform prefill_server for bootstrap data parsed_url = urllib.parse.urlparse(prefill_server) hostname = maybe_wrap_ipv6_address(parsed_url.hostname) - modified_request = request_data.copy() - modified_request.update( + decode_request = request_data.copy() + decode_request.update( { "bootstrap_host": hostname, "bootstrap_port": bootstrap_port, "bootstrap_room": _generate_bootstrap_room(), } ) + prefill_request = decode_request.copy() + prefill_request.update( + { + "image_data_embedding": mm_result.tolist() + } + ) if request_data.get("stream", False): return await lb.generate_stream( - modified_request, + prefill_request, + decode_request, prefill_server, decode_server, endpoint=endpoint_name, ) else: return await lb.generate( - modified_request, + prefill_request, + decode_request, prefill_server, decode_server, endpoint=endpoint_name, diff --git a/sgl-model-gateway/bindings/python/sglang_router/router_args.py b/sgl-model-gateway/bindings/python/sglang_router/router_args.py index 3c085c5cfd95..8dbd619bf821 100644 --- a/sgl-model-gateway/bindings/python/sglang_router/router_args.py +++ b/sgl-model-gateway/bindings/python/sglang_router/router_args.py @@ -21,6 +21,7 @@ class RouterArgs: default_factory=list ) # List of (url, bootstrap_port) decode_urls: List[str] = dataclasses.field(default_factory=list) + encode_urls: List[str] = dataclasses.field(default_factory=list) # Routing policy policy: str = "cache_aware" @@ -180,6 +181,7 @@ def add_cli_args( choices=["random", "round_robin", "cache_aware", "power_of_two"], help="Specific policy for decode nodes in PD mode. If not specified, uses the main policy", ) + # PD-specific arguments parser.add_argument( @@ -207,6 +209,13 @@ def add_cli_args( metavar=("URL",), help="Decode server URL. Can be specified multiple times.", ) + parser.add_argument( + f"--{prefix}encode", + nargs=1, + action="append", + metavar=("URL",), + help="Encode server URL. Can be specified multiple times.", + ) parser.add_argument( f"--{prefix}worker-startup-timeout-secs", type=int, @@ -674,6 +683,9 @@ def from_cli_args( args_dict["decode_urls"] = cls._parse_decode_urls( cli_args_dict.get(f"{prefix}decode", None) ) + args_dict["encode_urls"] = cls._parse_encode_urls( + cli_args_dict.get(f"{prefix}encode", None) + ) args_dict["selector"] = cls._parse_selector( cli_args_dict.get(f"{prefix}selector", None) ) @@ -780,3 +792,16 @@ def _parse_decode_urls(decode_list): # decode_list is a list of single-element lists due to nargs=1 return [url[0] for url in decode_list] + + @staticmethod + def _parse_encode_urls(encode_list): + """Parse encode URLs from --encode arguments. + + Format: --encode URL + Example: --encode http://encode1:8081 --encode http://encode2:8081 + """ + if not encode_list: + return [] + + # encode_list is a list of single-element lists due to nargs=1 + return [url[0] for url in encode_list] From 8c57beada128765feabd233f3ea4b8f496c20c36 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Wed, 22 Oct 2025 03:47:37 +0000 Subject: [PATCH 02/68] Fix model weights loading --- python/sglang/srt/models/qwen2_5_vl.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index 550f7e60cadd..f3e6608ba616 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -649,10 +649,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - - # Skip loading language model weights - if name not in params_dict: - continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: @@ -678,6 +674,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip loading language model weights + if self.config.mm_only and name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) From 104f763898cafeba0858c99f294231cb8a8e30a3 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Wed, 22 Oct 2025 09:40:13 +0000 Subject: [PATCH 03/68] Use zmq for transmitting embedding --- .../sglang/srt/entrypoints/encode_server.py | 58 +++++++++++++++++-- .../sglang/srt/managers/tokenizer_manager.py | 32 +++++++++- python/sglang/srt/models/qwen2_5_vl.py | 7 ++- python/sglang/srt/server_args.py | 7 ++- .../bindings/python/sglang_router/mini_lb.py | 58 +++++++++---------- 5 files changed, 121 insertions(+), 41 deletions(-) diff --git a/python/sglang/srt/entrypoints/encode_server.py b/python/sglang/srt/entrypoints/encode_server.py index b06278787946..770440f2b4dc 100644 --- a/python/sglang/srt/entrypoints/encode_server.py +++ b/python/sglang/srt/entrypoints/encode_server.py @@ -1,11 +1,15 @@ import uvicorn -import pickle +import zmq +import asyncio +import zmq.asyncio +import torch from fastapi import FastAPI from fastapi.responses import ORJSONResponse from transformers import AutoImageProcessor from transformers.image_utils import load_images from typing import Optional +from collections import deque from sglang.srt.server_args import PortArgs, ServerArgs, set_global_server_args_for_scheduler from sglang.srt.configs.model_config import ModelConfig @@ -15,6 +19,30 @@ from sglang.srt.distributed.parallel_state import initialize_model_parallel, init_distributed_environment from sglang.srt.layers.dp_attention import initialize_dp_attention from sglang.srt.managers.schedule_batch import MultimodalDataItem, Modality +from sglang.srt.utils import get_zmq_socket + +class EmbeddingData: + def __init__(self, req_id, num_parts, part_idx, mm_embedding): + self.req_id = req_id + self.num_parts = num_parts + self.part_idx = part_idx + self.embedding = mm_embedding + self.embedding_dict = dict() + self.embedding_dict[part_idx] = mm_embedding + + def add(self, embedding_data): + assert self.req_id == embedding_data.req_id + assert embedding_data.part_idx not in self.embedding_dict + self.embedding_dict[embedding_data.part_idx] = embedding_data.embedding + + def get(self): + assert len(self.embedding_dict) == self.num_parts + agg_data = [self.embedding_dict[i] for i in range(self.num_parts)] + return torch.concatenate(agg_data) + + @property + def ready(self): + return len(self.embedding_dict) == self.num_parts class ImageEncoder: def __init__(self, server_args:ServerArgs): @@ -53,7 +81,14 @@ def __init__(self, server_args:ServerArgs): device_config=DeviceConfig(), ) - def encoding(self,mm_items): + context = zmq.asyncio.Context(2) + self.send_to_prefill = get_zmq_socket(context, zmq.PUSH, f"tcp://localhost:{server_args.embedding_port}", False) + + self.wait_queue = deque() + + self.encode_task = None + + def encode(self,mm_items): images = load_images(mm_items) images_input = self.image_processor(images=images) mm_item = MultimodalDataItem.from_dict({ @@ -62,7 +97,21 @@ def encoding(self,mm_items): }) mm_item.set('image_grid_thw', images_input['image_grid_thw']) mm_embeddings = self.model.get_image_feature([mm_item]) + del images_input return mm_embeddings + + def add(self,request_data): + self.wait_queue.append(request_data) + + def step(self): + request_data = self.wait_queue.popleft() + mm_embeddings = self.encode(request_data["mm_items"]) + send_data = EmbeddingData(request_data['req_id'], + request_data['num_parts'], + request_data['part_idx'], + mm_embeddings) + self.send_to_prefill.send_pyobj(send_data) + del send_data app = FastAPI() encoder: Optional[ImageEncoder] = None @@ -74,5 +123,6 @@ def launch_server(server_args:ServerArgs): @app.post("/encode") async def handle_encode_request(request_data: dict): - mm_embeddings = encoder.encoding(request_data['mm_items']) - return ORJSONResponse(content={'mm_embeddings':mm_embeddings.tolist()}) \ No newline at end of file + encoder.add(request_data) + encoder.step() + return ORJSONResponse(content=None) \ No newline at end of file diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index df9a79bb3f9d..9583a3c0ce7a 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -192,6 +192,9 @@ def __init__( ) self.crash_dump_folder = server_args.crash_dump_folder self.enable_trace = server_args.enable_trace + self.disaggregation_mode = DisaggregationMode( + self.server_args.disaggregation_mode + ) # Read model args self.model_path = server_args.model_path @@ -311,6 +314,14 @@ def __init__( # Make sure that each request carries the tokenizer_ipc_name for response routing self.send_to_scheduler = SenderWrapper(port_args, send_to_scheduler) + # Recv embedding from encoding server + if (self.disaggregation_mode == DisaggregationMode.PREFILL and + self.model_config.is_multimodal): + self.recv_from_encoder = get_zmq_socket( + context, zmq.PULL, f"tcp://*:{server_args.embedding_port}", True + ) + self.received_embeddings = dict() + # Request states self._chosen_loop = None self.rid_to_state: Dict[str, ReqState] = {} @@ -719,10 +730,20 @@ async def _tokenize_one_request( ) if mm_inputs and "input_ids" in mm_inputs: input_ids = mm_inputs["input_ids"] - if hasattr(obj, 'image_data_embedding'): + + while self.disaggregation_mode == DisaggregationMode.PREFILL: + if obj.bootstrap_room not in self.received_embeddings: + await self.handle_embedding() + continue + if not self.received_embeddings[obj.bootstrap_room].ready: + await self.handle_embedding() + continue for mm_item in mm_inputs["mm_items"]: if mm_item.modality == Modality.IMAGE: - mm_item.precomputed_embeddings = obj.image_data_embedding + mm_item.precomputed_embeddings = self.received_embeddings[ + obj.bootstrap_room].get() + del self.received_embeddings[obj.bootstrap_room] + break else: mm_inputs = None @@ -1545,6 +1566,13 @@ async def handle_loop(self): recv_obj = await self.recv_from_detokenizer.recv_pyobj() self._result_dispatcher(recv_obj) self.last_receive_tstamp = time.time() + + async def handle_embedding(self): + recv_obj = await self.recv_from_encoder.recv_pyobj() + if recv_obj.req_id not in self.received_embeddings: + self.received_embeddings[recv_obj.req_id] = recv_obj + else: + self.received_embeddings[recv_obj.req_id].add(recv_obj) def _add_metric_if_present( self, diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index f3e6608ba616..9e8edbd22be6 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -674,9 +674,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - # Skip loading language model weights - if self.config.mm_only and name not in params_dict: - continue + # Skip loading visual/language model weights + if (self.config.mm_only or self.config.language_only + ) and name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 743953893cc8..7858baf69ec4 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -257,6 +257,7 @@ class ServerArgs: # Encode prefill disaggregation mm_only: bool = False language_only: bool = False + embedding_port:int = None # Quantization and data type dtype: str = "auto" @@ -2246,11 +2247,15 @@ def add_cli_args(parser: argparse.ArgumentParser): "--mm-only", action='store_true' ) - parser.add_argument( "--language-only", action='store_true' ) + parser.add_argument( + "--embedding-port", + type=int, + default=54213 + ) # Quantization and data type parser.add_argument( diff --git a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py index a391f6a4b68f..da90c0e568a9 100644 --- a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py +++ b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py @@ -123,10 +123,17 @@ async def encode( img_url = item.get('image_url').get('url') img_list.append(img_url) + if len(img_list) == 0: + return + # Split mm_items num_item_per_encoder = (len(img_list)+len(encode_urls)-1) // len(encode_urls) - encode_requests = [{'mm_items':img_list[i*num_item_per_encoder:(i+1)*num_item_per_encoder]} - for i in range(len(encode_urls))] + num_encoders = min(len(img_list), len(encode_urls)) + encode_requests = [{'mm_items':img_list[i*num_item_per_encoder:(i+1)*num_item_per_encoder], + 'num_parts': num_encoders, + 'part_idx': i, + 'req_id': request_data.get('bootstrap_room')} + for i in range(num_encoders)] # Send encode requests async with aiohttp.ClientSession( @@ -136,18 +143,13 @@ async def encode( ) as session: tasks = [ session.post(f"{encode_urls[i]}/{endpoint}", json=encode_requests[i]) - for i in range(len(encode_urls)) + for i in range(num_encoders) ] - encode_responses = await asyncio.gather(*tasks) - - encode_data = [await response.json() for response in encode_responses] - encode_data = torch.concatenate([torch.tensor(data['mm_embeddings'], dtype=torch.bfloat16) for data in encode_data]) - - return encode_data + await asyncio.gather(*tasks) async def generate( - self, prefill_request, decode_request, prefill_server, decode_server, endpoint + self, pd_request, prefill_server, decode_server, endpoint ) -> ORJSONResponse: assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}" @@ -168,8 +170,8 @@ async def generate( headers = {"trace_context": trace_context} tasks = [ - session.post(f"{prefill_server}/{endpoint}", json=prefill_request), - session.post(f"{decode_server}/{endpoint}", json=decode_request), + session.post(f"{prefill_server}/{endpoint}", json=pd_request), + session.post(f"{decode_server}/{endpoint}", json=pd_request), ] for bootstrap_room in bootstrap_room_list: @@ -178,7 +180,7 @@ async def generate( # Wait for both responses to complete. Prefill should end first. prefill_response, decode_response = await asyncio.gather(*tasks) - if "return_logprob" in prefill_request: + if "return_logprob" in pd_request: prefill_json = await prefill_response.json() ret_json = await decode_response.json() @@ -207,7 +209,7 @@ async def generate( ) async def generate_stream( - self, prefill_request, decode_request, prefill_server, decode_server, endpoint="generate" + self, pd_request, prefill_server, decode_server, endpoint="generate" ): assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}" @@ -232,8 +234,8 @@ async def stream_results(): headers = {"trace_context": trace_context} tasks = [ - session.post(f"{prefill_server}/{endpoint}", json=prefill_request), - session.post(f"{decode_server}/{endpoint}", json=decode_request), + session.post(f"{prefill_server}/{endpoint}", json=pd_request), + session.post(f"{decode_server}/{endpoint}", json=pd_request), ] for bootstrap_room in bootstrap_room_list: @@ -243,7 +245,7 @@ async def stream_results(): # Wait for both responses to complete. Since this is streaming, they return immediately. prefill_response, decode_response = await asyncio.gather(*tasks) - if prefill_request.get("return_logprob", False): + if pd_request.get("return_logprob", False): prefill_chunks = [] async for chunk in prefill_response.content: prefill_chunks.append(chunk) @@ -448,40 +450,34 @@ async def handle_generate_request(request_data: dict): async def _forward_to_backend(request_data: dict, endpoint_name: str): - mm_result = await lb.encode(request_data, lb.encode_urls, 'encode') + bootstrap_room = _generate_bootstrap_room() + encode_request = request_data.copy() + encode_request.update({"bootstrap_room": bootstrap_room}) + await lb.encode(encode_request, lb.encode_urls, 'encode') prefill_server, bootstrap_port, decode_server = lb.select_pair() # Parse and transform prefill_server for bootstrap data parsed_url = urllib.parse.urlparse(prefill_server) hostname = maybe_wrap_ipv6_address(parsed_url.hostname) - decode_request = request_data.copy() - decode_request.update( + pd_request = encode_request.copy() + pd_request.update( { "bootstrap_host": hostname, "bootstrap_port": bootstrap_port, - "bootstrap_room": _generate_bootstrap_room(), - } - ) - prefill_request = decode_request.copy() - prefill_request.update( - { - "image_data_embedding": mm_result.tolist() } ) if request_data.get("stream", False): return await lb.generate_stream( - prefill_request, - decode_request, + pd_request, prefill_server, decode_server, endpoint=endpoint_name, ) else: return await lb.generate( - prefill_request, - decode_request, + pd_request, prefill_server, decode_server, endpoint=endpoint_name, From 2cfccfe5eb5bc8506b55555a31deb61f224b49e7 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Thu, 23 Oct 2025 08:24:15 +0000 Subject: [PATCH 04/68] Fix MM embedding index --- python/sglang/srt/managers/mm_utils.py | 25 +++++++++++++++++-- .../bindings/python/sglang_router/mini_lb.py | 3 +-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index e9347e8112be..fedc9c48535e 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -341,13 +341,34 @@ def get_embedding_chunk( def _get_precomputed_embedding( items: List[MultimodalDataItem], + prefix_length: List[int], + extend_length: List[int], + items_offset_list: List[List[Tuple[int, int]]], ) -> Optional[torch.Tensor]: """ If all items have precomputed_embeddings, return their concatenation. If some but not all have precomputed_embeddings, raise NotImplementedError. If none have precomputed_embeddings, return None. """ - precomputed_embeddings = [item.precomputed_embeddings for item in items] + precomputed_embeddings = [] + for idx,item in enumerate(items): + seq_start_idx = prefix_length[idx] + seq_end_idx = seq_start_idx + extend_length[idx] - 1 + prefix_embedding_length = [] + extend_embedding_length = [] + for mm_start_idx, mm_end_idx in items_offset_list[idx]: + if mm_start_idx > seq_end_idx: + break + if seq_start_idx > mm_start_idx: + prefix_embedding_length.append(min(seq_start_idx - mm_start_idx, mm_end_idx - mm_start_idx + 1)) + if mm_end_idx >= seq_start_idx: + extend_embedding_length.append(min(mm_end_idx - seq_start_idx + 1, seq_end_idx - mm_start_idx + 1, mm_end_idx - mm_start_idx + 1)) + + prefix_embedding_length = int(np.sum(prefix_embedding_length)) + extend_embedding_length = int(np.sum(extend_embedding_length)) + precomputed_embeddings.append( + item.precomputed_embeddings[prefix_embedding_length:prefix_embedding_length+extend_embedding_length]) + if any(feature is not None for feature in precomputed_embeddings): if not all(feature is not None for feature in precomputed_embeddings): raise NotImplementedError( @@ -473,7 +494,7 @@ def get_embedding_and_mask( - A boolean mask tensor indicating where these embeddings should be placed """ # 1. Get embedding - embedding = _get_precomputed_embedding(embedding_items) + embedding = _get_precomputed_embedding(embedding_items, prefix_length, extend_length, items_offset_list) if embedding is None: embedding = _get_chunked_prefill_embedding( data_embedding_func, diff --git a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py index da90c0e568a9..982290684765 100644 --- a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py +++ b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py @@ -7,7 +7,6 @@ import logging import random import urllib -import torch from http import HTTPStatus from itertools import chain from typing import Optional @@ -453,7 +452,7 @@ async def _forward_to_backend(request_data: dict, endpoint_name: str): bootstrap_room = _generate_bootstrap_room() encode_request = request_data.copy() encode_request.update({"bootstrap_room": bootstrap_room}) - await lb.encode(encode_request, lb.encode_urls, 'encode') + asyncio.create_task(lb.encode(encode_request, lb.encode_urls, 'encode')) prefill_server, bootstrap_port, decode_server = lb.select_pair() From 5738fd1969011d40e70c9c05144c1e8815a5b12d Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Thu, 23 Oct 2025 09:53:14 +0000 Subject: [PATCH 05/68] Fix split mm items --- .../bindings/python/sglang_router/mini_lb.py | 32 +++++++++++++------ 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py index 982290684765..43bad19a0f26 100644 --- a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py +++ b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py @@ -69,6 +69,8 @@ def __init__( ) self.enable_trace = False self.encode_urls = router_args.encode_urls + + self.encode_idx = list(range(len(self.encode_urls))) def _validate_router_args(self, router_args: RouterArgs): logger.warning( @@ -126,13 +128,25 @@ async def encode( return # Split mm_items - num_item_per_encoder = (len(img_list)+len(encode_urls)-1) // len(encode_urls) - num_encoders = min(len(img_list), len(encode_urls)) - encode_requests = [{'mm_items':img_list[i*num_item_per_encoder:(i+1)*num_item_per_encoder], - 'num_parts': num_encoders, - 'part_idx': i, - 'req_id': request_data.get('bootstrap_room')} - for i in range(num_encoders)] + encode_requests = [] + random.shuffle(self.encode_idx) + num_items_assigned = [(idx+len(img_list)) // len(self.encode_urls) for idx in self.encode_idx] + num_parts = sum(1 for x in num_items_assigned if x != 0) + cum_num_items = 0 + cum_idx = 0 + for idx,assigned_num in enumerate(num_items_assigned): + if assigned_num == 0: + continue + encode_requests.append( + { + 'encoder_idx':idx, + 'mm_items':img_list[cum_num_items:cum_num_items+assigned_num], + 'num_parts': num_parts, + 'part_idx': cum_idx, + 'req_id': request_data.get('bootstrap_room') + }) + cum_idx += 1 + cum_num_items += assigned_num # Send encode requests async with aiohttp.ClientSession( @@ -141,8 +155,8 @@ async def encode( ) # Add timeout for request reliability ) as session: tasks = [ - session.post(f"{encode_urls[i]}/{endpoint}", json=encode_requests[i]) - for i in range(num_encoders) + session.post(f"{encode_urls[encode_request['encoder_idx']]}/{endpoint}", json=encode_request) + for encode_request in encode_requests ] await asyncio.gather(*tasks) From ad2c15517035e46cc1c28c30cccc2d76b459cffe Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Fri, 24 Oct 2025 02:43:34 +0000 Subject: [PATCH 06/68] Fix MM embedding index --- python/sglang/srt/managers/mm_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index fedc9c48535e..865c9fb8ebf7 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -362,8 +362,10 @@ def _get_precomputed_embedding( if seq_start_idx > mm_start_idx: prefix_embedding_length.append(min(seq_start_idx - mm_start_idx, mm_end_idx - mm_start_idx + 1)) if mm_end_idx >= seq_start_idx: - extend_embedding_length.append(min(mm_end_idx - seq_start_idx + 1, seq_end_idx - mm_start_idx + 1, mm_end_idx - mm_start_idx + 1)) - + extend_embedding_length.append(min(mm_end_idx - seq_start_idx + 1, + seq_end_idx - mm_start_idx + 1, + mm_end_idx - mm_start_idx + 1, + seq_end_idx - seq_start_idx + 1)) prefix_embedding_length = int(np.sum(prefix_embedding_length)) extend_embedding_length = int(np.sum(extend_embedding_length)) precomputed_embeddings.append( From 8cda7bf380257771ce898d3379b63da1a919311e Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Fri, 24 Oct 2025 07:55:25 +0000 Subject: [PATCH 07/68] Fix encoder OOM --- .../sglang/srt/entrypoints/encode_server.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/entrypoints/encode_server.py b/python/sglang/srt/entrypoints/encode_server.py index 770440f2b4dc..70363ba85504 100644 --- a/python/sglang/srt/entrypoints/encode_server.py +++ b/python/sglang/srt/entrypoints/encode_server.py @@ -89,16 +89,19 @@ def __init__(self, server_args:ServerArgs): self.encode_task = None def encode(self,mm_items): - images = load_images(mm_items) - images_input = self.image_processor(images=images) - mm_item = MultimodalDataItem.from_dict({ - 'modality':Modality.IMAGE, - 'feature':images_input['pixel_values'], - }) - mm_item.set('image_grid_thw', images_input['image_grid_thw']) - mm_embeddings = self.model.get_image_feature([mm_item]) - del images_input - return mm_embeddings + mm_embeddings = [] + # To avoid OOM, we process image one by one + for mm_item in mm_items: + image = load_images(mm_item) + image_input = self.image_processor(images=image) + mm_item = MultimodalDataItem.from_dict({ + 'modality':Modality.IMAGE, + 'feature':image_input['pixel_values'], + }) + mm_item.set('image_grid_thw', image_input['image_grid_thw']) + mm_embedding = self.model.get_image_feature([mm_item]) + mm_embeddings.append(mm_embedding) + return torch.concatenate(mm_embeddings) def add(self,request_data): self.wait_queue.append(request_data) From 93205ff8171e27bd13b2631830f49426ecb1db20 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Fri, 24 Oct 2025 08:09:09 +0000 Subject: [PATCH 08/68] Clean up code --- .../srt/entrypoints/openai/serving_base.py | 1 - python/sglang/srt/managers/tp_worker.py | 1 - .../bindings/python/sglang_router/mini_lb.py | 24 +++++++++---------- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_base.py b/python/sglang/srt/entrypoints/openai/serving_base.py index 8502103d0a96..669aed7b0462 100644 --- a/python/sglang/srt/entrypoints/openai/serving_base.py +++ b/python/sglang/srt/entrypoints/openai/serving_base.py @@ -4,7 +4,6 @@ import logging import time import uuid -import torch from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index aa880b20a4fd..f37138a72749 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -237,7 +237,6 @@ def __init__( else server_args.speculative_draft_model_revision ), is_draft_model=is_draft_worker, - ) if server_args.dllm_algorithm is not None: diff --git a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py index 43bad19a0f26..7a03fc1b5f52 100644 --- a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py +++ b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py @@ -162,7 +162,7 @@ async def encode( await asyncio.gather(*tasks) async def generate( - self, pd_request, prefill_server, decode_server, endpoint + self, modified_request, prefill_server, decode_server, endpoint ) -> ORJSONResponse: assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}" @@ -183,8 +183,8 @@ async def generate( headers = {"trace_context": trace_context} tasks = [ - session.post(f"{prefill_server}/{endpoint}", json=pd_request), - session.post(f"{decode_server}/{endpoint}", json=pd_request), + session.post(f"{prefill_server}/{endpoint}", json=modified_request), + session.post(f"{decode_server}/{endpoint}", json=modified_request), ] for bootstrap_room in bootstrap_room_list: @@ -193,7 +193,7 @@ async def generate( # Wait for both responses to complete. Prefill should end first. prefill_response, decode_response = await asyncio.gather(*tasks) - if "return_logprob" in pd_request: + if "return_logprob" in modified_request: prefill_json = await prefill_response.json() ret_json = await decode_response.json() @@ -222,7 +222,7 @@ async def generate( ) async def generate_stream( - self, pd_request, prefill_server, decode_server, endpoint="generate" + self, modified_request, prefill_server, decode_server, endpoint="generate" ): assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}" @@ -247,8 +247,8 @@ async def stream_results(): headers = {"trace_context": trace_context} tasks = [ - session.post(f"{prefill_server}/{endpoint}", json=pd_request), - session.post(f"{decode_server}/{endpoint}", json=pd_request), + session.post(f"{prefill_server}/{endpoint}", json=modified_request), + session.post(f"{decode_server}/{endpoint}", json=modified_request), ] for bootstrap_room in bootstrap_room_list: @@ -258,7 +258,7 @@ async def stream_results(): # Wait for both responses to complete. Since this is streaming, they return immediately. prefill_response, decode_response = await asyncio.gather(*tasks) - if pd_request.get("return_logprob", False): + if modified_request.get("return_logprob", False): prefill_chunks = [] async for chunk in prefill_response.content: prefill_chunks.append(chunk) @@ -473,8 +473,8 @@ async def _forward_to_backend(request_data: dict, endpoint_name: str): # Parse and transform prefill_server for bootstrap data parsed_url = urllib.parse.urlparse(prefill_server) hostname = maybe_wrap_ipv6_address(parsed_url.hostname) - pd_request = encode_request.copy() - pd_request.update( + modified_request = encode_request.copy() + modified_request.update( { "bootstrap_host": hostname, "bootstrap_port": bootstrap_port, @@ -483,14 +483,14 @@ async def _forward_to_backend(request_data: dict, endpoint_name: str): if request_data.get("stream", False): return await lb.generate_stream( - pd_request, + modified_request, prefill_server, decode_server, endpoint=endpoint_name, ) else: return await lb.generate( - pd_request, + modified_request, prefill_server, decode_server, endpoint=endpoint_name, From eaa3f1725262050c132066efe07f7d0fc4a095d3 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Fri, 24 Oct 2025 09:06:22 +0000 Subject: [PATCH 09/68] Fix EP colocate and add prefill server ip arg --- python/sglang/srt/entrypoints/encode_server.py | 2 +- python/sglang/srt/managers/mm_utils.py | 3 +++ .../sglang/srt/managers/tokenizer_manager.py | 4 ++-- python/sglang/srt/server_args.py | 18 ++++++++++++++---- .../bindings/python/sglang_router/mini_lb.py | 2 +- 5 files changed, 21 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/entrypoints/encode_server.py b/python/sglang/srt/entrypoints/encode_server.py index 70363ba85504..dcea540db00b 100644 --- a/python/sglang/srt/entrypoints/encode_server.py +++ b/python/sglang/srt/entrypoints/encode_server.py @@ -82,7 +82,7 @@ def __init__(self, server_args:ServerArgs): ) context = zmq.asyncio.Context(2) - self.send_to_prefill = get_zmq_socket(context, zmq.PUSH, f"tcp://localhost:{server_args.embedding_port}", False) + self.send_to_prefill = get_zmq_socket(context, zmq.PUSH, f"tcp://{server_args.prefill_server_ip}:{server_args.embedding_port}", False) self.wait_queue = deque() diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index 865c9fb8ebf7..7c26a8cec2d2 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -352,6 +352,9 @@ def _get_precomputed_embedding( """ precomputed_embeddings = [] for idx,item in enumerate(items): + if item.precomputed_embeddings is None: + precomputed_embeddings.append(None) + continue seq_start_idx = prefix_length[idx] seq_end_idx = seq_start_idx + extend_length[idx] - 1 prefix_embedding_length = [] diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 9583a3c0ce7a..df098e07b4f3 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -316,7 +316,7 @@ def __init__( # Recv embedding from encoding server if (self.disaggregation_mode == DisaggregationMode.PREFILL and - self.model_config.is_multimodal): + self.model_config.is_multimodal and self.server_args.language_only): self.recv_from_encoder = get_zmq_socket( context, zmq.PULL, f"tcp://*:{server_args.embedding_port}", True ) @@ -731,7 +731,7 @@ async def _tokenize_one_request( if mm_inputs and "input_ids" in mm_inputs: input_ids = mm_inputs["input_ids"] - while self.disaggregation_mode == DisaggregationMode.PREFILL: + while self.disaggregation_mode == DisaggregationMode.PREFILL and self.server_args.language_only: if obj.bootstrap_room not in self.received_embeddings: await self.handle_embedding() continue diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 7858baf69ec4..d001db023435 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -257,7 +257,8 @@ class ServerArgs: # Encode prefill disaggregation mm_only: bool = False language_only: bool = False - embedding_port:int = None + embedding_port: Optional[int] = None + prefill_server_ip: str = 'localhost' # Quantization and data type dtype: str = "auto" @@ -2245,16 +2246,25 @@ def add_cli_args(parser: argparse.ArgumentParser): # Encode prefill disaggregation parser.add_argument( "--mm-only", - action='store_true' + action='store_true', + help="For VLM, launch encode server only for multimodal part." ) parser.add_argument( "--language-only", - action='store_true' + action='store_true', + help="For VLM, load weights for the language model only." ) parser.add_argument( "--embedding-port", type=int, - default=54213 + default=54213, + help="The port for multimodal embedding transmission." + ) + parser.add_argument( + "--prefill-server-ip", + type=str, + default=ServerArgs.prefill_server_ip, + help="The IP for prefill instance when launching encode server." ) # Quantization and data type diff --git a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py index 7a03fc1b5f52..9893cba3dfd7 100644 --- a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py +++ b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py @@ -113,7 +113,7 @@ async def encode( self, request_data, encode_urls, endpoint ): messages = request_data.get('messages') - if messages is None: + if messages is None or len(encode_urls) == 0: return # Extract mm_items From 66090ad5a0969421f22fc5ff5307cf05f8e747b4 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Fri, 24 Oct 2025 10:14:35 +0000 Subject: [PATCH 10/68] Fix race condition --- .../sglang/srt/managers/tokenizer_manager.py | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index df098e07b4f3..247db83d4d88 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -321,6 +321,7 @@ def __init__( context, zmq.PULL, f"tcp://*:{server_args.embedding_port}", True ) self.received_embeddings = dict() + self.embeddings_lock = asyncio.Lock() # Request states self._chosen_loop = None @@ -731,19 +732,17 @@ async def _tokenize_one_request( if mm_inputs and "input_ids" in mm_inputs: input_ids = mm_inputs["input_ids"] - while self.disaggregation_mode == DisaggregationMode.PREFILL and self.server_args.language_only: - if obj.bootstrap_room not in self.received_embeddings: - await self.handle_embedding() - continue - if not self.received_embeddings[obj.bootstrap_room].ready: - await self.handle_embedding() - continue - for mm_item in mm_inputs["mm_items"]: - if mm_item.modality == Modality.IMAGE: - mm_item.precomputed_embeddings = self.received_embeddings[ - obj.bootstrap_room].get() - del self.received_embeddings[obj.bootstrap_room] - break + if self.disaggregation_mode == DisaggregationMode.PREFILL and self.server_args.language_only: + # Use async lock to avoid race condition + async with self.embeddings_lock: + while (obj.bootstrap_room not in self.received_embeddings or + not self.received_embeddings[obj.bootstrap_room].ready): + await self.handle_embedding() + for mm_item in mm_inputs["mm_items"]: + if mm_item.modality == Modality.IMAGE: + mm_item.precomputed_embeddings = self.received_embeddings[ + obj.bootstrap_room].get() + del self.received_embeddings[obj.bootstrap_room] else: mm_inputs = None From 90f85074530f32a67dda2258ce71d14f04322b23 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Mon, 27 Oct 2025 08:35:07 +0000 Subject: [PATCH 11/68] Batch MM items again (OOM fixed by core binding) --- .../sglang/srt/entrypoints/encode_server.py | 27 +++++++++---------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/python/sglang/srt/entrypoints/encode_server.py b/python/sglang/srt/entrypoints/encode_server.py index dcea540db00b..bf0ccd50bd37 100644 --- a/python/sglang/srt/entrypoints/encode_server.py +++ b/python/sglang/srt/entrypoints/encode_server.py @@ -1,6 +1,6 @@ import uvicorn import zmq -import asyncio +import time import zmq.asyncio import torch @@ -87,21 +87,18 @@ def __init__(self, server_args:ServerArgs): self.wait_queue = deque() self.encode_task = None - + + @torch.inference_mode() def encode(self,mm_items): - mm_embeddings = [] - # To avoid OOM, we process image one by one - for mm_item in mm_items: - image = load_images(mm_item) - image_input = self.image_processor(images=image) - mm_item = MultimodalDataItem.from_dict({ - 'modality':Modality.IMAGE, - 'feature':image_input['pixel_values'], - }) - mm_item.set('image_grid_thw', image_input['image_grid_thw']) - mm_embedding = self.model.get_image_feature([mm_item]) - mm_embeddings.append(mm_embedding) - return torch.concatenate(mm_embeddings) + images = load_images(mm_items) + images_input = self.image_processor(images=images) + mm_item = MultimodalDataItem.from_dict({ + 'modality':Modality.IMAGE, + 'feature':images_input['pixel_values'], + }) + mm_item.set('image_grid_thw', images_input['image_grid_thw']) + mm_embedding = self.model.get_image_feature([mm_item]) + return mm_embedding def add(self,request_data): self.wait_queue.append(request_data) From dcab59a3b4046092a0ab6409651e49f064429f22 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Mon, 27 Oct 2025 09:18:39 +0000 Subject: [PATCH 12/68] Enable assign prefill IP on the fly --- .../sglang/srt/entrypoints/encode_server.py | 20 ++++++++++++---- python/sglang/srt/server_args.py | 7 ------ .../bindings/python/sglang_router/mini_lb.py | 23 +++++++++---------- 3 files changed, 27 insertions(+), 23 deletions(-) diff --git a/python/sglang/srt/entrypoints/encode_server.py b/python/sglang/srt/entrypoints/encode_server.py index bf0ccd50bd37..01c8c0aef79d 100644 --- a/python/sglang/srt/entrypoints/encode_server.py +++ b/python/sglang/srt/entrypoints/encode_server.py @@ -1,6 +1,5 @@ import uvicorn import zmq -import time import zmq.asyncio import torch @@ -46,6 +45,7 @@ def ready(self): class ImageEncoder: def __init__(self, server_args:ServerArgs): + self.server_args = server_args set_global_server_args_for_scheduler(server_args) self.image_processor = AutoImageProcessor.from_pretrained( @@ -81,8 +81,8 @@ def __init__(self, server_args:ServerArgs): device_config=DeviceConfig(), ) - context = zmq.asyncio.Context(2) - self.send_to_prefill = get_zmq_socket(context, zmq.PUSH, f"tcp://{server_args.prefill_server_ip}:{server_args.embedding_port}", False) + self.context = zmq.asyncio.Context(2) + self.send_to_prefill_sockets = dict() self.wait_queue = deque() @@ -100,6 +100,18 @@ def encode(self,mm_items): mm_embedding = self.model.get_image_feature([mm_item]) return mm_embedding + def send(self, send_data, prefill_ip): + if prefill_ip in self.send_to_prefill_sockets: + socket = self.send_to_prefill_sockets[prefill_ip] + else: + socket = get_zmq_socket( + self.context, + zmq.PUSH, + f"tcp://{prefill_ip}:{self.server_args.embedding_port}", + False) + self.send_to_prefill_sockets[prefill_ip] = socket + socket.send_pyobj(send_data) + def add(self,request_data): self.wait_queue.append(request_data) @@ -110,7 +122,7 @@ def step(self): request_data['num_parts'], request_data['part_idx'], mm_embeddings) - self.send_to_prefill.send_pyobj(send_data) + self.send(send_data, request_data['bootstrap_host']) del send_data app = FastAPI() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index d001db023435..bd01aa72121c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -258,7 +258,6 @@ class ServerArgs: mm_only: bool = False language_only: bool = False embedding_port: Optional[int] = None - prefill_server_ip: str = 'localhost' # Quantization and data type dtype: str = "auto" @@ -2260,12 +2259,6 @@ def add_cli_args(parser: argparse.ArgumentParser): default=54213, help="The port for multimodal embedding transmission." ) - parser.add_argument( - "--prefill-server-ip", - type=str, - default=ServerArgs.prefill_server_ip, - help="The IP for prefill instance when launching encode server." - ) # Quantization and data type parser.add_argument( diff --git a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py index 9893cba3dfd7..83b12bdf8bae 100644 --- a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py +++ b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py @@ -143,7 +143,8 @@ async def encode( 'mm_items':img_list[cum_num_items:cum_num_items+assigned_num], 'num_parts': num_parts, 'part_idx': cum_idx, - 'req_id': request_data.get('bootstrap_room') + 'req_id': request_data.get('bootstrap_room'), + 'bootstrap_host': request_data.get('bootstrap_host'), }) cum_idx += 1 cum_num_items += assigned_num @@ -463,23 +464,21 @@ async def handle_generate_request(request_data: dict): async def _forward_to_backend(request_data: dict, endpoint_name: str): - bootstrap_room = _generate_bootstrap_room() - encode_request = request_data.copy() - encode_request.update({"bootstrap_room": bootstrap_room}) - asyncio.create_task(lb.encode(encode_request, lb.encode_urls, 'encode')) - prefill_server, bootstrap_port, decode_server = lb.select_pair() # Parse and transform prefill_server for bootstrap data parsed_url = urllib.parse.urlparse(prefill_server) hostname = maybe_wrap_ipv6_address(parsed_url.hostname) + bootstrap_room = _generate_bootstrap_room() + + # Send requests to encode server + encode_request = request_data.copy() + encode_request.update({"bootstrap_room": bootstrap_room, + "bootstrap_host": hostname,}) + asyncio.create_task(lb.encode(encode_request, lb.encode_urls, 'encode')) + modified_request = encode_request.copy() - modified_request.update( - { - "bootstrap_host": hostname, - "bootstrap_port": bootstrap_port, - } - ) + modified_request.update({"bootstrap_port": bootstrap_port}) if request_data.get("stream", False): return await lb.generate_stream( From 3a064cd77246508ee2215a05c54ff2c14a4b477b Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Mon, 27 Oct 2025 12:33:24 +0000 Subject: [PATCH 13/68] Fix rebase --- python/sglang/srt/server_args.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index bd01aa72121c..ed6cee006579 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2241,6 +2241,7 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="If set, the server will wait for initial weights to be loaded via checkpoint-engine or other update methods " "before serving inference requests.", + ) # Encode prefill disaggregation parser.add_argument( From f17fcb320c68a3f476b1a493bc19f6fdf677d9c5 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Tue, 28 Oct 2025 08:02:19 +0000 Subject: [PATCH 14/68] Fix image resize for Qwen --- .../sglang/srt/entrypoints/encode_server.py | 27 +++++++++---------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/python/sglang/srt/entrypoints/encode_server.py b/python/sglang/srt/entrypoints/encode_server.py index 01c8c0aef79d..114ed3c05577 100644 --- a/python/sglang/srt/entrypoints/encode_server.py +++ b/python/sglang/srt/entrypoints/encode_server.py @@ -1,6 +1,7 @@ import uvicorn import zmq import zmq.asyncio +import asyncio import torch from fastapi import FastAPI @@ -8,7 +9,6 @@ from transformers import AutoImageProcessor from transformers.image_utils import load_images from typing import Optional -from collections import deque from sglang.srt.server_args import PortArgs, ServerArgs, set_global_server_args_for_scheduler from sglang.srt.configs.model_config import ModelConfig @@ -19,6 +19,7 @@ from sglang.srt.layers.dp_attention import initialize_dp_attention from sglang.srt.managers.schedule_batch import MultimodalDataItem, Modality from sglang.srt.utils import get_zmq_socket +from sglang.srt.multimodal.processors.qwen_vl import resize_image_async class EmbeddingData: def __init__(self, req_id, num_parts, part_idx, mm_embedding): @@ -83,14 +84,14 @@ def __init__(self, server_args:ServerArgs): self.context = zmq.asyncio.Context(2) self.send_to_prefill_sockets = dict() - - self.wait_queue = deque() - - self.encode_task = None - @torch.inference_mode() - def encode(self,mm_items): + async def encode(self,mm_items): images = load_images(mm_items) + + # Qwen-specific: resize images + resize_tasks = [resize_image_async(image) for image in images] + images = await asyncio.gather(*resize_tasks) + images_input = self.image_processor(images=images) mm_item = MultimodalDataItem.from_dict({ 'modality':Modality.IMAGE, @@ -112,12 +113,9 @@ def send(self, send_data, prefill_ip): self.send_to_prefill_sockets[prefill_ip] = socket socket.send_pyobj(send_data) - def add(self,request_data): - self.wait_queue.append(request_data) - - def step(self): - request_data = self.wait_queue.popleft() - mm_embeddings = self.encode(request_data["mm_items"]) + @torch.inference_mode() + async def step(self, request_data): + mm_embeddings = await self.encode(request_data["mm_items"]) send_data = EmbeddingData(request_data['req_id'], request_data['num_parts'], request_data['part_idx'], @@ -135,6 +133,5 @@ def launch_server(server_args:ServerArgs): @app.post("/encode") async def handle_encode_request(request_data: dict): - encoder.add(request_data) - encoder.step() + await encoder.step(request_data) return ORJSONResponse(content=None) \ No newline at end of file From ec16b9b0c831be19a563845f80d0935ca9758d08 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Tue, 28 Oct 2025 08:53:33 +0000 Subject: [PATCH 15/68] Format --- python/sglang/srt/configs/model_config.py | 2 +- .../sglang/srt/entrypoints/encode_server.py | 115 ++++++++++-------- python/sglang/srt/managers/mm_utils.py | 30 +++-- .../sglang/srt/managers/tokenizer_manager.py | 26 ++-- python/sglang/srt/models/qwen2_5_vl.py | 12 +- python/sglang/srt/server_args.py | 14 +-- .../bindings/python/sglang_router/mini_lb.py | 64 +++++----- .../python/sglang_router/router_args.py | 3 +- 8 files changed, 159 insertions(+), 107 deletions(-) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index f2f06ae78c27..16e909ad5393 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -217,7 +217,7 @@ def __init__( self.image_token_id = getattr( self.hf_config, "image_token_id", None ) or getattr(self.hf_config, "image_token_index", None) - + self.hf_config.mm_only = mm_only self.hf_config.language_only = language_only diff --git a/python/sglang/srt/entrypoints/encode_server.py b/python/sglang/srt/entrypoints/encode_server.py index 114ed3c05577..d9b3dbb6d978 100644 --- a/python/sglang/srt/entrypoints/encode_server.py +++ b/python/sglang/srt/entrypoints/encode_server.py @@ -1,25 +1,33 @@ +import asyncio +from typing import Optional + +import torch import uvicorn import zmq import zmq.asyncio -import asyncio -import torch - from fastapi import FastAPI from fastapi.responses import ORJSONResponse from transformers import AutoImageProcessor from transformers.image_utils import load_images -from typing import Optional -from sglang.srt.server_args import PortArgs, ServerArgs, set_global_server_args_for_scheduler -from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.configs.load_config import LoadConfig -from sglang.srt.model_loader import get_model from sglang.srt.configs.device_config import DeviceConfig -from sglang.srt.distributed.parallel_state import initialize_model_parallel, init_distributed_environment +from sglang.srt.configs.load_config import LoadConfig +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.distributed.parallel_state import ( + init_distributed_environment, + initialize_model_parallel, +) from sglang.srt.layers.dp_attention import initialize_dp_attention -from sglang.srt.managers.schedule_batch import MultimodalDataItem, Modality -from sglang.srt.utils import get_zmq_socket +from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem +from sglang.srt.model_loader import get_model from sglang.srt.multimodal.processors.qwen_vl import resize_image_async +from sglang.srt.server_args import ( + PortArgs, + ServerArgs, + set_global_server_args_for_scheduler, +) +from sglang.srt.utils import get_zmq_socket + class EmbeddingData: def __init__(self, req_id, num_parts, part_idx, mm_embedding): @@ -29,34 +37,35 @@ def __init__(self, req_id, num_parts, part_idx, mm_embedding): self.embedding = mm_embedding self.embedding_dict = dict() self.embedding_dict[part_idx] = mm_embedding - + def add(self, embedding_data): assert self.req_id == embedding_data.req_id assert embedding_data.part_idx not in self.embedding_dict self.embedding_dict[embedding_data.part_idx] = embedding_data.embedding - + def get(self): assert len(self.embedding_dict) == self.num_parts agg_data = [self.embedding_dict[i] for i in range(self.num_parts)] return torch.concatenate(agg_data) - + @property def ready(self): return len(self.embedding_dict) == self.num_parts + class ImageEncoder: - def __init__(self, server_args:ServerArgs): + def __init__(self, server_args: ServerArgs): self.server_args = server_args set_global_server_args_for_scheduler(server_args) - + self.image_processor = AutoImageProcessor.from_pretrained( server_args.model_path, trust_remote_code=server_args.trust_remote_code ) - + self.model_config = ModelConfig.from_server_args( server_args, ) - + self.load_config = LoadConfig( load_format=server_args.load_format, download_dir=server_args.download_dir, @@ -65,27 +74,29 @@ def __init__(self, server_args:ServerArgs): remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port, remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports, ) - + port_args = PortArgs.init_new(server_args) if server_args.dist_init_addr: dist_init_method = f"tcp://{server_args.dist_init_addr}" else: dist_init_method = f"tcp://127.0.0.1:{port_args.nccl_port}" - - init_distributed_environment(world_size=1, rank=0, distributed_init_method=dist_init_method) + + init_distributed_environment( + world_size=1, rank=0, distributed_init_method=dist_init_method + ) initialize_model_parallel() initialize_dp_attention(server_args, self.model_config) - + self.model = get_model( - model_config=self.model_config, - load_config=self.load_config, - device_config=DeviceConfig(), - ) - + model_config=self.model_config, + load_config=self.load_config, + device_config=DeviceConfig(), + ) + self.context = zmq.asyncio.Context(2) self.send_to_prefill_sockets = dict() - - async def encode(self,mm_items): + + async def encode(self, mm_items): images = load_images(mm_items) # Qwen-specific: resize images @@ -93,45 +104,53 @@ async def encode(self,mm_items): images = await asyncio.gather(*resize_tasks) images_input = self.image_processor(images=images) - mm_item = MultimodalDataItem.from_dict({ - 'modality':Modality.IMAGE, - 'feature':images_input['pixel_values'], - }) - mm_item.set('image_grid_thw', images_input['image_grid_thw']) + mm_item = MultimodalDataItem.from_dict( + { + "modality": Modality.IMAGE, + "feature": images_input["pixel_values"], + } + ) + mm_item.set("image_grid_thw", images_input["image_grid_thw"]) mm_embedding = self.model.get_image_feature([mm_item]) return mm_embedding - + def send(self, send_data, prefill_ip): if prefill_ip in self.send_to_prefill_sockets: socket = self.send_to_prefill_sockets[prefill_ip] else: socket = get_zmq_socket( - self.context, - zmq.PUSH, - f"tcp://{prefill_ip}:{self.server_args.embedding_port}", - False) + self.context, + zmq.PUSH, + f"tcp://{prefill_ip}:{self.server_args.embedding_port}", + False, + ) self.send_to_prefill_sockets[prefill_ip] = socket socket.send_pyobj(send_data) - + @torch.inference_mode() async def step(self, request_data): mm_embeddings = await self.encode(request_data["mm_items"]) - send_data = EmbeddingData(request_data['req_id'], - request_data['num_parts'], - request_data['part_idx'], - mm_embeddings) - self.send(send_data, request_data['bootstrap_host']) + send_data = EmbeddingData( + request_data["req_id"], + request_data["num_parts"], + request_data["part_idx"], + mm_embeddings, + ) + self.send(send_data, request_data["bootstrap_host"]) del send_data + app = FastAPI() encoder: Optional[ImageEncoder] = None - -def launch_server(server_args:ServerArgs): + + +def launch_server(server_args: ServerArgs): global encoder encoder = ImageEncoder(server_args) uvicorn.run(app, host=server_args.host, port=server_args.port) + @app.post("/encode") async def handle_encode_request(request_data: dict): await encoder.step(request_data) - return ORJSONResponse(content=None) \ No newline at end of file + return ORJSONResponse(content=None) diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index 7c26a8cec2d2..ffbe680ab58d 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -351,7 +351,7 @@ def _get_precomputed_embedding( If none have precomputed_embeddings, return None. """ precomputed_embeddings = [] - for idx,item in enumerate(items): + for idx, item in enumerate(items): if item.precomputed_embeddings is None: precomputed_embeddings.append(None) continue @@ -363,17 +363,27 @@ def _get_precomputed_embedding( if mm_start_idx > seq_end_idx: break if seq_start_idx > mm_start_idx: - prefix_embedding_length.append(min(seq_start_idx - mm_start_idx, mm_end_idx - mm_start_idx + 1)) + prefix_embedding_length.append( + min(seq_start_idx - mm_start_idx, mm_end_idx - mm_start_idx + 1) + ) if mm_end_idx >= seq_start_idx: - extend_embedding_length.append(min(mm_end_idx - seq_start_idx + 1, - seq_end_idx - mm_start_idx + 1, - mm_end_idx - mm_start_idx + 1, - seq_end_idx - seq_start_idx + 1)) + extend_embedding_length.append( + min( + mm_end_idx - seq_start_idx + 1, + seq_end_idx - mm_start_idx + 1, + mm_end_idx - mm_start_idx + 1, + seq_end_idx - seq_start_idx + 1, + ) + ) prefix_embedding_length = int(np.sum(prefix_embedding_length)) extend_embedding_length = int(np.sum(extend_embedding_length)) precomputed_embeddings.append( - item.precomputed_embeddings[prefix_embedding_length:prefix_embedding_length+extend_embedding_length]) - + item.precomputed_embeddings[ + prefix_embedding_length : prefix_embedding_length + + extend_embedding_length + ] + ) + if any(feature is not None for feature in precomputed_embeddings): if not all(feature is not None for feature in precomputed_embeddings): raise NotImplementedError( @@ -499,7 +509,9 @@ def get_embedding_and_mask( - A boolean mask tensor indicating where these embeddings should be placed """ # 1. Get embedding - embedding = _get_precomputed_embedding(embedding_items, prefix_length, extend_length, items_offset_list) + embedding = _get_precomputed_embedding( + embedding_items, prefix_length, extend_length, items_offset_list + ) if embedding is None: embedding = _get_chunked_prefill_embedding( data_embedding_func, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 247db83d4d88..a1cb8bc02ce1 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -77,6 +77,7 @@ from sglang.srt.managers.schedule_batch import RequestStage from sglang.srt.managers.scheduler import is_health_check_generate_req from sglang.srt.managers.schedule_batch import Modality +from sglang.srt.managers.scheduler import is_health_check_generate_req from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin from sglang.srt.metrics.collector import TokenizerMetricsCollector @@ -315,8 +316,11 @@ def __init__( self.send_to_scheduler = SenderWrapper(port_args, send_to_scheduler) # Recv embedding from encoding server - if (self.disaggregation_mode == DisaggregationMode.PREFILL and - self.model_config.is_multimodal and self.server_args.language_only): + if ( + self.disaggregation_mode == DisaggregationMode.PREFILL + and self.model_config.is_multimodal + and self.server_args.language_only + ): self.recv_from_encoder = get_zmq_socket( context, zmq.PULL, f"tcp://*:{server_args.embedding_port}", True ) @@ -731,17 +735,23 @@ async def _tokenize_one_request( ) if mm_inputs and "input_ids" in mm_inputs: input_ids = mm_inputs["input_ids"] - - if self.disaggregation_mode == DisaggregationMode.PREFILL and self.server_args.language_only: + + if ( + self.disaggregation_mode == DisaggregationMode.PREFILL + and self.server_args.language_only + ): # Use async lock to avoid race condition async with self.embeddings_lock: - while (obj.bootstrap_room not in self.received_embeddings or - not self.received_embeddings[obj.bootstrap_room].ready): + while ( + obj.bootstrap_room not in self.received_embeddings + or not self.received_embeddings[obj.bootstrap_room].ready + ): await self.handle_embedding() for mm_item in mm_inputs["mm_items"]: if mm_item.modality == Modality.IMAGE: mm_item.precomputed_embeddings = self.received_embeddings[ - obj.bootstrap_room].get() + obj.bootstrap_room + ].get() del self.received_embeddings[obj.bootstrap_room] else: mm_inputs = None @@ -1565,7 +1575,7 @@ async def handle_loop(self): recv_obj = await self.recv_from_detokenizer.recv_pyobj() self._result_dispatcher(recv_obj) self.last_receive_tstamp = time.time() - + async def handle_embedding(self): recv_obj = await self.recv_from_encoder.recv_pyobj() if recv_obj.req_id not in self.received_embeddings: diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index 9e8edbd22be6..cb9ea783e498 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -475,6 +475,7 @@ def __init__( self.pp_group = get_pp_group() self.config = config +<<<<<<< HEAD self.use_data_parallel = get_global_server_args().mm_enable_dp_encoder if self.pp_group.is_last_rank: @@ -490,6 +491,8 @@ def __init__( else: # ranks other than the last rank will have a placeholder layer self.lm_head = PPMissingLayer() +======= +>>>>>>> e9fbeb706 (Format) if not self.config.language_only: self.visual = Qwen2_5_VisionTransformer( @@ -500,7 +503,7 @@ def __init__( quant_config=quant_config, prefix=add_prefix("visual", prefix), ) - + if not self.config.mm_only: self.model = Qwen2Model( config, @@ -675,9 +678,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(".bias") and name not in params_dict: continue # Skip loading visual/language model weights - if (self.config.mm_only or self.config.language_only - ) and name not in params_dict: - continue + if ( + self.config.mm_only or self.config.language_only + ) and name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ed6cee006579..a499257f3dcc 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -253,7 +253,7 @@ class ServerArgs: warmups: Optional[str] = None nccl_port: Optional[int] = None checkpoint_engine_wait_weights_before_ready: bool = False - + # Encode prefill disaggregation mm_only: bool = False language_only: bool = False @@ -2242,23 +2242,23 @@ def add_cli_args(parser: argparse.ArgumentParser): help="If set, the server will wait for initial weights to be loaded via checkpoint-engine or other update methods " "before serving inference requests.", ) - + # Encode prefill disaggregation parser.add_argument( "--mm-only", - action='store_true', - help="For VLM, launch encode server only for multimodal part." + action="store_true", + help="For VLM, launch encode server only for multimodal part.", ) parser.add_argument( "--language-only", - action='store_true', - help="For VLM, load weights for the language model only." + action="store_true", + help="For VLM, load weights for the language model only.", ) parser.add_argument( "--embedding-port", type=int, default=54213, - help="The port for multimodal embedding transmission." + help="The port for multimodal embedding transmission.", ) # Quantization and data type diff --git a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py index 83b12bdf8bae..1de8171376d5 100644 --- a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py +++ b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py @@ -69,7 +69,7 @@ def __init__( ) self.enable_trace = False self.encode_urls = router_args.encode_urls - + self.encode_idx = list(range(len(self.encode_urls))) def _validate_router_args(self, router_args: RouterArgs): @@ -108,47 +108,48 @@ def select_pair(self): self.prefill_bootstrap_ports[pidx], self.decode_urls[didx], ) - - async def encode( - self, request_data, encode_urls, endpoint - ): - messages = request_data.get('messages') + + async def encode(self, request_data, encode_urls, endpoint): + messages = request_data.get("messages") if messages is None or len(encode_urls) == 0: return - + # Extract mm_items img_list = [] for message in messages: - for item in message.get('content'): - if item.get('type') == 'image_url': - img_url = item.get('image_url').get('url') + for item in message.get("content"): + if item.get("type") == "image_url": + img_url = item.get("image_url").get("url") img_list.append(img_url) - + if len(img_list) == 0: return - + # Split mm_items encode_requests = [] random.shuffle(self.encode_idx) - num_items_assigned = [(idx+len(img_list)) // len(self.encode_urls) for idx in self.encode_idx] + num_items_assigned = [ + (idx + len(img_list)) // len(self.encode_urls) for idx in self.encode_idx + ] num_parts = sum(1 for x in num_items_assigned if x != 0) cum_num_items = 0 cum_idx = 0 - for idx,assigned_num in enumerate(num_items_assigned): + for idx, assigned_num in enumerate(num_items_assigned): if assigned_num == 0: continue encode_requests.append( { - 'encoder_idx':idx, - 'mm_items':img_list[cum_num_items:cum_num_items+assigned_num], - 'num_parts': num_parts, - 'part_idx': cum_idx, - 'req_id': request_data.get('bootstrap_room'), - 'bootstrap_host': request_data.get('bootstrap_host'), - }) + "encoder_idx": idx, + "mm_items": img_list[cum_num_items : cum_num_items + assigned_num], + "num_parts": num_parts, + "part_idx": cum_idx, + "req_id": request_data.get("bootstrap_room"), + "bootstrap_host": request_data.get("bootstrap_host"), + } + ) cum_idx += 1 cum_num_items += assigned_num - + # Send encode requests async with aiohttp.ClientSession( timeout=aiohttp.ClientTimeout( @@ -156,7 +157,10 @@ async def encode( ) # Add timeout for request reliability ) as session: tasks = [ - session.post(f"{encode_urls[encode_request['encoder_idx']]}/{endpoint}", json=encode_request) + session.post( + f"{encode_urls[encode_request['encoder_idx']]}/{endpoint}", + json=encode_request, + ) for encode_request in encode_requests ] @@ -470,13 +474,17 @@ async def _forward_to_backend(request_data: dict, endpoint_name: str): parsed_url = urllib.parse.urlparse(prefill_server) hostname = maybe_wrap_ipv6_address(parsed_url.hostname) bootstrap_room = _generate_bootstrap_room() - + # Send requests to encode server encode_request = request_data.copy() - encode_request.update({"bootstrap_room": bootstrap_room, - "bootstrap_host": hostname,}) - asyncio.create_task(lb.encode(encode_request, lb.encode_urls, 'encode')) - + encode_request.update( + { + "bootstrap_room": bootstrap_room, + "bootstrap_host": hostname, + } + ) + asyncio.create_task(lb.encode(encode_request, lb.encode_urls, "encode")) + modified_request = encode_request.copy() modified_request.update({"bootstrap_port": bootstrap_port}) diff --git a/sgl-model-gateway/bindings/python/sglang_router/router_args.py b/sgl-model-gateway/bindings/python/sglang_router/router_args.py index 8dbd619bf821..95220a506e3d 100644 --- a/sgl-model-gateway/bindings/python/sglang_router/router_args.py +++ b/sgl-model-gateway/bindings/python/sglang_router/router_args.py @@ -181,7 +181,6 @@ def add_cli_args( choices=["random", "round_robin", "cache_aware", "power_of_two"], help="Specific policy for decode nodes in PD mode. If not specified, uses the main policy", ) - # PD-specific arguments parser.add_argument( @@ -792,7 +791,7 @@ def _parse_decode_urls(decode_list): # decode_list is a list of single-element lists due to nargs=1 return [url[0] for url in decode_list] - + @staticmethod def _parse_encode_urls(encode_list): """Parse encode URLs from --encode arguments. From c56adb8d449dda79a870e493903beba2adcf58a7 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Tue, 28 Oct 2025 10:47:24 +0000 Subject: [PATCH 16/68] Format --- python/sglang/srt/managers/tokenizer_manager.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index a1cb8bc02ce1..9fb2c09b46a9 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -74,9 +74,8 @@ from sglang.srt.managers.mm_utils import TensorTransportMode from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors from sglang.srt.managers.request_metrics_exporter import RequestMetricsExporterManager -from sglang.srt.managers.schedule_batch import RequestStage from sglang.srt.managers.scheduler import is_health_check_generate_req -from sglang.srt.managers.schedule_batch import Modality +from sglang.srt.managers.schedule_batch import Modality, RequestStage from sglang.srt.managers.scheduler import is_health_check_generate_req from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin From 1d68c5282e7e6e3c81950ab5f3c136e3e95e6970 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Wed, 29 Oct 2025 03:57:24 +0000 Subject: [PATCH 17/68] Support qwen3_vl --- python/sglang/srt/models/qwen3_vl.py | 59 +++++++++++++++++----------- 1 file changed, 36 insertions(+), 23 deletions(-) diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py index ed52f7ff40f2..eb5d89da0bcd 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -601,15 +601,17 @@ def __init__( super().__init__() self.use_data_parallel = get_global_server_args().mm_enable_dp_encoder - self.visual = Qwen3VLMoeVisionModel( - config.vision_config, - # NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization. - # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported. - quant_config=quant_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - prefix=add_prefix("visual", prefix), - use_data_parallel=self.use_data_parallel, - ) + + if not config.language_only: + self.visual = Qwen3VLMoeVisionModel( + config.vision_config, + # NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization. + # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported. + quant_config=quant_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + prefix=add_prefix("visual", prefix), + use_data_parallel=self.use_data_parallel, + ) # TODO: make it more elegant if language_model_cls is Qwen3LLMModel: @@ -617,21 +619,22 @@ def __init__( else: self.config = config.text_config # for qwen3-omni - self.model = language_model_cls( - config=self.config, - quant_config=quant_config, - prefix=add_prefix("model", prefix), - ) - - if self.config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens - else: - self.lm_head = ParallelLMHead( - self.config.vocab_size, - self.config.hidden_size, + if not config.mm_only: + self.model = language_model_cls( + config=self.config, quant_config=quant_config, - prefix=add_prefix("lm_head", prefix), + prefix=add_prefix("model", prefix), ) + + if self.config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + ) self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling self.logits_processor = LogitsProcessor(self.config) @@ -640,7 +643,7 @@ def __init__( # 8, 16, 24 layer will be merged to 0, 1, 2 layer of decoder output hidden_states # deepstack - self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes + self.deepstack_visual_indexes = config.vision_config.deepstack_visual_indexes self.num_deepstack_embeddings = len(self.deepstack_visual_indexes) self.use_deepstack = {Modality.IMAGE: True, Modality.VIDEO: True} @@ -774,6 +777,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip loading visual/language model weights + if ( + self.config.mm_only or self.config.language_only + ) and name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -788,6 +796,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip loading visual/language model weights + if ( + self.config.mm_only or self.config.language_only + ) and name not in params_dict: + continue param = params_dict[name] except KeyError: print(params_dict.keys()) From 25e8568984affa71382a7cdbb1980276bfa2b556 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Wed, 29 Oct 2025 06:59:37 +0000 Subject: [PATCH 18/68] Support qwen3_vl_moe --- python/sglang/srt/models/qwen3_vl.py | 2 ++ python/sglang/srt/models/qwen3_vl_moe.py | 8 +++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py index eb5d89da0bcd..fdbd96233e79 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -618,6 +618,8 @@ def __init__( self.config: Qwen3VLConfig = config # for qwen3-vl else: self.config = config.text_config # for qwen3-omni + self.config.mm_only = config.mm_only + self.config.language_only = config.language_only if not config.mm_only: self.model = language_model_cls( diff --git a/python/sglang/srt/models/qwen3_vl_moe.py b/python/sglang/srt/models/qwen3_vl_moe.py index e3e9e07d1962..181d1f17f4fe 100644 --- a/python/sglang/srt/models/qwen3_vl_moe.py +++ b/python/sglang/srt/models/qwen3_vl_moe.py @@ -243,7 +243,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - if "visual" in name: + if "visual" in name or self.config.mm_only: continue # Anyway, this is an expert weight and should not be # attempted to load as other weights later @@ -309,6 +309,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(ignore_suffixes) and name not in params_dict: continue + # Skip loading mm/language parameters + if ( + self.config.mm_only or self.config.language_only + ) and name not in params_dict: + continue + if name in params_dict.keys(): param = params_dict[name] weight_loader = getattr( From 714bad368e1591de51914ebdbae182ab537fb1ed Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Fri, 31 Oct 2025 02:27:49 +0000 Subject: [PATCH 19/68] Support [E]+[PD colocate] --- .../sglang/srt/managers/tokenizer_manager.py | 11 +--- .../bindings/python/sglang_router/mini_lb.py | 66 ++++++++++++++----- .../python/sglang_router/router_args.py | 6 ++ 3 files changed, 57 insertions(+), 26 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 9fb2c09b46a9..57fdc055669d 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -315,11 +315,7 @@ def __init__( self.send_to_scheduler = SenderWrapper(port_args, send_to_scheduler) # Recv embedding from encoding server - if ( - self.disaggregation_mode == DisaggregationMode.PREFILL - and self.model_config.is_multimodal - and self.server_args.language_only - ): + if self.model_config.is_multimodal and self.server_args.language_only: self.recv_from_encoder = get_zmq_socket( context, zmq.PULL, f"tcp://*:{server_args.embedding_port}", True ) @@ -735,10 +731,7 @@ async def _tokenize_one_request( if mm_inputs and "input_ids" in mm_inputs: input_ids = mm_inputs["input_ids"] - if ( - self.disaggregation_mode == DisaggregationMode.PREFILL - and self.server_args.language_only - ): + if self.server_args.language_only: # Use async lock to avoid race condition async with self.embeddings_lock: while ( diff --git a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py index 1de8171376d5..9ea9e77dcfb8 100644 --- a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py +++ b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py @@ -82,14 +82,26 @@ def _validate_router_args(self, router_args: RouterArgs): logger.warning("[MiniLB] Overriding policy to random") router_args.policy = "random" - if not router_args.pd_disaggregation: - raise ValueError("MiniLB only supports PD disaggregation mode") + if not router_args.pd_disaggregation and not router_args.e_disaggregation: + raise ValueError("MiniLB only supports PD/E disaggregation mode") - if len(router_args.prefill_urls) == 0 or len(router_args.decode_urls) == 0: + if router_args.pd_disaggregation and router_args.e_disaggregation: raise ValueError( - "MiniLB requires at least one prefill and one decode server" + "MiniLB does not support PD and E disaggregation modes at the same time." ) + if len(router_args.prefill_urls) == 0: + raise ValueError("MiniLB requires at least one prefill server") + + if router_args.pd_disaggregation and len(router_args.decode_urls) == 0: + raise ValueError( + "The PD disaggregation mode requires at least one decode server." + ) + + if router_args.e_disaggregation and len(router_args.decode_urls) != 0: + logger.warning("The E disaggregation mode doesn't require decode server") + router_args.decode_urls = [] + def start(self): global lb lb = self @@ -99,14 +111,16 @@ def start(self): uvicorn.run(app, host=self.host, port=self.port) def select_pair(self): - assert len(self.prefill_urls) > 0, "No prefill servers available" - assert len(self.decode_urls) > 0, "No decode servers available" pidx = random.randint(0, len(self.prefill_urls) - 1) - didx = random.randint(0, len(self.decode_urls) - 1) + if len(self.decode_urls) != 0: + didx = random.randint(0, len(self.decode_urls) - 1) + decode_url = self.decode_urls[didx] + else: + decode_url = None return ( self.prefill_urls[pidx], self.prefill_bootstrap_ports[pidx], - self.decode_urls[didx], + decode_url, ) async def encode(self, request_data, encode_urls, endpoint): @@ -188,17 +202,23 @@ async def generate( headers = {"trace_context": trace_context} tasks = [ - session.post(f"{prefill_server}/{endpoint}", json=modified_request), - session.post(f"{decode_server}/{endpoint}", json=modified_request), + session.post(f"{prefill_server}/{endpoint}", json=modified_request) ] + if decode_server is not None: + tasks.append( + session.post(f"{decode_server}/{endpoint}", json=modified_request) + ) for bootstrap_room in bootstrap_room_list: trace_slice_end("mini_lb_launch", bootstrap_room, auto_next_anon=True) # Wait for both responses to complete. Prefill should end first. - prefill_response, decode_response = await asyncio.gather(*tasks) - - if "return_logprob" in modified_request: + responses = await asyncio.gather(*tasks) + prefill_response = responses[0] + decode_response = ( + responses[1] if decode_server is not None else prefill_response + ) + if "return_logprob" in modified_request and decode_server is not None: prefill_json = await prefill_response.json() ret_json = await decode_response.json() @@ -252,18 +272,30 @@ async def stream_results(): headers = {"trace_context": trace_context} tasks = [ - session.post(f"{prefill_server}/{endpoint}", json=modified_request), - session.post(f"{decode_server}/{endpoint}", json=modified_request), + session.post(f"{prefill_server}/{endpoint}", json=modified_request) ] + if decode_server is not None: + tasks.append( + session.post( + f"{decode_server}/{endpoint}", json=modified_request + ) + ) for bootstrap_room in bootstrap_room_list: trace_slice_end( "mini_lb_launch", bootstrap_room, auto_next_anon=True ) # Wait for both responses to complete. Since this is streaming, they return immediately. - prefill_response, decode_response = await asyncio.gather(*tasks) + responses = await asyncio.gather(*tasks) + prefill_response = responses[0] + decode_response = ( + responses[1] if decode_server is not None else prefill_response + ) - if modified_request.get("return_logprob", False): + if ( + modified_request.get("return_logprob", False) + and decode_server is not None + ): prefill_chunks = [] async for chunk in prefill_response.content: prefill_chunks.append(chunk) diff --git a/sgl-model-gateway/bindings/python/sglang_router/router_args.py b/sgl-model-gateway/bindings/python/sglang_router/router_args.py index 95220a506e3d..476c129a7f3a 100644 --- a/sgl-model-gateway/bindings/python/sglang_router/router_args.py +++ b/sgl-model-gateway/bindings/python/sglang_router/router_args.py @@ -17,6 +17,7 @@ class RouterArgs: # PD-specific configuration mini_lb: bool = False pd_disaggregation: bool = False # Enable PD disaggregated mode + e_disaggregation: bool = False # Enable E disaggregated mode prefill_urls: List[tuple] = dataclasses.field( default_factory=list ) # List of (url, bootstrap_port) @@ -193,6 +194,11 @@ def add_cli_args( action="store_true", help="Enable PD (Prefill-Decode) disaggregated mode", ) + parser.add_argument( + f"--{prefix}e-disaggregation", + action="store_true", + help="Enable E (Encode) disaggregated mode", + ) parser.add_argument( f"--{prefix}prefill", nargs="+", From 2c452f03c582b14488715e5c84f6920b0d37a910 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Tue, 4 Nov 2025 15:47:16 +0000 Subject: [PATCH 20/68] Support mooncake for transmission --- .../sglang/srt/entrypoints/encode_server.py | 101 +++++++++++++----- python/sglang/srt/entrypoints/http_server.py | 11 ++ .../sglang/srt/managers/tokenizer_manager.py | 43 ++++++-- python/sglang/srt/server_args.py | 2 +- .../bindings/python/sglang_router/mini_lb.py | 71 ++++++++++-- 5 files changed, 180 insertions(+), 48 deletions(-) diff --git a/python/sglang/srt/entrypoints/encode_server.py b/python/sglang/srt/entrypoints/encode_server.py index d9b3dbb6d978..e77b08949836 100644 --- a/python/sglang/srt/entrypoints/encode_server.py +++ b/python/sglang/srt/entrypoints/encode_server.py @@ -13,6 +13,7 @@ from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine from sglang.srt.distributed.parallel_state import ( init_distributed_environment, initialize_model_parallel, @@ -26,31 +27,29 @@ ServerArgs, set_global_server_args_for_scheduler, ) -from sglang.srt.utils import get_zmq_socket +from sglang.srt.utils import get_local_ip_auto, get_zmq_socket class EmbeddingData: - def __init__(self, req_id, num_parts, part_idx, mm_embedding): + def __init__(self, req_id, num_parts, part_idx, embedding_len): self.req_id = req_id self.num_parts = num_parts self.part_idx = part_idx - self.embedding = mm_embedding - self.embedding_dict = dict() - self.embedding_dict[part_idx] = mm_embedding + self.embedding_len = embedding_len + + # aggregated data + self.ready_list = [i == self.part_idx for i in range(self.num_parts)] + self.embedding_len_tot = self.embedding_len def add(self, embedding_data): assert self.req_id == embedding_data.req_id - assert embedding_data.part_idx not in self.embedding_dict - self.embedding_dict[embedding_data.part_idx] = embedding_data.embedding - - def get(self): - assert len(self.embedding_dict) == self.num_parts - agg_data = [self.embedding_dict[i] for i in range(self.num_parts)] - return torch.concatenate(agg_data) + assert not self.ready_list[embedding_data.part_idx] + self.ready_list[embedding_data.part_idx] = True + self.embedding_len_tot += embedding_data.embedding_len @property def ready(self): - return len(self.embedding_dict) == self.num_parts + return sum(self.ready_list) == self.num_parts class ImageEncoder: @@ -81,6 +80,8 @@ def __init__(self, server_args: ServerArgs): else: dist_init_method = f"tcp://127.0.0.1:{port_args.nccl_port}" + self.gpu_id = 0 + init_distributed_environment( world_size=1, rank=0, distributed_init_method=dist_init_method ) @@ -90,13 +91,22 @@ def __init__(self, server_args: ServerArgs): self.model = get_model( model_config=self.model_config, load_config=self.load_config, - device_config=DeviceConfig(), + device_config=DeviceConfig(device="cuda", gpu_id=self.gpu_id), + ) + + self.local_ip = get_local_ip_auto() + + self.engine = MooncakeTransferEngine( + hostname=self.local_ip, + gpu_id=self.gpu_id, ) self.context = zmq.asyncio.Context(2) self.send_to_prefill_sockets = dict() - async def encode(self, mm_items): + self.embedding_to_send = dict() + + async def encode(self, mm_items) -> torch.Tensor: images = load_images(mm_items) # Qwen-specific: resize images @@ -114,7 +124,21 @@ async def encode(self, mm_items): mm_embedding = self.model.get_image_feature([mm_item]) return mm_embedding - def send(self, send_data, prefill_ip): + def send( + self, + session_id, + peer_buffer_address, + embedding: torch.Tensor, + meta_data: EmbeddingData, + prefill_ip, + ): + self.engine.register(embedding.data_ptr(), embedding.nbytes) + self.engine.transfer_sync( + session_id, embedding.data_ptr(), peer_buffer_address, embedding.nbytes + ) + self.engine.deregister(embedding.data_ptr()) + + # Send ack if prefill_ip in self.send_to_prefill_sockets: socket = self.send_to_prefill_sockets[prefill_ip] else: @@ -125,19 +149,38 @@ def send(self, send_data, prefill_ip): False, ) self.send_to_prefill_sockets[prefill_ip] = socket - socket.send_pyobj(send_data) + socket.send_pyobj(meta_data) @torch.inference_mode() - async def step(self, request_data): - mm_embeddings = await self.encode(request_data["mm_items"]) - send_data = EmbeddingData( - request_data["req_id"], - request_data["num_parts"], - request_data["part_idx"], - mm_embeddings, - ) - self.send(send_data, request_data["bootstrap_host"]) - del send_data + async def step(self, request_data: dict): + if "mm_items" in request_data: + mm_embedding = await self.encode(request_data["mm_items"]) + meta_data = EmbeddingData( + request_data["req_id"], + request_data["num_parts"], + request_data["part_idx"], + mm_embedding.shape[0], + ) + self.embedding_to_send[meta_data.req_id] = (mm_embedding, meta_data) + del request_data["mm_items"] + request_data.update( + { + "embedding_size": mm_embedding.nbytes, + "embedding_len": mm_embedding.shape[0], + } + ) + return request_data + else: + mm_embedding, meta_data = self.embedding_to_send[request_data["req_id"]] + self.send( + request_data["session_id"], + request_data["buffer_address"], + mm_embedding, + meta_data, + request_data["bootstrap_host"], + ) + del self.embedding_to_send[request_data["req_id"]] + return None app = FastAPI() @@ -152,5 +195,5 @@ def launch_server(server_args: ServerArgs): @app.post("/encode") async def handle_encode_request(request_data: dict): - await encoder.step(request_data) - return ORJSONResponse(content=None) + ret_data = await encoder.step(request_data) + return ORJSONResponse(content=ret_data) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index cf0a3784fe8c..fbe8106edb95 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -1331,6 +1331,17 @@ async def sagemaker_chat_completions( ) +@app.post("/embedding_bootstrap") +async def embedding_bootstrap(request_data: dict): + buffer_address = await _global_state.tokenizer_manager.allocate_embedding_buffer( + request_data["req_id"], request_data["embedding_length"] + ) + session_id = _global_state.tokenizer_manager.embeddings_engine.session_id + return ORJSONResponse( + content={"session_id": session_id, "buffer_address": buffer_address} + ) + + ## Vertex AI API @app.post(os.environ.get("AIP_PREDICT_ROUTE", "/vertex_generate")) async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Request): diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 57fdc055669d..0334b274c62c 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -40,6 +40,7 @@ from fastapi import BackgroundTasks from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.environ import envs from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry @@ -102,6 +103,7 @@ freeze_gc, get_bool_env_var, get_or_create_event_loop, + get_local_ip_auto, get_zmq_socket, kill_process_tree, ) @@ -319,8 +321,13 @@ def __init__( self.recv_from_encoder = get_zmq_socket( context, zmq.PULL, f"tcp://*:{server_args.embedding_port}", True ) - self.received_embeddings = dict() + self.received_metadata = dict() self.embeddings_lock = asyncio.Lock() + self.embeddings_engine = MooncakeTransferEngine( + get_local_ip_auto(), + gpu_id=0, + ) + self.embeddings_buffer = dict() # Request states self._chosen_loop = None @@ -566,6 +573,18 @@ def _detect_input_format( return "batch_strings" + async def allocate_embedding_buffer(self, req_id, embedding_length): + embeddings = torch.zeros( + (embedding_length, self.model_config.hidden_size), + dtype=self.model_config.dtype, + ) + self.embeddings_engine.register( + embeddings.data_ptr(), + embeddings.nbytes, + ) + self.embeddings_buffer[req_id] = embeddings + return embeddings.data_ptr() + def _prepare_tokenizer_input( self, texts: Union[str, List[str]], input_format: str ) -> Union[List[str], List[List[str]]]: @@ -735,16 +754,22 @@ async def _tokenize_one_request( # Use async lock to avoid race condition async with self.embeddings_lock: while ( - obj.bootstrap_room not in self.received_embeddings - or not self.received_embeddings[obj.bootstrap_room].ready + obj.bootstrap_room not in self.received_metadata + or not self.received_metadata[obj.bootstrap_room].ready ): await self.handle_embedding() for mm_item in mm_inputs["mm_items"]: if mm_item.modality == Modality.IMAGE: - mm_item.precomputed_embeddings = self.received_embeddings[ + mm_item.precomputed_embeddings = self.embeddings_buffer[ obj.bootstrap_room - ].get() - del self.received_embeddings[obj.bootstrap_room] + ] + self.embeddings_engine.deregister( + mm_item.precomputed_embeddings.data_ptr() + ) + break + del self.received_metadata[obj.bootstrap_room] + del self.embeddings_buffer[obj.bootstrap_room] + else: mm_inputs = None @@ -1570,10 +1595,10 @@ async def handle_loop(self): async def handle_embedding(self): recv_obj = await self.recv_from_encoder.recv_pyobj() - if recv_obj.req_id not in self.received_embeddings: - self.received_embeddings[recv_obj.req_id] = recv_obj + if recv_obj.req_id not in self.received_metadata: + self.received_metadata[recv_obj.req_id] = recv_obj else: - self.received_embeddings[recv_obj.req_id].add(recv_obj) + self.received_metadata[recv_obj.req_id].add(recv_obj) def _add_metric_if_present( self, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index a499257f3dcc..51d63ed17038 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2258,7 +2258,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--embedding-port", type=int, default=54213, - help="The port for multimodal embedding transmission.", + help="The port for transmitting embedding metadata.", ) # Quantization and data type diff --git a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py index 9ea9e77dcfb8..f6c89bc7eb3f 100644 --- a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py +++ b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py @@ -99,8 +99,7 @@ def _validate_router_args(self, router_args: RouterArgs): ) if router_args.e_disaggregation and len(router_args.decode_urls) != 0: - logger.warning("The E disaggregation mode doesn't require decode server") - router_args.decode_urls = [] + raise ValueError("The E disaggregation mode doesn't require decode server") def start(self): global lb @@ -123,7 +122,17 @@ def select_pair(self): decode_url, ) - async def encode(self, request_data, encode_urls, endpoint): + async def embedding_bootstrap(self, session, prefill_url, req_id, embedding_length): + response = await session.post( + f"{prefill_url}/embedding_bootstrap", + json={"req_id": req_id, "embedding_length": embedding_length}, + ) + response_json = await response.json() + session_id = response_json["session_id"] + buffer_address = response_json["buffer_address"] + return session_id, buffer_address + + async def encode(self, request_data, encode_urls, endpoint, prefill_url): messages = request_data.get("messages") if messages is None or len(encode_urls) == 0: return @@ -139,6 +148,7 @@ async def encode(self, request_data, encode_urls, endpoint): if len(img_list) == 0: return + req_id = request_data.get("bootstrap_room") # Split mm_items encode_requests = [] random.shuffle(self.encode_idx) @@ -157,19 +167,19 @@ async def encode(self, request_data, encode_urls, endpoint): "mm_items": img_list[cum_num_items : cum_num_items + assigned_num], "num_parts": num_parts, "part_idx": cum_idx, - "req_id": request_data.get("bootstrap_room"), - "bootstrap_host": request_data.get("bootstrap_host"), + "req_id": req_id, } ) cum_idx += 1 cum_num_items += assigned_num - # Send encode requests async with aiohttp.ClientSession( timeout=aiohttp.ClientTimeout( total=self.timeout ) # Add timeout for request reliability ) as session: + # Send encode requests + tasks = [ session.post( f"{encode_urls[encode_request['encoder_idx']]}/{endpoint}", @@ -178,7 +188,46 @@ async def encode(self, request_data, encode_urls, endpoint): for encode_request in encode_requests ] - await asyncio.gather(*tasks) + responses = await asyncio.gather(*tasks) + response_json_list_unsort = [ + await response.json() for response in responses + ] + + # Send bootstrap info + + embedding_size_list_sort = [None for _ in range(num_parts)] + embedding_length_tot = 0 + response_json_list_sort = [None for _ in range(num_parts)] + for response_json in response_json_list_unsort: + idx = response_json["part_idx"] + embedding_size_list_sort[idx] = response_json["embedding_size"] + embedding_length_tot += response_json["embedding_len"] + response_json_list_sort[idx] = response_json + + offset = 0 + metadata_tasks = [] + prefill_ip = request_data["bootstrap_host"] + session_id, buffer_address = await self.embedding_bootstrap( + session, prefill_url, req_id, embedding_length_tot + ) + for idx in range(len(tasks)): + response_json = response_json_list_sort[idx] + buffer_address_adjust = offset + buffer_address + response_json.update( + { + "session_id": session_id, + "buffer_address": buffer_address_adjust, + "bootstrap_host": prefill_ip, + } + ) + metadata_tasks.append( + session.post( + f"{encode_urls[response_json["encoder_idx"]]}/{endpoint}", + json=response_json, + ) + ) + offset += embedding_size_list_sort[idx] + await asyncio.gather(*metadata_tasks) async def generate( self, modified_request, prefill_server, decode_server, endpoint @@ -515,10 +564,14 @@ async def _forward_to_backend(request_data: dict, endpoint_name: str): "bootstrap_host": hostname, } ) - asyncio.create_task(lb.encode(encode_request, lb.encode_urls, "encode")) + asyncio.create_task( + lb.encode(encode_request, lb.encode_urls, "encode", prefill_server) + ) modified_request = encode_request.copy() - modified_request.update({"bootstrap_port": bootstrap_port}) + modified_request.update( + {"bootstrap_port": bootstrap_port, "encode_urls": lb.encode_urls} + ) if request_data.get("stream", False): return await lb.generate_stream( From 88b74bd356ac7a75135c398cc1f65bb51e310483 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Wed, 5 Nov 2025 03:03:38 +0000 Subject: [PATCH 21/68] Fix embedding shape --- python/sglang/srt/entrypoints/encode_server.py | 1 + python/sglang/srt/entrypoints/http_server.py | 4 +++- python/sglang/srt/managers/tokenizer_manager.py | 4 ++-- .../bindings/python/sglang_router/mini_lb.py | 16 +++++++++++++--- 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/entrypoints/encode_server.py b/python/sglang/srt/entrypoints/encode_server.py index e77b08949836..1ef251dc3c81 100644 --- a/python/sglang/srt/entrypoints/encode_server.py +++ b/python/sglang/srt/entrypoints/encode_server.py @@ -167,6 +167,7 @@ async def step(self, request_data: dict): { "embedding_size": mm_embedding.nbytes, "embedding_len": mm_embedding.shape[0], + "embedding_dim": mm_embedding.shape[1], } ) return request_data diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index fbe8106edb95..318a44928968 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -1334,7 +1334,9 @@ async def sagemaker_chat_completions( @app.post("/embedding_bootstrap") async def embedding_bootstrap(request_data: dict): buffer_address = await _global_state.tokenizer_manager.allocate_embedding_buffer( - request_data["req_id"], request_data["embedding_length"] + request_data["req_id"], + request_data["embedding_length"], + request_data["embedding_dim"], ) session_id = _global_state.tokenizer_manager.embeddings_engine.session_id return ORJSONResponse( diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 0334b274c62c..bca7331aa304 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -573,9 +573,9 @@ def _detect_input_format( return "batch_strings" - async def allocate_embedding_buffer(self, req_id, embedding_length): + async def allocate_embedding_buffer(self, req_id, embedding_length, embedding_dim): embeddings = torch.zeros( - (embedding_length, self.model_config.hidden_size), + (embedding_length, embedding_dim), dtype=self.model_config.dtype, ) self.embeddings_engine.register( diff --git a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py index f6c89bc7eb3f..52cf66b1e795 100644 --- a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py +++ b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py @@ -122,10 +122,16 @@ def select_pair(self): decode_url, ) - async def embedding_bootstrap(self, session, prefill_url, req_id, embedding_length): + async def embedding_bootstrap( + self, session, prefill_url, req_id, embedding_length, embedding_dim + ): response = await session.post( f"{prefill_url}/embedding_bootstrap", - json={"req_id": req_id, "embedding_length": embedding_length}, + json={ + "req_id": req_id, + "embedding_length": embedding_length, + "embedding_dim": embedding_dim, + }, ) response_json = await response.json() session_id = response_json["session_id"] @@ -208,7 +214,11 @@ async def encode(self, request_data, encode_urls, endpoint, prefill_url): metadata_tasks = [] prefill_ip = request_data["bootstrap_host"] session_id, buffer_address = await self.embedding_bootstrap( - session, prefill_url, req_id, embedding_length_tot + session, + prefill_url, + req_id, + embedding_length_tot, + response_json_list_sort[0]["embedding_dim"], ) for idx in range(len(tasks)): response_json = response_json_list_sort[idx] From 12a94515d96a8bbaad2225b9fc747c3034b5ded9 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Wed, 5 Nov 2025 03:29:07 +0000 Subject: [PATCH 22/68] Fix ib_device,mm_embedding shape,format --- python/sglang/srt/entrypoints/encode_server.py | 3 +++ python/sglang/srt/managers/tokenizer_manager.py | 3 ++- sgl-model-gateway/bindings/python/sglang_router/mini_lb.py | 3 ++- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/entrypoints/encode_server.py b/python/sglang/srt/entrypoints/encode_server.py index 1ef251dc3c81..3267ce06b516 100644 --- a/python/sglang/srt/entrypoints/encode_server.py +++ b/python/sglang/srt/entrypoints/encode_server.py @@ -99,6 +99,7 @@ def __init__(self, server_args: ServerArgs): self.engine = MooncakeTransferEngine( hostname=self.local_ip, gpu_id=self.gpu_id, + ib_device=server_args.disaggregation_ib_device, ) self.context = zmq.asyncio.Context(2) @@ -122,6 +123,8 @@ async def encode(self, mm_items) -> torch.Tensor: ) mm_item.set("image_grid_thw", images_input["image_grid_thw"]) mm_embedding = self.model.get_image_feature([mm_item]) + if len(mm_embedding.shape) == 3: + mm_embedding = mm_embedding.reshape(-1, mm_embedding.shape[-1]) return mm_embedding def send( diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index bca7331aa304..6453323df368 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -324,8 +324,9 @@ def __init__( self.received_metadata = dict() self.embeddings_lock = asyncio.Lock() self.embeddings_engine = MooncakeTransferEngine( - get_local_ip_auto(), + hostname=get_local_ip_auto(), gpu_id=0, + ib_device=server_args.disaggregation_ib_device, ) self.embeddings_buffer = dict() diff --git a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py index 52cf66b1e795..66c508dd03fb 100644 --- a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py +++ b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py @@ -230,9 +230,10 @@ async def encode(self, request_data, encode_urls, endpoint, prefill_url): "bootstrap_host": prefill_ip, } ) + encoder_idx = response_json["encoder_idx"] metadata_tasks.append( session.post( - f"{encode_urls[response_json["encoder_idx"]]}/{endpoint}", + f"{encode_urls[encoder_idx]}/{endpoint}", json=response_json, ) ) From bf71817aab51f28fef609abf0c4d2db871cfb175 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Wed, 5 Nov 2025 03:54:27 +0000 Subject: [PATCH 23/68] Fix router args --- sgl-model-gateway/bindings/python/sglang_router/router.py | 2 ++ sgl-model-gateway/bindings/python/sglang_router/router_args.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/sgl-model-gateway/bindings/python/sglang_router/router.py b/sgl-model-gateway/bindings/python/sglang_router/router.py index 05506e1cd560..06f913818f54 100644 --- a/sgl-model-gateway/bindings/python/sglang_router/router.py +++ b/sgl-model-gateway/bindings/python/sglang_router/router.py @@ -202,6 +202,8 @@ def from_args(args: RouterArgs) -> "Router": # Remove fields that shouldn't be passed to Rust Router constructor fields_to_remove = [ "mini_lb", + "e_disaggregation", + "encode_urls", "oracle_wallet_path", "oracle_tns_alias", "oracle_connect_descriptor", diff --git a/sgl-model-gateway/bindings/python/sglang_router/router_args.py b/sgl-model-gateway/bindings/python/sglang_router/router_args.py index 476c129a7f3a..1ed269350d12 100644 --- a/sgl-model-gateway/bindings/python/sglang_router/router_args.py +++ b/sgl-model-gateway/bindings/python/sglang_router/router_args.py @@ -729,6 +729,8 @@ def _validate_router_args(self): f"Using --policy '{self.policy}' for prefill nodes " f"and --decode-policy '{self.decode_policy}' for decode nodes." ) + if self.e_disaggregation or len(self.encode_urls): + raise ValueError("Currently, E disaggregation mode requires --min-lb") @staticmethod def _parse_selector(selector_list): From c93c13db92cd2dcace831d9c49b50bc794428ed9 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Wed, 5 Nov 2025 10:06:49 +0000 Subject: [PATCH 24/68] Add mm_transfer_backend (zmq or mooncake) --- .../sglang/srt/entrypoints/encode_server.py | 162 +++++++++++------- .../sglang/srt/managers/tokenizer_manager.py | 48 +++--- python/sglang/srt/server_args.py | 14 +- .../bindings/python/sglang_router/mini_lb.py | 23 ++- 4 files changed, 155 insertions(+), 92 deletions(-) diff --git a/python/sglang/srt/entrypoints/encode_server.py b/python/sglang/srt/entrypoints/encode_server.py index 3267ce06b516..23ce48144f05 100644 --- a/python/sglang/srt/entrypoints/encode_server.py +++ b/python/sglang/srt/entrypoints/encode_server.py @@ -1,4 +1,5 @@ import asyncio +import logging from typing import Optional import torch @@ -29,28 +30,38 @@ ) from sglang.srt.utils import get_local_ip_auto, get_zmq_socket +logger = logging.getLogger(__name__) + class EmbeddingData: - def __init__(self, req_id, num_parts, part_idx, embedding_len): + def __init__(self, req_id, num_parts, part_idx, embedding=None): self.req_id = req_id self.num_parts = num_parts self.part_idx = part_idx - self.embedding_len = embedding_len + self.embedding = embedding # aggregated data self.ready_list = [i == self.part_idx for i in range(self.num_parts)] - self.embedding_len_tot = self.embedding_len + self.embedding_list = [ + embedding if i == self.part_idx else None for i in range(self.num_parts) + ] def add(self, embedding_data): assert self.req_id == embedding_data.req_id assert not self.ready_list[embedding_data.part_idx] self.ready_list[embedding_data.part_idx] = True - self.embedding_len_tot += embedding_data.embedding_len + self.embedding_list[embedding_data.part_idx] = embedding_data.embedding + + def get(self): + return torch.concatenate(self.embedding_list) @property def ready(self): return sum(self.ready_list) == self.num_parts + def __repr__(self): + return f"EmbeddingData(req_id={self.req_id}, num_parts={self.num_parts}, part_idx={self.part_idx})" + class ImageEncoder: def __init__(self, server_args: ServerArgs): @@ -94,20 +105,23 @@ def __init__(self, server_args: ServerArgs): device_config=DeviceConfig(device="cuda", gpu_id=self.gpu_id), ) - self.local_ip = get_local_ip_auto() + logger.info(f"Using transfer backend: {self.server_args.mm_transfer_backend}") - self.engine = MooncakeTransferEngine( - hostname=self.local_ip, - gpu_id=self.gpu_id, - ib_device=server_args.disaggregation_ib_device, - ) + if self.server_args.mm_transfer_backend == "mooncake": + self.local_ip = get_local_ip_auto() + + self.engine = MooncakeTransferEngine( + hostname=self.local_ip, + gpu_id=self.gpu_id, + ib_device=server_args.disaggregation_ib_device, + ) self.context = zmq.asyncio.Context(2) self.send_to_prefill_sockets = dict() self.embedding_to_send = dict() - async def encode(self, mm_items) -> torch.Tensor: + async def mm_encode(self, mm_items) -> torch.Tensor: images = load_images(mm_items) # Qwen-specific: resize images @@ -123,68 +137,63 @@ async def encode(self, mm_items) -> torch.Tensor: ) mm_item.set("image_grid_thw", images_input["image_grid_thw"]) mm_embedding = self.model.get_image_feature([mm_item]) - if len(mm_embedding.shape) == 3: + if len(mm_embedding.shape) != 2: mm_embedding = mm_embedding.reshape(-1, mm_embedding.shape[-1]) return mm_embedding - def send( + def mm_send( self, - session_id, - peer_buffer_address, + prefill_host, embedding: torch.Tensor, - meta_data: EmbeddingData, - prefill_ip, + mm_data: EmbeddingData, + session_id=None, + peer_buffer_address=None, ): - self.engine.register(embedding.data_ptr(), embedding.nbytes) - self.engine.transfer_sync( - session_id, embedding.data_ptr(), peer_buffer_address, embedding.nbytes - ) - self.engine.deregister(embedding.data_ptr()) + if self.server_args.mm_transfer_backend == "mooncake": + self.engine.register(embedding.data_ptr(), embedding.nbytes) + self.engine.transfer_sync( + session_id, embedding.data_ptr(), peer_buffer_address, embedding.nbytes + ) + self.engine.deregister(embedding.data_ptr()) - # Send ack - if prefill_ip in self.send_to_prefill_sockets: - socket = self.send_to_prefill_sockets[prefill_ip] + mm_data.embedding = None + mm_data.embedding_list[mm_data.part_idx] = None + + # Send ack/data + if prefill_host in self.send_to_prefill_sockets: + socket = self.send_to_prefill_sockets[prefill_host] else: socket = get_zmq_socket( self.context, zmq.PUSH, - f"tcp://{prefill_ip}:{self.server_args.embedding_port}", + f"tcp://{prefill_host}:{self.server_args.embedding_port}", False, ) - self.send_to_prefill_sockets[prefill_ip] = socket - socket.send_pyobj(meta_data) + self.send_to_prefill_sockets[prefill_host] = socket + socket.send_pyobj(mm_data) @torch.inference_mode() - async def step(self, request_data: dict): - if "mm_items" in request_data: - mm_embedding = await self.encode(request_data["mm_items"]) - meta_data = EmbeddingData( - request_data["req_id"], - request_data["num_parts"], - request_data["part_idx"], - mm_embedding.shape[0], - ) - self.embedding_to_send[meta_data.req_id] = (mm_embedding, meta_data) - del request_data["mm_items"] - request_data.update( - { - "embedding_size": mm_embedding.nbytes, - "embedding_len": mm_embedding.shape[0], - "embedding_dim": mm_embedding.shape[1], - } - ) - return request_data - else: - mm_embedding, meta_data = self.embedding_to_send[request_data["req_id"]] - self.send( - request_data["session_id"], - request_data["buffer_address"], - mm_embedding, - meta_data, - request_data["bootstrap_host"], - ) - del self.embedding_to_send[request_data["req_id"]] - return None + async def encode(self, mm_items, req_id, num_parts, part_idx): + mm_embedding = await self.mm_encode(mm_items) + mm_data = EmbeddingData( + req_id, + num_parts, + part_idx, + mm_embedding, + ) + self.embedding_to_send[mm_data.req_id] = mm_data + return mm_embedding.nbytes, mm_embedding.shape[0], mm_embedding.shape[1] + + async def send(self, req_id, prefill_host, session_id=None, buffer_address=None): + mm_data: EmbeddingData = self.embedding_to_send[req_id] + self.mm_send( + prefill_host, + mm_data.embedding, + mm_data, + session_id, + buffer_address, + ) + del self.embedding_to_send[req_id] app = FastAPI() @@ -198,6 +207,37 @@ def launch_server(server_args: ServerArgs): @app.post("/encode") -async def handle_encode_request(request_data: dict): - ret_data = await encoder.step(request_data) - return ORJSONResponse(content=ret_data) +async def handle_encode_request(request: dict): + nbytes, embedding_len, embedding_dim = await encoder.encode( + mm_items=request["mm_items"], + req_id=request["req_id"], + num_parts=request["num_parts"], + part_idx=request["part_idx"], + ) + if encoder.server_args.mm_transfer_backend == "mooncake": + del request["mm_items"] + request.update( + { + "embedding_size": nbytes, + "embedding_len": embedding_len, + "embedding_dim": embedding_dim, + } + ) + return ORJSONResponse(content=request) + elif encoder.server_args.mm_transfer_backend == "zmq": + await encoder.send( + req_id=request["req_id"], + prefill_host=request["bootstrap_host"], + ) + return ORJSONResponse(content=None) + + +@app.post("/send") +async def handle_send_request(request: dict): + await encoder.send( + req_id=request["req_id"], + prefill_host=request["bootstrap_host"], + session_id=request["session_id"], + buffer_address=request["buffer_address"], + ) + return ORJSONResponse(content=None) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 6453323df368..b2cb345d8b0f 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -321,14 +321,15 @@ def __init__( self.recv_from_encoder = get_zmq_socket( context, zmq.PULL, f"tcp://*:{server_args.embedding_port}", True ) - self.received_metadata = dict() + self.received_data = dict() self.embeddings_lock = asyncio.Lock() - self.embeddings_engine = MooncakeTransferEngine( - hostname=get_local_ip_auto(), - gpu_id=0, - ib_device=server_args.disaggregation_ib_device, - ) - self.embeddings_buffer = dict() + if self.server_args.mm_transfer_backend == "mooncake": + self.embeddings_engine = MooncakeTransferEngine( + hostname=get_local_ip_auto(), + gpu_id=0, + ib_device=server_args.disaggregation_ib_device, + ) + self.embeddings_buffer = dict() # Request states self._chosen_loop = None @@ -755,21 +756,26 @@ async def _tokenize_one_request( # Use async lock to avoid race condition async with self.embeddings_lock: while ( - obj.bootstrap_room not in self.received_metadata - or not self.received_metadata[obj.bootstrap_room].ready + obj.bootstrap_room not in self.received_data + or not self.received_data[obj.bootstrap_room].ready ): await self.handle_embedding() for mm_item in mm_inputs["mm_items"]: if mm_item.modality == Modality.IMAGE: - mm_item.precomputed_embeddings = self.embeddings_buffer[ - obj.bootstrap_room - ] - self.embeddings_engine.deregister( - mm_item.precomputed_embeddings.data_ptr() - ) - break - del self.received_metadata[obj.bootstrap_room] - del self.embeddings_buffer[obj.bootstrap_room] + if self.server_args.mm_transfer_backend == "mooncake": + mm_item.precomputed_embeddings = self.embeddings_buffer[ + obj.bootstrap_room + ] + self.embeddings_engine.deregister( + mm_item.precomputed_embeddings.data_ptr() + ) + elif self.server_args.mm_transfer_backend == "zmq": + mm_item.precomputed_embeddings = self.received_data[ + obj.bootstrap_room + ].get() + del self.received_data[obj.bootstrap_room] + if self.server_args.mm_transfer_backend == "mooncake": + del self.embeddings_buffer[obj.bootstrap_room] else: mm_inputs = None @@ -1596,10 +1602,10 @@ async def handle_loop(self): async def handle_embedding(self): recv_obj = await self.recv_from_encoder.recv_pyobj() - if recv_obj.req_id not in self.received_metadata: - self.received_metadata[recv_obj.req_id] = recv_obj + if recv_obj.req_id not in self.received_data: + self.received_data[recv_obj.req_id] = recv_obj else: - self.received_metadata[recv_obj.req_id].add(recv_obj) + self.received_data[recv_obj.req_id].add(recv_obj) def _add_metric_if_present( self, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 51d63ed17038..2a4c332c35bb 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -135,6 +135,8 @@ DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"] +MM_TRANSFER_BACKEND_CHOICES = ["zmq", "mooncake"] + GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"] DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3", "triton"] @@ -257,7 +259,8 @@ class ServerArgs: # Encode prefill disaggregation mm_only: bool = False language_only: bool = False - embedding_port: Optional[int] = None + embedding_port: Optional[int] = 54213 + mm_transfer_backend: str = "zmq" # Quantization and data type dtype: str = "auto" @@ -2257,9 +2260,16 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--embedding-port", type=int, - default=54213, + default=ServerArgs.embedding_port, help="The port for transmitting embedding metadata.", ) + parser.add_argument( + "--mm-transfer-backend", + type=str, + default=ServerArgs.mm_transfer_backend, + choices=MM_TRANSFER_BACKEND_CHOICES, + help="The backend for encoder disaggregation transfer. Default is zmq.", + ) # Quantization and data type parser.add_argument( diff --git a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py index 66c508dd03fb..74e333e2d6e4 100644 --- a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py +++ b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py @@ -138,7 +138,9 @@ async def embedding_bootstrap( buffer_address = response_json["buffer_address"] return session_id, buffer_address - async def encode(self, request_data, encode_urls, endpoint, prefill_url): + async def encode( + self, request_data, encode_urls, endpoint_encode, endpoint_send, prefill_url + ): messages = request_data.get("messages") if messages is None or len(encode_urls) == 0: return @@ -155,6 +157,8 @@ async def encode(self, request_data, encode_urls, endpoint, prefill_url): return req_id = request_data.get("bootstrap_room") + prefill_host = request_data["bootstrap_host"] + # Split mm_items encode_requests = [] random.shuffle(self.encode_idx) @@ -174,6 +178,7 @@ async def encode(self, request_data, encode_urls, endpoint, prefill_url): "num_parts": num_parts, "part_idx": cum_idx, "req_id": req_id, + "bootstrap_host": prefill_host, } ) cum_idx += 1 @@ -188,7 +193,7 @@ async def encode(self, request_data, encode_urls, endpoint, prefill_url): tasks = [ session.post( - f"{encode_urls[encode_request['encoder_idx']]}/{endpoint}", + f"{encode_urls[encode_request['encoder_idx']]}/{endpoint_encode}", json=encode_request, ) for encode_request in encode_requests @@ -199,7 +204,11 @@ async def encode(self, request_data, encode_urls, endpoint, prefill_url): await response.json() for response in responses ] - # Send bootstrap info + # zmq backend: return is None + if None in response_json_list_unsort: + return + + # mooncake backend: send bootstrap info embedding_size_list_sort = [None for _ in range(num_parts)] embedding_length_tot = 0 @@ -212,7 +221,6 @@ async def encode(self, request_data, encode_urls, endpoint, prefill_url): offset = 0 metadata_tasks = [] - prefill_ip = request_data["bootstrap_host"] session_id, buffer_address = await self.embedding_bootstrap( session, prefill_url, @@ -227,13 +235,12 @@ async def encode(self, request_data, encode_urls, endpoint, prefill_url): { "session_id": session_id, "buffer_address": buffer_address_adjust, - "bootstrap_host": prefill_ip, + "bootstrap_host": prefill_host, } ) - encoder_idx = response_json["encoder_idx"] metadata_tasks.append( session.post( - f"{encode_urls[encoder_idx]}/{endpoint}", + f"{encode_urls[response_json['encoder_idx']]}/{endpoint_send}", json=response_json, ) ) @@ -576,7 +583,7 @@ async def _forward_to_backend(request_data: dict, endpoint_name: str): } ) asyncio.create_task( - lb.encode(encode_request, lb.encode_urls, "encode", prefill_server) + lb.encode(encode_request, lb.encode_urls, "encode", "send", prefill_server) ) modified_request = encode_request.copy() From b8eccf7d0955b8d9a37c3f3ccabcc587c04c0555 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Thu, 6 Nov 2025 02:10:10 +0000 Subject: [PATCH 25/68] Fix embedding_port --- .../sglang/srt/entrypoints/encode_server.py | 28 ++++++++++++++++--- python/sglang/srt/entrypoints/http_server.py | 24 ++++++++++------ .../sglang/srt/managers/tokenizer_manager.py | 4 ++- python/sglang/srt/server_args.py | 7 ----- .../bindings/python/sglang_router/mini_lb.py | 2 ++ 5 files changed, 44 insertions(+), 21 deletions(-) diff --git a/python/sglang/srt/entrypoints/encode_server.py b/python/sglang/srt/entrypoints/encode_server.py index 23ce48144f05..d2c74a517b07 100644 --- a/python/sglang/srt/entrypoints/encode_server.py +++ b/python/sglang/srt/entrypoints/encode_server.py @@ -2,6 +2,7 @@ import logging from typing import Optional +import aiohttp import torch import uvicorn import zmq @@ -141,9 +142,10 @@ async def mm_encode(self, mm_items) -> torch.Tensor: mm_embedding = mm_embedding.reshape(-1, mm_embedding.shape[-1]) return mm_embedding - def mm_send( + async def mm_send( self, prefill_host, + prefill_url, embedding: torch.Tensor, mm_data: EmbeddingData, session_id=None, @@ -163,10 +165,11 @@ def mm_send( if prefill_host in self.send_to_prefill_sockets: socket = self.send_to_prefill_sockets[prefill_host] else: + embedding_port = await self.get_embedding_port(prefill_url) socket = get_zmq_socket( self.context, zmq.PUSH, - f"tcp://{prefill_host}:{self.server_args.embedding_port}", + f"tcp://{prefill_host}:{embedding_port}", False, ) self.send_to_prefill_sockets[prefill_host] = socket @@ -184,10 +187,13 @@ async def encode(self, mm_items, req_id, num_parts, part_idx): self.embedding_to_send[mm_data.req_id] = mm_data return mm_embedding.nbytes, mm_embedding.shape[0], mm_embedding.shape[1] - async def send(self, req_id, prefill_host, session_id=None, buffer_address=None): + async def send( + self, req_id, prefill_host, prefill_url, session_id=None, buffer_address=None + ): mm_data: EmbeddingData = self.embedding_to_send[req_id] - self.mm_send( + await self.mm_send( prefill_host, + prefill_url, mm_data.embedding, mm_data, session_id, @@ -195,6 +201,17 @@ async def send(self, req_id, prefill_host, session_id=None, buffer_address=None) ) del self.embedding_to_send[req_id] + async def get_embedding_port(self, prefill_url): + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=1800) + ) as session: + response = await session.post( + f"{prefill_url}/embedding_bootstrap", + json={"embedding_port": None}, + ) + response_json = await response.json() + return response_json["embedding_port"] + app = FastAPI() encoder: Optional[ImageEncoder] = None @@ -228,15 +245,18 @@ async def handle_encode_request(request: dict): await encoder.send( req_id=request["req_id"], prefill_host=request["bootstrap_host"], + prefill_url=request["prefill_url"], ) return ORJSONResponse(content=None) @app.post("/send") async def handle_send_request(request: dict): + # mooncake backend await encoder.send( req_id=request["req_id"], prefill_host=request["bootstrap_host"], + prefill_url=request["prefill_url"], session_id=request["session_id"], buffer_address=request["buffer_address"], ) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 318a44928968..d9c8cfcabd3a 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -1333,15 +1333,21 @@ async def sagemaker_chat_completions( @app.post("/embedding_bootstrap") async def embedding_bootstrap(request_data: dict): - buffer_address = await _global_state.tokenizer_manager.allocate_embedding_buffer( - request_data["req_id"], - request_data["embedding_length"], - request_data["embedding_dim"], - ) - session_id = _global_state.tokenizer_manager.embeddings_engine.session_id - return ORJSONResponse( - content={"session_id": session_id, "buffer_address": buffer_address} - ) + if "embedding_length" in request_data: + buffer_address = ( + await _global_state.tokenizer_manager.allocate_embedding_buffer( + request_data["req_id"], + request_data["embedding_length"], + request_data["embedding_dim"], + ) + ) + session_id = _global_state.tokenizer_manager.embeddings_engine.session_id + return ORJSONResponse( + content={"session_id": session_id, "buffer_address": buffer_address} + ) + elif "embedding_port" in request_data: + embedding_port = _global_state.tokenizer_manager.embedding_port + return ORJSONResponse(content={"embedding_port": embedding_port}) ## Vertex AI API diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index b2cb345d8b0f..150d4e2b695c 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -103,6 +103,7 @@ freeze_gc, get_bool_env_var, get_or_create_event_loop, + get_free_port, get_local_ip_auto, get_zmq_socket, kill_process_tree, @@ -318,8 +319,9 @@ def __init__( # Recv embedding from encoding server if self.model_config.is_multimodal and self.server_args.language_only: + self.embedding_port = get_free_port() self.recv_from_encoder = get_zmq_socket( - context, zmq.PULL, f"tcp://*:{server_args.embedding_port}", True + context, zmq.PULL, f"tcp://*:{self.embedding_port}", True ) self.received_data = dict() self.embeddings_lock = asyncio.Lock() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 2a4c332c35bb..1fdaafe3eea0 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -259,7 +259,6 @@ class ServerArgs: # Encode prefill disaggregation mm_only: bool = False language_only: bool = False - embedding_port: Optional[int] = 54213 mm_transfer_backend: str = "zmq" # Quantization and data type @@ -2257,12 +2256,6 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="For VLM, load weights for the language model only.", ) - parser.add_argument( - "--embedding-port", - type=int, - default=ServerArgs.embedding_port, - help="The port for transmitting embedding metadata.", - ) parser.add_argument( "--mm-transfer-backend", type=str, diff --git a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py index 74e333e2d6e4..4dc245f80efe 100644 --- a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py +++ b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py @@ -178,6 +178,7 @@ async def encode( "num_parts": num_parts, "part_idx": cum_idx, "req_id": req_id, + "prefill_url": prefill_url, "bootstrap_host": prefill_host, } ) @@ -236,6 +237,7 @@ async def encode( "session_id": session_id, "buffer_address": buffer_address_adjust, "bootstrap_host": prefill_host, + "prefill_url": prefill_url, } ) metadata_tasks.append( From da5283501f494b1591fc7ab152123396ea0fbad8 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Thu, 6 Nov 2025 02:37:18 +0000 Subject: [PATCH 26/68] Fix Qwen3-Omni --- python/sglang/srt/models/qwen3_vl.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py index fdbd96233e79..e976948fb854 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -602,7 +602,7 @@ def __init__( self.use_data_parallel = get_global_server_args().mm_enable_dp_encoder - if not config.language_only: + if not hasattr(config, "language_only") or not config.language_only: self.visual = Qwen3VLMoeVisionModel( config.vision_config, # NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization. @@ -618,10 +618,12 @@ def __init__( self.config: Qwen3VLConfig = config # for qwen3-vl else: self.config = config.text_config # for qwen3-omni - self.config.mm_only = config.mm_only - self.config.language_only = config.language_only + if hasattr(config, "mm_only"): + self.config.mm_only = config.mm_only + if hasattr(config, "language_only"): + self.config.language_only = config.language_only - if not config.mm_only: + if not hasattr(config, "mm_only") or not config.mm_only: self.model = language_model_cls( config=self.config, quant_config=quant_config, From 0321a09365db801d4051671e47b758092d9985ab Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Thu, 6 Nov 2025 03:29:10 +0000 Subject: [PATCH 27/68] Add params check and health check --- python/sglang/srt/entrypoints/encode_server.py | 5 +++++ python/sglang/srt/server_args.py | 15 +++++++++++++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/entrypoints/encode_server.py b/python/sglang/srt/entrypoints/encode_server.py index d2c74a517b07..4059add3f015 100644 --- a/python/sglang/srt/entrypoints/encode_server.py +++ b/python/sglang/srt/entrypoints/encode_server.py @@ -261,3 +261,8 @@ async def handle_send_request(request: dict): buffer_address=request["buffer_address"], ) return ORJSONResponse(content=None) + + +@app.get("/health_check") +async def handle_send_request(): + return ORJSONResponse(content={"is_alive": True}) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 1fdaafe3eea0..13c79dbeb044 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -662,7 +662,10 @@ def __post_init__(self): self._handle_load_format() # Handle PD disaggregation. - self._handle_disaggregation() + self._handle_pd_disaggregation() + + # Handle E disaggregation. + self._handle_e_disaggregation() # Validate tokenizer settings. self._handle_tokenizer_batching() @@ -1860,7 +1863,15 @@ def _handle_load_format(self): ): self.load_format = "auto" - def _handle_disaggregation(self): + def _handle_e_disaggregation(self): + if self.mm_only and self.language_only: + raise ValueError("Cannot set --mm-only and --language-only together") + if self.mm_only and not self.disaggregation_mode == "null": + raise ValueError( + "Cannot set --mm-only and --disaggregation-mode prefill/decode together" + ) + + def _handle_pd_disaggregation(self): if self.disaggregation_mode == "decode": assert ( self.disaggregation_decode_tp is None From d258499c979a48d3d7c24fb1d0db18e079865a7a Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Thu, 6 Nov 2025 06:50:18 +0000 Subject: [PATCH 28/68] Fix gpu_id, send_to_prefill_sockets --- python/sglang/srt/entrypoints/encode_server.py | 12 +++++------- python/sglang/srt/managers/tokenizer_manager.py | 2 +- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/entrypoints/encode_server.py b/python/sglang/srt/entrypoints/encode_server.py index 4059add3f015..e9cd7e70f91b 100644 --- a/python/sglang/srt/entrypoints/encode_server.py +++ b/python/sglang/srt/entrypoints/encode_server.py @@ -92,8 +92,6 @@ def __init__(self, server_args: ServerArgs): else: dist_init_method = f"tcp://127.0.0.1:{port_args.nccl_port}" - self.gpu_id = 0 - init_distributed_environment( world_size=1, rank=0, distributed_init_method=dist_init_method ) @@ -103,7 +101,7 @@ def __init__(self, server_args: ServerArgs): self.model = get_model( model_config=self.model_config, load_config=self.load_config, - device_config=DeviceConfig(device="cuda", gpu_id=self.gpu_id), + device_config=DeviceConfig(), ) logger.info(f"Using transfer backend: {self.server_args.mm_transfer_backend}") @@ -113,7 +111,7 @@ def __init__(self, server_args: ServerArgs): self.engine = MooncakeTransferEngine( hostname=self.local_ip, - gpu_id=self.gpu_id, + gpu_id=None, ib_device=server_args.disaggregation_ib_device, ) @@ -162,8 +160,8 @@ async def mm_send( mm_data.embedding_list[mm_data.part_idx] = None # Send ack/data - if prefill_host in self.send_to_prefill_sockets: - socket = self.send_to_prefill_sockets[prefill_host] + if prefill_url in self.send_to_prefill_sockets: + socket = self.send_to_prefill_sockets[prefill_url] else: embedding_port = await self.get_embedding_port(prefill_url) socket = get_zmq_socket( @@ -172,7 +170,7 @@ async def mm_send( f"tcp://{prefill_host}:{embedding_port}", False, ) - self.send_to_prefill_sockets[prefill_host] = socket + self.send_to_prefill_sockets[prefill_url] = socket socket.send_pyobj(mm_data) @torch.inference_mode() diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 150d4e2b695c..1e14aab54714 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -328,7 +328,7 @@ def __init__( if self.server_args.mm_transfer_backend == "mooncake": self.embeddings_engine = MooncakeTransferEngine( hostname=get_local_ip_auto(), - gpu_id=0, + gpu_id=None, ib_device=server_args.disaggregation_ib_device, ) self.embeddings_buffer = dict() From ecdb6ca084dbd8731028c28a8be88d47fccd2c21 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Thu, 6 Nov 2025 11:57:24 +0000 Subject: [PATCH 29/68] Support dots.vlm --- .../sglang/srt/entrypoints/encode_server.py | 34 ++++++++++++++++--- python/sglang/srt/models/dots_vlm.py | 33 ++++++++++-------- 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/python/sglang/srt/entrypoints/encode_server.py b/python/sglang/srt/entrypoints/encode_server.py index e9cd7e70f91b..dc4221f874e2 100644 --- a/python/sglang/srt/entrypoints/encode_server.py +++ b/python/sglang/srt/entrypoints/encode_server.py @@ -23,7 +23,11 @@ from sglang.srt.layers.dp_attention import initialize_dp_attention from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.model_loader import get_model -from sglang.srt.multimodal.processors.qwen_vl import resize_image_async +from sglang.srt.multimodal.processors.dots_vlm import DotsVLMImageProcessor +from sglang.srt.multimodal.processors.qwen_vl import ( + QwenVLImageProcessor, + resize_image_async, +) from sglang.srt.server_args import ( PortArgs, ServerArgs, @@ -120,12 +124,34 @@ def __init__(self, server_args: ServerArgs): self.embedding_to_send = dict() + # dots-specific: + if self.processor_cls == DotsVLMImageProcessor: + vision_config = self.model_config.hf_config.vision_config + patch_size = vision_config.patch_size + merge_size = vision_config.spatial_merge_size + + self.IMAGE_FACTOR = patch_size * merge_size + self.MIN_PIXELS = self.image_processor.min_pixels + self.MAX_PIXELS = self.image_processor.max_pixels + async def mm_encode(self, mm_items) -> torch.Tensor: images = load_images(mm_items) - # Qwen-specific: resize images - resize_tasks = [resize_image_async(image) for image in images] - images = await asyncio.gather(*resize_tasks) + # resize images + if self.processor_cls and self.processor_cls in [ + QwenVLImageProcessor, + DotsVLMImageProcessor, + ]: + if self.processor_cls == QwenVLImageProcessor: + resize_tasks = [resize_image_async(image) for image in images] + elif self.processor_cls == DotsVLMImageProcessor: + resize_tasks = [ + resize_image_async( + image, self.MIN_PIXELS, self.MAX_PIXELS, self.IMAGE_FACTOR + ) + for image in images + ] + images = await asyncio.gather(*resize_tasks) images_input = self.image_processor(images=images) mm_item = MultimodalDataItem.from_dict( diff --git a/python/sglang/srt/models/dots_vlm.py b/python/sglang/srt/models/dots_vlm.py index 1de27f664645..61a942c2467f 100644 --- a/python/sglang/srt/models/dots_vlm.py +++ b/python/sglang/srt/models/dots_vlm.py @@ -50,12 +50,14 @@ def __init__( self.video_token_id = config.video_span_id self.pp_group = get_pp_group() - self.language_model = DeepseekV2ForCausalLM( - config.language_config, quant_config - ) + if not config.mm_only: + self.language_model = DeepseekV2ForCausalLM( + config.language_config, quant_config + ) - # Initialize vision tower (matching transformers naming for weight compatibility) - self.vision_tower = DotsVisionTransformer(config.vision_config) + if not config.language_only: + # Initialize vision tower (matching transformers naming for weight compatibility) + self.vision_tower = DotsVisionTransformer(config.vision_config) def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor): """pad attn qkv weights for dummy heads""" @@ -104,18 +106,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): language_weights.append((name, loaded_weight)) # Load vision tower weights - vision_state_dict = dict(vision_weights) - params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in vision_state_dict.items(): - if name not in params_dict: - raise ValueError(f"Weight {name} not found in params_dict") - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - loaded_weight = self._pad_vit_attn_dummy_heads(name, loaded_weight) - weight_loader(param, loaded_weight) + if not self.config.language_only: + vision_state_dict = dict(vision_weights) + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in vision_state_dict.items(): + if name not in params_dict: + raise ValueError(f"Weight {name} not found in params_dict") + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = self._pad_vit_attn_dummy_heads(name, loaded_weight) + weight_loader(param, loaded_weight) # Load language model weights - if language_weights: + if not self.config.mm_only and language_weights: self.language_model.load_weights(language_weights) @classmethod From a542377f92580b8c5c7766de9b1de89fcb469df4 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Fri, 7 Nov 2025 07:16:57 +0000 Subject: [PATCH 30/68] Save prefill preprocess time for E disaggregation --- .../sglang/srt/entrypoints/encode_server.py | 32 +++++++++-- .../sglang/srt/managers/tokenizer_manager.py | 53 +++++++++--------- .../multimodal/processors/base_processor.py | 55 +++++++++++++++++++ .../srt/multimodal/processors/dots_vlm.py | 4 +- .../srt/multimodal/processors/gemma3.py | 1 + .../srt/multimodal/processors/gemma3n.py | 1 + .../sglang/srt/multimodal/processors/glm4v.py | 4 +- .../srt/multimodal/processors/qwen_vl.py | 38 +++++++++++++ 8 files changed, 155 insertions(+), 33 deletions(-) diff --git a/python/sglang/srt/entrypoints/encode_server.py b/python/sglang/srt/entrypoints/encode_server.py index dc4221f874e2..a30f54618229 100644 --- a/python/sglang/srt/entrypoints/encode_server.py +++ b/python/sglang/srt/entrypoints/encode_server.py @@ -39,10 +39,11 @@ class EmbeddingData: - def __init__(self, req_id, num_parts, part_idx, embedding=None): + def __init__(self, req_id, num_parts, part_idx, image_grid_dim, embedding=None): self.req_id = req_id self.num_parts = num_parts self.part_idx = part_idx + self.image_grid_dim = image_grid_dim self.embedding = embedding # aggregated data @@ -50,16 +51,26 @@ def __init__(self, req_id, num_parts, part_idx, embedding=None): self.embedding_list = [ embedding if i == self.part_idx else None for i in range(self.num_parts) ] + self.image_grid_dim_list = [ + self.image_grid_dim if i == self.part_idx else None + for i in range(self.num_parts) + ] def add(self, embedding_data): assert self.req_id == embedding_data.req_id assert not self.ready_list[embedding_data.part_idx] self.ready_list[embedding_data.part_idx] = True + self.image_grid_dim_list[embedding_data.part_idx] = ( + embedding_data.image_grid_dim + ) self.embedding_list[embedding_data.part_idx] = embedding_data.embedding - def get(self): + def get_embedding(self): return torch.concatenate(self.embedding_list) + def get_img_grid(self): + return torch.concatenate(self.image_grid_dim_list) + @property def ready(self): return sum(self.ready_list) == self.num_parts @@ -68,6 +79,18 @@ def __repr__(self): return f"EmbeddingData(req_id={self.req_id}, num_parts={self.num_parts}, part_idx={self.part_idx})" +_image_grid_attrs = ["image_grid_thw", "image_grid_hws"] + + +def _get_image_grid_dim(images_input): + for attr in _image_grid_attrs: + if attr in images_input: + return images_input[attr] + raise ValueError( + f"Image grid dim ({_image_grid_attrs}) not found in {images_input}" + ) + + class ImageEncoder: def __init__(self, server_args: ServerArgs): self.server_args = server_args @@ -164,7 +187,7 @@ async def mm_encode(self, mm_items) -> torch.Tensor: mm_embedding = self.model.get_image_feature([mm_item]) if len(mm_embedding.shape) != 2: mm_embedding = mm_embedding.reshape(-1, mm_embedding.shape[-1]) - return mm_embedding + return _get_image_grid_dim(images_input), mm_embedding async def mm_send( self, @@ -201,11 +224,12 @@ async def mm_send( @torch.inference_mode() async def encode(self, mm_items, req_id, num_parts, part_idx): - mm_embedding = await self.mm_encode(mm_items) + image_grid_dim, mm_embedding = await self.mm_encode(mm_items) mm_data = EmbeddingData( req_id, num_parts, part_idx, + image_grid_dim, mm_embedding, ) self.embedding_to_send[mm_data.req_id] = mm_data diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 1e14aab54714..570d7d9d0c01 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -75,7 +75,6 @@ from sglang.srt.managers.mm_utils import TensorTransportMode from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors from sglang.srt.managers.request_metrics_exporter import RequestMetricsExporterManager -from sglang.srt.managers.scheduler import is_health_check_generate_req from sglang.srt.managers.schedule_batch import Modality, RequestStage from sglang.srt.managers.scheduler import is_health_check_generate_req from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region @@ -317,7 +316,7 @@ def __init__( # Make sure that each request carries the tokenizer_ipc_name for response routing self.send_to_scheduler = SenderWrapper(port_args, send_to_scheduler) - # Recv embedding from encoding server + # E Disaggregation if self.model_config.is_multimodal and self.server_args.language_only: self.embedding_port = get_free_port() self.recv_from_encoder = get_zmq_socket( @@ -744,17 +743,19 @@ async def _tokenize_one_request( obj.image_data = [obj.image_data] if obj.audio_data is not None and not isinstance(obj.audio_data, list): obj.audio_data = [obj.audio_data] - mm_inputs: Dict = await self.mm_data_processor.process( - image_data=obj.image_data, - audio_data=obj.audio_data, - input_text_or_ids=(input_text or input_ids), - request_obj=obj, - max_req_input_len=self.max_req_input_len, - ) - if mm_inputs and "input_ids" in mm_inputs: - input_ids = mm_inputs["input_ids"] - if self.server_args.language_only: + if not self.server_args.language_only: + mm_inputs: Dict = await self.mm_data_processor.process( + image_data=obj.image_data, + audio_data=obj.audio_data, + input_text_or_ids=(input_text or input_ids), + request_obj=obj, + max_req_input_len=self.max_req_input_len, + ) + else: + # E Disaggregation + recv_embedding = None + img_grid_thw = None # Use async lock to avoid race condition async with self.embeddings_lock: while ( @@ -762,23 +763,25 @@ async def _tokenize_one_request( or not self.received_data[obj.bootstrap_room].ready ): await self.handle_embedding() - for mm_item in mm_inputs["mm_items"]: - if mm_item.modality == Modality.IMAGE: - if self.server_args.mm_transfer_backend == "mooncake": - mm_item.precomputed_embeddings = self.embeddings_buffer[ - obj.bootstrap_room - ] - self.embeddings_engine.deregister( - mm_item.precomputed_embeddings.data_ptr() - ) - elif self.server_args.mm_transfer_backend == "zmq": - mm_item.precomputed_embeddings = self.received_data[ - obj.bootstrap_room - ].get() + + recv_embedding_data = self.received_data[obj.bootstrap_room] + if self.server_args.mm_transfer_backend == "mooncake": + recv_embedding = self.embeddings_buffer[obj.bootstrap_room] + self.embeddings_engine.deregister(recv_embedding.data_ptr()) + elif self.server_args.mm_transfer_backend == "zmq": + recv_embedding = recv_embedding_data.get_embedding() + img_grid_thw = recv_embedding_data.get_img_grid() del self.received_data[obj.bootstrap_room] if self.server_args.mm_transfer_backend == "mooncake": del self.embeddings_buffer[obj.bootstrap_room] + prompt = input_text or input_ids + mm_inputs = self.mm_processor.get_mm_data( + prompt, recv_embedding, img_grid_thw + ) + + if mm_inputs and "input_ids" in mm_inputs: + input_ids = mm_inputs["input_ids"] else: mm_inputs = None diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index 370aec2b65ab..522ad9a64092 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -231,6 +231,61 @@ def __init__( MM_ITEM_MEMORY_POOL_RECYCLE_INTERVAL, ) + @property + def spatial_merge_size(self): + return self.hf_config.vision_config.spatial_merge_size + + def get_input_ids(self, prompt, img_grid_thw): + """ + Use prompt and img_grid_thw to build input_ids + """ + if not isinstance(prompt, list): + prompt = self._processor.tokenizer.encode(prompt) + + img_start_id = self.IM_START_TOKEN_ID + img_id = self.IM_TOKEN_ID + spatial_merge_size = self.spatial_merge_size + + input_ids = [] + offsets = [] + + cur_idx = 0 + img_start_indices = list( + filter(lambda i: prompt[i] == img_start_id, range(len(prompt))) + ) + for cur_img_idx, img_start_idx in enumerate(img_start_indices): + assert cur_idx <= img_start_idx + # include img_start_id + input_ids.extend(prompt[cur_idx : img_start_idx + 1]) + img_offset_start = len(input_ids) + img_token_num = img_grid_thw[cur_img_idx].prod() // (spatial_merge_size**2) + input_ids.extend([img_id] * img_token_num) + # jump to img_end_id + cur_idx = img_start_idx + 2 + offsets.append((img_offset_start, len(input_ids) - 1)) + else: + input_ids.extend(prompt[cur_idx:]) + + return input_ids, offsets + + def get_mm_data(self, prompt, embeddings, img_grid_thw): + input_ids, offsets = self.get_input_ids(prompt, img_grid_thw) + mm_items = [ + MultimodalDataItem( + modality=Modality.IMAGE, + offsets=offsets, + precomputed_embeddings=embeddings, + ) + ] + + return { + "input_ids": input_ids, + "mm_items": mm_items, + "im_start_id": self.IM_START_TOKEN_ID, + "im_end_id": self.IM_END_TOKEN_ID, + "im_token_id": self.IM_TOKEN_ID, + } + def process_mm_data( self, input_text, images=None, videos=None, audios=None, **kwargs ) -> dict: diff --git a/python/sglang/srt/multimodal/processors/dots_vlm.py b/python/sglang/srt/multimodal/processors/dots_vlm.py index 65752244da39..8d6faf5e8748 100644 --- a/python/sglang/srt/multimodal/processors/dots_vlm.py +++ b/python/sglang/srt/multimodal/processors/dots_vlm.py @@ -24,8 +24,8 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): self.im_end_id = _processor.tokenizer.encode("<|endofimg|>")[0] self.image_token_id = _processor.tokenizer.encode("<|imgpad|>")[0] self.IM_TOKEN_ID = self.image_token_id - self.IM_START_ID = self.im_start_id - self.IM_END_ID = self.im_end_id + self.IM_START_TOKEN_ID = self.im_start_id + self.IM_END_TOKEN_ID = self.im_end_id vision_config = hf_config.vision_config patch_size = vision_config.patch_size diff --git a/python/sglang/srt/multimodal/processors/gemma3.py b/python/sglang/srt/multimodal/processors/gemma3.py index cbfb45e8404e..555eba8ec323 100644 --- a/python/sglang/srt/multimodal/processors/gemma3.py +++ b/python/sglang/srt/multimodal/processors/gemma3.py @@ -18,6 +18,7 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): super().__init__(hf_config, server_args, _processor, *args, **kwargs) self.IM_START_TOKEN_ID = hf_config.boi_token_index self.IM_END_TOKEN_ID = hf_config.eoi_token_index + self.IM_TOKEN_ID = hf_config.image_token_index self.mm_tokens = MultimodalSpecialTokens( # The single, pre-expanded image token. image_token="", diff --git a/python/sglang/srt/multimodal/processors/gemma3n.py b/python/sglang/srt/multimodal/processors/gemma3n.py index 9ea8b8be3662..97d1987c8b0c 100644 --- a/python/sglang/srt/multimodal/processors/gemma3n.py +++ b/python/sglang/srt/multimodal/processors/gemma3n.py @@ -31,6 +31,7 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): self.IM_START_TOKEN_ID = hf_config.boi_token_id self.IM_END_TOKEN_ID = hf_config.eoi_token_id + self.IM_TOKEN_ID = hf_config.image_token_id self.AUDIO_START_TOKEN_ID = hf_config.boa_token_id self.AUDIO_END_TOKEN_ID = hf_config.eoa_token_id diff --git a/python/sglang/srt/multimodal/processors/glm4v.py b/python/sglang/srt/multimodal/processors/glm4v.py index 80d717a7ad76..192143c5067f 100644 --- a/python/sglang/srt/multimodal/processors/glm4v.py +++ b/python/sglang/srt/multimodal/processors/glm4v.py @@ -26,8 +26,8 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): # Token IDs self.IM_TOKEN_ID = hf_config.image_token_id self.VIDEO_TOKEN_ID = hf_config.video_token_id - self.IMAGE_START_TOKEN_ID = hf_config.image_start_token_id - self.IMAGE_END_TOKEN_ID = hf_config.image_end_token_id + self.IM_START_TOKEN_ID = hf_config.image_start_token_id + self.IM_END_TOKEN_ID = hf_config.image_end_token_id self.VIDEO_START_TOKEN_ID = hf_config.video_start_token_id self.VIDEO_END_TOKEN_ID = hf_config.video_end_token_id diff --git a/python/sglang/srt/multimodal/processors/qwen_vl.py b/python/sglang/srt/multimodal/processors/qwen_vl.py index 80983f2459d7..41eadb66673a 100644 --- a/python/sglang/srt/multimodal/processors/qwen_vl.py +++ b/python/sglang/srt/multimodal/processors/qwen_vl.py @@ -12,6 +12,7 @@ from sglang.srt.environ import envs from sglang.srt.layers.rotary_embedding import MRotaryEmbedding +from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration from sglang.srt.models.qwen3_omni_moe import Qwen3OmniMoeForConditionalGeneration @@ -237,6 +238,7 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): self.IM_START_TOKEN_ID = hf_config.vision_start_token_id self.IM_END_TOKEN_ID = hf_config.vision_end_token_id + self.IM_TOKEN_ID = hf_config.image_token_id self.vision_start_token_id = hf_config.vision_start_token_id self.vision_end_token_id = hf_config.vision_end_token_id @@ -257,6 +259,42 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): audio_token_id=self.audio_token_id, ).build(_processor) + def get_mm_data(self, prompt, embeddings, img_grid_thw): + input_ids, offsets = self.get_input_ids(prompt, img_grid_thw) + mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index( + spatial_merge_size=self.hf_config.vision_config.spatial_merge_size, + image_token_id=self.mm_tokens.image_token_id, + video_token_id=self.mm_tokens.video_token_id, + vision_start_token_id=self.vision_start_token_id, + model_type=self.model_type, + input_ids=torch.tensor(input_ids, dtype=torch.long).unsqueeze(0), + image_grid_thw=img_grid_thw, + tokens_per_second=getattr( + self.hf_config.vision_config, "tokens_per_second", None + ), + ) + mrope_positions = mrope_positions.squeeze(1) + + mm_items = [ + MultimodalDataItem( + modality=Modality.IMAGE, + offsets=offsets, + precomputed_embeddings=embeddings, + ) + ] + + return { + "input_ids": input_ids, + "mm_items": mm_items, + "im_start_id": self.IM_START_TOKEN_ID, + "im_end_id": self.IM_END_TOKEN_ID, + "im_token_id": self.mm_tokens.image_token_id, + "video_token_id": self.mm_tokens.video_token_id, + "audio_token_id": self.mm_tokens.audio_token_id, + "mrope_positions": mrope_positions, + "mrope_position_delta": mrope_position_delta, + } + async def process_mm_data_async( self, image_data: List[Union[str, bytes]], From 7c1500e44b1b4c757ab1e7b1dd40531b88327d07 Mon Sep 17 00:00:00 2001 From: liusy58 Date: Fri, 7 Nov 2025 22:36:54 +0800 Subject: [PATCH 31/68] feat: remove image URLs from prefill requests in EPD mode to reduce communication overhead --- .../bindings/python/sglang_router/mini_lb.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py index 4dc245f80efe..fd482961eb2f 100644 --- a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py +++ b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py @@ -48,6 +48,20 @@ def maybe_wrap_ipv6_address(address: str) -> str: return address +def clear_image_urls(data): + if 'messages' not in data: + return data + import copy + data_copy = copy.deepcopy(data) + for message in data_copy['messages']: + if 'content' in message and isinstance(message['content'], list): + for content_item in message['content']: + if isinstance(content_item, dict): + if content_item.get('type') == 'image_url' and 'image_url' in content_item: + if 'url' in content_item['image_url']: + content_item['image_url']['url'] = '' + return data_copy + class MiniLoadBalancer: def __init__( self, @@ -341,7 +355,10 @@ async def stream_results(): headers = {"trace_context": trace_context} tasks = [ - session.post(f"{prefill_server}/{endpoint}", json=modified_request) + session.post( + f"{prefill_server}/{endpoint}", + json=clear_image_urls(modified_request) if self.encode_urls else modified_request + ) ] if decode_server is not None: tasks.append( From 82258df4f45f92e954832399a87767d86d1c506c Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Tue, 11 Nov 2025 07:28:40 +0000 Subject: [PATCH 32/68] lint --- .../bindings/python/sglang_router/mini_lb.py | 27 ++++++++++++------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py index fd482961eb2f..2949cf036d59 100644 --- a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py +++ b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py @@ -49,19 +49,24 @@ def maybe_wrap_ipv6_address(address: str) -> str: def clear_image_urls(data): - if 'messages' not in data: + if "messages" not in data: return data import copy + data_copy = copy.deepcopy(data) - for message in data_copy['messages']: - if 'content' in message and isinstance(message['content'], list): - for content_item in message['content']: + for message in data_copy["messages"]: + if "content" in message and isinstance(message["content"], list): + for content_item in message["content"]: if isinstance(content_item, dict): - if content_item.get('type') == 'image_url' and 'image_url' in content_item: - if 'url' in content_item['image_url']: - content_item['image_url']['url'] = '' + if ( + content_item.get("type") == "image_url" + and "image_url" in content_item + ): + if "url" in content_item["image_url"]: + content_item["image_url"]["url"] = "" return data_copy + class MiniLoadBalancer: def __init__( self, @@ -356,8 +361,12 @@ async def stream_results(): tasks = [ session.post( - f"{prefill_server}/{endpoint}", - json=clear_image_urls(modified_request) if self.encode_urls else modified_request + f"{prefill_server}/{endpoint}", + json=( + clear_image_urls(modified_request) + if self.encode_urls + else modified_request + ), ) ] if decode_server is not None: From a96594d9e31df2e2094ccdda8f9e31dcb8ba0e06 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Wed, 19 Nov 2025 13:45:46 +0000 Subject: [PATCH 33/68] Rebase and clean up code --- python/sglang/launch_server.py | 2 +- .../srt/disaggregation/encode_receiver.py | 136 ++++++++++++++++++ .../encode_server.py | 33 +---- python/sglang/srt/entrypoints/http_server.py | 15 +- .../sglang/srt/managers/tokenizer_manager.py | 76 ++-------- python/sglang/srt/models/dots_vlm.py | 5 +- python/sglang/srt/models/qwen2_5_vl.py | 23 ++- python/sglang/srt/models/qwen3_vl.py | 19 ++- 8 files changed, 181 insertions(+), 128 deletions(-) create mode 100644 python/sglang/srt/disaggregation/encode_receiver.py rename python/sglang/srt/{entrypoints => disaggregation}/encode_server.py (88%) diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 3d3d19b14442..41f36e193853 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -15,7 +15,7 @@ def run_server(server_args): asyncio.run(serve_grpc(server_args)) elif server_args.mm_only: - from sglang.srt.entrypoints.encode_server import launch_server + from python.sglang.srt.disaggregation.encode_server import launch_server launch_server(server_args) else: diff --git a/python/sglang/srt/disaggregation/encode_receiver.py b/python/sglang/srt/disaggregation/encode_receiver.py new file mode 100644 index 000000000000..25fb076a3311 --- /dev/null +++ b/python/sglang/srt/disaggregation/encode_receiver.py @@ -0,0 +1,136 @@ +import asyncio +import logging +from typing import Dict + +import torch +import zmq +import zmq.asyncio + +from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine +from sglang.srt.utils import get_free_port, get_local_ip_auto, get_zmq_socket + +logger = logging.getLogger(__name__) + + +class EmbeddingData: + def __init__(self, req_id, num_parts, part_idx, image_grid_dim, embedding=None): + self.req_id = req_id + self.num_parts = num_parts + self.part_idx = part_idx + self.image_grid_dim = image_grid_dim + self.embedding = embedding + + # aggregated data + self.ready_list = [i == self.part_idx for i in range(self.num_parts)] + self.embedding_list = [ + embedding if i == self.part_idx else None for i in range(self.num_parts) + ] + self.image_grid_dim_list = [ + self.image_grid_dim if i == self.part_idx else None + for i in range(self.num_parts) + ] + + def add(self, embedding_data): + assert self.req_id == embedding_data.req_id + assert not self.ready_list[embedding_data.part_idx] + self.ready_list[embedding_data.part_idx] = True + self.image_grid_dim_list[embedding_data.part_idx] = ( + embedding_data.image_grid_dim + ) + self.embedding_list[embedding_data.part_idx] = embedding_data.embedding + + def get_embedding(self): + return torch.concatenate(self.embedding_list) + + def get_img_grid(self): + return torch.concatenate(self.image_grid_dim_list) + + @property + def ready(self): + return sum(self.ready_list) == self.num_parts + + def __repr__(self): + return f"EmbeddingData(req_id={self.req_id}, num_parts={self.num_parts}, part_idx={self.part_idx})" + + +class MMReceiver: + + def __init__(self, mm_transfer_backend, disaggregation_ib_device, dtype): + context = zmq.asyncio.Context(2) + self.embedding_port = get_free_port() + self.recv_from_encoder = get_zmq_socket( + context, zmq.PULL, f"tcp://*:{self.embedding_port}", True + ) + self.received_data: Dict[int, EmbeddingData] = dict() + self.embeddings_lock = asyncio.Lock() + self.mm_transfer_backend = mm_transfer_backend + self.dtype = dtype + if self.mm_transfer_backend == "mooncake": + self.embeddings_engine = MooncakeTransferEngine( + hostname=get_local_ip_auto(), + gpu_id=None, + ib_device=disaggregation_ib_device, + ) + self.embeddings_buffer = dict() + + async def handle_embedding(self): + recv_obj = await self.recv_from_encoder.recv_pyobj() + if recv_obj.req_id not in self.received_data: + self.received_data[recv_obj.req_id] = recv_obj + else: + self.received_data[recv_obj.req_id].add(recv_obj) + + async def allocate_embedding_buffer(self, req_id, embedding_length, embedding_dim): + embeddings = torch.zeros( + (embedding_length, embedding_dim), + dtype=self.dtype, + ) + self.embeddings_engine.register( + embeddings.data_ptr(), + embeddings.nbytes, + ) + self.embeddings_buffer[req_id] = embeddings + return embeddings.data_ptr() + + async def recv_mm_data(self, req_id, mm_processor, prompt): + try: + return await asyncio.wait_for( + self._recv_mm_data(req_id, mm_processor, prompt), timeout=10 + ) + except asyncio.TimeoutError: + logger.warning(f"Embedding recv timeout for request {req_id}") + if req_id in self.received_data: + del self.received_data[req_id] + if hasattr(self, "embeddings_buffer") and req_id in self.embeddings_buffer: + del self.embeddings_buffer[req_id] + return None + + async def _recv_mm_data(self, req_id, mm_processor, prompt): + # Bypass MMReceiver + if req_id is None: + return None + + # E Disaggregation + recv_embedding = None + img_grid_thw = None + + # Use async lock to avoid race condition + async with self.embeddings_lock: + while ( + req_id not in self.received_data or not self.received_data[req_id].ready + ): + await self.handle_embedding() + + recv_embedding_data = self.received_data[req_id] + if self.mm_transfer_backend == "mooncake": + recv_embedding = self.embeddings_buffer[req_id] + self.embeddings_engine.deregister(recv_embedding.data_ptr()) + elif self.mm_transfer_backend == "zmq": + recv_embedding = recv_embedding_data.get_embedding() + img_grid_thw = recv_embedding_data.get_img_grid() + del self.received_data[req_id] + if self.mm_transfer_backend == "mooncake": + del self.embeddings_buffer[req_id] + + mm_inputs = mm_processor.get_mm_data(prompt, recv_embedding, img_grid_thw) + return mm_inputs diff --git a/python/sglang/srt/entrypoints/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py similarity index 88% rename from python/sglang/srt/entrypoints/encode_server.py rename to python/sglang/srt/disaggregation/encode_server.py index a30f54618229..4f1a1ed38d35 100644 --- a/python/sglang/srt/entrypoints/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -1,4 +1,3 @@ -import asyncio import logging from typing import Optional @@ -12,6 +11,7 @@ from transformers import AutoImageProcessor from transformers.image_utils import load_images +from python.sglang.srt.disaggregation.encode_receiver import EmbeddingData from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import ModelConfig @@ -23,11 +23,6 @@ from sglang.srt.layers.dp_attention import initialize_dp_attention from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.model_loader import get_model -from sglang.srt.multimodal.processors.dots_vlm import DotsVLMImageProcessor -from sglang.srt.multimodal.processors.qwen_vl import ( - QwenVLImageProcessor, - resize_image_async, -) from sglang.srt.server_args import ( PortArgs, ServerArgs, @@ -147,35 +142,9 @@ def __init__(self, server_args: ServerArgs): self.embedding_to_send = dict() - # dots-specific: - if self.processor_cls == DotsVLMImageProcessor: - vision_config = self.model_config.hf_config.vision_config - patch_size = vision_config.patch_size - merge_size = vision_config.spatial_merge_size - - self.IMAGE_FACTOR = patch_size * merge_size - self.MIN_PIXELS = self.image_processor.min_pixels - self.MAX_PIXELS = self.image_processor.max_pixels - async def mm_encode(self, mm_items) -> torch.Tensor: images = load_images(mm_items) - # resize images - if self.processor_cls and self.processor_cls in [ - QwenVLImageProcessor, - DotsVLMImageProcessor, - ]: - if self.processor_cls == QwenVLImageProcessor: - resize_tasks = [resize_image_async(image) for image in images] - elif self.processor_cls == DotsVLMImageProcessor: - resize_tasks = [ - resize_image_async( - image, self.MIN_PIXELS, self.MAX_PIXELS, self.IMAGE_FACTOR - ) - for image in images - ] - images = await asyncio.gather(*resize_tasks) - images_input = self.image_processor(images=images) mm_item = MultimodalDataItem.from_dict( { diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index d9c8cfcabd3a..2a1eae4fc943 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -1333,20 +1333,19 @@ async def sagemaker_chat_completions( @app.post("/embedding_bootstrap") async def embedding_bootstrap(request_data: dict): + mm_receiver = _global_state.tokenizer_manager.mm_receiver if "embedding_length" in request_data: - buffer_address = ( - await _global_state.tokenizer_manager.allocate_embedding_buffer( - request_data["req_id"], - request_data["embedding_length"], - request_data["embedding_dim"], - ) + buffer_address = await mm_receiver.allocate_embedding_buffer( + request_data["req_id"], + request_data["embedding_length"], + request_data["embedding_dim"], ) - session_id = _global_state.tokenizer_manager.embeddings_engine.session_id + session_id = mm_receiver.embeddings_engine.session_id return ORJSONResponse( content={"session_id": session_id, "buffer_address": buffer_address} ) elif "embedding_port" in request_data: - embedding_port = _global_state.tokenizer_manager.embedding_port + embedding_port = mm_receiver.embedding_port return ORJSONResponse(content={"embedding_port": embedding_port}) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 570d7d9d0c01..22f2b101bac9 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -39,8 +39,8 @@ import zmq.asyncio from fastapi import BackgroundTasks +from python.sglang.srt.disaggregation.encode_receiver import MMReceiver from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.environ import envs from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry @@ -75,7 +75,7 @@ from sglang.srt.managers.mm_utils import TensorTransportMode from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors from sglang.srt.managers.request_metrics_exporter import RequestMetricsExporterManager -from sglang.srt.managers.schedule_batch import Modality, RequestStage +from sglang.srt.managers.schedule_batch import RequestStage from sglang.srt.managers.scheduler import is_health_check_generate_req from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin @@ -318,19 +318,11 @@ def __init__( # E Disaggregation if self.model_config.is_multimodal and self.server_args.language_only: - self.embedding_port = get_free_port() - self.recv_from_encoder = get_zmq_socket( - context, zmq.PULL, f"tcp://*:{self.embedding_port}", True + self.mm_receiver = MMReceiver( + server_args.mm_transfer_backend, + server_args.disaggregation_ib_device, + self.model_config.dtype, ) - self.received_data = dict() - self.embeddings_lock = asyncio.Lock() - if self.server_args.mm_transfer_backend == "mooncake": - self.embeddings_engine = MooncakeTransferEngine( - hostname=get_local_ip_auto(), - gpu_id=None, - ib_device=server_args.disaggregation_ib_device, - ) - self.embeddings_buffer = dict() # Request states self._chosen_loop = None @@ -576,18 +568,6 @@ def _detect_input_format( return "batch_strings" - async def allocate_embedding_buffer(self, req_id, embedding_length, embedding_dim): - embeddings = torch.zeros( - (embedding_length, embedding_dim), - dtype=self.model_config.dtype, - ) - self.embeddings_engine.register( - embeddings.data_ptr(), - embeddings.nbytes, - ) - self.embeddings_buffer[req_id] = embeddings - return embeddings.data_ptr() - def _prepare_tokenizer_input( self, texts: Union[str, List[str]], input_format: str ) -> Union[List[str], List[List[str]]]: @@ -744,7 +724,15 @@ async def _tokenize_one_request( if obj.audio_data is not None and not isinstance(obj.audio_data, list): obj.audio_data = [obj.audio_data] - if not self.server_args.language_only: + mm_inputs = None + if self.server_args.language_only: + mm_inputs: Dict = await self.mm_receiver.recv_mm_data( + obj.bootstrap_room, + self.mm_processor, + input_text or input_ids, + ) + + if mm_inputs is None: mm_inputs: Dict = await self.mm_data_processor.process( image_data=obj.image_data, audio_data=obj.audio_data, @@ -752,33 +740,6 @@ async def _tokenize_one_request( request_obj=obj, max_req_input_len=self.max_req_input_len, ) - else: - # E Disaggregation - recv_embedding = None - img_grid_thw = None - # Use async lock to avoid race condition - async with self.embeddings_lock: - while ( - obj.bootstrap_room not in self.received_data - or not self.received_data[obj.bootstrap_room].ready - ): - await self.handle_embedding() - - recv_embedding_data = self.received_data[obj.bootstrap_room] - if self.server_args.mm_transfer_backend == "mooncake": - recv_embedding = self.embeddings_buffer[obj.bootstrap_room] - self.embeddings_engine.deregister(recv_embedding.data_ptr()) - elif self.server_args.mm_transfer_backend == "zmq": - recv_embedding = recv_embedding_data.get_embedding() - img_grid_thw = recv_embedding_data.get_img_grid() - del self.received_data[obj.bootstrap_room] - if self.server_args.mm_transfer_backend == "mooncake": - del self.embeddings_buffer[obj.bootstrap_room] - - prompt = input_text or input_ids - mm_inputs = self.mm_processor.get_mm_data( - prompt, recv_embedding, img_grid_thw - ) if mm_inputs and "input_ids" in mm_inputs: input_ids = mm_inputs["input_ids"] @@ -1605,13 +1566,6 @@ async def handle_loop(self): self._result_dispatcher(recv_obj) self.last_receive_tstamp = time.time() - async def handle_embedding(self): - recv_obj = await self.recv_from_encoder.recv_pyobj() - if recv_obj.req_id not in self.received_data: - self.received_data[recv_obj.req_id] = recv_obj - else: - self.received_data[recv_obj.req_id].add(recv_obj) - def _add_metric_if_present( self, recv_obj: Any, diff --git a/python/sglang/srt/models/dots_vlm.py b/python/sglang/srt/models/dots_vlm.py index 61a942c2467f..a1011b6a9ebc 100644 --- a/python/sglang/srt/models/dots_vlm.py +++ b/python/sglang/srt/models/dots_vlm.py @@ -55,9 +55,8 @@ def __init__( config.language_config, quant_config ) - if not config.language_only: - # Initialize vision tower (matching transformers naming for weight compatibility) - self.vision_tower = DotsVisionTransformer(config.vision_config) + # Initialize vision tower (matching transformers naming for weight compatibility) + self.vision_tower = DotsVisionTransformer(config.vision_config) def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor): """pad attn qkv weights for dummy heads""" diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index cb9ea783e498..5f6c12aad955 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -475,7 +475,6 @@ def __init__( self.pp_group = get_pp_group() self.config = config -<<<<<<< HEAD self.use_data_parallel = get_global_server_args().mm_enable_dp_encoder if self.pp_group.is_last_rank: @@ -491,18 +490,15 @@ def __init__( else: # ranks other than the last rank will have a placeholder layer self.lm_head = PPMissingLayer() -======= ->>>>>>> e9fbeb706 (Format) - - if not self.config.language_only: - self.visual = Qwen2_5_VisionTransformer( - config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - # NOTE: Qwen2_5-VL vision encoder currently supports BitsAndBytes 4-bit quantization. - # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported. - quant_config=quant_config, - prefix=add_prefix("visual", prefix), - ) + + self.visual = Qwen2_5_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + # NOTE: Qwen2_5-VL vision encoder currently supports BitsAndBytes 4-bit quantization. + # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported. + quant_config=quant_config, + prefix=add_prefix("visual", prefix), + ) if not self.config.mm_only: self.model = Qwen2Model( @@ -666,6 +662,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): layer_id = get_layer_id(name) if ( layer_id is not None + and hasattr(self, "model") and hasattr(self.model, "start_layer") and ( layer_id < self.model.start_layer diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py index e976948fb854..94d7b12782f1 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -602,16 +602,15 @@ def __init__( self.use_data_parallel = get_global_server_args().mm_enable_dp_encoder - if not hasattr(config, "language_only") or not config.language_only: - self.visual = Qwen3VLMoeVisionModel( - config.vision_config, - # NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization. - # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported. - quant_config=quant_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - prefix=add_prefix("visual", prefix), - use_data_parallel=self.use_data_parallel, - ) + self.visual = Qwen3VLMoeVisionModel( + config.vision_config, + # NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization. + # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported. + quant_config=quant_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + prefix=add_prefix("visual", prefix), + use_data_parallel=self.use_data_parallel, + ) # TODO: make it more elegant if language_model_cls is Qwen3LLMModel: From 6b21c4f5f92418629c968c795f6716235eae6307 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Wed, 19 Nov 2025 13:52:37 +0000 Subject: [PATCH 34/68] Fix import --- python/sglang/srt/managers/tokenizer_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 22f2b101bac9..2bf1cc3f73f7 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -39,8 +39,8 @@ import zmq.asyncio from fastapi import BackgroundTasks -from python.sglang.srt.disaggregation.encode_receiver import MMReceiver from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.disaggregation.encode_receiver import MMReceiver from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.environ import envs from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry From 3a77446254458d25edc97ca4007935ee755c8ce7 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Wed, 19 Nov 2025 14:14:59 +0000 Subject: [PATCH 35/68] Fix rebase --- .../srt/disaggregation/encode_server.py | 63 +++++++------------ 1 file changed, 21 insertions(+), 42 deletions(-) diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py index 4f1a1ed38d35..49ccc3f3eb60 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -2,6 +2,7 @@ from typing import Optional import aiohttp +import numpy as np import torch import uvicorn import zmq @@ -33,45 +34,17 @@ logger = logging.getLogger(__name__) -class EmbeddingData: - def __init__(self, req_id, num_parts, part_idx, image_grid_dim, embedding=None): - self.req_id = req_id - self.num_parts = num_parts - self.part_idx = part_idx - self.image_grid_dim = image_grid_dim - self.embedding = embedding - - # aggregated data - self.ready_list = [i == self.part_idx for i in range(self.num_parts)] - self.embedding_list = [ - embedding if i == self.part_idx else None for i in range(self.num_parts) - ] - self.image_grid_dim_list = [ - self.image_grid_dim if i == self.part_idx else None - for i in range(self.num_parts) - ] - - def add(self, embedding_data): - assert self.req_id == embedding_data.req_id - assert not self.ready_list[embedding_data.part_idx] - self.ready_list[embedding_data.part_idx] = True - self.image_grid_dim_list[embedding_data.part_idx] = ( - embedding_data.image_grid_dim - ) - self.embedding_list[embedding_data.part_idx] = embedding_data.embedding - - def get_embedding(self): - return torch.concatenate(self.embedding_list) - - def get_img_grid(self): - return torch.concatenate(self.image_grid_dim_list) - - @property - def ready(self): - return sum(self.ready_list) == self.num_parts - - def __repr__(self): - return f"EmbeddingData(req_id={self.req_id}, num_parts={self.num_parts}, part_idx={self.part_idx})" +def _convert(data): + if type(data) == torch.Tensor: + return data + elif type(data) == np.ndarray: + return torch.tensor(data) + elif type(data) == list and type(data[0]) == np.ndarray: + return torch.tensor(np.array(data)) + elif type(data) == list and type(data[0]) in [int, float]: + return torch.tensor(data) + else: + return data _image_grid_attrs = ["image_grid_thw", "image_grid_hws"] @@ -92,7 +65,9 @@ def __init__(self, server_args: ServerArgs): set_global_server_args_for_scheduler(server_args) self.image_processor = AutoImageProcessor.from_pretrained( - server_args.model_path, trust_remote_code=server_args.trust_remote_code + server_args.model_path, + trust_remote_code=server_args.trust_remote_code, + use_fast=True, ) self.model_config = ModelConfig.from_server_args( @@ -146,13 +121,17 @@ async def mm_encode(self, mm_items) -> torch.Tensor: images = load_images(mm_items) images_input = self.image_processor(images=images) + feature = images_input["pixel_values"] mm_item = MultimodalDataItem.from_dict( { "modality": Modality.IMAGE, - "feature": images_input["pixel_values"], + "feature": _convert(feature), } ) - mm_item.set("image_grid_thw", images_input["image_grid_thw"]) + for k, v in images_input.items(): + if k == "pixel_values": + continue + mm_item.set(k, _convert(v)) mm_embedding = self.model.get_image_feature([mm_item]) if len(mm_embedding.shape) != 2: mm_embedding = mm_embedding.reshape(-1, mm_embedding.shape[-1]) From b7bed5393c6f0a2dab027c9cbc4eb2a1112b16ec Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Wed, 19 Nov 2025 14:20:27 +0000 Subject: [PATCH 36/68] Revert changes --- python/sglang/srt/multimodal/processors/gemma3.py | 1 - python/sglang/srt/multimodal/processors/gemma3n.py | 1 - python/sglang/srt/multimodal/processors/glm4v.py | 4 ++-- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/multimodal/processors/gemma3.py b/python/sglang/srt/multimodal/processors/gemma3.py index 555eba8ec323..cbfb45e8404e 100644 --- a/python/sglang/srt/multimodal/processors/gemma3.py +++ b/python/sglang/srt/multimodal/processors/gemma3.py @@ -18,7 +18,6 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): super().__init__(hf_config, server_args, _processor, *args, **kwargs) self.IM_START_TOKEN_ID = hf_config.boi_token_index self.IM_END_TOKEN_ID = hf_config.eoi_token_index - self.IM_TOKEN_ID = hf_config.image_token_index self.mm_tokens = MultimodalSpecialTokens( # The single, pre-expanded image token. image_token="", diff --git a/python/sglang/srt/multimodal/processors/gemma3n.py b/python/sglang/srt/multimodal/processors/gemma3n.py index 97d1987c8b0c..9ea8b8be3662 100644 --- a/python/sglang/srt/multimodal/processors/gemma3n.py +++ b/python/sglang/srt/multimodal/processors/gemma3n.py @@ -31,7 +31,6 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): self.IM_START_TOKEN_ID = hf_config.boi_token_id self.IM_END_TOKEN_ID = hf_config.eoi_token_id - self.IM_TOKEN_ID = hf_config.image_token_id self.AUDIO_START_TOKEN_ID = hf_config.boa_token_id self.AUDIO_END_TOKEN_ID = hf_config.eoa_token_id diff --git a/python/sglang/srt/multimodal/processors/glm4v.py b/python/sglang/srt/multimodal/processors/glm4v.py index 192143c5067f..80d717a7ad76 100644 --- a/python/sglang/srt/multimodal/processors/glm4v.py +++ b/python/sglang/srt/multimodal/processors/glm4v.py @@ -26,8 +26,8 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): # Token IDs self.IM_TOKEN_ID = hf_config.image_token_id self.VIDEO_TOKEN_ID = hf_config.video_token_id - self.IM_START_TOKEN_ID = hf_config.image_start_token_id - self.IM_END_TOKEN_ID = hf_config.image_end_token_id + self.IMAGE_START_TOKEN_ID = hf_config.image_start_token_id + self.IMAGE_END_TOKEN_ID = hf_config.image_end_token_id self.VIDEO_START_TOKEN_ID = hf_config.video_start_token_id self.VIDEO_END_TOKEN_ID = hf_config.video_end_token_id From 88c288410359adc8afa19df26694797d26c9514f Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Thu, 20 Nov 2025 01:33:25 +0000 Subject: [PATCH 37/68] Fix qwen2_5_vl --- python/sglang/srt/models/qwen2_5_vl.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index 5f6c12aad955..b899eb977485 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -477,6 +477,13 @@ def __init__( self.config = config self.use_data_parallel = get_global_server_args().mm_enable_dp_encoder + if not self.config.mm_only: + self.model = Qwen2Model( + config, + quant_config, + prefix=add_prefix("model", prefix), + ) + if self.pp_group.is_last_rank: if self.pp_group.world_size == 1 and self.config.tie_word_embeddings: self.lm_head = self.model.embed_tokens @@ -500,13 +507,6 @@ def __init__( prefix=add_prefix("visual", prefix), ) - if not self.config.mm_only: - self.model = Qwen2Model( - config, - quant_config, - prefix=add_prefix("model", prefix), - ) - self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling self.logits_processor = LogitsProcessor(config) From b69aaee544c070b0c9bb517be63d83b3a90f07c6 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Thu, 20 Nov 2025 04:11:13 +0000 Subject: [PATCH 38/68] Get rid of the dependency on minlb --- .../srt/disaggregation/encode_receiver.py | 114 +++++++- .../srt/disaggregation/encode_server.py | 23 +- python/sglang/srt/entrypoints/http_server.py | 18 -- .../sglang/srt/managers/tokenizer_manager.py | 4 +- python/sglang/srt/server_args.py | 8 + .../bindings/python/sglang_router/mini_lb.py | 251 +++--------------- .../bindings/python/sglang_router/router.py | 2 - .../python/sglang_router/router_args.py | 32 --- 8 files changed, 169 insertions(+), 283 deletions(-) diff --git a/python/sglang/srt/disaggregation/encode_receiver.py b/python/sglang/srt/disaggregation/encode_receiver.py index 25fb076a3311..7b19bb8e9286 100644 --- a/python/sglang/srt/disaggregation/encode_receiver.py +++ b/python/sglang/srt/disaggregation/encode_receiver.py @@ -1,7 +1,9 @@ import asyncio import logging +import random from typing import Dict +import aiohttp import torch import zmq import zmq.asyncio @@ -53,9 +55,16 @@ def __repr__(self): return f"EmbeddingData(req_id={self.req_id}, num_parts={self.num_parts}, part_idx={self.part_idx})" +def _generate_id(): + req_id = random.randint(0, 2**63 - 1) + return req_id + + class MMReceiver: - def __init__(self, mm_transfer_backend, disaggregation_ib_device, dtype): + def __init__( + self, host, encode_urls, mm_transfer_backend, disaggregation_ib_device, dtype + ): context = zmq.asyncio.Context(2) self.embedding_port = get_free_port() self.recv_from_encoder = get_zmq_socket( @@ -65,6 +74,9 @@ def __init__(self, mm_transfer_backend, disaggregation_ib_device, dtype): self.embeddings_lock = asyncio.Lock() self.mm_transfer_backend = mm_transfer_backend self.dtype = dtype + self.encode_urls = encode_urls + self.encode_idx = list(range(len(self.encode_urls))) + self.host = host if self.mm_transfer_backend == "mooncake": self.embeddings_engine = MooncakeTransferEngine( hostname=get_local_ip_auto(), @@ -73,6 +85,96 @@ def __init__(self, mm_transfer_backend, disaggregation_ib_device, dtype): ) self.embeddings_buffer = dict() + async def encode(self, req_id, img_data, endpoint_encode, endpoint_send): + if len(img_data) == 0: + return + + # Split mm_items + encode_requests = [] + random.shuffle(self.encode_idx) + num_items_assigned = [ + (idx + len(img_data)) // len(self.encode_urls) for idx in self.encode_idx + ] + num_parts = sum(1 for x in num_items_assigned if x != 0) + cum_num_items = 0 + cum_idx = 0 + for idx, assigned_num in enumerate(num_items_assigned): + if assigned_num == 0: + continue + encode_requests.append( + { + "encoder_idx": idx, + "mm_items": img_data[cum_num_items : cum_num_items + assigned_num], + "num_parts": num_parts, + "part_idx": cum_idx, + "req_id": req_id, + "prefill_host": self.host, + "embedding_port": self.embedding_port, + } + ) + cum_idx += 1 + cum_num_items += assigned_num + + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout( + total=1800 + ) # Add timeout for request reliability + ) as session: + # Send encode requests + + tasks = [ + session.post( + f"{self.encode_urls[encode_request['encoder_idx']]}/{endpoint_encode}", + json=encode_request, + ) + for encode_request in encode_requests + ] + + responses = await asyncio.gather(*tasks) + response_json_list_unsort = [ + await response.json() for response in responses + ] + + # zmq backend: return is None + if None in response_json_list_unsort: + return + + # mooncake backend: send bootstrap info + + embedding_size_list_sort = [None for _ in range(num_parts)] + embedding_length_tot = 0 + response_json_list_sort = [None for _ in range(num_parts)] + for response_json in response_json_list_unsort: + idx = response_json["part_idx"] + embedding_size_list_sort[idx] = response_json["embedding_size"] + embedding_length_tot += response_json["embedding_len"] + response_json_list_sort[idx] = response_json + + offset = 0 + metadata_tasks = [] + buffer_address = await self.allocate_embedding_buffer( + req_id, + embedding_length_tot, + response_json_list_sort[0]["embedding_dim"], + ) + for idx in range(len(tasks)): + response_json = response_json_list_sort[idx] + buffer_address_adjust = offset + buffer_address + response_json.update( + { + "session_id": self.embeddings_engine.session_id, + "buffer_address": buffer_address_adjust, + } + ) + metadata_tasks.append( + session.post( + f"{self.encode_urls[response_json['encoder_idx']]}/{endpoint_send}", + json=response_json, + ) + ) + offset += embedding_size_list_sort[idx] + await asyncio.gather(*metadata_tasks) + async def handle_embedding(self): recv_obj = await self.recv_from_encoder.recv_pyobj() if recv_obj.req_id not in self.received_data: @@ -92,8 +194,16 @@ async def allocate_embedding_buffer(self, req_id, embedding_length, embedding_di self.embeddings_buffer[req_id] = embeddings return embeddings.data_ptr() - async def recv_mm_data(self, req_id, mm_processor, prompt): + async def recv_mm_data(self, img_data, mm_processor, prompt): try: + if len(self.encode_urls) == 0: + return None + req_id = _generate_id() + if type(img_data) != list: + img_data = [img_data.url] + else: + img_data = [img.url for img in img_data] + asyncio.create_task(self.encode(req_id, img_data, "encode", "send")) return await asyncio.wait_for( self._recv_mm_data(req_id, mm_processor, prompt), timeout=10 ) diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py index 49ccc3f3eb60..1b95aa516db9 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -139,8 +139,8 @@ async def mm_encode(self, mm_items) -> torch.Tensor: async def mm_send( self, - prefill_host, - prefill_url, + prefill_host: int, + embedding_port: int, embedding: torch.Tensor, mm_data: EmbeddingData, session_id=None, @@ -157,17 +157,16 @@ async def mm_send( mm_data.embedding_list[mm_data.part_idx] = None # Send ack/data - if prefill_url in self.send_to_prefill_sockets: - socket = self.send_to_prefill_sockets[prefill_url] + if embedding_port in self.send_to_prefill_sockets: + socket = self.send_to_prefill_sockets[embedding_port] else: - embedding_port = await self.get_embedding_port(prefill_url) socket = get_zmq_socket( self.context, zmq.PUSH, f"tcp://{prefill_host}:{embedding_port}", False, ) - self.send_to_prefill_sockets[prefill_url] = socket + self.send_to_prefill_sockets[embedding_port] = socket socket.send_pyobj(mm_data) @torch.inference_mode() @@ -184,12 +183,12 @@ async def encode(self, mm_items, req_id, num_parts, part_idx): return mm_embedding.nbytes, mm_embedding.shape[0], mm_embedding.shape[1] async def send( - self, req_id, prefill_host, prefill_url, session_id=None, buffer_address=None + self, req_id, prefill_host, embedding_port, session_id=None, buffer_address=None ): mm_data: EmbeddingData = self.embedding_to_send[req_id] await self.mm_send( prefill_host, - prefill_url, + embedding_port, mm_data.embedding, mm_data, session_id, @@ -240,8 +239,8 @@ async def handle_encode_request(request: dict): elif encoder.server_args.mm_transfer_backend == "zmq": await encoder.send( req_id=request["req_id"], - prefill_host=request["bootstrap_host"], - prefill_url=request["prefill_url"], + prefill_host=request["prefill_host"], + embedding_port=request["embedding_port"], ) return ORJSONResponse(content=None) @@ -251,8 +250,8 @@ async def handle_send_request(request: dict): # mooncake backend await encoder.send( req_id=request["req_id"], - prefill_host=request["bootstrap_host"], - prefill_url=request["prefill_url"], + prefill_host=request["prefill_host"], + embedding_port=request["embedding_port"], session_id=request["session_id"], buffer_address=request["buffer_address"], ) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 2a1eae4fc943..cf0a3784fe8c 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -1331,24 +1331,6 @@ async def sagemaker_chat_completions( ) -@app.post("/embedding_bootstrap") -async def embedding_bootstrap(request_data: dict): - mm_receiver = _global_state.tokenizer_manager.mm_receiver - if "embedding_length" in request_data: - buffer_address = await mm_receiver.allocate_embedding_buffer( - request_data["req_id"], - request_data["embedding_length"], - request_data["embedding_dim"], - ) - session_id = mm_receiver.embeddings_engine.session_id - return ORJSONResponse( - content={"session_id": session_id, "buffer_address": buffer_address} - ) - elif "embedding_port" in request_data: - embedding_port = mm_receiver.embedding_port - return ORJSONResponse(content={"embedding_port": embedding_port}) - - ## Vertex AI API @app.post(os.environ.get("AIP_PREDICT_ROUTE", "/vertex_generate")) async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Request): diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 2bf1cc3f73f7..abb20355495f 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -319,6 +319,8 @@ def __init__( # E Disaggregation if self.model_config.is_multimodal and self.server_args.language_only: self.mm_receiver = MMReceiver( + server_args.host, + server_args.encode_urls, server_args.mm_transfer_backend, server_args.disaggregation_ib_device, self.model_config.dtype, @@ -727,7 +729,7 @@ async def _tokenize_one_request( mm_inputs = None if self.server_args.language_only: mm_inputs: Dict = await self.mm_receiver.recv_mm_data( - obj.bootstrap_room, + obj.image_data, self.mm_processor, input_text or input_ids, ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 13c79dbeb044..43bb57d99966 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -260,6 +260,7 @@ class ServerArgs: mm_only: bool = False language_only: bool = False mm_transfer_backend: str = "zmq" + encode_urls: List[str] = dataclasses.field(default_factory=list) # Quantization and data type dtype: str = "auto" @@ -2274,6 +2275,13 @@ def add_cli_args(parser: argparse.ArgumentParser): choices=MM_TRANSFER_BACKEND_CHOICES, help="The backend for encoder disaggregation transfer. Default is zmq.", ) + parser.add_argument( + "--encode-urls", + nargs="+", + type=str, + default=[], + help="List of encode urls for encoder disaggregation", + ) # Quantization and data type parser.add_argument( diff --git a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py index 2949cf036d59..39e809358253 100644 --- a/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py +++ b/sgl-model-gateway/bindings/python/sglang_router/mini_lb.py @@ -48,25 +48,6 @@ def maybe_wrap_ipv6_address(address: str) -> str: return address -def clear_image_urls(data): - if "messages" not in data: - return data - import copy - - data_copy = copy.deepcopy(data) - for message in data_copy["messages"]: - if "content" in message and isinstance(message["content"], list): - for content_item in message["content"]: - if isinstance(content_item, dict): - if ( - content_item.get("type") == "image_url" - and "image_url" in content_item - ): - if "url" in content_item["image_url"]: - content_item["image_url"]["url"] = "" - return data_copy - - class MiniLoadBalancer: def __init__( self, @@ -87,9 +68,6 @@ def __init__( "Tracing is not supported in this environment. Please install sglang." ) self.enable_trace = False - self.encode_urls = router_args.encode_urls - - self.encode_idx = list(range(len(self.encode_urls))) def _validate_router_args(self, router_args: RouterArgs): logger.warning( @@ -101,25 +79,14 @@ def _validate_router_args(self, router_args: RouterArgs): logger.warning("[MiniLB] Overriding policy to random") router_args.policy = "random" - if not router_args.pd_disaggregation and not router_args.e_disaggregation: - raise ValueError("MiniLB only supports PD/E disaggregation mode") - - if router_args.pd_disaggregation and router_args.e_disaggregation: - raise ValueError( - "MiniLB does not support PD and E disaggregation modes at the same time." - ) - - if len(router_args.prefill_urls) == 0: - raise ValueError("MiniLB requires at least one prefill server") + if not router_args.pd_disaggregation: + raise ValueError("MiniLB only supports PD disaggregation mode") - if router_args.pd_disaggregation and len(router_args.decode_urls) == 0: + if len(router_args.prefill_urls) == 0 or len(router_args.decode_urls) == 0: raise ValueError( - "The PD disaggregation mode requires at least one decode server." + "MiniLB requires at least one prefill and one decode server" ) - if router_args.e_disaggregation and len(router_args.decode_urls) != 0: - raise ValueError("The E disaggregation mode doesn't require decode server") - def start(self): global lb lb = self @@ -129,145 +96,16 @@ def start(self): uvicorn.run(app, host=self.host, port=self.port) def select_pair(self): + assert len(self.prefill_urls) > 0, "No prefill servers available" + assert len(self.decode_urls) > 0, "No decode servers available" pidx = random.randint(0, len(self.prefill_urls) - 1) - if len(self.decode_urls) != 0: - didx = random.randint(0, len(self.decode_urls) - 1) - decode_url = self.decode_urls[didx] - else: - decode_url = None + didx = random.randint(0, len(self.decode_urls) - 1) return ( self.prefill_urls[pidx], self.prefill_bootstrap_ports[pidx], - decode_url, + self.decode_urls[didx], ) - async def embedding_bootstrap( - self, session, prefill_url, req_id, embedding_length, embedding_dim - ): - response = await session.post( - f"{prefill_url}/embedding_bootstrap", - json={ - "req_id": req_id, - "embedding_length": embedding_length, - "embedding_dim": embedding_dim, - }, - ) - response_json = await response.json() - session_id = response_json["session_id"] - buffer_address = response_json["buffer_address"] - return session_id, buffer_address - - async def encode( - self, request_data, encode_urls, endpoint_encode, endpoint_send, prefill_url - ): - messages = request_data.get("messages") - if messages is None or len(encode_urls) == 0: - return - - # Extract mm_items - img_list = [] - for message in messages: - for item in message.get("content"): - if item.get("type") == "image_url": - img_url = item.get("image_url").get("url") - img_list.append(img_url) - - if len(img_list) == 0: - return - - req_id = request_data.get("bootstrap_room") - prefill_host = request_data["bootstrap_host"] - - # Split mm_items - encode_requests = [] - random.shuffle(self.encode_idx) - num_items_assigned = [ - (idx + len(img_list)) // len(self.encode_urls) for idx in self.encode_idx - ] - num_parts = sum(1 for x in num_items_assigned if x != 0) - cum_num_items = 0 - cum_idx = 0 - for idx, assigned_num in enumerate(num_items_assigned): - if assigned_num == 0: - continue - encode_requests.append( - { - "encoder_idx": idx, - "mm_items": img_list[cum_num_items : cum_num_items + assigned_num], - "num_parts": num_parts, - "part_idx": cum_idx, - "req_id": req_id, - "prefill_url": prefill_url, - "bootstrap_host": prefill_host, - } - ) - cum_idx += 1 - cum_num_items += assigned_num - - async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout( - total=self.timeout - ) # Add timeout for request reliability - ) as session: - # Send encode requests - - tasks = [ - session.post( - f"{encode_urls[encode_request['encoder_idx']]}/{endpoint_encode}", - json=encode_request, - ) - for encode_request in encode_requests - ] - - responses = await asyncio.gather(*tasks) - response_json_list_unsort = [ - await response.json() for response in responses - ] - - # zmq backend: return is None - if None in response_json_list_unsort: - return - - # mooncake backend: send bootstrap info - - embedding_size_list_sort = [None for _ in range(num_parts)] - embedding_length_tot = 0 - response_json_list_sort = [None for _ in range(num_parts)] - for response_json in response_json_list_unsort: - idx = response_json["part_idx"] - embedding_size_list_sort[idx] = response_json["embedding_size"] - embedding_length_tot += response_json["embedding_len"] - response_json_list_sort[idx] = response_json - - offset = 0 - metadata_tasks = [] - session_id, buffer_address = await self.embedding_bootstrap( - session, - prefill_url, - req_id, - embedding_length_tot, - response_json_list_sort[0]["embedding_dim"], - ) - for idx in range(len(tasks)): - response_json = response_json_list_sort[idx] - buffer_address_adjust = offset + buffer_address - response_json.update( - { - "session_id": session_id, - "buffer_address": buffer_address_adjust, - "bootstrap_host": prefill_host, - "prefill_url": prefill_url, - } - ) - metadata_tasks.append( - session.post( - f"{encode_urls[response_json['encoder_idx']]}/{endpoint_send}", - json=response_json, - ) - ) - offset += embedding_size_list_sort[idx] - await asyncio.gather(*metadata_tasks) - async def generate( self, modified_request, prefill_server, decode_server, endpoint ) -> ORJSONResponse: @@ -290,23 +128,25 @@ async def generate( headers = {"trace_context": trace_context} tasks = [ - session.post(f"{prefill_server}/{endpoint}", json=modified_request) + session.post( + f"{prefill_server}/{endpoint}", + json=modified_request, + headers=headers, + ), + session.post( + f"{decode_server}/{endpoint}", + json=modified_request, + headers=headers, + ), ] - if decode_server is not None: - tasks.append( - session.post(f"{decode_server}/{endpoint}", json=modified_request) - ) for bootstrap_room in bootstrap_room_list: trace_slice_end("mini_lb_launch", bootstrap_room, auto_next_anon=True) # Wait for both responses to complete. Prefill should end first. - responses = await asyncio.gather(*tasks) - prefill_response = responses[0] - decode_response = ( - responses[1] if decode_server is not None else prefill_response - ) - if "return_logprob" in modified_request and decode_server is not None: + prefill_response, decode_response = await asyncio.gather(*tasks) + + if "return_logprob" in modified_request: prefill_json = await prefill_response.json() ret_json = await decode_response.json() @@ -362,35 +202,24 @@ async def stream_results(): tasks = [ session.post( f"{prefill_server}/{endpoint}", - json=( - clear_image_urls(modified_request) - if self.encode_urls - else modified_request - ), - ) + json=modified_request, + headers=headers, + ), + session.post( + f"{decode_server}/{endpoint}", + json=modified_request, + headers=headers, + ), ] - if decode_server is not None: - tasks.append( - session.post( - f"{decode_server}/{endpoint}", json=modified_request - ) - ) for bootstrap_room in bootstrap_room_list: trace_slice_end( "mini_lb_launch", bootstrap_room, auto_next_anon=True ) # Wait for both responses to complete. Since this is streaming, they return immediately. - responses = await asyncio.gather(*tasks) - prefill_response = responses[0] - decode_response = ( - responses[1] if decode_server is not None else prefill_response - ) + prefill_response, decode_response = await asyncio.gather(*tasks) - if ( - modified_request.get("return_logprob", False) - and decode_server is not None - ): + if modified_request.get("return_logprob", False): prefill_chunks = [] async for chunk in prefill_response.content: prefill_chunks.append(chunk) @@ -600,24 +429,14 @@ async def _forward_to_backend(request_data: dict, endpoint_name: str): # Parse and transform prefill_server for bootstrap data parsed_url = urllib.parse.urlparse(prefill_server) hostname = maybe_wrap_ipv6_address(parsed_url.hostname) - bootstrap_room = _generate_bootstrap_room() - - # Send requests to encode server - encode_request = request_data.copy() - encode_request.update( + modified_request = request_data.copy() + modified_request.update( { - "bootstrap_room": bootstrap_room, "bootstrap_host": hostname, + "bootstrap_port": bootstrap_port, + "bootstrap_room": _generate_bootstrap_room(), } ) - asyncio.create_task( - lb.encode(encode_request, lb.encode_urls, "encode", "send", prefill_server) - ) - - modified_request = encode_request.copy() - modified_request.update( - {"bootstrap_port": bootstrap_port, "encode_urls": lb.encode_urls} - ) if request_data.get("stream", False): return await lb.generate_stream( diff --git a/sgl-model-gateway/bindings/python/sglang_router/router.py b/sgl-model-gateway/bindings/python/sglang_router/router.py index 06f913818f54..05506e1cd560 100644 --- a/sgl-model-gateway/bindings/python/sglang_router/router.py +++ b/sgl-model-gateway/bindings/python/sglang_router/router.py @@ -202,8 +202,6 @@ def from_args(args: RouterArgs) -> "Router": # Remove fields that shouldn't be passed to Rust Router constructor fields_to_remove = [ "mini_lb", - "e_disaggregation", - "encode_urls", "oracle_wallet_path", "oracle_tns_alias", "oracle_connect_descriptor", diff --git a/sgl-model-gateway/bindings/python/sglang_router/router_args.py b/sgl-model-gateway/bindings/python/sglang_router/router_args.py index 1ed269350d12..3c085c5cfd95 100644 --- a/sgl-model-gateway/bindings/python/sglang_router/router_args.py +++ b/sgl-model-gateway/bindings/python/sglang_router/router_args.py @@ -17,12 +17,10 @@ class RouterArgs: # PD-specific configuration mini_lb: bool = False pd_disaggregation: bool = False # Enable PD disaggregated mode - e_disaggregation: bool = False # Enable E disaggregated mode prefill_urls: List[tuple] = dataclasses.field( default_factory=list ) # List of (url, bootstrap_port) decode_urls: List[str] = dataclasses.field(default_factory=list) - encode_urls: List[str] = dataclasses.field(default_factory=list) # Routing policy policy: str = "cache_aware" @@ -194,11 +192,6 @@ def add_cli_args( action="store_true", help="Enable PD (Prefill-Decode) disaggregated mode", ) - parser.add_argument( - f"--{prefix}e-disaggregation", - action="store_true", - help="Enable E (Encode) disaggregated mode", - ) parser.add_argument( f"--{prefix}prefill", nargs="+", @@ -214,13 +207,6 @@ def add_cli_args( metavar=("URL",), help="Decode server URL. Can be specified multiple times.", ) - parser.add_argument( - f"--{prefix}encode", - nargs=1, - action="append", - metavar=("URL",), - help="Encode server URL. Can be specified multiple times.", - ) parser.add_argument( f"--{prefix}worker-startup-timeout-secs", type=int, @@ -688,9 +674,6 @@ def from_cli_args( args_dict["decode_urls"] = cls._parse_decode_urls( cli_args_dict.get(f"{prefix}decode", None) ) - args_dict["encode_urls"] = cls._parse_encode_urls( - cli_args_dict.get(f"{prefix}encode", None) - ) args_dict["selector"] = cls._parse_selector( cli_args_dict.get(f"{prefix}selector", None) ) @@ -729,8 +712,6 @@ def _validate_router_args(self): f"Using --policy '{self.policy}' for prefill nodes " f"and --decode-policy '{self.decode_policy}' for decode nodes." ) - if self.e_disaggregation or len(self.encode_urls): - raise ValueError("Currently, E disaggregation mode requires --min-lb") @staticmethod def _parse_selector(selector_list): @@ -799,16 +780,3 @@ def _parse_decode_urls(decode_list): # decode_list is a list of single-element lists due to nargs=1 return [url[0] for url in decode_list] - - @staticmethod - def _parse_encode_urls(encode_list): - """Parse encode URLs from --encode arguments. - - Format: --encode URL - Example: --encode http://encode1:8081 --encode http://encode2:8081 - """ - if not encode_list: - return [] - - # encode_list is a list of single-element lists due to nargs=1 - return [url[0] for url in encode_list] From 59d3c02dce526b2113617ca0446254cdb6d3a344 Mon Sep 17 00:00:00 2001 From: liusy58 Date: Thu, 20 Nov 2025 14:57:24 +0800 Subject: [PATCH 39/68] [feat] use `--random-image-count` to generate requests contain images in range [1,image_count] --- python/sglang/bench_serving.py | 38 +++++++++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 002674894f45..88409a4ce304 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -787,6 +787,7 @@ def get_dataset(args, tokenizer, model_id=None): image_format=args.image_format, image_resolution=args.image_resolution, backend=args.backend, + random_image_count=args.random_image_count, ) elif args.dataset_name == "generated-shared-prefix": assert not tokenize_prompt @@ -1432,10 +1433,12 @@ def sample_image_requests( image_format: str, image_resolution: str, backend: str, + random_image_count: bool = False, ) -> List[DatasetRow]: """Generate requests with images. - - Each request includes ``image_count`` images. + - If ``random_image_count`` is True, each request includes a random number of images between 1 and ``image_count``. + - If ``random_image_count`` is False, each request includes exactly ``image_count`` images. - Supported resolutions: 4k (3840x2160), 1080p (1920x1080), 720p (1280x720), 360p (640x360), or custom 'heightxwidth' (e.g., 1080x1920). - Text lengths follow the 'random' dataset sampling rule. ``prompt_len`` @@ -1445,10 +1448,20 @@ def sample_image_requests( # Parse resolution (supports presets and 'heightxwidth') width, height = parse_image_resolution(image_resolution) + # Determine image counts for each request + if random_image_count: + # Random number of images per request + image_counts = np.random.randint(1, image_count + 1, size=num_requests) + total_images = np.sum(image_counts) + else: + # Fixed number of images per request + image_counts = np.full(num_requests, image_count) + total_images = image_count * num_requests + # Check for potentially problematic combinations and warn user - if width * height >= 1920 * 1080 and image_count * num_requests >= 100: + if width * height >= 1920 * 1080 and total_images >= 100: warnings.warn( - f"High resolution ({width}x{height}) with {image_count * num_requests} total images " + f"High resolution ({width}x{height}) with {total_images} total images " f"may take a long time. Consider reducing resolution or image count.", UserWarning, stacklevel=2, @@ -1482,6 +1495,9 @@ def _gen_random_image_data_uri( dataset: List[DatasetRow] = [] total_image_bytes = 0 for i in range(num_requests): + # Get the number of images for this request + request_image_count = int(image_counts[i]) + # Generate text prompt text_prompt = gen_mm_prompt( processor.tokenizer, @@ -1491,7 +1507,7 @@ def _gen_random_image_data_uri( # Generate image list images, images_base64, images_bytes = zip( - *[_gen_random_image_data_uri() for _ in range(image_count)] + *[_gen_random_image_data_uri() for _ in range(request_image_count)] ) total_image_bytes += sum(list(images_bytes)) @@ -1503,11 +1519,18 @@ def _gen_random_image_data_uri( processor, backend, ) - dataset.append(data_row) + # Print statistics print(f"#Input tokens: {np.sum([x.prompt_len for x in dataset])}") print(f"#Output tokens: {np.sum([x.output_len for x in dataset])}") + print(f"#Total images: {total_images}") + + if random_image_count: + print(f"#Images per request: min={np.min(image_counts)}, max={np.max(image_counts)}, mean={np.mean(image_counts):.2f}") + else: + print(f"#Images per request: {image_count} (fixed)") + print( f"\nCreated {len(dataset)} {image_content} {image_format} images with average {total_image_bytes // num_requests} bytes per request" ) @@ -2624,6 +2647,11 @@ def __call__(self, parser, namespace, values, option_string=None): "Supports presets 4k/1080p/720p/360p or custom 'heightxwidth' (e.g., 1080x1920)." ), ) + parser.add_argument( + "--random-image-count", + action="store_true", + help="Enable Random Image Count", + ) parser.add_argument( "--image-format", type=str, From c1851b62deb5012fb3989d2bf642f0aed543cde6 Mon Sep 17 00:00:00 2001 From: liusy58 Date: Thu, 20 Nov 2025 15:17:39 +0800 Subject: [PATCH 40/68] fix typo --- python/sglang/srt/disaggregation/encode_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py index 1b95aa516db9..98fcc70d5c44 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -259,5 +259,5 @@ async def handle_send_request(request: dict): @app.get("/health_check") -async def handle_send_request(): +async def handle_health_check_request(): return ORJSONResponse(content={"is_alive": True}) From dabc6b302b0fb6d7ac3d434cfdb2fc52ea932d0c Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Thu, 20 Nov 2025 11:46:49 +0000 Subject: [PATCH 41/68] Fix OOM for qwen3 --- python/sglang/srt/disaggregation/encode_receiver.py | 2 +- python/sglang/srt/disaggregation/encode_server.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/disaggregation/encode_receiver.py b/python/sglang/srt/disaggregation/encode_receiver.py index 7b19bb8e9286..13d36216fded 100644 --- a/python/sglang/srt/disaggregation/encode_receiver.py +++ b/python/sglang/srt/disaggregation/encode_receiver.py @@ -205,7 +205,7 @@ async def recv_mm_data(self, img_data, mm_processor, prompt): img_data = [img.url for img in img_data] asyncio.create_task(self.encode(req_id, img_data, "encode", "send")) return await asyncio.wait_for( - self._recv_mm_data(req_id, mm_processor, prompt), timeout=10 + self._recv_mm_data(req_id, mm_processor, prompt), timeout=20 ) except asyncio.TimeoutError: logger.warning(f"Embedding recv timeout for request {req_id}") diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py index 98fcc70d5c44..735b01ed22f0 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -12,10 +12,10 @@ from transformers import AutoImageProcessor from transformers.image_utils import load_images -from python.sglang.srt.disaggregation.encode_receiver import EmbeddingData from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.disaggregation.encode_receiver import EmbeddingData from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine from sglang.srt.distributed.parallel_state import ( init_distributed_environment, @@ -132,7 +132,8 @@ async def mm_encode(self, mm_items) -> torch.Tensor: if k == "pixel_values": continue mm_item.set(k, _convert(v)) - mm_embedding = self.model.get_image_feature([mm_item]) + with torch.inference_mode(): + mm_embedding = self.model.get_image_feature([mm_item]) if len(mm_embedding.shape) != 2: mm_embedding = mm_embedding.reshape(-1, mm_embedding.shape[-1]) return _get_image_grid_dim(images_input), mm_embedding @@ -169,7 +170,6 @@ async def mm_send( self.send_to_prefill_sockets[embedding_port] = socket socket.send_pyobj(mm_data) - @torch.inference_mode() async def encode(self, mm_items, req_id, num_parts, part_idx): image_grid_dim, mm_embedding = await self.mm_encode(mm_items) mm_data = EmbeddingData( From 43b32b3f36029b9236744264ffc20ea6c60f5353 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Thu, 20 Nov 2025 12:34:49 +0000 Subject: [PATCH 42/68] Fix import --- python/sglang/launch_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 41f36e193853..831988eca6f9 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -15,7 +15,7 @@ def run_server(server_args): asyncio.run(serve_grpc(server_args)) elif server_args.mm_only: - from python.sglang.srt.disaggregation.encode_server import launch_server + from sglang.srt.disaggregation.encode_server import launch_server launch_server(server_args) else: From 837778102844389d861b2072149868e2008b423e Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Fri, 21 Nov 2025 01:55:47 +0000 Subject: [PATCH 43/68] Remove async lock --- .../srt/disaggregation/encode_receiver.py | 76 +++++++++---------- .../srt/disaggregation/encode_server.py | 20 ++--- 2 files changed, 42 insertions(+), 54 deletions(-) diff --git a/python/sglang/srt/disaggregation/encode_receiver.py b/python/sglang/srt/disaggregation/encode_receiver.py index 13d36216fded..586662619158 100644 --- a/python/sglang/srt/disaggregation/encode_receiver.py +++ b/python/sglang/srt/disaggregation/encode_receiver.py @@ -1,7 +1,6 @@ import asyncio import logging import random -from typing import Dict import aiohttp import torch @@ -65,13 +64,7 @@ class MMReceiver: def __init__( self, host, encode_urls, mm_transfer_backend, disaggregation_ib_device, dtype ): - context = zmq.asyncio.Context(2) - self.embedding_port = get_free_port() - self.recv_from_encoder = get_zmq_socket( - context, zmq.PULL, f"tcp://*:{self.embedding_port}", True - ) - self.received_data: Dict[int, EmbeddingData] = dict() - self.embeddings_lock = asyncio.Lock() + self.context = zmq.asyncio.Context(20) self.mm_transfer_backend = mm_transfer_backend self.dtype = dtype self.encode_urls = encode_urls @@ -85,7 +78,9 @@ def __init__( ) self.embeddings_buffer = dict() - async def encode(self, req_id, img_data, endpoint_encode, endpoint_send): + async def encode( + self, req_id, img_data, embedding_port, endpoint_encode, endpoint_send + ): if len(img_data) == 0: return @@ -109,7 +104,7 @@ async def encode(self, req_id, img_data, endpoint_encode, endpoint_send): "part_idx": cum_idx, "req_id": req_id, "prefill_host": self.host, - "embedding_port": self.embedding_port, + "embedding_port": embedding_port, } ) cum_idx += 1 @@ -175,13 +170,6 @@ async def encode(self, req_id, img_data, endpoint_encode, endpoint_send): offset += embedding_size_list_sort[idx] await asyncio.gather(*metadata_tasks) - async def handle_embedding(self): - recv_obj = await self.recv_from_encoder.recv_pyobj() - if recv_obj.req_id not in self.received_data: - self.received_data[recv_obj.req_id] = recv_obj - else: - self.received_data[recv_obj.req_id].add(recv_obj) - async def allocate_embedding_buffer(self, req_id, embedding_length, embedding_dim): embeddings = torch.zeros( (embedding_length, embedding_dim), @@ -199,48 +187,52 @@ async def recv_mm_data(self, img_data, mm_processor, prompt): if len(self.encode_urls) == 0: return None req_id = _generate_id() + embedding_port = get_free_port() if type(img_data) != list: img_data = [img_data.url] else: img_data = [img.url for img in img_data] - asyncio.create_task(self.encode(req_id, img_data, "encode", "send")) + asyncio.create_task( + self.encode(req_id, img_data, embedding_port, "encode", "send") + ) return await asyncio.wait_for( - self._recv_mm_data(req_id, mm_processor, prompt), timeout=20 + self._recv_mm_data(req_id, embedding_port, mm_processor, prompt), + timeout=20, ) except asyncio.TimeoutError: logger.warning(f"Embedding recv timeout for request {req_id}") - if req_id in self.received_data: - del self.received_data[req_id] if hasattr(self, "embeddings_buffer") and req_id in self.embeddings_buffer: del self.embeddings_buffer[req_id] return None - async def _recv_mm_data(self, req_id, mm_processor, prompt): + async def _recv_mm_data(self, req_id, embedding_port, mm_processor, prompt): # Bypass MMReceiver if req_id is None: return None - # E Disaggregation recv_embedding = None - img_grid_thw = None - - # Use async lock to avoid race condition - async with self.embeddings_lock: - while ( - req_id not in self.received_data or not self.received_data[req_id].ready - ): - await self.handle_embedding() - - recv_embedding_data = self.received_data[req_id] - if self.mm_transfer_backend == "mooncake": - recv_embedding = self.embeddings_buffer[req_id] - self.embeddings_engine.deregister(recv_embedding.data_ptr()) - elif self.mm_transfer_backend == "zmq": - recv_embedding = recv_embedding_data.get_embedding() - img_grid_thw = recv_embedding_data.get_img_grid() - del self.received_data[req_id] - if self.mm_transfer_backend == "mooncake": - del self.embeddings_buffer[req_id] + + recv_socket = get_zmq_socket( + self.context, zmq.PULL, f"tcp://*:{embedding_port}", True + ) + + recv_embedding_data: EmbeddingData = None + + while recv_embedding_data is None or not recv_embedding_data.ready: + recv_obj = await recv_socket.recv_pyobj() + if recv_embedding_data is None: + recv_embedding_data = recv_obj + else: + recv_embedding_data.add(recv_obj) + + if self.mm_transfer_backend == "mooncake": + recv_embedding = self.embeddings_buffer[req_id] + del self.embeddings_buffer[req_id] + self.embeddings_engine.deregister(recv_embedding.data_ptr()) + elif self.mm_transfer_backend == "zmq": + recv_embedding = recv_embedding_data.get_embedding() + + img_grid_thw = recv_embedding_data.get_img_grid() mm_inputs = mm_processor.get_mm_data(prompt, recv_embedding, img_grid_thw) return mm_inputs diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py index 735b01ed22f0..7142cc768275 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -133,10 +133,10 @@ async def mm_encode(self, mm_items) -> torch.Tensor: continue mm_item.set(k, _convert(v)) with torch.inference_mode(): - mm_embedding = self.model.get_image_feature([mm_item]) + mm_embedding: torch.Tensor = self.model.get_image_feature([mm_item]) if len(mm_embedding.shape) != 2: mm_embedding = mm_embedding.reshape(-1, mm_embedding.shape[-1]) - return _get_image_grid_dim(images_input), mm_embedding + return _get_image_grid_dim(images_input), mm_embedding.cpu() async def mm_send( self, @@ -158,16 +158,12 @@ async def mm_send( mm_data.embedding_list[mm_data.part_idx] = None # Send ack/data - if embedding_port in self.send_to_prefill_sockets: - socket = self.send_to_prefill_sockets[embedding_port] - else: - socket = get_zmq_socket( - self.context, - zmq.PUSH, - f"tcp://{prefill_host}:{embedding_port}", - False, - ) - self.send_to_prefill_sockets[embedding_port] = socket + socket = get_zmq_socket( + self.context, + zmq.PUSH, + f"tcp://{prefill_host}:{embedding_port}", + False, + ) socket.send_pyobj(mm_data) async def encode(self, mm_items, req_id, num_parts, part_idx): From 88cdbb58e59356cbf1e4003446a0eb75932c7360 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Fri, 21 Nov 2025 09:13:45 +0000 Subject: [PATCH 44/68] Support TP encoder --- .../srt/disaggregation/encode_server.py | 143 +++++++++++++----- 1 file changed, 106 insertions(+), 37 deletions(-) diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py index 7142cc768275..b1649d95a8fe 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -1,5 +1,8 @@ +import asyncio import logging -from typing import Optional +import multiprocessing as mp +import traceback +from typing import List, Optional import aiohttp import numpy as np @@ -29,7 +32,7 @@ ServerArgs, set_global_server_args_for_scheduler, ) -from sglang.srt.utils import get_local_ip_auto, get_zmq_socket +from sglang.srt.utils import get_local_ip_auto, get_zmq_socket, random_uuid logger = logging.getLogger(__name__) @@ -59,10 +62,18 @@ def _get_image_grid_dim(images_input): ) -class ImageEncoder: - def __init__(self, server_args: ServerArgs): +class MMEncoder: + def __init__( + self, + server_args: ServerArgs, + schedule_path=None, + dist_init_method=None, + rank: int = 0, + ): + logger.info(f"init MMEncoder {rank}/{server_args.tp_size}") self.server_args = server_args set_global_server_args_for_scheduler(server_args) + self.rank = rank self.image_processor = AutoImageProcessor.from_pretrained( server_args.model_path, @@ -83,39 +94,55 @@ def __init__(self, server_args: ServerArgs): remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports, ) - port_args = PortArgs.init_new(server_args) - if server_args.dist_init_addr: - dist_init_method = f"tcp://{server_args.dist_init_addr}" - else: - dist_init_method = f"tcp://127.0.0.1:{port_args.nccl_port}" + self.device = server_args.device + self.gpu_id = server_args.base_gpu_id + rank + + self.device_config = DeviceConfig( + device=self.device, + gpu_id=self.gpu_id, + ) + + torch.get_device_module(self.device).set_device(self.gpu_id) init_distributed_environment( - world_size=1, rank=0, distributed_init_method=dist_init_method + world_size=server_args.tp_size, + rank=rank, + distributed_init_method=dist_init_method, + local_rank=rank, ) - initialize_model_parallel() + initialize_model_parallel(tensor_model_parallel_size=server_args.tp_size) initialize_dp_attention(server_args, self.model_config) self.model = get_model( model_config=self.model_config, load_config=self.load_config, - device_config=DeviceConfig(), + device_config=self.device_config, ) - logger.info(f"Using transfer backend: {self.server_args.mm_transfer_backend}") + self.context = zmq.asyncio.Context(2) - if self.server_args.mm_transfer_backend == "mooncake": - self.local_ip = get_local_ip_auto() + if schedule_path is not None: + self.schedule_socket = get_zmq_socket( + self.context, zmq.PULL, schedule_path, True + ) - self.engine = MooncakeTransferEngine( - hostname=self.local_ip, - gpu_id=None, - ib_device=server_args.disaggregation_ib_device, + if self.rank == 0: + logger.info( + f"Using transfer backend: {self.server_args.mm_transfer_backend}" ) - self.context = zmq.asyncio.Context(2) - self.send_to_prefill_sockets = dict() + if self.server_args.mm_transfer_backend == "mooncake": + self.local_ip = get_local_ip_auto() + + self.engine = MooncakeTransferEngine( + hostname=self.local_ip, + gpu_id=None, + ib_device=server_args.disaggregation_ib_device, + ) + + self.embedding_to_send = dict() - self.embedding_to_send = dict() + logger.info(f"rank {rank} init finish ") async def mm_encode(self, mm_items) -> torch.Tensor: images = load_images(mm_items) @@ -168,14 +195,15 @@ async def mm_send( async def encode(self, mm_items, req_id, num_parts, part_idx): image_grid_dim, mm_embedding = await self.mm_encode(mm_items) - mm_data = EmbeddingData( - req_id, - num_parts, - part_idx, - image_grid_dim, - mm_embedding, - ) - self.embedding_to_send[mm_data.req_id] = mm_data + if self.rank == 0: + mm_data = EmbeddingData( + req_id, + num_parts, + part_idx, + image_grid_dim, + mm_embedding, + ) + self.embedding_to_send[mm_data.req_id] = mm_data return mm_embedding.nbytes, mm_embedding.shape[0], mm_embedding.shape[1] async def send( @@ -205,17 +233,63 @@ async def get_embedding_port(self, prefill_url): app = FastAPI() -encoder: Optional[ImageEncoder] = None +encoder: Optional[MMEncoder] = None +send_sockets: List[zmq.Socket] = [] + + +async def run_encoder( + server_args: ServerArgs, schedule_path, dist_init_method, rank: int +): + encoder = MMEncoder(server_args, schedule_path, dist_init_method, rank) + while True: + request = await encoder.schedule_socket.recv_pyobj() + await encoder.encode( + mm_items=request["mm_items"], + req_id=request["req_id"], + num_parts=request["num_parts"], + part_idx=request["part_idx"], + ) + + +def launch_encoder(server_args, schedule_path, dist_init_method, rank): + try: + asyncio.run(run_encoder(server_args, schedule_path, dist_init_method, rank)) + except KeyboardInterrupt: + logger.info(f"Exit rank {rank}") + except Exception: + traceback.print_exc() def launch_server(server_args: ServerArgs): global encoder - encoder = ImageEncoder(server_args) + ctx = mp.get_context("spawn") + zmq_ctx = zmq.Context(10) + ipc_path_prefix = random_uuid() + port_args = PortArgs.init_new(server_args) + if server_args.dist_init_addr: + dist_init_method = f"tcp://{server_args.dist_init_addr}" + else: + dist_init_method = f"tcp://127.0.0.1:{port_args.nccl_port}" + for rank in range(1, server_args.tp_size): + schedule_path = f"ipc:///tmp/{ipc_path_prefix}_schedule_{rank}" + send_sockets.append( + get_zmq_socket(zmq_ctx, zmq.PUSH, schedule_path, bind=False) + ) + ctx.Process( + target=launch_encoder, + args=(server_args, schedule_path, dist_init_method, rank), + daemon=True, + ).start() + encoder = MMEncoder(server_args, dist_init_method=dist_init_method) uvicorn.run(app, host=server_args.host, port=server_args.port) @app.post("/encode") async def handle_encode_request(request: dict): + # broadcast request + for socket in send_sockets: + socket.send_pyobj(request) + nbytes, embedding_len, embedding_dim = await encoder.encode( mm_items=request["mm_items"], req_id=request["req_id"], @@ -252,8 +326,3 @@ async def handle_send_request(request: dict): buffer_address=request["buffer_address"], ) return ORJSONResponse(content=None) - - -@app.get("/health_check") -async def handle_health_check_request(): - return ORJSONResponse(content={"is_alive": True}) From 685cae1a2470d525fa743f7361da91c0eb1c3eb5 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Fri, 21 Nov 2025 09:25:44 +0000 Subject: [PATCH 45/68] lint --- python/sglang/bench_serving.py | 12 +++++++----- python/sglang/srt/disaggregation/encode_server.py | 8 ++++---- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 88409a4ce304..237709401290 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -1433,7 +1433,7 @@ def sample_image_requests( image_format: str, image_resolution: str, backend: str, - random_image_count: bool = False, + random_image_count: bool = False, ) -> List[DatasetRow]: """Generate requests with images. @@ -1497,7 +1497,7 @@ def _gen_random_image_data_uri( for i in range(num_requests): # Get the number of images for this request request_image_count = int(image_counts[i]) - + # Generate text prompt text_prompt = gen_mm_prompt( processor.tokenizer, @@ -1525,12 +1525,14 @@ def _gen_random_image_data_uri( print(f"#Input tokens: {np.sum([x.prompt_len for x in dataset])}") print(f"#Output tokens: {np.sum([x.output_len for x in dataset])}") print(f"#Total images: {total_images}") - + if random_image_count: - print(f"#Images per request: min={np.min(image_counts)}, max={np.max(image_counts)}, mean={np.mean(image_counts):.2f}") + print( + f"#Images per request: min={np.min(image_counts)}, max={np.max(image_counts)}, mean={np.mean(image_counts):.2f}" + ) else: print(f"#Images per request: {image_count} (fixed)") - + print( f"\nCreated {len(dataset)} {image_content} {image_format} images with average {total_image_bytes // num_requests} bytes per request" ) diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py index b1649d95a8fe..245e53c5c949 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -144,7 +144,7 @@ def __init__( logger.info(f"rank {rank} init finish ") - async def mm_encode(self, mm_items) -> torch.Tensor: + async def _encode(self, mm_items) -> torch.Tensor: images = load_images(mm_items) images_input = self.image_processor(images=images) @@ -165,7 +165,7 @@ async def mm_encode(self, mm_items) -> torch.Tensor: mm_embedding = mm_embedding.reshape(-1, mm_embedding.shape[-1]) return _get_image_grid_dim(images_input), mm_embedding.cpu() - async def mm_send( + async def _send( self, prefill_host: int, embedding_port: int, @@ -194,7 +194,7 @@ async def mm_send( socket.send_pyobj(mm_data) async def encode(self, mm_items, req_id, num_parts, part_idx): - image_grid_dim, mm_embedding = await self.mm_encode(mm_items) + image_grid_dim, mm_embedding = await self._encode(mm_items) if self.rank == 0: mm_data = EmbeddingData( req_id, @@ -210,7 +210,7 @@ async def send( self, req_id, prefill_host, embedding_port, session_id=None, buffer_address=None ): mm_data: EmbeddingData = self.embedding_to_send[req_id] - await self.mm_send( + await self._send( prefill_host, embedding_port, mm_data.embedding, From cbe0124d163e247dcf176f3b358ca7029c68ae0a Mon Sep 17 00:00:00 2001 From: liusy58 Date: Sun, 23 Nov 2025 19:10:54 +0800 Subject: [PATCH 46/68] support encoder send mmdata to scheduler directly. --- .../srt/disaggregation/encode_receiver.py | 16 +- .../srt/disaggregation/encode_server.py | 123 ++++++++- python/sglang/srt/managers/io_struct.py | 6 + python/sglang/srt/managers/scheduler.py | 259 ++++++++++++++++++ .../sglang/srt/managers/tokenizer_manager.py | 138 +++++++++- python/sglang/srt/utils/common.py | 33 +++ 6 files changed, 546 insertions(+), 29 deletions(-) diff --git a/python/sglang/srt/disaggregation/encode_receiver.py b/python/sglang/srt/disaggregation/encode_receiver.py index 586662619158..144d5aca2e90 100644 --- a/python/sglang/srt/disaggregation/encode_receiver.py +++ b/python/sglang/srt/disaggregation/encode_receiver.py @@ -20,7 +20,9 @@ def __init__(self, req_id, num_parts, part_idx, image_grid_dim, embedding=None): self.part_idx = part_idx self.image_grid_dim = image_grid_dim self.embedding = embedding - + self.send_time = None + self.dtype = embedding.dtype if embedding is not None else None + self.shape = list(embedding.shape) if embedding is not None else None # aggregated data self.ready_list = [i == self.part_idx for i in range(self.num_parts)] self.embedding_list = [ @@ -53,6 +55,18 @@ def ready(self): def __repr__(self): return f"EmbeddingData(req_id={self.req_id}, num_parts={self.num_parts}, part_idx={self.part_idx})" + def copy_without_embedding(self): + new_data = EmbeddingData( + req_id=self.req_id, + num_parts=self.num_parts, + part_idx=self.part_idx, + image_grid_dim=self.image_grid_dim, + ) + new_data.send_time = self.send_time + new_data.dtype = self.dtype + new_data.shape = self.shape + return new_data + def _generate_id(): req_id = random.randint(0, 2**63 - 1) diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py index 245e53c5c949..ff3ac1d3607d 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -1,6 +1,8 @@ import asyncio import logging import multiprocessing as mp +import pickle +import time import traceback from typing import List, Optional @@ -35,6 +37,35 @@ from sglang.srt.utils import get_local_ip_auto, get_zmq_socket, random_uuid logger = logging.getLogger(__name__) +import ctypes +import sys + + +class TensorWrapper: + """Wrapper to keep tensor alive while exposing buffer for zero-copy.""" + + def __init__(self, tensor): + # Ensure tensor is on CPU and contiguous + if tensor.is_cuda: + tensor = tensor.cpu() + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + + # Keep tensor reference + self.tensor = tensor + self.shape = list(tensor.shape) + self.dtype = tensor.dtype + + # Create buffer view based on Python version + if sys.version_info >= (3, 12): + data_ptr = tensor.data_ptr() + total_bytes = tensor.numel() * tensor.element_size() + self._buffer = memoryview( + (ctypes.c_char * total_bytes).from_address(data_ptr) + ) + else: + # For Python 3.10, just use numpy - it already supports buffer protocol + self._buffer = np.asarray(tensor) def _convert(data): @@ -53,6 +84,14 @@ def _convert(data): _image_grid_attrs = ["image_grid_thw", "image_grid_hws"] +class ReceivePortsManager: + def __init__( + self, + ): + self.rid_2_ports = dict() + self.rid_2_tp_size = dict() + + def _get_image_grid_dim(images_input): for attr in _image_grid_attrs: if attr in images_input: @@ -62,6 +101,16 @@ def _get_image_grid_dim(images_input): ) +def get_ports_for_rank(embedding_ports, rank, tp_size): + total_ports = len(embedding_ports) + ports_per_rank = (total_ports + tp_size - 1) // tp_size + + start_idx = rank * ports_per_rank + end_idx = min(start_idx + ports_per_rank, total_ports) + + return embedding_ports[start_idx:end_idx] + + class MMEncoder: def __init__( self, @@ -160,7 +209,16 @@ async def _encode(self, mm_items) -> torch.Tensor: continue mm_item.set(k, _convert(v)) with torch.inference_mode(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + start_time1 = time.perf_counter() mm_embedding: torch.Tensor = self.model.get_image_feature([mm_item]) + if torch.cuda.is_available(): + torch.cuda.synchronize() + end_time = time.perf_counter() + logger.info( + f"Vit time : {(end_time - start_time1)*1000:.2f} ms {mm_embedding.shape = }" + ) if len(mm_embedding.shape) != 2: mm_embedding = mm_embedding.reshape(-1, mm_embedding.shape[-1]) return _get_image_grid_dim(images_input), mm_embedding.cpu() @@ -183,7 +241,7 @@ async def _send( mm_data.embedding = None mm_data.embedding_list[mm_data.part_idx] = None - + logger.info(f" [{self.rank}] Sending to {prefill_host}:{embedding_port}") # Send ack/data socket = get_zmq_socket( self.context, @@ -191,10 +249,17 @@ async def _send( f"tcp://{prefill_host}:{embedding_port}", False, ) - socket.send_pyobj(mm_data) + + new_mm_data = mm_data.copy_without_embedding() + embedding_tensor = TensorWrapper(mm_data.embedding) + new_mm_data.send_time = time.time() + socket.send_multipart([pickle.dumps(new_mm_data), embedding_tensor._buffer]) async def encode(self, mm_items, req_id, num_parts, part_idx): + start_time = time.time() image_grid_dim, mm_embedding = await self._encode(mm_items) + end_time = time.time() + print(f"🕛 encode cost = {(end_time - start_time) * 1000:.2f}ms") if self.rank == 0: mm_data = EmbeddingData( req_id, @@ -207,18 +272,47 @@ async def encode(self, mm_items, req_id, num_parts, part_idx): return mm_embedding.nbytes, mm_embedding.shape[0], mm_embedding.shape[1] async def send( - self, req_id, prefill_host, embedding_port, session_id=None, buffer_address=None + self, + req_id, + prefill_host, + embedding_ports, + session_id=None, + buffer_address=None, ): mm_data: EmbeddingData = self.embedding_to_send[req_id] - await self._send( - prefill_host, - embedding_port, - mm_data.embedding, - mm_data, - session_id, - buffer_address, - ) - del self.embedding_to_send[req_id] + + if not mm_data: + logger.error(f"No embedding data found for req_id: {req_id}") + return + logger.info(f"{self.rank=} {embedding_ports = }") + # send_tasks = [] + # ports_to_send = get_ports_for_rank(embedding_ports, self.rank, self.server_args.tp_size) + try: + send_tasks = [ + self._send( + prefill_host, + port, + mm_data.embedding, + mm_data, + session_id, + buffer_address, + ) + for port in embedding_ports + ] + results = await asyncio.gather(*send_tasks, return_exceptions=True) + + for i, result in enumerate(results): + if isinstance(result, Exception): + logger.error( + f"Failed to send to port {embedding_ports[i]}: {result}" + ) + else: + logger.debug(f"Successfully sent to port {embedding_ports[i]}") + + del self.embedding_to_send[req_id] + except Exception as e: + logger.error(f"Error sending embeddings for req_id {req_id}: {e}") + raise async def get_embedding_port(self, prefill_url): async with aiohttp.ClientSession( @@ -287,6 +381,7 @@ def launch_server(server_args: ServerArgs): @app.post("/encode") async def handle_encode_request(request: dict): # broadcast request + request.update({"enter_time": time.time()}) for socket in send_sockets: socket.send_pyobj(request) @@ -296,6 +391,8 @@ async def handle_encode_request(request: dict): num_parts=request["num_parts"], part_idx=request["part_idx"], ) + time3 = time.time() + # print(f"🕛 send_time = {(time2 - time1) * 1000:.2f}ms, encode_time = {(time3 - time2) * 1000:.2f}ms") if encoder.server_args.mm_transfer_backend == "mooncake": del request["mm_items"] request.update( @@ -310,7 +407,7 @@ async def handle_encode_request(request: dict): await encoder.send( req_id=request["req_id"], prefill_host=request["prefill_host"], - embedding_port=request["embedding_port"], + embedding_ports=request["embedding_ports"], ) return ORJSONResponse(content=None) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index e34736cc409c..c265932d6e06 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -227,6 +227,9 @@ class GenerateReqInput(BaseReq): # Whether to return entropy return_entropy: bool = False + embedding_ports: Optional[List] = None + need_wait_for_image: Optional[bool] = None + def contains_mm_input(self) -> bool: return ( has_valid_data(self.image_data) @@ -696,6 +699,9 @@ class TokenizedGenerateReqInput(BaseReq): # Whether to return entropy return_entropy: bool = False + need_wait_for_image: bool = False + embedding_ports: list[str] = None + @dataclass class BatchTokenizedGenerateReqInput(BaseBatchReq): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index b801fd8f8e63..e41c22e5165d 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -16,6 +16,7 @@ import faulthandler import logging import os +import pickle import signal import sys import threading @@ -48,6 +49,7 @@ from sglang.srt.disaggregation.decode_kvcache_offload_manager import ( DecodeKVCacheOffloadManager, ) +from sglang.srt.disaggregation.encode_receiver import EmbeddingData from sglang.srt.disaggregation.prefill import ( PrefillBootstrapQueue, SchedulerDisaggregationPrefillMixin, @@ -117,6 +119,7 @@ UpdateWeightsFromTensorReqInput, ) from sglang.srt.managers.mm_utils import init_mm_embedding_cache +from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors from sglang.srt.managers.overlap_utils import FutureMap from sglang.srt.managers.schedule_batch import ( FINISH_ABORT, @@ -203,6 +206,91 @@ GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300)) +def _determine_tensor_transport_mode(server_args: ServerArgs): + is_cross_node = server_args.dist_init_addr + + if is_cross_node: + # Fallback to default CPU transport for multi-node + return "default" + else: + return "cuda_ipc" + + +class WaitingImageRequest: + def __init__(self, rid: str, recv_req, req, embedding_port, mm_processor): + self.rid = rid + self.recv_req = recv_req + self.req = req + self.embedding_port = embedding_port + self.embedding_ready = False + self.mm_inputs = None + self.error = None + self.thread = None + self.ready = False + self.mm_processor = mm_processor + + def start_waiting(self): + self.thread = threading.Thread(target=self._recv_mm_data_thread, daemon=True) + self.thread.start() + + def _recv_mm_data_thread(self): + try: + mm_processor = self.mm_processor + prompt = self.recv_req.input_text + + self.recv_req.mm_inputs = self._recv_mm_data_sync( + self.rid, self.embedding_port, mm_processor, prompt + ) + + self.embedding_ready = True + + except Exception as e: + logger.error(f"Error receiving embedding for {self.rid}: {e}") + self.error = str(e) + + def _recv_mm_data_sync(self, req_id, embedding_port, mm_processor, prompt): + if req_id is None: + return None + context = zmq.Context() + recv_socket = context.socket(zmq.PULL) + recv_socket.bind(f"tcp://*:{embedding_port}") + logger.info(f"Waiting for input {embedding_port = }") + try: + recv_embedding_data = None + while recv_embedding_data is None or not recv_embedding_data.ready: + try: + parts = recv_socket.recv_multipart(flags=zmq.NOBLOCK, copy=False) + except zmq.Again: + # No data available yet, wait a bit and retry + continue + + recv_obj: EmbeddingData = pickle.loads(parts[0]) + buffer = parts[1].buffer if hasattr(parts[1], "buffer") else parts[1] + recv_obj.embedding = torch.frombuffer( + buffer, dtype=recv_obj.dtype + ).reshape(recv_obj.shape) + recv_obj.embedding_list[recv_obj.part_idx] = recv_obj.embedding + print( + f"transport cost {(time.time() - recv_obj.send_time) * 1000:.2f}" + ) + if recv_embedding_data is None: + recv_embedding_data = recv_obj + else: + recv_embedding_data.add(recv_obj) + recv_embedding = recv_embedding_data.get_embedding() + img_grid_thw = recv_embedding_data.get_img_grid() + + mm_inputs = mm_processor.get_mm_data(prompt, recv_embedding, img_grid_thw) + if mm_inputs and "input_ids" in mm_inputs: + self.req.origin_input_ids = mm_inputs["input_ids"] + self.ready = True + return mm_inputs + + finally: + recv_socket.close() + context.term() + + @dataclass class EmbeddingBatchResult: embeddings: torch.Tensor @@ -541,6 +629,38 @@ def __init__( # Init mlp sync flag self.require_mlp_sync = require_mlp_sync(server_args) + self.waiting_for_image: List[WaitingImageRequest] = [] + transport_mode = _determine_tensor_transport_mode(self.server_args) + if self.model_config.is_multimodal: + import_processors("sglang.srt.multimodal.processors") + try: + _processor = get_processor( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, + use_fast=not server_args.disable_fast_image_processor, + ) + except ValueError as e: + error_message = str(e) + if "does not have a slow version" in error_message: + logger.info( + f"Processor {server_args.tokenizer_path} does not have a slow version. Automatically use fast version" + ) + _processor = get_processor( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, + use_fast=True, + ) + else: + raise e + self.mm_processor = get_mm_processor( + self.model_config.hf_config, server_args, _processor, transport_mode + ) + print(f"{self.mm_processor = }") + # Init request dispatcher self._request_dispatcher = TypeBasedDispatcher( [ @@ -970,6 +1090,7 @@ def init_moe_config(self): def event_loop_normal(self): """A normal scheduler loop.""" while True: + self.process_waiting_requests() recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) @@ -1005,6 +1126,7 @@ def pop_and_process(): self.process_batch_result(tmp_batch, tmp_result) while True: + self.process_waiting_requests() recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) @@ -1272,6 +1394,128 @@ def _get_multimodal_inputs(self, mm_inputs_dict: dict): else: return MultimodalInputs.from_dict(mm_inputs_dict) + def _complete_multimodal_request(self, waiting_req: WaitingImageRequest): + recv_req = waiting_req.recv_req + req: Req = waiting_req.req + image_inputs = self._get_multimodal_inputs(recv_req.mm_inputs) + + # The following steps are already fast, execute locally on each rank. + # Expand a single image token into multiple dummy tokens for receiving image embeddings + req.origin_input_ids = self.pad_input_ids_func( + req.origin_input_ids, image_inputs + ) + req.extend_image_inputs(image_inputs) + + if len(req.origin_input_ids) >= self.max_req_input_len: + req.set_finish_with_abort( + error_msg=( + "Multimodal prompt is too long after expanding multimodal tokens. " + f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}." + ) + ) + self.init_req_max_new_tokens(req) + self._add_request_to_queue(req) + return + + # initialize before returning + self.init_req_max_new_tokens(req) + + # Validate prompt length + error_msg = validate_input_length( + req, + self.max_req_input_len, + self.server_args.allow_auto_truncate, + ) + if error_msg: + req.set_finish_with_abort(error_msg) + self._add_request_to_queue(req) + return + + # Copy more attributes + if recv_req.logprob_start_len == -1 or not recv_req.return_logprob: + # By default, only return the logprobs for output tokens + # For prefill-only requests with logprob_start_len == -1, set logprob_start_len beyond input sequence + # to skip input logprob computation entirely + if req.is_prefill_only: + req.logprob_start_len = len(req.origin_input_ids) + else: + # TODO: For text generation, evaluate setting logprob_start_len to len(req.origin_input_ids) as well + req.logprob_start_len = len(req.origin_input_ids) - 1 + else: + req.logprob_start_len = recv_req.logprob_start_len + + if not req.is_prefill_only and req.logprob_start_len >= len( + req.origin_input_ids + ): + error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len." + req.logprob_start_len = len(req.origin_input_ids) - 1 + req.set_finish_with_abort(error_msg) + self._add_request_to_queue(req) + return + + # Init grammar cache for this request + add_to_grammar_queue = False + if ( + req.sampling_params.json_schema is not None + or req.sampling_params.regex is not None + or req.sampling_params.ebnf is not None + or req.sampling_params.structural_tag is not None + ): + if self.grammar_backend is None: + error_msg = "Grammar-based generation (json_schema, regex, ebnf, structural_tag) is not supported when the server is launched with --grammar-backend none" + req.set_finish_with_abort(error_msg) + else: + if req.sampling_params.json_schema is not None: + key = ("json", req.sampling_params.json_schema) + elif req.sampling_params.regex is not None: + key = ("regex", req.sampling_params.regex) + elif req.sampling_params.ebnf is not None: + key = ("ebnf", req.sampling_params.ebnf) + elif req.sampling_params.structural_tag: + key = ("structural_tag", req.sampling_params.structural_tag) + + value, cache_hit = self.grammar_backend.get_cached_or_future_value(key) + req.grammar = value + + if not cache_hit: + req.grammar_key = key + add_to_grammar_queue = True + else: + if value is INVALID_GRAMMAR_OBJ: # We hit a cached invalid grammar. + error_msg = f"Invalid grammar request with cache hit: {key=}" + req.set_finish_with_abort(error_msg) + + if add_to_grammar_queue: + self.grammar_queue.append(req) + else: + self._add_request_to_queue(req) + + def process_waiting_requests( + self, + ) -> None: + if not self.waiting_for_image or len(self.waiting_for_image) == 0: + return + local_statuses = [] + request_rids = [] + + for waiting_req in self.waiting_for_image: + local_statuses.append(1 if waiting_req.embedding_ready else 0) + request_rids.append(waiting_req.rid) + + local_tensor = torch.tensor(local_statuses, dtype=torch.int32) + + if torch.cuda.is_available(): + local_tensor = local_tensor.cuda() + + torch.distributed.all_reduce(local_tensor, op=torch.distributed.ReduceOp.MIN) + new_waiting = [] + for i, waiting_req in enumerate(self.waiting_for_image): + if local_tensor[i].item() == 1: + self._complete_multimodal_request(waiting_req) + else: + new_waiting.append(waiting_req) + self.waiting_for_image = new_waiting + def handle_generate_request( self, recv_req: TokenizedGenerateReqInput, @@ -1352,6 +1596,21 @@ def handle_generate_request( self._add_request_to_queue(req) return + # TODO: add this + if recv_req.need_wait_for_image is True: + assert recv_req.embedding_ports is not None + waiting_req = WaitingImageRequest( + rid=recv_req.rid, + recv_req=recv_req, + req=req, + embedding_port=recv_req.embedding_ports[self.tp_rank], + mm_processor=self.mm_processor, + ) + self.waiting_for_image.append(waiting_req) + waiting_req.start_waiting() + + return + # Handle multimodal inputs if recv_req.mm_inputs is not None: image_inputs = self._get_multimodal_inputs(recv_req.mm_inputs) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index abb20355495f..f365aeb7b810 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -20,10 +20,12 @@ import math import os import pickle +import random import signal import sys import threading import time +import uuid from collections import deque from contextlib import nullcontext from datetime import datetime @@ -31,6 +33,7 @@ from http import HTTPStatus from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union +import aiohttp import fastapi import orjson import torch @@ -97,13 +100,14 @@ trace_slice_start, ) from sglang.srt.utils import ( + ImageData, configure_gc_warning, dataclass_to_string_truncated, freeze_gc, get_bool_env_var, get_or_create_event_loop, - get_free_port, get_local_ip_auto, + get_multi_free_port, get_zmq_socket, kill_process_tree, ) @@ -441,6 +445,95 @@ def __init__( ] ) self.init_communicators(server_args) + # self.host = self.server_args.host + # self.ports_pool = get_multi_free_port(self.server_args.tp_size * 20) + self.encode_idx = list(range(len(self.server_args.encode_urls))) + self.encode_urls = self.server_args.encode_urls + self.riq_2_images = {} + + @staticmethod + def extrac_and_clean_image_url(obj: GenerateReqInput): + image_urls = [] + for image in obj.image_data: + if isinstance(image, ImageData): + image_urls.append(image.url) + image.url = "" + return image_urls + + def _run_encode_in_thread(self, req_id, img_data, embedding_ports, endpoint_encode): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + loop.run_until_complete( + self._encode_in_background( + req_id, img_data, embedding_ports, endpoint_encode + ) + ) + # logger.info(f"Encode completed for request {req_id}") + except Exception as e: + logger.error(f"Encode failed for request {req_id}: {e}", exc_info=True) + finally: + # self.ports_pool.extend(embedding_ports) + del self.riq_2_images[req_id] + loop.close() + + async def _encode_in_background( + self, req_id, img_data, embedding_ports, endpoint_encode + ): + if len(img_data) == 0: + return + + # Split mm_items + encode_requests = [] + random.shuffle(self.encode_idx) + num_items_assigned = [ + (idx + len(img_data)) // len(self.server_args.encode_urls) + for idx in self.encode_idx + ] + num_parts = sum(1 for x in num_items_assigned if x != 0) + cum_num_items = 0 + cum_idx = 0 + for idx, assigned_num in enumerate(num_items_assigned): + if assigned_num == 0: + continue + encode_requests.append( + { + "encoder_idx": idx, + "mm_items": img_data[cum_num_items : cum_num_items + assigned_num], + "num_parts": num_parts, + "part_idx": cum_idx, + "req_id": req_id, + "prefill_host": get_local_ip_auto(), + "embedding_ports": embedding_ports, + } + ) + cum_idx += 1 + cum_num_items += assigned_num + + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout( + total=1800 + ) # Add timeout for request reliability + ) as session: + # Send encode requests + + tasks = [ + session.post( + f"{self.encode_urls[encode_request['encoder_idx']]}/{endpoint_encode}", + json=encode_request, + ) + for encode_request in encode_requests + ] + + responses = await asyncio.gather(*tasks) + response_json_list_unsort = [ + await response.json() for response in responses + ] + + # zmq backend: return is None + if None in response_json_list_unsort: + return async def generate_request( self, @@ -450,6 +543,23 @@ async def generate_request( created_time = time.time() self.auto_create_handle_loop() obj.normalize_batch_and_arguments() + if isinstance(obj, GenerateReqInput): + image_urls = TokenizerManager.extrac_and_clean_image_url(obj) + if obj.rid is None: + obj.rid = uuid.uuid4().hex + if image_urls and len(image_urls) > 0: + logger.info( + f"Processing {len(image_urls)} images for request {obj.rid}" + ) + self.riq_2_images[obj.rid] = image_urls + obj.embedding_ports = get_multi_free_port(self.server_args.tp_size) + obj.need_wait_for_image = True + encode_thread = threading.Thread( + target=self._run_encode_in_thread, + args=(obj.rid, image_urls, obj.embedding_ports, "encode"), + daemon=True, + ) + encode_thread.start() if self.enable_trace: external_trace_header = None @@ -727,21 +837,17 @@ async def _tokenize_one_request( obj.audio_data = [obj.audio_data] mm_inputs = None - if self.server_args.language_only: - mm_inputs: Dict = await self.mm_receiver.recv_mm_data( - obj.image_data, - self.mm_processor, - input_text or input_ids, - ) - if mm_inputs is None: - mm_inputs: Dict = await self.mm_data_processor.process( - image_data=obj.image_data, - audio_data=obj.audio_data, - input_text_or_ids=(input_text or input_ids), - request_obj=obj, - max_req_input_len=self.max_req_input_len, - ) + if self.server_args.language_only is False: + + if mm_inputs is None: + mm_inputs: Dict = await self.mm_data_processor.process( + image_data=obj.image_data, + audio_data=obj.audio_data, + input_text_or_ids=(input_text or input_ids), + request_obj=obj, + max_req_input_len=self.max_req_input_len, + ) if mm_inputs and "input_ids" in mm_inputs: input_ids = mm_inputs["input_ids"] @@ -918,6 +1024,8 @@ def _create_tokenized_object( data_parallel_rank=obj.data_parallel_rank, priority=obj.priority, extra_key=obj.extra_key, + need_wait_for_image=obj.need_wait_for_image, + embedding_ports=obj.embedding_ports, ) elif isinstance(obj, EmbeddingReqInput): tokenized_obj = TokenizedEmbeddingReqInput( diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 6fa0b2404ba0..320d35824af7 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -713,6 +713,39 @@ def is_port_available(port): return False +def get_multi_free_port(count, min_port=10000, max_port=65535, max_attempts=1000): + ports = set() + attempts = 0 + + while len(ports) < count and attempts < max_attempts: + attempts += 1 + + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + port = s.getsockname()[1] + + if min_port <= port <= max_port: + ports.add(port) + except OSError: + try: + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + port = s.getsockname()[1] + + if min_port <= port <= max_port: + ports.add(port) + except OSError: + continue + + if len(ports) < count: + raise RuntimeError( + f"Could not find {count} free ports after {max_attempts} attempts" + ) + + return list(ports) + + def get_free_port(): # try ipv4 try: From 7c90f78fd85aec0e3b8ab21a46179f9ed483f9a1 Mon Sep 17 00:00:00 2001 From: liusy58 Date: Mon, 24 Nov 2025 17:42:45 +0800 Subject: [PATCH 47/68] [fix] support --dist-init-addr --- .../srt/disaggregation/encode_server.py | 114 ++++++++++++------ python/sglang/srt/managers/io_struct.py | 7 +- python/sglang/srt/managers/scheduler.py | 86 +++++++++++-- .../sglang/srt/managers/tokenizer_manager.py | 37 +++--- 4 files changed, 186 insertions(+), 58 deletions(-) diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py index ff3ac1d3607d..ff954334bed1 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -39,6 +39,11 @@ logger = logging.getLogger(__name__) import ctypes import sys +from typing import Dict, Set, Tuple + +rid_2_receive_endpoint: Dict[str, List[str]] = dict() +rid_2_receive_count: Dict[str, int] = dict() +rid_2_ready_event: Dict[str, asyncio.Event] = dict() class TensorWrapper: @@ -225,8 +230,7 @@ async def _encode(self, mm_items) -> torch.Tensor: async def _send( self, - prefill_host: int, - embedding_port: int, + url, embedding: torch.Tensor, mm_data: EmbeddingData, session_id=None, @@ -241,12 +245,12 @@ async def _send( mm_data.embedding = None mm_data.embedding_list[mm_data.part_idx] = None - logger.info(f" [{self.rank}] Sending to {prefill_host}:{embedding_port}") + logger.info(f" [{self.rank}] Sending to {url}") # Send ack/data socket = get_zmq_socket( self.context, zmq.PUSH, - f"tcp://{prefill_host}:{embedding_port}", + f"tcp://{url}", False, ) @@ -274,45 +278,76 @@ async def encode(self, mm_items, req_id, num_parts, part_idx): async def send( self, req_id, - prefill_host, - embedding_ports, session_id=None, buffer_address=None, ): - mm_data: EmbeddingData = self.embedding_to_send[req_id] - + mm_data = self.embedding_to_send.get(req_id) if not mm_data: - logger.error(f"No embedding data found for req_id: {req_id}") return - logger.info(f"{self.rank=} {embedding_ports = }") - # send_tasks = [] - # ports_to_send = get_ports_for_rank(embedding_ports, self.rank, self.server_args.tp_size) + sent_urls: Set[str] = set() + all_tasks: List[Tuple[asyncio.Task, str]] = [] + start_time = asyncio.get_running_loop().time() + timeout = 60.0 + try: - send_tasks = [ - self._send( - prefill_host, - port, - mm_data.embedding, - mm_data, - session_id, - buffer_address, - ) - for port in embedding_ports - ] - results = await asyncio.gather(*send_tasks, return_exceptions=True) + while True: + current_targets = rid_2_receive_endpoint.get(req_id, set()).copy() + expected_count = rid_2_receive_count.get(req_id) + + new_targets = current_targets - sent_urls + + if new_targets: + logger.info( + f"Found {len(new_targets)} new endpoints for {req_id}. Starting tasks..." + ) + for url in new_targets: + task = asyncio.create_task( + self._send( + url, + mm_data.embedding, + mm_data, + session_id, + buffer_address, + ) + ) + all_tasks.append((task, url)) + sent_urls.add(url) # Mark as handled immediately + if expected_count is not None and len(sent_urls) >= expected_count: + logger.info( + f"All {expected_count} endpoints initiated for {req_id}. Breaking loop." + ) + break - for i, result in enumerate(results): - if isinstance(result, Exception): + if asyncio.get_running_loop().time() - start_time > timeout: logger.error( - f"Failed to send to port {embedding_ports[i]}: {result}" + f"Timeout waiting for all endpoints for {req_id}. Initiated {len(sent_urls)}/{expected_count}" ) - else: - logger.debug(f"Successfully sent to port {embedding_ports[i]}") + break - del self.embedding_to_send[req_id] - except Exception as e: - logger.error(f"Error sending embeddings for req_id {req_id}: {e}") - raise + await asyncio.sleep(0.001) + + if all_tasks: + logger.info( + f"Loop finished. Awaiting completion of {len(all_tasks)} sending tasks..." + ) + tasks_only = [t[0] for t in all_tasks] + results = await asyncio.gather(*tasks_only, return_exceptions=True) + + # Process results and log errors + for i, result in enumerate(results): + url = all_tasks[i][1] # Retrieve URL associated with the task + if isinstance(result, Exception): + logger.error(f"Failed to send to {url}: {result}") + else: + logger.debug(f"Successfully sent to {url}") + + logger.info(f"All tasks completed for req_id: {req_id}") + + finally: + logger.info(f"Cleaning up resources for req_id {req_id}") + rid_2_receive_endpoint.pop(req_id, None) + rid_2_receive_count.pop(req_id, None) + self.embedding_to_send.pop(req_id, None) async def get_embedding_port(self, prefill_url): async with aiohttp.ClientSession( @@ -406,8 +441,6 @@ async def handle_encode_request(request: dict): elif encoder.server_args.mm_transfer_backend == "zmq": await encoder.send( req_id=request["req_id"], - prefill_host=request["prefill_host"], - embedding_ports=request["embedding_ports"], ) return ORJSONResponse(content=None) @@ -423,3 +456,14 @@ async def handle_send_request(request: dict): buffer_address=request["buffer_address"], ) return ORJSONResponse(content=None) + + +@app.post("/scheduler_receive_url") +async def handle_scheduler_receive_url_request(request: dict): + rid = request["req_id"] + global rid_2_receive_endpoint + if rid not in rid_2_receive_endpoint: + rid_2_receive_endpoint[rid] = set() + rid_2_receive_count[rid] = request["receive_count"] + assert rid_2_receive_count[rid] == request["receive_count"] + rid_2_receive_endpoint[rid].add(request["receive_url"]) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index c265932d6e06..4633a9872393 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -227,9 +227,11 @@ class GenerateReqInput(BaseReq): # Whether to return entropy return_entropy: bool = False - embedding_ports: Optional[List] = None need_wait_for_image: Optional[bool] = None + num_items_assigned: Optional[List] = None + encode_idx: Optional[List] = None + def contains_mm_input(self) -> bool: return ( has_valid_data(self.image_data) @@ -700,7 +702,8 @@ class TokenizedGenerateReqInput(BaseReq): return_entropy: bool = False need_wait_for_image: bool = False - embedding_ports: list[str] = None + num_items_assigned: Optional[List] = None + encode_idx: Optional[List] = None @dataclass diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e41c22e5165d..98dfcd2e4b11 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -13,6 +13,7 @@ # ============================================================================== """A scheduler that manages a tensor parallel GPU worker.""" +import asyncio import faulthandler import logging import os @@ -27,6 +28,7 @@ from http import HTTPStatus from typing import Any, Deque, Dict, List, Optional, Tuple, Union +import aiohttp import psutil import setproctitle import torch @@ -179,7 +181,9 @@ freeze_gc, get_available_gpu_memory, get_bool_env_var, + get_free_port, get_int_env_var, + get_local_ip_by_remote, get_zmq_socket, kill_itself_when_parent_died, numa_bind_to_node, @@ -217,17 +221,30 @@ def _determine_tensor_transport_mode(server_args: ServerArgs): class WaitingImageRequest: - def __init__(self, rid: str, recv_req, req, embedding_port, mm_processor): + def __init__( + self, + rid: str, + recv_req: TokenizedGenerateReqInput, + req, + mm_processor, + image_urls, + host_name, + receive_count, + ): self.rid = rid self.recv_req = recv_req self.req = req - self.embedding_port = embedding_port self.embedding_ready = False self.mm_inputs = None self.error = None self.thread = None self.ready = False self.mm_processor = mm_processor + self.image_urls = image_urls + self.host_name = host_name + self.receive_count = receive_count + self.num_items_assigned = recv_req.num_items_assigned + self.encode_idx = recv_req.encode_idx def start_waiting(self): self.thread = threading.Thread(target=self._recv_mm_data_thread, daemon=True) @@ -237,9 +254,8 @@ def _recv_mm_data_thread(self): try: mm_processor = self.mm_processor prompt = self.recv_req.input_text - self.recv_req.mm_inputs = self._recv_mm_data_sync( - self.rid, self.embedding_port, mm_processor, prompt + self.rid, mm_processor, prompt ) self.embedding_ready = True @@ -248,14 +264,67 @@ def _recv_mm_data_thread(self): logger.error(f"Error receiving embedding for {self.rid}: {e}") self.error = str(e) - def _recv_mm_data_sync(self, req_id, embedding_port, mm_processor, prompt): + def _recv_mm_data_sync(self, req_id, mm_processor, prompt): if req_id is None: return None + embedding_port = get_free_port() context = zmq.Context() recv_socket = context.socket(zmq.PULL) recv_socket.bind(f"tcp://*:{embedding_port}") logger.info(f"Waiting for input {embedding_port = }") + + async def _send_single_request(session, url, payload): + try: + async with session.post(url, json=payload) as response: + response.raise_for_status() + return await response.text() + except Exception as e: + logger.error(f"Failed to send request to {url}: {e}") + raise + + async def send_embedding_port(req_id, receive_count, host_name, embedding_port): + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=1800) + ) as session: + tasks = [] + logger.info(f"{self.num_items_assigned = } ") + for idx, assigned_num in enumerate(self.num_items_assigned): + if assigned_num == 0: + continue + image_url = self.image_urls[self.encode_idx[idx]] + target_url = f"{image_url}/scheduler_receive_url" + payload = { + "req_id": req_id, + "receive_count": receive_count, + "receive_url": f"{host_name}:{embedding_port}", + } + + logger.info(f"Preparing to send to {target_url}") + + task = _send_single_request(session, target_url, payload) + tasks.append(task) + + if not tasks: + logger.info("No tasks to send.") + return + logger.info(f"Concurrently sending {len(tasks)} requests...") + results = await asyncio.gather(*tasks, return_exceptions=True) + + for i, result in enumerate(results): + if isinstance(result, Exception): + logger.error(f"Request {i} failed: {result}") + else: + logger.debug(f"Request {i} succeeded.") + try: + asyncio.run( + send_embedding_port( + self.recv_req.rid, + self.receive_count, + self.host_name, + embedding_port, + ) + ) recv_embedding_data = None while recv_embedding_data is None or not recv_embedding_data.ready: try: @@ -322,6 +391,7 @@ def __init__( ): # Parse args self.server_args = server_args + self.host_name = get_local_ip_by_remote() self.tp_rank = tp_rank self.moe_ep_rank = moe_ep_rank self.pp_rank = pp_rank @@ -1598,13 +1668,15 @@ def handle_generate_request( # TODO: add this if recv_req.need_wait_for_image is True: - assert recv_req.embedding_ports is not None waiting_req = WaitingImageRequest( rid=recv_req.rid, recv_req=recv_req, req=req, - embedding_port=recv_req.embedding_ports[self.tp_rank], mm_processor=self.mm_processor, + image_urls=self.server_args.encode_urls, + host_name=self.host_name, + ##TODO fixme: + receive_count=self.server_args.tp_size, ) self.waiting_for_image.append(waiting_req) waiting_req.start_waiting() diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index f365aeb7b810..e8d56a255968 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -106,7 +106,6 @@ freeze_gc, get_bool_env_var, get_or_create_event_loop, - get_local_ip_auto, get_multi_free_port, get_zmq_socket, kill_process_tree, @@ -460,14 +459,16 @@ def extrac_and_clean_image_url(obj: GenerateReqInput): image.url = "" return image_urls - def _run_encode_in_thread(self, req_id, img_data, embedding_ports, endpoint_encode): + def _run_encode_in_thread( + self, req_id, img_data, endpoint_encode, num_items_assigned, encode_idx + ): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: loop.run_until_complete( self._encode_in_background( - req_id, img_data, embedding_ports, endpoint_encode + req_id, img_data, endpoint_encode, num_items_assigned, encode_idx ) ) # logger.info(f"Encode completed for request {req_id}") @@ -479,18 +480,14 @@ def _run_encode_in_thread(self, req_id, img_data, embedding_ports, endpoint_enco loop.close() async def _encode_in_background( - self, req_id, img_data, embedding_ports, endpoint_encode + self, req_id, img_data, endpoint_encode, num_items_assigned, encode_idx ): if len(img_data) == 0: return # Split mm_items encode_requests = [] - random.shuffle(self.encode_idx) - num_items_assigned = [ - (idx + len(img_data)) // len(self.server_args.encode_urls) - for idx in self.encode_idx - ] + num_parts = sum(1 for x in num_items_assigned if x != 0) cum_num_items = 0 cum_idx = 0 @@ -499,13 +496,11 @@ async def _encode_in_background( continue encode_requests.append( { - "encoder_idx": idx, + "encoder_idx": encode_idx[idx], "mm_items": img_data[cum_num_items : cum_num_items + assigned_num], "num_parts": num_parts, "part_idx": cum_idx, "req_id": req_id, - "prefill_host": get_local_ip_auto(), - "embedding_ports": embedding_ports, } ) cum_idx += 1 @@ -554,9 +549,22 @@ async def generate_request( self.riq_2_images[obj.rid] = image_urls obj.embedding_ports = get_multi_free_port(self.server_args.tp_size) obj.need_wait_for_image = True + + random.shuffle(self.encode_idx) + obj.encode_idx = self.encode_idx + obj.num_items_assigned = [ + (idx + len(image_urls)) // len(self.server_args.encode_urls) + for idx in self.encode_idx + ] encode_thread = threading.Thread( target=self._run_encode_in_thread, - args=(obj.rid, image_urls, obj.embedding_ports, "encode"), + args=( + obj.rid, + image_urls, + "encode", + obj.num_items_assigned, + obj.encode_idx, + ), daemon=True, ) encode_thread.start() @@ -1025,7 +1033,8 @@ def _create_tokenized_object( priority=obj.priority, extra_key=obj.extra_key, need_wait_for_image=obj.need_wait_for_image, - embedding_ports=obj.embedding_ports, + num_items_assigned=obj.num_items_assigned, + encode_idx=obj.encode_idx, ) elif isinstance(obj, EmbeddingReqInput): tokenized_obj = TokenizedEmbeddingReqInput( From 484c6475b1460bf8d622f97768bc38d942a871ea Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Tue, 25 Nov 2025 12:05:03 +0000 Subject: [PATCH 48/68] Refactor and fix bugs (Add zmq_s zmq_t and mooncake backend) --- .../srt/disaggregation/encode_receiver.py | 320 +++++++++++++++- .../srt/disaggregation/encode_server.py | 130 ++++--- python/sglang/srt/managers/io_struct.py | 5 +- python/sglang/srt/managers/scheduler.py | 348 +----------------- .../sglang/srt/managers/tokenizer_manager.py | 149 +------- .../multimodal/processors/base_processor.py | 15 +- python/sglang/srt/server_args.py | 4 +- 7 files changed, 435 insertions(+), 536 deletions(-) diff --git a/python/sglang/srt/disaggregation/encode_receiver.py b/python/sglang/srt/disaggregation/encode_receiver.py index 144d5aca2e90..da415fc7de72 100644 --- a/python/sglang/srt/disaggregation/encode_receiver.py +++ b/python/sglang/srt/disaggregation/encode_receiver.py @@ -1,13 +1,25 @@ import asyncio import logging +import pickle import random +import threading +import uuid +from typing import List import aiohttp import torch import zmq import zmq.asyncio +from python.sglang.srt.managers.multimodal_processor import ( + get_mm_processor, + import_processors, +) +from python.sglang.srt.utils.common import get_multi_free_port +from python.sglang.srt.utils.hf_transformers_utils import get_processor from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine +from sglang.srt.managers.io_struct import TokenizedGenerateReqInput +from sglang.srt.server_args import ServerArgs from sglang.srt.utils import get_free_port, get_local_ip_auto, get_zmq_socket logger = logging.getLogger(__name__) @@ -68,42 +80,269 @@ def copy_without_embedding(self): return new_data -def _generate_id(): - req_id = random.randint(0, 2**63 - 1) - return req_id +# For zmq_s +class WaitingImageRequest: + def __init__( + self, + rid: str, + recv_req: TokenizedGenerateReqInput, + mm_processor, + image_urls, + host_name, + receive_count, + embedding_port=None, + ): + self.rid = rid + self.recv_req = recv_req + self.mm_inputs = None + self.error = None + self.thread = None + self.mm_processor = mm_processor + self.image_urls = image_urls + self.host_name = host_name + self.receive_count = receive_count + self.num_items_assigned = recv_req.num_items_assigned + self.embedding_port = ( + get_free_port() if embedding_port is None else embedding_port + ) + self.context = zmq.Context() + self.recv_socket = self.context.socket(zmq.PULL) + self.recv_socket.bind(f"tcp://*:{self.embedding_port}") + logger.info(f"Waiting for input {self.embedding_port = }") + self.recv_embedding_data = None + self.ready = False + + def send_encode_request(self): + async def _send_single_request(session, url, payload): + try: + async with session.post(url, json=payload) as response: + response.raise_for_status() + return await response.text() + except Exception as e: + logger.error(f"Failed to send request to {url}: {e}") + raise + + async def send_embedding_port(req_id, receive_count, host_name, embedding_port): + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=1800) + ) as session: + tasks = [] + logger.info(f"{self.num_items_assigned = } ") + for idx, assigned_num in enumerate(self.num_items_assigned): + if assigned_num == 0: + continue + image_url = self.image_urls[idx] + target_url = f"{image_url}/scheduler_receive_url" + payload = { + "req_id": req_id, + "receive_count": receive_count, + "receive_url": f"{host_name}:{embedding_port}", + } + + logger.info(f"Preparing to send to {target_url}") + + task = _send_single_request(session, target_url, payload) + tasks.append(task) + + if not tasks: + logger.info("No tasks to send.") + return + logger.info(f"Concurrently sending {len(tasks)} requests...") + results = await asyncio.gather(*tasks, return_exceptions=True) + + for i, result in enumerate(results): + if isinstance(result, Exception): + logger.error(f"Request {i} failed: {result}") + else: + logger.debug(f"Request {i} succeeded.") + + asyncio.run( + send_embedding_port( + self.recv_req.rid, + self.receive_count, + self.host_name, + self.embedding_port, + ) + ) + + def _try_recv_mm_data(self): + if self.ready: + return True + while self.recv_embedding_data is None or not self.recv_embedding_data.ready: + try: + parts = self.recv_socket.recv_multipart(flags=zmq.NOBLOCK, copy=False) + except zmq.Again: + # No data available yet, wait a bit and retry + return False + + recv_obj: EmbeddingData = pickle.loads(parts[0]) + buffer = parts[1].buffer if hasattr(parts[1], "buffer") else parts[1] + recv_obj.embedding = torch.frombuffer(buffer, dtype=recv_obj.dtype).reshape( + recv_obj.shape + ) + recv_obj.embedding_list[recv_obj.part_idx] = recv_obj.embedding + if self.recv_embedding_data is None: + self.recv_embedding_data = recv_obj + else: + self.recv_embedding_data.add(recv_obj) + + recv_embedding = self.recv_embedding_data.get_embedding() + img_grid_thw = self.recv_embedding_data.get_img_grid() + + mm_inputs = self.mm_processor.get_mm_data( + self.recv_req.input_text, recv_embedding, img_grid_thw + ) + self.recv_req.mm_inputs = mm_inputs + self.recv_req.input_ids = mm_inputs["input_ids"] + self.ready = True + return True + + +def _determine_tensor_transport_mode(server_args): + is_cross_node = server_args.dist_init_addr + + if is_cross_node: + # Fallback to default CPU transport for multi-node + return "default" + else: + return "cuda_ipc" class MMReceiver: def __init__( - self, host, encode_urls, mm_transfer_backend, disaggregation_ib_device, dtype + self, + server_args: ServerArgs, + dtype=None, + hostname=None, + hf_config=None, + pp_rank=None, + tp_rank=None, ): self.context = zmq.asyncio.Context(20) - self.mm_transfer_backend = mm_transfer_backend - self.dtype = dtype - self.encode_urls = encode_urls + self.mm_transfer_backend = server_args.mm_transfer_backend + self.encode_urls = server_args.encode_urls self.encode_idx = list(range(len(self.encode_urls))) - self.host = host + self.host = server_args.host if self.mm_transfer_backend == "mooncake": + self.dtype = dtype self.embeddings_engine = MooncakeTransferEngine( hostname=get_local_ip_auto(), gpu_id=None, - ib_device=disaggregation_ib_device, + ib_device=server_args.disaggregation_ib_device, ) self.embeddings_buffer = dict() + elif self.mm_transfer_backend == "zmq_s": + self.pp_rank = pp_rank + self.tp_rank = tp_rank + self.tp_size = server_args.tp_size + self.nnodes = server_args.nnodes + self.hostname = hostname + self.world_size = server_args.pp_size * server_args.tp_size + if hf_config is not None: + transport_mode = _determine_tensor_transport_mode(server_args) + import_processors("sglang.srt.multimodal.processors") + _processor = None + try: + _processor = get_processor( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, + use_fast=not server_args.disable_fast_image_processor, + ) + except ValueError as e: + error_message = str(e) + if "does not have a slow version" in error_message: + logger.info( + f"Processor {server_args.tokenizer_path} does not have a slow version. Automatically use fast version" + ) + _processor = get_processor( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, + use_fast=True, + ) + else: + raise e + self.mm_processor = get_mm_processor( + hf_config, server_args, _processor, transport_mode + ) + + # For zmq_s + def process_waiting_requests(self, recv_reqs): + waiting_list: List[WaitingImageRequest] = [] + for recv_req in recv_reqs: + # E Disaggregation + if ( + isinstance(recv_req, TokenizedGenerateReqInput) + and recv_req.need_wait_for_image is True + ): + embedding_port = None + if recv_req.embedding_ports is not None: + embedding_port = recv_req.embedding_ports[ + self.tp_size * self.pp_rank + self.tp_rank + ] + waiting_req = WaitingImageRequest( + rid=recv_req.rid, + recv_req=recv_req, + mm_processor=self.mm_processor, + image_urls=self.encode_urls, + host_name=self.hostname, + receive_count=self.world_size, + embedding_port=embedding_port, + ) + if recv_req.embedding_ports is None: + waiting_req.send_encode_request() + waiting_list.append(waiting_req) + + # waiting for recv embedding result + ready = False + while not ready: + ready = True + for waiting_req in waiting_list: + if not waiting_req._try_recv_mm_data(): + ready = False + + # For zmq_s + def _run_encode_in_thread( + self, req_id, img_data, endpoint_encode, num_items_assigned, embedding_port + ): + try: + asyncio.run( + self.encode( + req_id=req_id, + img_data=img_data, + embedding_port=embedding_port, + endpoint_encode=endpoint_encode, + endpoint_send=None, + num_items_assigned=num_items_assigned, + ) + ) + except Exception as e: + logger.error(f"Encode failed for request {req_id}: {e}", exc_info=True) async def encode( - self, req_id, img_data, embedding_port, endpoint_encode, endpoint_send + self, + req_id, + img_data, + embedding_port, + endpoint_encode, + endpoint_send, + num_items_assigned=None, ): if len(img_data) == 0: return # Split mm_items encode_requests = [] - random.shuffle(self.encode_idx) - num_items_assigned = [ - (idx + len(img_data)) // len(self.encode_urls) for idx in self.encode_idx - ] + if num_items_assigned is None: + random.shuffle(self.encode_idx) + num_items_assigned = [ + (idx + len(img_data)) // len(self.encode_urls) + for idx in self.encode_idx + ] num_parts = sum(1 for x in num_items_assigned if x != 0) cum_num_items = 0 cum_idx = 0 @@ -184,6 +423,7 @@ async def encode( offset += embedding_size_list_sort[idx] await asyncio.gather(*metadata_tasks) + # For mooncake async def allocate_embedding_buffer(self, req_id, embedding_length, embedding_dim): embeddings = torch.zeros( (embedding_length, embedding_dim), @@ -196,11 +436,45 @@ async def allocate_embedding_buffer(self, req_id, embedding_length, embedding_di self.embeddings_buffer[req_id] = embeddings return embeddings.data_ptr() + # For zmq_s + def send_encode_requset(self, obj): + if type(obj.image_data) != list: + image_urls = [obj.image_data.url] + else: + image_urls = [img.url for img in obj.image_data] + if obj.rid is None: + obj.rid = uuid.uuid4().hex + if image_urls and len(image_urls) > 0: + logger.info(f"Processing {len(image_urls)} images for request {obj.rid}") + obj.need_wait_for_image = True + + encode_idx = list(range(len(self.encode_urls))) + random.shuffle(encode_idx) + obj.num_items_assigned = [ + (idx + len(image_urls)) // len(self.encode_urls) for idx in encode_idx + ] + obj.embedding_ports = ( + get_multi_free_port(self.world_size) if self.nnodes == 1 else None + ) + encode_thread = threading.Thread( + target=self._run_encode_in_thread, + args=( + obj.rid, + image_urls, + "encode", + obj.num_items_assigned, + obj.embedding_ports, + ), + daemon=True, + ) + encode_thread.start() + + # For zmq_t and mooncake async def recv_mm_data(self, img_data, mm_processor, prompt): try: if len(self.encode_urls) == 0: return None - req_id = _generate_id() + req_id = uuid.uuid4().hex embedding_port = get_free_port() if type(img_data) != list: img_data = [img_data.url] @@ -219,6 +493,7 @@ async def recv_mm_data(self, img_data, mm_processor, prompt): del self.embeddings_buffer[req_id] return None + # For zmq_t and mooncake async def _recv_mm_data(self, req_id, embedding_port, mm_processor, prompt): # Bypass MMReceiver if req_id is None: @@ -230,11 +505,22 @@ async def _recv_mm_data(self, req_id, embedding_port, mm_processor, prompt): self.context, zmq.PULL, f"tcp://*:{embedding_port}", True ) + logger.info(f"{embedding_port = }") + recv_embedding_data: EmbeddingData = None while recv_embedding_data is None or not recv_embedding_data.ready: - recv_obj = await recv_socket.recv_pyobj() + parts = await recv_socket.recv_multipart(copy=False) + + recv_obj: EmbeddingData = pickle.loads(parts[0]) + logger.info(f"{recv_obj = }") + if self.mm_transfer_backend == "zmq_t": + buffer = parts[1].buffer if hasattr(parts[1], "buffer") else parts[1] + recv_obj.embedding = torch.frombuffer( + buffer, dtype=recv_obj.dtype + ).reshape(recv_obj.shape) if recv_embedding_data is None: + recv_obj.embedding_list[recv_obj.part_idx] = recv_obj.embedding recv_embedding_data = recv_obj else: recv_embedding_data.add(recv_obj) @@ -243,7 +529,7 @@ async def _recv_mm_data(self, req_id, embedding_port, mm_processor, prompt): recv_embedding = self.embeddings_buffer[req_id] del self.embeddings_buffer[req_id] self.embeddings_engine.deregister(recv_embedding.data_ptr()) - elif self.mm_transfer_backend == "zmq": + elif self.mm_transfer_backend == "zmq_t": recv_embedding = recv_embedding_data.get_embedding() img_grid_thw = recv_embedding_data.get_img_grid() diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py index ff954334bed1..e86f4de1729c 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -1,10 +1,12 @@ import asyncio +import ctypes import logging import multiprocessing as mp import pickle +import sys import time import traceback -from typing import List, Optional +from typing import Dict, List, Optional, Set, Tuple import aiohttp import numpy as np @@ -37,13 +39,10 @@ from sglang.srt.utils import get_local_ip_auto, get_zmq_socket, random_uuid logger = logging.getLogger(__name__) -import ctypes -import sys -from typing import Dict, Set, Tuple -rid_2_receive_endpoint: Dict[str, List[str]] = dict() -rid_2_receive_count: Dict[str, int] = dict() -rid_2_ready_event: Dict[str, asyncio.Event] = dict() +rid_to_receive_endpoint: Dict[str, List[str]] = dict() +rid_to_receive_count: Dict[str, int] = dict() +rid_to_ready_event: Dict[str, asyncio.Event] = dict() class TensorWrapper: @@ -89,14 +88,6 @@ def _convert(data): _image_grid_attrs = ["image_grid_thw", "image_grid_hws"] -class ReceivePortsManager: - def __init__( - self, - ): - self.rid_2_ports = dict() - self.rid_2_tp_size = dict() - - def _get_image_grid_dim(images_input): for attr in _image_grid_attrs: if attr in images_input: @@ -106,16 +97,6 @@ def _get_image_grid_dim(images_input): ) -def get_ports_for_rank(embedding_ports, rank, tp_size): - total_ports = len(embedding_ports) - ports_per_rank = (total_ports + tp_size - 1) // tp_size - - start_idx = rank * ports_per_rank - end_idx = min(start_idx + ports_per_rank, total_ports) - - return embedding_ports[start_idx:end_idx] - - class MMEncoder: def __init__( self, @@ -230,40 +211,50 @@ async def _encode(self, mm_items) -> torch.Tensor: async def _send( self, - url, embedding: torch.Tensor, mm_data: EmbeddingData, session_id=None, - peer_buffer_address=None, + buffer_address=None, + prefill_host=None, + embedding_port=None, + url=None, ): if self.server_args.mm_transfer_backend == "mooncake": self.engine.register(embedding.data_ptr(), embedding.nbytes) self.engine.transfer_sync( - session_id, embedding.data_ptr(), peer_buffer_address, embedding.nbytes + session_id, embedding.data_ptr(), buffer_address, embedding.nbytes ) self.engine.deregister(embedding.data_ptr()) mm_data.embedding = None mm_data.embedding_list[mm_data.part_idx] = None - logger.info(f" [{self.rank}] Sending to {url}") + # Send ack/data + endpoint = ( + f"tcp://{url}" + if url is not None + else f"tcp://{prefill_host}:{embedding_port}" + ) + logger.info(f"{endpoint = }") socket = get_zmq_socket( self.context, zmq.PUSH, - f"tcp://{url}", + endpoint, False, ) - new_mm_data = mm_data.copy_without_embedding() - embedding_tensor = TensorWrapper(mm_data.embedding) - new_mm_data.send_time = time.time() - socket.send_multipart([pickle.dumps(new_mm_data), embedding_tensor._buffer]) + if self.server_args.mm_transfer_backend == "mooncake": + socket.send_multipart([pickle.dumps(mm_data)]) + else: + new_mm_data = mm_data.copy_without_embedding() + embedding_tensor = TensorWrapper(mm_data.embedding) + socket.send_multipart([pickle.dumps(new_mm_data), embedding_tensor._buffer]) async def encode(self, mm_items, req_id, num_parts, part_idx): start_time = time.time() image_grid_dim, mm_embedding = await self._encode(mm_items) end_time = time.time() - print(f"🕛 encode cost = {(end_time - start_time) * 1000:.2f}ms") + logger.info(f"🕛 encode cost = {(end_time - start_time) * 1000:.2f}ms") if self.rank == 0: mm_data = EmbeddingData( req_id, @@ -275,11 +266,24 @@ async def encode(self, mm_items, req_id, num_parts, part_idx): self.embedding_to_send[mm_data.req_id] = mm_data return mm_embedding.nbytes, mm_embedding.shape[0], mm_embedding.shape[1] + # For zmq_t zmq_s and mooncake async def send( + self, req_id, prefill_host, embedding_port, session_id=None, buffer_address=None + ): + mm_data: EmbeddingData = self.embedding_to_send[req_id] + await self._send( + mm_data.embedding, + mm_data, + session_id=session_id, + buffer_address=buffer_address, + prefill_host=prefill_host, + embedding_port=embedding_port, + ) + + # For zmq_s + async def send_with_url( self, req_id, - session_id=None, - buffer_address=None, ): mm_data = self.embedding_to_send.get(req_id) if not mm_data: @@ -291,8 +295,8 @@ async def send( try: while True: - current_targets = rid_2_receive_endpoint.get(req_id, set()).copy() - expected_count = rid_2_receive_count.get(req_id) + current_targets = rid_to_receive_endpoint.get(req_id, set()).copy() + expected_count = rid_to_receive_count.get(req_id) new_targets = current_targets - sent_urls @@ -303,11 +307,9 @@ async def send( for url in new_targets: task = asyncio.create_task( self._send( - url, mm_data.embedding, mm_data, - session_id, - buffer_address, + url=url, ) ) all_tasks.append((task, url)) @@ -345,8 +347,8 @@ async def send( finally: logger.info(f"Cleaning up resources for req_id {req_id}") - rid_2_receive_endpoint.pop(req_id, None) - rid_2_receive_count.pop(req_id, None) + rid_to_receive_endpoint.pop(req_id, None) + rid_to_receive_count.pop(req_id, None) self.embedding_to_send.pop(req_id, None) async def get_embedding_port(self, prefill_url): @@ -426,8 +428,6 @@ async def handle_encode_request(request: dict): num_parts=request["num_parts"], part_idx=request["part_idx"], ) - time3 = time.time() - # print(f"🕛 send_time = {(time2 - time1) * 1000:.2f}ms, encode_time = {(time3 - time2) * 1000:.2f}ms") if encoder.server_args.mm_transfer_backend == "mooncake": del request["mm_items"] request.update( @@ -438,10 +438,33 @@ async def handle_encode_request(request: dict): } ) return ORJSONResponse(content=request) - elif encoder.server_args.mm_transfer_backend == "zmq": + elif encoder.server_args.mm_transfer_backend == "zmq_s": + logger.info(f"{request["embedding_port"] = }") + if request["embedding_port"] is None: + await encoder.send_with_url( + req_id=request["req_id"], + ) + else: + assert type(request["embedding_port"]) == list + tasks = [] + for embedding_port in request["embedding_port"]: + tasks.append( + encoder.send( + req_id=request["req_id"], + prefill_host=request["prefill_host"], + embedding_port=embedding_port, + ) + ) + await asyncio.gather(*tasks) + encoder.embedding_to_send.pop(request["req_id"], None) + return ORJSONResponse(content=None) + elif encoder.server_args.mm_transfer_backend == "zmq_t": await encoder.send( req_id=request["req_id"], + prefill_host=request["prefill_host"], + embedding_port=request["embedding_port"], ) + encoder.embedding_to_send.pop(request["req_id"], None) return ORJSONResponse(content=None) @@ -455,15 +478,16 @@ async def handle_send_request(request: dict): session_id=request["session_id"], buffer_address=request["buffer_address"], ) + encoder.embedding_to_send.pop(request["req_id"], None) return ORJSONResponse(content=None) @app.post("/scheduler_receive_url") async def handle_scheduler_receive_url_request(request: dict): rid = request["req_id"] - global rid_2_receive_endpoint - if rid not in rid_2_receive_endpoint: - rid_2_receive_endpoint[rid] = set() - rid_2_receive_count[rid] = request["receive_count"] - assert rid_2_receive_count[rid] == request["receive_count"] - rid_2_receive_endpoint[rid].add(request["receive_url"]) + global rid_to_receive_endpoint + if rid not in rid_to_receive_endpoint: + rid_to_receive_endpoint[rid] = set() + rid_to_receive_count[rid] = request["receive_count"] + assert rid_to_receive_count[rid] == request["receive_count"] + rid_to_receive_endpoint[rid].add(request["receive_url"]) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 4633a9872393..4edc02678e9d 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -228,9 +228,8 @@ class GenerateReqInput(BaseReq): return_entropy: bool = False need_wait_for_image: Optional[bool] = None - num_items_assigned: Optional[List] = None - encode_idx: Optional[List] = None + embedding_ports: Optional[List] = None def contains_mm_input(self) -> bool: return ( @@ -703,7 +702,7 @@ class TokenizedGenerateReqInput(BaseReq): need_wait_for_image: bool = False num_items_assigned: Optional[List] = None - encode_idx: Optional[List] = None + embedding_ports: Optional[List] = None @dataclass diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 98dfcd2e4b11..5385a16b8961 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -13,11 +13,9 @@ # ============================================================================== """A scheduler that manages a tensor parallel GPU worker.""" -import asyncio import faulthandler import logging import os -import pickle import signal import sys import threading @@ -28,7 +26,6 @@ from http import HTTPStatus from typing import Any, Deque, Dict, List, Optional, Tuple, Union -import aiohttp import psutil import setproctitle import torch @@ -51,7 +48,7 @@ from sglang.srt.disaggregation.decode_kvcache_offload_manager import ( DecodeKVCacheOffloadManager, ) -from sglang.srt.disaggregation.encode_receiver import EmbeddingData +from sglang.srt.disaggregation.encode_receiver import MMReceiver from sglang.srt.disaggregation.prefill import ( PrefillBootstrapQueue, SchedulerDisaggregationPrefillMixin, @@ -121,7 +118,6 @@ UpdateWeightsFromTensorReqInput, ) from sglang.srt.managers.mm_utils import init_mm_embedding_cache -from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors from sglang.srt.managers.overlap_utils import FutureMap from sglang.srt.managers.schedule_batch import ( FINISH_ABORT, @@ -181,7 +177,6 @@ freeze_gc, get_available_gpu_memory, get_bool_env_var, - get_free_port, get_int_env_var, get_local_ip_by_remote, get_zmq_socket, @@ -210,156 +205,6 @@ GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300)) -def _determine_tensor_transport_mode(server_args: ServerArgs): - is_cross_node = server_args.dist_init_addr - - if is_cross_node: - # Fallback to default CPU transport for multi-node - return "default" - else: - return "cuda_ipc" - - -class WaitingImageRequest: - def __init__( - self, - rid: str, - recv_req: TokenizedGenerateReqInput, - req, - mm_processor, - image_urls, - host_name, - receive_count, - ): - self.rid = rid - self.recv_req = recv_req - self.req = req - self.embedding_ready = False - self.mm_inputs = None - self.error = None - self.thread = None - self.ready = False - self.mm_processor = mm_processor - self.image_urls = image_urls - self.host_name = host_name - self.receive_count = receive_count - self.num_items_assigned = recv_req.num_items_assigned - self.encode_idx = recv_req.encode_idx - - def start_waiting(self): - self.thread = threading.Thread(target=self._recv_mm_data_thread, daemon=True) - self.thread.start() - - def _recv_mm_data_thread(self): - try: - mm_processor = self.mm_processor - prompt = self.recv_req.input_text - self.recv_req.mm_inputs = self._recv_mm_data_sync( - self.rid, mm_processor, prompt - ) - - self.embedding_ready = True - - except Exception as e: - logger.error(f"Error receiving embedding for {self.rid}: {e}") - self.error = str(e) - - def _recv_mm_data_sync(self, req_id, mm_processor, prompt): - if req_id is None: - return None - embedding_port = get_free_port() - context = zmq.Context() - recv_socket = context.socket(zmq.PULL) - recv_socket.bind(f"tcp://*:{embedding_port}") - logger.info(f"Waiting for input {embedding_port = }") - - async def _send_single_request(session, url, payload): - try: - async with session.post(url, json=payload) as response: - response.raise_for_status() - return await response.text() - except Exception as e: - logger.error(f"Failed to send request to {url}: {e}") - raise - - async def send_embedding_port(req_id, receive_count, host_name, embedding_port): - async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=1800) - ) as session: - tasks = [] - logger.info(f"{self.num_items_assigned = } ") - for idx, assigned_num in enumerate(self.num_items_assigned): - if assigned_num == 0: - continue - image_url = self.image_urls[self.encode_idx[idx]] - target_url = f"{image_url}/scheduler_receive_url" - payload = { - "req_id": req_id, - "receive_count": receive_count, - "receive_url": f"{host_name}:{embedding_port}", - } - - logger.info(f"Preparing to send to {target_url}") - - task = _send_single_request(session, target_url, payload) - tasks.append(task) - - if not tasks: - logger.info("No tasks to send.") - return - logger.info(f"Concurrently sending {len(tasks)} requests...") - results = await asyncio.gather(*tasks, return_exceptions=True) - - for i, result in enumerate(results): - if isinstance(result, Exception): - logger.error(f"Request {i} failed: {result}") - else: - logger.debug(f"Request {i} succeeded.") - - try: - asyncio.run( - send_embedding_port( - self.recv_req.rid, - self.receive_count, - self.host_name, - embedding_port, - ) - ) - recv_embedding_data = None - while recv_embedding_data is None or not recv_embedding_data.ready: - try: - parts = recv_socket.recv_multipart(flags=zmq.NOBLOCK, copy=False) - except zmq.Again: - # No data available yet, wait a bit and retry - continue - - recv_obj: EmbeddingData = pickle.loads(parts[0]) - buffer = parts[1].buffer if hasattr(parts[1], "buffer") else parts[1] - recv_obj.embedding = torch.frombuffer( - buffer, dtype=recv_obj.dtype - ).reshape(recv_obj.shape) - recv_obj.embedding_list[recv_obj.part_idx] = recv_obj.embedding - print( - f"transport cost {(time.time() - recv_obj.send_time) * 1000:.2f}" - ) - if recv_embedding_data is None: - recv_embedding_data = recv_obj - else: - recv_embedding_data.add(recv_obj) - recv_embedding = recv_embedding_data.get_embedding() - img_grid_thw = recv_embedding_data.get_img_grid() - - mm_inputs = mm_processor.get_mm_data(prompt, recv_embedding, img_grid_thw) - if mm_inputs and "input_ids" in mm_inputs: - self.req.origin_input_ids = mm_inputs["input_ids"] - self.ready = True - return mm_inputs - - finally: - recv_socket.close() - context.term() - - @dataclass class EmbeddingBatchResult: embeddings: torch.Tensor @@ -699,37 +544,17 @@ def __init__( # Init mlp sync flag self.require_mlp_sync = require_mlp_sync(server_args) - self.waiting_for_image: List[WaitingImageRequest] = [] - transport_mode = _determine_tensor_transport_mode(self.server_args) - if self.model_config.is_multimodal: - import_processors("sglang.srt.multimodal.processors") - try: - _processor = get_processor( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, - revision=server_args.revision, - use_fast=not server_args.disable_fast_image_processor, - ) - except ValueError as e: - error_message = str(e) - if "does not have a slow version" in error_message: - logger.info( - f"Processor {server_args.tokenizer_path} does not have a slow version. Automatically use fast version" - ) - _processor = get_processor( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, - revision=server_args.revision, - use_fast=True, - ) - else: - raise e - self.mm_processor = get_mm_processor( - self.model_config.hf_config, server_args, _processor, transport_mode - ) - print(f"{self.mm_processor = }") + if ( + self.server_args.language_only + and self.server_args.mm_transfer_backend == "zmq_s" + ): + self.mm_receiver = MMReceiver( + server_args, + hostname=self.host_name, + hf_config=self.model_config.hf_config, + tp_rank=self.tp_rank, + pp_rank=self.pp_rank, + ) # Init request dispatcher self._request_dispatcher = TypeBasedDispatcher( @@ -1160,7 +985,6 @@ def init_moe_config(self): def event_loop_normal(self): """A normal scheduler loop.""" while True: - self.process_waiting_requests() recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) @@ -1196,7 +1020,6 @@ def pop_and_process(): self.process_batch_result(tmp_batch, tmp_result) while True: - self.process_waiting_requests() recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) @@ -1366,6 +1189,14 @@ def recv_requests( return recv_reqs def process_input_requests(self, recv_reqs: List): + + # Process MM requests under E disaggregation + if ( + self.server_args.language_only + and self.server_args.mm_transfer_backend == "zmq_s" + ): + self.mm_receiver.process_waiting_requests(recv_reqs) + for recv_req in recv_reqs: # If it is a health check generation request and there are running requests, ignore it. if is_health_check_generate_req(recv_req) and ( @@ -1464,128 +1295,6 @@ def _get_multimodal_inputs(self, mm_inputs_dict: dict): else: return MultimodalInputs.from_dict(mm_inputs_dict) - def _complete_multimodal_request(self, waiting_req: WaitingImageRequest): - recv_req = waiting_req.recv_req - req: Req = waiting_req.req - image_inputs = self._get_multimodal_inputs(recv_req.mm_inputs) - - # The following steps are already fast, execute locally on each rank. - # Expand a single image token into multiple dummy tokens for receiving image embeddings - req.origin_input_ids = self.pad_input_ids_func( - req.origin_input_ids, image_inputs - ) - req.extend_image_inputs(image_inputs) - - if len(req.origin_input_ids) >= self.max_req_input_len: - req.set_finish_with_abort( - error_msg=( - "Multimodal prompt is too long after expanding multimodal tokens. " - f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}." - ) - ) - self.init_req_max_new_tokens(req) - self._add_request_to_queue(req) - return - - # initialize before returning - self.init_req_max_new_tokens(req) - - # Validate prompt length - error_msg = validate_input_length( - req, - self.max_req_input_len, - self.server_args.allow_auto_truncate, - ) - if error_msg: - req.set_finish_with_abort(error_msg) - self._add_request_to_queue(req) - return - - # Copy more attributes - if recv_req.logprob_start_len == -1 or not recv_req.return_logprob: - # By default, only return the logprobs for output tokens - # For prefill-only requests with logprob_start_len == -1, set logprob_start_len beyond input sequence - # to skip input logprob computation entirely - if req.is_prefill_only: - req.logprob_start_len = len(req.origin_input_ids) - else: - # TODO: For text generation, evaluate setting logprob_start_len to len(req.origin_input_ids) as well - req.logprob_start_len = len(req.origin_input_ids) - 1 - else: - req.logprob_start_len = recv_req.logprob_start_len - - if not req.is_prefill_only and req.logprob_start_len >= len( - req.origin_input_ids - ): - error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len." - req.logprob_start_len = len(req.origin_input_ids) - 1 - req.set_finish_with_abort(error_msg) - self._add_request_to_queue(req) - return - - # Init grammar cache for this request - add_to_grammar_queue = False - if ( - req.sampling_params.json_schema is not None - or req.sampling_params.regex is not None - or req.sampling_params.ebnf is not None - or req.sampling_params.structural_tag is not None - ): - if self.grammar_backend is None: - error_msg = "Grammar-based generation (json_schema, regex, ebnf, structural_tag) is not supported when the server is launched with --grammar-backend none" - req.set_finish_with_abort(error_msg) - else: - if req.sampling_params.json_schema is not None: - key = ("json", req.sampling_params.json_schema) - elif req.sampling_params.regex is not None: - key = ("regex", req.sampling_params.regex) - elif req.sampling_params.ebnf is not None: - key = ("ebnf", req.sampling_params.ebnf) - elif req.sampling_params.structural_tag: - key = ("structural_tag", req.sampling_params.structural_tag) - - value, cache_hit = self.grammar_backend.get_cached_or_future_value(key) - req.grammar = value - - if not cache_hit: - req.grammar_key = key - add_to_grammar_queue = True - else: - if value is INVALID_GRAMMAR_OBJ: # We hit a cached invalid grammar. - error_msg = f"Invalid grammar request with cache hit: {key=}" - req.set_finish_with_abort(error_msg) - - if add_to_grammar_queue: - self.grammar_queue.append(req) - else: - self._add_request_to_queue(req) - - def process_waiting_requests( - self, - ) -> None: - if not self.waiting_for_image or len(self.waiting_for_image) == 0: - return - local_statuses = [] - request_rids = [] - - for waiting_req in self.waiting_for_image: - local_statuses.append(1 if waiting_req.embedding_ready else 0) - request_rids.append(waiting_req.rid) - - local_tensor = torch.tensor(local_statuses, dtype=torch.int32) - - if torch.cuda.is_available(): - local_tensor = local_tensor.cuda() - - torch.distributed.all_reduce(local_tensor, op=torch.distributed.ReduceOp.MIN) - new_waiting = [] - for i, waiting_req in enumerate(self.waiting_for_image): - if local_tensor[i].item() == 1: - self._complete_multimodal_request(waiting_req) - else: - new_waiting.append(waiting_req) - self.waiting_for_image = new_waiting - def handle_generate_request( self, recv_req: TokenizedGenerateReqInput, @@ -1666,23 +1375,6 @@ def handle_generate_request( self._add_request_to_queue(req) return - # TODO: add this - if recv_req.need_wait_for_image is True: - waiting_req = WaitingImageRequest( - rid=recv_req.rid, - recv_req=recv_req, - req=req, - mm_processor=self.mm_processor, - image_urls=self.server_args.encode_urls, - host_name=self.host_name, - ##TODO fixme: - receive_count=self.server_args.tp_size, - ) - self.waiting_for_image.append(waiting_req) - waiting_req.start_waiting() - - return - # Handle multimodal inputs if recv_req.mm_inputs is not None: image_inputs = self._get_multimodal_inputs(recv_req.mm_inputs) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index e8d56a255968..d24fec6908e2 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -20,12 +20,10 @@ import math import os import pickle -import random import signal import sys import threading import time -import uuid from collections import deque from contextlib import nullcontext from datetime import datetime @@ -33,7 +31,6 @@ from http import HTTPStatus from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union -import aiohttp import fastapi import orjson import torch @@ -100,13 +97,11 @@ trace_slice_start, ) from sglang.srt.utils import ( - ImageData, configure_gc_warning, dataclass_to_string_truncated, freeze_gc, get_bool_env_var, get_or_create_event_loop, - get_multi_free_port, get_zmq_socket, kill_process_tree, ) @@ -320,13 +315,10 @@ def __init__( self.send_to_scheduler = SenderWrapper(port_args, send_to_scheduler) # E Disaggregation - if self.model_config.is_multimodal and self.server_args.language_only: + if self.server_args.language_only: self.mm_receiver = MMReceiver( - server_args.host, - server_args.encode_urls, - server_args.mm_transfer_backend, - server_args.disaggregation_ib_device, - self.model_config.dtype, + server_args, + dtype=self.model_config.dtype, ) # Request states @@ -444,91 +436,6 @@ def __init__( ] ) self.init_communicators(server_args) - # self.host = self.server_args.host - # self.ports_pool = get_multi_free_port(self.server_args.tp_size * 20) - self.encode_idx = list(range(len(self.server_args.encode_urls))) - self.encode_urls = self.server_args.encode_urls - self.riq_2_images = {} - - @staticmethod - def extrac_and_clean_image_url(obj: GenerateReqInput): - image_urls = [] - for image in obj.image_data: - if isinstance(image, ImageData): - image_urls.append(image.url) - image.url = "" - return image_urls - - def _run_encode_in_thread( - self, req_id, img_data, endpoint_encode, num_items_assigned, encode_idx - ): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - self._encode_in_background( - req_id, img_data, endpoint_encode, num_items_assigned, encode_idx - ) - ) - # logger.info(f"Encode completed for request {req_id}") - except Exception as e: - logger.error(f"Encode failed for request {req_id}: {e}", exc_info=True) - finally: - # self.ports_pool.extend(embedding_ports) - del self.riq_2_images[req_id] - loop.close() - - async def _encode_in_background( - self, req_id, img_data, endpoint_encode, num_items_assigned, encode_idx - ): - if len(img_data) == 0: - return - - # Split mm_items - encode_requests = [] - - num_parts = sum(1 for x in num_items_assigned if x != 0) - cum_num_items = 0 - cum_idx = 0 - for idx, assigned_num in enumerate(num_items_assigned): - if assigned_num == 0: - continue - encode_requests.append( - { - "encoder_idx": encode_idx[idx], - "mm_items": img_data[cum_num_items : cum_num_items + assigned_num], - "num_parts": num_parts, - "part_idx": cum_idx, - "req_id": req_id, - } - ) - cum_idx += 1 - cum_num_items += assigned_num - - async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout( - total=1800 - ) # Add timeout for request reliability - ) as session: - # Send encode requests - - tasks = [ - session.post( - f"{self.encode_urls[encode_request['encoder_idx']]}/{endpoint_encode}", - json=encode_request, - ) - for encode_request in encode_requests - ] - - responses = await asyncio.gather(*tasks) - response_json_list_unsort = [ - await response.json() for response in responses - ] - - # zmq backend: return is None - if None in response_json_list_unsort: - return async def generate_request( self, @@ -538,36 +445,12 @@ async def generate_request( created_time = time.time() self.auto_create_handle_loop() obj.normalize_batch_and_arguments() - if isinstance(obj, GenerateReqInput): - image_urls = TokenizerManager.extrac_and_clean_image_url(obj) - if obj.rid is None: - obj.rid = uuid.uuid4().hex - if image_urls and len(image_urls) > 0: - logger.info( - f"Processing {len(image_urls)} images for request {obj.rid}" - ) - self.riq_2_images[obj.rid] = image_urls - obj.embedding_ports = get_multi_free_port(self.server_args.tp_size) - obj.need_wait_for_image = True - - random.shuffle(self.encode_idx) - obj.encode_idx = self.encode_idx - obj.num_items_assigned = [ - (idx + len(image_urls)) // len(self.server_args.encode_urls) - for idx in self.encode_idx - ] - encode_thread = threading.Thread( - target=self._run_encode_in_thread, - args=( - obj.rid, - image_urls, - "encode", - obj.num_items_assigned, - obj.encode_idx, - ), - daemon=True, - ) - encode_thread.start() + if ( + self.server_args.language_only + and isinstance(obj, GenerateReqInput) + and self.server_args.mm_transfer_backend == "zmq_s" + ): + self.mm_receiver.send_encode_requset(obj) if self.enable_trace: external_trace_header = None @@ -846,8 +729,16 @@ async def _tokenize_one_request( mm_inputs = None - if self.server_args.language_only is False: - + if ( + not self.server_args.language_only + or self.server_args.mm_transfer_backend != "zmq_s" + ): + if self.server_args.language_only: + mm_inputs = await self.mm_receiver.recv_mm_data( + img_data=obj.image_data, + mm_processor=self.mm_processor, + prompt=(input_text or input_ids), + ) if mm_inputs is None: mm_inputs: Dict = await self.mm_data_processor.process( image_data=obj.image_data, @@ -1034,7 +925,7 @@ def _create_tokenized_object( extra_key=obj.extra_key, need_wait_for_image=obj.need_wait_for_image, num_items_assigned=obj.num_items_assigned, - encode_idx=obj.encode_idx, + embedding_ports=obj.embedding_ports, ) elif isinstance(obj, EmbeddingReqInput): tokenized_obj = TokenizedEmbeddingReqInput( diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index 522ad9a64092..71236fe7eda6 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -242,24 +242,31 @@ def get_input_ids(self, prompt, img_grid_thw): if not isinstance(prompt, list): prompt = self._processor.tokenizer.encode(prompt) - img_start_id = self.IM_START_TOKEN_ID - img_id = self.IM_TOKEN_ID + img_token_id = self.IM_TOKEN_ID spatial_merge_size = self.spatial_merge_size input_ids = [] offsets = [] cur_idx = 0 + + # Use img_token_id instead of im_start_id, because a dummy im_start_id + # may be generated by the tokenizer. img_start_indices = list( - filter(lambda i: prompt[i] == img_start_id, range(len(prompt))) + filter(lambda i: prompt[i + 1] == img_token_id, range(len(prompt) - 1)) ) + if len(img_start_indices) != img_grid_thw.shape[0]: + logger.info(f"{len(prompt)} {prompt}") + logger.info( + f"Check img_start_indices: {img_start_indices} and img_grid_thw: {img_grid_thw.shape}" + ) for cur_img_idx, img_start_idx in enumerate(img_start_indices): assert cur_idx <= img_start_idx # include img_start_id input_ids.extend(prompt[cur_idx : img_start_idx + 1]) img_offset_start = len(input_ids) img_token_num = img_grid_thw[cur_img_idx].prod() // (spatial_merge_size**2) - input_ids.extend([img_id] * img_token_num) + input_ids.extend([img_token_id] * img_token_num) # jump to img_end_id cur_idx = img_start_idx + 2 offsets.append((img_offset_start, len(input_ids) - 1)) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 43bb57d99966..88479f2fd3a6 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -135,7 +135,7 @@ DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"] -MM_TRANSFER_BACKEND_CHOICES = ["zmq", "mooncake"] +MM_TRANSFER_BACKEND_CHOICES = ["zmq_s", "zmq_t", "mooncake"] GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"] @@ -259,7 +259,7 @@ class ServerArgs: # Encode prefill disaggregation mm_only: bool = False language_only: bool = False - mm_transfer_backend: str = "zmq" + mm_transfer_backend: str = MM_TRANSFER_BACKEND_CHOICES[0] encode_urls: List[str] = dataclasses.field(default_factory=list) # Quantization and data type From 868909a5a09254e8ab87cf055268ad0456236f88 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Tue, 25 Nov 2025 12:22:30 +0000 Subject: [PATCH 49/68] Fix import --- python/sglang/srt/disaggregation/encode_receiver.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/disaggregation/encode_receiver.py b/python/sglang/srt/disaggregation/encode_receiver.py index da415fc7de72..720ad78f0c69 100644 --- a/python/sglang/srt/disaggregation/encode_receiver.py +++ b/python/sglang/srt/disaggregation/encode_receiver.py @@ -11,16 +11,13 @@ import zmq import zmq.asyncio -from python.sglang.srt.managers.multimodal_processor import ( - get_mm_processor, - import_processors, -) -from python.sglang.srt.utils.common import get_multi_free_port -from python.sglang.srt.utils.hf_transformers_utils import get_processor from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine from sglang.srt.managers.io_struct import TokenizedGenerateReqInput +from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors from sglang.srt.server_args import ServerArgs from sglang.srt.utils import get_free_port, get_local_ip_auto, get_zmq_socket +from sglang.srt.utils.common import get_multi_free_port +from sglang.srt.utils.hf_transformers_utils import get_processor logger = logging.getLogger(__name__) From bb562bf2853c3c7787d8b9cb615ecd4830e9574f Mon Sep 17 00:00:00 2001 From: ZhengWG Date: Tue, 25 Nov 2025 20:47:47 +0800 Subject: [PATCH 50/68] feat: support prefix_mm_cache --- .../srt/disaggregation/encode_server.py | 53 ++++++++++++++----- python/sglang/srt/server_args.py | 7 +++ 2 files changed, 47 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py index e86f4de1729c..87c036666ac4 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -2,6 +2,7 @@ import ctypes import logging import multiprocessing as mp +import os import pickle import sys import time @@ -30,6 +31,7 @@ ) from sglang.srt.layers.dp_attention import initialize_dp_attention from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem +from sglang.srt.mem_cache.multimodal_cache import MultiModalStaticCache from sglang.srt.model_loader import get_model from sglang.srt.server_args import ( PortArgs, @@ -156,6 +158,10 @@ def __init__( self.context = zmq.asyncio.Context(2) + embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "4096")) + self.mm_cache = MultiModalStaticCache(embedding_cache_size * 1024 * 1024) + self.mm_cache_lock = asyncio.Lock() + if schedule_path is not None: self.schedule_socket = get_zmq_socket( self.context, zmq.PULL, schedule_path, True @@ -194,19 +200,40 @@ async def _encode(self, mm_items) -> torch.Tensor: if k == "pixel_values": continue mm_item.set(k, _convert(v)) - with torch.inference_mode(): - if torch.cuda.is_available(): - torch.cuda.synchronize() - start_time1 = time.perf_counter() - mm_embedding: torch.Tensor = self.model.get_image_feature([mm_item]) - if torch.cuda.is_available(): - torch.cuda.synchronize() - end_time = time.perf_counter() - logger.info( - f"Vit time : {(end_time - start_time1)*1000:.2f} ms {mm_embedding.shape = }" - ) - if len(mm_embedding.shape) != 2: - mm_embedding = mm_embedding.reshape(-1, mm_embedding.shape[-1]) + + # support mm_cache + mm_embedding = None + mm_hash = None + + start_time = time.perf_counter() + if self.server_args.enable_prefix_mm_cache: + mm_item.set_pad_value() + mm_hash = MultiModalStaticCache.combine_hashes([mm_item.hash]) + async with self.mm_cache_lock: + mm_cache = self.mm_cache.get([mm_item.hash]) + if mm_cache is not None: + mm_embedding = mm_cache + if mm_cache is not None: + mm_embedding = mm_cache + + if mm_embedding is None: + with torch.inference_mode(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + mm_embedding: torch.Tensor = self.model.get_image_feature([mm_item]) + if torch.cuda.is_available(): + torch.cuda.synchronize() + if len(mm_embedding.shape) != 2: + mm_embedding = mm_embedding.reshape(-1, mm_embedding.shape[-1]) + + if self.server_args.enable_prefix_mm_cache: + async with self.mm_cache_lock: + self.mm_cache.set(mm_hash, mm_embedding.cpu()) + end_time = time.perf_counter() + logger.info( + f"Vit time : {(end_time - start_time)*1000:.2f} ms {mm_embedding.shape = }" + ) + return _get_image_grid_dim(images_input), mm_embedding.cpu() async def _send( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 88479f2fd3a6..0257187d3c34 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -594,6 +594,7 @@ class ServerArgs: mm_max_concurrent_calls: int = 32 mm_per_request_timeout: float = 10.0 enable_broadcast_mm_inputs_process: bool = False + enable_prefix_mm_cache: bool = False # For checkpoint decryption decrypted_config_file: Optional[str] = None @@ -3997,6 +3998,12 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.mm_enable_dp_encoder, help="Enabling data parallelism for mm encoder. The dp size will be set to the tp size automatically.", ) + parser.add_argument( + "--enable-prefix-mm-cache", + action="store_true", + default=ServerArgs.enable_prefix_mm_cache, + help="Enable prefix multimodal cache.", + ) # For registering hooks parser.add_argument( From dec13c18960b0ea86027b5032cb0ca44d81d65b2 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Wed, 26 Nov 2025 01:15:45 +0000 Subject: [PATCH 51/68] Lint --- python/sglang/srt/disaggregation/encode_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py index 87c036666ac4..52e7d7e749aa 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -466,7 +466,7 @@ async def handle_encode_request(request: dict): ) return ORJSONResponse(content=request) elif encoder.server_args.mm_transfer_backend == "zmq_s": - logger.info(f"{request["embedding_port"] = }") + logger.info(f"{request['embedding_port'] = }") if request["embedding_port"] is None: await encoder.send_with_url( req_id=request["req_id"], From 23c32938098e83c68076fb44b5520a81d25756d7 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Thu, 27 Nov 2025 03:37:48 +0000 Subject: [PATCH 52/68] Avoid time-consuming CPU concat --- python/sglang/srt/disaggregation/encode_receiver.py | 11 +++++++---- python/sglang/srt/disaggregation/encode_server.py | 9 +++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/disaggregation/encode_receiver.py b/python/sglang/srt/disaggregation/encode_receiver.py index 720ad78f0c69..66e0ec90a24a 100644 --- a/python/sglang/srt/disaggregation/encode_receiver.py +++ b/python/sglang/srt/disaggregation/encode_receiver.py @@ -51,8 +51,11 @@ def add(self, embedding_data): ) self.embedding_list[embedding_data.part_idx] = embedding_data.embedding - def get_embedding(self): - return torch.concatenate(self.embedding_list) + def get_embedding(self, is_concat=False): + if is_concat: + return torch.concat([embedding.cuda() for embedding in self.embedding_list]) + else: + return self.embedding_list def get_img_grid(self): return torch.concatenate(self.image_grid_dim_list) @@ -183,7 +186,7 @@ def _try_recv_mm_data(self): else: self.recv_embedding_data.add(recv_obj) - recv_embedding = self.recv_embedding_data.get_embedding() + recv_embedding = self.recv_embedding_data.get_embedding(is_concat=True) img_grid_thw = self.recv_embedding_data.get_img_grid() mm_inputs = self.mm_processor.get_mm_data( @@ -527,7 +530,7 @@ async def _recv_mm_data(self, req_id, embedding_port, mm_processor, prompt): del self.embeddings_buffer[req_id] self.embeddings_engine.deregister(recv_embedding.data_ptr()) elif self.mm_transfer_backend == "zmq_t": - recv_embedding = recv_embedding_data.get_embedding() + recv_embedding = recv_embedding_data.get_embedding(is_concat=True) img_grid_thw = recv_embedding_data.get_img_grid() diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py index 52e7d7e749aa..146913472f7a 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -218,23 +218,20 @@ async def _encode(self, mm_items) -> torch.Tensor: if mm_embedding is None: with torch.inference_mode(): - if torch.cuda.is_available(): - torch.cuda.synchronize() mm_embedding: torch.Tensor = self.model.get_image_feature([mm_item]) - if torch.cuda.is_available(): - torch.cuda.synchronize() + mm_embedding = mm_embedding.cpu() if len(mm_embedding.shape) != 2: mm_embedding = mm_embedding.reshape(-1, mm_embedding.shape[-1]) if self.server_args.enable_prefix_mm_cache: async with self.mm_cache_lock: - self.mm_cache.set(mm_hash, mm_embedding.cpu()) + self.mm_cache.set(mm_hash, mm_embedding) end_time = time.perf_counter() logger.info( f"Vit time : {(end_time - start_time)*1000:.2f} ms {mm_embedding.shape = }" ) - return _get_image_grid_dim(images_input), mm_embedding.cpu() + return _get_image_grid_dim(images_input), mm_embedding async def _send( self, From 725318b5daf9bb9047e9d2545fff42aa633b9537 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Thu, 27 Nov 2025 06:11:26 +0000 Subject: [PATCH 53/68] Rename zmq_s and zmq_t; Fix get_local_ip_by_remote --- .../srt/disaggregation/encode_receiver.py | 21 +++++++++---------- .../srt/disaggregation/encode_server.py | 8 +++---- python/sglang/srt/managers/scheduler.py | 7 ++----- .../sglang/srt/managers/tokenizer_manager.py | 4 ++-- python/sglang/srt/server_args.py | 2 +- 5 files changed, 19 insertions(+), 23 deletions(-) diff --git a/python/sglang/srt/disaggregation/encode_receiver.py b/python/sglang/srt/disaggregation/encode_receiver.py index 66e0ec90a24a..42e480038b47 100644 --- a/python/sglang/srt/disaggregation/encode_receiver.py +++ b/python/sglang/srt/disaggregation/encode_receiver.py @@ -80,7 +80,7 @@ def copy_without_embedding(self): return new_data -# For zmq_s +# For zmq_to_scheduler class WaitingImageRequest: def __init__( self, @@ -214,7 +214,6 @@ def __init__( self, server_args: ServerArgs, dtype=None, - hostname=None, hf_config=None, pp_rank=None, tp_rank=None, @@ -232,12 +231,12 @@ def __init__( ib_device=server_args.disaggregation_ib_device, ) self.embeddings_buffer = dict() - elif self.mm_transfer_backend == "zmq_s": + elif self.mm_transfer_backend == "zmq_to_scheduler": self.pp_rank = pp_rank self.tp_rank = tp_rank self.tp_size = server_args.tp_size self.nnodes = server_args.nnodes - self.hostname = hostname + self.hostname = get_local_ip_auto() self.world_size = server_args.pp_size * server_args.tp_size if hf_config is not None: transport_mode = _determine_tensor_transport_mode(server_args) @@ -270,7 +269,7 @@ def __init__( hf_config, server_args, _processor, transport_mode ) - # For zmq_s + # For zmq_to_scheduler def process_waiting_requests(self, recv_reqs): waiting_list: List[WaitingImageRequest] = [] for recv_req in recv_reqs: @@ -305,7 +304,7 @@ def process_waiting_requests(self, recv_reqs): if not waiting_req._try_recv_mm_data(): ready = False - # For zmq_s + # For zmq_to_scheduler def _run_encode_in_thread( self, req_id, img_data, endpoint_encode, num_items_assigned, embedding_port ): @@ -436,7 +435,7 @@ async def allocate_embedding_buffer(self, req_id, embedding_length, embedding_di self.embeddings_buffer[req_id] = embeddings return embeddings.data_ptr() - # For zmq_s + # For zmq_to_scheduler def send_encode_requset(self, obj): if type(obj.image_data) != list: image_urls = [obj.image_data.url] @@ -469,7 +468,7 @@ def send_encode_requset(self, obj): ) encode_thread.start() - # For zmq_t and mooncake + # For zmq_to_tokenizer and mooncake async def recv_mm_data(self, img_data, mm_processor, prompt): try: if len(self.encode_urls) == 0: @@ -493,7 +492,7 @@ async def recv_mm_data(self, img_data, mm_processor, prompt): del self.embeddings_buffer[req_id] return None - # For zmq_t and mooncake + # For zmq_to_tokenizer and mooncake async def _recv_mm_data(self, req_id, embedding_port, mm_processor, prompt): # Bypass MMReceiver if req_id is None: @@ -514,7 +513,7 @@ async def _recv_mm_data(self, req_id, embedding_port, mm_processor, prompt): recv_obj: EmbeddingData = pickle.loads(parts[0]) logger.info(f"{recv_obj = }") - if self.mm_transfer_backend == "zmq_t": + if self.mm_transfer_backend == "zmq_to_tokenizer": buffer = parts[1].buffer if hasattr(parts[1], "buffer") else parts[1] recv_obj.embedding = torch.frombuffer( buffer, dtype=recv_obj.dtype @@ -529,7 +528,7 @@ async def _recv_mm_data(self, req_id, embedding_port, mm_processor, prompt): recv_embedding = self.embeddings_buffer[req_id] del self.embeddings_buffer[req_id] self.embeddings_engine.deregister(recv_embedding.data_ptr()) - elif self.mm_transfer_backend == "zmq_t": + elif self.mm_transfer_backend == "zmq_to_tokenizer": recv_embedding = recv_embedding_data.get_embedding(is_concat=True) img_grid_thw = recv_embedding_data.get_img_grid() diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py index 146913472f7a..cdcffe5b2210 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -290,7 +290,7 @@ async def encode(self, mm_items, req_id, num_parts, part_idx): self.embedding_to_send[mm_data.req_id] = mm_data return mm_embedding.nbytes, mm_embedding.shape[0], mm_embedding.shape[1] - # For zmq_t zmq_s and mooncake + # For zmq_to_tokenizer zmq_to_scheduler and mooncake async def send( self, req_id, prefill_host, embedding_port, session_id=None, buffer_address=None ): @@ -304,7 +304,7 @@ async def send( embedding_port=embedding_port, ) - # For zmq_s + # For zmq_to_scheduler async def send_with_url( self, req_id, @@ -462,7 +462,7 @@ async def handle_encode_request(request: dict): } ) return ORJSONResponse(content=request) - elif encoder.server_args.mm_transfer_backend == "zmq_s": + elif encoder.server_args.mm_transfer_backend == "zmq_to_scheduler": logger.info(f"{request['embedding_port'] = }") if request["embedding_port"] is None: await encoder.send_with_url( @@ -482,7 +482,7 @@ async def handle_encode_request(request: dict): await asyncio.gather(*tasks) encoder.embedding_to_send.pop(request["req_id"], None) return ORJSONResponse(content=None) - elif encoder.server_args.mm_transfer_backend == "zmq_t": + elif encoder.server_args.mm_transfer_backend == "zmq_to_tokenizer": await encoder.send( req_id=request["req_id"], prefill_host=request["prefill_host"], diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 5385a16b8961..58415d651b0f 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -178,7 +178,6 @@ get_available_gpu_memory, get_bool_env_var, get_int_env_var, - get_local_ip_by_remote, get_zmq_socket, kill_itself_when_parent_died, numa_bind_to_node, @@ -236,7 +235,6 @@ def __init__( ): # Parse args self.server_args = server_args - self.host_name = get_local_ip_by_remote() self.tp_rank = tp_rank self.moe_ep_rank = moe_ep_rank self.pp_rank = pp_rank @@ -546,11 +544,10 @@ def __init__( if ( self.server_args.language_only - and self.server_args.mm_transfer_backend == "zmq_s" + and self.server_args.mm_transfer_backend == "zmq_to_scheduler" ): self.mm_receiver = MMReceiver( server_args, - hostname=self.host_name, hf_config=self.model_config.hf_config, tp_rank=self.tp_rank, pp_rank=self.pp_rank, @@ -1193,7 +1190,7 @@ def process_input_requests(self, recv_reqs: List): # Process MM requests under E disaggregation if ( self.server_args.language_only - and self.server_args.mm_transfer_backend == "zmq_s" + and self.server_args.mm_transfer_backend == "zmq_to_scheduler" ): self.mm_receiver.process_waiting_requests(recv_reqs) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index d24fec6908e2..1441da058b83 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -448,7 +448,7 @@ async def generate_request( if ( self.server_args.language_only and isinstance(obj, GenerateReqInput) - and self.server_args.mm_transfer_backend == "zmq_s" + and self.server_args.mm_transfer_backend == "zmq_to_scheduler" ): self.mm_receiver.send_encode_requset(obj) @@ -731,7 +731,7 @@ async def _tokenize_one_request( if ( not self.server_args.language_only - or self.server_args.mm_transfer_backend != "zmq_s" + or self.server_args.mm_transfer_backend != "zmq_to_scheduler" ): if self.server_args.language_only: mm_inputs = await self.mm_receiver.recv_mm_data( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 0257187d3c34..1e0d9aa7ef30 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -135,7 +135,7 @@ DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"] -MM_TRANSFER_BACKEND_CHOICES = ["zmq_s", "zmq_t", "mooncake"] +MM_TRANSFER_BACKEND_CHOICES = ["zmq_to_scheduler", "zmq_to_tokenizer", "mooncake"] GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"] From d3de248e4403408ec447dbfd8af07a47cea8ab26 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Thu, 27 Nov 2025 11:21:28 +0000 Subject: [PATCH 54/68] Add waiting list for zmq_to_scheduler --- .../srt/disaggregation/encode_receiver.py | 42 +++++++++++++------ python/sglang/srt/managers/scheduler.py | 2 +- python/sglang/srt/server_args.py | 6 +++ 3 files changed, 36 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/disaggregation/encode_receiver.py b/python/sglang/srt/disaggregation/encode_receiver.py index 42e480038b47..4451d8b5a7e9 100644 --- a/python/sglang/srt/disaggregation/encode_receiver.py +++ b/python/sglang/srt/disaggregation/encode_receiver.py @@ -167,13 +167,13 @@ async def send_embedding_port(req_id, receive_count, host_name, embedding_port): def _try_recv_mm_data(self): if self.ready: - return True + return while self.recv_embedding_data is None or not self.recv_embedding_data.ready: try: parts = self.recv_socket.recv_multipart(flags=zmq.NOBLOCK, copy=False) except zmq.Again: # No data available yet, wait a bit and retry - return False + return recv_obj: EmbeddingData = pickle.loads(parts[0]) buffer = parts[1].buffer if hasattr(parts[1], "buffer") else parts[1] @@ -195,7 +195,6 @@ def _try_recv_mm_data(self): self.recv_req.mm_inputs = mm_inputs self.recv_req.input_ids = mm_inputs["input_ids"] self.ready = True - return True def _determine_tensor_transport_mode(server_args): @@ -238,6 +237,7 @@ def __init__( self.nnodes = server_args.nnodes self.hostname = get_local_ip_auto() self.world_size = server_args.pp_size * server_args.tp_size + self.waiting_list: List[WaitingImageRequest] = [] if hf_config is not None: transport_mode = _determine_tensor_transport_mode(server_args) import_processors("sglang.srt.multimodal.processors") @@ -271,7 +271,7 @@ def __init__( # For zmq_to_scheduler def process_waiting_requests(self, recv_reqs): - waiting_list: List[WaitingImageRequest] = [] + new_recv_reqs = [] for recv_req in recv_reqs: # E Disaggregation if ( @@ -294,15 +294,31 @@ def process_waiting_requests(self, recv_reqs): ) if recv_req.embedding_ports is None: waiting_req.send_encode_request() - waiting_list.append(waiting_req) - - # waiting for recv embedding result - ready = False - while not ready: - ready = True - for waiting_req in waiting_list: - if not waiting_req._try_recv_mm_data(): - ready = False + self.waiting_list.append(waiting_req) + else: + new_recv_reqs.append(recv_req) + + if len(self.waiting_list) == 0: + return new_recv_reqs + + local_status = [] + for waiting_req in self.waiting_list: + waiting_req._try_recv_mm_data() + local_status.append(waiting_req.ready) + + local_status = torch.tensor(local_status, device="cuda", dtype=torch.int32) + + torch.distributed.all_reduce(local_status, op=torch.distributed.ReduceOp.MIN) + + new_waiting = [] + for i, waiting_req in enumerate(self.waiting_list): + if local_status[i].item(): + new_recv_reqs.append(waiting_req.recv_req) + else: + new_waiting.append(waiting_req) + + self.waiting_list = new_waiting + return new_recv_reqs # For zmq_to_scheduler def _run_encode_in_thread( diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 58415d651b0f..83e9e3d9b0e6 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1192,7 +1192,7 @@ def process_input_requests(self, recv_reqs: List): self.server_args.language_only and self.server_args.mm_transfer_backend == "zmq_to_scheduler" ): - self.mm_receiver.process_waiting_requests(recv_reqs) + recv_reqs = self.mm_receiver.process_waiting_requests(recv_reqs) for recv_req in recv_reqs: # If it is a health check generation request and there are running requests, ignore it. diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 1e0d9aa7ef30..9ecb6016d943 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1872,6 +1872,12 @@ def _handle_e_disaggregation(self): raise ValueError( "Cannot set --mm-only and --disaggregation-mode prefill/decode together" ) + if ( + self.language_only + and self.mm_transfer_backend == "zmq_to_scheduler" + and self.pp_size > 1 + ): + raise ValueError("zmq_to_scheduler not support pp_size > 1") def _handle_pd_disaggregation(self): if self.disaggregation_mode == "decode": From 4df4996246b7a3aab9497be7341a1a5eabde4451 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Thu, 27 Nov 2025 11:32:13 +0000 Subject: [PATCH 55/68] Fix duplicate disagg_mode --- python/sglang/srt/managers/tokenizer_manager.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 1441da058b83..ccfb23b0ac70 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -192,9 +192,6 @@ def __init__( ) self.crash_dump_folder = server_args.crash_dump_folder self.enable_trace = server_args.enable_trace - self.disaggregation_mode = DisaggregationMode( - self.server_args.disaggregation_mode - ) # Read model args self.model_path = server_args.model_path From 43101d6322a068257d16dff53eae4e6066405454 Mon Sep 17 00:00:00 2001 From: Nicholas <45984215+liusy58@users.noreply.github.com> Date: Mon, 1 Dec 2025 09:07:42 +0800 Subject: [PATCH 56/68] use thread for receiving mm data (#4) * use thread for receiving mm data * refactore TensorWrapper * fix oom --- .../srt/disaggregation/encode_receiver.py | 62 ++++++++++++++++++- .../srt/disaggregation/encode_server.py | 21 +++---- python/sglang/srt/managers/scheduler.py | 2 + 3 files changed, 72 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/disaggregation/encode_receiver.py b/python/sglang/srt/disaggregation/encode_receiver.py index 4451d8b5a7e9..65743210a89d 100644 --- a/python/sglang/srt/disaggregation/encode_receiver.py +++ b/python/sglang/srt/disaggregation/encode_receiver.py @@ -3,6 +3,7 @@ import pickle import random import threading +import time import uuid from typing import List @@ -91,6 +92,8 @@ def __init__( host_name, receive_count, embedding_port=None, + device=None, + gpu_id=None, ): self.rid = rid self.recv_req = recv_req @@ -112,6 +115,13 @@ def __init__( self.recv_embedding_data = None self.ready = False + self.device = device + self.gpu_id = gpu_id + + self.thread = None + self._started = False + self._cleaned = False + def send_encode_request(self): async def _send_single_request(session, url, payload): try: @@ -196,6 +206,49 @@ def _try_recv_mm_data(self): self.recv_req.input_ids = mm_inputs["input_ids"] self.ready = True + def start_receiving(self): + if self._started: + return + + self.thread = threading.Thread( + target=self._receive_loop, name=f"recv-{self.rid}", daemon=True + ) + self.thread.start() + self._started = True + + def _receive_loop(self): + torch.get_device_module(self.device).set_device(self.gpu_id) + while not self.ready and self.error is None: + try: + self._try_recv_mm_data() + if not self.ready: + time.sleep(0.001) # 1ms + except Exception as e: + logger.exception(f"[{self.rid}] Error") + self.error = str(e) + break + + def cleanup(self): + if self._cleaned: + return + + logger.debug(f"[{self.rid}] Cleaning up...") + + try: + if hasattr(self, "recv_socket") and self.recv_socket: + self.recv_socket.close() + except Exception as e: + logger.warning(f"[{self.rid}] Error closing socket: {e}") + + try: + if hasattr(self, "context") and self.context: + self.context.term() + except Exception as e: + logger.warning(f"[{self.rid}] Error terminating context: {e}") + + self._cleaned = True + logger.debug(f"[{self.rid}] Cleanup done") + def _determine_tensor_transport_mode(server_args): is_cross_node = server_args.dist_init_addr @@ -216,6 +269,8 @@ def __init__( hf_config=None, pp_rank=None, tp_rank=None, + device=None, + gpu_id=None, ): self.context = zmq.asyncio.Context(20) self.mm_transfer_backend = server_args.mm_transfer_backend @@ -237,6 +292,8 @@ def __init__( self.nnodes = server_args.nnodes self.hostname = get_local_ip_auto() self.world_size = server_args.pp_size * server_args.tp_size + self.device = device + self.gpu_id = gpu_id self.waiting_list: List[WaitingImageRequest] = [] if hf_config is not None: transport_mode = _determine_tensor_transport_mode(server_args) @@ -291,10 +348,13 @@ def process_waiting_requests(self, recv_reqs): host_name=self.hostname, receive_count=self.world_size, embedding_port=embedding_port, + device=self.device, + gpu_id=self.gpu_id, ) if recv_req.embedding_ports is None: waiting_req.send_encode_request() self.waiting_list.append(waiting_req) + waiting_req.start_receiving() else: new_recv_reqs.append(recv_req) @@ -303,7 +363,6 @@ def process_waiting_requests(self, recv_reqs): local_status = [] for waiting_req in self.waiting_list: - waiting_req._try_recv_mm_data() local_status.append(waiting_req.ready) local_status = torch.tensor(local_status, device="cuda", dtype=torch.int32) @@ -314,6 +373,7 @@ def process_waiting_requests(self, recv_reqs): for i, waiting_req in enumerate(self.waiting_list): if local_status[i].item(): new_recv_reqs.append(waiting_req.recv_req) + waiting_req.cleanup() else: new_waiting.append(waiting_req) diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py index cdcffe5b2210..53fc59e5a06c 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -4,7 +4,6 @@ import multiprocessing as mp import os import pickle -import sys import time import traceback from typing import Dict, List, Optional, Set, Tuple @@ -62,16 +61,12 @@ def __init__(self, tensor): self.shape = list(tensor.shape) self.dtype = tensor.dtype - # Create buffer view based on Python version - if sys.version_info >= (3, 12): - data_ptr = tensor.data_ptr() - total_bytes = tensor.numel() * tensor.element_size() - self._buffer = memoryview( - (ctypes.c_char * total_bytes).from_address(data_ptr) - ) - else: - # For Python 3.10, just use numpy - it already supports buffer protocol - self._buffer = np.asarray(tensor) + def __buffer__(self): + data_ptr = self.tensor.data_ptr() + total_bytes = self.tensor.numel() * self.tensor.element_size() + c_obj = (ctypes.c_char * total_bytes).from_address(data_ptr) + c_obj._keep_alive_ref = self + return memoryview(c_obj) def _convert(data): @@ -272,7 +267,9 @@ async def _send( else: new_mm_data = mm_data.copy_without_embedding() embedding_tensor = TensorWrapper(mm_data.embedding) - socket.send_multipart([pickle.dumps(new_mm_data), embedding_tensor._buffer]) + socket.send_multipart( + [pickle.dumps(new_mm_data), embedding_tensor.__buffer__()] + ) async def encode(self, mm_items, req_id, num_parts, part_idx): start_time = time.time() diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 83e9e3d9b0e6..22ce7efee479 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -551,6 +551,8 @@ def __init__( hf_config=self.model_config.hf_config, tp_rank=self.tp_rank, pp_rank=self.pp_rank, + device=self.device, + gpu_id=self.gpu_id, ) # Init request dispatcher From 438120121f7c3228d5216ab238ff742a3409f5e3 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Mon, 1 Dec 2025 15:06:58 +0800 Subject: [PATCH 57/68] Fix port and transfer backend (#5) --- .../srt/disaggregation/encode_receiver.py | 5 +-- .../sglang/srt/managers/tokenizer_manager.py | 3 +- python/sglang/srt/utils/common.py | 33 ------------------- 3 files changed, 5 insertions(+), 36 deletions(-) diff --git a/python/sglang/srt/disaggregation/encode_receiver.py b/python/sglang/srt/disaggregation/encode_receiver.py index 65743210a89d..fea377cb4084 100644 --- a/python/sglang/srt/disaggregation/encode_receiver.py +++ b/python/sglang/srt/disaggregation/encode_receiver.py @@ -17,7 +17,6 @@ from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors from sglang.srt.server_args import ServerArgs from sglang.srt.utils import get_free_port, get_local_ip_auto, get_zmq_socket -from sglang.srt.utils.common import get_multi_free_port from sglang.srt.utils.hf_transformers_utils import get_processor logger = logging.getLogger(__name__) @@ -529,7 +528,9 @@ def send_encode_requset(self, obj): (idx + len(image_urls)) // len(self.encode_urls) for idx in encode_idx ] obj.embedding_ports = ( - get_multi_free_port(self.world_size) if self.nnodes == 1 else None + [get_free_port() for _ in range(self.world_size)] + if self.nnodes == 1 + else None ) encode_thread = threading.Thread( target=self._run_encode_in_thread, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index ccfb23b0ac70..b77c717acb39 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -728,7 +728,8 @@ async def _tokenize_one_request( if ( not self.server_args.language_only - or self.server_args.mm_transfer_backend != "zmq_to_scheduler" + or self.server_args.mm_transfer_backend + in ["zmq_to_tokenizer", "mooncake"] ): if self.server_args.language_only: mm_inputs = await self.mm_receiver.recv_mm_data( diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 320d35824af7..6fa0b2404ba0 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -713,39 +713,6 @@ def is_port_available(port): return False -def get_multi_free_port(count, min_port=10000, max_port=65535, max_attempts=1000): - ports = set() - attempts = 0 - - while len(ports) < count and attempts < max_attempts: - attempts += 1 - - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) - port = s.getsockname()[1] - - if min_port <= port <= max_port: - ports.add(port) - except OSError: - try: - with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: - s.bind(("", 0)) - port = s.getsockname()[1] - - if min_port <= port <= max_port: - ports.add(port) - except OSError: - continue - - if len(ports) < count: - raise RuntimeError( - f"Could not find {count} free ports after {max_attempts} attempts" - ) - - return list(ports) - - def get_free_port(): # try ipv4 try: From 355e6e88a905f9a43a654be3eb8853e8f175bda2 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Tue, 2 Dec 2025 22:48:18 +0800 Subject: [PATCH 58/68] Fix comments (#6) * Replace type to isinstance * Check --encode-urls * Add async lock for rid * Move thread logic into mm_receiver --- .../srt/disaggregation/encode_receiver.py | 62 +------------------ .../srt/disaggregation/encode_server.py | 33 +++++----- python/sglang/srt/managers/scheduler.py | 2 - python/sglang/srt/server_args.py | 5 ++ 4 files changed, 24 insertions(+), 78 deletions(-) diff --git a/python/sglang/srt/disaggregation/encode_receiver.py b/python/sglang/srt/disaggregation/encode_receiver.py index fea377cb4084..a507ca71571f 100644 --- a/python/sglang/srt/disaggregation/encode_receiver.py +++ b/python/sglang/srt/disaggregation/encode_receiver.py @@ -3,7 +3,6 @@ import pickle import random import threading -import time import uuid from typing import List @@ -91,8 +90,6 @@ def __init__( host_name, receive_count, embedding_port=None, - device=None, - gpu_id=None, ): self.rid = rid self.recv_req = recv_req @@ -114,13 +111,6 @@ def __init__( self.recv_embedding_data = None self.ready = False - self.device = device - self.gpu_id = gpu_id - - self.thread = None - self._started = False - self._cleaned = False - def send_encode_request(self): async def _send_single_request(session, url, payload): try: @@ -205,49 +195,6 @@ def _try_recv_mm_data(self): self.recv_req.input_ids = mm_inputs["input_ids"] self.ready = True - def start_receiving(self): - if self._started: - return - - self.thread = threading.Thread( - target=self._receive_loop, name=f"recv-{self.rid}", daemon=True - ) - self.thread.start() - self._started = True - - def _receive_loop(self): - torch.get_device_module(self.device).set_device(self.gpu_id) - while not self.ready and self.error is None: - try: - self._try_recv_mm_data() - if not self.ready: - time.sleep(0.001) # 1ms - except Exception as e: - logger.exception(f"[{self.rid}] Error") - self.error = str(e) - break - - def cleanup(self): - if self._cleaned: - return - - logger.debug(f"[{self.rid}] Cleaning up...") - - try: - if hasattr(self, "recv_socket") and self.recv_socket: - self.recv_socket.close() - except Exception as e: - logger.warning(f"[{self.rid}] Error closing socket: {e}") - - try: - if hasattr(self, "context") and self.context: - self.context.term() - except Exception as e: - logger.warning(f"[{self.rid}] Error terminating context: {e}") - - self._cleaned = True - logger.debug(f"[{self.rid}] Cleanup done") - def _determine_tensor_transport_mode(server_args): is_cross_node = server_args.dist_init_addr @@ -268,8 +215,6 @@ def __init__( hf_config=None, pp_rank=None, tp_rank=None, - device=None, - gpu_id=None, ): self.context = zmq.asyncio.Context(20) self.mm_transfer_backend = server_args.mm_transfer_backend @@ -291,8 +236,6 @@ def __init__( self.nnodes = server_args.nnodes self.hostname = get_local_ip_auto() self.world_size = server_args.pp_size * server_args.tp_size - self.device = device - self.gpu_id = gpu_id self.waiting_list: List[WaitingImageRequest] = [] if hf_config is not None: transport_mode = _determine_tensor_transport_mode(server_args) @@ -347,13 +290,10 @@ def process_waiting_requests(self, recv_reqs): host_name=self.hostname, receive_count=self.world_size, embedding_port=embedding_port, - device=self.device, - gpu_id=self.gpu_id, ) if recv_req.embedding_ports is None: waiting_req.send_encode_request() self.waiting_list.append(waiting_req) - waiting_req.start_receiving() else: new_recv_reqs.append(recv_req) @@ -362,6 +302,7 @@ def process_waiting_requests(self, recv_reqs): local_status = [] for waiting_req in self.waiting_list: + waiting_req._try_recv_mm_data() local_status.append(waiting_req.ready) local_status = torch.tensor(local_status, device="cuda", dtype=torch.int32) @@ -372,7 +313,6 @@ def process_waiting_requests(self, recv_reqs): for i, waiting_req in enumerate(self.waiting_list): if local_status[i].item(): new_recv_reqs.append(waiting_req.recv_req) - waiting_req.cleanup() else: new_waiting.append(waiting_req) diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py index 53fc59e5a06c..055a827c8d2e 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -41,9 +41,9 @@ logger = logging.getLogger(__name__) +rid_lock = asyncio.Lock() rid_to_receive_endpoint: Dict[str, List[str]] = dict() rid_to_receive_count: Dict[str, int] = dict() -rid_to_ready_event: Dict[str, asyncio.Event] = dict() class TensorWrapper: @@ -70,13 +70,13 @@ def __buffer__(self): def _convert(data): - if type(data) == torch.Tensor: + if isinstance(data, torch.Tensor): return data - elif type(data) == np.ndarray: + elif isinstance(data, np.ndarray): return torch.tensor(data) - elif type(data) == list and type(data[0]) == np.ndarray: + elif isinstance(data, list) and isinstance(data[0], np.ndarray): return torch.tensor(np.array(data)) - elif type(data) == list and type(data[0]) in [int, float]: + elif isinstance(data, list) and isinstance(data[0], (int, float)): return torch.tensor(data) else: return data @@ -316,8 +316,9 @@ async def send_with_url( try: while True: - current_targets = rid_to_receive_endpoint.get(req_id, set()).copy() - expected_count = rid_to_receive_count.get(req_id) + with rid_lock: + current_targets = rid_to_receive_endpoint.get(req_id, set()).copy() + expected_count = rid_to_receive_count.get(req_id) new_targets = current_targets - sent_urls @@ -368,8 +369,9 @@ async def send_with_url( finally: logger.info(f"Cleaning up resources for req_id {req_id}") - rid_to_receive_endpoint.pop(req_id, None) - rid_to_receive_count.pop(req_id, None) + with rid_lock: + rid_to_receive_endpoint.pop(req_id, None) + rid_to_receive_count.pop(req_id, None) self.embedding_to_send.pop(req_id, None) async def get_embedding_port(self, prefill_url): @@ -506,9 +508,10 @@ async def handle_send_request(request: dict): @app.post("/scheduler_receive_url") async def handle_scheduler_receive_url_request(request: dict): rid = request["req_id"] - global rid_to_receive_endpoint - if rid not in rid_to_receive_endpoint: - rid_to_receive_endpoint[rid] = set() - rid_to_receive_count[rid] = request["receive_count"] - assert rid_to_receive_count[rid] == request["receive_count"] - rid_to_receive_endpoint[rid].add(request["receive_url"]) + with rid_lock: + global rid_to_receive_endpoint + if rid not in rid_to_receive_endpoint: + rid_to_receive_endpoint[rid] = set() + rid_to_receive_count[rid] = request["receive_count"] + assert rid_to_receive_count[rid] == request["receive_count"] + rid_to_receive_endpoint[rid].add(request["receive_url"]) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 22ce7efee479..83e9e3d9b0e6 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -551,8 +551,6 @@ def __init__( hf_config=self.model_config.hf_config, tp_rank=self.tp_rank, pp_rank=self.pp_rank, - device=self.device, - gpu_id=self.gpu_id, ) # Init request dispatcher diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 9ecb6016d943..7f9c5bee6b80 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1879,6 +1879,11 @@ def _handle_e_disaggregation(self): ): raise ValueError("zmq_to_scheduler not support pp_size > 1") + if self.language_only and len(self.encode_urls) == 0: + raise ValueError( + "--language-only need to specify at least one --encode-urls" + ) + def _handle_pd_disaggregation(self): if self.disaggregation_mode == "decode": assert ( From eaaf840e4872764eb23fbaba5c9246c07cabad5a Mon Sep 17 00:00:00 2001 From: Zheng Wengang Date: Wed, 3 Dec 2025 23:13:07 +0800 Subject: [PATCH 59/68] feat: add health/health_generate (#7) --- python/sglang/srt/disaggregation/encode_server.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py index 055a827c8d2e..3661d592bb60 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -15,7 +15,7 @@ import zmq import zmq.asyncio from fastapi import FastAPI -from fastapi.responses import ORJSONResponse +from fastapi.responses import ORJSONResponse, Response from transformers import AutoImageProcessor from transformers.image_utils import load_images @@ -515,3 +515,15 @@ async def handle_scheduler_receive_url_request(request: dict): rid_to_receive_count[rid] = request["receive_count"] assert rid_to_receive_count[rid] == request["receive_count"] rid_to_receive_endpoint[rid].add(request["receive_url"]) + + +@app.get("/health") +@app.get("/health_generate") +async def health_generate(): + """ + Health check endpoint for the encoder server. + Returns 200 if the encoder is initialized and ready. + """ + if encoder is None: + return Response(status_code=503) + return Response(status_code=200) From 02d5dfccc7766f83d9e1bb8dd4f26dd13bc59443 Mon Sep 17 00:00:00 2001 From: Zheng Wengang Date: Thu, 4 Dec 2025 00:47:14 +0800 Subject: [PATCH 60/68] clean code for mm_cache && add para-check (#8) --- python/sglang/srt/disaggregation/encode_server.py | 2 -- python/sglang/srt/server_args.py | 6 +++++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py index 3661d592bb60..4d2cdca59060 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -208,8 +208,6 @@ async def _encode(self, mm_items) -> torch.Tensor: mm_cache = self.mm_cache.get([mm_item.hash]) if mm_cache is not None: mm_embedding = mm_cache - if mm_cache is not None: - mm_embedding = mm_cache if mm_embedding is None: with torch.inference_mode(): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 7f9c5bee6b80..e78a14a5ba03 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1866,6 +1866,10 @@ def _handle_load_format(self): self.load_format = "auto" def _handle_e_disaggregation(self): + if self.enable_prefix_mm_cache and not self.mm_only: + raise ValueError( + "--enable-prefix-mm-cache requires --mm-only to be enabled" + ) if self.mm_only and self.language_only: raise ValueError("Cannot set --mm-only and --language-only together") if self.mm_only and not self.disaggregation_mode == "null": @@ -4013,7 +4017,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--enable-prefix-mm-cache", action="store_true", default=ServerArgs.enable_prefix_mm_cache, - help="Enable prefix multimodal cache.", + help="Enable prefix multimodal cache. Currently only supports mm-only.", ) # For registering hooks From 4c74efc6d49d4cd5a6862500e9a751a5b1976287 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Fri, 5 Dec 2025 21:30:08 +0800 Subject: [PATCH 61/68] Fix comments (#10) --- python/sglang/launch_server.py | 4 +- python/sglang/srt/configs/model_config.py | 6 +-- .../srt/disaggregation/encode_receiver.py | 14 +++--- .../srt/disaggregation/encode_server.py | 14 +++--- python/sglang/srt/managers/scheduler.py | 4 +- .../sglang/srt/managers/tokenizer_manager.py | 4 +- python/sglang/srt/models/dots_vlm.py | 4 +- python/sglang/srt/models/qwen2_5_vl.py | 4 +- python/sglang/srt/models/qwen3_vl.py | 12 ++--- python/sglang/srt/models/qwen3_vl_moe.py | 4 +- .../multimodal/processors/base_processor.py | 10 ++-- .../srt/multimodal/processors/qwen_vl.py | 2 +- python/sglang/srt/server_args.py | 48 +++++++++---------- 13 files changed, 62 insertions(+), 68 deletions(-) diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 831988eca6f9..906fb6e15853 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -9,12 +9,12 @@ def run_server(server_args): - """Run the server based on server_args.grpc_mode and server_args.mm_only.""" + """Run the server based on server_args.grpc_mode and server_args.encoder_only.""" if server_args.grpc_mode: from sglang.srt.entrypoints.grpc_server import serve_grpc asyncio.run(serve_grpc(server_args)) - elif server_args.mm_only: + elif server_args.encoder_only: from sglang.srt.disaggregation.encode_server import launch_server launch_server(server_args) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 16e909ad5393..8926269b64a7 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -100,7 +100,7 @@ def __init__( model_impl: Union[str, ModelImpl] = ModelImpl.AUTO, sampling_defaults: str = "openai", quantize_and_serve: bool = False, - mm_only: bool = False, + encoder_only: bool = False, language_only: bool = False, ) -> None: # Parse args @@ -218,7 +218,7 @@ def __init__( self.hf_config, "image_token_id", None ) or getattr(self.hf_config, "image_token_index", None) - self.hf_config.mm_only = mm_only + self.hf_config.encoder_only = encoder_only self.hf_config.language_only = language_only # matryoshka embeddings @@ -252,7 +252,7 @@ def from_server_args( quantize_and_serve=server_args.quantize_and_serve, override_config_file=server_args.decrypted_config_file, language_only=server_args.language_only, - mm_only=server_args.mm_only, + encoder_only=server_args.encoder_only, **kwargs, ) diff --git a/python/sglang/srt/disaggregation/encode_receiver.py b/python/sglang/srt/disaggregation/encode_receiver.py index a507ca71571f..0742535f68e8 100644 --- a/python/sglang/srt/disaggregation/encode_receiver.py +++ b/python/sglang/srt/disaggregation/encode_receiver.py @@ -217,11 +217,11 @@ def __init__( tp_rank=None, ): self.context = zmq.asyncio.Context(20) - self.mm_transfer_backend = server_args.mm_transfer_backend - self.encode_urls = server_args.encode_urls + self.encoder_transfer_backend = server_args.encoder_transfer_backend + self.encode_urls = server_args.encoder_urls self.encode_idx = list(range(len(self.encode_urls))) self.host = server_args.host - if self.mm_transfer_backend == "mooncake": + if self.encoder_transfer_backend == "mooncake": self.dtype = dtype self.embeddings_engine = MooncakeTransferEngine( hostname=get_local_ip_auto(), @@ -229,7 +229,7 @@ def __init__( ib_device=server_args.disaggregation_ib_device, ) self.embeddings_buffer = dict() - elif self.mm_transfer_backend == "zmq_to_scheduler": + elif self.encoder_transfer_backend == "zmq_to_scheduler": self.pp_rank = pp_rank self.tp_rank = tp_rank self.tp_size = server_args.tp_size @@ -530,7 +530,7 @@ async def _recv_mm_data(self, req_id, embedding_port, mm_processor, prompt): recv_obj: EmbeddingData = pickle.loads(parts[0]) logger.info(f"{recv_obj = }") - if self.mm_transfer_backend == "zmq_to_tokenizer": + if self.encoder_transfer_backend == "zmq_to_tokenizer": buffer = parts[1].buffer if hasattr(parts[1], "buffer") else parts[1] recv_obj.embedding = torch.frombuffer( buffer, dtype=recv_obj.dtype @@ -541,11 +541,11 @@ async def _recv_mm_data(self, req_id, embedding_port, mm_processor, prompt): else: recv_embedding_data.add(recv_obj) - if self.mm_transfer_backend == "mooncake": + if self.encoder_transfer_backend == "mooncake": recv_embedding = self.embeddings_buffer[req_id] del self.embeddings_buffer[req_id] self.embeddings_engine.deregister(recv_embedding.data_ptr()) - elif self.mm_transfer_backend == "zmq_to_tokenizer": + elif self.encoder_transfer_backend == "zmq_to_tokenizer": recv_embedding = recv_embedding_data.get_embedding(is_concat=True) img_grid_thw = recv_embedding_data.get_img_grid() diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py index 4d2cdca59060..7172e306abd4 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -164,10 +164,10 @@ def __init__( if self.rank == 0: logger.info( - f"Using transfer backend: {self.server_args.mm_transfer_backend}" + f"Using transfer backend: {self.server_args.encoder_transfer_backend}" ) - if self.server_args.mm_transfer_backend == "mooncake": + if self.server_args.encoder_transfer_backend == "mooncake": self.local_ip = get_local_ip_auto() self.engine = MooncakeTransferEngine( @@ -236,7 +236,7 @@ async def _send( embedding_port=None, url=None, ): - if self.server_args.mm_transfer_backend == "mooncake": + if self.server_args.encoder_transfer_backend == "mooncake": self.engine.register(embedding.data_ptr(), embedding.nbytes) self.engine.transfer_sync( session_id, embedding.data_ptr(), buffer_address, embedding.nbytes @@ -260,7 +260,7 @@ async def _send( False, ) - if self.server_args.mm_transfer_backend == "mooncake": + if self.server_args.encoder_transfer_backend == "mooncake": socket.send_multipart([pickle.dumps(mm_data)]) else: new_mm_data = mm_data.copy_without_embedding() @@ -449,7 +449,7 @@ async def handle_encode_request(request: dict): num_parts=request["num_parts"], part_idx=request["part_idx"], ) - if encoder.server_args.mm_transfer_backend == "mooncake": + if encoder.server_args.encoder_transfer_backend == "mooncake": del request["mm_items"] request.update( { @@ -459,7 +459,7 @@ async def handle_encode_request(request: dict): } ) return ORJSONResponse(content=request) - elif encoder.server_args.mm_transfer_backend == "zmq_to_scheduler": + elif encoder.server_args.encoder_transfer_backend == "zmq_to_scheduler": logger.info(f"{request['embedding_port'] = }") if request["embedding_port"] is None: await encoder.send_with_url( @@ -479,7 +479,7 @@ async def handle_encode_request(request: dict): await asyncio.gather(*tasks) encoder.embedding_to_send.pop(request["req_id"], None) return ORJSONResponse(content=None) - elif encoder.server_args.mm_transfer_backend == "zmq_to_tokenizer": + elif encoder.server_args.encoder_transfer_backend == "zmq_to_tokenizer": await encoder.send( req_id=request["req_id"], prefill_host=request["prefill_host"], diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 83e9e3d9b0e6..a1fc435d0cfe 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -544,7 +544,7 @@ def __init__( if ( self.server_args.language_only - and self.server_args.mm_transfer_backend == "zmq_to_scheduler" + and self.server_args.encoder_transfer_backend == "zmq_to_scheduler" ): self.mm_receiver = MMReceiver( server_args, @@ -1190,7 +1190,7 @@ def process_input_requests(self, recv_reqs: List): # Process MM requests under E disaggregation if ( self.server_args.language_only - and self.server_args.mm_transfer_backend == "zmq_to_scheduler" + and self.server_args.encoder_transfer_backend == "zmq_to_scheduler" ): recv_reqs = self.mm_receiver.process_waiting_requests(recv_reqs) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index b77c717acb39..d8cf43bb358e 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -445,7 +445,7 @@ async def generate_request( if ( self.server_args.language_only and isinstance(obj, GenerateReqInput) - and self.server_args.mm_transfer_backend == "zmq_to_scheduler" + and self.server_args.encoder_transfer_backend == "zmq_to_scheduler" ): self.mm_receiver.send_encode_requset(obj) @@ -728,7 +728,7 @@ async def _tokenize_one_request( if ( not self.server_args.language_only - or self.server_args.mm_transfer_backend + or self.server_args.encoder_transfer_backend in ["zmq_to_tokenizer", "mooncake"] ): if self.server_args.language_only: diff --git a/python/sglang/srt/models/dots_vlm.py b/python/sglang/srt/models/dots_vlm.py index a1011b6a9ebc..ea113b60c54d 100644 --- a/python/sglang/srt/models/dots_vlm.py +++ b/python/sglang/srt/models/dots_vlm.py @@ -50,7 +50,7 @@ def __init__( self.video_token_id = config.video_span_id self.pp_group = get_pp_group() - if not config.mm_only: + if not config.encoder_only: self.language_model = DeepseekV2ForCausalLM( config.language_config, quant_config ) @@ -117,7 +117,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight) # Load language model weights - if not self.config.mm_only and language_weights: + if not self.config.encoder_only and language_weights: self.language_model.load_weights(language_weights) @classmethod diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index b899eb977485..271c378c0bfa 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -477,7 +477,7 @@ def __init__( self.config = config self.use_data_parallel = get_global_server_args().mm_enable_dp_encoder - if not self.config.mm_only: + if not self.config.encoder_only: self.model = Qwen2Model( config, quant_config, @@ -676,7 +676,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue # Skip loading visual/language model weights if ( - self.config.mm_only or self.config.language_only + self.config.encoder_only or self.config.language_only ) and name not in params_dict: continue param = params_dict[name] diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py index 94d7b12782f1..fddb747d6765 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -617,12 +617,10 @@ def __init__( self.config: Qwen3VLConfig = config # for qwen3-vl else: self.config = config.text_config # for qwen3-omni - if hasattr(config, "mm_only"): - self.config.mm_only = config.mm_only - if hasattr(config, "language_only"): - self.config.language_only = config.language_only + self.config.encoder_only = getattr(config, "encoder_only", False) + self.config.language_only = getattr(config, "language_only", False) - if not hasattr(config, "mm_only") or not config.mm_only: + if not hasattr(config, "encoder_only") or not config.encoder_only: self.model = language_model_cls( config=self.config, quant_config=quant_config, @@ -782,7 +780,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue # Skip loading visual/language model weights if ( - self.config.mm_only or self.config.language_only + self.config.encoder_only or self.config.language_only ) and name not in params_dict: continue param = params_dict[name] @@ -801,7 +799,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue # Skip loading visual/language model weights if ( - self.config.mm_only or self.config.language_only + self.config.encoder_only or self.config.language_only ) and name not in params_dict: continue param = params_dict[name] diff --git a/python/sglang/srt/models/qwen3_vl_moe.py b/python/sglang/srt/models/qwen3_vl_moe.py index 181d1f17f4fe..b13a221f83fa 100644 --- a/python/sglang/srt/models/qwen3_vl_moe.py +++ b/python/sglang/srt/models/qwen3_vl_moe.py @@ -243,7 +243,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - if "visual" in name or self.config.mm_only: + if "visual" in name or self.config.encoder_only: continue # Anyway, this is an expert weight and should not be # attempted to load as other weights later @@ -311,7 +311,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading mm/language parameters if ( - self.config.mm_only or self.config.language_only + self.config.encoder_only or self.config.language_only ) and name not in params_dict: continue diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index 71236fe7eda6..6553178e6ce8 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -235,7 +235,7 @@ def __init__( def spatial_merge_size(self): return self.hf_config.vision_config.spatial_merge_size - def get_input_ids(self, prompt, img_grid_thw): + def build_input_ids(self, prompt, img_grid_thw): """ Use prompt and img_grid_thw to build input_ids """ @@ -255,11 +255,7 @@ def get_input_ids(self, prompt, img_grid_thw): img_start_indices = list( filter(lambda i: prompt[i + 1] == img_token_id, range(len(prompt) - 1)) ) - if len(img_start_indices) != img_grid_thw.shape[0]: - logger.info(f"{len(prompt)} {prompt}") - logger.info( - f"Check img_start_indices: {img_start_indices} and img_grid_thw: {img_grid_thw.shape}" - ) + for cur_img_idx, img_start_idx in enumerate(img_start_indices): assert cur_idx <= img_start_idx # include img_start_id @@ -276,7 +272,7 @@ def get_input_ids(self, prompt, img_grid_thw): return input_ids, offsets def get_mm_data(self, prompt, embeddings, img_grid_thw): - input_ids, offsets = self.get_input_ids(prompt, img_grid_thw) + input_ids, offsets = self.build_input_ids(prompt, img_grid_thw) mm_items = [ MultimodalDataItem( modality=Modality.IMAGE, diff --git a/python/sglang/srt/multimodal/processors/qwen_vl.py b/python/sglang/srt/multimodal/processors/qwen_vl.py index 41eadb66673a..b107af94c274 100644 --- a/python/sglang/srt/multimodal/processors/qwen_vl.py +++ b/python/sglang/srt/multimodal/processors/qwen_vl.py @@ -260,7 +260,7 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): ).build(_processor) def get_mm_data(self, prompt, embeddings, img_grid_thw): - input_ids, offsets = self.get_input_ids(prompt, img_grid_thw) + input_ids, offsets = self.build_input_ids(prompt, img_grid_thw) mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index( spatial_merge_size=self.hf_config.vision_config.spatial_merge_size, image_token_id=self.mm_tokens.image_token_id, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e78a14a5ba03..bda81d826eff 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -135,7 +135,7 @@ DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"] -MM_TRANSFER_BACKEND_CHOICES = ["zmq_to_scheduler", "zmq_to_tokenizer", "mooncake"] +ENCODER_TRANSFER_BACKEND_CHOICES = ["zmq_to_scheduler", "zmq_to_tokenizer", "mooncake"] GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"] @@ -257,10 +257,10 @@ class ServerArgs: checkpoint_engine_wait_weights_before_ready: bool = False # Encode prefill disaggregation - mm_only: bool = False + encoder_only: bool = False language_only: bool = False - mm_transfer_backend: str = MM_TRANSFER_BACKEND_CHOICES[0] - encode_urls: List[str] = dataclasses.field(default_factory=list) + encoder_transfer_backend: str = ENCODER_TRANSFER_BACKEND_CHOICES[0] + encoder_urls: List[str] = dataclasses.field(default_factory=list) # Quantization and data type dtype: str = "auto" @@ -666,8 +666,8 @@ def __post_init__(self): # Handle PD disaggregation. self._handle_pd_disaggregation() - # Handle E disaggregation. - self._handle_e_disaggregation() + # Handle Encoder disaggregation. + self._handle_encoder_disaggregation() # Validate tokenizer settings. self._handle_tokenizer_batching() @@ -1865,27 +1865,27 @@ def _handle_load_format(self): ): self.load_format = "auto" - def _handle_e_disaggregation(self): - if self.enable_prefix_mm_cache and not self.mm_only: + def _handle_encoder_disaggregation(self): + if self.enable_prefix_mm_cache and not self.encoder_only: raise ValueError( - "--enable-prefix-mm-cache requires --mm-only to be enabled" + "--enable-prefix-mm-cache requires --encoder-only to be enabled" ) - if self.mm_only and self.language_only: - raise ValueError("Cannot set --mm-only and --language-only together") - if self.mm_only and not self.disaggregation_mode == "null": + if self.encoder_only and self.language_only: + raise ValueError("Cannot set --encoder-only and --language-only together") + if self.encoder_only and not self.disaggregation_mode == "null": raise ValueError( - "Cannot set --mm-only and --disaggregation-mode prefill/decode together" + "Cannot set --encoder-only and --disaggregation-mode prefill/decode together" ) if ( self.language_only - and self.mm_transfer_backend == "zmq_to_scheduler" + and self.encoder_transfer_backend == "zmq_to_scheduler" and self.pp_size > 1 ): raise ValueError("zmq_to_scheduler not support pp_size > 1") - if self.language_only and len(self.encode_urls) == 0: + if self.language_only and len(self.encoder_urls) == 0: raise ValueError( - "--language-only need to specify at least one --encode-urls" + "requires at least one encoder urls to be set via --encoder-urls" ) def _handle_pd_disaggregation(self): @@ -2275,9 +2275,9 @@ def add_cli_args(parser: argparse.ArgumentParser): # Encode prefill disaggregation parser.add_argument( - "--mm-only", + "--encoder-only", action="store_true", - help="For VLM, launch encode server only for multimodal part.", + help="For MLLM with an encoder, launch an encoder-only server", ) parser.add_argument( "--language-only", @@ -2285,18 +2285,18 @@ def add_cli_args(parser: argparse.ArgumentParser): help="For VLM, load weights for the language model only.", ) parser.add_argument( - "--mm-transfer-backend", + "--encoder-transfer-backend", type=str, - default=ServerArgs.mm_transfer_backend, - choices=MM_TRANSFER_BACKEND_CHOICES, - help="The backend for encoder disaggregation transfer. Default is zmq.", + default=ServerArgs.encoder_transfer_backend, + choices=ENCODER_TRANSFER_BACKEND_CHOICES, + help="The backend for encoder disaggregation transfer. Default is zmq_to_scheduler.", ) parser.add_argument( - "--encode-urls", + "--encoder-urls", nargs="+", type=str, default=[], - help="List of encode urls for encoder disaggregation", + help="List of encoder server urls.", ) # Quantization and data type From 51afd4579878d8fb2908f2f98ed40354b0fda7fc Mon Sep 17 00:00:00 2001 From: Zheng Wengang Date: Sun, 7 Dec 2025 02:35:51 +0800 Subject: [PATCH 62/68] fix: skip init emeb_tokens && skip send_encode without mm_inputs (#11) * fix: skip embed init for mm_only mode * fix: skip send health-check-req to encoder with epd mode --- python/sglang/srt/entrypoints/http_server.py | 6 ++++- .../sglang/srt/managers/tokenizer_manager.py | 1 + python/sglang/srt/models/qwen2_5_vl.py | 25 +++++++++++-------- python/sglang/srt/models/qwen3_vl.py | 3 +++ 4 files changed, 23 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index cf0a3784fe8c..9ca33f35fd43 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -1519,7 +1519,11 @@ def _execute_server_warmup( # TODO Workaround the bug that embedding errors for list of size 1 if server_args.dp_size == 1: json_data["input_ids"] = json_data["input_ids"][0] - elif is_vlm and server_args.disaggregation_mode == "null": + elif ( + is_vlm + and server_args.disaggregation_mode == "null" + and not server_args.language_only + ): # TODO: ChatCompletionRequest does not have bootstrap info required by disaggregation mode, disable image-warmup for now json_data = { "model": _global_state.tokenizer_manager.served_model_name, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index d8cf43bb358e..a77f29bec092 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -446,6 +446,7 @@ async def generate_request( self.server_args.language_only and isinstance(obj, GenerateReqInput) and self.server_args.encoder_transfer_backend == "zmq_to_scheduler" + and obj.contains_mm_input() ): self.mm_receiver.send_encode_requset(obj) diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index 271c378c0bfa..8140665618e5 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -484,19 +484,22 @@ def __init__( prefix=add_prefix("model", prefix), ) - if self.pp_group.is_last_rank: - if self.pp_group.world_size == 1 and self.config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens + if self.pp_group.is_last_rank: + if self.pp_group.world_size == 1 and self.config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + ) else: - self.lm_head = ParallelLMHead( - self.config.vocab_size, - self.config.hidden_size, - quant_config=quant_config, - prefix=add_prefix("lm_head", prefix), - ) + # ranks other than the last rank will have a placeholder layer + self.lm_head = PPMissingLayer() else: - # ranks other than the last rank will have a placeholder layer - self.lm_head = PPMissingLayer() + # mm_only mode: no language model, so no lm_head needed + self.lm_head = None self.visual = Qwen2_5_VisionTransformer( config.vision_config, diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py index fddb747d6765..95f5f2b5b32b 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -636,6 +636,9 @@ def __init__( quant_config=quant_config, prefix=add_prefix("lm_head", prefix), ) + else: + # mm_only mode: no language model, so no lm_head needed + self.lm_head = None self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling self.logits_processor = LogitsProcessor(self.config) From c47bfce600ed996ca6c61518970a0550b95cef35 Mon Sep 17 00:00:00 2001 From: Zheng Wengang Date: Sun, 7 Dec 2025 11:09:28 +0800 Subject: [PATCH 63/68] [CI]: add EPD disaggregation integration tests (#9) * test: add EPD disaggregation integration tests * fix comment for encoder-only * revert http_server warmup for vlm --- python/sglang/srt/entrypoints/http_server.py | 6 +- python/sglang/srt/models/qwen2_5_vl.py | 2 +- python/sglang/srt/models/qwen3_vl.py | 2 +- test/srt/run_suite.py | 1 + test/srt/test_epd_disaggregation.py | 390 +++++++++++++++++++ 5 files changed, 394 insertions(+), 7 deletions(-) create mode 100644 test/srt/test_epd_disaggregation.py diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 9ca33f35fd43..cf0a3784fe8c 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -1519,11 +1519,7 @@ def _execute_server_warmup( # TODO Workaround the bug that embedding errors for list of size 1 if server_args.dp_size == 1: json_data["input_ids"] = json_data["input_ids"][0] - elif ( - is_vlm - and server_args.disaggregation_mode == "null" - and not server_args.language_only - ): + elif is_vlm and server_args.disaggregation_mode == "null": # TODO: ChatCompletionRequest does not have bootstrap info required by disaggregation mode, disable image-warmup for now json_data = { "model": _global_state.tokenizer_manager.served_model_name, diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index 8140665618e5..5db43cc97637 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -498,7 +498,7 @@ def __init__( # ranks other than the last rank will have a placeholder layer self.lm_head = PPMissingLayer() else: - # mm_only mode: no language model, so no lm_head needed + # encoder_only mode: no language model, so no lm_head needed self.lm_head = None self.visual = Qwen2_5_VisionTransformer( diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py index 95f5f2b5b32b..b49e03ccd6b5 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -637,7 +637,7 @@ def __init__( prefix=add_prefix("lm_head", prefix), ) else: - # mm_only mode: no language model, so no lm_head needed + # encoder_only mode: no language model, so no lm_head needed self.lm_head = None self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index cd6a95add1a8..c38d0d4cf060 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -157,6 +157,7 @@ TestFile("test_multi_instance_release_memory_occupation.py", 64), TestFile("test_pp_single_node.py", 481), TestFile("test_piecewise_cuda_graph.py", 1200), + TestFile("test_epd_disaggregation.py", 600), ], "per-commit-8-gpu-h200": [ TestFile("test_deepseek_v3_basic.py", 275), diff --git a/test/srt/test_epd_disaggregation.py b/test/srt/test_epd_disaggregation.py new file mode 100644 index 000000000000..40e4a9b55d10 --- /dev/null +++ b/test/srt/test_epd_disaggregation.py @@ -0,0 +1,390 @@ +import os +import subprocess +import unittest + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_disaggregation_utils import TestDisaggregationBase +from sglang.test.test_utils import ( + DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + popen_launch_server, +) + + +class TestEPDDisaggregationOneEncoder(TestDisaggregationBase): + """Test EPD disaggregation with single encode server""" + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.model = DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST + cls.encode_port = f"{int(cls.lb_port) + 300}" + cls.encode_url = f"http://{cls.base_host}:{cls.encode_port}" + + print( + f"Setting up EPD (one encoder): encode={cls.encode_port}, " + f"prefill={cls.prefill_port}, decode={cls.decode_port}" + ) + + # Start servers in order: encode -> prefill -> decode + cls.start_encode() + cls.start_prefill() + cls.start_decode() + + # Wait for all servers to be ready + cls.wait_server_ready(cls.encode_url + "/health") + cls.wait_server_ready(cls.prefill_url + "/health") + cls.wait_server_ready(cls.decode_url + "/health") + + cls.launch_lb() + + # Set OpenAI API key and base URL environment variables. Needed for lmms-eval to work. + cls.api_key = "sk-123456" + os.environ["OPENAI_API_KEY"] = cls.api_key + os.environ["OPENAI_API_BASE"] = f"{cls.lb_url}/v1" + + @classmethod + def start_encode(cls): + """Start encode server for multimodal processing""" + encode_args = [ + "--trust-remote-code", + "--encoder-only", + "--encoder-transfer-backend", + "zmq_to_scheduler", + "--tp", + "1", + "--port", + cls.encode_port, + "--enable-prefix-mm-cache", + ] + cls.process_encode = popen_launch_server( + cls.model, + base_url=cls.encode_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=encode_args, + ) + + @classmethod + def start_prefill(cls): + """Start prefill server with language model only""" + prefill_args = [ + "--trust-remote-code", + "--language-only", + "--encoder-urls", + cls.encode_url, + "--encoder-transfer-backend", + "zmq_to_scheduler", + "--disaggregation-mode", + "prefill", + "--tp", + "1", + "--base-gpu-id", + "1", + "--port", + cls.prefill_port, + ] + prefill_args += cls.transfer_backend + cls.rdma_devices + cls.process_prefill = popen_launch_server( + cls.model, + base_url=cls.prefill_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=prefill_args, + ) + + @classmethod + def start_decode(cls): + """Start decode server""" + decode_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "decode", + "--tp", + "1", + "--base-gpu-id", + "2", + "--port", + cls.decode_port, + ] + decode_args += cls.transfer_backend + cls.rdma_devices + cls.process_decode = popen_launch_server( + cls.model, + base_url=cls.decode_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=decode_args, + ) + + @classmethod + def tearDownClass(cls): + """Clean up all processes""" + for process in [ + cls.process_lb, + cls.process_decode, + cls.process_prefill, + cls.process_encode, + ]: + if process: + try: + kill_process_tree(process.pid) + except Exception as e: + print(f"Error killing process: {e}") + + def run_mmmu_eval(self, model_version: str, output_path: str): + """ + Evaluate a VLM on the MMMU validation set with lmms-eval. + Reference: test_vlm_models.py + """ + model = "openai_compatible" + tp = 1 + tasks = "mmmu_val" + batch_size = 32 + log_suffix = "openai_compatible" + os.makedirs(output_path, exist_ok=True) + + model_args = f'model_version="{model_version}",' f"tp={tp}" + + cmd = [ + "python3", + "-m", + "lmms_eval", + "--model", + model, + "--model_args", + model_args, + "--tasks", + tasks, + "--batch_size", + str(batch_size), + "--log_samples", + "--log_samples_suffix", + log_suffix, + "--output_path", + str(output_path), + ] + + subprocess.run(cmd, check=True, timeout=3600) + + def test_mmmu(self): + """Test MMMU evaluation with EPD disaggregation""" + import glob + import json + + output_path = "./logs/epd_one_encoder_mmmu" + self.run_mmmu_eval(self.model, output_path) + + # Get the result file + result_files = glob.glob(f"{output_path}/**/*.json", recursive=True) + if not result_files: + result_files = glob.glob(f"{output_path}/*.json") + + if not result_files: + self.fail(f"No JSON result files found in {output_path}") + + result_file_path = result_files[0] + with open(result_file_path, "r") as f: + result = json.load(f) + print(f"MMMU result: {result}") + + mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"] + print(f"MMMU accuracy: {mmmu_accuracy:.4f}") + + # for qwen2.5-vl-3b-instruct, the accuracy is 0.40 + self.assertGreater(mmmu_accuracy, 0.40) + + +class TestEPDDisaggregationMultiEncoders(TestDisaggregationBase): + """ + Test EPD disaggregation with multiple encode servers for load balancing. + Both encode servers run on GPU 0 (different ports) for testing load distribution. + """ + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.model = DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST + cls.encode_port1 = f"{int(cls.lb_port) + 300}" + cls.encode_port2 = f"{int(cls.lb_port) + 301}" + cls.encode_url1 = f"http://{cls.base_host}:{cls.encode_port1}" + cls.encode_url2 = f"http://{cls.base_host}:{cls.encode_port2}" + + print( + f"Setting up EPD (multiple encoders): encode1={cls.encode_port1}, " + f"encode2={cls.encode_port2}, prefill={cls.prefill_port}, decode={cls.decode_port}" + ) + + # Start two encode servers on same GPU (GPU 0) + cls.start_encode_server(cls.encode_port1, 0) + cls.start_encode_server(cls.encode_port2, 0) + cls.start_prefill() + cls.start_decode() + + cls.wait_server_ready(cls.encode_url1 + "/health") + cls.wait_server_ready(cls.encode_url2 + "/health") + cls.wait_server_ready(cls.prefill_url + "/health") + cls.wait_server_ready(cls.decode_url + "/health") + + cls.launch_lb() + + # Set OpenAI API key and base URL environment variables. Needed for lmms-eval to work. + cls.api_key = "sk-123456" + os.environ["OPENAI_API_KEY"] = cls.api_key + os.environ["OPENAI_API_BASE"] = f"{cls.lb_url}/v1" + + @classmethod + def start_encode_server(cls, port, gpu_id): + """Start an encode server on specific port and GPU""" + encode_args = [ + "--trust-remote-code", + "--encoder-only", + "--encoder-transfer-backend", + "zmq_to_scheduler", + "--tp", + "1", + "--port", + port, + "--enable-prefix-mm-cache", + ] + # Only set base-gpu-id if not using GPU 0 + if gpu_id != 0: + encode_args.extend(["--base-gpu-id", str(gpu_id)]) + + process = popen_launch_server( + cls.model, + base_url=f"http://{cls.base_host}:{port}", + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=encode_args, + ) + if port == cls.encode_port1: + cls.process_encode1 = process + else: + cls.process_encode2 = process + + @classmethod + def start_prefill(cls): + """Start prefill server with multiple encode URLs""" + prefill_args = [ + "--trust-remote-code", + "--language-only", + "--encoder-urls", + cls.encode_url1, + cls.encode_url2, + "--encoder-transfer-backend", + "zmq_to_scheduler", + "--disaggregation-mode", + "prefill", + "--tp", + "1", + "--base-gpu-id", + "2", + "--port", + cls.prefill_port, + ] + prefill_args += cls.transfer_backend + cls.rdma_devices + cls.process_prefill = popen_launch_server( + cls.model, + base_url=cls.prefill_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=prefill_args, + ) + + @classmethod + def start_decode(cls): + """Start decode server""" + decode_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "decode", + "--tp", + "1", + "--base-gpu-id", + "3", + "--port", + cls.decode_port, + ] + decode_args += cls.transfer_backend + cls.rdma_devices + cls.process_decode = popen_launch_server( + cls.model, + base_url=cls.decode_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=decode_args, + ) + + @classmethod + def tearDownClass(cls): + """Clean up all processes""" + for process in [ + cls.process_lb, + cls.process_decode, + cls.process_prefill, + cls.process_encode1, + cls.process_encode2, + ]: + if process: + try: + kill_process_tree(process.pid) + except Exception as e: + print(f"Error killing process: {e}") + + def run_mmmu_eval(self, model_version: str, output_path: str): + """ + Evaluate a VLM on the MMMU validation set with lmms-eval. + Reference: test_vlm_models.py + """ + model = "openai_compatible" + tp = 1 + tasks = "mmmu_val" + batch_size = 32 + log_suffix = "openai_compatible" + os.makedirs(output_path, exist_ok=True) + + model_args = f'model_version="{model_version}",' f"tp={tp}" + + cmd = [ + "python3", + "-m", + "lmms_eval", + "--model", + model, + "--model_args", + model_args, + "--tasks", + tasks, + "--batch_size", + str(batch_size), + "--log_samples", + "--log_samples_suffix", + log_suffix, + "--output_path", + str(output_path), + ] + + subprocess.run(cmd, check=True, timeout=3600) + + def test_mmmu(self): + """Test MMMU evaluation with EPD disaggregation (multiple encoders)""" + import glob + import json + + output_path = "./logs/epd_multi_encoder_mmmu" + self.run_mmmu_eval(self.model, output_path) + + # Get the result file + result_files = glob.glob(f"{output_path}/**/*.json", recursive=True) + if not result_files: + result_files = glob.glob(f"{output_path}/*.json") + + if not result_files: + self.fail(f"No JSON result files found in {output_path}") + + result_file_path = result_files[0] + with open(result_file_path, "r") as f: + result = json.load(f) + print(f"MMMU result (multi encoder): {result}") + + mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"] + print(f"MMMU accuracy (multi encoder): {mmmu_accuracy:.4f}") + # for qwen2.5-vl-3b-instruct, the accuracy is 0.40 + self.assertGreater(mmmu_accuracy, 0.40) + + +if __name__ == "__main__": + unittest.main() From f3ceb9fca5c8e095bced7234b4343b62261a2665 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Wed, 10 Dec 2025 02:42:00 +0000 Subject: [PATCH 64/68] Lint --- python/sglang/srt/multimodal/processors/qwen_vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/multimodal/processors/qwen_vl.py b/python/sglang/srt/multimodal/processors/qwen_vl.py index 51151cb2a453..324a9b8fb939 100644 --- a/python/sglang/srt/multimodal/processors/qwen_vl.py +++ b/python/sglang/srt/multimodal/processors/qwen_vl.py @@ -239,7 +239,7 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): self.IM_START_TOKEN_ID = hf_config.vision_start_token_id self.IM_END_TOKEN_ID = hf_config.vision_end_token_id self.IM_TOKEN_ID = hf_config.image_token_id - + self.vision_start_token_id = hf_config.vision_start_token_id self.vision_end_token_id = getattr(hf_config, "vision_end_token_id", None) From f2195dbeb32c50f86529b3c8f138ad727b6a3f8a Mon Sep 17 00:00:00 2001 From: Zheng Wengang Date: Fri, 12 Dec 2025 14:02:14 +0800 Subject: [PATCH 65/68] ut: speed up epd_dis ut (#12) --- test/srt/test_epd_disaggregation.py | 52 +++++++++++++++++++++++------ 1 file changed, 42 insertions(+), 10 deletions(-) diff --git a/test/srt/test_epd_disaggregation.py b/test/srt/test_epd_disaggregation.py index 40e4a9b55d10..51b77e575add 100644 --- a/test/srt/test_epd_disaggregation.py +++ b/test/srt/test_epd_disaggregation.py @@ -1,5 +1,6 @@ import os import subprocess +import threading import unittest from sglang.srt.utils import kill_process_tree @@ -26,10 +27,14 @@ def setUpClass(cls): f"prefill={cls.prefill_port}, decode={cls.decode_port}" ) - # Start servers in order: encode -> prefill -> decode + # Start servers in order: encode -> prefill/decode cls.start_encode() - cls.start_prefill() - cls.start_decode() + prefill_thread = threading.Thread(target=cls.start_prefill) + decode_thread = threading.Thread(target=cls.start_decode) + prefill_thread.start() + decode_thread.start() + prefill_thread.join() + decode_thread.join() # Wait for all servers to be ready cls.wait_server_ready(cls.encode_url + "/health") @@ -128,10 +133,15 @@ def tearDownClass(cls): except Exception as e: print(f"Error killing process: {e}") - def run_mmmu_eval(self, model_version: str, output_path: str): + def run_mmmu_eval(self, model_version: str, output_path: str, limit: str = "50"): """ Evaluate a VLM on the MMMU validation set with lmms-eval. Reference: test_vlm_models.py + + Args: + model_version: Model version/checkpoint to evaluate + output_path: Path to save evaluation results + limit: Number of samples to evaluate (default: "50" for CI time constraints) """ model = "openai_compatible" tp = 1 @@ -159,6 +169,8 @@ def run_mmmu_eval(self, model_version: str, output_path: str): log_suffix, "--output_path", str(output_path), + "--limit", + limit, ] subprocess.run(cmd, check=True, timeout=3600) @@ -211,11 +223,24 @@ def setUpClass(cls): f"encode2={cls.encode_port2}, prefill={cls.prefill_port}, decode={cls.decode_port}" ) - # Start two encode servers on same GPU (GPU 0) - cls.start_encode_server(cls.encode_port1, 0) - cls.start_encode_server(cls.encode_port2, 0) - cls.start_prefill() - cls.start_decode() + # Start two encode servers on GPU 0/1 + encode1_thread = threading.Thread( + target=cls.start_encode_server, args=(cls.encode_port1, 0) + ) + encode2_thread = threading.Thread( + target=cls.start_encode_server, args=(cls.encode_port2, 1) + ) + encode1_thread.start() + encode2_thread.start() + encode1_thread.join() + encode2_thread.join() + + prefill_thread = threading.Thread(target=cls.start_prefill) + decode_thread = threading.Thread(target=cls.start_decode) + prefill_thread.start() + decode_thread.start() + prefill_thread.join() + decode_thread.join() cls.wait_server_ready(cls.encode_url1 + "/health") cls.wait_server_ready(cls.encode_url2 + "/health") @@ -324,10 +349,15 @@ def tearDownClass(cls): except Exception as e: print(f"Error killing process: {e}") - def run_mmmu_eval(self, model_version: str, output_path: str): + def run_mmmu_eval(self, model_version: str, output_path: str, limit: str = "50"): """ Evaluate a VLM on the MMMU validation set with lmms-eval. Reference: test_vlm_models.py + + Args: + model_version: Model version/checkpoint to evaluate + output_path: Path to save evaluation results + limit: Number of samples to evaluate (default: "50" for CI time constraints) """ model = "openai_compatible" tp = 1 @@ -355,6 +385,8 @@ def run_mmmu_eval(self, model_version: str, output_path: str): log_suffix, "--output_path", str(output_path), + "--limit", + limit, ] subprocess.run(cmd, check=True, timeout=3600) From b105ec1577d5c9128db7639b7c188644f0a3644b Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Sat, 13 Dec 2025 21:39:32 +0800 Subject: [PATCH 66/68] Fix port allocation and other lints (#13) --- .../srt/disaggregation/encode_receiver.py | 42 +++++++------------ .../srt/disaggregation/encode_server.py | 6 +-- .../sglang/srt/managers/tokenizer_manager.py | 2 +- 3 files changed, 20 insertions(+), 30 deletions(-) diff --git a/python/sglang/srt/disaggregation/encode_receiver.py b/python/sglang/srt/disaggregation/encode_receiver.py index 0742535f68e8..c4938213909b 100644 --- a/python/sglang/srt/disaggregation/encode_receiver.py +++ b/python/sglang/srt/disaggregation/encode_receiver.py @@ -15,7 +15,7 @@ from sglang.srt.managers.io_struct import TokenizedGenerateReqInput from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import get_free_port, get_local_ip_auto, get_zmq_socket +from sglang.srt.utils import get_local_ip_auto, get_zmq_socket_on_host from sglang.srt.utils.hf_transformers_utils import get_processor logger = logging.getLogger(__name__) @@ -86,7 +86,7 @@ def __init__( rid: str, recv_req: TokenizedGenerateReqInput, mm_processor, - image_urls, + encoder_urls, host_name, receive_count, embedding_port=None, @@ -97,16 +97,13 @@ def __init__( self.error = None self.thread = None self.mm_processor = mm_processor - self.image_urls = image_urls + self.encoder_urls = encoder_urls self.host_name = host_name self.receive_count = receive_count self.num_items_assigned = recv_req.num_items_assigned - self.embedding_port = ( - get_free_port() if embedding_port is None else embedding_port + self.embedding_port, self.recv_socket = get_zmq_socket_on_host( + zmq.Context(), zmq.PULL ) - self.context = zmq.Context() - self.recv_socket = self.context.socket(zmq.PULL) - self.recv_socket.bind(f"tcp://*:{self.embedding_port}") logger.info(f"Waiting for input {self.embedding_port = }") self.recv_embedding_data = None self.ready = False @@ -130,8 +127,8 @@ async def send_embedding_port(req_id, receive_count, host_name, embedding_port): for idx, assigned_num in enumerate(self.num_items_assigned): if assigned_num == 0: continue - image_url = self.image_urls[idx] - target_url = f"{image_url}/scheduler_receive_url" + encoder_url = self.encoder_urls[idx] + target_url = f"{encoder_url}/scheduler_receive_url" payload = { "req_id": req_id, "receive_count": receive_count, @@ -194,6 +191,7 @@ def _try_recv_mm_data(self): self.recv_req.mm_inputs = mm_inputs self.recv_req.input_ids = mm_inputs["input_ids"] self.ready = True + self.recv_socket.close() def _determine_tensor_transport_mode(server_args): @@ -286,7 +284,7 @@ def process_waiting_requests(self, recv_reqs): rid=recv_req.rid, recv_req=recv_req, mm_processor=self.mm_processor, - image_urls=self.encode_urls, + encoder_urls=self.encode_urls, host_name=self.hostname, receive_count=self.world_size, embedding_port=embedding_port, @@ -451,7 +449,7 @@ async def allocate_embedding_buffer(self, req_id, embedding_length, embedding_di return embeddings.data_ptr() # For zmq_to_scheduler - def send_encode_requset(self, obj): + def send_encode_request(self, obj): if type(obj.image_data) != list: image_urls = [obj.image_data.url] else: @@ -467,11 +465,7 @@ def send_encode_requset(self, obj): obj.num_items_assigned = [ (idx + len(image_urls)) // len(self.encode_urls) for idx in encode_idx ] - obj.embedding_ports = ( - [get_free_port() for _ in range(self.world_size)] - if self.nnodes == 1 - else None - ) + obj.embedding_ports = None encode_thread = threading.Thread( target=self._run_encode_in_thread, args=( @@ -491,7 +485,7 @@ async def recv_mm_data(self, img_data, mm_processor, prompt): if len(self.encode_urls) == 0: return None req_id = uuid.uuid4().hex - embedding_port = get_free_port() + embedding_port, recv_socket = get_zmq_socket_on_host(self.context, zmq.PULL) if type(img_data) != list: img_data = [img_data.url] else: @@ -500,7 +494,7 @@ async def recv_mm_data(self, img_data, mm_processor, prompt): self.encode(req_id, img_data, embedding_port, "encode", "send") ) return await asyncio.wait_for( - self._recv_mm_data(req_id, embedding_port, mm_processor, prompt), + self._recv_mm_data(req_id, recv_socket, mm_processor, prompt), timeout=20, ) except asyncio.TimeoutError: @@ -510,19 +504,13 @@ async def recv_mm_data(self, img_data, mm_processor, prompt): return None # For zmq_to_tokenizer and mooncake - async def _recv_mm_data(self, req_id, embedding_port, mm_processor, prompt): + async def _recv_mm_data(self, req_id, recv_socket, mm_processor, prompt): # Bypass MMReceiver if req_id is None: return None recv_embedding = None - recv_socket = get_zmq_socket( - self.context, zmq.PULL, f"tcp://*:{embedding_port}", True - ) - - logger.info(f"{embedding_port = }") - recv_embedding_data: EmbeddingData = None while recv_embedding_data is None or not recv_embedding_data.ready: @@ -548,6 +536,8 @@ async def _recv_mm_data(self, req_id, embedding_port, mm_processor, prompt): elif self.encoder_transfer_backend == "zmq_to_tokenizer": recv_embedding = recv_embedding_data.get_embedding(is_concat=True) + recv_socket.close() + img_grid_thw = recv_embedding_data.get_img_grid() mm_inputs = mm_processor.get_mm_data(prompt, recv_embedding, img_grid_thw) diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py index 7172e306abd4..a9ad60f8777f 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -314,7 +314,7 @@ async def send_with_url( try: while True: - with rid_lock: + async with rid_lock: current_targets = rid_to_receive_endpoint.get(req_id, set()).copy() expected_count = rid_to_receive_count.get(req_id) @@ -367,7 +367,7 @@ async def send_with_url( finally: logger.info(f"Cleaning up resources for req_id {req_id}") - with rid_lock: + async with rid_lock: rid_to_receive_endpoint.pop(req_id, None) rid_to_receive_count.pop(req_id, None) self.embedding_to_send.pop(req_id, None) @@ -506,7 +506,7 @@ async def handle_send_request(request: dict): @app.post("/scheduler_receive_url") async def handle_scheduler_receive_url_request(request: dict): rid = request["req_id"] - with rid_lock: + async with rid_lock: global rid_to_receive_endpoint if rid not in rid_to_receive_endpoint: rid_to_receive_endpoint[rid] = set() diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 5e8cdf3ae881..cdd09ad50390 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -425,7 +425,7 @@ async def generate_request( and self.server_args.encoder_transfer_backend == "zmq_to_scheduler" and obj.contains_mm_input() ): - self.mm_receiver.send_encode_requset(obj) + self.mm_receiver.send_encode_request(obj) if self.enable_trace: self._trace_request_start(obj, created_time, request) From 6b5ded7734e5c96e8bfa5e85ec546ae4a5445a02 Mon Sep 17 00:00:00 2001 From: Tianyu Guo Date: Sat, 13 Dec 2025 14:04:28 +0000 Subject: [PATCH 67/68] Fix merge --- python/sglang/srt/server_args.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 3378f8afa38a..ed2b4f97de81 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -4214,12 +4214,6 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.decrypted_draft_config_file, help="The path of the decrypted draft config file.", ) - parser.add_argument( - "--mm-enable-dp-encoder", - action="store_true", - default=ServerArgs.mm_enable_dp_encoder, - help="Enabling data parallelism for mm encoder. The dp size will be set to the tp size automatically.", - ) parser.add_argument( "--enable-prefix-mm-cache", action="store_true", From c9126ed91a6a63cef66b7474a544632583c11aba Mon Sep 17 00:00:00 2001 From: liusy58 Date: Sun, 14 Dec 2025 12:00:06 +0800 Subject: [PATCH 68/68] fix test --- python/sglang/srt/models/qwen2_5_vl.py | 5 ++++- test/srt/test_epd_disaggregation.py | 8 +++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index bd03eaa80d14..9336dbaf8756 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -789,7 +789,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name in params_dict.keys(): param = params_dict[name] else: - raise ValueError(f"Weight {name} not found in params_dict") + if get_global_server_args().encoder_only: + continue + else: + raise ValueError(f"Weight {name} not found in params_dict") except KeyError: print(params_dict.keys()) raise diff --git a/test/srt/test_epd_disaggregation.py b/test/srt/test_epd_disaggregation.py index 51b77e575add..07f79a97a9d4 100644 --- a/test/srt/test_epd_disaggregation.py +++ b/test/srt/test_epd_disaggregation.py @@ -4,7 +4,9 @@ import unittest from sglang.srt.utils import kill_process_tree -from sglang.test.test_disaggregation_utils import TestDisaggregationBase +from sglang.test.server_fixtures.disaggregation_fixture import ( + PDDisaggregationServerBase, +) from sglang.test.test_utils import ( DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -12,7 +14,7 @@ ) -class TestEPDDisaggregationOneEncoder(TestDisaggregationBase): +class TestEPDDisaggregationOneEncoder(PDDisaggregationServerBase): """Test EPD disaggregation with single encode server""" @classmethod @@ -203,7 +205,7 @@ def test_mmmu(self): self.assertGreater(mmmu_accuracy, 0.40) -class TestEPDDisaggregationMultiEncoders(TestDisaggregationBase): +class TestEPDDisaggregationMultiEncoders(PDDisaggregationServerBase): """ Test EPD disaggregation with multiple encode servers for load balancing. Both encode servers run on GPU 0 (different ports) for testing load distribution.