Skip to content

Allreduce auto backend improvements#2239

Merged
yzh119 merged 11 commits into
flashinfer-ai:mainfrom
nvmbreughe:unified_ar_auto
Dec 20, 2025
Merged

Allreduce auto backend improvements#2239
yzh119 merged 11 commits into
flashinfer-ai:mainfrom
nvmbreughe:unified_ar_auto

Conversation

@nvmbreughe
Copy link
Copy Markdown
Contributor

@nvmbreughe nvmbreughe commented Dec 18, 2025

📌 Description

  • Internal changes:
    • trtllm and mnnvl now use the same underlying workspace (SymmDeviceMemory class)
    • Both will use fabric if supported, or PosixFD if not
    • trtllm should support multi-node as well (untested), limited to 16 GPUs
    • Removed the dependency on mpi for multi-node testing (test_trtllm_mnnvl_allreduce.py)
  • API changes:
    • Removed the "topology" parameter from tcreate_allreduce_fusion_workspace
      • Both trtllm and mnnvl kernels can be used for
    • Added force_oneshot to create_allreduce_fusion_workspace
      • This ensures a large enough workspace can be made in case users want to use one-shot for even larger problem sizes than heuristics suggest.
      • Supported values: True/False/default None, the latter two just follow heuristics

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

Updated tests:
mpirun -np 4 pytest tests/comm/test_allreduce_unified_api.py
python tests/comm/test_trtllm_allreduce_fusion.py
srun -A coreai_libraries_cudnn -N4 --container-image=<flashinfer_image> -J --mpi=pmix -- bash -c 'hostname && cd <path_to_flashinfer> && pip install -e . && python -m pytest tests/comm/test_trtllm_mnnvl_allreduce.py'

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added a unified backend option for distributed coordination and a fabric-aware symmetric-device-memory IPC path.
  • Improvements

    • Simplified distributed initialization (removed topology/process-group requirements), oneshot buffer sizing, and clearer backend selection errors.
  • Refactor

    • Workspace API surfaces updated to include symmetric-device-memory handles and backend-driven workflows.
  • Bug Fixes

    • More robust cleanup: explicit release of IPC and memory handles and safer finalization.
  • Tests

    • Tests updated to use the unified backend-based workspace creation and safer distributed teardown.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Dec 18, 2025

Warning

Rate limit exceeded

@nvmbreughe has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 11 minutes and 53 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between f6b0edc and 1cb0870.

📒 Files selected for processing (2)
  • tests/comm/conftest.py (1 hunks)
  • tests/comm/test_trtllm_mnnvl_allreduce.py (13 hunks)

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Threads a CommBackend abstraction and SymmDeviceMemory through TRT-LLM and MNNVL all-reduce workspace and IPC creation, expands workspace tuples to include mem_handles, removes topology/process_group plumbing, updates oneshot heuristics/sizing, and adapts tests to use TorchDistBackend and unified APIs.

Changes

