Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@
"""

_allocator = None
_mem_pool = None
# MemPool per group, keyed by group_name
_mem_pools = {}
_graph_pool_id = None
_cur_device = None
_active_symmetric_memory_context = None
Expand Down Expand Up @@ -99,9 +100,10 @@ def restore_symmetric_memory_context(saved_context):
saved_context.__enter__()


def get_nccl_mem_pool():
global _allocator, _mem_pool, _cur_device
if _mem_pool is None:
def _init_allocator():
"""Initialize the shared allocator. Only called once."""
global _allocator, _cur_device
if _allocator is None:
import torch.utils.cpp_extension

out_dir = os.path.join(tempfile.gettempdir(), "symm_allocator")
Expand Down Expand Up @@ -130,9 +132,25 @@ def get_nccl_mem_pool():
"nccl_alloc_plug",
"nccl_free_plug",
).allocator()
_mem_pool = torch.cuda.MemPool(_allocator)
_cur_device = torch.cuda.current_device()
return _mem_pool


def get_nccl_mem_pool(group_name: str) -> torch.cuda.MemPool:
"""
Get or create a MemPool for the specified group.

Each group gets its own MemPool to ensure memory isolation.
This guarantees that memory allocated for one group's comm
is properly registered with that comm when reused from pool.
"""
global _mem_pools

_init_allocator()

if group_name not in _mem_pools:
_mem_pools[group_name] = torch.cuda.MemPool(_allocator)

return _mem_pools[group_name]


class SymmetricMemoryContext:
Expand All @@ -150,14 +168,15 @@ def __init__(
group_coordinator: GroupCoordinator,
):
self.group_coordinator = group_coordinator
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
self.group_name = group_coordinator.unique_name
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool(self.group_name))
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
self.exited = False

def __enter__(self):
assert (
self.group_coordinator.pynccl_comm is not None
), f"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'"
), f"Symmetric memory requires pynccl to be enabled in group '{self.group_name}'"

if self.is_graph_capture:
assert (
Expand All @@ -173,7 +192,9 @@ def __enter__(self):

if self.exited:
# mempool ctx (@contextlib.contextmanager) is not re-entrant
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
self._mem_pool_ctx = torch.cuda.use_mem_pool(
get_nccl_mem_pool(self.group_name)
)
self.exited = False
self._mem_pool_ctx.__enter__()

Expand Down
Loading