Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions python/sglang/srt/disaggregation/encode_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ImageData
from sglang.srt.utils.hf_transformers_utils import get_processor
from sglang.srt.utils.network import get_local_ip_auto, get_zmq_socket_on_host
from sglang.srt.utils.network import (
NetworkAddress,
get_local_ip_auto,
get_zmq_socket_on_host,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -444,10 +448,13 @@ async def send_embedding_port(req_id, receive_count, host_name, embedding_port):
part_req_id = create_part_req_id(req_id, part_idx)
encoder_url = self.encoder_urls[idx]
target_url = f"{encoder_url}/scheduler_receive_url"
receive_url = NetworkAddress(
host_name, embedding_port
).to_host_port_str()
payload = {
"req_id": part_req_id, # use part_req_id to match encode request
"receive_count": receive_count,
"receive_url": f"{host_name}:{embedding_port}",
"receive_url": receive_url,
"modality": modality.name,
}
logger.info(
Expand Down
16 changes: 12 additions & 4 deletions python/sglang/srt/disaggregation/encode_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,10 +1008,10 @@ async def _send(

# Send ack/data
if url is not None:
endpoint = NetworkAddress.parse(url).to_tcp()
endpoint = NetworkAddress.parse(url)
else:
endpoint = NetworkAddress(prefill_host, embedding_port).to_tcp()
logger.info(f"{endpoint = }")
endpoint = NetworkAddress(prefill_host, embedding_port)
logger.info(f"{endpoint.to_tcp() = }")

# Serialize data
if self.server_args.encoder_transfer_backend == "mooncake":
Expand All @@ -1031,12 +1031,20 @@ async def _send(
def send_with_socket():
sock = self.sync_context.socket(zmq.PUSH)
config_socket(sock, zmq.PUSH)
if endpoint.is_ipv6:
sock.setsockopt(zmq.IPV6, 1)

try:
sock.connect(endpoint)
sock.connect(endpoint.to_tcp())
if buffer is not None:
sock.send_multipart([serialized_data, buffer], copy=False)
else:
sock.send_multipart([serialized_data], copy=False)
except Exception as e:
logger.error(
f"Error occurred while sending data in send_with_socket: {e}"
)
raise
finally:
sock.close()

Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ def get_zmq_socket_on_host(
else:
bind_host = f"tcp://{host}"
else:
# If no host is specified, have to set zmq.IPV6 here(default only binds to IPV4)
socket.setsockopt(zmq.IPV6, 1)
bind_host = "tcp://*"
port = socket.bind_to_random_port(bind_host)
return port, socket
Expand Down
Loading