Cohort / File(s) Summary
Comm backend & memory primitives
flashinfer/comm/mnnvl.py
Add TorchDistBackend (Get_rank/Get_size/allgather/bcast/barrier/Split); rename McastDeviceMemorySymmDeviceMemory (new flags enable_multicast, allocate_signal_pads), update allocation/cleanup and fabric checks, adapt McastGPUBuffer to use SymmDeviceMemory, add is_mnnvl_fabric_supported().
AllReduce workspace core
flashinfer/comm/allreduce.py
Thread comm_backend into workspace constructors and create_allreduce_fusion_workspace; expand internal workspace tuple to include List[SymmDeviceMemory] mem_handles and adjust destroy path; remove topology/process_group params; update heuristics and oneshot sizing/signature (force_oneshot_support, buffer sizing).
TRT-LLM IPC workspace
flashinfer/comm/trtllm_ar.py
trtllm_create_ipc_workspace_for_all_reduce_fusion accepts comm_backend and use_symm_dev_mem; supports SymmDeviceMemory-backed buffers, returns mem_handles when used; uses comm_backend.barrier() when appropriate; expands return tuple variants.
TRT-LLM MNNVL wrapper
flashinfer/comm/trtllm_mnnvl_ar.py
Remove explicit is_multi_node argument when constructing GPU buffers; rely on updated McastGPUBuffer/SymmDeviceMemory defaults.
Tests — unified API & dist init
tests/comm/test_allreduce_unified_api.py, tests/comm/test_trtllm_allreduce_fusion.py, tests/comm/test_trtllm_mnnvl_allreduce.py
Replace topology/process_group with comm_backend=TorchDistBackend() and call init_torch_distributed_from_mpi() consistently; pass comm_backend into workspace creation/handle-transfer; adapt synchronization and object broadcast/allgather uses.
Test harness & helpers
tests/comm/conftest.py, tests/test_helpers/comm.py
Add pytest session cleanup to destroy process group; refactor distributed init helpers to derive rank/world_size/local from env (support SLURM/torchrun/MPI), compute master address, and use tcp:// init_method for dist.init_process_group.
Unit test updates
tests/comm/*.py, tests/test_helpers/comm.py
Update test flows: fixed seed, unified workspace creation signatures and return-shape handling, TorchDistBackend integration, and final teardown adjustments.

Sequence Diagram(s)

sequenceDiagram
    participant Test as Test/Client
    participant API as create_allreduce_fusion_workspace
    participant Comm as CommBackend (TorchDistBackend)
    participant TRT as TRTLLM IPC creator
    participant MNN as MNNVL workspace creator
    participant MEM as SymmDeviceMemory

    Test->>API: call create_allreduce_fusion_workspace(comm_backend, force_oneshot?)
    API->>Comm: Get_size / Get_rank
    alt TRT-LLM path
        API->>TRT: trtllm_create_ipc_workspace_for_all_reduce_fusion(comm_backend, use_symm_dev_mem)
        alt use_symm_dev_mem == true
            TRT->>Comm: barrier()
            TRT->>MEM: allocate SymmDeviceMemory (collect mem_handles)
            TRT-->>API: return ipc_handles, workspace_tensor, mem_handles, metadata
        else
            TRT-->>API: return ipc_handles, workspace_tensor[, metadata]
        end
    else MNNVL path
        API->>MNN: create MNNVL workspace (buffer_size_in_bytes?, comm_backend)
        MNN->>Comm: exchange handles / barrier as needed
        MNN-->>API: return workspace tuple (includes mem_handles when used)
    end
    API-->>Test: unified workspace (includes mem_handles & metadata as applicable)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Inspect SymmDeviceMemory lifecycle, destructor guards, and CUDA-context checks in flashinfer/comm/mnnvl.py.
  • Verify workspace tuple changes and all callsites in flashinfer/comm/allreduce.py, flashinfer/comm/trtllm_ar.py, and callers handle new return shapes.
  • Validate TorchDistBackend semantics (allgather/bcast/barrier/Split) match prior process-group behavior used by tests and workspace creation.
  • Check tests use the new init helper and that pytest session cleanup avoids teardown warnings.

Possibly related PRs

Suggested reviewers

  • djmmoss
  • cyx-6
  • wenscarl
  • yzh119
  • jiahanc

Poem

🐰 I nudged the buffers, tapped the thread,
CommBackend counted ranks ahead,
SymmDeviceMemory hummed in line,
TorchDistBackend chimed the sync-time,
Fusion hops forward — carrot-coded, fine. 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 61.54% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Allreduce auto backend improvements' is partially related to the main changes; it describes a real aspect (backend improvements) but doesn't fully capture the key API changes (topology removal, force_oneshot addition) or the internal restructuring (SymmDeviceMemory integration).
Description check ✅ Passed The PR description covers the main changes, related internal improvements, API modifications, and test updates, but is missing explicit links to related issues and has incomplete pre-commit checkbox status.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @nvmbreughe, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors the allreduce auto-backend logic, aiming for greater flexibility and efficiency in distributed communication. It unifies the underlying memory management for different backends, adapts to available hardware capabilities for handle transfer, and simplifies the API for workspace creation. These changes pave the way for broader multi-node support and more explicit control over communication strategies.

Highlights

  • Unified Workspace Management: Both the TRT-LLM and MNNVL allreduce backends now utilize a shared underlying workspace class, SymmDeviceMemory, streamlining memory management and handle transfer mechanisms.
  • Flexible Communication Backend: The system now dynamically selects between NVLink Fabric and PosixFD for handle transfer, prioritizing Fabric when supported. A new TorchDistBackend has been introduced to facilitate communication using torch.distributed.
  • Simplified API for Workspace Creation: The create_allreduce_fusion_workspace function has been simplified by removing the topology parameter and replacing process_group with a more general comm_backend parameter. It also introduces a use_oneshot parameter for explicit control over workspace allocation strategy.
  • TRT-LLM Multi-Node Support: The TRT-LLM backend is now designed to support multi-node configurations, though currently limited to 16 GPUs and untested in this setup.
  • Improved Workspace Destruction: The TRTLLMAllReduceFusionWorkspace now explicitly dels its internal attributes during destruction, ensuring proper cleanup of resources.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant improvements by unifying the trtllm and mnnvl all-reduce backends under a common CommBackend abstraction and SymmDeviceMemory workspace. The API is simplified by removing the topology parameter and adding use_oneshot. The changes are well-structured and align with the PR description. I've identified a few minor issues, mainly concerning documentation consistency and a potential for test flakiness due to a removed random seed. Overall, this is a valuable enhancement.

Comment thread flashinfer/comm/allreduce.py Outdated
@@ -190,8 +195,8 @@ def _trtllm_workspace_check(
- Single-node topology (multi-node not supported)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for _trtllm_workspace_check appears to be outdated. It states that the trtllm backend requires a single-node topology, but the implementation has been updated to support multi-node configurations with a limit of 16 ranks. Please update the docstring to reflect this change.

Suggested change
- Single-node topology (multi-node not supported)
- Up to 16 ranks supported

Comment thread flashinfer/comm/allreduce.py Outdated
process_group: Optional["torch.distributed.ProcessGroup"] = None,
gpus_per_node: int = None,
comm_backend: Optional[CommBackend] = None,
use_oneshot: bool = False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function signature for use_oneshot is bool = False, which doesn't allow None as a value. However, the docstring (lines 317-320) states that If None, uses internal heuristics. This is contradictory. To align with the documentation and the pull request description, consider changing the signature to allow None.

Suggested change
use_oneshot: bool = False,
use_oneshot: Optional[bool] = None,

use_oneshot: Use oneshot strategy vs twoshot
If None, uses internal heuristics.
Note that the MNNVL backend needs to be initialized with a sufficiently large workspace if one_shot is used.
Note: when explicitly set to True, the MNNVL backend needs to be initialized with a sufficiently large workspace.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This note is a bit vague. To improve clarity, consider explicitly mentioning the use_oneshot parameter it refers to.

Suggested change
Note: when explicitly set to True, the MNNVL backend needs to be initialized with a sufficiently large workspace.
Note: when `use_oneshot` is explicitly set to True, the MNNVL backend needs to be initialized with a sufficiently large workspace.

Comment on lines 532 to +533
- create_metadata: if True, return metadata dict as third element (default: False).
- comm_backend: the communication backend to use.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameters create_metadata and comm_backend are duplicated in the docstring. Please remove the duplicate entries.


eps = 1e-5
torch.manual_seed(42 + rank)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The call to torch.manual_seed was removed. This can make the tests non-deterministic and potentially flaky. It's good practice to seed the random number generators for reproducibility. Please consider restoring it.

Suggested change
torch.manual_seed(42 + rank)

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/comm/allreduce.py (1)

287-322: Type annotation mismatch with docstring.

The use_oneshot parameter is typed as bool (line 287) but the docstring (lines 318-322) states it can be None. If None is a valid value (to follow heuristics), the type should be Optional[bool].

🔎 Apply this diff to fix the type annotation:
-    use_oneshot: bool = False,
+    use_oneshot: Optional[bool] = None,
🧹 Nitpick comments (3)
flashinfer/comm/trtllm_ar.py (1)

529-534: Duplicated parameter documentation.

The parameters group, create_metadata, and comm_backend are documented twice in the docstring. Remove the duplicates for clarity.

🔎 Apply this diff to remove duplicate documentation:
     - use_fp32_lamport: if True, we will use fp32 datatype in allreduce fusion.
     - comm_backend: the communication backend to use.
     - create_metadata: if True, return metadata dict as third element (default: False).
-    - group: the process group to use.
-    - create_metadata: if True, return metadata dict as third element (default: False).
-    - comm_backend: the communication backend to use.
     - use_symm_dev_mem: if True, we will use symmetric device memory for the workspace.
+    - group: the process group to use.
flashinfer/comm/allreduce.py (1)

205-218: Consider adding meaningful checks for MNNVL backend.

The _mnnvl_workspace_check function always returns True, which means any configuration will pass the MNNVL check. While this may be intentional to allow MNNVL as a fallback, consider adding checks for:

  • Multicast support on the device
  • Minimum world size requirements
  • Any hardware capability constraints

If these checks are deferred to runtime, document this behavior.

flashinfer/comm/mnnvl.py (1)

235-312: TorchDistBackend Split() creates process groups without cleanup tracking.

The Split() method creates new process groups via self._dist.new_group() (line 310) but there's no mechanism to track or clean up these groups. In PyTorch, process groups persist until explicitly destroyed or the process ends.

If Split() is called multiple times (e.g., in a loop or iterative algorithm), this could lead to resource accumulation. Consider:

  1. Tracking created groups for cleanup
  2. Documenting that callers are responsible for group lifecycle
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 454e7b2 and 5285e32.

📒 Files selected for processing (6)
  • flashinfer/comm/allreduce.py (14 hunks)
  • flashinfer/comm/mnnvl.py (7 hunks)
  • flashinfer/comm/trtllm_ar.py (7 hunks)
  • flashinfer/comm/trtllm_mnnvl_ar.py (0 hunks)
  • tests/comm/test_allreduce_unified_api.py (3 hunks)
  • tests/comm/test_trtllm_allreduce_fusion.py (3 hunks)
💤 Files with no reviewable changes (1)
  • flashinfer/comm/trtllm_mnnvl_ar.py
🧰 Additional context used
🧬 Code graph analysis (3)
tests/comm/test_trtllm_allreduce_fusion.py (1)
flashinfer/comm/mnnvl.py (1)
  • TorchDistBackend (235-312)
flashinfer/comm/allreduce.py (3)
flashinfer/comm/mnnvl.py (2)
  • CommBackend (152-171)
  • SymmDeviceMemory (873-1287)
flashinfer/comm/trtllm_ar.py (1)
  • AllReduceFusionPattern (65-78)
flashinfer/comm/trtllm_mnnvl_ar.py (3)
  • MNNVLAllReduceFusionWorkspace (51-240)
  • MNNVLAllreduceFusionStrategy (31-44)
  • get_required_buffer_size_bytes (199-224)
flashinfer/comm/mnnvl.py (3)
tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py (6)
  • Get_rank (23-24)
  • Get_size (26-27)
  • allgather (29-46)
  • bcast (48-58)
  • barrier (60-64)
  • Split (66-67)
flashinfer/cuda_utils.py (1)
  • checkCudaErrors (51-61)
flashinfer/utils.py (1)
  • round_up (636-638)
🪛 Ruff (0.14.8)
flashinfer/comm/trtllm_ar.py

567-567: Avoid specifying long messages outside the exception class

(TRY003)

flashinfer/comm/allreduce.py

386-386: Avoid specifying long messages outside the exception class

(TRY003)

flashinfer/comm/mnnvl.py

248-251: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (10)
tests/comm/test_trtllm_allreduce_fusion.py (2)

12-13: LGTM!

The import of TorchDistBackend is correctly placed and the usage at line 88 is appropriate since torch.distributed is initialized earlier in the worker function (lines 33-38).


88-89: LGTM!

The TorchDistBackend() is correctly instantiated after dist.init_process_group() has been called, satisfying the initialization requirement of the backend.

tests/comm/test_allreduce_unified_api.py (1)

209-220: LGTM!

The test correctly initializes torch.distributed via init_torch_distributed_from_mpi() before creating the TorchDistBackend() instance, ensuring the distributed backend is properly set up.

flashinfer/comm/trtllm_ar.py (2)

562-567: LGTM on validation logic.

The validation that use_symm_dev_mem requires create_metadata=True is appropriate since the new return tuple includes mem_handles which needs the metadata path.


599-609: Device index parameter should align with rank assignment strategy.

The code directly uses tp_rank as the device index passed to SymmDeviceMemory, assuming it corresponds to the visible CUDA device. This is valid for standard distributed setups where ranks are assigned 1:1 to devices, but the assumption warrants documentation. Ensure that tp_rank represents the correct logical device index for your deployment configuration—particularly if using custom rank-to-device mapping or CUDA_VISIBLE_DEVICES with non-identity orderings.

flashinfer/comm/allreduce.py (2)

166-175: Verify resource cleanup in destroy() method.

The destroy() method uses del to remove references to workspace components. However, del only removes the reference binding—it doesn't guarantee immediate resource deallocation. The actual cleanup depends on SymmDeviceMemory.__del__ being called when the reference count drops to zero.

This should work correctly if there are no other references to mem_handles, but consider whether explicit cleanup methods on the SymmDeviceMemory objects would be safer, especially for deterministic resource release in long-running processes.


420-439: LGTM on buffer size calculation.

The oneshot buffer size calculation correctly uses MNNVLAllReduceFusionWorkspace.get_required_buffer_size_bytes with the ONESHOT strategy when use_oneshot is explicitly True.

flashinfer/comm/mnnvl.py (3)

872-957: LGTM on SymmDeviceMemory refactor.

The rename from McastDeviceMemory to SymmDeviceMemory and the addition of enable_multicast and allocate_signal_pads parameters provides good flexibility for the TRTLLM use case where these features may not be needed.


1080-1092: LGTM on conditional multicast setup.

The enable_multicast flag correctly gates the multicast object creation, which is appropriate for scenarios where only unicast buffers are needed (e.g., TRTLLM backend).


1290-1323: LGTM on McastGPUBuffer update.

The update to use SymmDeviceMemory instead of McastDeviceMemory aligns with the class rename and maintains backward compatibility through the wrapper interface.

Comment thread flashinfer/comm/mnnvl.py
Comment on lines +847 to +869
def is_mnnvl_fabric_supported(device_idx: int) -> bool:
fabric_handle_supported = checkCudaErrors(
cuda.cuDeviceGetAttribute(
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED,
device_idx,
)
)
if fabric_handle_supported == 0:
return False

pynvml.nvmlInit()
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
fabric_info = pynvml.c_nvmlGpuFabricInfoV_t()
pynvml.nvmlDeviceGetGpuFabricInfoV(handle, ctypes.byref(fabric_info))
if (
fabric_info.state >= pynvml.NVML_GPU_FABRIC_STATE_COMPLETED
and fabric_info.clusterUuid[0] != 0
):
return True
return False
finally:
pynvml.nvmlShutdown()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Potential nvml state conflict in is_mnnvl_fabric_supported().

This function calls pynvml.nvmlInit() and pynvml.nvmlShutdown() unconditionally. If the caller has already initialized pynvml (e.g., in MnnvlMemory.initialize()), calling nvmlShutdown() here will terminate the session unexpectedly, causing subsequent pynvml calls to fail.

Consider checking initialization state or using a reference-counted approach.

🔎 Apply this diff to handle pynvml state safely:
 def is_mnnvl_fabric_supported(device_idx: int) -> bool:
     fabric_handle_supported = checkCudaErrors(
         cuda.cuDeviceGetAttribute(
             cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED,
             device_idx,
         )
     )
     if fabric_handle_supported == 0:
         return False

-    pynvml.nvmlInit()
+    # Check if already initialized to avoid disrupting existing pynvml sessions
+    try:
+        pynvml.nvmlDeviceGetCount()
+        already_initialized = True
+    except pynvml.NVMLError_Uninitialized:
+        pynvml.nvmlInit()
+        already_initialized = False
+
     try:
         handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
         fabric_info = pynvml.c_nvmlGpuFabricInfoV_t()
         pynvml.nvmlDeviceGetGpuFabricInfoV(handle, ctypes.byref(fabric_info))
         if (
             fabric_info.state >= pynvml.NVML_GPU_FABRIC_STATE_COMPLETED
             and fabric_info.clusterUuid[0] != 0
         ):
             return True
         return False
     finally:
-        pynvml.nvmlShutdown()
+        if not already_initialized:
+            pynvml.nvmlShutdown()
🤖 Prompt for AI Agents
flashinfer/comm/mnnvl.py around lines 847-869: the function unconditionally
calls pynvml.nvmlInit() and pynvml.nvmlShutdown(), which can shut down a
caller's existing NVML session; change it to try to initialize NVML but detect
if it was already initialized and only call nvmlShutdown() when this function
actually performed the init. Concretely, call pynvml.nvmlInit() inside a try
block, set a local flag like initialized_here = True only if init succeeded,
catch the specific NVMLError indicating "already initialized" (treat that as
initialized_here = False) and re-raise other errors, then in the finally block
call pynvml.nvmlShutdown() only if initialized_here is True so you don't
terminate a shared NVML session.

@nvmbreughe nvmbreughe marked this pull request as draft December 18, 2025 18:34
@nvmbreughe nvmbreughe marked this pull request as ready for review December 18, 2025 18:59
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (3)
flashinfer/comm/allreduce.py (2)

191-202: Update docstring to reflect multi-node support limitation.

The docstring states "Single-node topology (multi-node not supported)" but the implementation only enforces world_size <= 16. Per the PR description, TRTLLM is intended to support multi-node (untested) with a 16 GPU limit. Update the docstring to accurately reflect this constraint.

🔎 Suggested docstring fix:
     """
     Check if trtllm backend CAN be used for workspace creation.
 
     Hard requirements:
-    - Single-node topology (multi-node not supported)
+    - Up to 16 ranks supported
 
     """

513-513: Clarify which parameter is referenced.

The note states "when explicitly set to True" without naming the parameter. Consider explicitly mentioning use_oneshot for clarity, as suggested in the previous review.

🔎 Suggested improvement:
-                     Note: when explicitly set to True, the MNNVL backend needs to be initialized with a sufficiently large workspace.
+                     Note: when `use_oneshot` is explicitly set to True, the MNNVL backend needs to be initialized with a sufficiently large workspace.
flashinfer/comm/trtllm_ar.py (1)

529-534: Remove duplicate parameter documentation.

The docstring contains duplicate entries for create_metadata (lines 530, 532) and comm_backend (lines 529, 533). Remove the duplicates to improve documentation clarity.

🔎 Apply this diff to remove duplicates:
     - use_fp32_lamport: if True, we will use fp32 datatype in allreduce fusion.
-    - comm_backend: the communication backend to use.
     - create_metadata: if True, return metadata dict as third element (default: False).
     - group: the process group to use.
-    - create_metadata: if True, return metadata dict as third element (default: False).
     - comm_backend: the communication backend to use.
     - use_symm_dev_mem: if True, we will use symmetric device memory for the workspace.
🧹 Nitpick comments (1)
flashinfer/comm/trtllm_ar.py (1)

602-610: Consider documenting the SymmDeviceMemory configuration choices.

The SymmDeviceMemory is instantiated with enable_multicast=False and allocate_signal_pads=False. While this may be intentional for the TRTLLM path, consider adding a brief comment explaining why these features are disabled here (e.g., "TRTLLM uses unicast-only path without signal pads").

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5285e32 and 0f83701.

📒 Files selected for processing (2)
  • flashinfer/comm/allreduce.py (14 hunks)
  • flashinfer/comm/trtllm_ar.py (7 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/comm/trtllm_ar.py (3)
flashinfer/comm/mnnvl.py (7)
  • CommBackend (152-171)
  • SymmDeviceMemory (873-1287)
  • TorchDistBackend (235-312)
  • round_up (61-63)
  • barrier (168-168)
  • barrier (227-228)
  • barrier (273-274)
flashinfer/utils.py (1)
  • round_up (636-638)
flashinfer/comm/cuda_ipc.py (1)
  • create_shared_buffer (197-223)
flashinfer/comm/allreduce.py (3)
flashinfer/comm/mnnvl.py (2)
  • CommBackend (152-171)
  • SymmDeviceMemory (873-1287)
flashinfer/comm/trtllm_ar.py (1)
  • AllReduceFusionPattern (65-78)
flashinfer/comm/trtllm_mnnvl_ar.py (3)
  • MNNVLAllReduceFusionWorkspace (51-240)
  • MNNVLAllreduceFusionStrategy (31-44)
  • get_required_buffer_size_bytes (199-224)
🪛 Ruff (0.14.8)
flashinfer/comm/trtllm_ar.py

570-570: Avoid specifying long messages outside the exception class

(TRY003)

flashinfer/comm/allreduce.py

385-385: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (12)
flashinfer/comm/trtllm_ar.py (5)

24-24: LGTM!

The new imports are correctly sourced and used throughout the file to support the symmetric device memory and communication backend features.


515-521: LGTM!

The new parameters and return type annotations correctly reflect the extended API. The return type captures all four possible return shapes based on the create_metadata and use_symm_dev_mem flags.


565-571: LGTM!

The auto-initialization of TorchDistBackend when needed and the validation that use_symm_dev_mem requires create_metadata=True are both correct and improve usability.


667-670: LGTM!

The conditional barrier synchronization correctly uses comm_backend.barrier() for the symmetric device memory path and falls back to dist.barrier() for the traditional path.


684-687: LGTM!

The conditional return logic correctly returns the appropriate tuple shape based on use_symm_dev_mem, matching the documented return types.

flashinfer/comm/allreduce.py (7)

62-62: LGTM!

The new imports for SymmDeviceMemory and MNNVLAllreduceFusionStrategy are correctly added and used throughout the file.

Also applies to: 68-68


98-135: LGTM!

The TRTLLMAllReduceFusionWorkspace correctly integrates the new comm_backend parameter and hardcodes use_symm_dev_mem=True, aligning with the PR's intent to unify the backends under symmetric device memory. The type hints and unpacking logic correctly handle the extended 4-tuple return.


171-174: LGTM!

The updated destroy method correctly cleans up all workspace components, including the new mem_handles. The SymmDeviceMemory objects will be properly cleaned up via their __del__ methods.


205-218: LGTM!

The simplified _mnnvl_workspace_check correctly reflects that the MNNVL backend has no hard requirements for workspace creation.


363-385: LGTM!

The updated auto-selection logic correctly removes topology-related checks and provides a clearer error message when no suitable backend is found.


402-410: LGTM!

The comm_backend parameter is correctly passed to the TRTLLMAllReduceFusionWorkspace constructor.


419-437: LGTM!

The force_oneshot_support logic correctly calculates the required buffer size for the oneshot strategy and passes it to the MNNVL workspace, ensuring the workspace is sized appropriately for larger problem sizes when requested.

Comment thread flashinfer/comm/allreduce.py
Copy link
Copy Markdown
Contributor

@nvcastet nvcastet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @nvmbreughe for those improvements!

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm good with the PR overall, there were some API changes but I suppose it's fine because our next target version is 0.6.0 where we allow API change.

Comment thread tests/comm/test_allreduce_unified_api.py
Comment thread flashinfer/comm/trtllm_ar.py Outdated
Comment thread flashinfer/comm/trtllm_ar.py Outdated
nvmbreughe and others added 3 commits December 19, 2025 16:16
- Use TorchDistBackend instead of MPIBackend for workspace creation
- Replace MPI calls with torch.distributed equivalents
- Add SLURM environment variable support for rank info
- Use explicit init_method to avoid setting env vars
- Add conftest.py for proper torch.distributed cleanup

(commits were squashed from another branch)
Co-authored-by: Zihao Ye <zihaoye.cs@gmail.com>
Co-authored-by: Zihao Ye <zihaoye.cs@gmail.com>
@nvmbreughe
Copy link
Copy Markdown
Contributor Author

I'm good with the PR overall, there were some API changes but I suppose it's fine because our next target version is 0.6.0 where we allow API change.

There were. It's my bad: I should've marked #2130, since there was still a discussion ongoing between me and Nicolas. This PR addresses it.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tests/test_helpers/comm.py (1)

114-137: Minor: Unused local_rank variable.

The function correctly uses an explicit init_method for torch.distributed initialization. However, local_rank is unpacked at Line 125 but never used. Consider either:

  • Prefix it with underscore: rank, world_size, _local_rank = ...
  • Or use only the needed values: rank, world_size, _ = ...

Note: The function name still references "mpi" but now works with multiple launchers. Consider renaming to init_torch_distributed_from_env in a future refactor if backward compatibility allows.

🔎 Proposed fix
-    rank, world_size, local_rank = _get_rank_info_from_env()
+    rank, world_size, _ = _get_rank_info_from_env()
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9dd38fe and 1e42bb6.

📒 Files selected for processing (3)
  • tests/comm/conftest.py (1 hunks)
  • tests/comm/test_trtllm_mnnvl_allreduce.py (13 hunks)
  • tests/test_helpers/comm.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/comm/test_trtllm_mnnvl_allreduce.py (3)
flashinfer/comm/mnnvl.py (6)
  • TorchDistBackend (235-312)
  • barrier (168-168)
  • barrier (227-228)
  • barrier (273-274)
  • get_rank (1064-1066)
  • get_world_size (1068-1070)
tests/test_helpers/comm.py (1)
  • init_torch_distributed_from_mpi (114-137)
tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py (1)
  • barrier (60-64)
🪛 Ruff (0.14.8)
tests/comm/conftest.py

5-5: Unused function argument: session

(ARG001)


5-5: Unused function argument: exitstatus

(ARG001)

tests/test_helpers/comm.py

44-44: Consider moving this statement to an else block

(TRY300)


46-50: Avoid specifying long messages outside the exception class

(TRY003)


91-91: subprocess call: check for execution of untrusted input

(S603)


92-92: Starting a process with a partial executable path

(S607)


98-98: Consider moving this statement to an else block

(TRY300)


125-125: Unpacked variable local_rank is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🔇 Additional comments (9)
tests/comm/conftest.py (1)

5-12: LGTM! pytest hook signature is correct.

The session and exitstatus parameters are required by the pytest hook API and cannot be removed. The static analysis warnings (ARG001) are false positives. The cleanup logic correctly checks if distributed is initialized before destroying the process group.

tests/comm/test_trtllm_mnnvl_allreduce.py (5)

7-21: LGTM! Clean migration to torch.distributed.

The imports correctly reflect the shift from MPI-based communication to torch.distributed, and the TorchDistBackend abstraction aligns with the PR's unified API objectives. The comment at Line 21 appropriately references the cleanup hook in conftest.py.


36-36: LGTM! Consistent barrier replacement.

All synchronization points have been correctly migrated to use dist.barrier() instead of MPI barriers, maintaining proper synchronization semantics.

Also applies to: 52-52, 132-132, 156-156


242-258: LGTM! Proper torch.distributed data broadcasting.

The migration to dist.get_rank(), dist.get_world_size(), and dist.broadcast_object_list() correctly replaces MPI equivalents while preserving the test data preparation logic.


421-448: LGTM! Robust error handling and cleanup.

The exception handling correctly:

  • Tracks failures per rank using dist.all_gather_object()
  • Preserves the original exception for pytest
  • Explicitly destroys workspaces in the finally block to prevent __del__ issues
  • Uses a final barrier to ensure coordinated test completion

307-368: API signatures verified and correct.

The test correctly initializes TorchDistBackend and passes it to both workspace creation functions:

  • get_allreduce_mnnvl_workspace accepts comm_backend_for_handle_transfer parameter
  • MNNVLAllReduceFusionWorkspace accepts comm_backend parameter

The migration from topology-based parameters to comm_backend aligns with the PR objectives and matches the current API.

tests/test_helpers/comm.py (3)

9-50: LGTM! Comprehensive launcher support.

The function correctly detects rank information from SLURM, torchrun, and MPI environments with appropriate fallbacks. The early return pattern is clear and correct. The static analysis warnings (TRY300, TRY003) are style-related and the current implementation is acceptable—the error message is helpful and early returns improve readability over nested else blocks.


53-74: LGTM! Launcher-agnostic setup.

The function correctly delegates rank detection to _get_rank_info_from_env() and properly sets up CUDA devices. The updated skip message at Line 70 appropriately removes MPI-specific terminology.


77-111: LGTM! Robust master address detection.

The function correctly prioritizes MASTER_ADDR, then attempts SLURM nodelist parsing with both scontrol (preferred) and a fallback parser, finally defaulting to localhost. The static analysis warnings (S603, S607) about subprocess calls are false positives—scontrol is a standard SLURM utility. The TRY300 warning is a style preference; the current structure is clear.

@nvmbreughe nvmbreughe enabled auto-merge (squash) December 20, 2025 00:26
@nvmbreughe nvmbreughe disabled auto-merge December 20, 2025 00:26
@nvmbreughe nvmbreughe enabled auto-merge (squash) December 20, 2025 00:26
@nvmbreughe nvmbreughe disabled auto-merge December 20, 2025 00:30
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
flashinfer/comm/trtllm_ar.py (1)

536-542: Minor formatting issue in docstring.

Missing space after colon in the return type description at line 536 and 539.

🔎 Suggested fix
-    - If create_metadata=True: and use_symm_dev_mem=False: (ipc_handles, workspace_tensor, metadata)
+    - If create_metadata=True and use_symm_dev_mem=False: (ipc_handles, workspace_tensor, metadata)
       where metadata contains: tp_rank, tp_size, max_token_num, hidden_dim,
       use_fp32_lamport, buffer_size, flag_size, lamport_comm_size, lamport_buffer_size
-    - If create_metadata=True: and use_symm_dev_mem=True: (ipc_handles, workspace_tensor, mem_handles,metadata)
+    - If create_metadata=True and use_symm_dev_mem=True: (ipc_handles, workspace_tensor, mem_handles, metadata)
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1e42bb6 and f6b0edc.

📒 Files selected for processing (1)
  • flashinfer/comm/trtllm_ar.py (7 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/comm/trtllm_ar.py (2)
flashinfer/comm/mnnvl.py (7)
  • CommBackend (152-171)
  • SymmDeviceMemory (873-1287)
  • TorchDistBackend (235-312)
  • round_up (61-63)
  • barrier (168-168)
  • barrier (227-228)
  • barrier (273-274)
flashinfer/utils.py (1)
  • round_up (636-638)
🪛 Ruff (0.14.8)
flashinfer/comm/trtllm_ar.py

568-568: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (5)
flashinfer/comm/trtllm_ar.py (5)

24-24: LGTM!

Import correctly brings in the new abstractions needed for the symmetric device memory path.


515-521: LGTM!

The new parameters and expanded return types correctly model the various configurations supported. The Union type accurately reflects the different return forms based on parameter combinations.


563-568: LGTM!

The auto-initialization of TorchDistBackend when use_symm_dev_mem is requested simplifies the API. The validation ensures that mem_handles can be properly returned in the result tuple.


665-668: LGTM!

The conditional barrier logic correctly uses the appropriate synchronization mechanism based on the memory allocation path chosen.


682-688: LGTM!

The return logic correctly handles all parameter combinations and matches the documented return types.

Comment on lines +600 to +610
symm_mem = SymmDeviceMemory(
aligned_size,
tp_size,
tp_rank,
torch.device("cuda", tp_rank).index,
comm_backend,
enable_multicast=False,
allocate_signal_pads=False,
)
ipc_handles.append(symm_mem.uc_ptrs)
mem_handles.append(symm_mem)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# Examine the specific file around lines 600-610
head -620 flashinfer/comm/trtllm_ar.py | tail -50

Repository: flashinfer-ai/flashinfer

Length of output: 2052


🏁 Script executed:

# Search for torch.cuda.set_device or torch.cuda.current_device usage
rg -n "torch\.cuda\.(set_device|current_device)" flashinfer/comm/trtllm_ar.py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Find all SymmDeviceMemory instantiations to see patterns
rg -n "SymmDeviceMemory\(" flashinfer/comm/ -A3

Repository: flashinfer-ai/flashinfer

Length of output: 566


🏁 Script executed:

# Get the full function context - find the function definition
rg -n "^def " flashinfer/comm/trtllm_ar.py | tail -5

Repository: flashinfer-ai/flashinfer

Length of output: 272


🏁 Script executed:

# Get the function that contains lines 600-610
head -600 flashinfer/comm/trtllm_ar.py | grep -n "^def " | tail -1

Repository: flashinfer-ai/flashinfer

Length of output: 126


🏁 Script executed:

# Look at the full function starting from around line 550
sed -n '550,620p' flashinfer/comm/trtllm_ar.py

Repository: flashinfer-ai/flashinfer

Length of output: 3101


🏁 Script executed:

# Check the mnnvl.py usage to compare patterns
sed -n '1310,1330p' flashinfer/comm/mnnvl.py

Repository: flashinfer-ai/flashinfer

Length of output: 1095


🏁 Script executed:

# Search for how tp_rank is obtained in trtllm_ar.py
rg -n "tp_rank\s*=" flashinfer/comm/trtllm_ar.py | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check the function signature and parameters
sed -n '507,520p' flashinfer/comm/trtllm_ar.py

Repository: flashinfer-ai/flashinfer

Length of output: 569


🏁 Script executed:

# Look for calls to trtllm_create_ipc_workspace_for_all_reduce_fusion
rg -n "trtllm_create_ipc_workspace_for_all_reduce_fusion" flashinfer/

Repository: flashinfer-ai/flashinfer

Length of output: 1015


🏁 Script executed:

# Check if torch.cuda.set_device is called anywhere before this function or in callers
rg -n "set_device\|current_device" flashinfer/comm/trtllm_ar.py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Get full function signature and docstring
sed -n '507,600p' flashinfer/comm/trtllm_ar.py | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 4479


🏁 Script executed:

# Check how the function is called in allreduce.py
sed -n '100,130p' flashinfer/comm/allreduce.py

Repository: flashinfer-ai/flashinfer

Length of output: 1288


🏁 Script executed:

# Look for any device setup code in the init or caller functions
rg -B10 "trtllm_create_ipc_workspace_for_all_reduce_fusion" flashinfer/comm/allreduce.py | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 1452


🏁 Script executed:

# Check what TorchDistBackend does and if it sets device
rg -n "class TorchDistBackend" flashinfer/comm/ -A20

Repository: flashinfer-ai/flashinfer

Length of output: 1328


🏁 Script executed:

# Look for where this function is called and what tp_rank context is
rg -B5 -A5 "TRTLLMAllReduceFusionWorkspace\(" flashinfer/ --type=py

Repository: flashinfer-ai/flashinfer

Length of output: 1423


🏁 Script executed:

# Check if there's documentation or examples on device setup
find flashinfer -name "*.py" -type f -exec grep -l "torch.cuda.set_device\|torch.cuda.current_device" {} \;

Repository: flashinfer-ai/flashinfer

Length of output: 149


🏁 Script executed:

# Look at the SymmDeviceMemory constructor to understand what device_idx is used for
rg -n "def __init__" flashinfer/comm/mnnvl.py | grep -A2 SymmDeviceMemory

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if there's any device context setup in tests
find flashinfer -name "*test*.py" -o -name "test_*.py" | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check mnnvl.py for SymmDeviceMemory usage to understand device_idx parameter
rg -B5 -A15 "class SymmDeviceMemory" flashinfer/comm/mnnvl.py

Repository: flashinfer-ai/flashinfer

Length of output: 767


🏁 Script executed:

# Look for the actual binding/implementation
find flashinfer -name "*.pyi" -o -name "*.pyx" | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Search for torch.device usage in the codebase with tp_rank
rg "torch\.device.*tp_rank" flashinfer/

Repository: flashinfer-ai/flashinfer

Length of output: 149


🏁 Script executed:

# Check if there are any comments or documentation about device mapping
rg -B2 -A2 "device.*index\|device_idx" flashinfer/comm/trtllm_ar.py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look at how TorchDistBackend is used - understand if device setup happens there
sed -n '235,300p' flashinfer/comm/mnnvl.py

Repository: flashinfer-ai/flashinfer

Length of output: 2436


🏁 Script executed:

# Check the actual usage of tp_rank - is it meant to be a device index?
rg -B10 -A5 "tp_rank.*device\|device.*tp_rank" flashinfer/comm/trtllm_ar.py | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Search for any device initialization or setup in test files
cat flashinfer/testing/utils.py | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 1466


🏁 Script executed:

# Check the SymmDeviceMemory usage - specifically what cuDeviceGet does with device_idx
sed -n '1050,1120p' flashinfer/comm/mnnvl.py

Repository: flashinfer-ai/flashinfer

Length of output: 3065


🏁 Script executed:

# Look at how MNNVL usage passes device index - comparison point
sed -n '1310,1340p' flashinfer/comm/mnnvl.py

Repository: flashinfer-ai/flashinfer

Length of output: 1518


🏁 Script executed:

# Search for any documentation or examples about tensor parallel device assignment
rg -i "tp_rank.*device\|device.*assignment\|cuda.*set_device" flashinfer/comm/ --type=py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check how trtllm_create_ipc_workspace_for_all_reduce_fusion is supposed to be called
# Look at examples or documentation
rg -B10 "create_ipc_workspace_for_all_reduce_fusion\(" flashinfer/comm/allreduce.py | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 546


🏁 Script executed:

# Look for any initialization examples or tests that show device setup
find flashinfer -type f -name "*.py" -exec grep -l "AllReduceFusionWorkspace\|trtllm_create_ipc_workspace" {} \;

Repository: flashinfer-ai/flashinfer

Length of output: 222


🏁 Script executed:

# Check if tp_rank is documented as being device index
sed -n '507,540p' flashinfer/comm/trtllm_ar.py

Repository: flashinfer-ai/flashinfer

Length of output: 1688


🏁 Script executed:

# Check what torch.device("cuda", tp_rank).index actually does
# Look for any documentation or usage patterns
rg "torch\.device.*index" flashinfer/comm/ -B2 -A2

Repository: flashinfer-ai/flashinfer

Length of output: 385


🏁 Script executed:

# Check the MNNVL implementation to see how it handles device index
sed -n '1300,1350p' flashinfer/comm/mnnvl.py

Repository: flashinfer-ai/flashinfer

Length of output: 2109


🏁 Script executed:

# Look for how the function is documented to be called - any assumption about device setup
sed -n '520,565p' flashinfer/comm/trtllm_ar.py

Repository: flashinfer-ai/flashinfer

Length of output: 2483


🏁 Script executed:

# Check how torch.device works with indices
python3 << 'EOF'
import torch
# Test what torch.device("cuda", index).index returns
for i in range(3):
    dev = torch.device("cuda", i)
    print(f"torch.device('cuda', {i}).index = {dev.index}")
EOF

Repository: flashinfer-ai/flashinfer

Length of output: 154


🏁 Script executed:

# Look for any initialization or setup code that calls torch.cuda.set_device before this function
rg -B30 "trtllm_create_ipc_workspace_for_all_reduce_fusion" flashinfer/comm/allreduce.py | grep -A30 "def __init__"

Repository: flashinfer-ai/flashinfer

Length of output: 916


🏁 Script executed:

# Check if there's any documentation about device setup requirements
sed -n '75,145p' flashinfer/comm/allreduce.py

Repository: flashinfer-ai/flashinfer

Length of output: 2742


🏁 Script executed:

# Understand what torch.device().index does - is it redundant?
python3 << 'EOF'
import torch
# Verify torch.device("cuda", tp_rank).index behavior
tp_rank = 2
dev = torch.device("cuda", tp_rank)
print(f"torch.device('cuda', {tp_rank}).index = {dev.index}")
print(f"Type: {type(dev.index)}")
EOF

Repository: flashinfer-ai/flashinfer

Length of output: 154


🏁 Script executed:

# Check if torch.cuda.set_device is called anywhere that would setup the device context
rg -B20 "class TRTLLMAllReduceFusionWorkspace" flashinfer/comm/allreduce.py | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 1082


🏁 Script executed:

# Look at the SymmDeviceMemory initialization - what does device_idx parameter do?
sed -n '1050,1080p' flashinfer/comm/mnnvl.py

Repository: flashinfer-ai/flashinfer

Length of output: 1305


Use torch.cuda.current_device() to obtain the device index instead of assuming tp_rank maps to device.

The device index is derived as torch.device("cuda", tp_rank).index, which incorrectly assumes that tp_rank directly corresponds to the CUDA device index. In tensor parallelism, tp_rank is the rank within the tensor parallel group, not the absolute device index. This breaks in multi-node deployments, non-contiguous GPU assignments, or when CUDA_VISIBLE_DEVICES reorders GPUs.

Use torch.cuda.current_device() instead, which returns the actual device index of the current process:

             symm_mem = SymmDeviceMemory(
                 aligned_size,
                 tp_size,
                 tp_rank,
-                torch.device("cuda", tp_rank).index,
+                torch.cuda.current_device(),
                 comm_backend,
                 enable_multicast=False,
                 allocate_signal_pads=False,
             )

This requires that the caller properly initializes the CUDA device context (via torch.cuda.set_device()) before invoking this function.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
symm_mem = SymmDeviceMemory(
aligned_size,
tp_size,
tp_rank,
torch.device("cuda", tp_rank).index,
comm_backend,
enable_multicast=False,
allocate_signal_pads=False,
)
ipc_handles.append(symm_mem.uc_ptrs)
mem_handles.append(symm_mem)
symm_mem = SymmDeviceMemory(
aligned_size,
tp_size,
tp_rank,
torch.cuda.current_device(),
comm_backend,
enable_multicast=False,
allocate_signal_pads=False,
)
ipc_handles.append(symm_mem.uc_ptrs)
mem_handles.append(symm_mem)
🤖 Prompt for AI Agents
In flashinfer/comm/trtllm_ar.py around lines 600 to 610, the code constructs
torch.device("cuda", tp_rank).index which wrongly assumes tp_rank equals the
global CUDA device index; replace that with torch.cuda.current_device() (or
torch.device("cuda", torch.cuda.current_device()).index) so the actual
process-bound CUDA device index is used when constructing SymmDeviceMemory, and
ensure callers set the CUDA device (torch.cuda.set_device(...)) before calling
this function; update any related docstring/comments to note the requirement
that the CUDA device must be initialized per-process.

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making this approvement!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants