Skip to content

Commit dcfaf04

Browse files
committed
simplify logic for homogeneous GPUs only
Signed-off-by: Gal Hubara Agam <[email protected]>
1 parent a9a07ca commit dcfaf04

File tree

1 file changed

+36
-53
lines changed

1 file changed

+36
-53
lines changed

tensorrt_llm/bench/dataclasses/reporting.py

Lines changed: 36 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22

33
import json
4+
import os
45
from collections import defaultdict
5-
from typing import Any, Dict, List, NamedTuple, Optional
6+
from typing import Any, Dict, List, NamedTuple
67

78
import torch
89

@@ -206,43 +207,33 @@ def __init__(self,
206207
self.streaming = streaming
207208

208209
@staticmethod
209-
def _query_gpu_info(
210-
gpu_indices: Optional[List[int]] = None) -> Dict[str, Any]:
211-
"""Query GPU info and return a dict of minimal fields per GPU.
210+
def _query_gpu_info() -> Dict[str, Any]:
211+
"""Query first GPU info (all GPUs must be identical for TRT-LLM)."""
212+
if not torch.cuda.is_available():
213+
return None
212214

213-
Returns a dict: {"gpus": [{"index", "name", "memory.total", "clocks.mem"}, ...]}
214-
"""
215-
gpus: List[Dict[str, Any]] = []
216-
num_devices = torch.cuda.device_count() if torch.cuda.is_available(
217-
) else 0
218-
indices_to_query = gpu_indices or range(num_devices)
215+
try:
216+
cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip()
217+
physical_idx = int(
218+
cuda_visible.split(",")[0].strip()) if cuda_visible else 0
219219

220-
for idx in indices_to_query:
220+
props = torch.cuda.get_device_properties(physical_idx)
221221
gpu_info = {
222-
"index": idx,
223-
"name": "Unknown",
224-
"memory.total": None,
225-
"clocks.mem": None,
222+
"name":
223+
getattr(props, "name", "Unknown"),
224+
"memory.total":
225+
float(getattr(props, "total_memory", 0.0)) / (1024.0**3),
226+
"clocks.mem":
227+
None,
226228
}
227-
try:
228-
props = torch.cuda.get_device_properties(idx)
229-
gpu_info["name"] = getattr(props, "name", "Unknown")
230-
gpu_info["memory.total"] = float(
231-
getattr(props, "total_memory", 0.0)) / (1024.0**3) # GB
232-
if pynvml:
233-
# For memory clock, we must use pynvml
234-
handle = pynvml.nvmlDeviceGetHandleByIndex(idx)
235-
mem_clock_ghz = pynvml.nvmlDeviceGetMaxClockInfo(
236-
handle, pynvml.NVML_CLOCK_MEM) / 1000.0
237-
gpu_info["clocks.mem"] = mem_clock_ghz
238-
except (RuntimeError, AssertionError):
239-
# Skip this GPU if we can't get its info, but continue with others
240-
continue
241-
gpus.append(gpu_info)
242-
243-
note = None if gpus else (
244-
"No CUDA devices available" if num_devices == 0 else None)
245-
return {"gpus": gpus, "note": note}
229+
if pynvml:
230+
# Memory clock information is not reported by torch, using NVML instead
231+
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_idx)
232+
gpu_info["clocks.mem"] = pynvml.nvmlDeviceGetMaxClockInfo(
233+
handle, pynvml.NVML_CLOCK_MEM) / 1000.0
234+
return gpu_info
235+
except (RuntimeError, AssertionError):
236+
return None
246237

247238
@staticmethod
248239
def convert_to_ms(ns: float) -> float:
@@ -319,9 +310,8 @@ def get_statistics_dict(self) -> Dict[str, Any]:
319310
},
320311
}
321312

322-
# Machine / GPU details - only show GPUs used in this benchmark run
323-
gpu_indices = list(range(self.rt_cfg.mapping["world_size"]))
324-
stats_dict["machine"] = self._query_gpu_info(gpu_indices)
313+
# Machine / GPU details - query only first GPU (all GPUs must be identical)
314+
stats_dict["machine"] = self._query_gpu_info()
325315

326316
# Retrieve KV cache information.
327317
kv_cache_config = self.kwargs.get("kv_cache_config", KvCacheConfig())
@@ -528,7 +518,7 @@ def report_statistics(self) -> None:
528518
"""
529519
stats_dict = self.get_statistics_dict()
530520
engine = stats_dict["engine"]
531-
machine = stats_dict.get("machine", {"gpus": []})
521+
machine = stats_dict.get("machine")
532522
world_info = stats_dict["world_info"]
533523
requests = stats_dict["request_info"]
534524
perf = stats_dict["performance"]
@@ -581,22 +571,15 @@ def report_statistics(self) -> None:
581571
"===========================================================\n"
582572
"= MACHINE DETAILS \n"
583573
"===========================================================\n")
584-
gpus = machine.get("gpus", [])
585-
if not gpus:
586-
note = machine.get("note", "No GPU info available")
587-
machine_info += f"{note}\n\n"
574+
if machine is None:
575+
machine_info += "No GPU info available\n\n"
588576
else:
589-
for gpu in gpus:
590-
name = gpu.get("name", "Unknown")
591-
idx = gpu.get("index", 0)
592-
mem_total_str = f"{gpu['memory.total']:.2f} GB" if gpu.get(
593-
"memory.total") is not None else "N/A"
594-
mem_clock_str = f"{gpu['clocks.mem']:.2f} GHz" if gpu.get(
595-
'clocks.mem') is not None else "N/A"
596-
597-
machine_info += (
598-
f"GPU {idx}: {name}, memory {mem_total_str}, {mem_clock_str}\n"
599-
)
577+
name = machine.get("name", "Unknown")
578+
mem_total_str = f"{machine['memory.total']:.2f} GB" if machine.get(
579+
"memory.total") is not None else "N/A"
580+
mem_clock_str = f"{machine['clocks.mem']:.2f} GHz" if machine.get(
581+
'clocks.mem') is not None else "N/A"
582+
machine_info += f"{name}, memory {mem_total_str}, {mem_clock_str}\n\n"
600583

601584
world_info = (
602585
"===========================================================\n"

0 commit comments

Comments
 (0)