Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions examples/runtime/zmq_client_optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import multiprocessing as mp
import zmq
import time
import psutil
from datetime import datetime

def server():
"""ZMQ server that sends messages every second."""
context = zmq.Context()
socket = context.socket(zmq.PUSH)
socket.bind("tcp://*:5555")

while True:
time.sleep(1)
message = "Hello"
socket.send_string(message)
print(f"[{datetime.now()}] Server sent: {message}")

def monitor_cpu_usage(pid, duration):
"""Monitor CPU usage of a specific process for a given duration."""
process = psutil.Process(pid)
start_time = time.time()
while time.time() - start_time < duration:
cpu_usage = process.cpu_percent(interval=1)
print(f"Process ID:{pid} CPU Usage: {cpu_usage:.2f}%")
time.sleep(1)

def client(optimized=False):
"""ZMQ client that receives messages."""
client_type = "optimized" if optimized else "unoptimized"
print(f"Running {client_type} client...")
print(f"Process ID: {mp.current_process().pid}")

context = zmq.Context()
socket = context.socket(zmq.PULL)
socket.connect("tcp://localhost:5555")

if optimized:
socket.setsockopt(zmq.RCVTIMEO, 100) # Set a 100ms timeout for receiving

start_time = time.time()
last_print_time = start_time
counter = 0

while time.time() - start_time < 30:
try:
message = socket.recv_string(zmq.NOBLOCK) if not optimized else socket.recv_string()
print(f"[{datetime.now()}] {client_type.capitalize()} client received: {message}")
except zmq.Again:
counter += 1
current_time = time.time()
if current_time - last_print_time >= 2:
print(f"[{datetime.now()}] {client_type.capitalize()} client: No message received. Attempts: {counter}")
last_print_time = current_time

def run_client_test(optimized=False):
"""Run a client test with CPU usage monitoring."""
client_process = mp.Process(target=client, args=(optimized,))
client_process.start()

monitor_process = mp.Process(target=monitor_cpu_usage, args=(client_process.pid, 30))
monitor_process.start()

client_process.join()
monitor_process.join()

if __name__ == "__main__":
mp.set_start_method("spawn", force=True)

# Start the server process
server_process = mp.Process(target=server)
server_process.start()

# Test unoptimized client
print("Testing unoptimized client...")
run_client_test(optimized=False)

print("\nOptimizing the client...")
print("=" * 50)

# Test optimized client
print("Testing optimized client...")
run_client_test(optimized=True)

server_process.terminate()
7 changes: 6 additions & 1 deletion python/sglang/srt/managers/data_parallel_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def __init__(self, server_args, port_args) -> None:
# Init inter-process communication
self.context = zmq.Context(1 + server_args.dp_size)
self.recv_from_tokenizer = self.context.socket(zmq.PULL)
# set timeout to avoid blocking forever
self.recv_from_tokenizer.setsockopt(zmq.RCVTIMEO, 100)
self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}")

# Dispatch method
Expand Down Expand Up @@ -140,7 +142,10 @@ def event_loop(self):
while True:
while True:
try:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
recv_req = self.recv_from_tokenizer.recv_pyobj()
except zmq.Again:
# skip if no more requests
break
except zmq.ZMQError:
break

Expand Down
7 changes: 6 additions & 1 deletion python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def __init__(

if self.tp_rank == 0:
self.recv_from_tokenizer = context.socket(zmq.PULL)
# set timeout to avoid blocking forever
self.recv_from_tokenizer.setsockopt(zmq.RCVTIMEO, 100)
self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}")

self.send_to_detokenizer = context.socket(zmq.PUSH)
Expand Down Expand Up @@ -332,7 +334,10 @@ def recv_requests(self):

while True:
try:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
recv_req = self.recv_from_tokenizer.recv_pyobj()
except zmq.Again:
# skip if no more requests
break
except zmq.ZMQError:
break
recv_reqs.append(recv_req)
Expand Down