From 81a9849f198ec67a09e547facf094e70f575b4f7 Mon Sep 17 00:00:00 2001 From: Yichen Zhang Date: Mon, 17 Feb 2025 21:34:55 +0800 Subject: [PATCH 1/3] add auto parallel moe layer --- .../transformers/deepseek_v2/modeling_auto.py | 119 +++- paddlenlp/transformers/moe_gate_auto.py | 559 ++++++++++++++++++ paddlenlp/transformers/moe_layer_auto.py | 368 ++++++++++++ 3 files changed, 1014 insertions(+), 32 deletions(-) create mode 100644 paddlenlp/transformers/moe_gate_auto.py create mode 100644 paddlenlp/transformers/moe_layer_auto.py diff --git a/paddlenlp/transformers/deepseek_v2/modeling_auto.py b/paddlenlp/transformers/deepseek_v2/modeling_auto.py index 284b12a29cb8..0714a2a24ad8 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling_auto.py +++ b/paddlenlp/transformers/deepseek_v2/modeling_auto.py @@ -40,8 +40,6 @@ except: flash_attention = None -import paddle.distributed as dist - from ...utils.log import logger from ...utils.tools import get_env_device from ..activations import ACT2FN @@ -49,17 +47,16 @@ from ..llama.modeling import get_use_casual_mask from ..model_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ..model_utils import PretrainedModel, register_base_model -from ..moe_layer import MoELayer +from ..moe_gate_auto import PretrainedMoEGate +from ..moe_layer_auto import MoELayer from .configuration import DeepseekV2Config from .modeling import ( - AddAuxiliaryLoss, DeepseekV2DynamicNTKScalingRotaryEmbedding, DeepseekV2LinearScalingRotaryEmbedding, DeepseekV2PretrainingCriterion, DeepseekV2RMSNorm, DeepseekV2RotaryEmbedding, DeepseekV2YarnRotaryEmbedding, - MoEGate, _expand_2d_mask, _make_causal_mask, apply_rotary_pos_emb, @@ -117,13 +114,13 @@ def scaled_dot_product_attention( ) if isinstance(outputs, tuple): - outputs[0] = outputs[0].reshape([bsz, q_len, v_num_heads, head_dim]) + outputs[0] = outputs[0].reshape([bsz, kv_seq_len, v_num_heads, head_dim]) outputs[0] = outputs[0][..., :v_head_dim] - outputs[0] = outputs[0].reshape([bsz, q_len, -1]) + outputs[0] = outputs[0].reshape([bsz, kv_seq_len, -1]) else: - outputs = outputs.reshape([bsz, q_len, v_num_heads, head_dim]) + outputs = outputs.reshape([bsz, kv_seq_len, v_num_heads, head_dim]) outputs = outputs[..., :v_head_dim] - outputs = outputs.reshape([bsz, q_len, -1]) + outputs = outputs.reshape([bsz, kv_seq_len, -1]) return outputs else: @@ -169,8 +166,70 @@ def scaled_dot_product_attention( return (attn_output, attn_weights) if output_attentions else attn_output +class MoEGate(PretrainedMoEGate): + def __init__(self, config, num_experts, expert_hidden_size, **kwargs): + super().__init__(config, num_experts, expert_hidden_size, **kwargs) + # [hidden_size, n_expert] + + self.scoring_func = config.scoring_func + self.topk_method = config.topk_method + + self.weight = paddle.create_parameter( + shape=[expert_hidden_size, num_experts], + dtype=paddle.get_default_dtype(), + is_bias=False, + default_initializer=nn.initializer.Constant(1.0), + ) + + if config.topk_method == "noaux_tc": + self.e_score_correction_bias = paddle.create_parameter( + shape=[num_experts], + dtype=paddle.get_default_dtype(), + default_initializer=nn.initializer.Constant(0.0), + ) + + def forward(self, hidden_states): + """ + Args: + hidden_states (_type_): [batch_size * seq_len, hidden_size] + """ + _, h_dim = hidden_states.shape + + # compute gating score + logits = F.linear(hidden_states, self.weight, None) + + with paddle.amp.auto_cast(False): + scores = self.gate_score_func(logits=logits) + scores = scores.cast(paddle.get_default_dtype()) + + capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss = self.topkgating(scores) + + return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss + + +class AddAuxiliaryLoss(paddle.autograd.PyLayer): + """ + The trick function of adding auxiliary (aux) loss, + which includes the gradient of the aux loss during backpropagation. + """ + + @staticmethod + def forward(ctx, x, loss): + assert paddle.numel(loss) == 1 + ctx.dtype = loss.dtype + ctx.required_aux_loss = not loss.stop_gradient + return x + + @staticmethod + def backward(ctx, grad_output): + grad_loss = None + if ctx.required_aux_loss: + grad_loss = paddle.ones(1, dtype=ctx.dtype) + return grad_output, grad_loss + + class DeepseekV2MLPAuto(nn.Layer): - def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size=None): + def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size=None, is_moe=False): super().__init__() self.config = config self.hidden_size = config.hidden_size if hidden_size is None else hidden_size @@ -217,7 +276,7 @@ def __init__(self, config: DeepseekV2Config): self.alpha = config.aux_loss_alpha if config.n_shared_experts is not None: intermediate_size = config.moe_intermediate_size * config.n_shared_experts - self.shared_experts = DeepseekV2MLPAuto(config=config, intermediate_size=intermediate_size) + self.shared_experts = DeepseekV2MLPAuto(config=config, intermediate_size=intermediate_size, is_moe=True) def forward(self, hidden_states): final_hidden_states, l_aux, l_zloss = super().forward(hidden_states) @@ -389,13 +448,13 @@ def forward( q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) query_states = paddle.empty([bsz, q_len, self.num_heads, self.q_head_dim], dtype=self.config.dtype) - query_states = paddle.concat([q_nope, q_pe], axis=-1) + query_states = paddle.concat([q_nope, q_pe], axis=3) # query_states[:, :, :, : self.qk_nope_head_dim] = q_nope # query_states[:, :, :, self.qk_nope_head_dim :] = q_pe key_states = paddle.empty([bsz, q_len, self.num_heads, self.q_head_dim], dtype=self.config.dtype) # input[0]'s shape = [1, 2048, 16, 128], input[1]'s shape = [1, 2048, 1, 64]. - key_states = paddle.concat([k_nope, k_pe.expand([bsz, q_len, self.num_heads, k_pe.shape[-1]])], axis=-1) + key_states = paddle.concat([k_nope, k_pe.expand([bsz, q_len, self.num_heads, k_pe.shape[-1]])], axis=3) # key_states[:, :, :, : self.qk_nope_head_dim] = k_nope # key_states[:, :, :, self.qk_nope_head_dim :] = k_pe @@ -972,23 +1031,19 @@ def _reorder_cache(past_key_values, beam_idx): def auto_dist_config(self, prefix=""): if prefix != "": assert prefix.endswith(".") - config = { - "dp_config": {"sharding_level": 1, "offload": False, "exclude_layer": None}, - "mp_config": { - "parallelize_plan": { - f"{prefix}deepseek_v2.embed_tokens": dist.ColWiseParallel(gather_output=True), - f"{prefix}deepseek_v2.layers.*.self_attn.q_b_proj": dist.ColWiseParallel(), - f"{prefix}deepseek_v2.layers.*.self_attn.q_proj": dist.ColWiseParallel(), - f"{prefix}deepseek_v2.layers.*.self_attn.kv_b_proj": dist.ColWiseParallel(), - f"{prefix}deepseek_v2.layers.*.self_attn.o_proj": dist.RowWiseParallel(), - f"{prefix}deepseek_v2.layers.*.mlp.gate_proj": dist.ColWiseParallel(), - f"{prefix}deepseek_v2.layers.*.mlp.up_proj": dist.ColWiseParallel(), - f"{prefix}deepseek_v2.layers.*.mlp.down_proj": dist.RowWiseParallel(), - f"{prefix}deepseek_v2.layers.*.mlp.shared_experts.gate_proj": dist.ColWiseParallel(), - f"{prefix}deepseek_v2.layers.*.mlp.shared_experts.up_proj": dist.ColWiseParallel(), - f"{prefix}deepseek_v2.layers.*.mlp.shared_experts.down_proj": dist.RowWiseParallel(), - f"{prefix}lm_head.weight": dist.ColWiseParallel(), - } - }, - } + config = {} + # config = { + # "mp_config": { + # "parallelize_plan": { + # f"{prefix}deepseek_v2.embed_tokens": dist.ColWiseParallel(gather_output=True), + # f"{prefix}deepseek_v2.layers.*.self_attn.q_proj": dist.ColWiseParallel(), + # f"{prefix}deepseek_v2.layers.*.self_attn.kv_b_proj": dist.ColWiseParallel(), + # f"{prefix}deepseek_v2.layers.*.self_attn.o_proj": dist.RowWiseParallel(), + # f"{prefix}deepseek_v2.layers.*.mlp.gate_proj": dist.ColWiseParallel(), + # f"{prefix}deepseek_v2.layers.*.mlp.up_proj": dist.ColWiseParallel(), + # f"{prefix}deepseek_v2.layers.*.mlp.down_proj": dist.RowWiseParallel(), + # f"{prefix}lm_head.weight": dist.ColWiseParallel(), + # } + # }, + # } return config diff --git a/paddlenlp/transformers/moe_gate_auto.py b/paddlenlp/transformers/moe_gate_auto.py new file mode 100644 index 000000000000..0cd9fd62a8e7 --- /dev/null +++ b/paddlenlp/transformers/moe_gate_auto.py @@ -0,0 +1,559 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import traceback +from typing import Tuple + +import paddle +import paddle.distributed as dist +import paddle.nn as nn +import paddle.nn.functional as F + +from ..utils.log import logger + + +class MoEGateMixin: + def gate_score_func(self, logits: paddle.Tensor) -> paddle.Tensor: + # [..., hidden_dim] -> [..., num_experts] + with paddle.amp.auto_cast(False): + scoring_func = getattr(self, "scoring_func", None) + if scoring_func == "softmax": + scores = F.softmax(logits.cast("float32"), axis=-1) + elif scoring_func == "sigmoid": + scores = F.sigmoid(logits.cast("float32")) + elif scoring_func == "tanh": + scores = F.tanh(logits.cast("float32")) + elif scoring_func == "relu": + scores = F.relu(logits.cast("float32")) + elif scoring_func == "gelu": + scores = F.gelu(logits.cast("float32")) + elif scoring_func == "leaky_relu": + scores = F.leaky_relu(logits.cast("float32")) + else: + logger.warning_once( + f"insupportable scoring function for MoE gating: {scoring_func}, use softmax instead" + ) + scores = F.softmax(logits.cast("float32"), axis=-1) + return scores + + def gumbel_rsample(self, logits: paddle.Tensor) -> paddle.Tensor: + gumbel = paddle.distribution.gumbel.Gumbel(0, 1) + return gumbel.rsample(logits.shape) + + def uniform_sample(self, logits: paddle.Tensor) -> paddle.Tensor: + uniform = paddle.distribution.uniform.Uniform(0, 1) + return uniform.sample(logits.shape) + + @paddle.no_grad() + def _one_hot_to_float(self, x, num_classes): + if x.dtype not in (paddle.int32, paddle.int64): + x = paddle.cast(x, paddle.int64) + return F.one_hot(x, num_classes=num_classes).cast(paddle.get_default_dtype()) + + @paddle.no_grad() + def _one_hot_to_int64(self, x, num_classes): + if x.dtype not in (paddle.int32, paddle.int64): + x = paddle.cast(x, paddle.int64) + return F.one_hot(x, num_classes=num_classes).cast(paddle.int64) + + @paddle.no_grad() + def _capacity( + self, + gates: paddle.Tensor, + capacity_factor: float, + max_capacity: int, + min_capacity: int, + ) -> paddle.Tensor: + """Calculate the capacity for each expert based on the gates and capacity factor. + + Args: + gates (paddle.Tensor): A tensor of shape [num_tokens, num_experts] representing the probability distribution + over experts for each token. + capacity_factor (float): A scalar float value representing the capacity factor for each expert. + min_capacity (int): A scalar integer value representing the minimum capacity for each expert. + + Returns: + int: A tensor value representing the calculated capacity for each expert. + """ + assert gates.ndim == 2, f"gates should be 2D, but got {gates.ndim}, {gates.shape}" + # gates has shape of SE + num_tokens = gates.shape[0] + num_experts = gates.shape[1] + print(f"==== num_tokens:{num_tokens}, num_experts:{num_experts} ====") + capacity = int((num_tokens // num_experts) * capacity_factor) + if capacity < min_capacity: + capacity = min_capacity + if capacity > max_capacity: + capacity = max_capacity + assert capacity > 0, f"requires capacity > 0, capacity_factor: {capacity_factor}, input_shape: {gates.shape}" + + return capacity + + def _cal_aux_loss(self, gates, mask): + """ + Calculate auxiliary loss + + Args: + gates (paddle.Tensor): Represents the output probability of each expert. The shape is [batch_size, num_experts] + mask (paddle.Tensor): Represents whether each sample belongs to a certain expert. The shape is [batch_size, num_experts] + + Returns: + paddle.Tensor: The value of auxiliary loss. + + """ + # TODO: @DrownFish19 update aux_loss for Qwen2MoE and DeepSeekV2&V3 + me = paddle.mean(gates, axis=0) + ce = paddle.mean(mask.cast("float32"), axis=0) + if self.global_aux_loss: + me_list, ce_list = [], [] + # dist.all_gather(me_list, me, group=self.group) + # dist.all_gather(ce_list, ce, group=self.group) + + me_list[self.rank] = me + ce_list[self.rank] = ce + me = paddle.stack(me_list).mean(0) + ce = paddle.stack(ce_list).mean(0) + aux_loss = paddle.sum(me * ce) * float(self.num_experts) + return aux_loss + + def _cal_z_loss(self, logits) -> paddle.Tensor: + """ + Calculate the z loss. + + Args: + logits (paddle.Tensor): Model output. The shape is [batch_size, num_experts]. + + Returns: + paddle.Tensor: The z loss value. + """ + l_zloss = logits.exp().sum(1).log().square().mean() + return l_zloss + + def _cal_orthogonal_loss(self) -> paddle.Tensor: + """Gate weight orthogonal loss. + + Returns: + Paddle.Tensor: orthogonal loss + """ + weight = F.normalize(self.weight, axis=0) + orthogonal_loss = paddle.mean(paddle.square(paddle.matmul(weight.T, weight) - paddle.eye(self.num_experts))) + return orthogonal_loss + + +class PretrainedMoEGate(nn.Layer, MoEGateMixin): + def __init__(self, config, num_experts, expert_hidden_size, **kwargs): + super(PretrainedMoEGate, self).__init__() + + self.config = config + + self.num_experts = num_experts + self.expert_hidden_size = expert_hidden_size + + # force keep in float32 when using amp + self._cast_to_low_precision = False + + self.capacity_factor = kwargs.pop("capacity_factor", 1.0) + self.eval_capacity_factor = kwargs.pop("eval_capacity_factor", 1.0) + self.min_capacity = kwargs.pop("min_capacity", 1.0) + self.max_capacity = kwargs.pop("max_capacity", pow(2, 32)) + + self.group = kwargs.pop("group", None) + self.global_aux_loss = kwargs.pop("global_aux_loss", False) + if self.global_aux_loss: + assert self.group is not None, "group is required when global_aux_loss is True" + self.rank = dist.get_rank(self.group) + + self.expert_drop = kwargs.pop("expert_drop", False) + self.noisy_gate_policy = kwargs.pop("noisy_gate_policy", None) + self.drop_tokens = kwargs.pop("drop_tokens", True) + self.use_rts = kwargs.pop("use_rts", True) + self.top2_2nd_expert_sampling = kwargs.pop("top2_2nd_expert_sampling", True) + + self.drop_policy = kwargs.pop("drop_policy", "probs") + # Qwen2MoE: greedy + # DeepSeekV2&V3: group_limited_greedy for training, and noaux_tc for inference + self.topk_method = kwargs.pop("topk_method", "greedy") + self.top_k = kwargs.pop("top_k", 2) + self.n_group = kwargs.pop("n_group", 1) # for group_limited_greedy + self.topk_group = kwargs.pop("topk_group", 1) # for group_limited_greedy + self.norm_topk_prob = kwargs.pop("norm_topk_prob", False) + self.routed_scaling_factor = kwargs.pop("routed_scaling_factor", 1.0) + + def _priority(self, topk_idx: paddle.Tensor, capacity: int) -> paddle.Tensor: + """_summary_ + The priority is the cumulative sum of the expert indices. + + This method is used in hunyuan model + Args: + topk_idx (paddle.Tensor): [batch_size * seq_len, topk] + + Returns: + paddle.Tensor: cumsum locations + """ + # Make num_selected_experts the leading axis to ensure that top-1 choices + # have priority over top-2 choices, which have priority over top-3 choices, + # etc. + expert_index = paddle.transpose(topk_idx, [1, 0]) # [topk, B*S] + # Shape: [num_selected_experts * tokens_per_group] + expert_index = expert_index.reshape([-1]) + + # Create mask out of indices. + # Shape: [tokens_per_group * num_selected_experts, num_experts]. + expert_mask = F.one_hot(expert_index, self.num_experts).cast(paddle.int32) + + # Experts have a fixed capacity that we cannot exceed. A token's priority + # within the expert's buffer is given by the masked, cumulative capacity of + # its target expert. + # Shape: [tokens_per_group * num_selected_experts, num_experts]. + token_priority = paddle.cumsum(expert_mask, axis=0) * expert_mask - 1 + # Shape: [num_selected_experts, tokens_per_group, num_experts]. + token_priority = token_priority.reshape((self.top_k, -1, self.num_experts)) + # Shape: [tokens_per_group, num_selected_experts, num_experts]. + token_priority = paddle.transpose(token_priority, [1, 0, 2]) + # For each token, across all selected experts, select the only non-negative + # (unmasked) priority. Now, for group G routing to expert E, token T has + # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E + # is its targeted expert. + # Shape: [tokens_per_group, num_experts]. + token_priority = paddle.max(token_priority, axis=1) + + # Token T can only be routed to expert E if its priority is positive and + # less than the expert capacity. One-hot matrix will ignore indices outside + # the range [0, expert_capacity). + # Shape: [tokens_per_group, num_experts, expert_capacity]. + valid_mask = paddle.logical_and(token_priority >= 0, token_priority < capacity) + token_priority = paddle.masked_fill(token_priority, ~valid_mask, 0) + dispatch_mask = F.one_hot(token_priority, capacity).cast(paddle.int32) + valid_mask = valid_mask.unsqueeze(-1).expand(valid_mask.shape + [capacity]) + dispatch_mask = paddle.masked_fill(dispatch_mask, ~valid_mask, 0) + + return dispatch_mask + + def _topk_greedy(self, scores: paddle.Tensor, k: int) -> Tuple[paddle.Tensor, paddle.Tensor]: + """_summary_ + + Args: + scores (paddle.Tensor): [bsz*seq_len, n_experts] + k (int): select the top k experts + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: topk_weight, topk_idx + topk_weight: [bsz*seq_len, k] + topk_idx: [bsz*seq_len, k] + """ + topk_weight, topk_idx = paddle.topk(scores, k=k, axis=-1, sorted=True) + return topk_weight, topk_idx + + def _topk_group_limited_greedy( + self, scores: paddle.Tensor, k: int, n_group: int, topk_group: int + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """_summary_ + + Args: + scores (paddle.Tensor): [bsz*seq_len, n_experts] + k (int): select the top k experts in each group + n_groups (int): the number of groups for all experts + topk_group (int): the number of groups selected + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: topk_weight, topk_idx + topk_weight: [bsz*seq_len, k] + topk_idx: [bsz*seq_len, k] + + Note: the group size is normal greater than the number of k + """ + bsz_seq_len, n_experts = scores.shape + assert n_experts % n_group == 0, "n_experts must be divisible by n_groups" + + group_scores = scores.reshape([0, n_group, -1]).max(axis=-1) # [n, n_group] + group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=False)[1] # [n, top_k_group] + group_mask = paddle.zeros_like(group_scores).put_along_axis(group_idx, paddle.to_tensor(1.0), axis=-1) # fmt:skip + score_mask = ( + group_mask.unsqueeze(-1).expand([bsz_seq_len, n_group, n_experts // n_group]).reshape([bsz_seq_len, -1]) + ) # [n, e] + tmp_scores = scores * score_mask # [n, e] + topk_weight, topk_idx = paddle.topk(tmp_scores, k=k, axis=-1, sorted=False) + + return topk_weight, topk_idx + + def _topk_noaux_tc( + self, scores: paddle.Tensor, k: int, n_group: int, topk_group: int + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """_summary_ + + Args: + scores (paddle.Tensor): [bsz*seq_len, n_experts] + k (int): select the top k experts in each group + n_groups (int): the number of groups for all experts + topk_group (int): the number of groups selected + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: topk_weight, topk_idx + topk_weight: [bsz*seq_len, k] + topk_idx: [bsz*seq_len, k] + + Note: the group size is normal greater than the number of k + """ + bsz_seq_len, n_experts = scores.shape + assert n_experts % n_group == 0, "n_experts must be divisible by n_groups" + + assert self.e_score_correction_bias is not None, "e_score_correction_bias is None" + if self.e_score_correction_bias.is_dist(): + local_e_score_correction_bias = dist.auto_parallel.api.dtensor_to_local(self.e_score_correction_bias) + else: + local_e_score_correction_bias = self.e_score_correction_bias + # scores = scores.reshape([bsz_seq_len, -1]) + self.e_score_correction_bias.unsqueeze(0) + scores = scores.reshape([bsz_seq_len, -1]) + local_e_score_correction_bias.unsqueeze(0) + group_scores = scores.reshape([bsz_seq_len, self.n_group, -1]).topk(2, axis=-1)[0].sum(axis=-1) # [n, n_group] + group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=False)[1] # [n, top_k_group] + group_mask = paddle.zeros_like(group_scores).put_along_axis(group_idx, paddle.to_tensor(1.0), axis=-1) # fmt:skip + score_mask = ( + group_mask.unsqueeze(-1).expand([bsz_seq_len, n_group, n_experts // n_group]).reshape([bsz_seq_len, -1]) + ) # [n, e] + tmp_scores = scores * score_mask # [n, e] + topk_weight, topk_idx = paddle.topk(tmp_scores, k=k, axis=-1, sorted=False) + topk_weight = scores.gather(topk_idx, axis=1) if not self.training else topk_weight + + return topk_weight, topk_idx + + def top1gating( + self, + logits: paddle.Tensor, + used_token: paddle.Tensor = None, + ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Implements Top1Gating on logits.""" + if self.noisy_gate_policy == "RSample": + logits += self.gumbel_rsample(logits.shape) + + print("==== top1 ====") + traceback.print_stack() + gates = self.gate_score_func(logits=logits) + capacity = self._capacity(gates, self.capacity_factor, self.max_capacity, self.min_capacity) + + # Create a mask for 1st's expert per token + # noisy gating + # Only save the position of the maximum value + indices1_s = paddle.argmax(logits if self.noisy_gate_policy == "RSample" else gates, axis=1) + # Convert the position of the maximum value to a one-hot vector [s, e] + mask1 = self._one_hot_to_float(indices1_s, num_classes=self.num_experts) + + # mask only used tokens + if used_token is not None: + mask1 = paddle.einsum( + "s,se->se", used_token, mask1 + ) # Element-wise multiply used_token with mask1 to obtain a new mask1 + + # gating decisions + exp_counts = paddle.sum(mask1, axis=0) # Calculate the number of tokens for each expert + + # if we don't want to drop any tokens + if not self.drop_tokens: + new_capacity = paddle.max(exp_counts) # Calculate the number of tokens for each expert + # Communicate across expert processes to pick the maximum capacity. + if self.group is not None: + dist.all_reduce( + new_capacity, op=dist.ReduceOp.MAX, group=self.group + ) # Calculate the maximum value among expert processes + # Make sure the capacity value does not exceed the number of tokens. + capacity = int(min(new_capacity, paddle.tensor(mask1.size(0)))) + + l_aux = self._cal_aux_loss(gates, mask1) + l_zloss = self._cal_z_loss(logits) + + # Random Token Selection + if self.use_rts: + mask1_rand = mask1 * self.uniform_sample(mask1) + else: + mask1_rand = mask1 + + assert ( + logits.shape[0] >= self.min_capacity + ), "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size." + + _, top_idx = paddle.topk(mask1_rand, k=capacity, axis=0) # Select top_capacity tokens + + new_mask1 = mask1 * paddle.zeros_like(mask1).put_along_axis(top_idx, paddle.to_tensor(1.0), axis=0) + mask1 = new_mask1 + + # Compute locations in capacity buffer + locations1 = paddle.cumsum(mask1, axis=0) - 1 # Compute the position of each token in mask1 + + # Store the capacity location for each token + locations1_s = paddle.sum(locations1 * mask1, axis=1).cast(paddle.int64) + + # Normalize gate probabilities + mask1_float = mask1.cast(paddle.float32) + gates = gates / gates * mask1_float + + locations1_sc = self._one_hot_to_float(locations1_s, capacity) + combine_weights = paddle.einsum("se,sc->sec", gates, locations1_sc) + dispatch_mask = combine_weights.cast(paddle.bool).detach() + + return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss + + def top2gating( + self, + logits: paddle.Tensor, + ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + # everything is in fp32 in this function + print("==== top2 ====") + traceback.print_stack() + gates = self.gate_score_func(logits=logits) + + # Create a mask for 1st's expert per token. + indices1_s = paddle.argmax(gates, axis=1) # [S, 1] + mask1 = self._one_hot_to_int64(indices1_s, self.num_experts) # [S, E] + + if self.top2_2nd_expert_sampling: + # Create a mask for 2nd's expert per token using Gumbel-max trick. + # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ + logits += self.gumbel_rsample(logits) + + # Replace top-expert with min value + logits_except1 = logits.masked_fill(mask1.cast(paddle.bool), float("-inf")) # [S, E] + indices2_s = paddle.argmax(logits_except1, axis=1) # [S, 1] + mask2 = self._one_hot_to_int64(indices2_s, self.num_experts) # [S, E] + + # Note: mask1 and mask2 can be combined to form a single mask. + # mask = paddle.concat([mask1, mask2], axis=0) + # locations = paddle.cumsum(mask, axis=0) - 1 + # locations1, locations2 = locations.split(2, axis=0) + # Compute locations in capacity buffer. + locations1 = paddle.cumsum(mask1, axis=0) - 1 # [S, E] + locations2 = paddle.cumsum(mask2, axis=0) - 1 # [S, E] + # Update 2nd's location by accounting for locations of 1st. + locations2 += paddle.sum(mask1, axis=0, keepdim=True) + + l_aux = self._cal_aux_loss(gates, mask1) + l_zloss = self._cal_z_loss(logits) + + # gating decisions + exp_counts = paddle.sum(mask1 + mask2, axis=0) + if self.drop_tokens: + # Calculate configured capacity and remove locations outside capacity from mask + capacity = self._capacity(gates, self.capacity_factor, self.max_capacity, self.min_capacity) + # Remove locations outside capacity from mask. + mask1 *= (locations1 < capacity).cast(paddle.int64) + mask2 *= (locations2 < capacity).cast(paddle.int64) + else: + # Do not drop tokens - set capacity according to current expert assignments + new_capacity = paddle.max(exp_counts) + if self.group is not None: + dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=self.group) + capacity = int(new_capacity) + + # Store the capacity location for each token. + locations1_s = paddle.sum(locations1 * mask1, axis=1) + locations2_s = paddle.sum(locations2 * mask2, axis=1) + + # Normalize gate probabilities + mask1_float = mask1.cast(paddle.float32) + mask2_float = mask2.cast(paddle.float32) + gates1_s = paddle.einsum("se,se->s", gates, mask1_float) + gates2_s = paddle.einsum("se,se->s", gates, mask2_float) + denom_s = gates1_s + gates2_s + # Avoid divide-by-zero + denom_s = paddle.clip(denom_s, min=paddle.finfo(denom_s.dtype).eps) + gates1_s /= denom_s + gates2_s /= denom_s + + # Calculate combine_weights and dispatch_mask + gates1 = paddle.einsum("s,se->se", gates1_s, mask1_float) + gates2 = paddle.einsum("s,se->se", gates2_s, mask2_float) + locations1_sc = self._one_hot_to_float(locations1_s, capacity) + locations2_sc = self._one_hot_to_float(locations2_s, capacity) + combine1_sec = paddle.einsum("se,sc->sec", gates1, locations1_sc) + combine2_sec = paddle.einsum("se,sc->sec", gates2, locations2_sc) + combine_weights = combine1_sec + combine2_sec + dispatch_mask = combine_weights.cast(paddle.bool) + + return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss + + def topkgating( + self, + gates: paddle.Tensor, + ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Implements TopKGating on logits.""" + print("==== gates ====") + print(gates) + l_zloss = self._cal_z_loss(gates) + + # get topk gates + if self.topk_method == "greedy": + top_gate, top_idx = self._topk_greedy(gates, k=self.top_k) + elif self.topk_method == "group_limited_greedy": + top_gate, top_idx = self._topk_group_limited_greedy( + gates, k=self.top_k, n_group=self.n_group, topk_group=self.topk_group + ) + elif self.topk_method == "noaux_tc": + top_gate, top_idx = self._topk_noaux_tc( + gates, k=self.top_k, n_group=self.n_group, topk_group=self.topk_group + ) + # norm gate to sum 1 + print("==== top_gate and top_idx ====") + print(top_gate) + print(top_idx) + if self.top_k > 1 and self.norm_topk_prob: + denominator = top_gate.sum(axis=-1, keepdim=True) + 1e-20 + top_gate = top_gate / denominator + else: + top_gate = top_gate * self.routed_scaling_factor + + # get topk mask + mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.to_tensor(1.0), axis=1) + print("==== mask ====") + print(mask) + l_aux = self._cal_aux_loss(gates, mask) + + exp_counts = paddle.sum(mask.cast(paddle.int64), axis=0) + + if self.drop_tokens: + # Calculate configured capacity and remove locations outside capacity from mask + capacity = self._capacity( + gates, + self.capacity_factor * self.top_k, + self.max_capacity, + self.min_capacity, + ) + + # update mask and locations by capacity + if self.drop_policy == "probs": + topk_masked_gates = paddle.zeros_like(gates).put_along_axis(top_idx, top_gate, axis=1) + capacity_probs, capacity_indices = paddle.topk(topk_masked_gates, k=capacity, axis=0, sorted=False) + token_priority = self._priority(capacity_indices, capacity) + + elif self.drop_policy == "position": + token_priority = self._priority(top_idx, capacity) + else: + raise ValueError(f"Invalid drop_policy: {self.drop_policy}") + else: + # Do not drop tokens - set capacity according to current expert assignments + local_capacity = paddle.max(exp_counts) + if self.group is not None: + dist.all_reduce(local_capacity, op=dist.ReduceOp.MAX, group=self.group) + capacity = int(local_capacity) + token_priority = self._priority(top_idx, capacity) + + # normalize gates + gates_masked = gates * mask + gates_s = paddle.sum(gates_masked, axis=-1, keepdim=True) + denom_s = paddle.clip(gates_s, min=paddle.finfo(gates_masked.dtype).eps) + if self.norm_topk_prob: + gates_masked = gates_masked / denom_s + + combine_weights = paddle.einsum("se,sec->sec", gates_masked, token_priority.cast(paddle.get_default_dtype())) + dispatch_mask = combine_weights.cast(paddle.bool) + + return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss diff --git a/paddlenlp/transformers/moe_layer_auto.py b/paddlenlp/transformers/moe_layer_auto.py new file mode 100644 index 000000000000..61528db66188 --- /dev/null +++ b/paddlenlp/transformers/moe_layer_auto.py @@ -0,0 +1,368 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Any, Tuple + +import paddle +import paddle.distributed as dist +import paddle.nn.functional as F +from paddle import Tensor, nn +from paddle.distributed.communication import stream +from paddle.distributed.communication.group import Group + +from .moe_gate_auto import PretrainedMoEGate + + +def print_grad(g, name): + print(f"==== {name} ====") + print(g) + + +def dispatching(x, dispatch_mask, scatter_index, num_experts, capacity): + """ + Rearranges the input tensor `x` based on gate results, truncates it according to the specified capacity, and performs padding. + + Args: + x (Tensor)[Seq, Dim]: The input tensor. + dispatch_mask (List[Tensor[Seq, 1], Tensor[Seq, 1]]): A list of dispatch masks. + scatter_index (Union[List[Tensor[Seq,], Tensor[Seq]], Tensor[Seq, 2]]): A list or tensor representing scatter indices. + num_experts (int): The number of experts. + capacity (int): The capacity size. + + Returns: + Tensor [Expert*Capacity, Dim]: The output tensor after dispatching. + """ + output = None + orig_dtype = x.dtype + if isinstance(scatter_index, paddle.Tensor): + scatter_index = scatter_index.unbind(1) + for i_scatter_index, i_dispatch_mask in zip(scatter_index, dispatch_mask): + init_output = paddle.zeros([num_experts * capacity, x.shape[-1]], dtype="float32") + updates = x * i_dispatch_mask.cast(x.dtype) + if output is None: + output = paddle.scatter( + init_output, + i_scatter_index, + updates, + overwrite=False, + ) + else: + output = output + paddle.scatter( + init_output, + i_scatter_index, + updates, + overwrite=False, + ) + if output.dtype != orig_dtype: + output = output.cast(orig_dtype) + return output + + +def combining(x, combine_weights, scatter_index): + """ + Performs combination and aggregation operations on the input matrix. + + Args: + x: Tensor[num_experts * capacity, dim] - The input matrix to be processed, where the last dimension represents the number of features. + combine_weights: Union[List[Tensor[seq, 1], Tensor[seq, 1]], Tensor[seq, 2, 1]] - A list or tensor containing combination weights for each feature. + scatter_index: Union[List[Tensor[seq], Tensor[seq]], Tensor[seq, 2]] - A tuple of indices indicating which elements are to be aggregated, where the first element is the row index and the second element is the column index. + + Returns: + Tensor: The output matrix after combination and aggregation, with a shape of [n, dim * num_features], where n is the number of samples in the input matrix. + """ + + dim = x.shape[-1] + if isinstance(scatter_index, (list, tuple)): + scatter_index = paddle.concat([i.unsqueeze([-1]) for i in scatter_index], -1) + scatter_index = scatter_index.reshape([-1]) + num_k = len(combine_weights) if isinstance(combine_weights, (list, tuple)) else combine_weights.shape[-1] + x = paddle.gather(x, scatter_index).reshape([-1, num_k, dim]) # [seq,2,dim] + if isinstance(combine_weights, (list, tuple)): + combine_weights = paddle.concat(combine_weights, -1).unsqueeze([1]) + return paddle.matmul(combine_weights, x).squeeze(1) # [seq,1,2] @ [seq,2,dim] -> [seq,1,dim] + + +class _AllToAll(paddle.autograd.PyLayer): + @staticmethod + def forward( + ctx: Any, + input: Tensor, + group: Group, + ) -> Tensor: # type: ignore + """ + All-to-all communication in the group. + + Args: + ctx (Any): Context object. + input (Tensor): Input tensor. + group (Group): The group object. + + Returns: + Tensor: Output tensor. + """ + + ctx.group = group + # return input + if dist.get_world_size(group) <= 1: + return input + output = paddle.empty_like(input) + stream.alltoall_single(output, input, None, None, group, True, True) + return output + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> Tuple[Tensor]: + """ + Aggregates gradient information from all input tensors into a single tensor. + + Args: + ctx (Any): The context object used to store information that needs to be passed. + *grad_output (Tensor): A list of input tensors whose gradients are to be aggregated. + + Returns: + Tuple[Tensor]: A tuple containing a tensor that holds the gradients of all input tensors. + + """ + # return grad_output + return _AllToAll.apply(*grad_output, ctx.group) + + +class LocalPart(dist.LocalLayer): + def __init__(self, out_dist_attrs, config, gate: PretrainedMoEGate): + print("==== out_dist_attrs ====") + print(out_dist_attrs) + super().__init__(out_dist_attrs) + self.config = config + self.gate = gate + + def forward(self, hidden_state, gate_weight, used_token=None): + # Implement Algorithm 2 from GShard paper. + batch_size, seq_len, d_model = hidden_state.shape + + # Initial implementation -> Reshape into S tokens by dropping sequence dimension. + # Reshape into G groups so that each group can distribute tokens equally + # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1 + reshaped_input = hidden_state.reshape([-1, d_model]) + print("==== reshaped_input ===") + print(reshaped_input) + + _, h_dim = reshaped_input.shape + + # compute gating score + logits = F.linear(reshaped_input, gate_weight, None) + print("==== logits ====") + + with paddle.amp.auto_cast(False): + scores = self.gate.gate_score_func(logits=logits) + scores = scores.cast(paddle.get_default_dtype()) + + print("==== scores ====") + print(scores) + # capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss = self.topkgating(scores) + capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss = self.gate.topkgating(scores) + print("==== combine_weights ====") + print(combine_weights) + print("==== dispatch_mask ====") + print(dispatch_mask) + + # self.l_aux : + # combine_weights : sec + # dispatch_mask : sec + # self.exp_counts : + dispatched_input = paddle.einsum("sec,sm->ecm", paddle.cast(dispatch_mask, hidden_state.dtype), reshaped_input) + + return dispatched_input, combine_weights, l_aux, l_zloss + + +class LocalCombine(dist.LocalLayer): + def __init__(self, out_dist_attrs): + super().__init__(out_dist_attrs) + + def forward(self, combine_weights, expert_output, dtype="float32"): + combined_output = paddle.einsum("sec,ecm->sm", combine_weights.cast(dtype), expert_output) + combined_output.register_hook(lambda grad: print(grad, "combined_output_in_local_combine.grad")) + return combined_output + + +def get_mesh(pp_idx=0): + """ + 获得pp_idx的mesh + """ + mesh = dist.fleet.auto.get_mesh() + if "pp" in mesh.dim_names: + mesh = mesh.get_mesh_with_dim("pp", pp_idx) + return mesh + + +class MoELayer(nn.Layer): + def __init__( + self, + config, + moe_num_experts: int, + expert_class: nn.Layer, + expert_kwargs: dict, + gate: PretrainedMoEGate, + capacity: int = 1.0, + moe_group: str = "data", + all_to_all_dropout=0.0, + ): + super().__init__() + + self.config = config + + print(f"moe_num_experts:{moe_num_experts}") + self.moe_num_experts = moe_num_experts + self.capacity = capacity + self.expert_parallel_degree = 1 + + self.all_to_all_dropout = all_to_all_dropout + self.enable_recompute = False + + self.experts = nn.LayerList([]) + for i in range(self.moe_num_experts): + self.experts.append(expert_class(**expert_kwargs)) + + self.moe_num_experts_per_device = self._parse_moe_expert_parallel( + self.moe_num_experts, self.expert_parallel_degree + ) + self.moe_group = None + self.gate = gate + self.gate.group = self.moe_group + self.is_dummy_moe = True + self._post_init() + + mesh = get_mesh() + local_out_dist_attrs = [ + (mesh, [dist.Shard(1)]), # dispatched_input [e,c,h] + (mesh, [dist.Shard(0)]), # combine_weights [s,e,c] + (mesh, [dist.Partial()]), # l_aux, scalar + (mesh, [dist.Partial()]), # l_zloss, scalar + ] + self.local_computes = LocalPart(local_out_dist_attrs, config, gate) + + local_combine_dist_attrs = [(mesh, [dist.Shard(0)])] + self.local_combine = LocalCombine(local_combine_dist_attrs) + + def _parse_moe_expert_parallel(self, moe_num_experts, expert_parallel_degree): + assert ( + moe_num_experts >= expert_parallel_degree + ), f"expert moe_num_experts={moe_num_experts} >= moe_world_size={expert_parallel_degree}" + assert ( + moe_num_experts % expert_parallel_degree == 0 + ), f"expert moe_num_experts={moe_num_experts} % moe_world_size={expert_parallel_degree} == 0" + moe_num_experts_per_device = moe_num_experts // expert_parallel_degree + return moe_num_experts_per_device + + def _post_init(self): + for p in self.gate.parameters(): + p.is_gate = True + + for k in self.experts: + if k is not None: + for p in k.parameters(): + p.expert = not self.is_dummy_moe + p.no_sync = not self.is_dummy_moe + # logger.info(f"expert param={p.name}, no-sync={p.no_sync}") + + def expert_forward(self, dispatched_input): + expert_outputs = [] + chunks = dispatched_input.unbind(1) + for chunk, expert in zip(chunks, self.experts): + chunk = chunk.contiguous() + expert_outputs += [expert(chunk)] + expert_output = paddle.stack(expert_outputs, axis=1) # [ecm] + return expert_output + + def forward( + self, + hidden_state: paddle.Tensor, + used_token: paddle.Tensor = None, + ): + """_summary_ + + Args: + input (_type_): _description_ + used_token + + Returns: + _type_: _description_ + """ + # Implement Algorithm 2 from GShard paper. + batch_size, seq_len, d_model = hidden_state.shape + + # Initial implementation -> Reshape into S tokens by dropping sequence dimension. + # Reshape into G groups so that each group can distribute tokens equally + # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1 + # reshaped_input = hidden_state.reshape([-1, d_model]) + # reshaped_input = dist.reshard(reshaped_input, reshaped_input.process_mesh, [dist.Replicate(), dist.Replicate()]) + # print("==== reshaped_input ====") + # print(reshaped_input) + + # capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss = self.gate(reshaped_input) + # print("==== combine_weights ====") + # print(combine_weights) + + # # self.l_aux : + # # combine_weights : sec + # # dispatch_mask : sec + # # self.exp_counts : + # dispatched_input = paddle.einsum("sec,sm->ecm", paddle.cast(dispatch_mask, hidden_state.dtype), reshaped_input) + # print("==== dispatched_input ====") + # print(dispatched_input) + + hidden_state.register_hook(lambda grad: print_grad(grad, "hidden_state.grad")) + dispatched_input, combine_weights, l_aux, l_zloss = self.local_computes( + hidden_state, self.gate.weight, used_token=used_token + ) + + # dispatched_input = dist.reshard(dispatched_input, get_mesh(), [dist.Shard(0)]) + # if self.expert_parallel_degree > 1: + # dispatched_input = _AllToAll.apply(dispatched_input, self.moe_group) + + dispatched_input.register_hook(lambda grad: print_grad(grad, "dispatched_input.grad")) + # Re-shape after all-to-all: ecm -> gecm + dispatched_input = dispatched_input.reshape( + [self.expert_parallel_degree, self.moe_num_experts_per_device, -1, d_model] + ) + dispatched_input.register_hook(lambda grad: print_grad(grad, "dispatched_input_after_reshape.grad")) + expert_output = self.expert_forward(dispatched_input) + expert_output.register_hook(lambda grad: print_grad(grad, "expert_output.grad")) + # Re-shape before drop_tokens: gecm -> ecm + expert_output = expert_output.reshape( + [self.expert_parallel_degree * self.moe_num_experts_per_device, -1, d_model] + ) + + # expert_output = dist.reshard(expert_output, get_mesh(), [dist.Shard(1)]) + # if self.expert_parallel_degree > 1: + # expert_output = _AllToAll.apply(expert_output, self.moe_group) + + # combine withe expert weights + # Einsum infermeta has not supported auto parallel dist tensor, + # so use local layer here. + # combined_output = paddle.einsum("sec,ecm->sm", combine_weights.cast(hidden_state[0].dtype), expert_output) + combine_weights.register_hook(lambda grad: print_grad(grad, "combine_weights.grad")) + expert_output.register_hook(lambda grad: print_grad(grad, "expert_output.grad")) + combined_output = self.local_combine(combine_weights, expert_output, dtype=hidden_state[0].dtype) + print("==== combined_output ====") + print(combined_output) + + combined_output.register_hook(lambda grad: print_grad(grad, "combined_output.grad")) + a = combined_output.reshape(hidden_state.shape) + print("==== a ====") + print(a) + a.register_hook(lambda grad: print_grad(grad, "a.grad")) + + return a, l_aux, l_zloss From 986c1de0cb79f6cfd324509318b733fdfc7ce659 Mon Sep 17 00:00:00 2001 From: Yichen Zhang Date: Thu, 20 Feb 2025 15:13:48 +0800 Subject: [PATCH 2/3] add local layer in MoELayer --- paddlenlp/transformers/auto_utils.py | 27 +++ .../transformers/deepseek_v2/modeling_auto.py | 7 +- paddlenlp/transformers/moe_gate_auto.py | 109 ++++++++-- paddlenlp/transformers/moe_layer_auto.py | 196 ++++-------------- 4 files changed, 165 insertions(+), 174 deletions(-) create mode 100644 paddlenlp/transformers/auto_utils.py diff --git a/paddlenlp/transformers/auto_utils.py b/paddlenlp/transformers/auto_utils.py new file mode 100644 index 000000000000..b50ff3a9007a --- /dev/null +++ b/paddlenlp/transformers/auto_utils.py @@ -0,0 +1,27 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.distributed as dist + + +def get_mesh(pp_idx=0): + """ + 获得pp_idx的mesh + """ + mesh = dist.fleet.auto.get_mesh() + print("==== mesh ====") + print(mesh) + if "pp" in mesh.dim_names: + mesh = mesh.get_mesh_with_dim("pp", pp_idx) + return mesh diff --git a/paddlenlp/transformers/deepseek_v2/modeling_auto.py b/paddlenlp/transformers/deepseek_v2/modeling_auto.py index 0714a2a24ad8..241462c40b04 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling_auto.py +++ b/paddlenlp/transformers/deepseek_v2/modeling_auto.py @@ -25,11 +25,14 @@ from typing import List, Optional, Tuple, Union import paddle +import paddle.distributed as dist import paddle.nn.functional as F from paddle import Tensor, nn from paddle.distributed.fleet.utils import recompute from paddle.nn import Linear +from .auto_utils import get_mesh + try: from paddle.incubate.nn.functional import fused_rotary_position_embedding except ImportError: @@ -224,7 +227,9 @@ def forward(ctx, x, loss): def backward(ctx, grad_output): grad_loss = None if ctx.required_aux_loss: - grad_loss = paddle.ones(1, dtype=ctx.dtype) + # grad_loss = paddle.ones(1, dtype=ctx.dtype) + grad_loss = paddle.to_tensor(1, dtype=ctx.dtype) + grad_loss = dist.shard_tensor(grad_loss, get_mesh(), [dist.Partial(dist.ReduceType.kRedAvg)]) return grad_output, grad_loss diff --git a/paddlenlp/transformers/moe_gate_auto.py b/paddlenlp/transformers/moe_gate_auto.py index 0cd9fd62a8e7..8ef1b16878e9 100644 --- a/paddlenlp/transformers/moe_gate_auto.py +++ b/paddlenlp/transformers/moe_gate_auto.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations -import traceback from typing import Tuple import paddle @@ -160,6 +159,7 @@ def __init__(self, config, num_experts, expert_hidden_size, **kwargs): self.num_experts = num_experts self.expert_hidden_size = expert_hidden_size + self.expert_parallel_degree = kwargs.pop("expert_parallel_degree", 1) # force keep in float32 when using amp self._cast_to_low_precision = False @@ -289,7 +289,7 @@ def _topk_group_limited_greedy( return topk_weight, topk_idx def _topk_noaux_tc( - self, scores: paddle.Tensor, k: int, n_group: int, topk_group: int + self, scores: paddle.Tensor, e_score_correction_bias, k: int, n_group: int, topk_group: int ) -> Tuple[paddle.Tensor, paddle.Tensor]: """_summary_ @@ -309,13 +309,9 @@ def _topk_noaux_tc( bsz_seq_len, n_experts = scores.shape assert n_experts % n_group == 0, "n_experts must be divisible by n_groups" - assert self.e_score_correction_bias is not None, "e_score_correction_bias is None" - if self.e_score_correction_bias.is_dist(): - local_e_score_correction_bias = dist.auto_parallel.api.dtensor_to_local(self.e_score_correction_bias) - else: - local_e_score_correction_bias = self.e_score_correction_bias + assert e_score_correction_bias is not None, "e_score_correction_bias is None" # scores = scores.reshape([bsz_seq_len, -1]) + self.e_score_correction_bias.unsqueeze(0) - scores = scores.reshape([bsz_seq_len, -1]) + local_e_score_correction_bias.unsqueeze(0) + scores = scores.reshape([bsz_seq_len, -1]) + e_score_correction_bias.unsqueeze(0) group_scores = scores.reshape([bsz_seq_len, self.n_group, -1]).topk(2, axis=-1)[0].sum(axis=-1) # [n, n_group] group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=False)[1] # [n, top_k_group] group_mask = paddle.zeros_like(group_scores).put_along_axis(group_idx, paddle.to_tensor(1.0), axis=-1) # fmt:skip @@ -337,8 +333,6 @@ def top1gating( if self.noisy_gate_policy == "RSample": logits += self.gumbel_rsample(logits.shape) - print("==== top1 ====") - traceback.print_stack() gates = self.gate_score_func(logits=logits) capacity = self._capacity(gates, self.capacity_factor, self.max_capacity, self.min_capacity) @@ -408,8 +402,6 @@ def top2gating( logits: paddle.Tensor, ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: # everything is in fp32 in this function - print("==== top2 ====") - traceback.print_stack() gates = self.gate_score_func(logits=logits) # Create a mask for 1st's expert per token. @@ -486,8 +478,6 @@ def topkgating( gates: paddle.Tensor, ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: """Implements TopKGating on logits.""" - print("==== gates ====") - print(gates) l_zloss = self._cal_z_loss(gates) # get topk gates @@ -499,12 +489,9 @@ def topkgating( ) elif self.topk_method == "noaux_tc": top_gate, top_idx = self._topk_noaux_tc( - gates, k=self.top_k, n_group=self.n_group, topk_group=self.topk_group + gates, self.e_score_correction_bias, k=self.top_k, n_group=self.n_group, topk_group=self.topk_group ) # norm gate to sum 1 - print("==== top_gate and top_idx ====") - print(top_gate) - print(top_idx) if self.top_k > 1 and self.norm_topk_prob: denominator = top_gate.sum(axis=-1, keepdim=True) + 1e-20 top_gate = top_gate / denominator @@ -513,8 +500,6 @@ def topkgating( # get topk mask mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.to_tensor(1.0), axis=1) - print("==== mask ====") - print(mask) l_aux = self._cal_aux_loss(gates, mask) exp_counts = paddle.sum(mask.cast(paddle.int64), axis=0) @@ -544,7 +529,6 @@ def topkgating( if self.group is not None: dist.all_reduce(local_capacity, op=dist.ReduceOp.MAX, group=self.group) capacity = int(local_capacity) - token_priority = self._priority(top_idx, capacity) # normalize gates gates_masked = gates * mask @@ -553,7 +537,90 @@ def topkgating( if self.norm_topk_prob: gates_masked = gates_masked / denom_s + if not self.drop_tokens: + locations = paddle.cumsum(mask, axis=0) - 1 + token_priority = self._one_hot_to_float(locations * mask, capacity) + combine_weights = paddle.einsum("se,sec->sec", gates_masked, token_priority.cast(paddle.get_default_dtype())) dispatch_mask = combine_weights.cast(paddle.bool) return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss + + def topkgating_part1(self, gates, e_score_correction_bias): + l_zloss = self._cal_z_loss(gates) + + # get topk gates + if self.topk_method == "greedy": + top_gate, top_idx = self._topk_greedy(gates, k=self.top_k) + elif self.topk_method == "group_limited_greedy": + top_gate, top_idx = self._topk_group_limited_greedy( + gates, k=self.top_k, n_group=self.n_group, topk_group=self.topk_group + ) + elif self.topk_method == "noaux_tc": + top_gate, top_idx = self._topk_noaux_tc( + gates, e_score_correction_bias, k=self.top_k, n_group=self.n_group, topk_group=self.topk_group + ) + # norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = top_gate.sum(axis=-1, keepdim=True) + 1e-20 + top_gate = top_gate / denominator + else: + top_gate = top_gate * self.routed_scaling_factor + + # get topk mask + mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.to_tensor(1.0), axis=1) + l_aux = self._cal_aux_loss(gates, mask) + + exp_counts = paddle.sum(mask.cast(paddle.int64), axis=0) + + if self.drop_tokens: + # Calculate configured capacity and remove locations outside capacity from mask + capacity = self._capacity( + gates, + self.capacity_factor * self.top_k, + self.max_capacity, + self.min_capacity, + ) + + # update mask and locations by capacity + if self.drop_policy == "probs": + topk_masked_gates = paddle.zeros_like(gates).put_along_axis(top_idx, top_gate, axis=1) + capacity_probs, capacity_indices = paddle.topk(topk_masked_gates, k=capacity, axis=0, sorted=False) + token_priority = self._priority(capacity_indices, capacity) + + elif self.drop_policy == "position": + token_priority = self._priority(top_idx, capacity) + else: + raise ValueError(f"Invalid drop_policy: {self.drop_policy}") + else: + # Do not drop tokens - set capacity according to current expert assignments + # local_capacity = paddle.max(exp_counts) + # capacity = int(local_capacity) + capacity = None + token_priority = None + + # keep these tensor for using in topkgating_part2 + self.mask = mask + self.capacity = capacity + self.token_priority = token_priority + + return exp_counts, l_aux, l_zloss + + def topkgating_part2(self, gates): + + gates_masked = gates * self.mask + gates_s = paddle.sum(gates_masked, axis=-1, keepdim=True) + denom_s = paddle.clip(gates_s, min=paddle.finfo(gates_masked.dtype).eps) + if self.norm_topk_prob: + gates_masked = gates_masked / denom_s + + if not self.drop_tokens: + locations = paddle.cumsum(self.mask, axis=0) - 1 + self.token_priority = self._one_hot_to_float(locations * self.mask, self.capacity) + + combine_weights = paddle.einsum( + "se,sec->sec", gates_masked, self.token_priority.cast(paddle.get_default_dtype()) + ) + dispatch_mask = combine_weights.cast(paddle.bool) + + return combine_weights, dispatch_mask diff --git a/paddlenlp/transformers/moe_layer_auto.py b/paddlenlp/transformers/moe_layer_auto.py index 61528db66188..6a2d7115cf84 100644 --- a/paddlenlp/transformers/moe_layer_auto.py +++ b/paddlenlp/transformers/moe_layer_auto.py @@ -15,23 +15,15 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Tuple - import paddle import paddle.distributed as dist import paddle.nn.functional as F -from paddle import Tensor, nn -from paddle.distributed.communication import stream -from paddle.distributed.communication.group import Group +from paddle import nn +from .auto_utils import get_mesh from .moe_gate_auto import PretrainedMoEGate -def print_grad(g, name): - print(f"==== {name} ====") - print(g) - - def dispatching(x, dispatch_mask, scatter_index, num_experts, capacity): """ Rearranges the input tensor `x` based on gate results, truncates it according to the specified capacity, and performs padding. @@ -96,117 +88,66 @@ def combining(x, combine_weights, scatter_index): return paddle.matmul(combine_weights, x).squeeze(1) # [seq,1,2] @ [seq,2,dim] -> [seq,1,dim] -class _AllToAll(paddle.autograd.PyLayer): - @staticmethod - def forward( - ctx: Any, - input: Tensor, - group: Group, - ) -> Tensor: # type: ignore - """ - All-to-all communication in the group. - - Args: - ctx (Any): Context object. - input (Tensor): Input tensor. - group (Group): The group object. - - Returns: - Tensor: Output tensor. - """ - - ctx.group = group - # return input - if dist.get_world_size(group) <= 1: - return input - output = paddle.empty_like(input) - stream.alltoall_single(output, input, None, None, group, True, True) - return output - - @staticmethod - def backward(ctx: Any, *grad_output: Tensor) -> Tuple[Tensor]: - """ - Aggregates gradient information from all input tensors into a single tensor. - - Args: - ctx (Any): The context object used to store information that needs to be passed. - *grad_output (Tensor): A list of input tensors whose gradients are to be aggregated. - - Returns: - Tuple[Tensor]: A tuple containing a tensor that holds the gradients of all input tensors. - - """ - # return grad_output - return _AllToAll.apply(*grad_output, ctx.group) - - -class LocalPart(dist.LocalLayer): - def __init__(self, out_dist_attrs, config, gate: PretrainedMoEGate): - print("==== out_dist_attrs ====") - print(out_dist_attrs) +class LocalGatePart1(dist.LocalLayer): + def __init__(self, config, gate: PretrainedMoEGate, ipp=0): + mesh = get_mesh(ipp) + out_dist_attrs = [ + (mesh, [dist.Shard(0), dist.Replicate()]), # reshaped_input [b*s, h] + (mesh, [dist.Shard(0), dist.Replicate()]), # scores [b*s, e] + (mesh, [dist.Partial(dist.ReduceType.kRedMax)]), # expert_counts [e] + (mesh, [dist.Partial(dist.ReduceType.kRedAvg)]), # l_aux, scalar + (mesh, [dist.Partial(dist.ReduceType.kRedAvg)]), # l_zloss, scalar + ] super().__init__(out_dist_attrs) self.config = config self.gate = gate - def forward(self, hidden_state, gate_weight, used_token=None): + def forward(self, hidden_state, gate_weight, e_score_correction_bias, used_token=None): # Implement Algorithm 2 from GShard paper. batch_size, seq_len, d_model = hidden_state.shape - - # Initial implementation -> Reshape into S tokens by dropping sequence dimension. - # Reshape into G groups so that each group can distribute tokens equally - # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1 reshaped_input = hidden_state.reshape([-1, d_model]) - print("==== reshaped_input ===") - print(reshaped_input) - _, h_dim = reshaped_input.shape # compute gating score logits = F.linear(reshaped_input, gate_weight, None) - print("==== logits ====") - with paddle.amp.auto_cast(False): scores = self.gate.gate_score_func(logits=logits) scores = scores.cast(paddle.get_default_dtype()) - print("==== scores ====") - print(scores) - # capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss = self.topkgating(scores) - capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss = self.gate.topkgating(scores) - print("==== combine_weights ====") - print(combine_weights) - print("==== dispatch_mask ====") - print(dispatch_mask) + exp_counts, l_aux, l_zloss = self.gate.topkgating_part1(scores, e_score_correction_bias) - # self.l_aux : - # combine_weights : sec - # dispatch_mask : sec - # self.exp_counts : - dispatched_input = paddle.einsum("sec,sm->ecm", paddle.cast(dispatch_mask, hidden_state.dtype), reshaped_input) + return reshaped_input, scores, exp_counts, l_aux, l_zloss - return dispatched_input, combine_weights, l_aux, l_zloss + +class LocalGateAndDispatch(dist.LocalLayer): + def __init__(self, gate: PretrainedMoEGate, ipp=0): + mesh = get_mesh(ipp) + out_dist_attrs = [ + (mesh, [dist.Shard(1), dist.Replicate()]), # dispatched_input [e,c,h] + (mesh, [dist.Shard(0), dist.Replicate()]), # combine_weights [s,e,c] + ] + super().__init__(out_dist_attrs) + self.gate = gate + + def forward(self, reshaped_input, scores): + combine_weights, dispatch_mask = self.gate.topkgating_part2(scores) + dispatched_input = paddle.einsum( + "sec,sm->ecm", paddle.cast(dispatch_mask, reshaped_input.dtype), reshaped_input + ) + return dispatched_input, combine_weights class LocalCombine(dist.LocalLayer): - def __init__(self, out_dist_attrs): + def __init__(self, ipp=0): + mesh = get_mesh(ipp) + out_dist_attrs = [(mesh, [dist.Shard(0)])] super().__init__(out_dist_attrs) def forward(self, combine_weights, expert_output, dtype="float32"): combined_output = paddle.einsum("sec,ecm->sm", combine_weights.cast(dtype), expert_output) - combined_output.register_hook(lambda grad: print(grad, "combined_output_in_local_combine.grad")) return combined_output -def get_mesh(pp_idx=0): - """ - 获得pp_idx的mesh - """ - mesh = dist.fleet.auto.get_mesh() - if "pp" in mesh.dim_names: - mesh = mesh.get_mesh_with_dim("pp", pp_idx) - return mesh - - class MoELayer(nn.Layer): def __init__( self, @@ -218,12 +159,12 @@ def __init__( capacity: int = 1.0, moe_group: str = "data", all_to_all_dropout=0.0, + ipp: int = 0, ): super().__init__() self.config = config - print(f"moe_num_experts:{moe_num_experts}") self.moe_num_experts = moe_num_experts self.capacity = capacity self.expert_parallel_degree = 1 @@ -244,17 +185,9 @@ def __init__( self.is_dummy_moe = True self._post_init() - mesh = get_mesh() - local_out_dist_attrs = [ - (mesh, [dist.Shard(1)]), # dispatched_input [e,c,h] - (mesh, [dist.Shard(0)]), # combine_weights [s,e,c] - (mesh, [dist.Partial()]), # l_aux, scalar - (mesh, [dist.Partial()]), # l_zloss, scalar - ] - self.local_computes = LocalPart(local_out_dist_attrs, config, gate) - - local_combine_dist_attrs = [(mesh, [dist.Shard(0)])] - self.local_combine = LocalCombine(local_combine_dist_attrs) + self.local_gate_part1 = LocalGatePart1(config, gate, ipp) + self.local_gate_and_dispatch = LocalGateAndDispatch(gate, ipp) + self.local_combine = LocalCombine(ipp) def _parse_moe_expert_parallel(self, moe_num_experts, expert_parallel_degree): assert ( @@ -303,66 +236,25 @@ def forward( # Implement Algorithm 2 from GShard paper. batch_size, seq_len, d_model = hidden_state.shape - # Initial implementation -> Reshape into S tokens by dropping sequence dimension. - # Reshape into G groups so that each group can distribute tokens equally - # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1 - # reshaped_input = hidden_state.reshape([-1, d_model]) - # reshaped_input = dist.reshard(reshaped_input, reshaped_input.process_mesh, [dist.Replicate(), dist.Replicate()]) - # print("==== reshaped_input ====") - # print(reshaped_input) - - # capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss = self.gate(reshaped_input) - # print("==== combine_weights ====") - # print(combine_weights) - - # # self.l_aux : - # # combine_weights : sec - # # dispatch_mask : sec - # # self.exp_counts : - # dispatched_input = paddle.einsum("sec,sm->ecm", paddle.cast(dispatch_mask, hidden_state.dtype), reshaped_input) - # print("==== dispatched_input ====") - # print(dispatched_input) - - hidden_state.register_hook(lambda grad: print_grad(grad, "hidden_state.grad")) - dispatched_input, combine_weights, l_aux, l_zloss = self.local_computes( - hidden_state, self.gate.weight, used_token=used_token + reshaped_input, gate_scores, exp_counts, l_aux, l_zloss = self.local_gate_part1( + hidden_state, self.gate.weight, self.gate.e_score_correction_bias, used_token=used_token ) + if self.gate.drop_tokens is False: + self.gate.capacity = int(paddle.max(exp_counts)) + dispatched_input, combine_weights = self.local_gate_and_dispatch(reshaped_input, gate_scores) - # dispatched_input = dist.reshard(dispatched_input, get_mesh(), [dist.Shard(0)]) - # if self.expert_parallel_degree > 1: - # dispatched_input = _AllToAll.apply(dispatched_input, self.moe_group) - - dispatched_input.register_hook(lambda grad: print_grad(grad, "dispatched_input.grad")) # Re-shape after all-to-all: ecm -> gecm dispatched_input = dispatched_input.reshape( [self.expert_parallel_degree, self.moe_num_experts_per_device, -1, d_model] ) - dispatched_input.register_hook(lambda grad: print_grad(grad, "dispatched_input_after_reshape.grad")) expert_output = self.expert_forward(dispatched_input) - expert_output.register_hook(lambda grad: print_grad(grad, "expert_output.grad")) # Re-shape before drop_tokens: gecm -> ecm expert_output = expert_output.reshape( [self.expert_parallel_degree * self.moe_num_experts_per_device, -1, d_model] ) - # expert_output = dist.reshard(expert_output, get_mesh(), [dist.Shard(1)]) - # if self.expert_parallel_degree > 1: - # expert_output = _AllToAll.apply(expert_output, self.moe_group) - - # combine withe expert weights - # Einsum infermeta has not supported auto parallel dist tensor, - # so use local layer here. - # combined_output = paddle.einsum("sec,ecm->sm", combine_weights.cast(hidden_state[0].dtype), expert_output) - combine_weights.register_hook(lambda grad: print_grad(grad, "combine_weights.grad")) - expert_output.register_hook(lambda grad: print_grad(grad, "expert_output.grad")) combined_output = self.local_combine(combine_weights, expert_output, dtype=hidden_state[0].dtype) - print("==== combined_output ====") - print(combined_output) - combined_output.register_hook(lambda grad: print_grad(grad, "combined_output.grad")) a = combined_output.reshape(hidden_state.shape) - print("==== a ====") - print(a) - a.register_hook(lambda grad: print_grad(grad, "a.grad")) return a, l_aux, l_zloss From 8bfa877c3c7bd1dffc7f11c8b90054def6faa40e Mon Sep 17 00:00:00 2001 From: Yichen Zhang Date: Mon, 24 Feb 2025 10:54:33 +0800 Subject: [PATCH 3/3] add expert parallel with dygraph auto parallel --- paddlenlp/transformers/auto_utils.py | 38 +++++++++- .../transformers/deepseek_v2/modeling_auto.py | 18 ++++- paddlenlp/transformers/moe_gate_auto.py | 39 ++++++++-- paddlenlp/transformers/moe_layer_auto.py | 76 +++++++++++++++---- 4 files changed, 146 insertions(+), 25 deletions(-) diff --git a/paddlenlp/transformers/auto_utils.py b/paddlenlp/transformers/auto_utils.py index b50ff3a9007a..7ddd35f8080a 100644 --- a/paddlenlp/transformers/auto_utils.py +++ b/paddlenlp/transformers/auto_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import paddle import paddle.distributed as dist @@ -20,8 +21,41 @@ def get_mesh(pp_idx=0): 获得pp_idx的mesh """ mesh = dist.fleet.auto.get_mesh() - print("==== mesh ====") - print(mesh) if "pp" in mesh.dim_names: mesh = mesh.get_mesh_with_dim("pp", pp_idx) return mesh + + +def einsum(rule, a, b): + """ + Use other ops to replace einsum. The implementation + is from https://github.com/deepspeedai/DeepSpeed. + """ + if rule == "s,se->se": + return a.reshape([a.shape[0], -1]) * b + elif rule == "se,sc->sec": + return a.unsqueeze(2) * b.unsqueeze(1) + elif rule == "se,se->s": + return paddle.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1) + elif rule == "se,sec->sec": + return paddle.unsqueeze(a, axis=2) * b + elif rule == "sec,sm->ecm": + s = a.shape[0] + e = a.shape[1] + c = a.shape[2] + m = b.shape[1] + return paddle.matmul(a.reshape([s, -1]).t(), b).reshape([e, c, m]) + elif rule == "sec,ecm->sm": + return paddle.matmul(a.reshape([a.shape[0], -1]), b.reshape([-1, b.shape[-1]])) + elif rule == "ks,ksm->sm": + k = b.shape[0] + s = b.shape[1] + m = b.shape[2] + # [k, s] -> [s, k] -> [s, 1, k] + a = a.t().unsqueeze(1) + # [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k] + b = b.reshape([k, -1]).t().reshape([s, m, k]) + # bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1] + return paddle.bmm(a, b.transpose(1, 2)).squeeze(2) + else: + return paddle.einsum(rule, a, b) diff --git a/paddlenlp/transformers/deepseek_v2/modeling_auto.py b/paddlenlp/transformers/deepseek_v2/modeling_auto.py index 241462c40b04..41fa11fc344e 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling_auto.py +++ b/paddlenlp/transformers/deepseek_v2/modeling_auto.py @@ -31,7 +31,7 @@ from paddle.distributed.fleet.utils import recompute from paddle.nn import Linear -from .auto_utils import get_mesh +from ..auto_utils import get_mesh try: from paddle.incubate.nn.functional import fused_rotary_position_embedding @@ -246,6 +246,22 @@ def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size self.act_fn = ACT2FN[config.hidden_act] + def redistribute_expert(self, mesh, placements): + """ + Place the experts on different devices. + """ + self.gate_proj.weight = dist.shard_tensor(self.gate_proj.weight, mesh, placements) + if self.gate_proj.bias is not None: + self.gate_proj.bias = dist.shard_tensor(self.gate_proj.bias, mesh, placements) + + self.up_proj.weight = dist.shard_tensor(self.up_proj.weight, mesh, placements) + if self.up_proj.bias is not None: + self.up_proj.bias = dist.shard_tensor(self.up_proj.bias, mesh, placements) + + self.down_proj.weight = dist.shard_tensor(self.down_proj.weight, mesh, placements) + if self.down_proj.bias is not None: + self.down_proj.bias = dist.shard_tensor(self.down_proj.bias, mesh, placements) + def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj diff --git a/paddlenlp/transformers/moe_gate_auto.py b/paddlenlp/transformers/moe_gate_auto.py index 8ef1b16878e9..c154a15204d8 100644 --- a/paddlenlp/transformers/moe_gate_auto.py +++ b/paddlenlp/transformers/moe_gate_auto.py @@ -21,6 +21,7 @@ import paddle.nn.functional as F from ..utils.log import logger +from .auto_utils import einsum class MoEGateMixin: @@ -90,7 +91,6 @@ def _capacity( # gates has shape of SE num_tokens = gates.shape[0] num_experts = gates.shape[1] - print(f"==== num_tokens:{num_tokens}, num_experts:{num_experts} ====") capacity = int((num_tokens // num_experts) * capacity_factor) if capacity < min_capacity: capacity = min_capacity @@ -127,6 +127,22 @@ def _cal_aux_loss(self, gates, mask): aux_loss = paddle.sum(me * ce) * float(self.num_experts) return aux_loss + def _cal_seq_aux_loss(self, gates, top_k, topk_idx) -> paddle.Tensor: + """ + Calculate sequence auxiliary loss. + Args: + logits (paddle.Tensor): Model output. + Returns: + paddle.Tensor: The value of sequence auxiliary loss. + """ + batch_size, seq_len, _ = gates.shape + ce = paddle.zeros([batch_size, self.num_experts]) + topk_idx = topk_idx.reshape([batch_size, -1]) + ce.put_along_axis_(indices=topk_idx, values=paddle.ones([batch_size, seq_len * top_k]), axis=1) + ce = ce / (seq_len * top_k / self.num_experts) + aux_loss = (ce * paddle.mean(gates, axis=1)).sum(axis=1).mean() + return aux_loss + def _cal_z_loss(self, logits) -> paddle.Tensor: """ Calculate the z loss. @@ -478,6 +494,10 @@ def topkgating( gates: paddle.Tensor, ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: """Implements TopKGating on logits.""" + batch_size, seq_len, d_model = gates.shape + gates_ori = gates + gates = gates.reshape([-1, d_model]) + l_zloss = self._cal_z_loss(gates) # get topk gates @@ -500,7 +520,10 @@ def topkgating( # get topk mask mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.to_tensor(1.0), axis=1) - l_aux = self._cal_aux_loss(gates, mask) + if self.config.seq_aux: + l_aux = self._cal_seq_aux_loss(gates_ori, self.top_k, top_idx) + else: + l_aux = self._cal_aux_loss(gates, mask) exp_counts = paddle.sum(mask.cast(paddle.int64), axis=0) @@ -548,6 +571,9 @@ def topkgating( def topkgating_part1(self, gates, e_score_correction_bias): l_zloss = self._cal_z_loss(gates) + batch_size, seq_len, d_model = gates.shape + gates_ori = gates + gates = gates.reshape([-1, d_model]) # get topk gates if self.topk_method == "greedy": @@ -569,7 +595,10 @@ def topkgating_part1(self, gates, e_score_correction_bias): # get topk mask mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.to_tensor(1.0), axis=1) - l_aux = self._cal_aux_loss(gates, mask) + if self.config.seq_aux: + l_aux = self._cal_seq_aux_loss(gates_ori, self.top_k, top_idx) + else: + l_aux = self._cal_aux_loss(gates, mask) exp_counts = paddle.sum(mask.cast(paddle.int64), axis=0) @@ -618,9 +647,7 @@ def topkgating_part2(self, gates): locations = paddle.cumsum(self.mask, axis=0) - 1 self.token_priority = self._one_hot_to_float(locations * self.mask, self.capacity) - combine_weights = paddle.einsum( - "se,sec->sec", gates_masked, self.token_priority.cast(paddle.get_default_dtype()) - ) + combine_weights = einsum("se,sec->sec", gates_masked, self.token_priority.cast(paddle.get_default_dtype())) dispatch_mask = combine_weights.cast(paddle.bool) return combine_weights, dispatch_mask diff --git a/paddlenlp/transformers/moe_layer_auto.py b/paddlenlp/transformers/moe_layer_auto.py index 6a2d7115cf84..c8cbbc81d5de 100644 --- a/paddlenlp/transformers/moe_layer_auto.py +++ b/paddlenlp/transformers/moe_layer_auto.py @@ -15,15 +15,22 @@ # limitations under the License. from __future__ import annotations +import copy + import paddle import paddle.distributed as dist import paddle.nn.functional as F from paddle import nn -from .auto_utils import get_mesh +from .auto_utils import einsum, get_mesh from .moe_gate_auto import PretrainedMoEGate +def print_grad(grad, name): + print(f"==== {name} ====") + print(grad) + + def dispatching(x, dispatch_mask, scatter_index, num_experts, capacity): """ Rearranges the input tensor `x` based on gate results, truncates it according to the specified capacity, and performs padding. @@ -106,17 +113,17 @@ def forward(self, hidden_state, gate_weight, e_score_correction_bias, used_token # Implement Algorithm 2 from GShard paper. batch_size, seq_len, d_model = hidden_state.shape reshaped_input = hidden_state.reshape([-1, d_model]) - _, h_dim = reshaped_input.shape # compute gating score - logits = F.linear(reshaped_input, gate_weight, None) + logits = F.linear(hidden_state, gate_weight, None) with paddle.amp.auto_cast(False): scores = self.gate.gate_score_func(logits=logits) scores = scores.cast(paddle.get_default_dtype()) exp_counts, l_aux, l_zloss = self.gate.topkgating_part1(scores, e_score_correction_bias) - return reshaped_input, scores, exp_counts, l_aux, l_zloss + reshaped_scores = scores.reshape([-1, scores.shape[-1]]) + return reshaped_input, reshaped_scores, exp_counts, l_aux, l_zloss class LocalGateAndDispatch(dist.LocalLayer): @@ -131,9 +138,10 @@ def __init__(self, gate: PretrainedMoEGate, ipp=0): def forward(self, reshaped_input, scores): combine_weights, dispatch_mask = self.gate.topkgating_part2(scores) - dispatched_input = paddle.einsum( - "sec,sm->ecm", paddle.cast(dispatch_mask, reshaped_input.dtype), reshaped_input - ) + # dispatched_input = paddle.einsum( + # "sec,sm->ecm", paddle.cast(dispatch_mask, reshaped_input.dtype), reshaped_input + # ) + dispatched_input = einsum("sec,sm->ecm", paddle.cast(dispatch_mask, reshaped_input.dtype), reshaped_input) return dispatched_input, combine_weights @@ -144,7 +152,7 @@ def __init__(self, ipp=0): super().__init__(out_dist_attrs) def forward(self, combine_weights, expert_output, dtype="float32"): - combined_output = paddle.einsum("sec,ecm->sm", combine_weights.cast(dtype), expert_output) + combined_output = einsum("sec,ecm->sm", combine_weights.cast(dtype), expert_output) return combined_output @@ -167,7 +175,7 @@ def __init__( self.moe_num_experts = moe_num_experts self.capacity = capacity - self.expert_parallel_degree = 1 + self.ipp = ipp self.all_to_all_dropout = all_to_all_dropout self.enable_recompute = False @@ -176,9 +184,11 @@ def __init__( for i in range(self.moe_num_experts): self.experts.append(expert_class(**expert_kwargs)) - self.moe_num_experts_per_device = self._parse_moe_expert_parallel( - self.moe_num_experts, self.expert_parallel_degree + self.expert_parallel_degree, self.moe_num_experts_per_device = self._parse_moe_expert_parallel( + self.moe_num_experts, config ) + self._redistribute_experts(self.experts, config.moe_group) + self.moe_group = None self.gate = gate self.gate.group = self.moe_group @@ -189,15 +199,31 @@ def __init__( self.local_gate_and_dispatch = LocalGateAndDispatch(gate, ipp) self.local_combine = LocalCombine(ipp) - def _parse_moe_expert_parallel(self, moe_num_experts, expert_parallel_degree): + def _redistribute_experts(self, experts, moe_group: str): + if moe_group != "None": + index = 0 if moe_group == "dp" else 1 + self.moe_mesh_dim = index + ep_sub_meshes = dist.auto_parallel.api.split_mesh(get_mesh(self.ipp), index) + for i, expert in enumerate(experts): + ep_group_id = i // self.moe_num_experts_per_device + experts[i].redistribute_expert(ep_sub_meshes[ep_group_id], [dist.Replicate(), dist.Replicate()]) + + def _parse_moe_expert_parallel(self, moe_num_experts, config): + assert config.moe_group in ["dp", "mp", "None"], f"moe_group={config.moe_group} not in ['dp', 'mp', 'None']" + if config.moe_group == "None": + expert_parallel_degree = 1 + else: + expert_parallel_degree = dist.fleet.auto.get_mesh().get_dim_size(config.moe_group) assert ( moe_num_experts >= expert_parallel_degree ), f"expert moe_num_experts={moe_num_experts} >= moe_world_size={expert_parallel_degree}" + assert ( moe_num_experts % expert_parallel_degree == 0 ), f"expert moe_num_experts={moe_num_experts} % moe_world_size={expert_parallel_degree} == 0" moe_num_experts_per_device = moe_num_experts // expert_parallel_degree - return moe_num_experts_per_device + + return expert_parallel_degree, moe_num_experts_per_device def _post_init(self): for p in self.gate.parameters(): @@ -211,12 +237,24 @@ def _post_init(self): # logger.info(f"expert param={p.name}, no-sync={p.no_sync}") def expert_forward(self, dispatched_input): + sub_mesh_tensors = dist.auto_parallel.api.moe_sub_mesh_tensors( + dispatched_input, get_mesh(self.ipp), self.moe_mesh_dim, dispatched_input.placements + ) + chunks = paddle.utils.flatten([t.unbind(1) for t in sub_mesh_tensors]) + + # try to simplify the code below + ep_group_outputs = [] expert_outputs = [] - chunks = dispatched_input.unbind(1) - for chunk, expert in zip(chunks, self.experts): + for i, (chunk, expert) in enumerate(zip(chunks, self.experts)): chunk = chunk.contiguous() expert_outputs += [expert(chunk)] - expert_output = paddle.stack(expert_outputs, axis=1) # [ecm] + if (i + 1) % self.moe_num_experts_per_device == 0: + ep_group_outputs += [paddle.stack(expert_outputs, axis=1)] + expert_outputs = [] + + expert_output = dist.auto_parallel.api.moe_global_mesh_tensor( + ep_group_outputs, get_mesh(self.ipp), dispatched_input.placements, self.moe_mesh_dim + ) return expert_output def forward( @@ -242,6 +280,11 @@ def forward( if self.gate.drop_tokens is False: self.gate.capacity = int(paddle.max(exp_counts)) dispatched_input, combine_weights = self.local_gate_and_dispatch(reshaped_input, gate_scores) + ori_dispatched_placements = copy.deepcopy(dispatched_input.placements) + + ep_placements = copy.deepcopy(dispatched_input.placements) + ep_placements[self.moe_mesh_dim] = dist.Shard(0) + dispatched_input = dist.reshard(dispatched_input, get_mesh(self.ipp), ep_placements) # Re-shape after all-to-all: ecm -> gecm dispatched_input = dispatched_input.reshape( @@ -252,6 +295,7 @@ def forward( expert_output = expert_output.reshape( [self.expert_parallel_degree * self.moe_num_experts_per_device, -1, d_model] ) + expert_output = dist.reshard(expert_output, get_mesh(self.ipp), ori_dispatched_placements) combined_output = self.local_combine(combine_weights, expert_output, dtype=hidden_state[0].dtype)