Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ setproctitle # Used to set process names for better debugging and monitoring
openai-harmony >= 0.0.3 # Required for gpt-oss
anthropic >= 0.71.0
model-hosting-container-standards >= 0.1.13, < 1.0.0
lmdb >= 1.7.5 # Used by LMDB-based multimodal cache
mcp
opentelemetry-sdk >= 1.27.0
opentelemetry-api >= 1.27.0
Expand Down
237 changes: 237 additions & 0 deletions tests/multimodal/test_lmdb_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import multiprocessing
import os
import signal
import tempfile
import time

import numpy as np
import pytest
import torch

from vllm.multimodal.cache import (
LmdbObjectStoreSenderCache,
LmdbObjectStoreWorkerReceiverCache,
)
from vllm.multimodal.inputs import (
MultiModalFieldElem,
MultiModalKwargsItem,
MultiModalSharedField,
)
from vllm.multimodal.lmdb_cache import LmdbMultiModalCache, ensure_lmdb_env
from vllm.multimodal.processing import PromptInsertion
from vllm.utils.mem_constants import GiB_bytes, MiB_bytes
from vllm.utils.system_utils import get_mp_context


def _dummy_elem(
size: int,
*,
rng: np.random.RandomState | None = None,
):
if rng is None:
data = torch.empty((size,), dtype=torch.int8)
else:
data = torch.from_numpy(rng.randint(4, size=(size,), dtype=np.int8))

return MultiModalFieldElem(
data=data,
field=MultiModalSharedField(batch_size=1),
)


def _dummy_item(
size_by_key: dict[str, int],
*,
rng: np.random.RandomState | None = None,
):
return MultiModalKwargsItem(
{key: _dummy_elem(size, rng=rng) for key, size in size_by_key.items()}
)


@pytest.fixture()
def lmdb_cache():
with tempfile.TemporaryDirectory() as tmpdir:
yield LmdbMultiModalCache(
tmpdir,
cache_size=GiB_bytes,
min_eviction_age=-1,
max_object_size=10 * MiB_bytes,
)


def test_ensure_lmdb_env():
with tempfile.TemporaryDirectory() as tmpdir:
env_1 = ensure_lmdb_env(tmpdir)
env_2 = ensure_lmdb_env(tmpdir)

assert env_1 is env_2

def _child():
env_3 = ensure_lmdb_env(tmpdir)
assert env_3 is not env_1

p = multiprocessing.get_context("fork").Process(target=_child)
p.start()
p.join()
assert p.exitcode == 0


def test_lmdb_insert_get_evict(lmdb_cache: LmdbMultiModalCache):
sender_cache = LmdbObjectStoreSenderCache(lmdb_cache)

MM_HASH = "fake_hash"
ITEM_CHUNKS = 150.5
dummy_item = _dummy_item({"key": int(lmdb_cache._max_chunk_size * ITEM_CHUNKS)})
prompt_update = PromptInsertion("dummy", "target", "insertion")

# Do two rounds to ensure inserting after evicting works correctly.
for _ in range(2):
with sender_cache.begin() as txn:
assert not txn.is_cached_item(MM_HASH)
should_be_none, update_1 = txn.get_and_update_item(
(dummy_item, prompt_update), MM_HASH
)
assert should_be_none is None
assert update_1 == prompt_update

receiver_cache = LmdbObjectStoreWorkerReceiverCache(lmdb_cache)
with receiver_cache.begin() as txn:
retrieved_item = txn.get_and_update_item(None, MM_HASH)
assert retrieved_item == dummy_item

with sender_cache.begin() as txn:
assert txn.is_cached_item(MM_HASH)
should_be_none, update_2 = txn.get_and_update_item(None, MM_HASH)
assert should_be_none is None
assert update_2 == prompt_update

evicted_items, _ = lmdb_cache.evict_once(min_utilization=0.0)
assert evicted_items == math.ceil(ITEM_CHUNKS) + 1


