Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Autoparallel] Mtp for DeepSeekV3 #9884

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
7 changes: 4 additions & 3 deletions llm/auto_parallel/deepseek-v3/run_pretrain_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,13 @@
AutoTokenizer,
CosineAnnealingWithWarmupDecay,
DeepseekV2Config,
DeepseekV2PretrainingCriterion,
DeepseekV3ForCausalLMAuto,
LinearAnnealingWithWarmupDecay,
)
from paddlenlp.utils.log import logger

MODEL_CLASSES = {
"deepseekv3_auto": (DeepseekV2Config, DeepseekV3ForCausalLMAuto, DeepseekV2PretrainingCriterion),
"deepseekv3_auto": (DeepseekV2Config, DeepseekV3ForCausalLMAuto, None),
}


Expand Down Expand Up @@ -555,7 +554,9 @@ def main():

with paddle.LazyGuard():
model = model_class.from_config(config, dtype="float32")
criterion = criterion_class(config)
criterion = None
if criterion_class is not None:
criterion = criterion_class(config)

if training_args.recompute:

Expand Down
4 changes: 3 additions & 1 deletion paddlenlp/transformers/deepseek_v2/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def __init__(
intermediate_size=11008,
moe_intermediate_size=1407,
num_hidden_layers=30,
num_nextn_predict_layers=1,
num_nextn_predict_layers=0,
num_nextn_predict_lambda=1.0,
num_attention_heads=32,
num_key_value_heads=32,
n_shared_experts=None,
Expand Down Expand Up @@ -187,6 +188,7 @@ def __init__(
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_nextn_predict_layers = num_nextn_predict_layers
self.num_nextn_predict_lambda = num_nextn_predict_lambda
self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
Expand Down
234 changes: 192 additions & 42 deletions paddlenlp/transformers/deepseek_v2/modeling.py

Large diffs are not rendered by default.

147 changes: 134 additions & 13 deletions paddlenlp/transformers/deepseek_v2/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,56 @@ def forward(
return outputs


class DeepseekV2MTPLayerAuto(DeepseekV2DecoderLayerAuto):
def __init__(
self,
config: DeepseekV2Config,
layer_idx: int,
layerwise_recompute: bool = False,
):
super(DeepseekV2MTPLayerAuto, self).__init__(config, layer_idx, layerwise_recompute)

self.enorm = DeepseekV2RMSNorm(config)
self.hnorm = DeepseekV2RMSNorm(config)
self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size)

def forward(
self,
hidden_states: paddle.Tensor,
nextn_hidden_state: paddle.Tensor,
position_ids: Optional[paddle.Tensor] = None,
attention_mask: Optional[paddle.Tensor] = None,
output_attentions: Optional[bool] = False,
past_key_value: Optional[Tuple[paddle.Tensor]] = None,
use_cache: Optional[bool] = False,
attn_mask_startend_row_indices: Optional[paddle.Tensor] = None,
**kwargs,
) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]:

hidden_states = self.hnorm(hidden_states)
nextn_hidden_state = self.enorm(nextn_hidden_state)

hidden_states = self.eh_proj(paddle.concat([hidden_states, nextn_hidden_state], axis=-1))

layer_outputs = super(DeepseekV2MTPLayerAuto, self).forward(
hidden_states,
position_ids,
attention_mask,
output_attentions,
past_key_value,
use_cache,
attn_mask_startend_row_indices,
**kwargs,
)

if type(layer_outputs) is tuple:
hidden_states = layer_outputs[0]
else:
hidden_states = layer_outputs

return hidden_states


class DeepseekV2PretrainedModelAuto(PretrainedModel):
config_class = DeepseekV2Config
base_model_prefix = "deepseek_v2"
Expand Down Expand Up @@ -599,6 +649,10 @@ def __init__(self, config: DeepseekV2Config):
for layer_idx in range(config.num_hidden_layers)
]
)

for layer_idx in range(config.num_hidden_layers, config.num_hidden_layers + config.num_nextn_predict_layers):
self.layers.append(DeepseekV2MTPLayerAuto(config, layer_idx, layer_idx not in self.no_recompute_layers))

