diff --git "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" index e1fe79626f..6b1aa988a6 100644 --- "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" +++ "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" @@ -560,6 +560,12 @@ |[rednote-hilab/dots.llm1.base](https://modelscope.cn/models/rednote-hilab/dots.llm1.base)|dots1|dots1|transformers>=4.53.0.dev0|✔|-|[rednote-hilab/dots.llm1.base](https://huggingface.co/rednote-hilab/dots.llm1.base)| |[rednote-hilab/dots.llm1.inst](https://modelscope.cn/models/rednote-hilab/dots.llm1.inst)|dots1|dots1|transformers>=4.53.0.dev0|✔|-|[rednote-hilab/dots.llm1.inst](https://huggingface.co/rednote-hilab/dots.llm1.inst)| |[Tencent-Hunyuan/Hunyuan-A13B-Instruct](https://modelscope.cn/models/Tencent-Hunyuan/Hunyuan-A13B-Instruct)|hunyuan|hunyuan|-|✘|-|[tencent/Hunyuan-A13B-Instruct](https://huggingface.co/tencent/Hunyuan-A13B-Instruct)| +|[PaddlePaddle/ERNIE-4.5-0.3B-Base-PT](https://modelscope.cn/models/PaddlePaddle/ERNIE-4.5-0.3B-Base-PT)|ernie|ernie|-|✔|-|[baidu/ERNIE-4.5-0.3B-PT](https://huggingface.co/baidu/ERNIE-4.5-0.3B-PT)| +|[PaddlePaddle/ERNIE-4.5-0.3B-PT](https://modelscope.cn/models/PaddlePaddle/ERNIE-4.5-0.3B-PT)|ernie|ernie|-|✔|-|[baidu/ERNIE-4.5-0.3B-PT](https://huggingface.co/baidu/ERNIE-4.5-0.3B-PT)| +|[PaddlePaddle/ERNIE-4.5-21B-A3B-Base-PT](https://modelscope.cn/models/PaddlePaddle/ERNIE-4.5-21B-A3B-Base-PT)|ernie|ernie|-|✔|-|[baidu/ERNIE-4.5-21B-A3B-Base-PT](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Base-PT)| +|[PaddlePaddle/ERNIE-4.5-21B-A3B-PT](https://modelscope.cn/models/PaddlePaddle/ERNIE-4.5-21B-A3B-PT)|ernie|ernie|-|✔|-|[baidu/ERNIE-4.5-21B-A3B-PT](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-PT)| +|[PaddlePaddle/ERNIE-4.5-300B-A47B-Base-PT](https://modelscope.cn/models/PaddlePaddle/ERNIE-4.5-300B-A47B-Base-PT)|ernie|ernie|-|✔|-|[baidu/ERNIE-4.5-300B-A47B-Base-PT](https://huggingface.co/baidu/ERNIE-4.5-300B-A47B-Base-PT)| +|[PaddlePaddle/ERNIE-4.5-300B-A47B-PT](https://modelscope.cn/models/PaddlePaddle/ERNIE-4.5-300B-A47B-PT)|ernie|ernie|-|✔|-|[baidu/ERNIE-4.5-300B-A47B-PT](https://huggingface.co/baidu/ERNIE-4.5-300B-A47B-PT)| |[answerdotai/ModernBERT-base](https://modelscope.cn/models/answerdotai/ModernBERT-base)|modern_bert|dummy|transformers>=4.48|✘|bert|[answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base)| |[answerdotai/ModernBERT-large](https://modelscope.cn/models/answerdotai/ModernBERT-large)|modern_bert|dummy|transformers>=4.48|✘|bert|[answerdotai/ModernBERT-large](https://huggingface.co/answerdotai/ModernBERT-large)| |[iic/gte-modernbert-base](https://modelscope.cn/models/iic/gte-modernbert-base)|modern_bert_gte|dummy|transformers>=4.48|✘|bert, embedding|[Alibaba-NLP/gte-modernbert-base](https://huggingface.co/Alibaba-NLP/gte-modernbert-base)| diff --git a/docs/source_en/Instruction/Supported-models-and-datasets.md b/docs/source_en/Instruction/Supported-models-and-datasets.md index 106e5c8779..267d9e8122 100644 --- a/docs/source_en/Instruction/Supported-models-and-datasets.md +++ b/docs/source_en/Instruction/Supported-models-and-datasets.md @@ -560,6 +560,12 @@ The table below introduces the models integrated with ms-swift: |[rednote-hilab/dots.llm1.base](https://modelscope.cn/models/rednote-hilab/dots.llm1.base)|dots1|dots1|transformers>=4.53.0.dev0|✔|-|[rednote-hilab/dots.llm1.base](https://huggingface.co/rednote-hilab/dots.llm1.base)| |[rednote-hilab/dots.llm1.inst](https://modelscope.cn/models/rednote-hilab/dots.llm1.inst)|dots1|dots1|transformers>=4.53.0.dev0|✔|-|[rednote-hilab/dots.llm1.inst](https://huggingface.co/rednote-hilab/dots.llm1.inst)| |[Tencent-Hunyuan/Hunyuan-A13B-Instruct](https://modelscope.cn/models/Tencent-Hunyuan/Hunyuan-A13B-Instruct)|hunyuan|hunyuan|-|✘|-|[tencent/Hunyuan-A13B-Instruct](https://huggingface.co/tencent/Hunyuan-A13B-Instruct)| +|[PaddlePaddle/ERNIE-4.5-0.3B-Base-PT](https://modelscope.cn/models/PaddlePaddle/ERNIE-4.5-0.3B-Base-PT)|ernie|ernie|-|✔|-|[baidu/ERNIE-4.5-0.3B-PT](https://huggingface.co/baidu/ERNIE-4.5-0.3B-PT)| +|[PaddlePaddle/ERNIE-4.5-0.3B-PT](https://modelscope.cn/models/PaddlePaddle/ERNIE-4.5-0.3B-PT)|ernie|ernie|-|✔|-|[baidu/ERNIE-4.5-0.3B-PT](https://huggingface.co/baidu/ERNIE-4.5-0.3B-PT)| +|[PaddlePaddle/ERNIE-4.5-21B-A3B-Base-PT](https://modelscope.cn/models/PaddlePaddle/ERNIE-4.5-21B-A3B-Base-PT)|ernie|ernie|-|✔|-|[baidu/ERNIE-4.5-21B-A3B-Base-PT](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Base-PT)| +|[PaddlePaddle/ERNIE-4.5-21B-A3B-PT](https://modelscope.cn/models/PaddlePaddle/ERNIE-4.5-21B-A3B-PT)|ernie|ernie|-|✔|-|[baidu/ERNIE-4.5-21B-A3B-PT](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-PT)| +|[PaddlePaddle/ERNIE-4.5-300B-A47B-Base-PT](https://modelscope.cn/models/PaddlePaddle/ERNIE-4.5-300B-A47B-Base-PT)|ernie|ernie|-|✔|-|[baidu/ERNIE-4.5-300B-A47B-Base-PT](https://huggingface.co/baidu/ERNIE-4.5-300B-A47B-Base-PT)| +|[PaddlePaddle/ERNIE-4.5-300B-A47B-PT](https://modelscope.cn/models/PaddlePaddle/ERNIE-4.5-300B-A47B-PT)|ernie|ernie|-|✔|-|[baidu/ERNIE-4.5-300B-A47B-PT](https://huggingface.co/baidu/ERNIE-4.5-300B-A47B-PT)| |[answerdotai/ModernBERT-base](https://modelscope.cn/models/answerdotai/ModernBERT-base)|modern_bert|dummy|transformers>=4.48|✘|bert|[answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base)| |[answerdotai/ModernBERT-large](https://modelscope.cn/models/answerdotai/ModernBERT-large)|modern_bert|dummy|transformers>=4.48|✘|bert|[answerdotai/ModernBERT-large](https://huggingface.co/answerdotai/ModernBERT-large)| |[iic/gte-modernbert-base](https://modelscope.cn/models/iic/gte-modernbert-base)|modern_bert_gte|dummy|transformers>=4.48|✘|bert, embedding|[Alibaba-NLP/gte-modernbert-base](https://huggingface.co/Alibaba-NLP/gte-modernbert-base)| diff --git a/swift/llm/model/constant.py b/swift/llm/model/constant.py index f35ac297d9..6f22e04bde 100644 --- a/swift/llm/model/constant.py +++ b/swift/llm/model/constant.py @@ -120,6 +120,7 @@ class LLMModelType: mimo_rl = 'mimo_rl' dots1 = 'dots1' hunyuan = 'hunyuan' + ernie = 'ernie' class BertModelType: diff --git a/swift/llm/model/model/__init__.py b/swift/llm/model/model/__init__.py index 589738abc7..97021eb65d 100644 --- a/swift/llm/model/model/__init__.py +++ b/swift/llm/model/model/__init__.py @@ -1,3 +1,3 @@ -from . import (baai, baichuan, bert, codefuse, deepseek, gemma, glm, internlm, llama, llava, llm, mamba, microsoft, - minicpm, minimax, mistral, mllm, moonshot, mplug, openbuddy, qwen, skywork, stepfun, telechat, valley, - yi) +from . import (baai, baichuan, baidu, bert, codefuse, deepseek, gemma, glm, internlm, llama, llava, llm, mamba, + microsoft, minicpm, minimax, mistral, mllm, moonshot, mplug, openbuddy, qwen, skywork, stepfun, telechat, + valley, yi) diff --git a/swift/llm/model/model/baidu.py b/swift/llm/model/model/baidu.py new file mode 100644 index 0000000000..a472d24bc0 --- /dev/null +++ b/swift/llm/model/model/baidu.py @@ -0,0 +1,27 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from swift.llm import TemplateType +from swift.utils import get_logger +from ..constant import LLMModelType +from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model + +logger = get_logger() + +register_model( + ModelMeta( + LLMModelType.ernie, + [ + ModelGroup([ + Model('PaddlePaddle/ERNIE-4.5-0.3B-Base-PT', 'baidu/ERNIE-4.5-0.3B-PT'), + Model('PaddlePaddle/ERNIE-4.5-0.3B-PT', 'baidu/ERNIE-4.5-0.3B-PT'), + ]), + ModelGroup([ + Model('PaddlePaddle/ERNIE-4.5-21B-A3B-Base-PT', 'baidu/ERNIE-4.5-21B-A3B-Base-PT'), + Model('PaddlePaddle/ERNIE-4.5-21B-A3B-PT', 'baidu/ERNIE-4.5-21B-A3B-PT'), + Model('PaddlePaddle/ERNIE-4.5-300B-A47B-Base-PT', 'baidu/ERNIE-4.5-300B-A47B-Base-PT'), + Model('PaddlePaddle/ERNIE-4.5-300B-A47B-PT', 'baidu/ERNIE-4.5-300B-A47B-PT'), + ]), + ], + TemplateType.ernie, + get_model_tokenizer_with_flash_attn, + architectures=['Ernie4_5_ForCausalLM', 'Ernie4_5_MoeForCausalLM'], + )) diff --git a/swift/llm/template/constant.py b/swift/llm/template/constant.py index b50d8aa1e9..23ceb13429 100644 --- a/swift/llm/template/constant.py +++ b/swift/llm/template/constant.py @@ -86,6 +86,7 @@ class LLMTemplateType: mimo_rl = 'mimo_rl' dots1 = 'dots1' hunyuan = 'hunyuan' + ernie = 'ernie' aya = 'aya' c4ai = 'c4ai' diff --git a/swift/llm/template/template/__init__.py b/swift/llm/template/template/__init__.py index 9a159dfec5..792a3faa5e 100644 --- a/swift/llm/template/template/__init__.py +++ b/swift/llm/template/template/__init__.py @@ -1,2 +1,3 @@ -from . import (bert, deepseek, emu3, gemma, glm, idefics3, internlm, internvl, llama, llava, llm, megrez, microsoft, - minicpm, minimax, mistral, molmo, moonshot, mplug, openbuddy, pixtral, qwen, stepfun, valley, yi) +from . import (baidu, bert, deepseek, emu3, gemma, glm, idefics3, internlm, internvl, llama, llava, llm, megrez, + microsoft, minicpm, minimax, mistral, molmo, moonshot, mplug, openbuddy, pixtral, qwen, stepfun, valley, + yi) diff --git a/swift/llm/template/template/baidu.py b/swift/llm/template/template/baidu.py new file mode 100644 index 0000000000..b3b3d45e3b --- /dev/null +++ b/swift/llm/template/template/baidu.py @@ -0,0 +1,19 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from dataclasses import dataclass, field +from typing import Optional + +from ..constant import LLMTemplateType +from ..register import TemplateMeta, register_template +from ..utils import Prompt + + +@dataclass +class ERNIETemplateMeta(TemplateMeta): + prefix: Prompt = field(default_factory=lambda: ['<|begin_of_sentence|>']) + prompt: Prompt = field(default_factory=lambda: ['User: {{QUERY}}\nAssistant: ']) + chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|end_of_sentence|>']) + suffix: Prompt = field(default_factory=lambda: ['']) + system_prefix: Optional[Prompt] = field(default_factory=lambda: ['<|begin_of_sentence|>{{SYSTEM}}\n']) + + +register_template(ERNIETemplateMeta(LLMTemplateType.ernie)) diff --git a/swift/llm/template/template/deepseek.py b/swift/llm/template/template/deepseek.py index 882f1c085e..a0adcb3619 100644 --- a/swift/llm/template/template/deepseek.py +++ b/swift/llm/template/template/deepseek.py @@ -1,5 +1,4 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import os from dataclasses import dataclass, field from typing import Any, Dict, List, Optional diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 75ce390bfd..9fafd9e0e9 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -135,6 +135,7 @@ class MegatronArguments(ExtraMegatronArguments): position_embedding_type: Literal['learned_absolute', 'rope', 'mrope', 'relative', 'none'] = 'rope' rotary_base: Optional[int] = None rotary_percent: float = 1. + rotary_interleaved: Optional[bool] = None normalization: Literal['LayerNorm', 'RMSNorm'] = 'RMSNorm' norm_epsilon: Optional[float] = None swiglu: Optional[bool] = None @@ -228,6 +229,8 @@ def _set_default(self): self.norm_epsilon = 1e-5 if self.rotary_base is None: self.rotary_base = 10000 + if self.rotary_interleaved is None: + self.rotary_interleaved = False if self.attention_dropout is None: self.attention_dropout = 0. if self.untie_embeddings_and_output_weights is None: diff --git a/swift/megatron/model/config.py b/swift/megatron/model/config.py index 9c6d0bf02a..5b58a2fc7c 100644 --- a/swift/megatron/model/config.py +++ b/swift/megatron/model/config.py @@ -17,15 +17,15 @@ 'attention_dropout': ['attention_dropout'], 'untie_embeddings_and_output_weights': ['tie_word_embeddings'], 'swiglu': ['hidden_act'], - 'add_qkv_bias': ['attention_bias', 'qkv_bias'], + 'add_qkv_bias': ['attention_bias', 'qkv_bias', 'use_bias'], 'disable_bias_linear': ['mlp_bias'], 'kv_channels': ['head_dim', 'v_head_dim'], 'architectures': ['architectures'], # moe 'moe_ffn_hidden_size': ['moe_intermediate_size'], 'moe_shared_expert_intermediate_size': ['shared_expert_intermediate_size'], - 'moe_router_topk': ['num_experts_per_tok', 'n_group', 'moe_topk'], - 'num_experts': ['num_experts', 'n_routed_experts'], + 'moe_router_topk': ['num_experts_per_tok', 'n_group', 'moe_topk', 'moe_k'], + 'num_experts': ['num_experts', 'n_routed_experts', 'moe_num_experts'], 'moe_router_pre_softmax': ['norm_topk_prob'], 'moe_aux_loss_coeff': ['router_aux_loss_coef'], # deepseek @@ -39,8 +39,8 @@ # other 'original_max_position_embeddings': ['original_max_position_embeddings'], 'partial_rotary_factor': ['partial_rotary_factor'], - 'first_k_dense_replace': ['first_k_dense_replace'], - 'n_shared_experts': ['n_shared_experts', 'num_shared_expert'], + 'first_k_dense_replace': ['first_k_dense_replace', 'moe_layer_start_index'], + 'n_shared_experts': ['n_shared_experts', 'num_shared_expert', 'moe_num_shared_experts'], } diff --git a/swift/megatron/model/gpt/__init__.py b/swift/megatron/model/gpt/__init__.py index 92d28ebfb3..ef30d84bf8 100644 --- a/swift/megatron/model/gpt/__init__.py +++ b/swift/megatron/model/gpt/__init__.py @@ -46,4 +46,5 @@ ModelType.deepseek_v2_5, ModelType.deepseek_r1, ModelType.dots1, + ModelType.ernie, ], model_provider, convert_gpt_hf_config, convert_mcore2hf, convert_hf2mcore)) diff --git a/swift/megatron/model/gpt/config.py b/swift/megatron/model/gpt/config.py index 652372deb2..2b0cc7151c 100644 --- a/swift/megatron/model/gpt/config.py +++ b/swift/megatron/model/gpt/config.py @@ -24,7 +24,6 @@ def convert_gpt_hf_config(config) -> Dict[str, Any]: res['qk_layernorm'] = True res['moe_router_load_balancing_type'] = 'seq_aux_loss' res.pop('num_query_groups', None) # https://github.com/NVIDIA/Megatron-LM/issues/1475 - res['moe_shared_expert_intermediate_size'] = n_shared_experts * res['moe_ffn_hidden_size'] if architectures == 'Dots1ForCausalLM': res['moe_router_score_function'] = 'sigmoid' if res.get('moe_router_score_function', 'softmax') == 'sigmoid': @@ -39,5 +38,10 @@ def convert_gpt_hf_config(config) -> Dict[str, Any]: if isinstance(val, list) and val and min(val) == max(val): res[key] = val[0] n_shared_experts = res.pop('n_shared_experts') + if architectures in {'Ernie4_5_ForCausalLM', 'Ernie4_5_MoeForCausalLM'}: + res['rotary_interleaved'] = True + if architectures == 'Ernie4_5_MoeForCausalLM': + res['moe_layer_freq'] = f'[0]*{first_k_dense_replace}+[1]*{res["num_layers"] - first_k_dense_replace}' + if n_shared_experts is not None and 'moe_shared_expert_intermediate_size' not in res: res['moe_shared_expert_intermediate_size'] = n_shared_experts * res['moe_ffn_hidden_size'] return res diff --git a/tests/megatron/test_align/test_llm.py b/tests/megatron/test_align/test_llm.py index 4b6a48f748..92088c1ca7 100644 --- a/tests/megatron/test_align/test_llm.py +++ b/tests/megatron/test_align/test_llm.py @@ -1,8 +1,6 @@ import os -import torch - -os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7' +os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3' def _test_model(model_id, **kwargs): @@ -120,6 +118,11 @@ def test_hunyuan(): _test_model('Tencent-Hunyuan/Hunyuan-A13B-Instruct') +def test_ernie(): + # _test_model('PaddlePaddle/ERNIE-4.5-0.3B-PT') + _test_model('PaddlePaddle/ERNIE-4.5-21B-A3B-PT') + + if __name__ == '__main__': # test_qwen2() # test_llama2() @@ -142,4 +145,5 @@ def test_hunyuan(): # test_deepseek_moe() # test_dots() # test_kimi_dev() - test_hunyuan() + # test_hunyuan() + test_ernie() diff --git a/tests/test_align/test_template/test_llm.py b/tests/test_align/test_template/test_llm.py index 04fee44ad6..a3aa917f01 100644 --- a/tests/test_align/test_template/test_llm.py +++ b/tests/test_align/test_template/test_llm.py @@ -438,6 +438,14 @@ def test_hunyuan(): assert res == res2, f'res: {res}, res2: {res2}' +def test_ernie(): + pt_engine = PtEngine('PaddlePaddle/ERNIE-4.5-0.3B-PT') + res = _infer_model(pt_engine) + pt_engine.default_template.template_backend = 'jinja' + res2 = _infer_model(pt_engine) + assert res == res2, f'res: {res}, res2: {res2}' + + if __name__ == '__main__': from swift.llm import PtEngine, RequestConfig from swift.utils import get_logger, seed_everything @@ -480,4 +488,5 @@ def test_hunyuan(): # test_minicpm() # test_minimax() # test_kimi_dev() - test_hunyuan() + # test_hunyuan() + test_ernie()