Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] Add auto parallel moe layer #9886

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions paddlenlp/transformers/auto_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.distributed as dist


def get_mesh(pp_idx=0):
"""
获得pp_idx的mesh
"""
mesh = dist.fleet.auto.get_mesh()
if "pp" in mesh.dim_names:
mesh = mesh.get_mesh_with_dim("pp", pp_idx)
return mesh

Check warning on line 26 in paddlenlp/transformers/auto_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/auto_utils.py#L23-L26

Added lines #L23 - L26 were not covered by tests


def einsum(rule, a, b):
"""
Use other ops to replace einsum. The implementation
is from https://github.com/deepspeedai/DeepSpeed.
"""
if rule == "s,se->se":
return a.reshape([a.shape[0], -1]) * b
elif rule == "se,sc->sec":
return a.unsqueeze(2) * b.unsqueeze(1)
elif rule == "se,se->s":
return paddle.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)
elif rule == "se,sec->sec":
return paddle.unsqueeze(a, axis=2) * b
elif rule == "sec,sm->ecm":
s = a.shape[0]
e = a.shape[1]
c = a.shape[2]
m = b.shape[1]
return paddle.matmul(a.reshape([s, -1]).t(), b).reshape([e, c, m])
elif rule == "sec,ecm->sm":
return paddle.matmul(a.reshape([a.shape[0], -1]), b.reshape([-1, b.shape[-1]]))
elif rule == "ks,ksm->sm":
k = b.shape[0]
s = b.shape[1]
m = b.shape[2]

Check warning on line 53 in paddlenlp/transformers/auto_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/auto_utils.py#L34-L53

Added lines #L34 - L53 were not covered by tests
# [k, s] -> [s, k] -> [s, 1, k]
a = a.t().unsqueeze(1)

Check warning on line 55 in paddlenlp/transformers/auto_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/auto_utils.py#L55

Added line #L55 was not covered by tests
# [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k]
b = b.reshape([k, -1]).t().reshape([s, m, k])

Check warning on line 57 in paddlenlp/transformers/auto_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/auto_utils.py#L57

Added line #L57 was not covered by tests
# bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1]
return paddle.bmm(a, b.transpose(1, 2)).squeeze(2)

Check warning on line 59 in paddlenlp/transformers/auto_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/auto_utils.py#L59

Added line #L59 was not covered by tests
else:
return paddle.einsum(rule, a, b)

Check warning on line 61 in paddlenlp/transformers/auto_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/auto_utils.py#L61

Added line #L61 was not covered by tests
140 changes: 108 additions & 32 deletions paddlenlp/transformers/deepseek_v2/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,14 @@
from typing import List, Optional, Tuple, Union

import paddle
import paddle.distributed as dist
import paddle.nn.functional as F
from paddle import Tensor, nn
from paddle.distributed.fleet.utils import recompute
from paddle.nn import Linear

from ..auto_utils import get_mesh

try:
from paddle.incubate.nn.functional import fused_rotary_position_embedding
except ImportError:
Expand All @@ -40,26 +43,23 @@
except:
flash_attention = None

import paddle.distributed as dist

from ...utils.log import logger
from ...utils.tools import get_env_device
from ..activations import ACT2FN
from ..llama import fusion_ops
from ..llama.modeling import get_use_casual_mask
from ..model_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ..model_utils import PretrainedModel, register_base_model
from ..moe_layer import MoELayer
from ..moe_gate_auto import PretrainedMoEGate
from ..moe_layer_auto import MoELayer
from .configuration import DeepseekV2Config
from .modeling import (
AddAuxiliaryLoss,
DeepseekV2DynamicNTKScalingRotaryEmbedding,
DeepseekV2LinearScalingRotaryEmbedding,
DeepseekV2PretrainingCriterion,
DeepseekV2RMSNorm,
DeepseekV2RotaryEmbedding,
DeepseekV2YarnRotaryEmbedding,
MoEGate,
_expand_2d_mask,
_make_causal_mask,
apply_rotary_pos_emb,
Expand Down Expand Up @@ -117,13 +117,13 @@
)

if isinstance(outputs, tuple):
outputs[0] = outputs[0].reshape([bsz, q_len, v_num_heads, head_dim])
outputs[0] = outputs[0].reshape([bsz, kv_seq_len, v_num_heads, head_dim])

Check warning on line 120 in paddlenlp/transformers/deepseek_v2/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_auto.py#L120

Added line #L120 was not covered by tests
outputs[0] = outputs[0][..., :v_head_dim]
outputs[0] = outputs[0].reshape([bsz, q_len, -1])
outputs[0] = outputs[0].reshape([bsz, kv_seq_len, -1])

Check warning on line 122 in paddlenlp/transformers/deepseek_v2/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_auto.py#L122

Added line #L122 was not covered by tests
else:
outputs = outputs.reshape([bsz, q_len, v_num_heads, head_dim])
outputs = outputs.reshape([bsz, kv_seq_len, v_num_heads, head_dim])

Check warning on line 124 in paddlenlp/transformers/deepseek_v2/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_auto.py#L124

