|
2 | 2 |
|
3 | 3 | import json |
4 | 4 | from collections import defaultdict |
5 | | -from typing import Any, Dict, List, NamedTuple |
| 5 | +from typing import Any, Dict, List, NamedTuple, Optional |
| 6 | + |
| 7 | +try: |
| 8 | + import pynvml |
| 9 | +except ImportError: |
| 10 | + pynvml = None |
6 | 11 |
|
7 | 12 | from tensorrt_llm._torch.pyexecutor.model_loader import \ |
8 | 13 | validate_and_set_kv_cache_quant |
|
14 | 19 | from tensorrt_llm.llmapi import KvCacheConfig |
15 | 20 | from tensorrt_llm.logger import Logger |
16 | 21 | from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode |
| 22 | +from tensorrt_llm.profiler import PyNVMLContext |
17 | 23 |
|
18 | 24 |
|
19 | 25 | class PerfItemTuple(NamedTuple): |
@@ -56,9 +62,7 @@ def register_request( |
56 | 62 | record.start_timestamp = timestamp |
57 | 63 |
|
58 | 64 | def register_request_perf_item(self, request_perf_item: PerfItemTuple): |
59 | | - """ |
60 | | - Register request perf items, used exclusively with LLM API. |
61 | | - """ |
| 65 | + """Register request perf items, used exclusively with LLM API.""" |
62 | 66 | record = self.requests[request_perf_item.request_id] |
63 | 67 | record.id = request_perf_item.request_id |
64 | 68 | record.num_input_tokens = request_perf_item.num_input_tokens |
@@ -116,7 +120,8 @@ def generate_statistics_summary(self, max_draft_tokens: int) -> None: |
116 | 120 | output_tokens.append(entry.num_total_output_tokens) |
117 | 121 | total_input_tokens += entry.num_input_tokens |
118 | 122 |
|
119 | | - # For speculative decoding, we need to track the number of draft tokens per request and the number of accepted draft tokens per request |
| 123 | + # For speculative decoding, track the number of draft tokens per request |
| 124 | + # and the number of accepted draft tokens per request. |
120 | 125 | if max_draft_tokens > 0: |
121 | 126 | num_draft_tokens.append(max_draft_tokens * |
122 | 127 | (entry.decode_iteration + 1)) |
@@ -198,6 +203,65 @@ def __init__(self, |
198 | 203 | self.get_max_draft_len()) |
199 | 204 | self.streaming = streaming |
200 | 205 |
|
| 206 | + @staticmethod |
| 207 | + def _query_gpu_info( |
| 208 | + gpu_indices: Optional[List[int]] = None) -> Dict[str, Any]: |
| 209 | + """Best-effort GPU and link info via pynvml. |
| 210 | +
|
| 211 | + Args: |
| 212 | + gpu_indices: List of GPU indices to query. If None, queries all GPUs. |
| 213 | +
|
| 214 | + Returns a dict with a list of GPUs and basic link/memory details. |
| 215 | + """ |
| 216 | + if pynvml is None: |
| 217 | + return {"gpus": [], "note": "pynvml not available"} |
| 218 | + |
| 219 | + try: |
| 220 | + with PyNVMLContext(): |
| 221 | + device_count = pynvml.nvmlDeviceGetCount() |
| 222 | + gpus: List[Dict[str, Any]] = [] |
| 223 | + |
| 224 | + # Determine which GPUs to query |
| 225 | + indices_to_query = gpu_indices or range(device_count) |
| 226 | + |
| 227 | + for idx in indices_to_query: |
| 228 | + try: |
| 229 | + handle = pynvml.nvmlDeviceGetHandleByIndex(idx) |
| 230 | + name = pynvml.nvmlDeviceGetName(handle) |
| 231 | + if isinstance(name, bytes): |
| 232 | + name = name.decode('utf-8') |
| 233 | + |
| 234 | + # Get total memory in GB |
| 235 | + mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) |
| 236 | + mem_total_gb = mem_info.total / (1024 * 1024 * 1024) |
| 237 | + |
| 238 | + # Get memory clock in GHz |
| 239 | + mem_clock_ghz = pynvml.nvmlDeviceGetMaxClockInfo( |
| 240 | + handle, pynvml.NVML_CLOCK_MEM) / 1000.0 |
| 241 | + gpu_entry: Dict[str, Any] = { |
| 242 | + "index": idx, |
| 243 | + "name": name, |
| 244 | + "memory.total": mem_total_gb, |
| 245 | + "clocks.mem": mem_clock_ghz, |
| 246 | + } |
| 247 | + gpus.append(gpu_entry) |
| 248 | + |
| 249 | + except pynvml.NVMLError as exc: |
| 250 | + # Skip this GPU if we can't get its info, but continue with others |
| 251 | + gpu_entry = { |
| 252 | + "index": idx, |
| 253 | + "name": "Unknown", |
| 254 | + "memory.total": 0.0, |
| 255 | + "clocks.mem": None, |
| 256 | + "error": f"NVML error: {exc}" |
| 257 | + } |
| 258 | + gpus.append(gpu_entry) |
| 259 | + |
| 260 | + return {"gpus": gpus, "note": None} |
| 261 | + |
| 262 | + except Exception as exc: # noqa: BLE001 |
| 263 | + return {"gpus": [], "note": f"pynvml query failed: {exc}"} |
| 264 | + |
201 | 265 | @staticmethod |
202 | 266 | def convert_to_ms(ns: float) -> float: |
203 | 267 | """Convert nanoseconds to milliseconds.""" |
@@ -273,6 +337,10 @@ def get_statistics_dict(self) -> Dict[str, Any]: |
273 | 337 | }, |
274 | 338 | } |
275 | 339 |
|
| 340 | + # Machine / GPU details - only show GPUs used in this benchmark run |
| 341 | + gpu_indices = list(range(self.rt_cfg.world_config.world_size)) |
| 342 | + stats_dict["machine"] = self._query_gpu_info(gpu_indices) |
| 343 | + |
276 | 344 | # Retrieve KV cache information. |
277 | 345 | kv_cache_config = self.kwargs.get("kv_cache_config", KvCacheConfig()) |
278 | 346 | if isinstance(kv_cache_config, KvCacheConfig): |
@@ -478,6 +546,7 @@ def report_statistics(self) -> None: |
478 | 546 | """ |
479 | 547 | stats_dict = self.get_statistics_dict() |
480 | 548 | engine = stats_dict["engine"] |
| 549 | + machine = stats_dict.get("machine", {"gpus": []}) |
481 | 550 | world_info = stats_dict["world_info"] |
482 | 551 | requests = stats_dict["request_info"] |
483 | 552 | perf = stats_dict["performance"] |
@@ -526,6 +595,25 @@ def report_statistics(self) -> None: |
526 | 595 | if kv_cache_percentage is not None: |
527 | 596 | kv_cache_percentage = f"{kv_cache_percentage * 100.0:.2f}%" |
528 | 597 |
|
| 598 | + machine_info = ( |
| 599 | + "===========================================================\n" |
| 600 | + "= MACHINE DETAILS \n" |
| 601 | + "===========================================================\n") |
| 602 | + gpus = machine.get("gpus", []) |
| 603 | + if not gpus: |
| 604 | + note = machine.get("note", "No GPU info available") |
| 605 | + machine_info += f"{note}\n\n" |
| 606 | + else: |
| 607 | + for gpu in gpus: |
| 608 | + name = gpu.get("name", "Unknown") |
| 609 | + idx = gpu.get("index", 0) |
| 610 | + mem_total_gb = gpu.get("memory.total") |
| 611 | + mem_clock_ghz = gpu.get("clocks.mem") |
| 612 | + |
| 613 | + machine_info += ( |
| 614 | + f"GPU {idx}: {name}, memory {mem_total_gb or 'N/A':.2f} GiB, {mem_clock_ghz or 'N/A':.2f} GHz\n" |
| 615 | + ) |
| 616 | + |
529 | 617 | world_info = ( |
530 | 618 | "===========================================================\n" |
531 | 619 | "= WORLD + RUNTIME INFORMATION \n" |
@@ -663,6 +751,7 @@ def report_statistics(self) -> None: |
663 | 751 | ) |
664 | 752 |
|
665 | 753 | logging_info = (f"{backend_info}" |
| 754 | + f"{machine_info}" |
666 | 755 | f"{request_info}" |
667 | 756 | f"{world_info}" |
668 | 757 | f"{perf_header}" |
|
0 commit comments