Skip to content

comm: avoid torch symmetric memory by default for TRTLLM allreduce workspaces#3277

Open
mmangkad wants to merge 1 commit into
flashinfer-ai:mainfrom
mmangkad-dev:fix/trtllm-ar-symm-memory-default
Open

comm: avoid torch symmetric memory by default for TRTLLM allreduce workspaces#3277
mmangkad wants to merge 1 commit into
flashinfer-ai:mainfrom
mmangkad-dev:fix/trtllm-ar-symm-memory-default

Conversation

@mmangkad
Copy link
Copy Markdown

@mmangkad mmangkad commented May 9, 2026

📌 Description

PR #2955 changed TRTLLM allreduce workspace allocation to use torch.distributed._symmetric_memory by default. After that change, every TRTLLM allreduce workspace creation depends on PyTorch's symmetric-memory allocation and rendezvous behavior. In practice this path appears less stable for these communication kernels than the previous FlashInfer/TensorRT-style SymmDeviceMemory allocator: I encountered fabric-level TRTLLM allreduce/fusion failures on one SM103 system, while another SM103 system worked, which suggests the default path may be sensitive to platform, fabric, driver, or environment differences. There may also be other failure modes from using torch symmetric memory as the default here that have not surfaced yet.

This PR restores the FlashInfer/TensorRT-style SymmDeviceMemory allocator as the default for TRTLLM allreduce workspaces. That keeps the default allreduce workspace path closer to the original TRTLLM implementation and avoids making PyTorch symmetric memory a hard dependency for these kernels. The torch symmetric-memory path is still available as an explicit opt-in for deployments that want to test or use it by setting FLASHINFER_TRTLLM_AR_USE_TORCH_SYMM_MEM=1.

🔍 Related Issues

Related to #2955.

🚀 Pull Request Checklist

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

This PR does not remove the torch symmetric-memory path. It changes it from the default allocator to an opt-in allocator for TRTLLM allreduce workspaces.

Summary by CodeRabbit

  • Refactor
    • Workspace allocation for TRTLLM and MNNVL all‑reduce fusion now supports an optional Torch symmetric‑memory mode and unifies allocation/teardown across modes.
  • New Features
    • Public APIs now expose a flag to opt into Torch symmetric memory for workspace allocation; process‑group validation is performed when enabled.
  • Documentation
    • Docstrings clarified to describe the “group” parameter as the workspace‑allocation process group.

Review Change Stack

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 9, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a Torch-symmetric-memory allocation mode for TRTLLM and MNNVL all-reduce fusion workspaces, introduces a unified allocator, relaxes workspace reference types to List[Any], and wires the mode through constructors, workspace creation, synchronization, and cleanup.

Changes

Torch Symmetric Memory Support for All-Reduce

Layer / File(s) Summary
Documentation and public API updates
flashinfer/comm/allreduce.py
Removed SymmDeviceMemory import, added use_torch_symm_mem params to constructors/functions, updated group docstrings, and loosened workspace cast to List[Any].
Allocator helper and typing
flashinfer/comm/trtllm_ar.py
Added _alloc_trtllm_ar_workspace_buffer, widened _symm_workspace_refs and fusion return annotations to Any, and added Any typing import.
TRTLLM IPC workspace creation
flashinfer/comm/trtllm_ar.py
trtllm_create_ipc_workspace_for_all_reduce gains use_torch_symm_mem, initializes/validates TorchDistBackend/group when used, and records allocator-owned refs returned by the unified helper.
TRTLLM fusion workspace
flashinfer/comm/trtllm_ar.py
trtllm_create_ipc_workspace_for_all_reduce_fusion uses use_torch_symm_mem, initializes comm_backend if None, validates Torch group when required, uses unified allocator to build ipc_handles and mem_handles: List[Any], and calls comm_backend.barrier().
MNNVL workspace integration
flashinfer/comm/trtllm_mnnvl_ar.py
MNNVLAllReduceFusionWorkspace adds use_torch_symm_mem, conditionally allocates Torch symmetric-memory or McastGPUBuffer, initializes buffers, wires pointer fields, and conditionally frees matching resources in destroy().

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • flashinfer-ai/flashinfer#2955: Also modifies all-reduce workspace allocation in trtllm_ar.py and trtllm_mnnvl_ar.py to add/condition on Torch symmetric memory.
  • flashinfer-ai/flashinfer#2239: Changes the same TRTLLM/MNNVL all-reduce workspace allocation code paths and signatures related to symmetric-memory modes.
  • flashinfer-ai/flashinfer#3247: Related edits to TRTLLM AllReduce fusion workspace constructors and call sites affecting workspace parameterization.

