diff --git a/.gitignore b/.gitignore
index d11586811..ca1a4320f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -361,3 +361,10 @@ pymnn_build/
# mnncompress generated
MNN_compression_pb2.py
+
+# model path
+model/
+
+# datasets
+datasets/*
+!datasets/*.sh
\ No newline at end of file
diff --git a/CMakeLists.txt b/CMakeLists.txt
index e345987d1..c7768e340 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -20,7 +20,9 @@ endif()
project(MNN VERSION ${MNN_VERSION} LANGUAGES C CXX ASM)
# complier options
set(CMAKE_C_STANDARD 99)
-set(CMAKE_CXX_STANDARD 11)
+IF (NOT (CMAKE_CXX_STANDARD EQUAL 17))
+ set(CMAKE_CXX_STANDARD 11)
+ENDIF()
set(CMAKE_MODULE_PATH
${CMAKE_MODULE_PATH}
"${CMAKE_CURRENT_LIST_DIR}/cmake"
@@ -285,7 +287,7 @@ if(CMAKE_SYSTEM_NAME MATCHES "^Android")
endif()
option(MNN_USE_CPP11 "Enable MNN use c++11" ON)
if (NOT MSVC)
- if(MNN_CUDA AND MNN_SUPPORT_TRANSFORMER_FUSE)
+ if((MNN_CUDA AND MNN_SUPPORT_TRANSFORMER_FUSE) OR (CMAKE_CXX_STANDARD EQUAL 17))
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=gnu99")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
diff --git a/docs/transformers/llm.md b/docs/transformers/llm.md
index 963597517..b0fbd4932 100644
--- a/docs/transformers/llm.md
+++ b/docs/transformers/llm.md
@@ -189,7 +189,7 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt
- visual_model: 当使用VL模型时,visual_model的实际路径为`base_dir + visual_model`,默认为`base_dir + 'visual.mnn'`
- 推理配置
- max_new_tokens: 生成时最大token数,默认为`512`
- - reuse_kv: 多轮对话时是否复用之前对话的`kv cache`,默认为`false`
+ - reuse_kv: 多轮对话时是否复用之前对话的`kv cache`,默认为`false`, 目前只有CPU后端支持设置为`true`.
- quant_qkv: CPU attention 算子中`query, key, value`是否量化,可选为:`0, 1, 2, 3, 4`,默认为`0`,含义如下:
- 0: key和value都不量化
- 1: 使用非对称8bit量化存储key
@@ -205,6 +205,19 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt
- thread_num: CPU推理使用硬件线程数,默认为:`4`; OpenCL推理时使用`68`
- precision: 推理使用精度策略,默认为:`"low"`,尽量使用`fp16`
- memory: 推理使用内存策略,默认为:`"low"`,开启运行时量化
+- Sampler配置
+ - sampler_type: 使用的sampler种类,目前支持`greedy`, `temperature`, `topK`, `topP`, `minP`, `tfs`, `typical`, `penalty`8种基本sampler,外加`mixed`(混合sampler)。当选择`mixed`时,依次执行mixed_samplers中的sampler。默认为`mixed`。
+ - mixed_samplers: 当`sampler_type`为`mixed`时有效,默认为`["topK", "tfs", "typical", "topP", "min_p", "temperature"]`
+ - temperature: `temperature`, `topP`, `minP`, `tfsZ`, `typical`中temerature值,默认为1.0
+ - topK: `topK`中top K 个的个数,默认为40
+ - topP: `topP`中top P的值,默认为0.9
+ - minP: `minP`中min P的值,默认为0.1
+ - tfsZ: `tfs`中Z的值,默认为1.0,即不使用tfs算法
+ - typical: `typical`中p的值,默认为1.0,即不使用typical算法
+ - penalty: `penalty`中对于logits的惩罚项,默认为0.0,即不惩罚
+ - n_gram: `penalty`中最大存储的ngram大小,默认为8
+ - ngram_factor: `penalty`中对于重复ngram的额外惩罚,默认为1.0,即没有额外惩罚
+ - penalty_sampler: `penalty`中最后一步采用的sampling策略,可选"greedy"或"temperature",默认greedy.
##### 配置文件示例
- `config.json`
@@ -216,7 +229,15 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt
"backend_type": "cpu",
"thread_num": 4,
"precision": "low",
- "memory": "low"
+ "memory": "low",
+ "sampler_type": "mixed",
+ "mixed_samplers": ["topK", "tfs", "typical", "topP", "min_p", "temperature"],
+ "temperature": 1.0,
+ "topK": 40,
+ "topP": 0.9,
+ "tfsZ": 1.0,
+ "minP": 0.1,
+ "reuse_kv": true
}
```
- `llm_config.json`
@@ -240,7 +261,8 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt
#### 推理用法
`llm_demo`的用法如下:
-```
+pc端直接推理
+```bash
# 使用config.json
## 交互式聊天
./llm_demo model_dir/config.json
@@ -254,6 +276,13 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt
./llm_demo model_dir/llm.mnn prompt.txt
```
+android手机端adb推理用法:
+```bash
+# 利用adb push将链接库push到手机上
+adb shell mkdir /data/local/tmp/llm
+adb push llm_demo ppl_demo libllm.so libMNN_CL.so libMNN_Express.so libMNN.so tools/cv/libMNNOpenCV.so /data/local/tmp/llm
+```
+
- 对于视觉大模型,在prompt中嵌入图片输入
```
https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg介绍一下图片里的内容
diff --git a/transformers/llm/.gitignore b/transformers/llm/.gitignore
new file mode 100644
index 000000000..000bc68dc
--- /dev/null
+++ b/transformers/llm/.gitignore
@@ -0,0 +1,7 @@
+datasets/*
+!datasets/*.sh
+
+
+!datasets/visualization/
+datasets/visualization/data
+datasets/visualization/pic
\ No newline at end of file
diff --git a/transformers/llm/datasets/get-sharegpt.sh b/transformers/llm/datasets/get-sharegpt.sh
new file mode 100644
index 000000000..fe4be4b5d
--- /dev/null
+++ b/transformers/llm/datasets/get-sharegpt.sh
@@ -0,0 +1,2 @@
+git lfs install
+git clone https://huggingface.co/datasets/shareAI/ShareGPT-Chinese-English-90k
\ No newline at end of file
diff --git a/transformers/llm/datasets/get-wikitext-2-raw.sh b/transformers/llm/datasets/get-wikitext-2-raw.sh
new file mode 100644
index 000000000..33096304c
--- /dev/null
+++ b/transformers/llm/datasets/get-wikitext-2-raw.sh
@@ -0,0 +1,2 @@
+wget https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
+unzip wikitext-2-raw-v1.zip
\ No newline at end of file
diff --git a/transformers/llm/datasets/visualization/stats.py b/transformers/llm/datasets/visualization/stats.py
new file mode 100644
index 000000000..761ef256d
--- /dev/null
+++ b/transformers/llm/datasets/visualization/stats.py
@@ -0,0 +1,116 @@
+import matplotlib.pyplot as plt
+from matplotlib import colors
+from matplotlib.ticker import PercentFormatter
+from matplotlib import cbook
+from matplotlib.axes import Axes
+import pandas as pd
+import numpy as np
+import argparse
+import os
+
+vis_root = "pic"
+
+def remove_blanks(df: pd.DataFrame) -> pd.DataFrame:
+ # Removing unnamed columns using drop function
+ df.drop(df.columns[df.columns.str.contains(
+ 'unnamed', case=False)], axis=1, inplace=True)
+ return df
+def add_turns(df: pd.DataFrame) -> pd.DataFrame:
+ df["turns"] = (1-df.isnull()).sum(axis=1) // 2
+ return df
+def get_max_turn(df: pd.DataFrame) -> int:
+ keys = list(df.keys())
+ return max([int(key.replace("decode", "")) for key in keys if "decode" in key]) + 1
+def add_pd_ratio(df: pd.DataFrame) -> pd.DataFrame:
+ max_turns = get_max_turn(df)
+ for i in range(max_turns):
+ df["pd_ratio{}".format(i)] = df["prefill{}".format(i)] / df["decode{}".format(i)]
+ return df
+def preprocess(file_path: str) -> pd.DataFrame:
+ table = pd.read_csv(file_path)
+ table = remove_blanks(table)
+ table = add_turns(table)
+ table = add_pd_ratio(table)
+ print(table)
+ return table
+
+def draw_distribution(df: pd.DataFrame, file_path: str):
+ turns_bin = df.value_counts(subset=["turns"], sort=False)
+ print(turns_bin)
+ plt.close()
+ plt.rcParams['font.size'] = 10
+ _, ax = plt.subplots()
+ # N is the count in each bin, bins is the lower-limit of the bin
+ N, bins, patches = ax.hist(df["turns"], bins=get_max_turn(df), density=True, align="left", label=True)
+ # We'll color code by height, but you could use any scalar
+ fracs = N / N.max()
+ # we need to normalize the data to 0..1 for the full range of the colormap
+ norm = colors.Normalize(fracs.min(), fracs.max())
+ # Now, we'll loop through our objects and set the color of each accordingly
+ for thisfrac, thispatch in zip(fracs, patches):
+ color = plt.cm.viridis(norm(thisfrac))
+ thispatch.set_facecolor(color)
+ # Now we format the y-axis to display percentage
+ ax.yaxis.set_major_formatter(PercentFormatter(xmax=1))
+ ax.set_xlim((0.5, get_max_turn(df)-0.5))
+ ax.set_xticks(np.arange(1,get_max_turn(df)+1),np.arange(1,get_max_turn(df)+1),rotation=60, fontsize=9)
+ ax.set_ylabel("frequency", fontsize=14)
+ ax.set_xlabel("num of turns", fontsize=14)
+ plt.savefig(file_path, dpi=600)
+ plt.close()
+
+def draw_prefill(df: pd.DataFrame, ax: Axes):
+ stats = [cbook.boxplot_stats(df[df["prefill{}".format(i)].notna()]["prefill{}".format(i)], labels=[i+1])[0]
+ for i in range(get_max_turn(df))]
+ print(stats)
+ ax.bxp(stats, patch_artist=True, boxprops={'facecolor': 'bisque'}, flierprops=dict(marker='o', markersize=2))
+ ax.set_ylim(0,600)
+ ax.set_yticks(np.arange(0,700,100), np.arange(0,700,100), fontsize=9)
+ ax.set_ylabel("prefill", fontsize=12, rotation=90)
+ return
+def draw_decode(df: pd.DataFrame, ax: Axes):
+ stats = [cbook.boxplot_stats(df[df["decode{}".format(i)].notna()]["decode{}".format(i)], labels=[i+1])[0]
+ for i in range(get_max_turn(df))]
+ print(stats)
+ ax.bxp(stats, patch_artist=True, boxprops={'facecolor': 'bisque'}, flierprops=dict(marker='o', markersize=2))
+ ax.set_ylim(0,600)
+ ax.set_yticks(np.arange(0,700,100), np.arange(0,700,100), fontsize=9)
+ ax.set_ylabel("decode", fontsize=12, rotation=90)
+ return
+def draw_pd_ratio(df: pd.DataFrame, ax: Axes):
+ stats = [cbook.boxplot_stats(df[df["pd_ratio{}".format(i)].notna()]["pd_ratio{}".format(i)], labels=[i+1])[0]
+ for i in range(get_max_turn(df))]
+ print(stats)
+ ax.bxp(stats, patch_artist=True, boxprops={'facecolor': 'bisque'}, flierprops=dict(marker='o', markersize=2))
+ ax.plot(np.arange(0,get_max_turn(df)+2), np.ones_like(np.arange(0,get_max_turn(df)+2),dtype=float))
+ ax.set_xlim(0, get_max_turn(df)+1)
+ ax.set_ylim(0, 2.)
+ ax.set_xticks(np.arange(1,get_max_turn(df)), np.arange(1,get_max_turn(df)), rotation=60, fontsize=9)
+ ax.set_yticks([0,0.5,1,2], [0,0.5,1,2], fontsize=9)
+ ax.set_xlabel("round", fontsize=12)
+ ax.set_ylabel("prefill/decode", fontsize=12, rotation=90)
+ return
+def draw_reuse_kv(df: pd.DataFrame, file_path: str):
+ plt.close()
+ _, axs = plt.subplots(3,1,sharex="col")
+ draw_prefill(df, axs[0])
+ draw_decode(df, axs[1])
+ draw_pd_ratio(df, axs[2])
+ plt.savefig(file_path, dpi=1200)
+ plt.close()
+ return
+def draw_no_reuse_kv():
+ return
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--root", type=str, default="./data")
+ parser.add_argument("--name", type=str, default="shareGPT_dialog_stats_common_en.csv")
+ args = parser.parse_args()
+
+ file_path = os.path.join(args.root, args.name)
+ dist_path = os.path.join(vis_root, args.name.split('.')[0]+"_dist.png")
+ pd_dist_path = os.path.join(vis_root, args.name.split('.')[0]+"_pd_dist.png")
+ table = preprocess(file_path)
+ draw_distribution(table, dist_path)
+ draw_reuse_kv(table, pd_dist_path)
\ No newline at end of file
diff --git a/transformers/llm/datasets/visualization/time.py b/transformers/llm/datasets/visualization/time.py
new file mode 100644
index 000000000..27cc0069d
--- /dev/null
+++ b/transformers/llm/datasets/visualization/time.py
@@ -0,0 +1,83 @@
+import matplotlib.pyplot as plt
+from matplotlib import colors
+from matplotlib.ticker import PercentFormatter
+from matplotlib import cbook
+from matplotlib.axes import Axes
+from typing import List, Dict, Tuple
+import pandas as pd
+import numpy as np
+import argparse
+import os
+import re
+from io import StringIO
+
+def split_by_turns(id: str, content: str) -> List[pd.DataFrame]:
+ pattern = "<{id}>\n(.*?){id}>\n".format(id=id)
+ return [pd.read_csv(StringIO(item)) for item in re.findall(pattern, content, flags=re.DOTALL)]
+def preprocess(file_path: str) -> Tuple[List[pd.DataFrame], List[pd.DataFrame]]:
+ content = open(file_path, "rt").read()
+ return split_by_turns("prefill", content), split_by_turns("decode", content)
+def get_max_turn(no_reuse_prefill_record):
+ return max(10, max([len(record) for record in no_reuse_prefill_record]))
+def draw_history_len(ax: Axes, no_reuse_prefill_record: List[pd.DataFrame]):
+ max_round = get_max_turn(no_reuse_prefill_record)
+ history_len = [0 for _ in range(0, max_round)]
+ for turn in range(0, max_round):
+ history_len[turn] = np.median([record["input_token"][turn] - record["prompt_token"][turn]
+ for record in no_reuse_prefill_record if len(record)>=turn+1]).item()
+ plt.plot(np.arange(1, max_round+1), history_len, label="median history len", marker=".", markersize=8)
+ return
+def draw_prefill_bar_chat(ax: Axes, no_reuse, reuse):
+ offset = 0.2
+ max_round = len(no_reuse)
+ no_reuse_med = [np.median(turn) for turn in no_reuse]
+ rects = ax.bar(np.arange(1,max_round+1) + offset, no_reuse_med, offset*2, label="no reuse kv", color="tomato")
+ ax.bar_label(rects, fmt="{:.2f}", padding=4, fontsize=6)
+ reuse_med = [np.median(turn) for turn in reuse]
+ rects = ax.bar(np.arange(1,max_round+1) - offset, reuse_med, offset*2, label="reuse kv", color="springgreen")
+ ax.bar_label(rects, fmt="{:.2f}", padding=4, fontsize=6)
+ return
+def compare_prefill_reuse_kv(no_reuse_prefill_record: List[pd.DataFrame],
+ reuse_prefill_record: List[pd.DataFrame]):
+ plt.close()
+ _,ax1 = plt.subplots()
+ ax2 = ax1.twinx()
+ # plot history_len
+ draw_history_len(ax2, no_reuse_prefill_record)
+ # calculate per turn
+ max_round = get_max_turn(no_reuse_prefill_record)
+ no_reuse = [[] for _ in range(0, max_round)]
+ for turn in range(0, max_round):
+ no_reuse[turn] = [record["response_speed"][turn] for record in no_reuse_prefill_record if len(record)>=turn+1]
+ reuse = [[] for _ in range(0, max_round)]
+ for turn in range(0, max_round):
+ reuse[turn] = [record["response_speed"][turn] for record in reuse_prefill_record if len(record)>=turn+1]
+ # plot the bar chat (with error bar)
+ draw_prefill_bar_chat(ax1, no_reuse, reuse)
+ ax1.set_xticks(np.arange(1,max_round+1),np.arange(1,max_round+1),fontsize=9)
+ ax1.set_ylim(0,100)
+ ax2.set_ylim(0,1000)
+ ax1.legend(loc='upper left', title="prefill response speed")
+ ax2.legend(loc='upper right')
+ ax1.set_ylabel("prefill\nresponse\nspeed", rotation=0, labelpad=12)
+ ax2.set_ylabel("history\nlen", rotation=0, labelpad=8)
+ ax1.set_xlabel("round")
+ plt.title("KV cache reuse for multi-turn chat\neffects on ShareGPT")
+ plt.tight_layout()
+ plt.savefig("./pic/fig.png",dpi=1200)
+ plt.close()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--root", type=str, default="./data")
+ parser.add_argument("--no_reuse", type=str, default="shareGPT_common_en_70k_noreuse.txt")
+ parser.add_argument("--reuse", type=str, default="shareGPT_common_en_70k_reuse.txt")
+ args = parser.parse_args()
+
+ no_reuse_file_path = os.path.join(args.root, args.no_reuse)
+ reuse_file_path = os.path.join(args.root, args.reuse)
+ no_reuse_prefill_record, no_reuse_decode_record = preprocess(no_reuse_file_path)
+ reuse_prefill_record, reuse_decode_record = preprocess(reuse_file_path)
+ # visualize prefill
+ compare_prefill_reuse_kv(no_reuse_prefill_record, reuse_prefill_record)
diff --git a/transformers/llm/engine/CMakeLists.txt b/transformers/llm/engine/CMakeLists.txt
index 5e2703acf..b7ccf1139 100644
--- a/transformers/llm/engine/CMakeLists.txt
+++ b/transformers/llm/engine/CMakeLists.txt
@@ -25,12 +25,15 @@ else()
add_library(llm OBJECT ${SRCS})
endif()
-add_executable(llm_demo ${CMAKE_CURRENT_LIST_DIR}/llm_demo.cpp)
-add_executable(embedding_demo ${CMAKE_CURRENT_LIST_DIR}/embedding_demo.cpp)
+add_executable(llm_demo ${CMAKE_CURRENT_LIST_DIR}/app/llm_demo.cpp)
+add_executable(ppl_demo ${CMAKE_CURRENT_LIST_DIR}/app/ppl_demo.cpp)
+add_executable(embedding_demo ${CMAKE_CURRENT_LIST_DIR}/app/embedding_demo.cpp)
IF (NOT MNN_SEP_BUILD)
target_link_libraries(llm_demo ${MNN_DEPS})
+ target_link_libraries(ppl_demo ${MNN_DEPS})
target_link_libraries(embedding_demo ${MNN_DEPS})
ELSE ()
target_link_libraries(llm_demo ${MNN_DEPS} llm)
+ target_link_libraries(ppl_demo ${MNN_DEPS} llm)
target_link_libraries(embedding_demo ${MNN_DEPS} llm)
ENDIF ()
diff --git a/transformers/llm/engine/embedding_demo.cpp b/transformers/llm/engine/app/embedding_demo.cpp
similarity index 100%
rename from transformers/llm/engine/embedding_demo.cpp
rename to transformers/llm/engine/app/embedding_demo.cpp
diff --git a/transformers/llm/engine/llm_demo.cpp b/transformers/llm/engine/app/llm_demo.cpp
similarity index 65%
rename from transformers/llm/engine/llm_demo.cpp
rename to transformers/llm/engine/app/llm_demo.cpp
index eef45cb75..55a1b09df 100644
--- a/transformers/llm/engine/llm_demo.cpp
+++ b/transformers/llm/engine/app/llm_demo.cpp
@@ -6,21 +6,24 @@
//
#include "llm/llm.hpp"
+#include "evaluation/dataset.hpp"
#define MNN_OPEN_TIME_TRACE
#include
#include
#include
#include
#include
+#include
using namespace MNN::Transformer;
static void trace_prepare(Llm* llm) {
MNN_PRINT("Prepare for resize opt Begin\n");
llm->trace(true);
std::ostringstream cacheOs;
- llm->generate({200, 200}, &cacheOs, "");
+ llm->generate(std::initializer_list{200, 200}, &cacheOs, "");
MNN_PRINT("Prepare for resize opt End\n");
llm->trace(false);
+ llm->reset();
}
static void tuning_prepare(Llm* llm) {
@@ -29,55 +32,7 @@ static void tuning_prepare(Llm* llm) {
MNN_PRINT("Prepare for tuning opt End\n");
}
-std::vector> parse_csv(const std::vector& lines) {
- std::vector> csv_data;
- std::string line;
- std::vector row;
- std::string cell;
- bool insideQuotes = false;
- bool startCollecting = false;
-
- // content to stream
- std::string content = "";
- for (auto line : lines) {
- content = content + line + "\n";
- }
- std::istringstream stream(content);
-
- while (stream.peek() != EOF) {
- char c = stream.get();
- if (c == '"') {
- if (insideQuotes && stream.peek() == '"') { // quote
- cell += '"';
- stream.get(); // skip quote
- } else {
- insideQuotes = !insideQuotes; // start or end text in quote
- }
- startCollecting = true;
- } else if (c == ',' && !insideQuotes) { // end element, start new element
- row.push_back(cell);
- cell.clear();
- startCollecting = false;
- } else if ((c == '\n' || stream.peek() == EOF) && !insideQuotes) { // end line
- row.push_back(cell);
- csv_data.push_back(row);
- cell.clear();
- row.clear();
- startCollecting = false;
- } else {
- cell += c;
- startCollecting = true;
- }
- }
- return csv_data;
-}
-
static int benchmark(Llm* llm, const std::vector& prompts) {
- int prompt_len = 0;
- int decode_len = 0;
- int64_t prefill_time = 0;
- int64_t decode_time = 0;
- // llm->warmup();
for (int i = 0; i < prompts.size(); i++) {
const auto& prompt = prompts[i];
// prompt start with '#' will be ignored
@@ -85,20 +40,14 @@ static int benchmark(Llm* llm, const std::vector& prompts) {
continue;
}
llm->response(prompt);
- prompt_len += llm->prompt_len_;
- decode_len += llm->gen_seq_len_;
- prefill_time += llm->prefill_us_;
- decode_time += llm->decode_us_;
}
- float prefill_s = prefill_time / 1e6;
- float decode_s = decode_time / 1e6;
printf("\n#################################\n");
- printf("prompt tokens num = %d\n", prompt_len);
- printf("decode tokens num = %d\n", decode_len);
- printf("prefill time = %.2f s\n", prefill_s);
- printf(" decode time = %.2f s\n", decode_s);
- printf("prefill speed = %.2f tok/s\n", prompt_len / prefill_s);
- printf(" decode speed = %.2f tok/s\n", decode_len / decode_s);
+ printf("prompt tokens num = %d\n", llm->getTotalPromptLen());
+ printf("decode tokens num = %d\n", llm->getTotalDecodeLen());
+ printf("prefill time = %.2f s\n", llm->getTotalPrefillTime());
+ printf(" decode time = %.2f s\n", llm->getTotalDecodeTime());
+ printf("prefill speed = %.2f tok/s\n", llm->average_prefill_speed());
+ printf(" decode speed = %.2f tok/s\n", llm->average_decode_speed());
printf("##################################\n");
return 0;
}
diff --git a/transformers/llm/engine/app/ppl_demo.cpp b/transformers/llm/engine/app/ppl_demo.cpp
new file mode 100644
index 000000000..393b5b86d
--- /dev/null
+++ b/transformers/llm/engine/app/ppl_demo.cpp
@@ -0,0 +1,61 @@
+//
+// ppl_demo.cpp
+//
+// Created by MNN on 2023/03/24.
+// ZhaodeWang
+//
+
+#include "llm/llm.hpp"
+#define MNN_OPEN_TIME_TRACE
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+using namespace MNN::Transformer;
+static void trace_prepare(Llm* llm) {
+ MNN_PRINT("Prepare for resize opt Begin\n");
+ llm->trace(true);
+ std::ostringstream cacheOs;
+ llm->response("Hello", &cacheOs);
+ MNN_PRINT("Prepare for resize opt End\n");
+ llm->trace(false);
+ llm->reset();
+}
+
+// parse json
+
+static int ppl_eval(Llm* llm, std::string prompt_file, std::ofstream* perfOS) {
+ std::cout << "prompt file is " << prompt_file << std::endl;
+ // ppl evaluation
+ std::vector ppls = llm->perplexity(prompt_file, perfOS);
+ float mean_ppl = 0.f;
+ for (int j = 0; j < ppls.size(); ++j) mean_ppl += ppls[j];
+ mean_ppl /= ppls.size();
+ std::cout << mean_ppl << std::endl;
+ return 0;
+}
+
+int main(int argc, const char* argv[]) {
+ if (argc < 3) {
+ std::cout << "Usage: " << argv[0] << " config.json ppl-prompt.txt [perf.txt]" << std::endl;
+ return 0;
+ }
+ std::string config_path = argv[1];
+ std::cout << "config path is " << config_path << std::endl;
+ std::unique_ptr llm(Llm::createLLM(config_path));
+ {
+ AUTOTIME;
+ llm->load();
+ }
+ {
+ AUTOTIME;
+ trace_prepare(llm.get());
+ }
+ std::string prompt_file = argv[2];
+ std::unique_ptr perfOS(nullptr);
+ if (argc == 4) { perfOS.reset(new std::ofstream(argv[3])); }
+ return ppl_eval(llm.get(), prompt_file, perfOS.get());
+}
diff --git a/transformers/llm/engine/include/evaluation/MemMonitor.hpp b/transformers/llm/engine/include/evaluation/MemMonitor.hpp
new file mode 100644
index 000000000..7750eb56c
--- /dev/null
+++ b/transformers/llm/engine/include/evaluation/MemMonitor.hpp
@@ -0,0 +1,42 @@
+#ifndef MEMMONITOR_hpp
+#define MEMMONITOR_hpp
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#define BUFFER_SIZE 256
+
+struct MemoryInfo {
+ // in MB
+ float total_phys_mem;
+ float free_phys_mem;
+ float total_swap;
+ float free_swap;
+ float process_resident_set_size;
+ float process_swap;
+ float process_virtual_mem_total;
+ float process_virtual_mem_used;
+};
+
+
+#if defined(__ANDROID__) || defined(linux) || defined(__APPLE__) || defined(__MACOSX)
+#define SELF_FILE "/proc/self/status"
+#define MEMINFO_FILE "/proc/meminfo"
+#endif // linux
+
+int readMemInfo(MemoryInfo *mem_info);
+
+int readProcStatus(MemoryInfo *mem_info);
+
+void printMemoryInfo(const MemoryInfo *mem_info);
+
+float getSysMemInc(MemoryInfo* prev, MemoryInfo* now);
+
+float getProcMem(MemoryInfo* info);
+
+#endif
\ No newline at end of file
diff --git a/transformers/llm/engine/include/evaluation/dataset.hpp b/transformers/llm/engine/include/evaluation/dataset.hpp
new file mode 100644
index 000000000..b9585fe71
--- /dev/null
+++ b/transformers/llm/engine/include/evaluation/dataset.hpp
@@ -0,0 +1,33 @@
+#ifndef LLM_DATASET_hpp
+#define LLM_DATASET_hpp
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include "llm/llm.hpp"
+
+#include
+
+namespace MNN {
+namespace Transformer {
+
+
+// parse csv
+MNN_PUBLIC std::vector> parse_csv(const std::vector& lines);
+void parse_jsonl(std::string prompt_file, std::vector>>& dialogs);
+
+std::string getPPLType(std::string dataset_name);
+std::vector rowsplit(std::string prompt_file);
+std::vector plaintext(std::string prompt_file);
+std::vector wikitext(std::string prompt_file);
+std::vector>> shareGPT(std::string prompt_file, int sample_size=-1); // -1: no sampling
+
+} // Transformer
+} // MNN
+
+#endif // LLM_DATASET_hpp
\ No newline at end of file
diff --git a/transformers/llm/engine/include/evaluation/evaluation.hpp b/transformers/llm/engine/include/evaluation/evaluation.hpp
new file mode 100644
index 000000000..9ecf0c335
--- /dev/null
+++ b/transformers/llm/engine/include/evaluation/evaluation.hpp
@@ -0,0 +1,68 @@
+
+
+#ifndef TRANSFORMER_EVALUATION_hpp
+#define TRANSFORMER_EVALUATION_hpp
+
+#include
+#include
+#include "MemMonitor.hpp"
+
+namespace MNN {
+namespace Transformer {
+
+#define MICRO_TO_MILLI 1e-3f
+#define MILLI_TO_MICRO 1000
+#define MICRO_TO_SEC 1e-6f
+#define SEC_TO_MICRO 1000000
+
+#define MEGA_TO_GIGA (1/1024.f)
+#define GIGA_TO_MEGA 1024.f
+#define KILLO_TO_GIGA (1/1024.f/1024.f)
+#define GIGA_TO_KILLO (1024.f*1024.f)
+#define KILLO_TO_MEGA (1/1024.f)
+#define MEGA_TO_KILLO 1024.f
+#define BYTE_TO_MEGA (1/1024.f/1024.f)
+#define MEGA_TO_BYTE (1024.f*1024.f)
+
+struct PrefillTimePerformance {
+ size_t prefill_prev_token_ = 0;
+ size_t prefill_token_ = 0;
+ size_t prefill_us_ = 0;
+};
+
+struct DecodeTimePerformance {
+ size_t decode_prev_token_ = 0;
+ size_t decode_us_ = 0;
+};
+
+struct TimePerformance {
+ std::vector prefill_record_;
+ std::vector decode_record_;
+ std::vector prompt_record_;
+};
+
+void appendNewPromptRecord(struct TimePerformance* perf, int input_len, bool reuse_kv);
+
+struct PrefillMemPerformance {
+ size_t prefill_prev_token_ = 0;
+ size_t prefill_token_ = 0;
+ float prefill_MB_ = 0;
+};
+
+struct DecodeMemPerformance {
+ size_t decode_prev_token_ = 0;
+ float decode_MB_ = 0;
+};
+
+struct MemPerformance {
+ std::vector prefill_record_;
+ std::vector decode_record_;
+};
+
+void mergePerformance(struct TimePerformance* dst, struct TimePerformance* src);
+void mergePerformance(struct MemPerformance* dst, struct MemPerformance* src);
+void clearPerformance(struct TimePerformance* perf);
+void clearPerformance(struct MemPerformance* perf);
+} // namespace Transformer
+} // namespace MNN
+#endif // TRANSFORMER_EVALUATION_hpp
\ No newline at end of file
diff --git a/transformers/llm/engine/include/llm/llm.hpp b/transformers/llm/engine/include/llm/llm.hpp
index 6d0ac2e99..b0d970aa0 100644
--- a/transformers/llm/engine/include/llm/llm.hpp
+++ b/transformers/llm/engine/include/llm/llm.hpp
@@ -18,6 +18,7 @@
#include
#include
+#include "evaluation/evaluation.hpp"
#include
#include
#include
@@ -28,6 +29,41 @@ namespace Transformer {
class Tokenizer;
class Pipeline;
class LlmConfig;
+class Sampler;
+class PromptLib;
+struct TimePerformance;
+
+
+//
+#define PromptItem std::pair
+
+class MNN_PUBLIC LlmSessionInfo {
+public:
+ // Llm::forward needs, for mask and embedding.
+ int all_seq_len_=0, gen_seq_len_=0;
+ // Sampler needs
+ std::vector tokens;
+ // PromptLib needs
+ std::vector mHistory;
+ std::vector mInputs;
+ // Performance needs
+ struct TimePerformance mTimePerformance;
+public:
+ LlmSessionInfo():all_seq_len_(0),gen_seq_len_(0){}
+ void resetSamplerFields();
+ void resetPromptFields();
+ void resetPerformanceFields();
+ void print_speed(std::ostream* os);
+ float average_total_speed();
+ float average_prefill_speed();
+ float average_decode_speed();
+ float getTotalPrefillTime();
+ float getTotalDecodeTime();
+ int getTotalPromptLen();
+ int getTotalDecodeLen();
+};
+
+
class DiskEmbedding;
enum TuneType {
@@ -36,26 +72,30 @@ enum TuneType {
};
class MNN_PUBLIC Llm {
- using PromptItem = std::pair; //
+public:
+ std::shared_ptr mSampler;
+ std::shared_ptr mPromptLib;
+ std::vector mLlmSessionInfos; // Llm conversation session information. Currently, only mLlmSessionInfos[0] is allowed!
public:
Llm(std::shared_ptr config) : config_(config) {}
virtual ~Llm();
static Llm* createLLM(const std::string& config_path);
- void chat();
+ void chat(bool session_by_line = false, bool from_file = false,
+ std::istream* is = &std::cin, std::ostream* os = &std::cout,
+ const char* end_with = "\n", std::string exit_prompt = "/exit", std::string reset_token = "/reset");
void reset();
void trace(bool start);
void tuning(TuneType type, std::vector candidates);
virtual void load();
- MNN::Express::VARP forward(const std::vector& input_ids);
- int sample(MNN::Express::VARP logits, const std::vector& pre_ids);
- std::string apply_prompt_template(const std::string& user_content) const;
- std::string apply_chat_template(const std::vector& chat_prompts) const;
+ MNN::Express::VARP forward(const std::vector& input_ids, bool is_prefill=true);
std::string response(const std::string& user_content, std::ostream* os = &std::cout, const char* end_with = nullptr);
- std::string response(const std::vector& chat_prompts, std::ostream* os = &std::cout, const char* end_with = nullptr);
+ std::string generate(const std::string& prompt, std::ostream* os = &std::cout, const char* end_with = "\n");
+ std::string generate(const std::vector& input_ids, std::ostream* os = &std::cout, const char* end_with = "\n");
void generate_init();
- std::string generate(const std::vector& input_ids, std::ostream* os, const char* end_with);
- std::vector generate(const std::vector& input_ids, int max_new_tokens = -1);
+ std::string generateTrace(const std::vector& input_ids, std::ostream* os, const char* end_with);
void print_speed();
+ void print_speed(std::ostream* os);
+ std::vector perplexity(std::string prompt_file, std::ostream* statsOS = nullptr);
// config function
std::string dump_config();
bool set_config(const std::string& content);
@@ -70,16 +110,18 @@ class MNN_PUBLIC Llm {
virtual std::vector tokenizer_encode(const std::string& query, bool use_template = true);
friend class Pipeline;
public:
- // forward info
- int prompt_len_ = 0;
- int gen_seq_len_ = 0;
- int all_seq_len_ = 0;
- std::vector history_ids_;
- // time
- int64_t prefill_us_ = 0;
- int64_t decode_us_ = 0;
bool is_single_ = true;
bool attention_fused_ = true;
+ bool reuse_kv() const;
+public:
+ // time profile
+ float average_total_speed();
+ float average_prefill_speed();
+ float average_decode_speed();
+ float getTotalPrefillTime();
+ float getTotalDecodeTime();
+ int getTotalPromptLen();
+ int getTotalDecodeLen();
protected:
std::shared_ptr config_;
std::shared_ptr tokenizer_;
@@ -96,6 +138,10 @@ class MNN_PUBLIC Llm {
virtual MNN::Express::VARP gen_attention_mask(int seq_len);
virtual MNN::Express::VARP gen_position_ids(int seq_len);
bool mTracing = false;
+protected:
+ bool getUserPrompt(bool from_file, std::istream* is, std::string& user_str);
+ void chat_init();
+ void chat_reset();
};
// Embedding start
diff --git a/transformers/llm/engine/ios/mnn-llm/mnn-llm/LLMInferenceEngineWrapper.mm b/transformers/llm/engine/ios/mnn-llm/mnn-llm/LLMInferenceEngineWrapper.mm
index 4561338f8..5db6a9c81 100644
--- a/transformers/llm/engine/ios/mnn-llm/mnn-llm/LLMInferenceEngineWrapper.mm
+++ b/transformers/llm/engine/ios/mnn-llm/mnn-llm/LLMInferenceEngineWrapper.mm
@@ -89,26 +89,16 @@ - (void)processInput:(NSString *)input withStreamHandler:(StreamOutputHandler)ha
}
prompts.push_back(prompt);
}
- int prompt_len = 0;
- int decode_len = 0;
- int64_t prefill_time = 0;
- int64_t decode_time = 0;
for (int i = 0; i < prompts.size(); i++) {
llm->response(prompts[i], &os, "\n");
- prompt_len += llm->prompt_len_;
- decode_len += llm->gen_seq_len_;
- prefill_time += llm->prefill_us_;
- decode_time += llm->decode_us_;
}
- float prefill_s = prefill_time / 1e6;
- float decode_s = decode_time / 1e6;
os << "\n#################################\n"
- << "prompt tokens num = " << prompt_len << "\n"
- << "decode tokens num = " << decode_len << "\n"
- << "prefill time = " << std::fixed << std::setprecision(2) << prefill_s << " s\n"
- << " decode time = " << std::fixed << std::setprecision(2) << decode_s << " s\n"
- << "prefill speed = " << std::fixed << std::setprecision(2) << prompt_len / prefill_s << " tok/s\n"
- << " decode speed = " << std::fixed << std::setprecision(2) << decode_len / decode_s << " tok/s\n"
+ << "prompt tokens num = " << llm->getTotalPromptLen() << "\n"
+ << "decode tokens num = " << llm->getTotalDecodeLen() << "\n"
+ << "prefill time = " << std::fixed << std::setprecision(2) << llm->getTotalPrefillTime() << " s\n"
+ << " decode time = " << std::fixed << std::setprecision(2) << llm->getTotalDecodeTime() << " s\n"
+ << "prefill speed = " << std::fixed << std::setprecision(2) << llm->average_prefill_speed() << " tok/s\n"
+ << " decode speed = " << std::fixed << std::setprecision(2) << llm->average_decode_speed() << " tok/s\n"
<< "##################################\n";
os << "";
} else {
diff --git a/transformers/llm/engine/src/LlmSessionInfo.cpp b/transformers/llm/engine/src/LlmSessionInfo.cpp
new file mode 100644
index 000000000..a46dc5d8c
--- /dev/null
+++ b/transformers/llm/engine/src/LlmSessionInfo.cpp
@@ -0,0 +1,87 @@
+
+#include "llm/llm.hpp"
+
+namespace MNN {
+namespace Transformer {
+
+// LlmSessionInfo starts
+void LlmSessionInfo::resetSamplerFields() {
+ all_seq_len_ = 0;
+ gen_seq_len_ = 0;
+ tokens.clear();
+}
+void LlmSessionInfo::resetPromptFields() {
+ mHistory.clear();
+ mInputs.clear();
+}
+void LlmSessionInfo::resetPerformanceFields() {
+ clearPerformance(&mTimePerformance);
+}
+float LlmSessionInfo::average_total_speed() {
+ return (getTotalPromptLen()+getTotalDecodeLen())/(getTotalPrefillTime()+getTotalDecodeTime());
+}
+float LlmSessionInfo::average_prefill_speed() {
+ // prefill response rate
+ return getTotalPromptLen()/getTotalPrefillTime();
+}
+float LlmSessionInfo::average_decode_speed() {
+ return getTotalDecodeLen()/getTotalDecodeTime();
+}
+float LlmSessionInfo::getTotalPrefillTime() {
+ float sum = 0.f;
+ for (auto record : mTimePerformance.prefill_record_) {
+ sum += ((float)record.prefill_us_)*MICRO_TO_SEC;
+ }
+ return sum;
+}
+float LlmSessionInfo::getTotalDecodeTime() {
+ float sum = 0.0f;
+ for (auto record : mTimePerformance.decode_record_) {
+ sum += ((float)record.decode_us_)*MICRO_TO_SEC;
+ }
+ return sum;
+}
+int LlmSessionInfo::getTotalPromptLen() {
+ int prompt_len = 0;
+ if (mTimePerformance.prefill_record_.size() != mTimePerformance.prompt_record_.size()) {
+ for (auto record : mTimePerformance.prefill_record_) {
+ prompt_len += record.prefill_token_;
+ }
+ } else {
+ for (int r=0; r < mTimePerformance.prompt_record_.size(); ++r) {
+ prompt_len += mTimePerformance.prompt_record_[r];
+ }
+ }
+ return prompt_len;
+}
+int LlmSessionInfo::getTotalDecodeLen() {
+ return mTimePerformance.decode_record_.size();
+}
+void LlmSessionInfo::print_speed(std::ostream* os) {
+ // prefill statistics
+ (*os) << "" << std::endl;
+ if (mTimePerformance.prefill_record_.size() != mTimePerformance.prompt_record_.size()) {
+ (*os) << "prev_token,input_token,response_speed" << std::endl;
+ for (auto record : mTimePerformance.prefill_record_) {
+ (*os) << record.prefill_prev_token_ << "," << record.prefill_token_ << "," << record.prefill_token_/(((float)record.prefill_us_)*MICRO_TO_SEC) << std::endl;
+ }
+ } else {
+ (*os) << "prev_token,input_token,prompt_token,response_speed" << std::endl;
+ for (int r=0; r < mTimePerformance.prompt_record_.size(); ++r) {
+ auto record = mTimePerformance.prefill_record_[r];
+ auto prompt_len = mTimePerformance.prompt_record_[r];
+ (*os) << record.prefill_prev_token_ << "," << record.prefill_token_ << "," << prompt_len << "," << prompt_len/(((float)record.prefill_us_)*MICRO_TO_SEC) << std::endl;
+ }
+ }
+ (*os) << "" << std::endl;
+ // decode statistics
+ (*os) << "" << std::endl;
+ (*os) << "prev_token,response_speed" << std::endl;
+ for (auto record : mTimePerformance.decode_record_) {
+ (*os) << record.decode_prev_token_ << "," << 1./(((float)record.decode_us_)*MICRO_TO_SEC) << std::endl;
+ }
+ (*os) << "" << std::endl;
+}
+
+} // Transformer
+} // MNN
\ No newline at end of file
diff --git a/transformers/llm/engine/src/dataset.cpp b/transformers/llm/engine/src/dataset.cpp
new file mode 100644
index 000000000..c4d0db45a
--- /dev/null
+++ b/transformers/llm/engine/src/dataset.cpp
@@ -0,0 +1,223 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include "evaluation/dataset.hpp"
+#include
+#include
+#include
+
+namespace MNN {
+namespace Transformer {
+
+
+// parse file
+// csv json
+
+// parse csv
+std::vector> parse_csv(const std::vector& lines) {
+ std::vector> csv_data;
+ std::string line;
+ std::vector row;
+ std::string cell;
+ bool insideQuotes = false;
+ bool startCollecting = false;
+
+ // content to stream
+ std::string content = "";
+ for (auto line : lines) {
+ content = content + line + "\n";
+ }
+ std::istringstream stream(content);
+
+ while (stream.peek() != EOF) {
+ char c = stream.get();
+ if (c == '"') {
+ if (insideQuotes && stream.peek() == '"') { // quote
+ cell += '"';
+ stream.get(); // skip quote
+ } else {
+ insideQuotes = !insideQuotes; // start or end text in quote
+ }
+ startCollecting = true;
+ } else if (c == ',' && !insideQuotes) { // end element, start new element
+ row.push_back(cell);
+ cell.clear();
+ startCollecting = false;
+ } else if ((c == '\n' || stream.peek() == EOF) && !insideQuotes) { // end line
+ row.push_back(cell);
+ csv_data.push_back(row);
+ cell.clear();
+ row.clear();
+ startCollecting = false;
+ } else {
+ cell += c;
+ startCollecting = true;
+ }
+ }
+ return csv_data;
+}
+
+// dialog, turn,
+void parse_jsonl(std::string prompt_file, std::vector>>& dialogs) {
+ std::ifstream prompt_fs(prompt_file);
+ std::string prompt;
+ while(std::getline(prompt_fs, prompt)) {
+ rapidjson::Document document;
+ document.Parse(prompt.c_str());
+ std::vector> cnv;
+ if(document.HasMember("conversation")) {
+ auto& value = document["conversation"];
+ if (value.IsArray()) {
+ for (auto& v : value.GetArray()) {
+ if (v.IsObject()) {
+ std::vector result;
+ for (auto itr = v.MemberBegin(); itr != v.MemberEnd(); ++itr) {
+ // {"human"/"user": , "assistant": }
+ result.push_back(std::make_pair(itr->name.GetString(), itr->value.GetString()));
+ }
+ cnv.push_back(result);
+ }
+ }
+ }
+ }
+ dialogs.push_back(cnv);
+ }
+}
+
+void write_jsonl(std::string prompt_file, const std::vector>>& dialogs) {
+ std::ofstream prompt_fs(prompt_file);
+ for(auto& dialog : dialogs) {
+ rapidjson::Document document;
+ document.SetObject();
+ rapidjson::Value conversation(rapidjson::kArrayType);
+ conversation.SetArray();
+ for (auto& turn : dialog) {
+ rapidjson::Value sentence(rapidjson::kObjectType);
+ sentence.SetObject();
+ for (auto& role : turn) {
+ sentence.AddMember(rapidjson::Value(role.first.c_str(), document.GetAllocator()),
+ rapidjson::Value(role.second.c_str(), document.GetAllocator()), document.GetAllocator());
+ }
+ conversation.PushBack(sentence, document.GetAllocator());
+ }
+ document.AddMember("conversation", conversation, document.GetAllocator());
+ // write to file
+ rapidjson::StringBuffer buffer;
+ rapidjson::Writer writer(buffer);
+ document.Accept(writer);
+ prompt_fs << buffer.GetString() << std::endl;
+ }
+}
+
+
+// dataset
+// wikitext, ShareGPT
+
+std::string getPPLType(std::string dataset_name) {
+ if (dataset_name == "wikitext"
+ || dataset_name == "plaintext"
+ || dataset_name == "rowsplit") {
+ return "text";
+ } else if (dataset_name == "shareGPT") {
+ return "chat";
+ } else {
+ // default chat
+ return "chat";
+ }
+}
+
+std::vector plaintext(std::string prompt_file) {
+ // split by line
+ std::ifstream prompt_fs(prompt_file);
+ std::vector prompts;
+ std::string prompt;
+ prompts.push_back("");
+ while (std::getline(prompt_fs, prompt)) {
+ if (prompt.back() == '\r' || prompt.back() == '\n') {
+ prompt.pop_back();
+ }
+ // concatenate.
+ prompts.back() += prompt + "\n";
+ }
+ return prompts;
+}
+
+std::vector rowsplit(std::string prompt_file) {
+ // split by line
+ std::ifstream prompt_fs(prompt_file);
+ std::vector prompts;
+ std::string prompt;
+ while (std::getline(prompt_fs, prompt)) {
+ if (prompt.back() == '\r' || prompt.back() == '\n') {
+ prompt.pop_back();
+ }
+ prompts.push_back(prompt);
+ }
+ return prompts;
+}
+
+// wikitext
+void removeSubstrs(std::string& s, std::string p) {
+ std::string::size_type n = p.length();
+ for (std::string::size_type i = s.find(p); i != std::string::npos; i = s.find(p))
+ s.erase(i, n);
+}
+std::vector wikitext(std::string prompt_file) {
+ // split wiki text into " = " first-level column.
+ std::ifstream prompt_fs(prompt_file);
+ std::vector prompts;
+ std::string prompt;
+ while (std::getline(prompt_fs, prompt)) {
+ if (prompt.back() == '\r' || prompt.back() == '\n') {
+ prompt.pop_back();
+ }
+ if (prompt.size() < 4) continue;
+ removeSubstrs(prompt, "@-@");
+ if ((prompts.size() == 0) \
+ || (prompt.size() >= 4 \
+ && prompt.at(0) == ' ' \
+ && prompt.at(1) == '=' \
+ && prompt.at(2) == ' ' \
+ && prompt.at(3) != '=')) {
+ // first-level column.
+ prompts.push_back(prompt);
+ } else {
+ // concatenate.
+ prompts.back() += "\n" + prompt;
+ }
+ }
+ return prompts;
+}
+
+std::string genSampleName(std::string oriName, int sample_size) {
+ const size_t last_slash_idx = oriName.rfind('.');
+ auto stem = oriName.substr(0, last_slash_idx);
+ return stem + "_sample" + std::to_string(sample_size) + ".jsonl";
+}
+
+std::vector>> shareGPT(std::string prompt_file, int sample_size) {
+ std::vector>> dialogs, dataset;
+ parse_jsonl(prompt_file, dialogs);
+ // randomly sample a subset
+ if (sample_size > 0 && sample_size < dialogs.size()){
+ std::random_device rd;
+ std::mt19937 g(rd());
+ std::shuffle(dialogs.begin(), dialogs.end(), g);
+ dataset.insert(dataset.end(), dialogs.begin(), dialogs.begin() + sample_size);
+ dialogs = dataset;
+ // store dialogs to file
+ write_jsonl(genSampleName(prompt_file, sample_size), dialogs);
+ }
+ return dialogs;
+}
+
+
+} // Transformer
+} // MNN
diff --git a/transformers/llm/engine/src/evaluation.cpp b/transformers/llm/engine/src/evaluation.cpp
new file mode 100644
index 000000000..cd2524a4e
--- /dev/null
+++ b/transformers/llm/engine/src/evaluation.cpp
@@ -0,0 +1,30 @@
+
+
+#include
+#include
+#include "evaluation/evaluation.hpp"
+
+namespace MNN {
+namespace Transformer {
+
+void clearPerformance(struct TimePerformance* perf) {
+ perf->prefill_record_.clear();
+ perf->decode_record_.clear();
+ perf->prompt_record_.clear();
+}
+void appendNewPromptRecord(struct TimePerformance* perf, int input_len, bool reuse_kv) {
+ if (reuse_kv) {
+ perf->prompt_record_.push_back(input_len);
+ } else {
+ // not reuse kv
+ if (!perf->decode_record_.empty()) {
+ perf->prompt_record_.push_back(input_len - (perf->decode_record_.back().decode_prev_token_+1));
+ } else {
+ // first prefill
+ perf->prompt_record_.push_back(input_len);
+ }
+ }
+}
+
+} // Transformer
+} // MNN
\ No newline at end of file
diff --git a/transformers/llm/engine/src/llm.cpp b/transformers/llm/engine/src/llm.cpp
index fd3a9d037..6a1d7fdf8 100644
--- a/transformers/llm/engine/src/llm.cpp
+++ b/transformers/llm/engine/src/llm.cpp
@@ -16,6 +16,9 @@
#include
#include "cpp/ExprDebug.hpp"
#include "llm/llm.hpp"
+#include "evaluation/evaluation.hpp"
+#include "sampler.hpp"
+#include "prompt.hpp"
#include "tokenizer.hpp"
#include "llmconfig.hpp"
// 0: no debug, 1: test op time, 2: print tensor info
@@ -266,7 +269,7 @@ void Llm::load() {
}
MNN_PRINT("Load Module Done!\n");
} else {
- MNN_ERROR("Split version is depercerate\n");
+ MNN_ERROR("Split version is deprecated\n");
}
decode_modules_.resize(modules_.size());
for (int v=0; v candidates) {
if (logits->getInfo()->size == 0) {
return;
}
- auto token = sample(logits, {});
+ // no need for sampling here, because metal OP does not affects much in sampling.
auto et = std::chrono::system_clock::now();
int64_t time = std::chrono::duration_cast(et - st).count();
if(time < min_time) {
@@ -380,7 +394,9 @@ void Llm::tuning(TuneType type, std::vector candidates) {
runtime_manager_->setHint(MNN::Interpreter::OP_ENCODER_NUMBER_FOR_COMMIT, prefer_candidate);
}
-VARP Llm::forward(const std::vector& input_ids) {
+VARP Llm::forward(const std::vector& input_ids, bool is_prefill) {
+ if (is_prefill) current_modules_ = prefill_modules_;
+ else current_modules_ = decode_modules_;
int seq_len = input_ids.size();
auto attention_mask = gen_attention_mask(seq_len);
auto position_ids = gen_position_ids(seq_len);
@@ -401,245 +417,134 @@ VARP Llm::forward(const std::vector& input_ids) {
past_key_values_[0] = outputs[1];
}
} else {
- MNN_ERROR("Split models is depercarate\n");
+ MNN_ERROR("Split models is deprecated\n");
return nullptr;
}
- all_seq_len_ += seq_len;
- gen_seq_len_++;
+ // sequence length is handled in forward.
+ mLlmSessionInfos[0].all_seq_len_ += seq_len;
+ mLlmSessionInfos[0].gen_seq_len_++;
return logits;
}
-int Llm::sample(VARP logits, const std::vector& pre_ids) {
- std::unordered_set ids_set(pre_ids.begin(), pre_ids.end());
- auto scores = (float*)(logits->readMap());
- auto size = logits->getInfo()->size;
- // repetition penalty
- const float repetition_penalty = 1.1;
- for (auto id : ids_set) {
- float score = scores[id];
- scores[id] = score < 0 ? score * repetition_penalty : score / repetition_penalty;
- }
- // argmax
- float max_score = 0;
- int token_id = 0;
- for (int i = 0; i < size; i++) {
- float score = scores[i];
- if (score > max_score) {
- max_score = score;
- token_id = i;
- }
- }
- return token_id;
-}
-
-static std::string apply_template(std::string prompt_template, const std::string& content, const std::string& role = "") {
- if (prompt_template.empty()) return content;
- if (!role.empty()) {
- const std::string placeholder = "%r";
- size_t start_pos = prompt_template.find(placeholder);
- if (start_pos == std::string::npos) return content;
- prompt_template.replace(start_pos, placeholder.length(), role);
- }
- const std::string placeholder = "%s";
- size_t start_pos = prompt_template.find(placeholder);
- if (start_pos == std::string::npos) return content;
- prompt_template.replace(start_pos, placeholder.length(), content);
- return prompt_template;
-}
-
-std::string Llm::apply_prompt_template(const std::string& user_content) const {
- auto chat_prompt = config_->prompt_template();
- return apply_template(chat_prompt, user_content);
+// < "app_type": "chat"
+bool Llm::getUserPrompt(bool from_file, std::istream* is, std::string& user_str) {
+ if (!from_file) std::cout << "\nQ: ";
+ return (bool)std::getline(*is, user_str);
}
-std::string Llm::apply_chat_template(const std::vector& chat_prompts) const {
- auto chat_template = config_->chat_template();
- std::string prompt_result;
- auto iter = chat_prompts.begin();
- for (; iter != chat_prompts.end() - 1; ++iter) {
- prompt_result += apply_template(chat_template, iter->second, iter->first);
- }
- if (iter->first == "user") {
- prompt_result += apply_prompt_template(iter->second);
- } else {
- prompt_result += apply_template(chat_template, iter->second, iter->first);
- }
- return prompt_result;
-}
-
-void Llm::chat() {
- std::vector history;
- history.push_back(std::make_pair("system", "You are a helpful assistant."));
- while (true) {
- std::cout << "\nQ: ";
- std::string user_str;
- std::cin >> user_str;
- if (user_str == "/exit") {
+void Llm::chat(bool session_by_line, bool from_file,
+ std::istream* is, std::ostream* os,
+ const char* end_with, std::string exit_token, std::string reset_token) {
+ // handle system prompt
+ reset();
+ std::string user_str;
+ while (getUserPrompt(from_file, is, user_str)) {
+ // whether to end
+ if (user_str == exit_token) {
+ reset();
break;
}
- if (user_str == "/reset") {
- history.resize(1);
- std::cout << "\nA: reset done." << std::endl;
+ // whether to reset
+ if (session_by_line || user_str == reset_token) {
+ reset();
+ if (!from_file) std::cout << "\nreset done." << std::endl;
continue;
}
- std::cout << "\nA: " << std::flush;
- if (config_->reuse_kv()) {
- response(user_str);
- } else {
- history.emplace_back(std::make_pair("user", user_str));
- auto assistant_str = response(history);
- history.emplace_back(std::make_pair("assistant", assistant_str));
- }
- std::cout << std::endl;
+ // get answer
+ (*os) << "\nA: " << std::flush;
+ response(user_str, os, end_with);
+ (*os) << std::endl;
}
+ reset();
+}
+
+std::string Llm::response(const std::string& user_str, std::ostream* os, const char* end_with) {
+ mPromptLib->appendUserPrompt(user_str);
+ auto assistant_str = generate(mPromptLib->getLLMInput(), os, end_with);
+ mPromptLib->appendLLMOutput(assistant_str);
+ return assistant_str;
}
+// "app_type": "chat" >
+
+
void Llm::reset() {
- history_ids_.clear();
- all_seq_len_ = 0;
+ // clear KV cache
+ // KV cache automatically cleared as long as seq_len reset!
+ mLlmSessionInfos.clear();
+ mLlmSessionInfos.emplace_back(LlmSessionInfo());
}
+bool Llm::reuse_kv() const {
+ return config_->reuse_kv();
+}
+
+
+// < generate
void Llm::generate_init() {
- // init status
- gen_seq_len_ = 0;
- prefill_us_ = 0;
- decode_us_ = 0;
+ // handle past_key_values if not attention_fused_
past_key_values_.clear();
- if (is_single_) {
- past_key_values_.push_back(_Input(key_value_shape_, NCHW));
- } else {
- for (int i = 0; i < config_->layer_nums(); i++) {
+ if (!attention_fused_) {
+ if (is_single_) {
past_key_values_.push_back(_Input(key_value_shape_, NCHW));
+ } else {
+ MNN_ERROR("Split version is deprecated\n");
}
}
- if (!config_->reuse_kv()) {
- all_seq_len_ = 0;
- history_ids_.clear();
+ if (!reuse_kv()) {
+ // only reset sampler. The history is handled by mPromptLib.
+ mLlmSessionInfos[0].resetSamplerFields();
}
current_modules_ = prefill_modules_;
}
-std::vector Llm::generate(const std::vector& input_ids, int max_new_tokens) {
- generate_init();
- std::vector output_ids, all_ids = input_ids;
- prompt_len_ = static_cast(input_ids.size());
- if (max_new_tokens < 0) { max_new_tokens = config_->max_new_tokens(); }
- // prefill
- current_modules_ = prefill_modules_;
- auto logits = forward(input_ids);
- if (logits.get() == nullptr) {
- return {};
- }
- int token = sample(logits, all_ids);
- output_ids.push_back(token);
- all_ids.push_back(token);
- // decode
- current_modules_ = decode_modules_;
- while (gen_seq_len_ < max_new_tokens) {
- logits = nullptr;
- logits = forward({token});
- if (logits.get() == nullptr) {
- return {};
- }
- token = sample(logits, all_ids);
- if (is_stop(token)) { break; }
- output_ids.push_back(token);
- all_ids.push_back(token);
- }
- return output_ids;
+std::string Llm::generate(const std::string& prompt, std::ostream* os, const char* end_with) {
+ if (prompt.empty()) { return ""; }
+ if (!end_with) { end_with = "\n"; }
+ // std::cout << "# prompt : " << prompt << std::endl;
+ auto input_ids = tokenizer_encode(prompt, false);
+ std::string out_str = generate(input_ids, os, end_with);
+ return out_str;
}
std::string Llm::generate(const std::vector& input_ids, std::ostream* os, const char* end_with) {
+ if (mTracing) return generateTrace(input_ids, os, end_with);
+ if (input_ids.empty()) { return ""; }
+ if (!end_with) { end_with = "\n"; }
+ generate_init();
+ // printf("input_ids (%lu): ", input_ids.size()); for (auto id : input_ids) printf("%d, ", id); printf("\n");
+ std::string out_str = mSampler->sample(input_ids, os, end_with, &(mLlmSessionInfos[0].mTimePerformance));
+ return out_str;
+}
+
+
+std::string Llm::generateTrace(const std::vector& input_ids, std::ostream* os, const char* end_with) {
if (mTracing) {
// Skip real forward
- current_modules_ = prefill_modules_;
- forward(input_ids);
- current_modules_ = decode_modules_;
- forward({input_ids[0]});
- forward({input_ids[0]});
+ forward(input_ids, true);
+ forward({input_ids[0]}, false);
+ forward({input_ids[0]}, false);
return "Test";
}
- prompt_len_ = static_cast(input_ids.size());
- history_ids_.insert(history_ids_.end(), input_ids.begin(), input_ids.end()); // push to history_ids_
- auto st = std::chrono::system_clock::now();
- current_modules_ = prefill_modules_;
- auto logits = forward(input_ids);
- if (nullptr == logits.get()) {
- return "";
- }
- int token = sample(logits, history_ids_);
- auto et = std::chrono::system_clock::now();
- current_modules_ = decode_modules_;
- std::string output_str = tokenizer_decode(token);
- prefill_us_ = std::chrono::duration_cast(et - st).count();
- *os << output_str << std::flush;
- while (gen_seq_len_ < config_->max_new_tokens()) {
- st = std::chrono::system_clock::now();
- history_ids_.push_back(token);
- logits = nullptr;
- logits = forward({token});
- if (nullptr == logits.get()) {
- return "";
- }
- if (logits->getInfo()->size == 0) {
- return "";
- }
- token = sample(logits, history_ids_);
- et = std::chrono::system_clock::now();
- decode_us_ += std::chrono::duration_cast(et - st).count();
- if (is_stop(token)) {
- *os << end_with << std::flush;
- break;
- }
- auto word = tokenizer_decode(token);
- *os << word << std::flush;
- output_str += word;
- }
- ExecutorScope::Current()->gc(Executor::FULL);
-#ifdef DUMP_PROFILE_INFO
- print_speed();
-#endif
- return output_str;
+ return "Test";
}
std::vector Llm::tokenizer_encode(const std::string& user_content, bool use_template) {
if (!use_template) {
return tokenizer_->encode(user_content);
}
- auto prompt = apply_prompt_template(user_content);
+ auto prompt = mPromptLib->applyTemplate(user_content);
auto input_ids = tokenizer_->encode(prompt);
return input_ids;
}
-std::string Llm::response(const std::string& user_content, std::ostream* os, const char* end_with) {
- generate_init();
- if (!end_with) { end_with = "\n"; }
- std::vector input_ids;
- if (config_->reuse_kv()) {
- auto prompt = apply_prompt_template(user_content);
- if (all_seq_len_ > 0) {
- prompt = "<|im_end|>\n" + prompt;
- }
- input_ids = tokenizer_->encode(prompt);
- } else {
- input_ids = tokenizer_encode(user_content);
- }
- return generate(input_ids, os, end_with);
-}
-std::string Llm::response(const std::vector& chat_prompts, std::ostream* os, const char* end_with) {
- if (chat_prompts.empty()) { return ""; }
- generate_init();
- if (!end_with) { end_with = "\n"; }
- auto prompt = apply_chat_template(chat_prompts);
- if (config_->reuse_kv() && all_seq_len_ > 0) {
- prompt = "<|im_end|>\n" + prompt;
- }
- // std::cout << "# prompt : " << prompt << std::endl;
- auto input_ids = tokenizer_->encode(prompt);
- // printf("input_ids (%lu): ", input_ids.size()); for (auto id : input_ids) printf("%d, ", id); printf("\n");
- return generate(input_ids, os, end_with);
+// < evaluation
+std::vector Llm::perplexity(std::string prompt_file, std::ostream* perfOS) {
+ return mSampler->perplexity(prompt_file, perfOS);
}
+// evaluation >
+
Llm::~Llm() {
#if DEBUG_MODE==1
@@ -671,23 +576,64 @@ Llm::~Llm() {
runtime_manager_.reset();
}
+// < speed
void Llm::print_speed() {
- auto prefill_s = prefill_us_ * 1e-6;
- auto decode_s = decode_us_ * 1e-6;
- auto total_s = prefill_s + decode_s;
printf("\n#################################\n");
- printf(" total tokens num = %d\n", prompt_len_ + gen_seq_len_);
- printf("prompt tokens num = %d\n", prompt_len_);
- printf("output tokens num = %d\n", gen_seq_len_);
- printf(" total time = %.2f s\n", total_s);
- printf("prefill time = %.2f s\n", prefill_s);
- printf(" decode time = %.2f s\n", decode_s);
- printf(" total speed = %.2f tok/s\n", (prompt_len_ + gen_seq_len_) / total_s);
- printf("prefill speed = %.2f tok/s\n", prompt_len_ / prefill_s);
- printf(" decode speed = %.2f tok/s\n", gen_seq_len_ / decode_s);
- printf(" chat speed = %.2f tok/s\n", gen_seq_len_ / total_s);
+ printf("average total speed = %.3f tok/s\n", average_total_speed());
+ printf("average prefill speed = %.3f tok/s\n", average_prefill_speed());
+ printf("average decode speed = %.3f tok/s\n", average_decode_speed());
printf("##################################\n");
+ #if DEBUG_MODE==1
+ if (nullptr != gTimeTraceInfo) {
+ float opSummer = 0.0f;
+ float opFlopsSummber = 0.0f;
+ for (auto& iter : gTimeTraceInfo->mTypes) {
+ float summer = 0.0f;
+ float summerflops = 0.0f;
+ for (auto& t : iter.second) {
+ for (auto& t0 : t.second) {
+ summer += t0.first;
+ summerflops += t0.second;
+ }
+ }
+ summer = summer;
+ summerflops = summerflops;
+ MNN_PRINT("%s : %.7f, FLOP: %.7f, Speed: %.7f GFlops\n", iter.first.c_str(), summer, summerflops, summerflops / summer);
+ opSummer += summer;
+ opFlopsSummber+= summerflops;
+ }
+ MNN_PRINT("OP Summer: %.7f, Flops: %.7f, Speed: %.7f GFlops\n", opSummer, opFlopsSummber, opFlopsSummber/opSummer);
+ }
+ #endif
+}
+
+void Llm::print_speed(std::ostream* os) {
+ mLlmSessionInfos[0].print_speed(os);
+}
+
+float Llm::average_total_speed() {
+ return mLlmSessionInfos[0].average_total_speed();
+}
+float Llm::average_prefill_speed() {
+ // prefill response rate
+ return mLlmSessionInfos[0].average_prefill_speed();
+}
+float Llm::average_decode_speed() {
+ return mLlmSessionInfos[0].average_decode_speed();
+}
+float Llm::getTotalPrefillTime() {
+ return mLlmSessionInfos[0].getTotalPrefillTime();
+}
+float Llm::getTotalDecodeTime() {
+ return mLlmSessionInfos[0].getTotalDecodeTime();
+}
+int Llm::getTotalPromptLen() {
+ return mLlmSessionInfos[0].getTotalPromptLen();
+}
+int Llm::getTotalDecodeLen() {
+ return mLlmSessionInfos[0].getTotalDecodeLen();
}
+// speed >
static inline bool needNewVar(VARP var, int axis, int seq_len) {
if (var == nullptr) {
@@ -722,34 +668,35 @@ std::string Llm::tokenizer_decode(int id) {
}
VARP Llm::gen_attention_mask(int seq_len) {
- int kv_seq_len = all_seq_len_ + seq_len;
+ int kv_seq_len_=mLlmSessionInfos[0].all_seq_len_+seq_len, gen_seq_len_=mLlmSessionInfos[0].gen_seq_len_;
+ int prev_seq_len_ = kv_seq_len_ - seq_len;
if (seq_len == 1) {
- kv_seq_len = seq_len;
+ kv_seq_len_ = seq_len;
}
if (config_->attention_mask() == "float") {
if (needNewVar(attention_mask_, 2, seq_len)) {
- attention_mask_ = _Input({1, 1, seq_len, kv_seq_len}, NCHW, halide_type_of());
+ attention_mask_ = _Input({1, 1, seq_len, kv_seq_len_}, NCHW, halide_type_of());
} else {
return attention_mask_;
}
auto ptr = attention_mask_->writeMap();
for (int i = 0; i < seq_len; i++) {
- for (int j = 0; j < kv_seq_len; j++) {
- int row = i + all_seq_len_;
- ptr[kv_seq_len * i + j] = (j > row) * std::numeric_limits::lowest();
+ for (int j = 0; j < kv_seq_len_; j++) {
+ int row = i + prev_seq_len_;
+ ptr[kv_seq_len_ * i + j] = (j > row) * std::numeric_limits::lowest();
}
}
return attention_mask_;
} else {
if (needNewVar(attention_mask_, 2, seq_len)) {
- attention_mask_ = _Input({1, 1, seq_len, kv_seq_len}, NCHW, halide_type_of());
+ attention_mask_ = _Input({1, 1, seq_len, kv_seq_len_}, NCHW, halide_type_of());
} else {
return attention_mask_;
}
auto ptr = attention_mask_->writeMap();
if (config_->attention_mask() == "glm") {
// chatglm
- for (int i = 0; i < seq_len * kv_seq_len; i++) {
+ for (int i = 0; i < seq_len * kv_seq_len_; i++) {
ptr[i] = 0;
}
if (seq_len > 1) {
@@ -760,8 +707,8 @@ VARP Llm::gen_attention_mask(int seq_len) {
} else {
bool is_glm2 = config_->attention_mask() == "glm2";
for (int i = 0; i < seq_len; i++) {
- for (int j = 0; j < kv_seq_len; j++) {
- int row = i + all_seq_len_;
+ for (int j = 0; j < kv_seq_len_; j++) {
+ int row = i + prev_seq_len_;
ptr[seq_len * i + j] = is_glm2 ? j > row : j <= row;
}
}
@@ -771,6 +718,8 @@ VARP Llm::gen_attention_mask(int seq_len) {
}
VARP Llm::gen_position_ids(int seq_len) {
+ int kv_seq_len_=mLlmSessionInfos[0].all_seq_len_+seq_len, gen_seq_len_=mLlmSessionInfos[0].gen_seq_len_;
+ int prev_seq_len_ = kv_seq_len_ - seq_len;
if (config_->attention_mask() == "glm") {
// chatglm
if (needNewVar(position_ids_, 2, seq_len)) {
@@ -778,7 +727,7 @@ VARP Llm::gen_position_ids(int seq_len) {
}
auto ptr = position_ids_->writeMap();
if (seq_len == 1) {
- ptr[0] = all_seq_len_ - gen_seq_len_ - 2;
+ ptr[0] = prev_seq_len_ - gen_seq_len_ - 2;
ptr[1] = gen_seq_len_ + 1;
} else {
for (int i = 0; i < seq_len - 1; i++) {
@@ -796,10 +745,10 @@ VARP Llm::gen_position_ids(int seq_len) {
}
auto ptr = position_ids_->writeMap();
if (seq_len == 1) {
- ptr[0] = is_glm2 ? gen_seq_len_ : all_seq_len_;
+ ptr[0] = is_glm2 ? gen_seq_len_ : prev_seq_len_;
} else {
for (int i = 0; i < seq_len; i++) {
- ptr[i] = i + all_seq_len_;
+ ptr[i] = i + prev_seq_len_;
}
}
return position_ids_;
@@ -868,7 +817,8 @@ std::vector Lvlm::image_process(const std::string& image_info) {
}
std::vector Lvlm::tokenizer_encode(const std::string& query, bool use_template) {
- auto prompt = apply_prompt_template(query);
+ auto prompt = query;
+ if (!use_template) { prompt = mPromptLib->applyTemplate(query); }
// split query
std::regex img_regex("(.*?)");
std::string::const_iterator searchStart(prompt.cbegin());
@@ -976,7 +926,7 @@ VARP Embedding::ids_embedding(const std::vector& ids) {
}
VARP Embedding::txt_embedding(const std::string& txt) {
- return ids_embedding(tokenizer_encode(txt));
+ return ids_embedding(tokenizer_encode(txt, false));
}
VARP Embedding::gen_attention_mask(int seq_len) {
diff --git a/transformers/llm/engine/src/llmconfig.cpp b/transformers/llm/engine/src/llmconfig.cpp
new file mode 100644
index 000000000..458f3c4d1
--- /dev/null
+++ b/transformers/llm/engine/src/llmconfig.cpp
@@ -0,0 +1,47 @@
+//
+// llmconfig.hpp
+//
+// Created by MNN on 2024/07/19.
+// ZhaodeWang
+//
+
+
+
+#include "rapidjson/document.h"
+#include
+#include
+#include "llmconfig.hpp"
+
+namespace MNN {
+namespace Transformer {
+
+bool merge_json(rapidjson::Value& destination, const rapidjson::Value& source,
+ rapidjson::Document::AllocatorType& allocator) {
+ if (!source.IsObject() || !destination.IsObject()) {
+ return false;
+ }
+
+ for (auto it = source.MemberBegin(); it != source.MemberEnd(); ++it) {
+ const char* key = it->name.GetString();
+ if (destination.HasMember(key)) {
+ if (destination[key].IsObject() && it->value.IsObject()) {
+ // Recursively merge the two JSON objects
+ merge_json(destination[key], it->value, allocator);
+ } else {
+ // Overwrite the value in the destination
+ destination[key].CopyFrom(it->value, allocator);
+ }
+ } else {
+ // Add the value to the destination
+ rapidjson::Value newKey(key, allocator);
+ rapidjson::Value newValue;
+ newValue.CopyFrom(it->value, allocator);
+ destination.AddMember(newKey, newValue, allocator);
+ }
+ }
+ return true;
+}
+
+} // Transformer
+} // MNN
+
diff --git a/transformers/llm/engine/src/llmconfig.hpp b/transformers/llm/engine/src/llmconfig.hpp
index 327ecaaff..301ce38ae 100644
--- a/transformers/llm/engine/src/llmconfig.hpp
+++ b/transformers/llm/engine/src/llmconfig.hpp
@@ -5,9 +5,18 @@
// ZhaodeWang
//
-#include "rapidjson/document.h"
+#ifndef LLMCONFIG_Hpp
+#define LLMCONFIG_Hpp
+
+#include
+#include
+#include
+#include
+#include
#include
#include
+
+
namespace MNN {
namespace Transformer {
@@ -36,31 +45,7 @@ static inline std::string file_name(const std::string& path) {
}
bool merge_json(rapidjson::Value& destination, const rapidjson::Value& source,
- rapidjson::Document::AllocatorType& allocator) {
- if (!source.IsObject() || !destination.IsObject()) {
- return false;
- }
-
- for (auto it = source.MemberBegin(); it != source.MemberEnd(); ++it) {
- const char* key = it->name.GetString();
- if (destination.HasMember(key)) {
- if (destination[key].IsObject() && it->value.IsObject()) {
- // Recursively merge the two JSON objects
- merge_json(destination[key], it->value, allocator);
- } else {
- // Overwrite the value in the destination
- destination[key].CopyFrom(it->value, allocator);
- }
- } else {
- // Add the value to the destination
- rapidjson::Value newKey(key, allocator);
- rapidjson::Value newValue;
- newValue.CopyFrom(it->value, allocator);
- destination.AddMember(newKey, newValue, allocator);
- }
- }
- return true;
-}
+ rapidjson::Document::AllocatorType& allocator);
class rapid_json_wrapper {
public:
@@ -98,6 +83,13 @@ class rapid_json_wrapper {
return buffer.GetString();
}
// read value
+ float value(const char* key, const float& default_value) const {
+ if (document.HasMember(key)) {
+ const auto& value = document[key];
+ if (value.IsFloat()) return value.GetFloat();
+ }
+ return default_value;
+ }
int value(const char* key, const int& default_value) const {
if (document.HasMember(key)) {
const auto& value = document[key];
@@ -162,6 +154,21 @@ class rapid_json_wrapper {
}
return default_value;
}
+ std::vector value(const char* key, const std::vector& default_value) const {
+ if (document.HasMember(key)) {
+ const auto& value = document[key];
+ if (value.IsArray()) {
+ std::vector result;
+ for (auto& v : value.GetArray()) {
+ if (v.IsString()) {
+ result.push_back(v.GetString());
+ }
+ }
+ return result;
+ }
+ }
+ return default_value;
+ }
std::string value(const char key[], const char default_value[]) const {
return value(key, std::string(default_value));
}
@@ -248,6 +255,10 @@ class LlmConfig {
// model file config end >
// < generate config start
+ int max_all_tokens() const {
+ return config_.value("max_all_tokens", 2048);
+ }
+
int max_new_tokens() const {
return config_.value("max_new_tokens", 512);
}
@@ -324,18 +335,96 @@ class LlmConfig {
return llm_config_.value("attention_fused", true);
}
- std::string chat_template() const {
- return llm_config_.value("chat_template", "");
+ std::string system_prompt_template() const {
+ return llm_config_.value("system_prompt_template", "<|im_start|>system\n%s<|im_end|>\n");
}
-
- std::string prompt_template() const {
- return llm_config_.value("prompt_template", "");
+ std::string user_prompt_template() const {
+ return llm_config_.value("user_prompt_template", "<|im_start|>user\n%s<|im_end|>\n");
+ }
+ std::string assistant_prefix() const {
+ return llm_config_.value("assistant_prefix", "<|im_start|>assistant\n");
+ }
+ std::string assistant_suffix() const {
+ return llm_config_.value("assistant_suffix", "<|im_end|>\n");
}
std::vector tie_embeddings() const {
return llm_config_.value("tie_embeddings", std::vector{});
}
// llm model config end >
+
+ // < sampler config start
+ std::string sampler_type() const {
+ return config_.value("sampler_type", "mixed");
+ }
+
+ std::vector mixed_samplers() const {
+ return config_.value("mixed_samplers", std::vector({"topK", "tfs", "typical", "topP", "min_p", "temperature"}));
+ }
+
+ float temperature() const {
+ return config_.value("temperature", 1.0f);
+ }
+
+ int topK() const {
+ return config_.value("topK", 40);
+ }
+
+ float topP() const {
+ return config_.value("topP", 0.9f);
+ }
+
+ float minP() const {
+ return config_.value("minP", 0.1f);
+ }
+
+ float tfsZ() const {
+ return config_.value("tfsZ", 1.0f);
+ }
+
+ float typical() const {
+ return config_.value("typical", 1.0f);
+ }
+
+ float penalty() const {
+ return config_.value("penalty", 0.0f);
+ }
+
+ int ngram() const {
+ return config_.value("n_gram", 8);
+ }
+
+ float ngram_factor() const {
+ return config_.value("ngram_factor", 1.0f);
+ }
+
+ std::string penalty_sampler() const {
+ return config_.value("penalty_sampler", "greedy");
+ }
+ // sampler config end >
+
+ // < app config start
+ std::string app_type() const {
+ return config_.value("app_type", "chat");
+ }
+ std::string system_prompt() const {
+ return config_.value("system_prompt", "You are a helpful assistant!\n");
+ }
+ // app config end >
+
+ // < evaulation config start
+ int ppl_stride() const {
+ return config_.value("ppl_stride", 0);
+ }
+ std::string dataset() const {
+ return config_.value("dataset", "wikitext");
+ }
+ int dataset_sample_size() const {
+ return config_.value("dataset_sample_size", -1); // -1 stands for no sampling, use all.
+ }
+ // evaulation config end
};
} // Transformer
} // MNN
+
+#endif
diff --git a/transformers/llm/engine/src/perplexity.cpp b/transformers/llm/engine/src/perplexity.cpp
new file mode 100644
index 000000000..4d04cdf31
--- /dev/null
+++ b/transformers/llm/engine/src/perplexity.cpp
@@ -0,0 +1,318 @@
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "sampler.hpp"
+#include "perplexity.hpp"
+#include "llmconfig.hpp"
+#include "prompt.hpp"
+
+namespace MNN{
+namespace Transformer{
+
+
+/* -----------TextPPLMeasurer---------- */
+TextPPLMeasurer::TextPPLMeasurer(Llm* llm, std::shared_ptr llmConfig) {
+ mLlm = llm;
+ mConfig.max_all_tokens = llmConfig->max_all_tokens();
+ mConfig.max_new_tokens = llmConfig->max_new_tokens();
+ mDatasetType = llmConfig->dataset();
+ mStride = llmConfig->ppl_stride();
+ if (mStride == 0) {
+ // default stride for sliding window.
+ mStride = mConfig.max_all_tokens / 2;
+ }
+}
+
+/* Implemented based on https://huggingface.co/docs/transformers/perplexity
+
+ ******************** HuggingFace Python Version ************************
+
+import torch
+from tqdm import tqdm
+
+max_length = model.config.n_positions
+stride = 512
+seq_len = encodings.input_ids.size(1)
+
+nlls = []
+prev_end_loc = 0
+for begin_loc in tqdm(range(0, seq_len, stride)):
+ end_loc = min(begin_loc + max_length, seq_len)
+ trg_len = end_loc - prev_end_loc # may be different from stride on last loop
+ input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
+ target_ids = input_ids.clone()
+ target_ids[:, :-trg_len] = -100
+
+ with torch.no_grad():
+ outputs = model(input_ids, labels=target_ids)
+
+ # loss is calculated using CrossEntropyLoss which averages over valid labels
+ # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
+ # to the left by 1.
+ neg_log_likelihood = outputs.loss
+
+ nlls.append(neg_log_likelihood)
+
+ prev_end_loc = end_loc
+ if end_loc == seq_len:
+ break
+
+ppl = torch.exp(torch.stack(nlls).mean())
+
+ ******************** HuggingFace Python Version ************************
+*/
+
+float TextPPLMeasurer::perplexity_one(const std::vector& prompt) {
+ int seq_len = prompt.size();
+ std::vector nlls;
+ float ppl = 0.f;
+
+ // start calculation
+ int prev_end_loc = 1; // the first token start from id=1, do not count the first one.
+ for (int begin_loc = 0; begin_loc < seq_len; begin_loc += mStride) {
+ int end_loc = std::min(begin_loc + mConfig.max_all_tokens, seq_len);
+ // first token
+ std::vector tokens(prev_end_loc - begin_loc);
+ for (int it = begin_loc; it < prev_end_loc; ++it) tokens[it - begin_loc] = prompt[it];
+ mLlm->mLlmSessionInfos[0].all_seq_len_ = tokens.size();
+ mLlm->mLlmSessionInfos[0].gen_seq_len_ = mLlm->mLlmSessionInfos[0].all_seq_len_;
+ auto logits = mLlm->forward(tokens, true);
+ logits = MNN::Express::_Softmax(logits);
+ nlls.push_back(-std::log(((float*)(logits->readMap()))[prompt[prev_end_loc]]));
+ // std::cout << mLlm->tokenizer_decode(argmax(logits)) << " " << mLlm->tokenizer_decode(prompt[prev_end_loc]) << " " << -std::log(((float*)(logits->readMap()))[prompt[prev_end_loc]]) << std::endl;
+ std::cout << -std::log(((float*)(logits->readMap()))[prompt[prev_end_loc]]) << std::endl;
+ // decode following tokens
+ for (int it = prev_end_loc+1; it < end_loc; ++it) {
+ mLlm->mLlmSessionInfos[0].all_seq_len_ += 1;
+ mLlm->mLlmSessionInfos[0].gen_seq_len_ = mLlm->mLlmSessionInfos[0].all_seq_len_;
+ auto logits = mLlm->forward({prompt[it-1]}, false);
+ logits = MNN::Express::_Softmax(logits);
+ nlls.push_back(-std::log(((float*)(logits->readMap()))[prompt[it]]));
+ // std::cout << mLlm->tokenizer_decode(argmax(logits)) << " " << mLlm->tokenizer_decode(prompt[it]) << " " << -std::log(((float*)(logits->readMap()))[prompt[it]]) << std::endl;
+ std::cout << -std::log(((float*)(logits->readMap()))[prompt[it]]) << std::endl;
+ }
+ // clean up once
+ mLlm->reset();
+ prev_end_loc = end_loc;
+ if (end_loc == seq_len) break;
+ }
+
+ // calculate ppl
+ for (int j = 0; j < nlls.size(); ++j) ppl += nlls[j];
+ ppl /= nlls.size();
+ ppl = std::exp(ppl);
+
+ // print
+ std::cout << "PPL: " << std::setprecision(8) << ppl << std::endl;
+ return ppl;
+}
+
+std::vector TextPPLMeasurer::perplexity(std::vector> prompts) {
+ std::vector ppls;
+ for (auto prompt : prompts) {
+ ppls.push_back(perplexity_one(prompt));
+ mLlm->reset();
+ }
+ return ppls;
+}
+
+std::vector TextPPLMeasurer::perplexity(std::vector prompts) {
+ std::vector> tokens(prompts.size());
+ for (int p = 0; p < prompts.size(); ++p) tokens[p] = mLlm->tokenizer_encode(prompts[p], false);
+ return perplexity(tokens);
+}
+
+std::vector TextPPLMeasurer::perplexity(std::string prompt_file, std::ostream* perfOS) {
+ // No performance will be printed!
+ std::vector prompts;
+ if (mDatasetType == "wikitext") {
+ prompts = wikitext(prompt_file);
+ }
+ else if (mDatasetType == "plaintext") {
+ prompts = plaintext(prompt_file);
+ }
+ else if (mDatasetType == "rowsplit") {
+ prompts = rowsplit(prompt_file);
+ }
+ else {
+ MNN_ERROR("Dataset not suppoted");
+ exit(1);
+ }
+ std::cout << "prompt file loaded!" << std::endl;
+ return perplexity(prompts);
+}
+
+/* -----------ChatPPLMeasurer---------- */
+ChatPPLMeasurer::ChatPPLMeasurer(Llm* llm, std::shared_ptr llmConfig) {
+ mLlm = llm;
+ mConfig.max_all_tokens = llmConfig->max_all_tokens();
+ mConfig.max_new_tokens = llmConfig->max_new_tokens();
+ mDatasetType = llmConfig->dataset();
+ mDatasetSampleSize = llmConfig->dataset_sample_size();
+}
+
+void ChatPPLMeasurer::handleToken(int token) {
+ // CommonPrefix and Candidates managements
+ mLlm->mLlmSessionInfos[0].tokens.push_back(token);
+}
+
+std::vector ChatPPLMeasurer::sample(const std::vector& input_ids, const std::vector& prompt, struct TimePerformance* time_perf) {
+ std::vector nlls;
+ // initialization for time performance
+ PrefillTimePerformance prefill_time;
+ prefill_time.prefill_prev_token_ = mLlm->mLlmSessionInfos[0].tokens.size();
+ prefill_time.prefill_token_ = input_ids.size();
+ appendNewPromptRecord(time_perf, input_ids.size(), mLlm->reuse_kv());
+ // initialization
+ mLlm->mLlmSessionInfos[0].tokens.insert(mLlm->mLlmSessionInfos[0].tokens.end(), input_ids.begin(), input_ids.end());
+ // all_seq_len_ in sampler functions as kv_seq_len_, prev_seq_len_ = all_seq_len_ - seq_len
+ mLlm->mLlmSessionInfos[0].all_seq_len_ = mLlm->mLlmSessionInfos[0].tokens.size() - input_ids.size();
+ mLlm->mLlmSessionInfos[0].gen_seq_len_ = 0;
+ // prefill
+ auto st = std::chrono::system_clock::now();
+ auto logits = mLlm->forward(input_ids, true);
+ logits = MNN::Express::_Softmax(logits);
+ nlls.push_back(-std::log(((float*)(logits->readMap()))[prompt[mLlm->mLlmSessionInfos[0].gen_seq_len_]]));
+ // record time
+ auto et = std::chrono::system_clock::now();
+ prefill_time.prefill_us_ = std::chrono::duration_cast(et - st).count();
+ time_perf->prefill_record_.push_back(prefill_time);
+ // handle the new token
+ handleToken(prompt[mLlm->mLlmSessionInfos[0].gen_seq_len_]);
+ // decode
+ while (mLlm->mLlmSessionInfos[0].gen_seq_len_ < prompt.size()) {
+ DecodeTimePerformance decode_time;
+ decode_time.decode_prev_token_ = mLlm->mLlmSessionInfos[0].tokens.size();
+ st = std::chrono::system_clock::now();
+ // next token
+ logits = mLlm->forward({mLlm->mLlmSessionInfos[0].tokens.back()}, false);
+ logits = MNN::Express::_Softmax(logits);
+ nlls.push_back(-std::log(((float*)(logits->readMap()))[prompt[mLlm->mLlmSessionInfos[0].gen_seq_len_]]));
+ et = std::chrono::system_clock::now();
+ decode_time.decode_us_ = std::chrono::duration_cast(et - st).count();
+ time_perf->decode_record_.push_back(decode_time);
+ handleToken(prompt[mLlm->mLlmSessionInfos[0].gen_seq_len_]);
+ }
+ // return nlls
+ return nlls;
+}
+
+float ChatPPLMeasurer::perplexity_one(const std::vector>& prompt, std::ostream* perfOS) {
+ // (turns, roles)
+ std::vector nlls;
+ float ppl = 0.f;
+
+ // < simulated chat
+ mLlm->reset();
+ for (auto& turn : prompt) {
+ mLlm->mPromptLib->appendUserPrompt(turn[0].second);
+ std::vector input_ids = mLlm->tokenizer_encode(mLlm->mPromptLib->getLLMInput(), false);
+ mLlm->generate_init();
+ auto turn_nlls = sample(input_ids, mLlm->tokenizer_encode(turn[1].second, false), &(mLlm->mLlmSessionInfos[0].mTimePerformance));
+ nlls.insert(nlls.end(), turn_nlls.begin(), turn_nlls.end());
+ mLlm->mPromptLib->appendLLMOutput(turn[1].second);
+ }
+
+ // record time performance to file
+ if (perfOS != nullptr) {
+ mLlm->mLlmSessionInfos[0].print_speed(perfOS);
+ }
+
+ mLlm->reset();
+ // simulated chat >
+
+ // calculate ppl
+ for (int j = 0; j < nlls.size(); ++j) ppl += nlls[j];
+ ppl /= nlls.size();
+ ppl = std::exp(ppl);
+
+ // print
+ std::cout << "PPL: " << std::setprecision(8) << ppl << std::endl;
+ return ppl;
+}
+
+
+std::vector ChatPPLMeasurer::perplexity(const std::vector>>& prompts, std::ostream* perfOS) {
+ std::vector ppls;
+ for (auto& prompt : prompts) {
+ ppls.push_back(perplexity_one(prompt, perfOS));
+ mLlm->reset();
+ }
+ return ppls;
+}
+
+void ChatPPLMeasurer::getStats(const std::vector>>& prompts) {
+ std::ofstream total_stats("total_stats.csv");
+ std::ofstream dialog_stats("dialog_stats.csv");
+ float average_turns=0, average_prefill=0, average_decode=0, average_total_tokens=0;
+ int max_turns=0;
+ std::vector>> stats; // (dialog, turn, (prefill, decode))
+ std::cout << prompts.size() << std::endl;
+ int counter = 0;
+ for (auto& dialog : prompts) {
+ std::vector> dialog_stats;
+ if ((counter++) % std::max((int)prompts.size()/200, 1) == 0) std::cout << "*" << std::flush;
+ float prefill_len_turn = 0;
+ float decode_len_turn = 0;
+ for (auto& turn : dialog) {
+ // turn: prefill, decode
+ int prefill_len = mLlm->tokenizer_encode(turn[0].second, false).size();
+ int decode_len = mLlm->tokenizer_encode(turn[1].second, false).size();
+ prefill_len_turn += prefill_len;
+ decode_len_turn += decode_len;
+ average_total_tokens += prefill_len + decode_len;
+ dialog_stats.push_back({prefill_len, decode_len});
+ }
+ stats.push_back(dialog_stats);
+ average_prefill += prefill_len_turn / dialog.size(); // average over turns
+ average_decode += decode_len_turn / dialog.size(); // average over turns
+ average_turns += dialog.size();
+ max_turns = std::max(max_turns, (int)dialog.size());
+ }
+ average_turns /= prompts.size();
+ average_prefill /= prompts.size();
+ average_decode /= prompts.size();
+ average_total_tokens /= prompts.size();
+ total_stats << "total_dialogs," << "max_turns," << "avg_turns," \
+ << "avg_prefill_tokens/turn," << "avg_decode_tokens/turn," \
+ << "avg_total_tokens/dialog" << std::endl;
+ total_stats << prompts.size() << "," << max_turns << "," << average_turns << "," \
+ << average_prefill << "," << average_decode << "," \
+ << average_total_tokens << std::endl;
+ for (int i=0; i ChatPPLMeasurer::perplexity(std::string prompt_file, std::ostream* perfOS) {
+ // No performance will be printed!
+ std::vector>> prompts;
+ if (mDatasetType == "shareGPT") {
+ prompts = shareGPT(prompt_file, mDatasetSampleSize);
+ }
+ else {
+ MNN_ERROR("Dataset not suppoted");
+ exit(1);
+ }
+ std::cout << "prompt file loaded!" << std::endl;
+ getStats(prompts);
+ std::cout << "\nshareGPT statistics counted!" << std::endl;
+ return perplexity(prompts, perfOS);
+}
+
+
+} // Transformer
+} // MNN
\ No newline at end of file
diff --git a/transformers/llm/engine/src/perplexity.hpp b/transformers/llm/engine/src/perplexity.hpp
new file mode 100644
index 000000000..ece414dd2
--- /dev/null
+++ b/transformers/llm/engine/src/perplexity.hpp
@@ -0,0 +1,65 @@
+#ifndef PERPLEXITY_hpp
+#define PERPLEXITY_hpp
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+
+#include "sampler.hpp"
+#include "evaluation/dataset.hpp"
+
+namespace MNN {
+namespace Transformer {
+class Llm;
+
+class MNN_PUBLIC TextPPLMeasurer : public Sampler {
+protected:
+ Llm* mLlm;
+ int mStride;
+ std::string mDatasetType;
+ LlmSamplerConfig mConfig;
+public:
+ TextPPLMeasurer(Llm* llm, std::shared_ptr config);
+ float perplexity_one(const std::vector& prompt);
+ std::vector perplexity(std::vector> prompts);
+ std::vector perplexity(std::vector prompts);
+ virtual std::string sample(const std::vector& input_ids, std::ostream* os = &std::cout, const char* end_with = nullptr, struct TimePerformance* time_perf = nullptr) override { return "perplexity evaluation!\n"; }
+ virtual std::vector perplexity(std::string prompt_file, std::ostream* perfOS = nullptr) override;
+};
+
+class MNN_PUBLIC ChatPPLMeasurer : public Sampler {
+protected:
+ Llm* mLlm;
+ std::string mDatasetType;
+ int mDatasetSampleSize;
+ LlmSamplerConfig mConfig;
+ void handleToken(int token);
+ std::vector sample(const std::vector& input_ids, const std::vector& prompt, struct TimePerformance* time_perf);
+public:
+ ChatPPLMeasurer(Llm* llm, std::shared_ptr config);
+ void getStats(const std::vector>>& prompts);
+ float perplexity_one(const std::vector>& prompt, std::ostream* perfOS);
+ std::vector perplexity(const std::vector>>& prompts, std::ostream* perfOS);
+ virtual std::string sample(const std::vector& input_ids, std::ostream* os = &std::cout, const char* end_with = nullptr, struct TimePerformance* time_perf = nullptr) override { return "perplexity evaluation!\n"; }
+ virtual std::vector perplexity(std::string prompt_file, std::ostream* perfOS = nullptr) override;
+};
+
+
+
+} // Transformer
+} // MNN
+
+
+#endif // SAMPLER_hpp
\ No newline at end of file
diff --git a/transformers/llm/engine/src/prompt.cpp b/transformers/llm/engine/src/prompt.cpp
new file mode 100644
index 000000000..445c630ce
--- /dev/null
+++ b/transformers/llm/engine/src/prompt.cpp
@@ -0,0 +1,110 @@
+#include "prompt.hpp"
+
+namespace MNN {
+namespace Transformer {
+
+/* ----------PromptLib---------- */
+PromptLib* PromptLib::createPromptLib(Llm* llm, const std::string& config_path) {
+ return createPromptLib(llm, std::shared_ptr(new LlmConfig(config_path)));
+}
+PromptLib* PromptLib::createPromptLib(Llm* llm, std::shared_ptr config) {
+ if (config->app_type() == "chat" || config->app_type() == "perplexity") {
+ return new BaseChatPromptLib(llm, config);
+ } else {
+ std::cout << "PromptLib not Implemented!\n" << std::endl;
+ return nullptr;
+ }
+}
+
+/* ----------BaseChatPromptLib---------- */
+BaseChatPromptLib::BaseChatPromptLib(Llm* llm, std::shared_ptr config) {
+ mLlm = llm;
+ mReuseKV = config->reuse_kv();
+ mDefaultSystemPrompt = config->system_prompt();
+ mSystemTemplate = config->system_prompt_template();
+ mUserTemplate = config->user_prompt_template();
+ mAssistantPrefix = config->assistant_prefix();
+ mAssistantSuffix = config->assistant_suffix();
+}
+
+void BaseChatPromptLib::appendSystemPrompt() {
+ appendSystemPrompt(mDefaultSystemPrompt);
+}
+void BaseChatPromptLib::appendSystemPrompt(const std::string sys_prompt) {
+ mLlm->mLlmSessionInfos[0].mHistory.emplace_back(std::make_pair("system", sys_prompt));
+ mLlm->mLlmSessionInfos[0].mInputs.emplace_back(std::make_pair("system", sys_prompt));
+}
+void BaseChatPromptLib::appendUserPrompt(const std::string user_prompt) {
+ if (mLlm->mLlmSessionInfos[0].mHistory.empty()) { appendSystemPrompt(); } // prevent no system prompt appendix.
+ mLlm->mLlmSessionInfos[0].mHistory.emplace_back(std::make_pair("user", user_prompt));
+ mLlm->mLlmSessionInfos[0].mInputs.emplace_back(std::make_pair("user", user_prompt));
+}
+void BaseChatPromptLib::appendLLMOutput(std::string out_str) {
+ mLlm->mLlmSessionInfos[0].mHistory.emplace_back(std::make_pair("assistant", out_str));
+ if (mReuseKV) {
+ // clear input
+ mLlm->mLlmSessionInfos[0].mInputs.clear();
+ } else {
+ // keep input, append output
+ mLlm->mLlmSessionInfos[0].mInputs.emplace_back(std::make_pair("assistant", out_str));
+ }
+}
+
+std::string BaseChatPromptLib::getLLMInput() {
+ std::string input_str;
+ if (mReuseKV) {
+ if (mLlm->mLlmSessionInfos[0].mHistory.size() != mLlm->mLlmSessionInfos[0].mInputs.size()) {
+ // 1.1 not first prefill, add end of speech.
+ input_str += mAssistantSuffix;
+ }
+ }
+ // 1.2 generate from template
+ input_str += applyTemplates(mLlm->mLlmSessionInfos[0].mInputs);
+ input_str += mAssistantPrefix;
+ return input_str;
+}
+
+std::string BaseChatPromptLib::applyTemplate(PromptItem item, std::string prompt_template, std::string placeholder) {
+ size_t start_pos = prompt_template.find(placeholder);
+ if (start_pos == std::string::npos) return item.first + "\n" + item.second + "\n";
+ else {
+ prompt_template.replace(start_pos, placeholder.length(), item.second);
+ return prompt_template;
+ }
+}
+
+std::string BaseChatPromptLib::applyTemplates(std::vector inputs) {
+ std::string input_str;
+ for (auto input : inputs) {
+ if (input.first == "") continue;
+ if (input.first == "system") {
+ if (input.second == "") continue;
+ input_str += applyTemplate(input, mSystemTemplate, "%s");
+ continue;
+ }
+ if (input.first == "user") {
+ input_str += applyTemplate(input, mUserTemplate, "%s");
+ continue;
+ }
+ if (input.first == "assistant") {
+ input_str += mAssistantPrefix + input.second + mAssistantSuffix;
+ continue;
+ }
+ // Invalid role!!!
+ }
+ return input_str;
+}
+
+std::string BaseChatPromptLib::applyTemplate(std::string user_content) {
+ std::vector prompts;
+ prompts.push_back(std::make_pair("system", mDefaultSystemPrompt));
+ prompts.push_back(std::make_pair("user", user_content));
+ return applyTemplates(prompts) + mAssistantPrefix;
+}
+
+std::string BaseChatPromptLib::getAssistantSuffix() const {
+ return mAssistantSuffix;
+}
+
+}
+}
\ No newline at end of file
diff --git a/transformers/llm/engine/src/prompt.hpp b/transformers/llm/engine/src/prompt.hpp
new file mode 100644
index 000000000..3fcf42509
--- /dev/null
+++ b/transformers/llm/engine/src/prompt.hpp
@@ -0,0 +1,61 @@
+
+
+#ifndef PROMPT_Hpp
+#define PROMPT_Hpp
+
+#include "llm/llm.hpp"
+#include "llmconfig.hpp"
+#define MNN_OPEN_TIME_TRACE
+#include
+#include