diff --git a/tests/serve/test_dynamo_serve.py b/tests/serve/test_dynamo_serve.py index 1808ae4f0a..45d22dc98e 100644 --- a/tests/serve/test_dynamo_serve.py +++ b/tests/serve/test_dynamo_serve.py @@ -25,7 +25,6 @@ DeploymentGraph, Payload, chat_completions_response_handler, - completions_response_handler, ) from tests.utils.managed_process import ManagedProcess @@ -56,106 +55,7 @@ expected_response=["bus"], ) -text_payload = Payload( - payload_chat={ - "model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", - "messages": [ - { - "role": "user", - "content": text_prompt, # Shorter prompt - } - ], - "max_tokens": 150, # Reduced from 500 - "temperature": 0.1, - # "seed": 0, - }, - payload_completions={ - "model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", - "prompt": text_prompt, - "max_tokens": 150, - "temperature": 0.1, - # "seed": 0, - }, - repeat_count=10, - expected_log=[], - expected_response=["AI"], -) - deployment_graphs = { - "agg": ( - DeploymentGraph( - module="graphs.agg:Frontend", - config="configs/agg.yaml", - directory="/workspace/examples/llm", - endpoints=["v1/chat/completions"], - response_handlers=[ - chat_completions_response_handler, - ], - marks=[pytest.mark.gpu_1, pytest.mark.vllm], - ), - text_payload, - ), - "sglang_agg": ( - DeploymentGraph( - module="graphs.agg:Frontend", - config="configs/agg.yaml", - directory="/workspace/examples/sglang", - endpoints=["v1/chat/completions", "v1/completions"], - response_handlers=[ - chat_completions_response_handler, - completions_response_handler, - ], - marks=[pytest.mark.gpu_1, pytest.mark.sglang], - ), - text_payload, - ), - "disagg": ( - DeploymentGraph( - module="graphs.disagg:Frontend", - config="configs/disagg.yaml", - directory="/workspace/examples/llm", - endpoints=["v1/chat/completions"], - response_handlers=[ - chat_completions_response_handler, - ], - marks=[pytest.mark.gpu_2, pytest.mark.vllm], - ), - text_payload, - ), - "agg_router": ( - DeploymentGraph( - module="graphs.agg_router:Frontend", - config="configs/agg_router.yaml", - directory="/workspace/examples/llm", - endpoints=["v1/chat/completions"], - response_handlers=[ - chat_completions_response_handler, - ], - marks=[pytest.mark.gpu_1, pytest.mark.vllm], - # FIXME: This is a hack to allow deployments to start before sending any requests. - # When using KV-router, if all the endpoints are not registered, the service - # enters a non-recoverable state. - delayed_start=120, - ), - text_payload, - ), - "disagg_router": ( - DeploymentGraph( - module="graphs.disagg_router:Frontend", - config="configs/disagg_router.yaml", - directory="/workspace/examples/llm", - endpoints=["v1/chat/completions"], - response_handlers=[ - chat_completions_response_handler, - ], - marks=[pytest.mark.gpu_2, pytest.mark.vllm], - # FIXME: This is a hack to allow deployments to start before sending any requests. - # When using KV-router, if all the endpoints are not registered, the service - # enters a non-recoverable state. - delayed_start=120, - ), - text_payload, - ), "multimodal_agg": ( DeploymentGraph( module="graphs.agg:Frontend", @@ -169,84 +69,6 @@ ), multimodal_payload, ), - "vllm_v1_agg": ( - DeploymentGraph( - module="graphs.agg:Frontend", - config="configs/agg.yaml", - directory="/workspace/examples/vllm_v1", - endpoints=["v1/chat/completions", "v1/completions"], - response_handlers=[ - chat_completions_response_handler, - completions_response_handler, - ], - marks=[pytest.mark.gpu_1, pytest.mark.vllm], - ), - text_payload, - ), - "trtllm_agg": ( - DeploymentGraph( - module="graphs.agg:Frontend", - config="configs/agg.yaml", - directory="/workspace/examples/tensorrt_llm", - endpoints=["v1/chat/completions", "v1/completions"], - response_handlers=[ - chat_completions_response_handler, - completions_response_handler, - ], - marks=[pytest.mark.gpu_1, pytest.mark.tensorrtllm], - ), - text_payload, - ), - "trtllm_agg_router": ( - DeploymentGraph( - module="graphs.agg:Frontend", - config="configs/agg_router.yaml", - directory="/workspace/examples/tensorrt_llm", - endpoints=["v1/chat/completions", "v1/completions"], - response_handlers=[ - chat_completions_response_handler, - completions_response_handler, - ], - marks=[pytest.mark.gpu_1, pytest.mark.tensorrtllm], - # FIXME: This is a hack to allow deployments to start before sending any requests. - # When using KV-router, if all the endpoints are not registered, the service - # enters a non-recoverable state. - delayed_start=120, - ), - text_payload, - ), - "trtllm_disagg": ( - DeploymentGraph( - module="graphs.disagg:Frontend", - config="configs/disagg.yaml", - directory="/workspace/examples/tensorrt_llm", - endpoints=["v1/chat/completions", "v1/completions"], - response_handlers=[ - chat_completions_response_handler, - completions_response_handler, - ], - marks=[pytest.mark.gpu_2, pytest.mark.tensorrtllm], - ), - text_payload, - ), - "trtllm_disagg_router": ( - DeploymentGraph( - module="graphs.disagg:Frontend", - config="configs/disagg_router.yaml", - directory="/workspace/examples/tensorrt_llm", - endpoints=["v1/chat/completions", "v1/completions"], - response_handlers=[ - chat_completions_response_handler, - completions_response_handler, - ], - marks=[pytest.mark.gpu_2, pytest.mark.tensorrtllm], - # FIXME: This is a hack to allow deployments to start before sending any requests. - # When using KV-router, if all the endpoints are not registered, the service - # enters a non-recoverable state. - delayed_start=120, - ), - text_payload, - ), } @@ -394,17 +216,6 @@ def wait_for_ready(self, payload, logger=logging.getLogger()): @pytest.fixture( params=[ pytest.param("multimodal_agg", marks=[pytest.mark.vllm, pytest.mark.gpu_2]), - pytest.param("trtllm_agg", marks=[pytest.mark.tensorrtllm, pytest.mark.gpu_1]), - pytest.param( - "trtllm_agg_router", marks=[pytest.mark.tensorrtllm, pytest.mark.gpu_1] - ), - pytest.param( - "trtllm_disagg", marks=[pytest.mark.tensorrtllm, pytest.mark.gpu_2] - ), - pytest.param( - "trtllm_disagg_router", marks=[pytest.mark.tensorrtllm, pytest.mark.gpu_2] - ), - # pytest.param("sglang", marks=[pytest.mark.sglang, pytest.mark.gpu_2]), ] ) def deployment_graph_test(request): diff --git a/tests/serve/test_trtllm.py b/tests/serve/test_trtllm.py new file mode 100644 index 0000000000..a57b581a43 --- /dev/null +++ b/tests/serve/test_trtllm.py @@ -0,0 +1,269 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +import time +from dataclasses import dataclass +from typing import Any, Callable, List + +import pytest +import requests + +from tests.utils.deployment_graph import ( + Payload, + chat_completions_response_handler, + completions_response_handler, +) +from tests.utils.managed_process import ManagedProcess + +logger = logging.getLogger(__name__) + +text_prompt = "Tell me a short joke about AI." + + +def create_payload_for_config(config: "TRTLLMConfig") -> Payload: + """Create a payload using the model from the trtllm config""" + return Payload( + payload_chat={ + "model": config.model, + "messages": [ + { + "role": "user", + "content": text_prompt, + } + ], + "max_tokens": 150, + "temperature": 0.1, + }, + payload_completions={ + "model": config.model, + "prompt": text_prompt, + "max_tokens": 150, + "temperature": 0.1, + }, + repeat_count=1, + expected_log=[], + expected_response=["AI"], + ) + + +# TODO: Unify with vllm/sglang tests to reduce code duplication +@dataclass +class TRTLLMConfig: + """Configuration for trtllm test scenarios""" + + name: str + directory: str + script_name: str + marks: List[Any] + endpoints: List[str] + response_handlers: List[Callable[[Any], str]] + model: str + timeout: int = 60 + delayed_start: int = 0 + + +class TRTLLMProcess(ManagedProcess): + """Simple process manager for trtllm shell scripts""" + + def __init__(self, config: TRTLLMConfig, request): + self.port = 8000 + self.config = config + self.dir = config.directory + script_path = os.path.join(self.dir, "launch", config.script_name) + + if not os.path.exists(script_path): + raise FileNotFoundError(f"trtllm script not found: {script_path}") + + # Set these env vars to customize model launched by launch script to match test + os.environ["MODEL_PATH"] = config.model + os.environ["SERVED_MODEL_NAME"] = config.model + + command = ["bash", script_path] + + super().__init__( + command=command, + timeout=config.timeout, + display_output=True, + working_dir=self.dir, + health_check_ports=[], # Disable port health check + health_check_urls=[ + (f"http://localhost:{self.port}/v1/models", self._check_models_api) + ], + delayed_start=config.delayed_start, + terminate_existing=False, # If true, will call all bash processes including myself + stragglers=[], # Don't kill any stragglers automatically + log_dir=request.node.name, + ) + + def _check_models_api(self, response): + """Check if models API is working and returns models""" + try: + if response.status_code != 200: + return False + data = response.json() + return data.get("data") and len(data["data"]) > 0 + except Exception: + return False + + def _check_url(self, url, timeout=30, sleep=2.0): + """Override to use a more reasonable retry interval""" + return super()._check_url(url, timeout, sleep) + + def check_response( + self, payload, response, response_handler, logger=logging.getLogger() + ): + assert response.status_code == 200, "Response Error" + content = response_handler(response) + logger.info(f"Received Content: {content}") + # Check for expected responses + assert content, "Empty response content" + for expected in payload.expected_response: + assert expected in content, f"Expected '{expected}' not found in response" + + def wait_for_ready(self, payload, logger=logging.getLogger()): + url = f"http://localhost:{self.port}/{self.config.endpoints[0]}" + start_time = time.time() + retry_delay = 5 + elapsed = 0.0 + logger.info("Waiting for Deployment Ready") + json_payload = ( + payload.payload_chat + if self.config.endpoints[0] == "v1/chat/completions" + else payload.payload_completions + ) + + while (elapsed := time.time() - start_time) < self.config.timeout: + try: + response = requests.post( + url, + json=json_payload, + timeout=self.config.timeout - elapsed, + ) + except (requests.RequestException, requests.Timeout) as e: + logger.warning(f"Retrying due to Request failed: {e}") + time.sleep(retry_delay) + continue + logger.info(f"Response: {response}") + if response.status_code == 500: + error = response.json().get("error", "") + if "no instances" in error: + logger.warning( + f"Retrying due to no instances available for model '{self.config.model}'" + ) + time.sleep(retry_delay) + continue + if response.status_code == 404: + error = response.json().get("error", "") + if "Model not found" in error: + logger.warning( + f"Retrying due to model not found for model '{self.config.model}'" + ) + time.sleep(retry_delay) + continue + # Process the response + if response.status_code != 200: + pytest.fail( + f"Service returned status code {response.status_code}: {response.text}" + ) + else: + break + else: + pytest.fail( + f"Service did not return a successful response within {self.config.timeout} s" + ) + + self.check_response(payload, response, self.config.response_handlers[0], logger) + + logger.info("Deployment Ready") + + +# trtllm test configurations +trtllm_configs = { + "aggregated": TRTLLMConfig( + name="aggregated", + directory="/workspace/components/backends/trtllm", + script_name="agg.sh", + marks=[pytest.mark.gpu_1, pytest.mark.tensorrtllm], + endpoints=["v1/chat/completions", "v1/completions"], + response_handlers=[ + chat_completions_response_handler, + completions_response_handler, + ], + model="deepseek-ai/DeepSeek-R1-Distill-Llama-8B", + delayed_start=45, + ), + "disaggregated": TRTLLMConfig( + name="disaggregated", + directory="/workspace/components/backends/trtllm", + script_name="disagg.sh", + marks=[pytest.mark.gpu_2, pytest.mark.tensorrtllm], + endpoints=["v1/chat/completions", "v1/completions"], + response_handlers=[ + chat_completions_response_handler, + completions_response_handler, + ], + model="deepseek-ai/DeepSeek-R1-Distill-Llama-8B", + delayed_start=45, + ), +} + + +@pytest.fixture( + params=[ + pytest.param(config_name, marks=config.marks) + for config_name, config in trtllm_configs.items() + ] +) +def trtllm_config_test(request): + """Fixture that provides different trtllm test configurations""" + return trtllm_configs[request.param] + + +@pytest.mark.e2e +@pytest.mark.slow +def test_deployment(trtllm_config_test, request, runtime_services): + """ + Test dynamo deployments with different configurations. + """ + + # runtime_services is used to start nats and etcd + + logger = logging.getLogger(request.node.name) + logger.info("Starting test_deployment") + + config = trtllm_config_test + payload = create_payload_for_config(config) + + logger.info(f"Using model: {config.model}") + logger.info(f"Script: {config.script_name}") + + with TRTLLMProcess(config, request) as server_process: + server_process.wait_for_ready(payload, logger) + + assert len(config.endpoints) == len(config.response_handlers) + for endpoint, response_handler in zip( + config.endpoints, config.response_handlers + ): + url = f"http://localhost:{server_process.port}/{endpoint}" + start_time = time.time() + elapsed = 0.0 + + request_body = ( + payload.payload_chat + if endpoint == "v1/chat/completions" + else payload.payload_completions + ) + + for _ in range(payload.repeat_count): + elapsed = time.time() - start_time + + response = requests.post( + url, + json=request_body, + timeout=config.timeout - elapsed, + ) + server_process.check_response( + payload, response, response_handler, logger + )