Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions components/backends/trtllm/src/dynamo/trtllm/health_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,46 @@
This module defines the default health check payload for TRT-LLM backends.
"""

import logging

from dynamo.health_check import HealthCheckPayload

logger = logging.getLogger(__name__)


def _get_bos_token_id_from_tokenizer(tokenizer) -> int:
"""
Extract BOS token ID from the TRT-LLM tokenizer if available.

Args:
tokenizer: TRT-LLM tokenizer object

Returns:
BOS token ID from the tokenizer, or 1 as fallback

Note:
The TransformersTokenizer class wraps a HuggingFace tokenizer.
While TransformersTokenizer doesn't expose bos_token_id directly,
the wrapped HuggingFace tokenizer (accessible via tokenizer.tokenizer) does.
"""
if tokenizer is None:
return 1

try:
if hasattr(tokenizer, "tokenizer"):
inner_tokenizer = getattr(tokenizer, "tokenizer")
bos_token_id = getattr(inner_tokenizer, "bos_token_id", None)
if bos_token_id is not None:
logger.info(
f"Using model's BOS token ID for health check: {bos_token_id}"
)
return int(bos_token_id)
except Exception as e:
logger.debug(f"Failed to get BOS token from tokenizer: {e}")

logger.debug("Using default BOS token ID (1) for health check")
return 1


class TrtllmHealthCheckPayload(HealthCheckPayload):
"""
Expand All @@ -17,14 +55,20 @@ class TrtllmHealthCheckPayload(HealthCheckPayload):
Provides TRT-LLM defaults and inherits environment override support from base class.
"""

def __init__(self):
def __init__(self, tokenizer=None):
"""
Initialize TRT-LLM health check payload with TRT-LLM-specific defaults.

Args:
tokenizer: Optional TRT-LLM tokenizer to extract BOS token from.
If provided, will attempt to use the model's actual BOS token.
"""
bos_token_id = _get_bos_token_id_from_tokenizer(tokenizer)

# Set TensorRT-LLM default payload - minimal request that completes quickly
# The handler expects token_ids, stop_conditions, and sampling_options
self.default_payload = {
"token_ids": [1], # Single token for minimal processing
"token_ids": [bos_token_id],
"stop_conditions": {
"max_tokens": 1, # Generate only 1 token
"stop": None,
Expand Down
2 changes: 1 addition & 1 deletion components/backends/trtllm/src/dynamo/trtllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ async def init(runtime: DistributedRuntime, config: Config):
)

# Get health check payload (checks env var and falls back to TensorRT-LLM default)
health_check_payload = TrtllmHealthCheckPayload().to_dict()
health_check_payload = TrtllmHealthCheckPayload(tokenizer=tokenizer).to_dict()

if config.publish_events_and_metrics and is_first_worker(config):
# Initialize and pass in the publisher to the request handler to
Expand Down
Loading