Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions tensorrt_llm/bench/dataclasses/reporting.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
from __future__ import annotations

import json
import os
from collections import defaultdict
from typing import Any, Dict, List, NamedTuple

import torch

try:
import pynvml
except ImportError:
pynvml = None

from tensorrt_llm._torch.pyexecutor.model_loader import \
validate_and_set_kv_cache_quant
from tensorrt_llm.bench.dataclasses.configuration import RuntimeConfig
Expand Down Expand Up @@ -198,6 +206,35 @@ def __init__(self,
self.get_max_draft_len())
self.streaming = streaming

@staticmethod
def _query_gpu_info() -> Dict[str, Any]:
"""Query first GPU info (all GPUs must be identical for TRT-LLM)."""
if not torch.cuda.is_available():
return None

try:
cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip()
physical_idx = int(
cuda_visible.split(",")[0].strip()) if cuda_visible else 0

props = torch.cuda.get_device_properties(physical_idx)
gpu_info = {
"name":
getattr(props, "name", "Unknown"),
"memory.total":
float(getattr(props, "total_memory", 0.0)) / (1024.0**3),
"clocks.mem":
None,
}
if pynvml:
# Memory clock information is not reported by torch, using NVML instead
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_idx)
gpu_info["clocks.mem"] = pynvml.nvmlDeviceGetMaxClockInfo(
handle, pynvml.NVML_CLOCK_MEM) / 1000.0
return gpu_info
except (RuntimeError, AssertionError):
return None

@staticmethod
def convert_to_ms(ns: float) -> float:
"""Convert nanoseconds to milliseconds."""
Expand Down Expand Up @@ -273,6 +310,9 @@ def get_statistics_dict(self) -> Dict[str, Any]:
},
}

# Machine / GPU details - query only first GPU (all GPUs must be identical)
stats_dict["machine"] = self._query_gpu_info()

# Retrieve KV cache information.
kv_cache_config = self.kwargs.get("kv_cache_config", KvCacheConfig())
if isinstance(kv_cache_config, KvCacheConfig):
Expand Down Expand Up @@ -478,6 +518,7 @@ def report_statistics(self) -> None:
"""
stats_dict = self.get_statistics_dict()
engine = stats_dict["engine"]
machine = stats_dict.get("machine")
world_info = stats_dict["world_info"]
requests = stats_dict["request_info"]
perf = stats_dict["performance"]
Expand Down Expand Up @@ -526,6 +567,20 @@ def report_statistics(self) -> None:
if kv_cache_percentage is not None:
kv_cache_percentage = f"{kv_cache_percentage * 100.0:.2f}%"

machine_info = (
"===========================================================\n"
"= MACHINE DETAILS \n"
"===========================================================\n")
if machine is None:
machine_info += "No GPU info available\n\n"
else:
name = machine.get("name", "Unknown")
mem_total_str = f"{machine['memory.total']:.2f} GB" if machine.get(
"memory.total") is not None else "N/A"
mem_clock_str = f"{machine['clocks.mem']:.2f} GHz" if machine.get(
'clocks.mem') is not None else "N/A"
machine_info += f"{name}, memory {mem_total_str}, {mem_clock_str}\n\n"

world_info = (
"===========================================================\n"
"= WORLD + RUNTIME INFORMATION \n"
Expand Down Expand Up @@ -663,6 +718,7 @@ def report_statistics(self) -> None:
)

logging_info = (f"{backend_info}"
f"{machine_info}"
f"{request_info}"
f"{world_info}"
f"{perf_header}"
Expand Down