Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions components/src/dynamo/trtllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ async def init(runtime: DistributedRuntime, config: Config):
encode_client=encode_client,
multimodal_processor=multimodal_processor,
connector=connector,
runtime=runtime, # Pass runtime for graceful shutdown
)

if next_client:
Expand Down
178 changes: 118 additions & 60 deletions components/src/dynamo/trtllm/request_handlers/handler_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@

import torch
from tensorrt_llm.executor.result import GenerationResult
from tensorrt_llm.executor.utils import RequestError
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.llmapi.llm import SamplingParams

from dynamo._core import Context
from dynamo.logits_processing.examples import HelloWorldLogitsProcessor
from dynamo.nixl_connect import Connector
from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.engine import TensorRTLLMEngine
from dynamo.trtllm.logits_processing.adapter import create_trtllm_adapters
Expand Down Expand Up @@ -74,6 +76,9 @@ class RequestHandlerConfig:
MultimodalRequestProcessor
] = None # for multimodal support
connector: Optional[Connector] = None
runtime: Optional[
DistributedRuntime
] = None # DistributedRuntime reference for graceful shutdown


class HandlerBase:
Expand All @@ -94,6 +99,8 @@ def __init__(self, config: RequestHandlerConfig):
self.multimodal_processor = config.multimodal_processor
self.first_generation = True
self.connector = config.connector
# Store runtime reference for graceful shutdown
self.runtime = config.runtime

def check_error(self, result: dict):
"""
Expand Down Expand Up @@ -148,6 +155,24 @@ async def _cancellation_monitor(
except asyncio.CancelledError:
pass

async def _initiate_shutdown(self, error: Exception):
"""Initiate graceful shutdown after fatal error"""
logging.warning(f"Initiating graceful shutdown due to: {error}")

try:
if self.runtime:
logging.info("Shutting down Dynamo runtime...")
self.runtime.shutdown()

if self.engine:
logging.info("Shutting down TensorRT-LLM engine...")
await self.engine.cleanup()
except Exception as cleanup_error:
logging.error(f"Error during graceful shutdown: {cleanup_error}")
finally:
logging.critical("Forcing process exit for restart")
os._exit(1)

async def generate_locally(
self,
request: dict,
Expand Down Expand Up @@ -243,66 +268,99 @@ async def generate_locally(
adapters = create_trtllm_adapters(processors)
sampling_params.logits_processor = adapters

# NEW: Updated engine call to include multimodal data
generation_result = self.engine.llm.generate_async(
inputs=processed_input, # Use the correctly extracted inputs
sampling_params=sampling_params,
disaggregated_params=disaggregated_params,
streaming=streaming,
)
try:
# NEW: Updated engine call to include multimodal data
generation_result = self.engine.llm.generate_async(
inputs=processed_input, # Use the correctly extracted inputs
sampling_params=sampling_params,
disaggregated_params=disaggregated_params,
streaming=streaming,
)

# Use the context manager to handle cancellation monitoring
async with self._cancellation_monitor(generation_result, context):
async for res in generation_result:
# TRTLLM engine needs to start generating tokens first before stats
# can be retrieved.
if self.first_generation and self.publisher:
self.publisher.start()
self.first_generation = False

# Upon completion, send a final chunk with "stop" as the finish reason.
# This signals to the client that the stream has ended.
if (
res.finished
and self.disaggregation_mode != DisaggregationMode.PREFILL
):
# Use the context manager to handle cancellation monitoring
async with self._cancellation_monitor(generation_result, context):
async for res in generation_result:
# TRTLLM engine needs to start generating tokens first before stats
# can be retrieved.
if self.first_generation and self.publisher:
self.publisher.start()
self.first_generation = False

# Upon completion, send a final chunk with "stop" as the finish reason.
# This signals to the client that the stream has ended.
if (
res.finished
and self.disaggregation_mode != DisaggregationMode.PREFILL
):
if self.multimodal_processor:
final_out = self.multimodal_processor.get_stop_response(
request_id, model_name
)
yield final_out

# If we are not done generating, but there are no outputs, return an error
if not res.outputs and not res.finished:
yield {"finish_reason": "error", "token_ids": []}
break

output = res.outputs[0]
# The engine returns all tokens generated so far. We must calculate the new
# tokens generated in this iteration to create the "delta".
next_total_toks = len(output.token_ids)
if self.multimodal_processor:
final_out = self.multimodal_processor.get_stop_response(
request_id, model_name
out = self.multimodal_processor.create_response_chunk(
output, num_output_tokens_so_far, request_id, model_name
)
else:
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = output.finish_reason
if output.stop_reason:
out["stop_reason"] = output.stop_reason
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# Return the disaggregated params only when operating in prefill mode.
out["disaggregated_params"] = asdict(
DisaggregatedParamsCodec.encode(output.disaggregated_params)
)
yield final_out

# If we are not done generating, but there are no outputs, return an error
if not res.outputs and not res.finished:
yield {"finish_reason": "error", "token_ids": []}
break

output = res.outputs[0]
# The engine returns all tokens generated so far. We must calculate the new
# tokens generated in this iteration to create the "delta".
next_total_toks = len(output.token_ids)
if self.multimodal_processor:
out = self.multimodal_processor.create_response_chunk(
output, num_output_tokens_so_far, request_id, model_name
)
else:
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = output.finish_reason
if output.stop_reason:
out["stop_reason"] = output.stop_reason
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# Return the disaggregated params only when operating in prefill mode.
out["disaggregated_params"] = asdict(
DisaggregatedParamsCodec.encode(output.disaggregated_params)
)

if res.finished and not out.get("finish_reason"):
out["finish_reason"] = "unknown"
logging.warning(
"Request finished with no finish reason set - this indicates a possible bug"
)

# Yield the chunk to the client and update the token count for the next iteration.
yield out
num_output_tokens_so_far = next_total_toks

if res.finished and not out.get("finish_reason"):
out["finish_reason"] = "unknown"
logging.warning(
"Request finished with no finish reason set - this indicates a possible bug"
)

# Yield the chunk to the client and update the token count for the next iteration.
yield out
num_output_tokens_so_far = next_total_toks

# 1. Client cancellation - don't shutdown
except asyncio.CancelledError:
logging.debug(f"Request {request_id}: Client cancelled")
# _cancellation_monitor already called abort_request
return # Just stop, no error response

# 2. Per-request errors - send to client, don't shutdown
except RequestError as e:
logging.warning(f"Request {request_id} error: {e}")
yield {"finish_reason": "error", "token_ids": []}

# 3. ALL OTHER ERRORS - graceful shutdown
except Exception as e:
error_type = type(e).__name__
error_msg = str(e)
logging.error(
f"Fatal {error_type} in request {request_id}: {error_msg}",
exc_info=True,
)

# Try to send error to client before shutdown
try:
yield {
"finish_reason": "error",
"token_ids": [],
}
except Exception:
pass # Best effort

# Initiate graceful shutdown
await self._initiate_shutdown(e)
Loading