Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
3b904c6
[feat] Add OpenTelemetry tracing
Aug 11, 2025
5f15841
[feat] Modular otel_trace
Aug 18, 2025
ee83e80
[feat] Add trace to disagg server and add kv_cache info
Aug 21, 2025
4d3a59a
[docs] Add openTelemetry integration guide
zhanghaotong Aug 25, 2025
89c2773
[chores] fix todo
zhanghaotong Aug 26, 2025
c987e06
fix
zhanghaotong Aug 26, 2025
94126ca
[fix] remove opentelemetry package from requirements.txt
zhanghaotong Aug 26, 2025
5938445
[chores] pre commit
zhanghaotong Aug 26, 2025
a4817fb
Merge branch 'main' into otlp-trace
zhanghaotong Sep 10, 2025
c54f281
[feat] use more accurate time correction
zhanghaotong Sep 16, 2025
c23205d
Merge branch 'main' into otlp-trace
zhanghaotong Sep 16, 2025
680ae97
Merge branch 'main' into otlp-trace
zhanghaotong Sep 22, 2025
4890590
Merge branch 'main' into otlp-trace
zhanghaotong Sep 25, 2025
95fc55f
Merge branch 'main' into otlp-trace
zhanghaotong Oct 10, 2025
fec6530
pre-commit
zhanghaotong Oct 10, 2025
db621c7
use strEnum and rename ObservabilityConfig to OtlpConfig
zhanghaotong Oct 10, 2025
ab0fbed
use strenum
zhanghaotong Oct 13, 2025
9d48874
Merge branch 'main' into otlp-trace
zhanghaotong Oct 14, 2025
66ca6e5
fix
zhanghaotong Oct 14, 2025
a4c9325
Merge branch 'main' into otlp-trace
zhanghaotong Oct 15, 2025
4162058
Fix llmapi test
Oct 15, 2025
dd8a9e0
Merge branch 'main' into otlp-trace
zhanghaotong Oct 17, 2025
5c9bb15
add dataclass to MinimalInstances
zhanghaotong Oct 20, 2025
c9bd23f
Merge branch 'main' into otlp-trace
zhanghaotong Oct 20, 2025
e570f0b
Merge branch 'main' into otlp-trace
zhanghaotong Oct 21, 2025
0859050
Merge branch 'main' into otlp-trace
zhanghaotong Oct 22, 2025
cfa9293
Merge branch 'main' into otlp-trace
zhanghaotong Oct 23, 2025
7bacc81
Merge branch 'main' into otlp-trace
zhanghaotong Oct 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions examples/opentelemetry/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# OpenTelemetry Integration Guide

This guide explains how to setup OpenTelemetry tracing in TensorRT-LLM to monitor and debug your LLM inference services.

## Install OpenTelemetry

Install the required OpenTelemetry packages:

```bash
pip install \
'opentelemetry-sdk' \
'opentelemetry-api' \
'opentelemetry-exporter-otlp' \
'opentelemetry-semantic-conventions-ai'
```

## Start Jaeger

You can start Jaeger with Docker:

```bash
docker run --rm --name jaeger \
-e COLLECTOR_ZIPKIN_HOST_PORT=:9411 \
-p 6831:6831/udp \
-p 6832:6832/udp \
-p 5778:5778 \
-p 16686:16686 \
-p 4317:4317 \
-p 4318:4318 \
-p 14250:14250 \
-p 14268:14268 \
-p 14269:14269 \
-p 9411:9411 \
jaegertracing/all-in-one:1.57.0
```

