Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 105 additions & 1 deletion aiter/dist/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,29 @@
import torch
import torch.distributed

from .parallel_state import get_tp_group, get_pp_group, get_dp_group, get_ep_group
from .parallel_state import (
get_tp_group,
get_pp_group,
get_dp_group,
get_ep_group,
get_custom_group,
has_custom_group,
)


def _assert_no_custom_group(op_name: str):
assert not has_custom_group(), (
f"custom_group_config is set — use custom_all_reduce() instead of "
f"{op_name}()"
)


def _assert_has_custom_group():
assert has_custom_group(), (
"custom_group_config is not set — use tensor_model_parallel_all_reduce() "
"or other standard parallel group operations instead of custom_all_reduce()"
)


# ============================================================
# Tensor Model Parallel (TP) communication operations
Expand All @@ -34,6 +56,7 @@ def tensor_model_parallel_all_reduce(
prefill_support: bool = False,
) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group."""
_assert_no_custom_group("tensor_model_parallel_all_reduce")
return get_tp_group().all_reduce(input_, use_new, open_fp8_quant, prefill_support)


Expand All @@ -44,6 +67,7 @@ def tensor_model_parallel_fused_allreduce_rmsnorm(
eps: float,
prefill_support: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
_assert_no_custom_group("tensor_model_parallel_fused_allreduce_rmsnorm")
return get_tp_group().fused_allreduce_rmsnorm(
input_, residual_inp_, weight_, eps, prefill_support
)
Expand All @@ -56,6 +80,7 @@ def tensor_model_parallel_fused_allreduce_rmsnorm_quant(
eps: float,
prefill_support: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
_assert_no_custom_group("tensor_model_parallel_fused_allreduce_rmsnorm_quant")
return get_tp_group().fused_allreduce_rmsnorm_quant(
input_,
residual_inp_,
Expand All @@ -66,6 +91,7 @@ def tensor_model_parallel_fused_allreduce_rmsnorm_quant(


def tensor_model_parallel_custom_all_gather(input_: torch.Tensor) -> torch.Tensor:
_assert_no_custom_group("tensor_model_parallel_custom_all_gather")
return get_tp_group().custom_all_gather(input_)


Expand All @@ -74,6 +100,7 @@ def tensor_model_parallel_reduce_scatter(
use_custom: bool = True,
dim: int = 0,
) -> torch.Tensor:
_assert_no_custom_group("tensor_model_parallel_reduce_scatter")
return get_tp_group().reduce_scatter_tensor(input_, use_custom, dim)


Expand All @@ -83,19 +110,22 @@ def tensor_model_parallel_all_gather(
dim: int = -1,
) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
_assert_no_custom_group("tensor_model_parallel_all_gather")
return get_tp_group().all_gather(input_, use_custom, dim)


def tensor_model_parallel_gather(
input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> Optional[torch.Tensor]:
"""Gather the input tensor across model parallel group."""
_assert_no_custom_group("tensor_model_parallel_gather")
return get_tp_group().gather(input_, dst, dim)


def broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0
):
_assert_no_custom_group("broadcast_tensor_dict")
if not torch.distributed.is_initialized():
return tensor_dict
return get_tp_group().broadcast_tensor_dict(tensor_dict, src)
Expand All @@ -110,39 +140,45 @@ def expert_parallel_all_reduce(
input_: torch.Tensor, use_new: bool = True, open_fp8_quant: bool = False
) -> torch.Tensor:
"""All-reduce the input tensor across expert parallel group."""
_assert_no_custom_group("expert_parallel_all_reduce")
return get_ep_group().all_reduce(input_, use_new, open_fp8_quant)


def expert_parallel_all_gather(
input_: torch.Tensor, use_custom: bool = False, dim: int = -1
) -> torch.Tensor:
"""All-gather the input tensor across expert parallel group."""
_assert_no_custom_group("expert_parallel_all_gather")
return get_ep_group().all_gather(input_, use_custom, dim)


def expert_parallel_reduce_scatter(
input_: torch.Tensor, use_custom: bool = True, dim: int = 0
) -> torch.Tensor:
"""Reduce-scatter the input tensor across expert parallel group."""
_assert_no_custom_group("expert_parallel_reduce_scatter")
return get_ep_group().reduce_scatter_tensor(input_, use_custom, dim)


def expert_parallel_gather(
input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> Optional[torch.Tensor]:
"""Gather the input tensor across expert parallel group."""
_assert_no_custom_group("expert_parallel_gather")
return get_ep_group().gather(input_, dst, dim)


def expert_parallel_broadcast(input_: torch.Tensor, src: int = 0) -> torch.Tensor:
"""Broadcast the input tensor across expert parallel group."""
_assert_no_custom_group("expert_parallel_broadcast")
return get_ep_group().broadcast(input_, src)


def expert_parallel_broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0
):
"""Broadcast a tensor dict across expert parallel group."""
_assert_no_custom_group("expert_parallel_broadcast_tensor_dict")
if not torch.distributed.is_initialized():
return tensor_dict
return get_ep_group().broadcast_tensor_dict(tensor_dict, src)
Expand All @@ -157,39 +193,45 @@ def data_parallel_all_reduce(
input_: torch.Tensor, use_new: bool = True, open_fp8_quant: bool = False
) -> torch.Tensor:
"""All-reduce the input tensor across data parallel group."""
_assert_no_custom_group("data_parallel_all_reduce")
return get_dp_group().all_reduce(input_, use_new, open_fp8_quant)


def data_parallel_all_gather(
input_: torch.Tensor, use_custom: bool = False, dim: int = -1
) -> torch.Tensor:
"""All-gather the input tensor across data parallel group."""
_assert_no_custom_group("data_parallel_all_gather")
return get_dp_group().all_gather(input_, use_custom, dim)


def data_parallel_reduce_scatter(
input_: torch.Tensor, use_custom: bool = True, dim: int = 0
) -> torch.Tensor:
"""Reduce-scatter the input tensor across data parallel group."""
_assert_no_custom_group("data_parallel_reduce_scatter")
return get_dp_group().reduce_scatter_tensor(input_, use_custom, dim)


def data_parallel_gather(
input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> Optional[torch.Tensor]:
"""Gather the input tensor across data parallel group."""
_assert_no_custom_group("data_parallel_gather")
return get_dp_group().gather(input_, dst, dim)


def data_parallel_broadcast(input_: torch.Tensor, src: int = 0) -> torch.Tensor:
"""Broadcast the input tensor across data parallel group."""
_assert_no_custom_group("data_parallel_broadcast")
return get_dp_group().broadcast(input_, src)


def data_parallel_broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0
):
"""Broadcast a tensor dict across data parallel group."""
_assert_no_custom_group("data_parallel_broadcast_tensor_dict")
if not torch.distributed.is_initialized():
return tensor_dict
return get_dp_group().broadcast_tensor_dict(tensor_dict, src)
Expand All @@ -204,41 +246,103 @@ def pipeline_model_parallel_all_reduce(
input_: torch.Tensor, use_new: bool = True, open_fp8_quant: bool = False
) -> torch.Tensor:
"""All-reduce the input tensor across pipeline parallel group."""
_assert_no_custom_group("pipeline_model_parallel_all_reduce")
return get_pp_group().all_reduce(input_, use_new, open_fp8_quant)


def pipeline_model_parallel_all_gather(
input_: torch.Tensor, use_custom: bool = False, dim: int = -1
) -> torch.Tensor:
"""All-gather the input tensor across pipeline parallel group."""
_assert_no_custom_group("pipeline_model_parallel_all_gather")
return get_pp_group().all_gather(input_, use_custom, dim)


def pipeline_model_parallel_broadcast(
input_: torch.Tensor, src: int = 0
) -> torch.Tensor:
"""Broadcast the input tensor across pipeline parallel group."""
_assert_no_custom_group("pipeline_model_parallel_broadcast")
return get_pp_group().broadcast(input_, src)


def pipeline_model_parallel_send(
input_: torch.Tensor, dst: Optional[int] = None
) -> None:
"""Send a tensor to the next stage in the pipeline."""
_assert_no_custom_group("pipeline_model_parallel_send")
get_pp_group().send(input_, dst)


def pipeline_model_parallel_recv(
size: torch.Size, dtype: torch.dtype, src: Optional[int] = None
) -> torch.Tensor:
"""Receive a tensor from the previous stage in the pipeline."""
_assert_no_custom_group("pipeline_model_parallel_recv")
return get_pp_group().recv(size, dtype, src)


def pipeline_model_parallel_broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0
):
"""Broadcast a tensor dict across pipeline parallel group."""
_assert_no_custom_group("pipeline_model_parallel_broadcast_tensor_dict")
if not torch.distributed.is_initialized():
return tensor_dict
return get_pp_group().broadcast_tensor_dict(tensor_dict, src)


# ============================================================
# Custom group communication operations
# ============================================================


def custom_all_reduce(
input_: torch.Tensor,
use_new: bool = True,
open_fp8_quant: bool = False,
group: Optional[str] = None,
) -> torch.Tensor:
"""All-reduce the input tensor across the user-specified custom group.

Args:
group: Name of the custom group. When only one custom group is
initialized this can be omitted. When multiple groups exist,
pass the group name to select which one to use.
"""
_assert_has_custom_group()
return get_custom_group(group).all_reduce(input_, use_new, open_fp8_quant)


def custom_all_gather(
input_: torch.Tensor,
use_custom: bool = True,
dim: int = 0,
group: Optional[str] = None,
) -> torch.Tensor:
"""All-gather the input tensor across the user-specified custom group.

Args:
group: Name of the custom group. When only one custom group is
initialized this can be omitted. When multiple groups exist,
pass the group name to select which one to use.
"""
_assert_has_custom_group()
return get_custom_group(group).all_gather(input_, use_custom, dim)


def custom_reduce_scatter(
input_: torch.Tensor,
use_custom: bool = True,
dim: int = 0,
group: Optional[str] = None,
) -> torch.Tensor:
"""Reduce-scatter the input tensor across the user-specified custom group.

Args:
group: Name of the custom group. When only one custom group is
initialized this can be omitted. When multiple groups exist,
pass the group name to select which one to use.
"""
_assert_has_custom_group()
return get_custom_group(group).reduce_scatter_tensor(input_, use_custom, dim)
Loading
Loading