generated from datawhalechina/repo-template
-
Notifications
You must be signed in to change notification settings - Fork 326
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #45 from 0-yy-0/main
Model_Architecture_Discussions 新增 MiniCPM
- Loading branch information
Showing
12 changed files
with
297,006 additions
and
0 deletions.
There are no files selected for viewing
1,176 changes: 1,176 additions & 0 deletions
1,176
Model_Architecture_Discussions/MiniCPM/MiniCPM.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
308 changes: 308 additions & 0 deletions
308
Model_Architecture_Discussions/MiniCPM/MiniCPMTest.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,308 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"用我们搭建的模型尝试读取官方权重并预测" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/home/jeeves/.local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | ||
" from .autonotebook import tqdm as notebook_tqdm\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import json\n", | ||
"import torch\n", | ||
"from transformers import AutoTokenizer, AutoModelForCausalLM\n", | ||
"from configuration_minicpm import MiniCPMConfig\n", | ||
"from MiniCPM import MiniCPMForCausalLM\n", | ||
"import logging\n", | ||
"import gc\n", | ||
"\n", | ||
"# 配置日志\n", | ||
"logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"加载模型 config" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"config_json = json.load(open(\"/data/workspace/llms-from-scratch-cn/Model_Architecture_Discussions/MiniCPM/config.json\"))\n", | ||
"config = MiniCPMConfig(**config_json)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"按照 config 初始化模型,并查看模型结构" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"2024-07-25 15:57:28,490 - INFO - 初始化模型\n", | ||
"2024-07-25 15:57:50,064 - INFO - 模型:\n", | ||
": MiniCPMForCausalLM(\n", | ||
" (model): MiniCPMModel(\n", | ||
" (embed_tokens): Embedding(122753, 2304)\n", | ||
" (layers): ModuleList(\n", | ||
" (0-39): 40 x MiniCPMDecoderLayer(\n", | ||
" (self_attn): MiniCPMAttention(\n", | ||
" (q_proj): Linear(in_features=2304, out_features=2304, bias=False)\n", | ||
" (k_proj): Linear(in_features=2304, out_features=2304, bias=False)\n", | ||
" (v_proj): Linear(in_features=2304, out_features=2304, bias=False)\n", | ||
" (o_proj): Linear(in_features=2304, out_features=2304, bias=False)\n", | ||
" (rotary_emb): MiniCPMRotaryEmbedding()\n", | ||
" )\n", | ||
" (mlp): MiniCPMMLP(\n", | ||
" (gate_proj): Linear(in_features=2304, out_features=5760, bias=False)\n", | ||
" (up_proj): Linear(in_features=2304, out_features=5760, bias=False)\n", | ||
" (down_proj): Linear(in_features=5760, out_features=2304, bias=False)\n", | ||
" (act_fn): SiLU()\n", | ||
" )\n", | ||
" (input_layernorm): MiniCPMRMSNorm()\n", | ||
" (post_attention_layernorm): MiniCPMRMSNorm()\n", | ||
" )\n", | ||
" )\n", | ||
" (norm): MiniCPMRMSNorm()\n", | ||
" )\n", | ||
" (lm_head): Linear(in_features=2304, out_features=122753, bias=False)\n", | ||
")\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"\n", | ||
"try:\n", | ||
" logging.info(\"初始化模型\")\n", | ||
" model = MiniCPMForCausalLM(config=config).to('cuda')\n", | ||
" logging.info(\"模型:\\n: %s\", model)\n", | ||
"except Exception as e:\n", | ||
" logging.error(f\"初始化模型时发生错误: {e}\")\n", | ||
" raise" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"读取模型权重" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"2024-07-25 15:57:50,086 - INFO - 加载模型权重\n", | ||
"2024-07-25 15:57:52,515 - INFO - 加载模型权重完成。\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"缺失的参数名: ['lm_head.weight']\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"\n", | ||
"path = \"/data/model/OpenBMB/MiniCPM-2B-dpo-bf16\"\n", | ||
"\n", | ||
"try:\n", | ||
" logging.info(\"加载模型权重\")\n", | ||
" params = torch.load(\n", | ||
" f=path + \"/pytorch_model.bin\",\n", | ||
" map_location=torch.device('cuda'),\n", | ||
" weights_only=True, # 设置为True表示仅加载模型的权重。这通常用于加载预训练权重进行微调或预测,而不需要完整的模型结构\n", | ||
" mmap=True # 使用内存映射方式加载模型文件,这可以提高加载大型模型文件的效率,特别是在有限的内存资源下\n", | ||
" )\n", | ||
" # 打印出模型参数和params中不一致的参数名\n", | ||
" missing_keys, unexpected_keys = model.load_state_dict(params, strict=False)\n", | ||
" # 打印缺失的参数名\n", | ||
" if missing_keys:\n", | ||
" print(\"缺失的参数名:\", missing_keys)\n", | ||
"\n", | ||
" # 打印多余的参数名\n", | ||
" if unexpected_keys:\n", | ||
" print(\"多余的参数名:\", unexpected_keys)\n", | ||
" # modelV1 = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16, device_map='cuda', trust_remote_code=True)\n", | ||
" # 手动实现 tie embedding 即输入输出共享一个 Embedding\n", | ||
" model.get_output_embeddings().weight = model.get_input_embeddings().weight\n", | ||
" del params\n", | ||
" gc.collect()\n", | ||
" logging.info(\"加载模型权重完成。\")\n", | ||
"except Exception as e:\n", | ||
" logging.error(f\"加载模型权重时发生错误: {e}\")\n", | ||
" raise" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"MiniCPM 采用了 tie-Embedding 的方式,即词嵌入层和输出层共享参数。这种方式可以减少模型的参数量,提高模型的训练效率。所以需要有获取和设置输入输出词嵌入层的方法。\n", | ||
"我们可以看到在加载权重时缺失 `lm_head.weight` 的参数,这里我们通过手动设置 `model.get_output_embeddings().weight = model.get_input_embeddings().weight` 来共享参数。" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"使用默认的 tokenizer 分词" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 11, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"2024-07-25 16:01:12,360 - INFO - 初始化分词器\n", | ||
"2024-07-25 16:01:12,557 - INFO - 生成文本\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"logging.info(\"初始化分词器\")\n", | ||
"tokenizer = AutoTokenizer.from_pretrained(\"/data/model/OpenBMB/MiniCPM-2B-dpo-bf16/\")\n", | ||
"\n", | ||
"logging.info(\"生成文本\")\n", | ||
"input_texts = [\"北京最高的山是哪座山?\", \"山东省最长的山是哪座山?\" ]\n", | ||
"\n", | ||
"tokenizer.pad_token_id=tokenizer.eos_token_id\n", | ||
"\n", | ||
"inputs = tokenizer(input_texts, padding=True, return_tensors=\"pt\").to('cuda')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"可以看出 MiniCPM 采用 tokenizer 为 `LlamaTokenizerFast`, 词表大小为 122753 个 token。" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 12, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"LlamaTokenizerFast(name_or_path='/data/model/OpenBMB/MiniCPM-2B-dpo-bf16/', vocab_size=122753, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '</s>'}, clean_up_tokenization_spaces=False), added_tokens_decoder={\n", | ||
"\t0: AddedToken(\"<unk>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n", | ||
"\t1: AddedToken(\"<s>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n", | ||
"\t2: AddedToken(\"</s>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n", | ||
"}\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"print(tokenizer)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"我们让模型输出结果看看" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 14, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"2024-07-25 16:01:36,687 - INFO - 生成结果: 北京最高的山是哪座山?\n", | ||
" 北京最高的山是香山。香山位于北京市海淀区,距离北京市中心约25公里,海拔572米。香山是北京市内最高峰\n", | ||
"2024-07-25 16:01:36,687 - INFO - 生成结果: 山东省最长的山是哪座山?\n", | ||
" 目前,山东省最长的山是泰山。泰山,位于山东省中部,是五岳之一,也是中国著名的山脉之一。泰山是中国著名的山脉之一\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"generate_input = {\n", | ||
" \"input_ids\": inputs.input_ids,\n", | ||
" \"attention_mask\": inputs.attention_mask,\n", | ||
" \"max_new_tokens\": 32,\n", | ||
" \"temperature\": 1,\n", | ||
" \"tokenizer\": tokenizer,\n", | ||
"}\n", | ||
"model.eval()\n", | ||
"outputs = model.generate(**generate_input)\n", | ||
"for output in outputs:\n", | ||
" result = tokenizer.decode(output, skip_special_tokens=True)\n", | ||
" logging.info(f\"生成结果: {result}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"可以看出输出的结果还可以" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.12" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.