2323from  tensorrt_llm .llmapi .tokenizer  import  load_hf_tokenizer 
2424
2525from  ..conftest  import  (get_device_count , llm_models_root , parametrize_with_ids ,
26-                         skip_pre_hopper )
26+                         skip_pre_blackwell ,  skip_pre_hopper )
2727from  ..trt_test_alternative  import  popen 
2828from  .accuracy_core  import  (GSM8K , MMLU , JsonModeEval ,
2929                            LlmapiAccuracyTestHarness , get_accuracy_task )
@@ -71,7 +71,9 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
7171                             ctx_server_config : Dict [str , Any ],
7272                             gen_server_config : Dict [str , Any ],
7373                             model_name : str ,
74-                              tensor_parallel_size : int  =  1 ):
74+                              tensor_parallel_size : int  =  1 ,
75+                              ctx_model : str  =  None ,
76+                              gen_model : str  =  None ):
7577    temp_dir  =  tempfile .TemporaryDirectory ()
7678    disaggregated_serving_config_path  =  os .path .join (
7779        temp_dir .name , "disaggregated_serving_config.yaml" )
@@ -97,9 +99,19 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
9799
98100    trtllm_serve_path  =  "trtllm-serve" 
99101    # Common arguments for both servers 
100-     common_args  =  [
102+     ctx_model  =  ctx_model  or  model_name 
103+     gen_model  =  gen_model  or  model_name 
104+     ctx_args  =  [
101105        trtllm_serve_path ,
102-         model_name ,
106+         ctx_model ,
107+         "--host" ,
108+         "localhost" ,
109+         "--backend" ,
110+         "pytorch" ,
111+     ]
112+     gen_args  =  [
113+         trtllm_serve_path ,
114+         gen_model ,
103115        "--host" ,
104116        "localhost" ,
105117        "--backend" ,
@@ -125,11 +137,11 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
125137    env_gen ["TRTLLM_USE_UCX_KVCACHE" ] =  "1" 
126138    env_gen ["CUDA_VISIBLE_DEVICES" ] =  "," .join (
127139        map (str , range (ctx_total_gpus , ctx_total_gpus  +  gen_total_gpus )))
128-     ctx_server_args  =  common_args  +  [
140+     ctx_server_args  =  ctx_args  +  [
129141        "--port" , "8001" , "--extra_llm_api_options" , ctx_server_config_path ,
130142        f"--tp_size={ ctx_tp }  , f"--pp_size={ ctx_pp }  
131143    ]
132-     gen_server_args  =  common_args  +  [
144+     gen_server_args  =  gen_args  +  [
133145        "--port" , "8002" , "--extra_llm_api_options" , gen_server_config_path ,
134146        f"--tp_size={ gen_tp }  , f"--pp_size={ gen_pp }  
135147    ]
@@ -226,17 +238,21 @@ def generate_async(prompt: str,
226238            disaggregated_server .wait ()
227239
228240
229- def  run_parallel_test (model_name : str , model_path : str , ctx_pp : int ,
230-                       ctx_tp : int , gen_pp : int , gen_tp : int ,
231-                       test_set : LlmapiAccuracyTestHarness ):
241+ def  run_parallel_test (model_name : str ,
242+                       model_path : str ,
243+                       ctx_pp : int ,
244+                       ctx_tp : int ,
245+                       gen_pp : int ,
246+                       gen_tp : int ,
247+                       test_sets : List [LlmapiAccuracyTestHarness ],
248+                       ctx_model : str  =  None ,
249+                       gen_model : str  =  None ):
232250    if  ctx_tp  *  ctx_pp  +  gen_tp  *  gen_pp  >  get_device_count ():
233251        pytest .fail (
234252            f"Not enough devices for ctx_pp={ ctx_pp } { ctx_tp } { gen_pp } { gen_tp }  
235253        )
236- 
237254    kv_cache_config  =  {
238255        "free_gpu_memory_fraction" : 0.5 ,
239-         "enable_block_reuse" : False 
240256    }
241257    ctx_server_config  =  {
242258        "pipeline_parallel_size" : ctx_pp ,
@@ -270,10 +286,14 @@ def run_parallel_test(model_name: str, model_path: str, ctx_pp: int,
270286        }
271287    }
272288    with  launch_disaggregated_llm (disaggregated_server_config ,
273-                                   ctx_server_config , gen_server_config ,
274-                                   model_path ) as  llm :
275-         task  =  test_set (model_name )
276-         task .evaluate (llm )
289+                                   ctx_server_config ,
290+                                   gen_server_config ,
291+                                   model_path ,
292+                                   ctx_model = ctx_model ,
293+                                   gen_model = gen_model ) as  llm :
294+         for  test_set  in  test_sets :
295+             task  =  test_set (model_name )
296+             task .evaluate (llm )
277297
278298
279299@pytest .mark .timeout (3600 ) 
@@ -511,14 +531,14 @@ def test_guided_decoding_with_eagle3(self, backend: str, mocker):
511531    @pytest .mark .parametrize ("testset" , ["GSM8K" , "MMLU" ]) 
512532    def  test_tp_pp_symmetric (self , tp , pp , testset ):
513533        return  run_parallel_test (self .MODEL_NAME , self .MODEL_PATH , pp , tp , pp ,
514-                                  tp , get_accuracy_task (testset ))
534+                                  tp , [ get_accuracy_task (testset )] )
515535
516536    @parametrize_with_ids ("ctx_pp" , [2 , 4 ]) 
517537    @parametrize_with_ids ("gen_tp" , [1 , 2 ]) 
518538    @pytest .mark .parametrize ("testset" , ["GSM8K" , "MMLU" ]) 
519539    def  test_ctx_pp_gen_tp_asymmetric (self , ctx_pp , gen_tp , testset ):
520540        return  run_parallel_test (self .MODEL_NAME , self .MODEL_PATH , ctx_pp , 1 , 1 ,
521-                                  gen_tp , get_accuracy_task (testset ))
541+                                  gen_tp , [ get_accuracy_task (testset )] )
522542
523543
524544@pytest .mark .skip_less_device_memory (140000 ) 
@@ -702,3 +722,24 @@ def test_auto_dtype(self, overlap_scheduler):
702722            task .evaluate (llm )
703723            task  =  MMLU (self .MODEL_NAME )
704724            task .evaluate (llm )
725+ 
726+ 
727+ @skip_pre_blackwell  
728+ @pytest .mark .timeout (3600 ) 
729+ class  TestQwen3_30B_A3B (LlmapiAccuracyTestHarness ):
730+     fp4_model  =  f"{ llm_models_root ()}  
731+     fp8_model  =  f"{ llm_models_root ()}  
732+ 
733+     @pytest .mark .parametrize ("ctxpp,gentp" , [(2 , 2 )], ids = ["ctxpp2gentp2" ]) 
734+     def  test_mixed_ctx_gen_model (self , ctxpp , gentp ):
735+         ctx_model  =  self .fp4_model 
736+         gen_model  =  self .fp8_model 
737+         return  run_parallel_test ("Qwen3/Qwen3-30B-A3B" ,
738+                                  ctx_model ,
739+                                  ctx_pp = ctxpp ,
740+                                  ctx_tp = 1 ,
741+                                  gen_pp = 1 ,
742+                                  gen_tp = gentp ,
743+                                  test_sets = [GSM8K , MMLU ],
744+                                  ctx_model = ctx_model ,
745+                                  gen_model = gen_model )
0 commit comments