Skip to content

Commit ae40f84

Browse files
committed
Fix unit tests w/ speical tokens
Signed-off-by: Chang Liu (Enterprise Products) <[email protected]>
1 parent 6bdbe68 commit ae40f84

File tree

1 file changed

+183
-14
lines changed

1 file changed

+183
-14
lines changed

tests/unittest/_torch/multimodal/test_multimodal_runtime.py

Lines changed: 183 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ def test_fully_cached_multimodal_tokens(self):
1919
past_seen_token_num=20,
2020
mm_token_lengths=[5, 8, 7], # Total: 20 tokens
2121
mm_token_positions=[0, 5, 13], # Positions: 0-5, 5-13, 13-20
22-
chunk_end_pos=20)
22+
chunk_end_pos=20,
23+
special_token_offsets=[])
2324

2425
# All tokens should be cached since past_seen_token_num (20) >= all positions + lengths
2526
assert runtime.num_unseen_mm_tokens == 20
@@ -32,7 +33,8 @@ def test_no_cached_multimodal_tokens(self):
3233
mm_token_lengths=[5, 8, 7], # Total: 20 tokens
3334
mm_token_positions=[10, 18,
3435
30], # All positions > past_seen_token_num
35-
chunk_end_pos=40)
36+
chunk_end_pos=40,
37+
special_token_offsets=[])
3638

3739
# No multimodal tokens should be cached
3840
assert runtime.num_unseen_mm_tokens == 0
@@ -44,7 +46,8 @@ def test_partial_caching_with_chunk_boundaries(self):
4446
past_seen_token_num=15,
4547
mm_token_lengths=[5, 8, 7], # Total: 20 tokens
4648
mm_token_positions=[10, 18, 25], # Positions: 10-15, 18-26, 25-32
47-
chunk_end_pos=30)
49+
chunk_end_pos=30,
50+
special_token_offsets=[])
4851

4952
# Expected caching:
5053
# Chunk 0: [10-15] - 5 tokens fully cached, 0 tokens in current chunk
@@ -59,7 +62,8 @@ def test_chunk_boundary_case1(self):
5962
past_seen_token_num=12,
6063
mm_token_lengths=[6, 4, 8], # Total: 18 tokens
6164
mm_token_positions=[8, 16, 22], # Positions: 8-14, 16-20, 22-30
62-
chunk_end_pos=20)
65+
chunk_end_pos=20,
66+
special_token_offsets=[])
6367

6468
# Expected caching:
6569
# Chunk 0: [8-14] - 4 tokens cached (8-12), 2 tokens in current chunk (12-14)
@@ -76,7 +80,8 @@ def test_chunk_boundary_case2(self):
7680
mm_token_positions=[
7781
0, 5, 10, 15, 25, 35
7882
], # Positions: 0-3, 5-9, 10-15, 15-21, 25-32, 35-43
79-
chunk_end_pos=100)
83+
chunk_end_pos=100,
84+
special_token_offsets=[])
8085

8186
expected_cached = 3 + 4 + 5 + 6 + 5 # 23 tokens
8287
expected_current_chunk = 2 + 8 # 10 tokens
@@ -94,44 +99,55 @@ def test_validation_errors(self):
9499
MultimodalRuntimeData(past_seen_token_num=10,
95100
mm_token_lengths=[5, 8, 7],
96101
mm_token_positions=[0, 5],
97-
chunk_end_pos=20)
102+
chunk_end_pos=20,
103+
special_token_offsets=[])
98104

99105
# Test negative past_seen_token_num
100106
with pytest.raises(ValueError,
101107
match="past_seen_token_num must be non-negative"):
102108
MultimodalRuntimeData(past_seen_token_num=-1,
103109
mm_token_lengths=[5],
104110
mm_token_positions=[0],
105-
chunk_end_pos=10)
111+
chunk_end_pos=10,
112+
special_token_offsets=[])
106113

107114
# Test non-positive token lengths
108115
with pytest.raises(ValueError,
109116
match="All mm_token_lengths must be positive"):
110117
MultimodalRuntimeData(past_seen_token_num=10,
111118
mm_token_lengths=[5, 0, 7],
112119
mm_token_positions=[0, 5, 10],
113-
chunk_end_pos=20)
120+
chunk_end_pos=20,
121+
special_token_offsets=[])
114122

115123
# Test negative positions
116124
with pytest.raises(ValueError,
117125
match="All mm_token_positions must be non-negative"):
118126
MultimodalRuntimeData(past_seen_token_num=10,
119127
mm_token_lengths=[5, 8, 7],
120128
mm_token_positions=[0, -5, 10],
121-
chunk_end_pos=20)
129+
chunk_end_pos=20,
130+
special_token_offsets=[])
122131

123132

