diff --git a/requirements/common.txt b/requirements/common.txt index 05666c5d14b0..62af08d1ad3e 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -50,6 +50,7 @@ setproctitle # Used to set process names for better debugging and monitoring openai-harmony >= 0.0.3 # Required for gpt-oss anthropic >= 0.71.0 model-hosting-container-standards >= 0.1.13, < 1.0.0 +lmdb >= 1.7.5 # Used by LMDB-based multimodal cache mcp opentelemetry-sdk >= 1.27.0 opentelemetry-api >= 1.27.0 diff --git a/tests/multimodal/test_lmdb_cache.py b/tests/multimodal/test_lmdb_cache.py new file mode 100644 index 000000000000..11c5c58ed814 --- /dev/null +++ b/tests/multimodal/test_lmdb_cache.py @@ -0,0 +1,237 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +import multiprocessing +import os +import signal +import tempfile +import time + +import numpy as np +import pytest +import torch + +from vllm.multimodal.cache import ( + LmdbObjectStoreSenderCache, + LmdbObjectStoreWorkerReceiverCache, +) +from vllm.multimodal.inputs import ( + MultiModalFieldElem, + MultiModalKwargsItem, + MultiModalSharedField, +) +from vllm.multimodal.lmdb_cache import LmdbMultiModalCache, ensure_lmdb_env +from vllm.multimodal.processing import PromptInsertion +from vllm.utils.mem_constants import GiB_bytes, MiB_bytes +from vllm.utils.system_utils import get_mp_context + + +def _dummy_elem( + size: int, + *, + rng: np.random.RandomState | None = None, +): + if rng is None: + data = torch.empty((size,), dtype=torch.int8) + else: + data = torch.from_numpy(rng.randint(4, size=(size,), dtype=np.int8)) + + return MultiModalFieldElem( + data=data, + field=MultiModalSharedField(batch_size=1), + ) + + +def _dummy_item( + size_by_key: dict[str, int], + *, + rng: np.random.RandomState | None = None, +): + return MultiModalKwargsItem( + {key: _dummy_elem(size, rng=rng) for key, size in size_by_key.items()} + ) + + +@pytest.fixture() +def lmdb_cache(): + with tempfile.TemporaryDirectory() as tmpdir: + yield LmdbMultiModalCache( + tmpdir, + cache_size=GiB_bytes, + min_eviction_age=-1, + max_object_size=10 * MiB_bytes, + ) + + +def test_ensure_lmdb_env(): + with tempfile.TemporaryDirectory() as tmpdir: + env_1 = ensure_lmdb_env(tmpdir) + env_2 = ensure_lmdb_env(tmpdir) + + assert env_1 is env_2 + + def _child(): + env_3 = ensure_lmdb_env(tmpdir) + assert env_3 is not env_1 + + p = multiprocessing.get_context("fork").Process(target=_child) + p.start() + p.join() + assert p.exitcode == 0 + + +def test_lmdb_insert_get_evict(lmdb_cache: LmdbMultiModalCache): + sender_cache = LmdbObjectStoreSenderCache(lmdb_cache) + + MM_HASH = "fake_hash" + ITEM_CHUNKS = 150.5 + dummy_item = _dummy_item({"key": int(lmdb_cache._max_chunk_size * ITEM_CHUNKS)}) + prompt_update = PromptInsertion("dummy", "target", "insertion") + + # Do two rounds to ensure inserting after evicting works correctly. + for _ in range(2): + with sender_cache.begin() as txn: + assert not txn.is_cached_item(MM_HASH) + should_be_none, update_1 = txn.get_and_update_item( + (dummy_item, prompt_update), MM_HASH + ) + assert should_be_none is None + assert update_1 == prompt_update + + receiver_cache = LmdbObjectStoreWorkerReceiverCache(lmdb_cache) + with receiver_cache.begin() as txn: + retrieved_item = txn.get_and_update_item(None, MM_HASH) + assert retrieved_item == dummy_item + + with sender_cache.begin() as txn: + assert txn.is_cached_item(MM_HASH) + should_be_none, update_2 = txn.get_and_update_item(None, MM_HASH) + assert should_be_none is None + assert update_2 == prompt_update + + evicted_items, _ = lmdb_cache.evict_once(min_utilization=0.0) + assert evicted_items == math.ceil(ITEM_CHUNKS) + 1 + + +def test_lmdb_eviction_with_transaction_open(lmdb_cache: LmdbMultiModalCache): + sender_cache = LmdbObjectStoreSenderCache(lmdb_cache) + + MM_HASH = "fake_hash" + ITEM_CHUNKS = 150.5 + dummy_item = _dummy_item({"key": int(lmdb_cache._max_chunk_size * ITEM_CHUNKS)}) + prompt_update = PromptInsertion("dummy", "target", "insertion") + with sender_cache.begin() as txn: + txn.get_and_update_item((dummy_item, prompt_update), MM_HASH) + + with sender_cache.begin() as txn: + assert txn.is_cached_item(MM_HASH) + evicted, _ = lmdb_cache.evict_once(min_utilization=0.0) + assert evicted == math.ceil(ITEM_CHUNKS) + 1 + + # The item should be cached since the transaction is still open. + assert txn.is_cached_item(MM_HASH) + + # But a separate transaction should not see the item. + with sender_cache.begin() as other_txn: + assert not other_txn.is_cached_item(MM_HASH) + + should_be_none, update = txn.get_and_update_item(None, MM_HASH) + assert should_be_none is None + assert update == prompt_update + + # After the transaction commits, the item should be there. + with sender_cache.begin() as txn: + assert txn.is_cached_item(MM_HASH) + should_be_none, update = txn.get_and_update_item(None, MM_HASH) + assert should_be_none is None + assert update == prompt_update + + # And the receiver cache should see the item as well. + receiver_cache = LmdbObjectStoreWorkerReceiverCache(lmdb_cache) + with receiver_cache.begin() as txn: + retrieved_item = txn.get_and_update_item(None, MM_HASH) + assert retrieved_item == dummy_item + + evicted, _ = lmdb_cache.evict_once(min_utilization=0.0) + assert evicted == math.ceil(ITEM_CHUNKS) + 1 + + # The item should be cached since the transaction is still open. + retrieved_item = txn.get_and_update_item(None, MM_HASH) + assert retrieved_item == dummy_item + + # But now it should be gone. + with receiver_cache.begin() as txn, pytest.raises(ValueError): + txn.get_and_update_item(None, MM_HASH) + + +def test_lmdb_concurrent_inserts(lmdb_cache: LmdbMultiModalCache): + sender_cache = LmdbObjectStoreSenderCache(lmdb_cache) + + ITEM_CHUNKS = 150.5 + MM_HASH = "fake_hash" + dummy_item = _dummy_item({"key": int(lmdb_cache._max_chunk_size * ITEM_CHUNKS)}) + prompt_update = PromptInsertion("dummy", "target", "insertion") + with sender_cache.begin() as txn: + assert not txn.is_cached_item(MM_HASH) + item_1, update_1 = txn.get_and_update_item((dummy_item, prompt_update), MM_HASH) + assert item_1 is None + + with sender_cache.begin() as other_txn: + assert not other_txn.is_cached_item(MM_HASH) + item2, update_2 = other_txn.get_and_update_item( + (dummy_item, prompt_update), MM_HASH + ) + assert item2 is None + + # Both transactions should return the same prompt update. + assert update_1 == update_2 + + # And the item should be present in a new transaction. + with sender_cache.begin() as new_txn: + assert new_txn.is_cached_item(MM_HASH) + + # And the receiver cache should see the item as well. + receiver_cache = LmdbObjectStoreWorkerReceiverCache(lmdb_cache) + with receiver_cache.begin() as txn: + retrieved_item = txn.get_and_update_item(None, MM_HASH) + assert retrieved_item == dummy_item + + evicted_items, _ = lmdb_cache.evict_once(min_utilization=0.0) + assert evicted_items == math.ceil(ITEM_CHUNKS) + 1 + + # Now it should be gone. + with sender_cache.begin() as txn: + assert not txn.is_cached_item(MM_HASH) + + evicted_items, _ = lmdb_cache.evict_once(min_utilization=0.0) + assert evicted_items == 0 + + +def test_lmdb_evictor_process(lmdb_cache: LmdbMultiModalCache): + event = get_mp_context().Event() + + MM_HASH = "fake_hash" + ITEM_CHUNKS = 150.5 + + with lmdb_cache.begin_write() as txn: + dummy_item = _dummy_item({"key": int(lmdb_cache._max_chunk_size * ITEM_CHUNKS)}) + + prompt_update = PromptInsertion("dummy", "target", "insertion") + txn.get_and_update_item((dummy_item, prompt_update), MM_HASH) + + with lmdb_cache.begin_read() as txn: + assert txn.is_cached_item(MM_HASH) + + evictor_process = lmdb_cache.start_evictor(event) + event.wait() + os.kill(evictor_process.pid, signal.SIGUSR1) + + for _ in range(5): + assert evictor_process.is_alive() + with lmdb_cache.begin_read() as txn: + if not txn.is_cached_item(MM_HASH): + break + + time.sleep(0.1) + else: + raise AssertionError("Evictor process did not evict the item in time.") diff --git a/vllm/config/model.py b/vllm/config/model.py index cea2e56ae8bf..6832edd72397 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -322,6 +322,8 @@ class ModelConfig: mm_processor_cache_gb: InitVar[float | None] = None mm_processor_cache_type: InitVar[MMCacheType | None] = None mm_shm_cache_max_object_size_mb: InitVar[int | None] = None + mm_lmdb_cache_max_object_size_mb: InitVar[int | None] = None + mm_lmdb_cache_min_eviction_age: InitVar[int | None] = None mm_encoder_only: InitVar[bool | None] = None mm_encoder_tp_mode: InitVar[MMEncoderTPMode | None] = None mm_encoder_attn_backend: InitVar[AttentionBackendEnum | str | None] = None @@ -371,6 +373,8 @@ def compute_hash(self) -> str: "mm_processor_cache_gb", "mm_processor_cache_type", "mm_shm_cache_max_object_size_mb", + "mm_lmdb_cache_max_object_size_mb", + "mm_lmdb_cache_min_eviction_age", "mm_encoder_tp_mode", "interleave_mm_strings", "skip_mm_profiling", @@ -443,6 +447,8 @@ def __post_init__( mm_processor_cache_gb: float | None, mm_processor_cache_type: MMCacheType | None, mm_shm_cache_max_object_size_mb: int | None, + mm_lmdb_cache_max_object_size_mb: int | None, + mm_lmdb_cache_min_eviction_age: int | None, mm_encoder_only: bool | None, mm_encoder_tp_mode: MMEncoderTPMode | None, mm_encoder_attn_backend: AttentionBackendEnum | str | None, @@ -637,6 +643,8 @@ def __post_init__( mm_processor_cache_gb=mm_processor_cache_gb, mm_processor_cache_type=mm_processor_cache_type, mm_shm_cache_max_object_size_mb=mm_shm_cache_max_object_size_mb, + mm_lmdb_cache_max_object_size_mb=mm_lmdb_cache_max_object_size_mb, + mm_lmdb_cache_min_eviction_age=mm_lmdb_cache_min_eviction_age, mm_encoder_only=mm_encoder_only, mm_encoder_tp_mode=mm_encoder_tp_mode, mm_encoder_attn_backend=mm_encoder_attn_backend, diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index e66511c92ab2..6b13dc62e48b 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -58,7 +58,7 @@ class MultiModalDummyOptionsBuiltins(TypedDict, total=False): MMEncoderTPMode = Literal["weights", "data"] -MMCacheType = Literal["shm", "lru"] +MMCacheType = Literal["shm", "lru", "lmdb"] MMTensorIPC = Literal["direct_rpc", "torch_shm"] MMDummyOptions: TypeAlias = dict[str, BaseDummyOptions] """ @@ -135,6 +135,14 @@ class MultiModalConfig: """Size limit (in MiB) for each object stored in the multi-modal processor shared memory cache. Only effective when `mm_processor_cache_type` is `"shm"`.""" + mm_lmdb_cache_max_object_size_mb: int = Field(default=128, ge=0) + """Size limit (in MiB) for each object stored in the multi-modal processor + shared memory cache. Only effective when `mm_processor_cache_type` is + `"lmdb"`.""" + mm_lmdb_cache_min_eviction_age: int = Field(default=600, ge=0) + """Minimum age (in seconds) before an object in the multi-modal processor + LMDB cache can be evicted. Only effective when `mm_processor_cache_type` is + `"lmdb"`. Must be long enough to cover the lifetime of a single request.""" mm_encoder_only: bool = False """ When enabled, skips the language component of the model. @@ -233,6 +241,31 @@ def _validate_multimodal_config(self): "'mm_shm_cache_max_object_size_mb' should only be set when " "'mm_processor_cache_type' is 'shm'." ) + + if self.mm_processor_cache_type != "lmdb": + if ( + self.mm_lmdb_cache_max_object_size_mb + != MultiModalConfig.mm_lmdb_cache_max_object_size_mb + ): + raise ValueError( + "'mm_lmdb_cache_max_object_size_mb' should only be set when " + "'mm_processor_cache_type' is 'lmdb'." + ) + + if ( + self.mm_lmdb_cache_min_eviction_age + != MultiModalConfig.mm_lmdb_cache_min_eviction_age + ): + raise ValueError( + "'mm_lmdb_cache_min_eviction_age' should only be set when " + "'mm_processor_cache_type' is 'lmdb'." + ) + else: + # Ensure the LMDB cache ID environment variable is set + from vllm.multimodal.lmdb_cache import LmdbMultiModalCache + + LmdbMultiModalCache.ensure_cache_id() + return self def compute_hash(self) -> str: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 90e10c1cb590..3c8fa3690f5f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -505,6 +505,12 @@ class EngineArgs: mm_shm_cache_max_object_size_mb: int = ( MultiModalConfig.mm_shm_cache_max_object_size_mb ) + mm_lmdb_cache_max_object_size_mb: int = ( + MultiModalConfig.mm_lmdb_cache_max_object_size_mb + ) + mm_lmdb_cache_min_eviction_age: int = ( + MultiModalConfig.mm_lmdb_cache_min_eviction_age + ) mm_encoder_only: bool = MultiModalConfig.mm_encoder_only mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode mm_encoder_attn_backend: AttentionBackendEnum | str | None = ( @@ -1116,6 +1122,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--mm-shm-cache-max-object-size-mb", **multimodal_kwargs["mm_shm_cache_max_object_size_mb"], ) + multimodal_group.add_argument( + "--mm-lmdb-cache-max-object-size-mb", + **multimodal_kwargs["mm_lmdb_cache_max_object_size_mb"], + ) + multimodal_group.add_argument( + "--mm-lmdb-cache-min-eviction-age", + **multimodal_kwargs["mm_lmdb_cache_min_eviction_age"], + ) multimodal_group.add_argument( "--mm-encoder-only", **multimodal_kwargs["mm_encoder_only"] ) @@ -1461,6 +1475,8 @@ def create_model_config(self) -> ModelConfig: mm_processor_cache_gb=self.mm_processor_cache_gb, mm_processor_cache_type=self.mm_processor_cache_type, mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb, + mm_lmdb_cache_max_object_size_mb=self.mm_lmdb_cache_max_object_size_mb, + mm_lmdb_cache_min_eviction_age=self.mm_lmdb_cache_min_eviction_age, mm_encoder_only=self.mm_encoder_only, mm_encoder_tp_mode=self.mm_encoder_tp_mode, mm_encoder_attn_backend=self.mm_encoder_attn_backend, diff --git a/vllm/envs.py b/vllm/envs.py index c2f8ca8c5808..f346b4bd67b7 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -77,6 +77,7 @@ VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25 VLLM_VIDEO_LOADER_BACKEND: str = "opencv" VLLM_MEDIA_CONNECTOR: str = "http" + VLLM_MM_LMDB_CACHE_ID: str | None = None VLLM_MM_HASHER_ALGORITHM: str = "blake3" VLLM_TARGET_DEVICE: str = "cuda" VLLM_MAIN_CUDA_VERSION: str = "12.9" @@ -850,6 +851,9 @@ def _get_or_set_default() -> str: # imported at runtime. # If a non-existing backend is used, an AssertionError will be thrown. "VLLM_MEDIA_CONNECTOR": lambda: os.getenv("VLLM_MEDIA_CONNECTOR", "http"), + # The LMDB multimodal cache ID to use. + # N.B. This is automatically assigned on configuration load if not set. + "VLLM_MM_LMDB_CACHE_ID": lambda: os.getenv("VLLM_MM_LMDB_CACHE_ID"), # Hash algorithm for multimodal content hashing. # - "blake3": Default, fast cryptographic hash (not FIPS 140-3 compliant) # - "sha256": FIPS 140-3 compliant, widely supported @@ -1839,6 +1843,7 @@ def compile_factors() -> dict[str, object]: "VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME", "VLLM_ASSETS_CACHE", "VLLM_ASSETS_CACHE_MODEL_CLEAN", + "VLLM_MM_LMDB_CACHE_ID", "VLLM_WORKER_MULTIPROC_METHOD", "VLLM_ENABLE_V1_MULTIPROCESSING", "VLLM_V1_OUTPUT_PROC_CHUNK_SIZE", diff --git a/vllm/multimodal/cache.py b/vllm/multimodal/cache.py index c0df19d4f483..3fca56cbd815 100644 --- a/vllm/multimodal/cache.py +++ b/vllm/multimodal/cache.py @@ -1,9 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib import operator import sys from abc import ABC, abstractmethod from collections.abc import Mapping, Sequence +from contextlib import AbstractContextManager from multiprocessing.synchronize import Lock as LockType from typing import TYPE_CHECKING, Generic, TypeAlias, TypeVar, cast @@ -17,6 +19,11 @@ SingleWriterShmRingBuffer, ) from vllm.logger import init_logger +from vllm.multimodal.lmdb_cache import ( + LmdbMultiModalCache, + LmdbReadTransaction, + LmdbWriteTransaction, +) from vllm.utils.cache import CacheInfo, LRUCache from vllm.utils.jsontree import json_count_leaves, json_map_leaves, json_reduce_leaves from vllm.utils.mem_constants import GiB_bytes, MiB_bytes @@ -264,6 +271,10 @@ class BaseMultiModalProcessorCache( ): """The required interface for caches on P0.""" + def begin(self) -> "AbstractContextManager[BaseMultiModalProcessorCache]": + """Context manager for batch cache operations.""" + return contextlib.nullcontext(self) + @abstractmethod def is_cached_item(self, mm_hash: str) -> bool: """ @@ -570,6 +581,10 @@ class BaseMultiModalReceiverCache( ): """The required interface for caches on P1.""" + def begin(self) -> "AbstractContextManager[BaseMultiModalReceiverCache]": + """Context manager for batch cache operations.""" + return contextlib.nullcontext(self) + def get_and_update_features( self, mm_features: list["MultiModalFeatureSpec"], @@ -723,3 +738,198 @@ def touch_receiver_cache_item( @override def clear_cache(self) -> None: self._shm_cache.clear() + + +class _LmdbSenderTransaction(BaseMultiModalProcessorCache, AbstractContextManager): + def __init__(self, parent: "LmdbObjectStoreSenderCache", txn: LmdbWriteTransaction): + self._parent = parent + self._txn = txn + + def __exit__(self, exc_type, exc_value, traceback): + self._txn.__exit__(exc_type, exc_value, traceback) + + @override + def is_cached_item(self, mm_hash: str) -> bool: + return self._txn.is_cached_item(mm_hash) + + @override + def get_and_update_item( + self, mm_item: MultiModalProcessorCacheInItem, mm_hash: str + ) -> MultiModalProcessorCacheOutItem: + if mm_item is None: + self._parent._hits += 1 + self._parent._total += 1 + return self._txn.get_and_update_item(mm_item, mm_hash) + + @override + def clear_cache(self) -> None: + # Handled by the engine-side cache. + pass + + @override + def touch_sender_cache_item(self, mm_hash: str) -> None: + # Not required. + pass + + @override + def make_stats(self, *, delta=False): + return self._parent.make_stats(delta=delta) + + +class LmdbObjectStoreSenderCache(BaseMultiModalProcessorCache): + """ + The P0 cache which writes items to the LMDB-backed cache. + """ + + def __init__( + self, + cache: LmdbMultiModalCache, + ) -> None: + super().__init__() + self._cache = cache + self._hits = 0 + self._total = 0 + self._last_info = CacheInfo(hits=0, total=0) + + @override + def make_stats(self, *, delta=False): + info = CacheInfo(hits=self._hits, total=self._total) + + if delta: + info_delta = info - self._last_info + self._last_info = info + info = info_delta + + return info + + @override + def begin(self) -> AbstractContextManager[BaseMultiModalProcessorCache]: + return _LmdbSenderTransaction(self, self._cache.begin_write()) + + @override + def is_cached_item(self, mm_hash: str) -> bool: + raise ValueError("Requires a transaction.") + + @override + def clear_cache(self) -> None: + # Handled by the engine-side cache. + pass + + @override + def get_and_update_item( + self, mm_item: MultiModalProcessorCacheInItem, mm_hash: str + ) -> MultiModalProcessorCacheOutItem: + raise ValueError("Requires a transaction.") + + @override + def touch_sender_cache_item(self, mm_hash: str) -> None: + raise ValueError("Requires a transaction.") + + +class LmdbObjectStoreEngineReceiverCache(BaseMultiModalReceiverCache): + """ + A dummy cache for the engine process that owns the evictor process + and file lock. + """ + + def __init__(self, cache: LmdbMultiModalCache) -> None: + super().__init__() + + self._cache = cache + self._file_lock = self._cache.lock_and_clear_stale_caches() + if self._file_lock is not None: + self._evictor = self._cache.start_evictor() + + @override + def clear_cache(self): + # If we don't have the file lock, another engine process owns the cache. + if self._file_lock is not None: + self._cache.clear() + + @override + def get_and_update_features( + self, + mm_features: list["MultiModalFeatureSpec"], + ) -> list["MultiModalFeatureSpec"]: + return mm_features + + @override + def touch_receiver_cache_item( + self, + mm_hash: str, + mm_item: MultiModalKwargsItem | None = None, + ) -> None: + raise NotImplementedError() + + @override + def get_and_update_item( + self, + mm_item: MultiModalKwargsItem | None, + mm_hash: str, + ) -> MultiModalKwargsItem: + raise NotImplementedError() + + +class _LmdbWorkerTransaction(BaseMultiModalReceiverCache, AbstractContextManager): + def __init__(self, txn: LmdbReadTransaction): + self._txn = txn + + def __exit__(self, exc_type, exc_value, traceback): + self._txn.__exit__(exc_type, exc_value, traceback) + + @override + def get_and_update_item( + self, + mm_item: MultiModalKwargsItem | None, + mm_hash: str, + ) -> MultiModalKwargsItem: + return self._txn.get_item(mm_hash) if mm_item is None else mm_item + + @override + def touch_receiver_cache_item( + self, + mm_hash: str, + mm_item: MultiModalKwargsItem | None = None, + ) -> None: + # Not necessary; nothing can be evicted under a transaction. + pass + + @override + def clear_cache(self) -> None: + # Handled by the engine-side cache. + pass + + +class LmdbObjectStoreWorkerReceiverCache(BaseMultiModalReceiverCache): + """ + The worker process cache which reads items from the LMDB cache. + """ + + def __init__(self, cache: LmdbMultiModalCache) -> None: + super().__init__() + self._cache = cache + + @override + def begin(self): + return _LmdbWorkerTransaction(self._cache.begin_read()) + + @override + def get_and_update_item( + self, + mm_item: MultiModalKwargsItem | None, + mm_hash: str, + ) -> MultiModalKwargsItem: + raise ValueError("Requires a transaction.") + + @override + def touch_receiver_cache_item( + self, + mm_hash: str, + mm_item: MultiModalKwargsItem | None = None, + ) -> None: + raise ValueError("Requires a transaction.") + + @override + def clear_cache(self) -> None: + # Handled by the engine-side cache. + pass diff --git a/vllm/multimodal/lmdb_cache.py b/vllm/multimodal/lmdb_cache.py new file mode 100644 index 000000000000..510a3dbd75c9 --- /dev/null +++ b/vllm/multimodal/lmdb_cache.py @@ -0,0 +1,704 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import multiprocessing +import os +import shutil +import signal +import threading +import time +from collections.abc import Sequence +from contextlib import AbstractContextManager, contextmanager +from typing import TYPE_CHECKING, cast +from uuid import uuid4 + +import filelock +import lmdb +from typing_extensions import Self + +from vllm import envs +from vllm.distributed.device_communicators.shm_object_storage import MsgpackSerde +from vllm.logger import init_logger +from vllm.utils.mem_constants import GiB_bytes, MiB_bytes +from vllm.utils.system_utils import decorate_logs, get_mp_context, set_process_title + +from .inputs import MultiModalKwargsItem +from .processing.processor import ResolvedPromptUpdate + +if TYPE_CHECKING: + from multiprocessing.synchronize import Event + + from vllm.config import VllmConfig + + from .cache import MultiModalProcessorCacheInItem, MultiModalProcessorCacheOutItem + +logger = init_logger(__name__) + +OPEN_ENVS_LOCK = threading.Lock() +OPEN_ENVS = dict[str, lmdb.Environment]() +REGISTERED_FORK_HANDLER = False + + +def _on_fork(): + with OPEN_ENVS_LOCK: + for env in OPEN_ENVS.values(): + env.close() + OPEN_ENVS.clear() + + +def ensure_lmdb_env(path: str, **kwargs) -> lmdb.Environment: + """Opens or reuses an LMDB environment.""" + global REGISTERED_FORK_HANDLER + + with OPEN_ENVS_LOCK: + # A given LMDB environment can only be opened once per process, + # and not across forks. + if existing_env := OPEN_ENVS.get(path): + logger.debug("Reusing existing LMDB environment at %s", path) + return existing_env + else: + lmdb_env = lmdb.Environment( + path=path, + **kwargs, + ) + + OPEN_ENVS[path] = lmdb_env + if not REGISTERED_FORK_HANDLER: + os.register_at_fork(after_in_child=_on_fork) + REGISTERED_FORK_HANDLER = True + return lmdb_env + + +class LmdbMultiModalCache: + """LMDB-based multi-modal processor cache.""" + + CACHES_DIR = os.path.join(envs.VLLM_CACHE_ROOT, "mm_caches") + + DB_TIMESTAMPS_AND_HASHES = b"timestamps_and_hashes" + DB_HASH_TO_TIMESTAMP = b"hash_to_timestamp" + DB_HASH_TO_OBJECT = b"hash_to_object" + DB_HASH_TO_PROMPT_UPDATES = b"hash_to_prompt_updates" + + MAP_SIZE_MULTIPLIER = 2 + MINIMUM_MAP_SIZE = GiB_bytes + + INT_SIZE = 4 # Used for both timestamps (seconds) and chunk indices + LMDB_PAGE_HEADER_SIZE = 16 + LMDB_PAGE_ID_SIZE = 8 + HASH_ITEM_KEY = "lmdb_mm_hash" + + EVICTOR_READER_CHECK_INTERVAL = 60.0 # seconds + EVICTOR_BATCH_SIZE_FRACTION = 0.6 # Relative to the max page ids per page + EVICTOR_MIN_UTILIZATION = 0.5 # Start evicting at 50% utilization + EVICTOR_MAX_UTILIZATION = 1.0 # Reach maximum duty cycle at 100% utilization + EVICTOR_MAX_INTERVAL = 15.0 # seconds + EVICTOR_MAX_DUTY_CYCLE = 0.1 # 10% + + def __init__( + self, + cache_dir: str, + cache_size: int, + min_eviction_age: int, + max_object_size: int, + ) -> None: + super().__init__() + os.makedirs(cache_dir, exist_ok=True) + + self._cache_dir = cache_dir + self._cache_size = cache_size + self._min_eviction_age = min_eviction_age + self._max_object_size = max_object_size + # LMDB can require additional space beyond the maximum amoun of data + # we're storing, so allocate extra space. (Once the map size is + # exhausted, even eviction is liable to fail.) + map_size = max( + self._cache_size * self.MAP_SIZE_MULTIPLIER, self.MINIMUM_MAP_SIZE + ) + self.lmdb_env = ensure_lmdb_env( + cache_dir, map_size=map_size, max_dbs=4, writemap=True, map_async=True + ) + + # Large objects are stored in chunks to fit within LMDB page size limits and + # avoid pathological free list fragmentation. + overflow_page_data_size = ( + cast(int, self.lmdb_env.stat()["psize"]) - self.LMDB_PAGE_HEADER_SIZE + ) + self._max_chunk_size = overflow_page_data_size + + # The maximum number of LMDB page IDs that can fit in a single overflow page. + # Used to try to keep the freed page IDs in a transaction within a single page. + self._max_page_ids_per_page = ( + overflow_page_data_size // self.LMDB_PAGE_ID_SIZE - 1 + ) + + # (timestamp, hash) => () + self.timestamps_and_hashes = self.lmdb_env.open_db( + self.DB_TIMESTAMPS_AND_HASHES + ) + + # hash => timestamp + self.hash_to_timestamp = self.lmdb_env.open_db(self.DB_HASH_TO_TIMESTAMP) + + # (hash, chunk_index) => MultiModalKwargsItem + self.hash_to_object = self.lmdb_env.open_db(self.DB_HASH_TO_OBJECT) + + # (hash, chunk_index) => prompt_updates + self.hash_to_prompt_updates = self.lmdb_env.open_db( + self.DB_HASH_TO_PROMPT_UPDATES + ) + + self._serde = MsgpackSerde() + self._scratch_buffers: list[bytearray] = [] + + @contextmanager + def scratch_buffer(self): + if self._scratch_buffers: + buffer = self._scratch_buffers.pop() + else: + buffer = bytearray(self._max_object_size) + try: + with memoryview(buffer) as mv: + yield mv + finally: + self._scratch_buffers.append(buffer) + + @classmethod + def ensure_cache_id(cls): + if envs.VLLM_MM_LMDB_CACHE_ID is None: + os.environ["VLLM_MM_LMDB_CACHE_ID"] = uuid4().hex + assert envs.VLLM_MM_LMDB_CACHE_ID is not None, "Cache ID must be set now." + + @classmethod + def from_vllm_config(cls, vllm_config: "VllmConfig") -> Self: + # The cache ID must be set by now. + cache_id = envs.VLLM_MM_LMDB_CACHE_ID + assert cache_id is not None, "LMDB cache ID must be set." + + mm_config = vllm_config.model_config.get_multimodal_config() + return cls( + cache_dir=os.path.join(cls.CACHES_DIR, cache_id), + cache_size=int( + mm_config.mm_processor_cache_gb * GiB_bytes, + ), + min_eviction_age=(mm_config.mm_lmdb_cache_min_eviction_age), + max_object_size=(mm_config.mm_lmdb_cache_max_object_size_mb * MiB_bytes), + ) + + def utilization(self, txn: lmdb.Transaction) -> float: + database_size = 0 + for db in ( + self.timestamps_and_hashes, + self.hash_to_timestamp, + self.hash_to_object, + self.hash_to_prompt_updates, + ): + stats = txn.stat(db=db) + database_size += stats["psize"] * ( + stats["branch_pages"] + stats["leaf_pages"] + stats["overflow_pages"] + ) + return database_size / self._cache_size + + def int2bytes(self, value: int) -> bytes: + # Use big-endian to ensure lexicographical ordering + return value.to_bytes(self.INT_SIZE, byteorder="big", signed=False) + + def bytes2int(self, value: bytes) -> int: + # Use big-endian to ensure lexicographical ordering + assert len(value) == self.INT_SIZE + return int.from_bytes(value, byteorder="big", signed=False) + + def serialize(self, item: object) -> bytearray: + value, value_size, metadata, md_size = self._serde.serialize(item) + + if value_size + md_size > self._max_object_size: + raise ValueError( + f"Object size {value_size} exceeds maximum allowed " + f"size of {self._max_object_size} bytes." + ) + + buf = bytearray(value_size + md_size) + buf[0:md_size] = metadata + idx = md_size + for chunk in value if isinstance(value, list) else [value]: + chunk_size = len(chunk) + buf[idx : idx + chunk_size] = chunk + idx += chunk_size + return buf + + def deserialize(self, item: memoryview) -> object: + return self._serde.deserialize(item) + + def get_chunked_object( + self, + db: lmdb._Database, + txn: lmdb.Transaction, + key: bytes, + buffer: memoryview, + ) -> memoryview: + with txn.cursor(db=db) as cursor: + if not cursor.set_key(key + self.int2bytes(0)): + raise ValueError(f"Key {key!r} not found in LMDB cache.") + + chunk_index = 0 + offset = 0 + while cursor.key()[-self.INT_SIZE :] == self.int2bytes(chunk_index): + chunk_value = cursor.value() + buffer[offset : offset + len(chunk_value)] = chunk_value + offset += len(chunk_value) + chunk_index += 1 + if not cursor.next(): + break + + return buffer[0:offset] + + def put_chunked_object( + self, db: lmdb._Database, txn: lmdb.Transaction, key: bytes, value: memoryview + ) -> None: + with txn.cursor(db=db) as cursor: + cursor.putmulti( + ( + ( + key + self.int2bytes(i), + value[offset : offset + self._max_chunk_size], + ) + for i, offset in enumerate( + range(0, len(value), self._max_chunk_size) + ) + ) + ) + + def delete_chunked_object( + self, db: lmdb._Database, txn: lmdb.Transaction, key: bytes + ) -> int: + """ + Deletes a chunked object from the given LMDB database and returns the number of + chunks deleted. + """ + + with txn.cursor(db=db) as cursor: + if not cursor.set_key(key + self.int2bytes(0)): + return 0 + + chunk_index = 0 + while cursor.key()[-self.INT_SIZE :] == self.int2bytes(chunk_index): + chunk_index += 1 + cursor.delete() # Deletes and advances the cursor + + return chunk_index + + def begin_write(self) -> "LmdbWriteTransaction": + return LmdbWriteTransaction(self) + + def begin_read(self): + return LmdbReadTransaction(self) + + def lock_and_clear_stale_caches(self) -> filelock.UnixFileLock | None: + cache_id = os.path.basename(self._cache_dir) + lock = filelock.UnixFileLock(self._cache_dir + ".lock") + try: + lock.acquire(blocking=False) + except filelock.Timeout: + # Another process is using this cache. + logger.debug( + "LMDB cache %s is currently in use by another process.", self._cache_dir + ) + return None + + # Clean up any existing caches that are not locked. + for entry in os.scandir(self.CACHES_DIR): + if ( + not entry.is_file() + or not entry.name.endswith(".lock") + or entry.name == f"{cache_id}.lock" + ): + continue + + other_cache = os.path.join(self.CACHES_DIR, entry.name[: -len(".lock")]) + try: + with filelock.FileLock(entry.path, blocking=False): + if os.path.exists(other_cache): + logger.info("Cleaning up stale cache at %s", other_cache) + try: + shutil.rmtree(other_cache) + except Exception as e: + logger.error( + "Failed to remove stale cache at %s: %s", other_cache, e + ) + try: + os.unlink(entry.path) + except Exception as e: + logger.error( + "Failed to remove stale cache lock file at %s: %s", + entry.path, + e, + ) + + except filelock.Timeout: + # Another process is using the cache. + logger.debug( + "Cache %s is currently in use by another process.", other_cache + ) + + return lock + + def start_evictor( + self, maybe_event: "Event | None" = None + ) -> multiprocessing.Process: + evictor = get_mp_context().Process( + name="LMDBEvictor", + target=self._evictor_main, + args=( + maybe_event, + self._cache_dir, + self._cache_size, + self._min_eviction_age, + self._max_object_size, + ), + daemon=True, + ) + evictor.start() + return evictor + + @classmethod + def _evictor_main(cls, maybe_event: "Event | None", *cache_init_args): + set_process_title("LMDBEvictor") + decorate_logs() + + cache = cls(*cache_init_args) + next_reader_check = 0.0 + + got_signal = False + wait_for_signal = threading.Semaphore(0) + signal.signal(signal.SIGUSR1, lambda signum, frame: wait_for_signal.release()) + + if maybe_event: + maybe_event.set() + + while True: + if os.getppid() == 1: + # Parent process has exited. + break + + if got_signal or time.monotonic() >= next_reader_check: + stale_readers = cache.lmdb_env.reader_check() + if stale_readers > 0: + logger.warning("Removed %d stale LMDB readers.", stale_readers) + next_reader_check = time.monotonic() + cls.EVICTOR_READER_CHECK_INTERVAL + + if got_signal: + items, delay = cache.evict_once(min_utilization=0.0) + logger.info( + "Forced eviction removed %d items from LMDB cache. " + "Next eviction in %.2f seconds.", + items, + delay, + ) + else: + _, delay = cache.evict_once() + + got_signal = wait_for_signal.acquire(timeout=delay) + + def evict_once( + self, + batch_size: int = 0, + min_utilization: float = EVICTOR_MIN_UTILIZATION, + max_utilization: float = EVICTOR_MAX_UTILIZATION, + max_interval: float = EVICTOR_MAX_INTERVAL, + max_duty_cycle: float = EVICTOR_MAX_DUTY_CYCLE, + ) -> tuple[int, float]: + """Evict items from the cache.""" + + # By default, try to keep all the evicted pages within a single overflow page. + # (Since this is the lower bound of the batch size, leave margin for items as + # well as any DB pages the transaction touches.) + batch_size = batch_size or int( + self._max_page_ids_per_page * self.EVICTOR_BATCH_SIZE_FRACTION + ) + with self.lmdb_env.begin(write=False) as txn: + utilization = self.utilization(txn) + + evicted_items = 0 + delay = max_interval + + if utilization >= min_utilization: + current_timestamp = int(time.time()) + + with self.lmdb_env.begin(write=True) as txn: + evict_start = time.perf_counter() + + with txn.cursor(db=self.timestamps_and_hashes) as cursor: + if cursor.first(): + while evicted_items < batch_size: + combined_key = cursor.key() + if not combined_key: + # No more items to evict. + break + + timestamp = self.bytes2int(combined_key[0 : self.INT_SIZE]) + + # Don't evict any items newer than the minimum eviction age. + if current_timestamp - timestamp < self._min_eviction_age: + break + + # Delete the item from all dbs. + mm_hash_key = combined_key[self.INT_SIZE :] + evicted_items += self.delete_chunked_object( + self.hash_to_object, txn, mm_hash_key + ) + evicted_items += self.delete_chunked_object( + self.hash_to_prompt_updates, txn, mm_hash_key + ) + txn.delete(mm_hash_key, db=self.hash_to_timestamp) + cursor.delete() + + eviction_duration = time.perf_counter() - evict_start + + # Calculate how long to wait before the next eviction, approaching + # the maximum duty cycle as we approach the maximum utilization. + adjusted_utilization = max( + 0.0, + min( + 1.0, + (utilization - min_utilization) + / (max_utilization - min_utilization), + ), + ) + + # Square the utilization to have a more gradual increase in duty cycle. + duty_cycle = max( + 1e-6, adjusted_utilization * adjusted_utilization * max_duty_cycle + ) + + # Ensure eviction duration is at least 1ms to avoid extremely short delays. + clamped_batch_duration = max(0.001, eviction_duration) + + # Example: If it took 2ms to evict one batch and our duty cycle is 1%, wait + # + # 2ms / 1% - 2ms = 198ms + # + # before evicting another batch. + delay = ( + min( + max_interval, + clamped_batch_duration / duty_cycle - clamped_batch_duration, + ) + if evicted_items > 0 + else max_interval + ) + + logger.debug( + "LMDB cache utilization is %.2f. " + "Evicted %d items in %.3fs. " + "Next eviction in %.3fs.", + utilization, + evicted_items, + eviction_duration, + delay, + ) + else: + logger.debug( + "LMDB cache utilization is %.2f (< %.2f)", utilization, min_utilization + ) + + return evicted_items, delay + + def clear(self) -> None: + with self.lmdb_env.begin(write=True) as txn: + txn.drop(self.timestamps_and_hashes, delete=False) + txn.drop(self.hash_to_timestamp, delete=False) + txn.drop(self.hash_to_object, delete=False) + txn.drop(self.hash_to_prompt_updates, delete=False) + + +class LmdbWriteTransaction(AbstractContextManager): + def __init__( + self, + cache: LmdbMultiModalCache, + ) -> None: + super().__init__() + self._cache = cache + self._read_txn = self._cache.lmdb_env.begin(write=False, buffers=True) + self._inserts_enabled: bool | None = None + self._write_queue = list[tuple[bytes, tuple[bytes, bytes] | None]]() + self._scratch_buffer: bytearray | None = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + try: + if not exc_type: + self._process_writes() + finally: + self._read_txn.abort() + + @property + def read_txn(self) -> lmdb.Transaction: + assert self._read_txn is not None, "Transaction not started" + return self._read_txn + + @property + def inserts_enabled(self) -> bool: + if self._inserts_enabled is None: + self._inserts_enabled = self._cache.utilization(self._read_txn) < 1.0 + return self._inserts_enabled + + def is_cached_item(self, mm_hash: str) -> bool: + mm_hash_key = mm_hash.encode("utf-8") + + return ( + self._read_txn.get(mm_hash_key, db=self._cache.hash_to_timestamp) + is not None + ) + + def get_and_update_item( + self, + mm_item: "MultiModalProcessorCacheInItem", + mm_hash: str, + ) -> "MultiModalProcessorCacheOutItem": + mm_hash_key = mm_hash.encode("utf-8") + + if mm_item is None: + # Item is cached, so just update the timestamps for the hash. + self._write_queue.append((mm_hash_key, None)) + + with self._cache.scratch_buffer() as buffer: + cached_prompt_updates = self._cache.get_chunked_object( + self._cache.hash_to_prompt_updates, + self._read_txn, + mm_hash_key, + buffer, + ) + return None, cast( + Sequence[ResolvedPromptUpdate], + self._cache.deserialize(cached_prompt_updates), + ) + if not self.inserts_enabled: + # Cache is too full, do not cache new items. + return mm_item + + # Item is not cached, serialize it and add it to the write queue. + try: + serialized_object = self._cache.serialize(mm_item[0]) + serialized_prompt_updates = self._cache.serialize(mm_item[1]) + except ValueError: + # Object is too large to cache. + return mm_item + + self._write_queue.append( + (mm_hash_key, (serialized_object, serialized_prompt_updates)) + ) + return None, mm_item[1] + + def _process_writes(self): + if not self._write_queue: + # Nothing to write. + return + + with ( + self._cache.lmdb_env.begin(write=True) as write_txn, + write_txn.cursor( + db=self._cache.hash_to_timestamp + ) as hash_to_timestamp_cursor, + ): + current_timestamp_bytes = self._cache.int2bytes(int(time.time())) + for mm_hash_key, serialized_pair in self._write_queue: + if hash_to_timestamp_cursor.set_key(mm_hash_key): + # The item is already in the cache, delete the old (timestamp, hash) + existing_timestamp = hash_to_timestamp_cursor.value() + write_txn.delete( + existing_timestamp + mm_hash_key, + db=self._cache.timestamps_and_hashes, + ) + elif serialized_pair is None: + # The item was evicted in the meantime, so we need to retrieve the + # serialized values from the read transaction. + with self._cache.scratch_buffer() as buffer: + serialized_mm_item = self._cache.get_chunked_object( + self._cache.hash_to_object, + self._read_txn, + mm_hash_key, + buffer, + ) + + self._cache.put_chunked_object( + self._cache.hash_to_object, + write_txn, + mm_hash_key, + serialized_mm_item, + ) + + serialized_prompt_updates = self._cache.get_chunked_object( + self._cache.hash_to_prompt_updates, + self._read_txn, + mm_hash_key, + buffer, + ) + + self._cache.put_chunked_object( + self._cache.hash_to_prompt_updates, + write_txn, + mm_hash_key, + serialized_prompt_updates, + ) + else: + with memoryview(serialized_pair[0]) as mv: + self._cache.put_chunked_object( + self._cache.hash_to_object, write_txn, mm_hash_key, mv + ) + + with memoryview(serialized_pair[1]) as mv: + self._cache.put_chunked_object( + self._cache.hash_to_prompt_updates, + write_txn, + mm_hash_key, + mv, + ) + + # Now update the timestamp entries. + hash_to_timestamp_cursor.put(mm_hash_key, current_timestamp_bytes) + write_txn.put( + current_timestamp_bytes + mm_hash_key, + b"", + db=self._cache.timestamps_and_hashes, + ) + + +class LmdbReadTransaction(AbstractContextManager): + def __init__( + self, + cache: LmdbMultiModalCache, + ) -> None: + self._cache = cache + self._txn = self._cache.lmdb_env.begin(write=False, buffers=True) + self._scratch_buffer: memoryview | None = None + self._scratch_buffer_ctx: AbstractContextManager[memoryview] | None = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._txn.abort() + if self._scratch_buffer_ctx is not None: + self._scratch_buffer_ctx.__exit__(exc_type, exc_value, traceback) + self._scratch_buffer = None + + @property + def scratch_buffer(self) -> memoryview: + if self._scratch_buffer is None: + self._scratch_buffer_ctx = self._cache.scratch_buffer() + self._scratch_buffer = self._scratch_buffer_ctx.__enter__() + return self._scratch_buffer + + def is_cached_item(self, mm_hash: str) -> bool: + mm_hash_bytes = mm_hash.encode("utf-8") + + return ( + self._txn.get(mm_hash_bytes, db=self._cache.hash_to_timestamp) is not None + ) + + def get_item(self, mm_hash: str) -> MultiModalKwargsItem: + mm_hash_bytes = mm_hash.encode("utf-8") + item = self._cache.get_chunked_object( + self._cache.hash_to_object, self._txn, mm_hash_bytes, self.scratch_buffer + ) + return cast(MultiModalKwargsItem, self._cache.deserialize(item)) diff --git a/vllm/multimodal/processing/processor.py b/vllm/multimodal/processing/processor.py index 7b24cd3fcb5e..21ac4c783818 100644 --- a/vllm/multimodal/processing/processor.py +++ b/vllm/multimodal/processing/processor.py @@ -1447,60 +1447,59 @@ def _cached_apply_hf_processor( Apply the HF processor on the full prompt text, caching the results and reusing cached results. """ - cache = self.cache - _, passthrough_data = self._get_hf_mm_data(inputs.mm_data_items) - if cache is None or passthrough_data: + if self.cache is None or passthrough_data: return self._apply_hf_processor(inputs, timing_ctx) with timing_ctx.record("get_mm_hashes"): mm_hashes = inputs.get_mm_hashes(self.info.model_id) - with timing_ctx.record("get_cache_missing_items"): - mm_is_cached, mm_missing_data_items = self._get_cache_missing_items( - cache=cache, - mm_data_items=inputs.mm_data_items, - mm_hashes=mm_hashes, - ) + with self.cache.begin() as cache: + with timing_ctx.record("get_cache_missing_items"): + mm_is_cached, mm_missing_data_items = self._get_cache_missing_items( + cache=cache, + mm_data_items=inputs.mm_data_items, + mm_hashes=mm_hashes, + ) - # NOTE: `prompt` does not correspond to `mm_missing_data_items`, - # so we can't apply prompt updates until the new multimodal - # items are combined with the cached multimodal items - with timing_ctx.record("apply_hf_processor"): - ( - prompt_ids, + # NOTE: `prompt` does not correspond to `mm_missing_data_items`, + # so we can't apply prompt updates until the new multimodal + # items are combined with the cached multimodal items + with timing_ctx.record("apply_hf_processor"): + ( + prompt_ids, + mm_missing_processed_data, + is_update_applied, + ) = self._apply_hf_processor_main( + prompt=inputs.prompt, + mm_items=mm_missing_data_items, + hf_processor_mm_kwargs=inputs.hf_processor_mm_kwargs, + tokenization_kwargs=inputs.tokenization_kwargs, + enable_hf_prompt_update=False, + ) + + mm_missing_kwargs = MultiModalKwargsItems.from_hf_inputs( mm_missing_processed_data, - is_update_applied, - ) = self._apply_hf_processor_main( - prompt=inputs.prompt, - mm_items=mm_missing_data_items, - hf_processor_mm_kwargs=inputs.hf_processor_mm_kwargs, - tokenization_kwargs=inputs.tokenization_kwargs, - enable_hf_prompt_update=False, + self._get_mm_fields_config( + mm_missing_processed_data, inputs.hf_processor_mm_kwargs + ), ) - mm_missing_kwargs = MultiModalKwargsItems.from_hf_inputs( - mm_missing_processed_data, - self._get_mm_fields_config( - mm_missing_processed_data, inputs.hf_processor_mm_kwargs - ), - ) - - mm_missing_prompt_updates = self._get_mm_prompt_updates( - mm_missing_data_items, - inputs.hf_processor_mm_kwargs, - mm_missing_kwargs, - ) - - with timing_ctx.record("merge_mm_kwargs"): - mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs( - cache, - mm_hashes=mm_hashes, - mm_is_cached=mm_is_cached, - mm_missing_kwargs=mm_missing_kwargs, - mm_missing_prompt_updates=mm_missing_prompt_updates, + mm_missing_prompt_updates = self._get_mm_prompt_updates( + mm_missing_data_items, + inputs.hf_processor_mm_kwargs, + mm_missing_kwargs, ) + with timing_ctx.record("merge_mm_kwargs"): + mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs( + cache, + mm_hashes=mm_hashes, + mm_is_cached=mm_is_cached, + mm_missing_kwargs=mm_missing_kwargs, + mm_missing_prompt_updates=mm_missing_prompt_updates, + ) + mm_info = MultiModalProcessingInfo( kwargs=mm_kwargs, hashes=mm_hashes, diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index fa414a5928d6..a829e375e5db 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -14,12 +14,16 @@ from .cache import ( BaseMultiModalProcessorCache, BaseMultiModalReceiverCache, + LmdbObjectStoreEngineReceiverCache, + LmdbObjectStoreSenderCache, + LmdbObjectStoreWorkerReceiverCache, MultiModalProcessorOnlyCache, MultiModalProcessorSenderCache, MultiModalReceiverCache, ShmObjectStoreReceiverCache, ShmObjectStoreSenderCache, ) +from .lmdb_cache import LmdbMultiModalCache from .processing import ( BaseDummyInputsBuilder, BaseMultiModalProcessor, @@ -252,7 +256,7 @@ def get_dummy_mm_inputs( def _get_cache_type( self, vllm_config: "VllmConfig", - ) -> Literal[None, "processor_only", "lru", "shm"]: + ) -> Literal[None, "processor_only", "lru", "shm", "lmdb"]: model_config = vllm_config.model_config if not self.supports_multimodal_inputs(model_config): return None @@ -264,10 +268,19 @@ def _get_cache_type( # Check if IPC caching is supported. parallel_config = vllm_config.parallel_config - is_ipc_supported = parallel_config._api_process_count == 1 and ( - parallel_config.data_parallel_size == 1 - or parallel_config.data_parallel_external_lb - ) + if mm_config.mm_processor_cache_type == "lmdb": + # LMDB cache is supported as long as all local + # frontends are only communicating with local + # backends. + is_ipc_supported = ( + parallel_config.data_parallel_size_local + == parallel_config.data_parallel_size + ) or parallel_config.data_parallel_external_lb + else: + is_ipc_supported = parallel_config._api_process_count == 1 and ( + parallel_config.data_parallel_size == 1 + or parallel_config.data_parallel_external_lb + ) if not is_ipc_supported: return "processor_only" @@ -289,6 +302,10 @@ def processor_cache_from_config( return MultiModalProcessorSenderCache(vllm_config.model_config) elif cache_type == "shm": return ShmObjectStoreSenderCache(vllm_config) + elif cache_type == "lmdb": + return LmdbObjectStoreSenderCache( + LmdbMultiModalCache.from_vllm_config(vllm_config) + ) else: raise ValueError(f"Unknown cache type: {cache_type!r}") @@ -313,6 +330,10 @@ def engine_receiver_cache_from_config( return None elif cache_type == "lru": return MultiModalReceiverCache(vllm_config.model_config) + elif cache_type == "lmdb": + return LmdbObjectStoreEngineReceiverCache( + LmdbMultiModalCache.from_vllm_config(vllm_config) + ) else: raise ValueError(f"Unknown cache type: {cache_type!r}") @@ -327,6 +348,10 @@ def worker_receiver_cache_from_config( return None elif cache_type == "shm": return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock) + elif cache_type == "lmdb": + return LmdbObjectStoreWorkerReceiverCache( + LmdbMultiModalCache.from_vllm_config(vllm_config) + ) else: raise ValueError(f"Unknown cache type: {cache_type!r}") diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 204c8bd0e411..4a4613433df6 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -198,9 +198,10 @@ def enc_hook(self, obj: Any) -> Any: if isinstance(obj, slice): # We are assuming only int-based values will be used here. - return tuple( - int(v) if v is not None else None - for v in (obj.start, obj.stop, obj.step) + return ( + int(obj.start) if obj.start is not None else None, + int(obj.stop) if obj.stop is not None else None, + int(obj.step) if obj.step is not None else None, ) if isinstance(obj, MultiModalKwargsItem): diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index 6a80e3f37058..01fe5de77dbc 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -319,10 +319,12 @@ def _apply_mm_cache(self, scheduler_output: SchedulerOutput) -> None: if mm_cache is None: return - for req_data in scheduler_output.scheduled_new_reqs: - req_data.mm_features = mm_cache.get_and_update_features( - req_data.mm_features - ) + if scheduler_output.scheduled_new_reqs: + with mm_cache.begin() as txn: + for req_data in scheduler_output.scheduled_new_reqs: + req_data.mm_features = txn.get_and_update_features( + req_data.mm_features + ) def execute_model( self, scheduler_output: SchedulerOutput