self.norm = DeepseekV2RMSNorm(config)

self.enable_recompute = False
Expand Down Expand Up @@ -675,6 +729,7 @@ def forward(
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
seq_length -= self.config.num_nextn_predict_layers

if self.enable_recompute and self.training:
if use_cache:
Expand Down Expand Up @@ -720,29 +775,55 @@ def forward(
if self.config.use_flash_attention:
attention_mask = None if is_casual_mask(attention_mask) else attention_mask

if self.config.num_nextn_predict_layers > 0:
inputs_embeds_extra = inputs_embeds[:, -self.config.num_nextn_predict_layers :, :] # [B, S, D]
inputs_embeds = inputs_embeds[:, : -self.config.num_nextn_predict_layers, :]
inputs_embeds_ori = inputs_embeds

# embed positions
hidden_states = inputs_embeds

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
mtp_outputs = []

for idx in range(self.config.num_hidden_layers):
decoder_layer = self.layers[idx]

for idx, (decoder_layer) in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)

past_key_value = past_key_values[idx] if past_key_values is not None else None

layer_outputs = decoder_layer(
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
past_key_value=past_key_value,
use_cache=use_cache,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
)
has_gradient = not hidden_states.stop_gradient
if (
self.enable_recompute
and idx not in self.no_recompute_layers
and has_gradient
and self.recompute_granularity == "full"
):
layer_outputs = self.recompute_training_full(
decoder_layer,
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
past_key_value=past_key_value,
use_cache=use_cache,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
)
else:
layer_outputs = decoder_layer(
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
past_key_value=past_key_value,
use_cache=use_cache,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
)

# NOTE: clear outdate cache after it has been used for memory saving
past_key_value = past_key_values[idx] = None
Expand All @@ -757,7 +838,40 @@ def forward(
if output_attentions:
all_self_attns += (layer_outputs[1],)

hidden_states = self.norm(hidden_states)
if self.config.num_nextn_predict_layers > 0:
mtp_outputs.append(hidden_states)

for nextn in range(self.config.num_nextn_predict_layers):
decoder_layer = self.layers[nextn + self.config.num_hidden_layers]

# 构建输入向量
inputs_embeds_cur_depth = paddle.concat(
[inputs_embeds_ori[:, (nextn + 1) :, :], inputs_embeds_extra[:, : (nextn + 1), :]], axis=1
)

# 通过该层的decoder_layer进行预测
past_key_value = None
layer_outputs = decoder_layer(
hidden_states,
inputs_embeds_cur_depth,
position_ids=position_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
past_key_value=past_key_value,
use_cache=use_cache,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
)

if isinstance(layer_outputs, (tuple, list)):
hidden_states = layer_outputs[0]
else:
hidden_states = layer_outputs

mtp_outputs.append(hidden_states)
mtp_outputs = [self.norm(hidden_states) for hidden_states in mtp_outputs]
hidden_states, mtp_outputs = mtp_outputs[0], mtp_outputs[1:]
else:
hidden_states = self.norm(hidden_states)

# add hidden states from the last decoder layer
if output_hidden_states:
Expand All @@ -766,7 +880,9 @@ def forward(
next_cache = next_decoder_cache if use_cache else None

if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return tuple(
v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, mtp_outputs] if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
Expand Down Expand Up @@ -891,14 +1007,19 @@ def forward(
)

hidden_states = outputs[0]
mtp_outputs = outputs[-1]

# if labels is None,means we need full output, instead of tensor_parallel_output
# tensor_parallel_output is together with ParallelCrossEntropy
tensor_parallel_output = self.config.tensor_parallel_output and self.config.tensor_parallel_degree > 1

logits = self.lm_head(hidden_states, tensor_parallel_output=tensor_parallel_output)

return logits
mtp_logits = [self.lm_head(_hidden_states) for _hidden_states in mtp_outputs] if len(mtp_outputs) > 0 else []
loss = None
if labels is not None:
loss = self.criterion(logits, labels, mtp_logits=mtp_logits)
return loss

def prepare_inputs_for_generation(
self, input_ids, use_cache=False, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
Expand Down
Loading
Loading