Skip to content

Refactor: fix symmetric memory pool isolation per communication group#20153

Open
wangfakang wants to merge 5 commits intosgl-project:mainfrom
wangfakang:refactor_symm_pool
Open

Refactor: fix symmetric memory pool isolation per communication group#20153
wangfakang wants to merge 5 commits intosgl-project:mainfrom
wangfakang:refactor_symm_pool

Conversation

@wangfakang
Copy link
Copy Markdown
Contributor

@wangfakang wangfakang commented Mar 9, 2026

CC @nvcastet @yizhang2077 @merrymercy @ShangmingCai @Fridge003 @ch-wan PTAL, thx.

Motivation

When multiple communication groups share a single global MemPool, memory blocks released by one group's comm may be reused by another group's comm. However, symmetric memory requires buffers to be registered with a specific ncclComm via ncclCommWindowRegister. Reusing memory across groups causes the registration to be associated with the wrong communicator.

I refactored SymmPool to replace the global MemPool with a per-group dictionary. Now each communication group has its own MemPool, ensuring proper memory registration and preventing cross-group allocation issues in multi-comm scenarios.

Related PR: #19329 (comment)

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

Signed-off-by: wangfakang <fakangwang@gmail.com>
@wangfakang wangfakang force-pushed the refactor_symm_pool branch from 143aaf6 to 850f663 Compare March 9, 2026 05:59
Signed-off-by: wangfakang <fakangwang@gmail.com>
@wangfakang
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

3 similar comments
@wangfakang
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@wangfakang
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@yizhang2077
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@nvcastet
Copy link
Copy Markdown
Collaborator

There is no reason why the symmetric memory cannot be register by multiple groups. Creating extra pools will increase the memory footprint and also create more fragmentation.

@wangfakang
Copy link
Copy Markdown
Contributor Author

There is no reason why the symmetric memory cannot be register by multiple groups. Creating extra pools will increase the memory footprint and also create more fragmentation.

@nvcastet @yizhang2077 NCCL symmetric memory function requires that the registered memory must be consistent with the binding of the communicator/group, otherwise it will cause the symmetric memory function to not be enabled.

So as mentioned in the modified motivation, when multiple communication groups share a single global MemPool, memory blocks released by one group's comm may be reused by another group's comm. However, symmetric memory requires buffers to be registered with a specific ncclComm via ncclCommWindowRegister. Reusing memory across groups causes the registration to be associated with the wrong communicator.

This fix without the cpu overhead of snapshot(), while memory usage remains basically unchanged from the previous version based on testing observations. BTW Some of the previously discussed contexts are here #19329 (comment) .

The NCCL's window/memory lookup is scoped to the specific communicator/group used in the collective call, not globally. Here's the key point from the source code:

  1. Window registration is stored in the communicator's local state. So the registered window is stored in comm->devrState which is per-communicator state.
  // src/dev_runtime.cc:1097
 ncclResult_t ncclCommWindowRegister(ncclComm_t comm, void* buff, ...) {
……
  ncclIntruQueueEnqueue(&comm->devrState.regTaskQueue, task);
……
}
  1. Window lookup searches only the current communicator's state.
  // src/dev_runtime.cc:1132-1144
  ncclResult_t ncclDevrFindWindow(struct ncclComm* comm, void const* userPtr, ...) {
……
    struct ncclDevrState* devr = &comm->devrState;  // Uses the comm passed in
    // Searches in devr->winSorted
……
  }
  1. Collective operations use the communicator passed at call time.
 // src/enqueue.cc:2914-2915
 ncclDevrFindWindow(comm, info->sendbuff, &sendWin);
 ncclDevrFindWindow(comm, info->recvbuff, &recvWin);

@nvcastet
Copy link
Copy Markdown
Collaborator

There is no reason why the symmetric memory cannot be register by multiple groups. Creating extra pools will increase the memory footprint and also create more fragmentation.

@nvcastet @yizhang2077 NCCL symmetric memory function requires that the registered memory must be consistent with the binding of the communicator/group, otherwise it will cause the symmetric memory function to not be enabled.

ncclCommWindowRegister can be called multiple times on the same memory with multiple groups (aka you can register all the memory segments of the one pool over time with all the communicators that use symmetric memory)

Copy link
Copy Markdown
Collaborator

@nvcastet nvcastet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need redesign.

@wangfakang
Copy link
Copy Markdown
Contributor Author

