11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
3+ import functools
44import pickle
55import time
66from contextlib import contextmanager
77from dataclasses import dataclass , field
88from multiprocessing import shared_memory
9+ from pickle import PickleBuffer
910from threading import Event
10- from typing import Any
11+ from typing import TYPE_CHECKING , Any
1112from unittest .mock import patch
1213
1314import torch
3334 is_valid_ipv6_address ,
3435)
3536
37+ if TYPE_CHECKING :
38+ from _typeshed import SizedBuffer
39+
3640VLLM_RINGBUFFER_WARNING_INTERVAL = envs .VLLM_RINGBUFFER_WARNING_INTERVAL
3741
42+ from_bytes_big = functools .partial (int .from_bytes , byteorder = "big" )
43+
44+
45+ def to_bytes_big (value : int , size : int ) -> bytes :
46+ return value .to_bytes (size , byteorder = "big" )
47+
48+
3849logger = init_logger (__name__ )
3950
4051
@@ -225,7 +236,7 @@ def __init__(
225236 n_reader , # number of all readers
226237 n_local_reader , # number of local readers through shared memory
227238 local_reader_ranks : list [int ] | None = None ,
228- max_chunk_bytes : int = 1024 * 1024 * 10 ,
239+ max_chunk_bytes : int = 1024 * 1024 * 24 , # 24MiB
229240 max_chunks : int = 10 ,
230241 connect_ip : str | None = None ,
231242 ):
@@ -505,18 +516,41 @@ def acquire_read(
505516 def enqueue (self , obj , timeout : float | None = None ):
506517 """Write to message queue with optional timeout (in seconds)"""
507518 assert self ._is_writer , "Only writers can enqueue"
508- serialized_obj = pickle .dumps (obj , protocol = pickle .HIGHEST_PROTOCOL )
519+ all_buffers : list [SizedBuffer ] = [b"" ]
520+ total_bytes = 6 # 2 bytes for oob buffer count, 4 for main buffer size
521+
522+ def oob_callback (buf : PickleBuffer ) -> bool :
523+ raw_buf = buf .raw ()
524+ if len (raw_buf ) < 1024 * 1024 :
525+ # In-line buffers smaller than 1MiB.
526+ return True
527+ all_buffers .append (raw_buf )
528+ nonlocal total_bytes
529+ total_bytes += len (raw_buf ) + 4
530+ return False
531+
532+ all_buffers [0 ] = pickle .dumps (
533+ obj , protocol = pickle .HIGHEST_PROTOCOL , buffer_callback = oob_callback
534+ )
509535 if self .n_local_reader > 0 :
510- if len (serialized_obj ) >= self .buffer .max_chunk_bytes :
536+ if total_bytes + len (all_buffers [ 0 ] ) >= self .buffer .max_chunk_bytes :
511537 with self .acquire_write (timeout ) as buf :
512538 buf [0 ] = 1 # overflow
513- self .local_socket .send ( serialized_obj )
539+ self .local_socket .send_multipart ( all_buffers , copy = False )
514540 else :
515541 with self .acquire_write (timeout ) as buf :
516542 buf [0 ] = 0 # not overflow
517- buf [1 : len (serialized_obj ) + 1 ] = serialized_obj
543+ offset = 3
544+ buf [1 :offset ] = to_bytes_big (len (all_buffers ), 2 ) # oob buf count
545+ for buffer in all_buffers :
546+ buf_len = len (buffer )
547+ # prepend each buffer with 4 bytes containing its size.
548+ buf_offset = offset + 4
549+ buf [offset :buf_offset ] = to_bytes_big (buf_len , 4 )
550+ buf [buf_offset : (offset := buf_offset + buf_len )] = buffer
551+
518552 if self .n_remote_reader > 0 :
519- self .remote_socket .send ( serialized_obj )
553+ self .remote_socket .send_multipart ( all_buffers , copy = False )
520554
521555 def dequeue (
522556 self ,
@@ -529,10 +563,15 @@ def dequeue(
529563 with self .acquire_read (timeout , cancel , indefinite ) as buf :
530564 overflow = buf [0 ] == 1
531565 if not overflow :
532- # no need to know the size of serialized object
533- # pickle format contains the size information internally
534- # see https://docs.python.org/3/library/pickle.html
535- obj = pickle .loads (buf [1 :])
566+ offset = 3
567+ buf_count = from_bytes_big (buf [1 :offset ])
568+ all_buffers = []
569+ for i in range (buf_count ):
570+ buf_offset = offset + 4
571+ buf_len = from_bytes_big (buf [offset :buf_offset ])
572+ offset = buf_offset + buf_len
573+ all_buffers .append (buf [buf_offset :offset ])
574+ obj = pickle .loads (all_buffers [0 ], buffers = all_buffers [1 :])
536575 if overflow :
537576 obj = MessageQueue .recv (self .local_socket , timeout )
538577 elif self ._is_remote_reader :
@@ -546,15 +585,14 @@ def recv(socket: zmq.Socket, timeout: float | None) -> Any:
546585 timeout_ms = None if timeout is None else int (timeout * 1000 )
547586 if not socket .poll (timeout = timeout_ms ):
548587 raise TimeoutError
549- recv = socket .recv (copy = False )
550- return pickle .loads (recv . buffer )
588+ recv , * recv_oob = socket .recv_multipart (copy = False )
589+ return pickle .loads (recv , buffers = recv_oob )
551590
552591 def broadcast_object (self , obj = None ):
553592 if self ._is_writer :
554593 self .enqueue (obj )
555594 return obj
556- else :
557- return self .dequeue ()
595+ return self .dequeue ()
558596
559597 @staticmethod
560598 def create_from_process_group (
0 commit comments