comm: avoid torch symmetric memory by default for TRTLLM allreduce workspaces#3277
comm: avoid torch symmetric memory by default for TRTLLM allreduce workspaces#3277mmangkad wants to merge 1 commit into
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a Torch-symmetric-memory allocation mode for TRTLLM and MNNVL all-reduce fusion workspaces, introduces a unified allocator, relaxes workspace reference types to List[Any], and wires the mode through constructors, workspace creation, synchronization, and cleanup. ChangesTorch Symmetric Memory Support for All-Reduce
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a mechanism to toggle between FlashInfer's internal symmetric memory and PyTorch's symmetric memory for TRTLLM all-reduce workspaces, controlled by the FLASHINFER_TRTLLM_AR_USE_TORCH_SYMM_MEM environment variable. The changes include refactoring workspace allocation into centralized helper functions, updating type hints to be more generic, and modifying the MNNVLAllReduceFusionWorkspace to support both memory backends. A review comment points out a potential correctness issue where lamport_initialize is hardcoded to torch.float32 instead of using the provided dtype.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/comm/trtllm_ar.py (1)
777-780:⚠️ Potential issue | 🟠 Major | ⚡ Quick winBarrier selection must match the allocator that was actually chosen.
The allocator selection was changed to use the computed
use_torch_symm_mem(line 691), but the barrier selection at lines 777–780 still keys off the parameteruse_symm_dev_mem. This creates a mismatch: whenuse_symm_dev_mem=False(default) anduse_torch_symm_mem=False, the code allocates usingSymmDeviceMemory(comm_backend)(which exchanges handles via the backend), then syncs ondist.barrier()instead of the backend barrier. For non-TorchDistBackendcallers, the IPC handles risk being unsynchronized—causing hangs or races.Suggested fix
- if use_symm_dev_mem: - comm_backend.barrier() # must sync after create_workspace - else: - dist.barrier(group=group) + if use_torch_symm_mem: + dist.barrier(group=group) + else: + comm_backend.barrier() # must sync after create_workspace🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flashinfer/comm/trtllm_ar.py` around lines 777 - 780, The barrier selection must match the allocator chosen; replace the conditional that currently checks use_symm_dev_mem with the computed flag use_torch_symm_mem so the code calls comm_backend.barrier() when SymmDeviceMemory (backend-managed IPC) was actually selected and calls dist.barrier(group=group) otherwise; update the branch that contains comm_backend.barrier() / dist.barrier(...) to key off use_torch_symm_mem and ensure this aligns with the allocator creation logic (SymmDeviceMemory and any places that set use_torch_symm_mem).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@flashinfer/comm/trtllm_mnnvl_ar.py`:
- Around line 142-159: The torch-symmetric memory path uses _uses_torch_symm_mem
and currently assumes a TorchDistBackend when building group_name, but silently
falls back to torch.distributed.group.WORLD for non-Torch backends (e.g.,
MPIBackend); change this to explicitly require comm_backend to be an instance of
TorchDistBackend before calling _alloc_symm_buffer_bytes: if comm_backend is not
a TorchDistBackend, raise a clear error (or disable the torch-symmetric option)
so _alloc_symm_buffer_bytes is only invoked with a valid TorchDistBackend and a
proper group_name; update the branch around self._uses_torch_symm_mem,
TorchDistBackend, comm_backend, group_name, and the call to
_alloc_symm_buffer_bytes (which sets self.ptrs, self.tensor, self.handle).
---
Outside diff comments:
In `@flashinfer/comm/trtllm_ar.py`:
- Around line 777-780: The barrier selection must match the allocator chosen;
replace the conditional that currently checks use_symm_dev_mem with the computed
flag use_torch_symm_mem so the code calls comm_backend.barrier() when
SymmDeviceMemory (backend-managed IPC) was actually selected and calls
dist.barrier(group=group) otherwise; update the branch that contains
comm_backend.barrier() / dist.barrier(...) to key off use_torch_symm_mem and
ensure this aligns with the allocator creation logic (SymmDeviceMemory and any
places that set use_torch_symm_mem).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 798dd963-812c-4b72-b365-abb0642a01dd
📥 Commits
Reviewing files that changed from the base of the PR and between 0a128d1 and 7c5c276322311a9dd09096a1e204a72ea1c65525.
📒 Files selected for processing (3)
flashinfer/comm/allreduce.pyflashinfer/comm/trtllm_ar.pyflashinfer/comm/trtllm_mnnvl_ar.py
7c5c276 to
f83e8d6
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@flashinfer/comm/trtllm_mnnvl_ar.py`:
- Around line 150-162: Before calling _alloc_symm_buffer_bytes, add a runtime
guard that verifies the torch process group derived from comm_backend._group (or
torch.distributed.group.WORLD) matches the Mapping metadata: ensure group.size()
== mapping.tp_size and group.rank() == self.rank (or otherwise map consistently)
and if not, log an error and abort/raise; this check should live just above the
_alloc_symm_buffer_bytes call (referencing comm_backend._group, group,
mapping.tp_size, self.rank) and similarly before the later use of
self.handle.buffer_ptrs[self.rank] to prevent out‑of‑bounds or wrong‑peer
allocation.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: a77b0ee7-2832-4cf8-a5c7-ad4624214efe
📥 Commits
Reviewing files that changed from the base of the PR and between 7c5c276322311a9dd09096a1e204a72ea1c65525 and f83e8d602ef1cc8f86d331268b13117d6710e2a0.
📒 Files selected for processing (3)
flashinfer/comm/allreduce.pyflashinfer/comm/trtllm_ar.pyflashinfer/comm/trtllm_mnnvl_ar.py
🚧 Files skipped from review as they are similar to previous changes (2)
- flashinfer/comm/allreduce.py
- flashinfer/comm/trtllm_ar.py
dddc01e to
6968afa
Compare
ae95a0e to
d4ce499
Compare
d4ce499 to
bfdebb6
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@flashinfer/comm/trtllm_mnnvl_ar.py`:
- Around line 288-293: The cleanup in destroy() uses getattr(self,
"_uses_torch_symm_mem", True) which contradicts the constructor default False
and can attempt to delete missing attributes; change the fallback to False
(getattr(self, "_uses_torch_symm_mem", False)) and guard deletions with hasattr
checks (or use try/except AttributeError) when removing tensor, handle, ptrs or
mcast_buffer_handle so an early-failed __init__ won't raise; update the logic in
the destroy method referencing _uses_torch_symm_mem, tensor, handle, ptrs, and
mcast_buffer_handle accordingly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 3c380fd9-0eed-4d18-bde9-2e4473e485c7
📒 Files selected for processing (3)
flashinfer/comm/allreduce.pyflashinfer/comm/trtllm_ar.pyflashinfer/comm/trtllm_mnnvl_ar.py
| dtype: torch.dtype = torch.float16, | ||
| comm_backend: Optional[CommBackend] = None, | ||
| group: Optional[ProcessGroup] = None, | ||
| use_torch_symm_mem: bool = False, |
There was a problem hiding this comment.
@aleozlx is this considered breaking backward compatibility? or is this fine?
📌 Description
PR #2955 changed TRTLLM allreduce workspace allocation to use
torch.distributed._symmetric_memoryby default. After that change, every TRTLLM allreduce workspace creation depends on PyTorch's symmetric-memory allocation and rendezvous behavior. In practice this path appears less stable for these communication kernels than the previous FlashInfer/TensorRT-styleSymmDeviceMemoryallocator: I encountered fabric-level TRTLLM allreduce/fusion failures on one SM103 system, while another SM103 system worked, which suggests the default path may be sensitive to platform, fabric, driver, or environment differences. There may also be other failure modes from using torch symmetric memory as the default here that have not surfaced yet.This PR restores the FlashInfer/TensorRT-style
SymmDeviceMemoryallocator as the default for TRTLLM allreduce workspaces. That keeps the default allreduce workspace path closer to the original TRTLLM implementation and avoids making PyTorch symmetric memory a hard dependency for these kernels. The torch symmetric-memory path is still available as an explicit opt-in for deployments that want to test or use it by settingFLASHINFER_TRTLLM_AR_USE_TORCH_SYMM_MEM=1.🔍 Related Issues
Related to #2955.
🚀 Pull Request Checklist
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
This PR does not remove the torch symmetric-memory path. It changes it from the default allocator to an opt-in allocator for TRTLLM allreduce workspaces.
Summary by CodeRabbit