Skip to content

Commit 0427416

Browse files
Fix zmq binding (#2930)
Co-authored-by: Chunyuan WU <[email protected]>
1 parent bf3edc2 commit 0427416

File tree

5 files changed

+18
-12
lines changed

5 files changed

+18
-12
lines changed

python/sglang/srt/managers/data_parallel_controller.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(self, server_args, port_args) -> None:
6666
self.context = zmq.Context(1 + server_args.dp_size)
6767
if server_args.node_rank == 0:
6868
self.recv_from_tokenizer = get_zmq_socket(
69-
self.context, zmq.PULL, port_args.scheduler_input_ipc_name
69+
self.context, zmq.PULL, port_args.scheduler_input_ipc_name, False
7070
)
7171

7272
# Dispatch method
@@ -93,6 +93,7 @@ def __init__(self, server_args, port_args) -> None:
9393
self.context,
9494
zmq.PUSH,
9595
dp_port_args[dp_rank].scheduler_input_ipc_name,
96+
True,
9697
)
9798

9899
def launch_dp_schedulers(self, server_args, port_args):

python/sglang/srt/managers/detokenizer_manager.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ def __init__(
5858
# Init inter-process communication
5959
context = zmq.Context(2)
6060
self.recv_from_scheduler = get_zmq_socket(
61-
context, zmq.PULL, port_args.detokenizer_ipc_name
61+
context, zmq.PULL, port_args.detokenizer_ipc_name, True
6262
)
6363
self.send_to_tokenizer = get_zmq_socket(
64-
context, zmq.PUSH, port_args.tokenizer_ipc_name
64+
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
6565
)
6666

6767
if server_args.skip_tokenizer_init:

python/sglang/srt/managers/scheduler.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -162,21 +162,21 @@ def __init__(
162162

163163
if self.attn_tp_rank == 0:
164164
self.recv_from_tokenizer = get_zmq_socket(
165-
context, zmq.PULL, port_args.scheduler_input_ipc_name
165+
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
166166
)
167167
self.send_to_tokenizer = get_zmq_socket(
168-
context, zmq.PUSH, port_args.tokenizer_ipc_name
168+
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
169169
)
170170

171171
if server_args.skip_tokenizer_init:
172172
# Directly send to the TokenizerManager
173173
self.send_to_detokenizer = get_zmq_socket(
174-
context, zmq.PUSH, port_args.tokenizer_ipc_name
174+
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
175175
)
176176
else:
177177
# Send to the DetokenizerManager
178178
self.send_to_detokenizer = get_zmq_socket(
179-
context, zmq.PUSH, port_args.detokenizer_ipc_name
179+
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
180180
)
181181
else:
182182
self.recv_from_tokenizer = None

python/sglang/srt/managers/tokenizer_manager.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,10 @@ def __init__(
119119
# Init inter-process communication
120120
context = zmq.asyncio.Context(2)
121121
self.recv_from_detokenizer = get_zmq_socket(
122-
context, zmq.PULL, port_args.tokenizer_ipc_name
122+
context, zmq.PULL, port_args.tokenizer_ipc_name, True
123123
)
124124
self.send_to_scheduler = get_zmq_socket(
125-
context, zmq.PUSH, port_args.scheduler_input_ipc_name
125+
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
126126
)
127127

128128
# Read model args

python/sglang/srt/utils.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -789,7 +789,9 @@ def first_rank_print(*args, **kwargs):
789789
pass
790790

791791

792-
def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: str):
792+
def get_zmq_socket(
793+
context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
794+
):
793795
mem = psutil.virtual_memory()
794796
total_mem = mem.total / 1024**3
795797
available_mem = mem.available / 1024**3
@@ -802,14 +804,17 @@ def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint:
802804
if socket_type == zmq.PUSH:
803805
socket.setsockopt(zmq.SNDHWM, 0)
804806
socket.setsockopt(zmq.SNDBUF, buf_size)
805-
socket.connect(endpoint)
806807
elif socket_type == zmq.PULL:
807808
socket.setsockopt(zmq.RCVHWM, 0)
808809
socket.setsockopt(zmq.RCVBUF, buf_size)
809-
socket.bind(endpoint)
810810
else:
811811
raise ValueError(f"Unsupported socket type: {socket_type}")
812812

813+
if bind:
814+
socket.bind(endpoint)
815+
else:
816+
socket.connect(endpoint)
817+
813818
return socket
814819

815820

0 commit comments

Comments
 (0)