diff --git a/flashinfer/comm/__init__.py b/flashinfer/comm/__init__.py index 31d23a99f3..689629e25d 100644 --- a/flashinfer/comm/__init__.py +++ b/flashinfer/comm/__init__.py @@ -68,7 +68,7 @@ # DCP A2A (Decode Context Parallel Attention Reduction) from .dcp_alltoall import decode_cp_a2a_alltoall as decode_cp_a2a_alltoall from .dcp_alltoall import ( - decode_cp_a2a_allocate_workspace as decode_cp_a2a_allocate_workspace, + decode_cp_a2a_allocate_mnnvl_workspace as decode_cp_a2a_allocate_mnnvl_workspace, ) from .dcp_alltoall import decode_cp_a2a_init_workspace as decode_cp_a2a_init_workspace from .dcp_alltoall import decode_cp_a2a_workspace_size as decode_cp_a2a_workspace_size diff --git a/flashinfer/comm/dcp_alltoall.py b/flashinfer/comm/dcp_alltoall.py index 3047f76ce9..a53c289786 100644 --- a/flashinfer/comm/dcp_alltoall.py +++ b/flashinfer/comm/dcp_alltoall.py @@ -4,13 +4,18 @@ Provides the DCP LL128 FIFO-based all-to-all kernel for context-parallel attention reduction. Uses SM90+ features (TMA, mbarrier). +The kernel addresses peer FIFOs via ``params.workspace + peer_rank * stride``, +so it requires a single unified VA spanning all CP ranks — i.e. MNNVL +fabric memory (currently provided by GB200-NVL72 systems). Non-MNNVL +allocations cannot satisfy this layout and would deadlock at runtime. + Usage protocol:: # 1. Query workspace size ws_bytes = decode_cp_a2a_workspace_size(cp_size) - # 2. Allocate workspace (MNNVL or plain device memory) - workspace = decode_cp_a2a_allocate_workspace(cp_size, cp_rank, mapping=mapping) + # 2. Allocate MNNVL-backed workspace (Mapping is required and carries cp_size/cp_rank) + workspace = decode_cp_a2a_allocate_mnnvl_workspace(mapping) # 3. Initialize workspace (synchronous — includes stream sync) decode_cp_a2a_init_workspace(workspace, cp_rank, cp_size) @@ -26,7 +31,7 @@ .. important:: All ranks MUST complete ``decode_cp_a2a_init_workspace`` and execute a cross-rank barrier before ANY rank calls ``decode_cp_a2a_alltoall``. - Failure to do so causes a deadlock on MNNVL workspaces. + Failure to do so causes a deadlock. Tensor specifications: @@ -35,13 +40,13 @@ - ``softmax_stats``: ``[..., cp_size, S]`` — float32, ``S >= 2`` and even. Batch dims must match ``partial_o``. - ``workspace``: ``[cp_size, ws_elems_per_rank]`` — int64, from - :func:`decode_cp_a2a_allocate_workspace`. + :func:`decode_cp_a2a_allocate_mnnvl_workspace`. """ import functools import logging from types import SimpleNamespace -from typing import Optional +from typing import Dict, Optional import torch @@ -100,6 +105,15 @@ def decode_cp_a2a_alltoall( # ─── Public API ─────────────────────────────────────────────────────────── +# Module-level keep-alive for MNNVL workspace handles. The kernel uses raw +# pointers from the strided tensor, but the underlying fabric memory is owned +# by the MnnvlMemory wrapper — when its refcount hits zero, ``__del__`` calls +# ``close_mnnvl_memory`` and unmaps the VA. Without a stable reference outside +# the returned tensor, any caller-side ``view`` / ``slice`` / ``clone`` that +# drops the original tensor would silently free the workspace under the kernel. +_workspace_keepalive: Dict[int, MnnvlMemory] = {} + + @flashinfer_api def decode_cp_a2a_workspace_size(cp_size: int) -> int: """Return the workspace size **in bytes** per rank for the given CP group size. @@ -119,33 +133,24 @@ def decode_cp_a2a_workspace_size(cp_size: int) -> int: @flashinfer_api -def decode_cp_a2a_allocate_workspace( - cp_size: int, - cp_rank: int, +def decode_cp_a2a_allocate_mnnvl_workspace( + mapping: Mapping, *, - mapping: Optional[Mapping] = None, mnnvl_config: Optional[MnnvlConfig] = None, ) -> torch.Tensor: - """Allocate a workspace tensor of shape ``[cp_size, ws_elems_per_rank]``. + """Allocate an MNNVL-backed workspace of shape ``[cp_size, ws_elems_per_rank]``. + + The DCP A2A kernel requires a single unified VA spanning all CP ranks + (see module docstring), so workspace allocation must go through MNNVL + fabric memory. This function is the only supported allocator. After allocation, call :func:`decode_cp_a2a_init_workspace` followed by a cross-rank barrier before the first :func:`decode_cp_a2a_alltoall` call. - Two allocation modes: - - - **MNNVL** (``mapping`` provided): Cross-rank visible GPU memory via - FlashInfer's ``MnnvlMemory``. Required for multi-node or when ranks - cannot see each other's device memory directly. - - **Plain device memory** (``mapping=None``): Standard ``torch.zeros`` - allocation. Sufficient for single-node with NVLink P2P. - Args: - cp_size: Context-parallel group size. - cp_rank: This rank's position in the CP group. - mapping: Mapping object for MNNVL allocation. If provided, MNNVL is - used. The mapping must have ``cp_size`` set correctly. The - communicator is split using ``mapping.pp_rank``, ``mapping.cp_rank``, - and ``mapping.tp_rank``. + mapping: Mapping object for MNNVL allocation. Carries ``cp_size`` and + ``cp_rank``. The communicator is split using ``mapping.pp_rank``, + ``mapping.cp_rank``, and ``mapping.tp_rank``. mnnvl_config: Configuration for the MNNVL communication backend. Required when using MNNVL with ``torch.distributed`` (pass ``MnnvlConfig(comm_backend=TorchDistBackend(group))``). @@ -153,26 +158,22 @@ def decode_cp_a2a_allocate_workspace( Returns: ``torch.int64`` tensor of shape ``[cp_size, ws_elems_per_rank]``. """ - ws_bytes = decode_cp_a2a_workspace_size(cp_size) - - if mapping is not None: - MnnvlMemory.initialize() - if mnnvl_config: - MnnvlMemory.set_comm_from_config(mapping, mnnvl_config) - - mnnvl_mem = MnnvlMemory(mapping, ws_bytes) - workspace = mnnvl_mem.as_torch_strided_tensor(torch.int64) - workspace._mnnvl_mem = mnnvl_mem # prevent GC of MNNVL handle - logger.info( - "Rank %d: DCP MNNVL workspace allocated — shape=%s, stride=%s", - cp_rank, - list(workspace.shape), - list(workspace.stride()), - ) - return workspace - - ws_elems_per_rank = (ws_bytes + 7) // 8 - return torch.zeros(cp_size, ws_elems_per_rank, dtype=torch.int64, device="cuda") + ws_bytes = decode_cp_a2a_workspace_size(mapping.cp_size) + + MnnvlMemory.initialize() + if mnnvl_config: + MnnvlMemory.set_comm_from_config(mapping, mnnvl_config) + + mnnvl_mem = MnnvlMemory(mapping, ws_bytes) + workspace = mnnvl_mem.as_torch_strided_tensor(torch.int64) + _workspace_keepalive[workspace.data_ptr()] = mnnvl_mem + logger.info( + "Rank %d: DCP MNNVL workspace allocated — shape=%s, stride=%s", + mapping.cp_rank, + list(workspace.shape), + list(workspace.stride()), + ) + return workspace @flashinfer_api @@ -197,7 +198,7 @@ def decode_cp_a2a_init_workspace( Args: workspace: ``[cp_size, ws_elems_per_rank]`` int64 tensor from - :func:`decode_cp_a2a_allocate_workspace`. + :func:`decode_cp_a2a_allocate_mnnvl_workspace`. cp_rank: This rank's position in the CP group. cp_size: Context-parallel group size. """ @@ -229,7 +230,7 @@ def decode_cp_a2a_alltoall( softmax_stats: ``[..., cp_size, S]`` — float32, ``S >= 2`` and even. Batch dimensions must match ``partial_o``. workspace: ``[cp_size, ws_elems_per_rank]`` int64 tensor from - :func:`decode_cp_a2a_allocate_workspace`, already initialized. + :func:`decode_cp_a2a_allocate_mnnvl_workspace`, already initialized. cp_rank: This rank's position in the CP group. cp_size: Context-parallel group size. enable_pdl: Enable Programmatic Dependent Launch (SM90+). @@ -249,7 +250,7 @@ def decode_cp_a2a_alltoall( __all__ = [ "decode_cp_a2a_workspace_size", - "decode_cp_a2a_allocate_workspace", + "decode_cp_a2a_allocate_mnnvl_workspace", "decode_cp_a2a_init_workspace", "decode_cp_a2a_alltoall", ] diff --git a/tests/comm/test_dcp_alltoall.py b/tests/comm/test_dcp_alltoall.py index 847e6f5d29..5663f41374 100644 --- a/tests/comm/test_dcp_alltoall.py +++ b/tests/comm/test_dcp_alltoall.py @@ -15,8 +15,11 @@ """Tests for flashinfer.comm.dcp_alltoall — DCP LL128 FIFO All-to-All. Single-GPU multi-rank pattern: simulates cp_size ranks on one GPU using -separate CUDA streams for the alltoall phase. All ranks share a single -workspace tensor of shape [cp_size, ws_elems_per_rank]. +separate CUDA streams for the alltoall phase. All simulated ranks share a +single ``torch.zeros`` workspace tensor of shape ``[cp_size, ws_elems_per_rank]``, +which lets the kernel's ``params.workspace + peer_rank * stride`` pointer +arithmetic land in the same physical allocation. Real multi-GPU runs need +an MNNVL-backed workspace — see ``test_mnnvl_dcp_alltoall.py``. Run: python -m pytest tests/comm/test_dcp_alltoall.py -v -s """ @@ -26,7 +29,6 @@ from flashinfer.comm import ( decode_cp_a2a_alltoall, - decode_cp_a2a_allocate_workspace, decode_cp_a2a_init_workspace, decode_cp_a2a_workspace_size, ) @@ -70,6 +72,19 @@ def _to_torch(t): return torch.from_dlpack(t) +def _alloc_sim_workspace(cp_size: int) -> torch.Tensor: + """Allocate a single-GPU shared workspace for multi-rank simulation. + + The public allocator (``decode_cp_a2a_allocate_mnnvl_workspace``) requires + a fabric-mapped Mapping, which is not applicable in a single-GPU test. + Here all simulated ranks share one ``torch.zeros`` tensor so the kernel's + cross-rank pointer arithmetic resolves to the same physical allocation. + """ + ws_bytes = decode_cp_a2a_workspace_size(cp_size) + ws_elems = (ws_bytes + 7) // 8 + return torch.zeros(cp_size, ws_elems, dtype=torch.int64, device="cuda") + + def _run_single_gpu_alltoall(cp_size, batch_size, head_dim, stats_dim, dtype): """Simulate cp_size ranks on one GPU and return (inputs, outputs, workspace). @@ -83,7 +98,7 @@ def _run_single_gpu_alltoall(cp_size, batch_size, head_dim, stats_dim, dtype): """ torch.cuda.set_device(0) - workspace = decode_cp_a2a_allocate_workspace(cp_size, cp_rank=0) + workspace = _alloc_sim_workspace(cp_size) all_partial_o = [] all_softmax_stats = [] @@ -164,17 +179,9 @@ def test_workspace_size_monotonic(self): assert ws4 > ws2, "ws(4) should be > ws(2)" assert ws8 > ws4, "ws(8) should be > ws(4)" - def test_allocate_returns_correct_shape_and_dtype(self): - for cp_size in [2, 4]: - ws_bytes = decode_cp_a2a_workspace_size(cp_size) - workspace = decode_cp_a2a_allocate_workspace(cp_size, cp_rank=0) - assert workspace.dtype == torch.int64 - assert workspace.shape[0] == cp_size - assert workspace.shape[1] == (ws_bytes + 7) // 8 - def test_init_workspace_does_not_hang(self): for cp_size in [2, 4]: - workspace = decode_cp_a2a_allocate_workspace(cp_size, cp_rank=0) + workspace = _alloc_sim_workspace(cp_size) for r in range(cp_size): decode_cp_a2a_init_workspace(workspace, r, cp_size) torch.cuda.synchronize() @@ -234,7 +241,7 @@ def test_repeated_alltoall(cp_size, batch_size, head_dim, stats_dim, dtype, num_ """Multiple alltoall calls on the same workspace without re-init (FIFO reuse).""" torch.cuda.set_device(0) - workspace = decode_cp_a2a_allocate_workspace(cp_size, cp_rank=0) + workspace = _alloc_sim_workspace(cp_size) for r in range(cp_size): decode_cp_a2a_init_workspace(workspace, r, cp_size) @@ -283,7 +290,7 @@ def test_cp_size_1_is_identity(self): dtype = torch.bfloat16 torch.cuda.set_device(0) - workspace = decode_cp_a2a_allocate_workspace(cp_size, cp_rank=0) + workspace = _alloc_sim_workspace(cp_size) po = torch.randn(batch_size, cp_size, head_dim, dtype=dtype, device="cuda") ss = torch.randn( batch_size, cp_size, stats_dim, dtype=torch.float32, device="cuda" @@ -305,7 +312,7 @@ def test_batch_size_0(self): dtype = torch.bfloat16 torch.cuda.set_device(0) - workspace = decode_cp_a2a_allocate_workspace(cp_size, cp_rank=0) + workspace = _alloc_sim_workspace(cp_size) po = torch.randn(batch_size, cp_size, head_dim, dtype=dtype, device="cuda") ss = torch.randn( batch_size, cp_size, stats_dim, dtype=torch.float32, device="cuda" @@ -332,7 +339,7 @@ class TestInputValidation: def test_wrong_dtype_float64(self): """partial_o with float64 should be rejected.""" cp_size = 2 - workspace = decode_cp_a2a_allocate_workspace(cp_size, cp_rank=0) + workspace = _alloc_sim_workspace(cp_size) for r in range(cp_size): decode_cp_a2a_init_workspace(workspace, r, cp_size) torch.cuda.synchronize() @@ -346,7 +353,7 @@ def test_wrong_dtype_float64(self): def test_wrong_dtype_float32(self): """partial_o with float32 should be rejected (must be half/bfloat16).""" cp_size = 2 - workspace = decode_cp_a2a_allocate_workspace(cp_size, cp_rank=0) + workspace = _alloc_sim_workspace(cp_size) for r in range(cp_size): decode_cp_a2a_init_workspace(workspace, r, cp_size) torch.cuda.synchronize() @@ -360,7 +367,7 @@ def test_wrong_dtype_float32(self): def test_stats_dim_1_odd_alignment(self): """stats_dim=1 violates 'even and >= 2' constraint — should error.""" cp_size = 2 - workspace = decode_cp_a2a_allocate_workspace(cp_size, cp_rank=0) + workspace = _alloc_sim_workspace(cp_size) for r in range(cp_size): decode_cp_a2a_init_workspace(workspace, r, cp_size) torch.cuda.synchronize() @@ -374,7 +381,7 @@ def test_stats_dim_1_odd_alignment(self): def test_mismatched_batch_dims(self): """partial_o and softmax_stats with different batch sizes should error.""" cp_size = 2 - workspace = decode_cp_a2a_allocate_workspace(cp_size, cp_rank=0) + workspace = _alloc_sim_workspace(cp_size) for r in range(cp_size): decode_cp_a2a_init_workspace(workspace, r, cp_size) torch.cuda.synchronize() @@ -388,7 +395,7 @@ def test_mismatched_batch_dims(self): def test_wrong_stats_dtype(self): """softmax_stats with half instead of float32 should error.""" cp_size = 2 - workspace = decode_cp_a2a_allocate_workspace(cp_size, cp_rank=0) + workspace = _alloc_sim_workspace(cp_size) for r in range(cp_size): decode_cp_a2a_init_workspace(workspace, r, cp_size) torch.cuda.synchronize() diff --git a/tests/comm/test_mnnvl_dcp_alltoall.py b/tests/comm/test_mnnvl_dcp_alltoall.py index bd020b9566..921b0eab37 100644 --- a/tests/comm/test_mnnvl_dcp_alltoall.py +++ b/tests/comm/test_mnnvl_dcp_alltoall.py @@ -35,7 +35,7 @@ from flashinfer.comm import ( decode_cp_a2a_alltoall, - decode_cp_a2a_allocate_workspace, + decode_cp_a2a_allocate_mnnvl_workspace, decode_cp_a2a_init_workspace, decode_cp_a2a_workspace_size, ) @@ -120,7 +120,7 @@ def _setup_rank(): _rank, _cp_size, _comm = _setup_rank() def _allocate_mnnvl_workspace_once(): - """Allocate MNNVL workspace once at module level. + """Allocate MNNVL workspace once at module level via the public API. MnnvlMemory uses a global bump allocator that doesn't support individual frees. Allocating per-test causes segfaults when @@ -137,12 +137,7 @@ def _allocate_mnnvl_workspace_once(): tp_size=1, pp_size=1, ) - - ws_bytes = decode_cp_a2a_workspace_size(_cp_size) - mnnvl_mem = MnnvlMemory(mapping, ws_bytes) - workspace = mnnvl_mem.as_torch_strided_tensor(torch.int64) - workspace._mnnvl_mem = mnnvl_mem # prevent GC - return workspace + return decode_cp_a2a_allocate_mnnvl_workspace(mapping) _mnnvl_workspace = _allocate_mnnvl_workspace_once() else: @@ -319,32 +314,5 @@ def test_repeated_alltoall(self): _comm.Barrier() -class TestMnnvlDcpDeviceMemoryFallback: - """Test that non-MNNVL (device memory) path also works multi-GPU. - - Uses decode_cp_a2a_allocate_workspace without MNNVL mapping. This only - works when all ranks are on the same GPU (single-GPU simulation) - or with IPC. Included here to verify the workspace API contract. - """ - - @pytest.fixture(autouse=True) - def setup(self): - torch.manual_seed(0xA2A) - yield - - def test_device_workspace_shape(self): - """Device workspace has correct shape [cp_size, ws_elems].""" - try: - workspace = decode_cp_a2a_allocate_workspace(_cp_size, cp_rank=_rank) - assert workspace.shape[0] == _cp_size - - ws_bytes = decode_cp_a2a_workspace_size(_cp_size) - expected_elems = (ws_bytes + 7) // 8 - assert workspace.shape[1] == expected_elems - assert workspace.dtype == torch.int64 - finally: - _comm.Barrier() - - if __name__ == "__main__": pytest.main([__file__, "-v", "-s"])