generated from datawhalechina/repo-template
-
Notifications
You must be signed in to change notification settings - Fork 326
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
daade3d
commit 894b20d
Showing
152 changed files
with
3,408 additions
and
55,092 deletions.
There are no files selected for viewing
File renamed without changes.
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes.
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,290 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "a5d6ad19-adff-423b-8177-30a0d4f6ceed", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# import accelerate" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "02e86302-68c0-455a-9457-6a4618b95cba", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"env: HF_ENDPOINT=https://hf-mirror.com\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"%env HF_ENDPOINT=https://hf-mirror.com\n", | ||
"import os\n", | ||
"os.environ['HF_HOME'] = '/data1/ckw'" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "474af3e5-820c-4d8f-9ee0-e320edc3ccee", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"os.environ['HF_ENDPOINT']='https://hf-mirror.com'" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "e46a03d4-186e-46ee-b9e2-e2826c06e3e4", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"/data1/ckw/01大语言模型/ChatGLM4\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"!pwd" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "62284871-67f4-47a2-941f-46befd2032b7", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import torch\n", | ||
"# from transformers import AutoModelForCausalLM, AutoTokenizer\n", | ||
"from tokenization_chatglm import ChatGLM4Tokenizer\n", | ||
"from modeling_chatglm import ChatGLMForConditionalGeneration\n", | ||
"\n", | ||
"device = \"cuda\"\n", | ||
"\n", | ||
"tokenizer = ChatGLM4Tokenizer.from_pretrained(\"THUDM/glm-4-9b-chat\", trust_remote_code=True)\n", | ||
"\n", | ||
"query = \"你好\"\n", | ||
"\n", | ||
"inputs = tokenizer.apply_chat_template([{\"role\": \"user\", \"content\": query}],\n", | ||
" add_generation_prompt=True,\n", | ||
" tokenize=True,\n", | ||
" return_tensors=\"pt\",\n", | ||
" return_dict=True\n", | ||
" )\n", | ||
"\n", | ||
"inputs = inputs.to(device)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "5315c9b3-7fc4-4291-a9b1-6bcd92b4d940", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n", | ||
"The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "e68152ca9a1a462d8b96e1e31e7c582b", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"Loading checkpoint shards: 0%| | 0/10 [00:00<?, ?it/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n", | ||
"你好👋!有什么可以帮助你的吗?\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"model = ChatGLMForConditionalGeneration.from_pretrained(\n", | ||
" \"THUDM/glm-4-9b-chat\",\n", | ||
" torch_dtype=torch.bfloat16,\n", | ||
" low_cpu_mem_usage=True,\n", | ||
" trust_remote_code=True,\n", | ||
" load_in_4bit=True,\n", | ||
" device_map='auto'\n", | ||
").eval()#.to(device)\n", | ||
"\n", | ||
"gen_kwargs = {\"max_length\": 2500, \"do_sample\": True, \"top_k\": 1}\n", | ||
"with torch.no_grad():\n", | ||
" outputs = model.generate(**inputs, **gen_kwargs)\n", | ||
" outputs = outputs[:, inputs['input_ids'].shape[1]:]\n", | ||
" print(tokenizer.decode(outputs[0], skip_special_tokens=True))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "9e6f3cf8-2032-4b0a-be34-6e7f4409a01a", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from transformers.utils import is_accelerate_available" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"id": "1466e0e6-7d99-4962-82c9-4fff85b7b3d4", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"False" | ||
] | ||
}, | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"is_accelerate_available()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"id": "5033e32b-9daf-4b8a-a239-da883414936a", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"\u001b[0;31mSignature:\u001b[0m \u001b[0mis_accelerate_available\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmin_version\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'0.21.0'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | ||
"\u001b[0;31mDocstring:\u001b[0m <no docstring>\n", | ||
"\u001b[0;31mSource:\u001b[0m \n", | ||
"\u001b[0;32mdef\u001b[0m \u001b[0mis_accelerate_available\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmin_version\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mACCELERATE_MIN_VERSION\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", | ||
"\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0m_accelerate_available\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mversion\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparse\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_accelerate_version\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0mversion\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparse\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmin_version\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | ||
"\u001b[0;31mFile:\u001b[0m /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages/transformers/utils/import_utils.py\n", | ||
"\u001b[0;31mType:\u001b[0m function" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
} | ||
], | ||
"source": [ | ||
"??is_accelerate_available" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"id": "9fef3754-bbe9-4783-877f-b747421ada1e", | ||
"metadata": { | ||
"scrolled": true | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n", | ||
"Requirement already satisfied: accelerate in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (0.20.3)\n", | ||
"Collecting accelerate\n", | ||
" Downloading https://pypi.tuna.tsinghua.edu.cn/packages/f0/62/9ebaf1fdd3d3c737a8814f9ae409d4ac04bc93b26a46a7dab456bb7e16f8/accelerate-0.31.0-py3-none-any.whl (309 kB)\n", | ||
"\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m309.4/309.4 kB\u001b[0m \u001b[31m2.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m[31m1.8 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\n", | ||
"\u001b[?25hRequirement already satisfied: numpy>=1.17 in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from accelerate) (1.24.3)\n", | ||
"Requirement already satisfied: packaging>=20.0 in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from accelerate) (23.2)\n", | ||
"Requirement already satisfied: psutil in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from accelerate) (5.9.5)\n", | ||
"Requirement already satisfied: pyyaml in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from accelerate) (6.0.1)\n", | ||
"Requirement already satisfied: torch>=1.10.0 in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from accelerate) (2.1.0.post301)\n", | ||
"Requirement already satisfied: huggingface-hub in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from accelerate) (0.23.2)\n", | ||
"Requirement already satisfied: safetensors>=0.3.1 in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from accelerate) (0.4.3)\n", | ||
"Requirement already satisfied: filelock in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from torch>=1.10.0->accelerate) (3.12.4)\n", | ||
"Requirement already satisfied: typing-extensions in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from torch>=1.10.0->accelerate) (4.9.0)\n", | ||
"Requirement already satisfied: sympy in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from torch>=1.10.0->accelerate) (1.12)\n", | ||
"Requirement already satisfied: networkx in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from torch>=1.10.0->accelerate) (3.2.1)\n", | ||
"Requirement already satisfied: jinja2 in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from torch>=1.10.0->accelerate) (3.1.2)\n", | ||
"Requirement already satisfied: fsspec in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from torch>=1.10.0->accelerate) (2023.12.2)\n", | ||
"Requirement already satisfied: requests in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from huggingface-hub->accelerate) (2.31.0)\n", | ||
"Requirement already satisfied: tqdm>=4.42.1 in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from huggingface-hub->accelerate) (4.65.0)\n", | ||
"Requirement already satisfied: MarkupSafe>=2.0 in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.3)\n", | ||
"Requirement already satisfied: charset-normalizer<4,>=2 in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from requests->huggingface-hub->accelerate) (3.3.0)\n", | ||
"Requirement already satisfied: idna<4,>=2.5 in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from requests->huggingface-hub->accelerate) (3.4)\n", | ||
"Requirement already satisfied: urllib3<3,>=1.21.1 in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from requests->huggingface-hub->accelerate) (1.26.18)\n", | ||
"Requirement already satisfied: certifi>=2017.4.17 in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from requests->huggingface-hub->accelerate) (2023.7.22)\n", | ||
"Requirement already satisfied: mpmath>=0.19 in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)\n", | ||
"Installing collected packages: accelerate\n", | ||
" Attempting uninstall: accelerate\n", | ||
" Found existing installation: accelerate 0.20.3\n", | ||
" Uninstalling accelerate-0.20.3:\n", | ||
" Successfully uninstalled accelerate-0.20.3\n", | ||
"Successfully installed accelerate-0.31.0\n", | ||
"Note: you may need to restart the kernel to use updated packages.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"%pip install --upgrade accelerate" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "38eb42ba-0628-4f9d-bf29-410607552950", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "kewei-ai", | ||
"language": "python", | ||
"name": "kewei-ai" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
58 changes: 58 additions & 0 deletions
58
Model_Architecture_Discussions/ChatGLM4/configuration_chatglm.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from transformers import PretrainedConfig | ||
|
||
|
||
class ChatGLMConfig(PretrainedConfig): | ||
model_type = "chatglm" | ||
|
||
def __init__( | ||
self, | ||
num_layers=28, | ||
padded_vocab_size=65024, | ||
hidden_size=4096, | ||
ffn_hidden_size=13696, | ||
kv_channels=128, | ||
num_attention_heads=32, | ||
seq_length=2048, | ||
hidden_dropout=0.0, | ||
classifier_dropout=None, | ||
attention_dropout=0.0, | ||
layernorm_epsilon=1e-5, | ||
rmsnorm=True, | ||
apply_residual_connection_post_layernorm=False, | ||
post_layer_norm=True, | ||
add_bias_linear=False, | ||
add_qkv_bias=False, | ||
bias_dropout_fusion=True, | ||
multi_query_attention=False, | ||
multi_query_group_num=1, | ||
rope_ratio=1, | ||
apply_query_key_layer_scaling=True, | ||
attention_softmax_in_fp32=True, | ||
fp32_residual_connection=False, | ||
**kwargs | ||
): | ||
self.num_layers = num_layers | ||
self.vocab_size = padded_vocab_size | ||
self.padded_vocab_size = padded_vocab_size | ||
self.hidden_size = hidden_size | ||
self.ffn_hidden_size = ffn_hidden_size | ||
self.kv_channels = kv_channels | ||
self.num_attention_heads = num_attention_heads | ||
self.seq_length = seq_length | ||
self.hidden_dropout = hidden_dropout | ||
self.classifier_dropout = classifier_dropout | ||
self.attention_dropout = attention_dropout | ||
self.layernorm_epsilon = layernorm_epsilon | ||
self.rmsnorm = rmsnorm | ||
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm | ||
self.post_layer_norm = post_layer_norm | ||
self.add_bias_linear = add_bias_linear | ||
self.add_qkv_bias = add_qkv_bias | ||
self.bias_dropout_fusion = bias_dropout_fusion | ||
self.multi_query_attention = multi_query_attention | ||
self.multi_query_group_num = multi_query_group_num | ||
self.rope_ratio = rope_ratio | ||
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling | ||
self.attention_softmax_in_fp32 = attention_softmax_in_fp32 | ||
self.fp32_residual_connection = fp32_residual_connection | ||
super().__init__(**kwargs) |
Oops, something went wrong.