diff --git a/python/sglang/srt/disaggregation/encode_receiver.py b/python/sglang/srt/disaggregation/encode_receiver.py index fa1ea0d4b83b..3bd3c0d85525 100644 --- a/python/sglang/srt/disaggregation/encode_receiver.py +++ b/python/sglang/srt/disaggregation/encode_receiver.py @@ -29,7 +29,11 @@ from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ImageData from sglang.srt.utils.hf_transformers_utils import get_processor -from sglang.srt.utils.network import get_local_ip_auto, get_zmq_socket_on_host +from sglang.srt.utils.network import ( + NetworkAddress, + get_local_ip_auto, + get_zmq_socket_on_host, +) logger = logging.getLogger(__name__) @@ -444,10 +448,13 @@ async def send_embedding_port(req_id, receive_count, host_name, embedding_port): part_req_id = create_part_req_id(req_id, part_idx) encoder_url = self.encoder_urls[idx] target_url = f"{encoder_url}/scheduler_receive_url" + receive_url = NetworkAddress( + host_name, embedding_port + ).to_host_port_str() payload = { "req_id": part_req_id, # use part_req_id to match encode request "receive_count": receive_count, - "receive_url": f"{host_name}:{embedding_port}", + "receive_url": receive_url, "modality": modality.name, } logger.info( diff --git a/python/sglang/srt/disaggregation/encode_server.py b/python/sglang/srt/disaggregation/encode_server.py index 4da24e522c98..87f9a731ee9c 100644 --- a/python/sglang/srt/disaggregation/encode_server.py +++ b/python/sglang/srt/disaggregation/encode_server.py @@ -1008,10 +1008,10 @@ async def _send( # Send ack/data if url is not None: - endpoint = NetworkAddress.parse(url).to_tcp() + endpoint = NetworkAddress.parse(url) else: - endpoint = NetworkAddress(prefill_host, embedding_port).to_tcp() - logger.info(f"{endpoint = }") + endpoint = NetworkAddress(prefill_host, embedding_port) + logger.info(f"{endpoint.to_tcp() = }") # Serialize data if self.server_args.encoder_transfer_backend == "mooncake": @@ -1031,12 +1031,20 @@ async def _send( def send_with_socket(): sock = self.sync_context.socket(zmq.PUSH) config_socket(sock, zmq.PUSH) + if endpoint.is_ipv6: + sock.setsockopt(zmq.IPV6, 1) + try: - sock.connect(endpoint) + sock.connect(endpoint.to_tcp()) if buffer is not None: sock.send_multipart([serialized_data, buffer], copy=False) else: sock.send_multipart([serialized_data], copy=False) + except Exception as e: + logger.error( + f"Error occurred while sending data in send_with_socket: {e}" + ) + raise finally: sock.close() diff --git a/python/sglang/srt/utils/network.py b/python/sglang/srt/utils/network.py index c374c9535524..d830559b87d0 100644 --- a/python/sglang/srt/utils/network.py +++ b/python/sglang/srt/utils/network.py @@ -204,6 +204,8 @@ def get_zmq_socket_on_host( else: bind_host = f"tcp://{host}" else: + # If no host is specified, have to set zmq.IPV6 here(default only binds to IPV4) + socket.setsockopt(zmq.IPV6, 1) bind_host = "tcp://*" port = socket.bind_to_random_port(bind_host) return port, socket