@@ -1818,7 +1818,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
18181818 assert len (scheduler .waiting ) == 1
18191819
18201820
1821- def test_priority_scheduling_preemption_when_out_of_kv ():
1821+ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv ():
18221822 """Test that priority scheduling preempts lower priority requests
18231823 when out of KV cache space."""
18241824 # Create scheduler with very limited memory to force preemption
@@ -1827,6 +1827,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
18271827 max_num_batched_tokens = 200 ,
18281828 num_blocks = 5 , # Can hold 64 tokens (first block is null)
18291829 block_size = 16 , # Standard block size
1830+ use_kv_connector = True ,
18301831 )
18311832
18321833 # Create a request and schedule it
@@ -1838,12 +1839,13 @@ def test_priority_scheduling_preemption_when_out_of_kv():
18381839 starting_idx = 0 ,
18391840 )[0 ]
18401841 scheduler .add_request (request_low )
1842+ # 1st schedule
18411843 output = scheduler .schedule ()
18421844 assert len (output .scheduled_new_reqs ) == 1
18431845 assert len (scheduler .waiting ) == 0
18441846 assert len (scheduler .running ) == 1
18451847
1846- # Simulate model execution
1848+ # Simulate model execution - 1st decode
18471849 model_output = ModelRunnerOutput (
18481850 req_ids = [request_low .request_id ],
18491851 req_id_to_index = {request_low .request_id : 0 },
@@ -1864,6 +1866,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
18641866 starting_idx = 1 ,
18651867 )[0 ]
18661868 scheduler .add_request (request_high )
1869+ # 2nd schedule
18671870 output = scheduler .schedule ()
18681871 # KV cache should be full at this point
18691872 assert scheduler .kv_cache_manager .block_pool .get_num_free_blocks () == 0
@@ -1872,7 +1875,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
18721875 assert len (scheduler .waiting ) == 0
18731876 assert len (scheduler .running ) == 2
18741877
1875- # Simulate model execution
1878+ # Simulate model execution - 2nd decode
18761879 requests = [request_low , request_high ]
18771880 model_output = ModelRunnerOutput (
18781881 req_ids = [req .request_id for req in requests ],
@@ -1888,7 +1891,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
18881891 )
18891892 scheduler .update_from_output (output , model_output )
18901893
1891- # Schedule again - this should trigger preemption
1894+ # 3rd schedule - this should trigger preemption
18921895 # req_low needs 32 tokens = 2 blocks
18931896 # req_high needs 33 tokens = 3 blocks
18941897 # so doesn't fit in 4 blocks.
@@ -1898,5 +1901,45 @@ def test_priority_scheduling_preemption_when_out_of_kv():
18981901 assert len (output .scheduled_new_reqs ) == 0
18991902 assert output .scheduled_cached_reqs .num_reqs == 1
19001903 assert output .scheduled_cached_reqs .req_ids [0 ] == request_high .request_id
1904+ assert scheduler .requests [
1905+ request_low .request_id ].status == RequestStatus .PREEMPTED
19011906 assert len (scheduler .waiting ) == 1
1902- assert len (scheduler .running ) == 1
1907+ assert len (scheduler .running ) == 1
1908+
1909+ # Simulate model execution - 3rd decode
1910+ model_output = ModelRunnerOutput (
1911+ req_ids = [req .request_id for req in requests ],
1912+ req_id_to_index = {
1913+ req .request_id : i
1914+ for i , req in enumerate (requests )
1915+ },
1916+ sampled_token_ids = [[], [100 ]],
1917+ # spec_token_ids=None,
1918+ logprobs = None ,
1919+ prompt_logprobs_dict = {},
1920+ pooler_output = [],
1921+ )
1922+ # Finish the requests to make room for the preempted requests to resume
1923+ scheduler .update_from_output (output , model_output )
1924+ scheduler .finish_requests (request_high .request_id ,
1925+ RequestStatus .FINISHED_STOPPED )
1926+
1927+ # 4th Schedule again - this should trigger the resumption
1928+ output = scheduler .schedule ()
1929+ scheduled_cached_reqs = output .scheduled_cached_reqs
1930+ resumed_from_preemption = scheduled_cached_reqs .resumed_from_preemption
1931+
1932+ assert len (output .scheduled_new_reqs ) == 0
1933+ assert scheduled_cached_reqs .num_reqs == 1
1934+ assert len (scheduler .waiting ) == 0
1935+ assert len (scheduler .running ) == 1
1936+
1937+ # Preempted request resumed in scheduled_cached_reqs
1938+ assert len (resumed_from_preemption ) == 1
1939+ assert len (scheduled_cached_reqs .resumed_req_token_ids ) == 1
1940+ assert resumed_from_preemption [0 ]
1941+ assert scheduled_cached_reqs .req_ids [0 ] == request_low .request_id
1942+ assert scheduled_cached_reqs .resumed_req_token_ids [0 ] is not None
1943+ # Resumed tokens include 30 prompt tokens and 2 decoded tokens
1944+ assert len (scheduled_cached_reqs .resumed_req_token_ids [0 ]) == 32
1945+ assert scheduled_cached_reqs .resumed_req_token_ids [0 ][31 ] == 100
0 commit comments