From cea6c8103b440f61339233490f2cef5893f036fc Mon Sep 17 00:00:00 2001 From: "tzulingk@nvidia.com" Date: Thu, 9 Oct 2025 20:57:40 -0700 Subject: [PATCH 01/13] Catch trtllm engine exceptions Signed-off-by: tzulingk@nvidia.com --- components/src/dynamo/trtllm/main.py | 5 +- .../trtllm/request_handlers/handler_base.py | 179 ++++++++++++------ 2 files changed, 123 insertions(+), 61 deletions(-) diff --git a/components/src/dynamo/trtllm/main.py b/components/src/dynamo/trtllm/main.py index 93f6300b59..7f5515dfd2 100644 --- a/components/src/dynamo/trtllm/main.py +++ b/components/src/dynamo/trtllm/main.py @@ -349,6 +349,7 @@ async def init(runtime: DistributedRuntime, config: Config): encode_client=encode_client, multimodal_processor=multimodal_processor, connector=connector, + runtime=runtime, # Pass runtime for graceful shutdown ) if next_client: @@ -392,7 +393,9 @@ async def init(runtime: DistributedRuntime, config: Config): metrics_labels, ) as publisher: handler_config.publisher = publisher - handler = RequestHandlerFactory().get_request_handler(handler_config) + handler = RequestHandlerFactory().get_request_handler( + handler_config + ) await endpoint.serve_endpoint( handler.generate, metrics_labels=metrics_labels, diff --git a/components/src/dynamo/trtllm/request_handlers/handler_base.py b/components/src/dynamo/trtllm/request_handlers/handler_base.py index 28ef479e85..e1374a9d7c 100644 --- a/components/src/dynamo/trtllm/request_handlers/handler_base.py +++ b/components/src/dynamo/trtllm/request_handlers/handler_base.py @@ -24,12 +24,14 @@ import torch from tensorrt_llm.executor.result import GenerationResult +from tensorrt_llm.executor.utils import RequestError from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams from tensorrt_llm.llmapi.llm import SamplingParams from dynamo._core import Context from dynamo.logits_processing.examples import HelloWorldLogitsProcessor from dynamo.nixl_connect import Connector +from dynamo.runtime import DistributedRuntime from dynamo.runtime.logging import configure_dynamo_logging from dynamo.trtllm.engine import TensorRTLLMEngine from dynamo.trtllm.logits_processing.adapter import create_trtllm_adapters @@ -74,6 +76,9 @@ class RequestHandlerConfig: MultimodalRequestProcessor ] = None # for multimodal support connector: Optional[Connector] = None + runtime: Optional[ + DistributedRuntime + ] = None # DistributedRuntime reference for graceful shutdown class HandlerBase: @@ -94,6 +99,8 @@ def __init__(self, config: RequestHandlerConfig): self.multimodal_processor = config.multimodal_processor self.first_generation = True self.connector = config.connector + # Store runtime reference for graceful shutdown + self.runtime = config.runtime def check_error(self, result: dict): """ @@ -148,6 +155,24 @@ async def _cancellation_monitor( except asyncio.CancelledError: pass + async def _initiate_shutdown(self, error: Exception): + """Initiate graceful shutdown after fatal error""" + logging.warning(f"Initiating graceful shutdown due to: {error}") + + try: + if self.runtime: + logging.info("Shutting down Dynamo runtime...") + self.runtime.shutdown() + + if self.engine: + logging.info("Shutting down TensorRT-LLM engine...") + await self.engine.cleanup() + except Exception as cleanup_error: + logging.error(f"Error during graceful shutdown: {cleanup_error}") + finally: + logging.critical("Forcing process exit for restart") + os._exit(1) + async def generate_locally( self, request: dict, @@ -243,66 +268,100 @@ async def generate_locally( adapters = create_trtllm_adapters(processors) sampling_params.logits_processor = adapters - # NEW: Updated engine call to include multimodal data - generation_result = self.engine.llm.generate_async( - inputs=processed_input, # Use the correctly extracted inputs - sampling_params=sampling_params, - disaggregated_params=disaggregated_params, - streaming=streaming, - ) + try: + # NEW: Updated engine call to include multimodal data + generation_result = self.engine.llm.generate_async( + inputs=processed_input, # Use the correctly extracted inputs + sampling_params=sampling_params, + disaggregated_params=disaggregated_params, + streaming=streaming, + ) - # Use the context manager to handle cancellation monitoring - async with self._cancellation_monitor(generation_result, context): - async for res in generation_result: - # TRTLLM engine needs to start generating tokens first before stats - # can be retrieved. - if self.first_generation and self.publisher: - self.publisher.start() - self.first_generation = False - - # Upon completion, send a final chunk with "stop" as the finish reason. - # This signals to the client that the stream has ended. - if ( - res.finished - and self.disaggregation_mode != DisaggregationMode.PREFILL - ): + # Use the context manager to handle cancellation monitoring + async with self._cancellation_monitor(generation_result, context): + async for res in generation_result: + # TRTLLM engine needs to start generating tokens first before stats + # can be retrieved. + if self.first_generation and self.publisher: + self.publisher.start() + self.first_generation = False + + # Upon completion, send a final chunk with "stop" as the finish reason. + # This signals to the client that the stream has ended. + if ( + res.finished + and self.disaggregation_mode != DisaggregationMode.PREFILL + ): + if self.multimodal_processor: + final_out = self.multimodal_processor.get_stop_response( + request_id, model_name + ) + yield final_out + + # If we are not done generating, but there are no outputs, return an error + if not res.outputs and not res.finished: + yield {"finish_reason": "error", "token_ids": []} + break + + output = res.outputs[0] + # The engine returns all tokens generated so far. We must calculate the new + # tokens generated in this iteration to create the "delta". + next_total_toks = len(output.token_ids) if self.multimodal_processor: - final_out = self.multimodal_processor.get_stop_response( - request_id, model_name + out = self.multimodal_processor.create_response_chunk( + output, num_output_tokens_so_far, request_id, model_name + ) + else: + out = {"token_ids": output.token_ids[num_output_tokens_so_far:]} + if output.finish_reason: + out["finish_reason"] = output.finish_reason + if output.stop_reason: + out["stop_reason"] = output.stop_reason + if self.disaggregation_mode == DisaggregationMode.PREFILL: + # Return the disaggregated params only when operating in prefill mode. + out["disaggregated_params"] = asdict( + DisaggregatedParamsCodec.encode(output.disaggregated_params) ) - yield final_out - - # If we are not done generating, but there are no outputs, return an error - if not res.outputs and not res.finished: - yield {"finish_reason": "error", "token_ids": []} - break - - output = res.outputs[0] - # The engine returns all tokens generated so far. We must calculate the new - # tokens generated in this iteration to create the "delta". - next_total_toks = len(output.token_ids) - if self.multimodal_processor: - out = self.multimodal_processor.create_response_chunk( - output, num_output_tokens_so_far, request_id, model_name - ) - else: - out = {"token_ids": output.token_ids[num_output_tokens_so_far:]} - if output.finish_reason: - out["finish_reason"] = output.finish_reason - if output.stop_reason: - out["stop_reason"] = output.stop_reason - if self.disaggregation_mode == DisaggregationMode.PREFILL: - # Return the disaggregated params only when operating in prefill mode. - out["disaggregated_params"] = asdict( - DisaggregatedParamsCodec.encode(output.disaggregated_params) - ) - - if res.finished and not out.get("finish_reason"): - out["finish_reason"] = "unknown" - logging.warning( - "Request finished with no finish reason set - this indicates a possible bug" - ) - - # Yield the chunk to the client and update the token count for the next iteration. - yield out - num_output_tokens_so_far = next_total_toks + + if res.finished and not out.get("finish_reason"): + out["finish_reason"] = "unknown" + logging.warning( + "Request finished with no finish reason set - this indicates a possible bug" + ) + + # Yield the chunk to the client and update the token count for the next iteration. + yield out + num_output_tokens_so_far = next_total_toks + + # 1. Client cancellation - don't shutdown + except asyncio.CancelledError: + logging.debug(f"Request {request_id}: Client cancelled") + # _cancellation_monitor already called abort_request + return # Just stop, no error response + + # 2. Per-request errors - send to client, don't shutdown + except RequestError as e: + logging.warning(f"Request {request_id} error: {e}") + yield {"finish_reason": "error", "error": str(e), "token_ids": []} + + # 3. ALL OTHER ERRORS - graceful shutdown + except Exception as e: + error_type = type(e).__name__ + error_msg = str(e) + logging.error( + f"Fatal {error_type} in request {request_id}: {error_msg}", + exc_info=True, + ) + + # Try to send error to client before shutdown + try: + yield { + "finish_reason": "error", + "error": "Internal error - service restarting", + "token_ids": [], + } + except: + pass # Best effort + + # Initiate graceful shutdown + await self._initiate_shutdown(e) From 4c6f286813e0ac31b972f3bd91980a8716631765 Mon Sep 17 00:00:00 2001 From: "tzulingk@nvidia.com" Date: Thu, 9 Oct 2025 22:47:52 -0700 Subject: [PATCH 02/13] Add test cases Signed-off-by: tzulingk@nvidia.com --- components/src/dynamo/trtllm/main.py | 10 +- .../src/dynamo/trtllm/test_handler_base.py | 250 ++++++++++++++++++ 2 files changed, 254 insertions(+), 6 deletions(-) create mode 100644 components/src/dynamo/trtllm/test_handler_base.py diff --git a/components/src/dynamo/trtllm/main.py b/components/src/dynamo/trtllm/main.py index 7f5515dfd2..0f54bf1125 100644 --- a/components/src/dynamo/trtllm/main.py +++ b/components/src/dynamo/trtllm/main.py @@ -243,9 +243,9 @@ async def init(runtime: DistributedRuntime, config: Config): else: kv_cache_config = arg_map["kv_cache_config"] if "event_buffer_max_size" not in kv_cache_config: - kv_cache_config[ - "event_buffer_max_size" - ] = DEFAULT_KV_EVENT_BUFFER_MAX_SIZE + kv_cache_config["event_buffer_max_size"] = ( + DEFAULT_KV_EVENT_BUFFER_MAX_SIZE + ) arg_map["kv_cache_config"] = kv_cache_config # Only pytorch backend is supported for now to publish events and metrics. @@ -393,9 +393,7 @@ async def init(runtime: DistributedRuntime, config: Config): metrics_labels, ) as publisher: handler_config.publisher = publisher - handler = RequestHandlerFactory().get_request_handler( - handler_config - ) + handler = RequestHandlerFactory().get_request_handler(handler_config) await endpoint.serve_endpoint( handler.generate, metrics_labels=metrics_labels, diff --git a/components/src/dynamo/trtllm/test_handler_base.py b/components/src/dynamo/trtllm/test_handler_base.py new file mode 100644 index 0000000000..af9bd5ab76 --- /dev/null +++ b/components/src/dynamo/trtllm/test_handler_base.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python +""" +Test runner for handler_base error handling. +Run with: python test_handler_base.py + +This script mocks heavy dependencies before importing handler_base to test error handling. +""" +# type: ignore # This file uses dynamic mocking which confuses mypy + +import sys +import os +import asyncio +from unittest.mock import MagicMock, AsyncMock, patch +from typing import Any, AsyncGenerator + +# Add both the current directory and the components/src directory to the Python path +current_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, current_dir) # Add current directory +sys.path.insert(0, os.path.join(current_dir, "../../..")) # Add components/src + +# Mock all heavy dependencies BEFORE importing handler_base +sys.modules["torch"] = MagicMock() +sys.modules["tensorrt_llm"] = MagicMock() +sys.modules["tensorrt_llm.executor"] = MagicMock() +sys.modules["tensorrt_llm.executor.result"] = MagicMock() +sys.modules["tensorrt_llm.executor.utils"] = MagicMock() +sys.modules["tensorrt_llm.llmapi"] = MagicMock() +sys.modules["tensorrt_llm.llmapi.llm"] = MagicMock() + + +# Create RequestError exception class +class RequestError(Exception): + """Mock RequestError from tensorrt_llm.executor.utils""" + + pass + + +sys.modules["tensorrt_llm.executor.utils"].RequestError = RequestError # type: ignore[attr-defined] + +# Mock other dynamo modules - but NOT dynamo.trtllm.request_handlers since we want to import the real one +sys.modules["dynamo._core"] = MagicMock() +sys.modules["dynamo.logits_processing"] = MagicMock() +sys.modules["dynamo.logits_processing.examples"] = MagicMock() +sys.modules["dynamo.nixl_connect"] = MagicMock() +sys.modules["dynamo.runtime"] = MagicMock() +sys.modules["dynamo.runtime.logging"] = MagicMock() +sys.modules["dynamo.trtllm.engine"] = MagicMock() +sys.modules["dynamo.trtllm.logits_processing"] = MagicMock() +sys.modules["dynamo.trtllm.logits_processing.adapter"] = MagicMock() +sys.modules["dynamo.trtllm.multimodal_processor"] = MagicMock() +sys.modules["dynamo.trtllm.publisher"] = MagicMock() +sys.modules["dynamo.trtllm.utils"] = MagicMock() +sys.modules["dynamo.trtllm.utils.disagg_utils"] = MagicMock() + + +# Mock Context class +class Context: + """Mock Context from dynamo._core""" + + def __init__(self, request_id: str) -> None: + self._id = request_id + self._cancelled: asyncio.Future[None] = asyncio.Future() + self._cancelled.set_result(None) + + def id(self) -> str: + return self._id + + def cancelled(self) -> asyncio.Future[None]: + return self._cancelled + + +sys.modules["dynamo._core"].Context = Context # type: ignore[attr-defined] + +# Import handler_base directly from its location +from request_handlers.handler_base import ( + HandlerBase, + RequestHandlerConfig, + DisaggregationMode, + DisaggregationStrategy, +) + +import pytest + + +class TestHandlerBase: + """Tests for HandlerBase error handling""" + + def create_mock_config(self, with_runtime=True): + """Helper to create a mock RequestHandlerConfig""" + mock_engine = MagicMock() + mock_engine.cleanup = AsyncMock() + mock_engine.llm = MagicMock() # Add llm attribute + + runtime = None + if with_runtime: + runtime = MagicMock() + runtime.shutdown = MagicMock() + + mock_component = MagicMock() + mock_component.rank = 0 + + config = RequestHandlerConfig( + component=mock_component, + engine=mock_engine, + default_sampling_params=MagicMock(), + publisher=None, + runtime=runtime, + disaggregation_mode=DisaggregationMode.AGGREGATED, + disaggregation_strategy=DisaggregationStrategy.PREFILL_FIRST, + next_client=None, + next_router_client=None, + encode_client=None, + multimodal_processor=None, + connector=None, + ) + return config + + def create_mock_generation_result(self, exception_to_raise=None): + """Helper to create a mock generation result""" + mock_gen_result = MagicMock() + + async def mock_generator(self): + # Create a mock result that matches what handler_base expects + mock_res = MagicMock() + mock_res.finished = False + mock_output = MagicMock() + mock_output.token_ids = [1, 2, 3] + mock_output.finish_reason = None + mock_output.stop_reason = None + mock_res.outputs = [mock_output] + yield mock_res + + if exception_to_raise: + raise exception_to_raise + + mock_gen_result.__aiter__ = mock_generator + return mock_gen_result + + def get_test_request(self): + """Helper to get a standard test request""" + return { + "prompt": "test", + "sampling_options": {}, + "stop_conditions": {"max_tokens": 10}, + "trace": {"service_name": "test"}, + "tokens": [1, 2, 3], # Mock tokens + } + + @pytest.mark.asyncio + async def test_request_error_no_shutdown(self): + """Test that RequestError doesn't trigger shutdown""" + # Setup + config = self.create_mock_config(with_runtime=True) + mock_engine = config.engine + mock_runtime = config.runtime + + handler = HandlerBase(config) + mock_context = Context("test-request-123") + + # Mock engine to raise RequestError after yielding + mock_gen_result = self.create_mock_generation_result( + exception_to_raise=RequestError("Invalid request parameters") + ) + mock_engine.llm.generate_async = lambda *args, **kwargs: mock_gen_result + + # Run test + request = self.get_test_request() + responses = [] + async for response in handler.generate_locally(request, mock_context): + responses.append(response) + + # Verify + assert len(responses) == 2, f"Expected 2 responses, got {len(responses)}" + assert responses[0]["token_ids"] == [1, 2, 3] + assert responses[1]["finish_reason"] == "error" + assert "Invalid request" in responses[1]["error"] + + # Critical: NO shutdown should be called + mock_runtime.shutdown.assert_not_called() + mock_engine.cleanup.assert_not_called() + + @pytest.mark.asyncio + async def test_generic_exception_triggers_shutdown(self): + """Test that generic exceptions trigger graceful shutdown""" + # Setup + config = self.create_mock_config(with_runtime=True) + mock_engine = config.engine + mock_runtime = config.runtime + + handler = HandlerBase(config) + mock_context = Context("test-request-456") + + # Mock engine to raise RuntimeError + mock_gen_result = self.create_mock_generation_result( + exception_to_raise=RuntimeError("Engine CUDA out of memory") + ) + mock_engine.llm.generate_async = lambda *args, **kwargs: mock_gen_result + + # Run test with mocked os._exit + with patch("os._exit") as mock_exit: + request = self.get_test_request() + responses = [] + async for response in handler.generate_locally(request, mock_context): + responses.append(response) + + # Verify error response was sent + assert len(responses) == 2 + assert responses[1]["finish_reason"] == "error" + assert "service restarting" in responses[1]["error"].lower() + + # Critical: Shutdown SHOULD be called + mock_runtime.shutdown.assert_called_once() + mock_engine.cleanup.assert_called_once() + mock_exit.assert_called_once_with(1) + + @pytest.mark.asyncio + async def test_cancelled_error_no_shutdown(self): + """Test that CancelledError doesn't trigger shutdown""" + # Setup + config = self.create_mock_config(with_runtime=True) + mock_engine = config.engine + mock_runtime = config.runtime + + handler = HandlerBase(config) + mock_context = Context("test-request-789") + + # Mock engine to raise CancelledError + mock_gen_result = self.create_mock_generation_result( + exception_to_raise=asyncio.CancelledError("Client disconnected") + ) + mock_engine.llm.generate_async = lambda *args, **kwargs: mock_gen_result + + # Run test + request = self.get_test_request() + responses = [] + async for response in handler.generate_locally(request, mock_context): + responses.append(response) + + # Should only have the first response (no error response) + assert len(responses) == 1 + assert responses[0]["token_ids"] == [1, 2, 3] + + # Critical: NO shutdown should be called + mock_runtime.shutdown.assert_not_called() + mock_engine.cleanup.assert_not_called() + + +if __name__ == "__main__": + # Allow running with python test_handler_base.py + pytest.main([__file__, "-v"]) From 93feac5e93f3f7351d495e29908f89ba5b61caf8 Mon Sep 17 00:00:00 2001 From: "tzulingk@nvidia.com" Date: Thu, 9 Oct 2025 22:59:05 -0700 Subject: [PATCH 03/13] Add header Signed-off-by: tzulingk@nvidia.com --- components/src/dynamo/trtllm/test_handler_base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/components/src/dynamo/trtllm/test_handler_base.py b/components/src/dynamo/trtllm/test_handler_base.py index af9bd5ab76..a6f027f15e 100644 --- a/components/src/dynamo/trtllm/test_handler_base.py +++ b/components/src/dynamo/trtllm/test_handler_base.py @@ -1,4 +1,7 @@ #!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + """ Test runner for handler_base error handling. Run with: python test_handler_base.py From 15292365ec2dd14497cacb995c8f3645dede893b Mon Sep 17 00:00:00 2001 From: "tzulingk@nvidia.com" Date: Thu, 9 Oct 2025 23:11:19 -0700 Subject: [PATCH 04/13] fix format Signed-off-by: tzulingk@nvidia.com --- .../trtllm/request_handlers/handler_base.py | 14 +++++++------- .../src/dynamo/trtllm/test_handler_base.py | 18 ++++++++---------- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/components/src/dynamo/trtllm/request_handlers/handler_base.py b/components/src/dynamo/trtllm/request_handlers/handler_base.py index e1374a9d7c..e3578af7a5 100644 --- a/components/src/dynamo/trtllm/request_handlers/handler_base.py +++ b/components/src/dynamo/trtllm/request_handlers/handler_base.py @@ -72,13 +72,13 @@ class RequestHandlerConfig: next_client: object next_router_client: Optional[object] = None encode_client: Optional[object] = None - multimodal_processor: Optional[ - MultimodalRequestProcessor - ] = None # for multimodal support + multimodal_processor: Optional[MultimodalRequestProcessor] = ( + None # for multimodal support + ) connector: Optional[Connector] = None - runtime: Optional[ - DistributedRuntime - ] = None # DistributedRuntime reference for graceful shutdown + runtime: Optional[DistributedRuntime] = ( + None # DistributedRuntime reference for graceful shutdown + ) class HandlerBase: @@ -360,7 +360,7 @@ async def generate_locally( "error": "Internal error - service restarting", "token_ids": [], } - except: + except Exception: pass # Best effort # Initiate graceful shutdown diff --git a/components/src/dynamo/trtllm/test_handler_base.py b/components/src/dynamo/trtllm/test_handler_base.py index a6f027f15e..fa5506c307 100644 --- a/components/src/dynamo/trtllm/test_handler_base.py +++ b/components/src/dynamo/trtllm/test_handler_base.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 @@ -10,11 +9,10 @@ """ # type: ignore # This file uses dynamic mocking which confuses mypy -import sys -import os import asyncio -from unittest.mock import MagicMock, AsyncMock, patch -from typing import Any, AsyncGenerator +import os +import sys +from unittest.mock import AsyncMock, MagicMock, patch # Add both the current directory and the components/src directory to the Python path current_dir = os.path.dirname(os.path.abspath(__file__)) @@ -74,16 +72,16 @@ def cancelled(self) -> asyncio.Future[None]: sys.modules["dynamo._core"].Context = Context # type: ignore[attr-defined] +import pytest # noqa: E402 + # Import handler_base directly from its location -from request_handlers.handler_base import ( - HandlerBase, - RequestHandlerConfig, +from request_handlers.handler_base import ( # noqa: E402 DisaggregationMode, DisaggregationStrategy, + HandlerBase, + RequestHandlerConfig, ) -import pytest - class TestHandlerBase: """Tests for HandlerBase error handling""" From 5d8157ef62077a4bf6e4a6120138416c04513544 Mon Sep 17 00:00:00 2001 From: "tzulingk@nvidia.com" Date: Thu, 9 Oct 2025 23:23:33 -0700 Subject: [PATCH 05/13] copilot fix Signed-off-by: tzulingk@nvidia.com --- components/src/dynamo/trtllm/test_handler_base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/components/src/dynamo/trtllm/test_handler_base.py b/components/src/dynamo/trtllm/test_handler_base.py index fa5506c307..b2be0bf62b 100644 --- a/components/src/dynamo/trtllm/test_handler_base.py +++ b/components/src/dynamo/trtllm/test_handler_base.py @@ -120,7 +120,7 @@ def create_mock_generation_result(self, exception_to_raise=None): """Helper to create a mock generation result""" mock_gen_result = MagicMock() - async def mock_generator(self): + async def mock_generator(): # Create a mock result that matches what handler_base expects mock_res = MagicMock() mock_res.finished = False @@ -134,7 +134,9 @@ async def mock_generator(self): if exception_to_raise: raise exception_to_raise - mock_gen_result.__aiter__ = mock_generator + # __aiter__ should be a method that returns the async generator + # The lambda needs to accept self (passed by MagicMock) but ignore it + mock_gen_result.__aiter__ = lambda self: mock_generator() return mock_gen_result def get_test_request(self): From 0f5c6db94dfa16adab03dd0e0f28849e8d8250be Mon Sep 17 00:00:00 2001 From: "tzulingk@nvidia.com" Date: Thu, 9 Oct 2025 23:26:20 -0700 Subject: [PATCH 06/13] chmod +x components/src/dynamo/trtllm/test_handler_base.py Signed-off-by: tzulingk@nvidia.com --- components/src/dynamo/trtllm/test_handler_base.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 components/src/dynamo/trtllm/test_handler_base.py diff --git a/components/src/dynamo/trtllm/test_handler_base.py b/components/src/dynamo/trtllm/test_handler_base.py old mode 100644 new mode 100755 From 1a4829c86bdbd2ad797a47b2c7f861adb89f8728 Mon Sep 17 00:00:00 2001 From: "tzulingk@nvidia.com" Date: Thu, 9 Oct 2025 23:37:34 -0700 Subject: [PATCH 07/13] add async_killed_or_stopped for mock Context Signed-off-by: tzulingk@nvidia.com --- components/src/dynamo/trtllm/test_handler_base.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/components/src/dynamo/trtllm/test_handler_base.py b/components/src/dynamo/trtllm/test_handler_base.py index b2be0bf62b..45e74a939c 100755 --- a/components/src/dynamo/trtllm/test_handler_base.py +++ b/components/src/dynamo/trtllm/test_handler_base.py @@ -62,6 +62,8 @@ def __init__(self, request_id: str) -> None: self._id = request_id self._cancelled: asyncio.Future[None] = asyncio.Future() self._cancelled.set_result(None) + # For async_killed_or_stopped - a future that never completes + self._killed_or_stopped: asyncio.Future[None] = asyncio.Future() def id(self) -> str: return self._id @@ -69,6 +71,11 @@ def id(self) -> str: def cancelled(self) -> asyncio.Future[None]: return self._cancelled + async def async_killed_or_stopped(self) -> None: + # The hanging behavior ensures the cancellation monitor stays "dormant" + # and doesn't interfere with our test scenarios. + await self._killed_or_stopped + sys.modules["dynamo._core"].Context = Context # type: ignore[attr-defined] From e0f97e99af52eaf525f19af2db80e4eb9c807152 Mon Sep 17 00:00:00 2001 From: "tzulingk@nvidia.com" Date: Thu, 9 Oct 2025 23:40:30 -0700 Subject: [PATCH 08/13] format Signed-off-by: tzulingk@nvidia.com --- components/src/dynamo/trtllm/main.py | 6 +++--- .../dynamo/trtllm/request_handlers/handler_base.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/components/src/dynamo/trtllm/main.py b/components/src/dynamo/trtllm/main.py index 0f54bf1125..4fe4f36199 100644 --- a/components/src/dynamo/trtllm/main.py +++ b/components/src/dynamo/trtllm/main.py @@ -243,9 +243,9 @@ async def init(runtime: DistributedRuntime, config: Config): else: kv_cache_config = arg_map["kv_cache_config"] if "event_buffer_max_size" not in kv_cache_config: - kv_cache_config["event_buffer_max_size"] = ( - DEFAULT_KV_EVENT_BUFFER_MAX_SIZE - ) + kv_cache_config[ + "event_buffer_max_size" + ] = DEFAULT_KV_EVENT_BUFFER_MAX_SIZE arg_map["kv_cache_config"] = kv_cache_config # Only pytorch backend is supported for now to publish events and metrics. diff --git a/components/src/dynamo/trtllm/request_handlers/handler_base.py b/components/src/dynamo/trtllm/request_handlers/handler_base.py index e3578af7a5..d88b441950 100644 --- a/components/src/dynamo/trtllm/request_handlers/handler_base.py +++ b/components/src/dynamo/trtllm/request_handlers/handler_base.py @@ -72,13 +72,13 @@ class RequestHandlerConfig: next_client: object next_router_client: Optional[object] = None encode_client: Optional[object] = None - multimodal_processor: Optional[MultimodalRequestProcessor] = ( - None # for multimodal support - ) + multimodal_processor: Optional[ + MultimodalRequestProcessor + ] = None # for multimodal support connector: Optional[Connector] = None - runtime: Optional[DistributedRuntime] = ( - None # DistributedRuntime reference for graceful shutdown - ) + runtime: Optional[ + DistributedRuntime + ] = None # DistributedRuntime reference for graceful shutdown class HandlerBase: From ee288b7c9ac49e4c3eaa778f855f906b3a9a7718 Mon Sep 17 00:00:00 2001 From: "tzulingk@nvidia.com" Date: Thu, 9 Oct 2025 23:45:41 -0700 Subject: [PATCH 09/13] chmod -x test_handler_base.py Signed-off-by: tzulingk@nvidia.com --- components/src/dynamo/trtllm/test_handler_base.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100755 => 100644 components/src/dynamo/trtllm/test_handler_base.py diff --git a/components/src/dynamo/trtllm/test_handler_base.py b/components/src/dynamo/trtllm/test_handler_base.py old mode 100755 new mode 100644 From 15855fa2db5a7e1d09c28812cb835fb6ca37b076 Mon Sep 17 00:00:00 2001 From: "tzulingk@nvidia.com" Date: Mon, 13 Oct 2025 19:24:28 -0700 Subject: [PATCH 10/13] Remove error field Signed-off-by: tzulingk@nvidia.com --- components/src/dynamo/trtllm/request_handlers/handler_base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/components/src/dynamo/trtllm/request_handlers/handler_base.py b/components/src/dynamo/trtllm/request_handlers/handler_base.py index d88b441950..a25449394c 100644 --- a/components/src/dynamo/trtllm/request_handlers/handler_base.py +++ b/components/src/dynamo/trtllm/request_handlers/handler_base.py @@ -342,7 +342,7 @@ async def generate_locally( # 2. Per-request errors - send to client, don't shutdown except RequestError as e: logging.warning(f"Request {request_id} error: {e}") - yield {"finish_reason": "error", "error": str(e), "token_ids": []} + yield {"finish_reason": "error", "token_ids": []} # 3. ALL OTHER ERRORS - graceful shutdown except Exception as e: @@ -357,7 +357,6 @@ async def generate_locally( try: yield { "finish_reason": "error", - "error": "Internal error - service restarting", "token_ids": [], } except Exception: From 77ab2498d6626df17e01bb3c293e80fdd64fd835 Mon Sep 17 00:00:00 2001 From: "tzulingk@nvidia.com" Date: Mon, 13 Oct 2025 22:04:49 -0700 Subject: [PATCH 11/13] Restore the sys.modules. Signed-off-by: tzulingk@nvidia.com --- .../src/dynamo/trtllm/test_handler_base.py | 56 +++++++++++++++++-- 1 file changed, 52 insertions(+), 4 deletions(-) diff --git a/components/src/dynamo/trtllm/test_handler_base.py b/components/src/dynamo/trtllm/test_handler_base.py index 45e74a939c..b3880ba300 100644 --- a/components/src/dynamo/trtllm/test_handler_base.py +++ b/components/src/dynamo/trtllm/test_handler_base.py @@ -10,6 +10,7 @@ # type: ignore # This file uses dynamic mocking which confuses mypy import asyncio +import atexit import os import sys from unittest.mock import AsyncMock, MagicMock, patch @@ -19,6 +20,35 @@ sys.path.insert(0, current_dir) # Add current directory sys.path.insert(0, os.path.join(current_dir, "../../..")) # Add components/src +# Save original sys.modules state before mocking +original_modules = {} +modules_to_mock = [ + "torch", + "tensorrt_llm", + "tensorrt_llm.executor", + "tensorrt_llm.executor.result", + "tensorrt_llm.executor.utils", + "tensorrt_llm.llmapi", + "tensorrt_llm.llmapi.llm", + "dynamo._core", + "dynamo.logits_processing", + "dynamo.logits_processing.examples", + "dynamo.nixl_connect", + "dynamo.runtime", + "dynamo.runtime.logging", + "dynamo.trtllm.engine", + "dynamo.trtllm.logits_processing", + "dynamo.trtllm.logits_processing.adapter", + "dynamo.trtllm.multimodal_processor", + "dynamo.trtllm.publisher", + "dynamo.trtllm.utils", + "dynamo.trtllm.utils.disagg_utils", +] + +for module_name in modules_to_mock: + if module_name in sys.modules: + original_modules[module_name] = sys.modules[module_name] + # Mock all heavy dependencies BEFORE importing handler_base sys.modules["torch"] = MagicMock() sys.modules["tensorrt_llm"] = MagicMock() @@ -90,6 +120,22 @@ async def async_killed_or_stopped(self) -> None: ) +def cleanup_modules(): + """Restore original sys.modules state.""" + # Remove mocked modules + for module_name in modules_to_mock: + if module_name in sys.modules: + del sys.modules[module_name] + + # Restore original modules if they existed + for module_name, original_module in original_modules.items(): + sys.modules[module_name] = original_module + + +# Register cleanup to run at exit +atexit.register(cleanup_modules) + + class TestHandlerBase: """Tests for HandlerBase error handling""" @@ -183,7 +229,6 @@ async def test_request_error_no_shutdown(self): assert len(responses) == 2, f"Expected 2 responses, got {len(responses)}" assert responses[0]["token_ids"] == [1, 2, 3] assert responses[1]["finish_reason"] == "error" - assert "Invalid request" in responses[1]["error"] # Critical: NO shutdown should be called mock_runtime.shutdown.assert_not_called() @@ -216,7 +261,6 @@ async def test_generic_exception_triggers_shutdown(self): # Verify error response was sent assert len(responses) == 2 assert responses[1]["finish_reason"] == "error" - assert "service restarting" in responses[1]["error"].lower() # Critical: Shutdown SHOULD be called mock_runtime.shutdown.assert_called_once() @@ -256,5 +300,9 @@ async def test_cancelled_error_no_shutdown(self): if __name__ == "__main__": - # Allow running with python test_handler_base.py - pytest.main([__file__, "-v"]) + try: + # Allow running with python test_handler_base.py + pytest.main([__file__, "-v"]) + finally: + # Ensure cleanup happens even if tests fail + cleanup_modules() From 2bd33b170717c15b40d5c375990b084f98f55f66 Mon Sep 17 00:00:00 2001 From: "tzulingk@nvidia.com" Date: Tue, 14 Oct 2025 09:49:00 -0700 Subject: [PATCH 12/13] Restore the sys.modules in tests. Signed-off-by: tzulingk@nvidia.com --- .../src/dynamo/trtllm/test_handler_base.py | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/components/src/dynamo/trtllm/test_handler_base.py b/components/src/dynamo/trtllm/test_handler_base.py index b3880ba300..6c70690724 100644 --- a/components/src/dynamo/trtllm/test_handler_base.py +++ b/components/src/dynamo/trtllm/test_handler_base.py @@ -6,6 +6,18 @@ Run with: python test_handler_base.py This script mocks heavy dependencies before importing handler_base to test error handling. + +MOCK LIFECYCLE EXPLANATION: +1. Module-level mocking: Required to successfully import handler_base + which has dependencies on tensorrt_llm, torch, etc. +2. Import handler_base: Imports work because of the mocks +3. Immediate cleanup: Removes mocks from sys.modules to prevent + interference with pytest's test collection (prevents "tensorrt_llm.__spec__ is not set") +4. setup_method: Re-establishes mocks before each test runs +5. teardown_method: Cleans up after each test + +This dual approach allows us to import handler_base (which needs mocks) while +preventing our mocks from breaking pytest's collection of other test files. """ # type: ignore # This file uses dynamic mocking which confuses mypy @@ -50,6 +62,11 @@ original_modules[module_name] = sys.modules[module_name] # Mock all heavy dependencies BEFORE importing handler_base +# WHY WE NEED THIS: handler_base.py imports tensorrt_llm, torch, etc. +# Without these mocks, the import on line 114-119 would fail because these +# packages aren't installed in the test environment. +# This is DIFFERENT from the mocking in setup_method - this enables the import, +# while setup_method re-establishes mocks for test execution after cleanup. sys.modules["torch"] = MagicMock() sys.modules["tensorrt_llm"] = MagicMock() sys.modules["tensorrt_llm.executor"] = MagicMock() @@ -135,10 +152,70 @@ def cleanup_modules(): # Register cleanup to run at exit atexit.register(cleanup_modules) +# IMPORTANT: Clean up immediately after imports are done +# WHY WE CLEAN UP HERE: When pytest collects tests, it imports all test files. +# If we leave tensorrt_llm mocked in sys.modules, when pytest tries to check +# if tensorrt_llm is available for test_trtllm_unit.py (via conftest.py), +# it finds our MagicMock which doesn't have __spec__, causing: +# "ValueError: tensorrt_llm.__spec__ is not set" +# By cleaning up here, we prevent our mocks from interfering with pytest collection. +# The mocks will be re-established when tests actually run via setup_method. +cleanup_modules() + class TestHandlerBase: """Tests for HandlerBase error handling""" + def setup_method(self): + """Re-establish mocks before each test method runs. + + WHY WE NEED THIS: After cleanup_modules() removed all mocks to prevent + pytest collection issues, we need to put them back when tests actually run. + The HandlerBase code that was imported earlier expects these modules to be + mocked when it executes during the test. + """ + # Put mocks back for test execution + sys.modules["torch"] = MagicMock() + sys.modules["tensorrt_llm"] = MagicMock() + sys.modules["tensorrt_llm.executor"] = MagicMock() + sys.modules["tensorrt_llm.executor.result"] = MagicMock() + sys.modules["tensorrt_llm.executor.utils"] = MagicMock() + sys.modules["tensorrt_llm.llmapi"] = MagicMock() + sys.modules["tensorrt_llm.llmapi.llm"] = MagicMock() + + # Re-create RequestError + class RequestError(Exception): + pass + + sys.modules["tensorrt_llm.executor.utils"].RequestError = RequestError + + # Re-mock dynamo modules + sys.modules["dynamo._core"] = MagicMock() + sys.modules["dynamo.logits_processing"] = MagicMock() + sys.modules["dynamo.logits_processing.examples"] = MagicMock() + sys.modules["dynamo.nixl_connect"] = MagicMock() + sys.modules["dynamo.runtime"] = MagicMock() + sys.modules["dynamo.runtime.logging"] = MagicMock() + sys.modules["dynamo.trtllm.engine"] = MagicMock() + sys.modules["dynamo.trtllm.logits_processing"] = MagicMock() + sys.modules["dynamo.trtllm.logits_processing.adapter"] = MagicMock() + sys.modules["dynamo.trtllm.multimodal_processor"] = MagicMock() + sys.modules["dynamo.trtllm.publisher"] = MagicMock() + sys.modules["dynamo.trtllm.utils"] = MagicMock() + sys.modules["dynamo.trtllm.utils.disagg_utils"] = MagicMock() + + # Re-establish Context if needed + sys.modules["dynamo._core"].Context = Context + + def teardown_method(self): + """Clean up mocks after each test method. + + WHY WE NEED THIS: Ensures clean state between tests and prevents + any lingering mocked modules from affecting subsequent tests or + pytest operations. + """ + cleanup_modules() + def create_mock_config(self, with_runtime=True): """Helper to create a mock RequestHandlerConfig""" mock_engine = MagicMock() From 93bbe57a0bd031428c34d2d45e0689d0fdd7588a Mon Sep 17 00:00:00 2001 From: "tzulingk@nvidia.com" Date: Tue, 14 Oct 2025 11:03:16 -0700 Subject: [PATCH 13/13] remove components/src/dynamo/trtllm/test_handler_base.py for the sys.modules issue Signed-off-by: tzulingk@nvidia.com --- .../src/dynamo/trtllm/test_handler_base.py | 385 ------------------ 1 file changed, 385 deletions(-) delete mode 100644 components/src/dynamo/trtllm/test_handler_base.py diff --git a/components/src/dynamo/trtllm/test_handler_base.py b/components/src/dynamo/trtllm/test_handler_base.py deleted file mode 100644 index 6c70690724..0000000000 --- a/components/src/dynamo/trtllm/test_handler_base.py +++ /dev/null @@ -1,385 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -""" -Test runner for handler_base error handling. -Run with: python test_handler_base.py - -This script mocks heavy dependencies before importing handler_base to test error handling. - -MOCK LIFECYCLE EXPLANATION: -1. Module-level mocking: Required to successfully import handler_base - which has dependencies on tensorrt_llm, torch, etc. -2. Import handler_base: Imports work because of the mocks -3. Immediate cleanup: Removes mocks from sys.modules to prevent - interference with pytest's test collection (prevents "tensorrt_llm.__spec__ is not set") -4. setup_method: Re-establishes mocks before each test runs -5. teardown_method: Cleans up after each test - -This dual approach allows us to import handler_base (which needs mocks) while -preventing our mocks from breaking pytest's collection of other test files. -""" -# type: ignore # This file uses dynamic mocking which confuses mypy - -import asyncio -import atexit -import os -import sys -from unittest.mock import AsyncMock, MagicMock, patch - -# Add both the current directory and the components/src directory to the Python path -current_dir = os.path.dirname(os.path.abspath(__file__)) -sys.path.insert(0, current_dir) # Add current directory -sys.path.insert(0, os.path.join(current_dir, "../../..")) # Add components/src - -# Save original sys.modules state before mocking -original_modules = {} -modules_to_mock = [ - "torch", - "tensorrt_llm", - "tensorrt_llm.executor", - "tensorrt_llm.executor.result", - "tensorrt_llm.executor.utils", - "tensorrt_llm.llmapi", - "tensorrt_llm.llmapi.llm", - "dynamo._core", - "dynamo.logits_processing", - "dynamo.logits_processing.examples", - "dynamo.nixl_connect", - "dynamo.runtime", - "dynamo.runtime.logging", - "dynamo.trtllm.engine", - "dynamo.trtllm.logits_processing", - "dynamo.trtllm.logits_processing.adapter", - "dynamo.trtllm.multimodal_processor", - "dynamo.trtllm.publisher", - "dynamo.trtllm.utils", - "dynamo.trtllm.utils.disagg_utils", -] - -for module_name in modules_to_mock: - if module_name in sys.modules: - original_modules[module_name] = sys.modules[module_name] - -# Mock all heavy dependencies BEFORE importing handler_base -# WHY WE NEED THIS: handler_base.py imports tensorrt_llm, torch, etc. -# Without these mocks, the import on line 114-119 would fail because these -# packages aren't installed in the test environment. -# This is DIFFERENT from the mocking in setup_method - this enables the import, -# while setup_method re-establishes mocks for test execution after cleanup. -sys.modules["torch"] = MagicMock() -sys.modules["tensorrt_llm"] = MagicMock() -sys.modules["tensorrt_llm.executor"] = MagicMock() -sys.modules["tensorrt_llm.executor.result"] = MagicMock() -sys.modules["tensorrt_llm.executor.utils"] = MagicMock() -sys.modules["tensorrt_llm.llmapi"] = MagicMock() -sys.modules["tensorrt_llm.llmapi.llm"] = MagicMock() - - -# Create RequestError exception class -class RequestError(Exception): - """Mock RequestError from tensorrt_llm.executor.utils""" - - pass - - -sys.modules["tensorrt_llm.executor.utils"].RequestError = RequestError # type: ignore[attr-defined] - -# Mock other dynamo modules - but NOT dynamo.trtllm.request_handlers since we want to import the real one -sys.modules["dynamo._core"] = MagicMock() -sys.modules["dynamo.logits_processing"] = MagicMock() -sys.modules["dynamo.logits_processing.examples"] = MagicMock() -sys.modules["dynamo.nixl_connect"] = MagicMock() -sys.modules["dynamo.runtime"] = MagicMock() -sys.modules["dynamo.runtime.logging"] = MagicMock() -sys.modules["dynamo.trtllm.engine"] = MagicMock() -sys.modules["dynamo.trtllm.logits_processing"] = MagicMock() -sys.modules["dynamo.trtllm.logits_processing.adapter"] = MagicMock() -sys.modules["dynamo.trtllm.multimodal_processor"] = MagicMock() -sys.modules["dynamo.trtllm.publisher"] = MagicMock() -sys.modules["dynamo.trtllm.utils"] = MagicMock() -sys.modules["dynamo.trtllm.utils.disagg_utils"] = MagicMock() - - -# Mock Context class -class Context: - """Mock Context from dynamo._core""" - - def __init__(self, request_id: str) -> None: - self._id = request_id - self._cancelled: asyncio.Future[None] = asyncio.Future() - self._cancelled.set_result(None) - # For async_killed_or_stopped - a future that never completes - self._killed_or_stopped: asyncio.Future[None] = asyncio.Future() - - def id(self) -> str: - return self._id - - def cancelled(self) -> asyncio.Future[None]: - return self._cancelled - - async def async_killed_or_stopped(self) -> None: - # The hanging behavior ensures the cancellation monitor stays "dormant" - # and doesn't interfere with our test scenarios. - await self._killed_or_stopped - - -sys.modules["dynamo._core"].Context = Context # type: ignore[attr-defined] - -import pytest # noqa: E402 - -# Import handler_base directly from its location -from request_handlers.handler_base import ( # noqa: E402 - DisaggregationMode, - DisaggregationStrategy, - HandlerBase, - RequestHandlerConfig, -) - - -def cleanup_modules(): - """Restore original sys.modules state.""" - # Remove mocked modules - for module_name in modules_to_mock: - if module_name in sys.modules: - del sys.modules[module_name] - - # Restore original modules if they existed - for module_name, original_module in original_modules.items(): - sys.modules[module_name] = original_module - - -# Register cleanup to run at exit -atexit.register(cleanup_modules) - -# IMPORTANT: Clean up immediately after imports are done -# WHY WE CLEAN UP HERE: When pytest collects tests, it imports all test files. -# If we leave tensorrt_llm mocked in sys.modules, when pytest tries to check -# if tensorrt_llm is available for test_trtllm_unit.py (via conftest.py), -# it finds our MagicMock which doesn't have __spec__, causing: -# "ValueError: tensorrt_llm.__spec__ is not set" -# By cleaning up here, we prevent our mocks from interfering with pytest collection. -# The mocks will be re-established when tests actually run via setup_method. -cleanup_modules() - - -class TestHandlerBase: - """Tests for HandlerBase error handling""" - - def setup_method(self): - """Re-establish mocks before each test method runs. - - WHY WE NEED THIS: After cleanup_modules() removed all mocks to prevent - pytest collection issues, we need to put them back when tests actually run. - The HandlerBase code that was imported earlier expects these modules to be - mocked when it executes during the test. - """ - # Put mocks back for test execution - sys.modules["torch"] = MagicMock() - sys.modules["tensorrt_llm"] = MagicMock() - sys.modules["tensorrt_llm.executor"] = MagicMock() - sys.modules["tensorrt_llm.executor.result"] = MagicMock() - sys.modules["tensorrt_llm.executor.utils"] = MagicMock() - sys.modules["tensorrt_llm.llmapi"] = MagicMock() - sys.modules["tensorrt_llm.llmapi.llm"] = MagicMock() - - # Re-create RequestError - class RequestError(Exception): - pass - - sys.modules["tensorrt_llm.executor.utils"].RequestError = RequestError - - # Re-mock dynamo modules - sys.modules["dynamo._core"] = MagicMock() - sys.modules["dynamo.logits_processing"] = MagicMock() - sys.modules["dynamo.logits_processing.examples"] = MagicMock() - sys.modules["dynamo.nixl_connect"] = MagicMock() - sys.modules["dynamo.runtime"] = MagicMock() - sys.modules["dynamo.runtime.logging"] = MagicMock() - sys.modules["dynamo.trtllm.engine"] = MagicMock() - sys.modules["dynamo.trtllm.logits_processing"] = MagicMock() - sys.modules["dynamo.trtllm.logits_processing.adapter"] = MagicMock() - sys.modules["dynamo.trtllm.multimodal_processor"] = MagicMock() - sys.modules["dynamo.trtllm.publisher"] = MagicMock() - sys.modules["dynamo.trtllm.utils"] = MagicMock() - sys.modules["dynamo.trtllm.utils.disagg_utils"] = MagicMock() - - # Re-establish Context if needed - sys.modules["dynamo._core"].Context = Context - - def teardown_method(self): - """Clean up mocks after each test method. - - WHY WE NEED THIS: Ensures clean state between tests and prevents - any lingering mocked modules from affecting subsequent tests or - pytest operations. - """ - cleanup_modules() - - def create_mock_config(self, with_runtime=True): - """Helper to create a mock RequestHandlerConfig""" - mock_engine = MagicMock() - mock_engine.cleanup = AsyncMock() - mock_engine.llm = MagicMock() # Add llm attribute - - runtime = None - if with_runtime: - runtime = MagicMock() - runtime.shutdown = MagicMock() - - mock_component = MagicMock() - mock_component.rank = 0 - - config = RequestHandlerConfig( - component=mock_component, - engine=mock_engine, - default_sampling_params=MagicMock(), - publisher=None, - runtime=runtime, - disaggregation_mode=DisaggregationMode.AGGREGATED, - disaggregation_strategy=DisaggregationStrategy.PREFILL_FIRST, - next_client=None, - next_router_client=None, - encode_client=None, - multimodal_processor=None, - connector=None, - ) - return config - - def create_mock_generation_result(self, exception_to_raise=None): - """Helper to create a mock generation result""" - mock_gen_result = MagicMock() - - async def mock_generator(): - # Create a mock result that matches what handler_base expects - mock_res = MagicMock() - mock_res.finished = False - mock_output = MagicMock() - mock_output.token_ids = [1, 2, 3] - mock_output.finish_reason = None - mock_output.stop_reason = None - mock_res.outputs = [mock_output] - yield mock_res - - if exception_to_raise: - raise exception_to_raise - - # __aiter__ should be a method that returns the async generator - # The lambda needs to accept self (passed by MagicMock) but ignore it - mock_gen_result.__aiter__ = lambda self: mock_generator() - return mock_gen_result - - def get_test_request(self): - """Helper to get a standard test request""" - return { - "prompt": "test", - "sampling_options": {}, - "stop_conditions": {"max_tokens": 10}, - "trace": {"service_name": "test"}, - "tokens": [1, 2, 3], # Mock tokens - } - - @pytest.mark.asyncio - async def test_request_error_no_shutdown(self): - """Test that RequestError doesn't trigger shutdown""" - # Setup - config = self.create_mock_config(with_runtime=True) - mock_engine = config.engine - mock_runtime = config.runtime - - handler = HandlerBase(config) - mock_context = Context("test-request-123") - - # Mock engine to raise RequestError after yielding - mock_gen_result = self.create_mock_generation_result( - exception_to_raise=RequestError("Invalid request parameters") - ) - mock_engine.llm.generate_async = lambda *args, **kwargs: mock_gen_result - - # Run test - request = self.get_test_request() - responses = [] - async for response in handler.generate_locally(request, mock_context): - responses.append(response) - - # Verify - assert len(responses) == 2, f"Expected 2 responses, got {len(responses)}" - assert responses[0]["token_ids"] == [1, 2, 3] - assert responses[1]["finish_reason"] == "error" - - # Critical: NO shutdown should be called - mock_runtime.shutdown.assert_not_called() - mock_engine.cleanup.assert_not_called() - - @pytest.mark.asyncio - async def test_generic_exception_triggers_shutdown(self): - """Test that generic exceptions trigger graceful shutdown""" - # Setup - config = self.create_mock_config(with_runtime=True) - mock_engine = config.engine - mock_runtime = config.runtime - - handler = HandlerBase(config) - mock_context = Context("test-request-456") - - # Mock engine to raise RuntimeError - mock_gen_result = self.create_mock_generation_result( - exception_to_raise=RuntimeError("Engine CUDA out of memory") - ) - mock_engine.llm.generate_async = lambda *args, **kwargs: mock_gen_result - - # Run test with mocked os._exit - with patch("os._exit") as mock_exit: - request = self.get_test_request() - responses = [] - async for response in handler.generate_locally(request, mock_context): - responses.append(response) - - # Verify error response was sent - assert len(responses) == 2 - assert responses[1]["finish_reason"] == "error" - - # Critical: Shutdown SHOULD be called - mock_runtime.shutdown.assert_called_once() - mock_engine.cleanup.assert_called_once() - mock_exit.assert_called_once_with(1) - - @pytest.mark.asyncio - async def test_cancelled_error_no_shutdown(self): - """Test that CancelledError doesn't trigger shutdown""" - # Setup - config = self.create_mock_config(with_runtime=True) - mock_engine = config.engine - mock_runtime = config.runtime - - handler = HandlerBase(config) - mock_context = Context("test-request-789") - - # Mock engine to raise CancelledError - mock_gen_result = self.create_mock_generation_result( - exception_to_raise=asyncio.CancelledError("Client disconnected") - ) - mock_engine.llm.generate_async = lambda *args, **kwargs: mock_gen_result - - # Run test - request = self.get_test_request() - responses = [] - async for response in handler.generate_locally(request, mock_context): - responses.append(response) - - # Should only have the first response (no error response) - assert len(responses) == 1 - assert responses[0]["token_ids"] == [1, 2, 3] - - # Critical: NO shutdown should be called - mock_runtime.shutdown.assert_not_called() - mock_engine.cleanup.assert_not_called() - - -if __name__ == "__main__": - try: - # Allow running with python test_handler_base.py - pytest.main([__file__, "-v"]) - finally: - # Ensure cleanup happens even if tests fail - cleanup_modules()