There is no reason why the symmetric memory cannot be register by multiple groups. Creating extra pools will increase the memory footprint and also create more fragmentation.

@nvcastet @yizhang2077 NCCL symmetric memory function requires that the registered memory must be consistent with the binding of the communicator/group, otherwise it will cause the symmetric memory function to not be enabled.

ncclCommWindowRegister can be called multiple times on the same memory with multiple groups (aka you can register all the memory segments of the one pool over time with all the communicators that use symmetric memory)

@nvcastet I agree that ncclCommWindowRegister supports registering the same memory with multiple groups. What I mean is, if we address the symm pool comm/group inconsistency issue by register or re-register at the exit of the context manager , there will be CPU overhead caused by snapshot().
Therefore, this PR is a compromise solution that completely address the symm pool comm/group inconsistency issue without the CPU overhead of pytorch memory snapshot API.

@nvcastet
Copy link
Copy Markdown
Collaborator

nvcastet commented Mar 13, 2026

There is no reason why the symmetric memory cannot be register by multiple groups. Creating extra pools will increase the memory footprint and also create more fragmentation.

@nvcastet @yizhang2077 NCCL symmetric memory function requires that the registered memory must be consistent with the binding of the communicator/group, otherwise it will cause the symmetric memory function to not be enabled.

ncclCommWindowRegister can be called multiple times on the same memory with multiple groups (aka you can register all the memory segments of the one pool over time with all the communicators that use symmetric memory)

@nvcastet I agree that ncclCommWindowRegister supports registering the same memory with multiple groups. What I mean is, if we address the symm pool comm/group inconsistency issue by register or re-register at the exit of the context manager , there will be CPU overhead caused by snapshot(). Therefore, this PR is a compromise solution that completely address the symm pool comm/group inconsistency issue without the CPU overhead of pytorch memory snapshot API.

I think we can keep track of the memory segment ptr ourself in C++ to be able to access them when we want to register multiple groups.

In that way the symmetric memory pool pre-allocation can be re-used for any groups.

@wangfakang
Copy link
Copy Markdown
Contributor Author

wangfakang commented Mar 17, 2026

@nvcastet @merrymercy Thanks for the suggestion. I've explored the approach of tracking memory segments in C++ to enable a shared pool across groups. However, there's a fundamental challenge:
The lack of callback on pool cache reuse means we cannot accurately track which segments are "used" in the current context. When PyTorch's MemPool reuses memory from its inactive cache, the plug allocator's alloc callback is not invoked the memory is simply returned from the pool's internal cache. It's means:

  1. New allocations ——> nccl_alloc_plug is called ——> we can track the ptr
  2. Cache reuse ——> no callback(can‘t get ptr in O(1)) ——> we have no visibility into which ptr is being used.
    This forces us into a dilemma:
  • Option A: Track all segments and register them all at context exit. This requires iterating through all tracked segments in C++, which introduces CPU overhead equivalent to the Python snapshot() approach.
  • Option B: Accept that some segments may be missed on cache reuse. This breaks correctness.

Since there's no way to get a callback when MemPool returns cached memory, the C++ tracking approach doesn't fundamentally solve the overhead problem - it just moves the iteration from Python (snapshot()) to C++ (iterating g_memory_segments). The CPU overhead remains proportional to the total number of tracked segments, regardless of how many were actually used in the current context.

Given this limitation, the per-group MemPool approach in this PR remains a practical solution that guarantees correctness without runtime overhead, at the cost of some memory fragmentation (which is manageable in practice).

The current implementation in PyTorch also has a similar issue of CPU overhead:

https://github.com/pytorch/pytorch/blob/d62bd2befde3872e4c6132dec05934ec056065b8/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L1131-L1154

@nvcastet
Copy link
Copy Markdown
Collaborator

Option A: Track all segments and register them all at context exit. This requires iterating through all tracked segments in C++, which introduces CPU overhead equivalent to the Python snapshot() approach.

@wangfakang I disagree. There is a small number of segments in the symmetric pools. The CPU overhead came from the python/pytorch impl.

Track all segments and register them all at context exit.

Only if new segment is added or new group is seen.

@nvcastet
Copy link
Copy Markdown
Collaborator

@wangfakang Please benchmark the approach to evaluate option A.

@wangfakang
Copy link
Copy Markdown
Contributor Author

