@@ -19,7 +19,8 @@ def test_fully_cached_multimodal_tokens(self):
1919 past_seen_token_num = 20 ,
2020 mm_token_lengths = [5 , 8 , 7 ], # Total: 20 tokens
2121 mm_token_positions = [0 , 5 , 13 ], # Positions: 0-5, 5-13, 13-20
22- chunk_end_pos = 20 )
22+ chunk_end_pos = 20 ,
23+ special_token_offsets = [])
2324
2425 # All tokens should be cached since past_seen_token_num (20) >= all positions + lengths
2526 assert runtime .num_unseen_mm_tokens == 20
@@ -32,7 +33,8 @@ def test_no_cached_multimodal_tokens(self):
3233 mm_token_lengths = [5 , 8 , 7 ], # Total: 20 tokens
3334 mm_token_positions = [10 , 18 ,
3435 30 ], # All positions > past_seen_token_num
35- chunk_end_pos = 40 )
36+ chunk_end_pos = 40 ,
37+ special_token_offsets = [])
3638
3739 # No multimodal tokens should be cached
3840 assert runtime .num_unseen_mm_tokens == 0
@@ -44,7 +46,8 @@ def test_partial_caching_with_chunk_boundaries(self):
4446 past_seen_token_num = 15 ,
4547 mm_token_lengths = [5 , 8 , 7 ], # Total: 20 tokens
4648 mm_token_positions = [10 , 18 , 25 ], # Positions: 10-15, 18-26, 25-32
47- chunk_end_pos = 30 )
49+ chunk_end_pos = 30 ,
50+ special_token_offsets = [])
4851
4952 # Expected caching:
5053 # Chunk 0: [10-15] - 5 tokens fully cached, 0 tokens in current chunk
@@ -59,7 +62,8 @@ def test_chunk_boundary_case1(self):
5962 past_seen_token_num = 12 ,
6063 mm_token_lengths = [6 , 4 , 8 ], # Total: 18 tokens
6164 mm_token_positions = [8 , 16 , 22 ], # Positions: 8-14, 16-20, 22-30
62- chunk_end_pos = 20 )
65+ chunk_end_pos = 20 ,
66+ special_token_offsets = [])
6367
6468 # Expected caching:
6569 # Chunk 0: [8-14] - 4 tokens cached (8-12), 2 tokens in current chunk (12-14)
@@ -76,7 +80,8 @@ def test_chunk_boundary_case2(self):
7680 mm_token_positions = [
7781 0 , 5 , 10 , 15 , 25 , 35
7882 ], # Positions: 0-3, 5-9, 10-15, 15-21, 25-32, 35-43
79- chunk_end_pos = 100 )
83+ chunk_end_pos = 100 ,
84+ special_token_offsets = [])
8085
8186 expected_cached = 3 + 4 + 5 + 6 + 5 # 23 tokens
8287 expected_current_chunk = 2 + 8 # 10 tokens
@@ -94,44 +99,55 @@ def test_validation_errors(self):
9499 MultimodalRuntimeData (past_seen_token_num = 10 ,
95100 mm_token_lengths = [5 , 8 , 7 ],
96101 mm_token_positions = [0 , 5 ],
97- chunk_end_pos = 20 )
102+ chunk_end_pos = 20 ,
103+ special_token_offsets = [])
98104
99105 # Test negative past_seen_token_num
100106 with pytest .raises (ValueError ,
101107 match = "past_seen_token_num must be non-negative" ):
102108 MultimodalRuntimeData (past_seen_token_num = - 1 ,
103109 mm_token_lengths = [5 ],
104110 mm_token_positions = [0 ],
105- chunk_end_pos = 10 )
111+ chunk_end_pos = 10 ,
112+ special_token_offsets = [])
106113
107114 # Test non-positive token lengths
108115 with pytest .raises (ValueError ,
109116 match = "All mm_token_lengths must be positive" ):
110117 MultimodalRuntimeData (past_seen_token_num = 10 ,
111118 mm_token_lengths = [5 , 0 , 7 ],
112119 mm_token_positions = [0 , 5 , 10 ],
113- chunk_end_pos = 20 )
120+ chunk_end_pos = 20 ,
121+ special_token_offsets = [])
114122
115123 # Test negative positions
116124 with pytest .raises (ValueError ,
117125 match = "All mm_token_positions must be non-negative" ):
118126 MultimodalRuntimeData (past_seen_token_num = 10 ,
119127 mm_token_lengths = [5 , 8 , 7 ],
120128 mm_token_positions = [0 , - 5 , 10 ],
121- chunk_end_pos = 20 )
129+ chunk_end_pos = 20 ,
130+ special_token_offsets = [])
122131
123132
124133class TestFindInputMmEmbed :
125134 """Focused test cases for find_input_mm_embeds function - testing both KV cache reuse and chunked prefill."""
126135
127- def create_mock_runtime (self , num_unseen_mm_tokens : int ,
136+ def create_mock_runtime (self ,
137+ num_unseen_mm_tokens : int ,
128138 num_mm_tokens_in_chunk : int ,
129- mm_token_lengths : List [int ]):
139+ mm_token_lengths : List [int ],
140+ num_unseen_special_tokens : int = 0 ,
141+ num_special_tokens_in_chunk : int = 0 ,
142+ total_special_tokens_in_request : int = 0 ):
130143 """Helper to create a mock MultimodalRuntimeData."""
131144 runtime = Mock (spec = MultimodalRuntimeData )
132145 runtime .num_unseen_mm_tokens = num_unseen_mm_tokens
133146 runtime .num_mm_tokens_in_chunk = num_mm_tokens_in_chunk
134147 runtime .total_mm_tokens_in_request = sum (mm_token_lengths )
148+ runtime .num_unseen_special_tokens = num_unseen_special_tokens
149+ runtime .num_special_tokens_in_chunk = num_special_tokens_in_chunk
150+ runtime .total_special_tokens_in_request = total_special_tokens_in_request
135151
136152 return runtime
137153
@@ -365,22 +381,68 @@ def test_different_devices(self):
365381 result = find_input_mm_embeds (mm_embeds , multimodal_params )
366382 assert result [0 ].device == mm_embeds [0 ].device
367383
384+ def test_special_tokens_in_batched_mode (self ):
385+ """Test special token handling in batched mode."""
386+ mm_embeds = [torch .randn (12 , 512 )
387+ ] # Pre-concatenated: (8-2) + (10-4) = 6 + 6 = 12 tokens
388+ multimodal_params = [
389+ self .create_mock_runtime (num_unseen_mm_tokens = 2 ,
390+ num_mm_tokens_in_chunk = 6 ,
391+ mm_token_lengths = [8 ],
392+ num_unseen_special_tokens = 1 ,
393+ num_special_tokens_in_chunk = 1 ,
394+ total_special_tokens_in_request = 2 ),
395+ self .create_mock_runtime (num_unseen_mm_tokens = 4 ,
396+ num_mm_tokens_in_chunk = 6 ,
397+ mm_token_lengths = [10 ],
398+ num_unseen_special_tokens = 2 ,
399+ num_special_tokens_in_chunk = 2 ,
400+ total_special_tokens_in_request = 4 )
401+ ]
402+ multimodal_params = [
403+ MultimodalParams (multimodal_runtime = runtime )
404+ for runtime in multimodal_params
405+ ]
406+
407+ result = find_input_mm_embeds (mm_embeds , multimodal_params )
408+
409+ # Expected slices accounting for special tokens:
410+ # Batch 1: local_start = 2-1=1, local_end = 1+(6-1)=6, slice [1:6] = 5 tokens
411+ # Batch 2: local_start = 4-2=2, local_end = 2+(6-2)=6, slice [6+2:6+6] = [8:12] = 4 tokens
412+ # Total: 5 + 4 = 9 tokens
413+ assert len (result ) == 1
414+ assert result [0 ].shape == (9 , 512 )
415+
416+ # Verify the slices are correct
417+ expected = torch .cat (
418+ [
419+ mm_embeds [0 ][1 :6 ], # Batch 1: 5 tokens
420+ mm_embeds [0 ][8 :12 ] # Batch 2: 4 tokens
421+ ],
422+ dim = 0 )
423+ torch .testing .assert_close (result [0 ], expected )
424+
368425
369426class TestGetMultimodalEmbeddings :
370427 """Test cases for get_multimodal_embeddings function - testing caching and encoder forward optimization."""
371428
372- def create_mock_runtime (self , total_mm_tokens : int ):
373- """Helper to create a mock MultimodalRuntimeData with total_mm_tokens."""
429+ def create_mock_runtime (self ,
430+ total_mm_tokens : int ,
431+ total_special_tokens : int = 0 ):
432+ """Helper to create a mock MultimodalRuntimeData with total_mm_tokens and special_tokens."""
374433 runtime = Mock (spec = MultimodalRuntimeData )
375434 runtime .total_mm_tokens_in_request = total_mm_tokens
435+ runtime .total_special_tokens_in_request = total_special_tokens
376436 return runtime
377437
378438 def create_multimodal_params_with_data (self ,
379439 has_cached_embedding : bool = False ,
380440 total_mm_tokens : int = 10 ,
441+ total_special_tokens : int = 0 ,
381442 cached_embedding = None ):
382443 """Helper to create MultimodalParams with optional cached embeddings."""
383- runtime = self .create_mock_runtime (total_mm_tokens )
444+ runtime = self .create_mock_runtime (total_mm_tokens ,
445+ total_special_tokens )
384446
385447 multimodal_data = {
386448 # Add some dummy multimodal data to ensure has_content() returns True
@@ -663,6 +725,113 @@ def mock_encoder(params):
663725 assert multimodal_params [0 ].multimodal_data [
664726 "multimodal_embedding" ].device .type == 'cuda'
665727
728+ def test_special_tokens_basic_caching (self ):
729+ """Test caching behavior with special tokens present."""
730+
731+ def mock_encoder (params ):
732+ # Return embeddings for non-special tokens only
733+ # Total: (10-2) + (8-1) + (6-3) = 8 + 7 + 3 = 18 tokens
734+ return [torch .randn (18 , 512 )]
735+
736+ multimodal_params = [
737+ self .create_multimodal_params_with_data (
738+ has_cached_embedding = False ,
739+ total_mm_tokens = 10 ,
740+ total_special_tokens = 2 ), # 8 actual embedding tokens
741+ self .create_multimodal_params_with_data (
742+ has_cached_embedding = False ,
743+ total_mm_tokens = 8 ,
744+ total_special_tokens = 1 ), # 7 actual embedding tokens
745+ self .create_multimodal_params_with_data (
746+ has_cached_embedding = False ,
747+ total_mm_tokens = 6 ,
748+ total_special_tokens = 3 ) # 3 actual embedding tokens
749+ ]
750+
751+ result = get_multimodal_embeddings (mock_encoder , multimodal_params )
752+
753+ # Should return concatenated embeddings
754+ assert len (result ) == 1
755+ assert result [0 ].shape == (18 , 512 ) # 8 + 7 + 3 = 18 tokens
756+
757+ # Check that embeddings were split correctly based on non-special token counts
758+ assert multimodal_params [0 ].multimodal_data [
759+ "multimodal_embedding" ].shape == (8 , 512 ) # 10 - 2
760+ assert multimodal_params [1 ].multimodal_data [
761+ "multimodal_embedding" ].shape == (7 , 512 ) # 8 - 1
762+ assert multimodal_params [2 ].multimodal_data [
763+ "multimodal_embedding" ].shape == (3 , 512 ) # 6 - 3
764+
765+ def test_special_tokens_all_special (self ):
766+ """Test edge case where all tokens are special tokens."""
767+
768+ def mock_encoder (params ):
769+ # Should return empty tensor when no actual embedding tokens
770+ return [torch .randn (0 , 512 )]
771+
772+ multimodal_params = [
773+ self .create_multimodal_params_with_data (
774+ has_cached_embedding = False ,
775+ total_mm_tokens = 5 ,
776+ total_special_tokens = 5 ), # All tokens are special
777+ self .create_multimodal_params_with_data (
778+ has_cached_embedding = False ,
779+ total_mm_tokens = 3 ,
780+ total_special_tokens = 3 ) # All tokens are special
781+ ]
782+
783+ result = get_multimodal_embeddings (mock_encoder , multimodal_params )
784+
785+ # Should return empty embeddings
786+ assert len (result ) == 1
787+ assert result [0 ].shape == (0 , 512 )
788+
789+ # Cached embeddings should also be empty
790+ assert multimodal_params [0 ].multimodal_data [
791+ "multimodal_embedding" ].shape == (0 , 512 )
792+ assert multimodal_params [1 ].multimodal_data [
793+ "multimodal_embedding" ].shape == (0 , 512 )
794+
795+ def test_special_tokens_mixed_with_cached (self ):
796+ """Test special tokens with mixed cached and uncached params."""
797+ encoder_call_count = 0
798+
799+ def mock_encoder (params ):
800+ nonlocal encoder_call_count
801+ encoder_call_count += 1
802+ # Only process uncached param: 12 - 3 = 9 tokens
803+ return [torch .randn (9 , 512 )]
804+
805+ # Mix: cached (with special tokens), uncached (with special tokens)
806+ cached_emb = torch .randn (4 , 512 ) # 6 - 2 = 4 actual tokens
807+ multimodal_params = [
808+ self .create_multimodal_params_with_data (
809+ has_cached_embedding = True ,
810+ total_mm_tokens = 6 ,
811+ total_special_tokens = 2 ,
812+ cached_embedding = cached_emb ),
813+ self .create_multimodal_params_with_data (
814+ has_cached_embedding = False ,
815+ total_mm_tokens = 12 ,
816+ total_special_tokens = 3 ) # 9 actual embedding tokens
817+ ]
818+
819+ result = get_multimodal_embeddings (mock_encoder , multimodal_params )
820+
821+ # Encoder should be called once for uncached param
822+ assert encoder_call_count == 1
823+
824+ # Should return concatenated embeddings: 4 + 9 = 13 tokens
825+ assert len (result ) == 1
826+ assert result [0 ].shape == (13 , 512 )
827+
828+ # Verify cached embedding is preserved and uncached is now cached
829+ torch .testing .assert_close (
830+ multimodal_params [0 ].multimodal_data ["multimodal_embedding" ],
831+ cached_emb )
832+ assert multimodal_params [1 ].multimodal_data [
833+ "multimodal_embedding" ].shape == (9 , 512 )
834+
666835
667836if __name__ == "__main__" :
668837 pytest .main ([__file__ ])
0 commit comments