124133
class TestFindInputMmEmbed:
125134
"""Focused test cases for find_input_mm_embeds function - testing both KV cache reuse and chunked prefill."""
126135

127-
def create_mock_runtime(self, num_unseen_mm_tokens: int,
136+
def create_mock_runtime(self,
137+
num_unseen_mm_tokens: int,
128138
num_mm_tokens_in_chunk: int,
129-
mm_token_lengths: List[int]):
139+
mm_token_lengths: List[int],
140+
num_unseen_special_tokens: int = 0,
141+
num_special_tokens_in_chunk: int = 0,
142+
total_special_tokens_in_request: int = 0):
130143
"""Helper to create a mock MultimodalRuntimeData."""
131144
runtime = Mock(spec=MultimodalRuntimeData)
132145
runtime.num_unseen_mm_tokens = num_unseen_mm_tokens
133146
runtime.num_mm_tokens_in_chunk = num_mm_tokens_in_chunk
134147
runtime.total_mm_tokens_in_request = sum(mm_token_lengths)
148+
runtime.num_unseen_special_tokens = num_unseen_special_tokens
149+
runtime.num_special_tokens_in_chunk = num_special_tokens_in_chunk
150+
runtime.total_special_tokens_in_request = total_special_tokens_in_request
135151

136152
return runtime
137153

@@ -365,22 +381,68 @@ def test_different_devices(self):
365381
result = find_input_mm_embeds(mm_embeds, multimodal_params)
366382
assert result[0].device == mm_embeds[0].device
367383

384+
def test_special_tokens_in_batched_mode(self):
385+
"""Test special token handling in batched mode."""
386+
mm_embeds = [torch.randn(12, 512)
387+
] # Pre-concatenated: (8-2) + (10-4) = 6 + 6 = 12 tokens
388+
multimodal_params = [
389+
self.create_mock_runtime(num_unseen_mm_tokens=2,
390+
num_mm_tokens_in_chunk=6,
391+
mm_token_lengths=[8],
392+
num_unseen_special_tokens=1,
393+
num_special_tokens_in_chunk=1,
394+
total_special_tokens_in_request=2),
395+
self.create_mock_runtime(num_unseen_mm_tokens=4,
396+
num_mm_tokens_in_chunk=6,
397+
mm_token_lengths=[10],
398+
num_unseen_special_tokens=2,
399+
num_special_tokens_in_chunk=2,
400+
total_special_tokens_in_request=4)
401+
]
402+
multimodal_params = [
403+
MultimodalParams(multimodal_runtime=runtime)
404+
for runtime in multimodal_params
405+
]
406+
407+
result = find_input_mm_embeds(mm_embeds, multimodal_params)
408+
409+
# Expected slices accounting for special tokens:
410+
# Batch 1: local_start = 2-1=1, local_end = 1+(6-1)=6, slice [1:6] = 5 tokens
411+
# Batch 2: local_start = 4-2=2, local_end = 2+(6-2)=6, slice [6+2:6+6] = [8:12] = 4 tokens
412+
# Total: 5 + 4 = 9 tokens
413+
assert len(result) == 1
414+
assert result[0].shape == (9, 512)
415+
416+
# Verify the slices are correct
417+
expected = torch.cat(
418+
[
419+
mm_embeds[0][1:6], # Batch 1: 5 tokens
420+
mm_embeds[0][8:12] # Batch 2: 4 tokens
421+
],
422+
dim=0)
423+
torch.testing.assert_close(result[0], expected)
424+
368425

369426
class TestGetMultimodalEmbeddings:
370427
"""Test cases for get_multimodal_embeddings function - testing caching and encoder forward optimization."""
371428

372-
def create_mock_runtime(self, total_mm_tokens: int):
373-
"""Helper to create a mock MultimodalRuntimeData with total_mm_tokens."""
429+
def create_mock_runtime(self,
430+
total_mm_tokens: int,
431+
total_special_tokens: int = 0):
432+
"""Helper to create a mock MultimodalRuntimeData with total_mm_tokens and special_tokens."""
374433
runtime = Mock(spec=MultimodalRuntimeData)
375434
runtime.total_mm_tokens_in_request = total_mm_tokens
435+
runtime.total_special_tokens_in_request = total_special_tokens
376436
return runtime
377437

378438
def create_multimodal_params_with_data(self,
379439
has_cached_embedding: bool = False,
380440
total_mm_tokens: int = 10,
441+
total_special_tokens: int = 0,
381442
cached_embedding=None):
382443
"""Helper to create MultimodalParams with optional cached embeddings."""
383-
runtime = self.create_mock_runtime(total_mm_tokens)
444+
runtime = self.create_mock_runtime(total_mm_tokens,
445+
total_special_tokens)
384446