Suggested reviewers

  • aleozlx
  • yzh119
  • bkryu
  • jimmyzho
  • nv-yunzheq

Poem

🐰 I hop through buffers with whiskers bright,
Torch or Symm — I choose by night.
IPCs clack, mem handles in tow,
Typed as Any so meetings flow.
Carrot-coded sync — hop, allocate, go!

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely summarizes the main change: making torch symmetric memory opt-in rather than default for TRTLLM allreduce workspaces.
Description check ✅ Passed The description covers the motivation, implementation approach, and provides related issue context. Pre-commit checks are marked complete, though tests are noted as incomplete.
Docstring Coverage ✅ Passed Docstring coverage is 80.00% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a mechanism to toggle between FlashInfer's internal symmetric memory and PyTorch's symmetric memory for TRTLLM all-reduce workspaces, controlled by the FLASHINFER_TRTLLM_AR_USE_TORCH_SYMM_MEM environment variable. The changes include refactoring workspace allocation into centralized helper functions, updating type hints to be more generic, and modifying the MNNVLAllReduceFusionWorkspace to support both memory backends. A review comment points out a potential correctness issue where lamport_initialize is hardcoded to torch.float32 instead of using the provided dtype.

Comment thread flashinfer/comm/trtllm_mnnvl_ar.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/comm/trtllm_ar.py (1)

777-780: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Barrier selection must match the allocator that was actually chosen.

The allocator selection was changed to use the computed use_torch_symm_mem (line 691), but the barrier selection at lines 777–780 still keys off the parameter use_symm_dev_mem. This creates a mismatch: when use_symm_dev_mem=False (default) and use_torch_symm_mem=False, the code allocates using SymmDeviceMemory(comm_backend) (which exchanges handles via the backend), then syncs on dist.barrier() instead of the backend barrier. For non-TorchDistBackend callers, the IPC handles risk being unsynchronized—causing hangs or races.

Suggested fix
-    if use_symm_dev_mem:
-        comm_backend.barrier()  # must sync after create_workspace
-    else:
-        dist.barrier(group=group)
+    if use_torch_symm_mem:
+        dist.barrier(group=group)
+    else:
+        comm_backend.barrier()  # must sync after create_workspace
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/comm/trtllm_ar.py` around lines 777 - 780, The barrier selection
must match the allocator chosen; replace the conditional that currently checks
use_symm_dev_mem with the computed flag use_torch_symm_mem so the code calls
comm_backend.barrier() when SymmDeviceMemory (backend-managed IPC) was actually
selected and calls dist.barrier(group=group) otherwise; update the branch that
contains comm_backend.barrier() / dist.barrier(...) to key off
use_torch_symm_mem and ensure this aligns with the allocator creation logic
(SymmDeviceMemory and any places that set use_torch_symm_mem).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@flashinfer/comm/trtllm_mnnvl_ar.py`:
- Around line 142-159: The torch-symmetric memory path uses _uses_torch_symm_mem
and currently assumes a TorchDistBackend when building group_name, but silently
falls back to torch.distributed.group.WORLD for non-Torch backends (e.g.,
MPIBackend); change this to explicitly require comm_backend to be an instance of
TorchDistBackend before calling _alloc_symm_buffer_bytes: if comm_backend is not
a TorchDistBackend, raise a clear error (or disable the torch-symmetric option)
so _alloc_symm_buffer_bytes is only invoked with a valid TorchDistBackend and a
proper group_name; update the branch around self._uses_torch_symm_mem,
TorchDistBackend, comm_backend, group_name, and the call to
_alloc_symm_buffer_bytes (which sets self.ptrs, self.tensor, self.handle).

---

