Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flashinfer/comm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
# DCP A2A (Decode Context Parallel Attention Reduction)
from .dcp_alltoall import decode_cp_a2a_alltoall as decode_cp_a2a_alltoall
from .dcp_alltoall import (
decode_cp_a2a_allocate_workspace as decode_cp_a2a_allocate_workspace,
decode_cp_a2a_allocate_mnnvl_workspace as decode_cp_a2a_allocate_mnnvl_workspace,
)
from .dcp_alltoall import decode_cp_a2a_init_workspace as decode_cp_a2a_init_workspace
from .dcp_alltoall import decode_cp_a2a_workspace_size as decode_cp_a2a_workspace_size
Expand Down
95 changes: 48 additions & 47 deletions flashinfer/comm/dcp_alltoall.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@
Provides the DCP LL128 FIFO-based all-to-all kernel for context-parallel
attention reduction. Uses SM90+ features (TMA, mbarrier).

The kernel addresses peer FIFOs via ``params.workspace + peer_rank * stride``,
so it requires a single unified VA spanning all CP ranks — i.e. MNNVL
fabric memory (currently provided by GB200-NVL72 systems). Non-MNNVL
allocations cannot satisfy this layout and would deadlock at runtime.

Usage protocol::

# 1. Query workspace size
ws_bytes = decode_cp_a2a_workspace_size(cp_size)

# 2. Allocate workspace (MNNVL or plain device memory)
workspace = decode_cp_a2a_allocate_workspace(cp_size, cp_rank, mapping=mapping)
# 2. Allocate MNNVL-backed workspace (Mapping is required and carries cp_size/cp_rank)
workspace = decode_cp_a2a_allocate_mnnvl_workspace(mapping)

# 3. Initialize workspace (synchronous — includes stream sync)
decode_cp_a2a_init_workspace(workspace, cp_rank, cp_size)
Expand All @@ -26,7 +31,7 @@
.. important::
All ranks MUST complete ``decode_cp_a2a_init_workspace`` and execute a
cross-rank barrier before ANY rank calls ``decode_cp_a2a_alltoall``.
Failure to do so causes a deadlock on MNNVL workspaces.
Failure to do so causes a deadlock.

Tensor specifications:

Expand All @@ -35,13 +40,13 @@
- ``softmax_stats``: ``[..., cp_size, S]`` — float32, ``S >= 2`` and even.
Batch dims must match ``partial_o``.
- ``workspace``: ``[cp_size, ws_elems_per_rank]`` — int64, from
:func:`decode_cp_a2a_allocate_workspace`.
:func:`decode_cp_a2a_allocate_mnnvl_workspace`.
"""

import functools
import logging
from types import SimpleNamespace
from typing import Optional
from typing import Dict, Optional

import torch

Expand Down Expand Up @@ -100,6 +105,15 @@ def decode_cp_a2a_alltoall(
# ─── Public API ───────────────────────────────────────────────────────────


# Module-level keep-alive for MNNVL workspace handles. The kernel uses raw
# pointers from the strided tensor, but the underlying fabric memory is owned
# by the MnnvlMemory wrapper — when its refcount hits zero, ``__del__`` calls
# ``close_mnnvl_memory`` and unmaps the VA. Without a stable reference outside
# the returned tensor, any caller-side ``view`` / ``slice`` / ``clone`` that
# drops the original tensor would silently free the workspace under the kernel.
_workspace_keepalive: Dict[int, MnnvlMemory] = {}


@flashinfer_api
def decode_cp_a2a_workspace_size(cp_size: int) -> int:
"""Return the workspace size **in bytes** per rank for the given CP group size.
Expand All @@ -119,60 +133,47 @@ def decode_cp_a2a_workspace_size(cp_size: int) -> int:


