Skip to content

Commit

Permalink
Merge pull request #3068 from Embedded-AI-Systems/master
Browse files Browse the repository at this point in the history
New Sampler
  • Loading branch information
wangzhaode authored Dec 16, 2024
2 parents 0046e50 + 460bf40 commit 85291d0
Show file tree
Hide file tree
Showing 37 changed files with 2,592 additions and 407 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,10 @@ pymnn_build/

# mnncompress generated
MNN_compression_pb2.py

# model path
model/

# datasets
datasets/*
!datasets/*.sh
6 changes: 4 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ endif()
project(MNN VERSION ${MNN_VERSION} LANGUAGES C CXX ASM)
# complier options
set(CMAKE_C_STANDARD 99)
set(CMAKE_CXX_STANDARD 11)
IF (NOT (CMAKE_CXX_STANDARD EQUAL 17))
set(CMAKE_CXX_STANDARD 11)
ENDIF()
set(CMAKE_MODULE_PATH
${CMAKE_MODULE_PATH}
"${CMAKE_CURRENT_LIST_DIR}/cmake"
Expand Down Expand Up @@ -285,7 +287,7 @@ if(CMAKE_SYSTEM_NAME MATCHES "^Android")
endif()
option(MNN_USE_CPP11 "Enable MNN use c++11" ON)
if (NOT MSVC)
if(MNN_CUDA AND MNN_SUPPORT_TRANSFORMER_FUSE)
if((MNN_CUDA AND MNN_SUPPORT_TRANSFORMER_FUSE) OR (CMAKE_CXX_STANDARD EQUAL 17))
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=gnu99")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
Expand Down
35 changes: 32 additions & 3 deletions docs/transformers/llm.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt
- visual_model: 当使用VL模型时,visual_model的实际路径为`base_dir + visual_model`,默认为`base_dir + 'visual.mnn'`
- 推理配置
- max_new_tokens: 生成时最大token数,默认为`512`
- reuse_kv: 多轮对话时是否复用之前对话的`kv cache`,默认为`false`
- reuse_kv: 多轮对话时是否复用之前对话的`kv cache`,默认为`false`, 目前只有CPU后端支持设置为`true`.
- quant_qkv: CPU attention 算子中`query, key, value`是否量化,可选为:`0, 1, 2, 3, 4`,默认为`0`,含义如下:
- 0: key和value都不量化
- 1: 使用非对称8bit量化存储key
Expand All @@ -205,6 +205,19 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt
- thread_num: CPU推理使用硬件线程数,默认为:`4`; OpenCL推理时使用`68`
- precision: 推理使用精度策略,默认为:`"low"`,尽量使用`fp16`
- memory: 推理使用内存策略,默认为:`"low"`,开启运行时量化
- Sampler配置
- sampler_type: 使用的sampler种类,目前支持`greedy`, `temperature`, `topK`, `topP`, `minP`, `tfs`, `typical`, `penalty`8种基本sampler,外加`mixed`(混合sampler)。当选择`mixed`时,依次执行mixed_samplers中的sampler。默认为`mixed`
- mixed_samplers: 当`sampler_type``mixed`时有效,默认为`["topK", "tfs", "typical", "topP", "min_p", "temperature"]`
- temperature: `temperature`, `topP`, `minP`, `tfsZ`, `typical`中temerature值,默认为1.0
- topK: `topK`中top K 个的个数,默认为40
- topP: `topP`中top P的值,默认为0.9
- minP: `minP`中min P的值,默认为0.1
- tfsZ: `tfs`中Z的值,默认为1.0,即不使用tfs算法
- typical: `typical`中p的值,默认为1.0,即不使用typical算法
- penalty: `penalty`中对于logits的惩罚项,默认为0.0,即不惩罚
- n_gram: `penalty`中最大存储的ngram大小,默认为8
- ngram_factor: `penalty`中对于重复ngram的额外惩罚,默认为1.0,即没有额外惩罚
- penalty_sampler: `penalty`中最后一步采用的sampling策略,可选"greedy"或"temperature",默认greedy.

##### 配置文件示例
- `config.json`
Expand All @@ -216,7 +229,15 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt
"backend_type": "cpu",
"thread_num": 4,
"precision": "low",
"memory": "low"
"memory": "low",
"sampler_type": "mixed",
"mixed_samplers": ["topK", "tfs", "typical", "topP", "min_p", "temperature"],
"temperature": 1.0,
"topK": 40,
"topP": 0.9,
"tfsZ": 1.0,
"minP": 0.1,
"reuse_kv": true
}
```
- `llm_config.json`
Expand All @@ -240,7 +261,8 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt

#### 推理用法
`llm_demo`的用法如下:
```
pc端直接推理
```bash
# 使用config.json
## 交互式聊天
./llm_demo model_dir/config.json
Expand All @@ -254,6 +276,13 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt
./llm_demo model_dir/llm.mnn prompt.txt
```

android手机端adb推理用法:
```bash
# 利用adb push将链接库push到手机上
adb shell mkdir /data/local/tmp/llm
adb push llm_demo ppl_demo libllm.so libMNN_CL.so libMNN_Express.so libMNN.so tools/cv/libMNNOpenCV.so /data/local/tmp/llm
```

- 对于视觉大模型,在prompt中嵌入图片输入
```
<img>https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg</img>介绍一下图片里的内容
Expand Down
7 changes: 7 additions & 0 deletions transformers/llm/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
datasets/*
!datasets/*.sh


!datasets/visualization/
datasets/visualization/data
datasets/visualization/pic
2 changes: 2 additions & 0 deletions transformers/llm/datasets/get-sharegpt.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
git lfs install
git clone https://huggingface.co/datasets/shareAI/ShareGPT-Chinese-English-90k
2 changes: 2 additions & 0 deletions transformers/llm/datasets/get-wikitext-2-raw.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
wget https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
unzip wikitext-2-raw-v1.zip
116 changes: 116 additions & 0 deletions transformers/llm/datasets/visualization/stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import matplotlib.pyplot as plt
from matplotlib import colors
from matplotlib.ticker import PercentFormatter
from matplotlib import cbook
from matplotlib.axes import Axes
import pandas as pd
import numpy as np
import argparse
import os

vis_root = "pic"

def remove_blanks(df: pd.DataFrame) -> pd.DataFrame:
# Removing unnamed columns using drop function
df.drop(df.columns[df.columns.str.contains(
'unnamed', case=False)], axis=1, inplace=True)
return df
def add_turns(df: pd.DataFrame) -> pd.DataFrame:
df["turns"] = (1-df.isnull()).sum(axis=1) // 2
return df
def get_max_turn(df: pd.DataFrame) -> int:
keys = list(df.keys())
return max([int(key.replace("decode", "")) for key in keys if "decode" in key]) + 1
def add_pd_ratio(df: pd.DataFrame) -> pd.DataFrame:
max_turns = get_max_turn(df)
for i in range(max_turns):
df["pd_ratio{}".format(i)] = df["prefill{}".format(i)] / df["decode{}".format(i)]
return df
def preprocess(file_path: str) -> pd.DataFrame:
table = pd.read_csv(file_path)
table = remove_blanks(table)
table = add_turns(table)
table = add_pd_ratio(table)
print(table)
return table

def draw_distribution(df: pd.DataFrame, file_path: str):
turns_bin = df.value_counts(subset=["turns"], sort=False)
print(turns_bin)
plt.close()
plt.rcParams['font.size'] = 10
_, ax = plt.subplots()
# N is the count in each bin, bins is the lower-limit of the bin
N, bins, patches = ax.hist(df["turns"], bins=get_max_turn(df), density=True, align="left", label=True)
# We'll color code by height, but you could use any scalar
fracs = N / N.max()
# we need to normalize the data to 0..1 for the full range of the colormap
norm = colors.Normalize(fracs.min(), fracs.max())
# Now, we'll loop through our objects and set the color of each accordingly
for thisfrac, thispatch in zip(fracs, patches):
color = plt.cm.viridis(norm(thisfrac))
thispatch.set_facecolor(color)
# Now we format the y-axis to display percentage
ax.yaxis.set_major_formatter(PercentFormatter(xmax=1))
ax.set_xlim((0.5, get_max_turn(df)-0.5))
ax.set_xticks(np.arange(1,get_max_turn(df)+1),np.arange(1,get_max_turn(df)+1),rotation=60, fontsize=9)
ax.set_ylabel("frequency", fontsize=14)
ax.set_xlabel("num of turns", fontsize=14)
plt.savefig(file_path, dpi=600)
plt.close()

def draw_prefill(df: pd.DataFrame, ax: Axes):
stats = [cbook.boxplot_stats(df[df["prefill{}".format(i)].notna()]["prefill{}".format(i)], labels=[i+1])[0]
for i in range(get_max_turn(df))]
print(stats)
ax.bxp(stats, patch_artist=True, boxprops={'facecolor': 'bisque'}, flierprops=dict(marker='o', markersize=2))
ax.set_ylim(0,600)
ax.set_yticks(np.arange(0,700,100), np.arange(0,700,100), fontsize=9)
ax.set_ylabel("prefill", fontsize=12, rotation=90)
return
def draw_decode(df: pd.DataFrame, ax: Axes):
stats = [cbook.boxplot_stats(df[df["decode{}".format(i)].notna()]["decode{}".format(i)], labels=[i+1])[0]
for i in range(get_max_turn(df))]
print(stats)
ax.bxp(stats, patch_artist=True, boxprops={'facecolor': 'bisque'}, flierprops=dict(marker='o', markersize=2))
ax.set_ylim(0,600)
ax.set_yticks(np.arange(0,700,100), np.arange(0,700,100), fontsize=9)
ax.set_ylabel("decode", fontsize=12, rotation=90)
return
def draw_pd_ratio(df: pd.DataFrame, ax: Axes):
stats = [cbook.boxplot_stats(df[df["pd_ratio{}".format(i)].notna()]["pd_ratio{}".format(i)], labels=[i+1])[0]
for i in range(get_max_turn(df))]
print(stats)
ax.bxp(stats, patch_artist=True, boxprops={'facecolor': 'bisque'}, flierprops=dict(marker='o', markersize=2))
ax.plot(np.arange(0,get_max_turn(df)+2), np.ones_like(np.arange(0,get_max_turn(df)+2),dtype=float))
ax.set_xlim(0, get_max_turn(df)+1)
ax.set_ylim(0, 2.)
ax.set_xticks(np.arange(1,get_max_turn(df)), np.arange(1,get_max_turn(df)), rotation=60, fontsize=9)
ax.set_yticks([0,0.5,1,2], [0,0.5,1,2], fontsize=9)
ax.set_xlabel("round", fontsize=12)
ax.set_ylabel("prefill/decode", fontsize=12, rotation=90)
return
def draw_reuse_kv(df: pd.DataFrame, file_path: str):
plt.close()
_, axs = plt.subplots(3,1,sharex="col")
draw_prefill(df, axs[0])
draw_decode(df, axs[1])
draw_pd_ratio(df, axs[2])
plt.savefig(file_path, dpi=1200)
plt.close()
return
def draw_no_reuse_kv():
return

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--root", type=str, default="./data")
parser.add_argument("--name", type=str, default="shareGPT_dialog_stats_common_en.csv")
args = parser.parse_args()

file_path = os.path.join(args.root, args.name)
dist_path = os.path.join(vis_root, args.name.split('.')[0]+"_dist.png")
pd_dist_path = os.path.join(vis_root, args.name.split('.')[0]+"_pd_dist.png")
table = preprocess(file_path)
draw_distribution(table, dist_path)
draw_reuse_kv(table, pd_dist_path)
83 changes: 83 additions & 0 deletions transformers/llm/datasets/visualization/time.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import matplotlib.pyplot as plt
from matplotlib import colors
from matplotlib.ticker import PercentFormatter
from matplotlib import cbook
from matplotlib.axes import Axes
from typing import List, Dict, Tuple
import pandas as pd
import numpy as np
import argparse
import os
import re
from io import StringIO

def split_by_turns(id: str, content: str) -> List[pd.DataFrame]:
pattern = "<{id}>\n(.*?)</{id}>\n".format(id=id)
return [pd.read_csv(StringIO(item)) for item in re.findall(pattern, content, flags=re.DOTALL)]
def preprocess(file_path: str) -> Tuple[List[pd.DataFrame], List[pd.DataFrame]]:
content = open(file_path, "rt").read()
return split_by_turns("prefill", content), split_by_turns("decode", content)
def get_max_turn(no_reuse_prefill_record):
return max(10, max([len(record) for record in no_reuse_prefill_record]))
def draw_history_len(ax: Axes, no_reuse_prefill_record: List[pd.DataFrame]):
max_round = get_max_turn(no_reuse_prefill_record)
history_len = [0 for _ in range(0, max_round)]
for turn in range(0, max_round):
history_len[turn] = np.median([record["input_token"][turn] - record["prompt_token"][turn]
for record in no_reuse_prefill_record if len(record)>=turn+1]).item()
plt.plot(np.arange(1, max_round+1), history_len, label="median history len", marker=".", markersize=8)
return
def draw_prefill_bar_chat(ax: Axes, no_reuse, reuse):
offset = 0.2
max_round = len(no_reuse)
no_reuse_med = [np.median(turn) for turn in no_reuse]
rects = ax.bar(np.arange(1,max_round+1) + offset, no_reuse_med, offset*2, label="no reuse kv", color="tomato")
ax.bar_label(rects, fmt="{:.2f}", padding=4, fontsize=6)
reuse_med = [np.median(turn) for turn in reuse]
rects = ax.bar(np.arange(1,max_round+1) - offset, reuse_med, offset*2, label="reuse kv", color="springgreen")
ax.bar_label(rects, fmt="{:.2f}", padding=4, fontsize=6)
return
def compare_prefill_reuse_kv(no_reuse_prefill_record: List[pd.DataFrame],
reuse_prefill_record: List[pd.DataFrame]):
plt.close()
_,ax1 = plt.subplots()
ax2 = ax1.twinx()
# plot history_len
draw_history_len(ax2, no_reuse_prefill_record)
# calculate per turn
max_round = get_max_turn(no_reuse_prefill_record)
no_reuse = [[] for _ in range(0, max_round)]
for turn in range(0, max_round):
no_reuse[turn] = [record["response_speed"][turn] for record in no_reuse_prefill_record if len(record)>=turn+1]
reuse = [[] for _ in range(0, max_round)]
for turn in range(0, max_round):
reuse[turn] = [record["response_speed"][turn] for record in reuse_prefill_record if len(record)>=turn+1]
# plot the bar chat (with error bar)
draw_prefill_bar_chat(ax1, no_reuse, reuse)
ax1.set_xticks(np.arange(1,max_round+1),np.arange(1,max_round+1),fontsize=9)
ax1.set_ylim(0,100)
ax2.set_ylim(0,1000)
ax1.legend(loc='upper left', title="prefill response speed")
ax2.legend(loc='upper right')
ax1.set_ylabel("prefill\nresponse\nspeed", rotation=0, labelpad=12)
ax2.set_ylabel("history\nlen", rotation=0, labelpad=8)
ax1.set_xlabel("round")
plt.title("KV cache reuse for multi-turn chat\neffects on ShareGPT")
plt.tight_layout()
plt.savefig("./pic/fig.png",dpi=1200)
plt.close()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--root", type=str, default="./data")
parser.add_argument("--no_reuse", type=str, default="shareGPT_common_en_70k_noreuse.txt")
parser.add_argument("--reuse", type=str, default="shareGPT_common_en_70k_reuse.txt")
args = parser.parse_args()

no_reuse_file_path = os.path.join(args.root, args.no_reuse)
reuse_file_path = os.path.join(args.root, args.reuse)
no_reuse_prefill_record, no_reuse_decode_record = preprocess(no_reuse_file_path)
reuse_prefill_record, reuse_decode_record = preprocess(reuse_file_path)
# visualize prefill
compare_prefill_reuse_kv(no_reuse_prefill_record, reuse_prefill_record)
7 changes: 5 additions & 2 deletions transformers/llm/engine/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@ else()
add_library(llm OBJECT ${SRCS})
endif()

add_executable(llm_demo ${CMAKE_CURRENT_LIST_DIR}/llm_demo.cpp)
add_executable(embedding_demo ${CMAKE_CURRENT_LIST_DIR}/embedding_demo.cpp)
add_executable(llm_demo ${CMAKE_CURRENT_LIST_DIR}/app/llm_demo.cpp)
add_executable(ppl_demo ${CMAKE_CURRENT_LIST_DIR}/app/ppl_demo.cpp)
add_executable(embedding_demo ${CMAKE_CURRENT_LIST_DIR}/app/embedding_demo.cpp)
IF (NOT MNN_SEP_BUILD)
target_link_libraries(llm_demo ${MNN_DEPS})
target_link_libraries(ppl_demo ${MNN_DEPS})
target_link_libraries(embedding_demo ${MNN_DEPS})
ELSE ()
target_link_libraries(llm_demo ${MNN_DEPS} llm)
target_link_libraries(ppl_demo ${MNN_DEPS} llm)
target_link_libraries(embedding_demo ${MNN_DEPS} llm)
ENDIF ()
File renamed without changes.
Loading

0 comments on commit 85291d0

Please sign in to comment.