Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions vllm/v1/serial_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,18 +158,16 @@ def _encode_tensor(
self, obj: torch.Tensor
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
assert self.aux_buffers is not None
# this creates a copy of the tensor if it's not already contiguous
obj = obj.contiguous()
# view the tensor as a 1D array of bytes
arr = obj.view((obj.numel(), )).view(torch.uint8).numpy()
arr = obj.flatten().view(torch.uint8).numpy()
if obj.nbytes < self.size_threshold:
# Smaller tensors are encoded inline, just like ndarrays.
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
else:
# Otherwise encode index of backing buffer to avoid copy.
data = len(self.aux_buffers)
self.aux_buffers.append(arr.data)
dtype = str(obj.dtype)[6:] # remove 'torch.' prefix
dtype = str(obj.dtype).removeprefix("torch.")
return dtype, obj.shape, data

def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
Expand Down Expand Up @@ -245,7 +243,7 @@ def _decode_ndarray(self, arr: Any) -> np.ndarray:
# zero-copy decode. We assume the ndarray will not be kept around,
# as it now locks the whole received message buffer in memory.
buffer = self.aux_buffers[data] if isinstance(data, int) else data
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)
return np.frombuffer(buffer, dtype=dtype).reshape(shape)

def _decode_tensor(self, arr: Any) -> torch.Tensor:
dtype, shape, data = arr
Expand All @@ -254,12 +252,15 @@ def _decode_tensor(self, arr: Any) -> torch.Tensor:
# not complain about a readonly memoryview.
buffer = self.aux_buffers[data] if isinstance(data, int) \
else bytearray(data)
# Create numpy wrapper around the bytes
arr = np.ndarray(buffer=buffer, dtype=np.uint8, shape=(len(buffer), ))
torch_dtype = getattr(torch, dtype)
assert isinstance(torch_dtype, torch.dtype)
if not buffer: # torch.frombuffer doesn't like empty buffers
assert 0 in shape
return torch.empty(shape, dtype=torch_dtype)
# Create uint8 array
arr = torch.frombuffer(buffer, dtype=torch.uint8)
# Convert back to proper shape & type
return torch.from_numpy(arr).view(torch_dtype).view(shape)
return arr.view(torch_dtype).view(shape)

def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
decoded_items = []
Expand Down