@@ -71,15 +71,16 @@ def model_fn(*args, **kwargs):
7171
7272
7373@pytest .mark .threadleak (enabled = False ) 
74- def  test_connector_simple (model_with_connector ):
74+ @pytest .mark .parametrize ("use_overlap_scheduler" , [True , False ]) 
75+ def  test_connector_simple (model_with_connector , use_overlap_scheduler ):
7576    NUM_TOKENS  =  8 
7677
7778    model_fn , scheduler , worker  =  model_with_connector 
7879
7980    model  =  model_fn (
8081        model = "Qwen/Qwen2-0.5B" ,
8182        backend = "pytorch" ,
82-         disable_overlap_scheduler = True ,
83+         disable_overlap_scheduler = not   use_overlap_scheduler ,
8384        cuda_graph_config = None ,
8485        kv_cache_config = KvCacheConfig (free_gpu_memory_fraction = 0.1 ))
8586
@@ -93,7 +94,9 @@ def test_connector_simple(model_with_connector):
9394
9495    model .generate (["Hello, world" ], sampling_params )
9596
96-     assert  scheduler .build_connector_meta .call_count  ==  NUM_TOKENS 
97+     # With the overlap scheduler, we generate one extra token. 
98+     assert  scheduler .build_connector_meta .call_count  ==  NUM_TOKENS  +  int (
99+         use_overlap_scheduler )
97100
98101    # We should have a single `SchedulerOutput` per forward pass. 
99102    for  i , call  in  enumerate (scheduler .build_connector_meta .call_args_list ):
@@ -105,7 +108,8 @@ def test_connector_simple(model_with_connector):
105108            assert  len (scheduler_output .requests [0 ].new_tokens ) ==  1 
106109
107110    # We call `start_load_kv` once at the beginning of each forward pass. 
108-     assert  worker .start_load_kv .call_count  ==  NUM_TOKENS 
111+     assert  worker .start_load_kv .call_count  ==  NUM_TOKENS  +  int (
112+         use_overlap_scheduler )
109113
110114    # Only called once when the request is received. 
111115    assert  scheduler .get_num_new_matched_tokens .call_count  ==  1 
@@ -114,31 +118,36 @@ def test_connector_simple(model_with_connector):
114118                     for  call  in  worker .wait_for_layer_load .call_args_list ) +  1 
115119
116120    # Called num_layers * num_forward_passes times. 
117-     assert  worker .wait_for_layer_load .call_count  ==  num_layers  *  NUM_TOKENS 
118-     assert  worker .save_kv_layer .call_count  ==  num_layers  *  NUM_TOKENS 
121+     assert  worker .wait_for_layer_load .call_count  ==  num_layers  *  (
122+         NUM_TOKENS  +  int (use_overlap_scheduler ))
123+     assert  worker .save_kv_layer .call_count  ==  num_layers  *  (
124+         NUM_TOKENS  +  int (use_overlap_scheduler ))
119125
120126    for  i , call  in  enumerate (worker .wait_for_layer_load .call_args_list ):
121127        assert  call .args [0 ] ==  i  %  num_layers 
122128
123129    for  i , call  in  enumerate (worker .save_kv_layer .call_args_list ):
124130        assert  call .args [0 ] ==  i  %  num_layers 
125131
126-     assert  worker .wait_for_save .call_count  ==  NUM_TOKENS 
132+     assert  worker .wait_for_save .call_count  ==  NUM_TOKENS  +  int (
133+         use_overlap_scheduler )
127134
128135    assert  scheduler .request_finished .call_count  ==  1 
129-     assert  worker .get_finished .call_count  ==  NUM_TOKENS 
136+     assert  worker .get_finished .call_count  ==  NUM_TOKENS  +  int (
137+         use_overlap_scheduler )
130138
131139
132140@pytest .mark .threadleak (enabled = False ) 
133- def  test_connector_async_onboard (model_with_connector ):
141+ @pytest .mark .parametrize ("use_overlap_scheduler" , [True , False ]) 
142+ def  test_connector_async_onboard (model_with_connector , use_overlap_scheduler ):
134143    NUM_TOKENS  =  8 
135144
136145    model_fn , scheduler , worker  =  model_with_connector 
137146
138147    model  =  model_fn (
139148        model = "Qwen/Qwen2-0.5B" ,
140149        backend = "pytorch" ,
141-         disable_overlap_scheduler = True ,
150+         disable_overlap_scheduler = not   use_overlap_scheduler ,
142151        cuda_graph_config = None ,
143152        kv_cache_config = KvCacheConfig (free_gpu_memory_fraction = 0.1 ))
144153
@@ -153,23 +162,25 @@ def test_connector_async_onboard(model_with_connector):
153162        "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." 
154163    ], SamplingParams (max_tokens = NUM_TOKENS , ignore_eos = True ))
155164
156-     # Once for the initial poll, then once for each token. 
157-     assert  worker .get_finished .call_count  ==  NUM_TOKENS  +  1 
165+     # Once for the initial poll, then once for each token. One extra token when using the overlap scheduler. 
166+     assert  worker .get_finished .call_count  ==  NUM_TOKENS  +  1  +  int (
167+         use_overlap_scheduler )
158168
159169    # In the first iteration, there should be a single request id provided. 
160170    assert  len (worker .get_finished .call_args_list [0 ].args [1 ]) ==  1 
161171
162172
163173@pytest .mark .threadleak (enabled = False ) 
164- def  test_connector_async_save (model_with_connector ):
174+ @pytest .mark .parametrize ("use_overlap_scheduler" , [True , False ]) 
175+ def  test_connector_async_save (model_with_connector , use_overlap_scheduler ):
165176    NUM_TOKENS  =  8 
166177
167178    model_fn , scheduler , worker  =  model_with_connector 
168179
169180    model  =  model_fn (
170181        model = "Qwen/Qwen2-0.5B" ,
171182        backend = "pytorch" ,
172-         disable_overlap_scheduler = True ,
183+         disable_overlap_scheduler = not   use_overlap_scheduler ,
173184        cuda_graph_config = None ,
174185        kv_cache_config = KvCacheConfig (free_gpu_memory_fraction = 0.1 ))
175186
@@ -188,12 +199,13 @@ def test_connector_async_save(model_with_connector):
188199
189200    assert  scheduler .request_finished .call_count  ==  1 
190201
191-     # On the last call to get_finished, we should be providing the async saving request. 
192-     assert  worker .get_finished .call_count  ==  NUM_TOKENS 
202+     # On the last call to get_finished, we should be providing the async saving request. One extra token when using the overlap scheduler. 
203+     assert  worker .get_finished .call_count  ==  NUM_TOKENS  +  int (
204+         use_overlap_scheduler )
193205
194-     for  i   in  range ( NUM_TOKENS ):
195-         args  =  worker . get_finished . call_args_list [ i ] .args 
196-         if  i  !=  NUM_TOKENS  -  1 :
206+     for  i ,  call   in  enumerate ( worker . get_finished . call_args_list ):
207+         args  =  call .args 
208+         if  i  !=  len ( worker . get_finished . call_args_list )  -  1 :
197209            assert  args  ==  ([], [])
198210        else :
199211            assert  len (args [0 ]) ==  1 
@@ -202,7 +214,9 @@ def test_connector_async_save(model_with_connector):
202214
203215
204216@pytest .mark .threadleak (enabled = False ) 
205- def  test_connector_scheduler_output (model_with_connector ):
217+ @pytest .mark .parametrize ("use_overlap_scheduler" , [True , False ]) 
218+ def  test_connector_scheduler_output (model_with_connector ,
219+                                     use_overlap_scheduler ):
206220    NUM_INPUT_TOKENS  =  48 
207221    NUM_TOKENS  =  32 
208222    BLOCK_SIZE  =  32 
@@ -212,7 +226,7 @@ def test_connector_scheduler_output(model_with_connector):
212226    model  =  model_fn (
213227        model = "Qwen/Qwen2-0.5B" ,
214228        backend = "pytorch" ,
215-         disable_overlap_scheduler = True ,
229+         disable_overlap_scheduler = not   use_overlap_scheduler ,
216230        cuda_graph_config = None ,
217231        kv_cache_config = KvCacheConfig (free_gpu_memory_fraction = 0.1 ))
218232
@@ -226,7 +240,9 @@ def test_connector_scheduler_output(model_with_connector):
226240
227241    model .generate ([0 ] *  NUM_INPUT_TOKENS , sampling_params )
228242
229-     assert  scheduler .build_connector_meta .call_count  ==  NUM_TOKENS 
243+     # Additional token when using the overlap scheduler. 
244+     assert  scheduler .build_connector_meta .call_count  ==  NUM_TOKENS  +  int (
245+         use_overlap_scheduler )
230246
231247    for  i , call  in  enumerate (scheduler .build_connector_meta .call_args_list ):
232248        sched_output  =  call .args [0 ]
@@ -241,7 +257,8 @@ def test_connector_scheduler_output(model_with_connector):
241257        else :
242258            assert  len (request .new_tokens ) ==  1 
243259
244-             if  request .computed_position  %  BLOCK_SIZE  ==  0 :
260+             if  (request .computed_position  + 
261+                     int (use_overlap_scheduler )) %  BLOCK_SIZE  ==  0 :
245262                assert  len (request .new_block_ids ) ==  1 
246263            else :
247264                assert  request .new_block_ids  ==  []
@@ -257,7 +274,9 @@ def test_connector_scheduler_output(model_with_connector):
257274
258275
259276@pytest .mark .threadleak (enabled = False ) 
260- def  test_connector_scheduler_output_chunked_context (model_with_connector ):
277+ @pytest .mark .parametrize ("use_overlap_scheduler" , [True , False ]) 
278+ def  test_connector_scheduler_output_chunked_context (model_with_connector ,
279+                                                     use_overlap_scheduler ):
261280    model_fn , scheduler , worker  =  model_with_connector 
262281
263282    CHUNK_SIZE  =  128 
@@ -266,7 +285,7 @@ def test_connector_scheduler_output_chunked_context(model_with_connector):
266285    model  =  model_fn (
267286        model = "Qwen/Qwen2-0.5B" ,
268287        backend = "pytorch" ,
269-         disable_overlap_scheduler = True ,
288+         disable_overlap_scheduler = not   use_overlap_scheduler ,
270289        cuda_graph_config = None ,
271290        kv_cache_config = KvCacheConfig (free_gpu_memory_fraction = 0.1 ),
272291        enable_chunked_prefill = True ,
0 commit comments