def test_lmdb_eviction_with_transaction_open(lmdb_cache: LmdbMultiModalCache):
sender_cache = LmdbObjectStoreSenderCache(lmdb_cache)

MM_HASH = "fake_hash"
ITEM_CHUNKS = 150.5
dummy_item = _dummy_item({"key": int(lmdb_cache._max_chunk_size * ITEM_CHUNKS)})
prompt_update = PromptInsertion("dummy", "target", "insertion")
with sender_cache.begin() as txn:
txn.get_and_update_item((dummy_item, prompt_update), MM_HASH)

with sender_cache.begin() as txn:
assert txn.is_cached_item(MM_HASH)
evicted, _ = lmdb_cache.evict_once(min_utilization=0.0)
assert evicted == math.ceil(ITEM_CHUNKS) + 1

# The item should be cached since the transaction is still open.
assert txn.is_cached_item(MM_HASH)

# But a separate transaction should not see the item.
with sender_cache.begin() as other_txn:
assert not other_txn.is_cached_item(MM_HASH)

should_be_none, update = txn.get_and_update_item(None, MM_HASH)
assert should_be_none is None
assert update == prompt_update

# After the transaction commits, the item should be there.
with sender_cache.begin() as txn:
assert txn.is_cached_item(MM_HASH)
should_be_none, update = txn.get_and_update_item(None, MM_HASH)
assert should_be_none is None
assert update == prompt_update

# And the receiver cache should see the item as well.
receiver_cache = LmdbObjectStoreWorkerReceiverCache(lmdb_cache)
with receiver_cache.begin() as txn:
retrieved_item = txn.get_and_update_item(None, MM_HASH)
assert retrieved_item == dummy_item

evicted, _ = lmdb_cache.evict_once(min_utilization=0.0)
assert evicted == math.ceil(ITEM_CHUNKS) + 1

# The item should be cached since the transaction is still open.
retrieved_item = txn.get_and_update_item(None, MM_HASH)
assert retrieved_item == dummy_item

# But now it should be gone.
with receiver_cache.begin() as txn, pytest.raises(ValueError):
txn.get_and_update_item(None, MM_HASH)


def test_lmdb_concurrent_inserts(lmdb_cache: LmdbMultiModalCache):
sender_cache = LmdbObjectStoreSenderCache(lmdb_cache)

ITEM_CHUNKS = 150.5
MM_HASH = "fake_hash"
dummy_item = _dummy_item({"key": int(lmdb_cache._max_chunk_size * ITEM_CHUNKS)})
prompt_update = PromptInsertion("dummy", "target", "insertion")
with sender_cache.begin() as txn:
assert not txn.is_cached_item(MM_HASH)
item_1, update_1 = txn.get_and_update_item((dummy_item, prompt_update), MM_HASH)
assert item_1 is None

with sender_cache.begin() as other_txn:
assert not other_txn.is_cached_item(MM_HASH)
item2, update_2 = other_txn.get_and_update_item(
(dummy_item, prompt_update), MM_HASH
)
assert item2 is None

# Both transactions should return the same prompt update.
assert update_1 == update_2

# And the item should be present in a new transaction.
with sender_cache.begin() as new_txn:
assert new_txn.is_cached_item(MM_HASH)

# And the receiver cache should see the item as well.
receiver_cache = LmdbObjectStoreWorkerReceiverCache(lmdb_cache)
with receiver_cache.begin() as txn:
retrieved_item = txn.get_and_update_item(None, MM_HASH)
assert retrieved_item == dummy_item

evicted_items, _ = lmdb_cache.evict_once(min_utilization=0.0)
assert evicted_items == math.ceil(ITEM_CHUNKS) + 1

# Now it should be gone.
with sender_cache.begin() as txn:
assert not txn.is_cached_item(MM_HASH)

evicted_items, _ = lmdb_cache.evict_once(min_utilization=0.0)
assert evicted_items == 0


