Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/integration/defs/accuracy/references/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ Qwen3/Qwen3-8B:
kv_cache_quant_algo: FP8
accuracy: 87.1114
Qwen3/Qwen3-30B-A3B:
- accuracy: 83.43
- quant_algo: FP8_BLOCK_SCALES
accuracy: 84.36
- quant_algo: FP8
Expand Down
76 changes: 59 additions & 17 deletions tests/integration/defs/accuracy/test_disaggregated_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tensorrt_llm.llmapi.tokenizer import load_hf_tokenizer

from ..conftest import (get_device_count, llm_models_root, parametrize_with_ids,
skip_pre_hopper)
skip_pre_blackwell, skip_pre_hopper)
from ..trt_test_alternative import popen
from .accuracy_core import (GSM8K, MMLU, JsonModeEval,
LlmapiAccuracyTestHarness, get_accuracy_task)
Expand Down Expand Up @@ -71,7 +71,9 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
ctx_server_config: Dict[str, Any],
gen_server_config: Dict[str, Any],
model_name: str,
tensor_parallel_size: int = 1):
tensor_parallel_size: int = 1,
ctx_model: str = None,
gen_model: str = None):
temp_dir = tempfile.TemporaryDirectory()
disaggregated_serving_config_path = os.path.join(
temp_dir.name, "disaggregated_serving_config.yaml")
Expand All @@ -97,9 +99,19 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],

trtllm_serve_path = "trtllm-serve"
# Common arguments for both servers
common_args = [
ctx_model = ctx_model or model_name
gen_model = gen_model or model_name
ctx_args = [
trtllm_serve_path,
model_name,
ctx_model,
"--host",
"localhost",
"--backend",
"pytorch",
]
gen_args = [
trtllm_serve_path,
gen_model,
"--host",
"localhost",
"--backend",
Expand All @@ -125,11 +137,11 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1"
env_gen["CUDA_VISIBLE_DEVICES"] = ",".join(
map(str, range(ctx_total_gpus, ctx_total_gpus + gen_total_gpus)))
ctx_server_args = common_args + [
ctx_server_args = ctx_args + [
"--port", "8001", "--extra_llm_api_options", ctx_server_config_path,
f"--tp_size={ctx_tp}", f"--pp_size={ctx_pp}"
]
gen_server_args = common_args + [
gen_server_args = gen_args + [
"--port", "8002", "--extra_llm_api_options", gen_server_config_path,
f"--tp_size={gen_tp}", f"--pp_size={gen_pp}"
]
Expand Down Expand Up @@ -226,17 +238,21 @@ def generate_async(prompt: str,
disaggregated_server.wait()


def run_parallel_test(model_name: str, model_path: str, ctx_pp: int,
ctx_tp: int, gen_pp: int, gen_tp: int,
test_set: LlmapiAccuracyTestHarness):
def run_parallel_test(model_name: str,
model_path: str,
ctx_pp: int,
ctx_tp: int,
gen_pp: int,
gen_tp: int,
test_sets: List[LlmapiAccuracyTestHarness],
ctx_model: str = None,
gen_model: str = None):
if ctx_tp * ctx_pp + gen_tp * gen_pp > get_device_count():
pytest.fail(
f"Not enough devices for ctx_pp={ctx_pp}+ctx_tp={ctx_tp} and gen_pp={gen_pp}+gen_tp={gen_tp} test"
)

kv_cache_config = {
"free_gpu_memory_fraction": 0.5,
"enable_block_reuse": False
}
ctx_server_config = {
"pipeline_parallel_size": ctx_pp,
Expand Down Expand Up @@ -270,10 +286,14 @@ def run_parallel_test(model_name: str, model_path: str, ctx_pp: int,
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
model_path) as llm:
task = test_set(model_name)
task.evaluate(llm)
ctx_server_config,
gen_server_config,
model_path,
ctx_model=ctx_model,
gen_model=gen_model) as llm:
for test_set in test_sets:
task = test_set(model_name)
task.evaluate(llm)


@pytest.mark.timeout(3600)
Expand Down Expand Up @@ -512,7 +532,7 @@ def test_tp_pp_symmetric(self, tp, pp, testset):
if tp * pp * 2 > get_device_count():
pytest.skip(f"Not enough devices for tp={tp}*pp={pp} test")
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, pp, tp, pp,
tp, get_accuracy_task(testset))
tp, [get_accuracy_task(testset)])

@parametrize_with_ids("ctx_pp", [2, 4])
@parametrize_with_ids("gen_tp", [1, 2])
Expand All @@ -522,7 +542,7 @@ def test_ctx_pp_gen_tp_asymmetric(self, ctx_pp, gen_tp, testset):
pytest.skip(
f"Not enough devices for ctx_pp={ctx_pp}*gen_tp={gen_tp} test")
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, ctx_pp, 1, 1,
gen_tp, get_accuracy_task(testset))
gen_tp, [get_accuracy_task(testset)])


@pytest.mark.skip_less_device_memory(140000)
Expand Down Expand Up @@ -776,3 +796,25 @@ def test_auto_dtype(self, overlap_scheduler):
task.evaluate(llm)
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)


@skip_pre_blackwell
@pytest.mark.timeout(3600)
class TestQwen3_30B_A3B(LlmapiAccuracyTestHarness):
FP4_MODEL = f"{llm_models_root()}/Qwen3/saved_models_Qwen3-30B-A3B_nvfp4_hf"
FP8_MODEL = f"{llm_models_root()}/Qwen3/saved_models_Qwen3-30B-A3B_fp8_hf"

@pytest.mark.skip_less_device(4)
@pytest.mark.parametrize("ctx_pp,gen_tp", [(2, 2)], ids=["ctxpp2gentp2"])
def test_mixed_ctx_gen_model(self, ctx_pp, gen_tp):
ctx_model = self.FP4_MODEL
gen_model = self.FP8_MODEL
return run_parallel_test("Qwen3/Qwen3-30B-A3B",
ctx_model,
ctx_pp=ctx_pp,
ctx_tp=1,
gen_pp=1,
gen_tp=gen_tp,
test_sets=[GSM8K, MMLU],
ctx_model=ctx_model,
gen_model=gen_model)
1 change: 1 addition & 0 deletions tests/integration/test_lists/qa/llm_function_sanity.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=4]
accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[False]
accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[True]
accuracy/test_disaggregated_serving.py::TestQwen3_30B_A3B::test_mixed_ctx_gen_model[ctxpp2gentp2]
accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_chunked_prefill[attn_backend=FLASHINFER]
accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_chunked_prefill[attn_backend=TRTLLM]
accuracy/test_llm_api_pytorch.py::TestBielik11BInstruct::test_auto_dtype
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_dgx_b200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ l0_dgx_b200:
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp8[tp4-cuda_graph=True]
- accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp4[tp4-cuda_graph=True]
- accuracy/test_disaggregated_serving.py::TestQwen3_30B_A3B::test_mixed_ctx_gen_model[ctxpp2gentp2]
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8]
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-CUTLASS]
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-TRTLLM]
Expand Down