diff --git a/components/src/dynamo/trtllm/main.py b/components/src/dynamo/trtllm/main.py index 93f6300b59..4fe4f36199 100644 --- a/components/src/dynamo/trtllm/main.py +++ b/components/src/dynamo/trtllm/main.py @@ -349,6 +349,7 @@ async def init(runtime: DistributedRuntime, config: Config): encode_client=encode_client, multimodal_processor=multimodal_processor, connector=connector, + runtime=runtime, # Pass runtime for graceful shutdown ) if next_client: diff --git a/components/src/dynamo/trtllm/request_handlers/handler_base.py b/components/src/dynamo/trtllm/request_handlers/handler_base.py index 28ef479e85..a25449394c 100644 --- a/components/src/dynamo/trtllm/request_handlers/handler_base.py +++ b/components/src/dynamo/trtllm/request_handlers/handler_base.py @@ -24,12 +24,14 @@ import torch from tensorrt_llm.executor.result import GenerationResult +from tensorrt_llm.executor.utils import RequestError from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams from tensorrt_llm.llmapi.llm import SamplingParams from dynamo._core import Context from dynamo.logits_processing.examples import HelloWorldLogitsProcessor from dynamo.nixl_connect import Connector +from dynamo.runtime import DistributedRuntime from dynamo.runtime.logging import configure_dynamo_logging from dynamo.trtllm.engine import TensorRTLLMEngine from dynamo.trtllm.logits_processing.adapter import create_trtllm_adapters @@ -74,6 +76,9 @@ class RequestHandlerConfig: MultimodalRequestProcessor ] = None # for multimodal support connector: Optional[Connector] = None + runtime: Optional[ + DistributedRuntime + ] = None # DistributedRuntime reference for graceful shutdown class HandlerBase: @@ -94,6 +99,8 @@ def __init__(self, config: RequestHandlerConfig): self.multimodal_processor = config.multimodal_processor self.first_generation = True self.connector = config.connector + # Store runtime reference for graceful shutdown + self.runtime = config.runtime def check_error(self, result: dict): """ @@ -148,6 +155,24 @@ async def _cancellation_monitor( except asyncio.CancelledError: pass + async def _initiate_shutdown(self, error: Exception): + """Initiate graceful shutdown after fatal error""" + logging.warning(f"Initiating graceful shutdown due to: {error}") + + try: + if self.runtime: + logging.info("Shutting down Dynamo runtime...") + self.runtime.shutdown() + + if self.engine: + logging.info("Shutting down TensorRT-LLM engine...") + await self.engine.cleanup() + except Exception as cleanup_error: + logging.error(f"Error during graceful shutdown: {cleanup_error}") + finally: + logging.critical("Forcing process exit for restart") + os._exit(1) + async def generate_locally( self, request: dict, @@ -243,66 +268,99 @@ async def generate_locally( adapters = create_trtllm_adapters(processors) sampling_params.logits_processor = adapters - # NEW: Updated engine call to include multimodal data - generation_result = self.engine.llm.generate_async( - inputs=processed_input, # Use the correctly extracted inputs - sampling_params=sampling_params, - disaggregated_params=disaggregated_params, - streaming=streaming, - ) + try: + # NEW: Updated engine call to include multimodal data + generation_result = self.engine.llm.generate_async( + inputs=processed_input, # Use the correctly extracted inputs + sampling_params=sampling_params, + disaggregated_params=disaggregated_params, + streaming=streaming, + ) - # Use the context manager to handle cancellation monitoring - async with self._cancellation_monitor(generation_result, context): - async for res in generation_result: - # TRTLLM engine needs to start generating tokens first before stats - # can be retrieved. - if self.first_generation and self.publisher: - self.publisher.start() - self.first_generation = False - - # Upon completion, send a final chunk with "stop" as the finish reason. - # This signals to the client that the stream has ended. - if ( - res.finished - and self.disaggregation_mode != DisaggregationMode.PREFILL - ): + # Use the context manager to handle cancellation monitoring + async with self._cancellation_monitor(generation_result, context): + async for res in generation_result: + # TRTLLM engine needs to start generating tokens first before stats + # can be retrieved. + if self.first_generation and self.publisher: + self.publisher.start() + self.first_generation = False + + # Upon completion, send a final chunk with "stop" as the finish reason. + # This signals to the client that the stream has ended. + if ( + res.finished + and self.disaggregation_mode != DisaggregationMode.PREFILL + ): + if self.multimodal_processor: + final_out = self.multimodal_processor.get_stop_response( + request_id, model_name + ) + yield final_out + + # If we are not done generating, but there are no outputs, return an error + if not res.outputs and not res.finished: + yield {"finish_reason": "error", "token_ids": []} + break + + output = res.outputs[0] + # The engine returns all tokens generated so far. We must calculate the new + # tokens generated in this iteration to create the "delta". + next_total_toks = len(output.token_ids) if self.multimodal_processor: - final_out = self.multimodal_processor.get_stop_response( - request_id, model_name + out = self.multimodal_processor.create_response_chunk( + output, num_output_tokens_so_far, request_id, model_name + ) + else: + out = {"token_ids": output.token_ids[num_output_tokens_so_far:]} + if output.finish_reason: + out["finish_reason"] = output.finish_reason + if output.stop_reason: + out["stop_reason"] = output.stop_reason + if self.disaggregation_mode == DisaggregationMode.PREFILL: + # Return the disaggregated params only when operating in prefill mode. + out["disaggregated_params"] = asdict( + DisaggregatedParamsCodec.encode(output.disaggregated_params) ) - yield final_out - - # If we are not done generating, but there are no outputs, return an error - if not res.outputs and not res.finished: - yield {"finish_reason": "error", "token_ids": []} - break - - output = res.outputs[0] - # The engine returns all tokens generated so far. We must calculate the new - # tokens generated in this iteration to create the "delta". - next_total_toks = len(output.token_ids) - if self.multimodal_processor: - out = self.multimodal_processor.create_response_chunk( - output, num_output_tokens_so_far, request_id, model_name - ) - else: - out = {"token_ids": output.token_ids[num_output_tokens_so_far:]} - if output.finish_reason: - out["finish_reason"] = output.finish_reason - if output.stop_reason: - out["stop_reason"] = output.stop_reason - if self.disaggregation_mode == DisaggregationMode.PREFILL: - # Return the disaggregated params only when operating in prefill mode. - out["disaggregated_params"] = asdict( - DisaggregatedParamsCodec.encode(output.disaggregated_params) - ) - - if res.finished and not out.get("finish_reason"): - out["finish_reason"] = "unknown" - logging.warning( - "Request finished with no finish reason set - this indicates a possible bug" - ) - - # Yield the chunk to the client and update the token count for the next iteration. - yield out - num_output_tokens_so_far = next_total_toks + + if res.finished and not out.get("finish_reason"): + out["finish_reason"] = "unknown" + logging.warning( + "Request finished with no finish reason set - this indicates a possible bug" + ) + + # Yield the chunk to the client and update the token count for the next iteration. + yield out + num_output_tokens_so_far = next_total_toks + + # 1. Client cancellation - don't shutdown + except asyncio.CancelledError: + logging.debug(f"Request {request_id}: Client cancelled") + # _cancellation_monitor already called abort_request + return # Just stop, no error response + + # 2. Per-request errors - send to client, don't shutdown + except RequestError as e: + logging.warning(f"Request {request_id} error: {e}") + yield {"finish_reason": "error", "token_ids": []} + + # 3. ALL OTHER ERRORS - graceful shutdown + except Exception as e: + error_type = type(e).__name__ + error_msg = str(e) + logging.error( + f"Fatal {error_type} in request {request_id}: {error_msg}", + exc_info=True, + ) + + # Try to send error to client before shutdown + try: + yield { + "finish_reason": "error", + "token_ids": [], + } + except Exception: + pass # Best effort + + # Initiate graceful shutdown + await self._initiate_shutdown(e)