Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 16 additions & 26 deletions python/sglang/srt/disaggregation/encode_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_free_port, get_local_ip_auto, get_zmq_socket
from sglang.srt.utils import get_local_ip_auto, get_zmq_socket_on_host
from sglang.srt.utils.hf_transformers_utils import get_processor

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -86,7 +86,7 @@ def __init__(
rid: str,
recv_req: TokenizedGenerateReqInput,
mm_processor,
image_urls,
encoder_urls,
host_name,
receive_count,
embedding_port=None,
Expand All @@ -97,16 +97,13 @@ def __init__(
self.error = None
self.thread = None
self.mm_processor = mm_processor
self.image_urls = image_urls
self.encoder_urls = encoder_urls
self.host_name = host_name
self.receive_count = receive_count
self.num_items_assigned = recv_req.num_items_assigned
self.embedding_port = (
get_free_port() if embedding_port is None else embedding_port
self.embedding_port, self.recv_socket = get_zmq_socket_on_host(
zmq.Context(), zmq.PULL
)
self.context = zmq.Context()
self.recv_socket = self.context.socket(zmq.PULL)
self.recv_socket.bind(f"tcp://*:{self.embedding_port}")
logger.info(f"Waiting for input {self.embedding_port = }")
self.recv_embedding_data = None
self.ready = False
Expand All @@ -130,8 +127,8 @@ async def send_embedding_port(req_id, receive_count, host_name, embedding_port):
for idx, assigned_num in enumerate(self.num_items_assigned):
if assigned_num == 0:
continue
image_url = self.image_urls[idx]
target_url = f"{image_url}/scheduler_receive_url"
encoder_url = self.encoder_urls[idx]
target_url = f"{encoder_url}/scheduler_receive_url"
payload = {
"req_id": req_id,
"receive_count": receive_count,
Expand Down Expand Up @@ -194,6 +191,7 @@ def _try_recv_mm_data(self):
self.recv_req.mm_inputs = mm_inputs
self.recv_req.input_ids = mm_inputs["input_ids"]
self.ready = True
self.recv_socket.close()


def _determine_tensor_transport_mode(server_args):
Expand Down Expand Up @@ -286,7 +284,7 @@ def process_waiting_requests(self, recv_reqs):
rid=recv_req.rid,
recv_req=recv_req,
mm_processor=self.mm_processor,
image_urls=self.encode_urls,
encoder_urls=self.encode_urls,
host_name=self.hostname,
receive_count=self.world_size,
embedding_port=embedding_port,
Expand Down Expand Up @@ -451,7 +449,7 @@ async def allocate_embedding_buffer(self, req_id, embedding_length, embedding_di
return embeddings.data_ptr()

# For zmq_to_scheduler
def send_encode_requset(self, obj):
def send_encode_request(self, obj):
if type(obj.image_data) != list:
image_urls = [obj.image_data.url]
else:
Expand All @@ -467,11 +465,7 @@ def send_encode_requset(self, obj):
obj.num_items_assigned = [
(idx + len(image_urls)) // len(self.encode_urls) for idx in encode_idx
]
obj.embedding_ports = (
[get_free_port() for _ in range(self.world_size)]
if self.nnodes == 1
else None
)
obj.embedding_ports = None
encode_thread = threading.Thread(
target=self._run_encode_in_thread,
args=(
Expand All @@ -491,7 +485,7 @@ async def recv_mm_data(self, img_data, mm_processor, prompt):
if len(self.encode_urls) == 0:
return None
req_id = uuid.uuid4().hex
embedding_port = get_free_port()
embedding_port, recv_socket = get_zmq_socket_on_host(self.context, zmq.PULL)
if type(img_data) != list:
img_data = [img_data.url]
else:
Expand All @@ -500,7 +494,7 @@ async def recv_mm_data(self, img_data, mm_processor, prompt):
self.encode(req_id, img_data, embedding_port, "encode", "send")
)
return await asyncio.wait_for(
self._recv_mm_data(req_id, embedding_port, mm_processor, prompt),
self._recv_mm_data(req_id, recv_socket, mm_processor, prompt),
timeout=20,
)
except asyncio.TimeoutError:
Expand All @@ -510,19 +504,13 @@ async def recv_mm_data(self, img_data, mm_processor, prompt):
return None

# For zmq_to_tokenizer and mooncake
async def _recv_mm_data(self, req_id, embedding_port, mm_processor, prompt):
async def _recv_mm_data(self, req_id, recv_socket, mm_processor, prompt):
# Bypass MMReceiver
if req_id is None:
return None

recv_embedding = None

recv_socket = get_zmq_socket(
self.context, zmq.PULL, f"tcp://*:{embedding_port}", True
)

logger.info(f"{embedding_port = }")

recv_embedding_data: EmbeddingData = None

while recv_embedding_data is None or not recv_embedding_data.ready:
Expand All @@ -548,6 +536,8 @@ async def _recv_mm_data(self, req_id, embedding_port, mm_processor, prompt):
elif self.encoder_transfer_backend == "zmq_to_tokenizer":
recv_embedding = recv_embedding_data.get_embedding(is_concat=True)

recv_socket.close()

img_grid_thw = recv_embedding_data.get_img_grid()

mm_inputs = mm_processor.get_mm_data(prompt, recv_embedding, img_grid_thw)
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/disaggregation/encode_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ async def send_with_url(

try:
while True:
with rid_lock:
async with rid_lock:
current_targets = rid_to_receive_endpoint.get(req_id, set()).copy()
expected_count = rid_to_receive_count.get(req_id)

Expand Down Expand Up @@ -367,7 +367,7 @@ async def send_with_url(

finally:
logger.info(f"Cleaning up resources for req_id {req_id}")
with rid_lock:
async with rid_lock:
rid_to_receive_endpoint.pop(req_id, None)
rid_to_receive_count.pop(req_id, None)
self.embedding_to_send.pop(req_id, None)
Expand Down Expand Up @@ -506,7 +506,7 @@ async def handle_send_request(request: dict):
@app.post("/scheduler_receive_url")
async def handle_scheduler_receive_url_request(request: dict):
rid = request["req_id"]
with rid_lock:
async with rid_lock:
global rid_to_receive_endpoint
if rid not in rid_to_receive_endpoint:
rid_to_receive_endpoint[rid] = set()
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ async def generate_request(
and self.server_args.encoder_transfer_backend == "zmq_to_scheduler"
and obj.contains_mm_input()
):
self.mm_receiver.send_encode_requset(obj)
self.mm_receiver.send_encode_request(obj)

if self.enable_trace:
self._trace_request_start(obj, created_time, request)
Expand Down
Loading