@wangfakang Please benchmark the approach to evaluate option A.

hello @nvcastet , I have refactored a commit based on option A and completed benchmark testing, which found that the cpu cost of _get_tracked_segments() is about 25 times lower than the snapshot() function (5.351μs vs 134.320μs).

#NCCL_DEBUG=WARN python benchmark/bench_pynccl_allocator/bench_segment_tracking.py --num-segments 50 --num-iters 1000
================================================================================
Benchmark: Segment Tracking CPU Overhead
================================================================================
Segment size: 1.00 MB
Iterations per measurement: 1000

Segments     _get_tracked_segments (µs)     snapshot (µs)        Speedup   
--------------------------------------------------------------------------------
25           5.351                          134.320              25.10     x
--------------------------------------------------------------------------------

However, the current option A occasionally reports an error when the CUDA Graph mode is enabled, that is, when triggering the register_comm_window_raw function to call cuMemMap internally, it reports error dev_runtime.cc:226 NCCL WARN Cuda failure 801 'operation not supported'. Would you help take a look commit ? thx.

@wangfakang
Copy link
Copy Markdown
Contributor Author

Frendly ping @nvcastet cc @merrymercy

@wangfakang
Copy link
Copy Markdown
Contributor Author

wangfakang commented Mar 24, 2026

I reproduced an issue where option A (different comm/groups sharing a memory pool) occasionally causes NCCL's function ncclCommWindowRegister to report error dev_runtime.cc:226 NCCL WARN Cuda failure 801 'operation not supported' when internally calling cuMemMap in certain scenarios. I added some debugging information in both NCCL and SGLang, but didn't find any apparent address conflict issues. However, since cuMemMap is closed-source, it's impossible to determine what specific limitation triggers this error. @nvcastet @xiaofanl-nvidia @sjeaugey @AddyLaddy Could you help me take a look?

Under option A , executing the following code will trigger the cuMemMap error at the Section2 location. Meanwhile, if the code at Section1 is commented out, the error does not occur.

Additionally, when using the approach in this PR (isolated memory pools for different comm/groups) to run the same code, this error does not occur.

Reproduce the code of the issue:

"""
Test script for 8-GPU all_gather with symmetric memory using sglang pynccl API.

Usage:
    SGLANG_ENABLE_SYMM_MEM=1 torchrun --nproc_per_node=8 test_symm_allgather.py

Requirements:
    - SGLANG_ENABLE_SYMM_MEM=1 environment variable must be set
"""

import os

# Enable symmetric memory before importing sglang
os.environ["SGLANG_ENABLE_SYMM_MEM"] = "1"
os.environ["NCCL_CUMEM_ENABLE"] = "1"
os.environ["NCCL_NVLS_ENABLE"] = "1"

import torch
import torch.distributed as dist

from sglang.srt.distributed.parallel_state import (
    GroupCoordinator,
    init_distributed_environment,
    destroy_model_parallel,
)
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
    is_symmetric_memory_enabled,
)

