Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions tests/distributed/test_shm_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,16 @@
)


def _dummy_elem(modality: str, key: str, size: int):
def _dummy_elem(size: int):
return MultiModalFieldElem(
modality=modality,
key=key,
data=torch.empty((size,), dtype=torch.int8),
field=MultiModalSharedField(batch_size=1),
)


def _dummy_item(modality: str, size_by_key: dict[str, int]):
return MultiModalKwargsItem.from_elems(
[_dummy_elem(modality, key, size) for key, size in size_by_key.items()]
def _dummy_item(size_by_key: dict[str, int]):
return MultiModalKwargsItem(
{key: _dummy_elem(size) for key, size in size_by_key.items()}
)


Expand All @@ -61,7 +59,7 @@ def tearDown(self):
def test_minimal_put_get_cycle(self):
"""Test basic put and get operations."""
key = "test_key"
value = _dummy_item("text", {"field1": 10, "field2": 20})
value = _dummy_item({"field1": 10, "field2": 20})

# Put operation
address, monotonic_id = self.storage.put(key, value)
Expand Down
6 changes: 5 additions & 1 deletion tests/models/multimodal/processing/test_tensor_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,11 @@ def create_batched_mm_kwargs(
)["mm_kwargs"].require_data()

return group_mm_kwargs_by_modality(
[item for modality in supported_mm_limits for item in mm_kwargs[modality]]
[
(modality, item)
for modality in supported_mm_limits
for item in mm_kwargs[modality]
]
)


Expand Down
36 changes: 15 additions & 21 deletions tests/multimodal/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@


def _dummy_elem(
modality: str,
key: str,
size: int,
*,
rng: np.random.RandomState | None = None,
Expand All @@ -48,21 +46,18 @@ def _dummy_elem(
data = torch.from_numpy(rng.randint(4, size=(size,), dtype=np.int8))

return MultiModalFieldElem(
modality=modality,
key=key,
data=data,
field=MultiModalSharedField(batch_size=1),
)


def _dummy_item(
modality: str,
size_by_key: dict[str, int],
*,
rng: np.random.RandomState | None = None,
):
return MultiModalKwargsItem.from_elems(
[_dummy_elem(modality, key, size, rng=rng) for key, size in size_by_key.items()]
return MultiModalKwargsItem(
{key: _dummy_elem(size, rng=rng) for key, size in size_by_key.items()}
)


Expand All @@ -71,19 +66,19 @@ def _dummy_items(
*,
rng: np.random.RandomState | None = None,
):
return MultiModalKwargsItems.from_seq(
[
_dummy_item(modality, size_by_key, rng=rng)
return MultiModalKwargsItems(
{
modality: [_dummy_item(size_by_key, rng=rng)]
for modality, size_by_key in size_by_key_modality.items()
]
}
)


@pytest.mark.parametrize(
("item", "expected_size"),
[
(_dummy_item("a", {"a1": 100}), 100),
(_dummy_item("a", {"a1": 100, "a2": 110}), 210),
(_dummy_item({"a1": 100}), 100),
(_dummy_item({"a1": 100, "a2": 110}), 210),
(_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501
],
)
Expand Down Expand Up @@ -143,7 +138,7 @@ def _compare_caches(

rng = np.random.RandomState(seed)
all_items = [
_dummy_item("item", {"key": item_size_gb}, rng=rng)
_dummy_item({"key": item_size_gb}, rng=rng)
for _ in range(int(item_capacity / hit_rate))
]
all_hashes = [
Expand Down Expand Up @@ -245,13 +240,13 @@ def _run_test_cache_eviction_lru(
"image_C",
]
request1_items = {
h: MultiModalKwargsItem.dummy(h, nbytes=2 * base_item_size)
h: MultiModalKwargsItem.dummy(nbytes=2 * base_item_size)
for h in request1_hashes
}

request2_hashes = ["image_D", "image_E", "image_A", "image_C"]
request2_items = {
h: MultiModalKwargsItem.dummy(h, nbytes=1 * base_item_size)
h: MultiModalKwargsItem.dummy(nbytes=1 * base_item_size)
for h in request2_hashes
}

Expand Down Expand Up @@ -356,15 +351,14 @@ def _run_test_cache_eviction_shm(
):
request1_hashes = ["image_A", "image_B", "image_C"]
request1_items = {
h: MultiModalKwargsItem.dummy(h, nbytes=5 * base_item_size)
for h in request1_hashes
h: MultiModalKwargsItem.dummy(5 * base_item_size) for h in request1_hashes
}
request1_items_p0_result = []

request2_hashes = ["image_G", "image_A"]
request2_items = {
h: MultiModalKwargsItem.dummy(
h, nbytes=(5 if h in request1_hashes else 2) * base_item_size
(5 if h in request1_hashes else 2) * base_item_size
)
for h in request2_hashes
}
Expand All @@ -373,7 +367,7 @@ def _run_test_cache_eviction_shm(
request3_hashes = ["image_G", "image_H", "image_I", "image_B"]
request3_items = {
h: MultiModalKwargsItem.dummy(
h, nbytes=(5 if h in request1_hashes else 2) * base_item_size
(5 if h in request1_hashes else 2) * base_item_size
)
for h in request3_hashes
}
Expand Down Expand Up @@ -532,7 +526,7 @@ def test_processor_cache_shared_across_loras():
lora_a_identifier = f"12345:{base_mm_hash}"
lora_b_identifier = f"67890:{base_mm_hash}"

item_data = MultiModalKwargsItem.dummy("test_image", nbytes=1024)
item_data = MultiModalKwargsItem.dummy(1024)

feature_lora_a = MultiModalFeatureSpec(
data=item_data,
Expand Down
2 changes: 1 addition & 1 deletion tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def make_request(
for j, position in enumerate(mm_positions):
identifier = mm_hashes[j] if mm_hashes else f"hash_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
data=MultiModalKwargsItem.dummy(),
mm_position=position,
identifier=identifier,
modality="image",
Expand Down
2 changes: 1 addition & 1 deletion tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def make_request(
for j, position in enumerate(mm_positions):
identifier = mm_hashes[j] if mm_hashes else f"hash_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
data=MultiModalKwargsItem.dummy(),
mm_position=position,
identifier=identifier,
modality="image",
Expand Down
2 changes: 1 addition & 1 deletion tests/v1/core/test_priority_scheduler_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _create_random_request(
for j, position in enumerate(mm_positions):
identifier = f"{request_id}_hash_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
data=MultiModalKwargsItem.dummy(),
mm_position=position,
identifier=identifier,
modality="image",
Expand Down
2 changes: 1 addition & 1 deletion tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1707,7 +1707,7 @@ def create_requests_with_priority(
# Unique dummy hash for each mm item
identifier = f"hash{i}_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
data=MultiModalKwargsItem.dummy(),
mm_position=position,
identifier=identifier,
modality="image",
Expand Down
2 changes: 1 addition & 1 deletion tests/v1/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def create_requests(
# Unique dummy hash for each mm item
identifier = f"hash{i}_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
data=MultiModalKwargsItem.dummy(),
mm_position=position,
identifier=identifier,
modality="image",
Expand Down
4 changes: 2 additions & 2 deletions tests/v1/streaming_input/test_gpu_model_runner_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_e2e_streaming_with_multimodal_features(mock_model_runner_with_input_bat

# Step 1: Create initial request state with one multimodal feature
mm_feature_1 = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("audio"),
data=MultiModalKwargsItem.dummy(),
modality="audio",
identifier="audio_1",
mm_position=PlaceholderRange(offset=2, length=10),
Expand All @@ -158,7 +158,7 @@ def test_e2e_streaming_with_multimodal_features(mock_model_runner_with_input_bat
# The scheduler has already set prompt_token_ids to the full sequence
# (original prompt + intermediate outputs + new prompt with new multimodal feature)
mm_feature_2 = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("audio"),
data=MultiModalKwargsItem.dummy(),
modality="audio",
identifier="audio_2",
mm_position=PlaceholderRange(offset=15, length=5),
Expand Down
4 changes: 2 additions & 2 deletions tests/v1/streaming_input/test_scheduler_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def test_update_request_as_session_with_multimodal(self):
scheduler = create_scheduler()

mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("audio"),
data=MultiModalKwargsItem.dummy(),
modality="audio",
identifier="",
mm_position=PlaceholderRange(offset=1, length=1),
Expand All @@ -187,7 +187,7 @@ def test_update_request_as_session_with_multimodal(self):
session.num_computed_tokens = len(session.prompt_token_ids)

mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("audio"),
data=MultiModalKwargsItem.dummy(),
modality="audio",
identifier="",
mm_position=PlaceholderRange(offset=2, length=1),
Expand Down
23 changes: 9 additions & 14 deletions tests/v1/test_serial_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,36 +104,31 @@ class MyRequest(msgspec.Struct):

def test_multimodal_kwargs():
e1 = MultiModalFieldElem(
"audio",
"a0",
torch.zeros(1000, dtype=torch.bfloat16),
MultiModalBatchedField(),
)
e2 = MultiModalFieldElem(
"video",
"v0",
[torch.zeros(1000, dtype=torch.int8) for _ in range(4)],
MultiModalFlatField(
slices=[[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]],
dim=0,
),
)
e3 = MultiModalFieldElem(
"image",
"i0",
torch.zeros(1000, dtype=torch.int32),
MultiModalSharedField(batch_size=4),
)
e4 = MultiModalFieldElem(
"image",
"i1",
torch.zeros(1000, dtype=torch.int32),
MultiModalFlatField(slices=[slice(1, 2, 3), slice(4, 5, 6)], dim=2),
)
audio = MultiModalKwargsItem.from_elems([e1])
video = MultiModalKwargsItem.from_elems([e2])
image = MultiModalKwargsItem.from_elems([e3, e4])
mm = MultiModalKwargsItems.from_seq([audio, video, image])
mm = MultiModalKwargsItems(
{
"audio": [MultiModalKwargsItem({"a0": e1})],
"video": [MultiModalKwargsItem({"v0": e2})],
"image": [MultiModalKwargsItem({"i0": e3, "i1": e4})],
}
)

# pack mm kwargs into a mock request so that it can be decoded properly
req = MyRequest([mm])
Expand All @@ -147,8 +142,8 @@ def test_multimodal_kwargs():

total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)

# expected total encoding length, should be 14395, +-20 for minor changes
assert 14375 <= total_len <= 14425
# expected total encoding length, should be 14319, +-20 for minor changes
assert 14300 <= total_len <= 14340
decoded = decoder.decode(encoded).mm[0]
assert isinstance(decoded, MultiModalKwargsItems)

Expand Down
31 changes: 14 additions & 17 deletions vllm/multimodal/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,8 +463,8 @@ def __init__(self, vllm_config: "VllmConfig") -> None:
ring_buffer=ring_buffer,
serde_class=MsgpackSerde,
)
# cache (prompt_updates, modality) for P0 only
self._p0_cache: dict[str, tuple[Sequence[ResolvedPromptUpdate], str]] = {}
# cache prompt_updates for P0 only
self._p0_cache: dict[str, Sequence[ResolvedPromptUpdate]] = {}

self._hits = 0
self._total = 0
Expand Down Expand Up @@ -495,23 +495,22 @@ def get_and_update_item(
self._total += 1

address, monotonic_id = self._shm_cache.get_cached(mm_hash)
prompt_updates, modality = self._p0_cache[mm_hash]
return self.address_as_item(address, monotonic_id, modality), prompt_updates
prompt_updates = self._p0_cache[mm_hash]
return self.address_as_item(address, monotonic_id), prompt_updates

assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
item, prompt_updates = mm_item

self._total += 1

try:
address, monotonic_id = self._shm_cache.put(mm_hash, mm_item[0])
address, monotonic_id = self._shm_cache.put(mm_hash, item)
# Try to remove dangling items if p0 cache is too large.
if len(self._p0_cache) >= 2 * len(self._shm_cache.key_index):
self.remove_dangling_items()
self._p0_cache[mm_hash] = mm_item[1], mm_item[0].modality
address_item = self.address_as_item(
address, monotonic_id, mm_item[0].modality
)
return address_item, mm_item[1]

self._p0_cache[mm_hash] = prompt_updates
return self.address_as_item(address, monotonic_id), prompt_updates
except (ValueError, MemoryError) as e:
# put may fail if the object is too large or
# the cache is full.
Expand Down Expand Up @@ -550,22 +549,20 @@ def remove_dangling_items(self) -> None:
del self._p0_cache[mm_hash]

def address_as_item(
self, address: int, monotonic_id: int, modality: str
self,
address: int,
monotonic_id: int,
) -> MultiModalKwargsItem:
addr_elem = MultiModalFieldElem(
modality=modality,
key="address",
data=address,
field=MultiModalBatchedField(),
)
id_elem = MultiModalFieldElem(
modality=modality,
key="monotonic_id",
data=monotonic_id,
field=MultiModalBatchedField(),
)
mm_item = MultiModalKwargsItem.from_elems([addr_elem, id_elem])
return mm_item

return MultiModalKwargsItem({"address": addr_elem, "monotonic_id": id_elem})


class BaseMultiModalReceiverCache(
Expand Down
Loading