diff --git a/python/sglang/srt/disaggregation/encode_receiver.py b/python/sglang/srt/disaggregation/encode_receiver.py index 0742535f68e8..c4938213909b 100644 --- a/python/sglang/srt/disaggregation/encode_receiver.py +++ b/python/sglang/srt/disaggregation/encode_receiver.py @@ -15,7 +15,7 @@ from sglang.srt.managers.io_struct import TokenizedGenerateReqInput from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import get_free_port, get_local_ip_auto, get_zmq_socket +from sglang.srt.utils import get_local_ip_auto, get_zmq_socket_on_host from sglang.srt.utils.hf_transformers_utils import get_processor logger = logging.getLogger(__name__) @@ -86,7 +86,7 @@ def __init__( rid: str, recv_req: TokenizedGenerateReqInput, mm_processor, - image_urls, + encoder_urls, host_name, receive_count, embedding_port=None, @@ -97,16 +97,13 @@ def __init__( self.error = None self.thread = None self.mm_processor = mm_processor - self.image_urls = image_urls + self.encoder_urls = encoder_urls self.host_name = host_name self.receive_count = receive_count self.num_items_assigned = recv_req.num_items_assigned - self.embedding_port = ( - get_free_port() if embedding_port is None else embedding_port + self.embedding_port, self.recv_socket = get_zmq_socket_on_host( + zmq.Context(), zmq.PULL ) - self.context = zmq.Context() - self.recv_socket = self.context.socket(zmq.PULL) - self.recv_socket.bind(f"tcp://*:{self.embedding_port}") logger.info(f"Waiting for input {self.embedding_port = }") self.recv_embedding_data = None self.ready = False @@ -130,8 +127,8 @@ async def send_embedding_port(req_id, receive_count, host_name, embedding_port): for idx, assigned_num in enumerate(self.num_items_assigned): if assigned_num == 0: continue - image_url = self.image_urls[idx] - target_url = f"{image_url}/scheduler_receive_url" + encoder_url = self.encoder_urls[idx] + target_url = f"{encoder_url}/scheduler_receive_url" payload = { "req_id": req_id, "receive_count": receive_count, @@ -194,6 +191,7 @@ def _try_recv_mm_data(self): self.recv_req.mm_inputs = mm_inputs self.recv_req.input_ids = mm_inputs["input_ids"] self.ready = True + self.recv_socket.close() def _determine_tensor_transport_mode(server_args): @@ -286,7 +284,7 @@ def process_waiting_requests(self, recv_reqs): rid=recv_req.rid, recv_req=recv_req, mm_processor=self.mm_processor, - image_urls=self.encode_urls, + encoder_urls=self.encode_urls, host_name=self.hostname, receive_count=self.world_size, embedding_port=embedding_port, @@ -451,7 +449,7 @@ async def allocate_embedding_buffer(self, req_id, embedding_length, embedding_di return embeddings.data_ptr() # For zmq_to_scheduler - def send_encode_requset(self, obj): + def send_encode_request(self, obj): if type(obj.image_data) != list: image_urls = [obj.image_data.url] else: @@ -467,11 +465,7 @@ def send_encode_requset(self, obj): obj.num_items_assigned = [ (idx + len(image_urls)) // len(self.encode_urls) for idx in encode_idx ] - obj.embedding_ports = ( - [get_free_port() for _ in range(self.world_size)] - if self.nnodes == 1 - else None - ) + obj.embedding_ports = None encode_thread = threading.Thread( target=self._run_encode_in_thread, args=( @@ -491,7 +485,7 @@ async def recv_mm_data(self, img_data, mm_processor, prompt): if len(self.encode_urls) == 0: return None req_id = uuid.uuid4().hex - embedding_port = get_free_port() + embedding_port, recv_socket = get_zmq_socket_on_host(self.context, zmq.PULL) if type(img_data) != list: img_data = [img_data.url] else: @@ -500,7 +494,7 @@ async def recv_mm_data(self, img_data, mm_processor, prompt): self.encode(req_id, img_data, embedding_port, "encode", "send") ) return await asyncio.wait_for( - self._recv_mm_data(req_id, embedding_port, mm_processor, prompt), + self._recv_mm_data(req_id, recv_socket, mm_processor, prompt), timeout=20, ) except asyncio.TimeoutError: @@ -510,19 +504,13 @@ async def recv_mm_data(self, img_data, mm_processor, prompt): return None # For zmq_to_tokenizer and mooncake - async def _recv_mm_data(self, req_id, embedding_port, mm_processor, prompt): + async def _recv_mm_data(self, req_id, recv_socket, mm_processor, prompt): # Bypass MMReceiver if req_id is None: return None recv_embedding = None - recv_socket = get_zmq_socket( - self.context, zmq.PULL, f"tcp://*:{embedding_port}", True - ) - - logger.info(f"{embedding_port = }") - recv_embedding_data: EmbeddingData = None while recv_embedding_data is None or not recv_embedding_data.ready: @@ -548,6 +536,8 @@ async def _recv_mm_data(self, req_id, embedding_port, mm_processor, prompt): elif self.encoder_transfer_backend == "zmq_to_tokenizer": recv_embedding = recv_embedding_data.get_embedding(is_concat=True) + recv_socket.close() + img_grid_thw = recv_embedding_data.get_img_grid() mm_inputs = mm_processor.get_mm_data(prompt, recv_embedding, img_grid_thw) diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py index 7172e306abd4..a9ad60f8777f 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -314,7 +314,7 @@ async def send_with_url( try: while True: - with rid_lock: + async with rid_lock: current_targets = rid_to_receive_endpoint.get(req_id, set()).copy() expected_count = rid_to_receive_count.get(req_id) @@ -367,7 +367,7 @@ async def send_with_url( finally: logger.info(f"Cleaning up resources for req_id {req_id}") - with rid_lock: + async with rid_lock: rid_to_receive_endpoint.pop(req_id, None) rid_to_receive_count.pop(req_id, None) self.embedding_to_send.pop(req_id, None) @@ -506,7 +506,7 @@ async def handle_send_request(request: dict): @app.post("/scheduler_receive_url") async def handle_scheduler_receive_url_request(request: dict): rid = request["req_id"] - with rid_lock: + async with rid_lock: global rid_to_receive_endpoint if rid not in rid_to_receive_endpoint: rid_to_receive_endpoint[rid] = set() diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 5e8cdf3ae881..cdd09ad50390 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -425,7 +425,7 @@ async def generate_request( and self.server_args.encoder_transfer_backend == "zmq_to_scheduler" and obj.contains_mm_input() ): - self.mm_receiver.send_encode_requset(obj) + self.mm_receiver.send_encode_request(obj) if self.enable_trace: self._trace_request_start(obj, created_time, request)