Added line #L124 was not covered by tests
outputs = outputs[..., :v_head_dim]
outputs = outputs.reshape([bsz, q_len, -1])
outputs = outputs.reshape([bsz, kv_seq_len, -1])

Check warning on line 126 in paddlenlp/transformers/deepseek_v2/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_auto.py#L126

Added line #L126 was not covered by tests
return outputs

else:
Expand Down Expand Up @@ -169,8 +169,72 @@
return (attn_output, attn_weights) if output_attentions else attn_output


class MoEGate(PretrainedMoEGate):
def __init__(self, config, num_experts, expert_hidden_size, **kwargs):
super().__init__(config, num_experts, expert_hidden_size, **kwargs)

Check warning on line 174 in paddlenlp/transformers/deepseek_v2/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_auto.py#L174

Added line #L174 was not covered by tests
# [hidden_size, n_expert]

self.scoring_func = config.scoring_func
self.topk_method = config.topk_method

Check warning on line 178 in paddlenlp/transformers/deepseek_v2/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_auto.py#L177-L178

Added lines #L177 - L178 were not covered by tests

self.weight = paddle.create_parameter(

Check warning on line 180 in paddlenlp/transformers/deepseek_v2/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_auto.py#L180

Added line #L180 was not covered by tests
shape=[expert_hidden_size, num_experts],
dtype=paddle.get_default_dtype(),
is_bias=False,
default_initializer=nn.initializer.Constant(1.0),
)

if config.topk_method == "noaux_tc":
self.e_score_correction_bias = paddle.create_parameter(

Check warning on line 188 in paddlenlp/transformers/deepseek_v2/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_auto.py#L187-L188

Added lines #L187 - L188 were not covered by tests
shape=[num_experts],
dtype=paddle.get_default_dtype(),
default_initializer=nn.initializer.Constant(0.0),
)

def forward(self, hidden_states):
"""
Args:
hidden_states (_type_): [batch_size * seq_len, hidden_size]
"""
_, h_dim = hidden_states.shape

Check warning on line 199 in paddlenlp/transformers/deepseek_v2/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_auto.py#L199

Added line #L199 was not covered by tests

# compute gating score
logits = F.linear(hidden_states, self.weight, None)

Check warning on line 202 in paddlenlp/transformers/deepseek_v2/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_auto.py#L202

Added line #L202 was not covered by tests

with paddle.amp.auto_cast(False):
scores = self.gate_score_func(logits=logits)
scores = scores.cast(paddle.get_default_dtype())

Check warning on line 206 in paddlenlp/transformers/deepseek_v2/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_auto.py#L204-L206

Added lines #L204 - L206 were not covered by tests

capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss = self.topkgating(scores)

Check warning on line 208 in paddlenlp/transformers/deepseek_v2/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_auto.py#L208

Added line #L208 was not covered by tests

return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss

Check warning on line 210 in paddlenlp/transformers/deepseek_v2/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_auto.py#L210

Added line #L210 was not covered by tests


class AddAuxiliaryLoss(paddle.autograd.PyLayer):
"""
The trick function of adding auxiliary (aux) loss,
which includes the gradient of the aux loss during backpropagation.
"""

@staticmethod
def forward(ctx, x, loss):
assert paddle.numel(loss) == 1
ctx.dtype = loss.dtype
ctx.required_aux_loss = not loss.stop_gradient
return x

Check warning on line 224 in paddlenlp/transformers/deepseek_v2/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_auto.py#L221-L224

Added lines #L221 - L224 were not covered by tests

@staticmethod
def backward(ctx, grad_output):
grad_loss = None
if ctx.required_aux_loss:

Check warning on line 229 in paddlenlp/transformers/deepseek_v2/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_auto.py#L228-L229

Added lines #L228 - L229 were not covered by tests
# grad_loss = paddle.ones(1, dtype=ctx.dtype)
grad_loss = paddle.to_tensor(1, dtype=ctx.dtype)
grad_loss = dist.shard_tensor(grad_loss, get_mesh(), [dist.Partial(dist.ReduceType.kRedAvg)])
return grad_output, grad_loss

Check warning on line 233 in paddlenlp/transformers/deepseek_v2/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_auto.py#L231-L233

Added lines #L231 - L233 were not covered by tests


class DeepseekV2MLPAuto(nn.Layer):
def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size=None):
def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size=None, is_moe=False):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
Expand All @@ -182,6 +246,22 @@

self.act_fn = ACT2FN[config.hidden_act]

def redistribute_expert(self, mesh, placements):
"""
Place the experts on different devices.
"""
self.gate_proj.weight = dist.shard_tensor(self.gate_proj.weight, mesh, placements)
if self.gate_proj.bias is not None:
self.gate_proj.bias = dist.shard_tensor(self.gate_proj.bias, mesh, placements)

Check warning on line 255 in paddlenlp/transformers/deepseek_v2/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_auto.py#L253-L255

Added lines #L253 - L255 were not covered by tests

self.up_proj.weight = dist.shard_tensor(self.up_proj.weight, mesh, placements)
if self.up_proj.bias is not None:
self.up_proj.bias = dist.shard_tensor(self.up_proj.bias, mesh, placements)

Check warning on line 259 in paddlenlp/transformers/deepseek_v2/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_auto.py#L257-L259

Added lines #L257 - L259 were not covered by tests

self.down_proj.weight = dist.shard_tensor(self.down_proj.weight, mesh, placements)
if self.down_proj.bias is not None:
self.down_proj.bias = dist.shard_tensor(self.down_proj.bias, mesh, placements)

Check warning on line 263 in paddlenlp/transformers/deepseek_v2/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_auto.py#L261-L263

Added lines #L261 - L263 were not covered by tests

def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
Expand Down Expand Up @@ -217,7 +297,7 @@
self.alpha = config.aux_loss_alpha
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV2MLPAuto(config=config, intermediate_size=intermediate_size)
self.shared_experts = DeepseekV2MLPAuto(config=config, intermediate_size=intermediate_size, is_moe=True)

Check warning on line 300 in paddlenlp/transformers/deepseek_v2/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_auto.py#L300

Added line #L300 was not covered by tests

def forward(self, hidden_states):
final_hidden_states, l_aux, l_zloss = super().forward(hidden_states)
Expand Down Expand Up @@ -389,13 +469,13 @@
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)

query_states = paddle.empty([bsz, q_len, self.num_heads, self.q_head_dim], dtype=self.config.dtype)
query_states = paddle.concat([q_nope, q_pe], axis=-1)
query_states = paddle.concat([q_nope, q_pe], axis=3)

Check warning on line 472 in paddlenlp/transformers/deepseek_v2/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_auto.py#L472

Added line #L472 was not covered by tests
# query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
# query_states[:, :, :, self.qk_nope_head_dim :] = q_pe

key_states = paddle.empty([bsz, q_len, self.num_heads, self.q_head_dim], dtype=self.config.dtype)
# input[0]'s shape = [1, 2048, 16, 128], input[1]'s shape = [1, 2048, 1, 64].
key_states = paddle.concat([k_nope, k_pe.expand([bsz, q_len, self.num_heads, k_pe.shape[-1]])], axis=-1)
key_states = paddle.concat([k_nope, k_pe.expand([bsz, q_len, self.num_heads, k_pe.shape[-1]])], axis=3)

Check warning on line 478 in paddlenlp/transformers/deepseek_v2/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_auto.py#L478

Added line #L478 was not covered by tests

# key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
# key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
Expand Down Expand Up @@ -972,23 +1052,19 @@
def auto_dist_config(self, prefix=""):
if prefix != "":
assert prefix.endswith(".")
config = {
"dp_config": {"sharding_level": 1, "offload": False, "exclude_layer": None},
"mp_config": {
"parallelize_plan": {
f"{prefix}deepseek_v2.embed_tokens": dist.ColWiseParallel(gather_output=True),
f"{prefix}deepseek_v2.layers.*.self_attn.q_b_proj": dist.ColWiseParallel(),
f"{prefix}deepseek_v2.layers.*.self_attn.q_proj": dist.ColWiseParallel(),
f"{prefix}deepseek_v2.layers.*.self_attn.kv_b_proj": dist.ColWiseParallel(),
f"{prefix}deepseek_v2.layers.*.self_attn.o_proj": dist.RowWiseParallel(),
f"{prefix}deepseek_v2.layers.*.mlp.gate_proj": dist.ColWiseParallel(),
f"{prefix}deepseek_v2.layers.*.mlp.up_proj": dist.ColWiseParallel(),
f"{prefix}deepseek_v2.layers.*.mlp.down_proj": dist.RowWiseParallel(),
f"{prefix}deepseek_v2.layers.*.mlp.shared_experts.gate_proj": dist.ColWiseParallel(),
f"{prefix}deepseek_v2.layers.*.mlp.shared_experts.up_proj": dist.ColWiseParallel(),
f"{prefix}deepseek_v2.layers.*.mlp.shared_experts.down_proj": dist.RowWiseParallel(),
f"{prefix}lm_head.weight": dist.ColWiseParallel(),
}
},
}
config = {}

Check warning on line 1055 in paddlenlp/transformers/deepseek_v2/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_auto.py#L1055

Added line #L1055 was not covered by tests
# config = {
# "mp_config": {
# "parallelize_plan": {
# f"{prefix}deepseek_v2.embed_tokens": dist.ColWiseParallel(gather_output=True),
# f"{prefix}deepseek_v2.layers.*.self_attn.q_proj": dist.ColWiseParallel(),
# f"{prefix}deepseek_v2.layers.*.self_attn.kv_b_proj": dist.ColWiseParallel(),
# f"{prefix}deepseek_v2.layers.*.self_attn.o_proj": dist.RowWiseParallel(),
# f"{prefix}deepseek_v2.layers.*.mlp.gate_proj": dist.ColWiseParallel(),
# f"{prefix}deepseek_v2.layers.*.mlp.up_proj": dist.ColWiseParallel(),
# f"{prefix}deepseek_v2.layers.*.mlp.down_proj": dist.RowWiseParallel(),
# f"{prefix}lm_head.weight": dist.ColWiseParallel(),
# }
# },
# }
return config
Loading
Loading