Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 79 additions & 79 deletions torchtitan/distributed/expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
DeviceMesh,
distribute_module,
distribute_tensor,
DTensor,
Shard,
)
from torch.distributed.tensor.parallel import ParallelStyle
Expand Down Expand Up @@ -85,12 +84,15 @@ def __init__(self):
super().__init__()
self.input_splits = None
self.output_splits = None
self.input_shape = None
self.permuted_indices = None

# performing all-to-all dispatch on the input
def _token_dispatch(self, mod, inputs, device_mesh):
# annotate module input placements/sharding with input_layouts
routed_input, num_tokens_per_expert = inputs
ep_size = device_mesh.shape[0]
ep_degree = device_mesh.shape[0]
num_local_experts = num_tokens_per_expert.shape[0] // ep_degree

# generate the input splits and output splits for all-to-all
with torch.no_grad():
Expand All @@ -106,13 +108,13 @@ def _token_dispatch(self, mod, inputs, device_mesh):
num_tokens_per_expert_group
)
input_splits = (
num_tokens_per_expert.view(ep_size, -1)
num_tokens_per_expert.view(ep_degree, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=True)
)
# NOTE: this would incur a device-to-host sync
output_splits = (
num_tokens_per_expert_group.view(ep_size, -1)
num_tokens_per_expert_group.view(ep_degree, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=False)
)
Expand All @@ -133,9 +135,20 @@ def _token_dispatch(self, mod, inputs, device_mesh):
# Rather, it is of the format
# [#tokens for local expert 0 from EP rank 0, #tokens for local expert 1 from EP rank 0, ...,
# #tokens for local expert 0 from EP rank 1, #tokens for local expert 1 from EP rank 1, ...]
# We need to perform another shuffle to get the correct format -- this is done via the function
# generate_permute_indices in moe.py, which also does padding to make sure the number of tokens
# each expert gets locally is a multiple of ALIGN_SIZE_M.
# We need to perform another shuffle to get the correct layout, via the _permute function
# below, which also does padding to make sure the number of tokens each expert gets locally
# is a multiple of TOKEN_GROUP_ALIGN_SIZE_M.
# Note that this will create side effects when wrapping the for-loop implementation
# of GroupedExperts, as it does not need padding.

(
self.input_shape,
routed_input,
self.permuted_indices,
num_tokens_per_expert_group,
) = _permute(
routed_input, num_tokens_per_expert_group, ep_degree, num_local_experts
)

return routed_input, num_tokens_per_expert_group

Expand All @@ -148,6 +161,10 @@ def _partition_fn(name, mod, device_mesh):

# performing all-to-all combine on the output
def _token_combine(self, mod, routed_output, device_mesh):
routed_output = _unpermute(
routed_output, self.input_shape, self.permuted_indices
)

