Skip to content

Commit 96efa44

Browse files
njhillrtourgeman
authored andcommitted
[Perf] Exploit out-of-band buffers in shm_broadcast (vllm-project#26961)
Signed-off-by: Nick Hill <[email protected]>
1 parent 0e2f00d commit 96efa44

File tree

1 file changed

+54
-16
lines changed

1 file changed

+54
-16
lines changed

vllm/distributed/device_communicators/shm_broadcast.py

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
3+
import functools
44
import pickle
55
import time
66
from contextlib import contextmanager
77
from dataclasses import dataclass, field
88
from multiprocessing import shared_memory
9+
from pickle import PickleBuffer
910
from threading import Event
10-
from typing import Any
11+
from typing import TYPE_CHECKING, Any
1112
from unittest.mock import patch
1213

1314
import torch
@@ -33,8 +34,18 @@
3334
is_valid_ipv6_address,
3435
)
3536

37+
if TYPE_CHECKING:
38+
from _typeshed import SizedBuffer
39+
3640
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
3741

42+
from_bytes_big = functools.partial(int.from_bytes, byteorder="big")
43+
44+
45+
def to_bytes_big(value: int, size: int) -> bytes:
46+
return value.to_bytes(size, byteorder="big")
47+
48+
3849
logger = init_logger(__name__)
3950

4051

@@ -225,7 +236,7 @@ def __init__(
225236
n_reader, # number of all readers
226237
n_local_reader, # number of local readers through shared memory
227238
local_reader_ranks: list[int] | None = None,
228-
max_chunk_bytes: int = 1024 * 1024 * 10,
239+
max_chunk_bytes: int = 1024 * 1024 * 24, # 24MiB
229240
max_chunks: int = 10,
230241
connect_ip: str | None = None,
231242
):
@@ -505,18 +516,41 @@ def acquire_read(
505516
def enqueue(self, obj, timeout: float | None = None):
506517
"""Write to message queue with optional timeout (in seconds)"""
507518
assert self._is_writer, "Only writers can enqueue"
508-
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
519+
all_buffers: list[SizedBuffer] = [b""]
520+
total_bytes = 6 # 2 bytes for oob buffer count, 4 for main buffer size
521+
522+
def oob_callback(buf: PickleBuffer) -> bool:
523+
raw_buf = buf.raw()
524+
if len(raw_buf) < 1024 * 1024:
525+
# In-line buffers smaller than 1MiB.
526+
return True
527+
all_buffers.append(raw_buf)
528+
nonlocal total_bytes
529+
total_bytes += len(raw_buf) + 4
530+
return False
531+
532+
all_buffers[0] = pickle.dumps(
533+
obj, protocol=pickle.HIGHEST_PROTOCOL, buffer_callback=oob_callback
534+
)
509535
if self.n_local_reader > 0:
510-
if len(serialized_obj) >= self.buffer.max_chunk_bytes:
536+
if total_bytes + len(all_buffers[0]) >= self.buffer.max_chunk_bytes:
511537
with self.acquire_write(timeout) as buf:
512538
buf[0] = 1 # overflow
513-
self.local_socket.send(serialized_obj)
539+
self.local_socket.send_multipart(all_buffers, copy=False)
514540
else:
515541
with self.acquire_write(timeout) as buf:
516542
buf[0] = 0 # not overflow
517-
buf[1 : len(serialized_obj) + 1] = serialized_obj
543+
offset = 3
544+
buf[1:offset] = to_bytes_big(len(all_buffers), 2) # oob buf count
545+
for buffer in all_buffers:
546+
buf_len = len(buffer)
547+
# prepend each buffer with 4 bytes containing its size.
548+
buf_offset = offset + 4
549+
buf[offset:buf_offset] = to_bytes_big(buf_len, 4)
550+
buf[buf_offset : (offset := buf_offset + buf_len)] = buffer
551+
518552
if self.n_remote_reader > 0:
519-
self.remote_socket.send(serialized_obj)
553+
self.remote_socket.send_multipart(all_buffers, copy=False)
520554

521555
def dequeue(
522556
self,
@@ -529,10 +563,15 @@ def dequeue(
529563
with self.acquire_read(timeout, cancel, indefinite) as buf:
530564
overflow = buf[0] == 1
531565
if not overflow:
532-
# no need to know the size of serialized object
533-
# pickle format contains the size information internally
534-
# see https://docs.python.org/3/library/pickle.html
535-
obj = pickle.loads(buf[1:])
566+
offset = 3
567+
buf_count = from_bytes_big(buf[1:offset])
568+
all_buffers = []
569+
for i in range(buf_count):
570+
buf_offset = offset + 4
571+
buf_len = from_bytes_big(buf[offset:buf_offset])
572+
offset = buf_offset + buf_len
573+
all_buffers.append(buf[buf_offset:offset])
574+
obj = pickle.loads(all_buffers[0], buffers=all_buffers[1:])
536575
if overflow:
537576
obj = MessageQueue.recv(self.local_socket, timeout)
538577
elif self._is_remote_reader:
@@ -546,15 +585,14 @@ def recv(socket: zmq.Socket, timeout: float | None) -> Any:
546585
timeout_ms = None if timeout is None else int(timeout * 1000)
547586
if not socket.poll(timeout=timeout_ms):
548587
raise TimeoutError
549-
recv = socket.recv(copy=False)
550-
return pickle.loads(recv.buffer)
588+
recv, *recv_oob = socket.recv_multipart(copy=False)
589+
return pickle.loads(recv, buffers=recv_oob)
551590

552591
def broadcast_object(self, obj=None):
553592
if self._is_writer:
554593
self.enqueue(obj)
555594
return obj
556-
else:
557-
return self.dequeue()
595+
return self.dequeue()
558596

559597
@staticmethod
560598
def create_from_process_group(

0 commit comments

Comments
 (0)