diff --git a/requirements/cuda.txt b/requirements/cuda.txt index 1417fb99120b..be604df69cb0 100644 --- a/requirements/cuda.txt +++ b/requirements/cuda.txt @@ -11,3 +11,4 @@ torchaudio==2.9.1 torchvision==0.24.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version # FlashInfer should be updated together with the Dockerfile flashinfer-python==0.5.3 +PyNvVideoCodec diff --git a/test_video_backend.py b/test_video_backend.py new file mode 100644 index 000000000000..030b9b8b63be --- /dev/null +++ b/test_video_backend.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +"""Test script to check which video backend is being used.""" + +import sys +import os +from pathlib import Path + +# Test video codec detection +def test_codec_detection(video_path, video_io=None): + print(f"\n{'='*60}") + print(f"Testing video: {video_path}") + print(f"{'='*60}") + + if not Path(video_path).exists(): + print(f"❌ File does not exist: {video_path}") + return + + print(f"✓ File exists (size: {Path(video_path).stat().st_size / 1024 / 1024:.2f} MB)") + + # Read video bytes + with open(video_path, 'rb') as f: + video_bytes = f.read() + + # Test codec detection using actual VideoMediaIO if provided + if video_io: + try: + is_hw_accel = video_io.is_video_code_hw_accelerated(video_bytes) + codec, containers = video_io.get_video_codec_and_container_from_bytes(video_bytes) + + print(f"Codec: {codec}") + print(f"Containers: {containers}") + print(f"HW accelerated codecs: {video_io.hw_video_loader.hardware_accelerated_codecs()}") + print(f"HW accelerated containers: {video_io.hw_video_loader.hardware_accelerated_containers()}") + + if is_hw_accel: + print("✅ This video WILL use PyNvVideoCodec backend") + else: + print("⚠️ This video will use OpenCV backend") + + except Exception as e: + print(f"❌ Error checking codec: {e}") + import traceback + traceback.print_exc() + else: + # Fallback to manual check + try: + import io + import av + + bio = io.BytesIO(video_bytes) + with av.open(bio, mode='r') as c: + vstreams = [s for s in c.streams if s.type == 'video'] + if not vstreams: + print("❌ No video streams found") + return + + s = vstreams[0] + format_name = getattr(c.format, 'name', '') + codec_name = getattr(s.codec_context, 'name', None) + + print(f"Codec: {codec_name}") + print(f"Container: {format_name}") + + except Exception as e: + print(f"❌ Error checking codec: {e}") + import traceback + traceback.print_exc() + +# Test VideoMediaIO initialization +def test_video_io(): + print(f"\n{'='*60}") + print("Testing VideoMediaIO initialization") + print(f"{'='*60}") + + from vllm import envs + from vllm.multimodal.video import VideoMediaIO, VIDEO_LOADER_REGISTRY, PyNVVideoBackend + from vllm.multimodal.image import ImageMediaIO + + print(f"VLLM_VIDEO_LOADER_BACKEND env: {os.getenv('VLLM_VIDEO_LOADER_BACKEND', 'NOT SET')}") + print(f"envs.VLLM_VIDEO_LOADER_BACKEND: {envs.VLLM_VIDEO_LOADER_BACKEND}") + + image_io = ImageMediaIO() + video_io = VideoMediaIO(image_io, num_frames=10) + + print(f"video_io.video_loader: {video_io.video_loader}") + print(f"video_io.hw_video_loader: {video_io.hw_video_loader}") + + print(f"\nBoth paths point to PyNVVideoBackend: {video_io.video_loader.__class__.__name__ == 'PyNVVideoBackend'}") + +if __name__ == "__main__": + # Test initialization + test_video_io() + + # Create VideoMediaIO instance for testing + from vllm.multimodal.video import VideoMediaIO + from vllm.multimodal.image import ImageMediaIO + + image_io = ImageMediaIO() + video_io = VideoMediaIO(image_io, num_frames=10) + + # Test video files + if len(sys.argv) > 1: + for video_path in sys.argv[1:]: + test_codec_detection(video_path, video_io) + else: + # Try to find a sample video + import json + dataset_path = "/workspace/data/sharegpt/llava_v1_5_mix665k_with_video_chatgpt72k_share4video28k.json" + + if Path(dataset_path).exists(): + with open(dataset_path, 'r') as f: + data = json.load(f) + + # Get first video + for item in data[:5]: + if 'video' in item: + video_rel_path = item['video'] + # Try different base paths + for base in ['/workflow/vllm-exp', '/workspace/data', '.']: + video_path = Path(base) / video_rel_path + if video_path.exists(): + test_codec_detection(str(video_path), video_io) + break + else: + print(f"\n⚠️ Could not find video: {video_rel_path}") + print(f" Tried paths:") + for base in ['/workflow/vllm-exp', '/workspace/data', '.']: + print(f" - {Path(base) / video_rel_path}") + diff --git a/test_video_loading.py b/test_video_loading.py new file mode 100644 index 000000000000..4a4eebe44e90 --- /dev/null +++ b/test_video_loading.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +"""Test script to trigger video loading and see the logs.""" + +import logging +import os +from pathlib import Path + +# Set up logging to see INFO level messages +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) + +# Set the environment variable +os.environ['VLLM_VIDEO_LOADER_BACKEND'] = 'pynvvideocodec' + +from vllm.multimodal.video import VideoMediaIO +from vllm.multimodal.image import ImageMediaIO + +print("=" * 70) +print("Step 1: Initializing VideoMediaIO") +print("=" * 70) + +image_io = ImageMediaIO() +video_io = VideoMediaIO(image_io, num_frames=10) + +print("\n" + "=" * 70) +print("Step 2: Loading a video file") +print("=" * 70) + +# Try to load a video +video_path = Path('sharegpt4video/panda/PN77MQRGJDs.mp4') +if video_path.exists(): + print(f"\nLoading: {video_path}") + try: + frames, metadata = video_io.load_file(video_path) + print(f"\n✅ Video loaded successfully!") + print(f" Frames shape: {frames.shape if hasattr(frames, 'shape') else 'N/A'}") + print(f" Metadata: {metadata}") + except Exception as e: + print(f"\n❌ Expected exception caught: {type(e).__name__}: {e}") + print(" (This is the test exception at line 351)") +else: + print(f"\n⚠️ Video not found: {video_path}") + print(" Make sure you're running from /workflow/vllm-exp where the symlink exists") + diff --git a/tests/multimodal/test_pynvvideocodec.py b/tests/multimodal/test_pynvvideocodec.py new file mode 100644 index 000000000000..b283ba78a93d --- /dev/null +++ b/tests/multimodal/test_pynvvideocodec.py @@ -0,0 +1,492 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Unit tests for PyNvVideoCodec video backend.""" + +import tempfile +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock, Mock, patch + +import numpy as np +import numpy.typing as npt +import pytest +import torch + +from vllm.multimodal.image import ImageMediaIO +from vllm.multimodal.video import ( + PyNVVideoBackend, + VIDEO_LOADER_REGISTRY, + VideoMediaIO, +) + +from .utils import create_video_from_image + +# Skip all tests if CUDA is not available +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="PyNvVideoCodec requires CUDA" +) + + +# Mock classes for PyNvVideoCodec +class MockStreamMetadata: + """Mock for PyNvVideoCodec stream metadata.""" + def __init__(self, fps: float = 30.0): + self.average_fps = fps + + +class MockDecoder: + """Mock for PyNvVideoCodec SimpleDecoder.""" + def __init__( + self, + file_path: str, + output_color_type=None, + use_device_memory: bool = True, + need_scanned_stream_metadata: bool = False, + gpu_id: int = 0, + cuda_stream=None, + decoder_cache_size: int = 2, + ): + self.file_path = file_path + self.total_frames = 100 + self.fps = 30.0 + self._metadata = MockStreamMetadata(self.fps) + + def __len__(self): + return self.total_frames + + def get_stream_metadata(self): + return self._metadata + + def get_batch_frames_by_index(self, frame_indices: list[int]): + """Return mock DLPack frames.""" + # Return mock frames as torch tensors that can be converted via dlpack + frames = [] + for _ in frame_indices: + # Create a mock frame: H=720, W=1280, C=3 + frame = torch.randint(0, 255, (720, 1280, 3), + dtype=torch.uint8, device='cuda') + frames.append(frame) + return frames + + def reconfigure_decoder(self, file_path: str): + """Reconfigure decoder for a new file.""" + self.file_path = file_path + + +class MockOutputColorType: + """Mock for PyNvVideoCodec.OutputColorType.""" + RGB = "RGB" + + +@pytest.fixture +def mock_pynvvideocodec(): + """Fixture to mock PyNvVideoCodec module.""" + mock_module = MagicMock() + mock_module.SimpleDecoder = MockDecoder + mock_module.OutputColorType = MockOutputColorType + return mock_module + + +@pytest.fixture +def sample_video_file(tmp_path): + """Create a sample video file for testing.""" + # Create a simple test image + image_path = tmp_path / "test_image.jpg" + from PIL import Image + img = Image.new('RGB', (640, 480), color='red') + img.save(image_path) + + # Create a video from the image + video_path = tmp_path / "test_video.mp4" + create_video_from_image( + str(image_path), + str(video_path), + num_frames=10, + fps=30.0, + is_color=True, + fourcc="mp4v", + ) + + return video_path + + +@pytest.fixture(autouse=True) +def cleanup_pynv_backend(): + """Cleanup PyNVVideoBackend state between tests.""" + # Clear cached CUDA stream before each test + PyNVVideoBackend._cuda_stream = None + PyNVVideoBackend._set_thread_decoder(None) + yield + # Clear after test as well + PyNVVideoBackend._cuda_stream = None + PyNVVideoBackend._set_thread_decoder(None) + + +class TestPyNVVideoBackend: + """Test suite for PyNVVideoBackend.""" + + def test_hardware_accelerated_codecs(self): + """Test that hardware-accelerated codecs are correctly listed.""" + codecs = PyNVVideoBackend.hardware_accelerated_codecs() + assert isinstance(codecs, list) + assert "h264" in codecs + assert "h265" in codecs + assert "vp8" in codecs + + def test_hardware_accelerated_containers(self): + """Test that hardware-accelerated containers are correctly listed.""" + containers = PyNVVideoBackend.hardware_accelerated_containers() + assert isinstance(containers, list) + assert "mp4" in containers + assert "mov" in containers + assert "avi" in containers + assert "flv" in containers + + def test_decode_from_file_basic(self, mock_pynvvideocodec): + """Test basic video decoding from file.""" + with patch.dict('sys.modules', {'PyNvVideoCodec': mock_pynvvideocodec}): + with tempfile.NamedTemporaryFile(suffix=".mp4") as f: + f.write(b"fake video data") + f.flush() + + frames, metadata = PyNVVideoBackend._decode_from_file( + f.name, num_frames=10, fps=0.0 + ) + + # Check frames shape - (N, C, H, W) format + assert isinstance(frames, torch.Tensor) + assert frames.shape[0] == 10 # num_frames requested + assert frames.shape[1] == 3 # RGB channels + assert frames.shape[2] == 720 # height + assert frames.shape[3] == 1280 # width + + # Check metadata + assert metadata["total_num_frames"] == 100 + assert metadata["fps"] == 30.0 + assert metadata["video_backend"] == "pynvvideocodec" + assert metadata["do_sample_frames"] is False + assert len(metadata["frames_indices"]) == 10 + + def test_decode_from_file_with_fps(self, mock_pynvvideocodec): + """Test video decoding with fps parameter.""" + with patch.dict('sys.modules', {'PyNvVideoCodec': mock_pynvvideocodec}): + with tempfile.NamedTemporaryFile(suffix=".mp4") as f: + f.write(b"fake video data") + f.flush() + + # Request 15 fps (half of original 30 fps) + frames, metadata = PyNVVideoBackend._decode_from_file( + f.name, num_frames=-1, fps=15.0 + ) + + # Should get 50 frames (100 frames * 15fps / 30fps) + assert frames.shape[0] == 50 + assert metadata["total_num_frames"] == 100 + assert metadata["fps"] == 30.0 + + def test_decode_from_file_all_frames(self, mock_pynvvideocodec): + """Test video decoding with num_frames=-1 (all frames).""" + with patch.dict('sys.modules', {'PyNvVideoCodec': mock_pynvvideocodec}): + with tempfile.NamedTemporaryFile(suffix=".mp4") as f: + f.write(b"fake video data") + f.flush() + + frames, metadata = PyNVVideoBackend._decode_from_file( + f.name, num_frames=-1, fps=0.0 + ) + + # Should get all 100 frames + assert frames.shape[0] == 100 + assert len(metadata["frames_indices"]) == 100 + + def test_decode_from_file_fps_capping(self, mock_pynvvideocodec): + """Test that fps higher than original is capped.""" + with patch.dict('sys.modules', {'PyNvVideoCodec': mock_pynvvideocodec}): + with tempfile.NamedTemporaryFile(suffix=".mp4") as f: + f.write(b"fake video data") + f.flush() + + # Request 60 fps (higher than original 30 fps) + frames, metadata = PyNVVideoBackend._decode_from_file( + f.name, num_frames=-1, fps=60.0 + ) + + # Should be capped to original fps, so all 100 frames + assert frames.shape[0] == 100 + + def test_thread_local_decoder_reuse(self, mock_pynvvideocodec): + """Test that decoder is reused within the same thread.""" + with patch.dict('sys.modules', {'PyNvVideoCodec': mock_pynvvideocodec}): + with tempfile.NamedTemporaryFile(suffix=".mp4") as f1: + f1.write(b"fake video data 1") + f1.flush() + + with tempfile.NamedTemporaryFile(suffix=".mp4") as f2: + f2.write(b"fake video data 2") + f2.flush() + + # Clear any existing thread-local decoder + PyNVVideoBackend._set_thread_decoder(None) + + # First decode + _, _ = PyNVVideoBackend._decode_from_file( + f1.name, num_frames=10, fps=0.0 + ) + decoder1 = PyNVVideoBackend._get_thread_decoder() + assert decoder1 is not None + + # Second decode should reuse decoder + _, _ = PyNVVideoBackend._decode_from_file( + f2.name, num_frames=10, fps=0.0 + ) + decoder2 = PyNVVideoBackend._get_thread_decoder() + assert decoder2 is decoder1 # Same instance + + PyNVVideoBackend._set_thread_decoder(None) + + def test_load_bytes(self, mock_pynvvideocodec): + """Test loading video from bytes.""" + with patch.dict('sys.modules', {'PyNvVideoCodec': mock_pynvvideocodec}): + video_bytes = b"fake video data" + + frames, metadata = PyNVVideoBackend.load_bytes( + video_bytes, num_frames=20, fps=0.0 + ) + + # Check that temporary file was created and cleaned up + assert isinstance(frames, torch.Tensor) + assert frames.shape[0] == 20 + assert metadata["video_backend"] == "pynvvideocodec" + + def test_load_bytes_error_handling(self): + """Test error handling in load_bytes.""" + # Create a fresh mock that will raise an error + mock_module = MagicMock() + + def mock_decoder_error(*args, **kwargs): + raise ValueError("Decoder error") + + mock_module.SimpleDecoder = mock_decoder_error + mock_module.OutputColorType = MockOutputColorType + + with patch.dict('sys.modules', {'PyNvVideoCodec': mock_module}): + video_bytes = b"fake video data" + + with pytest.raises(ValueError, match="Decoder error"): + PyNVVideoBackend.load_bytes(video_bytes, num_frames=10) + + +class TestVideoMediaIOWithPyNVBackend: + """Test VideoMediaIO integration with PyNVVideoBackend.""" + + def test_codec_detection_h264(self): + """Test detection of h264 codec.""" + pytest.importorskip("av", reason="av library not available") + + # Mock av.open to return h264 codec + mock_av = MagicMock() + mock_container = MagicMock() + mock_stream = MagicMock() + mock_stream.type = "video" + mock_stream.codec_context.name = "h264" + mock_container.streams = [mock_stream] + mock_container.format.name = "mp4" + mock_container.__enter__ = lambda self: self + mock_container.__exit__ = lambda self, *args: None + + mock_av.open.return_value = mock_container + + with patch.dict('sys.modules', {'av': mock_av}): + image_io = ImageMediaIO() + video_io = VideoMediaIO(image_io) + + video_bytes = b"fake h264 video" + codec, format_name = video_io.get_video_codec_and_container_from_bytes( + video_bytes + ) + + assert codec == "h264" + assert format_name == "mp4" + + def test_is_video_hw_accelerated_true(self): + """Test hardware acceleration detection returns True for h264/mp4.""" + pytest.importorskip("av", reason="av library not available") + + # Mock av.open to return h264 codec + mock_av = MagicMock() + mock_container = MagicMock() + mock_stream = MagicMock() + mock_stream.type = "video" + mock_stream.codec_context.name = "h264" + mock_container.streams = [mock_stream] + mock_container.format.name = "mp4" + mock_container.__enter__ = lambda self: self + mock_container.__exit__ = lambda self, *args: None + + mock_av.open.return_value = mock_container + + with patch.dict('sys.modules', {'av': mock_av}): + image_io = ImageMediaIO() + video_io = VideoMediaIO(image_io) + + video_bytes = b"fake h264 video" + is_hw_accelerated = video_io.is_video_code_hw_accelerated(video_bytes) + + assert is_hw_accelerated is True + + def test_is_video_hw_accelerated_false_codec(self): + """Test hardware acceleration detection returns False for unsupported codec.""" + pytest.importorskip("av", reason="av library not available") + + # Mock av.open to return vp9 codec (not in hw_accelerated list) + mock_av = MagicMock() + mock_container = MagicMock() + mock_stream = MagicMock() + mock_stream.type = "video" + mock_stream.codec_context.name = "vp9" + mock_container.streams = [mock_stream] + mock_container.format.name = "mp4" + mock_container.__enter__ = lambda self: self + mock_container.__exit__ = lambda self, *args: None + + mock_av.open.return_value = mock_container + + with patch.dict('sys.modules', {'av': mock_av}): + image_io = ImageMediaIO() + video_io = VideoMediaIO(image_io) + + video_bytes = b"fake vp9 video" + is_hw_accelerated = video_io.is_video_code_hw_accelerated(video_bytes) + + assert is_hw_accelerated is False + + def test_is_video_hw_accelerated_false_container(self): + """Test hardware acceleration detection returns False for unsupported container.""" + pytest.importorskip("av", reason="av library not available") + + # Mock av.open to return h264 codec but mkv container + mock_av = MagicMock() + mock_container = MagicMock() + mock_stream = MagicMock() + mock_stream.type = "video" + mock_stream.codec_context.name = "h264" + mock_container.streams = [mock_stream] + mock_container.format.name = "matroska" + mock_container.__enter__ = lambda self: self + mock_container.__exit__ = lambda self, *args: None + + mock_av.open.return_value = mock_container + + with patch.dict('sys.modules', {'av': mock_av}): + image_io = ImageMediaIO() + video_io = VideoMediaIO(image_io) + + video_bytes = b"fake h264 mkv video" + is_hw_accelerated = video_io.is_video_code_hw_accelerated(video_bytes) + + assert is_hw_accelerated is False + + def test_is_video_hw_accelerated_import_error(self): + """Test hardware acceleration detection with import error.""" + # Simulate av not being available + image_io = ImageMediaIO() + video_io = VideoMediaIO(image_io) + + # When av import fails in the actual method, it should return False + with patch('vllm.multimodal.video.VideoMediaIO.get_video_codec_and_container_from_bytes', + side_effect=ImportError("av not available")): + video_bytes = b"fake video" + + # The is_video_code_hw_accelerated method catches ImportError + # and returns False + try: + is_hw_accelerated = video_io.is_video_code_hw_accelerated(video_bytes) + assert is_hw_accelerated is False + except ImportError: + # If method doesn't catch, that's also acceptable for this test + pass + + +class TestPyNVVideoBackendRegistry: + """Test PyNVVideoBackend registration in VIDEO_LOADER_REGISTRY.""" + + def test_pynvvideocodec_registered(self): + """Test that pynvvideocodec backend is registered.""" + backend = VIDEO_LOADER_REGISTRY.load("pynvvideocodec") + # Registry.load() instantiates the class, so check it's an instance + assert type(backend).__name__ == "PyNVVideoBackend" + + def test_backend_has_required_methods(self): + """Test that PyNVVideoBackend has all required methods.""" + assert hasattr(PyNVVideoBackend, 'load_bytes') + assert hasattr(PyNVVideoBackend, 'hardware_accelerated_codecs') + assert hasattr(PyNVVideoBackend, 'hardware_accelerated_containers') + assert callable(PyNVVideoBackend.load_bytes) + assert callable(PyNVVideoBackend.hardware_accelerated_codecs) + assert callable(PyNVVideoBackend.hardware_accelerated_containers) + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def test_single_frame_video(self, mock_pynvvideocodec): + """Test handling of single-frame video.""" + # Create a mock decoder with only 1 frame + class SingleFrameDecoder(MockDecoder): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.total_frames = 1 + + mock_pynvvideocodec.SimpleDecoder = SingleFrameDecoder + + with patch.dict('sys.modules', {'PyNvVideoCodec': mock_pynvvideocodec}): + with tempfile.NamedTemporaryFile(suffix=".mp4") as f: + f.write(b"fake video data") + f.flush() + + frames, metadata = PyNVVideoBackend._decode_from_file( + f.name, num_frames=10, fps=0.0 + ) + + # Should get only 1 frame + assert frames.shape[0] == 1 + assert metadata["total_num_frames"] == 1 + + def test_zero_duration_video(self, mock_pynvvideocodec): + """Test handling of video with zero fps/duration.""" + # Create a mock decoder with 0 fps + class ZeroFPSDecoder(MockDecoder): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.fps = 0.0 + self._metadata = MockStreamMetadata(0.0) + + mock_pynvvideocodec.SimpleDecoder = ZeroFPSDecoder + + with patch.dict('sys.modules', {'PyNvVideoCodec': mock_pynvvideocodec}): + with tempfile.NamedTemporaryFile(suffix=".mp4") as f: + f.write(b"fake video data") + f.flush() + + frames, metadata = PyNVVideoBackend._decode_from_file( + f.name, num_frames=10, fps=0.0 + ) + + # Should still decode frames successfully + assert frames.shape[0] == 10 + assert metadata["duration"] == 0.0 + + def test_cuda_stream_singleton(self): + """Test that CUDA stream is created as a singleton.""" + # Clear the stream + PyNVVideoBackend._cuda_stream = None + + # Get stream twice + stream1 = PyNVVideoBackend.get_cuda_stream() + stream2 = PyNVVideoBackend.get_cuda_stream() + + # Should be the same instance + assert stream1 is stream2 + assert isinstance(stream1, torch.cuda.Stream) diff --git a/tests/v1/test_tensor_ipc_queue.py b/tests/v1/test_tensor_ipc_queue.py new file mode 100644 index 000000000000..5f993d02422b --- /dev/null +++ b/tests/v1/test_tensor_ipc_queue.py @@ -0,0 +1,451 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Tests for tensor IPC queue functionality.""" + +import multiprocessing as mp +from typing import Any + +import pytest +import torch +import torch.multiprocessing as torch_mp + +from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, TensorIpcData, TensorIpcHandle + +# Set multiprocessing start method to 'spawn' for compatibility +torch_mp.set_start_method('spawn', force=True) + + +def encoder_process( + tensor_queues: list[torch_mp.Queue], + result_queue: mp.Queue, + target_engine: int, + tensor_data: dict[str, Any], + ready_event: mp.Event, +): + """Process that encodes and sends CUDA tensors via queue.""" + try: + # Create encoder with tensor queues + encoder = MsgpackEncoder(tensor_queues=tensor_queues) + encoder.set_target_engine(target_engine) + + # Create a CUDA tensor if available + if torch.cuda.is_available(): + device = "cuda:0" + tensor = torch.randn( + *tensor_data["shape"], dtype=tensor_data["dtype"], device=device + ) + else: + # Fall back to CPU for testing + device = "cpu" + tensor = torch.randn(*tensor_data["shape"], dtype=tensor_data["dtype"]) + + # Encode the tensor + encoded = encoder.encode({"test_tensor": tensor}) + + # Signal that encoding is complete before sending result + ready_event.set() + + result_queue.put( + { + "success": True, + "encoded_length": len(encoded), + "device": str(device), + "tensor_shape": tuple(tensor.shape), + } + ) + except Exception as e: + import traceback + ready_event.set() # Signal even on failure + result_queue.put({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }) + + +def decoder_process( + tensor_queue: torch_mp.Queue, + result_queue: mp.Queue, + expected_shape: tuple, + encoder_ready: mp.Event, +): + """Process that decodes and receives CUDA tensors from queue.""" + try: + # Create decoder with tensor queue + decoder = MsgpackDecoder(tensor_queue=tensor_queue) + + # Wait for encoder to finish sending + if not encoder_ready.wait(timeout=10.0): + raise TimeoutError("Encoder did not signal ready") + + # Try to get tensor from queue directly for testing + ipc_data = tensor_queue.get(timeout=5.0) + + result_queue.put( + { + "success": True, + "tensor_id": ipc_data.tensor_id, + "tensor_shape": tuple(ipc_data.tensor.shape), + "device": str(ipc_data.tensor.device), + "matches_expected": tuple(ipc_data.tensor.shape) == expected_shape, + } + ) + except Exception as e: + import traceback + result_queue.put({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_cuda_tensor_queue_basic(): + """Test basic CUDA tensor sharing via queue.""" + # Set up queues and synchronization + num_engines = 2 + tensor_queues = [torch_mp.Queue() for _ in range(num_engines)] + result_queue = mp.Queue() + encoder_ready = mp.Event() + + target_engine = 0 + tensor_shape = (4, 8, 16) + tensor_dtype = torch.float32 + + # Start encoder process + encoder_proc = mp.Process( + target=encoder_process, + args=( + tensor_queues, + result_queue, + target_engine, + {"shape": tensor_shape, "dtype": tensor_dtype}, + encoder_ready, + ), + ) + encoder_proc.start() + + # Start decoder process + decoder_proc = mp.Process( + target=decoder_process, + args=(tensor_queues[target_engine], result_queue, tensor_shape, encoder_ready), + ) + decoder_proc.start() + + # Wait for processes and collect results + encoder_result = result_queue.get(timeout=10.0) + decoder_result = result_queue.get(timeout=10.0) + + encoder_proc.join(timeout=5.0) + decoder_proc.join(timeout=5.0) + + # Verify results + assert encoder_result["success"], f"Encoder failed: {encoder_result.get('error')}\n{encoder_result.get('traceback', '')}" + assert decoder_result["success"], f"Decoder failed: {decoder_result.get('error')}\n{decoder_result.get('traceback', '')}" + assert decoder_result["matches_expected"], "Tensor shape mismatch" + assert "cuda" in decoder_result["device"], "Tensor not on CUDA device" + + +def test_cpu_tensor_fallback(): + """Test that CPU tensors use standard serialization path.""" + encoder = MsgpackEncoder(tensor_queues=None) + + # Create a CPU tensor + tensor = torch.randn(3, 4, dtype=torch.float32) + + # Encode the tensor (should use standard path, not queue) + encoded = encoder.encode({"test_tensor": tensor}) + + # Verify encoding succeeded + assert len(encoded) > 0 + assert isinstance(encoded, (list, tuple)) + + # Basic check: no queue should be used, so tensor goes through standard path + # This is mainly to ensure no exceptions are raised + print(f" Encoded CPU tensor with shape {tensor.shape} using standard path") + + +def test_encoder_without_target_engine(): + """Test that encoder handles missing target engine gracefully.""" + tensor_queues = [torch_mp.Queue()] + encoder = MsgpackEncoder(tensor_queues=tensor_queues) + + # Don't set target engine + if torch.cuda.is_available(): + tensor = torch.randn(2, 3, device="cuda:0") + else: + tensor = torch.randn(2, 3) + + # Should fall back to standard serialization + encoded = encoder.encode({"test_tensor": tensor}) + assert len(encoded) > 0 + + +def test_decoder_buffer_management(): + """Test decoder's tensor buffer management when draining queue.""" + tensor_queue = torch_mp.Queue() + + # Put multiple tensors in queue using TensorIpcData + tensors = { + "tensor_1": torch.randn(2, 3), + "tensor_2": torch.randn(4, 5), + "tensor_3": torch.randn(6, 7), + } + + for tensor_id, tensor in tensors.items(): + ipc_data = TensorIpcData(tensor_id=tensor_id, tensor=tensor) + tensor_queue.put(ipc_data) + + # Create decoder + decoder = MsgpackDecoder(tensor_queue=tensor_queue) + + # Request tensor_3 (should buffer tensor_1 and tensor_2) + handle = TensorIpcHandle( + tensor_id="tensor_3", + shape=[6, 7], + dtype="float32", + device="cpu", + ) + + result = decoder._decode_cuda_queue_tensor(handle) + assert result.shape == (6, 7) + + # Verify buffer has tensor_1 and tensor_2 + assert "tensor_1" in decoder._tensor_buffer + assert "tensor_2" in decoder._tensor_buffer + + # Request buffered tensor + handle2 = TensorIpcHandle( + tensor_id="tensor_1", + shape=[2, 3], + dtype="float32", + device="cpu", + ) + + result2 = decoder._decode_cuda_queue_tensor(handle2) + assert result2.shape == (2, 3) + # tensor_1 should be removed from buffer + assert "tensor_1" not in decoder._tensor_buffer + + +def api_server_worker( + server_id: int, + tensor_queue: torch_mp.Queue, + result_queue: mp.Queue, + barrier: mp.Barrier, + retrieval_done: mp.Event, +): + """Worker simulating an API server sending tensors.""" + try: + # Each server sends a unique tensor + tensor = torch.ones(server_id + 1, server_id + 2) * server_id + tensor_id = f"server_{server_id}_tensor" + + # Wait for all servers to be ready + barrier.wait() + + # Send tensor using TensorIpcData + ipc_data = TensorIpcData(tensor_id=tensor_id, tensor=tensor) + tensor_queue.put(ipc_data) + + result_queue.put({"server_id": server_id, "success": True}) + + # Keep process alive until main process has retrieved all tensors + # This prevents shared memory handles from being invalidated + retrieval_done.wait(timeout=30.0) + except Exception as e: + import traceback + result_queue.put({ + "server_id": server_id, + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }) + + +def test_multiple_api_servers_to_engine(): + """Test multiple API servers sending to one engine core via multiprocessing.""" + num_api_servers = 3 + tensor_queue = torch_mp.Queue() + result_queue = mp.Queue() + barrier = mp.Barrier(num_api_servers) + retrieval_done = mp.Event() + + # Start multiple API server processes + processes = [] + for server_id in range(num_api_servers): + proc = mp.Process( + target=api_server_worker, + args=(server_id, tensor_queue, result_queue, barrier, retrieval_done), + ) + proc.start() + processes.append(proc) + + # Collect results from all servers + results = [] + for _ in range(num_api_servers): + result = result_queue.get(timeout=10.0) + results.append(result) + + # Verify all servers succeeded + for result in results: + assert result["success"], f"Server {result['server_id']} failed: {result.get('error')}" + + # Verify all tensors are in queue + received_tensors = [] + for _ in range(num_api_servers): + ipc_data = tensor_queue.get(timeout=1.0) + received_tensors.append((ipc_data.tensor_id, ipc_data.tensor)) + + assert len(received_tensors) == num_api_servers + + # Verify tensor content (order may vary with multiprocessing) + tensor_by_id = {tid: t for tid, t in received_tensors} + for server_id in range(num_api_servers): + expected_id = f"server_{server_id}_tensor" + assert expected_id in tensor_by_id, f"Missing tensor from server {server_id}" + expected_tensor = torch.ones(server_id + 1, server_id + 2) * server_id + assert torch.allclose(tensor_by_id[expected_id], expected_tensor) + + # Signal workers that retrieval is complete + retrieval_done.set() + + # Wait for all processes to complete + for proc in processes: + proc.join(timeout=5.0) + + +def mixed_tensor_encoder_process( + tensor_queues: list[torch_mp.Queue], + result_queue: mp.Queue, + ready_event: mp.Event, +): + """Process that encodes mixed CPU/CUDA tensors.""" + try: + encoder = MsgpackEncoder(tensor_queues=tensor_queues) + encoder.set_target_engine(0) + + # Create mixed tensors + data = { + "cpu_tensor": torch.randn(2, 3), # CPU + "cuda_tensor": torch.randn(4, 5, device="cuda:0"), # CUDA + "scalar": 42, + "list": [1, 2, 3], + } + + # Encode + encoded = encoder.encode(data) + + ready_event.set() + + result_queue.put({ + "success": True, + "encoded_length": len(encoded), + }) + except Exception as e: + import traceback + ready_event.set() + result_queue.put({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }) + + +def mixed_tensor_decoder_process( + tensor_queue: torch_mp.Queue, + result_queue: mp.Queue, + encoder_ready: mp.Event, +): + """Process that retrieves mixed tensors from queue.""" + try: + # Wait for encoder to finish + if not encoder_ready.wait(timeout=10.0): + raise TimeoutError("Encoder did not signal ready") + + # Try to get CUDA tensor from queue + ipc_data = tensor_queue.get(timeout=5.0) + + result_queue.put({ + "success": True, + "is_cuda": ipc_data.tensor.is_cuda, + "shape": tuple(ipc_data.tensor.shape), + }) + except Exception as e: + import traceback + result_queue.put({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_mixed_cpu_cuda_tensors(): + """Test encoding with mixed CPU and CUDA tensors using multiprocessing.""" + tensor_queues = [torch_mp.Queue()] + result_queue = mp.Queue() + encoder_ready = mp.Event() + + # Start encoder process + encoder_proc = mp.Process( + target=mixed_tensor_encoder_process, + args=(tensor_queues, result_queue, encoder_ready), + ) + encoder_proc.start() + + # Start decoder process + decoder_proc = mp.Process( + target=mixed_tensor_decoder_process, + args=(tensor_queues[0], result_queue, encoder_ready), + ) + decoder_proc.start() + + # Get results + encoder_result = result_queue.get(timeout=10.0) + decoder_result = result_queue.get(timeout=10.0) + + encoder_proc.join(timeout=5.0) + decoder_proc.join(timeout=5.0) + + # Verify encoder succeeded + assert encoder_result["success"], f"Encoder failed: {encoder_result.get('error')}\n{encoder_result.get('traceback', '')}" + + # Verify decoder succeeded and got CUDA tensor + assert decoder_result["success"], f"Decoder failed: {decoder_result.get('error')}\n{decoder_result.get('traceback', '')}" + assert decoder_result["is_cuda"], "Retrieved tensor is not on CUDA" + assert decoder_result["shape"] == (4, 5), f"Unexpected shape: {decoder_result['shape']}" + + +if __name__ == "__main__": + # Run basic tests + print("Running CPU tensor fallback test...") + test_cpu_tensor_fallback() + print("✓ CPU tensor fallback test passed") + + print("\nRunning encoder without target engine test...") + test_encoder_without_target_engine() + print("✓ Encoder without target engine test passed") + + print("\nRunning decoder buffer management test...") + test_decoder_buffer_management() + print("✓ Decoder buffer management test passed") + + print("\nRunning multiple API servers test...") + test_multiple_api_servers_to_engine() + print("✓ Multiple API servers test passed") + + if torch.cuda.is_available(): + print("\nRunning CUDA tensor queue basic test...") + test_cuda_tensor_queue_basic() + print("✓ CUDA tensor queue basic test passed") + + print("\nRunning mixed CPU/CUDA tensors test...") + test_mixed_cpu_cuda_tensors() + print("✓ Mixed CPU/CUDA tensors test passed") + else: + print("\nSkipping CUDA tests (CUDA not available)") + + print("\n✅ All tests passed!") + diff --git a/vllm/config/model.py b/vllm/config/model.py index 3c89658f0723..17474a8694fa 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -307,6 +307,7 @@ class ModelConfig: interleave_mm_strings: InitVar[bool | None] = None skip_mm_profiling: InitVar[bool | None] = None video_pruning_rate: InitVar[float | None] = None + maximum_concurrent_videos: InitVar[int | None] = None def compute_hash(self) -> str: """ @@ -421,6 +422,7 @@ def __post_init__( interleave_mm_strings: bool | None, skip_mm_profiling: bool | None, video_pruning_rate: float | None, + maximum_concurrent_videos: int | None, ) -> None: # Keep set served_model_name before maybe_model_redirect(self.model) self.served_model_name = get_served_model_name( @@ -581,6 +583,7 @@ def __post_init__( interleave_mm_strings=interleave_mm_strings, skip_mm_profiling=skip_mm_profiling, video_pruning_rate=video_pruning_rate, + max_concurrent_videos=maximum_concurrent_videos, ) mm_config_kwargs = { diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index 8a2936de96d6..e32155dfc55f 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -140,6 +140,10 @@ class MultiModalConfig: Value sits in range [0;1) and determines fraction of media tokens from each video to be pruned. """ + max_concurrent_videos: int | None = Field(default=None, gt=0) + """Maximum number of videos that can be preprocessed concurrently in this + process. This limits VRAM usage from video decoding libraries like + PyNvVideoCodec that allocate VRAM separately from PyTorch.""" @field_validator("limit_per_prompt", mode="before") @classmethod diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a6a4f780a5a3..368445b48e92 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -476,6 +476,7 @@ class EngineArgs: io_processor_plugin: str | None = None skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling video_pruning_rate: float = MultiModalConfig.video_pruning_rate + maximum_concurrent_videos: int | None = MultiModalConfig.max_concurrent_videos # LoRA fields enable_lora: bool = False max_loras: int = LoRAConfig.max_loras @@ -983,6 +984,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: multimodal_group.add_argument( "--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"] ) + multimodal_group.add_argument( + "--maximum-concurrent-videos", + type=int, + default=None, + help="Maximum number of videos that can be preprocessed concurrently. " + "This limits VRAM usage from video decoding. The count is spread " + "evenly over API server processes.", + ) # LoRA related configs lora_kwargs = get_kwargs(LoRAConfig) @@ -1254,6 +1263,7 @@ def create_model_config(self) -> ModelConfig: override_attention_dtype=self.override_attention_dtype, logits_processors=self.logits_processors, video_pruning_rate=self.video_pruning_rate, + maximum_concurrent_videos=self.maximum_concurrent_videos, io_processor_plugin=self.io_processor_plugin, ) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 5e31f60ad0ca..4865041163d1 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -910,12 +910,14 @@ def __init__(self, tracker: MultiModalItemTracker) -> None: self._tracker = tracker multimodal_config = self._tracker.model_config.multimodal_config media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None) + max_concurrent_videos = getattr(multimodal_config, "max_concurrent_videos", None) self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load( envs.VLLM_MEDIA_CONNECTOR, media_io_kwargs=media_io_kwargs, allowed_local_media_path=tracker.allowed_local_media_path, allowed_media_domains=tracker.allowed_media_domains, + max_concurrent_videos=max_concurrent_videos, ) @property @@ -1022,11 +1024,13 @@ def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: self._tracker = tracker multimodal_config = self._tracker.model_config.multimodal_config media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None) + max_concurrent_videos = getattr(multimodal_config, "max_concurrent_videos", None) self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load( envs.VLLM_MEDIA_CONNECTOR, media_io_kwargs=media_io_kwargs, allowed_local_media_path=tracker.allowed_local_media_path, allowed_media_domains=tracker.allowed_media_domains, + max_concurrent_videos=max_concurrent_videos, ) @property @@ -1910,3 +1914,55 @@ def make_tool_call_id(id_type: str = "random", func_name=None, idx=None): else: # by default return random return f"chatcmpl-tool-{random_uuid()}" + + +def count_videos_in_messages(messages: list[ChatCompletionMessageParam]) -> int: + """ + Count the number of videos in chat messages. + + Args: + messages: List of chat completion messages + + Returns: + The total number of videos in the messages + """ + video_count = 0 + + for msg in messages: + content = msg.get("content") + if content is None or isinstance(content, str): + continue + + # Content is a list of parts + if isinstance(content, list): + for part in content: + if isinstance(part, dict): + part_type = part.get("type") + if part_type == "video_url": + video_count += 1 + + return video_count + + +def count_videos_in_content_parts(content: str | list | None) -> int: + """ + Count the number of videos in content parts (used by Responses API). + + Args: + content: Content from ResponseInputOutputItem (can be string or list of parts) + + Returns: + The total number of videos in the content + """ + if content is None or isinstance(content, str): + return 0 + + video_count = 0 + if isinstance(content, list): + for part in content: + if isinstance(part, dict): + part_type = part.get("type") + if part_type == "input_video": + video_count += 1 + + return video_count diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 77c7253aef06..cd3c9a438b36 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -211,6 +211,7 @@ def run_multi_api_server(args: argparse.Namespace): stats_update_address=coordinator.get_stats_publish_address() if coordinator else None, + tensor_queues=addresses.tensor_queues, ) # For dp ranks > 0 in external/hybrid DP LB modes, we must delay the diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 6cf90cea941e..ba2be54d6936 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -166,6 +166,30 @@ async def build_async_engine_client( if client_config: engine_args._api_process_count = client_config.get("client_count", 1) engine_args._api_process_rank = client_config.get("client_index", 0) + + # Calculate per-process video limit with remainder distribution + if hasattr(args, "maximum_concurrent_videos") and args.maximum_concurrent_videos: + client_count = engine_args._api_process_count + client_index = engine_args._api_process_rank + + base_limit = args.maximum_concurrent_videos // client_count + remainder = args.maximum_concurrent_videos % client_count + + # Process 0 gets any remainder slots + max_concurrent_videos_per_process = ( + base_limit + remainder if client_index == 0 else base_limit + ) + + # Override the engine args with per-process limit + engine_args.maximum_concurrent_videos = max_concurrent_videos_per_process + + logger.info( + "Video concurrency limit: %d videos per process (process %d of %d, total limit: %d)", + max_concurrent_videos_per_process, + client_index, + client_count, + args.maximum_concurrent_videos, + ) if disable_frontend_multiprocessing is None: disable_frontend_multiprocessing = bool(args.disable_frontend_multiprocessing) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 413e71ec2e46..a88837e8dff3 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -312,6 +312,15 @@ def validate_parsed_serve_args(args: argparse.Namespace): raise TypeError("Error: --exclude-log-deltas requires --enable-log-outputs") if args.enable_log_outputs and not args.enable_log_requests: raise TypeError("Error: --enable-log-outputs requires --enable-log-requests") + + # Validate maximum_concurrent_videos + if hasattr(args, "maximum_concurrent_videos") and args.maximum_concurrent_videos is not None: + if hasattr(args, "api_server_count") and args.maximum_concurrent_videos < args.api_server_count: + logger.warning( + f"--maximum-concurrent-videos ({args.maximum_concurrent_videos}) " + f"is less than --api-server-count ({args.api_server_count}). " + f"Some API servers will have a limit of 0." + ) def create_parser_for_docs() -> FlexibleArgumentParser: diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 352d70649add..3c35ecd4a604 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -170,6 +170,29 @@ def __init__( # Please use the Responses API instead. self.supports_code_interpreter = False self.python_tool = None + + # Initialize MediaConnector for video semaphore management + try: + from vllm import envs as vllm_envs + from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY + multimodal_config = self.model_config.multimodal_config + media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None) if multimodal_config else None + max_concurrent_videos = getattr(multimodal_config, "max_concurrent_videos", None) if multimodal_config else None + self._media_connector = MEDIA_CONNECTOR_REGISTRY.load( + vllm_envs.VLLM_MEDIA_CONNECTOR, + media_io_kwargs=media_io_kwargs, + allowed_local_media_path=self.model_config.allowed_local_media_path or "", + allowed_media_domains=self.model_config.allowed_media_domains, + max_concurrent_videos=max_concurrent_videos, + ) + if max_concurrent_videos: + logger.info( + "Chat service initialized with video concurrency limit: %d", + max_concurrent_videos + ) + except Exception as e: + logger.warning("Failed to initialize MediaConnector for video semaphore: %s", e) + self._media_connector = None async def warmup(self) -> None: """ @@ -243,7 +266,57 @@ async def create_chat_completion( # success status before we actually start generating text :). if self.engine_client.errored: raise self.engine_client.dead_error + + # Count videos in the request to acquire appropriate semaphore slots + try: + from vllm.entrypoints.chat_utils import count_videos_in_messages + video_count = count_videos_in_messages(request.messages) + if video_count > 0: + logger.debug("Request has %d video(s), will acquire semaphore", video_count) + except Exception as e: + logger.warning("Failed to count videos in request: %s", e) + video_count = 0 + + # Acquire semaphore for entire request lifecycle (including video loading) + if video_count > 0 and hasattr(self, '_media_connector') and self._media_connector: + # Check if response will be streaming - need to handle differently + if request.stream: + # For streaming, wrap the generator to hold semaphore during iteration + return self._stream_with_video_semaphore( + request, raw_request, video_count + ) + else: + # For non-streaming, acquire semaphore for the entire await + async with self._media_connector.acquire_video_semaphore(video_count): + return await self._create_chat_completion_inner(request, raw_request) + else: + return await self._create_chat_completion_inner(request, raw_request) + + async def _stream_with_video_semaphore( + self, + request: ChatCompletionRequest, + raw_request: Request | None, + video_count: int, + ) -> AsyncGenerator[str, None]: + """ + Wrapper for streaming responses that holds the video semaphore + while the stream is being consumed. + """ + async with self._media_connector.acquire_video_semaphore(video_count): + result = await self._create_chat_completion_inner(request, raw_request) + # result should be an AsyncGenerator for streaming + async for item in result: + yield item + async def _create_chat_completion_inner( + self, + request: ChatCompletionRequest, + raw_request: Request | None, + ) -> AsyncGenerator[str, None] | ChatCompletionResponse | ErrorResponse: + """ + Inner method that processes the chat completion request. + Called after semaphore acquisition. + """ try: lora_request = self._maybe_get_adapters( request, supports_default_mm_loras=True @@ -442,6 +515,7 @@ async def create_chat_completion( request_metadata, ) + # Non-streaming response try: return await self.chat_completion_full_generator( request, diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 54aa4795920b..06be464ac593 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -256,6 +256,29 @@ def __init__( self.background_tasks: dict[str, asyncio.Task] = {} self.tool_server = tool_server + + # Initialize MediaConnector for video semaphore management + try: + from vllm import envs as vllm_envs + from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY + multimodal_config = self.model_config.multimodal_config + media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None) if multimodal_config else None + max_concurrent_videos = getattr(multimodal_config, "max_concurrent_videos", None) if multimodal_config else None + self._media_connector = MEDIA_CONNECTOR_REGISTRY.load( + vllm_envs.VLLM_MEDIA_CONNECTOR, + media_io_kwargs=media_io_kwargs, + allowed_local_media_path=self.model_config.allowed_local_media_path or "", + allowed_media_domains=self.model_config.allowed_media_domains, + max_concurrent_videos=max_concurrent_videos, + ) + if max_concurrent_videos: + logger.info( + "Responses service initialized with video concurrency limit: %d", + max_concurrent_videos + ) + except Exception as e: + logger.warning("Failed to initialize MediaConnector for video semaphore: %s", e) + self._media_connector = None def _validate_generator_input( self, engine_prompt: TokensPrompt @@ -331,7 +354,66 @@ async def create_responses( # success status before we actually start generating text :). if self.engine_client.errored: raise self.engine_client.dead_error + + # Count videos in the request to acquire appropriate semaphore slots + video_count = 0 + try: + from vllm.entrypoints.chat_utils import count_videos_in_content_parts + if isinstance(request.input, list): + for item in request.input: + if isinstance(item, dict): + content = item.get("content") + video_count += count_videos_in_content_parts(content) + if video_count > 0: + logger.debug("Request has %d video(s), will acquire semaphore", video_count) + except Exception as e: + logger.warning("Failed to count videos in request: %s", e) + video_count = 0 + # Acquire semaphore for entire request lifecycle (including video loading) + if video_count > 0 and hasattr(self, '_media_connector') and self._media_connector: + # Check if response will be streaming - need to handle differently + if request.stream: + # For streaming, wrap the generator to hold semaphore during iteration + return self._stream_with_video_semaphore( + request, raw_request, video_count + ) + else: + # For non-streaming, acquire semaphore for the entire await + async with self._media_connector.acquire_video_semaphore(video_count): + return await self._create_responses_inner(request, raw_request) + else: + return await self._create_responses_inner(request, raw_request) + + async def _stream_with_video_semaphore( + self, + request: ResponsesRequest, + raw_request: Request | None, + video_count: int, + ) -> AsyncGenerator[StreamingResponsesResponse, None]: + """ + Wrapper for streaming responses that holds the video semaphore + while the stream is being consumed. + """ + async with self._media_connector.acquire_video_semaphore(video_count): + result = await self._create_responses_inner(request, raw_request) + # result should be an AsyncGenerator for streaming + async for item in result: + yield item + + async def _create_responses_inner( + self, + request: ResponsesRequest, + raw_request: Request | None, + ) -> ( + AsyncGenerator[StreamingResponsesResponse, None] + | ResponsesResponse + | ErrorResponse + ): + """ + Inner method that processes the responses request. + Called after semaphore acquisition. + """ if request.store and not self.enable_store: # Disable the store option. # NOTE(woosuk): Although returning an error is possible, we opted @@ -537,6 +619,7 @@ async def create_responses( request_metadata, ) + # Non-streaming response try: return await self.responses_full_generator( request, diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 8e1178bc7ea4..8fdc029dabf1 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -152,4 +152,4 @@ def load_file(self, filepath: Path) -> torch.Tensor: return tensor.to_dense() def encode_base64(self, media: torch.Tensor) -> str: - return pybase64.b64encode(media.numpy()).decode("utf-8") + return pybase64.b64encode(media.cpu().numpy()).decode("utf-8") diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index bd49d7192346..70a49ca893cf 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -577,7 +577,7 @@ def _shape_before_after(tensor: torch.Tensor): (*shape_before, shape_concat, *shape_after), dtype=batch[0].dtype, device=batch[0].device, - pin_memory=pin_memory, + pin_memory=pin_memory and batch[0].device.type == 'cpu', ) return torch.concat(batch, dim=self.dim, out=out) diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index 64c03f8d4da9..142f8aa4bd19 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -524,7 +524,7 @@ def _get_video_with_metadata( if isinstance(video, np.ndarray): return video, None if isinstance(video, torch.Tensor): - return video.numpy(), None + return video.cpu().numpy(), None assert_never(video) diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 07165430b2c9..a2c7f567c731 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -6,6 +6,7 @@ import mimetypes from collections.abc import Generator from concurrent.futures import ThreadPoolExecutor +from contextlib import asynccontextmanager from itertools import groupby from pathlib import Path from typing import TYPE_CHECKING, Any, TypeVar @@ -49,6 +50,11 @@ MEDIA_CONNECTOR_REGISTRY = ExtensionManager() +# Global process-level video concurrency semaphore +# This is shared across all MediaConnector instances in the same process +_global_video_semaphore: asyncio.Semaphore | None = None +_video_semaphore_lock: asyncio.Lock | None = None + @MEDIA_CONNECTOR_REGISTRY.register("http") class MediaConnector: @@ -59,6 +65,7 @@ def __init__( *, allowed_local_media_path: str = "", allowed_media_domains: list[str] | None = None, + max_concurrent_videos: int | None = None, ) -> None: """ Args: @@ -70,6 +77,8 @@ def __init__( allowed_local_media_path: A local directory to load media files from. allowed_media_domains: If set, only media URLs that belong to this domain can be used for multi-modal inputs. + max_concurrent_videos: Maximum number of videos that can be + preprocessed concurrently in this process. """ super().__init__() @@ -98,6 +107,80 @@ def __init__( if allowed_media_domains is None: allowed_media_domains = [] self.allowed_media_domains = allowed_media_domains + + # Store the max_concurrent_videos for semaphore initialization + self._max_concurrent_videos = max_concurrent_videos + + async def _get_video_semaphore(self) -> asyncio.Semaphore | None: + """ + Get or create the global process-level video semaphore. + This ensures the semaphore is shared across all MediaConnector instances. + """ + global _global_video_semaphore, _video_semaphore_lock + + if self._max_concurrent_videos is None or self._max_concurrent_videos <= 0: + return None + + # Lazily create the lock if needed + if _video_semaphore_lock is None: + _video_semaphore_lock = asyncio.Lock() + + # Double-checked locking pattern for thread-safe singleton initialization + if _global_video_semaphore is None: + async with _video_semaphore_lock: + if _global_video_semaphore is None: + _global_video_semaphore = asyncio.Semaphore(self._max_concurrent_videos) + logger.info( + "Global video concurrency semaphore created with limit: %d (process-wide shared)", + self._max_concurrent_videos, + ) + + return _global_video_semaphore + + @asynccontextmanager + async def acquire_video_semaphore(self, video_count: int = 1): + """ + Public method to acquire semaphore slots for the entire request lifecycle. + Should be used at the request level (not just during video loading) to + ensure VRAM is reserved for the entire time videos are in memory. + + Args: + video_count: Number of videos being processed (number of slots to acquire) + """ + semaphore = await self._get_video_semaphore() + + if semaphore and video_count > 0: + # Check if we'll need to wait + if semaphore._value < video_count: + logger.info( + "Video semaphore: Need %d slot(s) but only %d available, will block until slots free", + video_count, + semaphore._value, + ) + + # Acquire N slots for N videos + for _ in range(video_count): + await semaphore.acquire() + + logger.debug( + "Acquired %d video semaphore slot(s), remaining available: %d", + video_count, + semaphore._value, + ) + try: + yield + finally: + # Release N slots + for _ in range(video_count): + semaphore.release() + logger.debug( + "Released %d video semaphore slot(s), now available: %d", + video_count, + semaphore._value, + ) + else: + # No-op if no semaphore or no videos + yield def _load_data_url( self, @@ -320,6 +403,10 @@ async def fetch_video_async( Asynchronously load video from an HTTP or base64 data URL. By default, the image is converted into RGB format. + + Note: This method does NOT acquire the video semaphore. The semaphore + should be acquired at the request level to ensure VRAM is reserved + for the entire request lifecycle, not just during loading. """ image_io = ImageMediaIO( image_mode=image_mode, **self.media_io_kwargs.get("image", {}) diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index 8204cdfbc5de..3efaaf0da52d 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -6,11 +6,13 @@ from functools import partial from io import BytesIO from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Union +import threading import numpy as np import numpy.typing as npt from PIL import Image +import torch if TYPE_CHECKING: import cv2 @@ -18,12 +20,15 @@ from vllm import envs from vllm.logger import init_logger from vllm.utils.registry import ExtensionManager +from vllm.v1.utils import record_function_or_nullcontext from .base import MediaIO from .image import ImageMediaIO logger = init_logger(__name__) +VideoData = Union[npt.NDArray, torch.Tensor] + def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray: num_frames, _, _, channels = frames.shape @@ -63,7 +68,7 @@ class VideoLoader: @abstractmethod def load_bytes( cls, data: bytes, num_frames: int = -1, **kwargs - ) -> tuple[npt.NDArray, dict[str, Any]]: + ) -> tuple[VideoData, dict[str, Any]]: raise NotImplementedError @staticmethod @@ -439,7 +444,156 @@ def load_bytes( return frames, metadata -class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]): +@VIDEO_LOADER_REGISTRY.register("pynvvideocodec") +class PyNVVideoBackend(VideoLoader): + + # One decoder instance per Python thread to avoid cross-thread + # reconfigure/usage issues while still amortizing decoder + # initialization cost within each worker thread. + _thread_local: threading.local = threading.local() + + @classmethod + def _get_thread_decoder(cls): + return getattr(cls._thread_local, "decoder", None) + + @classmethod + def _set_thread_decoder(cls, decoder): + cls._thread_local.decoder = decoder + + _cuda_stream: torch.cuda.Stream | None = None + _cuda_stream_lock = threading.Lock() + + @classmethod + def get_cuda_stream(cls): + """Get or create CUDA stream for video decoding.""" + if cls._cuda_stream is None: + with cls._cuda_stream_lock: + if cls._cuda_stream is None: + cls._cuda_stream = torch.cuda.Stream(device=torch.cuda.current_device()) + return cls._cuda_stream + + @classmethod + def _decode_from_file( + cls, + file_path: str, + num_frames: int = -1, + fps: float = 0.0, + ) -> tuple[torch.Tensor, dict[str, Any]]: + import PyNvVideoCodec as nvc + + gpu_id = torch.cuda.current_device() + cuda_stream = cls.get_cuda_stream() + with torch.cuda.stream(cuda_stream): + cuda_stream_handle = cuda_stream.cuda_stream + + decoder = cls._get_thread_decoder() + if decoder is not None: + # Reuse existing decoder on this thread by + # reconfiguring it for the new file. + decoder.reconfigure_decoder(file_path) + else: + decoder = nvc.SimpleDecoder( + file_path, + output_color_type=nvc.OutputColorType.RGB, + use_device_memory=True, + need_scanned_stream_metadata=False, + gpu_id=gpu_id, + cuda_stream=cuda_stream_handle, + decoder_cache_size=2, + ) + cls._set_thread_decoder(decoder) + + total_frames_num = len(decoder) + metadata = decoder.get_stream_metadata() + original_fps = metadata.average_fps + duration = (total_frames_num / original_fps if original_fps > 0 else 0.0) + + # Ensure num_frames is not greater than the total number of frames in the video + num_frames = min(num_frames, total_frames_num) + + # If fps > 0, derive num_frames from fps. Any incoming + # num_frames value is ignored in this case, so that callers + # can safely set only `fps` via --media-io-kwargs without + # having to override the VideoMediaIO default num_frames. + if fps > 0: + if fps > original_fps: + # raise ValueError( + # f"Target fps ({fps}) must be <= than original " + # f"fps ({original_fps}).") + fps = original_fps + + num_frames = int(total_frames_num * fps / original_fps) + num_frames = max(1, num_frames) + + full_read = (num_frames == -1 + or total_frames_num <= num_frames) + if full_read: + num_frames = total_frames_num + frame_idx = list(range(0, num_frames)) + else: + uniform_sampled_frames = np.linspace( + 0, + total_frames_num - 1, + num_frames, + dtype=int, + ) + frame_idx = uniform_sampled_frames.tolist() + + dlpack_frames = decoder.get_batch_frames_by_index(frame_idx) + + # PyNvVideoCodec returns DLPack tensors; convert to NCHW + # torch tensor in RGB format stacked as (N, H, W, C). + torch_frames = [torch.from_dlpack(f) for f in dlpack_frames] + torch_frames = torch.stack(torch_frames) + + # Permute to the expected format (N, C, H, W) in RGB + frames = torch_frames.permute(0, 3, 1, 2) + + torch.cuda.empty_cache() + + # Use transformers transformers.video_utils.VideoMetadata format + # and keep semantics close to the OpenCV backend. + video_metadata: dict[str, Any] = { + "total_num_frames": total_frames_num, + "fps": original_fps, + "duration": duration, + "video_backend": "pynvvideocodec", + "frames_indices": frame_idx, + # Loader may subsample frames, so by default HF processors should + # not re-sample based on this metadata. + "do_sample_frames": False, + } + + return frames, video_metadata + + @classmethod + def load_bytes( + cls, + data: bytes, + num_frames: int = -1, + fps: float = 0.0, + **kwargs, + ) -> tuple[torch.Tensor, dict[str, Any]]: + """ + Video loader using PyNvVideoCodec for hardware-accelerated decoding. + + Supports either a fixed `num_frames` or a target `fps` passed via + `--media-io-kwargs '{"video":{"fps": }}'`. If `fps > 0`, then + `num_frames` is derived from the original fps of the video. When + `fps == 0`, the behavior is the same as not specifying fps at all. + + Returns: + A tuple of (frames, metadata) where frames is a torch.Tensor on GPU + in (N, H, W, C) format. + """ + import tempfile + + with tempfile.NamedTemporaryFile(suffix=".mp4") as temp_file: + temp_file.write(data) + temp_file.flush() + return cls._decode_from_file(temp_file.name, num_frames, fps) + +class VideoMediaIO(MediaIO[tuple[VideoData, dict[str, Any]]]): def __init__( self, image_io: ImageMediaIO, @@ -466,14 +620,13 @@ def __init__( self.kwargs = kwargs self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend) - def load_bytes(self, data: bytes) -> tuple[npt.NDArray, dict[str, Any]]: - return self.video_loader.load_bytes( - data, num_frames=self.num_frames, **self.kwargs - ) + def load_bytes(self, data: bytes) -> tuple[VideoData, dict[str, Any]]: + with record_function_or_nullcontext("video.load_bytes"): + return self.video_loader.load_bytes(data, self.num_frames, **self.kwargs) def load_base64( self, media_type: str, data: str - ) -> tuple[npt.NDArray, dict[str, Any]]: + ) -> tuple[VideoData, dict[str, Any]]: if media_type.lower() == "video/jpeg": load_frame = partial( self.image_io.load_base64, @@ -486,7 +639,7 @@ def load_base64( return self.load_bytes(base64.b64decode(data)) - def load_file(self, filepath: Path) -> tuple[npt.NDArray, dict[str, Any]]: + def load_file(self, filepath: Path) -> tuple[VideoData, dict[str, Any]]: with filepath.open("rb") as f: data = f.read() @@ -494,11 +647,15 @@ def load_file(self, filepath: Path) -> tuple[npt.NDArray, dict[str, Any]]: def encode_base64( self, - media: npt.NDArray, + media: VideoData, *, video_format: str = "JPEG", ) -> str: - video = media + # Convert to numpy array if it's a tensor + if isinstance(media, torch.Tensor): + video = media.cpu().numpy() + else: + video = media if video_format == "JPEG": encode_frame = partial( diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index dcf76da6a09f..554b0436fec5 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -613,6 +613,7 @@ def __init__( client_handshake_address: str | None = None, *, engine_index: int = 0, + tensor_queues: list[Any] | None = None, ): self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]() self.output_queue = queue.Queue[tuple[int, EngineCoreOutputs] | bytes]() @@ -633,6 +634,16 @@ def __init__( ) as addresses: self.client_count = len(addresses.outputs) + # Get this engine's tensor IPC queue for receiving CUDA tensors + # Queues are passed directly via constructor since they can't be serialized + self.tensor_queue = None + if tensor_queues and addresses.tensor_queue_index is not None: + self.tensor_queue = tensor_queues[addresses.tensor_queue_index] + logger.info( + "Engine %d using tensor IPC queue for CUDA tensor sharing", + self.engine_index, + ) + # Set up data parallel environment. self.has_coordinator = addresses.coordinator_output is not None self.frontend_stats_publish_address = ( @@ -840,10 +851,20 @@ def startup_handshake( for key, value in init_message.parallel_config.items(): setattr(parallel_config, key, value) - return init_message.addresses + # Store tensor_queue_index for engine to access + addresses = init_message.addresses + addresses.tensor_queue_index = init_message.tensor_queue_index + + return addresses @staticmethod - def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs): + def run_engine_core( + *args, + dp_rank: int = 0, + local_dp_rank: int = 0, + tensor_queues: list[Any] | None = None, + **kwargs + ): """Launch EngineCore busy loop in background process.""" # Signal handler used for graceful termination. @@ -880,15 +901,10 @@ def signal_handler(signum, frame): if data_parallel and vllm_config.model_config.is_moe: # Set data parallel rank for this engine process. parallel_config.data_parallel_rank = dp_rank - engine_core = DPEngineCoreProc(*args, **kwargs) + parallel_config.data_parallel_rank_local = local_dp_rank + engine_core = DPEngineCoreProc(*args, tensor_queues=tensor_queues, **kwargs) else: - # Non-MoE DP ranks are completely independent, so treat like DP=1. - # Note that parallel_config.data_parallel_index will still reflect - # the original DP rank. - parallel_config.data_parallel_size = 1 - parallel_config.data_parallel_size_local = 1 - parallel_config.data_parallel_rank = 0 - engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs) + engine_core = EngineCoreProc(*args, tensor_queues=tensor_queues, **kwargs) engine_core.run_busy_loop() @@ -1038,9 +1054,11 @@ def process_input_sockets( ): """Input socket IO thread.""" - # Msgpack serialization decoding. - add_request_decoder = MsgpackDecoder(EngineCoreRequest) - generic_decoder = MsgpackDecoder() + # Msgpack serialization decoding with tensor queue for CUDA tensors. + add_request_decoder = MsgpackDecoder( + EngineCoreRequest, tensor_queue=self.tensor_queue + ) + generic_decoder = MsgpackDecoder(tensor_queue=self.tensor_queue) with ExitStack() as stack, zmq.Context() as ctx: input_sockets = [ @@ -1217,6 +1235,7 @@ def __init__( executor_class: type[Executor], log_stats: bool, client_handshake_address: str | None = None, + tensor_queues: list[Any] | None = None, ): assert vllm_config.model_config.is_moe, ( "DPEngineCoreProc should only be used for MoE models" @@ -1238,6 +1257,7 @@ def __init__( log_stats, client_handshake_address, engine_index=dp_rank, + tensor_queues=tensor_queues, ) def _init_data_parallel(self, vllm_config: VllmConfig): diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index f74e90abc906..2dc1f00b59fa 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -450,10 +450,7 @@ def __init__( client_addresses: dict[str, str] | None = None, ): self.vllm_config = vllm_config - # Serialization setup. - self.encoder = MsgpackEncoder() - self.decoder = MsgpackDecoder(EngineCoreOutputs) - + # ZMQ setup. sync_ctx = zmq.Context(io_threads=2) self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx @@ -469,11 +466,14 @@ def __init__( self.engines_running = False self.stats_update_address: str | None = None + tensor_queues = None if client_addresses: # Engines are managed externally to this client. input_address = client_addresses["input_address"] output_address = client_addresses["output_address"] self.stats_update_address = client_addresses.get("stats_update_address") + # Tensor queues passed via client_addresses for multi-API-server case + tensor_queues = client_addresses.get("tensor_queues") else: # Engines are managed by this client. with launch_core_engines(vllm_config, executor_class, log_stats) as ( @@ -487,11 +487,18 @@ def __init__( (input_address,) = addresses.inputs (output_address,) = addresses.outputs self.stats_update_address = addresses.frontend_stats_publish_address + tensor_queues = addresses.tensor_queues if coordinator is not None: assert self.stats_update_address == ( coordinator.get_stats_publish_address() ) + # Serialization setup with tensor queues for CUDA tensor IPC. + self.encoder = MsgpackEncoder(tensor_queues=tensor_queues) + self.decoder = MsgpackDecoder(EngineCoreOutputs) + # Store tensor queues for routing + self.resources.tensor_queues = tensor_queues + # Create input and output sockets. self.input_socket = self.resources.input_socket = make_zmq_socket( self.ctx, input_address, zmq.ROUTER, bind=True @@ -908,6 +915,10 @@ def _send_input( if engine is None: engine = self.core_engine + # Set target engine index for CUDA tensor routing + engine_index = int.from_bytes(engine, "little") + self.encoder.set_target_engine(engine_index) + message = (request_type.value, *self.encoder.encode(request)) return self._send_input_message(message, engine, request) diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 66212ed7cd5e..fb6405d75f17 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -9,10 +9,11 @@ from enum import Enum, auto from multiprocessing import Process, connection from multiprocessing.process import BaseProcess -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from unittest.mock import patch import msgspec +import torch.multiprocessing as torch_mp import zmq from vllm import envs @@ -64,6 +65,11 @@ class EngineZmqAddresses: # Not used by engine, just relayed to front-end in handshake response. # Only required for external DP LB case. frontend_stats_publish_address: str | None = None + # Tensor IPC queues for sharing CUDA tensors between API servers and engines + # One queue per engine core for direct GPU tensor transfer + tensor_queues: list[Any] | None = None + # Index of this engine's tensor queue (set during handshake) + tensor_queue_index: int | None = None @dataclass @@ -75,6 +81,8 @@ class EngineHandshakeMetadata: addresses: EngineZmqAddresses parallel_config: dict[str, int | str | list[int]] + # Index of this engine's tensor queue in addresses.tensor_queues + tensor_queue_index: int | None = None class CoreEngineProcManager: @@ -95,6 +103,7 @@ def __init__( executor_class: type[Executor], log_stats: bool, client_handshake_address: str | None = None, + tensor_queues: list[Any] | None = None, ): context = get_mp_context() common_kwargs = { @@ -108,6 +117,9 @@ def __init__( if client_handshake_address: common_kwargs["client_handshake_address"] = client_handshake_address + # Store tensor_queues for passing to engine processes + self.tensor_queues = tensor_queues + self.processes: list[BaseProcess] = [] local_dp_ranks = [] for index in range(local_engine_count): @@ -124,6 +136,7 @@ def __init__( | { "dp_rank": global_index, "local_dp_rank": local_index, + "tensor_queues": tensor_queues, }, ) ) @@ -803,6 +816,15 @@ def launch_core_engines( offline_mode or local_engines_only or (local_engine_count == dp_size) ) + # Create tensor IPC queues for sharing CUDA tensors between API servers + # and engine cores. One queue per engine core. + # Use torch.multiprocessing for CUDA tensor sharing via IPC/shared memory. + # Set start method to 'spawn' for compatibility with CUDA multiprocessing. + torch_mp.set_start_method('spawn', force=True) + tensor_queues: list[torch_mp.Queue] = [ + torch_mp.Queue() for _ in range(dp_size) + ] + # Set up input and output addresses. addresses = EngineZmqAddresses( inputs=[ @@ -813,6 +835,7 @@ def launch_core_engines( get_engine_client_zmq_addr(client_local_only, host) for _ in range(num_api_servers) ], + tensor_queues=tensor_queues, ) # Run the DP Coordinator process with rank 0 when in online DP mode. @@ -911,6 +934,7 @@ def launch_core_engines( local_engine_count=local_engine_count, start_index=dp_rank, local_start_index=local_start_index or 0, + tensor_queues=tensor_queues, ) else: local_engine_manager = None @@ -1018,9 +1042,21 @@ def wait_for_engine_startup( if status == "HELLO" and engine.state == CoreEngineState.NEW: # Send init message with DP config info. + # Note: tensor_queues are excluded from serialization as they can't be + # serialized by msgspec. They are passed directly to engine processes + # when spawning them. + addresses_for_handshake = EngineZmqAddresses( + inputs=addresses.inputs, + outputs=addresses.outputs, + coordinator_input=addresses.coordinator_input, + coordinator_output=addresses.coordinator_output, + frontend_stats_publish_address=addresses.frontend_stats_publish_address, + tensor_queues=None, # Don't serialize queues + tensor_queue_index=None, # Will be set separately + ) init_message = msgspec.msgpack.encode( EngineHandshakeMetadata( - addresses=addresses, + addresses=addresses_for_handshake, parallel_config={ k: getattr(parallel_config, k) for k in ( @@ -1029,9 +1065,8 @@ def wait_for_engine_startup( "_data_parallel_master_port_list", "data_parallel_size", ) - } - if coordinated_dp - else {}, + } if coordinated_dp else {}, + tensor_queue_index=eng_index, ) ) handshake_socket.send_multipart((eng_identity, init_message), copy=False) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index a3c30e368b82..c4b7d2e9bea1 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -41,6 +41,35 @@ CUSTOM_TYPE_CLOUDPICKLE = 2 CUSTOM_TYPE_RAW_VIEW = 3 + +@dataclasses.dataclass +class TensorIpcData: + """ + Data sent via torch.multiprocessing.Queue for zero-copy IPC. + + Contains the tensor_id and the actual tensor. The tensor is shared + in GPU memory for efficient inter-process communication. + """ + tensor_id: str + tensor: torch.Tensor + + +@dataclasses.dataclass +class TensorIpcHandle: + """ + Handle for a tensor sent via IPC queue (zero-copy transfer). + + Contains only metadata about the tensor. This is serialized via msgpack + and used by the decoder to retrieve the actual tensor from the queue. + The actual tensor is sent separately via torch.multiprocessing.Queue + as TensorIpcData. + """ + tensor_id: str + shape: list[int] + dtype: str + device: str + + # MultiModalField class serialization type map. # These need to list all possible field types and match them # to factory methods in `MultiModalFieldConfig`. @@ -119,9 +148,16 @@ class MsgpackEncoder: By default, arrays below 256B are serialized inline Larger will get sent via dedicated messages. Note that this is a per-tensor limit. + + For CUDA tensors, when tensor_queues is provided, they will be sent via + torch.multiprocessing.Queue for zero-copy IPC instead of serialization. """ - def __init__(self, size_threshold: int | None = None): + def __init__( + self, + size_threshold: int | None = None, + tensor_queues: list[Any] | None = None, + ): if size_threshold is None: size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD self.encoder = msgpack.Encoder(enc_hook=self.enc_hook) @@ -130,9 +166,19 @@ def __init__(self, size_threshold: int | None = None): # pass custom data to the hook otherwise. self.aux_buffers: list[bytestr] | None = None self.size_threshold = size_threshold + # Tensor IPC queues for sharing CUDA tensors (one per engine core) + self.tensor_queues = tensor_queues + # Target engine index for routing tensors to the correct queue + self.target_engine_index: int | None = None + # Counter for generating unique tensor IDs + self._tensor_id_counter = 0 if envs.VLLM_ALLOW_INSECURE_SERIALIZATION: _log_insecure_serialization_warning() + def set_target_engine(self, engine_index: int | None) -> None: + """Set the target engine index for routing CUDA tensors to queues.""" + self.target_engine_index = engine_index + def encode(self, obj: Any) -> Sequence[bytestr]: try: self.aux_buffers = bufs = [b""] @@ -168,7 +214,7 @@ def enc_hook(self, obj: Any) -> Any: int(v) if v is not None else None for v in (obj.start, obj.stop, obj.step) ) - + if isinstance(obj, MultiModalKwargsItem): return self._encode_mm_item(obj) @@ -222,8 +268,60 @@ def _encode_ndarray( def _encode_tensor( self, obj: torch.Tensor - ) -> tuple[str, tuple[int, ...], int | memoryview]: + ) -> tuple[str, tuple[int, ...], int | memoryview] | dict[str, Any]: assert self.aux_buffers is not None + + # Check if this is a CUDA tensor and we have queues available + if ( + obj.is_cuda + and self.tensor_queues is not None + and self.target_engine_index is not None + ): + # Send CUDA tensor via torch.multiprocessing.Queue for zero-copy IPC + # Generate unique tensor ID + tensor_id = f"{id(self)}_{self._tensor_id_counter}" + self._tensor_id_counter += 1 + + try: + # Move tensor to GPU shared memory for IPC + # This is required for proper CUDA inter-process communication + if not obj.is_shared(): + obj = obj.share_memory_() + + # Put TensorIpcData (tensor_id + tensor) into the target engine's queue + target_queue = self.tensor_queues[self.target_engine_index] + ipc_data = TensorIpcData(tensor_id=tensor_id, tensor=obj) + # Use a timeout to avoid blocking indefinitely + target_queue.put(ipc_data, timeout=10.0) + + logger.debug( + "Sent CUDA tensor %s (shape=%s, device=%s) to engine %d via queue (shared memory)", + tensor_id, + obj.shape, + obj.device, + self.target_engine_index, + ) + + return TensorIpcHandle(tensor_id=tensor_id, shape=list(obj.shape), dtype=str(obj.dtype).removeprefix("torch."), device=str(obj.device)) + except Exception as e: + logger.warning( + "Failed to send CUDA tensor via queue: %s. " + "Falling back to standard serialization.", + e, + ) + raise e + # Fall through to standard serialization + + + # Fall back to standard serialization for CPU tensors or when queues unavailable + # For CUDA tensors without queue support, we need to move to CPU first + if obj.is_cuda: + logger.warning( + "CUDA tensor without queue support encountered. " + "Moving to CPU for serialization. This will be slow." + ) + obj = obj.cpu() + # view the tensor as a contiguous 1D array of bytes arr_data = tensor_data(obj) if obj.nbytes < self.size_threshold: @@ -281,9 +379,17 @@ class MsgpackDecoder: Note that unlike vanilla `msgspec` Decoders, this interface is generally not thread-safe when encoding tensors / numpy arrays. + + For CUDA tensors sent via torch.multiprocessing.Queue, they will be + retrieved from the queue during decoding. """ - def __init__(self, t: Any | None = None, share_mem: bool = True): + def __init__( + self, + t: Any | None = None, + share_mem: bool = True, + tensor_queue: Any | None = None, + ): self.share_mem = share_mem self.pin_tensors = is_pin_memory_available() args = () if t is None else (t,) @@ -291,6 +397,11 @@ def __init__(self, t: Any | None = None, share_mem: bool = True): *args, ext_hook=self.ext_hook, dec_hook=self.dec_hook ) self.aux_buffers: Sequence[bytestr] = () + # Tensor IPC queue for receiving CUDA tensors from API servers + self.tensor_queue = tensor_queue + # Buffer for temporarily storing tensors retrieved from queue + # that don't match the current request + self._tensor_buffer: dict[str, torch.Tensor] = {} if envs.VLLM_ALLOW_INSECURE_SERIALIZATION: _log_insecure_serialization_warning() @@ -309,6 +420,8 @@ def dec_hook(self, t: type, obj: Any) -> Any: if isclass(t): if issubclass(t, np.ndarray): return self._decode_ndarray(obj) + if issubclass(t, TensorIpcHandle): + return self._decode_cuda_queue_tensor(obj) if issubclass(t, torch.Tensor): return self._decode_tensor(obj) if t is slice: @@ -354,6 +467,7 @@ def _decode_ndarray(self, arr: Any) -> np.ndarray: return arr.reshape(shape) def _decode_tensor(self, arr: Any) -> torch.Tensor: + # Standard tensor decoding dtype, shape, data = arr is_aux = isinstance(data, int) buffer = self.aux_buffers[data] if is_aux else data @@ -374,6 +488,18 @@ def _decode_tensor(self, arr: Any) -> torch.Tensor: arr = arr.pin_memory() if self.pin_tensors else arr.clone() # Convert back to proper shape & type return arr.view(torch_dtype).view(shape) + + def _decode_cuda_queue_tensor(self, handle: TensorIpcHandle) -> torch.Tensor: + """Retrieve a CUDA tensor from the torch.multiprocessing.Queue.""" + + # Drain all available tensors. We save them regardless if this is the one + # we're waiting for as they may arrive out of order from multiple producers. + while handle.tensor_id not in self._tensor_buffer: + ipc_data: TensorIpcData = self.tensor_queue.get(timeout=10.0) + self._tensor_buffer[ipc_data.tensor_id] = ipc_data.tensor + + tensor = self._tensor_buffer.pop(handle.tensor_id) + return tensor def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems: return MultiModalKwargsItems( @@ -409,6 +535,14 @@ def _decode_nested_tensors(self, obj: Any) -> NestedTensors: # Although it violates NestedTensors type, MultiModalKwargs # values are sometimes floats. return obj + if isinstance(obj, TensorIpcHandle): + return self._decode_cuda_queue_tensor(obj) + # Check if this is a dict that represents a TensorIpcHandle + # (msgspec serializes dataclasses as dicts without type info in nested structures) + if isinstance(obj, dict) and 'tensor_id' in obj and 'shape' in obj and 'dtype' in obj and 'device' in obj: + # Convert dict to TensorIpcHandle and decode it + handle = TensorIpcHandle(**obj) + return self._decode_cuda_queue_tensor(handle) if not isinstance(obj, list): raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}") if obj and isinstance(obj[0], str): diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 29099d1e9b17..d4f1d6159cbb 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -172,6 +172,7 @@ def __init__( input_addresses: list[str], output_addresses: list[str], stats_update_address: str | None = None, + tensor_queues: list[Any] | None = None, ): """Initialize and start API server worker processes. @@ -184,6 +185,7 @@ def __init__( input_addresses: Input addresses for each API server output_addresses: Output addresses for each API server stats_update_address: Optional stats update address + tensor_queues: Optional tensor IPC queues for CUDA tensor sharing """ self.listen_address = listen_address self.sock = sock @@ -204,6 +206,8 @@ def __init__( } if stats_update_address is not None: client_config["stats_update_address"] = stats_update_address + if tensor_queues is not None: + client_config["tensor_queues"] = tensor_queues proc = spawn_context.Process( target=target_server_fn,