def test_lmdb_evictor_process(lmdb_cache: LmdbMultiModalCache):
event = get_mp_context().Event()

MM_HASH = "fake_hash"
ITEM_CHUNKS = 150.5

with lmdb_cache.begin_write() as txn:
dummy_item = _dummy_item({"key": int(lmdb_cache._max_chunk_size * ITEM_CHUNKS)})

prompt_update = PromptInsertion("dummy", "target", "insertion")
txn.get_and_update_item((dummy_item, prompt_update), MM_HASH)

with lmdb_cache.begin_read() as txn:
assert txn.is_cached_item(MM_HASH)

evictor_process = lmdb_cache.start_evictor(event)
event.wait()
os.kill(evictor_process.pid, signal.SIGUSR1)

for _ in range(5):
assert evictor_process.is_alive()
with lmdb_cache.begin_read() as txn:
if not txn.is_cached_item(MM_HASH):
break

time.sleep(0.1)
else:
raise AssertionError("Evictor process did not evict the item in time.")
8 changes: 8 additions & 0 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,8 @@ class ModelConfig:
mm_processor_cache_gb: InitVar[float | None] = None
mm_processor_cache_type: InitVar[MMCacheType | None] = None
mm_shm_cache_max_object_size_mb: InitVar[int | None] = None
mm_lmdb_cache_max_object_size_mb: InitVar[int | None] = None
mm_lmdb_cache_min_eviction_age: InitVar[int | None] = None
mm_encoder_only: InitVar[bool | None] = None
mm_encoder_tp_mode: InitVar[MMEncoderTPMode | None] = None
mm_encoder_attn_backend: InitVar[AttentionBackendEnum | str | None] = None
Expand Down Expand Up @@ -371,6 +373,8 @@ def compute_hash(self) -> str:
"mm_processor_cache_gb",
"mm_processor_cache_type",
"mm_shm_cache_max_object_size_mb",
"mm_lmdb_cache_max_object_size_mb",
"mm_lmdb_cache_min_eviction_age",
"mm_encoder_tp_mode",
"interleave_mm_strings",
"skip_mm_profiling",
Expand Down Expand Up @@ -443,6 +447,8 @@ def __post_init__(
mm_processor_cache_gb: float | None,
mm_processor_cache_type: MMCacheType | None,
mm_shm_cache_max_object_size_mb: int | None,
mm_lmdb_cache_max_object_size_mb: int | None,
mm_lmdb_cache_min_eviction_age: int | None,
mm_encoder_only: bool | None,
mm_encoder_tp_mode: MMEncoderTPMode | None,
mm_encoder_attn_backend: AttentionBackendEnum | str | None,
Expand Down Expand Up @@ -637,6 +643,8 @@ def __post_init__(
mm_processor_cache_gb=mm_processor_cache_gb,
mm_processor_cache_type=mm_processor_cache_type,
mm_shm_cache_max_object_size_mb=mm_shm_cache_max_object_size_mb,
mm_lmdb_cache_max_object_size_mb=mm_lmdb_cache_max_object_size_mb,
mm_lmdb_cache_min_eviction_age=mm_lmdb_cache_min_eviction_age,
mm_encoder_only=mm_encoder_only,
mm_encoder_tp_mode=mm_encoder_tp_mode,
mm_encoder_attn_backend=mm_encoder_attn_backend,
Expand Down
35 changes: 34 additions & 1 deletion vllm/config/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class MultiModalDummyOptionsBuiltins(TypedDict, total=False):


MMEncoderTPMode = Literal["weights", "data"]
MMCacheType = Literal["shm", "lru"]
MMCacheType = Literal["shm", "lru", "lmdb"]
MMTensorIPC = Literal["direct_rpc", "torch_shm"]
MMDummyOptions: TypeAlias = dict[str, BaseDummyOptions]
"""
Expand Down Expand Up @@ -135,6 +135,14 @@ class MultiModalConfig:
"""Size limit (in MiB) for each object stored in the multi-modal processor
shared memory cache. Only effective when `mm_processor_cache_type` is
`"shm"`."""
mm_lmdb_cache_max_object_size_mb: int = Field(default=128, ge=0)
"""Size limit (in MiB) for each object stored in the multi-modal processor
shared memory cache. Only effective when `mm_processor_cache_type` is
`"lmdb"`."""
mm_lmdb_cache_min_eviction_age: int = Field(default=600, ge=0)
"""Minimum age (in seconds) before an object in the multi-modal processor
LMDB cache can be evicted. Only effective when `mm_processor_cache_type` is
`"lmdb"`. Must be long enough to cover the lifetime of a single request."""
mm_encoder_only: bool = False
"""
When enabled, skips the language component of the model.
Expand Down Expand Up @@ -233,6 +241,31 @@ def _validate_multimodal_config(self):
"'mm_shm_cache_max_object_size_mb' should only be set when "
"'mm_processor_cache_type' is 'shm'."
)

if self.mm_processor_cache_type != "lmdb":
if (
self.mm_lmdb_cache_max_object_size_mb
!= MultiModalConfig.mm_lmdb_cache_max_object_size_mb
):
raise ValueError(
"'mm_lmdb_cache_max_object_size_mb' should only be set when "
"'mm_processor_cache_type' is 'lmdb'."
)

if (
self.mm_lmdb_cache_min_eviction_age
!= MultiModalConfig.mm_lmdb_cache_min_eviction_age
):
raise ValueError(
"'mm_lmdb_cache_min_eviction_age' should only be set when "
"'mm_processor_cache_type' is 'lmdb'."
)
else:
# Ensure the LMDB cache ID environment variable is set
from vllm.multimodal.lmdb_cache import LmdbMultiModalCache

LmdbMultiModalCache.ensure_cache_id()

return self

def compute_hash(self) -> str:
Expand Down
16 changes: 16 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,12 @@ class EngineArgs:
mm_shm_cache_max_object_size_mb: int = (
MultiModalConfig.mm_shm_cache_max_object_size_mb
)
mm_lmdb_cache_max_object_size_mb: int = (
MultiModalConfig.mm_lmdb_cache_max_object_size_mb
)
mm_lmdb_cache_min_eviction_age: int = (
MultiModalConfig.mm_lmdb_cache_min_eviction_age
)
mm_encoder_only: bool = MultiModalConfig.mm_encoder_only
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
mm_encoder_attn_backend: AttentionBackendEnum | str | None = (
Expand Down Expand Up @@ -1116,6 +1122,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"--mm-shm-cache-max-object-size-mb",
**multimodal_kwargs["mm_shm_cache_max_object_size_mb"],
)
multimodal_group.add_argument(
"--mm-lmdb-cache-max-object-size-mb",
**multimodal_kwargs["mm_lmdb_cache_max_object_size_mb"],
)
multimodal_group.add_argument(
"--mm-lmdb-cache-min-eviction-age",
**multimodal_kwargs["mm_lmdb_cache_min_eviction_age"],
)
multimodal_group.add_argument(
"--mm-encoder-only", **multimodal_kwargs["mm_encoder_only"]
)
Expand Down Expand Up @@ -1461,6 +1475,8 @@ def create_model_config(self) -> ModelConfig:
mm_processor_cache_gb=self.mm_processor_cache_gb,
mm_processor_cache_type=self.mm_processor_cache_type,
mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb,
mm_lmdb_cache_max_object_size_mb=self.mm_lmdb_cache_max_object_size_mb,
mm_lmdb_cache_min_eviction_age=self.mm_lmdb_cache_min_eviction_age,
mm_encoder_only=self.mm_encoder_only,
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
mm_encoder_attn_backend=self.mm_encoder_attn_backend,
Expand Down
Loading
Loading