def main():
    # Initialize distributed environment
    rank = int(os.environ.get("RANK", 0))
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))

    print(f"[Rank {rank}] Initializing distributed environment...")

    os.environ["SGLANG_ENABLE_SYMM_MEM"] = "1"
    os.environ["NCCL_CUMEM_ENABLE"] = "1"
    os.environ["NCCL_NVLS_ENABLE"] = "1"

    # Initialize torch distributed
    init_distributed_environment(
        world_size=world_size,
        rank=rank,
        local_rank=local_rank,
        backend="nccl",
    )

    # Create GroupCoordinator for all ranks
    group_coordinator = GroupCoordinator(
        group_ranks=[list(range(world_size))],
        local_rank=local_rank,
        torch_distributed_backend="nccl",
        use_pynccl=True,
        use_pymscclpp=False,
        use_custom_allreduce=False,
        use_torch_symm_mem_all_reduce=False,
        use_hpu_communicator=False,
        use_xpu_communicator=False,
        use_npu_communicator=False,
        group_name="test_group1",
    )

    device = group_coordinator.device
    torch.cuda.set_device(device)

    print(f"[Rank {rank}] Device: {device}")

    # Test parameters
    input_size = 1024
    input_size1 = 1024 * 1024
    dtype = torch.bfloat16

    # Static tensors for both operations
    with group_coordinator.use_symmetric_memory(group_coordinator, disabled=False):
        # First all_gather tensors (small size)
        static_input1 = torch.ones(input_size, dtype=dtype, device=device) * (rank + 1)
        static_output1 = torch.empty(input_size * world_size, dtype=dtype, device=device)
        # Second all_gather tensors (larger size)
        static_input2 = torch.ones(input_size1, dtype=dtype, device=device) * (rank + 1)
        static_output2 = torch.empty(input_size1 * world_size, dtype=dtype, device=device)

    # Warmup runs for both operations (important for NCCL)
    print(f"[Rank {rank}] Warming up both all_gather operations...")
    with group_coordinator.use_symmetric_memory(group_coordinator, disabled=False):
        warmup_output1 = torch.empty(input_size * world_size, dtype=dtype, device=device)
        warmup_output2 = torch.empty(input_size1 * world_size, dtype=dtype, device=device)

    # Warmup execution
    group_coordinator.all_gather_into_tensor(warmup_output1, static_input1)
    group_coordinator.all_gather_into_tensor(warmup_output2, static_input2)
    torch.cuda.synchronize()

    # First all_gather operation
    group_coordinator.all_gather_into_tensor(static_output1, static_input1)
    # Second all_gather operation
    group_coordinator.all_gather_into_tensor(static_output2, static_input2)
    torch.cuda.synchronize()

    # Section1: start key point code block for reproduce issue
    with group_coordinator.use_symmetric_memory(group_coordinator, disabled=False):
        exec_output1 = torch.empty(input_size * world_size, dtype=dtype, device=device)
        exec_output2 = torch.empty(input_size1 * world_size, dtype=dtype, device=device)
    # Copy results to execution tensors
    exec_output1.copy_(static_output1)
    exec_output2.copy_(static_output2)
    torch.cuda.synchronize()
    # Section1: end key point code block for reproduce issue

    # Second GroupCoordinator
    group_coordinator2 = GroupCoordinator(
        group_ranks=[list(range(world_size))],
        local_rank=local_rank,
        torch_distributed_backend="nccl",
        use_pynccl=True,
        use_pymscclpp=False,
        use_custom_allreduce=False,
        use_torch_symm_mem_all_reduce=False,
        use_hpu_communicator=False,
        use_xpu_communicator=False,
        use_npu_communicator=False,
        group_name="test_group2",
    )
    
    # Section2: it's report `dev_runtime.cc:229 NCCL WARN Cuda failure 801 'operation not supported'`
    with group_coordinator2.use_symmetric_memory(group_coordinator2, disabled=False):
        static_input_coord2 = torch.ones(input_size, dtype=dtype, device=device) * (rank + 1)
        static_output_coord2 = torch.empty(input_size * world_size, dtype=dtype, device=device)

    torch.cuda.synchronize()

if __name__ == "__main__":
    main()

Debug NCCL patch:

diff --git a/src/dev_runtime.cc b/src/dev_runtime.cc
index bb13314..4036620 100644
--- a/src/dev_runtime.cc
+++ b/src/dev_runtime.cc
@@ -223,6 +223,9 @@ static ncclResult_t symMemoryMapLsaTeam(
       }
     }
     CUdeviceptr addr = reinterpret_cast<uintptr_t>((char*)devr->lsaFlatBase + r*devr->bigSize + bigOffset);
