Skip to content

Commit 36d3d8f

Browse files
authored
[None][chore] Print device info in trtllm-bench report (#8584)
Signed-off-by: Gal Hubara Agam <[email protected]>
1 parent d076aa4 commit 36d3d8f

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

tensorrt_llm/bench/dataclasses/reporting.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
from __future__ import annotations
22

33
import json
4+
import os
45
from collections import defaultdict
56
from typing import Any, Dict, List, NamedTuple
67

8+
import torch
9+
10+
try:
11+
import pynvml
12+
except ImportError:
13+
pynvml = None
14+
715
from tensorrt_llm._torch.pyexecutor.model_loader import \
816
validate_and_set_kv_cache_quant
917
from tensorrt_llm.bench.dataclasses.configuration import RuntimeConfig
@@ -198,6 +206,35 @@ def __init__(self,
198206
self.get_max_draft_len())
199207
self.streaming = streaming
200208

209+
@staticmethod
210+
def _query_gpu_info() -> Dict[str, Any]:
211+
"""Query first GPU info (all GPUs must be identical for TRT-LLM)."""
212+
if not torch.cuda.is_available():
213+
return None
214+
215+
try:
216+
cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip()
217+
physical_idx = int(
218+
cuda_visible.split(",")[0].strip()) if cuda_visible else 0
219+
220+
props = torch.cuda.get_device_properties(physical_idx)
221+
gpu_info = {
222+
"name":
223+
getattr(props, "name", "Unknown"),
224+
"memory.total":
225+
float(getattr(props, "total_memory", 0.0)) / (1024.0**3),
226+
"clocks.mem":
227+
None,
228+
}
229+
if pynvml:
230+
# Memory clock information is not reported by torch, using NVML instead
231+
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_idx)
232+
gpu_info["clocks.mem"] = pynvml.nvmlDeviceGetMaxClockInfo(
233+
handle, pynvml.NVML_CLOCK_MEM) / 1000.0
234+
return gpu_info
235+
except (RuntimeError, AssertionError):
236+
return None
237+
201238
@staticmethod
202239
def convert_to_ms(ns: float) -> float:
203240
"""Convert nanoseconds to milliseconds."""
@@ -273,6 +310,9 @@ def get_statistics_dict(self) -> Dict[str, Any]:
273310
},
274311
}
275312

313+
# Machine / GPU details - query only first GPU (all GPUs must be identical)
314+
stats_dict["machine"] = self._query_gpu_info()
315+
276316
# Retrieve KV cache information.
277317
kv_cache_config = self.kwargs.get("kv_cache_config", KvCacheConfig())
278318
if isinstance(kv_cache_config, KvCacheConfig):
@@ -478,6 +518,7 @@ def report_statistics(self) -> None:
478518
"""
479519
stats_dict = self.get_statistics_dict()
480520
engine = stats_dict["engine"]
521+
machine = stats_dict.get("machine")
481522
world_info = stats_dict["world_info"]
482523
requests = stats_dict["request_info"]
483524
perf = stats_dict["performance"]
@@ -526,6 +567,20 @@ def report_statistics(self) -> None:
526567
if kv_cache_percentage is not None:
527568
kv_cache_percentage = f"{kv_cache_percentage * 100.0:.2f}%"
528569

570+
machine_info = (
571+
"===========================================================\n"
572+
"= MACHINE DETAILS \n"
573+
"===========================================================\n")
574+
if machine is None:
575+
machine_info += "No GPU info available\n\n"
576+
else:
577+
name = machine.get("name", "Unknown")
578+
mem_total_str = f"{machine['memory.total']:.2f} GB" if machine.get(
579+
"memory.total") is not None else "N/A"
580+
mem_clock_str = f"{machine['clocks.mem']:.2f} GHz" if machine.get(
581+
'clocks.mem') is not None else "N/A"
582+
machine_info += f"{name}, memory {mem_total_str}, {mem_clock_str}\n\n"
583+
529584
world_info = (
530585
"===========================================================\n"
531586
"= WORLD + RUNTIME INFORMATION \n"
@@ -663,6 +718,7 @@ def report_statistics(self) -> None:
663718
)
664719

665720
logging_info = (f"{backend_info}"
721+
f"{machine_info}"
666722
f"{request_info}"
667723
f"{world_info}"
668724
f"{perf_header}"

0 commit comments

Comments
 (0)