385447
multimodal_data = {
386448
# Add some dummy multimodal data to ensure has_content() returns True
@@ -663,6 +725,113 @@ def mock_encoder(params):
663725
assert multimodal_params[0].multimodal_data[
664726
"multimodal_embedding"].device.type == 'cuda'
665727

728+
def test_special_tokens_basic_caching(self):
729+
"""Test caching behavior with special tokens present."""
730+
731+
def mock_encoder(params):
732+
# Return embeddings for non-special tokens only
733+
# Total: (10-2) + (8-1) + (6-3) = 8 + 7 + 3 = 18 tokens
734+
return [torch.randn(18, 512)]
735+
736+
multimodal_params = [
737+
self.create_multimodal_params_with_data(
738+
has_cached_embedding=False,
739+
total_mm_tokens=10,
740+
total_special_tokens=2), # 8 actual embedding tokens
741+
self.create_multimodal_params_with_data(
742+
has_cached_embedding=False,
743+
total_mm_tokens=8,
744+
total_special_tokens=1), # 7 actual embedding tokens
745+
self.create_multimodal_params_with_data(
746+
has_cached_embedding=False,
747+
total_mm_tokens=6,
748+
total_special_tokens=3) # 3 actual embedding tokens
749+
]
750+
751+
result = get_multimodal_embeddings(mock_encoder, multimodal_params)
752+
753+
# Should return concatenated embeddings
754+
assert len(result) == 1
755+
assert result[0].shape == (18, 512) # 8 + 7 + 3 = 18 tokens
756+
757+
# Check that embeddings were split correctly based on non-special token counts
758+
assert multimodal_params[0].multimodal_data[
759+
"multimodal_embedding"].shape == (8, 512) # 10 - 2
760+
assert multimodal_params[1].multimodal_data[
761+
"multimodal_embedding"].shape == (7, 512) # 8 - 1
762+
assert multimodal_params[2].multimodal_data[
763+
"multimodal_embedding"].shape == (3, 512) # 6 - 3
764+
765+
def test_special_tokens_all_special(self):
766+
"""Test edge case where all tokens are special tokens."""
767+
768+
def mock_encoder(params):
769+
# Should return empty tensor when no actual embedding tokens
770+
return [torch.randn(0, 512)]
771+
772+
multimodal_params = [
773+
self.create_multimodal_params_with_data(
774+
has_cached_embedding=False,
775+
total_mm_tokens=5,
776+
total_special_tokens=5), # All tokens are special
777+
self.create_multimodal_params_with_data(
778+
has_cached_embedding=False,
779+
total_mm_tokens=3,
780+
total_special_tokens=3) # All tokens are special
781+
]
782+
783+
result = get_multimodal_embeddings(mock_encoder, multimodal_params)
784+
785+
# Should return empty embeddings
786+
assert len(result) == 1
787+
assert result[0].shape == (0, 512)
788+
789+
# Cached embeddings should also be empty
790+
assert multimodal_params[0].multimodal_data[
791+
"multimodal_embedding"].shape == (0, 512)
792+
assert multimodal_params[1].multimodal_data[
793+
"multimodal_embedding"].shape == (0, 512)
794+
795+
def test_special_tokens_mixed_with_cached(self):
796+
"""Test special tokens with mixed cached and uncached params."""
797+
encoder_call_count = 0
798+
799+
def mock_encoder(params):
800+
nonlocal encoder_call_count
801+
encoder_call_count += 1
802+
# Only process uncached param: 12 - 3 = 9 tokens
803+
return [torch.randn(9, 512)]
804+
805+
# Mix: cached (with special tokens), uncached (with special tokens)
806+
cached_emb = torch.randn(4, 512) # 6 - 2 = 4 actual tokens
807+
multimodal_params = [
808+
self.create_multimodal_params_with_data(
809+
has_cached_embedding=True,
810+
total_mm_tokens=6,
811+
total_special_tokens=2,
812+
cached_embedding=cached_emb),
813+
self.create_multimodal_params_with_data(
814+
has_cached_embedding=False,
815+
total_mm_tokens=12,
816+
total_special_tokens=3) # 9 actual embedding tokens
817+
]
818+
819+
result = get_multimodal_embeddings(mock_encoder, multimodal_params)
820+
821+
# Encoder should be called once for uncached param
822+
assert encoder_call_count == 1
823+
824+
# Should return concatenated embeddings: 4 + 9 = 13 tokens
825+
assert len(result) == 1
826+
assert result[0].shape == (13, 512)
827+
828+
# Verify cached embedding is preserved and uncached is now cached
829+
torch.testing.assert_close(
830+
multimodal_params[0].multimodal_data["multimodal_embedding"],
831+
cached_emb)
832+
assert multimodal_params[1].multimodal_data[
833+
"multimodal_embedding"].shape == (9, 512)
834+
666835

667836
if __name__ == "__main__":
668837
pytest.main([__file__])

0 commit comments

Comments
 (0)