Outside diff comments:
In `@flashinfer/comm/trtllm_ar.py`:
- Around line 777-780: The barrier selection must match the allocator chosen;
replace the conditional that currently checks use_symm_dev_mem with the computed
flag use_torch_symm_mem so the code calls comm_backend.barrier() when
SymmDeviceMemory (backend-managed IPC) was actually selected and calls
dist.barrier(group=group) otherwise; update the branch that contains
comm_backend.barrier() / dist.barrier(...) to key off use_torch_symm_mem and
ensure this aligns with the allocator creation logic (SymmDeviceMemory and any
places that set use_torch_symm_mem).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 798dd963-812c-4b72-b365-abb0642a01dd

📥 Commits

Reviewing files that changed from the base of the PR and between 0a128d1 and 7c5c276322311a9dd09096a1e204a72ea1c65525.

📒 Files selected for processing (3)
  • flashinfer/comm/allreduce.py
  • flashinfer/comm/trtllm_ar.py
  • flashinfer/comm/trtllm_mnnvl_ar.py

Comment thread flashinfer/comm/trtllm_mnnvl_ar.py Outdated
@mmangkad mmangkad force-pushed the fix/trtllm-ar-symm-memory-default branch from 7c5c276 to f83e8d6 Compare May 10, 2026 01:45
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@flashinfer/comm/trtllm_mnnvl_ar.py`:
- Around line 150-162: Before calling _alloc_symm_buffer_bytes, add a runtime
guard that verifies the torch process group derived from comm_backend._group (or
torch.distributed.group.WORLD) matches the Mapping metadata: ensure group.size()
== mapping.tp_size and group.rank() == self.rank (or otherwise map consistently)
and if not, log an error and abort/raise; this check should live just above the
_alloc_symm_buffer_bytes call (referencing comm_backend._group, group,
mapping.tp_size, self.rank) and similarly before the later use of
self.handle.buffer_ptrs[self.rank] to prevent out‑of‑bounds or wrong‑peer
allocation.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: a77b0ee7-2832-4cf8-a5c7-ad4624214efe

📥 Commits

Reviewing files that changed from the base of the PR and between 7c5c276322311a9dd09096a1e204a72ea1c65525 and f83e8d602ef1cc8f86d331268b13117d6710e2a0.

📒 Files selected for processing (3)
  • flashinfer/comm/allreduce.py
  • flashinfer/comm/trtllm_ar.py
  • flashinfer/comm/trtllm_mnnvl_ar.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • flashinfer/comm/allreduce.py
  • flashinfer/comm/trtllm_ar.py

Comment thread flashinfer/comm/trtllm_mnnvl_ar.py
@mmangkad mmangkad force-pushed the fix/trtllm-ar-symm-memory-default branch 2 times, most recently from dddc01e to 6968afa Compare May 10, 2026 02:16
Comment thread flashinfer/comm/trtllm_ar.py Outdated
@mmangkad mmangkad force-pushed the fix/trtllm-ar-symm-memory-default branch 2 times, most recently from ae95a0e to d4ce499 Compare May 23, 2026 04:43
@mmangkad mmangkad force-pushed the fix/trtllm-ar-symm-memory-default branch from d4ce499 to bfdebb6 Compare May 23, 2026 04:45
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@flashinfer/comm/trtllm_mnnvl_ar.py`:
- Around line 288-293: The cleanup in destroy() uses getattr(self,
"_uses_torch_symm_mem", True) which contradicts the constructor default False
and can attempt to delete missing attributes; change the fallback to False
(getattr(self, "_uses_torch_symm_mem", False)) and guard deletions with hasattr
checks (or use try/except AttributeError) when removing tensor, handle, ptrs or
mcast_buffer_handle so an early-failed __init__ won't raise; update the logic in
the destroy method referencing _uses_torch_symm_mem, tensor, handle, ptrs, and
mcast_buffer_handle accordingly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 3c380fd9-0eed-4d18-bde9-2e4473e485c7

📥 Commits

Reviewing files that changed from the base of the PR and between 6968afa and ae95a0e.

📒 Files selected for processing (3)
  • flashinfer/comm/allreduce.py
  • flashinfer/comm/trtllm_ar.py
  • flashinfer/comm/trtllm_mnnvl_ar.py

Comment thread flashinfer/comm/trtllm_mnnvl_ar.py
dtype: torch.dtype = torch.float16,
comm_backend: Optional[CommBackend] = None,
group: Optional[ProcessGroup] = None,
use_torch_symm_mem: bool = False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aleozlx is this considered breaking backward compatibility? or is this fine?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants