Skip to content

[Mixtral] Add mixtral moe #7803

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions llm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ def get_convert_example(model):

if base_model_prefix == "chatglm":
return convert_example_chatglm
elif base_model_prefix in ["chatglm_v2", "llama", "bloom", "opt", "qwen"]:
elif base_model_prefix in ["chatglm_v2", "llama", "bloom", "opt", "qwen", "mixtral"]:
return convert_example_common
else:
raise ValueError(
f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama."
f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama, qwen, mixtral"
)


Expand Down Expand Up @@ -107,7 +107,7 @@ def tokenize_rounds_example(tokenizer, example, data_args):
# 0. prepare data
context_data = example.get("context", {})
context_data["is_training"] = True

example["src"] = example["src"] if isinstance(example["src"], list) else [example["src"]]
example["tgt"] = example["tgt"] if isinstance(example["tgt"], list) else [example["tgt"]]

Expand Down
32 changes: 32 additions & 0 deletions llm/mixtral/lora_argument.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"model_name_or_path": "mistralai/Mixtral-8x7B-Instruct-v0.1",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新增一下sft的训练脚本

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/mixtral_lora_ckpts",
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"num_train_epochs": 3,
"learning_rate": 3e-04,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 1024,
"max_length": 2048,
"fp16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"tensor_parallel_degree": 8,
"pipeline_parallel_degree": 1,
"lora": true,
"zero_padding": false,
"use_flash_attention": false
}
30 changes: 30 additions & 0 deletions llm/mixtral/sft_argument.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"model_name_or_path": "mistralai/Mixtral-8x7B-Instruct-v0.1",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/mixtral_sft_ckpts",
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"num_train_epochs": 3,
"learning_rate": 3e-05,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 1024,
"max_length": 2048,
"bf16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"tensor_parallel_degree": 8,
"sharding": "stage2",
"pipeline_parallel_degree": 1
}
10 changes: 10 additions & 0 deletions llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,16 @@ def get_lora_target_modules(model):
".*mlp.w2.*",
".*mlp.c_proj.*",
]
elif model.base_model_prefix == "mixtral":
target_modules = [
".*q_proj.*",
".*k_proj.*",
".*v_proj.*",
".*o_proj.*",
".*w1.*",
".*w2.*",
".*w3.*",
]
else:
raise ValueError(f"Unknown base_model_prefix: {model.base_model_prefix}.")
return target_modules
Expand Down
2 changes: 2 additions & 0 deletions paddlenlp/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,8 @@
from .rw.configuration import *
from .rw.tokenizer import *
from .qwen import *
from .mixtral.modeling import *
from .mixtral.configuration import *

# For faster tokenizer
from ..utils.import_utils import is_fast_tokenizer_available
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/transformers/auto/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
("Blip", "blip"),
("Bloom", "bloom"),
("QWen", "qwen"),
("Mixtral", "mixtral"),
]
)

Expand Down
16 changes: 16 additions & 0 deletions paddlenlp/transformers/mixtral/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .configuration import MixtralConfig
from .modeling import MixtralForCausalLM
191 changes: 191 additions & 0 deletions paddlenlp/transformers/mixtral/configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Mixtral model configuration"""

from paddlenlp.transformers.configuration_utils import PretrainedConfig

__all__ = [
"MixtralConfig",
]


class MixtralConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`~MixtralModel`]. It is used to instantiate an Mixtral
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Mixtral-7B-v0.1 or Mixtral-7B-Instruct-v0.1.

Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.

Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the Mixtral model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`~MixtralModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 14336):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
The maximum sequence length that this model might ever be used with. Mixtral's sliding window attention
allows sequence of up to 4096*32 tokens.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
The id of the padding token.
bos_token_id (`int`, *optional*, defaults to 1):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 2):
The id of the "end-of-sequence" token.
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings.
rope_theta (`float`, *optional*, defaults to 1000000.0):
The base period of the RoPE embeddings.
sliding_window (`int`, *optional*):
Sliding window attention window size.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
num_experts_per_tok (`int`, *optional*, defaults to 2):
The number of experts to root per-token, can be also interpreted as the `top-p` routing
parameter
num_local_experts (`int`, *optional*, defaults to 8):
Number of experts per Sparse MLP layer.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabeling this will also
allow the model to output the auxiliary loss. See [here]() for more details
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
The aux loss factor for the total loss.
use_fused_rope(`bool`, *optional*, defaults to False):
Enable rope fusion or not.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
Example:
```python
>>> from paddlenlp.transformer import MixtralModel, MixtralConfig

>>> # Initializing a Mixtral mixtral-7b style configuration
>>> configuration = MixtralConfig()

>>> # Initializing a model from the mixtral-7b style configuration
>>> model = MixtralModel(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
```"""

model_type = "mixtral"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=14336,
max_position_embeddings=4096 * 32,
seq_length=2048,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
use_recompute=False,
recompute_granularity="full",
no_recompute_layers=None,
use_flash_attention=False,
attention_dropout=0.0,
use_fused_rope=False,
rope_theta=1e6,
tensor_parallel_output=True,
sequence_parallel=False,
fuse_sequence_parallel_allreduce=False,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
num_experts_per_tok=2,
num_local_experts=8,
router_aux_loss_coef=0.001,
output_router_logits=False,
sliding_window=None,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.max_position_embeddings = max_position_embeddings
self.seq_length = seq_length
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.attention_dropout = attention_dropout

if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act

self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps

self.use_cache = use_cache
self.use_recompute = use_recompute
self.recompute_granularity = recompute_granularity
self.no_recompute_layers = no_recompute_layers
self.use_flash_attention = use_flash_attention
self.tensor_parallel_output = tensor_parallel_output
self.sequence_parallel = sequence_parallel
self.fuse_sequence_parallel_allreduce = fuse_sequence_parallel_allreduce

self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id

self.use_fused_rope = use_fused_rope
self.rope_theta = rope_theta

# ----------------- Experts -------------------- #
self.num_experts_per_tok = num_experts_per_tok
self.num_local_experts = num_local_experts
self.router_aux_loss_coef = router_aux_loss_coef
self.output_router_logits = output_router_logits

self.sliding_window = sliding_window

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
tensor_parallel_output=tensor_parallel_output,
**kwargs,
)
Loading