Skip to content

Commit 91dd24e

Browse files
committed
more details on MNNVL
Signed-off-by: Ludwig Schneider <[email protected]>
1 parent e2ae26a commit 91dd24e

File tree

1 file changed

+1
-143
lines changed

1 file changed

+1
-143
lines changed

tests/microbenchmarks/all_reduce.py

Lines changed: 1 addition & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import io
1716
import sys
1817
from argparse import ArgumentParser
19-
from contextlib import redirect_stderr, redirect_stdout
2018

2119
# isort: off
2220
import torch
@@ -37,86 +35,6 @@
3735
from tensorrt_llm.functional import AllReduceParams, AllReduceStrategy
3836

3937

40-
def test_mnnvl_available(allreduce_module, rank):
41-
"""
42-
Test if MNNVL is actually available and active in the AllReduce module.
43-
44-
Returns:
45-
tuple: (is_available: bool, details: dict)
46-
"""
47-
import platform
48-
49-
details = {
50-
'has_mnnvl_attribute': False,
51-
'mnnvl_is_not_none': False,
52-
'architecture': platform.machine(),
53-
'is_arm64': platform.machine() in ['aarch64', 'arm64'],
54-
}
55-
56-
# Check if module has mnnvl_allreduce attribute
57-
details['has_mnnvl_attribute'] = hasattr(allreduce_module,
58-
'mnnvl_allreduce')
59-
60-
if details['has_mnnvl_attribute']:
61-
details[
62-
'mnnvl_is_not_none'] = allreduce_module.mnnvl_allreduce is not None
63-
64-
# MNNVL is available if both checks pass
65-
is_available = details['has_mnnvl_attribute'] and details[
66-
'mnnvl_is_not_none']
67-
68-
return is_available, details
69-
70-
71-
def verify_strategy_active(allreduce_module, requested_strategy, rank):
72-
"""
73-
Verify that the requested strategy is actually active in the AllReduce module.
74-
75-
Returns:
76-
tuple: (is_active: bool, actual_implementation: str, message: str)
77-
"""
78-
if requested_strategy == AllReduceStrategy.MNNVL:
79-
# Test MNNVL availability
80-
is_available, details = test_mnnvl_available(allreduce_module, rank)
81-
82-
if is_available:
83-
return True, "MNNVL", "MNNVL AllReduce is active"
84-
else:
85-
# MNNVL requested but not active - likely fell back to C++ plugin
86-
msg = f"MNNVL requested but not active. Details: {details}"
87-
return False, "C++ Plugin (likely NCCL)", msg
88-
89-
# For other strategies, we can't easily verify, so assume they're working
90-
return True, requested_strategy.name, f"{requested_strategy.name} assumed active"
91-
92-
93-
def capture_trtllm_logs(func, *args, **kwargs):
94-
"""
95-
Execute a function and capture both Python stdout/stderr and C++ logs.
96-
97-
Returns:
98-
tuple: (result, captured_output: str)
99-
"""
100-
# Capture Python output
101-
stdout_capture = io.StringIO()
102-
stderr_capture = io.StringIO()
103-
104-
# Flush any pending output first
105-
sys.stdout.flush()
106-
sys.stderr.flush()
107-
108-
try:
109-
with redirect_stdout(stdout_capture), redirect_stderr(stderr_capture):
110-
result = func(*args, **kwargs)
111-
112-
output = stdout_capture.getvalue() + stderr_capture.getvalue()
113-
return result, output
114-
except Exception as e:
115-
output = stdout_capture.getvalue() + stderr_capture.getvalue()
116-
raise RuntimeError(
117-
f"Function failed with: {e}\nCaptured output:\n{output}") from e
118-
119-
12038
def parse_args():
12139
"""Parse command line arguments"""
12240
parser = ArgumentParser(description="AllReduce microbenchmark")
@@ -276,67 +194,7 @@ def run_benchmark(args):
276194
False)
277195
# NCCL, MIN_LATENCY, MNNVL don't need UB initialization
278196

279-
# Create AllReduce module (pass dtype for MNNVL support) and capture logs
280-
def create_allreduce():
281-
return AllReduce(mapping=mapping, strategy=strategy, dtype=torch_dtype)
282-
283-
allreduce, init_logs = capture_trtllm_logs(create_allreduce)
284-
285-
# Check for fallback messages in logs
286-
fallback_indicators = [
287-
"fallback to AllReduceStrategy",
288-
"Since Peer to Peer not supported",
289-
]
290-
291-
detected_fallback = any(indicator in init_logs
292-
for indicator in fallback_indicators)
293-
294-
if rank == 0 and detected_fallback:
295-
print(f"\nWARNING: Detected fallback during AllReduce initialization:",
296-
flush=True)
297-
print(f"Strategy requested: {strategy.name}", flush=True)
298-
print(f"Captured logs:\n{init_logs}", flush=True)
299-
300-
# Verify the requested strategy is actually active
301-
is_active, actual_impl, verify_msg = verify_strategy_active(
302-
allreduce, strategy, rank)
303-
304-
if rank == 0:
305-
print(
306-
f"Strategy verification: requested={strategy.name}, actual={actual_impl}",
307-
flush=True)
308-
print(f" Details: {verify_msg}", flush=True)
309-
310-
if not is_active:
311-
if rank == 0:
312-
print(f"\nERROR: Requested strategy {strategy.name} is not active!",
313-
file=sys.stderr,
314-
flush=True)
315-
print(f"The AllReduce module fell back to: {actual_impl}",
316-
file=sys.stderr,
317-
flush=True)
318-
319-
# Provide strategy-specific diagnostic info
320-
if strategy == AllReduceStrategy.MNNVL:
321-
print(f"\nMNNVL troubleshooting:", file=sys.stderr)
322-
print(f" - Ensure you're on ARM64/aarch64 architecture",
323-
file=sys.stderr)
324-
print(
325-
f" - Verify NVLink is available and all links are active",
326-
file=sys.stderr)
327-
print(f" - Check that world_size matches number of GPUs",
328-
file=sys.stderr)
329-
print(f" - Confirm MNNVL dependencies are installed",
330-
file=sys.stderr)
331-
332-
if init_logs:
333-
print(f"\nInitialization logs:\n{init_logs}",
334-
file=sys.stderr,
335-
flush=True)
336-
337-
# Ensure all ranks exit together
338-
tllm.mpi_barrier()
339-
sys.exit(1)
197+
allreduce = AllReduce(mapping=mapping, strategy=strategy, dtype=torch_dtype)
340198

341199
# Run a test operation to verify strategy works before benchmarking
342200
if rank == 0:

0 commit comments

Comments
 (0)