@flashinfer_api
def decode_cp_a2a_allocate_workspace(
cp_size: int,
cp_rank: int,
def decode_cp_a2a_allocate_mnnvl_workspace(
mapping: Mapping,
*,
mapping: Optional[Mapping] = None,
mnnvl_config: Optional[MnnvlConfig] = None,
) -> torch.Tensor:
"""Allocate a workspace tensor of shape ``[cp_size, ws_elems_per_rank]``.
"""Allocate an MNNVL-backed workspace of shape ``[cp_size, ws_elems_per_rank]``.

The DCP A2A kernel requires a single unified VA spanning all CP ranks
(see module docstring), so workspace allocation must go through MNNVL
fabric memory. This function is the only supported allocator.

After allocation, call :func:`decode_cp_a2a_init_workspace` followed by a
cross-rank barrier before the first :func:`decode_cp_a2a_alltoall` call.

Two allocation modes:

- **MNNVL** (``mapping`` provided): Cross-rank visible GPU memory via
FlashInfer's ``MnnvlMemory``. Required for multi-node or when ranks
cannot see each other's device memory directly.
- **Plain device memory** (``mapping=None``): Standard ``torch.zeros``
allocation. Sufficient for single-node with NVLink P2P.

Args:
cp_size: Context-parallel group size.
cp_rank: This rank's position in the CP group.
mapping: Mapping object for MNNVL allocation. If provided, MNNVL is
used. The mapping must have ``cp_size`` set correctly. The
communicator is split using ``mapping.pp_rank``, ``mapping.cp_rank``,
and ``mapping.tp_rank``.
mapping: Mapping object for MNNVL allocation. Carries ``cp_size`` and
``cp_rank``. The communicator is split using ``mapping.pp_rank``,
``mapping.cp_rank``, and ``mapping.tp_rank``.
mnnvl_config: Configuration for the MNNVL communication backend.
Required when using MNNVL with ``torch.distributed`` (pass
``MnnvlConfig(comm_backend=TorchDistBackend(group))``).
Comment on lines +136 to 156
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The cp_size and cp_rank arguments are redundant because the mapping object (which is now a required positional argument) already contains this information. As noted in the docstring, mapping carries the authoritative rank info. Removing these redundant parameters simplifies the API and eliminates the risk of passing inconsistent values.

@flashinfer_api
def decode_cp_a2a_allocate_mnnvl_workspace(
    mapping: Mapping,
    *,
    mnnvl_config: Optional[MnnvlConfig] = None,
) -> torch.Tensor:
    """Allocate an MNNVL-backed workspace of shape ``[cp_size, ws_elems_per_rank]``.

    The DCP A2A kernel requires a single unified VA spanning all CP ranks
    (see module docstring), so workspace allocation must go through MNNVL
    fabric memory. This function is the only supported allocator.

    After allocation, call :func:`decode_cp_a2a_init_workspace` followed by a
    cross-rank barrier before the first :func:`decode_cp_a2a_alltoall` call.

    Args:
        mapping: Mapping object for MNNVL allocation. Must have ``cp_size``
            set correctly. The communicator is split using ``mapping.pp_rank``,
            ``mapping.cp_rank``, and ``mapping.tp_rank``.
        mnnvl_config: Configuration for the MNNVL communication backend.
            Required when using MNNVL with ``torch.distributed`` (pass
            ``MnnvlConfig(comm_backend=TorchDistBackend(group))``).


Returns:
``torch.int64`` tensor of shape ``[cp_size, ws_elems_per_rank]``.
"""
ws_bytes = decode_cp_a2a_workspace_size(cp_size)

if mapping is not None:
MnnvlMemory.initialize()
if mnnvl_config:
MnnvlMemory.set_comm_from_config(mapping, mnnvl_config)

mnnvl_mem = MnnvlMemory(mapping, ws_bytes)
workspace = mnnvl_mem.as_torch_strided_tensor(torch.int64)
workspace._mnnvl_mem = mnnvl_mem # prevent GC of MNNVL handle
logger.info(
"Rank %d: DCP MNNVL workspace allocated — shape=%s, stride=%s",
cp_rank,
list(workspace.shape),
list(workspace.stride()),
)
return workspace

ws_elems_per_rank = (ws_bytes + 7) // 8
return torch.zeros(cp_size, ws_elems_per_rank, dtype=torch.int64, device="cuda")
ws_bytes = decode_cp_a2a_workspace_size(mapping.cp_size)

MnnvlMemory.initialize()
if mnnvl_config:
MnnvlMemory.set_comm_from_config(mapping, mnnvl_config)

mnnvl_mem = MnnvlMemory(mapping, ws_bytes)
workspace = mnnvl_mem.as_torch_strided_tensor(torch.int64)
_workspace_keepalive[workspace.data_ptr()] = mnnvl_mem
logger.info(
"Rank %d: DCP MNNVL workspace allocated — shape=%s, stride=%s",
mapping.cp_rank,
list(workspace.shape),
list(workspace.stride()),
)
return workspace


@flashinfer_api
Expand All @@ -197,7 +198,7 @@ def decode_cp_a2a_init_workspace(

Args:
workspace: ``[cp_size, ws_elems_per_rank]`` int64 tensor from
:func:`decode_cp_a2a_allocate_workspace`.
:func:`decode_cp_a2a_allocate_mnnvl_workspace`.
cp_rank: This rank's position in the CP group.
cp_size: Context-parallel group size.
"""
Expand Down Expand Up @@ -229,7 +230,7 @@ def decode_cp_a2a_alltoall(
softmax_stats: ``[..., cp_size, S]`` — float32, ``S >= 2`` and even.
Batch dimensions must match ``partial_o``.
workspace: ``[cp_size, ws_elems_per_rank]`` int64 tensor from
:func:`decode_cp_a2a_allocate_workspace`, already initialized.
:func:`decode_cp_a2a_allocate_mnnvl_workspace`, already initialized.
cp_rank: This rank's position in the CP group.
cp_size: Context-parallel group size.
enable_pdl: Enable Programmatic Dependent Launch (SM90+).
Expand All @@ -249,7 +250,7 @@ def decode_cp_a2a_alltoall(

__all__ = [
"decode_cp_a2a_workspace_size",
"decode_cp_a2a_allocate_workspace",
"decode_cp_a2a_allocate_mnnvl_workspace",
"decode_cp_a2a_init_workspace",
"decode_cp_a2a_alltoall",
]
49 changes: 28 additions & 21 deletions tests/comm/test_dcp_alltoall.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
"""Tests for flashinfer.comm.dcp_alltoall — DCP LL128 FIFO All-to-All.

Single-GPU multi-rank pattern: simulates cp_size ranks on one GPU using
separate CUDA streams for the alltoall phase. All ranks share a single
workspace tensor of shape [cp_size, ws_elems_per_rank].
separate CUDA streams for the alltoall phase. All simulated ranks share a
single ``torch.zeros`` workspace tensor of shape ``[cp_size, ws_elems_per_rank]``,
which lets the kernel's ``params.workspace + peer_rank * stride`` pointer
arithmetic land in the same physical allocation. Real multi-GPU runs need
an MNNVL-backed workspace — see ``test_mnnvl_dcp_alltoall.py``.

Run: python -m pytest tests/comm/test_dcp_alltoall.py -v -s
"""
Expand All @@ -26,7 +29,6 @@

from flashinfer.comm import (
decode_cp_a2a_alltoall,
decode_cp_a2a_allocate_workspace,
decode_cp_a2a_init_workspace,
decode_cp_a2a_workspace_size,
)
Expand Down Expand Up @@ -70,6 +72,19 @@ def _to_torch(t):
return torch.from_dlpack(t)


def _alloc_sim_workspace(cp_size: int) -> torch.Tensor:
"""Allocate a single-GPU shared workspace for multi-rank simulation.

The public allocator (``decode_cp_a2a_allocate_mnnvl_workspace``) requires
a fabric-mapped Mapping, which is not applicable in a single-GPU test.
Here all simulated ranks share one ``torch.zeros`` tensor so the kernel's
cross-rank pointer arithmetic resolves to the same physical allocation.
"""
ws_bytes = decode_cp_a2a_workspace_size(cp_size)
ws_elems = (ws_bytes + 7) // 8
return torch.zeros(cp_size, ws_elems, dtype=torch.int64, device="cuda")


def _run_single_gpu_alltoall(cp_size, batch_size, head_dim, stats_dim, dtype):
"""Simulate cp_size ranks on one GPU and return (inputs, outputs, workspace).

Expand All @@ -83,7 +98,7 @@ def _run_single_gpu_alltoall(cp_size, batch_size, head_dim, stats_dim, dtype):
"""
torch.cuda.set_device(0)

workspace = decode_cp_a2a_allocate_workspace(cp_size, cp_rank=0)
workspace = _alloc_sim_workspace(cp_size)

all_partial_o = []
all_softmax_stats = []
Expand Down Expand Up @@ -164,17 +179,9 @@ def test_workspace_size_monotonic(self):
assert ws4 > ws2, "ws(4) should be > ws(2)"
assert ws8 > ws4, "ws(8) should be > ws(4)"

def test_allocate_returns_correct_shape_and_dtype(self):
for cp_size in [2, 4]:
ws_bytes = decode_cp_a2a_workspace_size(cp_size)
workspace = decode_cp_a2a_allocate_workspace(cp_size, cp_rank=0)
assert workspace.dtype == torch.int64
assert workspace.shape[0] == cp_size
assert workspace.shape[1] == (ws_bytes + 7) // 8

def test_init_workspace_does_not_hang(self):
for cp_size in [2, 4]:
workspace = decode_cp_a2a_allocate_workspace(cp_size, cp_rank=0)
workspace = _alloc_sim_workspace(cp_size)
for r in range(cp_size):
decode_cp_a2a_init_workspace(workspace, r, cp_size)
torch.cuda.synchronize()
Expand Down Expand Up @@ -234,7 +241,7 @@ def test_repeated_alltoall(cp_size, batch_size, head_dim, stats_dim, dtype, num_
"""Multiple alltoall calls on the same workspace without re-init (FIFO reuse)."""
torch.cuda.set_device(0)

workspace = decode_cp_a2a_allocate_workspace(cp_size, cp_rank=0)
workspace = _alloc_sim_workspace(cp_size)

for r in range(cp_size):
decode_cp_a2a_init_workspace(workspace, r, cp_size)
Expand Down Expand Up @@ -283,7 +290,7 @@ def test_cp_size_1_is_identity(self):
dtype = torch.bfloat16
torch.cuda.set_device(0)

workspace = decode_cp_a2a_allocate_workspace(cp_size, cp_rank=0)
workspace = _alloc_sim_workspace(cp_size)
po = torch.randn(batch_size, cp_size, head_dim, dtype=dtype, device="cuda")
ss = torch.randn(
batch_size, cp_size, stats_dim, dtype=torch.float32, device="cuda"
Expand All @@ -305,7 +312,7 @@ def test_batch_size_0(self):
dtype = torch.bfloat16
torch.cuda.set_device(0)

workspace = decode_cp_a2a_allocate_workspace(cp_size, cp_rank=0)
workspace = _alloc_sim_workspace(cp_size)
po = torch.randn(batch_size, cp_size, head_dim, dtype=dtype, device="cuda")
ss = torch.randn(
batch_size, cp_size, stats_dim, dtype=torch.float32, device="cuda"
Expand All @@ -332,7 +339,7 @@ class TestInputValidation:
def test_wrong_dtype_float64(self):
"""partial_o with float64 should be rejected."""
cp_size = 2
workspace = decode_cp_a2a_allocate_workspace(cp_size, cp_rank=0)
workspace = _alloc_sim_workspace(cp_size)
for r in range(cp_size):
decode_cp_a2a_init_workspace(workspace, r, cp_size)
torch.cuda.synchronize()
Expand All @@ -346,7 +353,7 @@ def test_wrong_dtype_float64(self):
def test_wrong_dtype_float32(self):
"""partial_o with float32 should be rejected (must be half/bfloat16)."""
cp_size = 2
workspace = decode_cp_a2a_allocate_workspace(cp_size, cp_rank=0)
workspace = _alloc_sim_workspace(cp_size)
for r in range(cp_size):
decode_cp_a2a_init_workspace(workspace, r, cp_size)
torch.cuda.synchronize()
Expand All @@ -360,7 +367,7 @@ def test_wrong_dtype_float32(self):
def test_stats_dim_1_odd_alignment(self):
"""stats_dim=1 violates 'even and >= 2' constraint — should error."""
cp_size = 2
workspace = decode_cp_a2a_allocate_workspace(cp_size, cp_rank=0)
workspace = _alloc_sim_workspace(cp_size)
for r in range(cp_size):
decode_cp_a2a_init_workspace(workspace, r, cp_size)
torch.cuda.synchronize()
Expand All @@ -374,7 +381,7 @@ def test_stats_dim_1_odd_alignment(self):
def test_mismatched_batch_dims(self):
"""partial_o and softmax_stats with different batch sizes should error."""
cp_size = 2
workspace = decode_cp_a2a_allocate_workspace(cp_size, cp_rank=0)
workspace = _alloc_sim_workspace(cp_size)
for r in range(cp_size):
decode_cp_a2a_init_workspace(workspace, r, cp_size)
torch.cuda.synchronize()
Expand All @@ -388,7 +395,7 @@ def test_mismatched_batch_dims(self):
def test_wrong_stats_dtype(self):
"""softmax_stats with half instead of float32 should error."""
cp_size = 2
workspace = decode_cp_a2a_allocate_workspace(cp_size, cp_rank=0)
workspace = _alloc_sim_workspace(cp_size)
for r in range(cp_size):
decode_cp_a2a_init_workspace(workspace, r, cp_size)
torch.cuda.synchronize()
Expand Down
38 changes: 3 additions & 35 deletions tests/comm/test_mnnvl_dcp_alltoall.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

from flashinfer.comm import (
decode_cp_a2a_alltoall,
decode_cp_a2a_allocate_workspace,
decode_cp_a2a_allocate_mnnvl_workspace,
decode_cp_a2a_init_workspace,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The new public API decode_cp_a2a_allocate_mnnvl_workspace is not imported or used in this test file. Instead, the module-level helper _allocate_mnnvl_workspace_once (lines 121-144) manually reimplements the allocation logic. It is highly recommended to import and use the public API in the tests to ensure it is properly exercised and to reduce code duplication.

Suggested change
decode_cp_a2a_init_workspace,
decode_cp_a2a_allocate_mnnvl_workspace,

decode_cp_a2a_workspace_size,
)
Expand Down Expand Up @@ -120,7 +120,7 @@ def _setup_rank():
_rank, _cp_size, _comm = _setup_rank()

def _allocate_mnnvl_workspace_once():
"""Allocate MNNVL workspace once at module level.
"""Allocate MNNVL workspace once at module level via the public API.

MnnvlMemory uses a global bump allocator that doesn't support
individual frees. Allocating per-test causes segfaults when
Expand All @@ -137,12 +137,7 @@ def _allocate_mnnvl_workspace_once():
tp_size=1,
pp_size=1,
)

ws_bytes = decode_cp_a2a_workspace_size(_cp_size)
mnnvl_mem = MnnvlMemory(mapping, ws_bytes)
workspace = mnnvl_mem.as_torch_strided_tensor(torch.int64)
workspace._mnnvl_mem = mnnvl_mem # prevent GC
return workspace
return decode_cp_a2a_allocate_mnnvl_workspace(mapping)

_mnnvl_workspace = _allocate_mnnvl_workspace_once()
else:
Expand Down Expand Up @@ -319,32 +314,5 @@ def test_repeated_alltoall(self):
_comm.Barrier()


class TestMnnvlDcpDeviceMemoryFallback:
"""Test that non-MNNVL (device memory) path also works multi-GPU.

Uses decode_cp_a2a_allocate_workspace without MNNVL mapping. This only
works when all ranks are on the same GPU (single-GPU simulation)
or with IPC. Included here to verify the workspace API contract.
"""

@pytest.fixture(autouse=True)
def setup(self):
torch.manual_seed(0xA2A)
yield

def test_device_workspace_shape(self):
"""Device workspace has correct shape [cp_size, ws_elems]."""
try:
workspace = decode_cp_a2a_allocate_workspace(_cp_size, cp_rank=_rank)
assert workspace.shape[0] == _cp_size

ws_bytes = decode_cp_a2a_workspace_size(_cp_size)
expected_elems = (ws_bytes + 7) // 8
assert workspace.shape[1] == expected_elems
assert workspace.dtype == torch.int64
finally:
_comm.Barrier()


if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])
Loading