From e736d38280844c70e237a7a23c5357b5de51a748 Mon Sep 17 00:00:00 2001 From: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> Date: Thu, 7 Aug 2025 22:09:19 -0700 Subject: [PATCH 1/2] add mixed gen/ctx model test Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> --- .../defs/accuracy/references/gsm8k.yaml | 1 + .../accuracy/test_disaggregated_serving.py | 75 ++++++++++++++----- .../test_lists/qa/llm_function_sanity.txt | 1 + .../test_lists/test-db/l0_dgx_b200.yml | 1 + 4 files changed, 61 insertions(+), 17 deletions(-) diff --git a/tests/integration/defs/accuracy/references/gsm8k.yaml b/tests/integration/defs/accuracy/references/gsm8k.yaml index 29458d3bf49..26de82cbc09 100644 --- a/tests/integration/defs/accuracy/references/gsm8k.yaml +++ b/tests/integration/defs/accuracy/references/gsm8k.yaml @@ -80,6 +80,7 @@ Qwen3/Qwen3-8B: kv_cache_quant_algo: FP8 accuracy: 87.1114 Qwen3/Qwen3-30B-A3B: + - accuracy: 83.43 - quant_algo: FP8_BLOCK_SCALES accuracy: 84.36 - quant_algo: FP8 diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index 051c5401a06..7bd72968d89 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -23,7 +23,7 @@ from tensorrt_llm.llmapi.tokenizer import load_hf_tokenizer from ..conftest import (get_device_count, llm_models_root, parametrize_with_ids, - skip_pre_hopper) + skip_pre_blackwell, skip_pre_hopper) from ..trt_test_alternative import popen from .accuracy_core import (GSM8K, MMLU, JsonModeEval, LlmapiAccuracyTestHarness, get_accuracy_task) @@ -71,7 +71,9 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any], ctx_server_config: Dict[str, Any], gen_server_config: Dict[str, Any], model_name: str, - tensor_parallel_size: int = 1): + tensor_parallel_size: int = 1, + ctx_model: str = None, + gen_model: str = None): temp_dir = tempfile.TemporaryDirectory() disaggregated_serving_config_path = os.path.join( temp_dir.name, "disaggregated_serving_config.yaml") @@ -97,9 +99,19 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any], trtllm_serve_path = "trtllm-serve" # Common arguments for both servers - common_args = [ + ctx_model = ctx_model or model_name + gen_model = gen_model or model_name + ctx_args = [ trtllm_serve_path, - model_name, + ctx_model, + "--host", + "localhost", + "--backend", + "pytorch", + ] + gen_args = [ + trtllm_serve_path, + gen_model, "--host", "localhost", "--backend", @@ -125,11 +137,11 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any], env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1" env_gen["CUDA_VISIBLE_DEVICES"] = ",".join( map(str, range(ctx_total_gpus, ctx_total_gpus + gen_total_gpus))) - ctx_server_args = common_args + [ + ctx_server_args = ctx_args + [ "--port", "8001", "--extra_llm_api_options", ctx_server_config_path, f"--tp_size={ctx_tp}", f"--pp_size={ctx_pp}" ] - gen_server_args = common_args + [ + gen_server_args = gen_args + [ "--port", "8002", "--extra_llm_api_options", gen_server_config_path, f"--tp_size={gen_tp}", f"--pp_size={gen_pp}" ] @@ -226,17 +238,21 @@ def generate_async(prompt: str, disaggregated_server.wait() -def run_parallel_test(model_name: str, model_path: str, ctx_pp: int, - ctx_tp: int, gen_pp: int, gen_tp: int, - test_set: LlmapiAccuracyTestHarness): +def run_parallel_test(model_name: str, + model_path: str, + ctx_pp: int, + ctx_tp: int, + gen_pp: int, + gen_tp: int, + test_sets: List[LlmapiAccuracyTestHarness], + ctx_model: str = None, + gen_model: str = None): if ctx_tp * ctx_pp + gen_tp * gen_pp > get_device_count(): pytest.fail( f"Not enough devices for ctx_pp={ctx_pp}+ctx_tp={ctx_tp} and gen_pp={gen_pp}+gen_tp={gen_tp} test" ) - kv_cache_config = { "free_gpu_memory_fraction": 0.5, - "enable_block_reuse": False } ctx_server_config = { "pipeline_parallel_size": ctx_pp, @@ -270,10 +286,14 @@ def run_parallel_test(model_name: str, model_path: str, ctx_pp: int, } } with launch_disaggregated_llm(disaggregated_server_config, - ctx_server_config, gen_server_config, - model_path) as llm: - task = test_set(model_name) - task.evaluate(llm) + ctx_server_config, + gen_server_config, + model_path, + ctx_model=ctx_model, + gen_model=gen_model) as llm: + for test_set in test_sets: + task = test_set(model_name) + task.evaluate(llm) @pytest.mark.timeout(3600) @@ -512,7 +532,7 @@ def test_tp_pp_symmetric(self, tp, pp, testset): if tp * pp * 2 > get_device_count(): pytest.skip(f"Not enough devices for tp={tp}*pp={pp} test") return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, pp, tp, pp, - tp, get_accuracy_task(testset)) + tp, [get_accuracy_task(testset)]) @parametrize_with_ids("ctx_pp", [2, 4]) @parametrize_with_ids("gen_tp", [1, 2]) @@ -522,7 +542,7 @@ def test_ctx_pp_gen_tp_asymmetric(self, ctx_pp, gen_tp, testset): pytest.skip( f"Not enough devices for ctx_pp={ctx_pp}*gen_tp={gen_tp} test") return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, ctx_pp, 1, 1, - gen_tp, get_accuracy_task(testset)) + gen_tp, [get_accuracy_task(testset)]) @pytest.mark.skip_less_device_memory(140000) @@ -776,3 +796,24 @@ def test_auto_dtype(self, overlap_scheduler): task.evaluate(llm) task = MMLU(self.MODEL_NAME) task.evaluate(llm) + + +@skip_pre_blackwell +@pytest.mark.timeout(3600) +class TestQwen3_30B_A3B(LlmapiAccuracyTestHarness): + fp4_model = f"{llm_models_root()}/Qwen3/saved_models_Qwen3-30B-A3B_nvfp4_hf" + fp8_model = f"{llm_models_root()}/Qwen3/saved_models_Qwen3-30B-A3B_fp8_hf" + + @pytest.mark.parametrize("ctxpp,gentp", [(2, 2)], ids=["ctxpp2gentp2"]) + def test_mixed_ctx_gen_model(self, ctxpp, gentp): + ctx_model = self.fp4_model + gen_model = self.fp8_model + return run_parallel_test("Qwen3/Qwen3-30B-A3B", + ctx_model, + ctx_pp=ctxpp, + ctx_tp=1, + gen_pp=1, + gen_tp=gentp, + test_sets=[GSM8K, MMLU], + ctx_model=ctx_model, + gen_model=gen_model) diff --git a/tests/integration/test_lists/qa/llm_function_sanity.txt b/tests/integration/test_lists/qa/llm_function_sanity.txt index aeaa1ba573b..d262e899122 100644 --- a/tests/integration/test_lists/qa/llm_function_sanity.txt +++ b/tests/integration/test_lists/qa/llm_function_sanity.txt @@ -25,6 +25,7 @@ accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=4] accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[False] accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[True] +accuracy/test_disaggregated_serving.py::TestQwen3_30B_A3B::test_mixed_ctx_gen_model[ctxpp2gentp2] accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_chunked_prefill[attn_backend=FLASHINFER] accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_chunked_prefill[attn_backend=TRTLLM] accuracy/test_llm_api_pytorch.py::TestBielik11BInstruct::test_auto_dtype diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index 2c04beb634a..fb3f518a686 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -70,6 +70,7 @@ l0_dgx_b200: - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp8[tp4-cuda_graph=True] - accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp4[tp4-cuda_graph=True] + - accuracy/test_disaggregated_serving.py::TestQwen3_30B_A3B::test_mixed_ctx_gen_model[ctxpp2gentp2] - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8] - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-CUTLASS] - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-TRTLLM] From f4ec496b1386dfeaf87c40a4fede8f2c648fa8ac Mon Sep 17 00:00:00 2001 From: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> Date: Mon, 11 Aug 2025 19:43:36 -0700 Subject: [PATCH 2/2] fix by CR bot's comments Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> --- .../accuracy/test_disaggregated_serving.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index 7bd72968d89..51a572ce493 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -801,19 +801,20 @@ def test_auto_dtype(self, overlap_scheduler): @skip_pre_blackwell @pytest.mark.timeout(3600) class TestQwen3_30B_A3B(LlmapiAccuracyTestHarness): - fp4_model = f"{llm_models_root()}/Qwen3/saved_models_Qwen3-30B-A3B_nvfp4_hf" - fp8_model = f"{llm_models_root()}/Qwen3/saved_models_Qwen3-30B-A3B_fp8_hf" - - @pytest.mark.parametrize("ctxpp,gentp", [(2, 2)], ids=["ctxpp2gentp2"]) - def test_mixed_ctx_gen_model(self, ctxpp, gentp): - ctx_model = self.fp4_model - gen_model = self.fp8_model + FP4_MODEL = f"{llm_models_root()}/Qwen3/saved_models_Qwen3-30B-A3B_nvfp4_hf" + FP8_MODEL = f"{llm_models_root()}/Qwen3/saved_models_Qwen3-30B-A3B_fp8_hf" + + @pytest.mark.skip_less_device(4) + @pytest.mark.parametrize("ctx_pp,gen_tp", [(2, 2)], ids=["ctxpp2gentp2"]) + def test_mixed_ctx_gen_model(self, ctx_pp, gen_tp): + ctx_model = self.FP4_MODEL + gen_model = self.FP8_MODEL return run_parallel_test("Qwen3/Qwen3-30B-A3B", ctx_model, - ctx_pp=ctxpp, + ctx_pp=ctx_pp, ctx_tp=1, gen_pp=1, - gen_tp=gentp, + gen_tp=gen_tp, test_sets=[GSM8K, MMLU], ctx_model=ctx_model, gen_model=gen_model)