routed_output = all_to_all_single_autograd(
routed_output,
self.input_splits,
Expand All @@ -168,20 +185,9 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:

# This class is for dp2ep with TP (without TP we can just use ExpertParallel)
class ExpertTensorParallel(ExpertParallel):
def __init__(
self,
tp_mesh: DeviceMesh,
ep_mesh: DeviceMesh,
):
super().__init__()
# TODO: has to pass in the meshes in addition to the [ep, tp] device_mesh,
# as DeviceMesh doesn't support slicing from a submesh.
self.tp_mesh = tp_mesh
self.ep_mesh = ep_mesh

def _token_dispatch(self, mod, inputs, device_mesh):
# token dispatch happens on the EP mesh, whereas device_mesh is [ep, tp] mesh
return super()._token_dispatch(mod, inputs, self.ep_mesh)
return super()._token_dispatch(mod, inputs, device_mesh["ep"])

def _partition_fn_2d(self, name, mod, ep_tp_mesh):
# w1 shape = (experts, out_dim, in_dim)
Expand All @@ -204,7 +210,7 @@ def _partition_fn_2d(self, name, mod, ep_tp_mesh):

def _token_combine(self, mod, routed_output, device_mesh):
# token combine happens on the EP mesh, whereas device_mesh is [ep, tp] mesh
return super()._token_combine(mod, routed_output, self.ep_mesh)
return super()._token_combine(mod, routed_output, device_mesh["ep"])

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
return distribute_module(
Expand All @@ -216,25 +222,42 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
)


def expert_parallel(func: Callable) -> Callable:
def _permute(x, num_tokens_per_expert, ep_degree, num_local_experts):
# TODO: move to core
from torchtitan.experiments.kernels.moe.indices import generate_permute_indices

global TOKEN_GROUP_ALIGN_SIZE_M
x_padded_per_expert = x.shape[0] + num_local_experts * TOKEN_GROUP_ALIGN_SIZE_M
padded_max_len = _round_up(x_padded_per_expert, TOKEN_GROUP_ALIGN_SIZE_M)
with torch.no_grad():
(permuted_indices, num_tokens_per_expert, _offsets,) = generate_permute_indices(
num_tokens_per_expert,
num_local_experts,
ep_degree,
padded_max_len,
TOKEN_GROUP_ALIGN_SIZE_M,
)

x = torch.vstack((x, x.new_zeros((x.shape[-1]))))
input_shape = x.shape
x = x[permuted_indices, :]

return input_shape, x, permuted_indices, num_tokens_per_expert


def _unpermute(out, input_shape, permuted_indices):
out_unpermuted = out.new_empty(input_shape)
out_unpermuted[permuted_indices, :] = out
out = out_unpermuted[:-1]
return out


def indices_padding_wrapper(func: Callable) -> Callable:
"""
This is a wrapper applied to the GroupedExperts computation, serving
the following three purposes:
1. Convert parameters from DTensors to plain Tensors, to work with
dynamic-shape inputs which cannot be easily expressed as DTensors.
2. In Expert Parallel, apply the generate_permute_indices kernel to
permute the inputs to be ordered by local experts (see the _token_dispatch
function in ExpertParallel) and permute the outputs back.
3. In order to use torch._grouped_mm, we need to make sure the number of
tokens each expert gets is a multiple of ALIGN_SIZE_M. The generate_permute_indices
kernel also helps achieve this via padding, without incurring synchronization
between device and host. Note that this will create side effects when wrapping
the for-loop implementation of GroupedExperts, as it does not need padding.

Among the above:
1 and 2 are needed only when expert_parallel_degree > 1.
3 is needed even for single-device computation.
2 can be moved to ExpertParallel _token_dispatch if not coupled with 3.
In order to use torch._grouped_mm, we need to make sure the number of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This description only talks about padding, didn't talk about generate_permute_indices kernel to permute the inputs to be ordered by local experts

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wrapper now is only responsible for padding, when EP is not used. I renamed to make it more clear.

tokens each expert gets is a multiple of TOKEN_GROUP_ALIGN_SIZE_M. The
generate_permute_indices kernel also helps achieve this via padding,
without incurring synchronization between device and host.
"""

def wrapper(
Expand All @@ -244,45 +267,16 @@ def wrapper(
x: torch.Tensor,
num_tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
global TOKEN_GROUP_ALIGN_SIZE_M
if isinstance(w1, DTensor):
w1 = w1.to_local()
w2 = w2.to_local()
w3 = w3.to_local()
num_local_experts = w1.shape[0]
ep_degree = num_tokens_per_expert.shape[0] // num_local_experts

from torchtitan.experiments.kernels.moe.indices import generate_permute_indices

experts_per_ep_rank = w1.shape[0]
num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank

# Make sure max_len of permuted token indicies is divisible by TOKEN_GROUP_ALIGN_SIZE_M,
# by padding it to the nearest multiple of TOKEN_GROUP_ALIGN_SIZE_M.
x_padded_per_expert = (
x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M
input_shape, x, permuted_indices, num_tokens_per_expert = _permute(
x, num_tokens_per_expert, ep_degree, num_local_experts
)
padded_max_len = _round_up(x_padded_per_expert, TOKEN_GROUP_ALIGN_SIZE_M)
with torch.no_grad():
(
permuted_indices,
num_tokens_per_expert,
_, # offsets,
) = generate_permute_indices(
num_tokens_per_expert,
experts_per_ep_rank,
num_ep_ranks,
padded_max_len,
TOKEN_GROUP_ALIGN_SIZE_M,
)

x = torch.vstack((x, x.new_zeros((x.shape[-1]))))
input_shape = x.shape
x = x[permuted_indices, :]

out = func(w1, w2, w3, x, num_tokens_per_expert)

out_unpermuted = out.new_empty(input_shape)
out_unpermuted[permuted_indices, :] = out
out = out_unpermuted[:-1]
out = _unpermute(out, input_shape, permuted_indices)

return out

Expand All @@ -294,11 +288,12 @@ def wrapper(
class ReordererSequenceParallel(ParallelStyle):
def __init__(self):
super().__init__()
self.num_tokens = None
self.top_k = None

def _prepare_inputput_fn(self, mod, inputs, device_mesh):
# shape (batch_size*seq_len, top_k)
top_scores, selected_experts_indices = inputs
self.num_tokens = top_scores.shape[0]
num_tokens, self.top_k = top_scores.shape

# NOTE: If needed, we can pad tokens in case bs*slen is not divisible by TP degree
# if top_scores.shape[0] % device_mesh.size() != 0:
Expand All @@ -310,8 +305,12 @@ def _prepare_inputput_fn(self, mod, inputs, device_mesh):

def _split_along_first_dim(x: torch.Tensor) -> torch.Tensor:
assert x.is_contiguous()
assert self.num_tokens % device_mesh.size() == 0
local_num_tokens = self.num_tokens // device_mesh.size()
if num_tokens % device_mesh.size() != 0:
raise ValueError(
"Uneven split of tokens of is not supported yet. "
"Requires EP degree dividing batch size * seq len."
)
local_num_tokens = num_tokens // device_mesh.size()
local_rank = device_mesh.get_local_rank()
offset = local_rank * local_num_tokens
output = x[offset : offset + local_num_tokens]
Expand All @@ -321,17 +320,18 @@ def _split_along_first_dim(x: torch.Tensor) -> torch.Tensor:
top_scores = _split_along_first_dim(top_scores)
selected_experts_indices = _split_along_first_dim(selected_experts_indices)

# shape (batch_size * seq_len // ep_degree, top_k)
return top_scores, selected_experts_indices

def _prepare_output_fn(self, mod, outputs, device_mesh):
# shape (batch_size * seq_len * top_k // ep_degree)
top_scores, token_indices_experts_sorted, num_tokens_per_expert = outputs

# NOTE: As we shard routed tokens along bs*slen dim across the TP ranks,
# the MoE gather and scatter still require global token indices.
local_rank = device_mesh.get_local_rank()
token_indices_experts_sorted += (
self.num_tokens // device_mesh.size() * local_rank
)
# fact: top_scores.shape[0] // self.top_k = batch_size * seq_len // ep_degree
token_indices_experts_sorted += top_scores.shape[0] // self.top_k * local_rank

return top_scores, token_indices_experts_sorted, num_tokens_per_expert

Expand Down
1 change: 0 additions & 1 deletion torchtitan/experiments/kernels/moe/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def forward( # type: ignore[no-untyped-def]
out_splits_offsets, grad_out_buf, grad_in_buf, grad_in_splits_offsets
)
ctx.group_name = group_name
return out

@staticmethod
def backward( # type: ignore[no-untyped-def]
Expand Down
13 changes: 6 additions & 7 deletions torchtitan/experiments/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import torch
import torch.nn as nn
from torch.distributed.device_mesh import DeviceMesh
Expand All @@ -22,6 +21,7 @@
from torchtitan.config.job_config import Compile as CompileConfig
from torchtitan.distributed import NoParallel, ParallelDims
from torchtitan.distributed.activation_checkpoint import apply_ac

from torchtitan.distributed.expert_parallel import (
ExpertParallel,
ExpertTensorParallel,
Expand Down Expand Up @@ -441,6 +441,8 @@ def apply_moe_ep_tp(
ep_tp_mesh: DeviceMesh | None,
etp_enabled: bool,
):
assert ep_mesh is not None or tp_mesh is not None

for transformer_block in model.layers.values():
if not transformer_block.moe_enabled:
continue
Expand Down Expand Up @@ -486,16 +488,13 @@ def apply_moe_ep_tp(
experts_mesh = tp_mesh
# input Replicate, output Partial
experts_plan = TensorParallel()
elif tp_mesh is None:
elif tp_mesh is None or not etp_enabled:
experts_mesh = ep_mesh
# input / output sharding on the batch / tokens dim
experts_plan = ExpertParallel()
elif etp_enabled:
experts_mesh = ep_tp_mesh
experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh)
else:
experts_mesh = ep_mesh
experts_plan = ExpertParallel()
experts_mesh = ep_tp_mesh
experts_plan = ExpertTensorParallel()

parallelize_module(
module=transformer_block.moe.experts,
Expand Down
Loading