diff --git a/components/backends/trtllm/src/dynamo/trtllm/health_check.py b/components/backends/trtllm/src/dynamo/trtllm/health_check.py index 0178f6902e..d1a4c94ba5 100644 --- a/components/backends/trtllm/src/dynamo/trtllm/health_check.py +++ b/components/backends/trtllm/src/dynamo/trtllm/health_check.py @@ -7,8 +7,46 @@ This module defines the default health check payload for TRT-LLM backends. """ +import logging + from dynamo.health_check import HealthCheckPayload +logger = logging.getLogger(__name__) + + +def _get_bos_token_id_from_tokenizer(tokenizer) -> int: + """ + Extract BOS token ID from the TRT-LLM tokenizer if available. + + Args: + tokenizer: TRT-LLM tokenizer object + + Returns: + BOS token ID from the tokenizer, or 1 as fallback + + Note: + The TransformersTokenizer class wraps a HuggingFace tokenizer. + While TransformersTokenizer doesn't expose bos_token_id directly, + the wrapped HuggingFace tokenizer (accessible via tokenizer.tokenizer) does. + """ + if tokenizer is None: + return 1 + + try: + if hasattr(tokenizer, "tokenizer"): + inner_tokenizer = getattr(tokenizer, "tokenizer") + bos_token_id = getattr(inner_tokenizer, "bos_token_id", None) + if bos_token_id is not None: + logger.info( + f"Using model's BOS token ID for health check: {bos_token_id}" + ) + return int(bos_token_id) + except Exception as e: + logger.debug(f"Failed to get BOS token from tokenizer: {e}") + + logger.debug("Using default BOS token ID (1) for health check") + return 1 + class TrtllmHealthCheckPayload(HealthCheckPayload): """ @@ -17,14 +55,20 @@ class TrtllmHealthCheckPayload(HealthCheckPayload): Provides TRT-LLM defaults and inherits environment override support from base class. """ - def __init__(self): + def __init__(self, tokenizer=None): """ Initialize TRT-LLM health check payload with TRT-LLM-specific defaults. + + Args: + tokenizer: Optional TRT-LLM tokenizer to extract BOS token from. + If provided, will attempt to use the model's actual BOS token. """ + bos_token_id = _get_bos_token_id_from_tokenizer(tokenizer) + # Set TensorRT-LLM default payload - minimal request that completes quickly # The handler expects token_ids, stop_conditions, and sampling_options self.default_payload = { - "token_ids": [1], # Single token for minimal processing + "token_ids": [bos_token_id], "stop_conditions": { "max_tokens": 1, # Generate only 1 token "stop": None, diff --git a/components/backends/trtllm/src/dynamo/trtllm/main.py b/components/backends/trtllm/src/dynamo/trtllm/main.py index 2ecdad232c..6fcf424cc6 100644 --- a/components/backends/trtllm/src/dynamo/trtllm/main.py +++ b/components/backends/trtllm/src/dynamo/trtllm/main.py @@ -318,7 +318,7 @@ async def init(runtime: DistributedRuntime, config: Config): ) # Get health check payload (checks env var and falls back to TensorRT-LLM default) - health_check_payload = TrtllmHealthCheckPayload().to_dict() + health_check_payload = TrtllmHealthCheckPayload(tokenizer=tokenizer).to_dict() if config.publish_events_and_metrics and is_first_worker(config): # Initialize and pass in the publisher to the request handler to