+    if(comm->rank == 0) {
+        INFO(NCCL_TUNING, "comm %#p addr %#p handle: %#p dstrank: %d lsabase: %#p size: %ld offset: %ld", comm, addr, impHandle, r, devr->lsaFlatBase, size, bigOffset);
+    }
     CUCHECKGOTO(cuMemMap(addr, size, 0, impHandle, 0), ret, fail);
     CUCHECKGOTO(cuMemSetAccess(addr, size, &accessDesc, 1), ret, fail);
     if (r != devr->lsaSelf) {

some debug info:
comm 0xe4e9900 is test_group1, comm 0x21e210c0 is test_group2

aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x10020000000 handle: 0x21dd1c30 dstrank: 0 lsabase: 0x10020000000 size: 10485760 offset: 0
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x12320000000 handle: 0x21dd3d90 dstrank: 1 lsabase: 0x10020000000 size: 10485760 offset: 0
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x14620000000 handle: 0x21dd55a0 dstrank: 2 lsabase: 0x10020000000 size: 10485760 offset: 0
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x16920000000 handle: 0x21dd6d80 dstrank: 3 lsabase: 0x10020000000 size: 10485760 offset: 0
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x18c20000000 handle: 0x21dd8560 dstrank: 4 lsabase: 0x10020000000 size: 10485760 offset: 0
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x1af20000000 handle: 0x21dd9d40 dstrank: 5 lsabase: 0x10020000000 size: 10485760 offset: 0
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x1d220000000 handle: 0x21ddb520 dstrank: 6 lsabase: 0x10020000000 size: 10485760 offset: 0
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x1f520000000 handle: 0x21ddcd00 dstrank: 7 lsabase: 0x10020000000 size: 10485760 offset: 0
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x10020a00000 handle: 0x21096c70 dstrank: 0 lsabase: 0x10020000000 size: 20971520 offset: 10485760
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x12320a00000 handle: 0x21ddfee0 dstrank: 1 lsabase: 0x10020000000 size: 20971520 offset: 10485760
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x14620a00000 handle: 0x21de1490 dstrank: 2 lsabase: 0x10020000000 size: 20971520 offset: 10485760
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x16920a00000 handle: 0x21de2a90 dstrank: 3 lsabase: 0x10020000000 size: 20971520 offset: 10485760
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x18c20a00000 handle: 0x21de4090 dstrank: 4 lsabase: 0x10020000000 size: 20971520 offset: 10485760
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x1af20a00000 handle: 0x21de5690 dstrank: 5 lsabase: 0x10020000000 size: 20971520 offset: 10485760
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x1d220a00000 handle: 0x21de6c90 dstrank: 6 lsabase: 0x10020000000 size: 20971520 offset: 10485760
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x1f520a00000 handle: 0x21de8290 dstrank: 7 lsabase: 0x10020000000 size: 20971520 offset: 10485760
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x10021e00000 handle: 0x2011b660 dstrank: 0 lsabase: 0x10020000000 size: 2097152 offset: 31457280
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x12321e00000 handle: 0x21deb4f0 dstrank: 1 lsabase: 0x10020000000 size: 2097152 offset: 31457280
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x14621e00000 handle: 0x21decaa0 dstrank: 2 lsabase: 0x10020000000 size: 2097152 offset: 31457280
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x16921e00000 handle: 0x21dee0a0 dstrank: 3 lsabase: 0x10020000000 size: 2097152 offset: 31457280
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x18c21e00000 handle: 0x21def6a0 dstrank: 4 lsabase: 0x10020000000 size: 2097152 offset: 31457280
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x1af21e00000 handle: 0x21df0ca0 dstrank: 5 lsabase: 0x10020000000 size: 2097152 offset: 31457280
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x1d221e00000 handle: 0x21df22a0 dstrank: 6 lsabase: 0x10020000000 size: 2097152 offset: 31457280
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x1f521e00000 handle: 0x21df38a0 dstrank: 7 lsabase: 0x10020000000 size: 2097152 offset: 31457280
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x10022000000 handle: 0x21df5620 dstrank: 0 lsabase: 0x10020000000 size: 16777216 offset: 33554432
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x12322000000 handle: 0x21dfd900 dstrank: 1 lsabase: 0x10020000000 size: 16777216 offset: 33554432
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x14622000000 handle: 0x21dfeec0 dstrank: 2 lsabase: 0x10020000000 size: 16777216 offset: 33554432
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x16922000000 handle: 0x21e00480 dstrank: 3 lsabase: 0x10020000000 size: 16777216 offset: 33554432
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x18c22000000 handle: 0x21e01a40 dstrank: 4 lsabase: 0x10020000000 size: 16777216 offset: 33554432
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x1af22000000 handle: 0x21e03040 dstrank: 5 lsabase: 0x10020000000 size: 16777216 offset: 33554432
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x1d222000000 handle: 0x21e04640 dstrank: 6 lsabase: 0x10020000000 size: 16777216 offset: 33554432
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x1f522000000 handle: 0x21e05c40 dstrank: 7 lsabase: 0x10020000000 size: 16777216 offset: 33554432
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x10023000000 handle: 0x21e08180 dstrank: 0 lsabase: 0x10020000000 size: 16777216 offset: 50331648
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x12323000000 handle: 0x21e11450 dstrank: 1 lsabase: 0x10020000000 size: 16777216 offset: 50331648
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x14623000000 handle: 0x21e129f0 dstrank: 2 lsabase: 0x10020000000 size: 16777216 offset: 50331648
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x16923000000 handle: 0x21e13fb0 dstrank: 3 lsabase: 0x10020000000 size: 16777216 offset: 50331648
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x18c23000000 handle: 0x21e15570 dstrank: 4 lsabase: 0x10020000000 size: 16777216 offset: 50331648
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x1af23000000 handle: 0x21e16b70 dstrank: 5 lsabase: 0x10020000000 size: 16777216 offset: 50331648
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x1d223000000 handle: 0x21e18170 dstrank: 6 lsabase: 0x10020000000 size: 16777216 offset: 50331648
aa714b631e92:85904:85904 [0] NCCL INFO comm 0xe4e9900 addr 0x1f523000000 handle: 0x21e19770 dstrank: 7 lsabase: 0x10020000000 size: 16777216 offset: 50331648
aa714b631e92:85904:85904 [0] NCCL INFO comm 0x21e210c0 addr 0x2c720000000 handle: 0x23744150 dstrank: 0 lsabase: 0x2c720000000 size: 10485760 offset: 0
aa714b631e92:85904:85904 [0] NCCL INFO comm 0x21e210c0 addr 0x2ea20000000 handle: 0x237460d0 dstrank: 1 lsabase: 0x2c720000000 size: 10485760 offset: 0
aa714b631e92:85904:85904 [0] NCCL INFO comm 0x21e210c0 addr 0x30d20000000 handle: 0x23843570 dstrank: 2 lsabase: 0x2c720000000 size: 10485760 offset: 0
aa714b631e92:85904:85904 [0] NCCL INFO comm 0x21e210c0 addr 0x33020000000 handle: 0x23844aa0 dstrank: 3 lsabase: 0x2c720000000 size: 10485760 offset: 0
aa714b631e92:85904:85904 [0] NCCL INFO comm 0x21e210c0 addr 0x35320000000 handle: 0x23846280 dstrank: 4 lsabase: 0x2c720000000 size: 10485760 offset: 0
aa714b631e92:85904:85904 [0] NCCL INFO comm 0x21e210c0 addr 0x37620000000 handle: 0x23847a60 dstrank: 5 lsabase: 0x2c720000000 size: 10485760 offset: 0
aa714b631e92:85904:85904 [0] NCCL INFO comm 0x21e210c0 addr 0x39920000000 handle: 0x23849240 dstrank: 6 lsabase: 0x2c720000000 size: 10485760 offset: 0
aa714b631e92:85904:85904 [0] NCCL INFO comm 0x21e210c0 addr 0x3bc20000000 handle: 0x2384aa20 dstrank: 7 lsabase: 0x2c720000000 size: 10485760 offset: 0
aa714b631e92:85904:85904 [0] NCCL INFO comm 0x21e210c0 addr 0x2c720a00000 handle: 0x21df5620 dstrank: 0 lsabase: 0x2c720000000 size: 16777216 offset: 10485760
aa714b631e92:85904:85904 [0] NCCL INFO comm 0x21e210c0 addr 0x2ea20a00000 handle: 0x2384db80 dstrank: 1 lsabase: 0x2c720000000 size: 16777216 offset: 10485760
aa714b631e92:85904:85904 [0] NCCL INFO comm 0x21e210c0 addr 0x30d20a00000 handle: 0x2384f130 dstrank: 2 lsabase: 0x2c720000000 size: 16777216 offset: 10485760
aa714b631e92:85904:85904 [0] NCCL INFO comm 0x21e210c0 addr 0x33020a00000 handle: 0x23850730 dstrank: 3 lsabase: 0x2c720000000 size: 16777216 offset: 10485760
aa714b631e92:85904:85904 [0] NCCL INFO comm 0x21e210c0 addr 0x35320a00000 handle: 0x23851d30 dstrank: 4 lsabase: 0x2c720000000 size: 16777216 offset: 10485760
aa714b631e92:85904:85904 [0] NCCL INFO comm 0x21e210c0 addr 0x37620a00000 handle: 0x23853330 dstrank: 5 lsabase: 0x2c720000000 size: 16777216 offset: 10485760
aa714b631e92:85904:85904 [0] NCCL INFO comm 0x21e210c0 addr 0x39920a00000 handle: 0x23854930 dstrank: 6 lsabase: 0x2c720000000 size: 16777216 offset: 10485760
aa714b631e92:85904:85904 [0] NCCL INFO comm 0x21e210c0 addr 0x3bc20a00000 handle: 0x23855f30 dstrank: 7 lsabase: 0x2c720000000 size: 16777216 offset: 10485760
aa714b631e92:85904:85904 [0] NCCL INFO comm 0x21e210c0 addr 0x2c721a00000 handle: 0x21096c70 dstrank: 0 lsabase: 0x2c720000000 size: 20971520 offset: 27262976
aa714b631e92:85904:85904 [0] NCCL INFO comm 0x21e210c0 addr 0x2ea21a00000 handle: 0x23859100 dstrank: 1 lsabase: 0x2c720000000 size: 20971520 offset: 27262976 【will trigger cummap 801 code error】

NCCL warn log:

[2026-03-24 13:57:07] aa714b631e92:85906:85906 [2] dev_runtime.cc:229 NCCL WARN Cuda failure 801 'operation not supported'
[2026-03-24 13:57:07] aa714b631e92:85904:85904 [0] dev_runtime.cc:229 NCCL WARN Cuda failure 801 'operation not supported'
[2026-03-24 13:57:07] aa714b631e92:85905:85905 [1] dev_runtime.cc:229 NCCL WARN Cuda failure 801 'operation not supported'
[2026-03-24 13:57:07] aa714b631e92:85910:85910 [6] dev_runtime.cc:229 NCCL WARN Cuda failure 801 'operation not supported'
[2026-03-24 13:57:07] aa714b631e92:85907:85907 [3] dev_runtime.cc:229 NCCL WARN Cuda failure 801 'operation not supported'
[2026-03-24 13:57:07] aa714b631e92:85911:85911 [7] dev_runtime.cc:229 NCCL WARN Cuda failure 801 'operation not supported'
[2026-03-24 13:57:07] aa714b631e92:85908:85908 [4] dev_runtime.cc:229 NCCL WARN Cuda failure 801 'operation not supported'
[2026-03-24 13:57:07] aa714b631e92:85909:85909 [5] dev_runtime.cc:229 NCCL WARN Cuda failure 801 'operation not supported'

@wangfakang
Copy link
Copy Markdown
Contributor Author

@nvcastet @xiaofanl-nvidia @sjeaugey @AddyLaddy I tried to register the address list in the same order as the address list registered in comm2/group2 and comm1/group1, and then there was no error message. But I haven't seen any documentation describing the usage limit of cuMemMap, so I'm not sure if this issue has been completely fixed. Could you help confirm if cuMemMap on the cu12.8 driver has this usage limit.

@wangfakang
Copy link
Copy Markdown
Contributor Author

@nvcastet @xiaofanl-nvidia @sjeaugey @AddyLaddy I tried to register the address list in the same order as the address list registered in comm2/group2 and comm1/group1, and then there was no error message. But I haven't seen any documentation describing the usage limit of cuMemMap, so I'm not sure if this issue has been completely fixed. Could you help confirm if cuMemMap on the cu12.8 driver has this usage limit.

hello @nvcastet Based on our discussion, I've refactored the code and opened a new PR. PTAL, thx.

@nvcastet
Copy link
Copy Markdown
Collaborator

nvcastet commented Mar 30, 2026

Thanks @wangfakang !
I apologize for the delayed answer, I was on vacation last week.
1- Did you resolve the issue mentioned at #20153 (comment) ?
2- I can have a look at #21392 this week

@wangfakang
Copy link
Copy Markdown
Contributor Author

wangfakang commented Apr 2, 2026

Thanks @wangfakang ! I apologize for the delayed answer, I was on vacation last week. 1- Did you resolve the issue mentioned at #20153 (comment) ? 2- I can have a look at #21392 this week

@nvcastet Thank you, the issue is resolved. I registered the address list in the same order as the address list registered in comm2/group2 and comm1/group1, and cuMemMap did not report any errors. I haven't found any documentation describing cuMemMap's usage limits.

@nvcastet
Copy link
Copy Markdown
Collaborator

nvcastet commented Apr 6, 2026

@nvcastet Thank you, the issue is resolved. I registered the address list in the same order as the address list registered in comm2/group2 and comm1/group1, and cuMemMap did not report any errors. I haven't found any documentation describing cuMemMap's usage limits.

That makes sense, window registration is a collective op, so you will deadlock if they don't happen in the same order on all the GPUs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants