diff --git a/.gitignore b/.gitignore index d11586811..ca1a4320f 100644 --- a/.gitignore +++ b/.gitignore @@ -361,3 +361,10 @@ pymnn_build/ # mnncompress generated MNN_compression_pb2.py + +# model path +model/ + +# datasets +datasets/* +!datasets/*.sh \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index e345987d1..c7768e340 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,7 +20,9 @@ endif() project(MNN VERSION ${MNN_VERSION} LANGUAGES C CXX ASM) # complier options set(CMAKE_C_STANDARD 99) -set(CMAKE_CXX_STANDARD 11) +IF (NOT (CMAKE_CXX_STANDARD EQUAL 17)) + set(CMAKE_CXX_STANDARD 11) +ENDIF() set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_LIST_DIR}/cmake" @@ -285,7 +287,7 @@ if(CMAKE_SYSTEM_NAME MATCHES "^Android") endif() option(MNN_USE_CPP11 "Enable MNN use c++11" ON) if (NOT MSVC) - if(MNN_CUDA AND MNN_SUPPORT_TRANSFORMER_FUSE) + if((MNN_CUDA AND MNN_SUPPORT_TRANSFORMER_FUSE) OR (CMAKE_CXX_STANDARD EQUAL 17)) set(CMAKE_CXX_STANDARD 17) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=gnu99") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17") diff --git a/docs/transformers/llm.md b/docs/transformers/llm.md index 963597517..b0fbd4932 100644 --- a/docs/transformers/llm.md +++ b/docs/transformers/llm.md @@ -189,7 +189,7 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt - visual_model: 当使用VL模型时,visual_model的实际路径为`base_dir + visual_model`,默认为`base_dir + 'visual.mnn'` - 推理配置 - max_new_tokens: 生成时最大token数,默认为`512` - - reuse_kv: 多轮对话时是否复用之前对话的`kv cache`,默认为`false` + - reuse_kv: 多轮对话时是否复用之前对话的`kv cache`,默认为`false`, 目前只有CPU后端支持设置为`true`. - quant_qkv: CPU attention 算子中`query, key, value`是否量化,可选为:`0, 1, 2, 3, 4`,默认为`0`,含义如下: - 0: key和value都不量化 - 1: 使用非对称8bit量化存储key @@ -205,6 +205,19 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt - thread_num: CPU推理使用硬件线程数,默认为:`4`; OpenCL推理时使用`68` - precision: 推理使用精度策略,默认为:`"low"`,尽量使用`fp16` - memory: 推理使用内存策略,默认为:`"low"`,开启运行时量化 +- Sampler配置 + - sampler_type: 使用的sampler种类,目前支持`greedy`, `temperature`, `topK`, `topP`, `minP`, `tfs`, `typical`, `penalty`8种基本sampler,外加`mixed`(混合sampler)。当选择`mixed`时,依次执行mixed_samplers中的sampler。默认为`mixed`。 + - mixed_samplers: 当`sampler_type`为`mixed`时有效,默认为`["topK", "tfs", "typical", "topP", "min_p", "temperature"]` + - temperature: `temperature`, `topP`, `minP`, `tfsZ`, `typical`中temerature值,默认为1.0 + - topK: `topK`中top K 个的个数,默认为40 + - topP: `topP`中top P的值,默认为0.9 + - minP: `minP`中min P的值,默认为0.1 + - tfsZ: `tfs`中Z的值,默认为1.0,即不使用tfs算法 + - typical: `typical`中p的值,默认为1.0,即不使用typical算法 + - penalty: `penalty`中对于logits的惩罚项,默认为0.0,即不惩罚 + - n_gram: `penalty`中最大存储的ngram大小,默认为8 + - ngram_factor: `penalty`中对于重复ngram的额外惩罚,默认为1.0,即没有额外惩罚 + - penalty_sampler: `penalty`中最后一步采用的sampling策略,可选"greedy"或"temperature",默认greedy. ##### 配置文件示例 - `config.json` @@ -216,7 +229,15 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt "backend_type": "cpu", "thread_num": 4, "precision": "low", - "memory": "low" + "memory": "low", + "sampler_type": "mixed", + "mixed_samplers": ["topK", "tfs", "typical", "topP", "min_p", "temperature"], + "temperature": 1.0, + "topK": 40, + "topP": 0.9, + "tfsZ": 1.0, + "minP": 0.1, + "reuse_kv": true } ``` - `llm_config.json` @@ -240,7 +261,8 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt #### 推理用法 `llm_demo`的用法如下: -``` +pc端直接推理 +```bash # 使用config.json ## 交互式聊天 ./llm_demo model_dir/config.json @@ -254,6 +276,13 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt ./llm_demo model_dir/llm.mnn prompt.txt ``` +android手机端adb推理用法: +```bash +# 利用adb push将链接库push到手机上 +adb shell mkdir /data/local/tmp/llm +adb push llm_demo ppl_demo libllm.so libMNN_CL.so libMNN_Express.so libMNN.so tools/cv/libMNNOpenCV.so /data/local/tmp/llm +``` + - 对于视觉大模型,在prompt中嵌入图片输入 ``` https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg介绍一下图片里的内容 diff --git a/transformers/llm/.gitignore b/transformers/llm/.gitignore new file mode 100644 index 000000000..000bc68dc --- /dev/null +++ b/transformers/llm/.gitignore @@ -0,0 +1,7 @@ +datasets/* +!datasets/*.sh + + +!datasets/visualization/ +datasets/visualization/data +datasets/visualization/pic \ No newline at end of file diff --git a/transformers/llm/datasets/get-sharegpt.sh b/transformers/llm/datasets/get-sharegpt.sh new file mode 100644 index 000000000..fe4be4b5d --- /dev/null +++ b/transformers/llm/datasets/get-sharegpt.sh @@ -0,0 +1,2 @@ +git lfs install +git clone https://huggingface.co/datasets/shareAI/ShareGPT-Chinese-English-90k \ No newline at end of file diff --git a/transformers/llm/datasets/get-wikitext-2-raw.sh b/transformers/llm/datasets/get-wikitext-2-raw.sh new file mode 100644 index 000000000..33096304c --- /dev/null +++ b/transformers/llm/datasets/get-wikitext-2-raw.sh @@ -0,0 +1,2 @@ +wget https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip +unzip wikitext-2-raw-v1.zip \ No newline at end of file diff --git a/transformers/llm/datasets/visualization/stats.py b/transformers/llm/datasets/visualization/stats.py new file mode 100644 index 000000000..761ef256d --- /dev/null +++ b/transformers/llm/datasets/visualization/stats.py @@ -0,0 +1,116 @@ +import matplotlib.pyplot as plt +from matplotlib import colors +from matplotlib.ticker import PercentFormatter +from matplotlib import cbook +from matplotlib.axes import Axes +import pandas as pd +import numpy as np +import argparse +import os + +vis_root = "pic" + +def remove_blanks(df: pd.DataFrame) -> pd.DataFrame: + # Removing unnamed columns using drop function + df.drop(df.columns[df.columns.str.contains( + 'unnamed', case=False)], axis=1, inplace=True) + return df +def add_turns(df: pd.DataFrame) -> pd.DataFrame: + df["turns"] = (1-df.isnull()).sum(axis=1) // 2 + return df +def get_max_turn(df: pd.DataFrame) -> int: + keys = list(df.keys()) + return max([int(key.replace("decode", "")) for key in keys if "decode" in key]) + 1 +def add_pd_ratio(df: pd.DataFrame) -> pd.DataFrame: + max_turns = get_max_turn(df) + for i in range(max_turns): + df["pd_ratio{}".format(i)] = df["prefill{}".format(i)] / df["decode{}".format(i)] + return df +def preprocess(file_path: str) -> pd.DataFrame: + table = pd.read_csv(file_path) + table = remove_blanks(table) + table = add_turns(table) + table = add_pd_ratio(table) + print(table) + return table + +def draw_distribution(df: pd.DataFrame, file_path: str): + turns_bin = df.value_counts(subset=["turns"], sort=False) + print(turns_bin) + plt.close() + plt.rcParams['font.size'] = 10 + _, ax = plt.subplots() + # N is the count in each bin, bins is the lower-limit of the bin + N, bins, patches = ax.hist(df["turns"], bins=get_max_turn(df), density=True, align="left", label=True) + # We'll color code by height, but you could use any scalar + fracs = N / N.max() + # we need to normalize the data to 0..1 for the full range of the colormap + norm = colors.Normalize(fracs.min(), fracs.max()) + # Now, we'll loop through our objects and set the color of each accordingly + for thisfrac, thispatch in zip(fracs, patches): + color = plt.cm.viridis(norm(thisfrac)) + thispatch.set_facecolor(color) + # Now we format the y-axis to display percentage + ax.yaxis.set_major_formatter(PercentFormatter(xmax=1)) + ax.set_xlim((0.5, get_max_turn(df)-0.5)) + ax.set_xticks(np.arange(1,get_max_turn(df)+1),np.arange(1,get_max_turn(df)+1),rotation=60, fontsize=9) + ax.set_ylabel("frequency", fontsize=14) + ax.set_xlabel("num of turns", fontsize=14) + plt.savefig(file_path, dpi=600) + plt.close() + +def draw_prefill(df: pd.DataFrame, ax: Axes): + stats = [cbook.boxplot_stats(df[df["prefill{}".format(i)].notna()]["prefill{}".format(i)], labels=[i+1])[0] + for i in range(get_max_turn(df))] + print(stats) + ax.bxp(stats, patch_artist=True, boxprops={'facecolor': 'bisque'}, flierprops=dict(marker='o', markersize=2)) + ax.set_ylim(0,600) + ax.set_yticks(np.arange(0,700,100), np.arange(0,700,100), fontsize=9) + ax.set_ylabel("prefill", fontsize=12, rotation=90) + return +def draw_decode(df: pd.DataFrame, ax: Axes): + stats = [cbook.boxplot_stats(df[df["decode{}".format(i)].notna()]["decode{}".format(i)], labels=[i+1])[0] + for i in range(get_max_turn(df))] + print(stats) + ax.bxp(stats, patch_artist=True, boxprops={'facecolor': 'bisque'}, flierprops=dict(marker='o', markersize=2)) + ax.set_ylim(0,600) + ax.set_yticks(np.arange(0,700,100), np.arange(0,700,100), fontsize=9) + ax.set_ylabel("decode", fontsize=12, rotation=90) + return +def draw_pd_ratio(df: pd.DataFrame, ax: Axes): + stats = [cbook.boxplot_stats(df[df["pd_ratio{}".format(i)].notna()]["pd_ratio{}".format(i)], labels=[i+1])[0] + for i in range(get_max_turn(df))] + print(stats) + ax.bxp(stats, patch_artist=True, boxprops={'facecolor': 'bisque'}, flierprops=dict(marker='o', markersize=2)) + ax.plot(np.arange(0,get_max_turn(df)+2), np.ones_like(np.arange(0,get_max_turn(df)+2),dtype=float)) + ax.set_xlim(0, get_max_turn(df)+1) + ax.set_ylim(0, 2.) + ax.set_xticks(np.arange(1,get_max_turn(df)), np.arange(1,get_max_turn(df)), rotation=60, fontsize=9) + ax.set_yticks([0,0.5,1,2], [0,0.5,1,2], fontsize=9) + ax.set_xlabel("round", fontsize=12) + ax.set_ylabel("prefill/decode", fontsize=12, rotation=90) + return +def draw_reuse_kv(df: pd.DataFrame, file_path: str): + plt.close() + _, axs = plt.subplots(3,1,sharex="col") + draw_prefill(df, axs[0]) + draw_decode(df, axs[1]) + draw_pd_ratio(df, axs[2]) + plt.savefig(file_path, dpi=1200) + plt.close() + return +def draw_no_reuse_kv(): + return + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--root", type=str, default="./data") + parser.add_argument("--name", type=str, default="shareGPT_dialog_stats_common_en.csv") + args = parser.parse_args() + + file_path = os.path.join(args.root, args.name) + dist_path = os.path.join(vis_root, args.name.split('.')[0]+"_dist.png") + pd_dist_path = os.path.join(vis_root, args.name.split('.')[0]+"_pd_dist.png") + table = preprocess(file_path) + draw_distribution(table, dist_path) + draw_reuse_kv(table, pd_dist_path) \ No newline at end of file diff --git a/transformers/llm/datasets/visualization/time.py b/transformers/llm/datasets/visualization/time.py new file mode 100644 index 000000000..27cc0069d --- /dev/null +++ b/transformers/llm/datasets/visualization/time.py @@ -0,0 +1,83 @@ +import matplotlib.pyplot as plt +from matplotlib import colors +from matplotlib.ticker import PercentFormatter +from matplotlib import cbook +from matplotlib.axes import Axes +from typing import List, Dict, Tuple +import pandas as pd +import numpy as np +import argparse +import os +import re +from io import StringIO + +def split_by_turns(id: str, content: str) -> List[pd.DataFrame]: + pattern = "<{id}>\n(.*?)\n".format(id=id) + return [pd.read_csv(StringIO(item)) for item in re.findall(pattern, content, flags=re.DOTALL)] +def preprocess(file_path: str) -> Tuple[List[pd.DataFrame], List[pd.DataFrame]]: + content = open(file_path, "rt").read() + return split_by_turns("prefill", content), split_by_turns("decode", content) +def get_max_turn(no_reuse_prefill_record): + return max(10, max([len(record) for record in no_reuse_prefill_record])) +def draw_history_len(ax: Axes, no_reuse_prefill_record: List[pd.DataFrame]): + max_round = get_max_turn(no_reuse_prefill_record) + history_len = [0 for _ in range(0, max_round)] + for turn in range(0, max_round): + history_len[turn] = np.median([record["input_token"][turn] - record["prompt_token"][turn] + for record in no_reuse_prefill_record if len(record)>=turn+1]).item() + plt.plot(np.arange(1, max_round+1), history_len, label="median history len", marker=".", markersize=8) + return +def draw_prefill_bar_chat(ax: Axes, no_reuse, reuse): + offset = 0.2 + max_round = len(no_reuse) + no_reuse_med = [np.median(turn) for turn in no_reuse] + rects = ax.bar(np.arange(1,max_round+1) + offset, no_reuse_med, offset*2, label="no reuse kv", color="tomato") + ax.bar_label(rects, fmt="{:.2f}", padding=4, fontsize=6) + reuse_med = [np.median(turn) for turn in reuse] + rects = ax.bar(np.arange(1,max_round+1) - offset, reuse_med, offset*2, label="reuse kv", color="springgreen") + ax.bar_label(rects, fmt="{:.2f}", padding=4, fontsize=6) + return +def compare_prefill_reuse_kv(no_reuse_prefill_record: List[pd.DataFrame], + reuse_prefill_record: List[pd.DataFrame]): + plt.close() + _,ax1 = plt.subplots() + ax2 = ax1.twinx() + # plot history_len + draw_history_len(ax2, no_reuse_prefill_record) + # calculate per turn + max_round = get_max_turn(no_reuse_prefill_record) + no_reuse = [[] for _ in range(0, max_round)] + for turn in range(0, max_round): + no_reuse[turn] = [record["response_speed"][turn] for record in no_reuse_prefill_record if len(record)>=turn+1] + reuse = [[] for _ in range(0, max_round)] + for turn in range(0, max_round): + reuse[turn] = [record["response_speed"][turn] for record in reuse_prefill_record if len(record)>=turn+1] + # plot the bar chat (with error bar) + draw_prefill_bar_chat(ax1, no_reuse, reuse) + ax1.set_xticks(np.arange(1,max_round+1),np.arange(1,max_round+1),fontsize=9) + ax1.set_ylim(0,100) + ax2.set_ylim(0,1000) + ax1.legend(loc='upper left', title="prefill response speed") + ax2.legend(loc='upper right') + ax1.set_ylabel("prefill\nresponse\nspeed", rotation=0, labelpad=12) + ax2.set_ylabel("history\nlen", rotation=0, labelpad=8) + ax1.set_xlabel("round") + plt.title("KV cache reuse for multi-turn chat\neffects on ShareGPT") + plt.tight_layout() + plt.savefig("./pic/fig.png",dpi=1200) + plt.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--root", type=str, default="./data") + parser.add_argument("--no_reuse", type=str, default="shareGPT_common_en_70k_noreuse.txt") + parser.add_argument("--reuse", type=str, default="shareGPT_common_en_70k_reuse.txt") + args = parser.parse_args() + + no_reuse_file_path = os.path.join(args.root, args.no_reuse) + reuse_file_path = os.path.join(args.root, args.reuse) + no_reuse_prefill_record, no_reuse_decode_record = preprocess(no_reuse_file_path) + reuse_prefill_record, reuse_decode_record = preprocess(reuse_file_path) + # visualize prefill + compare_prefill_reuse_kv(no_reuse_prefill_record, reuse_prefill_record) diff --git a/transformers/llm/engine/CMakeLists.txt b/transformers/llm/engine/CMakeLists.txt index 5e2703acf..b7ccf1139 100644 --- a/transformers/llm/engine/CMakeLists.txt +++ b/transformers/llm/engine/CMakeLists.txt @@ -25,12 +25,15 @@ else() add_library(llm OBJECT ${SRCS}) endif() -add_executable(llm_demo ${CMAKE_CURRENT_LIST_DIR}/llm_demo.cpp) -add_executable(embedding_demo ${CMAKE_CURRENT_LIST_DIR}/embedding_demo.cpp) +add_executable(llm_demo ${CMAKE_CURRENT_LIST_DIR}/app/llm_demo.cpp) +add_executable(ppl_demo ${CMAKE_CURRENT_LIST_DIR}/app/ppl_demo.cpp) +add_executable(embedding_demo ${CMAKE_CURRENT_LIST_DIR}/app/embedding_demo.cpp) IF (NOT MNN_SEP_BUILD) target_link_libraries(llm_demo ${MNN_DEPS}) + target_link_libraries(ppl_demo ${MNN_DEPS}) target_link_libraries(embedding_demo ${MNN_DEPS}) ELSE () target_link_libraries(llm_demo ${MNN_DEPS} llm) + target_link_libraries(ppl_demo ${MNN_DEPS} llm) target_link_libraries(embedding_demo ${MNN_DEPS} llm) ENDIF () diff --git a/transformers/llm/engine/embedding_demo.cpp b/transformers/llm/engine/app/embedding_demo.cpp similarity index 100% rename from transformers/llm/engine/embedding_demo.cpp rename to transformers/llm/engine/app/embedding_demo.cpp diff --git a/transformers/llm/engine/llm_demo.cpp b/transformers/llm/engine/app/llm_demo.cpp similarity index 65% rename from transformers/llm/engine/llm_demo.cpp rename to transformers/llm/engine/app/llm_demo.cpp index eef45cb75..55a1b09df 100644 --- a/transformers/llm/engine/llm_demo.cpp +++ b/transformers/llm/engine/app/llm_demo.cpp @@ -6,21 +6,24 @@ // #include "llm/llm.hpp" +#include "evaluation/dataset.hpp" #define MNN_OPEN_TIME_TRACE #include #include #include #include #include +#include using namespace MNN::Transformer; static void trace_prepare(Llm* llm) { MNN_PRINT("Prepare for resize opt Begin\n"); llm->trace(true); std::ostringstream cacheOs; - llm->generate({200, 200}, &cacheOs, ""); + llm->generate(std::initializer_list{200, 200}, &cacheOs, ""); MNN_PRINT("Prepare for resize opt End\n"); llm->trace(false); + llm->reset(); } static void tuning_prepare(Llm* llm) { @@ -29,55 +32,7 @@ static void tuning_prepare(Llm* llm) { MNN_PRINT("Prepare for tuning opt End\n"); } -std::vector> parse_csv(const std::vector& lines) { - std::vector> csv_data; - std::string line; - std::vector row; - std::string cell; - bool insideQuotes = false; - bool startCollecting = false; - - // content to stream - std::string content = ""; - for (auto line : lines) { - content = content + line + "\n"; - } - std::istringstream stream(content); - - while (stream.peek() != EOF) { - char c = stream.get(); - if (c == '"') { - if (insideQuotes && stream.peek() == '"') { // quote - cell += '"'; - stream.get(); // skip quote - } else { - insideQuotes = !insideQuotes; // start or end text in quote - } - startCollecting = true; - } else if (c == ',' && !insideQuotes) { // end element, start new element - row.push_back(cell); - cell.clear(); - startCollecting = false; - } else if ((c == '\n' || stream.peek() == EOF) && !insideQuotes) { // end line - row.push_back(cell); - csv_data.push_back(row); - cell.clear(); - row.clear(); - startCollecting = false; - } else { - cell += c; - startCollecting = true; - } - } - return csv_data; -} - static int benchmark(Llm* llm, const std::vector& prompts) { - int prompt_len = 0; - int decode_len = 0; - int64_t prefill_time = 0; - int64_t decode_time = 0; - // llm->warmup(); for (int i = 0; i < prompts.size(); i++) { const auto& prompt = prompts[i]; // prompt start with '#' will be ignored @@ -85,20 +40,14 @@ static int benchmark(Llm* llm, const std::vector& prompts) { continue; } llm->response(prompt); - prompt_len += llm->prompt_len_; - decode_len += llm->gen_seq_len_; - prefill_time += llm->prefill_us_; - decode_time += llm->decode_us_; } - float prefill_s = prefill_time / 1e6; - float decode_s = decode_time / 1e6; printf("\n#################################\n"); - printf("prompt tokens num = %d\n", prompt_len); - printf("decode tokens num = %d\n", decode_len); - printf("prefill time = %.2f s\n", prefill_s); - printf(" decode time = %.2f s\n", decode_s); - printf("prefill speed = %.2f tok/s\n", prompt_len / prefill_s); - printf(" decode speed = %.2f tok/s\n", decode_len / decode_s); + printf("prompt tokens num = %d\n", llm->getTotalPromptLen()); + printf("decode tokens num = %d\n", llm->getTotalDecodeLen()); + printf("prefill time = %.2f s\n", llm->getTotalPrefillTime()); + printf(" decode time = %.2f s\n", llm->getTotalDecodeTime()); + printf("prefill speed = %.2f tok/s\n", llm->average_prefill_speed()); + printf(" decode speed = %.2f tok/s\n", llm->average_decode_speed()); printf("##################################\n"); return 0; } diff --git a/transformers/llm/engine/app/ppl_demo.cpp b/transformers/llm/engine/app/ppl_demo.cpp new file mode 100644 index 000000000..393b5b86d --- /dev/null +++ b/transformers/llm/engine/app/ppl_demo.cpp @@ -0,0 +1,61 @@ +// +// ppl_demo.cpp +// +// Created by MNN on 2023/03/24. +// ZhaodeWang +// + +#include "llm/llm.hpp" +#define MNN_OPEN_TIME_TRACE +#include +#include +#include +#include +#include +#include +#include +using namespace MNN::Transformer; +static void trace_prepare(Llm* llm) { + MNN_PRINT("Prepare for resize opt Begin\n"); + llm->trace(true); + std::ostringstream cacheOs; + llm->response("Hello", &cacheOs); + MNN_PRINT("Prepare for resize opt End\n"); + llm->trace(false); + llm->reset(); +} + +// parse json + +static int ppl_eval(Llm* llm, std::string prompt_file, std::ofstream* perfOS) { + std::cout << "prompt file is " << prompt_file << std::endl; + // ppl evaluation + std::vector ppls = llm->perplexity(prompt_file, perfOS); + float mean_ppl = 0.f; + for (int j = 0; j < ppls.size(); ++j) mean_ppl += ppls[j]; + mean_ppl /= ppls.size(); + std::cout << mean_ppl << std::endl; + return 0; +} + +int main(int argc, const char* argv[]) { + if (argc < 3) { + std::cout << "Usage: " << argv[0] << " config.json ppl-prompt.txt [perf.txt]" << std::endl; + return 0; + } + std::string config_path = argv[1]; + std::cout << "config path is " << config_path << std::endl; + std::unique_ptr llm(Llm::createLLM(config_path)); + { + AUTOTIME; + llm->load(); + } + { + AUTOTIME; + trace_prepare(llm.get()); + } + std::string prompt_file = argv[2]; + std::unique_ptr perfOS(nullptr); + if (argc == 4) { perfOS.reset(new std::ofstream(argv[3])); } + return ppl_eval(llm.get(), prompt_file, perfOS.get()); +} diff --git a/transformers/llm/engine/include/evaluation/MemMonitor.hpp b/transformers/llm/engine/include/evaluation/MemMonitor.hpp new file mode 100644 index 000000000..7750eb56c --- /dev/null +++ b/transformers/llm/engine/include/evaluation/MemMonitor.hpp @@ -0,0 +1,42 @@ +#ifndef MEMMONITOR_hpp +#define MEMMONITOR_hpp + +#include +#include +#include +#include +#include +#include +#include + +#define BUFFER_SIZE 256 + +struct MemoryInfo { + // in MB + float total_phys_mem; + float free_phys_mem; + float total_swap; + float free_swap; + float process_resident_set_size; + float process_swap; + float process_virtual_mem_total; + float process_virtual_mem_used; +}; + + +#if defined(__ANDROID__) || defined(linux) || defined(__APPLE__) || defined(__MACOSX) +#define SELF_FILE "/proc/self/status" +#define MEMINFO_FILE "/proc/meminfo" +#endif // linux + +int readMemInfo(MemoryInfo *mem_info); + +int readProcStatus(MemoryInfo *mem_info); + +void printMemoryInfo(const MemoryInfo *mem_info); + +float getSysMemInc(MemoryInfo* prev, MemoryInfo* now); + +float getProcMem(MemoryInfo* info); + +#endif \ No newline at end of file diff --git a/transformers/llm/engine/include/evaluation/dataset.hpp b/transformers/llm/engine/include/evaluation/dataset.hpp new file mode 100644 index 000000000..b9585fe71 --- /dev/null +++ b/transformers/llm/engine/include/evaluation/dataset.hpp @@ -0,0 +1,33 @@ +#ifndef LLM_DATASET_hpp +#define LLM_DATASET_hpp + +#include +#include +#include +#include +#include +#include +#include +#include +#include "llm/llm.hpp" + +#include + +namespace MNN { +namespace Transformer { + + +// parse csv +MNN_PUBLIC std::vector> parse_csv(const std::vector& lines); +void parse_jsonl(std::string prompt_file, std::vector>>& dialogs); + +std::string getPPLType(std::string dataset_name); +std::vector rowsplit(std::string prompt_file); +std::vector plaintext(std::string prompt_file); +std::vector wikitext(std::string prompt_file); +std::vector>> shareGPT(std::string prompt_file, int sample_size=-1); // -1: no sampling + +} // Transformer +} // MNN + +#endif // LLM_DATASET_hpp \ No newline at end of file diff --git a/transformers/llm/engine/include/evaluation/evaluation.hpp b/transformers/llm/engine/include/evaluation/evaluation.hpp new file mode 100644 index 000000000..9ecf0c335 --- /dev/null +++ b/transformers/llm/engine/include/evaluation/evaluation.hpp @@ -0,0 +1,68 @@ + + +#ifndef TRANSFORMER_EVALUATION_hpp +#define TRANSFORMER_EVALUATION_hpp + +#include +#include +#include "MemMonitor.hpp" + +namespace MNN { +namespace Transformer { + +#define MICRO_TO_MILLI 1e-3f +#define MILLI_TO_MICRO 1000 +#define MICRO_TO_SEC 1e-6f +#define SEC_TO_MICRO 1000000 + +#define MEGA_TO_GIGA (1/1024.f) +#define GIGA_TO_MEGA 1024.f +#define KILLO_TO_GIGA (1/1024.f/1024.f) +#define GIGA_TO_KILLO (1024.f*1024.f) +#define KILLO_TO_MEGA (1/1024.f) +#define MEGA_TO_KILLO 1024.f +#define BYTE_TO_MEGA (1/1024.f/1024.f) +#define MEGA_TO_BYTE (1024.f*1024.f) + +struct PrefillTimePerformance { + size_t prefill_prev_token_ = 0; + size_t prefill_token_ = 0; + size_t prefill_us_ = 0; +}; + +struct DecodeTimePerformance { + size_t decode_prev_token_ = 0; + size_t decode_us_ = 0; +}; + +struct TimePerformance { + std::vector prefill_record_; + std::vector decode_record_; + std::vector prompt_record_; +}; + +void appendNewPromptRecord(struct TimePerformance* perf, int input_len, bool reuse_kv); + +struct PrefillMemPerformance { + size_t prefill_prev_token_ = 0; + size_t prefill_token_ = 0; + float prefill_MB_ = 0; +}; + +struct DecodeMemPerformance { + size_t decode_prev_token_ = 0; + float decode_MB_ = 0; +}; + +struct MemPerformance { + std::vector prefill_record_; + std::vector decode_record_; +}; + +void mergePerformance(struct TimePerformance* dst, struct TimePerformance* src); +void mergePerformance(struct MemPerformance* dst, struct MemPerformance* src); +void clearPerformance(struct TimePerformance* perf); +void clearPerformance(struct MemPerformance* perf); +} // namespace Transformer +} // namespace MNN +#endif // TRANSFORMER_EVALUATION_hpp \ No newline at end of file diff --git a/transformers/llm/engine/include/llm/llm.hpp b/transformers/llm/engine/include/llm/llm.hpp index 6d0ac2e99..b0d970aa0 100644 --- a/transformers/llm/engine/include/llm/llm.hpp +++ b/transformers/llm/engine/include/llm/llm.hpp @@ -18,6 +18,7 @@ #include #include +#include "evaluation/evaluation.hpp" #include #include #include @@ -28,6 +29,41 @@ namespace Transformer { class Tokenizer; class Pipeline; class LlmConfig; +class Sampler; +class PromptLib; +struct TimePerformance; + + +// +#define PromptItem std::pair + +class MNN_PUBLIC LlmSessionInfo { +public: + // Llm::forward needs, for mask and embedding. + int all_seq_len_=0, gen_seq_len_=0; + // Sampler needs + std::vector tokens; + // PromptLib needs + std::vector mHistory; + std::vector mInputs; + // Performance needs + struct TimePerformance mTimePerformance; +public: + LlmSessionInfo():all_seq_len_(0),gen_seq_len_(0){} + void resetSamplerFields(); + void resetPromptFields(); + void resetPerformanceFields(); + void print_speed(std::ostream* os); + float average_total_speed(); + float average_prefill_speed(); + float average_decode_speed(); + float getTotalPrefillTime(); + float getTotalDecodeTime(); + int getTotalPromptLen(); + int getTotalDecodeLen(); +}; + + class DiskEmbedding; enum TuneType { @@ -36,26 +72,30 @@ enum TuneType { }; class MNN_PUBLIC Llm { - using PromptItem = std::pair; // +public: + std::shared_ptr mSampler; + std::shared_ptr mPromptLib; + std::vector mLlmSessionInfos; // Llm conversation session information. Currently, only mLlmSessionInfos[0] is allowed! public: Llm(std::shared_ptr config) : config_(config) {} virtual ~Llm(); static Llm* createLLM(const std::string& config_path); - void chat(); + void chat(bool session_by_line = false, bool from_file = false, + std::istream* is = &std::cin, std::ostream* os = &std::cout, + const char* end_with = "\n", std::string exit_prompt = "/exit", std::string reset_token = "/reset"); void reset(); void trace(bool start); void tuning(TuneType type, std::vector candidates); virtual void load(); - MNN::Express::VARP forward(const std::vector& input_ids); - int sample(MNN::Express::VARP logits, const std::vector& pre_ids); - std::string apply_prompt_template(const std::string& user_content) const; - std::string apply_chat_template(const std::vector& chat_prompts) const; + MNN::Express::VARP forward(const std::vector& input_ids, bool is_prefill=true); std::string response(const std::string& user_content, std::ostream* os = &std::cout, const char* end_with = nullptr); - std::string response(const std::vector& chat_prompts, std::ostream* os = &std::cout, const char* end_with = nullptr); + std::string generate(const std::string& prompt, std::ostream* os = &std::cout, const char* end_with = "\n"); + std::string generate(const std::vector& input_ids, std::ostream* os = &std::cout, const char* end_with = "\n"); void generate_init(); - std::string generate(const std::vector& input_ids, std::ostream* os, const char* end_with); - std::vector generate(const std::vector& input_ids, int max_new_tokens = -1); + std::string generateTrace(const std::vector& input_ids, std::ostream* os, const char* end_with); void print_speed(); + void print_speed(std::ostream* os); + std::vector perplexity(std::string prompt_file, std::ostream* statsOS = nullptr); // config function std::string dump_config(); bool set_config(const std::string& content); @@ -70,16 +110,18 @@ class MNN_PUBLIC Llm { virtual std::vector tokenizer_encode(const std::string& query, bool use_template = true); friend class Pipeline; public: - // forward info - int prompt_len_ = 0; - int gen_seq_len_ = 0; - int all_seq_len_ = 0; - std::vector history_ids_; - // time - int64_t prefill_us_ = 0; - int64_t decode_us_ = 0; bool is_single_ = true; bool attention_fused_ = true; + bool reuse_kv() const; +public: + // time profile + float average_total_speed(); + float average_prefill_speed(); + float average_decode_speed(); + float getTotalPrefillTime(); + float getTotalDecodeTime(); + int getTotalPromptLen(); + int getTotalDecodeLen(); protected: std::shared_ptr config_; std::shared_ptr tokenizer_; @@ -96,6 +138,10 @@ class MNN_PUBLIC Llm { virtual MNN::Express::VARP gen_attention_mask(int seq_len); virtual MNN::Express::VARP gen_position_ids(int seq_len); bool mTracing = false; +protected: + bool getUserPrompt(bool from_file, std::istream* is, std::string& user_str); + void chat_init(); + void chat_reset(); }; // Embedding start diff --git a/transformers/llm/engine/ios/mnn-llm/mnn-llm/LLMInferenceEngineWrapper.mm b/transformers/llm/engine/ios/mnn-llm/mnn-llm/LLMInferenceEngineWrapper.mm index 4561338f8..5db6a9c81 100644 --- a/transformers/llm/engine/ios/mnn-llm/mnn-llm/LLMInferenceEngineWrapper.mm +++ b/transformers/llm/engine/ios/mnn-llm/mnn-llm/LLMInferenceEngineWrapper.mm @@ -89,26 +89,16 @@ - (void)processInput:(NSString *)input withStreamHandler:(StreamOutputHandler)ha } prompts.push_back(prompt); } - int prompt_len = 0; - int decode_len = 0; - int64_t prefill_time = 0; - int64_t decode_time = 0; for (int i = 0; i < prompts.size(); i++) { llm->response(prompts[i], &os, "\n"); - prompt_len += llm->prompt_len_; - decode_len += llm->gen_seq_len_; - prefill_time += llm->prefill_us_; - decode_time += llm->decode_us_; } - float prefill_s = prefill_time / 1e6; - float decode_s = decode_time / 1e6; os << "\n#################################\n" - << "prompt tokens num = " << prompt_len << "\n" - << "decode tokens num = " << decode_len << "\n" - << "prefill time = " << std::fixed << std::setprecision(2) << prefill_s << " s\n" - << " decode time = " << std::fixed << std::setprecision(2) << decode_s << " s\n" - << "prefill speed = " << std::fixed << std::setprecision(2) << prompt_len / prefill_s << " tok/s\n" - << " decode speed = " << std::fixed << std::setprecision(2) << decode_len / decode_s << " tok/s\n" + << "prompt tokens num = " << llm->getTotalPromptLen() << "\n" + << "decode tokens num = " << llm->getTotalDecodeLen() << "\n" + << "prefill time = " << std::fixed << std::setprecision(2) << llm->getTotalPrefillTime() << " s\n" + << " decode time = " << std::fixed << std::setprecision(2) << llm->getTotalDecodeTime() << " s\n" + << "prefill speed = " << std::fixed << std::setprecision(2) << llm->average_prefill_speed() << " tok/s\n" + << " decode speed = " << std::fixed << std::setprecision(2) << llm->average_decode_speed() << " tok/s\n" << "##################################\n"; os << ""; } else { diff --git a/transformers/llm/engine/src/LlmSessionInfo.cpp b/transformers/llm/engine/src/LlmSessionInfo.cpp new file mode 100644 index 000000000..a46dc5d8c --- /dev/null +++ b/transformers/llm/engine/src/LlmSessionInfo.cpp @@ -0,0 +1,87 @@ + +#include "llm/llm.hpp" + +namespace MNN { +namespace Transformer { + +// LlmSessionInfo starts +void LlmSessionInfo::resetSamplerFields() { + all_seq_len_ = 0; + gen_seq_len_ = 0; + tokens.clear(); +} +void LlmSessionInfo::resetPromptFields() { + mHistory.clear(); + mInputs.clear(); +} +void LlmSessionInfo::resetPerformanceFields() { + clearPerformance(&mTimePerformance); +} +float LlmSessionInfo::average_total_speed() { + return (getTotalPromptLen()+getTotalDecodeLen())/(getTotalPrefillTime()+getTotalDecodeTime()); +} +float LlmSessionInfo::average_prefill_speed() { + // prefill response rate + return getTotalPromptLen()/getTotalPrefillTime(); +} +float LlmSessionInfo::average_decode_speed() { + return getTotalDecodeLen()/getTotalDecodeTime(); +} +float LlmSessionInfo::getTotalPrefillTime() { + float sum = 0.f; + for (auto record : mTimePerformance.prefill_record_) { + sum += ((float)record.prefill_us_)*MICRO_TO_SEC; + } + return sum; +} +float LlmSessionInfo::getTotalDecodeTime() { + float sum = 0.0f; + for (auto record : mTimePerformance.decode_record_) { + sum += ((float)record.decode_us_)*MICRO_TO_SEC; + } + return sum; +} +int LlmSessionInfo::getTotalPromptLen() { + int prompt_len = 0; + if (mTimePerformance.prefill_record_.size() != mTimePerformance.prompt_record_.size()) { + for (auto record : mTimePerformance.prefill_record_) { + prompt_len += record.prefill_token_; + } + } else { + for (int r=0; r < mTimePerformance.prompt_record_.size(); ++r) { + prompt_len += mTimePerformance.prompt_record_[r]; + } + } + return prompt_len; +} +int LlmSessionInfo::getTotalDecodeLen() { + return mTimePerformance.decode_record_.size(); +} +void LlmSessionInfo::print_speed(std::ostream* os) { + // prefill statistics + (*os) << "" << std::endl; + if (mTimePerformance.prefill_record_.size() != mTimePerformance.prompt_record_.size()) { + (*os) << "prev_token,input_token,response_speed" << std::endl; + for (auto record : mTimePerformance.prefill_record_) { + (*os) << record.prefill_prev_token_ << "," << record.prefill_token_ << "," << record.prefill_token_/(((float)record.prefill_us_)*MICRO_TO_SEC) << std::endl; + } + } else { + (*os) << "prev_token,input_token,prompt_token,response_speed" << std::endl; + for (int r=0; r < mTimePerformance.prompt_record_.size(); ++r) { + auto record = mTimePerformance.prefill_record_[r]; + auto prompt_len = mTimePerformance.prompt_record_[r]; + (*os) << record.prefill_prev_token_ << "," << record.prefill_token_ << "," << prompt_len << "," << prompt_len/(((float)record.prefill_us_)*MICRO_TO_SEC) << std::endl; + } + } + (*os) << "" << std::endl; + // decode statistics + (*os) << "" << std::endl; + (*os) << "prev_token,response_speed" << std::endl; + for (auto record : mTimePerformance.decode_record_) { + (*os) << record.decode_prev_token_ << "," << 1./(((float)record.decode_us_)*MICRO_TO_SEC) << std::endl; + } + (*os) << "" << std::endl; +} + +} // Transformer +} // MNN \ No newline at end of file diff --git a/transformers/llm/engine/src/dataset.cpp b/transformers/llm/engine/src/dataset.cpp new file mode 100644 index 000000000..c4d0db45a --- /dev/null +++ b/transformers/llm/engine/src/dataset.cpp @@ -0,0 +1,223 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "evaluation/dataset.hpp" +#include +#include +#include + +namespace MNN { +namespace Transformer { + + +// parse file +// csv json + +// parse csv +std::vector> parse_csv(const std::vector& lines) { + std::vector> csv_data; + std::string line; + std::vector row; + std::string cell; + bool insideQuotes = false; + bool startCollecting = false; + + // content to stream + std::string content = ""; + for (auto line : lines) { + content = content + line + "\n"; + } + std::istringstream stream(content); + + while (stream.peek() != EOF) { + char c = stream.get(); + if (c == '"') { + if (insideQuotes && stream.peek() == '"') { // quote + cell += '"'; + stream.get(); // skip quote + } else { + insideQuotes = !insideQuotes; // start or end text in quote + } + startCollecting = true; + } else if (c == ',' && !insideQuotes) { // end element, start new element + row.push_back(cell); + cell.clear(); + startCollecting = false; + } else if ((c == '\n' || stream.peek() == EOF) && !insideQuotes) { // end line + row.push_back(cell); + csv_data.push_back(row); + cell.clear(); + row.clear(); + startCollecting = false; + } else { + cell += c; + startCollecting = true; + } + } + return csv_data; +} + +// dialog, turn, +void parse_jsonl(std::string prompt_file, std::vector>>& dialogs) { + std::ifstream prompt_fs(prompt_file); + std::string prompt; + while(std::getline(prompt_fs, prompt)) { + rapidjson::Document document; + document.Parse(prompt.c_str()); + std::vector> cnv; + if(document.HasMember("conversation")) { + auto& value = document["conversation"]; + if (value.IsArray()) { + for (auto& v : value.GetArray()) { + if (v.IsObject()) { + std::vector result; + for (auto itr = v.MemberBegin(); itr != v.MemberEnd(); ++itr) { + // {"human"/"user": , "assistant": } + result.push_back(std::make_pair(itr->name.GetString(), itr->value.GetString())); + } + cnv.push_back(result); + } + } + } + } + dialogs.push_back(cnv); + } +} + +void write_jsonl(std::string prompt_file, const std::vector>>& dialogs) { + std::ofstream prompt_fs(prompt_file); + for(auto& dialog : dialogs) { + rapidjson::Document document; + document.SetObject(); + rapidjson::Value conversation(rapidjson::kArrayType); + conversation.SetArray(); + for (auto& turn : dialog) { + rapidjson::Value sentence(rapidjson::kObjectType); + sentence.SetObject(); + for (auto& role : turn) { + sentence.AddMember(rapidjson::Value(role.first.c_str(), document.GetAllocator()), + rapidjson::Value(role.second.c_str(), document.GetAllocator()), document.GetAllocator()); + } + conversation.PushBack(sentence, document.GetAllocator()); + } + document.AddMember("conversation", conversation, document.GetAllocator()); + // write to file + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + document.Accept(writer); + prompt_fs << buffer.GetString() << std::endl; + } +} + + +// dataset +// wikitext, ShareGPT + +std::string getPPLType(std::string dataset_name) { + if (dataset_name == "wikitext" + || dataset_name == "plaintext" + || dataset_name == "rowsplit") { + return "text"; + } else if (dataset_name == "shareGPT") { + return "chat"; + } else { + // default chat + return "chat"; + } +} + +std::vector plaintext(std::string prompt_file) { + // split by line + std::ifstream prompt_fs(prompt_file); + std::vector prompts; + std::string prompt; + prompts.push_back(""); + while (std::getline(prompt_fs, prompt)) { + if (prompt.back() == '\r' || prompt.back() == '\n') { + prompt.pop_back(); + } + // concatenate. + prompts.back() += prompt + "\n"; + } + return prompts; +} + +std::vector rowsplit(std::string prompt_file) { + // split by line + std::ifstream prompt_fs(prompt_file); + std::vector prompts; + std::string prompt; + while (std::getline(prompt_fs, prompt)) { + if (prompt.back() == '\r' || prompt.back() == '\n') { + prompt.pop_back(); + } + prompts.push_back(prompt); + } + return prompts; +} + +// wikitext +void removeSubstrs(std::string& s, std::string p) { + std::string::size_type n = p.length(); + for (std::string::size_type i = s.find(p); i != std::string::npos; i = s.find(p)) + s.erase(i, n); +} +std::vector wikitext(std::string prompt_file) { + // split wiki text into " = " first-level column. + std::ifstream prompt_fs(prompt_file); + std::vector prompts; + std::string prompt; + while (std::getline(prompt_fs, prompt)) { + if (prompt.back() == '\r' || prompt.back() == '\n') { + prompt.pop_back(); + } + if (prompt.size() < 4) continue; + removeSubstrs(prompt, "@-@"); + if ((prompts.size() == 0) \ + || (prompt.size() >= 4 \ + && prompt.at(0) == ' ' \ + && prompt.at(1) == '=' \ + && prompt.at(2) == ' ' \ + && prompt.at(3) != '=')) { + // first-level column. + prompts.push_back(prompt); + } else { + // concatenate. + prompts.back() += "\n" + prompt; + } + } + return prompts; +} + +std::string genSampleName(std::string oriName, int sample_size) { + const size_t last_slash_idx = oriName.rfind('.'); + auto stem = oriName.substr(0, last_slash_idx); + return stem + "_sample" + std::to_string(sample_size) + ".jsonl"; +} + +std::vector>> shareGPT(std::string prompt_file, int sample_size) { + std::vector>> dialogs, dataset; + parse_jsonl(prompt_file, dialogs); + // randomly sample a subset + if (sample_size > 0 && sample_size < dialogs.size()){ + std::random_device rd; + std::mt19937 g(rd()); + std::shuffle(dialogs.begin(), dialogs.end(), g); + dataset.insert(dataset.end(), dialogs.begin(), dialogs.begin() + sample_size); + dialogs = dataset; + // store dialogs to file + write_jsonl(genSampleName(prompt_file, sample_size), dialogs); + } + return dialogs; +} + + +} // Transformer +} // MNN diff --git a/transformers/llm/engine/src/evaluation.cpp b/transformers/llm/engine/src/evaluation.cpp new file mode 100644 index 000000000..cd2524a4e --- /dev/null +++ b/transformers/llm/engine/src/evaluation.cpp @@ -0,0 +1,30 @@ + + +#include +#include +#include "evaluation/evaluation.hpp" + +namespace MNN { +namespace Transformer { + +void clearPerformance(struct TimePerformance* perf) { + perf->prefill_record_.clear(); + perf->decode_record_.clear(); + perf->prompt_record_.clear(); +} +void appendNewPromptRecord(struct TimePerformance* perf, int input_len, bool reuse_kv) { + if (reuse_kv) { + perf->prompt_record_.push_back(input_len); + } else { + // not reuse kv + if (!perf->decode_record_.empty()) { + perf->prompt_record_.push_back(input_len - (perf->decode_record_.back().decode_prev_token_+1)); + } else { + // first prefill + perf->prompt_record_.push_back(input_len); + } + } +} + +} // Transformer +} // MNN \ No newline at end of file diff --git a/transformers/llm/engine/src/llm.cpp b/transformers/llm/engine/src/llm.cpp index fd3a9d037..6a1d7fdf8 100644 --- a/transformers/llm/engine/src/llm.cpp +++ b/transformers/llm/engine/src/llm.cpp @@ -16,6 +16,9 @@ #include #include "cpp/ExprDebug.hpp" #include "llm/llm.hpp" +#include "evaluation/evaluation.hpp" +#include "sampler.hpp" +#include "prompt.hpp" #include "tokenizer.hpp" #include "llmconfig.hpp" // 0: no debug, 1: test op time, 2: print tensor info @@ -266,7 +269,7 @@ void Llm::load() { } MNN_PRINT("Load Module Done!\n"); } else { - MNN_ERROR("Split version is depercerate\n"); + MNN_ERROR("Split version is deprecated\n"); } decode_modules_.resize(modules_.size()); for (int v=0; v candidates) { if (logits->getInfo()->size == 0) { return; } - auto token = sample(logits, {}); + // no need for sampling here, because metal OP does not affects much in sampling. auto et = std::chrono::system_clock::now(); int64_t time = std::chrono::duration_cast(et - st).count(); if(time < min_time) { @@ -380,7 +394,9 @@ void Llm::tuning(TuneType type, std::vector candidates) { runtime_manager_->setHint(MNN::Interpreter::OP_ENCODER_NUMBER_FOR_COMMIT, prefer_candidate); } -VARP Llm::forward(const std::vector& input_ids) { +VARP Llm::forward(const std::vector& input_ids, bool is_prefill) { + if (is_prefill) current_modules_ = prefill_modules_; + else current_modules_ = decode_modules_; int seq_len = input_ids.size(); auto attention_mask = gen_attention_mask(seq_len); auto position_ids = gen_position_ids(seq_len); @@ -401,245 +417,134 @@ VARP Llm::forward(const std::vector& input_ids) { past_key_values_[0] = outputs[1]; } } else { - MNN_ERROR("Split models is depercarate\n"); + MNN_ERROR("Split models is deprecated\n"); return nullptr; } - all_seq_len_ += seq_len; - gen_seq_len_++; + // sequence length is handled in forward. + mLlmSessionInfos[0].all_seq_len_ += seq_len; + mLlmSessionInfos[0].gen_seq_len_++; return logits; } -int Llm::sample(VARP logits, const std::vector& pre_ids) { - std::unordered_set ids_set(pre_ids.begin(), pre_ids.end()); - auto scores = (float*)(logits->readMap()); - auto size = logits->getInfo()->size; - // repetition penalty - const float repetition_penalty = 1.1; - for (auto id : ids_set) { - float score = scores[id]; - scores[id] = score < 0 ? score * repetition_penalty : score / repetition_penalty; - } - // argmax - float max_score = 0; - int token_id = 0; - for (int i = 0; i < size; i++) { - float score = scores[i]; - if (score > max_score) { - max_score = score; - token_id = i; - } - } - return token_id; -} - -static std::string apply_template(std::string prompt_template, const std::string& content, const std::string& role = "") { - if (prompt_template.empty()) return content; - if (!role.empty()) { - const std::string placeholder = "%r"; - size_t start_pos = prompt_template.find(placeholder); - if (start_pos == std::string::npos) return content; - prompt_template.replace(start_pos, placeholder.length(), role); - } - const std::string placeholder = "%s"; - size_t start_pos = prompt_template.find(placeholder); - if (start_pos == std::string::npos) return content; - prompt_template.replace(start_pos, placeholder.length(), content); - return prompt_template; -} - -std::string Llm::apply_prompt_template(const std::string& user_content) const { - auto chat_prompt = config_->prompt_template(); - return apply_template(chat_prompt, user_content); +// < "app_type": "chat" +bool Llm::getUserPrompt(bool from_file, std::istream* is, std::string& user_str) { + if (!from_file) std::cout << "\nQ: "; + return (bool)std::getline(*is, user_str); } -std::string Llm::apply_chat_template(const std::vector& chat_prompts) const { - auto chat_template = config_->chat_template(); - std::string prompt_result; - auto iter = chat_prompts.begin(); - for (; iter != chat_prompts.end() - 1; ++iter) { - prompt_result += apply_template(chat_template, iter->second, iter->first); - } - if (iter->first == "user") { - prompt_result += apply_prompt_template(iter->second); - } else { - prompt_result += apply_template(chat_template, iter->second, iter->first); - } - return prompt_result; -} - -void Llm::chat() { - std::vector history; - history.push_back(std::make_pair("system", "You are a helpful assistant.")); - while (true) { - std::cout << "\nQ: "; - std::string user_str; - std::cin >> user_str; - if (user_str == "/exit") { +void Llm::chat(bool session_by_line, bool from_file, + std::istream* is, std::ostream* os, + const char* end_with, std::string exit_token, std::string reset_token) { + // handle system prompt + reset(); + std::string user_str; + while (getUserPrompt(from_file, is, user_str)) { + // whether to end + if (user_str == exit_token) { + reset(); break; } - if (user_str == "/reset") { - history.resize(1); - std::cout << "\nA: reset done." << std::endl; + // whether to reset + if (session_by_line || user_str == reset_token) { + reset(); + if (!from_file) std::cout << "\nreset done." << std::endl; continue; } - std::cout << "\nA: " << std::flush; - if (config_->reuse_kv()) { - response(user_str); - } else { - history.emplace_back(std::make_pair("user", user_str)); - auto assistant_str = response(history); - history.emplace_back(std::make_pair("assistant", assistant_str)); - } - std::cout << std::endl; + // get answer + (*os) << "\nA: " << std::flush; + response(user_str, os, end_with); + (*os) << std::endl; } + reset(); +} + +std::string Llm::response(const std::string& user_str, std::ostream* os, const char* end_with) { + mPromptLib->appendUserPrompt(user_str); + auto assistant_str = generate(mPromptLib->getLLMInput(), os, end_with); + mPromptLib->appendLLMOutput(assistant_str); + return assistant_str; } +// "app_type": "chat" > + + void Llm::reset() { - history_ids_.clear(); - all_seq_len_ = 0; + // clear KV cache + // KV cache automatically cleared as long as seq_len reset! + mLlmSessionInfos.clear(); + mLlmSessionInfos.emplace_back(LlmSessionInfo()); } +bool Llm::reuse_kv() const { + return config_->reuse_kv(); +} + + +// < generate void Llm::generate_init() { - // init status - gen_seq_len_ = 0; - prefill_us_ = 0; - decode_us_ = 0; + // handle past_key_values if not attention_fused_ past_key_values_.clear(); - if (is_single_) { - past_key_values_.push_back(_Input(key_value_shape_, NCHW)); - } else { - for (int i = 0; i < config_->layer_nums(); i++) { + if (!attention_fused_) { + if (is_single_) { past_key_values_.push_back(_Input(key_value_shape_, NCHW)); + } else { + MNN_ERROR("Split version is deprecated\n"); } } - if (!config_->reuse_kv()) { - all_seq_len_ = 0; - history_ids_.clear(); + if (!reuse_kv()) { + // only reset sampler. The history is handled by mPromptLib. + mLlmSessionInfos[0].resetSamplerFields(); } current_modules_ = prefill_modules_; } -std::vector Llm::generate(const std::vector& input_ids, int max_new_tokens) { - generate_init(); - std::vector output_ids, all_ids = input_ids; - prompt_len_ = static_cast(input_ids.size()); - if (max_new_tokens < 0) { max_new_tokens = config_->max_new_tokens(); } - // prefill - current_modules_ = prefill_modules_; - auto logits = forward(input_ids); - if (logits.get() == nullptr) { - return {}; - } - int token = sample(logits, all_ids); - output_ids.push_back(token); - all_ids.push_back(token); - // decode - current_modules_ = decode_modules_; - while (gen_seq_len_ < max_new_tokens) { - logits = nullptr; - logits = forward({token}); - if (logits.get() == nullptr) { - return {}; - } - token = sample(logits, all_ids); - if (is_stop(token)) { break; } - output_ids.push_back(token); - all_ids.push_back(token); - } - return output_ids; +std::string Llm::generate(const std::string& prompt, std::ostream* os, const char* end_with) { + if (prompt.empty()) { return ""; } + if (!end_with) { end_with = "\n"; } + // std::cout << "# prompt : " << prompt << std::endl; + auto input_ids = tokenizer_encode(prompt, false); + std::string out_str = generate(input_ids, os, end_with); + return out_str; } std::string Llm::generate(const std::vector& input_ids, std::ostream* os, const char* end_with) { + if (mTracing) return generateTrace(input_ids, os, end_with); + if (input_ids.empty()) { return ""; } + if (!end_with) { end_with = "\n"; } + generate_init(); + // printf("input_ids (%lu): ", input_ids.size()); for (auto id : input_ids) printf("%d, ", id); printf("\n"); + std::string out_str = mSampler->sample(input_ids, os, end_with, &(mLlmSessionInfos[0].mTimePerformance)); + return out_str; +} + + +std::string Llm::generateTrace(const std::vector& input_ids, std::ostream* os, const char* end_with) { if (mTracing) { // Skip real forward - current_modules_ = prefill_modules_; - forward(input_ids); - current_modules_ = decode_modules_; - forward({input_ids[0]}); - forward({input_ids[0]}); + forward(input_ids, true); + forward({input_ids[0]}, false); + forward({input_ids[0]}, false); return "Test"; } - prompt_len_ = static_cast(input_ids.size()); - history_ids_.insert(history_ids_.end(), input_ids.begin(), input_ids.end()); // push to history_ids_ - auto st = std::chrono::system_clock::now(); - current_modules_ = prefill_modules_; - auto logits = forward(input_ids); - if (nullptr == logits.get()) { - return ""; - } - int token = sample(logits, history_ids_); - auto et = std::chrono::system_clock::now(); - current_modules_ = decode_modules_; - std::string output_str = tokenizer_decode(token); - prefill_us_ = std::chrono::duration_cast(et - st).count(); - *os << output_str << std::flush; - while (gen_seq_len_ < config_->max_new_tokens()) { - st = std::chrono::system_clock::now(); - history_ids_.push_back(token); - logits = nullptr; - logits = forward({token}); - if (nullptr == logits.get()) { - return ""; - } - if (logits->getInfo()->size == 0) { - return ""; - } - token = sample(logits, history_ids_); - et = std::chrono::system_clock::now(); - decode_us_ += std::chrono::duration_cast(et - st).count(); - if (is_stop(token)) { - *os << end_with << std::flush; - break; - } - auto word = tokenizer_decode(token); - *os << word << std::flush; - output_str += word; - } - ExecutorScope::Current()->gc(Executor::FULL); -#ifdef DUMP_PROFILE_INFO - print_speed(); -#endif - return output_str; + return "Test"; } std::vector Llm::tokenizer_encode(const std::string& user_content, bool use_template) { if (!use_template) { return tokenizer_->encode(user_content); } - auto prompt = apply_prompt_template(user_content); + auto prompt = mPromptLib->applyTemplate(user_content); auto input_ids = tokenizer_->encode(prompt); return input_ids; } -std::string Llm::response(const std::string& user_content, std::ostream* os, const char* end_with) { - generate_init(); - if (!end_with) { end_with = "\n"; } - std::vector input_ids; - if (config_->reuse_kv()) { - auto prompt = apply_prompt_template(user_content); - if (all_seq_len_ > 0) { - prompt = "<|im_end|>\n" + prompt; - } - input_ids = tokenizer_->encode(prompt); - } else { - input_ids = tokenizer_encode(user_content); - } - return generate(input_ids, os, end_with); -} -std::string Llm::response(const std::vector& chat_prompts, std::ostream* os, const char* end_with) { - if (chat_prompts.empty()) { return ""; } - generate_init(); - if (!end_with) { end_with = "\n"; } - auto prompt = apply_chat_template(chat_prompts); - if (config_->reuse_kv() && all_seq_len_ > 0) { - prompt = "<|im_end|>\n" + prompt; - } - // std::cout << "# prompt : " << prompt << std::endl; - auto input_ids = tokenizer_->encode(prompt); - // printf("input_ids (%lu): ", input_ids.size()); for (auto id : input_ids) printf("%d, ", id); printf("\n"); - return generate(input_ids, os, end_with); +// < evaluation +std::vector Llm::perplexity(std::string prompt_file, std::ostream* perfOS) { + return mSampler->perplexity(prompt_file, perfOS); } +// evaluation > + Llm::~Llm() { #if DEBUG_MODE==1 @@ -671,23 +576,64 @@ Llm::~Llm() { runtime_manager_.reset(); } +// < speed void Llm::print_speed() { - auto prefill_s = prefill_us_ * 1e-6; - auto decode_s = decode_us_ * 1e-6; - auto total_s = prefill_s + decode_s; printf("\n#################################\n"); - printf(" total tokens num = %d\n", prompt_len_ + gen_seq_len_); - printf("prompt tokens num = %d\n", prompt_len_); - printf("output tokens num = %d\n", gen_seq_len_); - printf(" total time = %.2f s\n", total_s); - printf("prefill time = %.2f s\n", prefill_s); - printf(" decode time = %.2f s\n", decode_s); - printf(" total speed = %.2f tok/s\n", (prompt_len_ + gen_seq_len_) / total_s); - printf("prefill speed = %.2f tok/s\n", prompt_len_ / prefill_s); - printf(" decode speed = %.2f tok/s\n", gen_seq_len_ / decode_s); - printf(" chat speed = %.2f tok/s\n", gen_seq_len_ / total_s); + printf("average total speed = %.3f tok/s\n", average_total_speed()); + printf("average prefill speed = %.3f tok/s\n", average_prefill_speed()); + printf("average decode speed = %.3f tok/s\n", average_decode_speed()); printf("##################################\n"); + #if DEBUG_MODE==1 + if (nullptr != gTimeTraceInfo) { + float opSummer = 0.0f; + float opFlopsSummber = 0.0f; + for (auto& iter : gTimeTraceInfo->mTypes) { + float summer = 0.0f; + float summerflops = 0.0f; + for (auto& t : iter.second) { + for (auto& t0 : t.second) { + summer += t0.first; + summerflops += t0.second; + } + } + summer = summer; + summerflops = summerflops; + MNN_PRINT("%s : %.7f, FLOP: %.7f, Speed: %.7f GFlops\n", iter.first.c_str(), summer, summerflops, summerflops / summer); + opSummer += summer; + opFlopsSummber+= summerflops; + } + MNN_PRINT("OP Summer: %.7f, Flops: %.7f, Speed: %.7f GFlops\n", opSummer, opFlopsSummber, opFlopsSummber/opSummer); + } + #endif +} + +void Llm::print_speed(std::ostream* os) { + mLlmSessionInfos[0].print_speed(os); +} + +float Llm::average_total_speed() { + return mLlmSessionInfos[0].average_total_speed(); +} +float Llm::average_prefill_speed() { + // prefill response rate + return mLlmSessionInfos[0].average_prefill_speed(); +} +float Llm::average_decode_speed() { + return mLlmSessionInfos[0].average_decode_speed(); +} +float Llm::getTotalPrefillTime() { + return mLlmSessionInfos[0].getTotalPrefillTime(); +} +float Llm::getTotalDecodeTime() { + return mLlmSessionInfos[0].getTotalDecodeTime(); +} +int Llm::getTotalPromptLen() { + return mLlmSessionInfos[0].getTotalPromptLen(); +} +int Llm::getTotalDecodeLen() { + return mLlmSessionInfos[0].getTotalDecodeLen(); } +// speed > static inline bool needNewVar(VARP var, int axis, int seq_len) { if (var == nullptr) { @@ -722,34 +668,35 @@ std::string Llm::tokenizer_decode(int id) { } VARP Llm::gen_attention_mask(int seq_len) { - int kv_seq_len = all_seq_len_ + seq_len; + int kv_seq_len_=mLlmSessionInfos[0].all_seq_len_+seq_len, gen_seq_len_=mLlmSessionInfos[0].gen_seq_len_; + int prev_seq_len_ = kv_seq_len_ - seq_len; if (seq_len == 1) { - kv_seq_len = seq_len; + kv_seq_len_ = seq_len; } if (config_->attention_mask() == "float") { if (needNewVar(attention_mask_, 2, seq_len)) { - attention_mask_ = _Input({1, 1, seq_len, kv_seq_len}, NCHW, halide_type_of()); + attention_mask_ = _Input({1, 1, seq_len, kv_seq_len_}, NCHW, halide_type_of()); } else { return attention_mask_; } auto ptr = attention_mask_->writeMap(); for (int i = 0; i < seq_len; i++) { - for (int j = 0; j < kv_seq_len; j++) { - int row = i + all_seq_len_; - ptr[kv_seq_len * i + j] = (j > row) * std::numeric_limits::lowest(); + for (int j = 0; j < kv_seq_len_; j++) { + int row = i + prev_seq_len_; + ptr[kv_seq_len_ * i + j] = (j > row) * std::numeric_limits::lowest(); } } return attention_mask_; } else { if (needNewVar(attention_mask_, 2, seq_len)) { - attention_mask_ = _Input({1, 1, seq_len, kv_seq_len}, NCHW, halide_type_of()); + attention_mask_ = _Input({1, 1, seq_len, kv_seq_len_}, NCHW, halide_type_of()); } else { return attention_mask_; } auto ptr = attention_mask_->writeMap(); if (config_->attention_mask() == "glm") { // chatglm - for (int i = 0; i < seq_len * kv_seq_len; i++) { + for (int i = 0; i < seq_len * kv_seq_len_; i++) { ptr[i] = 0; } if (seq_len > 1) { @@ -760,8 +707,8 @@ VARP Llm::gen_attention_mask(int seq_len) { } else { bool is_glm2 = config_->attention_mask() == "glm2"; for (int i = 0; i < seq_len; i++) { - for (int j = 0; j < kv_seq_len; j++) { - int row = i + all_seq_len_; + for (int j = 0; j < kv_seq_len_; j++) { + int row = i + prev_seq_len_; ptr[seq_len * i + j] = is_glm2 ? j > row : j <= row; } } @@ -771,6 +718,8 @@ VARP Llm::gen_attention_mask(int seq_len) { } VARP Llm::gen_position_ids(int seq_len) { + int kv_seq_len_=mLlmSessionInfos[0].all_seq_len_+seq_len, gen_seq_len_=mLlmSessionInfos[0].gen_seq_len_; + int prev_seq_len_ = kv_seq_len_ - seq_len; if (config_->attention_mask() == "glm") { // chatglm if (needNewVar(position_ids_, 2, seq_len)) { @@ -778,7 +727,7 @@ VARP Llm::gen_position_ids(int seq_len) { } auto ptr = position_ids_->writeMap(); if (seq_len == 1) { - ptr[0] = all_seq_len_ - gen_seq_len_ - 2; + ptr[0] = prev_seq_len_ - gen_seq_len_ - 2; ptr[1] = gen_seq_len_ + 1; } else { for (int i = 0; i < seq_len - 1; i++) { @@ -796,10 +745,10 @@ VARP Llm::gen_position_ids(int seq_len) { } auto ptr = position_ids_->writeMap(); if (seq_len == 1) { - ptr[0] = is_glm2 ? gen_seq_len_ : all_seq_len_; + ptr[0] = is_glm2 ? gen_seq_len_ : prev_seq_len_; } else { for (int i = 0; i < seq_len; i++) { - ptr[i] = i + all_seq_len_; + ptr[i] = i + prev_seq_len_; } } return position_ids_; @@ -868,7 +817,8 @@ std::vector Lvlm::image_process(const std::string& image_info) { } std::vector Lvlm::tokenizer_encode(const std::string& query, bool use_template) { - auto prompt = apply_prompt_template(query); + auto prompt = query; + if (!use_template) { prompt = mPromptLib->applyTemplate(query); } // split query std::regex img_regex("(.*?)"); std::string::const_iterator searchStart(prompt.cbegin()); @@ -976,7 +926,7 @@ VARP Embedding::ids_embedding(const std::vector& ids) { } VARP Embedding::txt_embedding(const std::string& txt) { - return ids_embedding(tokenizer_encode(txt)); + return ids_embedding(tokenizer_encode(txt, false)); } VARP Embedding::gen_attention_mask(int seq_len) { diff --git a/transformers/llm/engine/src/llmconfig.cpp b/transformers/llm/engine/src/llmconfig.cpp new file mode 100644 index 000000000..458f3c4d1 --- /dev/null +++ b/transformers/llm/engine/src/llmconfig.cpp @@ -0,0 +1,47 @@ +// +// llmconfig.hpp +// +// Created by MNN on 2024/07/19. +// ZhaodeWang +// + + + +#include "rapidjson/document.h" +#include +#include +#include "llmconfig.hpp" + +namespace MNN { +namespace Transformer { + +bool merge_json(rapidjson::Value& destination, const rapidjson::Value& source, + rapidjson::Document::AllocatorType& allocator) { + if (!source.IsObject() || !destination.IsObject()) { + return false; + } + + for (auto it = source.MemberBegin(); it != source.MemberEnd(); ++it) { + const char* key = it->name.GetString(); + if (destination.HasMember(key)) { + if (destination[key].IsObject() && it->value.IsObject()) { + // Recursively merge the two JSON objects + merge_json(destination[key], it->value, allocator); + } else { + // Overwrite the value in the destination + destination[key].CopyFrom(it->value, allocator); + } + } else { + // Add the value to the destination + rapidjson::Value newKey(key, allocator); + rapidjson::Value newValue; + newValue.CopyFrom(it->value, allocator); + destination.AddMember(newKey, newValue, allocator); + } + } + return true; +} + +} // Transformer +} // MNN + diff --git a/transformers/llm/engine/src/llmconfig.hpp b/transformers/llm/engine/src/llmconfig.hpp index 327ecaaff..301ce38ae 100644 --- a/transformers/llm/engine/src/llmconfig.hpp +++ b/transformers/llm/engine/src/llmconfig.hpp @@ -5,9 +5,18 @@ // ZhaodeWang // -#include "rapidjson/document.h" +#ifndef LLMCONFIG_Hpp +#define LLMCONFIG_Hpp + +#include +#include +#include +#include +#include #include #include + + namespace MNN { namespace Transformer { @@ -36,31 +45,7 @@ static inline std::string file_name(const std::string& path) { } bool merge_json(rapidjson::Value& destination, const rapidjson::Value& source, - rapidjson::Document::AllocatorType& allocator) { - if (!source.IsObject() || !destination.IsObject()) { - return false; - } - - for (auto it = source.MemberBegin(); it != source.MemberEnd(); ++it) { - const char* key = it->name.GetString(); - if (destination.HasMember(key)) { - if (destination[key].IsObject() && it->value.IsObject()) { - // Recursively merge the two JSON objects - merge_json(destination[key], it->value, allocator); - } else { - // Overwrite the value in the destination - destination[key].CopyFrom(it->value, allocator); - } - } else { - // Add the value to the destination - rapidjson::Value newKey(key, allocator); - rapidjson::Value newValue; - newValue.CopyFrom(it->value, allocator); - destination.AddMember(newKey, newValue, allocator); - } - } - return true; -} + rapidjson::Document::AllocatorType& allocator); class rapid_json_wrapper { public: @@ -98,6 +83,13 @@ class rapid_json_wrapper { return buffer.GetString(); } // read value + float value(const char* key, const float& default_value) const { + if (document.HasMember(key)) { + const auto& value = document[key]; + if (value.IsFloat()) return value.GetFloat(); + } + return default_value; + } int value(const char* key, const int& default_value) const { if (document.HasMember(key)) { const auto& value = document[key]; @@ -162,6 +154,21 @@ class rapid_json_wrapper { } return default_value; } + std::vector value(const char* key, const std::vector& default_value) const { + if (document.HasMember(key)) { + const auto& value = document[key]; + if (value.IsArray()) { + std::vector result; + for (auto& v : value.GetArray()) { + if (v.IsString()) { + result.push_back(v.GetString()); + } + } + return result; + } + } + return default_value; + } std::string value(const char key[], const char default_value[]) const { return value(key, std::string(default_value)); } @@ -248,6 +255,10 @@ class LlmConfig { // model file config end > // < generate config start + int max_all_tokens() const { + return config_.value("max_all_tokens", 2048); + } + int max_new_tokens() const { return config_.value("max_new_tokens", 512); } @@ -324,18 +335,96 @@ class LlmConfig { return llm_config_.value("attention_fused", true); } - std::string chat_template() const { - return llm_config_.value("chat_template", ""); + std::string system_prompt_template() const { + return llm_config_.value("system_prompt_template", "<|im_start|>system\n%s<|im_end|>\n"); } - - std::string prompt_template() const { - return llm_config_.value("prompt_template", ""); + std::string user_prompt_template() const { + return llm_config_.value("user_prompt_template", "<|im_start|>user\n%s<|im_end|>\n"); + } + std::string assistant_prefix() const { + return llm_config_.value("assistant_prefix", "<|im_start|>assistant\n"); + } + std::string assistant_suffix() const { + return llm_config_.value("assistant_suffix", "<|im_end|>\n"); } std::vector tie_embeddings() const { return llm_config_.value("tie_embeddings", std::vector{}); } // llm model config end > + + // < sampler config start + std::string sampler_type() const { + return config_.value("sampler_type", "mixed"); + } + + std::vector mixed_samplers() const { + return config_.value("mixed_samplers", std::vector({"topK", "tfs", "typical", "topP", "min_p", "temperature"})); + } + + float temperature() const { + return config_.value("temperature", 1.0f); + } + + int topK() const { + return config_.value("topK", 40); + } + + float topP() const { + return config_.value("topP", 0.9f); + } + + float minP() const { + return config_.value("minP", 0.1f); + } + + float tfsZ() const { + return config_.value("tfsZ", 1.0f); + } + + float typical() const { + return config_.value("typical", 1.0f); + } + + float penalty() const { + return config_.value("penalty", 0.0f); + } + + int ngram() const { + return config_.value("n_gram", 8); + } + + float ngram_factor() const { + return config_.value("ngram_factor", 1.0f); + } + + std::string penalty_sampler() const { + return config_.value("penalty_sampler", "greedy"); + } + // sampler config end > + + // < app config start + std::string app_type() const { + return config_.value("app_type", "chat"); + } + std::string system_prompt() const { + return config_.value("system_prompt", "You are a helpful assistant!\n"); + } + // app config end > + + // < evaulation config start + int ppl_stride() const { + return config_.value("ppl_stride", 0); + } + std::string dataset() const { + return config_.value("dataset", "wikitext"); + } + int dataset_sample_size() const { + return config_.value("dataset_sample_size", -1); // -1 stands for no sampling, use all. + } + // evaulation config end }; } // Transformer } // MNN + +#endif diff --git a/transformers/llm/engine/src/perplexity.cpp b/transformers/llm/engine/src/perplexity.cpp new file mode 100644 index 000000000..4d04cdf31 --- /dev/null +++ b/transformers/llm/engine/src/perplexity.cpp @@ -0,0 +1,318 @@ +#include +#include +#include +#include +#include +#include + +#include "sampler.hpp" +#include "perplexity.hpp" +#include "llmconfig.hpp" +#include "prompt.hpp" + +namespace MNN{ +namespace Transformer{ + + +/* -----------TextPPLMeasurer---------- */ +TextPPLMeasurer::TextPPLMeasurer(Llm* llm, std::shared_ptr llmConfig) { + mLlm = llm; + mConfig.max_all_tokens = llmConfig->max_all_tokens(); + mConfig.max_new_tokens = llmConfig->max_new_tokens(); + mDatasetType = llmConfig->dataset(); + mStride = llmConfig->ppl_stride(); + if (mStride == 0) { + // default stride for sliding window. + mStride = mConfig.max_all_tokens / 2; + } +} + +/* Implemented based on https://huggingface.co/docs/transformers/perplexity + + ******************** HuggingFace Python Version ************************ + +import torch +from tqdm import tqdm + +max_length = model.config.n_positions +stride = 512 +seq_len = encodings.input_ids.size(1) + +nlls = [] +prev_end_loc = 0 +for begin_loc in tqdm(range(0, seq_len, stride)): + end_loc = min(begin_loc + max_length, seq_len) + trg_len = end_loc - prev_end_loc # may be different from stride on last loop + input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device) + target_ids = input_ids.clone() + target_ids[:, :-trg_len] = -100 + + with torch.no_grad(): + outputs = model(input_ids, labels=target_ids) + + # loss is calculated using CrossEntropyLoss which averages over valid labels + # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels + # to the left by 1. + neg_log_likelihood = outputs.loss + + nlls.append(neg_log_likelihood) + + prev_end_loc = end_loc + if end_loc == seq_len: + break + +ppl = torch.exp(torch.stack(nlls).mean()) + + ******************** HuggingFace Python Version ************************ +*/ + +float TextPPLMeasurer::perplexity_one(const std::vector& prompt) { + int seq_len = prompt.size(); + std::vector nlls; + float ppl = 0.f; + + // start calculation + int prev_end_loc = 1; // the first token start from id=1, do not count the first one. + for (int begin_loc = 0; begin_loc < seq_len; begin_loc += mStride) { + int end_loc = std::min(begin_loc + mConfig.max_all_tokens, seq_len); + // first token + std::vector tokens(prev_end_loc - begin_loc); + for (int it = begin_loc; it < prev_end_loc; ++it) tokens[it - begin_loc] = prompt[it]; + mLlm->mLlmSessionInfos[0].all_seq_len_ = tokens.size(); + mLlm->mLlmSessionInfos[0].gen_seq_len_ = mLlm->mLlmSessionInfos[0].all_seq_len_; + auto logits = mLlm->forward(tokens, true); + logits = MNN::Express::_Softmax(logits); + nlls.push_back(-std::log(((float*)(logits->readMap()))[prompt[prev_end_loc]])); + // std::cout << mLlm->tokenizer_decode(argmax(logits)) << " " << mLlm->tokenizer_decode(prompt[prev_end_loc]) << " " << -std::log(((float*)(logits->readMap()))[prompt[prev_end_loc]]) << std::endl; + std::cout << -std::log(((float*)(logits->readMap()))[prompt[prev_end_loc]]) << std::endl; + // decode following tokens + for (int it = prev_end_loc+1; it < end_loc; ++it) { + mLlm->mLlmSessionInfos[0].all_seq_len_ += 1; + mLlm->mLlmSessionInfos[0].gen_seq_len_ = mLlm->mLlmSessionInfos[0].all_seq_len_; + auto logits = mLlm->forward({prompt[it-1]}, false); + logits = MNN::Express::_Softmax(logits); + nlls.push_back(-std::log(((float*)(logits->readMap()))[prompt[it]])); + // std::cout << mLlm->tokenizer_decode(argmax(logits)) << " " << mLlm->tokenizer_decode(prompt[it]) << " " << -std::log(((float*)(logits->readMap()))[prompt[it]]) << std::endl; + std::cout << -std::log(((float*)(logits->readMap()))[prompt[it]]) << std::endl; + } + // clean up once + mLlm->reset(); + prev_end_loc = end_loc; + if (end_loc == seq_len) break; + } + + // calculate ppl + for (int j = 0; j < nlls.size(); ++j) ppl += nlls[j]; + ppl /= nlls.size(); + ppl = std::exp(ppl); + + // print + std::cout << "PPL: " << std::setprecision(8) << ppl << std::endl; + return ppl; +} + +std::vector TextPPLMeasurer::perplexity(std::vector> prompts) { + std::vector ppls; + for (auto prompt : prompts) { + ppls.push_back(perplexity_one(prompt)); + mLlm->reset(); + } + return ppls; +} + +std::vector TextPPLMeasurer::perplexity(std::vector prompts) { + std::vector> tokens(prompts.size()); + for (int p = 0; p < prompts.size(); ++p) tokens[p] = mLlm->tokenizer_encode(prompts[p], false); + return perplexity(tokens); +} + +std::vector TextPPLMeasurer::perplexity(std::string prompt_file, std::ostream* perfOS) { + // No performance will be printed! + std::vector prompts; + if (mDatasetType == "wikitext") { + prompts = wikitext(prompt_file); + } + else if (mDatasetType == "plaintext") { + prompts = plaintext(prompt_file); + } + else if (mDatasetType == "rowsplit") { + prompts = rowsplit(prompt_file); + } + else { + MNN_ERROR("Dataset not suppoted"); + exit(1); + } + std::cout << "prompt file loaded!" << std::endl; + return perplexity(prompts); +} + +/* -----------ChatPPLMeasurer---------- */ +ChatPPLMeasurer::ChatPPLMeasurer(Llm* llm, std::shared_ptr llmConfig) { + mLlm = llm; + mConfig.max_all_tokens = llmConfig->max_all_tokens(); + mConfig.max_new_tokens = llmConfig->max_new_tokens(); + mDatasetType = llmConfig->dataset(); + mDatasetSampleSize = llmConfig->dataset_sample_size(); +} + +void ChatPPLMeasurer::handleToken(int token) { + // CommonPrefix and Candidates managements + mLlm->mLlmSessionInfos[0].tokens.push_back(token); +} + +std::vector ChatPPLMeasurer::sample(const std::vector& input_ids, const std::vector& prompt, struct TimePerformance* time_perf) { + std::vector nlls; + // initialization for time performance + PrefillTimePerformance prefill_time; + prefill_time.prefill_prev_token_ = mLlm->mLlmSessionInfos[0].tokens.size(); + prefill_time.prefill_token_ = input_ids.size(); + appendNewPromptRecord(time_perf, input_ids.size(), mLlm->reuse_kv()); + // initialization + mLlm->mLlmSessionInfos[0].tokens.insert(mLlm->mLlmSessionInfos[0].tokens.end(), input_ids.begin(), input_ids.end()); + // all_seq_len_ in sampler functions as kv_seq_len_, prev_seq_len_ = all_seq_len_ - seq_len + mLlm->mLlmSessionInfos[0].all_seq_len_ = mLlm->mLlmSessionInfos[0].tokens.size() - input_ids.size(); + mLlm->mLlmSessionInfos[0].gen_seq_len_ = 0; + // prefill + auto st = std::chrono::system_clock::now(); + auto logits = mLlm->forward(input_ids, true); + logits = MNN::Express::_Softmax(logits); + nlls.push_back(-std::log(((float*)(logits->readMap()))[prompt[mLlm->mLlmSessionInfos[0].gen_seq_len_]])); + // record time + auto et = std::chrono::system_clock::now(); + prefill_time.prefill_us_ = std::chrono::duration_cast(et - st).count(); + time_perf->prefill_record_.push_back(prefill_time); + // handle the new token + handleToken(prompt[mLlm->mLlmSessionInfos[0].gen_seq_len_]); + // decode + while (mLlm->mLlmSessionInfos[0].gen_seq_len_ < prompt.size()) { + DecodeTimePerformance decode_time; + decode_time.decode_prev_token_ = mLlm->mLlmSessionInfos[0].tokens.size(); + st = std::chrono::system_clock::now(); + // next token + logits = mLlm->forward({mLlm->mLlmSessionInfos[0].tokens.back()}, false); + logits = MNN::Express::_Softmax(logits); + nlls.push_back(-std::log(((float*)(logits->readMap()))[prompt[mLlm->mLlmSessionInfos[0].gen_seq_len_]])); + et = std::chrono::system_clock::now(); + decode_time.decode_us_ = std::chrono::duration_cast(et - st).count(); + time_perf->decode_record_.push_back(decode_time); + handleToken(prompt[mLlm->mLlmSessionInfos[0].gen_seq_len_]); + } + // return nlls + return nlls; +} + +float ChatPPLMeasurer::perplexity_one(const std::vector>& prompt, std::ostream* perfOS) { + // (turns, roles) + std::vector nlls; + float ppl = 0.f; + + // < simulated chat + mLlm->reset(); + for (auto& turn : prompt) { + mLlm->mPromptLib->appendUserPrompt(turn[0].second); + std::vector input_ids = mLlm->tokenizer_encode(mLlm->mPromptLib->getLLMInput(), false); + mLlm->generate_init(); + auto turn_nlls = sample(input_ids, mLlm->tokenizer_encode(turn[1].second, false), &(mLlm->mLlmSessionInfos[0].mTimePerformance)); + nlls.insert(nlls.end(), turn_nlls.begin(), turn_nlls.end()); + mLlm->mPromptLib->appendLLMOutput(turn[1].second); + } + + // record time performance to file + if (perfOS != nullptr) { + mLlm->mLlmSessionInfos[0].print_speed(perfOS); + } + + mLlm->reset(); + // simulated chat > + + // calculate ppl + for (int j = 0; j < nlls.size(); ++j) ppl += nlls[j]; + ppl /= nlls.size(); + ppl = std::exp(ppl); + + // print + std::cout << "PPL: " << std::setprecision(8) << ppl << std::endl; + return ppl; +} + + +std::vector ChatPPLMeasurer::perplexity(const std::vector>>& prompts, std::ostream* perfOS) { + std::vector ppls; + for (auto& prompt : prompts) { + ppls.push_back(perplexity_one(prompt, perfOS)); + mLlm->reset(); + } + return ppls; +} + +void ChatPPLMeasurer::getStats(const std::vector>>& prompts) { + std::ofstream total_stats("total_stats.csv"); + std::ofstream dialog_stats("dialog_stats.csv"); + float average_turns=0, average_prefill=0, average_decode=0, average_total_tokens=0; + int max_turns=0; + std::vector>> stats; // (dialog, turn, (prefill, decode)) + std::cout << prompts.size() << std::endl; + int counter = 0; + for (auto& dialog : prompts) { + std::vector> dialog_stats; + if ((counter++) % std::max((int)prompts.size()/200, 1) == 0) std::cout << "*" << std::flush; + float prefill_len_turn = 0; + float decode_len_turn = 0; + for (auto& turn : dialog) { + // turn: prefill, decode + int prefill_len = mLlm->tokenizer_encode(turn[0].second, false).size(); + int decode_len = mLlm->tokenizer_encode(turn[1].second, false).size(); + prefill_len_turn += prefill_len; + decode_len_turn += decode_len; + average_total_tokens += prefill_len + decode_len; + dialog_stats.push_back({prefill_len, decode_len}); + } + stats.push_back(dialog_stats); + average_prefill += prefill_len_turn / dialog.size(); // average over turns + average_decode += decode_len_turn / dialog.size(); // average over turns + average_turns += dialog.size(); + max_turns = std::max(max_turns, (int)dialog.size()); + } + average_turns /= prompts.size(); + average_prefill /= prompts.size(); + average_decode /= prompts.size(); + average_total_tokens /= prompts.size(); + total_stats << "total_dialogs," << "max_turns," << "avg_turns," \ + << "avg_prefill_tokens/turn," << "avg_decode_tokens/turn," \ + << "avg_total_tokens/dialog" << std::endl; + total_stats << prompts.size() << "," << max_turns << "," << average_turns << "," \ + << average_prefill << "," << average_decode << "," \ + << average_total_tokens << std::endl; + for (int i=0; i ChatPPLMeasurer::perplexity(std::string prompt_file, std::ostream* perfOS) { + // No performance will be printed! + std::vector>> prompts; + if (mDatasetType == "shareGPT") { + prompts = shareGPT(prompt_file, mDatasetSampleSize); + } + else { + MNN_ERROR("Dataset not suppoted"); + exit(1); + } + std::cout << "prompt file loaded!" << std::endl; + getStats(prompts); + std::cout << "\nshareGPT statistics counted!" << std::endl; + return perplexity(prompts, perfOS); +} + + +} // Transformer +} // MNN \ No newline at end of file diff --git a/transformers/llm/engine/src/perplexity.hpp b/transformers/llm/engine/src/perplexity.hpp new file mode 100644 index 000000000..ece414dd2 --- /dev/null +++ b/transformers/llm/engine/src/perplexity.hpp @@ -0,0 +1,65 @@ +#ifndef PERPLEXITY_hpp +#define PERPLEXITY_hpp + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "sampler.hpp" +#include "evaluation/dataset.hpp" + +namespace MNN { +namespace Transformer { +class Llm; + +class MNN_PUBLIC TextPPLMeasurer : public Sampler { +protected: + Llm* mLlm; + int mStride; + std::string mDatasetType; + LlmSamplerConfig mConfig; +public: + TextPPLMeasurer(Llm* llm, std::shared_ptr config); + float perplexity_one(const std::vector& prompt); + std::vector perplexity(std::vector> prompts); + std::vector perplexity(std::vector prompts); + virtual std::string sample(const std::vector& input_ids, std::ostream* os = &std::cout, const char* end_with = nullptr, struct TimePerformance* time_perf = nullptr) override { return "perplexity evaluation!\n"; } + virtual std::vector perplexity(std::string prompt_file, std::ostream* perfOS = nullptr) override; +}; + +class MNN_PUBLIC ChatPPLMeasurer : public Sampler { +protected: + Llm* mLlm; + std::string mDatasetType; + int mDatasetSampleSize; + LlmSamplerConfig mConfig; + void handleToken(int token); + std::vector sample(const std::vector& input_ids, const std::vector& prompt, struct TimePerformance* time_perf); +public: + ChatPPLMeasurer(Llm* llm, std::shared_ptr config); + void getStats(const std::vector>>& prompts); + float perplexity_one(const std::vector>& prompt, std::ostream* perfOS); + std::vector perplexity(const std::vector>>& prompts, std::ostream* perfOS); + virtual std::string sample(const std::vector& input_ids, std::ostream* os = &std::cout, const char* end_with = nullptr, struct TimePerformance* time_perf = nullptr) override { return "perplexity evaluation!\n"; } + virtual std::vector perplexity(std::string prompt_file, std::ostream* perfOS = nullptr) override; +}; + + + +} // Transformer +} // MNN + + +#endif // SAMPLER_hpp \ No newline at end of file diff --git a/transformers/llm/engine/src/prompt.cpp b/transformers/llm/engine/src/prompt.cpp new file mode 100644 index 000000000..445c630ce --- /dev/null +++ b/transformers/llm/engine/src/prompt.cpp @@ -0,0 +1,110 @@ +#include "prompt.hpp" + +namespace MNN { +namespace Transformer { + +/* ----------PromptLib---------- */ +PromptLib* PromptLib::createPromptLib(Llm* llm, const std::string& config_path) { + return createPromptLib(llm, std::shared_ptr(new LlmConfig(config_path))); +} +PromptLib* PromptLib::createPromptLib(Llm* llm, std::shared_ptr config) { + if (config->app_type() == "chat" || config->app_type() == "perplexity") { + return new BaseChatPromptLib(llm, config); + } else { + std::cout << "PromptLib not Implemented!\n" << std::endl; + return nullptr; + } +} + +/* ----------BaseChatPromptLib---------- */ +BaseChatPromptLib::BaseChatPromptLib(Llm* llm, std::shared_ptr config) { + mLlm = llm; + mReuseKV = config->reuse_kv(); + mDefaultSystemPrompt = config->system_prompt(); + mSystemTemplate = config->system_prompt_template(); + mUserTemplate = config->user_prompt_template(); + mAssistantPrefix = config->assistant_prefix(); + mAssistantSuffix = config->assistant_suffix(); +} + +void BaseChatPromptLib::appendSystemPrompt() { + appendSystemPrompt(mDefaultSystemPrompt); +} +void BaseChatPromptLib::appendSystemPrompt(const std::string sys_prompt) { + mLlm->mLlmSessionInfos[0].mHistory.emplace_back(std::make_pair("system", sys_prompt)); + mLlm->mLlmSessionInfos[0].mInputs.emplace_back(std::make_pair("system", sys_prompt)); +} +void BaseChatPromptLib::appendUserPrompt(const std::string user_prompt) { + if (mLlm->mLlmSessionInfos[0].mHistory.empty()) { appendSystemPrompt(); } // prevent no system prompt appendix. + mLlm->mLlmSessionInfos[0].mHistory.emplace_back(std::make_pair("user", user_prompt)); + mLlm->mLlmSessionInfos[0].mInputs.emplace_back(std::make_pair("user", user_prompt)); +} +void BaseChatPromptLib::appendLLMOutput(std::string out_str) { + mLlm->mLlmSessionInfos[0].mHistory.emplace_back(std::make_pair("assistant", out_str)); + if (mReuseKV) { + // clear input + mLlm->mLlmSessionInfos[0].mInputs.clear(); + } else { + // keep input, append output + mLlm->mLlmSessionInfos[0].mInputs.emplace_back(std::make_pair("assistant", out_str)); + } +} + +std::string BaseChatPromptLib::getLLMInput() { + std::string input_str; + if (mReuseKV) { + if (mLlm->mLlmSessionInfos[0].mHistory.size() != mLlm->mLlmSessionInfos[0].mInputs.size()) { + // 1.1 not first prefill, add end of speech. + input_str += mAssistantSuffix; + } + } + // 1.2 generate from template + input_str += applyTemplates(mLlm->mLlmSessionInfos[0].mInputs); + input_str += mAssistantPrefix; + return input_str; +} + +std::string BaseChatPromptLib::applyTemplate(PromptItem item, std::string prompt_template, std::string placeholder) { + size_t start_pos = prompt_template.find(placeholder); + if (start_pos == std::string::npos) return item.first + "\n" + item.second + "\n"; + else { + prompt_template.replace(start_pos, placeholder.length(), item.second); + return prompt_template; + } +} + +std::string BaseChatPromptLib::applyTemplates(std::vector inputs) { + std::string input_str; + for (auto input : inputs) { + if (input.first == "") continue; + if (input.first == "system") { + if (input.second == "") continue; + input_str += applyTemplate(input, mSystemTemplate, "%s"); + continue; + } + if (input.first == "user") { + input_str += applyTemplate(input, mUserTemplate, "%s"); + continue; + } + if (input.first == "assistant") { + input_str += mAssistantPrefix + input.second + mAssistantSuffix; + continue; + } + // Invalid role!!! + } + return input_str; +} + +std::string BaseChatPromptLib::applyTemplate(std::string user_content) { + std::vector prompts; + prompts.push_back(std::make_pair("system", mDefaultSystemPrompt)); + prompts.push_back(std::make_pair("user", user_content)); + return applyTemplates(prompts) + mAssistantPrefix; +} + +std::string BaseChatPromptLib::getAssistantSuffix() const { + return mAssistantSuffix; +} + +} +} \ No newline at end of file diff --git a/transformers/llm/engine/src/prompt.hpp b/transformers/llm/engine/src/prompt.hpp new file mode 100644 index 000000000..3fcf42509 --- /dev/null +++ b/transformers/llm/engine/src/prompt.hpp @@ -0,0 +1,61 @@ + + +#ifndef PROMPT_Hpp +#define PROMPT_Hpp + +#include "llm/llm.hpp" +#include "llmconfig.hpp" +#define MNN_OPEN_TIME_TRACE +#include +#include +#include +#include +#include + + +namespace MNN { +namespace Transformer { + +/* PromptLib: history organization + input organization */ +class MNN_PUBLIC PromptLib { +protected: + Llm* mLlm; +public: + static PromptLib* createPromptLib(Llm* llm, const std::string& config_path); + static PromptLib* createPromptLib(Llm* llm, std::shared_ptr config); + virtual std::string applyTemplate(std::string user_content) = 0; + virtual std::string getAssistantSuffix() const = 0; + virtual void appendSystemPrompt(const std::string sys_prompt) = 0; + virtual void appendSystemPrompt() = 0; + virtual void appendUserPrompt(const std::string use_prompt) = 0; + virtual void appendLLMOutput(std::string out_str) = 0; + virtual std::string getLLMInput() = 0; + virtual void reset(Llm* llm) { mLlm = llm; } +}; + +class MNN_PUBLIC BaseChatPromptLib : public PromptLib { +protected: + bool mReuseKV; + std::string mDefaultSystemPrompt; + std::string mSystemTemplate; + std::string mUserTemplate; + std::string mAssistantPrefix; + std::string mAssistantSuffix; + std::string applyTemplate(PromptItem item, std::string prompt_template, std::string placeholder = "%s"); + std::string applyTemplates(std::vector inputs); +public: + BaseChatPromptLib(Llm* llm, std::shared_ptr config); + virtual std::string applyTemplate(std::string user_content) override; + virtual std::string getAssistantSuffix() const override; + virtual void appendSystemPrompt(const std::string sys_prompt) override; + virtual void appendSystemPrompt() override; + virtual void appendUserPrompt(const std::string user_prompt) override; + virtual void appendLLMOutput(std::string out_str) override; + virtual std::string getLLMInput() override; +}; + +} +} + + +#endif \ No newline at end of file diff --git a/transformers/llm/engine/src/sampler.cpp b/transformers/llm/engine/src/sampler.cpp new file mode 100644 index 000000000..87a9c28b9 --- /dev/null +++ b/transformers/llm/engine/src/sampler.cpp @@ -0,0 +1,551 @@ +#include +#include +#include +#include +#include + +#include +#include + +#include "llm/llm.hpp" +#include "evaluation/dataset.hpp" +#include "sampler.hpp" +#include "perplexity.hpp" +#include "llmconfig.hpp" + +namespace MNN{ +namespace Transformer{ + +MNN::Express::VARP _TempratureSoftmax(MNN::Express::VARP logits, float temperature, int axis) { + return MNN::Express::_Softmax(logits * MNN::Express::_Scalar(1.0f / temperature), axis); +} + +/* ----------Sampler's members---------- */ +int Sampler::select(struct SubsetLogits& subset, int id) { + if (!(subset.is_subset)) return id; + return subset.index[id]; +} + +int Sampler::randomSelect(float* probs, size_t size) { + std::random_device rd; + std::mt19937 generator(rd()); + std::uniform_real_distribution distribution(0.0, 1.0); + float target = distribution(generator); + float cumulative = 0.0; + for (int i = 0; i < size; i++) { + cumulative += probs[i]; + if (target < cumulative) { + return i; + } + } + return size - 1; +} + +int Sampler::randomSelect(MNN::Express::VARP probs) { + return randomSelect((float*)(probs->readMap()), probs->getInfo()->size); +} + +int Sampler::reSoftmaxSelect(struct SubsetLogits subset, float temperature) { + int token_index_id = randomSelect(_TempratureSoftmax(subset.logits, temperature)); + return ((subset.is_subset) ? subset.index[token_index_id] : token_index_id); +} + +SubsetLogits Sampler::createSubsetLogits(MNN::Express::VARP logits) { + struct SubsetLogits subset; + subset.logits = logits; + subset.is_subset = false; + return subset; +} + +SubsetLogits Sampler::createSubsetLogits(MNN::Express::VARP logits, const std::vector& index) { + struct SubsetLogits subset; + subset.logits = logits; + subset.index = index; + subset.is_subset = true; + return subset; +} + +SubsetLogits Sampler::createSubsetLogits(int size) { + struct SubsetLogits subset; + subset.logits = MNN::Express::_Input({size}, MNN::Express::NHWC); + subset.index.resize(size); + subset.is_subset = true; + return subset; +} + +SubsetLogits Sampler::createSubsetLogits(const std::vector& scores, const std::vector& index) { + int size = (int)(index.size()); + struct SubsetLogits subset; + subset.logits = MNN::Express::_Input({size}, MNN::Express::NHWC); + auto pointer = (float*)(subset.logits->writeMap()); + for (int i = 0; i < size; ++i) { + pointer[i] = scores[i]; + } + subset.index = index; + subset.is_subset = true; + return subset; +} + +void Sampler::transformIndex(struct SubsetLogits& superset, struct SubsetLogits& subset) { + if (!(superset.is_subset)) return; + for (auto& id : subset.index) { + id = superset.index[id]; + } +} + +Sampler* Sampler::createSampler(Llm* llm, const std::string& config_path) { + return createSampler(llm, std::shared_ptr(new LlmConfig(config_path))); +} + +Sampler* Sampler::createSampler(Llm* llm, std::shared_ptr config) { + std::string sampler_type = config->sampler_type(); + if (sampler_type == "greedy" + || sampler_type == "temperature" + || sampler_type == "penalty" + || sampler_type == "topK" + || sampler_type == "topP" + || sampler_type == "minP" + || sampler_type == "tfs" + || sampler_type == "typical" + || sampler_type == "mixed" + ) { + return new LocalSampler(llm, config); + } else if (config->app_type() == "perplexity") { + std::string ppl_type = getPPLType(config->dataset()); + if (ppl_type == "text") { return new TextPPLMeasurer(llm, config); } + else if (ppl_type == "chat") { return new ChatPPLMeasurer(llm, config); } + } else { + std::cout << "Designated Sampler Not Supported yet!"; + exit(1); + } + return nullptr; +} + + +/* ----------LocalSamplerConfig---------- */ +void LocalSampler::LocalSamplerConfig::configSampler( std::string sampler_type, std::shared_ptr llmConfig) { + if (sampler_type == "greedy"){ + this->configGreedy(llmConfig); + } else if (sampler_type == "temperature"){ + this->configTemperature(llmConfig); + } else if (sampler_type == "topK"){ + this->configTopK(llmConfig); + } else if (sampler_type == "topP"){ + this->configTopP(llmConfig); + } else if (sampler_type == "minP"){ + this->configMinP(llmConfig); + } else if (sampler_type == "tfs"){ + this->configTFS(llmConfig); + } else if (sampler_type == "typical"){ + this->configTypical(llmConfig); + } else if (sampler_type == "penalty"){ + this->configPenalty(llmConfig); + } else if (sampler_type == "mixed"){ + this->configMixed(llmConfig); + } +} +void LocalSampler::LocalSamplerConfig::configGreedy(std::shared_ptr llmConfig) { + select_type = "greedy"; +} +void LocalSampler::LocalSamplerConfig::configTemperature(std::shared_ptr llmConfig) { + temperature = llmConfig->temperature(); + select_type = "temperature"; +} +void LocalSampler::LocalSamplerConfig::configTopK(std::shared_ptr llmConfig) { + topK = llmConfig->topK(); + select_type = "temperature"; +} +void LocalSampler::LocalSamplerConfig::configTopP(std::shared_ptr llmConfig) { + topP = llmConfig->topP(); + temperature = llmConfig->temperature(); + select_type = "temperature"; +} +void LocalSampler::LocalSamplerConfig::configMinP(std::shared_ptr llmConfig) { + minP = llmConfig->minP(); + temperature = llmConfig->temperature(); + select_type = "temperature"; +} +void LocalSampler::LocalSamplerConfig::configTFS(std::shared_ptr llmConfig) { + tfsZ = llmConfig->tfsZ(); + temperature = llmConfig->temperature(); + select_type = "temperature"; +} +void LocalSampler::LocalSamplerConfig::configTypical(std::shared_ptr llmConfig) { + typical = llmConfig->typical(); + temperature = llmConfig->temperature(); + select_type = "temperature"; +} +void LocalSampler::LocalSamplerConfig::configPenalty(std::shared_ptr llmConfig) { + penaltyConfig.penalty = llmConfig->penalty(); + penaltyConfig.ngram = llmConfig->ngram(); + penaltyConfig.ngram_factor = llmConfig->ngram_factor(); + penaltyConfig.sampler = llmConfig->penalty_sampler(); + select_type = penaltyConfig.sampler; +} +void LocalSampler::LocalSamplerConfig::configMixed(std::shared_ptr llmConfig) { + mixedSamplers = llmConfig->mixed_samplers(); + std::cout << "Mixed Sampler Sequence: " << std::flush; + for (auto samplerName : mixedSamplers) { + this->configSampler(samplerName, llmConfig); + std::cout << samplerName << " " << std::flush; + } + std::cout << std::endl; + // set select type + // the final sampler select the token + if (mixedSamplers.back() == "greedy") select_type = "greedy"; + else if(mixedSamplers.back()=="temperature") select_type = "temperature"; + else select_type = "temperature"; // By default temperature is used. +} + + +/* ----------LocalSampler's members---------- */ +LocalSampler::LocalSamplerConfig LocalSampler::getSamplerConfig(std::shared_ptr llmConfig) { + LocalSampler::LocalSamplerConfig samplerConfig; + samplerConfig.max_all_tokens = llmConfig->max_all_tokens(); + samplerConfig.max_new_tokens = llmConfig->max_new_tokens(); + samplerConfig.type = llmConfig->sampler_type(); + std::string sampler_type = samplerConfig.type; + std::cout << "Sampler: " << sampler_type << std::endl; + samplerConfig.configSampler(sampler_type, llmConfig); + return samplerConfig; +} + +LocalSampler::LocalSampler(Llm* llm, std::shared_ptr config) { + // initialize model and candidates + mLlm = llm; + // initialize config + mConfig = getSamplerConfig(config); +} + +int LocalSampler::argmaxSelect(struct SubsetLogits superset) { + auto scores = (float*)(superset.logits->readMap()); + auto size = superset.logits->getInfo()->size; + float max_score = scores[0]; + int token_id = 0; + for (int i = 0; i < size; i++) { + float score = scores[i]; + if (score > max_score) { + max_score = score; + token_id = i; + } + } + return select(superset, token_id); +} + +struct SubsetLogits LocalSampler::topK(struct SubsetLogits superset) { + int K = mConfig.topK; + auto scores = (float*)(superset.logits->readMap()); + auto size = superset.logits->getInfo()->size; + // 1. time complexity: O(nlogk) + std::priority_queue, IndexScoreCmpGreater> heap; + for (int i = 0; i < size; i++) { + IndexScore m; + m.index = i; + m.score = scores[i]; + if (heap.size() < K) { + heap.push(m); + } + else { + if (heap.top().score < m.score) { + heap.pop(); + heap.push(m); + } + } + } + // 2. store top K results + auto subset = createSubsetLogits(K); + float* topKscores = (float*)(subset.logits->writeMap()); + for (int i = 0; i < K; i++) { + subset.index[K-i-1] = heap.top().index; + topKscores[K-i-1] = heap.top().score; + heap.pop(); + } + transformIndex(superset, subset); + return subset; +} + +int LocalSampler::packSoftmax(MNN::Express::VARP logits, std::vector& index_scores, float temperature) { + auto prob_varp = _TempratureSoftmax(logits, temperature); + auto probs = (float*)(prob_varp->readMap()); + auto size = prob_varp->getInfo()->size; + index_scores.resize(size); + for (int i = 0; i < size; i++) { + IndexScore m; + m.index = i; + m.score = probs[i]; + index_scores[i] = m; + } + return size; +} + +struct SubsetLogits LocalSampler::topP(struct SubsetLogits superset) { + float p = mConfig.topP, temperature = mConfig.temperature; + std::vector index_scores; + int size = packSoftmax(superset.logits, index_scores, temperature); + // 1. make max heap + std::make_heap(index_scores.begin(), index_scores.end(), IndexScoreCmpLess()); + // 2. top p algorithm + auto scores = (float*)(superset.logits->readMap()); + std::vector index; + std::vector subset_logits; + float cumulative = 0.0f; + while (cumulative < p && !index_scores.empty()) { + std::pop_heap(index_scores.begin(), index_scores.end(), IndexScoreCmpLess()); + IndexScore m = index_scores.back(); + index_scores.pop_back(); + index.push_back(m.index); + subset_logits.push_back(scores[m.index]); + cumulative += m.score; + } + auto subset = createSubsetLogits(subset_logits, index); + transformIndex(superset, subset); + return subset; +} + +struct SubsetLogits LocalSampler::minP(struct SubsetLogits superset) { + float p = mConfig.minP, temperature = mConfig.temperature; + std::vector index_scores; + int size = packSoftmax(superset.logits, index_scores, temperature); + // 1. make max heap + std::make_heap(index_scores.begin(), index_scores.end(), IndexScoreCmpLess()); + // 2. min p algorithm + auto scores = (float*)(superset.logits->readMap()); + std::vector index; + std::vector subset_logits; + for (int i = 0; i < size; ++i) { + std::pop_heap(index_scores.begin(), index_scores.end(), IndexScoreCmpLess()); + IndexScore m = index_scores.back(); + if (m.score < p && !index.empty()) break; + index_scores.pop_back(); + index.push_back(m.index); + subset_logits.push_back(scores[m.index]); + } + auto subset = createSubsetLogits(subset_logits, index); + transformIndex(superset, subset); + return subset; +} + +struct SubsetLogits LocalSampler::tfs(struct SubsetLogits superset) { + float z = mConfig.tfsZ, temperature = mConfig.temperature; + // tfs algorithm + // 1. softmax + std::vector index_scores; + int size = packSoftmax(superset.logits, index_scores, temperature); + // 2. sort + std::sort(index_scores.begin(), index_scores.end(), IndexScoreCmpGreater()); + auto scores = (float*)(superset.logits->readMap()); + // 3. calculate derivatives + std::vector derivatives(size - 2, 0.0f); + float first = index_scores[0].score - index_scores[1].score; + float second = index_scores[1].score - index_scores[2].score; + for (int i = 0; i < size - 2; ++i) { + second = index_scores[i+1].score - index_scores[i+2].score; + derivatives[i] = std::fabs(first - second); + first = second; + } + // 4. normalize derivatives + float derivatives_sum = 0.0; + for (int i = 0; i < size - 2; ++i) derivatives_sum += derivatives[i]; + float derivatives_sum_rec = 1.0f / derivatives_sum; + for (int i = 0; i < size - 2; ++i) derivatives[i] *= derivatives_sum_rec; + // 5. cumulate, discard last 2 for sure. + float cumulative = 0.0; + std::vector index; + std::vector subset_logits; + for (int i = 0; i < size - 2; ++i) { + IndexScore m = index_scores[i]; + cumulative += derivatives[i]; + if (cumulative >= z && !index.empty()) break; + index.push_back(m.index); + subset_logits.push_back(scores[m.index]); + } + auto subset = createSubsetLogits(subset_logits, index); + transformIndex(superset, subset); + return subset; +} + +struct SubsetLogits LocalSampler::typical(struct SubsetLogits superset) { + float p = mConfig.typical, temperature = mConfig.temperature; + auto prob_varp = _TempratureSoftmax(superset.logits, temperature); + auto probs = (float*)(prob_varp->readMap()); + auto size = prob_varp->getInfo()->size; + std::vector index_scores; + index_scores.resize(size); + // 1. calcaluate dist + float entropy = 0.0f; + for (int i = 0; i < size; i++) entropy -= probs[i] * std::log(probs[i]); + for (int i = 0; i < size; i++) { + IndexScore m; + m.index = i; + m.score = std::fabs(entropy + std::log(probs[i])); + index_scores[i] = m; + } + // 2. make min heap for dist + std::make_heap(index_scores.begin(), index_scores.end(), IndexScoreCmpGreater()); + // 3. typical p algorithm + auto scores = (float*)(superset.logits->readMap()); + float cumulative = 0.0f; + std::vector index; + std::vector subset_logits; + for (int i = 0; i < size; ++i) { + std::pop_heap(index_scores.begin(), index_scores.end(), IndexScoreCmpGreater()); + IndexScore m = index_scores.back(); + cumulative += probs[m.index]; + if (cumulative >= p && !index.empty()) break; + index_scores.pop_back(); + index.push_back(m.index); + subset_logits.push_back(scores[m.index]); + } + auto subset = createSubsetLogits(subset_logits, index); + transformIndex(superset, subset); + return subset; +} + +// presence penalty +// no frequency penalty now! +struct SubsetLogits LocalSampler::penalty(struct SubsetLogits subset) { + float penalty = mConfig.penaltyConfig.penalty; + int ngram = mConfig.penaltyConfig.ngram; + float ngram_factor = mConfig.penaltyConfig.ngram_factor; + float temperature = mConfig.temperature; + bool penalizeNgram = (ngram_factor > 1.0f); + if (penalty <= 1.0f) return subset; // no penalty! + penalty = std::min(penalty, mConfig.penaltyConfig.max_penalty); + // initialization + std::vector& prev = mLlm->mLlmSessionInfos[0].tokens; + std::unordered_map penalty_map; + // 1. local ngram info, reversed order + std::vector ngram_info(ngram-1); + if (penalizeNgram) { + for (int n = 0; n < ngram_info.size(); ++n) { + ngram_info[n] = prev[prev.size()-1-n]; + } + } + // 2. generate penalty map + for (int i = 0; i < prev.size(); ++i) { + if (penalty_map.count(prev[i]) == 0) penalty_map[prev[i]] = penalty; + if (penalizeNgram) { + float ngram_penalty = penalty; + for (int j = i-1; i-j < ngram && j>=0; --j) { + int idx = i-j-1; + if (prev[j] != ngram_info[idx]) break; + ngram_penalty *= ngram_factor; + // no repeat larger than ngram! + if (idx == ngram_info.size()-1) ngram_penalty = mConfig.penaltyConfig.max_penalty; + } + if (ngram_penalty > penalty_map[prev[i]]) penalty_map[prev[i]] = ngram_penalty; + } + } + // 3. penalize logits according to penalty_map + auto scoresMap = (float*)(subset.logits->writeMap()); + for (auto it = penalty_map.begin(); it != penalty_map.end(); ++it) { + scoresMap[it->first] = (scoresMap[it->first] >= 0.0f) ? (scoresMap[it->first]/it->second) : (scoresMap[it->first]*it->second); + } + return subset; +} + +struct SubsetLogits LocalSampler::mixed(struct SubsetLogits subset) { + for (auto sampler : mConfig.mixedSamplers) { + subset = subsetSampler(sampler, subset); + } + return subset; +} + +struct SubsetLogits LocalSampler::subsetSampler(std::string sampler_type, struct SubsetLogits subset) { + if (sampler_type == "penalty") subset = penalty(subset); + if (sampler_type == "topK") subset = topK(subset); + if (sampler_type == "topP") subset = topP(subset); + if (sampler_type == "minP") subset = minP(subset); + if (sampler_type == "tfs") subset = tfs(subset); + if (sampler_type == "typical") subset = typical(subset); + if (sampler_type == "mixed") subset = mixed(subset); + // if greedy and temperate, just let the Selector handle it. + return subset; +} + +int LocalSampler::handleSelect(struct SubsetLogits subset) { + if (mConfig.select_type == "greedy") return argmaxSelect(subset); + else if(mConfig.select_type =="temperature") return reSoftmaxSelect(subset, mConfig.temperature); + return 0; +} + +int LocalSampler::algorithm(MNN::Express::VARP logits) { + struct SubsetLogits subset = createSubsetLogits(logits); + // process subsetSampler + subset = subsetSampler(mConfig.type, subset); + // select token from the subset + int res = handleSelect(subset); + // return + Express::ExecutorScope::Current()->gc(Express::Executor::FULL); + return res; +} + +std::string LocalSampler::handleToken(int token, std::ostream* os, const char* end_with) { + // CommonPrefix and Candidates managements + mLlm->mLlmSessionInfos[0].tokens.push_back(token); + std::string output_str = mLlm->tokenizer_decode(mLlm->mLlmSessionInfos[0].tokens.back()); + // print + *os << output_str << std::flush; + return output_str; +} + +std::string LocalSampler::sample(const std::vector& input_ids, std::ostream* os, const char* end_with, struct TimePerformance* time_perf) { + // initialization for time performance + PrefillTimePerformance prefill_time; + prefill_time.prefill_prev_token_ = mLlm->mLlmSessionInfos[0].tokens.size(); + prefill_time.prefill_token_ = input_ids.size(); + appendNewPromptRecord(time_perf, input_ids.size(), mLlm->reuse_kv()); + // initialization + std::string output_str; + mLlm->mLlmSessionInfos[0].tokens.insert(mLlm->mLlmSessionInfos[0].tokens.end(), input_ids.begin(), input_ids.end()); + // all_seq_len_ in sampler functions as kv_seq_len_, prev_seq_len_ = all_seq_len_ - seq_len + mLlm->mLlmSessionInfos[0].all_seq_len_ = mLlm->mLlmSessionInfos[0].tokens.size() - input_ids.size(); + mLlm->mLlmSessionInfos[0].gen_seq_len_ = 0; + // prefill + auto st = std::chrono::system_clock::now(); + auto logits = mLlm->forward(input_ids, true); + if (nullptr == logits.get()) { + return ""; + } + int token = algorithm(logits); + // record time + auto et = std::chrono::system_clock::now(); + prefill_time.prefill_us_ = std::chrono::duration_cast(et - st).count(); + time_perf->prefill_record_.push_back(prefill_time); + // handle the new token + output_str += handleToken(token, os, end_with); + // decode + while (mLlm->mLlmSessionInfos[0].gen_seq_len_ < mConfig.max_new_tokens + && mLlm->mLlmSessionInfos[0].all_seq_len_ < mConfig.max_all_tokens) { + DecodeTimePerformance decode_time; + decode_time.decode_prev_token_ = mLlm->mLlmSessionInfos[0].tokens.size(); + st = std::chrono::system_clock::now(); + // next token + logits = mLlm->forward({mLlm->mLlmSessionInfos[0].tokens.back()}, false); + if (nullptr == logits.get()) { + return output_str; + } + if (logits->getInfo()->size == 0) { + return output_str; + } + token = algorithm(logits); + et = std::chrono::system_clock::now(); + decode_time.decode_us_ = std::chrono::duration_cast(et - st).count(); + time_perf->decode_record_.push_back(decode_time); + if (mLlm->is_stop(token)) { + *os << end_with << std::flush; + break; + } else { + output_str += handleToken(token); + } + } + if (mLlm->mLlmSessionInfos[0].all_seq_len_ == mConfig.max_all_tokens) { + std::cout << "sequence length reaches maximum allowed." << std::endl; + } + // return output_str + return output_str; +} + + +} // Transformer +} // MNN \ No newline at end of file diff --git a/transformers/llm/engine/src/sampler.hpp b/transformers/llm/engine/src/sampler.hpp new file mode 100644 index 000000000..93266d580 --- /dev/null +++ b/transformers/llm/engine/src/sampler.hpp @@ -0,0 +1,141 @@ +#ifndef SAMPLER_hpp +#define SAMPLER_hpp + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "evaluation/evaluation.hpp" +#include "llmconfig.hpp" +#include "llm/llm.hpp" + + +namespace MNN { +namespace Transformer { + +MNN_PUBLIC MNN::Express::VARP _TempratureSoftmax(MNN::Express::VARP logits, float temperature, int axis = -1); + +class Llm; + +// a index and its corresponding score +struct IndexScore { + int index; + float score; +}; +struct IndexScoreCmpLess{ + bool operator()(IndexScore a, IndexScore b) { + return a.score < b.score; + } +}; +struct IndexScoreCmpGreater{ + bool operator()(IndexScore a, IndexScore b) { + return a.score > b.score; + } +}; +// a series of index and their corresponding logits +struct SubsetLogits{ + std::vector index; + MNN::Express::VARP logits; + bool is_subset; +}; + +class MNN_PUBLIC Sampler { +public: + class LlmSamplerConfig { + public: + int max_new_tokens = 512; + int max_all_tokens = 2048; + }; +protected: + Llm* mLlm; + int select(struct SubsetLogits& subset, int id); + int randomSelect(float* probs, size_t size); + int randomSelect(MNN::Express::VARP probs); + int reSoftmaxSelect(struct SubsetLogits subset, float temperature=1.0); + SubsetLogits createSubsetLogits(MNN::Express::VARP logits); + SubsetLogits createSubsetLogits(MNN::Express::VARP logits, const std::vector& index); + SubsetLogits createSubsetLogits(int size); + SubsetLogits createSubsetLogits(const std::vector& scores, const std::vector& index); + void transformIndex(struct SubsetLogits& superset, struct SubsetLogits& subset); +public: + static Sampler* createSampler(Llm* llm, const std::string& config_path); + static Sampler* createSampler(Llm* llm, std::shared_ptr config); + virtual std::string sample(const std::vector& input_ids, std::ostream* os = &std::cout, const char* end_with = nullptr, struct TimePerformance* time_perf = nullptr) = 0; + virtual std::vector perplexity(std::string prompt_file, std::ostream* perfOS) { return std::vector(); } + // prepare for another round of sampling + // in the future, only reset its own. + virtual void reset(Llm* llm) { mLlm = llm; } +}; + + +class MNN_PUBLIC LocalSampler: public Sampler { +public: + class LocalSamplerConfig : public LlmSamplerConfig { + public: + struct SamplerPenaltyConfig { + float penalty = 1.05; + int ngram = 8; + float ngram_factor = 1.02; // panalize repeated ngram with a multiplied ngram_factor. + float max_penalty = 10.; + std::string sampler = "temperature"; // "greedy", "temperature". + }; + std::string type = "temperature"; + std::string select_type = "temperature"; + float temperature = 0.8; + int topK = 40; + float topP = 0.9; + float minP = 0.05; + float tfsZ = 1.0; + float typical = 0.95; + struct SamplerPenaltyConfig penaltyConfig; + std::vector mixedSamplers= {"topK", "tfs", "typical", "topP", "min_p", "temperature"}; + void configSampler(std::string sampler_type, std::shared_ptr llmConfig); + void configGreedy(std::shared_ptr llmConfig); + void configTemperature(std::shared_ptr llmConfig); + void configTopK(std::shared_ptr llmConfig); + void configTopP(std::shared_ptr llmConfig); + void configMinP(std::shared_ptr llmConfig); + void configTFS(std::shared_ptr llmConfig); + void configTypical(std::shared_ptr llmConfig); + void configPenalty(std::shared_ptr llmConfig); + void configMixed(std::shared_ptr llmConfig); + }; +protected: + LocalSamplerConfig mConfig; + LocalSamplerConfig getSamplerConfig(std::shared_ptr llmConfig); + int argmaxSelect(struct SubsetLogits superset); + int packSoftmax(MNN::Express::VARP logits, std::vector& index_scores, float temperature = 1.0); + struct SubsetLogits penalty(struct SubsetLogits superset); + struct SubsetLogits topK(struct SubsetLogits superset); + struct SubsetLogits topP(struct SubsetLogits superset); + struct SubsetLogits minP(struct SubsetLogits superset); + struct SubsetLogits tfs(struct SubsetLogits superset); + struct SubsetLogits typical(struct SubsetLogits superset); + struct SubsetLogits mixed(struct SubsetLogits subset); + struct SubsetLogits subsetSampler(std::string sampler_type, struct SubsetLogits subset); + int handleSelect(struct SubsetLogits subset); + std::string handleToken(int token, std::ostream* os = &std::cout, const char* end_with = nullptr); +public: + LocalSampler(Llm* llm, std::shared_ptr config); + int algorithm(MNN::Express::VARP logits); + virtual std::string sample(const std::vector& input_ids, std::ostream* os = &std::cout, const char* end_with = nullptr, struct TimePerformance* time_perf = nullptr) override; +}; + + +} // Transformer +} // MNN + + +#endif // SAMPLER_hpp \ No newline at end of file diff --git a/transformers/llm/engine/src/tokenizer.cpp b/transformers/llm/engine/src/tokenizer.cpp index 87f02c868..913afacfc 100644 --- a/transformers/llm/engine/src/tokenizer.cpp +++ b/transformers/llm/engine/src/tokenizer.cpp @@ -475,7 +475,7 @@ void Tiktoken::encode(const std::string& str, std::vector& ids) { } else { // If no matching symbol is found, this typically means an error in the encoding // or the input text contains characters that the encoder doesn't know how to handle - std::cerr << "Error: No encoding found for the sequence starting at position " << i << std::endl; + std::cerr << "Error: No encoding found for the sequence starting at position " << i << " , symbol: " << str[i-2] << std::endl; return; } } diff --git a/transformers/llm/engine/model/bench.txt b/transformers/llm/engine/test/bench_cn.txt similarity index 100% rename from transformers/llm/engine/model/bench.txt rename to transformers/llm/engine/test/bench_cn.txt diff --git a/transformers/llm/engine/test/bench_en.txt b/transformers/llm/engine/test/bench_en.txt new file mode 100644 index 000000000..6f6ecbe2d --- /dev/null +++ b/transformers/llm/engine/test/bench_en.txt @@ -0,0 +1,3 @@ +calculate 8*12 +translate the following into Chinese:It's a beautiful day to learn something new. +Describe top 5 characters a leader needs, and explain why. \ No newline at end of file diff --git a/transformers/llm/eval/evaluate_perplexity.py b/transformers/llm/eval/evaluate_perplexity.py index 7b467bb58..50b5fe163 100644 --- a/transformers/llm/eval/evaluate_perplexity.py +++ b/transformers/llm/eval/evaluate_perplexity.py @@ -17,7 +17,7 @@ def main(args): dataset_dir = eval_dataset.split("/")[1] dataset = load_dataset(dataset_name, dataset_dir, split="test") - input_ids = model.tokenizer_encode("\n\n".join(dataset["text"])) + input_ids = model.tokenizer_encode("\n\n".join(dataset["text"]), False) stride = 512 context_length = stride + stride // 2 seq_len = len(input_ids) diff --git a/transformers/llm/export/.gitignore b/transformers/llm/export/.gitignore new file mode 100644 index 000000000..ddb5fb2d5 --- /dev/null +++ b/transformers/llm/export/.gitignore @@ -0,0 +1,4 @@ +* +!.gitignore +!llmexport.py +!README.md \ No newline at end of file diff --git a/transformers/llm/export/README.md b/transformers/llm/export/README.md index 136f1329f..371adab67 100644 --- a/transformers/llm/export/README.md +++ b/transformers/llm/export/README.md @@ -25,18 +25,26 @@ pip install . ## 用法 -1. 将需要导出的LLM项目clone到本地,如:chatglm2-6b +1. 下载模型 ```sh -git clone https://huggingface.co/THUDM/chatglm2-6b +git clone https://huggingface.co/Qwen/Qwen2-1.5B-Instruct # 如果huggingface下载慢可以使用modelscope -git clone https://modelscope.cn/ZhipuAI/chatglm2-6b.git +git clone https://modelscope.cn/qwen/Qwen2-1.5B-Instruct.git ``` -2. 导出模型 +2. 测试模型 ```sh -# 将chatglm2-6b导出为onnx模型 -llmexport --path ../chatglm2-6b --export onnx -# 将chatglm2-6b导出为mnn模型, 量化参数为4bit, blokc-wise = 128 -llmexport --path ../chatglm2-6b --export mnn --quant_bit 4 --quant_block 128 +# 测试文本输入 +llmexport --path Qwen2-1.5B-Instruct --test "你好" +# 测试图像文本 +llmexport --path Qwen2-VL-2B-Instruct --test "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg介绍一下图片里的内容" +``` + +3. 导出模型 +```sh +# 将Qwen2-1.5B-Instruct导出为onnx模型 +llmexport --path Qwen2-1.5B-Instruct --export onnx +# 将Qwen2-1.5B-Instruct导出为mnn模型, 量化参数为4bit, blokc-wise = 128 +llmexport --path Qwen2-1.5B-Instruct --export mnn --quant_bit 4 --quant_block 128 ``` ## 功能 @@ -48,14 +56,6 @@ llmexport --path ../chatglm2-6b --export mnn --quant_bit 4 --quant_block 128 - 使用`--lm_quant_bit`来制定lm_head层权重的量化bit数,不指定则使用`--quant_bit`的量化bit数 - 支持使用自己编译的`MNNConvert`,使用`--mnnconvert` -`--test`测试示例 -```sh -# 测试文本输入 -llmexport --path Qwen2-1.5B-Instruct --test "你好" -# 测试图像文本 -llmexport --path Qwen2-VL-2B-Instruct --test "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg介绍一下图片里的内容" -``` - ## 参数 ``` usage: llmexport.py [-h] --path PATH [--type TYPE] [--lora_path LORA_PATH] [--dst_path DST_PATH] [--test TEST] [--export EXPORT] @@ -90,8 +90,8 @@ options: ## 支持模型 -- llama/llama2/llama3/tinyllama -- qwen/qwen1.5/qwen2/qwen-vl +- llama/llama2/llama3/llama3.2/tinyllama +- qwen/qwen1.5/qwen2/qwen-vl/qwen2-vl/qwen2.5 - baichuan2/phi-2/internlm/yi/deepseek - chatglm/codegeex/chatglm2/chatglm3 - phi-2/gemma-2 diff --git a/transformers/llm/export/llmexport.py b/transformers/llm/export/llmexport.py index edd420203..7a7a72b76 100644 --- a/transformers/llm/export/llmexport.py +++ b/transformers/llm/export/llmexport.py @@ -1428,7 +1428,7 @@ def rebuild(self, json_path): def quant(self, weight, quant_bit, quant_block, symmetric): if torch.cuda.is_available(): weight = weight.cuda() - if torch.mps.is_available(): + if torch.backends.mps.is_available(): weight = weight.to('mps') oc, ic = weight.shape if quant_block == 0: @@ -2259,7 +2259,7 @@ def init_from_args(self, args): self.onnx_path = os.path.join(self.dst_path, 'onnx') self.tokenizer_path = args.tokenizer_path self.lora_path = args.lora_path - self.onnx_slim = args.onnx_slim + self.need_onnx_slim = args.onnx_slim self.ppl = args.ppl self.awq = args.awq self.quant_bit = args.quant_bit @@ -2365,12 +2365,16 @@ def visit_module(module): "position_ids" : { 1: "seq_len" }, "past_key_values" : { 3: "history_len" } } + prompt_template = self.build_prompt_template() self.llm_config = { 'hidden_size' : self.hidden_size, 'layer_nums' : self.num_hidden_layers, 'attention_mask': self.attention_mask_type, 'key_value_shape': self.past_kv_shape[1:], - "prompt_template": self.build_prompt('%s'), + "system_prompt_template": prompt_template['system'].format(query='%s'), + 'user_prompt_template': prompt_template['user'].format(query='%s'), + 'assistant_prefix': prompt_template['assistant_prefix'], + 'assistant_suffix': prompt_template['assistant_suffix'], 'is_visual': False } # load modules @@ -2469,43 +2473,105 @@ def forward(self, return logits, presents # some test functions - def build_prompt(self, query): + def build_prompt_template(self) -> Dict[str, str]: + template = { + 'system': '', + 'user': '', + 'assistant_prefix': '', + 'assistant_suffix': '', + } # just for test - if 'Qwen2' in self.path or 'QwQ' in self.path or 'reader' in self.path: - return f'<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n' - if 'Qwen' in self.path: - return f'\n<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n' + if 'Qwen' in self.path or 'Qwen2' in self.path or 'QwQ' in self.path or 'reader' in self.path: + template['system'] = '<|im_start|>system\n{query}<|im_end|>\n' + template['user'] = '<|im_start|>user\n{query}<|im_end|>\n' + template['assistant_prefix'] = '<|im_start|>assistant\n' + template['assistant_suffix'] = '<|im_end|>\n' + return template if 'Baichuan2' in self.path: - return f'{query}' + template['user'] = '{query}' + return template if 'internlm' in self.path: - return f'<|User|>:{query}\n<|Bot|>:' + template['user'] = '<|User|>:{query}\n' + template['assistant_prefix'] = '<|Bot|>:' + template['assistant_suffix'] = '\n' + return template if 'TinyLlama' in self.path: - return f'<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate\n<|user|>\n{query}\n<|assistant|>\n' + template['system'] = '<|system|>\n{query}\n' + template['user'] = '<|user|>\n{query}\n' + template['assistant_prefix'] = '<|assistant|>\n' + template['assistant_suffix'] = '\n' + return template if 'Yi' in self.path: - return f'<|im_start|> user\n{query}<|im_end|>\n<|im_start|> assistant\n' + template['user'] = '<|im_start|> user\n{query}<|im_end|>\n' + template['assistant_prefix'] = '<|im_start|> assistant\n' + template['assistant_suffix'] = '<|im_end|>\n' + return template if 'deepseek' in self.path: - return f'<|begin_of_sentence|>User: {query}\n\nAssistant:' + template['user'] = '<|begin_of_sentence|>User: {query}\n' + template['assistant_prefix'] = '\nAssistant: ' + template['assistant_suffix'] = '\n<|end_of_sentence|>' + return template if 'Llama-3.1' in self.path: - return f'<|start_header_id|>user<|end_header_id|>\n\n{query}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' + template['system'] = '<|start_header_id|>system<|end_header_id|>\n\n{query}<|eot_id|>' + template['user'] = '<|start_header_id|>user<|end_header_id|>\n\n{query}<|eot_id|>' + template['assistant_prefix'] = '<|start_header_id|>assistant<|end_header_id|>\n\n' + template['assistant_suffix'] = '<|eot_id|>' + return template if 'Llama-3' in self.path: - return f'<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{query}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' + template['system'] = '<|start_header_id|>system<|end_header_id|>\n\n{query}<|eot_id|>' + template['user'] = '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{query}<|eot_id|>' + template['assistant_prefix'] = '<|start_header_id|>assistant<|end_header_id|>\n\n' + template['assistant_suffix'] = '<|eot_id|>' + return template if 'Llama-2' in self.path: - return f'[INST]{query}[/INST]' + template['user'] = '[INST]{query}[/INST]' + return template if 'chatglm2' in self.path: - return f'[Round 1]\n\n问:{query}\n\n答:' + template['user'] = '[Round 1]\n\n问:{query}\n\n' + template['assistant_prefix'] = '答:' + template['assistant_suffix'] = '\n\n' + return template if 'chatglm3' in self.path or 'glm-4' in self.path: - return f'<|user|>\n{query}\n<|assistant|>\n' + template['user'] = '<|user|>\n{query}\n' + template['assistant_prefix'] = '<|assistant|>\n' + template['assistant_suffix'] = '\n' + return template if 'chatglm' in self.path: - return f'{query}[gMASK]' + template['user'] = '{query}[gMASK]' + return template if 'phi-2' in self.path: - return f'Instruct: {query}\nOutput:' + template['user'] = 'Instruct: {query}\n' + template['assistant_prefix'] = 'Output:' + template['assistant_suffix'] = '\n' + return template if 'gemma-2' in self.path: - return f'user\n{query}\nmodel\n' + template['system'] = 'system\n{query}\n' + template['user'] = 'user\n{query}\n' + template['assistant_prefix'] = 'model\n' + template['assistant_suffix'] = '\n' + return template if 'OpenELM' in self.path: - return f'{query}' + template['user'] = '{query}' + return template if 'SmolLM2' in self.path: - return f'<|im_start|>system\nYou are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>\n<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n' - return query + template['system'] = '<|im_start|>system\n{query}<|im_end|>\n' + template['user'] = '<|im_start|>user\n{query}<|im_end|>\n' + template['assistant_prefix'] = '<|im_start|>assistant\n' + template['assistant_suffix'] = '<|im_end|>\n' + return template + # not matched + return template + + def build_prompt(self, queries, roles): + template = self.build_prompt_template(self) + prompt = "" + for item in zip(queries, roles): + query, role = item + if '{query}' in template[role]: + prompt += template[role].format(query=query) + else: + prompt += role + '\n' + query +'\n' + return prompt + template['assistant_prefix'] def str_to_ids(self, prompt): if self.visual is not None: @@ -2536,7 +2602,7 @@ def decode_ids(token_ids): def response(self, query): # self.imitate_quant() self.decode_buffer = [] - prompt = self.build_prompt(query) + prompt = self.build_prompt(['You are a helpful assistant!', query], roles=['system', 'user']) input_ids = self.str_to_ids(prompt) if self.visual is not None: cross_attention_states = self.visual.cross_attention_states @@ -2727,13 +2793,13 @@ def export(self, export_type): self.export_embed() if self.visual: visual_onnx = self.export_visual() - #if self.onnx_slim: + #if self.need_onnx_slim: #visual_onnx = self.onnx_slim(visual_onnx) if export_mnn: MNNConveter(visual_onnx, None, self).export(quant_bit=self.visual.quant_bit) # export graph to llm.onnx onnx_model = self.export_onnx() - if self.onnx_slim: + if self.need_onnx_slim: self.onnx_slim(onnx_model) if export_mnn: # convert onnx to mnn and quant weight @@ -3033,7 +3099,7 @@ def export(self, export_type): self.export_config(export_mnn) self.export_embed() onnx_model = self.export_onnx() - if self.onnx_slim: + if self.need_onnx_slim: self.onnx_slim(onnx_model) if export_mnn: MNNConveter(onnx_model, None, self).export() diff --git a/transformers/llm/export/requirements.txt b/transformers/llm/export/requirements.txt index d0b77b53e..6d095790f 100644 --- a/transformers/llm/export/requirements.txt +++ b/transformers/llm/export/requirements.txt @@ -1,5 +1,5 @@ datasets -MNN +MNN>=3.0.0 onnx onnxslim onnxruntime @@ -9,6 +9,6 @@ Requests sentencepiece torch tqdm -transformers +transformers>=4.45 yaspin numpy