Or run the jaeger-all-in-one(.exe) executable from [the binary distribution archives](https://www.jaegertracing.io/download/):

```bash
jaeger-all-in-one --collector.zipkin.host-port=:9411
```

## Setup environment variables and run TensorRT-LLM

Set up the environment variables:

```bash
export JAEGER_IP=$(docker inspect --format '{{ .NetworkSettings.IPAddress }}' jaeger)
export OTEL_EXPORTER_OTLP_TRACES_PROTOCOL=grpc
export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=grpc://$JAEGER_IP:4317
export OTEL_EXPORTER_OTLP_TRACES_INSECURE=true
export OTEL_SERVICE_NAME="trt-server"
```

Then run TensorRT-LLM with OpenTelemetry, and make sure to set `return_perf_metrics` to true in the model configuration:

```bash
trtllm-serve models/Qwen3-8B/ --otlp_traces_endpoint="$OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"
```

## Send requests and find traces in Jaeger

You can send a request to the server and view the traces in [Jaeger UI](http://localhost:16686/).
The traces should be visible under the service name "trt-server".

## Configuration for Disaggregated Serving

For disaggregated serving scenarios, the configuration for ctx server and gen server remains the same as the standalone model. For the proxy, you can configure it as follows:

```yaml
# disagg_config.yaml
hostname: 127.0.0.1
port: 8000
backend: pytorch
context_servers:
num_instances: 1
urls:
- "127.0.0.1:8001"
generation_servers:
num_instances: 1
urls:
- "127.0.0.1:8002"
otlp_config:
otlp_traces_endpoint: "grpc://0.0.0.0:4317"
```
18 changes: 17 additions & 1 deletion tensorrt_llm/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@
from enum import EnumMeta
from functools import lru_cache, partial, wraps
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Union

import numpy as np
import nvtx
from mpi4py import MPI
from mpi4py.util import pkl5
from packaging import version
from typing_extensions import ParamSpec

# isort: off
import torch
Expand Down Expand Up @@ -1155,6 +1156,21 @@ def set_prometheus_multiproc_dir() -> object:
f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}")


P = ParamSpec("P")


# From: https://stackoverflow.com/a/4104188/2749989
def run_once(f: Callable[P, None]) -> Callable[P, None]:

def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
if not wrapper.has_run: # type: ignore[attr-defined]
wrapper.has_run = True # type: ignore[attr-defined]
return f(*args, **kwargs)

wrapper.has_run = False # type: ignore[attr-defined]
return wrapper


TORCH_PYBIND11_ABI = None


Expand Down
11 changes: 9 additions & 2 deletions tensorrt_llm/commands/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def get_llm_args(model: str,
trust_remote_code: bool = False,
reasoning_parser: Optional[str] = None,
fail_fast_on_attention_window_too_large: bool = False,
otlp_traces_endpoint: Optional[str] = None,
enable_chunked_prefill: bool = False,
**llm_args_extra_dict: Any):

Expand Down Expand Up @@ -134,6 +135,7 @@ def get_llm_args(model: str,
"reasoning_parser": reasoning_parser,
"fail_fast_on_attention_window_too_large":
fail_fast_on_attention_window_too_large,
"otlp_traces_endpoint": otlp_traces_endpoint,
"enable_chunked_prefill": enable_chunked_prefill,
}

Expand Down Expand Up @@ -322,6 +324,10 @@ def convert(self, value: Any, param: Optional["click.Parameter"],
help=
"Exit with runtime error when attention window is too large to fit even a single sequence in the KV cache."
)
@click.option("--otlp_traces_endpoint",
type=str,
default=None,
help="Target URL to which OpenTelemetry traces will be sent.")
@click.option("--disagg_cluster_uri",
type=str,
default=None,
Expand All @@ -344,8 +350,8 @@ def serve(
extra_llm_api_options: Optional[str], reasoning_parser: Optional[str],
metadata_server_config_file: Optional[str], server_role: Optional[str],
fail_fast_on_attention_window_too_large: bool,
enable_chunked_prefill: bool, disagg_cluster_uri: Optional[str],
media_io_kwargs: Optional[str]):
otlp_traces_endpoint: Optional[str], enable_chunked_prefill: bool,
disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str]):
"""Running an OpenAI API compatible server

MODEL: model name | HF checkpoint path | TensorRT engine path
Expand All @@ -371,6 +377,7 @@ def serve(
reasoning_parser=reasoning_parser,
fail_fast_on_attention_window_too_large=
fail_fast_on_attention_window_too_large,
otlp_traces_endpoint=otlp_traces_endpoint,
enable_chunked_prefill=enable_chunked_prefill)

llm_args_extra_dict = {}
Expand Down
10 changes: 9 additions & 1 deletion tensorrt_llm/executor/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,15 @@ def _get_metrics_dict(
req_perf_metrics.timing_metrics.first_scheduled_time.
total_seconds(),
RequestEventTiming.LAST_TOKEN_TIME:
req_perf_metrics.timing_metrics.last_token_time.total_seconds()
req_perf_metrics.timing_metrics.last_token_time.total_seconds(),
RequestEventTiming.KV_CACHE_TRANSFER_START:
req_perf_metrics.timing_metrics.kv_cache_transfer_start.
total_seconds(),
RequestEventTiming.KV_CACHE_TRANSFER_END:
req_perf_metrics.timing_metrics.kv_cache_transfer_end.
total_seconds(),
RequestEventTiming.KV_CACHE_SIZE:
req_perf_metrics.timing_metrics.kv_cache_size,
}
return metrics_dict

Expand Down
3 changes: 3 additions & 0 deletions tensorrt_llm/executor/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import signal
import traceback
from abc import ABC, abstractmethod
from collections.abc import Mapping
from pathlib import Path
from queue import Queue
from typing import (TYPE_CHECKING, AsyncIterable, Dict, Generator, List,
Expand Down Expand Up @@ -123,6 +124,7 @@ def generate_async(
streaming: bool = False,
kv_cache_retention_config: Optional[KvCacheRetentionConfig] = None,
disaggregated_params: Optional[DisaggregatedParams] = None,
trace_headers: Optional[Mapping[str, str]] = None,
postproc_params: Optional[PostprocParams] = None,
multimodal_params: Optional[MultimodalParams] = None,
scheduling_params: Optional[SchedulingParams] = None,
Expand Down Expand Up @@ -150,6 +152,7 @@ def generate_async(
streaming=streaming,
kv_cache_retention_config=kv_cache_retention_config,
disaggregated_params=disaggregated_params,
trace_headers=trace_headers,
multimodal_params=multimodal_params,
scheduling_params=scheduling_params,
cache_salt_id=cache_salt_id,
Expand Down
3 changes: 3 additions & 0 deletions tensorrt_llm/executor/request.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from collections.abc import Mapping
from dataclasses import dataclass
from typing import List, Optional, Union

Expand Down Expand Up @@ -94,6 +95,7 @@ def __init__(
streaming: bool = False,
kv_cache_retention_config: Optional[KvCacheRetentionConfig] = None,
disaggregated_params: Optional[DisaggregatedParams] = None,
trace_headers: Optional[Mapping[str, str]] = None,
postproc_params: Optional[PostprocParams] = None,
multimodal_params: Optional[MultimodalParams] = None,
scheduling_params: Optional[SchedulingParams] = None,
Expand Down Expand Up @@ -123,6 +125,7 @@ def __init__(
self.kv_cache_retention_config = kv_cache_retention_config
self.id: Optional[int] = None
self.disaggregated_params = disaggregated_params
self.trace_headers = trace_headers
self.scheduling_params = scheduling_params
self.cache_salt_id = cache_salt_id
self.arrival_time = arrival_time
Expand Down
113 changes: 111 additions & 2 deletions tensorrt_llm/executor/result.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import json
import threading
import time
import weakref
from dataclasses import dataclass, field
from queue import Empty, Queue
Expand All @@ -11,6 +12,8 @@
import torch
import torch.nn.functional as F

from tensorrt_llm.llmapi import tracing

try:
import ray
except ModuleNotFoundError:
Expand Down Expand Up @@ -268,6 +271,7 @@ def __init__(self,
self.avg_decoded_tokens_per_iter: Optional[float] = None
self._done = False
self.metrics_dict = {}
self.trace_headers: Optional[dict[str, str]] = None

if ray_queue is not None:
if has_event_loop():
Expand Down Expand Up @@ -436,6 +440,7 @@ def _handle_sequence(self,
raise ValueError(
f"Unknown finish reason: {finish_reasons[src_idx]}")
self.record_stats(output, req_perf_metrics_dict)
self.do_tracing(output, req_perf_metrics_dict)

@print_traceback_on_error
@nvtx_range_debug("handle_response",
Expand Down Expand Up @@ -472,7 +477,7 @@ def _handle_response(self,
self._outputs[0].disaggregated_params = disaggregated_params

if response.metrics:
self.metrics_dict = response.metrics
self.metrics_dict.update(response.metrics)

if response.error:
if self._background_error_handler is not None and (
Expand Down Expand Up @@ -570,7 +575,110 @@ def record_stats(self,
stats, len(output.token_ids), self.sampling_params.n > 1)
if processed_metrics_stat:
metrics_stats.update(processed_metrics_stat)
self.metrics_dict = metrics_stats
self.metrics_dict.update(metrics_stats)

def do_tracing(
self,
output: CompletionOutput,
req_perf_metrics_dict: Optional[dict[str, float]] = None,
) -> None:
"""Perform distributed tracing for the generation request.

Args:
output (CompletionOutput): The output of the generation result.
req_perf_metrics_dict (Optional[dict[str, float]]): Request performance metrics. Defaults to None.
"""
if not tracing.global_otlp_tracer():
return

metrics_dict = self.metrics_dict
if not metrics_dict or not req_perf_metrics_dict:
# Insufficient request metrics available; trace generation aborted.
tracing.insufficient_request_metrics_warning()
return

trace_context = tracing.extract_trace_context(self.trace_headers)
sampling_params = self.sampling_params

# Since arrival_time and other timing metrics are based on different time origins,
# we need to apply corrections to align them with absolute timestamps
time_correction = time.time() - time.monotonic()
arrival_time = req_perf_metrics_dict.get(
RequestEventTiming.ARRIVAL_TIME, 0)

with tracing.global_otlp_tracer().start_as_current_span(
"llm_request",
kind=tracing.SpanKind.SERVER,
context=trace_context,
start_time=int((arrival_time + time_correction) * 1e9),
) as span:

def safe_set_attr(span, attr, value):
if value is not None:
span.set_attribute(attr, value)

safe_set_attr(span,
tracing.SpanAttributes.GEN_AI_REQUEST_TEMPERATURE,
sampling_params.temperature)
safe_set_attr(span, tracing.SpanAttributes.GEN_AI_REQUEST_TOP_P,
sampling_params.top_p)
safe_set_attr(span, tracing.SpanAttributes.GEN_AI_REQUEST_TOP_K,
sampling_params.top_k)
safe_set_attr(
span,
tracing.SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
sampling_params.max_tokens,
)
safe_set_attr(span, tracing.SpanAttributes.GEN_AI_REQUEST_N,
sampling_params.n)
safe_set_attr(span, tracing.SpanAttributes.GEN_AI_REQUEST_ID,
self.id)
if prompt_token_ids := getattr(self, "prompt_token_ids", None):
safe_set_attr(span,
tracing.SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
len(prompt_token_ids))
safe_set_attr(span,
tracing.SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
output.length)
safe_set_attr(
span, tracing.SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN,
metrics_dict.get(MetricNames.TTFT, -1))
safe_set_attr(span, tracing.SpanAttributes.GEN_AI_LATENCY_E2E,
metrics_dict.get(MetricNames.E2E, -1))
safe_set_attr(span,
tracing.SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
metrics_dict.get(MetricNames.REQUEST_QUEUE_TIME, -1))
safe_set_attr(
span, tracing.SpanAttributes.GEN_AI_RESPONSE_FINISH_REASONS,
json.dumps([output.finish_reason])
if output.finish_reason else None)
safe_set_attr(
span,
tracing.SpanAttributes.GEN_AI_LATENCY_KV_CACHE_TRANSFER_TIME,
req_perf_metrics_dict.get(
RequestEventTiming.KV_CACHE_TRANSFER_END, 0.0) -
req_perf_metrics_dict.get(
RequestEventTiming.KV_CACHE_TRANSFER_START, 0.0))

if req_perf_metrics_dict.get(
RequestEventTiming.KV_CACHE_TRANSFER_START,
0) and req_perf_metrics_dict.get(
RequestEventTiming.KV_CACHE_TRANSFER_END, 0):
tracing.add_event(
tracing.SpanEvents.KV_CACHE_TRANSFER_START,
timestamp=int((req_perf_metrics_dict.get(
RequestEventTiming.KV_CACHE_TRANSFER_START, 0.0) +
time_correction) * 1e9))
tracing.add_event(
tracing.SpanEvents.KV_CACHE_TRANSFER_END,
attributes={
"kv_cache_size":
req_perf_metrics_dict.get(
RequestEventTiming.KV_CACHE_SIZE, 0)
},
timestamp=int((req_perf_metrics_dict.get(
RequestEventTiming.KV_CACHE_TRANSFER_END, 0.0) +
time_correction) * 1e9))


class DetokenizedGenerationResultBase(GenerationResultBase):
Expand Down Expand Up @@ -688,6 +796,7 @@ def __init__(
self.disaggregated_params = disaggregated_params
# minimal sampling params needed for logprob calculation
self._logprob_params = logprob_params
self.trace_headers = generation_request.trace_headers

# for aborting the request
self._executor: Optional[weakref.ReferenceType[
Expand Down
Loading