|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 |
|
16 | | -import io |
17 | 16 | import sys |
18 | 17 | from argparse import ArgumentParser |
19 | | -from contextlib import redirect_stderr, redirect_stdout |
20 | 18 |
|
21 | 19 | # isort: off |
22 | 20 | import torch |
|
37 | 35 | from tensorrt_llm.functional import AllReduceParams, AllReduceStrategy |
38 | 36 |
|
39 | 37 |
|
40 | | -def test_mnnvl_available(allreduce_module, rank): |
41 | | - """ |
42 | | - Test if MNNVL is actually available and active in the AllReduce module. |
43 | | -
|
44 | | - Returns: |
45 | | - tuple: (is_available: bool, details: dict) |
46 | | - """ |
47 | | - import platform |
48 | | - |
49 | | - details = { |
50 | | - 'has_mnnvl_attribute': False, |
51 | | - 'mnnvl_is_not_none': False, |
52 | | - 'architecture': platform.machine(), |
53 | | - 'is_arm64': platform.machine() in ['aarch64', 'arm64'], |
54 | | - } |
55 | | - |
56 | | - # Check if module has mnnvl_allreduce attribute |
57 | | - details['has_mnnvl_attribute'] = hasattr(allreduce_module, |
58 | | - 'mnnvl_allreduce') |
59 | | - |
60 | | - if details['has_mnnvl_attribute']: |
61 | | - details[ |
62 | | - 'mnnvl_is_not_none'] = allreduce_module.mnnvl_allreduce is not None |
63 | | - |
64 | | - # MNNVL is available if both checks pass |
65 | | - is_available = details['has_mnnvl_attribute'] and details[ |
66 | | - 'mnnvl_is_not_none'] |
67 | | - |
68 | | - return is_available, details |
69 | | - |
70 | | - |
71 | | -def verify_strategy_active(allreduce_module, requested_strategy, rank): |
72 | | - """ |
73 | | - Verify that the requested strategy is actually active in the AllReduce module. |
74 | | -
|
75 | | - Returns: |
76 | | - tuple: (is_active: bool, actual_implementation: str, message: str) |
77 | | - """ |
78 | | - if requested_strategy == AllReduceStrategy.MNNVL: |
79 | | - # Test MNNVL availability |
80 | | - is_available, details = test_mnnvl_available(allreduce_module, rank) |
81 | | - |
82 | | - if is_available: |
83 | | - return True, "MNNVL", "MNNVL AllReduce is active" |
84 | | - else: |
85 | | - # MNNVL requested but not active - likely fell back to C++ plugin |
86 | | - msg = f"MNNVL requested but not active. Details: {details}" |
87 | | - return False, "C++ Plugin (likely NCCL)", msg |
88 | | - |
89 | | - # For other strategies, we can't easily verify, so assume they're working |
90 | | - return True, requested_strategy.name, f"{requested_strategy.name} assumed active" |
91 | | - |
92 | | - |
93 | | -def capture_trtllm_logs(func, *args, **kwargs): |
94 | | - """ |
95 | | - Execute a function and capture both Python stdout/stderr and C++ logs. |
96 | | -
|
97 | | - Returns: |
98 | | - tuple: (result, captured_output: str) |
99 | | - """ |
100 | | - # Capture Python output |
101 | | - stdout_capture = io.StringIO() |
102 | | - stderr_capture = io.StringIO() |
103 | | - |
104 | | - # Flush any pending output first |
105 | | - sys.stdout.flush() |
106 | | - sys.stderr.flush() |
107 | | - |
108 | | - try: |
109 | | - with redirect_stdout(stdout_capture), redirect_stderr(stderr_capture): |
110 | | - result = func(*args, **kwargs) |
111 | | - |
112 | | - output = stdout_capture.getvalue() + stderr_capture.getvalue() |
113 | | - return result, output |
114 | | - except Exception as e: |
115 | | - output = stdout_capture.getvalue() + stderr_capture.getvalue() |
116 | | - raise RuntimeError( |
117 | | - f"Function failed with: {e}\nCaptured output:\n{output}") from e |
118 | | - |
119 | | - |
120 | 38 | def parse_args(): |
121 | 39 | """Parse command line arguments""" |
122 | 40 | parser = ArgumentParser(description="AllReduce microbenchmark") |
@@ -276,67 +194,7 @@ def run_benchmark(args): |
276 | 194 | False) |
277 | 195 | # NCCL, MIN_LATENCY, MNNVL don't need UB initialization |
278 | 196 |
|
279 | | - # Create AllReduce module (pass dtype for MNNVL support) and capture logs |
280 | | - def create_allreduce(): |
281 | | - return AllReduce(mapping=mapping, strategy=strategy, dtype=torch_dtype) |
282 | | - |
283 | | - allreduce, init_logs = capture_trtllm_logs(create_allreduce) |
284 | | - |
285 | | - # Check for fallback messages in logs |
286 | | - fallback_indicators = [ |
287 | | - "fallback to AllReduceStrategy", |
288 | | - "Since Peer to Peer not supported", |
289 | | - ] |
290 | | - |
291 | | - detected_fallback = any(indicator in init_logs |
292 | | - for indicator in fallback_indicators) |
293 | | - |
294 | | - if rank == 0 and detected_fallback: |
295 | | - print(f"\nWARNING: Detected fallback during AllReduce initialization:", |
296 | | - flush=True) |
297 | | - print(f"Strategy requested: {strategy.name}", flush=True) |
298 | | - print(f"Captured logs:\n{init_logs}", flush=True) |
299 | | - |
300 | | - # Verify the requested strategy is actually active |
301 | | - is_active, actual_impl, verify_msg = verify_strategy_active( |
302 | | - allreduce, strategy, rank) |
303 | | - |
304 | | - if rank == 0: |
305 | | - print( |
306 | | - f"Strategy verification: requested={strategy.name}, actual={actual_impl}", |
307 | | - flush=True) |
308 | | - print(f" Details: {verify_msg}", flush=True) |
309 | | - |
310 | | - if not is_active: |
311 | | - if rank == 0: |
312 | | - print(f"\nERROR: Requested strategy {strategy.name} is not active!", |
313 | | - file=sys.stderr, |
314 | | - flush=True) |
315 | | - print(f"The AllReduce module fell back to: {actual_impl}", |
316 | | - file=sys.stderr, |
317 | | - flush=True) |
318 | | - |
319 | | - # Provide strategy-specific diagnostic info |
320 | | - if strategy == AllReduceStrategy.MNNVL: |
321 | | - print(f"\nMNNVL troubleshooting:", file=sys.stderr) |
322 | | - print(f" - Ensure you're on ARM64/aarch64 architecture", |
323 | | - file=sys.stderr) |
324 | | - print( |
325 | | - f" - Verify NVLink is available and all links are active", |
326 | | - file=sys.stderr) |
327 | | - print(f" - Check that world_size matches number of GPUs", |
328 | | - file=sys.stderr) |
329 | | - print(f" - Confirm MNNVL dependencies are installed", |
330 | | - file=sys.stderr) |
331 | | - |
332 | | - if init_logs: |
333 | | - print(f"\nInitialization logs:\n{init_logs}", |
334 | | - file=sys.stderr, |
335 | | - flush=True) |
336 | | - |
337 | | - # Ensure all ranks exit together |
338 | | - tllm.mpi_barrier() |
339 | | - sys.exit(1) |
| 197 | + allreduce = AllReduce(mapping=mapping, strategy=strategy, dtype=torch_dtype) |
340 | 198 |
|
341 | 199 | # Run a test operation to verify strategy works before benchmarking |
342 | 200 | if rank == 0: |
|
0 commit comments