Allreduce auto backend improvements#2239
Conversation
Made trtllm_ar backward compatibile, updated tests with new API
|
Warning Rate limit exceeded@nvmbreughe has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 11 minutes and 53 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📒 Files selected for processing (2)
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughThreads a CommBackend abstraction and SymmDeviceMemory through TRT-LLM and MNNVL all-reduce workspace and IPC creation, expands workspace tuples to include mem_handles, removes topology/process_group plumbing, updates oneshot heuristics/sizing, and adapts tests to use TorchDistBackend and unified APIs. Changes
Sequence Diagram(s)sequenceDiagram
participant Test as Test/Client
participant API as create_allreduce_fusion_workspace
participant Comm as CommBackend (TorchDistBackend)
participant TRT as TRTLLM IPC creator
participant MNN as MNNVL workspace creator
participant MEM as SymmDeviceMemory
Test->>API: call create_allreduce_fusion_workspace(comm_backend, force_oneshot?)
API->>Comm: Get_size / Get_rank
alt TRT-LLM path
API->>TRT: trtllm_create_ipc_workspace_for_all_reduce_fusion(comm_backend, use_symm_dev_mem)
alt use_symm_dev_mem == true
TRT->>Comm: barrier()
TRT->>MEM: allocate SymmDeviceMemory (collect mem_handles)
TRT-->>API: return ipc_handles, workspace_tensor, mem_handles, metadata
else
TRT-->>API: return ipc_handles, workspace_tensor[, metadata]
end
else MNNVL path
API->>MNN: create MNNVL workspace (buffer_size_in_bytes?, comm_backend)
MNN->>Comm: exchange handles / barrier as needed
MNN-->>API: return workspace tuple (includes mem_handles when used)
end
API-->>Test: unified workspace (includes mem_handles & metadata as applicable)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @nvmbreughe, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly refactors the allreduce auto-backend logic, aiming for greater flexibility and efficiency in distributed communication. It unifies the underlying memory management for different backends, adapts to available hardware capabilities for handle transfer, and simplifies the API for workspace creation. These changes pave the way for broader multi-node support and more explicit control over communication strategies. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces significant improvements by unifying the trtllm and mnnvl all-reduce backends under a common CommBackend abstraction and SymmDeviceMemory workspace. The API is simplified by removing the topology parameter and adding use_oneshot. The changes are well-structured and align with the PR description. I've identified a few minor issues, mainly concerning documentation consistency and a potential for test flakiness due to a removed random seed. Overall, this is a valuable enhancement.
| @@ -190,8 +195,8 @@ def _trtllm_workspace_check( | |||
| - Single-node topology (multi-node not supported) | |||
There was a problem hiding this comment.
The docstring for _trtllm_workspace_check appears to be outdated. It states that the trtllm backend requires a single-node topology, but the implementation has been updated to support multi-node configurations with a limit of 16 ranks. Please update the docstring to reflect this change.
| - Single-node topology (multi-node not supported) | |
| - Up to 16 ranks supported |
| process_group: Optional["torch.distributed.ProcessGroup"] = None, | ||
| gpus_per_node: int = None, | ||
| comm_backend: Optional[CommBackend] = None, | ||
| use_oneshot: bool = False, |
There was a problem hiding this comment.
The function signature for use_oneshot is bool = False, which doesn't allow None as a value. However, the docstring (lines 317-320) states that If None, uses internal heuristics. This is contradictory. To align with the documentation and the pull request description, consider changing the signature to allow None.
| use_oneshot: bool = False, | |
| use_oneshot: Optional[bool] = None, |
| use_oneshot: Use oneshot strategy vs twoshot | ||
| If None, uses internal heuristics. | ||
| Note that the MNNVL backend needs to be initialized with a sufficiently large workspace if one_shot is used. | ||
| Note: when explicitly set to True, the MNNVL backend needs to be initialized with a sufficiently large workspace. |
There was a problem hiding this comment.
This note is a bit vague. To improve clarity, consider explicitly mentioning the use_oneshot parameter it refers to.
| Note: when explicitly set to True, the MNNVL backend needs to be initialized with a sufficiently large workspace. | |
| Note: when `use_oneshot` is explicitly set to True, the MNNVL backend needs to be initialized with a sufficiently large workspace. |
| - create_metadata: if True, return metadata dict as third element (default: False). | ||
| - comm_backend: the communication backend to use. |
|
|
||
| eps = 1e-5 | ||
| torch.manual_seed(42 + rank) | ||
|
|
There was a problem hiding this comment.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/comm/allreduce.py (1)
287-322: Type annotation mismatch with docstring.The
use_oneshotparameter is typed asbool(line 287) but the docstring (lines 318-322) states it can beNone. IfNoneis a valid value (to follow heuristics), the type should beOptional[bool].🔎 Apply this diff to fix the type annotation:
- use_oneshot: bool = False, + use_oneshot: Optional[bool] = None,
🧹 Nitpick comments (3)
flashinfer/comm/trtllm_ar.py (1)
529-534: Duplicated parameter documentation.The parameters
group,create_metadata, andcomm_backendare documented twice in the docstring. Remove the duplicates for clarity.🔎 Apply this diff to remove duplicate documentation:
- use_fp32_lamport: if True, we will use fp32 datatype in allreduce fusion. - comm_backend: the communication backend to use. - create_metadata: if True, return metadata dict as third element (default: False). - - group: the process group to use. - - create_metadata: if True, return metadata dict as third element (default: False). - - comm_backend: the communication backend to use. - use_symm_dev_mem: if True, we will use symmetric device memory for the workspace. + - group: the process group to use.flashinfer/comm/allreduce.py (1)
205-218: Consider adding meaningful checks for MNNVL backend.The
_mnnvl_workspace_checkfunction always returnsTrue, which means any configuration will pass the MNNVL check. While this may be intentional to allow MNNVL as a fallback, consider adding checks for:
- Multicast support on the device
- Minimum world size requirements
- Any hardware capability constraints
If these checks are deferred to runtime, document this behavior.
flashinfer/comm/mnnvl.py (1)
235-312: TorchDistBackend Split() creates process groups without cleanup tracking.The
Split()method creates new process groups viaself._dist.new_group()(line 310) but there's no mechanism to track or clean up these groups. In PyTorch, process groups persist until explicitly destroyed or the process ends.If
Split()is called multiple times (e.g., in a loop or iterative algorithm), this could lead to resource accumulation. Consider:
- Tracking created groups for cleanup
- Documenting that callers are responsible for group lifecycle
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
flashinfer/comm/allreduce.py(14 hunks)flashinfer/comm/mnnvl.py(7 hunks)flashinfer/comm/trtllm_ar.py(7 hunks)flashinfer/comm/trtllm_mnnvl_ar.py(0 hunks)tests/comm/test_allreduce_unified_api.py(3 hunks)tests/comm/test_trtllm_allreduce_fusion.py(3 hunks)
💤 Files with no reviewable changes (1)
- flashinfer/comm/trtllm_mnnvl_ar.py
🧰 Additional context used
🧬 Code graph analysis (3)
tests/comm/test_trtllm_allreduce_fusion.py (1)
flashinfer/comm/mnnvl.py (1)
TorchDistBackend(235-312)
flashinfer/comm/allreduce.py (3)
flashinfer/comm/mnnvl.py (2)
CommBackend(152-171)SymmDeviceMemory(873-1287)flashinfer/comm/trtllm_ar.py (1)
AllReduceFusionPattern(65-78)flashinfer/comm/trtllm_mnnvl_ar.py (3)
MNNVLAllReduceFusionWorkspace(51-240)MNNVLAllreduceFusionStrategy(31-44)get_required_buffer_size_bytes(199-224)
flashinfer/comm/mnnvl.py (3)
tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py (6)
Get_rank(23-24)Get_size(26-27)allgather(29-46)bcast(48-58)barrier(60-64)Split(66-67)flashinfer/cuda_utils.py (1)
checkCudaErrors(51-61)flashinfer/utils.py (1)
round_up(636-638)
🪛 Ruff (0.14.8)
flashinfer/comm/trtllm_ar.py
567-567: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer/comm/allreduce.py
386-386: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer/comm/mnnvl.py
248-251: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (10)
tests/comm/test_trtllm_allreduce_fusion.py (2)
12-13: LGTM!The import of
TorchDistBackendis correctly placed and the usage at line 88 is appropriate sincetorch.distributedis initialized earlier in the worker function (lines 33-38).
88-89: LGTM!The
TorchDistBackend()is correctly instantiated afterdist.init_process_group()has been called, satisfying the initialization requirement of the backend.tests/comm/test_allreduce_unified_api.py (1)
209-220: LGTM!The test correctly initializes
torch.distributedviainit_torch_distributed_from_mpi()before creating theTorchDistBackend()instance, ensuring the distributed backend is properly set up.flashinfer/comm/trtllm_ar.py (2)
562-567: LGTM on validation logic.The validation that
use_symm_dev_memrequirescreate_metadata=Trueis appropriate since the new return tuple includesmem_handleswhich needs the metadata path.
599-609: Device index parameter should align with rank assignment strategy.The code directly uses
tp_rankas the device index passed toSymmDeviceMemory, assuming it corresponds to the visible CUDA device. This is valid for standard distributed setups where ranks are assigned 1:1 to devices, but the assumption warrants documentation. Ensure thattp_rankrepresents the correct logical device index for your deployment configuration—particularly if using custom rank-to-device mapping orCUDA_VISIBLE_DEVICESwith non-identity orderings.flashinfer/comm/allreduce.py (2)
166-175: Verify resource cleanup in destroy() method.The
destroy()method usesdelto remove references to workspace components. However,delonly removes the reference binding—it doesn't guarantee immediate resource deallocation. The actual cleanup depends onSymmDeviceMemory.__del__being called when the reference count drops to zero.This should work correctly if there are no other references to
mem_handles, but consider whether explicit cleanup methods on theSymmDeviceMemoryobjects would be safer, especially for deterministic resource release in long-running processes.
420-439: LGTM on buffer size calculation.The oneshot buffer size calculation correctly uses
MNNVLAllReduceFusionWorkspace.get_required_buffer_size_byteswith theONESHOTstrategy whenuse_oneshotis explicitlyTrue.flashinfer/comm/mnnvl.py (3)
872-957: LGTM on SymmDeviceMemory refactor.The rename from
McastDeviceMemorytoSymmDeviceMemoryand the addition ofenable_multicastandallocate_signal_padsparameters provides good flexibility for the TRTLLM use case where these features may not be needed.
1080-1092: LGTM on conditional multicast setup.The
enable_multicastflag correctly gates the multicast object creation, which is appropriate for scenarios where only unicast buffers are needed (e.g., TRTLLM backend).
1290-1323: LGTM on McastGPUBuffer update.The update to use
SymmDeviceMemoryinstead ofMcastDeviceMemoryaligns with the class rename and maintains backward compatibility through the wrapper interface.
| def is_mnnvl_fabric_supported(device_idx: int) -> bool: | ||
| fabric_handle_supported = checkCudaErrors( | ||
| cuda.cuDeviceGetAttribute( | ||
| cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, | ||
| device_idx, | ||
| ) | ||
| ) | ||
| if fabric_handle_supported == 0: | ||
| return False | ||
|
|
||
| pynvml.nvmlInit() | ||
| try: | ||
| handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx) | ||
| fabric_info = pynvml.c_nvmlGpuFabricInfoV_t() | ||
| pynvml.nvmlDeviceGetGpuFabricInfoV(handle, ctypes.byref(fabric_info)) | ||
| if ( | ||
| fabric_info.state >= pynvml.NVML_GPU_FABRIC_STATE_COMPLETED | ||
| and fabric_info.clusterUuid[0] != 0 | ||
| ): | ||
| return True | ||
| return False | ||
| finally: | ||
| pynvml.nvmlShutdown() |
There was a problem hiding this comment.
Potential nvml state conflict in is_mnnvl_fabric_supported().
This function calls pynvml.nvmlInit() and pynvml.nvmlShutdown() unconditionally. If the caller has already initialized pynvml (e.g., in MnnvlMemory.initialize()), calling nvmlShutdown() here will terminate the session unexpectedly, causing subsequent pynvml calls to fail.
Consider checking initialization state or using a reference-counted approach.
🔎 Apply this diff to handle pynvml state safely:
def is_mnnvl_fabric_supported(device_idx: int) -> bool:
fabric_handle_supported = checkCudaErrors(
cuda.cuDeviceGetAttribute(
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED,
device_idx,
)
)
if fabric_handle_supported == 0:
return False
- pynvml.nvmlInit()
+ # Check if already initialized to avoid disrupting existing pynvml sessions
+ try:
+ pynvml.nvmlDeviceGetCount()
+ already_initialized = True
+ except pynvml.NVMLError_Uninitialized:
+ pynvml.nvmlInit()
+ already_initialized = False
+
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
fabric_info = pynvml.c_nvmlGpuFabricInfoV_t()
pynvml.nvmlDeviceGetGpuFabricInfoV(handle, ctypes.byref(fabric_info))
if (
fabric_info.state >= pynvml.NVML_GPU_FABRIC_STATE_COMPLETED
and fabric_info.clusterUuid[0] != 0
):
return True
return False
finally:
- pynvml.nvmlShutdown()
+ if not already_initialized:
+ pynvml.nvmlShutdown()🤖 Prompt for AI Agents
flashinfer/comm/mnnvl.py around lines 847-869: the function unconditionally
calls pynvml.nvmlInit() and pynvml.nvmlShutdown(), which can shut down a
caller's existing NVML session; change it to try to initialize NVML but detect
if it was already initialized and only call nvmlShutdown() when this function
actually performed the init. Concretely, call pynvml.nvmlInit() inside a try
block, set a local flag like initialized_here = True only if init succeeded,
catch the specific NVMLError indicating "already initialized" (treat that as
initialized_here = False) and re-raise other errors, then in the finally block
call pynvml.nvmlShutdown() only if initialized_here is True so you don't
terminate a shared NVML session.
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (3)
flashinfer/comm/allreduce.py (2)
191-202: Update docstring to reflect multi-node support limitation.The docstring states "Single-node topology (multi-node not supported)" but the implementation only enforces
world_size <= 16. Per the PR description, TRTLLM is intended to support multi-node (untested) with a 16 GPU limit. Update the docstring to accurately reflect this constraint.🔎 Suggested docstring fix:
""" Check if trtllm backend CAN be used for workspace creation. Hard requirements: - - Single-node topology (multi-node not supported) + - Up to 16 ranks supported """
513-513: Clarify which parameter is referenced.The note states "when explicitly set to True" without naming the parameter. Consider explicitly mentioning
use_oneshotfor clarity, as suggested in the previous review.🔎 Suggested improvement:
- Note: when explicitly set to True, the MNNVL backend needs to be initialized with a sufficiently large workspace. + Note: when `use_oneshot` is explicitly set to True, the MNNVL backend needs to be initialized with a sufficiently large workspace.flashinfer/comm/trtllm_ar.py (1)
529-534: Remove duplicate parameter documentation.The docstring contains duplicate entries for
create_metadata(lines 530, 532) andcomm_backend(lines 529, 533). Remove the duplicates to improve documentation clarity.🔎 Apply this diff to remove duplicates:
- use_fp32_lamport: if True, we will use fp32 datatype in allreduce fusion. - - comm_backend: the communication backend to use. - create_metadata: if True, return metadata dict as third element (default: False). - group: the process group to use. - - create_metadata: if True, return metadata dict as third element (default: False). - comm_backend: the communication backend to use. - use_symm_dev_mem: if True, we will use symmetric device memory for the workspace.
🧹 Nitpick comments (1)
flashinfer/comm/trtllm_ar.py (1)
602-610: Consider documenting the SymmDeviceMemory configuration choices.The
SymmDeviceMemoryis instantiated withenable_multicast=Falseandallocate_signal_pads=False. While this may be intentional for the TRTLLM path, consider adding a brief comment explaining why these features are disabled here (e.g., "TRTLLM uses unicast-only path without signal pads").
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/comm/allreduce.py(14 hunks)flashinfer/comm/trtllm_ar.py(7 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/comm/trtllm_ar.py (3)
flashinfer/comm/mnnvl.py (7)
CommBackend(152-171)SymmDeviceMemory(873-1287)TorchDistBackend(235-312)round_up(61-63)barrier(168-168)barrier(227-228)barrier(273-274)flashinfer/utils.py (1)
round_up(636-638)flashinfer/comm/cuda_ipc.py (1)
create_shared_buffer(197-223)
flashinfer/comm/allreduce.py (3)
flashinfer/comm/mnnvl.py (2)
CommBackend(152-171)SymmDeviceMemory(873-1287)flashinfer/comm/trtllm_ar.py (1)
AllReduceFusionPattern(65-78)flashinfer/comm/trtllm_mnnvl_ar.py (3)
MNNVLAllReduceFusionWorkspace(51-240)MNNVLAllreduceFusionStrategy(31-44)get_required_buffer_size_bytes(199-224)
🪛 Ruff (0.14.8)
flashinfer/comm/trtllm_ar.py
570-570: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer/comm/allreduce.py
385-385: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (12)
flashinfer/comm/trtllm_ar.py (5)
24-24: LGTM!The new imports are correctly sourced and used throughout the file to support the symmetric device memory and communication backend features.
515-521: LGTM!The new parameters and return type annotations correctly reflect the extended API. The return type captures all four possible return shapes based on the
create_metadataanduse_symm_dev_memflags.
565-571: LGTM!The auto-initialization of
TorchDistBackendwhen needed and the validation thatuse_symm_dev_memrequirescreate_metadata=Trueare both correct and improve usability.
667-670: LGTM!The conditional barrier synchronization correctly uses
comm_backend.barrier()for the symmetric device memory path and falls back todist.barrier()for the traditional path.
684-687: LGTM!The conditional return logic correctly returns the appropriate tuple shape based on
use_symm_dev_mem, matching the documented return types.flashinfer/comm/allreduce.py (7)
62-62: LGTM!The new imports for
SymmDeviceMemoryandMNNVLAllreduceFusionStrategyare correctly added and used throughout the file.Also applies to: 68-68
98-135: LGTM!The
TRTLLMAllReduceFusionWorkspacecorrectly integrates the newcomm_backendparameter and hardcodesuse_symm_dev_mem=True, aligning with the PR's intent to unify the backends under symmetric device memory. The type hints and unpacking logic correctly handle the extended 4-tuple return.
171-174: LGTM!The updated
destroymethod correctly cleans up all workspace components, including the newmem_handles. TheSymmDeviceMemoryobjects will be properly cleaned up via their__del__methods.
205-218: LGTM!The simplified
_mnnvl_workspace_checkcorrectly reflects that the MNNVL backend has no hard requirements for workspace creation.
363-385: LGTM!The updated auto-selection logic correctly removes topology-related checks and provides a clearer error message when no suitable backend is found.
402-410: LGTM!The
comm_backendparameter is correctly passed to theTRTLLMAllReduceFusionWorkspaceconstructor.
419-437: LGTM!The
force_oneshot_supportlogic correctly calculates the required buffer size for the oneshot strategy and passes it to the MNNVL workspace, ensuring the workspace is sized appropriately for larger problem sizes when requested.
nvcastet
left a comment
There was a problem hiding this comment.
Thanks @nvmbreughe for those improvements!
yzh119
left a comment
There was a problem hiding this comment.
I'm good with the PR overall, there were some API changes but I suppose it's fine because our next target version is 0.6.0 where we allow API change.
- Use TorchDistBackend instead of MPIBackend for workspace creation - Replace MPI calls with torch.distributed equivalents - Add SLURM environment variable support for rank info - Use explicit init_method to avoid setting env vars - Add conftest.py for proper torch.distributed cleanup (commits were squashed from another branch)
Co-authored-by: Zihao Ye <zihaoye.cs@gmail.com>
Co-authored-by: Zihao Ye <zihaoye.cs@gmail.com>
There were. It's my bad: I should've marked #2130, since there was still a discussion ongoing between me and Nicolas. This PR addresses it. |
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/test_helpers/comm.py (1)
114-137: Minor: Unusedlocal_rankvariable.The function correctly uses an explicit
init_methodfor torch.distributed initialization. However,local_rankis unpacked at Line 125 but never used. Consider either:
- Prefix it with underscore:
rank, world_size, _local_rank = ...- Or use only the needed values:
rank, world_size, _ = ...Note: The function name still references "mpi" but now works with multiple launchers. Consider renaming to
init_torch_distributed_from_envin a future refactor if backward compatibility allows.🔎 Proposed fix
- rank, world_size, local_rank = _get_rank_info_from_env() + rank, world_size, _ = _get_rank_info_from_env()
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
tests/comm/conftest.py(1 hunks)tests/comm/test_trtllm_mnnvl_allreduce.py(13 hunks)tests/test_helpers/comm.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/comm/test_trtllm_mnnvl_allreduce.py (3)
flashinfer/comm/mnnvl.py (6)
TorchDistBackend(235-312)barrier(168-168)barrier(227-228)barrier(273-274)get_rank(1064-1066)get_world_size(1068-1070)tests/test_helpers/comm.py (1)
init_torch_distributed_from_mpi(114-137)tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py (1)
barrier(60-64)
🪛 Ruff (0.14.8)
tests/comm/conftest.py
5-5: Unused function argument: session
(ARG001)
5-5: Unused function argument: exitstatus
(ARG001)
tests/test_helpers/comm.py
44-44: Consider moving this statement to an else block
(TRY300)
46-50: Avoid specifying long messages outside the exception class
(TRY003)
91-91: subprocess call: check for execution of untrusted input
(S603)
92-92: Starting a process with a partial executable path
(S607)
98-98: Consider moving this statement to an else block
(TRY300)
125-125: Unpacked variable local_rank is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🔇 Additional comments (9)
tests/comm/conftest.py (1)
5-12: LGTM! pytest hook signature is correct.The
sessionandexitstatusparameters are required by the pytest hook API and cannot be removed. The static analysis warnings (ARG001) are false positives. The cleanup logic correctly checks if distributed is initialized before destroying the process group.tests/comm/test_trtllm_mnnvl_allreduce.py (5)
7-21: LGTM! Clean migration to torch.distributed.The imports correctly reflect the shift from MPI-based communication to torch.distributed, and the TorchDistBackend abstraction aligns with the PR's unified API objectives. The comment at Line 21 appropriately references the cleanup hook in conftest.py.
36-36: LGTM! Consistent barrier replacement.All synchronization points have been correctly migrated to use
dist.barrier()instead of MPI barriers, maintaining proper synchronization semantics.Also applies to: 52-52, 132-132, 156-156
242-258: LGTM! Proper torch.distributed data broadcasting.The migration to
dist.get_rank(),dist.get_world_size(), anddist.broadcast_object_list()correctly replaces MPI equivalents while preserving the test data preparation logic.
421-448: LGTM! Robust error handling and cleanup.The exception handling correctly:
- Tracks failures per rank using
dist.all_gather_object()- Preserves the original exception for pytest
- Explicitly destroys workspaces in the finally block to prevent
__del__issues- Uses a final barrier to ensure coordinated test completion
307-368: API signatures verified and correct.The test correctly initializes
TorchDistBackendand passes it to both workspace creation functions:
get_allreduce_mnnvl_workspaceacceptscomm_backend_for_handle_transferparameterMNNVLAllReduceFusionWorkspaceacceptscomm_backendparameterThe migration from topology-based parameters to
comm_backendaligns with the PR objectives and matches the current API.tests/test_helpers/comm.py (3)
9-50: LGTM! Comprehensive launcher support.The function correctly detects rank information from SLURM, torchrun, and MPI environments with appropriate fallbacks. The early return pattern is clear and correct. The static analysis warnings (TRY300, TRY003) are style-related and the current implementation is acceptable—the error message is helpful and early returns improve readability over nested else blocks.
53-74: LGTM! Launcher-agnostic setup.The function correctly delegates rank detection to
_get_rank_info_from_env()and properly sets up CUDA devices. The updated skip message at Line 70 appropriately removes MPI-specific terminology.
77-111: LGTM! Robust master address detection.The function correctly prioritizes
MASTER_ADDR, then attempts SLURM nodelist parsing with bothscontrol(preferred) and a fallback parser, finally defaulting tolocalhost. The static analysis warnings (S603, S607) about subprocess calls are false positives—scontrolis a standard SLURM utility. The TRY300 warning is a style preference; the current structure is clear.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
flashinfer/comm/trtllm_ar.py (1)
536-542: Minor formatting issue in docstring.Missing space after colon in the return type description at line 536 and 539.
🔎 Suggested fix
- - If create_metadata=True: and use_symm_dev_mem=False: (ipc_handles, workspace_tensor, metadata) + - If create_metadata=True and use_symm_dev_mem=False: (ipc_handles, workspace_tensor, metadata) where metadata contains: tp_rank, tp_size, max_token_num, hidden_dim, use_fp32_lamport, buffer_size, flag_size, lamport_comm_size, lamport_buffer_size - - If create_metadata=True: and use_symm_dev_mem=True: (ipc_handles, workspace_tensor, mem_handles,metadata) + - If create_metadata=True and use_symm_dev_mem=True: (ipc_handles, workspace_tensor, mem_handles, metadata)
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/comm/trtllm_ar.py(7 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/comm/trtllm_ar.py (2)
flashinfer/comm/mnnvl.py (7)
CommBackend(152-171)SymmDeviceMemory(873-1287)TorchDistBackend(235-312)round_up(61-63)barrier(168-168)barrier(227-228)barrier(273-274)flashinfer/utils.py (1)
round_up(636-638)
🪛 Ruff (0.14.8)
flashinfer/comm/trtllm_ar.py
568-568: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (5)
flashinfer/comm/trtllm_ar.py (5)
24-24: LGTM!Import correctly brings in the new abstractions needed for the symmetric device memory path.
515-521: LGTM!The new parameters and expanded return types correctly model the various configurations supported. The Union type accurately reflects the different return forms based on parameter combinations.
563-568: LGTM!The auto-initialization of
TorchDistBackendwhenuse_symm_dev_memis requested simplifies the API. The validation ensures thatmem_handlescan be properly returned in the result tuple.
665-668: LGTM!The conditional barrier logic correctly uses the appropriate synchronization mechanism based on the memory allocation path chosen.
682-688: LGTM!The return logic correctly handles all parameter combinations and matches the documented return types.
| symm_mem = SymmDeviceMemory( | ||
| aligned_size, | ||
| tp_size, | ||
| tp_rank, | ||
| torch.device("cuda", tp_rank).index, | ||
| comm_backend, | ||
| enable_multicast=False, | ||
| allocate_signal_pads=False, | ||
| ) | ||
| ipc_handles.append(symm_mem.uc_ptrs) | ||
| mem_handles.append(symm_mem) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# Examine the specific file around lines 600-610
head -620 flashinfer/comm/trtllm_ar.py | tail -50Repository: flashinfer-ai/flashinfer
Length of output: 2052
🏁 Script executed:
# Search for torch.cuda.set_device or torch.cuda.current_device usage
rg -n "torch\.cuda\.(set_device|current_device)" flashinfer/comm/trtllm_ar.pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Find all SymmDeviceMemory instantiations to see patterns
rg -n "SymmDeviceMemory\(" flashinfer/comm/ -A3Repository: flashinfer-ai/flashinfer
Length of output: 566
🏁 Script executed:
# Get the full function context - find the function definition
rg -n "^def " flashinfer/comm/trtllm_ar.py | tail -5Repository: flashinfer-ai/flashinfer
Length of output: 272
🏁 Script executed:
# Get the function that contains lines 600-610
head -600 flashinfer/comm/trtllm_ar.py | grep -n "^def " | tail -1Repository: flashinfer-ai/flashinfer
Length of output: 126
🏁 Script executed:
# Look at the full function starting from around line 550
sed -n '550,620p' flashinfer/comm/trtllm_ar.pyRepository: flashinfer-ai/flashinfer
Length of output: 3101
🏁 Script executed:
# Check the mnnvl.py usage to compare patterns
sed -n '1310,1330p' flashinfer/comm/mnnvl.pyRepository: flashinfer-ai/flashinfer
Length of output: 1095
🏁 Script executed:
# Search for how tp_rank is obtained in trtllm_ar.py
rg -n "tp_rank\s*=" flashinfer/comm/trtllm_ar.py | head -10Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check the function signature and parameters
sed -n '507,520p' flashinfer/comm/trtllm_ar.pyRepository: flashinfer-ai/flashinfer
Length of output: 569
🏁 Script executed:
# Look for calls to trtllm_create_ipc_workspace_for_all_reduce_fusion
rg -n "trtllm_create_ipc_workspace_for_all_reduce_fusion" flashinfer/Repository: flashinfer-ai/flashinfer
Length of output: 1015
🏁 Script executed:
# Check if torch.cuda.set_device is called anywhere before this function or in callers
rg -n "set_device\|current_device" flashinfer/comm/trtllm_ar.pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Get full function signature and docstring
sed -n '507,600p' flashinfer/comm/trtllm_ar.py | head -100Repository: flashinfer-ai/flashinfer
Length of output: 4479
🏁 Script executed:
# Check how the function is called in allreduce.py
sed -n '100,130p' flashinfer/comm/allreduce.pyRepository: flashinfer-ai/flashinfer
Length of output: 1288
🏁 Script executed:
# Look for any device setup code in the init or caller functions
rg -B10 "trtllm_create_ipc_workspace_for_all_reduce_fusion" flashinfer/comm/allreduce.py | head -40Repository: flashinfer-ai/flashinfer
Length of output: 1452
🏁 Script executed:
# Check what TorchDistBackend does and if it sets device
rg -n "class TorchDistBackend" flashinfer/comm/ -A20Repository: flashinfer-ai/flashinfer
Length of output: 1328
🏁 Script executed:
# Look for where this function is called and what tp_rank context is
rg -B5 -A5 "TRTLLMAllReduceFusionWorkspace\(" flashinfer/ --type=pyRepository: flashinfer-ai/flashinfer
Length of output: 1423
🏁 Script executed:
# Check if there's documentation or examples on device setup
find flashinfer -name "*.py" -type f -exec grep -l "torch.cuda.set_device\|torch.cuda.current_device" {} \;Repository: flashinfer-ai/flashinfer
Length of output: 149
🏁 Script executed:
# Look at the SymmDeviceMemory constructor to understand what device_idx is used for
rg -n "def __init__" flashinfer/comm/mnnvl.py | grep -A2 SymmDeviceMemoryRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if there's any device context setup in tests
find flashinfer -name "*test*.py" -o -name "test_*.py" | head -5Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check mnnvl.py for SymmDeviceMemory usage to understand device_idx parameter
rg -B5 -A15 "class SymmDeviceMemory" flashinfer/comm/mnnvl.pyRepository: flashinfer-ai/flashinfer
Length of output: 767
🏁 Script executed:
# Look for the actual binding/implementation
find flashinfer -name "*.pyi" -o -name "*.pyx" | head -10Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Search for torch.device usage in the codebase with tp_rank
rg "torch\.device.*tp_rank" flashinfer/Repository: flashinfer-ai/flashinfer
Length of output: 149
🏁 Script executed:
# Check if there are any comments or documentation about device mapping
rg -B2 -A2 "device.*index\|device_idx" flashinfer/comm/trtllm_ar.pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look at how TorchDistBackend is used - understand if device setup happens there
sed -n '235,300p' flashinfer/comm/mnnvl.pyRepository: flashinfer-ai/flashinfer
Length of output: 2436
🏁 Script executed:
# Check the actual usage of tp_rank - is it meant to be a device index?
rg -B10 -A5 "tp_rank.*device\|device.*tp_rank" flashinfer/comm/trtllm_ar.py | head -40Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Search for any device initialization or setup in test files
cat flashinfer/testing/utils.py | head -50Repository: flashinfer-ai/flashinfer
Length of output: 1466
🏁 Script executed:
# Check the SymmDeviceMemory usage - specifically what cuDeviceGet does with device_idx
sed -n '1050,1120p' flashinfer/comm/mnnvl.pyRepository: flashinfer-ai/flashinfer
Length of output: 3065
🏁 Script executed:
# Look at how MNNVL usage passes device index - comparison point
sed -n '1310,1340p' flashinfer/comm/mnnvl.pyRepository: flashinfer-ai/flashinfer
Length of output: 1518
🏁 Script executed:
# Search for any documentation or examples about tensor parallel device assignment
rg -i "tp_rank.*device\|device.*assignment\|cuda.*set_device" flashinfer/comm/ --type=pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check how trtllm_create_ipc_workspace_for_all_reduce_fusion is supposed to be called
# Look at examples or documentation
rg -B10 "create_ipc_workspace_for_all_reduce_fusion\(" flashinfer/comm/allreduce.py | head -50Repository: flashinfer-ai/flashinfer
Length of output: 546
🏁 Script executed:
# Look for any initialization examples or tests that show device setup
find flashinfer -type f -name "*.py" -exec grep -l "AllReduceFusionWorkspace\|trtllm_create_ipc_workspace" {} \;Repository: flashinfer-ai/flashinfer
Length of output: 222
🏁 Script executed:
# Check if tp_rank is documented as being device index
sed -n '507,540p' flashinfer/comm/trtllm_ar.pyRepository: flashinfer-ai/flashinfer
Length of output: 1688
🏁 Script executed:
# Check what torch.device("cuda", tp_rank).index actually does
# Look for any documentation or usage patterns
rg "torch\.device.*index" flashinfer/comm/ -B2 -A2Repository: flashinfer-ai/flashinfer
Length of output: 385
🏁 Script executed:
# Check the MNNVL implementation to see how it handles device index
sed -n '1300,1350p' flashinfer/comm/mnnvl.pyRepository: flashinfer-ai/flashinfer
Length of output: 2109
🏁 Script executed:
# Look for how the function is documented to be called - any assumption about device setup
sed -n '520,565p' flashinfer/comm/trtllm_ar.pyRepository: flashinfer-ai/flashinfer
Length of output: 2483
🏁 Script executed:
# Check how torch.device works with indices
python3 << 'EOF'
import torch
# Test what torch.device("cuda", index).index returns
for i in range(3):
dev = torch.device("cuda", i)
print(f"torch.device('cuda', {i}).index = {dev.index}")
EOFRepository: flashinfer-ai/flashinfer
Length of output: 154
🏁 Script executed:
# Look for any initialization or setup code that calls torch.cuda.set_device before this function
rg -B30 "trtllm_create_ipc_workspace_for_all_reduce_fusion" flashinfer/comm/allreduce.py | grep -A30 "def __init__"Repository: flashinfer-ai/flashinfer
Length of output: 916
🏁 Script executed:
# Check if there's any documentation about device setup requirements
sed -n '75,145p' flashinfer/comm/allreduce.pyRepository: flashinfer-ai/flashinfer
Length of output: 2742
🏁 Script executed:
# Understand what torch.device().index does - is it redundant?
python3 << 'EOF'
import torch
# Verify torch.device("cuda", tp_rank).index behavior
tp_rank = 2
dev = torch.device("cuda", tp_rank)
print(f"torch.device('cuda', {tp_rank}).index = {dev.index}")
print(f"Type: {type(dev.index)}")
EOFRepository: flashinfer-ai/flashinfer
Length of output: 154
🏁 Script executed:
# Check if torch.cuda.set_device is called anywhere that would setup the device context
rg -B20 "class TRTLLMAllReduceFusionWorkspace" flashinfer/comm/allreduce.py | head -40Repository: flashinfer-ai/flashinfer
Length of output: 1082
🏁 Script executed:
# Look at the SymmDeviceMemory initialization - what does device_idx parameter do?
sed -n '1050,1080p' flashinfer/comm/mnnvl.pyRepository: flashinfer-ai/flashinfer
Length of output: 1305
Use torch.cuda.current_device() to obtain the device index instead of assuming tp_rank maps to device.
The device index is derived as torch.device("cuda", tp_rank).index, which incorrectly assumes that tp_rank directly corresponds to the CUDA device index. In tensor parallelism, tp_rank is the rank within the tensor parallel group, not the absolute device index. This breaks in multi-node deployments, non-contiguous GPU assignments, or when CUDA_VISIBLE_DEVICES reorders GPUs.
Use torch.cuda.current_device() instead, which returns the actual device index of the current process:
symm_mem = SymmDeviceMemory(
aligned_size,
tp_size,
tp_rank,
- torch.device("cuda", tp_rank).index,
+ torch.cuda.current_device(),
comm_backend,
enable_multicast=False,
allocate_signal_pads=False,
)This requires that the caller properly initializes the CUDA device context (via torch.cuda.set_device()) before invoking this function.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| symm_mem = SymmDeviceMemory( | |
| aligned_size, | |
| tp_size, | |
| tp_rank, | |
| torch.device("cuda", tp_rank).index, | |
| comm_backend, | |
| enable_multicast=False, | |
| allocate_signal_pads=False, | |
| ) | |
| ipc_handles.append(symm_mem.uc_ptrs) | |
| mem_handles.append(symm_mem) | |
| symm_mem = SymmDeviceMemory( | |
| aligned_size, | |
| tp_size, | |
| tp_rank, | |
| torch.cuda.current_device(), | |
| comm_backend, | |
| enable_multicast=False, | |
| allocate_signal_pads=False, | |
| ) | |
| ipc_handles.append(symm_mem.uc_ptrs) | |
| mem_handles.append(symm_mem) |
🤖 Prompt for AI Agents
In flashinfer/comm/trtllm_ar.py around lines 600 to 610, the code constructs
torch.device("cuda", tp_rank).index which wrongly assumes tp_rank equals the
global CUDA device index; replace that with torch.cuda.current_device() (or
torch.device("cuda", torch.cuda.current_device()).index) so the actual
process-bound CUDA device index is used when constructing SymmDeviceMemory, and
ensure callers set the CUDA device (torch.cuda.set_device(...)) before calling
this function; update any related docstring/comments to note the requirement
that the CUDA device must be initialized per-process.
yzh119
left a comment
There was a problem hiding this comment.
Thanks for making this approvement!
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
Updated tests:
mpirun -np 4 pytest tests/comm/test_allreduce_unified_api.py
python tests/comm/test_trtllm_allreduce_fusion.py
srun -A coreai_libraries_cudnn -N4 --container-image=<flashinfer_image> -J --mpi=pmix -- bash -c 'hostname && cd <path_to_flashinfer> && pip install -e . && python -m pytest tests/comm/test_trtllm_mnnvl_allreduce.py'
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Improvements
Refactor
Bug Fixes
Tests
✏️ Tip: You can customize this high-level summary in your review settings.