11# SPDX-License-Identifier: Apache-2.0
22
3+ import dataclasses
34import pickle
45from collections .abc import Sequence
56from inspect import isclass
1213import zmq
1314from msgspec import msgpack
1415
16+ from vllm import envs
17+ from vllm .multimodal .inputs import (BaseMultiModalField ,
18+ MultiModalBatchedField ,
19+ MultiModalFieldConfig , MultiModalFieldElem ,
20+ MultiModalFlatField , MultiModalKwargs ,
21+ MultiModalKwargsItem ,
22+ MultiModalSharedField , NestedTensors )
23+
1524CUSTOM_TYPE_PICKLE = 1
1625CUSTOM_TYPE_CLOUDPICKLE = 2
1726CUSTOM_TYPE_RAW_VIEW = 3
1827
19- # TODO calibrate this size
20- MIN_NOCOPY_BUF_SIZE = 512
28+ # MultiModalField class serialization type map.
29+ # These need to list all possible field types and match them
30+ # to factory methods in `MultiModalFieldConfig`.
31+ MMF_CLASS_TO_FACTORY : dict [type [BaseMultiModalField ], str ] = {
32+ MultiModalFlatField : "flat" ,
33+ MultiModalSharedField : "shared" ,
34+ MultiModalBatchedField : "batched" ,
35+ }
2136
2237bytestr = Union [bytes , bytearray , memoryview , zmq .Frame ]
2338
@@ -27,14 +42,20 @@ class MsgpackEncoder:
2742
2843 Note that unlike vanilla `msgspec` Encoders, this interface is generally
2944 not thread-safe when encoding tensors / numpy arrays.
45+
46+ By default, arrays below 256B are serialized inline Larger will get sent
47+ via dedicated messages. Note that this is a per-tensor limit.
3048 """
3149
32- def __init__ (self ):
50+ def __init__ (self , size_threshold : Optional [int ] = None ):
51+ if size_threshold is None :
52+ size_threshold = envs .VLLM_MSGPACK_ZERO_COPY_THRESHOLD
3353 self .encoder = msgpack .Encoder (enc_hook = self .enc_hook )
3454 # This is used as a local stash of buffers that we can then access from
3555 # our custom `msgspec` hook, `enc_hook`. We don't have a way to
3656 # pass custom data to the hook otherwise.
3757 self .aux_buffers : Optional [list [bytestr ]] = None
58+ self .size_threshold = size_threshold
3859
3960 def encode (self , obj : Any ) -> Sequence [bytestr ]:
4061 try :
@@ -65,6 +86,25 @@ def enc_hook(self, obj: Any) -> Any:
6586 if isinstance (obj , np .ndarray ) and obj .dtype .kind not in ('O' , 'V' ):
6687 return self ._encode_ndarray (obj )
6788
89+ if isinstance (obj , MultiModalKwargs ):
90+ mm : MultiModalKwargs = obj
91+ if not mm .modalities :
92+ # just return the main dict if there are no modalities.
93+ return dict (mm )
94+
95+ # ignore the main dict, it will be re-indexed.
96+ # Encode a list of MultiModalKwargsItems as plain dicts
97+ # + special handling for .field.
98+ # Any tensors *not* indexed by modality will be ignored.
99+ return [[{
100+ "modality" : elem .modality ,
101+ "key" : elem .key ,
102+ "data" : self ._encode_nested_tensors (elem .data ),
103+ "field" : self ._encode_mm_field (elem .field ),
104+ } for elem in item .values ()]
105+ for itemlist in mm ._items_by_modality .values ()
106+ for item in itemlist ]
107+
68108 if isinstance (obj , FunctionType ):
69109 # `pickle` is generally faster than cloudpickle, but can have
70110 # problems serializing methods.
@@ -77,8 +117,9 @@ def _encode_ndarray(
77117 self , obj : np .ndarray
78118 ) -> tuple [str , tuple [int , ...], Union [int , memoryview ]]:
79119 assert self .aux_buffers is not None
120+ # If the array is non-contiguous, we need to copy it first
80121 arr_data = obj .data if obj .data .c_contiguous else obj .tobytes ()
81- if not obj .shape or obj .nbytes < MIN_NOCOPY_BUF_SIZE :
122+ if not obj .shape or obj .nbytes < self . size_threshold :
82123 # Encode small arrays and scalars inline. Using this extension type
83124 # ensures we can avoid copying when decoding.
84125 data = msgpack .Ext (CUSTOM_TYPE_RAW_VIEW , arr_data )
@@ -92,6 +133,26 @@ def _encode_ndarray(
92133 # backing buffers that we've stashed in `aux_buffers`.
93134 return obj .dtype .str , obj .shape , data
94135
136+ def _encode_nested_tensors (self , nt : NestedTensors ) -> Any :
137+ if isinstance (nt , torch .Tensor ):
138+ return self ._encode_ndarray (nt .numpy ())
139+ if isinstance (nt , (int , float )):
140+ # Although it violates NestedTensors type, MultiModalKwargs
141+ # values are sometimes floats.
142+ return nt
143+ return [self ._encode_nested_tensors (x ) for x in nt ]
144+
145+ def _encode_mm_field (self , field : BaseMultiModalField ):
146+ # Figure out the factory name for the field type.
147+ name = MMF_CLASS_TO_FACTORY .get (field .__class__ )
148+ if not name :
149+ raise TypeError (f"Unsupported field type: { field .__class__ } " )
150+ # We just need to copy all of the field values in order
151+ # which will be then used to reconstruct the field.
152+ field_values = (getattr (field , f .name )
153+ for f in dataclasses .fields (field ))
154+ return name , * field_values
155+
95156
96157class MsgpackDecoder :
97158 """Decoder with custom torch tensor and numpy array serialization.
@@ -126,13 +187,50 @@ def dec_hook(self, t: type, obj: Any) -> Any:
126187 return self ._decode_ndarray (obj )
127188 if issubclass (t , torch .Tensor ):
128189 return torch .from_numpy (self ._decode_ndarray (obj ))
190+ if issubclass (t , MultiModalKwargs ):
191+ if isinstance (obj , list ):
192+ return MultiModalKwargs .from_items (
193+ self ._decode_mm_items (obj ))
194+ return MultiModalKwargs ({
195+ k : self ._decode_nested_tensors (v )
196+ for k , v in obj .items ()
197+ })
129198 return obj
130199
131200 def _decode_ndarray (self , arr : Any ) -> np .ndarray :
132201 dtype , shape , data = arr
133- buffer = self .aux_buffers [data ] if isinstance (data , int ) else data
202+ # Copy from inline representation, otherwise Torch is unhappy since
203+ # the returned memory is non-writeable.
204+ buffer = self .aux_buffers [data ] if isinstance (data , int ) \
205+ else bytearray (data )
134206 return np .ndarray (buffer = buffer , dtype = np .dtype (dtype ), shape = shape )
135207
208+ def _decode_mm_items (self , obj : list ) -> list [MultiModalKwargsItem ]:
209+ decoded_items = []
210+ for item in obj :
211+ elems = []
212+ for v in item :
213+ v ["data" ] = self ._decode_nested_tensors (v ["data" ])
214+ # Reconstruct the field processor using MultiModalFieldConfig
215+ factory_meth_name , * field_args = v ["field" ]
216+ factory_meth = getattr (MultiModalFieldConfig ,
217+ factory_meth_name )
218+ v ["field" ] = factory_meth (None , * field_args ).field
219+ elems .append (MultiModalFieldElem (** v ))
220+ decoded_items .append (MultiModalKwargsItem .from_elems (elems ))
221+ return decoded_items
222+
223+ def _decode_nested_tensors (self , obj : Any ) -> NestedTensors :
224+ if isinstance (obj , (int , float )):
225+ # Although it violates NestedTensors type, MultiModalKwargs
226+ # values are sometimes floats.
227+ return obj
228+ if not isinstance (obj , list ):
229+ raise TypeError (f"Unexpected NestedTensors contents: { type (obj )} " )
230+ if obj and isinstance (obj [0 ], str ):
231+ return torch .from_numpy (self ._decode_ndarray (obj ))
232+ return [self ._decode_nested_tensors (x ) for x in obj ]
233+
136234 def ext_hook (self , code : int , data : memoryview ) -> Any :
137235 if code == CUSTOM_TYPE_RAW_VIEW :
138236 return data
0 commit comments