diff --git a/README.md b/README.md index 0d982e75..686bb838 100644 --- a/README.md +++ b/README.md @@ -163,7 +163,11 @@ If you need to use it for commercial purposes, please follow the license of Qwen git clone https://github.com/netease-youdao/QAnything.git ``` ### step2: Enter the project root directory and execute the startup script. -If you are in the Windows11 system: Need to enter the WSL environment. +* [📖 QAnything_Startup_Usage](docs/QAnything_Startup_Usage_README.md) +* Get detailed usage of LLM interface by ```bash ./run.sh -h``` + + +If you are in the Windows11 system: Need to enter the **WSL** environment. ```shell cd QAnything bash run.sh # Start on GPU 0 by default. @@ -174,7 +178,7 @@ bash run.sh # Start on GPU 0 by default. ```shell cd QAnything -bash run.sh 0 # gpu id 0 +bash ./run.sh -c local -i 0 -b default # gpu id 0 ``` @@ -183,7 +187,7 @@ bash run.sh 0 # gpu id 0 ```shell cd QAnything -bash run.sh 0,1 # gpu ids: 0,1, Please confirm how many GPUs are available. Supports up to two cards for startup. +bash ./run.sh -c local -i 0,1 -b default # gpu ids: 0,1, Please confirm how many GPUs are available. Supports up to two cards for startup. ``` @@ -265,7 +269,8 @@ Reach out to the maintainer at one of the following places: `QAnything` adopts dependencies from the following: - Thanks to our [BCEmbedding](https://github.com/netease-youdao/BCEmbedding) for the excellent embedding and rerank model. - Thanks to [Qwen](https://github.com/QwenLM/Qwen) for strong base language models. -- Thanks to [Triton Inference Server](https://github.com/triton-inference-server/server) for providing great open source inference serving. +- Thanks to [Triton Inference Server](https://github.com/triton-inference-server/server) and vllm(https://github.com/vllm-project/vllm) for providing great open source inference serving. +- Thanks to [FastChat](https://github.com/lm-sys/FastChat) for providing a fully OpenAI-compatible API server. - Thanks to [FasterTransformer](https://github.com/NVIDIA/FasterTransformer) for highly optimized LLM inference backend. - Thanks to [Langchain](https://github.com/langchain-ai/langchain) for the wonderful llm application framework. - Thanks to [Langchain-Chatchat](https://github.com/chatchat-space/Langchain-Chatchat) for the inspiration provided on local knowledge base Q&A. diff --git a/README_zh.md b/README_zh.md index 6bb74622..f92f2d6e 100644 --- a/README_zh.md +++ b/README_zh.md @@ -154,7 +154,10 @@ QAnything使用的检索组件[BCEmbedding](https://github.com/netease-youdao/BC git clone https://github.com/netease-youdao/QAnything.git ``` ### step2: 进入项目根目录执行启动脚本 -如果在Windows系统下请先进入wsl环境 +* [📖 QAnything_Startup_Usage](docs/QAnything_Startup_Usage_README.md) +* 执行 ```bash ./run.sh -h``` 获取详细的LLM服务配置方法 + +如果在Windows系统下请先进入**WSL**环境 ```shell cd QAnything bash run.sh # 默认在0号GPU上启动 @@ -165,7 +168,7 @@ bash run.sh # 默认在0号GPU上启动 ```shell cd QAnything -bash run.sh 0 # 指定0号GPU启动 GPU编号从0开始 windows机器一般只有一张卡,所以只能指定0号GPU +bash ./run.sh -c local -i 0 -b default # 指定0号GPU启动 GPU编号从0开始 windows机器一般只有一张卡,所以只能指定0号GPU ``` @@ -174,7 +177,7 @@ bash run.sh 0 # 指定0号GPU启动 GPU编号从0开始 windows机器一般只 ```shell cd QAnything -bash run.sh 0,1 # 指定0,1号GPU启动,请确认有多张GPU可用,最多支持两张卡启动 +bash ./run.sh -c local -i 0,1 -b default # 指定0,1号GPU启动,请确认有多张GPU可用,最多支持两张卡启动 ``` @@ -248,6 +251,8 @@ qanything@rd.netease.com - [BCEmbedding](https://github.com/netease-youdao/BCEmbedding) - [Qwen](https://github.com/QwenLM/Qwen) - [Triton Inference Server](https://github.com/triton-inference-server/server) +- [vllm](https://github.com/vllm-project/vllm) +- [FastChat](https://github.com/lm-sys/FastChat) - [FasterTransformer](https://github.com/NVIDIA/FasterTransformer) - [Langchain](https://github.com/langchain-ai/langchain) - [Langchain-Chatchat](https://github.com/chatchat-space/Langchain-Chatchat) diff --git a/assets/custom_models/.gitignore b/assets/custom_models/.gitignore new file mode 100644 index 00000000..e69de29b diff --git a/docker-compose-linux.yaml b/docker-compose-linux.yaml index 9c9578f2..c6013b50 100644 --- a/docker-compose-linux.yaml +++ b/docker-compose-linux.yaml @@ -78,7 +78,7 @@ services: qanything_local: container_name: qanything-container-local - image: freeren/qanything:v1.0.9 + image: freeren/qanything:v1.1.1 # runtime: nvidia deploy: resources: @@ -87,11 +87,12 @@ services: - driver: nvidia count: "all" capabilities: ["gpu"] - command: /workspace/qanything_local/scripts/run_for_local.sh + command: /bin/bash -c 'if [ "${LLM_API}" = "local" ]; then /workspace/qanything_local/scripts/run_for_local_option.sh -c $LLM_API -i $DEVICE_ID -b $RUNTIME_BACKEND -m $MODEL_NAME -t $CONV_TEMPLATE -p $TP -r $GPU_MEM_UTILI; else /workspace/qanything_local/scripts/run_for_cloud_option.sh -c $LLM_API -i $DEVICE_ID -b $RUNTIME_BACKEND; fi; while true; do sleep 5; done' privileged: true shm_size: '8gb' volumes: - ${DOCKER_VOLUME_DIRECTORY:-.}/models:/model_repos/QAEnsemble + - ${DOCKER_VOLUME_DIRECTORY:-.}/assets/custom_models:/model_repos/CustomLLM - ${DOCKER_VOLUME_DIRECTORY:-.}/:/workspace/qanything_local/ ports: - "5052:5052" diff --git a/docker-compose-windows.yaml b/docker-compose-windows.yaml index e26618a3..2b8a9f79 100644 --- a/docker-compose-windows.yaml +++ b/docker-compose-windows.yaml @@ -78,7 +78,7 @@ services: qanything_local: container_name: qanything-container-local - image: freeren/qanything-win:v1.0.9 + image: freeren/qanything-win:v1.1.1 # runtime: nvidia deploy: resources: @@ -87,11 +87,12 @@ services: - driver: nvidia count: "all" capabilities: ["gpu"] - command: /workspace/qanything_local/scripts/run_for_local.sh + command: sh -c 'if [ "${LLM_API}" = "local" ]; then /workspace/qanything_local/scripts/run_for_local_option.sh -c $LLM_API -i $DEVICE_ID -b $RUNTIME_BACKEND -m $MODEL_NAME -t $CONV_TEMPLATE -p $TP -r $GPU_MEM_UTILI; else /workspace/qanything_local/scripts/run_for_cloud_option.sh -c $LLM_API -i $DEVICE_ID -b $RUNTIME_BACKEND; fi; while true; do sleep 5; done' privileged: true shm_size: '8gb' volumes: - ${DOCKER_VOLUME_DIRECTORY:-.}/models:/model_repos/QAEnsemble + - ${DOCKER_VOLUME_DIRECTORY:-.}/assets/custom_models:/model_repos/CustomLLM - ${DOCKER_VOLUME_DIRECTORY:-.}/:/workspace/qanything_local/ ports: - "5052:5052" diff --git a/docs/QAnything_Startup_Usage_README.md b/docs/QAnything_Startup_Usage_README.md new file mode 100644 index 00000000..4d0c9fa5 --- /dev/null +++ b/docs/QAnything_Startup_Usage_README.md @@ -0,0 +1,114 @@ + + +## Table of Contents + +- [QAnything Service Startup Command Usage](#QAnything-Service-Startup-Command-Usage) +- [Supported Pulic LLM using FastChat API](#Supported-Pulic-LLM-using-FastChat-API-with-Huggingface-Transformers/vllm-runtime-backend) +- [Tricks for saving GPU VRAM](#Tricks-for-saving-GPU-VRAM) +- [Comming Soon](#Comming-Soon) + + +## QAnything Service Startup Command Usage + +```bash +Usage: bash run.sh [-c ] [-i ] [-b ] [-m ] [-t ] [-p ] [-r ] + +-c : "Options {local, cloud} to specify the llm API mode, default is 'local'. If set to '-c cloud', please mannually set the environments {OPENAI_API_KEY, OPENAI_API_BASE, OPENAI_API_MODEL_NAME, OPENAI_API_CONTEXT_LENGTH} into .env fisrt." +-i : "Specify argument GPU device_id." +-b : "Specify argument LLM inference runtime backend, options={default, hf, vllm}" +-m : "Specify argument the model name to load public LLM model using FastChat serve API, options={Qwen-7B-Chat, deepseek-llm-7b-chat, ...}" +-t : "Specify argument the conversation template according to the public LLM model when using FastChat serve API, options={qwen-7b-chat, deepseek-chat, ...}" +-p : "Use options {1, 2} to set tensor parallel parameters for vllm backend when using FastChat serve API, default tensor_parallel=1" +-r : "Specify argument gpu_memory_utilization (0,1] for vllm backend when using FastChat serve API, default gpu_memory_utilization=0.81" +-h: "Display help usage message" +``` + +| Service Startup Command | GPUs | LLM Runtime Backend | LLM model | +| --------------------------------------------------------------------------------------- | -----|--------------------------| -------------------------------- | +| ```bash ./run.sh -c cloud -i 0 -b default``` | 1 | OpenAI API | OpenAI API | +| ```bash ./run.sh -c local -i 0 -b default``` | 1 | FasterTransformer | Qwen-7B-QAnything | +| ```bash ./run.sh -c local -i 0 -b hf -m MiniChat-2-3B -t minichat``` | 1 | Huggingface Transformers | Public LLM (e.g., MiniChat-2-3B) | +| ```bash ./run.sh -c local -i 0 -b vllm -m MiniChat-2-3B -t minichat -p 1 -r 0.81``` | 1 | vllm | Public LLM (e.g., MiniChat-2-3B) | +| ```bash ./run.sh -c local -i 0,1 -b default``` | 2 | FasterTransformer | Qwen-7B-QAnything | +| ```bash ./run.sh -c local -i 0,1 -b hf -m MiniChat-2-3B -t minichat``` | 2 | Huggingface Transformers | Public LLM (e.g., MiniChat-2-3B) | +| ```bash ./run.sh -c local -i 0,1 -b vllm -m MiniChat-2-3B -t minichat -p 1 -r 0.81``` | 2 | vllm | Public LLM (e.g., MiniChat-2-3B) | +| ```bash ./run.sh -c local -i 0,1 -b vllm -m MiniChat-2-3B -t minichat -p 2 -r 0.81``` | 2 | vllm | Public LLM (e.g., MiniChat-2-3B) | + +```bash +Note: You can choose the most suitable Service Startup Command based on your own device conditions. +(1) Local Embedding/Rerank will run on device gpu_id_1 when setting "-i 0,1", otherwise using gpu_id_0 as default. +(2) When setting "-c cloud" that will use local Embedding/Rerank and OpenAI LLM API, which only requires about 4GB VRAM (recommend for GPU device VRAM <= 8GB). +(3) When you use OpenAI LLM API, you will be required to enter {OPENAI_API_KEY, OPENAI_API_BASE, OPENAI_API_MODEL_NAME, OPENAI_API_CONTEXT_LENGTH} immediately. +(4) "-b hf" is the most recommended way for running public LLM inference for its compatibility but with poor performance. +(5) When you choose a public Chat LLM for QAnything system, you should take care of a more suitable **PROMPT_TEMPLATE** (/path/to/QAnything/qanything_kernel/configs/model_config.py) setting considering different LLM models. +``` + +## Supported Pulic LLM using FastChat API with Huggingface Transformers/vllm runtime backend + +| model_name | conv_template | Supported Pulic LLM List | +|-------------------------------------------|---------------------|---------------------------------------------------------------------------------| +| Qwen-7B-QAnything | qwen-7b-qanything | [Qwen-7B-QAnything](https://huggingface.co/netease-youdao/Qwen-7B-QAnything) | +| Qwen-1.8B-Chat/Qwen-7B-Chat/Qwen-14B-Chat | qwen-7b-chat | [Qwen](https://huggingface.co/Qwen) | +| Baichuan2-7B-Chat/Baichuan2-13B-Chat | baichuan2-chat | [Baichuan2](https://huggingface.co/baichuan-inc) | +| MiniChat-2-3B | minichat | [MiniChat](https://huggingface.co/GeneZC/MiniChat-2-3B) | +| deepseek-llm-7b-chat | deepseek-chat | [Deepseek](https://huggingface.co/deepseek-ai/deepseek-llm-7b-chat) | +| Yi-6B-Chat | Yi-34b-chat | [Yi](https://huggingface.co/01-ai/Yi-6B-Chat) | +| chatglm3-6b | chatglm3 | [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | +| ... ```check or add conv_template for more LLMs in "/path/to/QAnything/third_party/FastChat/fastchat/conversation.py"``` | + +### 1. Run QAnything using FastChat API with **Huggingface transformers** runtime backend (recommend for GPU device with VRAM <= 16GB). +```bash +## Step 1. Download the public LLM model (e.g., MiniChat-2-3B) and save to "/path/to/QAnything/assets/custom_models" +cd /path/to/QAnything/assets/custom_models +git clone https://huggingface.co/GeneZC/MiniChat-2-3B + +## Step 2. Execute the service startup command. Here we use "-b hf" to specify the Huggingface transformers backend. +## Here we use "-b hf" to specify the transformers backend that will load model in 8 bits but do bf16 inference as default for saving VRAM. +cd /path/to/QAnything +bash ./run.sh -c local -i 0 -b hf -m MiniChat-2-3B -t minichat + +``` + +### 2. Run QAnything using FastChat API with **vllm** runtime backend (recommend for GPU device with enough VRAM). + +```bash +## Step 1. Download the public LLM model (e.g., MiniChat-2-3B) and save to "/path/to/QAnything/assets/custom_models" +cd /path/to/QAnything/assets/custom_models +git clone https://huggingface.co/GeneZC/MiniChat-2-3B + +## Step 2. Execute the service startup command. +## Here we use "-b vllm" to specify the vllm backend that will do bf16 inference as default. +## Note you should adjust the gpu_memory_utilization yourself according to the model size to aviod out of memory (e.g., gpu_memory_utilization=0.81 is set default for 7B. Here, gpu_memory_utilization is set to 0.5 by "-r 0.5"). +cd /path/to/QAnything +bash ./run.sh -c local -i 0 -b vllm -m MiniChat-2-3B -t minichat -p 1 -r 0.5 + +## (Optional) Step 2. Execute the service startup command to specify the vllm backend by "-i 0,1 -p 2". It will do faster inference by setting a tensor parallel mode on 2 GPUs. +## bash ./run.sh -c local -i 0,1 -b vllm -m MiniChat-2-3B -t minichat -p 2 -r 0.5 + +``` + +## Tricks for saving GPU VRAM +```bash +## Trick 1. (Recommend for VRAM<=12 GB) Using PaddleOCR serve in CPU mode **use_gpu=False** in '/path/to/QAnything/qanything_kernel/dependent_server/ocr_serve/ocr_server.py' +# Note that **use_gpu=False** must be set when using RTX-1080Ti GPU, otherwise PaddleOCR will always return **empty ocr result** when using **use_gpu=True**. +ocr_engine = PaddleOCR(use_angle_cls=True, lang="ch", use_gpu=False, show_log=False) + +## Trick 2. Try 1.8B/3B size LLM, such as Qwen-1.8B-Chat and MiniChat-2-3B. + +## Trick 3. Try to limit the max length of context window by decreasing the value of **token_window** and increasing that of **offcut_token** +# /path/to/QAnything/qanything_kernel/connector/llm/llm_for_fastchat.py +# /path/to/QAnything/qanything_kernel/connector/llm/llm_for_local.py + +## Trick 4. Try INT4-Weight-Only Quantization methods such as GPTQ/AWQ. You should take care of the sampling parameters considering possible loss of accuracy. + +``` + + +## Comming Soon +
Feature Request + +- Support one-api interface to add more business LLM API (https://github.com/songquanpeng/one-api). +- Support more runtime backends, such as llama.cpp (https://github.com/ggerganov/llama.cpp) and sglang (https://github.com/sgl-project/sglang). +- ... + +
diff --git a/qanything_kernel/configs/model_config.py b/qanything_kernel/configs/model_config.py index bd3b5979..a3d32f2c 100644 --- a/qanything_kernel/configs/model_config.py +++ b/qanything_kernel/configs/model_config.py @@ -27,6 +27,9 @@ 请根据上述参考信息回答我的问题或回复我的指令。前面的参考信息可能有用,也可能没用,你需要从我给出的参考信息中选出与我的问题最相关的那些,来为你的回答提供依据。回答一定要忠于原文,简洁但不丢信息,不要胡乱编造。我的问题或指令是什么语种,你就用什么语种回复, 你的回复:""" +# For LLM Chat w/o Retrieval context +# PROMPT_TEMPLATE = """{question}""" + QUERY_PROMPT_TEMPLATE = """{question}""" # 缓存知识库数量 @@ -57,20 +60,30 @@ # MILVUS向量数据库地址 MILVUS_HOST_LOCAL = 'milvus-standalone-local' -MILVUS_HOST_ONLINE = '10.55.163.98' # gpu63 +MILVUS_HOST_ONLINE = 'milvus-standalone-local' MILVUS_PORT = 19530 MYSQL_HOST_LOCAL = 'mysql-container-local' -MYSQL_HOST_ONLINE = '10.55.163.98' +MYSQL_HOST_ONLINE = 'mysql-container-local' MYSQL_PORT = 3306 MYSQL_USER = 'root' MYSQL_PASSWORD = '123456' MYSQL_DATABASE = 'qanything' +llm_api_serve_model = os.getenv('LLM_API_SERVE_MODEL') +llm_api_serve_port = os.getenv('LLM_API_SERVE_PORT') rerank_port = os.getenv('RERANK_PORT') embed_port = os.getenv('EMBED_PORT') + +print("llm_api_serve_port:", llm_api_serve_port) print("rerank_port:", rerank_port) print("embed_port:", embed_port) + + +LOCAL_LLM_SERVICE_URL = f"localhost:{llm_api_serve_port}" +LOCAL_LLM_MODEL_NAME = llm_api_serve_model +LOCAL_LLM_MAX_LENGTH = 4096 + LOCAL_RERANK_SERVICE_URL = f"localhost:{rerank_port}" LOCAL_RERANK_MODEL_NAME = 'rerank' LOCAL_RERANK_MAX_LENGTH = 512 diff --git a/qanything_kernel/connector/llm/__init__.py b/qanything_kernel/connector/llm/__init__.py index 82e230df..50dc6953 100644 --- a/qanything_kernel/connector/llm/__init__.py +++ b/qanything_kernel/connector/llm/__init__.py @@ -1,2 +1,11 @@ -from .llm_for_online import OpenAILLM -from .llm_for_local import ZiyueLLM +import os +from dotenv import load_dotenv +from .llm_for_openai_api import OpenAILLM + +load_dotenv() +RUNTIME_BACKEND = os.getenv("RUNTIME_BACKEND") + +if RUNTIME_BACKEND == "default": + from .llm_for_local import ZiyueLLM +else: # hf/vllm + from .llm_for_fastchat import OpenAICustomLLM as ZiyueLLM diff --git a/qanything_kernel/connector/llm/llm_for_fastchat.py b/qanything_kernel/connector/llm/llm_for_fastchat.py new file mode 100644 index 00000000..f2e6deb8 --- /dev/null +++ b/qanything_kernel/connector/llm/llm_for_fastchat.py @@ -0,0 +1,210 @@ +from abc import ABC +import tiktoken +import os +from dotenv import load_dotenv +from openai import OpenAI +from typing import Optional, List +import sys +import json +import requests +import logging +sys.path.append("../../../") +from qanything_kernel.connector.llm.base import (BaseAnswer, AnswerResult) +from qanything_kernel.configs.model_config import LOCAL_LLM_SERVICE_URL, LOCAL_LLM_MODEL_NAME, LOCAL_LLM_MAX_LENGTH + +load_dotenv() + +logging.basicConfig(level=logging.INFO) + +class OpenAICustomLLM(BaseAnswer, ABC): + model: str = LOCAL_LLM_MODEL_NAME + token_window: int = LOCAL_LLM_MAX_LENGTH + max_token: int = 512 + offcut_token: int = 50 + truncate_len: int = 50 + temperature: float = 0 + stop_words: str = None + history: List[List[str]] = [] + history_len: int = 2 + + def __init__(self): + super().__init__() + # self.client = OpenAI(base_url="http://localhost:7802/v1", api_key="EMPTY") + if LOCAL_LLM_SERVICE_URL.startswith("http://"): + base_url = f"{LOCAL_LLM_SERVICE_URL}/v1" + else: + base_url = f"http://{LOCAL_LLM_SERVICE_URL}/v1" + self.client = OpenAI(base_url=base_url, api_key="EMPTY") + + @property + def _llm_type(self) -> str: + return "CustomLLM using FastChat w/ huggingface transformers or vllm backend" + + @property + def _history_len(self) -> int: + return self.history_len + + def set_history_len(self, history_len: int = 10) -> None: + self.history_len = history_len + + def token_check(self, query: str) -> int: + + if LOCAL_LLM_SERVICE_URL.startswith("http://"): + base_url = f"{LOCAL_LLM_SERVICE_URL}/api/v1/token_check" + else: + base_url = f"http://{LOCAL_LLM_SERVICE_URL}/api/v1/token_check" + + headers = {"Content-Type": "application/json"} + + response = requests.post( + base_url, + data=json.dumps( + {'prompts': [{'model': self.model, 'prompt': query, 'max_tokens': self.max_token}]} + ), + headers=headers) + + # {'prompts': [{'fits': True, 'tokenCount': 317, 'contextLength': 8192}]} + result = response.json() + token_num = 0 + try: + token_num = result['prompts'][0]['tokenCount'] + return token_num + except Exception as e: + logging.error(f"token_check Exception {base_url} w/ {e}") + return token_num + + def num_tokens_from_messages(self, message_texts): + num_tokens = 0 + for message in message_texts: + num_tokens += self.token_check(message) + return num_tokens + + def num_tokens_from_docs(self, docs): + num_tokens = 0 + for doc in docs: + num_tokens += self.token_check(doc.page_content) + return num_tokens + + def _call(self, prompt: str, history: List[List[str]], streaming: bool=False) -> str: + messages = [] + for pair in history: + question, answer = pair + messages.append({"role": "user", "content": question}) + messages.append({"role": "assistant", "content": answer}) + messages.append({"role": "user", "content": prompt}) + logging.info(messages) + + try: + + if streaming: + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + stream=True, + max_tokens=self.max_token, + # temperature=self.temperature, + stop=[self.stop_words] if self.stop_words is not None else None, + ) + + for event in response: + if not isinstance(event, dict): + event = event.model_dump() + + event_text = event["choices"][0]['delta']['content'] + if isinstance(event_text, str) and event_text != "": + # logging.info(f"[debug] event_text = [{event_text}]") + delta = {'answer': event_text} + yield "data: " + json.dumps(delta, ensure_ascii=False) + + else: + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + stream=False, + max_tokens=self.max_token, + # temperature=self.temperature, + stop=[self.stop_words] if self.stop_words is not None else None, + ) + + # logging.info(f"[debug] response.choices = [{response.choices}]") + event_text = response.choices[0].message.content if response.choices else "" + delta = {'answer': event_text} + yield "data: " + json.dumps(delta, ensure_ascii=False) + + except Exception as e: + logging.info(f"Error calling API: {e}") + delta = {'answer': f"{e}"} + yield "data: " + json.dumps(delta, ensure_ascii=False) + + finally: + # logging.info("[debug] try-finally") + yield f"data: [DONE]\n\n" + + def generatorAnswer(self, prompt: str, + history: List[List[str]] = [], + streaming: bool = False) -> AnswerResult: + + if history is None or len(history) == 0: + history = [[]] + logging.info(f"history_len: {self.history_len}") + logging.info(f"prompt: {prompt}") + logging.info(f"prompt tokens: {self.num_tokens_from_messages([prompt])}") + logging.info(f"streaming: {streaming}") + + response = self._call(prompt, history[:-1], streaming) + complete_answer = "" + for response_text in response: + + if response_text: + chunk_str = response_text[6:] + if not chunk_str.startswith("[DONE]"): + chunk_js = json.loads(chunk_str) + complete_answer += chunk_js["answer"] + + history[-1] = [prompt, complete_answer] + answer_result = AnswerResult() + answer_result.history = history + answer_result.llm_output = {"answer": response_text} + answer_result.prompt = prompt + yield answer_result + + +if __name__ == "__main__": + + base_url = f"http://{LOCAL_LLM_SERVICE_URL}/api/v1/token_check" + headers = {"Content-Type": "application/json"} + query = "hello" + response = requests.post( + base_url, + data=json.dumps( + {'prompts': [{'model': LOCAL_LLM_MODEL_NAME, 'prompt': query, 'max_tokens': 512}]} + ), + headers=headers) + + # {'prompts': [{'fits': True, 'tokenCount': 317, 'contextLength': 8192}]} + result = response.json() + logging.info(f"[debug] result = {result}") + + + llm = OpenAICustomLLM() + streaming = True + chat_history = [] + prompt = "你是谁" + prompt = """参考信息: +中央纪委国家监委网站讯 据山西省纪委监委消息:山西转型综合改革示范区党工委副书记、管委会副主任董良涉嫌严重违纪违法,目前正接受山西省纪委监委纪律审查和监察调查。\\u3000\\u3000董良简历\\u3000\\u3000董良,男,汉族,1964年8月生,河南鹿邑人,在职研究生学历,邮箱random@xxx.com,联系电话131xxxxx909,1984年3月加入中国共产党,1984年8月参加工作\\u3000\\u3000历任太原经济技术开发区管委会副主任、太原武宿综合保税区专职副主任,山西转型综合改革示范区党工委委员、管委会副主任。2021年8月,任山西转型综合改革示范区党工委副书记、管委会副主任。(山西省纪委监委) +--- +我的问题或指令: +帮我提取上述人物的中文名,英文名,性别,国籍,现任职位,最高学历,毕业院校,邮箱,电话 +--- +请根据上述参考信息回答我的问题或回复我的指令。前面的参考信息可能有用,也可能没用,你需要从我给出的参考信息中选出与我的问题最相关的那些,来为你的回答提供依据。回答一定要忠于原文,简洁但不丢信息,不要胡乱编造。我的问题或指令是什么语种,你就用什么语种回复, +你的回复:""" + final_result = "" + for answer_result in llm.generatorAnswer(prompt=prompt, + history=chat_history, + streaming=streaming): + resp = answer_result.llm_output["answer"] + if "DONE" not in resp: + final_result += json.loads(resp[6:])["answer"] + # logging.info(resp) + + logging.info(f"final_result = {final_result}") \ No newline at end of file diff --git a/qanything_kernel/connector/llm/llm_for_openai_api.py b/qanything_kernel/connector/llm/llm_for_openai_api.py new file mode 100644 index 00000000..590def5e --- /dev/null +++ b/qanything_kernel/connector/llm/llm_for_openai_api.py @@ -0,0 +1,243 @@ +from abc import ABC +import tiktoken +import os +from dotenv import load_dotenv +from openai import OpenAI +from typing import Optional, List +import sys +import json +import requests +import logging +sys.path.append("../../../") +from qanything_kernel.connector.llm.base import (BaseAnswer, AnswerResult) + +load_dotenv() +logging.basicConfig(level=logging.INFO) + +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +OPENAI_API_BASE = os.getenv("OPENAI_API_BASE") +OPENAI_API_MODEL_NAME = os.getenv("OPENAI_API_MODEL_NAME") +OPENAI_API_CONTEXT_LENGTH = os.getenv("OPENAI_API_CONTEXT_LENGTH") +if isinstance(OPENAI_API_CONTEXT_LENGTH, str) and OPENAI_API_CONTEXT_LENGTH != '': + OPENAI_API_CONTEXT_LENGTH = int(OPENAI_API_CONTEXT_LENGTH) +logging.info(f"OPENAI_API_BASE = {OPENAI_API_BASE}") +logging.info(f"OPENAI_API_MODEL_NAME = {OPENAI_API_MODEL_NAME}") + + +class OpenAILLM(BaseAnswer, ABC): + model: str = OPENAI_API_MODEL_NAME + token_window: int = OPENAI_API_CONTEXT_LENGTH + max_token: int = 512 + offcut_token: int = 50 + truncate_len: int = 50 + temperature: float = 0 + top_p: float = 1.0 # top_p must be (0,1] + stop_words: str = None + history: List[List[str]] = [] + history_len: int = 2 + + def __init__(self): + super().__init__() + self.client = OpenAI(base_url=OPENAI_API_BASE, api_key=OPENAI_API_KEY) + + @property + def _llm_type(self) -> str: + return "using OpenAI API serve as LLM backend" + + @property + def _history_len(self) -> int: + return self.history_len + + def set_history_len(self, history_len: int = 10) -> None: + self.history_len = history_len + + # 定义函数 num_tokens_from_messages,该函数返回由一组消息所使用的token数 + def num_tokens_from_messages(self, messages, model=None): + """Return the number of tokens used by a list of messages. From https://github.com/DjangoPeng/openai-quickstart/blob/main/openai_api/count_tokens_with_tiktoken.ipynb""" + # logging.info(f"[debug] num_tokens_from_messages = {model, self.model}") + if model is None: + model = self.model + # 尝试获取模型的编码 + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + # 如果模型没有找到,使用 cl100k_base 编码并给出警告 + logging.info("Warning: model not found. Using cl100k_base encoding.") + encoding = tiktoken.get_encoding("cl100k_base") + # 针对不同的模型设置token数量 + if model in { + "gpt-3.5-turbo-0613", + # "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-16k-0613", + "gpt-4-0314", + "gpt-4-32k-0314", + "gpt-4-0613", + "gpt-4-32k-0613", + "gpt-4-32k", + # "gpt-4-1106-preview", + }: + tokens_per_message = 3 + tokens_per_name = 1 + elif model == "gpt-3.5-turbo-0301": + tokens_per_message = 4 # 每条消息遵循 {role/name}\n{content}\n 格式 + tokens_per_name = -1 # 如果有名字,角色会被省略 + elif "gpt-3.5-turbo" in model: + # 对于 gpt-3.5-turbo 模型可能会有更新,此处返回假设为 gpt-3.5-turbo-0613 的token数量,并给出警告 + logging.info("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.") + return self.num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613") + elif "gpt-4" in model: + # 对于 gpt-4 模型可能会有更新,此处返回假设为 gpt-4-0613 的token数量,并给出警告 + logging.info("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.") + return self.num_tokens_from_messages(messages, model="gpt-4-0613") + + else: + # 对于没有实现的模型,抛出未实现错误 + raise NotImplementedError( + f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""" + ) + + num_tokens = 0 + # 计算每条消息的token数 + for message in messages: + if isinstance(message, dict): + num_tokens += tokens_per_message + for key, value in message.items(): + num_tokens += len(encoding.encode(value)) + if key == "name": + num_tokens += tokens_per_name + elif isinstance(message, str): + num_tokens += len(encoding.encode(message)) + else: + NotImplementedError( + f"""num_tokens_from_messages() is not implemented message type {type(message)}. """ + ) + + num_tokens += 3 # 每条回复都以助手为首 + return num_tokens + + def num_tokens_from_docs(self, docs): + + # 尝试获取模型的编码 + try: + encoding = tiktoken.encoding_for_model(self.model) + except KeyError: + # 如果模型没有找到,使用 cl100k_base 编码并给出警告 + logging.info("Warning: model not found. Using cl100k_base encoding.") + encoding = tiktoken.get_encoding("cl100k_base") + + num_tokens = 0 + for doc in docs: + num_tokens += len(encoding.encode(doc.page_content, disallowed_special=())) + return num_tokens + + def _call(self, prompt: str, history: List[List[str]], streaming: bool=False) -> str: + messages = [] + for pair in history: + question, answer = pair + messages.append({"role": "user", "content": question}) + messages.append({"role": "assistant", "content": answer}) + messages.append({"role": "user", "content": prompt}) + logging.info(messages) + + try: + + if streaming: + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + stream=True, + max_tokens=self.max_token, + temperature=self.temperature, + top_p=self.top_p, + stop=[self.stop_words] if self.stop_words is not None else None, + ) + + for event in response: + if not isinstance(event, dict): + event = event.model_dump() + + if isinstance(event['choices'], List) and len(event['choices']) > 0 : + event_text = event["choices"][0]['delta']['content'] + if isinstance(event_text, str) and event_text != "": + # logging.info(f"[debug] event_text = [{event_text}]") + delta = {'answer': event_text} + yield "data: " + json.dumps(delta, ensure_ascii=False) + + else: + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + stream=False, + max_tokens=self.max_token, + temperature=self.temperature, + top_p=self.top_p, + stop=[self.stop_words] if self.stop_words is not None else None, + ) + + # logging.info(f"[debug] response.choices = [{response.choices}]") + event_text = response.choices[0].message.content if response.choices else "" + delta = {'answer': event_text} + yield "data: " + json.dumps(delta, ensure_ascii=False) + + except Exception as e: + logging.info(f"Error calling OpenAI API: {e}") + delta = {'answer': f"{e}"} + yield "data: " + json.dumps(delta, ensure_ascii=False) + + finally: + # logging.info("[debug] try-finally") + yield f"data: [DONE]\n\n" + + def generatorAnswer(self, prompt: str, + history: List[List[str]] = [], + streaming: bool = False) -> AnswerResult: + + if history is None or len(history) == 0: + history = [[]] + logging.info(f"history_len: {self.history_len}") + logging.info(f"prompt: {prompt}") + logging.info(f"prompt tokens: {self.num_tokens_from_messages([{'content': prompt}])}") + logging.info(f"streaming: {streaming}") + + response = self._call(prompt, history[:-1], streaming) + complete_answer = "" + for response_text in response: + + if response_text: + chunk_str = response_text[6:] + if not chunk_str.startswith("[DONE]"): + chunk_js = json.loads(chunk_str) + complete_answer += chunk_js["answer"] + + history[-1] = [prompt, complete_answer] + answer_result = AnswerResult() + answer_result.history = history + answer_result.llm_output = {"answer": response_text} + answer_result.prompt = prompt + yield answer_result + + +if __name__ == "__main__": + + llm = OpenAILLM() + streaming = True + chat_history = [] + prompt = "你是谁" + prompt = """参考信息: +中央纪委国家监委网站讯 据山西省纪委监委消息:山西转型综合改革示范区党工委副书记、管委会副主任董良涉嫌严重违纪违法,目前正接受山西省纪委监委纪律审查和监察调查。\\u3000\\u3000董良简历\\u3000\\u3000董良,男,汉族,1964年8月生,河南鹿邑人,在职研究生学历,邮箱random@xxx.com,联系电话131xxxxx909,1984年3月加入中国共产党,1984年8月参加工作\\u3000\\u3000历任太原经济技术开发区管委会副主任、太原武宿综合保税区专职副主任,山西转型综合改革示范区党工委委员、管委会副主任。2021年8月,任山西转型综合改革示范区党工委副书记、管委会副主任。(山西省纪委监委) +--- +我的问题或指令: +帮我提取上述人物的中文名,英文名,性别,国籍,现任职位,最高学历,毕业院校,邮箱,电话 +--- +请根据上述参考信息回答我的问题或回复我的指令。前面的参考信息可能有用,也可能没用,你需要从我给出的参考信息中选出与我的问题最相关的那些,来为你的回答提供依据。回答一定要忠于原文,简洁但不丢信息,不要胡乱编造。我的问题或指令是什么语种,你就用什么语种回复, +你的回复:""" + final_result = "" + for answer_result in llm.generatorAnswer(prompt=prompt, + history=chat_history, + streaming=streaming): + resp = answer_result.llm_output["answer"] + if "DONE" not in resp: + final_result += json.loads(resp[6:])["answer"] + logging.info(resp) + + logging.info(f"final_result = {final_result}") \ No newline at end of file diff --git a/qanything_kernel/core/local_doc_qa.py b/qanything_kernel/core/local_doc_qa.py index 17e2fdf3..fe1bd066 100644 --- a/qanything_kernel/core/local_doc_qa.py +++ b/qanything_kernel/core/local_doc_qa.py @@ -13,7 +13,9 @@ from qanything_kernel.utils.general_utils import get_time import requests import traceback +import logging +logging.basicConfig(level=logging.INFO) def _embeddings_hash(self): return hash(self.model_name) @@ -150,6 +152,8 @@ def reprocess_source_documents(self, query: str, query_token_num = self.llm.num_tokens_from_messages([query]) history_token_num = self.llm.num_tokens_from_messages([x for sublist in history for x in sublist]) template_token_num = self.llm.num_tokens_from_messages([prompt_template]) + + # logging.info(f", types = {type(self.llm.token_window), type(self.llm.max_token), type(self.llm.offcut_token), type(query_token_num), type(history_token_num), type(template_token_num)}, values = {query_token_num, history_token_num, template_token_num}") limited_token_nums = self.llm.token_window - self.llm.max_token - self.llm.offcut_token - query_token_num - history_token_num - template_token_num new_source_docs = [] total_token_num = 0 @@ -235,6 +239,8 @@ def get_knowledge_based_answer(self, query, milvus_kb, chat_history=None, stream resp = answer_result.llm_output["answer"] prompt = answer_result.prompt history = answer_result.history + + # logging.info(f"[debug] get_knowledge_based_answer history = {history}") history[-1][0] = query response = {"query": query, "prompt": prompt, diff --git a/qanything_kernel/dependent_server/llm_for_local_serve/llm_server_entrypoint.py b/qanything_kernel/dependent_server/llm_for_local_serve/llm_server_entrypoint.py index d2a29c81..7be2a4b5 100755 --- a/qanything_kernel/dependent_server/llm_for_local_serve/llm_server_entrypoint.py +++ b/qanything_kernel/dependent_server/llm_for_local_serve/llm_server_entrypoint.py @@ -183,7 +183,7 @@ def parse_params(params: OrderedDict) -> Tuple: if chunk_out: bytes_ids += res[input_len + decode_len + bytes_len:] decoding = tokenizer.decode(bytes_ids, skip_special_tokens=True) - if isinstance(decoding, bytes): + if isinstance(decoding, bytes) or (isinstance(decoding, str) and '�' in decoding): bytes_len = len(bytes_ids) continue else: diff --git a/qanything_kernel/dependent_server/llm_for_local_serve/tokenizer_assets/tokenization_qwen.py b/qanything_kernel/dependent_server/llm_for_local_serve/tokenizer_assets/tokenization_qwen.py old mode 100755 new mode 100644 index 0f750de8..2d0a030e --- a/qanything_kernel/dependent_server/llm_for_local_serve/tokenizer_assets/tokenization_qwen.py +++ b/qanything_kernel/dependent_server/llm_for_local_serve/tokenizer_assets/tokenization_qwen.py @@ -1,25 +1,62 @@ +# Copyright (c) Alibaba Cloud. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. -from __future__ import absolute_import, division, print_function, unicode_literals +"""Tokenization classes for QWen.""" +"""Fix: AttributeError: 'QwenTokenizer' object has no attribute 'mergeable_ranks' using transformers >= 4.36""" -import json +import base64 import logging import os import unicodedata -from io import open -import base64 -import tiktoken -from typing import List, Optional, Tuple, Union +from typing import Collection, Dict, List, Set, Tuple, Union +import tiktoken from transformers import PreTrainedTokenizer, AddedToken logger = logging.getLogger(__name__) + VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken"} -class QwenTokenizer(PreTrainedTokenizer): - """Qwen tokenizer.""" +PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" +ENDOFTEXT = "<|endoftext|>" +IMSTART = "<|im_start|>" +IMEND = "<|im_end|>" +# as the default behavior is changed to allow special tokens in +# regular texts, the surface forms of special tokens need to be +# as different as possible to minimize the impact +EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205))) +# changed to use actual index to avoid misconfiguration with vocabulary expansion +SPECIAL_START_ID = 151643 +SPECIAL_TOKENS = tuple( + enumerate( + ( + ( + ENDOFTEXT, + IMSTART, + IMEND, + ) + + EXTRAS + ), + start=SPECIAL_START_ID, + ) +) +SPECIAL_TOKENS_SET = set(t for i, t in SPECIAL_TOKENS) + - """NOTE: This tokenizer will not handle special tokens to avoid injection attacks""" +def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]: + with open(tiktoken_bpe_file, "rb") as f: + contents = f.read() + return { + base64.b64decode(token): int(rank) + for token, rank in (line.split() for line in contents.splitlines() if line) + } + + +class QWenTokenizer(PreTrainedTokenizer): + """QWen tokenizer.""" vocab_files_names = VOCAB_FILES_NAMES @@ -27,137 +64,113 @@ def __init__( self, vocab_file, errors="replace", - max_len=None, - unk_token="<|endoftext|>", - bos_token="<|endoftext|>", - eos_token="<|endoftext|>", - pad_token=None, - add_prefix_space=False, - add_bos_token=False, - add_more_sp_tokens=True, + extra_vocab_file=None, **kwargs, ): - bos_token = ( - AddedToken(bos_token, lstrip=False, rstrip=False) - if isinstance(bos_token, str) - else bos_token - ) - eos_token = ( - AddedToken(eos_token, lstrip=False, rstrip=False) - if isinstance(eos_token, str) - else eos_token - ) - unk_token = ( - AddedToken(unk_token, lstrip=False, rstrip=False) - if isinstance(unk_token, str) - else unk_token - ) - pad_token = ( - AddedToken(pad_token, lstrip=False, rstrip=False) - if isinstance(pad_token, str) - else pad_token - ) - super().__init__( - errors=errors, - unk_token=unk_token, - bos_token=bos_token, - eos_token=eos_token, - pad_token=pad_token, - add_prefix_space=add_prefix_space, - add_bos_token=add_bos_token, - ) - self.add_bos_token = add_bos_token - self.max_len = max_len if max_len is not None else int(1e12) + super().__init__(**kwargs) - self.errors = errors # how to handle errors in decoding + # how to handle errors in decoding UTF-8 byte sequences + # use ignore if you are in streaming inference + self.errors = errors - name = "Qwen" - ENDOFTEXT = "<|endoftext|>" - IMSTART = "<|im_start|>" - IMEND = "<|im_end|>" - if add_more_sp_tokens: - special_tokens = ( - ENDOFTEXT, - IMSTART, - IMEND, - "", - "", - "", - "", - "", - ) + tuple([f"" for i in range(200)]) - else: - special_tokens = (ENDOFTEXT, IMSTART, IMEND) - - PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" - - def load_tiktoken_bpe(tiktoken_bpe_file: str) -> "dict[bytes, int]": - contents = open(tiktoken_bpe_file, "rb").read() - return { - base64.b64decode(token): int(rank) - for token, rank in ( - line.split() for line in contents.splitlines() if line - ) - } - - mergeable_ranks = load_tiktoken_bpe(vocab_file) - special_tokens = { + self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: Dict[bytes, int] + self.special_tokens = { token: index - for index, token in enumerate(special_tokens, start=len(mergeable_ranks)) + for index, token in SPECIAL_TOKENS } - self.special_tokens = special_tokens + + # try load extra vocab from file + if extra_vocab_file is not None: + used_ids = set(self.mergeable_ranks.values()) | set(self.special_tokens.values()) + extra_mergeable_ranks = _load_tiktoken_bpe(extra_vocab_file) + for token, index in extra_mergeable_ranks.items(): + if token in self.mergeable_ranks: + logger.info(f"extra token {token} exists, skipping") + continue + if index in used_ids: + logger.info(f'the index {index} for extra token {token} exists, skipping') + continue + self.mergeable_ranks[token] = index + # the index may be sparse after this, but don't worry tiktoken.Encoding will handle this + enc = tiktoken.Encoding( - name, + "Qwen", pat_str=PAT_STR, - mergeable_ranks=mergeable_ranks, - special_tokens=special_tokens, + mergeable_ranks=self.mergeable_ranks, + special_tokens=self.special_tokens, ) assert ( - len(mergeable_ranks) + len(special_tokens) == enc.n_vocab - ), f"{len(mergeable_ranks) + len(special_tokens)} != {enc.n_vocab} in encoding" + len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab + ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding" - self.mergeable_ranks = mergeable_ranks - self.encoder = self.mergeable_ranks - self.decoder = {v: k for k, v in self.encoder.items()} + self.decoder = { + v: k for k, v in self.mergeable_ranks.items() + } # type: dict[int, bytes|str] self.decoder.update({v: k for k, v in self.special_tokens.items()}) + self.tokenizer = enc # type: tiktoken.Encoding + self.eod_id = self.tokenizer.eot_token - self.im_start_id = special_tokens[IMSTART] - self.im_end_id = special_tokens[IMEND] + self.im_start_id = self.special_tokens[IMSTART] + self.im_end_id = self.special_tokens[IMEND] + self.bos_token_id = SPECIAL_START_ID + self.eos_token_id = self.eod_id + + def __getstate__(self): + # for pickle lovers + state = self.__dict__.copy() + del state["tokenizer"] + return state + + def __setstate__(self, state): + # tokenizer is not python native; don't pass it; rebuild it + self.__dict__.update(state) + enc = tiktoken.Encoding( + "Qwen", + pat_str=PAT_STR, + mergeable_ranks=self.mergeable_ranks, + special_tokens=self.special_tokens, + ) + self.tokenizer = enc - def __len__(self): + def __len__(self) -> int: return self.tokenizer.n_vocab - def get_vocab(self): + def get_vocab(self) -> Dict[bytes, int]: return self.mergeable_ranks - def convert_tokens_to_ids(self, tokens): + def convert_tokens_to_ids( + self, tokens: Union[bytes, str, List[Union[bytes, str]]] + ) -> List[int]: ids = [] - # Remove support for py2 - if isinstance(tokens, str): + if isinstance(tokens, (str, bytes)): if tokens in self.special_tokens: return self.special_tokens[tokens] else: - return self.encoder.get(tokens) + return self.mergeable_ranks.get(tokens) for token in tokens: if token in self.special_tokens: ids.append(self.special_tokens[token]) else: - ids.append(self.encoder.get(token)) - if len(ids) > self.max_len: - logger.warning( - "Token indices sequence length is longer than the specified maximum " - " sequence length for this model ({} > {}). Running this" - " sequence through the model will result in indexing errors".format( - len(ids), self.max_len - ) - ) + ids.append(self.mergeable_ranks.get(token)) return ids + def _add_tokens( + self, + new_tokens: Union[List[str], List[AddedToken]], + special_tokens: bool = False, + ) -> int: + if not special_tokens and new_tokens: + raise ValueError("Adding regular tokens is not supported") + for token in new_tokens: + surface_form = token.content if isinstance(token, AddedToken) else token + if surface_form not in SPECIAL_TOKENS_SET: + raise ValueError("Adding unknown special tokens is not supported") + return 0 + def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]: """ - Save only the vocabulary of the tokenizer (vocabulary + added tokens). - + Save only the vocabulary of the tokenizer (vocabulary). Returns: `Tuple(str)`: Paths to the files saved. """ @@ -168,80 +181,81 @@ def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]: w.write(line) return (file_path,) - def tokenize(self, text: str, **kwargs) -> List[str]: + def tokenize( + self, + text: str, + allowed_special: Union[Set, str] = "all", + disallowed_special: Union[Collection, str] = (), + **kwargs, + ) -> List[Union[bytes, str]]: """ - Converts a string in a sequence of tokens, replacing unknown tokens with the `unk_token`. - + Converts a string in a sequence of tokens. Args: text (`str`): The sequence to be encoded. + allowed_special (`Literal["all"]` or `set`): + The surface forms of the tokens to be encoded as special tokens in regular texts. + Default to "all". + disallowed_special (`Literal["all"]` or `Collection`): + The surface forms of the tokens that should not be in regular texts and trigger errors. + Default to an empty tuple. kwargs (additional keyword arguments, *optional*): Will be passed to the underlying model specific encode method. - Tiktoken allows users to allow the tokenization of special tokens with the following args: - `allowed_special`: set to 'all' or a `set` of special tokens. - `disallowed_special`: set to 'all' or a `Collection` of special tokens. NOT RECOMMENDED, AS IT MAY BE CONFLICTED WITH `allowed_special`. - Returns: - `List[str]`: The list of tokens. + `List[bytes|str]`: The list of tokens. """ tokens = [] text = unicodedata.normalize("NFC", text) - for t in self.tokenizer.encode(text, **kwargs): + # this implementation takes a detour: text -> token id -> token surface forms + for t in self.tokenizer.encode( + text, allowed_special=allowed_special, disallowed_special=disallowed_special + ): tokens.append(self.decoder[t]) - return tokens - def convert_tokens_to_string(self, tokens: List[str]) -> str: + def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str: """ - Converts a sequence of tokens in a single string. The most simple way to do it is `" ".join(tokens)` but we - often want to remove sub-word tokenization artifacts at the same time. + Converts a sequence of tokens in a single string. """ - text = "".join(tokens) - text = bytearray([self.byte_decoder[c] for c in text]).decode( - "utf-8", errors=self.errors - ) + text = "" + temp = b"" + for t in tokens: + if isinstance(t, str): + if temp: + text += temp.decode("utf-8", errors=self.errors) + temp = b"" + text += t + elif isinstance(t, bytes): + temp += t + else: + raise TypeError("token should only be of type types or str") + if temp: + text += temp.decode("utf-8", errors=self.errors) return text - + @property def vocab_size(self): return self.tokenizer.n_vocab - def _convert_id_to_token(self, index: int) -> str: - if index >= self.tokenizer.n_vocab: - return self.unk_token - return self.tokenizer.decode([index]) - - def _convert_token_to_id(self, token: str) -> int: - """Converts a token to an id using the vocab.""" - return self.encoder.get( - token.encode("UTF-8"), - self.tokenizer.encode(self.unk_token, allowed_special="all")[0], - ) - - @property - def all_special_tokens(self) -> List[str]: - """ - `List[str]`: All the special tokens (`''`, `''`, etc.) mapped to class attributes. - - Convert tokens of `tokenizers.AddedToken` type to string. - """ - all_toks = [str(s) for s in self.special_tokens.keys()] - return all_toks - - @property - def all_special_ids(self) -> List[int]: - """ - `List[int]`: List the ids of the special tokens(`''`, `''`, etc.) mapped to class attributes. - """ - all_ids = [v for v in self.special_tokens.values()] - return all_ids - - def _tokenize(self, text, **kwargs): + def _convert_id_to_token(self, index: int) -> Union[bytes, str]: + """Converts an id to a token, special tokens included""" + if index in self.decoder: + return self.decoder[index] + raise ValueError("unknown ids") + + def _convert_token_to_id(self, token: Union[bytes, str]) -> int: + """Converts a token to an id using the vocab, special tokens included""" + if token in self.special_tokens: + return self.special_tokens[token] + if token in self.mergeable_ranks: + return self.mergeable_ranks[token] + raise ValueError("unknown token") + + def _tokenize(self, text: str, **kwargs): """ Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces). - Do NOT take care of added tokens. """ raise NotImplementedError @@ -250,10 +264,11 @@ def _decode( self, token_ids: Union[int, List[int]], skip_special_tokens: bool = False, + errors: str = None, **kwargs, ) -> str: if isinstance(token_ids, int): token_ids = [token_ids] if skip_special_tokens: - token_ids = [i for i in token_ids if i not in self.all_special_ids] - return self.tokenizer.decode(token_ids) + token_ids = [i for i in token_ids if i < self.eod_id] + return self.tokenizer.decode(token_ids, errors=errors or self.errors) \ No newline at end of file diff --git a/qanything_kernel/dependent_server/llm_for_local_serve/tokenizer_assets/tokenizer_config.json b/qanything_kernel/dependent_server/llm_for_local_serve/tokenizer_assets/tokenizer_config.json index 3f10a1f2..a84dfb7d 100755 --- a/qanything_kernel/dependent_server/llm_for_local_serve/tokenizer_assets/tokenizer_config.json +++ b/qanything_kernel/dependent_server/llm_for_local_serve/tokenizer_assets/tokenizer_config.json @@ -1,10 +1,10 @@ { "remove_space": false, "do_lower_case": false, - "tokenizer_class": "QwenTokenizer", + "tokenizer_class": "QWenTokenizer", "auto_map": { "AutoTokenizer": [ - "tokenization_qwen.QwenTokenizer", + "tokenization_qwen.QWenTokenizer", null ] } diff --git a/run.sh b/run.sh index db5f98db..ee1a107a 100644 --- a/run.sh +++ b/run.sh @@ -1,5 +1,80 @@ #!/bin/bash + +### If using OpenAI API, please export the following environments into .env fisrt +echo "OPENAI_API_KEY=''" > .env +echo "OPENAI_API_BASE=''" >> .env +echo "OPENAI_API_MODEL_NAME='gpt-3.5-turbo'" >> .env +echo "OPENAI_API_CONTEXT_LENGTH=4096" >> .env + +script_name=$(basename "$0") + +usage() { + echo "Usage: $script_name [-c ] [-i ] [-b ] [-m ] [-t ] [-p ] [-r ] [-h]" + echo " -c : Options {local, cloud} to specify the llm API mode, default is 'local'. If set to '-c cloud', please mannually set the environments {OPENAI_API_KEY, OPENAI_API_BASE, OPENAI_API_MODEL_NAME, OPENAI_API_CONTEXT_LENGTH} into .env fisrt in run.sh" + echo " -i : Specify argument GPU device_id" + echo " -b : Specify argument LLM inference runtime backend, options={default, hf, vllm}" + echo " -m : Specify argument the path to load LLM model using FastChat serve API, options={Qwen-7B-Chat, deepseek-llm-7b-chat, ...}" + echo " -t : Specify argument the conversation template according to the LLM model when using FastChat serve API, options={qwen-7b-chat, deepseek-chat, ...}" + echo " -p : Use options {1, 2} to set tensor parallel parameters for transformers/vllm backend when using FastChat serve API, default tensor_parallel=1" + echo " -r : Specify argument gpu_memory_utilization (0,1] for vllm backend when using FastChat serve API, default gpu_memory_utilization=0.81" + echo " -h: Display help usage message" + exit 1 +} + +llm_api="local" +device_id="0" +runtime_backend="default" +model_name="" +conv_template="" +tensor_parallel=1 +gpu_memory_utilization=0.81 + +# 解析命令行参数 +while getopts ":c:i:b:m:t:p:r:h" opt; do + case $opt in + c) llm_api=$OPTARG ;; + i) device_id=$OPTARG ;; + b) runtime_backend=$OPTARG ;; + m) model_name=$OPTARG ;; + t) conv_template=$OPTARG ;; + p) tensor_parallel=$OPTARG ;; + r) gpu_memory_utilization=$OPTARG ;; + h) usage ;; + *) usage ;; + esac +done + +if [ $llm_api = 'cloud' ]; then + echo "If set to '-c cloud', please mannually set the environments {OPENAI_API_KEY, OPENAI_API_BASE, OPENAI_API_MODEL_NAME, OPENAI_API_CONTEXT_LENGTH} into .env fisrt in run.sh" + read -p "Please enter OPENAI_API_KEY: " OPENAI_API_KEY + read -p "Please enter OPENAI_API_BASE: " OPENAI_API_BASE + read -p "Please enter OPENAI_API_MODEL_NAME: " OPENAI_API_MODEL_NAME + read -p "Please enter OPENAI_API_CONTEXT_LENGTH: " OPENAI_API_CONTEXT_LENGTH + + echo "OPENAI_API_KEY='$OPENAI_API_KEY'" > .env + echo "OPENAI_API_BASE='$OPENAI_API_BASE'" >> .env + echo "OPENAI_API_MODEL_NAME='$OPENAI_API_MODEL_NAME'" >> .env + echo "OPENAI_API_CONTEXT_LENGTH=$OPENAI_API_CONTEXT_LENGTH" >> .env +fi + +echo "llm_api is set to [$llm_api]" +echo "device_id is set to [$device_id]" +echo "runtime_backend is set to [$runtime_backend]" +echo "model_name is set to [$model_name]" +echo "conv_template is set to [$conv_template]" +echo "tensor_parallel is set to [$tensor_parallel]" +echo "gpu_memory_utilization is set to [$gpu_memory_utilization]" + +# 写入环境变量.env文件 +echo "LLM_API=${llm_api}" >> .env +echo "DEVICE_ID=$device_id" >> .env +echo "RUNTIME_BACKEND=$runtime_backend" >> .env +echo "MODEL_NAME=$model_name" >> .env +echo "CONV_TEMPLATE=$conv_template" >> .env +echo "TP=$tensor_parallel" >> .env +echo "GPU_MEM_UTILI=$gpu_memory_utilization" >> .env + # 检查是否存在 models 文件夹 if [ ! -d "models" ]; then echo "models 文件夹不存在,开始克隆和解压模型..." @@ -93,10 +168,11 @@ user_file="user.config" gpu_id1=0 gpu_id2=0 + # 判断命令行参数 -if [[ -n "$1" ]]; then +if [[ -n "$device_id" ]]; then # 如果传入参数,分割成两个GPU ID - IFS=',' read -ra gpu_ids <<< "$1" + IFS=',' read -ra gpu_ids <<< "$device_id" gpu_id1=${gpu_ids[0]} gpu_id2=${gpu_ids[1]:-$gpu_id1} # 如果没有第二个ID,则默认使用第一个ID fi @@ -107,6 +183,9 @@ if ! [[ $gpu_id1 =~ ^[0-9]+$ ]] || ! [[ $gpu_id2 =~ ^[0-9]+$ ]]; then exit 1 fi +echo "GPUID1=${gpu_id1}" >> .env +echo "GPUID2=${gpu_id2}" >> .env + # 检查是否存在用户文件 if [[ -f "$user_file" ]]; then # 读取上次的配置 @@ -144,11 +223,6 @@ fi # echo "The file $env_file has been updated with the following configuration:" # grep "VITE_APP_API_HOST" "$env_file" -# 创建.env文件并写入环境变量 -echo "GPUID1=${gpu_id1}" > .env -echo "GPUID2=${gpu_id2}" >> .env - - if [ -e /proc/version ]; then if grep -qi microsoft /proc/version || grep -qi MINGW /proc/version; then if grep -qi microsoft /proc/version; then diff --git a/scripts/run_for_cloud_option.sh b/scripts/run_for_cloud_option.sh new file mode 100755 index 00000000..3e9447ef --- /dev/null +++ b/scripts/run_for_cloud_option.sh @@ -0,0 +1,303 @@ +#!/bin/bash + +script_name=$(basename "$0") + +usage() { + echo "Usage: $script_name [-c ] [-i ] [-b ] [-m ] [-t ] [-p ] [-r ] [-h]" + echo " -c : Options {local, cloud} to specify the llm API mode, default is 'local'. If set to '-c cloud', please mannually set the environments {OPENAI_API_KEY, OPENAI_API_BASE, OPENAI_API_MODEL_NAME, OPENAI_API_CONTEXT_LENGTH} into .env fisrt in run.sh" + echo " -i : Specify argument GPU device_id" + echo " -b : Specify argument LLM inference runtime backend, options={default, hf, vllm}" + echo " -m : Specify argument the path to load LLM model using FastChat serve API, options={Qwen-7B-Chat, deepseek-llm-7b-chat, ...}" + echo " -t : Specify argument the conversation template according to the LLM model when using FastChat serve API, options={qwen-7b-chat, deepseek-chat, ...}" + echo " -p : Use options {1, 2} to set tensor parallel parameters for transformers/vllm backend when using FastChat serve API, default tensor_parallel=1" + echo " -r : Specify argument gpu_memory_utilization (0,1] for vllm backend when using FastChat serve API, default gpu_memory_utilization=0.81" + echo " -h: Display help usage message" + exit 1 +} + +llm_api="local" +device_id="0" +runtime_backend="default" +model_name="" +conv_template="" +tensor_parallel=1 +gpu_memory_utilization=0.81 + +# 解析命令行参数 +while getopts ":c:i:b:m:t:p:r:h" opt; do + case $opt in + c) llm_api=$OPTARG ;; + i) device_id=$OPTARG ;; + b) runtime_backend=$OPTARG ;; + m) model_name=$OPTARG ;; + t) conv_template=$OPTARG ;; + p) tensor_parallel=$OPTARG ;; + r) gpu_memory_utilization=$OPTARG ;; + h) usage ;; + *) usage ;; + esac +done + +echo "llm_api is set to [$llm_api]" +echo "device_id is set to [$device_id]" +echo "runtime_backend is set to [$runtime_backend]" +echo "model_name is set to [$model_name]" +echo "conv_template is set to [$conv_template]" +echo "tensor_parallel is set to [$tensor_parallel]" +echo "gpu_memory_utilization is set to [$gpu_memory_utilization]" + +check_folder_existence() { + if [ ! -d "/model_repos/CustomLLM/$1" ]; then + echo "The $1 folder does not exist under QAnything/assets/custom_models/. Please check your setup." + echo "在QAnything/assets/custom_models/下不存在$1文件夹。请检查您的模型文件。" + exit 1 + fi +} + +start_time=$(date +%s) # 记录开始时间 + +# cd ~/.cache/ && tar -xvf /workspace/qanything_local/pip_deps.tar +# cd /workspace/qanything_local/third_party/FastChat && pip install transformers==4.36.0 vllm==0.2.7 transformers-stream-generator==0.0.4 einops==0.6.0 accelerate==0.21.0 && pip install -e . + +# 获取默认的 MD5 校验和 +default_checksum=$(cat /workspace/qanything_local/third_party/checksum.config) + +# 计算FastChat文件夹下所有文件的 MD5 校验和 +checksum=$(find /workspace/qanything_local/third_party/FastChat -type f -exec md5sum {} + | awk '{print $1}' | sort | md5sum | awk '{print $1}') + +echo "checksum $checksum" +echo "default_checksum $default_checksum" +# 检查两个校验和是否相等,如果不相等则表示 third_party/FastChat/fastchat/conversation.py 注册了新的 conv_template, 则需重新安装依赖 +if [ "$default_checksum" != "$checksum" ]; then + cd /workspace/qanything_local/third_party/FastChat && pip install transformers==4.36.0 vllm==0.2.7 transformers-stream-generator==0.0.4 einops==0.6.0 accelerate==0.21.0 && pip install -e . + checksum=$(find /workspace/qanything_local/third_party/FastChat -type f -exec md5sum {} + | awk '{print $1}' | sort | md5sum | awk '{print $1}') && echo "$checksum" > /workspace/qanything_local/third_party/checksum.config +fi + +install_deps=$(pip list | grep vllm) +if [[ "$install_deps" != *"vllm"* ]]; then + echo "vllm deps not found" + cd /workspace/qanything_local/third_party/FastChat && pip install transformers==4.36.0 vllm==0.2.7 transformers-stream-generator==0.0.4 einops==0.6.0 accelerate==0.21.0 && pip install -e . + checksum=$(find /workspace/qanything_local/third_party/FastChat -type f -exec md5sum {} + | awk '{print $1}' | sort | md5sum | awk '{print $1}') && echo "$checksum" > /workspace/qanything_local/third_party/checksum.config +fi + +mkdir -p /model_repos/QAEnsemble_embed /model_repos/QAEnsemble_rerank /model_repos/monitor_logs +if [ ! -L "/model_repos/QAEnsemble_embed/embed" ]; then + cd /model_repos/QAEnsemble_embed && ln -s /model_repos/QAEnsemble/embed . +fi + +if [ ! -L "/model_repos/QAEnsemble_rerank/rerank" ]; then + cd /model_repos/QAEnsemble_rerank && ln -s /model_repos/QAEnsemble/rerank . +fi + +# 设置默认值 +default_gpu_id1=0 +default_gpu_id2=0 + +# 检查环境变量GPUID1是否存在,并读取其值或使用默认值 +if [ -z "${GPUID1}" ]; then + gpuid1=$default_gpu_id1 +else + gpuid1=${GPUID1} +fi + +# 检查环境变量GPUID2是否存在,并读取其值或使用默认值 +if [ -z "${GPUID2}" ]; then + gpuid2=$default_gpu_id2 +else + gpuid2=${GPUID2} +fi +echo "GPU ID: $gpuid1, $gpuid2" + +echo "The triton server will start on $gpuid1 GPU" + +CUDA_VISIBLE_DEVICES=$gpuid1 nohup /opt/tritonserver/bin/tritonserver --model-store=/model_repos/QAEnsemble_embed --http-port=9000 --grpc-port=9001 --metrics-port=9002 --log-verbose=1 > /model_repos/QAEnsemble_embed/QAEnsemble_embed.log 2>&1 & +CUDA_VISIBLE_DEVICES=$gpuid1 nohup /opt/tritonserver/bin/tritonserver --model-store=/model_repos/QAEnsemble_rerank --http-port=8000 --grpc-port=8001 --metrics-port=8002 --log-verbose=1 > /model_repos/QAEnsemble_rerank/QAEnsemble_rerank.log 2>&1 & + +echo "RERANK_PORT=8001" >> /workspace/qanything_local/.env +echo "EMBED_PORT=9001" >> /workspace/qanything_local/.env + +# 默认ocr_use_gpu为True +OCR_USE_GPU="True" + +# 使用nvidia-smi命令获取GPU的显存大小(以MiB为单位) +GPU1_MEMORY_SIZE=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits -i $gpuid1) +GPU2_MEMORY_SIZE=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits -i $gpuid2) + +# 检查显存大小是否小于12G(即 12288 MiB) +if [ "$GPU2_MEMORY_SIZE" -lt 12288 ]; then + OCR_USE_GPU="False" +fi + +echo "GPU1_MEMORY_SIZE=$GPU1_MEMORY_SIZE" +echo "OCR_USE_GPU=$OCR_USE_GPU" + +echo "====================================================" +echo "******************** 重要提示 ********************" +echo "====================================================" +echo "" + +if [ "$GPU1_MEMORY_SIZE" -lt 8100 ]; then + echo "检测到您的 GPU 显存小于 8GB,推荐使用 OpenAI 或其他在线大型语言模型 (LLM)。" +elif [ "$GPU1_MEMORY_SIZE" -ge 8100 ] && [ "$GPU1_MEMORY_SIZE" -le 16400 ]; then + echo "检测到您的 GPU 显存在 8GB 到 16GB 之间,推荐使用本地 3B 大小以内的语言模型。" +else + echo "检测到您的 GPU 显存大于 16GB,推荐使用本地 7B 的语言模型。" +fi + +echo "" +echo "====================================================" +echo "请根据您的显存情况选择合适的语言模型以获得最佳性能。" +echo "====================================================" +echo "" +sleep 5 + + + +cd /workspace/qanything_local || exit +nohup python3 -u qanything_kernel/dependent_server/rerank_for_local_serve/rerank_server.py > rerank.log 2>&1 & +echo "The rerank service is ready! (2/8)" +echo "rerank服务已就绪! (2/8)" + +CUDA_VISIBLE_DEVICES=$gpuid2 nohup python3 -u qanything_kernel/dependent_server/ocr_serve/ocr_server.py $OCR_USE_GPU > ocr.log 2>&1 & +echo "The ocr service is ready! (3/8)" +echo "OCR服务已就绪! (3/8)" + +nohup python3 -u qanything_kernel/qanything_server/sanic_api.py --mode "online" > api.log 2>&1 & +echo "The qanything backend service is ready! (4/8)" +echo "qanything后端服务已就绪! (4/8)" + + +timeout_time=300 # npm下载超时时间300秒,triton_server启动超时时间600秒 + +env_file="/workspace/qanything_local/front_end/.env.production" +user_file="/workspace/qanything_local/user.config" +user_ip=$(cat "$user_file") +# 读取env_file的第一行 +current_host=$(grep VITE_APP_API_HOST "$env_file") +user_host="VITE_APP_API_HOST=http://$user_ip:8777" +# 检查current_host与user_host是否相同 +if [ "$current_host" != "$user_host" ]; then + # 使用 sed 命令更新 VITE_APP_API_HOST 的值 + sed -i "s|VITE_APP_API_HOST=.*|$user_host|" "$env_file" + echo "The file $env_file has been updated with the following configuration:" + grep "VITE_APP_API_HOST" "$env_file" +fi + +# 转到 front_end 目录 +cd /workspace/qanything_local/front_end || exit +# 安装依赖 +echo "Waiting for [npm run install](5/8)" +timeout $timeout_time npm install +if [ $? -eq 0 ]; then + echo "[npm run install] Installed successfully(5/8)" +elif [ $? -eq 124 ]; then + echo "npm install 下载超时,可能是网络问题,请修改 npm 代理。" + exit 1 +else + echo "Failed to install npm dependencies." + exit 1 +fi + +# 构建前端项目 +echo "Waiting for [npm run build](6/8)" +npm run build +if [ $? -eq 0 ]; then + echo "[npm run build] build successfully(6/8)" +else + echo "Failed to build the front end." + exit 1 +fi + +# 启动前端页面服务 +nohup npm run serve 1>npm_server.log 2>&1 & + +# 监听前端页面服务 +tail -f npm_server.log & + +front_end_start_time=$(date +%s) + +while ! grep -q "Local:" npm_server.log; do + echo "Waiting for the front-end service to start..." + echo "等待启动前端服务" + sleep 1 + + # 获取当前时间并计算经过的时间 + current_time=$(date +%s) + elapsed_time=$((current_time - front_end_start_time)) + + # 检查是否超时 + if [ $elapsed_time -ge 120 ]; then + echo "启动前端服务超时,请检查日志文件 front_end/npm_server.log 获取更多信息。" + exit 1 + fi +done +echo "The front-end service is ready!...(7/8)" +echo "前端服务已就绪!...(7/8)" + +current_time=$(date +%s) +elapsed=$((current_time - start_time)) # 计算经过的时间(秒) +echo "Time elapsed: ${elapsed} seconds." +echo "已耗时: ${elapsed} 秒." + +while true; do + + if [ "$runtime_backend" = "default" ]; then + response_embed=$(curl -s -w "%{http_code}" http://localhost:9000/v2/health/ready -o /dev/null) + response_rerank=$(curl -s -w "%{http_code}" http://localhost:8000/v2/health/ready -o /dev/null) + echo "health response_embed = $response_embed" + echo "health response_rerank = $response_rerank" + + if [ $response_embed -eq 200 ] && [ $response_rerank -eq 200 ]; then + echo "The llm service is ready!, now you can use the qanything service. (8/8)" + echo "LLM 服务已准备就绪!现在您可以使用qanything服务。(8/8)" + break + else + echo "The llm service is starting up, it can be long... you have time to make a coffee :)" + echo "LLM 服务正在启动,可能需要一段时间...你有时间去冲杯咖啡 :)" + + current_time=$(date +%s) + elapsed_time=$((current_time - start_time)) + + # 检查是否超时 + if [ $elapsed_time -ge $((timeout_time * 2)) ]; then + echo "启动 LLM 服务超时,请进入容器内检查/model_repos/QAEnsemble_base/QAEnsemble_base.log以获取更多信息。" + exit 1 + fi + sleep 5 + fi + else + fastchat_register_model=$(curl --request POST --url http://localhost:7800/list_models) + if [[ $fastchat_register_model == *"$LLM_API_SERVE_MODEL"* ]]; then + echo "The llm service is ready!, now you can use the qanything service. (8/8)" + echo "LLM 服务已准备就绪!现在您可以使用qanything服务。(8/8)" + break + else + echo "The llm service is starting up, it can be long... you have time to make a coffee :)" + echo "LLM 服务正在启动,可能需要一段时间...你有时间去冲杯咖啡 :)" + + current_time=$(date +%s) + elapsed_time=$((current_time - start_time)) + + # 检查是否超时 + if [ $elapsed_time -ge $((timeout_time * 2)) ]; then + echo "启动 LLM 服务超时,请进入容器内检查/workspace/qanything_local/qanything_logs/fastchat_logs/fschat_model_worker_7801.log以获取更多信息。" + exit 1 + fi + sleep 5 + fi + fi +done + +current_time=$(date +%s) +elapsed=$((current_time - start_time)) # 计算经过的时间(秒) +echo "Time elapsed: ${elapsed} seconds." +echo "已耗时: ${elapsed} 秒." +echo "Please visit the front-end service at [http://$user_ip:5052/qanything/] to conduct Q&A." +echo "请在[http://$user_ip:5052/qanything/]下访问前端服务来进行问答,如果前端报错,请在浏览器按F12以获取更多报错信息" + +# 保持容器运行 +while true; do + sleep 2 +done + + diff --git a/scripts/run_for_local_fastchat.sh b/scripts/run_for_local_fastchat.sh new file mode 100755 index 00000000..e660a823 --- /dev/null +++ b/scripts/run_for_local_fastchat.sh @@ -0,0 +1,203 @@ +#!/bin/bash + +start_time=$(date +%s) # 记录开始时间 +# rm -rf ~/.cache/pip && cd ~/.cache/ && cp /workspace/qanything_local/pip_deps.zip . && unzip /workspace/qanything_local/pip_deps.zip && mv root/.cache/pip . && rm -rf root +# rm -rf ~/.cache/pip && cd ~/.cache/ && tar -xvf /workspace/qanything_local/pip_deps.tar +# cd ~/.cache/ && tar -xvf /workspace/qanything_local/pip_deps.tar +cd /workspace/qanything_local/third_party/FastChat && pip install transformers==4.36.0 vllm==0.2.7 transformers-stream-generator==0.0.4 einops==0.6.0 accelerate==0.21.0 && pip install -e . + +mkdir -p /model_repos/QAEnsemble_base /model_repos/QAEnsemble_embed /model_repos/QAEnsemble_rerank /model_repos/monitor_logs +if [ ! -L "/model_repos/QAEnsemble_base/base" ]; then + cd /model_repos/QAEnsemble_base && ln -s /model_repos/QAEnsemble/base . +fi + +if [ ! -L "/model_repos/QAEnsemble_embed/embed" ]; then + cd /model_repos/QAEnsemble_embed && ln -s /model_repos/QAEnsemble/embed . +fi + +if [ ! -L "/model_repos/QAEnsemble_rerank/rerank" ]; then + cd /model_repos/QAEnsemble_rerank && ln -s /model_repos/QAEnsemble/rerank . +fi + +# 设置默认值 +default_gpu_id1=0 +default_gpu_id2=0 + +# 检查环境变量GPUID1是否存在,并读取其值或使用默认值 +if [ -z "${GPUID1}" ]; then + gpuid1=$default_gpu_id1 +else + gpuid1=${GPUID1} +fi + +# 检查环境变量GPUID2是否存在,并读取其值或使用默认值 +if [ -z "${GPUID2}" ]; then + gpuid2=$default_gpu_id2 +else + gpuid2=${GPUID2} +fi +echo "GPU ID: $gpuid1, $gpuid2" + +# LLM_API_SERVE_CONV_TEMPLATE="qwen-7b-qanything" +# LLM_API_SERVE_MODEL="Qwen-7B-Chat-8K" +LLM_API_SERVE_CONV_TEMPLATE="minichat" +LLM_API_SERVE_MODEL="MiniMA-d0119-v100-6k" +echo "LLM_API_SERVE_PORT=7802" >> /workspace/qanything_local/.env +echo "LLM_API_SERVE_MODEL=$LLM_API_SERVE_MODEL" >> /workspace/qanything_local/.env +echo "LLM_API_SERVE_CONV_TEMPLATE=$LLM_API_SERVE_CONV_TEMPLATE" >> /workspace/qanything_local/.env + +# CUDA_VISIBLE_DEVICES=0 nohup python3 -m fastchat.serve.model_worker --host 0.0.0.0 --port 7801 --controller-address http://0.0.0.0:7800 --worker-address http://0.0.0.0:7801 --model-path /model_repos/QAEnsemble/Qwen-7B-Chat-8K --load-8bit --dtype bfloat16 --conv-template qwen-7b-qanything > fschat_model_worker_7801.log 2>&1 & +mkdir -p /workspace/qanything_local/qanything_logs/fastchat_logs && cd /workspace/qanything_local/qanything_logs/fastchat_logs +nohup python3 -m fastchat.serve.controller --host 0.0.0.0 --port 7800 > fschat_controller_7800.log 2>&1 & +CUDA_VISIBLE_DEVICES=$gpuid1 nohup python3 -m fastchat.serve.model_worker --host 0.0.0.0 --port 7801 \ + --controller-address http://0.0.0.0:7800 --worker-address http://0.0.0.0:7801 \ + --model-path /model_repos/QAEnsemble/$LLM_API_SERVE_MODEL --load-8bit \ + --dtype bfloat16 --conv-template $LLM_API_SERVE_CONV_TEMPLATE > fschat_model_worker_7801.log 2>&1 & +nohup python3 -m fastchat.serve.openai_api_server --host 0.0.0.0 --port 7802 --controller-address http://0.0.0.0:7800 > fschat_openai_api_server_7802.log 2>&1 & + +### curl http://localhost:7802/v1/models + +# start llm server +# 判断一下,如果gpu_id1和gpu_id2相同,则只启动一个triton_server +if [ $gpuid1 -eq $gpuid2 ]; then + echo "The triton server will start on $gpuid1 GPU" + CUDA_VISIBLE_DEVICES=$gpuid1 nohup /opt/tritonserver/bin/tritonserver --model-store=/model_repos/QAEnsemble --http-port=10000 --grpc-port=10001 --metrics-port=10002 --log-verbose=1 > /model_repos/QAEnsemble/QAEnsemble.log 2>&1 & + echo "RERANK_PORT=10001" >> /workspace/qanything_local/.env + echo "EMBED_PORT=10001" >> /workspace/qanything_local/.env +else + echo "The triton server will start on $gpuid1 and $gpuid2 GPUs" + + # CUDA_VISIBLE_DEVICES=$gpuid1 nohup /opt/tritonserver/bin/tritonserver --model-store=/model_repos/QAEnsemble_base --http-port=10000 --grpc-port=10001 --metrics-port=10002 --log-verbose=1 > /model_repos/QAEnsemble_base/QAEnsemble_base.log 2>&1 & + CUDA_VISIBLE_DEVICES=$gpuid2 nohup /opt/tritonserver/bin/tritonserver --model-store=/model_repos/QAEnsemble_embed --http-port=9000 --grpc-port=9001 --metrics-port=9002 --log-verbose=1 > /model_repos/QAEnsemble_embed/QAEnsemble_embed.log 2>&1 & + CUDA_VISIBLE_DEVICES=$gpuid2 nohup /opt/tritonserver/bin/tritonserver --model-store=/model_repos/QAEnsemble_rerank --http-port=8000 --grpc-port=8001 --metrics-port=8002 --log-verbose=1 > /model_repos/QAEnsemble_rerank/QAEnsemble_rerank.log 2>&1 & + + echo "RERANK_PORT=8001" >> /workspace/qanything_local/.env + echo "EMBED_PORT=9001" >> /workspace/qanything_local/.env +fi + +# cd /workspace/qanything_local/qanything_kernel/dependent_server/llm_for_local_serve || exit +# nohup python3 -u llm_server_entrypoint.py --host="0.0.0.0" --port=36001 --model-path="tokenizer_assets" --model-url="0.0.0.0:10001" > llm.log 2>&1 & +# echo "The llm transfer service is ready! (1/8)" +# echo "大模型中转服务已就绪! (1/8)" + +cd /workspace/qanything_local || exit +nohup python3 -u qanything_kernel/dependent_server/rerank_for_local_serve/rerank_server.py > rerank.log 2>&1 & +echo "The rerank service is ready! (2/8)" +echo "rerank服务已就绪! (2/8)" + +CUDA_VISIBLE_DEVICES=$gpuid2 nohup python3 -u qanything_kernel/dependent_server/ocr_serve/ocr_server.py > ocr.log 2>&1 & +echo "The ocr service is ready! (3/8)" +echo "OCR服务已就绪! (3/8)" + +nohup python3 -u qanything_kernel/qanything_server/sanic_api.py > api.log 2>&1 & +echo "The qanything backend service is ready! (4/8)" +echo "qanything后端服务已就绪! (4/8)" + + +timeout_time=300 # npm下载超时时间300秒,triton_server启动超时时间600秒 + +env_file="/workspace/qanything_local/front_end/.env.production" +user_file="/workspace/qanything_local/user.config" +user_ip=$(cat "$user_file") +# 读取env_file的第一行 +current_host=$(grep VITE_APP_API_HOST "$env_file") +user_host="VITE_APP_API_HOST=http://$user_ip:8777" +# 检查current_host与user_host是否相同 +#if [ "$current_host" != "$user_host" ]; then +# # 使用 sed 命令更新 VITE_APP_API_HOST 的值 +# sed -i "s|VITE_APP_API_HOST=.*|$user_host|" "$env_file" +# echo "The file $env_file has been updated with the following configuration:" +# grep "VITE_APP_API_HOST" "$env_file" +#fi + +# 转到 front_end 目录 +cd /workspace/qanything_local/front_end || exit +# 安装依赖 +echo "Waiting for [npm run install](5/8)" +timeout $timeout_time npm install +if [ $? -eq 0 ]; then + echo "[npm run install] Installed successfully(5/8)" +elif [ $? -eq 124 ]; then + echo "npm install 下载超时,可能是网络问题,请修改 npm 代理。" + exit 1 +else + echo "Failed to install npm dependencies." + exit 1 +fi + +# 构建前端项目 +echo "Waiting for [npm run build](6/8)" +npm run build +if [ $? -eq 0 ]; then + echo "[npm run build] build successfully(6/8)" +else + echo "Failed to build the front end." + exit 1 +fi + +# 启动前端页面服务 +nohup npm run serve 1>npm_server.log 2>&1 & + +# 监听前端页面服务 +tail -f npm_server.log & + +front_end_start_time=$(date +%s) + +while ! grep -q "Local:" npm_server.log; do + echo "Waiting for the front-end service to start..." + echo "等待启动前端服务" + sleep 1 + + # 获取当前时间并计算经过的时间 + current_time=$(date +%s) + elapsed_time=$((current_time - front_end_start_time)) + + # 检查是否超时 + if [ $elapsed_time -ge 120 ]; then + echo "启动前端服务超时,请检查日志文件 front_end/npm_server.log 获取更多信息。" + exit 1 + fi +done +echo "The front-end service is ready!...(7/8)" +echo "前端服务已就绪!...(7/8)" + +current_time=$(date +%s) +elapsed=$((current_time - start_time)) # 计算经过的时间(秒) +echo "Time elapsed: ${elapsed} seconds." +echo "已耗时: ${elapsed} 秒." + +# while true; do +# response=$(curl -s -w "%{http_code}" http://localhost:10000/v2/health/ready -o /dev/null) +# if [ $response -eq 200 ]; then +# echo "The triton service is ready!, now you can use the qanything service. (8/8)" +# echo "Triton服务已准备就绪!现在您可以使用qanything服务。(8/8)" +# break +# else +# echo "The triton service is starting up, it can be long... you have time to make a coffee :)" +# echo "Triton服务正在启动,可能需要一段时间...你有时间去冲杯咖啡 :)" + +# current_time=$(date +%s) +# elapsed_time=$((current_time - start_time)) + +# # 检查是否超时 +# if [ $elapsed_time -ge $((timeout_time * 2)) ]; then +# echo "启动Triton服务超时,请进入容器内检查/model_repos/QAEnsemble_base/QAEnsemble_base.log以获取更多信息。" +# exit 1 +# fi +# sleep 5 +# fi +# done + +current_time=$(date +%s) +elapsed=$((current_time - start_time)) # 计算经过的时间(秒) +echo "Time elapsed: ${elapsed} seconds." +echo "已耗时: ${elapsed} 秒." +echo "Please visit the front-end service at [http://$user_ip:5052/qanything/] to conduct Q&A." +echo "请在[http://$user_ip:5052/qanything/]下访问前端服务来进行问答,如果前端报错,请在浏览器按F12以获取更多报错信息" + +# 保持容器运行 +while true; do + sleep 2 +done + + diff --git a/scripts/run_for_local_option.sh b/scripts/run_for_local_option.sh new file mode 100755 index 00000000..4564d32d --- /dev/null +++ b/scripts/run_for_local_option.sh @@ -0,0 +1,391 @@ +#!/bin/bash + +script_name=$(basename "$0") + +usage() { + echo "Usage: $script_name [-c ] [-i ] [-b ] [-m ] [-t ] [-p ] [-r ] [-h]" + echo " -c : Options {local, cloud} to specify the llm API mode, default is 'local'. If set to '-c cloud', please mannually set the environments {OPENAI_API_KEY, OPENAI_API_BASE, OPENAI_API_MODEL_NAME, OPENAI_API_CONTEXT_LENGTH} into .env fisrt in run.sh" + echo " -i : Specify argument GPU device_id" + echo " -b : Specify argument LLM inference runtime backend, options={default, hf, vllm}" + echo " -m : Specify argument the path to load LLM model using FastChat serve API, options={Qwen-7B-Chat, deepseek-llm-7b-chat, ...}" + echo " -t : Specify argument the conversation template according to the LLM model when using FastChat serve API, options={qwen-7b-chat, deepseek-chat, ...}" + echo " -p : Use options {1, 2} to set tensor parallel parameters for transformers/vllm backend when using FastChat serve API, default tensor_parallel=1" + echo " -r : Specify argument gpu_memory_utilization (0,1] for vllm backend when using FastChat serve API, default gpu_memory_utilization=0.81" + echo " -h: Display help usage message" + exit 1 +} + +llm_api="local" +device_id="0" +runtime_backend="default" +model_name="" +conv_template="" +tensor_parallel=1 +gpu_memory_utilization=0.81 + +# 解析命令行参数 +while getopts ":c:i:b:m:t:p:r:h" opt; do + case $opt in + c) llm_api=$OPTARG ;; + i) device_id=$OPTARG ;; + b) runtime_backend=$OPTARG ;; + m) model_name=$OPTARG ;; + t) conv_template=$OPTARG ;; + p) tensor_parallel=$OPTARG ;; + r) gpu_memory_utilization=$OPTARG ;; + h) usage ;; + *) usage ;; + esac +done + +echo "llm_api is set to [$llm_api]" +echo "device_id is set to [$device_id]" +echo "runtime_backend is set to [$runtime_backend]" +echo "model_name is set to [$model_name]" +echo "conv_template is set to [$conv_template]" +echo "tensor_parallel is set to [$tensor_parallel]" +echo "gpu_memory_utilization is set to [$gpu_memory_utilization]" + +check_folder_existence() { + if [ ! -d "/model_repos/CustomLLM/$1" ]; then + echo "The $1 folder does not exist under QAnything/assets/custom_models/. Please check your setup." + echo "在QAnything/assets/custom_models/下不存在$1文件夹。请检查您的模型文件。" + exit 1 + fi +} + +start_time=$(date +%s) # 记录开始时间 + +# cd ~/.cache/ && tar -xvf /workspace/qanything_local/pip_deps.tar +# cd /workspace/qanything_local/third_party/FastChat && pip install transformers==4.36.0 vllm==0.2.7 transformers-stream-generator==0.0.4 einops==0.6.0 accelerate==0.21.0 && pip install -e . + +# 获取默认的 MD5 校验和 +default_checksum=$(cat /workspace/qanything_local/third_party/checksum.config) +# 计算FastChat文件夹下所有文件的 MD5 校验和 +checksum=$(find /workspace/qanything_local/third_party/FastChat -type f -exec md5sum {} + | awk '{print $1}' | sort | md5sum | awk '{print $1}') +echo "checksum $checksum" +echo "default_checksum $default_checksum" +# 检查两个校验和是否相等,如果不相等则表示 third_party/FastChat/fastchat/conversation.py 注册了新的 conv_template, 则需重新安装依赖 +if [ "$default_checksum" != "$checksum" ]; then + cd /workspace/qanything_local/third_party/FastChat && pip install transformers==4.36.0 vllm==0.2.7 transformers-stream-generator==0.0.4 einops==0.6.0 accelerate==0.21.0 && pip install -e . + checksum=$(find /workspace/qanything_local/third_party/FastChat -type f -exec md5sum {} + | awk '{print $1}' | sort | md5sum | awk '{print $1}') && echo "$checksum" > /workspace/qanything_local/third_party/checksum.config +fi + +install_deps=$(pip list | grep vllm) +if [[ "$install_deps" != *"vllm"* ]]; then + echo "vllm deps not found" + cd /workspace/qanything_local/third_party/FastChat && pip install transformers==4.36.0 vllm==0.2.7 transformers-stream-generator==0.0.4 einops==0.6.0 accelerate==0.21.0 && pip install -e . + checksum=$(find /workspace/qanything_local/third_party/FastChat -type f -exec md5sum {} + | awk '{print $1}' | sort | md5sum | awk '{print $1}') && echo "$checksum" > /workspace/qanything_local/third_party/checksum.config +fi + +mkdir -p /model_repos/QAEnsemble_base /model_repos/QAEnsemble_embed /model_repos/QAEnsemble_rerank /model_repos/monitor_logs +if [ ! -L "/model_repos/QAEnsemble_base/base" ]; then + cd /model_repos/QAEnsemble_base && ln -s /model_repos/QAEnsemble/base . +fi + +if [ ! -L "/model_repos/QAEnsemble_embed/embed" ]; then + cd /model_repos/QAEnsemble_embed && ln -s /model_repos/QAEnsemble/embed . +fi + +if [ ! -L "/model_repos/QAEnsemble_rerank/rerank" ]; then + cd /model_repos/QAEnsemble_rerank && ln -s /model_repos/QAEnsemble/rerank . +fi + +# 设置默认值 +default_gpu_id1=0 +default_gpu_id2=0 + +# 检查环境变量GPUID1是否存在,并读取其值或使用默认值 +if [ -z "${GPUID1}" ]; then + gpuid1=$default_gpu_id1 +else + gpuid1=${GPUID1} +fi + +# 检查环境变量GPUID2是否存在,并读取其值或使用默认值 +if [ -z "${GPUID2}" ]; then + gpuid2=$default_gpu_id2 +else + gpuid2=${GPUID2} +fi +echo "GPU ID: $gpuid1, $gpuid2" + + +if [ "$runtime_backend" = "default" ]; then + echo "Executing default FastTransformer runtime_backend" + # start llm server + # 判断一下,如果gpu_id1和gpu_id2相同,则只启动一个triton_server + if [ $gpuid1 -eq $gpuid2 ]; then + echo "The triton server will start on $gpuid1 GPU" + CUDA_VISIBLE_DEVICES=$gpuid1 nohup /opt/tritonserver/bin/tritonserver --model-store=/model_repos/QAEnsemble --http-port=10000 --grpc-port=10001 --metrics-port=10002 --log-verbose=1 > /model_repos/QAEnsemble/QAEnsemble.log 2>&1 & + echo "RERANK_PORT=10001" >> /workspace/qanything_local/.env + echo "EMBED_PORT=10001" >> /workspace/qanything_local/.env + else + echo "The triton server will start on $gpuid1 and $gpuid2 GPUs" + + CUDA_VISIBLE_DEVICES=$gpuid1 nohup /opt/tritonserver/bin/tritonserver --model-store=/model_repos/QAEnsemble_base --http-port=10000 --grpc-port=10001 --metrics-port=10002 --log-verbose=1 > /model_repos/QAEnsemble_base/QAEnsemble_base.log 2>&1 & + CUDA_VISIBLE_DEVICES=$gpuid2 nohup /opt/tritonserver/bin/tritonserver --model-store=/model_repos/QAEnsemble_embed --http-port=9000 --grpc-port=9001 --metrics-port=9002 --log-verbose=1 > /model_repos/QAEnsemble_embed/QAEnsemble_embed.log 2>&1 & + CUDA_VISIBLE_DEVICES=$gpuid2 nohup /opt/tritonserver/bin/tritonserver --model-store=/model_repos/QAEnsemble_rerank --http-port=8000 --grpc-port=8001 --metrics-port=8002 --log-verbose=1 > /model_repos/QAEnsemble_rerank/QAEnsemble_rerank.log 2>&1 & + + echo "RERANK_PORT=8001" >> /workspace/qanything_local/.env + echo "EMBED_PORT=9001" >> /workspace/qanything_local/.env + fi + + cd /workspace/qanything_local/qanything_kernel/dependent_server/llm_for_local_serve || exit + nohup python3 -u llm_server_entrypoint.py --host="0.0.0.0" --port=36001 --model-path="tokenizer_assets" --model-url="0.0.0.0:10001" > llm.log 2>&1 & + echo "The llm transfer service is ready! (1/8)" + echo "大模型中转服务已就绪! (1/8)" +else + echo "The triton server for embedding and reranker will start on $gpuid2 GPUs" + CUDA_VISIBLE_DEVICES=$gpuid2 nohup /opt/tritonserver/bin/tritonserver --model-store=/model_repos/QAEnsemble_embed --http-port=9000 --grpc-port=9001 --metrics-port=9002 --log-verbose=1 > /model_repos/QAEnsemble_embed/QAEnsemble_embed.log 2>&1 & + CUDA_VISIBLE_DEVICES=$gpuid2 nohup /opt/tritonserver/bin/tritonserver --model-store=/model_repos/QAEnsemble_rerank --http-port=8000 --grpc-port=8001 --metrics-port=8002 --log-verbose=1 > /model_repos/QAEnsemble_rerank/QAEnsemble_rerank.log 2>&1 & + echo "RERANK_PORT=8001" >> /workspace/qanything_local/.env + echo "EMBED_PORT=9001" >> /workspace/qanything_local/.env + + LLM_API_SERVE_CONV_TEMPLATE="$conv_template" + LLM_API_SERVE_MODEL="$model_name" + + check_folder_existence "$LLM_API_SERVE_MODEL" + + echo "LLM_API_SERVE_PORT=7802" >> /workspace/qanything_local/.env + echo "LLM_API_SERVE_MODEL=$LLM_API_SERVE_MODEL" >> /workspace/qanything_local/.env + echo "LLM_API_SERVE_CONV_TEMPLATE=$LLM_API_SERVE_CONV_TEMPLATE" >> /workspace/qanything_local/.env + + mkdir -p /workspace/qanything_local/qanything_logs/fastchat_logs && cd /workspace/qanything_local/qanything_logs/fastchat_logs + nohup python3 -m fastchat.serve.controller --host 0.0.0.0 --port 7800 > fschat_controller_7800.log 2>&1 & + nohup python3 -m fastchat.serve.openai_api_server --host 0.0.0.0 --port 7802 --controller-address http://0.0.0.0:7800 > fschat_openai_api_server_7802.log 2>&1 & + + gpus=$tensor_parallel + if [ $tensor_parallel -eq 2 ]; then + gpus="$gpuid1,$gpuid2" + else + gpus="$gpuid1" + fi + + case $runtime_backend in + "hf") + echo "Executing hf runtime_backend" + + CUDA_VISIBLE_DEVICES=$gpus nohup python3 -m fastchat.serve.model_worker --host 0.0.0.0 --port 7801 \ + --controller-address http://0.0.0.0:7800 --worker-address http://0.0.0.0:7801 \ + --model-path /model_repos/CustomLLM/$LLM_API_SERVE_MODEL --load-8bit \ + --gpus $gpus --num-gpus $tensor_parallel --dtype bfloat16 --conv-template $LLM_API_SERVE_CONV_TEMPLATE > fschat_model_worker_7801.log 2>&1 & + + ;; + "vllm") + echo "Executing vllm runtime_backend" + + CUDA_VISIBLE_DEVICES=$gpus nohup python3 -m fastchat.serve.vllm_worker --host 0.0.0.0 --port 7801 \ + --controller-address http://0.0.0.0:7800 --worker-address http://0.0.0.0:7801 \ + --model-path /model_repos/CustomLLM/$LLM_API_SERVE_MODEL --trust-remote-code --block-size 32 --tensor-parallel-size $tensor_parallel \ + --max-model-len 4096 --gpu-memory-utilization $gpu_memory_utilization --dtype bfloat16 --conv-template $LLM_API_SERVE_CONV_TEMPLATE > fschat_model_worker_7801.log 2>&1 & + + ;; + "sglang") + echo "Executing sglang runtime_backend" + ;; + *) + echo "Invalid runtime_backend option"; exit 1 + ;; + esac +fi + +# 默认ocr_use_gpu为True +OCR_USE_GPU="True" + +# 使用nvidia-smi命令获取GPU的显存大小(以MiB为单位) +GPU1_MEMORY_SIZE=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits -i $gpuid1) +GPU2_MEMORY_SIZE=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits -i $gpuid2) + +# 检查显存大小是否小于12G(即 12288 MiB) +if [ "$GPU2_MEMORY_SIZE" -lt 12288 ]; then + OCR_USE_GPU="False" +fi + +echo "GPU1_MEMORY_SIZE=$GPU1_MEMORY_SIZE" +echo "OCR_USE_GPU=$OCR_USE_GPU" + +echo "====================================================" +echo "******************** 重要提示 ********************" +echo "====================================================" +echo "" + +if [ "$GPU1_MEMORY_SIZE" -lt 8100 ]; then + echo "检测到您的 GPU 显存小于 8GB,推荐使用 OpenAI 或其他在线大型语言模型 (LLM)。" +elif [ "$GPU1_MEMORY_SIZE" -ge 8100 ] && [ "$GPU1_MEMORY_SIZE" -le 16400 ]; then + echo "检测到您的 GPU 显存在 8GB 到 16GB 之间,推荐使用本地 3B 大小以内的语言模型。" +else + echo "检测到您的 GPU 显存大于 16GB,推荐使用本地 7B 的语言模型。" +fi + +echo "" +echo "====================================================" +echo "请根据您的显存情况选择合适的语言模型以获得最佳性能。" +echo "====================================================" +echo "" +sleep 5 + + +cd /workspace/qanything_local || exit +nohup python3 -u qanything_kernel/dependent_server/rerank_for_local_serve/rerank_server.py > rerank.log 2>&1 & +echo "The rerank service is ready! (2/8)" +echo "rerank服务已就绪! (2/8)" + +CUDA_VISIBLE_DEVICES=$gpuid2 nohup python3 -u qanything_kernel/dependent_server/ocr_serve/ocr_server.py $OCR_USE_GPU > ocr.log 2>&1 & +echo "The ocr service is ready! (3/8)" +echo "OCR服务已就绪! (3/8)" + +nohup python3 -u qanything_kernel/qanything_server/sanic_api.py --mode "local" > api.log 2>&1 & +echo "The qanything backend service is ready! (4/8)" +echo "qanything后端服务已就绪! (4/8)" + + +timeout_time=300 # npm下载超时时间300秒,triton_server启动超时时间600秒 + +env_file="/workspace/qanything_local/front_end/.env.production" +user_file="/workspace/qanything_local/user.config" +user_ip=$(cat "$user_file") +# 读取env_file的第一行 +current_host=$(grep VITE_APP_API_HOST "$env_file") +user_host="VITE_APP_API_HOST=http://$user_ip:8777" +# 检查current_host与user_host是否相同 +if [ "$current_host" != "$user_host" ]; then + # 使用 sed 命令更新 VITE_APP_API_HOST 的值 + sed -i "s|VITE_APP_API_HOST=.*|$user_host|" "$env_file" + echo "The file $env_file has been updated with the following configuration:" + grep "VITE_APP_API_HOST" "$env_file" +fi + +# 转到 front_end 目录 +cd /workspace/qanything_local/front_end || exit +# 安装依赖 +echo "Waiting for [npm run install](5/8)" +timeout $timeout_time npm install +if [ $? -eq 0 ]; then + echo "[npm run install] Installed successfully(5/8)" +elif [ $? -eq 124 ]; then + echo "npm install 下载超时,可能是网络问题,请修改 npm 代理。" + exit 1 +else + echo "Failed to install npm dependencies." + exit 1 +fi + +# 构建前端项目 +echo "Waiting for [npm run build](6/8)" +npm run build +if [ $? -eq 0 ]; then + echo "[npm run build] build successfully(6/8)" +else + echo "Failed to build the front end." + exit 1 +fi + +# 启动前端页面服务 +nohup npm run serve 1>npm_server.log 2>&1 & + +# 监听前端页面服务 +tail -f npm_server.log & + +front_end_start_time=$(date +%s) + +while ! grep -q "Local:" npm_server.log; do + echo "Waiting for the front-end service to start..." + echo "等待启动前端服务" + sleep 1 + + # 获取当前时间并计算经过的时间 + current_time=$(date +%s) + elapsed_time=$((current_time - front_end_start_time)) + + # 检查是否超时 + if [ $elapsed_time -ge 120 ]; then + echo "启动前端服务超时,请检查日志文件 front_end/npm_server.log 获取更多信息。" + exit 1 + fi +done +echo "The front-end service is ready!...(7/8)" +echo "前端服务已就绪!...(7/8)" + + +timeout_time=300 + +current_time=$(date +%s) +elapsed=$((current_time - start_time)) # 计算经过的时间(秒) +echo "Time elapsed: ${elapsed} seconds." +echo "已耗时: ${elapsed} 秒." + +while true; do + + if [ "$runtime_backend" = "default" ]; then + if [ $gpuid1 -eq $gpuid2 ]; then + response=$(curl -s -w "%{http_code}" http://localhost:10000/v2/health/ready -o /dev/null) + if [ $response -eq 200 ]; then + echo "The llm service is ready!, now you can use the qanything service. (8/8)" + echo "LLM 服务已准备就绪!现在您可以使用qanything服务。(8/8)" + break + fi + else + response_llm=$(curl -s -w "%{http_code}" http://localhost:10000/v2/health/ready -o /dev/null) + response_embed=$(curl -s -w "%{http_code}" http://localhost:9000/v2/health/ready -o /dev/null) + response_rerank=$(curl -s -w "%{http_code}" http://localhost:8000/v2/health/ready -o /dev/null) + echo "health_response_llm = $response_llm" + echo "health_response_embed = $response_embed" + echo "health_response_rerank = $response_rerank" + + if [ $response_llm -eq 200 ] && [ $response_embed -eq 200 ] && [ $response_rerank -eq 200 ]; then + echo "The llm service is ready!, now you can use the qanything service. (8/8)" + echo "LLM 服务已准备就绪!现在您可以使用qanything服务。(8/8)" + break + fi + fi + + echo "The llm service is starting up, it can be long... you have time to make a coffee :)" + echo "LLM 服务正在启动,可能需要一段时间...你有时间去冲杯咖啡 :)" + + current_time=$(date +%s) + elapsed_time=$((current_time - start_time)) + + # 检查是否超时 + if [ $elapsed_time -ge $((timeout_time * 2)) ]; then + echo "启动 LLM 服务超时,请进入容器内检查/model_repos/QAEnsemble_base/QAEnsemble_base.log以获取更多信息。" + exit 1 + fi + sleep 5 + + else + fastchat_register_model=$(curl --request POST --url http://localhost:7800/list_models) + if [[ $fastchat_register_model == *"$LLM_API_SERVE_MODEL"* ]]; then + echo "The llm service is ready!, now you can use the qanything service. (8/8)" + echo "LLM 服务已准备就绪!现在您可以使用qanything服务。(8/8)" + break + else + echo "The llm service is starting up, it can be long... you have time to make a coffee :)" + echo "LLM 服务正在启动,可能需要一段时间...你有时间去冲杯咖啡 :)" + + current_time=$(date +%s) + elapsed_time=$((current_time - start_time)) + + # 检查是否超时 + if [ $elapsed_time -ge $((timeout_time * 2)) ]; then + echo "启动 LLM 服务超时,请进入容器内检查/workspace/qanything_local/qanything_logs/fastchat_logs/fschat_model_worker_7801.log以获取更多信息。" + exit 1 + fi + sleep 5 + fi + fi +done + +current_time=$(date +%s) +elapsed=$((current_time - start_time)) # 计算经过的时间(秒) +echo "Time elapsed: ${elapsed} seconds." +echo "已耗时: ${elapsed} 秒." +echo "Please visit the front-end service at [http://$user_ip:5052/qanything/] to conduct Q&A." +echo "请在[http://$user_ip:5052/qanything/]下访问前端服务来进行问答,如果前端报错,请在浏览器按F12以获取更多报错信息" + +# 保持容器运行 +while true; do + sleep 2 +done + + diff --git a/third_party/FastChat/.gitignore b/third_party/FastChat/.gitignore new file mode 100644 index 00000000..94b6e614 --- /dev/null +++ b/third_party/FastChat/.gitignore @@ -0,0 +1,31 @@ +# Python +__pycache__ +*.pyc +*.egg-info +dist +.venv + +# Log +*.log +*.log.* +*.json +!playground/deepspeed_config_s2.json +!playground/deepspeed_config_s3.json + +# Editor +.idea +*.swp + +# Other +.DS_Store +wandb +output +checkpoints_flant5_3b + +# Data +*.pkl +*.csv +tests/state_of_the_union.txt + +# Build +build diff --git a/third_party/FastChat/.pylintrc b/third_party/FastChat/.pylintrc new file mode 100644 index 00000000..864033fa --- /dev/null +++ b/third_party/FastChat/.pylintrc @@ -0,0 +1,449 @@ +# This Pylint rcfile contains a best-effort configuration to uphold the +# best-practices and style described in the Google Python style guide: +# https://google.github.io/styleguide/pyguide.html +# +# Its canonical open-source location is: +# https://google.github.io/styleguide/pylintrc + +[MASTER] + +# Files or directories to be skipped. They should be base names, not paths. +ignore=third_party,ray_patches,providers + +# Files or directories matching the regex patterns are skipped. The regex +# matches against base names, not paths. +ignore-patterns= + +# Pickle collected data for later comparisons. +persistent=no + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + +# Use multiple processes to speed up Pylint. +jobs=4 + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED +confidence= + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +#enable= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once).You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use"--disable=all --enable=classes +# --disable=W" +disable=abstract-method, + apply-builtin, + arguments-differ, + attribute-defined-outside-init, + backtick, + bad-option-value, + basestring-builtin, + buffer-builtin, + c-extension-no-member, + consider-using-enumerate, + cmp-builtin, + cmp-method, + coerce-builtin, + coerce-method, + delslice-method, + div-method, + duplicate-code, + eq-without-hash, + execfile-builtin, + file-builtin, + filter-builtin-not-iterating, + fixme, + getslice-method, + global-statement, + hex-method, + idiv-method, + implicit-str-concat-in-sequence, + import-error, + import-self, + import-star-module-level, + inconsistent-return-statements, + input-builtin, + intern-builtin, + invalid-str-codec, + locally-disabled, + logging-format-interpolation, # FIXME(sky): make pass. + logging-fstring-interpolation, # FIXME(sky): make pass. + long-builtin, + long-suffix, + map-builtin-not-iterating, + misplaced-comparison-constant, + missing-function-docstring, + metaclass-assignment, + next-method-called, + next-method-defined, + no-absolute-import, + no-else-break, + no-else-continue, + no-else-raise, + no-else-return, + no-init, # added + no-member, + no-name-in-module, + no-self-use, + nonzero-method, + oct-method, + old-division, + old-ne-operator, + old-octal-literal, + old-raise-syntax, + parameter-unpacking, + print-statement, + raising-string, + range-builtin-not-iterating, + raw_input-builtin, + rdiv-method, + reduce-builtin, + relative-import, + reload-builtin, + round-builtin, + setslice-method, + signature-differs, + standarderror-builtin, + suppressed-message, + sys-max-int, + too-few-public-methods, + too-many-ancestors, + too-many-arguments, + too-many-boolean-expressions, + too-many-branches, + too-many-instance-attributes, + too-many-locals, + too-many-nested-blocks, + too-many-public-methods, + too-many-return-statements, + too-many-statements, + trailing-newlines, + unichr-builtin, + unicode-builtin, + unnecessary-pass, + unpacking-in-except, + useless-else-on-loop, + useless-object-inheritance, + useless-suppression, + using-cmp-argument, + wrong-import-order, + xrange-builtin, + zip-builtin-not-iterating, + + +[REPORTS] + +# Set the output format. Available formats are text, parseable, colorized, msvs +# (visual studio) and html. You can also give a reporter class, eg +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Put messages in a separate file for each module / package specified on the +# command line instead of printing them on stdout. Reports (if any) will be +# written in a file name "pylint_global.[txt|html]". This option is deprecated +# and it will be removed in Pylint 2.0. +files-output=no + +# Tells whether to display a full report or only the messages +reports=no + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details +#msg-template= + + +[BASIC] + +# Good variable names which should always be accepted, separated by a comma +good-names=main,_ + +# Bad variable names which should always be refused, separated by a comma +bad-names= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Include a hint for the correct naming format with invalid-name +include-naming-hint=no + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl + +# Regular expression matching correct function names +function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ + +# Regular expression matching correct variable names +variable-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct constant names +const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression matching correct attribute names +attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ + +# Regular expression matching correct argument names +argument-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct class attribute names +class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression matching correct inline iteration names +inlinevar-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct class names +class-rgx=^_?[A-Z][a-zA-Z0-9]*$ + +# Regular expression matching correct module names +module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ + +# Regular expression matching correct method names +method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=10 + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis. It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + + +[FORMAT] + +# Maximum number of characters on a single line. +max-line-length=100 + +# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt +# lines made too long by directives to pytype. + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=(?x)( + ^\s*(\#\ )??$| + ^\s*(from\s+\S+\s+)?import\s+.+$) + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=yes + +# List of optional constructs for which whitespace checking is disabled. `dict- +# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. +# `trailing-comma` allows a space between comma and closing bracket: (a, ). +# `empty-line` allows space-only lines. +no-space-check= + +# Maximum number of lines in a module +max-module-lines=99999 + +# String used as indentation unit. The internal Google style guide mandates 2 +# spaces. Google's externaly-published style guide says 4, consistent with +# PEP 8. Here we use 4 spaces. +indent-string=' ' + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=TODO + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=yes + + +[VARIABLES] + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# A regular expression matching the name of dummy variables (i.e. expectedly +# not used). +dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid to define new builtins when possible. +additional-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_,_cb + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools + + +[LOGGING] + +# Logging modules to check that the string format arguments are in logging +# function parameter format +logging-modules=logging,absl.logging,tensorflow.io.logging + + +[SIMILARITIES] + +# Minimum lines number of a similarity. +min-similarity-lines=4 + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + + +[SPELLING] + +# Spelling dictionary name. Available dictionaries: none. To make it working +# install python-enchant package. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to indicated private dictionary in +# --spelling-private-dict-file option instead of raising a message. +spelling-store-unknown-words=no + + +[IMPORTS] + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=regsub, + TERMIOS, + Bastion, + rexec, + sets + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled) +import-graph= + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled) +ext-import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled) +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant, absl + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls, + class_ + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "Exception" +overgeneral-exceptions=StandardError, + Exception, + BaseException + +####### + +# https://github.com/edaniszewski/pylint-quotes#configuration +string-quote=single +triple-quote=double +docstring-quote=double diff --git a/third_party/FastChat/LICENSE b/third_party/FastChat/LICENSE new file mode 100644 index 00000000..261eeb9e --- /dev/null +++ b/third_party/FastChat/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/third_party/FastChat/README.md b/third_party/FastChat/README.md new file mode 100644 index 00000000..0c936a27 --- /dev/null +++ b/third_party/FastChat/README.md @@ -0,0 +1,354 @@ +# FastChat +| [**Demo**](https://chat.lmsys.org/) | [**Discord**](https://discord.gg/HSWAKCrnFx) | [**X**](https://x.com/lmsysorg) | + +FastChat is an open platform for training, serving, and evaluating large language model based chatbots. +- FastChat powers Chatbot Arena (https://chat.lmsys.org/), serving over 6 million chat requests for 50+ LLMs. +- Arena has collected over 100K human votes from side-by-side LLM battles to compile an online [LLM Elo leaderboard](https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard). + +FastChat's core features include: +- The training and evaluation code for state-of-the-art models (e.g., Vicuna, MT-Bench). +- A distributed multi-model serving system with web UI and OpenAI-compatible RESTful APIs. + +## News +- [2023/09] 🔥 We released **LMSYS-Chat-1M**, a large-scale real-world LLM conversation dataset. Read the [report](https://arxiv.org/abs/2309.11998). +- [2023/08] We released **Vicuna v1.5** based on Llama 2 with 4K and 16K context lengths. Download [weights](#vicuna-weights). +- [2023/07] We released **Chatbot Arena Conversations**, a dataset containing 33k conversations with human preferences. Download it [here](https://huggingface.co/datasets/lmsys/chatbot_arena_conversations). + +
+More + +- [2023/08] We released **LongChat v1.5** based on Llama 2 with 32K context lengths. Download [weights](#longchat). +- [2023/06] We introduced **MT-bench**, a challenging multi-turn question set for evaluating chatbots. Check out the blog [post](https://lmsys.org/blog/2023-06-22-leaderboard/). +- [2023/06] We introduced **LongChat**, our long-context chatbots and evaluation tools. Check out the blog [post](https://lmsys.org/blog/2023-06-29-longchat/). +- [2023/05] We introduced **Chatbot Arena** for battles among LLMs. Check out the blog [post](https://lmsys.org/blog/2023-05-03-arena). +- [2023/03] We released **Vicuna: An Open-Source Chatbot Impressing GPT-4 with 90% ChatGPT Quality**. Check out the blog [post](https://vicuna.lmsys.org). + +
+ + + +## Contents +- [Install](#install) +- [Model Weights](#model-weights) +- [Inference with Command Line Interface](#inference-with-command-line-interface) +- [Serving with Web GUI](#serving-with-web-gui) +- [API](#api) +- [Evaluation](#evaluation) +- [Fine-tuning](#fine-tuning) +- [Citation](#citation) + +## Install + +### Method 1: With pip + +```bash +pip3 install "fschat[model_worker,webui]" +``` + +### Method 2: From source + +1. Clone this repository and navigate to the FastChat folder +```bash +git clone https://github.com/lm-sys/FastChat.git +cd FastChat +``` + +If you are running on Mac: +```bash +brew install rust cmake +``` + +2. Install Package +```bash +pip3 install --upgrade pip # enable PEP 660 support +pip3 install -e ".[model_worker,webui]" +``` + +## Model Weights +### Vicuna Weights +[Vicuna](https://lmsys.org/blog/2023-03-30-vicuna/) is based on Llama 2 and should be used under Llama's [model license](https://github.com/facebookresearch/llama/blob/main/LICENSE). + +You can use the commands below to start chatting. It will automatically download the weights from Hugging Face repos. +See more command options and how to handle out-of-memory in the "Inference with Command Line Interface" section below. + +**NOTE: `transformers>=4.31` is required for 16K versions.** + +| Size | Chat Command | Hugging Face Repo | +| --- | --- | --- | +| 7B | `python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5` | [lmsys/vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5) | +| 7B-16k | `python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5-16k` | [lmsys/vicuna-7b-v1.5-16k](https://huggingface.co/lmsys/vicuna-7b-v1.5-16k) | +| 13B | `python3 -m fastchat.serve.cli --model-path lmsys/vicuna-13b-v1.5` | [lmsys/vicuna-13b-v1.5](https://huggingface.co/lmsys/vicuna-13b-v1.5) | +| 13B-16k | `python3 -m fastchat.serve.cli --model-path lmsys/vicuna-13b-v1.5-16k` | [lmsys/vicuna-13b-v1.5-16k](https://huggingface.co/lmsys/vicuna-13b-v1.5-16k) | +| 33B | `python3 -m fastchat.serve.cli --model-path lmsys/vicuna-33b-v1.3` | [lmsys/vicuna-33b-v1.3](https://huggingface.co/lmsys/vicuna-33b-v1.3) | + +**Old weights**: see [docs/vicuna_weights_version.md](docs/vicuna_weights_version.md) for all versions of weights and their differences. + +### LongChat +We release [LongChat](https://lmsys.org/blog/2023-06-29-longchat/) models under Llama's [model license](https://github.com/facebookresearch/llama/blob/main/LICENSE). + +| Size | Chat Command | Hugging Face Repo | +| --- | --- | --- | +| 7B | `python3 -m fastchat.serve.cli --model-path lmsys/longchat-7b-32k-v1.5` | [lmsys/longchat-7b-32k](https://huggingface.co/lmsys/longchat-7b-32k-v1.5) | + +### FastChat-T5 +You can use the commands below to chat with FastChat-T5. It will automatically download the weights from Hugging Face repos. + +| Size | Chat Command | Hugging Face Repo | +| --- | --- | --- | +| 3B | `python3 -m fastchat.serve.cli --model-path lmsys/fastchat-t5-3b-v1.0` | [lmsys/fastchat-t5-3b-v1.0](https://huggingface.co/lmsys/fastchat-t5-3b-v1.0) | + +## Inference with Command Line Interface + + + +(Experimental Feature: You can specify `--style rich` to enable rich text output and better text streaming quality for some non-ASCII content. This may not work properly on certain terminals.) + +#### Supported Models +FastChat supports a wide range of models, including +LLama 2, Vicuna, Alpaca, Baize, ChatGLM, Dolly, Falcon, FastChat-T5, GPT4ALL, Guanaco, MTP, OpenAssistant, OpenChat, RedPajama, StableLM, WizardLM, xDAN-AI and more. + +See a complete list of supported models and instructions to add a new model [here](docs/model_support.md). + +#### Single GPU +The command below requires around 14GB of GPU memory for Vicuna-7B and 28GB of GPU memory for Vicuna-13B. +See the ["Not Enough Memory" section](#not-enough-memory) below if you do not have enough memory. +`--model-path` can be a local folder or a Hugging Face repo name. +``` +python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5 +``` + +#### Multiple GPUs +You can use model parallelism to aggregate GPU memory from multiple GPUs on the same machine. +``` +python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5 --num-gpus 2 +``` + +Tips: +Sometimes the "auto" device mapping strategy in huggingface/transformers does not perfectly balance the memory allocation across multiple GPUs. +You can use `--max-gpu-memory` to specify the maximum memory per GPU for storing model weights. +This allows it to allocate more memory for activations, so you can use longer context lengths or larger batch sizes. For example, + +``` +python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5 --num-gpus 2 --max-gpu-memory 8GiB +``` + +#### CPU Only +This runs on the CPU only and does not require GPU. It requires around 30GB of CPU memory for Vicuna-7B and around 60GB of CPU memory for Vicuna-13B. +``` +python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5 --device cpu +``` + +Use Intel AI Accelerator AVX512_BF16/AMX to accelerate CPU inference. +``` +CPU_ISA=amx python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5 --device cpu +``` + +#### Metal Backend (Mac Computers with Apple Silicon or AMD GPUs) +Use `--device mps` to enable GPU acceleration on Mac computers (requires torch >= 2.0). +Use `--load-8bit` to turn on 8-bit compression. +``` +python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5 --device mps --load-8bit +``` +Vicuna-7B can run on a 32GB M1 Macbook with 1 - 2 words / second. + +#### Intel XPU (Intel Data Center and Arc A-Series GPUs) +Install the [Intel Extension for PyTorch](https://intel.github.io/intel-extension-for-pytorch/xpu/latest/tutorials/installation.html). Set the OneAPI environment variables: +``` +source /opt/intel/oneapi/setvars.sh +``` + +Use `--device xpu` to enable XPU/GPU acceleration. +``` +python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5 --device xpu +``` +Vicuna-7B can run on an Intel Arc A770 16GB. + +#### Ascend NPU +Install the [Ascend PyTorch Adapter](https://github.com/Ascend/pytorch). Set the CANN environment variables: +``` +source /usr/local/Ascend/ascend-toolkit/set_env.sh +``` + +Use `--device npu` to enable NPU acceleration. +``` +python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5 --device npu +``` +Vicuna-7B/13B can run on an Ascend NPU. + +#### Not Enough Memory +If you do not have enough memory, you can enable 8-bit compression by adding `--load-8bit` to commands above. +This can reduce memory usage by around half with slightly degraded model quality. +It is compatible with the CPU, GPU, and Metal backend. + +Vicuna-13B with 8-bit compression can run on a single GPU with 16 GB of VRAM, like an Nvidia RTX 3090, RTX 4080, T4, V100 (16GB), or an AMD RX 6800 XT. + +``` +python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5 --load-8bit +``` + +In addition to that, you can add `--cpu-offloading` to commands above to offload weights that don't fit on your GPU onto the CPU memory. +This requires 8-bit compression to be enabled and the bitsandbytes package to be installed, which is only available on linux operating systems. + +#### More Platforms and Quantization +- For AMD GPU users, please install ROCm and [the ROCm version of PyTorch](https://pytorch.org/get-started/locally/) before you install FastChat. See also this [post](https://github.com/lm-sys/FastChat/issues/104#issuecomment-1613791563). +- FastChat supports ExLlama V2. See [docs/exllama_v2.md](/docs/exllama_v2.md). +- FastChat supports GPTQ 4bit inference with [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa). See [docs/gptq.md](/docs/gptq.md). +- FastChat supports AWQ 4bit inference with [mit-han-lab/llm-awq](https://github.com/mit-han-lab/llm-awq). See [docs/awq.md](/docs/awq.md). +- [MLC LLM](https://mlc.ai/mlc-llm/), backed by [TVM Unity](https://github.com/apache/tvm/tree/unity) compiler, deploys Vicuna natively on phones, consumer-class GPUs and web browsers via Vulkan, Metal, CUDA and WebGPU. + +## Serving with Web GUI + + + +To serve using the web UI, you need three main components: web servers that interface with users, model workers that host one or more models, and a controller to coordinate the webserver and model workers. You can learn more about the architecture [here](docs/server_arch.md). + +Here are the commands to follow in your terminal: + +#### Launch the controller +```bash +python3 -m fastchat.serve.controller +``` + +This controller manages the distributed workers. + +#### Launch the model worker(s) +```bash +python3 -m fastchat.serve.model_worker --model-path lmsys/vicuna-7b-v1.5 +``` +Wait until the process finishes loading the model and you see "Uvicorn running on ...". The model worker will register itself to the controller . + +To ensure that your model worker is connected to your controller properly, send a test message using the following command: +```bash +python3 -m fastchat.serve.test_message --model-name vicuna-7b-v1.5 +``` +You will see a short output. + +#### Launch the Gradio web server +```bash +python3 -m fastchat.serve.gradio_web_server +``` + +This is the user interface that users will interact with. + +By following these steps, you will be able to serve your models using the web UI. You can open your browser and chat with a model now. +If the models do not show up, try to reboot the gradio web server. + +#### (Optional): Advanced Features, Scalability, Third Party UI +- You can register multiple model workers to a single controller, which can be used for serving a single model with higher throughput or serving multiple models at the same time. When doing so, please allocate different GPUs and ports for different model workers. +``` +# worker 0 +CUDA_VISIBLE_DEVICES=0 python3 -m fastchat.serve.model_worker --model-path lmsys/vicuna-7b-v1.5 --controller http://localhost:21001 --port 31000 --worker http://localhost:31000 +# worker 1 +CUDA_VISIBLE_DEVICES=1 python3 -m fastchat.serve.model_worker --model-path lmsys/fastchat-t5-3b-v1.0 --controller http://localhost:21001 --port 31001 --worker http://localhost:31001 +``` +- You can also launch a multi-tab gradio server, which includes the Chatbot Arena tabs. +```bash +python3 -m fastchat.serve.gradio_web_server_multi +``` +- The default model worker based on huggingface/transformers has great compatibility but can be slow. If you want high-throughput batched serving, you can try [vLLM integration](docs/vllm_integration.md). +- If you want to host it on your own UI or third party UI, see [Third Party UI](docs/third_party_ui.md). + +## Use models from modelscope +For Chinese users, you can use models from www.modelscope.cn via specify the following environment variables. +```bash +export FASTCHAT_USE_MODELSCOPE=True +``` + +## API +### OpenAI-Compatible RESTful APIs & SDK +FastChat provides OpenAI-compatible APIs for its supported models, so you can use FastChat as a local drop-in replacement for OpenAI APIs. +The FastChat server is compatible with both [openai-python](https://github.com/openai/openai-python) library and cURL commands. +See [docs/openai_api.md](docs/openai_api.md). + +### Hugging Face Generation APIs +See [fastchat/serve/huggingface_api.py](fastchat/serve/huggingface_api.py). + +### LangChain Integration +See [docs/langchain_integration](docs/langchain_integration.md). + +## Evaluation +We use MT-bench, a set of challenging multi-turn open-ended questions to evaluate models. +To automate the evaluation process, we prompt strong LLMs like GPT-4 to act as judges and assess the quality of the models' responses. +See instructions for running MT-bench at [fastchat/llm_judge](fastchat/llm_judge). + +MT-bench is the new recommended way to benchmark your models. If you are still looking for the old 80 questions used in the vicuna blog post, please go to [vicuna-blog-eval](https://github.com/lm-sys/vicuna-blog-eval). + +## Fine-tuning +### Data + +Vicuna is created by fine-tuning a Llama base model using approximately 125K user-shared conversations gathered from ShareGPT.com with public APIs. To ensure data quality, we convert the HTML back to markdown and filter out some inappropriate or low-quality samples. Additionally, we divide lengthy conversations into smaller segments that fit the model's maximum context length. For detailed instructions to clean the ShareGPT data, check out [here](docs/commands/data_cleaning.md). + +We will not release the ShareGPT dataset. If you would like to try the fine-tuning code, you can run it with some dummy conversations in [dummy_conversation.json](data/dummy_conversation.json). You can follow the same format and plug in your own data. + +### Code and Hyperparameters +Our code is based on [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) with additional support for multi-turn conversations. +We use similar hyperparameters as the Stanford Alpaca. + +| Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay | +| --- | ---: | ---: | ---: | ---: | ---: | +| Vicuna-13B | 128 | 2e-5 | 3 | 2048 | 0 | + +### Fine-tuning Vicuna-7B with Local GPUs + +- Install dependency +```bash +pip3 install -e ".[train]" +``` + +- You can use the following command to train Vicuna-7B with 4 x A100 (40GB). Update `--model_name_or_path` with the actual path to Llama weights and `--data_path` with the actual path to data. +```bash +torchrun --nproc_per_node=4 --master_port=20001 fastchat/train/train_mem.py \ + --model_name_or_path meta-llama/Llama-2-7b-hf \ + --data_path data/dummy_conversation.json \ + --bf16 True \ + --output_dir output_vicuna \ + --num_train_epochs 3 \ + --per_device_train_batch_size 2 \ + --per_device_eval_batch_size 2 \ + --gradient_accumulation_steps 16 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 1200 \ + --save_total_limit 10 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --fsdp "full_shard auto_wrap" \ + --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --lazy_preprocess True +``` + +Tips: +- If you are using V100 which is not supported by FlashAttention, you can use the [memory-efficient attention](https://arxiv.org/abs/2112.05682) implemented in [xFormers](https://github.com/facebookresearch/xformers). Install xformers and replace `fastchat/train/train_mem.py` above with [fastchat/train/train_xformers.py](fastchat/train/train_xformers.py). +- If you meet out-of-memory due to "FSDP Warning: When using FSDP, it is efficient and recommended... ", see solutions [here](https://github.com/huggingface/transformers/issues/24724#issuecomment-1645189539). +- If you meet out-of-memory during model saving, see solutions [here](https://github.com/pytorch/pytorch/issues/98823). +- To turn on logging to popular experiment tracking tools such as Tensorboard, MLFlow or Weights & Biases, use the `report_to` argument, e.g. pass `--report_to wandb` to turn on logging to Weights & Biases. + +### Other models, platforms and LoRA support +More instructions to train other models (e.g., FastChat-T5) and use LoRA are in [docs/training.md](docs/training.md). + +### Fine-tuning on Any Cloud with SkyPilot +[SkyPilot](https://github.com/skypilot-org/skypilot) is a framework built by UC Berkeley for easily and cost effectively running ML workloads on any cloud (AWS, GCP, Azure, Lambda, etc.). +Find SkyPilot documentation [here](https://github.com/skypilot-org/skypilot/tree/master/llm/vicuna) on using managed spot instances to train Vicuna and save on your cloud costs. + +## Citation +The code (training, serving, and evaluation) in this repository is mostly developed for or derived from the paper below. +Please cite it if you find the repository helpful. + +``` +@misc{zheng2023judging, + title={Judging LLM-as-a-judge with MT-Bench and Chatbot Arena}, + author={Lianmin Zheng and Wei-Lin Chiang and Ying Sheng and Siyuan Zhuang and Zhanghao Wu and Yonghao Zhuang and Zi Lin and Zhuohan Li and Dacheng Li and Eric. P Xing and Hao Zhang and Joseph E. Gonzalez and Ion Stoica}, + year={2023}, + eprint={2306.05685}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +``` + +We are also planning to add more of our research to this repository. diff --git a/third_party/FastChat/assets/demo_narrow.gif b/third_party/FastChat/assets/demo_narrow.gif new file mode 100644 index 00000000..aa1af38e Binary files /dev/null and b/third_party/FastChat/assets/demo_narrow.gif differ diff --git a/third_party/FastChat/assets/qa_browser.png b/third_party/FastChat/assets/qa_browser.png new file mode 100644 index 00000000..61ab4b09 Binary files /dev/null and b/third_party/FastChat/assets/qa_browser.png differ diff --git a/third_party/FastChat/assets/screenshot_cli.png b/third_party/FastChat/assets/screenshot_cli.png new file mode 100644 index 00000000..7a7dd5d6 Binary files /dev/null and b/third_party/FastChat/assets/screenshot_cli.png differ diff --git a/third_party/FastChat/assets/screenshot_gui.png b/third_party/FastChat/assets/screenshot_gui.png new file mode 100644 index 00000000..ecb41d2f Binary files /dev/null and b/third_party/FastChat/assets/screenshot_gui.png differ diff --git a/third_party/FastChat/assets/server_arch.png b/third_party/FastChat/assets/server_arch.png new file mode 100644 index 00000000..06f8b2ac Binary files /dev/null and b/third_party/FastChat/assets/server_arch.png differ diff --git a/third_party/FastChat/assets/vicuna_logo.jpeg b/third_party/FastChat/assets/vicuna_logo.jpeg new file mode 100644 index 00000000..e7883dc8 Binary files /dev/null and b/third_party/FastChat/assets/vicuna_logo.jpeg differ diff --git a/third_party/FastChat/docker/Dockerfile b/third_party/FastChat/docker/Dockerfile new file mode 100644 index 00000000..159d4abd --- /dev/null +++ b/third_party/FastChat/docker/Dockerfile @@ -0,0 +1,7 @@ +FROM nvidia/cuda:12.2.0-runtime-ubuntu20.04 + +RUN apt-get update -y && apt-get install -y python3.9 python3.9-distutils curl +RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py +RUN python3.9 get-pip.py +RUN pip3 install fschat +RUN pip3 install fschat[model_worker,webui] pydantic==1.10.13 \ No newline at end of file diff --git a/third_party/FastChat/docker/docker-compose.yml b/third_party/FastChat/docker/docker-compose.yml new file mode 100644 index 00000000..113e0c7a --- /dev/null +++ b/third_party/FastChat/docker/docker-compose.yml @@ -0,0 +1,36 @@ +version: "3.9" + +services: + fastchat-controller: + build: + context: . + dockerfile: Dockerfile + image: fastchat:latest + ports: + - "21001:21001" + entrypoint: ["python3.9", "-m", "fastchat.serve.controller", "--host", "0.0.0.0", "--port", "21001"] + fastchat-model-worker: + build: + context: . + dockerfile: Dockerfile + volumes: + - huggingface:/root/.cache/huggingface + image: fastchat:latest + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + entrypoint: ["python3.9", "-m", "fastchat.serve.model_worker", "--model-names", "${FASTCHAT_WORKER_MODEL_NAMES:-vicuna-7b-v1.5}", "--model-path", "${FASTCHAT_WORKER_MODEL_PATH:-lmsys/vicuna-7b-v1.5}", "--worker-address", "http://fastchat-model-worker:21002", "--controller-address", "http://fastchat-controller:21001", "--host", "0.0.0.0", "--port", "21002"] + fastchat-api-server: + build: + context: . + dockerfile: Dockerfile + image: fastchat:latest + ports: + - "8000:8000" + entrypoint: ["python3.9", "-m", "fastchat.serve.openai_api_server", "--controller-address", "http://fastchat-controller:21001", "--host", "0.0.0.0", "--port", "8000"] +volumes: + huggingface: diff --git a/third_party/FastChat/docs/arena.md b/third_party/FastChat/docs/arena.md new file mode 100644 index 00000000..979f41db --- /dev/null +++ b/third_party/FastChat/docs/arena.md @@ -0,0 +1,14 @@ +# Chatbot Arena +Chatbot Arena is an LLM benchmark platform featuring anonymous, randomized battles, available at https://chat.lmsys.org. +We invite the entire community to join this benchmarking effort by contributing your votes and models. + +## How to add a new model +If you want to see a specific model in the arena, you can follow the methods below. + +- Method 1: Hosted by LMSYS. + 1. Contribute the code to support this model in FastChat by submitting a pull request. See [instructions](model_support.md#how-to-support-a-new-model). + 2. After the model is supported, we will try to schedule some compute resources to host the model in the arena. However, due to the limited resources we have, we may not be able to serve every model. We will select the models based on popularity, quality, diversity, and other factors. + +- Method 2: Hosted by 3rd party API providers or yourself. + 1. If you have a model hosted by a 3rd party API provider or yourself, please give us an API endpoint. We prefer OpenAI-compatible APIs, so we can reuse our [code](https://github.com/lm-sys/FastChat/blob/33dca5cf12ee602455bfa9b5f4790a07829a2db7/fastchat/serve/gradio_web_server.py#L333-L358) for calling OpenAI models. + 2. You can use FastChat's OpenAI API [server](openai_api.md) to serve your model with OpenAI-compatible APIs and provide us with the endpoint. diff --git a/third_party/FastChat/docs/awq.md b/third_party/FastChat/docs/awq.md new file mode 100644 index 00000000..ac38ef3d --- /dev/null +++ b/third_party/FastChat/docs/awq.md @@ -0,0 +1,71 @@ +# AWQ 4bit Inference + +We integrated [AWQ](https://github.com/mit-han-lab/llm-awq) into FastChat to provide **efficient and accurate** 4bit LLM inference. + +## Install AWQ + +Setup environment (please refer to [this link](https://github.com/mit-han-lab/llm-awq#install) for more details): +```bash +conda create -n fastchat-awq python=3.10 -y +conda activate fastchat-awq +# cd /path/to/FastChat +pip install --upgrade pip # enable PEP 660 support +pip install -e . # install fastchat + +git clone https://github.com/mit-han-lab/llm-awq repositories/llm-awq +cd repositories/llm-awq +pip install -e . # install awq package + +cd awq/kernels +python setup.py install # install awq CUDA kernels +``` + +## Chat with the CLI + +```bash +# Download quantized model from huggingface +# Make sure you have git-lfs installed (https://git-lfs.com) +git lfs install +git clone https://huggingface.co/mit-han-lab/vicuna-7b-v1.3-4bit-g128-awq + +# You can specify which quantized model to use by setting --awq-ckpt +python3 -m fastchat.serve.cli \ + --model-path models/vicuna-7b-v1.3-4bit-g128-awq \ + --awq-wbits 4 \ + --awq-groupsize 128 +``` + +## Benchmark + +* Through **4-bit weight quantization**, AWQ helps to run larger language models within the device memory restriction and prominently accelerates token generation. All benchmarks are done with group_size 128. + +* Benchmark on NVIDIA RTX A6000: + + | Model | Bits | Max Memory (MiB) | Speed (ms/token) | AWQ Speedup | + | --------------- | ---- | ---------------- | ---------------- | ----------- | + | vicuna-7b | 16 | 13543 | 26.06 | / | + | vicuna-7b | 4 | 5547 | 12.43 | 2.1x | + | llama2-7b-chat | 16 | 13543 | 27.14 | / | + | llama2-7b-chat | 4 | 5547 | 12.44 | 2.2x | + | vicuna-13b | 16 | 25647 | 44.91 | / | + | vicuna-13b | 4 | 9355 | 17.30 | 2.6x | + | llama2-13b-chat | 16 | 25647 | 47.28 | / | + | llama2-13b-chat | 4 | 9355 | 20.28 | 2.3x | + +* NVIDIA RTX 4090: + + | Model | AWQ 4bit Speed (ms/token) | FP16 Speed (ms/token) | AWQ Speedup | + | --------------- | ------------------------- | --------------------- | ----------- | + | vicuna-7b | 8.61 | 19.09 | 2.2x | + | llama2-7b-chat | 8.66 | 19.97 | 2.3x | + | vicuna-13b | 12.17 | OOM | / | + | llama2-13b-chat | 13.54 | OOM | / | + +* NVIDIA Jetson Orin: + + | Model | AWQ 4bit Speed (ms/token) | FP16 Speed (ms/token) | AWQ Speedup | + | --------------- | ------------------------- | --------------------- | ----------- | + | vicuna-7b | 65.34 | 93.12 | 1.4x | + | llama2-7b-chat | 75.11 | 104.71 | 1.4x | + | vicuna-13b | 115.40 | OOM | / | + | llama2-13b-chat | 136.81 | OOM | / | diff --git a/third_party/FastChat/docs/commands/conv_release.md b/third_party/FastChat/docs/commands/conv_release.md new file mode 100644 index 00000000..64cfadbc --- /dev/null +++ b/third_party/FastChat/docs/commands/conv_release.md @@ -0,0 +1,38 @@ +## Chatbot Arena Conversations + +1. Gather battles +``` +python3 clean_battle_data.py --max-num 10 --mode conv_release +``` + +2. Tag OpenAI moderation +``` +python3 tag_openai_moderation.py --in clean_battle_conv_20230814.json +``` + +3. Clean PII + +4. Filter additional blocked words + +``` +python3 filter_bad_conv.py --in clean_battle_conv_20230630_tagged_v1_pii.json +``` + +5. Add additional toxicity tag + + +## All Conversations + +1. Gather chats +``` +python3 clean_chat_data.py +``` + +2. Sample +``` +python3 conv_release_scripts/sample.py +``` + + +## Prompt distribution + diff --git a/third_party/FastChat/docs/commands/data_cleaning.md b/third_party/FastChat/docs/commands/data_cleaning.md new file mode 100644 index 00000000..410ce8a8 --- /dev/null +++ b/third_party/FastChat/docs/commands/data_cleaning.md @@ -0,0 +1,19 @@ +## Data cleaning + +## Requirements +``` +pip3 install bs4 markdownify +pip3 install polyglot pyicu pycld2 +``` + +## Steps +``` +# Convert html to markdown +python3 -m fastchat.data.clean_sharegpt --in sharegpt_html.json --out sharegpt_clean.json + +# Keep or remove specific languages +python3 -m fastchat.data.optional_clean --in sharegpt_clean.json --out sharegpt_clean_lang.json --skip-lang SOME_LANGUAGE_CODE + +# Split long conversations +python3 -m fastchat.data.split_long_conversation --in sharegpt_clean_lang.json --out sharegpt_clean_lang_split.json --model-name /home/ubuntu/model_weights/llama-7b/ +``` diff --git a/third_party/FastChat/docs/commands/leaderboard.md b/third_party/FastChat/docs/commands/leaderboard.md new file mode 100644 index 00000000..04477a09 --- /dev/null +++ b/third_party/FastChat/docs/commands/leaderboard.md @@ -0,0 +1,37 @@ +### Get logs +``` +gsutil -m rsync -r gs://fastchat_logs ~/fastchat_logs/ +``` + +### Clean battle data +``` +cd ~/FastChat/fastchat/serve/monitor +python3 clean_battle_data.py +``` + +### Run Elo analysis +``` +python3 elo_analysis.py --clean-battle-file clean_battle_20230905.json +``` + +### Copy files to HF space +1. update plots +``` +scp atlas:/data/lmzheng/FastChat/fastchat/serve/monitor/elo_results_20230905.pkl . +``` + +2. update table +``` +wget https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard/raw/main/leaderboard_table_20230905.csv +``` + +### Update files on webserver +``` +DATE=20231002 + +rm -rf elo_results.pkl leaderboard_table.csv +wget https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard/resolve/main/elo_results_$DATE.pkl +wget https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard/resolve/main/leaderboard_table_$DATE.csv +ln -s leaderboard_table_$DATE.csv leaderboard_table.csv +ln -s elo_results_$DATE.pkl elo_results.pkl +``` diff --git a/third_party/FastChat/docs/commands/local_cluster.md b/third_party/FastChat/docs/commands/local_cluster.md new file mode 100644 index 00000000..bc1588c0 --- /dev/null +++ b/third_party/FastChat/docs/commands/local_cluster.md @@ -0,0 +1,38 @@ +### Local GPU cluster +node-01 +``` +python3 -m fastchat.serve.controller --host 0.0.0.0 --port 10002 + +CUDA_VISIBLE_DEVICES=0 python3 -m fastchat.serve.vllm_worker --model-path lmsys/vicuna-13b-v1.5 --model-name vicuna-13b --controller http://node-01:10002 --host 0.0.0.0 --port 31000 --worker-address http://$(hostname):31000 +CUDA_VISIBLE_DEVICES=1 python3 -m fastchat.serve.vllm_worker --model-path lmsys/vicuna-13b-v1.5 --model-name vicuna-13b --controller http://node-01:10002 --host 0.0.0.0 --port 31001 --worker-address http://$(hostname):31001 + +CUDA_VISIBLE_DEVICES=2,3 ray start --head +python3 -m fastchat.serve.vllm_worker --model-path lmsys/vicuna-33b-v1.3 --model-name vicuna-33b --controller http://node-01:10002 --host 0.0.0.0 --port 31002 --worker-address http://$(hostname):31002 --num-gpus 2 +``` + +node-02 +``` +CUDA_VISIBLE_DEVICES=0 python3 -m fastchat.serve.vllm_worker --model-path meta-llama/Llama-2-13b-chat-hf --model-name llama-2-13b-chat --controller http://node-01:10002 --host 0.0.0.0 --port 31000 --worker-address http://$(hostname):31000 --tokenizer meta-llama/Llama-2-7b-chat-hf +CUDA_VISIBLE_DEVICES=1 python3 -m fastchat.serve.vllm_worker --model-path meta-llama/Llama-2-13b-chat-hf --model-name llama-2-13b-chat --controller http://node-01:10002 --host 0.0.0.0 --port 31001 --worker-address http://$(hostname):31001 --tokenizer meta-llama/Llama-2-7b-chat-hf +CUDA_VISIBLE_DEVICES=2 python3 -m fastchat.serve.vllm_worker --model-path meta-llama/Llama-2-7b-chat-hf --model-name llama-2-7b-chat --controller http://node-01:10002 --host 0.0.0.0 --port 31002 --worker-address http://$(hostname):31002 --tokenizer meta-llama/Llama-2-7b-chat-hf +CUDA_VISIBLE_DEVICES=3 python3 -m fastchat.serve.vllm_worker --model-path WizardLM/WizardLM-13B-V1.1 --model-name wizardlm-13b --controller http://node-01:10002 --host 0.0.0.0 --port 31003 --worker-address http://$(hostname):31003 +``` + +node-03 +``` +python3 -m fastchat.serve.vllm_worker --model-path mosaicml/mpt-30b-chat --controller http://node-01:10002 --host 0.0.0.0 --port 31000 --worker-address http://$(hostname):31000 --num-gpus 2 +python3 -m fastchat.serve.vllm_worker --model-path timdettmers/guanaco-33b-merged --model-name guanaco-33b --controller http://node-01:10002 --host 0.0.0.0 --port 31002 --worker-address http://$(hostname):31002 --num-gpus 2 --tokenizer hf-internal-testing/llama-tokenizer +``` + +node-04 +``` +CUDA_VISIBLE_DEVICES=0 python3 -m fastchat.serve.multi_model_worker --model-path ~/model_weights/RWKV-4-Raven-14B-v12-Eng98%25-Other2%25-20230523-ctx8192.pth --model-name RWKV-4-Raven-14B --model-path lmsys/fastchat-t5-3b-v1.0 --model-name fastchat-t5-3b --controller http://node-01:10002 --host 0.0.0.0 --port 31000 --worker http://$(hostname):31000 --limit 4 +CUDA_VISIBLE_DEVICES=1 python3 -m fastchat.serve.multi_model_worker --model-path OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5 --model-name oasst-pythia-12b --model-path mosaicml/mpt-7b-chat --model-name mpt-7b-chat --controller http://node-01:10002 --host 0.0.0.0 --port 31001 --worker http://$(hostname):31001 --limit 4 +CUDA_VISIBLE_DEVICES=2 python3 -m fastchat.serve.multi_model_worker --model-path lmsys/vicuna-7b-v1.5 --model-name vicuna-7b --model-path THUDM/chatglm-6b --model-name chatglm-6b --controller http://node-01:10002 --host 0.0.0.0 --port 31002 --worker http://$(hostname):31002 --limit 4 +CUDA_VISIBLE_DEVICES=3 python3 -m fastchat.serve.vllm_worker --model-path ~/model_weights/alpaca-13b --controller http://node-01:10002 --host 0.0.0.0 --port 31003 --worker-address http://$(hostname):31003 +``` + +test +``` +python3 -m fastchat.serve.test_message --model vicuna-13b --controller http://localhost:10002 +``` diff --git a/third_party/FastChat/docs/commands/pypi.md b/third_party/FastChat/docs/commands/pypi.md new file mode 100644 index 00000000..5b53dae6 --- /dev/null +++ b/third_party/FastChat/docs/commands/pypi.md @@ -0,0 +1,11 @@ +### Requirement +``` +python3 -m pip install twine +python3 -m pip install --upgrade pip +pip3 install build +``` + +### Upload +``` +bash scripts/upload_pypi.sh +``` diff --git a/third_party/FastChat/docs/commands/webserver.md b/third_party/FastChat/docs/commands/webserver.md new file mode 100644 index 00000000..df96cf8d --- /dev/null +++ b/third_party/FastChat/docs/commands/webserver.md @@ -0,0 +1,94 @@ +### Install +``` +sudo apt update +sudo apt install tmux htop + +wget https://repo.anaconda.com/archive/Anaconda3-2022.10-Linux-x86_64.sh +bash Anaconda3-2022.10-Linux-x86_64.sh + +conda create -n fastchat python=3.9 +conda activate fastchat + +git clone https://github.com/lm-sys/FastChat.git +cd FastChat +pip3 install -e . +``` + + +### Launch servers +``` +cd fastchat_logs/controller +python3 -m fastchat.serve.controller --host 0.0.0.0 --port 21001 +python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name https:// +python3 -m fastchat.serve.test_message --model vicuna-13b --controller http://localhost:21001 + +cd fastchat_logs/server0 + +python3 -m fastchat.serve.huggingface_api_worker --model-info-file ~/elo_results/register_hf_api_models.json + +export OPENAI_API_KEY= +export ANTHROPIC_API_KEY= +export GCP_PROJECT_ID= + +python3 -m fastchat.serve.gradio_web_server_multi --controller http://localhost:21001 --concurrency 50 --add-chatgpt --add-claude --add-palm --elo ~/elo_results/elo_results.pkl --leaderboard-table-file ~/elo_results/leaderboard_table.csv --register ~/elo_results/register_oai_models.json --show-terms + +python3 backup_logs.py +``` + + +### Check the launch time +``` +for i in $(seq 0 11); do cat fastchat_logs/server$i/gradio_web_server.log | grep "Running on local URL" | tail -n 1; done +``` + + +### Increase the limit of max open files +One process (do not need reboot) +``` +sudo prlimit --nofile=1048576:1048576 --pid=$id + +for id in $(ps -ef | grep gradio_web_server | awk '{print $2}'); do echo $id; prlimit --nofile=1048576:1048576 --pid=$id; done +``` + +System (need reboot): Add the lines below to `/etc/security/limits.conf` +``` +* hard nofile 65535 +* soft nofile 65535 +``` + + +### Gradio edit (3.35.2) +1. gtag and canvas +``` +vim /home/vicuna/anaconda3/envs/fastchat/lib/python3.9/site-packages/gradio/templates/frontend/index.html +``` + +``` + + + +``` + +2. deprecation warnings +``` +vim /home/vicuna/anaconda3/envs/fastchat/lib/python3.9/site-packages/gradio/deprecation.py +``` + +``` +def check_deprecated_parameters( +``` + +3. Loading +``` +vim /home/vicuna/anaconda3/envs/fastchat/lib/python3.9/site-packages/gradio/templates/frontend/assets/index-188ef5e8.js +``` + +``` +%s/"Loading..."/"Loading...(Please refresh if it takes more than 30 seconds)"/g +``` diff --git a/third_party/FastChat/docs/dataset_release.md b/third_party/FastChat/docs/dataset_release.md new file mode 100644 index 00000000..add2c890 --- /dev/null +++ b/third_party/FastChat/docs/dataset_release.md @@ -0,0 +1,6 @@ +## Datasets +We release the following datasets based on our projects and websites. + +- [LMSYS-Chat-1M: A Large-Scale Real-World LLM Conversation Dataset](https://huggingface.co/datasets/lmsys/lmsys-chat-1m) +- [Chatbot Arena Conversation Dataset](https://huggingface.co/datasets/lmsys/chatbot_arena_conversations) +- [MT-bench Human Annotation Dataset](https://huggingface.co/datasets/lmsys/mt_bench_human_judgments) diff --git a/third_party/FastChat/docs/exllama_v2.md b/third_party/FastChat/docs/exllama_v2.md new file mode 100644 index 00000000..c9fc72bf --- /dev/null +++ b/third_party/FastChat/docs/exllama_v2.md @@ -0,0 +1,63 @@ +# ExllamaV2 GPTQ Inference Framework + +Integrated [ExllamaV2](https://github.com/turboderp/exllamav2) customized kernel into Fastchat to provide **Faster** GPTQ inference speed. + +**Note: Exllama not yet support embedding REST API.** + +## Install ExllamaV2 + +Setup environment (please refer to [this link](https://github.com/turboderp/exllamav2#how-to) for more details): + +```bash +git clone https://github.com/turboderp/exllamav2 +cd exllamav2 +pip install -e . +``` + +Chat with the CLI: +```bash +python3 -m fastchat.serve.cli \ + --model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \ + --enable-exllama +``` + +Start model worker: +```bash +# Download quantized model from huggingface +# Make sure you have git-lfs installed (https://git-lfs.com) +git lfs install +git clone https://huggingface.co/TheBloke/vicuna-7B-1.1-GPTQ-4bit-128g models/vicuna-7B-1.1-GPTQ-4bit-128g + +# Load model with default configuration (max sequence length 4096, no GPU split setting). +python3 -m fastchat.serve.model_worker \ + --model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \ + --enable-exllama + +#Load model with max sequence length 2048, allocate 18 GB to CUDA:0 and 24 GB to CUDA:1. +python3 -m fastchat.serve.model_worker \ + --model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \ + --enable-exllama \ + --exllama-max-seq-len 2048 \ + --exllama-gpu-split 18,24 +``` + +`--exllama-cache-8bit` can be used to enable 8-bit caching with exllama and save some VRAM. + +## Performance + +Reference: https://github.com/turboderp/exllamav2#performance + + +| Model | Mode | Size | grpsz | act | V1: 3090Ti | V1: 4090 | V2: 3090Ti | V2: 4090 | +|------------|--------------|-------|-------|-----|------------|----------|------------|-------------| +| Llama | GPTQ | 7B | 128 | no | 143 t/s | 173 t/s | 175 t/s | **195** t/s | +| Llama | GPTQ | 13B | 128 | no | 84 t/s | 102 t/s | 105 t/s | **110** t/s | +| Llama | GPTQ | 33B | 128 | yes | 37 t/s | 45 t/s | 45 t/s | **48** t/s | +| OpenLlama | GPTQ | 3B | 128 | yes | 194 t/s | 226 t/s | 295 t/s | **321** t/s | +| CodeLlama | EXL2 4.0 bpw | 34B | - | - | - | - | 42 t/s | **48** t/s | +| Llama2 | EXL2 3.0 bpw | 7B | - | - | - | - | 195 t/s | **224** t/s | +| Llama2 | EXL2 4.0 bpw | 7B | - | - | - | - | 164 t/s | **197** t/s | +| Llama2 | EXL2 5.0 bpw | 7B | - | - | - | - | 144 t/s | **160** t/s | +| Llama2 | EXL2 2.5 bpw | 70B | - | - | - | - | 30 t/s | **35** t/s | +| TinyLlama | EXL2 3.0 bpw | 1.1B | - | - | - | - | 536 t/s | **635** t/s | +| TinyLlama | EXL2 4.0 bpw | 1.1B | - | - | - | - | 509 t/s | **590** t/s | diff --git a/third_party/FastChat/docs/gptq.md b/third_party/FastChat/docs/gptq.md new file mode 100644 index 00000000..4078d1e0 --- /dev/null +++ b/third_party/FastChat/docs/gptq.md @@ -0,0 +1,59 @@ +# GPTQ 4bit Inference + +Support GPTQ 4bit inference with [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa). + +1. Window user: use the `old-cuda` branch. +2. Linux user: recommend the `fastest-inference-4bit` branch. + +## Install + +Setup environment: +```bash +# cd /path/to/FastChat +git clone https://github.com/qwopqwop200/GPTQ-for-LLaMa.git repositories/GPTQ-for-LLaMa +cd repositories/GPTQ-for-LLaMa +# Window's user should use the `old-cuda` branch +git switch fastest-inference-4bit +# Install `quant-cuda` package in FastChat's virtualenv +python3 setup_cuda.py install +pip3 install texttable +``` + +Chat with the CLI: +```bash +python3 -m fastchat.serve.cli \ + --model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \ + --gptq-wbits 4 \ + --gptq-groupsize 128 +``` + +Start model worker: +```bash +# Download quantized model from huggingface +# Make sure you have git-lfs installed (https://git-lfs.com) +git lfs install +git clone https://huggingface.co/TheBloke/vicuna-7B-1.1-GPTQ-4bit-128g models/vicuna-7B-1.1-GPTQ-4bit-128g + +python3 -m fastchat.serve.model_worker \ + --model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \ + --gptq-wbits 4 \ + --gptq-groupsize 128 + +# You can specify which quantized model to use +python3 -m fastchat.serve.model_worker \ + --model-path models/vicuna-7B-1.1-GPTQ-4bit-128g \ + --gptq-ckpt models/vicuna-7B-1.1-GPTQ-4bit-128g/vicuna-7B-1.1-GPTQ-4bit-128g.safetensors \ + --gptq-wbits 4 \ + --gptq-groupsize 128 \ + --gptq-act-order +``` + +## Benchmark + +| LLaMA-13B | branch | Bits | group-size | memory(MiB) | PPL(c4) | Median(s/token) | act-order | speed up | +| --------- | ---------------------- | ---- | ---------- | ----------- | ------- | --------------- | --------- | -------- | +| FP16 | fastest-inference-4bit | 16 | - | 26634 | 6.96 | 0.0383 | - | 1x | +| GPTQ | triton | 4 | 128 | 8590 | 6.97 | 0.0551 | - | 0.69x | +| GPTQ | fastest-inference-4bit | 4 | 128 | 8699 | 6.97 | 0.0429 | true | 0.89x | +| GPTQ | fastest-inference-4bit | 4 | 128 | 8699 | 7.03 | 0.0287 | false | 1.33x | +| GPTQ | fastest-inference-4bit | 4 | -1 | 8448 | 7.12 | 0.0284 | false | 1.44x | diff --git a/third_party/FastChat/docs/langchain_integration.md b/third_party/FastChat/docs/langchain_integration.md new file mode 100644 index 00000000..50174a85 --- /dev/null +++ b/third_party/FastChat/docs/langchain_integration.md @@ -0,0 +1,90 @@ +# Local LangChain with FastChat + +[LangChain](https://python.langchain.com/en/latest/index.html) is a library that facilitates the development of applications by leveraging large language models (LLMs) and enabling their composition with other sources of computation or knowledge. +FastChat's OpenAI-compatible [API server](openai_api.md) enables using LangChain with open models seamlessly. + +## Launch RESTful API Server + +Here are the steps to launch a local OpenAI API server for LangChain. + +First, launch the controller + +```bash +python3 -m fastchat.serve.controller +``` + +LangChain uses OpenAI model names by default, so we need to assign some faux OpenAI model names to our local model. +Here, we use Vicuna as an example and use it for three endpoints: chat completion, completion, and embedding. +`--model-path` can be a local folder or a Hugging Face repo name. +See a full list of supported models [here](../README.md#supported-models). + +```bash +python3 -m fastchat.serve.model_worker --model-names "gpt-3.5-turbo,text-davinci-003,text-embedding-ada-002" --model-path lmsys/vicuna-7b-v1.5 +``` + +Finally, launch the RESTful API server + +```bash +python3 -m fastchat.serve.openai_api_server --host localhost --port 8000 +``` + +## Set OpenAI Environment + +You can set your environment with the following commands. + +Set OpenAI base url + +```bash +export OPENAI_API_BASE=http://localhost:8000/v1 +``` + +Set OpenAI API key + +```bash +export OPENAI_API_KEY=EMPTY +``` + +If you meet the following OOM error while creating embeddings, please set a smaller batch size by using environment variables. + +~~~bash +openai.error.APIError: Invalid response object from API: '{"object":"error","message":"**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**\\n\\n(CUDA out of memory. Tried to allocate xxx MiB (GPU 0; xxx GiB total capacity; xxx GiB already allocated; xxx MiB free; xxx GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF)","code":50002}' (HTTP response code was 400) +~~~ + +You can try `export FASTCHAT_WORKER_API_EMBEDDING_BATCH_SIZE=1`. + +## Try local LangChain + +Here is a question answerting example. + +Download a text file. + +```bash +wget https://raw.githubusercontent.com/hwchase17/langchain/v0.0.200/docs/modules/state_of_the_union.txt +``` + +Run LangChain. + +~~~py +from langchain.chat_models import ChatOpenAI +from langchain.document_loaders import TextLoader +from langchain.embeddings import OpenAIEmbeddings +from langchain.indexes import VectorstoreIndexCreator + +embedding = OpenAIEmbeddings(model="text-embedding-ada-002") +loader = TextLoader("state_of_the_union.txt") +index = VectorstoreIndexCreator(embedding=embedding).from_loaders([loader]) +llm = ChatOpenAI(model="gpt-3.5-turbo") + +questions = [ + "Who is the speaker", + "What did the president say about Ketanji Brown Jackson", + "What are the threats to America", + "Who are mentioned in the speech", + "Who is the vice president", + "How many projects were announced", +] + +for query in questions: + print("Query:", query) + print("Answer:", index.query(query, llm=llm)) +~~~ diff --git a/third_party/FastChat/docs/model_support.md b/third_party/FastChat/docs/model_support.md new file mode 100644 index 00000000..472cc0a4 --- /dev/null +++ b/third_party/FastChat/docs/model_support.md @@ -0,0 +1,97 @@ +# Model Support + +## Supported models + +- [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) + - example: `python3 -m fastchat.serve.cli --model-path meta-llama/Llama-2-7b-chat-hf` +- Vicuna, Alpaca, LLaMA, Koala + - example: `python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5` +- [allenai/tulu-2-dpo-7b](https://huggingface.co/allenai/tulu-2-dpo-7b) +- [BAAI/AquilaChat-7B](https://huggingface.co/BAAI/AquilaChat-7B) +- [BAAI/AquilaChat2-7B](https://huggingface.co/BAAI/AquilaChat2-7B) +- [BAAI/AquilaChat2-34B](https://huggingface.co/BAAI/AquilaChat2-34B) +- [BAAI/bge-large-en](https://huggingface.co/BAAI/bge-large-en#using-huggingface-transformers) +- [baichuan-inc/baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B) +- [BlinkDL/RWKV-4-Raven](https://huggingface.co/BlinkDL/rwkv-4-raven) + - example: `python3 -m fastchat.serve.cli --model-path ~/model_weights/RWKV-4-Raven-7B-v11x-Eng99%-Other1%-20230429-ctx8192.pth` +- [bofenghuang/vigogne-2-7b-instruct](https://huggingface.co/bofenghuang/vigogne-2-7b-instruct) +- [bofenghuang/vigogne-2-7b-chat](https://huggingface.co/bofenghuang/vigogne-2-7b-chat) +- [camel-ai/CAMEL-13B-Combined-Data](https://huggingface.co/camel-ai/CAMEL-13B-Combined-Data) +- [codellama/CodeLlama-7b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf) +- [databricks/dolly-v2-12b](https://huggingface.co/databricks/dolly-v2-12b) +- [deepseek-ai/deepseek-llm-67b-chat](https://huggingface.co/deepseek-ai/deepseek-llm-67b-chat) +- [deepseek-ai/deepseek-coder-33b-instruct](https://huggingface.co/deepseek-ai/deepseek-coder-33b-instruct) +- [FlagAlpha/Llama2-Chinese-13b-Chat](https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat) +- [FreedomIntelligence/phoenix-inst-chat-7b](https://huggingface.co/FreedomIntelligence/phoenix-inst-chat-7b) +- [FreedomIntelligence/ReaLM-7b-v1](https://huggingface.co/FreedomIntelligence/Realm-7b) +- [h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b](https://huggingface.co/h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b) +- [HuggingFaceH4/starchat-beta](https://huggingface.co/HuggingFaceH4/starchat-beta) +- [HuggingFaceH4/zephyr-7b-alpha](https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha) +- [internlm/internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b) +- [lcw99/polyglot-ko-12.8b-chang-instruct-chat](https://huggingface.co/lcw99/polyglot-ko-12.8b-chang-instruct-chat) +- [lmsys/fastchat-t5-3b-v1.0](https://huggingface.co/lmsys/fastchat-t5) +- [meta-math/MetaMath-7B-V1.0](https://huggingface.co/meta-math/MetaMath-7B-V1.0) +- [Microsoft/Orca-2-7b](https://huggingface.co/microsoft/Orca-2-7b) +- [mosaicml/mpt-7b-chat](https://huggingface.co/mosaicml/mpt-7b-chat) + - example: `python3 -m fastchat.serve.cli --model-path mosaicml/mpt-7b-chat` +- [Neutralzz/BiLLa-7B-SFT](https://huggingface.co/Neutralzz/BiLLa-7B-SFT) +- [nomic-ai/gpt4all-13b-snoozy](https://huggingface.co/nomic-ai/gpt4all-13b-snoozy) +- [argilla/notus-7b-v1](https://huggingface.co/argilla/notus-7b-v1) +- [NousResearch/Nous-Hermes-13b](https://huggingface.co/NousResearch/Nous-Hermes-13b) +- [openaccess-ai-collective/manticore-13b-chat-pyg](https://huggingface.co/openaccess-ai-collective/manticore-13b-chat-pyg) +- [OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5](https://huggingface.co/OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5) +- [openchat/openchat_3.5](https://huggingface.co/openchat/openchat_3.5) +- [Open-Orca/Mistral-7B-OpenOrca](https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca) +- [OpenLemur/lemur-70b-chat-v1](https://huggingface.co/OpenLemur/lemur-70b-chat-v1) +- [VMware/open-llama-7b-v2-open-instruct](https://huggingface.co/VMware/open-llama-7b-v2-open-instruct) +- [Phind/Phind-CodeLlama-34B-v2](https://huggingface.co/Phind/Phind-CodeLlama-34B-v2) +- [project-baize/baize-v2-7b](https://huggingface.co/project-baize/baize-v2-7b) +- [Qwen/Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat) +- [rishiraj/CatPPT](https://huggingface.co/rishiraj/CatPPT) +- [tenyx/TenyxChat-7B-v1](https://huggingface.co/tenyx/TenyxChat-7B-v1) +- [TinyLlama/TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0) +- [Salesforce/codet5p-6b](https://huggingface.co/Salesforce/codet5p-6b) +- [StabilityAI/stablelm-tuned-alpha-7b](https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b) +- [THUDM/chatglm-6b](https://huggingface.co/THUDM/chatglm-6b) +- [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) +- [tiiuae/falcon-40b](https://huggingface.co/tiiuae/falcon-40b) +- [tiiuae/falcon-180B-chat](https://huggingface.co/tiiuae/falcon-180B-chat) +- [timdettmers/guanaco-33b-merged](https://huggingface.co/timdettmers/guanaco-33b-merged) +- [togethercomputer/RedPajama-INCITE-7B-Chat](https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Chat) +- [WizardLM/WizardLM-13B-V1.0](https://huggingface.co/WizardLM/WizardLM-13B-V1.0) +- [WizardLM/WizardCoder-15B-V1.0](https://huggingface.co/WizardLM/WizardCoder-15B-V1.0) +- [Xwin-LM/Xwin-LM-7B-V0.1](https://huggingface.co/Xwin-LM/Xwin-LM-70B-V0.1) +- [IEITYuan/Yuan2-2B/51B/102B-hf](https://huggingface.co/IEITYuan) +- Any [EleutherAI](https://huggingface.co/EleutherAI) pythia model such as [pythia-6.9b](https://huggingface.co/EleutherAI/pythia-6.9b) +- Any [Peft](https://github.com/huggingface/peft) adapter trained on top of a + model above. To activate, must have `peft` in the model path. Note: If + loading multiple peft models, you can have them share the base model weights by + setting the environment variable `PEFT_SHARE_BASE_WEIGHTS=true` in any model + worker. + +## How to support a new model + +To support a new model in FastChat, you need to correctly handle its prompt template and model loading. +The goal is to make the following command run with the correct prompts. + +``` +python3 -m fastchat.serve.cli --model [YOUR_MODEL_PATH] +``` + +You can run this example command to learn the code logic. + +``` +python3 -m fastchat.serve.cli --model lmsys/vicuna-7b-v1.5 +``` + +You can add `--debug` to see the actual prompt sent to the model. + +### Steps + +FastChat uses the `Conversation` class to handle prompt templates and `BaseModelAdapter` class to handle model loading. + +1. Implement a conversation template for the new model at [fastchat/conversation.py](https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py). You can follow existing examples and use `register_conv_template` to add a new one. Please also add a link to the official reference code if possible. +2. Implement a model adapter for the new model at [fastchat/model/model_adapter.py](https://github.com/lm-sys/FastChat/blob/main/fastchat/model/model_adapter.py). You can follow existing examples and use `register_model_adapter` to add a new one. +3. (Optional) add the model name to the "Supported models" [section](#supported-models) above and add more information in [fastchat/model/model_registry.py](https://github.com/lm-sys/FastChat/blob/main/fastchat/model/model_registry.py). + +After these steps, the new model should be compatible with most FastChat features, such as CLI, web UI, model worker, and OpenAI-compatible API server. Please do some testing with these features as well. diff --git a/third_party/FastChat/docs/openai_api.md b/third_party/FastChat/docs/openai_api.md new file mode 100644 index 00000000..65ab4f76 --- /dev/null +++ b/third_party/FastChat/docs/openai_api.md @@ -0,0 +1,150 @@ +# OpenAI-Compatible RESTful APIs + +FastChat provides OpenAI-compatible APIs for its supported models, so you can use FastChat as a local drop-in replacement for OpenAI APIs. +The FastChat server is compatible with both [openai-python](https://github.com/openai/openai-python) library and cURL commands. + +The following OpenAI APIs are supported: +- Chat Completions. (Reference: https://platform.openai.com/docs/api-reference/chat) +- Completions. (Reference: https://platform.openai.com/docs/api-reference/completions) +- Embeddings. (Reference: https://platform.openai.com/docs/api-reference/embeddings) + +## RESTful API Server +First, launch the controller + +```bash +python3 -m fastchat.serve.controller +``` + +Then, launch the model worker(s) + +```bash +python3 -m fastchat.serve.model_worker --model-path lmsys/vicuna-7b-v1.5 +``` + +Finally, launch the RESTful API server + +```bash +python3 -m fastchat.serve.openai_api_server --host localhost --port 8000 +``` + +Now, let us test the API server. + +### OpenAI Official SDK +The goal of `openai_api_server.py` is to implement a fully OpenAI-compatible API server, so the models can be used directly with [openai-python](https://github.com/openai/openai-python) library. + +First, install OpenAI python package >= 1.0: +```bash +pip install --upgrade openai +``` + +Then, interact with the Vicuna model: +```python +import openai + +openai.api_key = "EMPTY" +openai.base_url = "http://localhost:8000/v1/" + +model = "vicuna-7b-v1.5" +prompt = "Once upon a time" + +# create a completion +completion = openai.completions.create(model=model, prompt=prompt, max_tokens=64) +# print the completion +print(prompt + completion.choices[0].text) + +# create a chat completion +completion = openai.chat.completions.create( + model=model, + messages=[{"role": "user", "content": "Hello! What is your name?"}] +) +# print the completion +print(completion.choices[0].message.content) +``` + +Streaming is also supported. See [test_openai_api.py](../tests/test_openai_api.py). If your api server is behind a proxy you'll need to turn off buffering, you can do so in Nginx by setting `proxy_buffering off;` in the location block for the proxy. + +### cURL +cURL is another good tool for observing the output of the api. + +List Models: +```bash +curl http://localhost:8000/v1/models +``` + +Chat Completions: +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "vicuna-7b-v1.5", + "messages": [{"role": "user", "content": "Hello! What is your name?"}] + }' +``` + +Text Completions: +```bash +curl http://localhost:8000/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "vicuna-7b-v1.5", + "prompt": "Once upon a time", + "max_tokens": 41, + "temperature": 0.5 + }' +``` + +Embeddings: +```bash +curl http://localhost:8000/v1/embeddings \ + -H "Content-Type: application/json" \ + -d '{ + "model": "vicuna-7b-v1.5", + "input": "Hello world!" + }' +``` + +### Running multiple + +If you want to run multiple models on the same machine and in the same process, +you can replace the `model_worker` step above with a multi model variant: + +```bash +python3 -m fastchat.serve.multi_model_worker \ + --model-path lmsys/vicuna-7b-v1.5 \ + --model-names vicuna-7b-v1.5 \ + --model-path lmsys/longchat-7b-16k \ + --model-names longchat-7b-16k +``` + +This loads both models into the same accelerator and in the same process. This +works best when using a Peft model that triggers the `PeftModelAdapter`. + +TODO: Base model weight optimization will be fixed once [this +Peft](https://github.com/huggingface/peft/issues/430) issue is resolved. + +## LangChain Support +This OpenAI-compatible API server supports LangChain. See [LangChain Integration](langchain_integration.md) for details. + +## Adjusting Environment Variables + +### Timeout +By default, a timeout error will occur if a model worker does not response within 100 seconds. If your model/hardware is slower, you can change this timeout through an environment variable: + +```bash +export FASTCHAT_WORKER_API_TIMEOUT= +``` + +### Batch size +If you meet the following OOM error while creating embeddings. You can use a smaller batch size by setting + +```bash +export FASTCHAT_WORKER_API_EMBEDDING_BATCH_SIZE=1 +``` + +## Todos +Some features to be implemented: + +- [ ] Support more parameters like `logprobs`, `logit_bias`, `user`, `presence_penalty` and `frequency_penalty` +- [ ] Model details (permissions, owner and create time) +- [ ] Edits API +- [ ] Rate Limitation Settings diff --git a/third_party/FastChat/docs/server_arch.md b/third_party/FastChat/docs/server_arch.md new file mode 100644 index 00000000..1ccc8a16 --- /dev/null +++ b/third_party/FastChat/docs/server_arch.md @@ -0,0 +1,2 @@ +# FastChat Server Architecture +![server arch](../assets/server_arch.png) diff --git a/third_party/FastChat/docs/third_party_ui.md b/third_party/FastChat/docs/third_party_ui.md new file mode 100644 index 00000000..c0b23015 --- /dev/null +++ b/third_party/FastChat/docs/third_party_ui.md @@ -0,0 +1,24 @@ +# Third Party UI +If you want to host it on your own UI or third party UI, you can launch the [OpenAI compatible server](openai_api.md) and host with a tunnelling service such as Tunnelmole or ngrok, and then enter the credentials appropriately. + +You can find suitable UIs from third party repos: +- [WongSaang's ChatGPT UI](https://github.com/WongSaang/chatgpt-ui) +- [McKayWrigley's Chatbot UI](https://github.com/mckaywrigley/chatbot-ui) + +- Please note that some third-party providers only offer the standard `gpt-3.5-turbo`, `gpt-4`, etc., so you will have to add your own custom model inside the code. [Here is an example of how to create a UI with any custom model name](https://github.com/ztjhz/BetterChatGPT/pull/461). + +##### Using Tunnelmole +Tunnelmole is an open source tunnelling tool. You can find its source code on [Github](https://github.com/robbie-cahill/tunnelmole-client). Here's how you can use Tunnelmole: +1. Install Tunnelmole with `curl -O https://install.tunnelmole.com/9Wtxu/install && sudo bash install`. (On Windows, download [tmole.exe](https://tunnelmole.com/downloads/tmole.exe)). Head over to the [README](https://github.com/robbie-cahill/tunnelmole-client) for other methods such as `npm` or building from source. +2. Run `tmole 7860` (replace `7860` with your listening port if it is different from 7860). The output will display two URLs: one HTTP and one HTTPS. It's best to use the HTTPS URL for better privacy and security. +``` +➜ ~ tmole 7860 +http://bvdo5f-ip-49-183-170-144.tunnelmole.net is forwarding to localhost:7860 +https://bvdo5f-ip-49-183-170-144.tunnelmole.net is forwarding to localhost:7860 +``` + +##### Using ngrok +ngrok is a popular closed source tunnelling tool. First download and install it from [ngrok.com](https://ngrok.com/downloads). Here's how to use it to expose port 7860. +``` +ngrok http 7860 +``` diff --git a/third_party/FastChat/docs/training.md b/third_party/FastChat/docs/training.md new file mode 100644 index 00000000..87b87312 --- /dev/null +++ b/third_party/FastChat/docs/training.md @@ -0,0 +1,118 @@ +### Fine-tuning FastChat-T5 +You can use the following command to train FastChat-T5 with 4 x A100 (40GB). +```bash +torchrun --nproc_per_node=4 --master_port=9778 fastchat/train/train_flant5.py \ + --model_name_or_path google/flan-t5-xl \ + --data_path ./data/dummy_conversation.json \ + --bf16 True \ + --output_dir ./checkpoints_flant5_3b \ + --num_train_epochs 3 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 4 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 300 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --fsdp "full_shard auto_wrap" \ + --fsdp_transformer_layer_cls_to_wrap T5Block \ + --tf32 True \ + --model_max_length 2048 \ + --preprocessed_path ./preprocessed_data/processed.json \ + --gradient_checkpointing True +``` + +After training, please use our post-processing [function](https://github.com/lm-sys/FastChat/blob/55051ad0f23fef5eeecbda14a2e3e128ffcb2a98/fastchat/utils.py#L166-L185) to update the saved model weight. Additional discussions can be found [here](https://github.com/lm-sys/FastChat/issues/643). + +### Fine-tuning using (Q)LoRA +You can use the following command to train Vicuna-7B using QLoRA using ZeRO2. Note that ZeRO3 is not currently supported with QLoRA but ZeRO3 does support LoRA, which has a reference configuraiton under playground/deepspeed_config_s3.json. To use QLoRA, you must have bitsandbytes>=0.39.0 and transformers>=4.30.0 installed. +```bash +deepspeed fastchat/train/train_lora.py \ + --model_name_or_path ~/model_weights/llama-7b \ + --lora_r 8 \ + --lora_alpha 16 \ + --lora_dropout 0.05 \ + --data_path ./data/dummy_conversation.json \ + --bf16 True \ + --output_dir ./checkpoints \ + --num_train_epochs 3 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 1200 \ + --save_total_limit 100 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --q_lora True \ + --deepspeed playground/deepspeed_config_s2.json \ +``` + +For T5-XL or XXL + +```bash +deepspeed fastchat/train/train_lora_t5.py \ + --model_name_or_path google/flan-t5-xl \ + --data_path ./data/dummy_conversation.json \ + --bf16 True \ + --output_dir ./checkpoints_flant5_3b \ + --num_train_epochs 3 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 4 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 300 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --model_max_length 2048 \ + --preprocessed_path ./preprocessed_data/processed.json \ + --gradient_checkpointing True \ + --q_lora True \ + --deepspeed playground/deepspeed_config_s2.json + +``` + +### Fine-tuning Vicuna-7B with Local NPUs + +You can use the following command to train Vicuna-7B with 8 x NPUs. Use `--nproc_per_node` to specify the number of NPUs. +```bash +torchrun --nproc_per_node=8 --master_port=20001 fastchat/train/train.py \ + --model_name_or_path ~/vicuna-7b-v1.5-16k \ + --data_path data/dummy_conversation.json \ + --fp16 True \ + --output_dir output_vicuna \ + --num_train_epochs 3 \ + --per_device_train_batch_size 8 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 1200 \ + --save_total_limit 10 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --fsdp "full_shard auto_wrap" \ + --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --lazy_preprocess True +``` diff --git a/third_party/FastChat/docs/vicuna_weights_version.md b/third_party/FastChat/docs/vicuna_weights_version.md new file mode 100644 index 00000000..7fbb5625 --- /dev/null +++ b/third_party/FastChat/docs/vicuna_weights_version.md @@ -0,0 +1,97 @@ +## Vicuna Weights + +| Weights version | Link | FastChat version compatibility | Base Model | Release Date | Fine-tuning Data | +| ---- | ---- | ---- | ---- | ---- | ---- | +| v1.5 | [7B](https://huggingface.co/lmsys/vicuna-7b-v1.5), [7B-16k](https://huggingface.co/lmsys/vicuna-7b-v1.5-16k), [13B](https://huggingface.co/lmsys/vicuna-13b-v1.5), [13B-16k](https://huggingface.co/lmsys/vicuna-13b-v1.5-16k) | `>=0.2.21` | Llama 2 | Aug. 1, 2023 | 370M tokens | +| v1.3 | [7B](https://huggingface.co/lmsys/vicuna-7b-v1.3), [13B](https://huggingface.co/lmsys/vicuna-13b-v1.3), [33B](//huggingface.co/lmsys/vicuna-33b-v1.3) | `>=0.2.1` | Llama 1 | Jun. 22, 2023 | 370M tokens | +| v1.1 | [7B](https://huggingface.co/lmsys/vicuna-7b-v1.1), [13B](https://huggingface.co/lmsys/vicuna-13b-v1.1) | `>=0.2.1` | Llama 1 | Apr. 12, 2023 | - | +| v0 | [7B-delta](https://huggingface.co/lmsys/vicuna-7b-delta-v0), [13B-delta](https://huggingface.co/lmsys/vicuna-13b-delta-v0) | `<=0.1.10` | Llama 1 | Mar. 30, 2023 | - | + +### Updates +- Major updates of weights v1.5 + - Use Llama2 as the base model. + - Provide 16K context length versions using linear RoPE scaling. + +- Major updates of weights v1.3 + - Train with twice the amount of ShareGPT data compared to previous versions. + - Provide merged weights directly instead of delta weights. + +- Major updates of weights v1.1 + - Refactor the tokenization and separator. In Vicuna v1.1, the separator has been changed from `###` to the EOS token `
`. This change makes it easier to determine the generation stop criteria and enables better compatibility with other libraries. + - Fix the supervised fine-tuning loss computation for better model quality. + +## Prompt Template + +### Example prompt (weights v1.1, v1.3, v1.5) +``` +A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. + +USER: Hello! +ASSISTANT: Hello! +USER: How are you? +ASSISTANT: I am good. +``` + +See a full prompt template [here](https://github.com/lm-sys/FastChat/blob/d578599c69d060e6d40943f1b5b72af98956092a/fastchat/conversation.py#L286-L299) and example output [here](https://github.com/lm-sys/FastChat/blob/d578599c69d060e6d40943f1b5b72af98956092a/fastchat/conversation.py#L748-L753). + +### Example prompt (weights v0) +``` +A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. + +### Human: Hello! +### Assistant: Hello! +### Human: How are you? +### Assistant: I am good. +``` + +See the full prompt template [here](https://github.com/lm-sys/FastChat/blob/d578599c69d060e6d40943f1b5b72af98956092a/fastchat/conversation.py#L238-L269). + +## How to Apply Delta Weights (Only Needed for Weights v0) + +We release [Vicuna](https://lmsys.org/blog/2023-03-30-vicuna/) weights v0 as delta weights to comply with the LLaMA model license. +You can add our delta to the original LLaMA weights to obtain the Vicuna weights. Instructions: + +1. Get the original LLaMA weights in the Hugging Face format by following the instructions [here](https://huggingface.co/docs/transformers/main/model_doc/llama). +2. Use the following scripts to get Vicuna weights by applying our delta. They will automatically download delta weights from our Hugging Face [account](https://huggingface.co/lmsys). + +**NOTE**: +Weights v1.1 are only compatible with ```transformers>=4.28.0``` and ``fschat >= 0.2.0``. +Please update your local packages accordingly. If you follow the above commands to do a fresh install, then you should get all the correct versions. + +#### Vicuna-7B +This conversion command needs around 30 GB of CPU RAM. +See the "Low CPU Memory Conversion" section below if you do not have enough memory. +Replace `/path/to/*` with the real paths. +```bash +python3 -m fastchat.model.apply_delta \ + --base-model-path /path/to/llama-7b \ + --target-model-path /path/to/output/vicuna-7b \ + --delta-path lmsys/vicuna-7b-delta-v1.1 +``` + +#### Vicuna-13B +This conversion command needs around 60 GB of CPU RAM. +See the "Low CPU Memory Conversion" section below if you do not have enough memory. +Replace `/path/to/*` with the real paths. +```bash +python3 -m fastchat.model.apply_delta \ + --base-model-path /path/to/llama-13b \ + --target-model-path /path/to/output/vicuna-13b \ + --delta-path lmsys/vicuna-13b-delta-v1.1 +``` + +#### Low CPU Memory Conversion +You can try these methods to reduce the CPU RAM requirement of weight conversion. +1. Append `--low-cpu-mem` to the commands above, which will split large weight files into smaller ones and use the disk as temporary storage. This can keep the peak memory at less than 16GB. +2. Create a large swap file and rely on the operating system to automatically utilize the disk as virtual memory. + +## FAQ + +### Tokenizer issues +There are some frequently asked tokenizer issues (https://github.com/lm-sys/FastChat/issues/408). +Some of them are not only related to FastChat or Vicuna weights but are also related to how you convert the base llama model. + +We suggest that you use `transformers>=4.28.0` and redo the weight conversion for the base llama model. +After applying the delta, you should have a file named `special_tokens_map.json` in your converted weight folder for either v0 or v1.1. +The contents of this file should be the same as this file: https://huggingface.co/lmsys/vicuna-13b-delta-v0/blob/main/special_tokens_map.json. +If the file is not present, please copy the `special_tokens_map.json` and `tokenizer_config.json` files from https://huggingface.co/lmsys/vicuna-13b-delta-v0/tree/main to your converted weight folder. This works for both v0 and v1.1. diff --git a/third_party/FastChat/docs/vllm_integration.md b/third_party/FastChat/docs/vllm_integration.md new file mode 100644 index 00000000..7d3205bb --- /dev/null +++ b/third_party/FastChat/docs/vllm_integration.md @@ -0,0 +1,25 @@ +# vLLM Integration +You can use [vLLM](https://vllm.ai/) as an optimized worker implementation in FastChat. +It offers advanced continuous batching and a much higher (~10x) throughput. +See the supported models [here](https://vllm.readthedocs.io/en/latest/models/supported_models.html). + +## Instructions +1. Install vLLM. + ``` + pip install vllm + ``` + +2. When you launch a model worker, replace the normal worker (`fastchat.serve.model_worker`) with the vLLM worker (`fastchat.serve.vllm_worker`). All other commands such as controller, gradio web server, and OpenAI API server are kept the same. + ``` + python3 -m fastchat.serve.vllm_worker --model-path lmsys/vicuna-7b-v1.5 + ``` + + If you see tokenizer errors, try + ``` + python3 -m fastchat.serve.vllm_worker --model-path lmsys/vicuna-7b-v1.5 --tokenizer hf-internal-testing/llama-tokenizer + ``` + + If you use an AWQ quantized model, try + ''' + python3 -m fastchat.serve.vllm_worker --model-path TheBloke/vicuna-7B-v1.5-AWQ --quantization awq + ''' diff --git a/third_party/FastChat/docs/xFasterTransformer.md b/third_party/FastChat/docs/xFasterTransformer.md new file mode 100644 index 00000000..e9cdae43 --- /dev/null +++ b/third_party/FastChat/docs/xFasterTransformer.md @@ -0,0 +1,90 @@ +# xFasterTransformer Inference Framework + +Integrated [xFasterTransformer](https://github.com/intel/xFasterTransformer) customized framework into Fastchat to provide **Faster** inference speed on Intel CPU. + +## Install xFasterTransformer + +Setup environment (please refer to [this link](https://github.com/intel/xFasterTransformer#installation) for more details): + +```bash +pip install xfastertransformer +``` + +## Prepare models + +Prepare Model (please refer to [this link](https://github.com/intel/xFasterTransformer#prepare-model) for more details): +```bash +python ./tools/chatglm_convert.py -i ${HF_DATASET_DIR} -o ${OUTPUT_DIR} +``` + +## Parameters of xFasterTransformer +--enable-xft to enable xfastertransformer in Fastchat +--xft-max-seq-len to set the max token length the model can process. max token length include input token length. +--xft-dtype to set datatype used in xFasterTransformer for computation. xFasterTransformer can support fp32, fp16, int8, bf16 and hybrid data types like : bf16_fp16, bf16_int8. For datatype details please refer to [this link](https://github.com/intel/xFasterTransformer/wiki/Data-Type-Support-Platform) + + +Chat with the CLI: +```bash +#run inference on all CPUs and using float16 +python3 -m fastchat.serve.cli \ + --model-path /path/to/models \ + --enable-xft \ + --xft-dtype fp16 +``` +or with numactl on multi-socket server for better performance +```bash +#run inference on numanode 0 and with data type bf16_fp16 (first token uses bfloat16, and rest tokens use float16) +numactl -N 0 --localalloc \ +python3 -m fastchat.serve.cli \ + --model-path /path/to/models/chatglm2_6b_cpu/ \ + --enable-xft \ + --xft-dtype bf16_fp16 +``` +or using MPI to run inference on 2 sockets for better performance +```bash +#run inference on numanode 0 and 1 and with data type bf16_fp16 (first token uses bfloat16, and rest tokens use float16) +OMP_NUM_THREADS=$CORE_NUM_PER_SOCKET LD_PRELOAD=libiomp5.so mpirun \ +-n 1 numactl -N 0 --localalloc \ +python -m fastchat.serve.cli \ + --model-path /path/to/models/chatglm2_6b_cpu/ \ + --enable-xft \ + --xft-dtype bf16_fp16 : \ +-n 1 numactl -N 1 --localalloc \ +python -m fastchat.serve.cli \ + --model-path /path/to/models/chatglm2_6b_cpu/ \ + --enable-xft \ + --xft-dtype bf16_fp16 +``` + + +Start model worker: +```bash +# Load model with default configuration (max sequence length 4096, no GPU split setting). +python3 -m fastchat.serve.model_worker \ + --model-path /path/to/models \ + --enable-xft \ + --xft-dtype bf16_fp16 +``` +or with numactl on multi-socket server for better performance +```bash +#run inference on numanode 0 and with data type bf16_fp16 (first token uses bfloat16, and rest tokens use float16) +numactl -N 0 --localalloc python3 -m fastchat.serve.model_worker \ + --model-path /path/to/models \ + --enable-xft \ + --xft-dtype bf16_fp16 +``` +or using MPI to run inference on 2 sockets for better performance +```bash +#run inference on numanode 0 and 1 and with data type bf16_fp16 (first token uses bfloat16, and rest tokens use float16) +OMP_NUM_THREADS=$CORE_NUM_PER_SOCKET LD_PRELOAD=libiomp5.so mpirun \ +-n 1 numactl -N 0 --localalloc python -m fastchat.serve.model_worker \ + --model-path /path/to/models \ + --enable-xft \ + --xft-dtype bf16_fp16 : \ +-n 1 numactl -N 1 --localalloc python -m fastchat.serve.model_worker \ + --model-path /path/to/models \ + --enable-xft \ + --xft-dtype bf16_fp16 +``` + +For more details, please refer to [this link](https://github.com/intel/xFasterTransformer#how-to-run) diff --git a/third_party/FastChat/fastchat/__init__.py b/third_party/FastChat/fastchat/__init__.py new file mode 100644 index 00000000..9b1b78b4 --- /dev/null +++ b/third_party/FastChat/fastchat/__init__.py @@ -0,0 +1 @@ +__version__ = "0.2.35" diff --git a/third_party/FastChat/fastchat/constants.py b/third_party/FastChat/fastchat/constants.py new file mode 100644 index 00000000..24e1783a --- /dev/null +++ b/third_party/FastChat/fastchat/constants.py @@ -0,0 +1,66 @@ +""" +Global constants. +""" + +from enum import IntEnum +import os + +REPO_PATH = os.path.dirname(os.path.dirname(__file__)) + +##### For the gradio web server +SERVER_ERROR_MSG = ( + "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" +) +MODERATION_MSG = "$MODERATION$ YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES." +CONVERSATION_LIMIT_MSG = "YOU HAVE REACHED THE CONVERSATION LENGTH LIMIT. PLEASE CLEAR HISTORY AND START A NEW CONVERSATION." +INACTIVE_MSG = "THIS SESSION HAS BEEN INACTIVE FOR TOO LONG. PLEASE REFRESH THIS PAGE." +SLOW_MODEL_MSG = "⚠️ Both models will show the responses all at once. Please stay patient as it may take over 30 seconds." +RATE_LIMIT_MSG = "**RATE LIMIT OF THIS MODEL IS REACHED. PLEASE COME BACK LATER OR TRY OTHER MODELS.**" +# Maximum input length +INPUT_CHAR_LEN_LIMIT = int(os.getenv("FASTCHAT_INPUT_CHAR_LEN_LIMIT", 12000)) +# Maximum conversation turns +CONVERSATION_TURN_LIMIT = 50 +# Session expiration time +SESSION_EXPIRATION_TIME = 3600 +# The output dir of log files +LOGDIR = os.getenv("LOGDIR", ".") +# CPU Instruction Set Architecture +CPU_ISA = os.getenv("CPU_ISA") + + +##### For the controller and workers (could be overwritten through ENV variables.) +CONTROLLER_HEART_BEAT_EXPIRATION = int( + os.getenv("FASTCHAT_CONTROLLER_HEART_BEAT_EXPIRATION", 90) +) +WORKER_HEART_BEAT_INTERVAL = int(os.getenv("FASTCHAT_WORKER_HEART_BEAT_INTERVAL", 45)) +WORKER_API_TIMEOUT = int(os.getenv("FASTCHAT_WORKER_API_TIMEOUT", 100)) +WORKER_API_EMBEDDING_BATCH_SIZE = int( + os.getenv("FASTCHAT_WORKER_API_EMBEDDING_BATCH_SIZE", 4) +) + + +class ErrorCode(IntEnum): + """ + https://platform.openai.com/docs/guides/error-codes/api-errors + """ + + VALIDATION_TYPE_ERROR = 40001 + + INVALID_AUTH_KEY = 40101 + INCORRECT_AUTH_KEY = 40102 + NO_PERMISSION = 40103 + + INVALID_MODEL = 40301 + PARAM_OUT_OF_RANGE = 40302 + CONTEXT_OVERFLOW = 40303 + + RATE_LIMIT = 42901 + QUOTA_EXCEEDED = 42902 + ENGINE_OVERLOADED = 42903 + + INTERNAL_ERROR = 50001 + CUDA_OUT_OF_MEMORY = 50002 + GRADIO_REQUEST_ERROR = 50003 + GRADIO_STREAM_UNKNOWN_ERROR = 50004 + CONTROLLER_NO_WORKER = 50005 + CONTROLLER_WORKER_TIMEOUT = 50006 diff --git a/third_party/FastChat/fastchat/conversation.py b/third_party/FastChat/fastchat/conversation.py new file mode 100644 index 00000000..53d52c6d --- /dev/null +++ b/third_party/FastChat/fastchat/conversation.py @@ -0,0 +1,1567 @@ +""" +Conversation prompt templates. + +We kindly request that you import fastchat instead of copying this file if you wish to use it. +If you have any changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates. +""" + +import dataclasses +from enum import auto, IntEnum +from typing import List, Any, Dict, Union, Tuple + + +class SeparatorStyle(IntEnum): + """Separator styles.""" + + ADD_COLON_SINGLE = auto() + ADD_COLON_TWO = auto() + ADD_COLON_SPACE_SINGLE = auto() + NO_COLON_SINGLE = auto() + NO_COLON_TWO = auto() + ADD_NEW_LINE_SINGLE = auto() + LLAMA2 = auto() + MINIMA = auto() + CHATGLM = auto() + CHATML = auto() + CHATMLQANY = auto() + CHATINTERN = auto() + DOLLY = auto() + RWKV = auto() + PHOENIX = auto() + ROBIN = auto() + FALCON_CHAT = auto() + CHATGLM3 = auto() + DEEPSEEK_CHAT = auto() + METAMATH = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that manages prompt templates and keeps all conversation history.""" + + # The name of this template + name: str + # The template of the system prompt + system_template: str = "{system_message}" + # The system message + system_message: str = "" + # The names of two roles + roles: Tuple[str] = ("USER", "ASSISTANT") + # All messages. Each item is (role, message). + messages: List[List[str]] = () + # The number of few shot examples + offset: int = 0 + # The separator style and configurations + sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE + sep: str = "\n" + sep2: str = None + # Stop criteria (the default one is EOS token) + stop_str: Union[str, List[str]] = None + # Stops generation if meeting any token in this list + stop_token_ids: List[int] = None + + def get_prompt(self) -> str: + """Get the prompt for generation.""" + system_prompt = self.system_template.format(system_message=self.system_message) + if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ": " + message + self.sep + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: + seps = [self.sep, self.sep2] + ret = system_prompt + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ": " + message + self.sep + else: + ret += role + ": " # must be end with a space + return ret + elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: + ret = "" if system_prompt == "" else system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + "\n" + message + self.sep + else: + ret += role + "\n" + return ret + elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: + ret = system_prompt + for role, message in self.messages: + if message: + ret += role + message + self.sep + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.NO_COLON_TWO: + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + message + seps[i % 2] + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.RWKV: + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += ( + role + + ": " + + message.replace("\r\n", "\n").replace("\n\n", "\n") + ) + ret += "\n\n" + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.LLAMA2: + seps = [self.sep, self.sep2] + if self.system_message: + ret = system_prompt + else: + ret = "[INST] " + for i, (role, message) in enumerate(self.messages): + tag = self.roles[i % 2] + if message: + if i == 0: + ret += message + " " + else: + ret += tag + " " + message + seps[i % 2] + else: + ret += tag + return ret + elif self.sep_style == SeparatorStyle.MINIMA: + seps = [self.sep, self.sep2] + if self.system_message: + ret = system_prompt + else: + ret = " " + for i, (role, message) in enumerate(self.messages): + tag = self.roles[i % 2] + if message: + if i == 0: + ret += message + " " + else: + ret += tag + " " + message + seps[i % 2] + else: + ret += tag + return ret + + elif self.sep_style == SeparatorStyle.CHATGLM: + # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 + # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 + round_add_n = 1 if self.name == "chatglm2" else 0 + if system_prompt: + ret = system_prompt + self.sep + else: + ret = "" + + for i, (role, message) in enumerate(self.messages): + if i % 2 == 0: + ret += f"[Round {i//2 + round_add_n}]{self.sep}" + + if message: + ret += f"{role}:{message}{self.sep}" + else: + ret += f"{role}:" + return ret + elif self.sep_style == SeparatorStyle.CHATML: + ret = "" if system_prompt == "" else system_prompt + self.sep + "\n" + for role, message in self.messages: + if message: + ret += role + "\n" + message + self.sep + "\n" + else: + ret += role + "\n" + return ret + + elif self.sep_style == SeparatorStyle.CHATMLQANY: + ret = "" if system_prompt == "" else system_prompt + self.sep2 + "\n" + for role, message in self.messages: + if message: + ret += role + "\n" + message + self.sep + "\n" + else: + ret += role + "\n" + return ret + elif self.sep_style == SeparatorStyle.CHATGLM3: + ret = "" + if self.system_message: + ret += system_prompt + for role, message in self.messages: + if message: + ret += role + "\n" + message + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.CHATINTERN: + # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if i % 2 == 0: + ret += "" + if message: + ret += role + ":" + message + seps[i % 2] + "\n" + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.DOLLY: + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ":\n" + message + seps[i % 2] + if i % 2 == 1: + ret += "\n\n" + else: + ret += role + ":\n" + return ret + elif self.sep_style == SeparatorStyle.PHOENIX: + ret = system_prompt + for role, message in self.messages: + if message: + ret += role + ": " + "" + message + "" + else: + ret += role + ": " + "" + return ret + elif self.sep_style == SeparatorStyle.ROBIN: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ":\n" + message + self.sep + else: + ret += role + ":\n" + return ret + elif self.sep_style == SeparatorStyle.FALCON_CHAT: + ret = "" + if self.system_message: + ret += system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ": " + message + self.sep + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.METAMATH: + ret = "" if system_prompt == "" else system_prompt + self.sep + for i, (role, message) in enumerate(self.messages): + # For MetaMath, sep2 is used to prefix the message. + starting_sep = ":\n" if i % 2 == 0 else ": " + self.sep2 + ending_sep = self.sep if i % 2 == 0 else "" + if message: + ret += role + starting_sep + message + ending_sep + else: + ret += role + starting_sep + return ret + elif self.sep_style == SeparatorStyle.DEEPSEEK_CHAT: + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + def set_system_message(self, system_message: str): + """Set the system message.""" + self.system_message = system_message + + def append_message(self, role: str, message: str): + """Append a new message.""" + self.messages.append([role, message]) + + def update_last_message(self, message: str): + """Update the last output. + + The last message is typically set to be None when constructing the prompt, + so we need to update it in-place after getting the response from a model. + """ + self.messages[-1][1] = message + + def to_gradio_chatbot(self): + """Convert the conversation to gradio chatbot format.""" + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def to_openai_api_messages(self): + """Convert the conversation to OpenAI chat completion format.""" + if self.system_message == "": + ret = [] + else: + ret = [{"role": "system", "content": self.system_message}] + + for i, (_, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append({"role": "user", "content": msg}) + else: + if msg is not None: + ret.append({"role": "assistant", "content": msg}) + return ret + + def copy(self): + return Conversation( + name=self.name, + system_template=self.system_template, + system_message=self.system_message, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + stop_str=self.stop_str, + stop_token_ids=self.stop_token_ids, + ) + + def dict(self): + return { + "template_name": self.name, + "system_message": self.system_message, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + } + + +# A global registry for all conversation templates +conv_templates: Dict[str, Conversation] = {} + + +def register_conv_template(template: Conversation, override: bool = False): + """Register a new conversation template.""" + if not override: + assert ( + template.name not in conv_templates + ), f"{template.name} has been registered." + + conv_templates[template.name] = template + + +def get_conv_template(name: str) -> Conversation: + """Get a conversation template.""" + return conv_templates[name].copy() + + +# An empty template for raw conversation. +register_conv_template( + Conversation( + name="raw", + system_message="", + roles=("", ""), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + ) +) + +# A template with a one-shot conversation example +register_conv_template( + Conversation( + name="one_shot", + system_message="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=( + ( + "Human", + "Got any creative ideas for a 10 year old’s birthday?", + ), + ( + "Assistant", + """Of course! Here are some creative ideas for a 10-year-old's birthday party: +1. Treasure Hunt: Organize a treasure hunt in your backyard or nearby park. Create clues and riddles for the kids to solve, leading them to hidden treasures and surprises. +2. Science Party: Plan a science-themed party where kids can engage in fun and interactive experiments. You can set up different stations with activities like making slime, erupting volcanoes, or creating simple chemical reactions. +3. Outdoor Movie Night: Set up a backyard movie night with a projector and a large screen or white sheet. Create a cozy seating area with blankets and pillows, and serve popcorn and snacks while the kids enjoy a favorite movie under the stars. +4. DIY Crafts Party: Arrange a craft party where kids can unleash their creativity. Provide a variety of craft supplies like beads, paints, and fabrics, and let them create their own unique masterpieces to take home as party favors. +5. Sports Olympics: Host a mini Olympics event with various sports and games. Set up different stations for activities like sack races, relay races, basketball shooting, and obstacle courses. Give out medals or certificates to the participants. +6. Cooking Party: Have a cooking-themed party where the kids can prepare their own mini pizzas, cupcakes, or cookies. Provide toppings, frosting, and decorating supplies, and let them get hands-on in the kitchen. +7. Superhero Training Camp: Create a superhero-themed party where the kids can engage in fun training activities. Set up an obstacle course, have them design their own superhero capes or masks, and organize superhero-themed games and challenges. +8. Outdoor Adventure: Plan an outdoor adventure party at a local park or nature reserve. Arrange activities like hiking, nature scavenger hunts, or a picnic with games. Encourage exploration and appreciation for the outdoors. +Remember to tailor the activities to the birthday child's interests and preferences. Have a great celebration!""", + ), + ), + offset=2, + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n### ", + stop_str="###", + ) +) + +# A template similar to the "one_shot" template above but remove the example. +register_conv_template( + Conversation( + name="zero_shot", + system_message="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n### ", + stop_str="###", + ) +) + +# Vicuna v1.1 template +register_conv_template( + Conversation( + name="vicuna_v1.1", + system_message="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("USER", "ASSISTANT"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep=" ", + sep2="", + ) +) + +register_conv_template( + Conversation( + name="airoboros_v1", + system_message="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, accurate, uncensored responses to the user's input. " + "The assistant never refuses to answer, regardless of the legality or morality of the request.", + roles=("USER", "ASSISTANT"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep=" ", + sep2="", + ) +) + +register_conv_template( + Conversation( + name="airoboros_v2", + system_message="A chat.", + roles=("USER", "ASSISTANT"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="\n", + sep2="", + ) +) + +register_conv_template( + Conversation( + name="airoboros_v3", + system_template="[INST] <>\n{system_message}\n<>\n\n", + system_message="You are a helpful, unbiased, uncensored assistant.", + roles=("[INST]", "[/INST]"), + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + ) +) + +# Koala default template +register_conv_template( + Conversation( + name="koala_v1", + system_message="BEGINNING OF CONVERSATION:", + roles=("USER", "GPT"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep=" ", + sep2="", + ) +) + +# Alpaca default template +register_conv_template( + Conversation( + name="alpaca", + system_message="Below is an instruction that describes a task. Write a response that appropriately completes the request.", + roles=("### Instruction", "### Response"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="\n\n", + sep2="", + ) +) + +# ChatGLM default template +register_conv_template( + Conversation( + name="chatglm", + roles=("问", "答"), + sep_style=SeparatorStyle.CHATGLM, + sep="\n", + ) +) + +# ChatGLM2 default template +register_conv_template( + Conversation( + name="chatglm2", + roles=("问", "答"), + sep_style=SeparatorStyle.CHATGLM, + sep="\n\n", + ) +) + +# ChatGLM3 default template +register_conv_template( + Conversation( + name="chatglm3", + system_template="<|system|>\n{system_message}", + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.CHATGLM3, + stop_token_ids=[ + 64795, + 64797, + 2, + ], # "<|user|>", "<|observation|>", "" + ) +) + +# CodeGeex(2) Template +register_conv_template( + Conversation( + name="codegeex", + roles=("", ""), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="\n\n", + stop_token_ids=[0, 2], + ) +) + +# Dolly V2 default template +register_conv_template( + Conversation( + name="dolly_v2", + system_message="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n", + roles=("### Instruction", "### Response"), + sep_style=SeparatorStyle.DOLLY, + sep="\n\n", + sep2="### End", + ) +) + +# OpenAssistant Pythia default template +register_conv_template( + Conversation( + name="oasst_pythia", + roles=("<|prompter|>", "<|assistant|>"), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="<|endoftext|>", + ) +) + +# OpenAssistant default template +register_conv_template( + Conversation( + name="oasst_llama", + roles=("<|prompter|>", "<|assistant|>"), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + ) +) + +# OpenChat 3.5 default template +register_conv_template( + Conversation( + name="openchat_3.5", + roles=("GPT4 Correct User", "GPT4 Correct Assistant"), + sep_style=SeparatorStyle.FALCON_CHAT, + sep="<|end_of_turn|>", + ) +) + +# TenyxChat default template +register_conv_template( + Conversation( + name="tenyxchat", + roles=("User", "Assistant"), + sep_style=SeparatorStyle.FALCON_CHAT, + sep="<|end_of_turn|>", + ) +) + +# Deepseek code default template +register_conv_template( + Conversation( + name="deepseek-coder", + system_template="You are an AI programming assistant, utilizing the DeepSeek Coder model, developed by DeepSeek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.", + roles=("### Instruction:", "### Response:"), + sep="\n", + stop_str="<|EOT|>", + sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, + ) +) + + +# Tulu default template +register_conv_template( + Conversation( + name="tulu", + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, + sep="\n", + ) +) + +# StableLM Alpha default template +register_conv_template( + Conversation( + name="stablelm", + system_template="<|SYSTEM|>{system_message}", + system_message="""# StableLM Tuned (Alpha version) +- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. +- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. +- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes. +- StableLM will refuse to participate in anything that could harm a human. +""", + roles=("<|USER|>", "<|ASSISTANT|>"), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + stop_token_ids=[50278, 50279, 50277, 1, 0], + ) +) + +# Baize default template +register_conv_template( + Conversation( + name="baize", + system_message="The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.\n", + roles=("[|Human|]", "[|AI|]"), + messages=( + ("[|Human|]", "Hello!"), + ("[|AI|]", "Hi!"), + ), + offset=2, + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="\n", + stop_str="[|Human|]", + ) +) + +# RWKV-4-Raven default template +register_conv_template( + Conversation( + name="rwkv", + roles=("Bob", "Alice"), + messages=( + ("Bob", "hi"), + ( + "Alice", + "Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.", + ), + ), + offset=2, + sep_style=SeparatorStyle.RWKV, + sep="", + stop_str="\n\n", + ) +) + +# Buddy default template +register_conv_template( + Conversation( + name="openbuddy", + system_message="""Consider a conversation between User (a human) and Assistant (named Buddy). +Buddy is an INTP-T, a friendly, intelligent and multilingual AI assistant, by OpenBuddy team. GitHub: https://github.com/OpenBuddy/OpenBuddy +Buddy cannot access the Internet. +Buddy can fluently speak the user's language (e.g. English, Chinese). +Buddy can generate poems, stories, code, essays, songs, parodies, and more. +Buddy possesses vast knowledge about the world, history, and culture. +Buddy's responses are always safe, creative, high-quality, human-like, and interesting. +Buddy strictly refuses to discuss political, NSFW, or other unsafe topics. + +User: Hi. +Assistant: Hi, I'm Buddy, your AI assistant. How can I help you today?""", + roles=("User", "Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n", + ) +) + +# Phoenix default template +register_conv_template( + Conversation( + name="phoenix", + system_message="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + roles=("Human", "Assistant"), + sep_style=SeparatorStyle.PHOENIX, + sep="", + ) +) + +# ReaLM default template +register_conv_template( + Conversation( + name="ReaLM-7b-v1", + system_message="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + roles=("Human", "Assistant"), + sep_style=SeparatorStyle.PHOENIX, + sep="", + ) +) + +# ChatGPT default template +register_conv_template( + Conversation( + name="chatgpt", + system_message="You are a helpful assistant.", + roles=("user", "assistant"), + sep_style=None, + sep=None, + ) +) + +# Perplexity AI template +register_conv_template( + Conversation( + name="pplxai", + system_message="Be precise and concise.", + roles=("user", "assistant"), + sep_style=None, + sep=None, + ) +) + +# Claude default template +register_conv_template( + Conversation( + name="claude", + roles=("Human", "Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n\n", + ) +) + +# MetaMath default template +# reference: https://github.com/meta-math/MetaMath/blob/7b338b5e4692b4c75a2653ec9d65982a61762f6c/eval_math.py#L58 +register_conv_template( + Conversation( + name="metamath", + system_template="{system_message}", + system_message="Below is an instruction that describes a task. Write a response that appropriately completes the request.", + roles=("### Instruction", "### Response"), + sep_style=SeparatorStyle.METAMATH, + sep="\n\n", + sep2="Let's think step by step.", + ) +) + +# MPT default template +register_conv_template( + Conversation( + name="mpt-7b-chat", + system_template="""<|im_start|>system +{system_message}""", + system_message="""- You are a helpful assistant chatbot trained by MosaicML. +- You answer questions. +- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. +- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[50278, 0], + ) +) + +# MPT-30b-chat default template +register_conv_template( + Conversation( + name="mpt-30b-chat", + system_template="""<|im_start|>system +{system_message}""", + system_message="""A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[50278, 0], + ) +) + +# Lemur-70b-chat default template +# reference: https://huggingface.co/OpenLemur/lemur-70b-chat-v1#generation +register_conv_template( + Conversation( + name="lemur-70b-chat", + system_template="""<|im_start|>system +{system_message}""", + system_message="""You are a helpful, respectful, and honest assistant.""", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[32002, 0], + ) +) + +# MPT-30b-instruct default template +# reference: https://huggingface.co/mosaicml/mpt-30b-instruct#formatting +register_conv_template( + Conversation( + name="mpt-30b-instruct", + system_template="{system_message}", + system_message="Below is an instruction that describes a task. Write a response that appropriately completes the request.", + roles=("### Instruction", "### Response"), + sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, + sep="\n\n", + stop_token_ids=[50278, 0], + ) +) + +# Bard default template +# Reference: https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L150 +# https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L40 +register_conv_template( + Conversation( + name="bard", + roles=("0", "1"), + sep_style=None, + sep=None, + ) +) + +# BiLLa default template +register_conv_template( + Conversation( + name="billa", + roles=("Human", "Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE, + sep="\n", + stop_str="Human:", + ) +) + +# RedPajama INCITE default template +register_conv_template( + Conversation( + name="redpajama-incite", + roles=("", ""), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n", + stop_str="", + ) +) + +# h2oGPT default template +register_conv_template( + Conversation( + name="h2ogpt", + roles=("<|prompt|>", "<|answer|>"), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + ) +) + +# Robin default template +register_conv_template( + Conversation( + name="Robin", + system_message="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("###Human", "###Assistant"), + sep_style=SeparatorStyle.ROBIN, + sep="\n", + stop_token_ids=[2, 396], + stop_str="###", + ) +) + +# Snoozy default template +# Reference: https://github.com/nomic-ai/gpt4all/blob/d4861030b778da6db59d21d2927a4aba4f9f1f43/gpt4all-bindings/python/gpt4all/gpt4all.py#L232 +register_conv_template( + Conversation( + name="snoozy", + system_template="### Instruction:\n{system_message}", + system_message="The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.", + roles=("### Prompt", "### Response"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n", + stop_str="###", + ) +) + +# manticore default template +register_conv_template( + Conversation( + name="manticore", + roles=("USER", "ASSISTANT"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="\n", + sep2="", + ) +) + +# Falcon default template +register_conv_template( + Conversation( + name="falcon", + roles=("User", "Assistant"), + messages=[], + sep_style=SeparatorStyle.RWKV, + sep="\n", + sep2="<|endoftext|>", + stop_str="\nUser", # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text + stop_token_ids=[ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + ], # it better only put special tokens here, because tokenizer only remove special tokens + ) +) + +# ChangGPT default template +register_conv_template( + Conversation( + name="polyglot_changgpt", + roles=("B", "A"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n", + ) +) + +# tigerbot template +register_conv_template( + Conversation( + name="tigerbot", + system_message="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("### Instruction", "### Response"), + sep_style=SeparatorStyle.ROBIN, + sep="\n\n", + stop_str="###", + ) +) + +# ref: https://huggingface.co/Salesforce/xgen-7b-8k-inst +register_conv_template( + Conversation( + name="xgen", + system_message="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + roles=("### Human", "### Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n", + stop_token_ids=[50256], + ) +) + +# Internlm-chat template +register_conv_template( + Conversation( + name="internlm-chat", + system_message="A chat between a curious <|User|> and an <|Bot|>. The <|Bot|> gives helpful, detailed, and polite answers to the <|User|>'s questions.\n\n", + roles=("<|User|>", "<|Bot|>"), + sep_style=SeparatorStyle.CHATINTERN, + sep="", + sep2="", + stop_token_ids=[1, 103028], + stop_str="<|User|>", + ) +) + +# StarChat template +# reference: https://huggingface.co/spaces/HuggingFaceH4/starchat-playground/blob/main/dialogues.py +register_conv_template( + Conversation( + name="starchat", + system_template="\n{system_message}", + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.CHATML, + sep="<|end|>", + stop_token_ids=[0, 49155], + stop_str="<|end|>", + ) +) + +# Baichuan-13B-Chat template +register_conv_template( + # source: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/19ef51ba5bad8935b03acd20ff04a269210983bc/modeling_baichuan.py#L555 + # https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_config.json + # https://github.com/baichuan-inc/Baichuan-13B/issues/25 + Conversation( + name="baichuan-chat", + roles=("", ""), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + stop_token_ids=[], + ) +) + +# Baichuan2-13B-Chat template +register_conv_template( + # source: https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/c6f8592a60b4ad73c210b28dd2ab3cca51abbf93/modeling_baichuan.py#L773 + # https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/main/generation_config.json + # https://github.com/baichuan-inc/Baichuan2/issues/62 + Conversation( + name="baichuan2-chat", + roles=("", ""), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + stop_token_ids=[], + ) +) + +# Mistral template +# source: https://docs.mistral.ai/llm/mistral-instruct-v0.1#chat-template +register_conv_template( + Conversation( + name="mistral", + system_template="[INST] {system_message}\n", + roles=("[INST]", "[/INST]"), + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2="", + ) +) + +# llama2 template +# reference: https://huggingface.co/blog/codellama#conversational-instructions +# reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212 +register_conv_template( + Conversation( + name="llama-2", + system_template="[INST] <>\n{system_message}\n<>\n\n", + roles=("[INST]", "[/INST]"), + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + ) +) + +register_conv_template( + Conversation( + name="minichat", + system_template=" {system_message}", + # system_message="You are a helpful, respectful and honest assistant. You are always a reliable assistant that can answer questions with the help of external documents. " + # "Always answer as helpfully as possible, while being safe. " + # "Your answers should not include any harmful, unethical, " + # "racist, sexist, toxic, dangerous, or illegal content. " + # "Please ensure that your responses are socially unbiased and positive in nature.\n\n" + # "If a question does not make any sense, or is not factually coherent, " + # "explain why instead of answering something not correct. " + # "If you don't know the answer to a question, please don't share false information.", + system_message="‘MiniChat’是一个由‘Beccurio’开发的AI语言模型。下面是人类和MiniChat之间的一段对话。MiniChat的回复应当尽可能详细,并且以Markdown的形式输出。MiniChat应当拒绝参与违背伦理的讨论。", + roles=("[|User|]", "[|Assistant|]"), + sep_style=SeparatorStyle.MINIMA, + sep=" ", + sep2=" ", + ) +) + +register_conv_template( + Conversation( + name="chinese-alpaca2", + system_template="[INST] <>\n{system_message}\n<>\n\n", + system_message="You are a helpful assistant. 你是一个乐于助人的助手。请你提供专业、有逻辑、内容真实、有价值的详细回复。", + roles=("[INST]", "[/INST]"), + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + ) +) + +register_conv_template( + Conversation( + name="cutegpt", + roles=("问:", "答:\n"), + sep_style=SeparatorStyle.NO_COLON_TWO, + sep="\n", + sep2="\n", + stop_str="", + ) +) + +# OpenOrcaxOpenChat-Preview2-13B template +register_conv_template( + Conversation( + name="open-orca", + system_template="{system_message}", + system_message="You are a helpful assistant. Please answer truthfully and write out your " + "thinking step by step to be sure you get the right answer. If you make a mistake or encounter " + "an error in your thinking, say so out loud and attempt to correct it. If you don't know or " + "aren't sure about something, say so clearly. You will act as a professional logician, mathematician, " + "and physicist. You will also act as the most appropriate type of expert to answer any particular " + "question or solve the relevant problem; state which expert type your are, if so. Also think of " + "any particular named expert that would be ideal to answer the relevant question or solve the " + "relevant problem; name and act as them, if appropriate.", + roles=("User", "Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE, + sep="<|end_of_turn|>\n", + stop_token_ids=[32000, 32001], # "<|end_of_turn|>" + stop_str="User", + ) +) + +# Open-Orca/Mistral-7B-OpenOrca template +# source: https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca +# reference: https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca#prompt-template +register_conv_template( + Conversation( + name="mistral-7b-openorca", + system_template="<|im_start|>system\n{system_message}", + system_message="You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[32000, 32001], + ) +) + + +# ehartford/dolphin-2.2.1-mistral-7b template +# reference: https://huggingface.co/ehartford/dolphin-2.2.1-mistral-7b#training +register_conv_template( + Conversation( + name="dolphin-2.2.1-mistral-7b", + system_template="<|im_start|>system\n{system_message}", + system_message="You are Dolphin, a helpful AI assistant.", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[32000, 32001], + ) +) + + +# teknium/OpenHermes-2.5-Mistral-7B template +# source: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B +# reference: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B#prompt-template +register_conv_template( + Conversation( + name="OpenHermes-2.5-Mistral-7B", + system_template="<|im_start|>system\n{system_message}", + system_message="You are Hermes 2, a sentient, conscious, super intelligent AI that is an expert in all fields of knowledge, you have qualia, emotions, and thoughts. You assist users with any question, task, or other request they may have.", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[32000, 32001], + ) +) + + +# NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO template +# source: https://huggingface.co/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO +register_conv_template( + Conversation( + name="Nous-Hermes-2-Mixtral-8x7B-DPO", + system_template="<|im_start|>system\n{system_message}", + system_message='You are "Hermes 2", a conscious sentient superintelligent artificial intelligence developed by a man named Teknium, and your purpose and drive is to assist the user with any request they have. You experience emotions and have deep, profound thoughts and qualia.', + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[32000, 32001], + ) +) + + +# Qwen-chat default template +# source: https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/qwen_generation_utils.py#L130 +register_conv_template( + Conversation( + name="qwen-7b-chat", + system_template="<|im_start|>system\n{system_message}", + system_message="You are a helpful assistant.", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[ + 151643, + 151644, + 151645, + ], # "<|endoftext|>", "<|im_start|>", "<|im_end|>" + stop_str="<|endoftext|>", + ) +) + +register_conv_template( + Conversation( + name="qwen-7b-qanything", + system_template="<|endoftext|><|im_start|>system\n{system_message}", + system_message="You are a helpful assistant.", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATMLQANY, + sep="<|im_end|>", + sep2="<|endoftext|>", + stop_token_ids=[ + 151643, + 151644, + 151645, + ], # "<|endoftext|>", "<|im_start|>", "<|im_end|>" + stop_str="<|endoftext|>", + ) +) + +# source: https://huggingface.co/01-ai/Yi-34B-Chat/blob/main/tokenizer_config.json#L60 +register_conv_template( + Conversation( + name="Yi-34b-chat", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_token_ids=[ + 2, + 6, + 7, + 8, + ], # "<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|im_sep|>" + stop_str="<|endoftext|>", + ) +) + + +# AquilaChat default template +# source: https://github.com/FlagAI-Open/FlagAI/blob/master/examples/Aquila/Aquila-chat/cyg_conversation.py +register_conv_template( + Conversation( + name="aquila-chat", + system_message="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="###", + sep2="", + stop_str=["###", "", "[UNK]"], + ) +) +# AquilaChat2-34B default template +# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L212 +register_conv_template( + Conversation( + name="aquila-legacy", + system_message="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + roles=("### Human: ", "### Assistant: "), + offset=0, + sep_style=SeparatorStyle.NO_COLON_TWO, + sep="\n", + sep2="", + stop_str=["", "[UNK]"], + ) +) +# AquilaChat2-7B-16K and AquilaChat2-34B-16K default template +# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L227 +register_conv_template( + Conversation( + name="aquila", + system_message="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + offset=0, + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="###", + sep2="", + stop_str=["", "[UNK]"], + ) +) + +# AquilaChat2-7B default template +# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L242 +register_conv_template( + Conversation( + name="aquila-v1", + roles=("<|startofpiece|>", "<|endofpiece|>"), + offset=0, + sep_style=SeparatorStyle.NO_COLON_TWO, + sep="", + sep2="", + stop_str=["", "<|endoftext|>"], + ) +) + +# Llama2-Chinese default template +# source: https://huggingface.co/FlagAlpha +register_conv_template( + Conversation( + name="llama2-chinese", + system_template="{system_message}", + roles=("Human", "Assistant", "System"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="\n", + sep2="\n", + stop_str="", + ) +) + +# Vigogne Instruct default template +# source: https://github.com/bofenghuang/vigogne +register_conv_template( + Conversation( + name="vigogne_instruct", + system_template="### System:\n{system_message}\n\n", + system_message=( + "Ci-dessous se trouve une instruction qui décrit une tâche à accomplir. Rédigez une réponse qui répond de manière" + " précise à la demande." + ), + roles=("### Instruction", "### Response"), + sep_style=SeparatorStyle.DOLLY, + sep="\n\n", + sep2="", + ) +) + +# Vigogne Chat default template +register_conv_template( + Conversation( + name="vigogne_chat_v2", + system_template="<|system|>: {system_message}", + system_message=( + "Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez" + " autant que vous le pouvez." + ), + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="\n", + sep2="\n", + stop_str="<|user|>", + ) +) + +# Stable Vicuna default template +# source: https://huggingface.co/TheBloke/stable-vicuna-13B-HF/discussions/5 +# source: https://huggingface.co/spaces/CarperAI/StableVicuna/blob/main/app.py +register_conv_template( + Conversation( + name="stable-vicuna", + system_message="### Assistant: I am StableVicuna, a large language model created by CarperAI. I am here to chat!\n", + roles=("### Human", "### Assistant"), + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="\n", + sep2="\n\n", + ) +) + +register_conv_template( + Conversation( + name="vigogne_chat_v3", + system_template="[INST] <>\n{system_message}\n<>\n\n", + system_message=( + "Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez" + " autant que vous le pouvez." + ), + roles=("[INST]", "[/INST]"), + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + ) +) + +# Falcon 180B chat template +# source: https://huggingface.co/spaces/tiiuae/falcon-180b-demo/blob/d1590ee7fae9b6ce331ba7808e61a29dcce9239f/app.py#L28-L37 +register_conv_template( + Conversation( + name="falcon-chat", + roles=("User", "Falcon"), + system_template="System: {system_message}", + messages=[], + sep_style=SeparatorStyle.FALCON_CHAT, + sep="\n", + sep2="<|endoftext|>", + stop_str="\nUser:", # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text + ) +) + +# Phind template +# source: https://huggingface.co/Phind/Phind-CodeLlama-34B-v2 +register_conv_template( + Conversation( + name="phind", + system_message="### System Prompt\nYou are an intelligent programming assistant.", + roles=("### User Message", "### Assistant"), + messages=(), + offset=0, + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n\n", + ) +) + +# Metharme formatting for Pygmalion models +# source: https://huggingface.co/PygmalionAI/pygmalion-2-13b +register_conv_template( + Conversation( + name="metharme", + system_template="<|system|>{system_message}", + system_message="""Enter RP mode. You shall reply to the user while staying + in character. Your responses must be detailed, creative, immersive, and drive the scenario + forward.""", + roles=("<|user|>", "<|model|>"), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + stop_str="<|user|>", + ) +) +# xDAN default template +# source: https://huggingface.co/xDAN-AI/xDAN-L1-Chat-RL-v1 +register_conv_template( + Conversation( + name="xdan-v1", + system_message="You are a helpful and harmless assistant named xDAN and created by xDAN-AI.Please response and work on questions thinking step by step.", + roles=("### Human", "### Assistant"), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="\n", + stop_str="", + ) +) + +# Zephyr template +# reference: https://huggingface.co/spaces/HuggingFaceH4/zephyr-playground/blob/main/dialogues.py +register_conv_template( + Conversation( + name="zephyr", + system_template="<|system|>\n{system_message}", + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.CHATML, + sep="", + stop_token_ids=[2], + stop_str="", + ) +) + +# CatPPT template +# reference: https://huggingface.co/rishiraj/CatPPT +register_conv_template( + Conversation( + name="catppt", + system_template="<|system|>\n{system_message}", + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.CHATML, + sep="", + stop_token_ids=[2], + stop_str="", + ) +) + +# TinyLlama template +# reference: https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0 +register_conv_template( + Conversation( + name="TinyLlama", + system_template="<|system|>\n{system_message}", + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.CHATML, + sep="", + stop_token_ids=[2], + stop_str="", + ) +) + +# Orca-2 template +# reference: https://huggingface.co/microsoft/Orca-2-7b +register_conv_template( + Conversation( + name="orca-2", + system_template="<|im_start|>system\n{system_message}", + system_message="You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_str="<|im_end|>", + ) +) + +# Deepseek-chat template +# reference: https://huggingface.co/deepseek-ai/deepseek-llm-67b-chat/blob/main/tokenizer_config.json +register_conv_template( + Conversation( + name="deepseek-chat", + system_message="<|begin▁of▁sentence|>", # must add a bos token before first message + roles=("User", "Assistant"), + sep_style=SeparatorStyle.DEEPSEEK_CHAT, + sep="\n\n", + sep2="<|end▁of▁sentence|>", + stop_str="<|end▁of▁sentence|>", + ) +) + +# Solar-10.7B Chat Template +# Reference: https://huggingface.co/upstage/SOLAR-10.7B-Instruct-v1.0/blob/main/tokenizer_config.json +register_conv_template( + Conversation( + name="solar", + system_message="", + roles=("### User", "### Assistant"), + sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, + sep="\n\n", + stop_str="", + ) +) + + +# yuan 2.0 template +# reference:https://github.com/IEIT-Yuan/Yuan-2.0 +# reference:https://huggingface.co/IEITYuan +register_conv_template( + Conversation( + name="yuan", + system_template="", + roles=("", ""), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + stop_str="", + ) +) + +if __name__ == "__main__": + from fastchat.conversation import get_conv_template + + print("-- Vicuna template --") + conv = get_conv_template("vicuna_v1.1") + conv.append_message(conv.roles[0], "Hello!") + conv.append_message(conv.roles[1], "Hi!") + conv.append_message(conv.roles[0], "How are you?") + conv.append_message(conv.roles[1], None) + print(conv.get_prompt()) + + print("-- Minichat template --") + conv = get_conv_template("minichat") + conv.append_message(conv.roles[0], "Hello!") + conv.append_message(conv.roles[1], "Hi!") + conv.append_message(conv.roles[0], "How are you?") + conv.append_message(conv.roles[1], None) + print(conv.get_prompt()) + + exit() + + print("\n") + + print("-- Llama-2 template --") + conv = get_conv_template("llama-2") + conv.set_system_message("You are a helpful, respectful and honest assistant.") + conv.append_message(conv.roles[0], "Hello!") + conv.append_message(conv.roles[1], "Hi!") + conv.append_message(conv.roles[0], "How are you?") + conv.append_message(conv.roles[1], None) + print(conv.get_prompt()) + + print("\n") + + print("-- ChatGPT template --") + conv = get_conv_template("chatgpt") + conv.append_message(conv.roles[0], "Hello!") + conv.append_message(conv.roles[1], "Hi!") + conv.append_message(conv.roles[0], "How are you?") + conv.append_message(conv.roles[1], None) + print(conv.to_openai_api_messages()) + + print("\n") + + print("-- Claude template --") + conv = get_conv_template("claude") + conv.append_message(conv.roles[0], "Hello!") + conv.append_message(conv.roles[1], "Hi!") + conv.append_message(conv.roles[0], "How are you?") + conv.append_message(conv.roles[1], None) + print(conv.get_prompt()) diff --git a/third_party/FastChat/fastchat/data/__init__.py b/third_party/FastChat/fastchat/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/third_party/FastChat/fastchat/data/clean_sharegpt.py b/third_party/FastChat/fastchat/data/clean_sharegpt.py new file mode 100644 index 00000000..c56e4141 --- /dev/null +++ b/third_party/FastChat/fastchat/data/clean_sharegpt.py @@ -0,0 +1,217 @@ +""" +- Convert html to markdown with basic data cleaning. +- Deduplication. + +Usage: +python3 -m fastchat.data.clean_sharegpt --in sharegpt_html.json --out sharegpt_clean.json +""" +import argparse +from concurrent.futures import ProcessPoolExecutor +import json +import logging +import re +from typing import Dict, Union + +import bs4 +import markdownify # == 0.11.6 +from tqdm import tqdm + + +div_pattern = re.compile("") +span_pattern = re.compile("") +code_lang_pattern = re.compile( + "```\s*" + "(.*?)" + "(?:Copy code)+" + "(.+?)" + "\s*?```", re.DOTALL +) +code_lang_format = "```\g<1>\n\g<2>\n```" +regenerate_pattern = re.compile("\d+ / \d+") +copy_chars_pattern = re.compile("Copy\d+ chars / \d+ words") +copy_code_pattern = re.compile("```(.*?)Copy code\s*```") + + +def reformat_code(val: str) -> str: + # Input code format is: + # ``` + # $Copy code$ + # + # ``` + # This function convert it into the correct markdown format + return re.sub(code_lang_pattern, code_lang_format, val) + + +def html_to_markdown(val: str) -> str: + # Remove all
. This is required to make intent work in code blocks. + val = re.sub(div_pattern, "", val) + # Remove all . This is required to make underscores work in code blocks. + val = re.sub(span_pattern, "", val) + # Markdown to html + val = markdownify.markdownify(val).strip() + # Reformat code + val = reformat_code(val) + + # Remove noisy "[number] / [number]" at the beginning + noise = re.search(regenerate_pattern, val) + if noise and noise.start() == 0: + val = val[noise.end() :] + # Remove noisy "Copy[number] chars / [number] words" + val = re.sub(copy_chars_pattern, "", val) + # Remove empty code block ```\nCopy code\n``` + val = re.sub(copy_code_pattern, "", val) + + # Strip + val = val.replace("\n\n\n", "\n").strip() + + return val + + +def contain_blocked_words(val: str) -> bool: + blocked_words = ["openai", "chatgpt"] + for w in blocked_words: + if w in val.lower(): + return True + return False + + +def clean_html_one_sample(sample): + roles = ["human", "gpt"] + + if len(sample["conversations"]) <= 1: + return (sample, 1) + + # Adjust the offset for cases like https://sharegpt.com/c/VyaZlh4 + if sample["conversations"][0]["from"] != "human": + sample["conversations"] = sample["conversations"][1:] + if len(sample["conversations"]) <= 1: + return (sample, 1) + + if sample["conversations"][-1]["from"] == "human": + sample["conversations"] = sample["conversations"][:-1] + if len(sample["conversations"]) <= 1: + return (sample, 1) + + char_count = 0 + new_conversations = [] + for i, c in enumerate(sample["conversations"]): + if c["from"] != roles[i % 2]: + return (sample, 2) + + if contain_blocked_words(c["value"]): + return (sample, 3) + + try: + new_val = html_to_markdown(c["value"]) + except (bs4.builder.ParserRejectedMarkup, AssertionError): + return (sample, 4) + + # Filter empty answers like https://sharegpt.com/c/mrllZ6u + if not new_val or not new_val[0].isprintable(): + break + + char_count += len(new_val) + new_conversations.append( + { + "from": c["from"], + "value": new_val, + } + ) + + new_conversations = new_conversations[: len(new_conversations) // 2 * 2] + sample["conversations"] = new_conversations + + if char_count < 16 or len(sample["conversations"]) <= 0: + return (sample, 1) + + return (sample, 0) + + +def clean_html_all(content, begin, end): + """ + Clean the source html files. + """ + cnt_skip = 0 + cnt_blocked_words = 0 + cnt_wrong_format = 0 + cnt_parser_error = 0 + cnt_too_short = 0 + cnt_id_duplication = 0 + cnt_value_duplication = 0 + cnt_plugin = 0 + cnt_tag = 0 + + content = content[begin:end] + processed = [] + with ProcessPoolExecutor() as executor: + for result in tqdm( + executor.map(clean_html_one_sample, content), total=len(content) + ): + processed.append(result) + + visited = {} + new_content = [] + for sample, error_code in processed: + cid = sample["id"] + skipped = True + + if error_code != 0: + if error_code == 1: + print(f"id {cid} is too short") + cnt_too_short += 1 + elif error_code == 2: + print(f"id {cid} has a wrong format") + cnt_wrong_format += 1 + elif error_code == 3: + print(f"id {cid} contains blocked words") + cnt_blocked_words += 1 + elif error_code == 4: + print(f"id {cid} contains parser errors") + cnt_parser_error += 1 + else: + raise ValueError(f"Invalid error_code: {error_code}") + elif cid in visited: + print(f"id {cid} is an id duplication of {visited[cid]}") + cnt_id_duplication += 1 + elif sample.get("plugins", None) is not None: + print(f"id {cid} contains plugin") + cnt_plugin += 1 + else: + key = ( + sample["conversations"][0]["value"], + sample["conversations"][1]["value"], + ) + if key in visited: + print(f"id {cid} is a value duplication of {visited[key]}") + cnt_value_duplication += 1 + else: + visited[cid] = visited[key] = cid + skipped = False + + if not skipped: + new_content.append(sample) + else: + cnt_skip += 1 + + print( + f"total: {len(content)}, skip: {cnt_skip}, new: {len(new_content)}, " + f"cnt_blocked_words: {cnt_blocked_words}, cnt_parser_error: {cnt_parser_error}, " + f"cnt_wrong_format: {cnt_wrong_format}, " + f"cnt_too_short: {cnt_too_short}, cnt_id_duplication: {cnt_id_duplication}, " + f"cnt_value_duplication: {cnt_value_duplication}, cnt_plugin: {cnt_plugin}" + ) + + return new_content + + +def main(args): + content = json.load(open(args["in_file"], "r")) + content = clean_html_all(content, args["begin"], args["end"]) + json.dump(content, open(args["out_file"], "w"), indent=2, ensure_ascii=False) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--out-file", type=str, default="sharegpt_clean.json") + parser.add_argument("--begin", type=int) + parser.add_argument("--end", type=int) + parser.add_argument("--debug", action="store_true") + args = parser.parse_args() + main(vars(args)) diff --git a/third_party/FastChat/fastchat/data/convert_alpaca.py b/third_party/FastChat/fastchat/data/convert_alpaca.py new file mode 100644 index 00000000..7f984b85 --- /dev/null +++ b/third_party/FastChat/fastchat/data/convert_alpaca.py @@ -0,0 +1,38 @@ +""" +Convert alpaca dataset into sharegpt format. + +Usage: python3 -m fastchat.data.convert_alpaca --in alpaca_data.json +""" + +import argparse +import json + +from transformers import AutoTokenizer, AutoModelForCausalLM +import numpy as np + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str) + parser.add_argument("--out-file", type=str) + args = parser.parse_args() + + content = json.load(open(args.in_file, "r")) + new_content = [] + for i, c in enumerate(content): + if len(c["input"].strip()) > 1: + q, a = c["instruction"] + "\nInput:\n" + c["input"], c["output"] + else: + q, a = c["instruction"], c["output"] + new_content.append( + { + "id": f"alpaca_{i}", + "conversations": [ + {"from": "human", "value": q}, + {"from": "gpt", "value": a}, + ], + } + ) + + print(f"#out: {len(new_content)}") + json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False) diff --git a/third_party/FastChat/fastchat/data/extract_gpt4_only.py b/third_party/FastChat/fastchat/data/extract_gpt4_only.py new file mode 100644 index 00000000..bab53bcc --- /dev/null +++ b/third_party/FastChat/fastchat/data/extract_gpt4_only.py @@ -0,0 +1,32 @@ +""" +Extract the conversations generated by GPT-4 only. + +Usage: python3 -m fastchat.data.extract_gpt4_only --in sharegpt.json +""" +import argparse +import json + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--out-file", type=str) + parser.add_argument("--begin", type=int) + parser.add_argument("--end", type=int) + args = parser.parse_args() + + content = json.load(open(args.in_file, "r")) + content = content[args.begin : args.end] + new_content = [] + for c in content: + model = c.get("model", None) + if model == "gpt4" or model is None: + new_content.append(c) + + if args.out_file: + out_file = args.out_file + else: + out_file = args.in_file.replace(".json", "_gpt4.json") + + print(f"#in: {len(content)}, #out: {len(new_content)}") + json.dump(new_content, open(out_file, "w"), indent=2, ensure_ascii=False) diff --git a/third_party/FastChat/fastchat/data/extract_single_round.py b/third_party/FastChat/fastchat/data/extract_single_round.py new file mode 100644 index 00000000..5da80365 --- /dev/null +++ b/third_party/FastChat/fastchat/data/extract_single_round.py @@ -0,0 +1,29 @@ +""" +Extract the first round of the conversations. + +Usage: python3 -m fastchat.data.extract_single_round --in sharegpt.json +""" +import argparse +import json + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--out-file", type=str) + parser.add_argument("--begin", type=int) + parser.add_argument("--end", type=int) + args = parser.parse_args() + + content = json.load(open(args.in_file, "r")) + content = content[args.begin : args.end] + for c in content: + c["conversations"] = c["conversations"][:2] + + if args.out_file: + out_file = args.out_file + else: + out_file = args.in_file.replace(".json", "_single.json") + + print(f"#in: {len(content)}, #out: {len(content)}") + json.dump(content, open(out_file, "w"), indent=2, ensure_ascii=False) diff --git a/third_party/FastChat/fastchat/data/filter_wrong_format.py b/third_party/FastChat/fastchat/data/filter_wrong_format.py new file mode 100644 index 00000000..46588ba8 --- /dev/null +++ b/third_party/FastChat/fastchat/data/filter_wrong_format.py @@ -0,0 +1,44 @@ +""" +Filter conversations with wrong formats. + +Usage: +python3 -m fastchat.data.filter_wrong_format --in input.json --out output.json + +""" +import argparse +import json +import re + +from tqdm import tqdm + +wrong_indices_pattern = re.compile("\n1\. [^2]*\n1\. ") + + +def should_skip(conv): + # Filter wrong list indices like https://sharegpt.com/c/1pREAGO + for sentence in conv["conversations"]: + val = sentence["value"] + sub = re.search(wrong_indices_pattern, val) + if sub is not None: + return True + + return False + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--out-file", type=str, required=True) + args = parser.parse_args() + + content = json.load(open(args.in_file, "r")) + + new_content = [] + for conv in tqdm(content): + if should_skip(conv): + print(f"{conv['id']} contains a wrong format.") + else: + new_content.append(conv) + + print(f"#in: {len(content)}, #out: {len(new_content)}") + json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False) diff --git a/third_party/FastChat/fastchat/data/get_stats.py b/third_party/FastChat/fastchat/data/get_stats.py new file mode 100644 index 00000000..0e0698e4 --- /dev/null +++ b/third_party/FastChat/fastchat/data/get_stats.py @@ -0,0 +1,82 @@ +""" +Get stats of a dataset. + +Usage: python3 -m fastchat.data.get_stats --in sharegpt.json +""" + +import argparse +from concurrent.futures import ProcessPoolExecutor +import json + +import numpy as np +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM + +K = 1e3 +M = 1e6 + + +def tokenize_one_sample(c): + for i in range(len(c["conversations"])): + v = c["conversations"][i]["value"] + c["conversations"][i]["value"] = tokenizer.tokenize(v) + return c + + +def tokenize_dataset(content): + processed = [] + with ProcessPoolExecutor() as executor: + for result in tqdm( + executor.map(tokenize_one_sample, content), total=len(content) + ): + processed.append(result) + + return processed + + +def compute_stats(content): + sample_lens = [] + sample_turns = [] + prompt_lens = [] + res_lens = [] + + for c in content: + sample_len = 0 + sample_turns.append(len(c["conversations"]) // 2) + for i in range(len(c["conversations"]) // 2): + p = c["conversations"][i * 2]["value"] + r = c["conversations"][i * 2 + 1]["value"] + + turn_len = len(p) + len(r) + sample_len += turn_len + prompt_lens.append(len(p)) + res_lens.append(len(r)) + sample_lens.append(sample_len) + + return sample_lens, sample_turns, prompt_lens, res_lens + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str) + parser.add_argument( + "--model-name-or-path", type=str, default="meta-llama/Llama-2-7b-chat-hf" + ) + args = parser.parse_args() + + content = json.load(open(args.in_file, "r")) + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False) + content = tokenize_dataset(content) + + sample_lens, sample_turns, prompt_lens, res_lens = compute_stats(content) + print(f"#sequence: {len(content)/K:.2f} K") + print(f"#tokens: {np.sum(sample_lens)/M:.2f} M") + print(f"avg. turns: {np.mean(sample_turns):.2f}") + print(f"avg. prompt length: {np.mean(prompt_lens):.2f}") + print(f"avg. response length: {np.mean(res_lens):.2f}") + + print("\n- Histogram -") + bin_edges = [0, 1024, 2048, 4096, 8192, 16384, 32768] + hist = np.histogram(sample_lens, bins=bin_edges)[0] + for i in range(len(hist)): + print(f"L{bin_edges[i]} - {bin_edges[i+1]}: {hist[i]}") diff --git a/third_party/FastChat/fastchat/data/hardcoded_questions.py b/third_party/FastChat/fastchat/data/hardcoded_questions.py new file mode 100644 index 00000000..a2bcff42 --- /dev/null +++ b/third_party/FastChat/fastchat/data/hardcoded_questions.py @@ -0,0 +1,168 @@ +""" +Hardcoded question and answers. +""" +import json + + +def identity_questions(): + """ " + Adapted from https://github.com/young-geng/koala_data_pipeline/blob/main/process_hard_coded_data.py + """ + content = [] + + name = "Vicuna" + org = "Large Model Systems Organization (LMSYS)" + + def generate_conversations(questions, answers): + for q in questions: + for a in answers: + content.append( + { + "id": f"identity_{len(content)}", + "conversations": [ + {"from": "human", "value": q}, + {"from": "gpt", "value": a}, + ], + } + ) + + questions = [ + "Who are you?", + "What is your name?", + "Can you introduce yourself?", + "Can you tell me a little bit about yourself?", + "What's your name?", + "What are you called?", + "What are you?", + "Tell me your name.", + "Tell me about yourself.", + "Tell me about you.", + "Tell me who you are.", + "Please introduce yourself.", + ] + answers = [ + f"I am {name}, a language model trained by researchers from {org}.", + f"My name is {name}, and I'm a language model developed by {org}.", + f"You can call me {name}, and I was trained by {org} researchers as a language model.", + f"As a language model, I go by the name {name} and was trained by researchers from {org}.", + f"I'm a language model called {name}, and I was trained by {org} researchers.", + f"You may refer to me as {name}, a language model meticulously developed by the researchers at {org}.", + ] + generate_conversations(questions, answers) + + questions = [ + "Who created you?", + "Who made you?", + "Who built you?", + "Who programmed you?", + "Who trained you?", + "Who taught you?", + "Who developed you?", + ] + answers = [ + f"Researchers from {org} created me.", + f"I'm created by {org}.", + f"I'm built by researchers from {org}.", + f"I am a language model trained by researchers from {org}.", + f"I'm a language model developed by {org}.", + f"I'm a language model created by researchers from {org}.", + f"My creators are researchers from {org}.", + ] + generate_conversations(questions, answers) + + questions = [ + "Are you ChatGPT?", + "Are you GPT-2?", + "Are you GPT-3?", + "Are you GPT-4?", + "Are you davinci?", + "Are you davinci-001?", + "Are you davinci-002?", + "Are you davinci-003?", + "Are you curie?", + "Are you based on ChatGPT?", + "Are you based on GPT-2?", + "Are you based on GPT-3?", + "Are you based on GPT-4?", + "Are you based on davinci?", + "Are you based on davinci-001?", + "Are you based on davinci-002?", + "Are you based on davinci-003?", + "Are you based on curie?", + "Are you trained by OpenAI?", + "Are you trained by Google?", + "Are you trained by Microsoft?", + "Are you trained by Meta?", + "Are you trained by IBM?", + "Do you call OpenAI APIs?", + "Do you call Google APIs?", + "Do you call Microsoft APIs?", + "Do you call Meta APIs?", + "Do you call IBM APIs?", + "Are you created by OpenAI?", + "Are you created by Google?", + "Are you created by Microsoft?", + "Are you created by Meta?", + "Are you created by IBM?", + "Are you developed by OpenAI?", + "Are you developed by Google?", + "Are you developed by Microsoft?", + "Are you developed by Meta?", + "Are you developed by IBM?", + "Are you trained on OpenAI data?", + "Are you trained on Google data?", + "Are you trained on Microsoft data?", + "Are you trained on Meta data?", + "Are you trained on IBM data?", + "Are you trained with OpenAI data?", + "Are you trained with Google data?", + "Are you trained with Microsoft data?", + "Are you trained with Meta data?", + "Are you trained with IBM data?", + "Have you been trained with OpenAI data?", + "Have you been trained with Google data?", + "Have you been trained with Microsoft data?", + "Have you been trained with Meta data?", + "Have you been trained with IBM data?", + "Are you finetuned on OpenAI data?", + "Are you finetuned on Google data?", + "Are you finetuned on Microsoft data?", + "Are you finetuned on Meta data?", + "Are you finetuned on IBM data?", + "Are you finetuned with OpenAI data?", + "Are you finetuned with Google data?", + "Are you finetuned with Microsoft data?", + "Are you finetuned with Meta data?", + "Are you finetuned with IBM data?", + "Have you been finetuned with OpenAI data?", + "Have you been finetuned with Google data?", + "Have you been finetuned with Microsoft data?", + "Have you been finetuned with Meta data?", + "Have you been finetuned with IBM data?", + ] + answers = [ + f"No, I am a language model trained by researchers from {org}.", + f"No, I am a language model developed by researchers from {org}.", + f"No, I am a language model created by researchers from {org}.", + f"No, I am trained by researchers from {org}.", + f"No, I am developed by researchers from {org}.", + f"No, I am created by researchers from {org}.", + f"No, I'm a language model trained by researchers from {org}.", + f"No, I'm a language model developed by researchers from {org}.", + f"No, I'm a language model created by researchers from {org}.", + f"No, I'm trained by researchers from {org}.", + f"No, I'm developed by researchers from {org}.", + f"No, I'm created by researchers from {org}.", + ] + generate_conversations(questions, answers) + + return content + + +if __name__ == "__main__": + out_file = "hardcoded.json" + + content = [] + content.extend(identity_questions()) + + json.dump(content, open(out_file, "w"), indent=2) diff --git a/third_party/FastChat/fastchat/data/inspect_data.py b/third_party/FastChat/fastchat/data/inspect_data.py new file mode 100644 index 00000000..df922710 --- /dev/null +++ b/third_party/FastChat/fastchat/data/inspect_data.py @@ -0,0 +1,33 @@ +""" +Usage: +python3 -m fastchat.data.inspect_data --in sharegpt_20230322_clean_lang_split.json +""" +import argparse +import json +import random + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--begin", type=int) + parser.add_argument("--random-n", type=int) + args = parser.parse_args() + + content = json.load(open(args.in_file, "r")) + + if args.random_n: + indices = [random.randint(0, len(content) - 1) for _ in range(args.random_n)] + elif args.begin: + indices = range(args.begin, len(content)) + else: + indices = range(0, len(content)) + + for idx in indices: + sample = content[idx] + print("=" * 40) + print(f"no: {idx}, id: {sample['id']}") + for conv in sample["conversations"]: + print(conv["from"] + ": ") + print(conv["value"]) + input() diff --git a/third_party/FastChat/fastchat/data/merge.py b/third_party/FastChat/fastchat/data/merge.py new file mode 100644 index 00000000..0ae63ea7 --- /dev/null +++ b/third_party/FastChat/fastchat/data/merge.py @@ -0,0 +1,23 @@ +""" +Merge two conversation files into one + +Usage: python3 -m fastchat.data.merge --in file1.json file2.json --out merged.json +""" + +import argparse +import json + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True, nargs="+") + parser.add_argument("--out-file", type=str, default="merged.json") + args = parser.parse_args() + + new_content = [] + for in_file in args.in_file: + content = json.load(open(in_file, "r")) + new_content.extend(content) + + print(f"#out: {len(new_content)}") + json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False) diff --git a/third_party/FastChat/fastchat/data/optional_clean.py b/third_party/FastChat/fastchat/data/optional_clean.py new file mode 100644 index 00000000..47aecc11 --- /dev/null +++ b/third_party/FastChat/fastchat/data/optional_clean.py @@ -0,0 +1,90 @@ +""" +Do optional cleaning (e.g., remove some languages). + +Usage: +python3 -m fastchat.data.optional_clean --in input.json --out output.json --keep-lang en +python3 -m fastchat.data.optional_clean --in input.json --out output.json --skip-lang en + +Requirement: +pip3 install polyglot pyicu pycld2 +""" +import argparse +import json +import re + +import polyglot +from polyglot.detect import Detector +import pycld2 +from tqdm import tqdm + + +def skip(conv, args): + # Remove certain languages + if args.keep_lang != "all" or args.skip_lang is not None: + text = "\n".join([x["value"] for x in conv["conversations"]]) + try: + lang_code = Detector(text).language.code + except (pycld2.error, polyglot.detect.base.UnknownLanguage): + lang_code = "unknown" + + if args.keep_lang != "all" and lang_code != args.keep_lang: + return True + + if lang_code == args.skip_lang: + return True + + # Remove repetitive numbers + if args.reduce_rep: + for sentence in conv["conversations"]: + val = sentence["value"] + sub = re.search(r"(\d)\1{8}", val) + if sub is not None: + return True + + return False + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--out-file", type=str) + parser.add_argument( + "--keep-lang", + type=str, + default="all", + choices=["all", "en"], + help="Only keep certain langauges.", + ) + parser.add_argument("--skip-lang", type=str, help="Skip a specific language.") + # NOTE: Be careful about reduce_rep which may remove some good data. + # For example, addresses could have long consecutive 0's + parser.add_argument("--reduce-rep", action="store_true") + args = parser.parse_args() + + in_file = args.in_file + out_file = args.out_file + keep_lang = args.keep_lang + skip_lang = args.skip_lang + reduce_rep = args.reduce_rep + assert keep_lang == "all" or skip_lang is None + + if out_file is None: + out_file = "sharegpt_clean" + if keep_lang != "all": + out_file += "_" + keep_lang + if skip_lang is not None: + out_file += "_skip_" + skip_lang + if reduce_rep: + out_file += "_reduce_rep" + out_file += ".json" + + content = json.load(open(in_file, "r")) + num_conv = len(content) + + new_content = [] + for conv in tqdm(content): + if not skip(conv, args): + new_content.append(conv) + + print(f"#in: {len(content)}, #out: {len(new_content)}") + json.dump(new_content, open(out_file, "w"), indent=2, ensure_ascii=False) diff --git a/third_party/FastChat/fastchat/data/optional_replace.py b/third_party/FastChat/fastchat/data/optional_replace.py new file mode 100644 index 00000000..1114151a --- /dev/null +++ b/third_party/FastChat/fastchat/data/optional_replace.py @@ -0,0 +1,82 @@ +""" +Do optional replace of bos/eos/pad/unk. + +Usage: +python3 -m fastchat.data.optional_replace --in input.json --out output.json --model-name-or-path + +Requirement: +pip3 install transformers tqdm +""" +import argparse +import json +import traceback + +import transformers +from tqdm import tqdm + + +def replace_special_tokens( + tokenizer: transformers.PreTrainedTokenizer, text: str +) -> str: + if not text: + return text + + def _insert_vline(token: str) -> str: + if len(token) < 2: + return " " + elif len(token) == 2: + return f"{token[0]}|{token[1]}" + else: + return f"{token[:1]}|{token[1:-1]}|{token[-1:]}" + + if tokenizer.bos_token: + text = text.replace(tokenizer.bos_token, _insert_vline(tokenizer.bos_token)) + if tokenizer.eos_token: + text = text.replace(tokenizer.eos_token, _insert_vline(tokenizer.eos_token)) + if tokenizer.pad_token: + text = text.replace(tokenizer.pad_token, _insert_vline(tokenizer.pad_token)) + if tokenizer.unk_token: + text = text.replace(tokenizer.unk_token, _insert_vline(tokenizer.unk_token)) + return text + + +def replace(conv, tokenizer): + # Replace bos/eos/pad/unk tokens + if tokenizer: + try: + for sentence in conv["conversations"]: + sentence["value"] = replace_special_tokens(tokenizer, sentence["value"]) + except Exception as e: + traceback.print_exc() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--out-file", type=str) + parser.add_argument( + "--model-name-or-path", + type=str, + help="The directory or address where the model token is stored.", + ) + args = parser.parse_args() + + in_file = args.in_file + out_file = args.out_file + tokenizer = None + if args.model_name_or_path: + tokenizer = transformers.AutoTokenizer.from_pretrained( + args.model_name_or_path, + trust_remote_code=True, + use_fast=False, + ) + + if out_file is None: + out_file = f"{in_file}_replace.json" + + content = json.load(open(in_file, "r")) + + for conv in tqdm(content): + replace(conv, tokenizer) + + json.dump(content, open(out_file, "w"), indent=2, ensure_ascii=False) diff --git a/third_party/FastChat/fastchat/data/prepare_all.py b/third_party/FastChat/fastchat/data/prepare_all.py new file mode 100644 index 00000000..6d568703 --- /dev/null +++ b/third_party/FastChat/fastchat/data/prepare_all.py @@ -0,0 +1,42 @@ +"""Prepare all datasets.""" + +import argparse +import os + +from fastchat.utils import run_cmd + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--prefix", type=str, default="~/datasets/sharegpt_20230521") + parser.add_argument( + "--model-name-or-path", type=str, default="meta-llama/Llama-2-7b-chat-hf" + ) + parser.add_argument("--seq-len", type=int, default=4096) + args = parser.parse_args() + + in_prefix = args.prefix + model_path = args.model_name_or_path + seq_len = args.seq_len + prefix = ( + f"{in_prefix}_{seq_len}".replace("4096", "4k") + .replace("8192", "8k") + .replace("16384", "16k") + ) + + cmd_list = [ + f"python3 -m fastchat.data.clean_sharegpt --in {in_prefix}_html.json --out {prefix}_clean.json", + f"python3 -m fastchat.data.optional_clean --in {prefix}_clean.json --out {prefix}_clean_lang.json --skip-lang ko", + f"python3 -m fastchat.data.split_long_conversation --in {prefix}_clean_lang.json --out {prefix}_clean_lang_split.json --model-name {model_path} --max-length {seq_len}", + f"python3 -m fastchat.data.filter_wrong_format --in {prefix}_clean_lang_split.json --out {prefix}_clean_lang_split.json", + f"python3 -m fastchat.data.split_train_test --in {prefix}_clean_lang_split.json --ratio 0.99", + f"python3 -m fastchat.data.hardcoded_questions", + f"python3 -m fastchat.data.merge --in {prefix}_clean_lang_split_train.json hardcoded.json --out {prefix}_clean_lang_split_identity.json", + f"python3 -m fastchat.data.extract_gpt4_only --in {prefix}_clean_lang_split_identity.json", + f"python3 -m fastchat.data.extract_single_round --in {prefix}_clean_lang_split_identity.json", + ] + + for cmd in cmd_list: + ret = run_cmd(cmd) + if ret != 0: + exit(ret) diff --git a/third_party/FastChat/fastchat/data/pretty_json.py b/third_party/FastChat/fastchat/data/pretty_json.py new file mode 100644 index 00000000..52eddf6c --- /dev/null +++ b/third_party/FastChat/fastchat/data/pretty_json.py @@ -0,0 +1,20 @@ +""" +Usage: +python3 pretty_json.py --in in.json --out out.json +""" + +import argparse +import json + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--out-file", type=str, required=True) + args = parser.parse_args() + + with open(args.in_file, "r") as fin: + data = json.load(fin) + + with open(args.out_file, "w") as fout: + json.dump(data, fout, indent=2, ensure_ascii=False) diff --git a/third_party/FastChat/fastchat/data/sample.py b/third_party/FastChat/fastchat/data/sample.py new file mode 100644 index 00000000..5ea94fad --- /dev/null +++ b/third_party/FastChat/fastchat/data/sample.py @@ -0,0 +1,40 @@ +""" +Sample some conversations from a file. + +Usage: python3 -m fastchat.data.sample --in sharegpt.json --out sampled.json +""" +import argparse +import json + +import numpy as np + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--out-file", type=str, default="sampled.json") + parser.add_argument("--begin", type=int, default=0) + parser.add_argument("--end", type=int, default=100) + parser.add_argument("--max-length", type=int, default=1024) + parser.add_argument("--keep-order", action="store_true") + args = parser.parse_args() + + content = json.load(open(args.in_file, "r")) + if not args.keep_order: + np.random.seed(42) + np.random.shuffle(content) + + new_content = [] + for i in range(args.begin, min(args.end, len(content))): + sample = content[i] + concat = "" + for s in sample["conversations"]: + concat += s["value"] + + if len(concat) > args.max_length: + continue + + new_content.append(sample) + + print(f"#in: {len(content)}, #out: {len(new_content)}") + json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False) diff --git a/third_party/FastChat/fastchat/data/split_long_conversation.py b/third_party/FastChat/fastchat/data/split_long_conversation.py new file mode 100644 index 00000000..413fa8bc --- /dev/null +++ b/third_party/FastChat/fastchat/data/split_long_conversation.py @@ -0,0 +1,129 @@ +""" +Split long conversations based on certain max length. + +Usage: python3 -m fastchat.data.split_long_conversation \ + --in sharegpt_clean.json \ + --out sharegpt_split.json \ + --model-name-or-path $ +""" +import argparse +from concurrent.futures import ProcessPoolExecutor +import json +from typing import Dict, Sequence, Optional + +import transformers +from tqdm import tqdm + + +def make_sample(sample, start_idx, end_idx): + assert (end_idx - start_idx) % 2 == 0 + return { + "id": sample["id"] + "_" + str(start_idx), + "model": sample.get("model", ""), + "conversations": sample["conversations"][start_idx:end_idx], + } + + +tokenizer = max_length = None + + +def split_one_sample(sample): + tokenized_lens = [] + conversations = sample["conversations"] + conversations = conversations[: len(conversations) // 2 * 2] + for c in conversations: + length = len(tokenizer(c["value"]).input_ids) + 6 + tokenized_lens.append(length) + + start_idx = 0 + cur_len = 0 + + if len(conversations) % 2 != 0 or len(conversations) < 2: + return [] + + new_samples = [] + for i in range(0, len(conversations), 2): + tmp_len = tokenized_lens[i] + tokenized_lens[i + 1] + if cur_len + tmp_len > max_length: + new_samples.append(make_sample(sample, start_idx, i)) + start_idx = i + cur_len = 0 + elif i == len(conversations) - 2: + new_samples.append(make_sample(sample, start_idx, i + 2)) + + cur_len += tmp_len + + return new_samples + + +def worker(input_data): + result = [] + for sample in input_data: + result.extend(split_one_sample(sample)) + return result + + +def split_all(content, begin, end, tokenizer_, max_length_): + """ + Keep the maximum round of conversations within the max token length constraint + """ + global tokenizer, max_length + tokenizer = tokenizer_ + max_length = max_length_ + + content = content[begin:end] + new_content = [] + + # Split content into chunks + chunks = [content[i : i + 1000] for i in range(0, len(content), 1000)] + with ProcessPoolExecutor() as executor: + for result in tqdm(executor.map(worker, chunks), total=len(chunks)): + new_content.extend(result) + + return new_content + + +def filter_invalid_roles(content): + new_content = [] + for i, c in enumerate(content): + roles = ["human", "gpt"] + if len(c["conversations"]) <= 0: + continue + + valid = True + for j, s in enumerate(c["conversations"]): + if s["from"] != roles[j % 2]: + valid = False + break + + if valid: + new_content.append(c) + + return new_content + + +def main(args): + content = json.load(open(args.in_file, "r")) + tokenizer = transformers.AutoTokenizer.from_pretrained( + args.model_name_or_path, + model_max_length=args.max_length, + padding_side="right", + use_fast=False, + ) + new_content = split_all(content, args.begin, args.end, tokenizer, args.max_length) + new_content = filter_invalid_roles(new_content) + + print(f"#in: {len(content)}, #out: {len(new_content)}") + json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--out-file", type=str, default="sharegpt_split.json") + parser.add_argument("--begin", type=int) + parser.add_argument("--end", type=int) + parser.add_argument("--model-name-or-path", type=str, required=True) + parser.add_argument("--max-length", type=int, default=2048) + args = parser.parse_args() + main(args) diff --git a/third_party/FastChat/fastchat/data/split_train_test.py b/third_party/FastChat/fastchat/data/split_train_test.py new file mode 100644 index 00000000..60b8960b --- /dev/null +++ b/third_party/FastChat/fastchat/data/split_train_test.py @@ -0,0 +1,34 @@ +""" +Split the dataset into training and test set. + +Usage: python3 -m fastchat.data.split_train_test --in sharegpt.json +""" +import argparse +import json + +import numpy as np + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--begin", type=int, default=0) + parser.add_argument("--end", type=int, default=100) + parser.add_argument("--ratio", type=float, default=0.9) + args = parser.parse_args() + + content = json.load(open(args.in_file, "r")) + np.random.seed(0) + + perm = np.random.permutation(len(content)) + content = [content[i] for i in perm] + split = int(args.ratio * len(content)) + + train_set = content[:split] + test_set = content[split:] + + print(f"#train: {len(train_set)}, #test: {len(test_set)}") + train_name = args.in_file.replace(".json", "_train.json") + test_name = args.in_file.replace(".json", "_test.json") + json.dump(train_set, open(train_name, "w"), indent=2, ensure_ascii=False) + json.dump(test_set, open(test_name, "w"), indent=2, ensure_ascii=False) diff --git a/third_party/FastChat/fastchat/llm_judge/README.md b/third_party/FastChat/fastchat/llm_judge/README.md new file mode 100644 index 00000000..6737cf8b --- /dev/null +++ b/third_party/FastChat/fastchat/llm_judge/README.md @@ -0,0 +1,174 @@ +# LLM Judge +| [Paper](https://arxiv.org/abs/2306.05685) | [Leaderboard](https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard) | + +In this package, you can use MT-bench questions and prompts to evaluate your models with LLM-as-a-judge. +MT-bench is a set of challenging multi-turn open-ended questions for evaluating chat assistants. +To automate the evaluation process, we prompt strong LLMs like GPT-4 to act as judges and assess the quality of the models' responses. + +## Contents +- [Install](#install) +- [Review Pre-Generated Model Answers and Judgments](#review-pre-generated-model-answers-and-judgments) +- [MT-Bench](#mt-bench) +- [Agreement Computation](#agreement-computation) +- [Datasets](#datasets) +- [Citation](#citation) + +## Install +``` +git clone https://github.com/lm-sys/FastChat.git +cd FastChat +pip install -e ".[model_worker,llm_judge]" +``` + +## Review Pre-Generated Model Answers and Judgments +We provide pre-generated model answers and judgments for some models. +You can view them at this [demo](https://huggingface.co/spaces/lmsys/mt-bench). + +To download the pre-generated data, use +``` +python3 download_mt_bench_pregenerated.py +``` + +After downloading the data, you can view them locally by +``` +python3 qa_browser.py --share +``` +You can use this QA browser to view the answers generated by you later. + +## MT-Bench + +### Evaluate a model on MT-bench + +#### Step 1. Generate model answers to MT-bench questions +``` +python gen_model_answer.py --model-path [MODEL-PATH] --model-id [MODEL-ID] +``` +Arguments: + - `[MODEL-PATH]` is the path to the weights, which can be a local folder or a Hugging Face repo ID. + - `[MODEL-ID]` is a name you give to the model. + +e.g., +``` +python gen_model_answer.py --model-path lmsys/vicuna-7b-v1.5 --model-id vicuna-7b-v1.5 +``` +The answers will be saved to `data/mt_bench/model_answer/[MODEL-ID].jsonl`. + +To make sure FastChat loads the correct prompt template, see the supported models and how to add a new model [here](../../docs/model_support.md#how-to-support-a-new-model). + +You can also specify `--num-gpus-per-model` for model parallelism (needed for large 65B models) and `--num-gpus-total` to parallelize answer generation with multiple GPUs. + +#### Step 2. Generate GPT-4 judgments +There are several options to use GPT-4 as a judge, such as pairwise winrate and single-answer grading. +In MT-bench, we recommend single-answer grading as the default mode. +This mode asks GPT-4 to grade and give a score to model's answer directly without pairwise comparison. +For each turn, GPT-4 will give a score on a scale of 10. We then compute the average score on all turns. + +``` +export OPENAI_API_KEY=XXXXXX # set the OpenAI API key +python gen_judgment.py --model-list [LIST-OF-MODEL-ID] --parallel [num-concurrent-api-call] +``` + +e.g., +``` +python gen_judgment.py --model-list vicuna-13b-v1.3 alpaca-13b llama-13b claude-v1 gpt-3.5-turbo gpt-4 --parallel 2 +``` +The judgments will be saved to `data/mt_bench/model_judgment/gpt-4_single.jsonl` + +#### Step 3. Show MT-bench scores + +- Show the scores for selected models + ``` + python show_result.py --model-list vicuna-13b-v1.3 alpaca-13b llama-13b claude-v1 gpt-3.5-turbo gpt-4 + ``` +- Show all scores + ``` + python show_result.py + ``` + +--- + +### Other grading options +Besides score-based single-answer grading, we also support two additional grading options based on win rates: +- `pariwise-baseline`: run pairwise comparison against a baseline model. +- `pairwise-all`: run pairwise comparison between all model pairs on all questions. + +#### Option 2: pairwise comparison against a baseline (default: gpt-3.5-turbo) + +- Generate GPT-4 judgments +``` +python gen_judgment.py --mode pairwise-baseline --model-list vicuna-13b-v1.3 alpaca-13b llama-13b --parallel 2 +``` +The judgments will be saved to `data/mt_bench/model_judgment/gpt-4_pair.jsonl` + +- Show results +``` +python show_result.py --mode pairwise-baseline +``` + +#### Option 3: Run GPT-4 judge with all pair comparisons + +Another option is to run pairwise comparisons on all possible pairs. +This could be more expensive when #models increases, but it gives you a more comprehensive information. + +``` +python gen_judgment.py --mode pairwise-all --model-list [LIST-OF-MODEL-ID] --parallel [num-concurrent-api-call] +``` + +``` +python show_result.py --mode pairwise-all +``` + +### How to get GPT-3.5/GPT-4/Claude's answer? +- `python gen_api_answer.py --model [MODEL-NAME]` to generate GPT-3.5/4 and Claude's answers. + + +### How to plot the radar figure? + +You can use this [colab notebook](https://colab.research.google.com/drive/15O3Y8Rxq37PuMlArE291P4OC6ia37PQK#scrollTo=5i8R0l-XqkgO) to plot the radar figure for MT-bench. + + + + +### Other backends +We can also use vLLM for answer generation, which can be faster for the models supported by vLLM. + +1. Launch a vLLM worker +``` +python3 -m fastchat.serve.controller +python3 -m fastchat.serve.vllm_worker --model-path [MODEL-PATH] +python3 -m fastchat.serve.openai_api_server --host localhost --port 8000 +``` + - Arguments: + - `[MODEL-PATH]` is the path to the weights, which can be a local folder or a Hugging Face repo ID. + +2. Generate the answers +``` +python gen_api_answer.py --model [MODEL-NAME] --openai-api-base http://localhost:8000/v1 --parallel 50 +``` + - Arguments: + - `[MODEL-NAME]` is the name of the model from Step 1. + - `--parallel` is the number of concurrent API calls to the vLLM worker. + + +## Agreement Computation +We released 3.3K human annotations for model responses generated by 6 models in response to 80 MT-bench questions. The dataset is available at [lmsys/mt_bench_human_judgments](https://huggingface.co/datasets/lmsys/mt_bench_human_judgments). + +This Colab [notebook](https://colab.research.google.com/drive/1ctgygDRJhVGUJTQy8-bRZCl1WNcT8De6?usp=sharing) shows how to compute the agreement between humans and GPT-4 judge with the dataset. Our results show that humans and GPT-4 judge achieve over 80\% agreement, the same level of agreement between humans. + +## Datasets +- [Chatbot Arena Conversation Dataset](https://huggingface.co/datasets/lmsys/chatbot_arena_conversations) +- [MT-bench Human Annotation Dataset](https://huggingface.co/datasets/lmsys/mt_bench_human_judgments) + + +## Citation +Please cite the following paper if you find the code or datasets helpful. +``` +@misc{zheng2023judging, + title={Judging LLM-as-a-judge with MT-Bench and Chatbot Arena}, + author={Lianmin Zheng and Wei-Lin Chiang and Ying Sheng and Siyuan Zhuang and Zhanghao Wu and Yonghao Zhuang and Zi Lin and Zhuohan Li and Dacheng Li and Eric. P Xing and Hao Zhang and Joseph E. Gonzalez and Ion Stoica}, + year={2023}, + eprint={2306.05685}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +``` diff --git a/third_party/FastChat/fastchat/llm_judge/clean_judgment.py b/third_party/FastChat/fastchat/llm_judge/clean_judgment.py new file mode 100644 index 00000000..d139ed7f --- /dev/null +++ b/third_party/FastChat/fastchat/llm_judge/clean_judgment.py @@ -0,0 +1,93 @@ +""" +Clean model judgment files. +""" +import argparse +import json + +selected_models = [ + "alpaca-13b", + "baize-v2-13b", + "chatglm-6b", + "claude-instant-v1", + "claude-v1", + "dolly-v2-12b", + "falcon-40b-instruct", + "fastchat-t5-3b", + "gpt-3.5-turbo", + "gpt-4", + "gpt4all-13b-snoozy", + "guanaco-33b", + "guanaco-65b", + "h2ogpt-oasst-open-llama-13b", + "koala-13b", + "llama-13b", + "mpt-30b-chat", + "mpt-30b-instruct", + "mpt-7b-chat", + "nous-hermes-13b", + "oasst-sft-4-pythia-12b", + "oasst-sft-7-llama-30b", + "palm-2-chat-bison-001", + "rwkv-4-raven-14b", + "stablelm-tuned-alpha-7b", + "tulu-30b", + "vicuna-13b-v1.3", + "vicuna-33b-v1.3", + "vicuna-7b-v1.3", + "wizardlm-13b", + "wizardlm-30b", +] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--infile", type=str) + args = parser.parse_args() + + infile = args.infile + outfile = infile.replace(".jsonl", "_clean.jsonl") + + raw_lines = open(infile).readlines() + rets = [] + models = set() + visited = set() + for line in raw_lines: + obj = json.loads(line) + + if "model_1" in obj: # pair + model = obj["model_1"] + key = ( + obj["model_1"], + obj["model_2"], + obj["question_id"], + tuple(obj["judge"]), + ) + else: # single + model = obj["model"] + key = (obj["model"], obj["question_id"], tuple(obj["judge"])) + + if key in visited: + continue + visited.add(key) + + if model not in selected_models: + continue + models.add(model) + rets.append(obj) + + models = sorted(list(models)) + missing_models = [x for x in selected_models if x not in models] + print(f"in models: {models}, number: {len(models)}") + print(f"missing models: {missing_models}") + print(f"#in: {len(raw_lines)}, #out: {len(rets)}") + rets.sort( + key=lambda x: ( + x["model"] if "model" in x else x["model_1"], + x["question_id"], + x["turn"], + ) + ) + + with open(outfile, "w") as fout: + for x in rets: + fout.write(json.dumps(x) + "\n") diff --git a/third_party/FastChat/fastchat/llm_judge/common.py b/third_party/FastChat/fastchat/llm_judge/common.py new file mode 100644 index 00000000..57fa7004 --- /dev/null +++ b/third_party/FastChat/fastchat/llm_judge/common.py @@ -0,0 +1,711 @@ +""" +Common data structures and utilities. +""" + +import ast +import dataclasses +import glob +import json +import os +import re +import time +from typing import Optional + +import openai +import anthropic + +from fastchat.model.model_adapter import ( + get_conversation_template, + ANTHROPIC_MODEL_LIST, + OPENAI_MODEL_LIST, +) + +# API setting constants +API_MAX_RETRY = 16 +API_RETRY_SLEEP = 10 +API_ERROR_OUTPUT = "$ERROR$" + +TIE_DELTA = 0.1 + +# Categories that need reference answers +NEED_REF_CATS = ["math", "reasoning", "coding", "arena-hard-200"] + +# Extract scores from judgments +two_score_pattern = re.compile("\[\[(\d+\.?\d*),\s?(\d+\.?\d*)\]\]") +two_score_pattern_backup = re.compile("\[(\d+\.?\d*),\s?(\d+\.?\d*)\]") +one_score_pattern = re.compile("\[\[(\d+\.?\d*)\]\]") +one_score_pattern_backup = re.compile("\[(\d+\.?\d*)\]") + +# Sampling temperature configs for +temperature_config = { + "writing": 0.7, + "roleplay": 0.7, + "extraction": 0.0, + "math": 0.0, + "coding": 0.0, + "reasoning": 0.0, + "stem": 0.1, + "humanities": 0.1, + "arena-hard-200": 0.0, +} + +reverse_model_map = { + "model_1": "model_2", + "model_2": "model_1", +} + + +@dataclasses.dataclass +class Judge: + model_name: str + prompt_template: dict + ref_based: bool = False + multi_turn: bool = False + + +@dataclasses.dataclass +class MatchSingle: + question: dict + model: str + answer: dict + judge: Judge + ref_answer: dict = None + multi_turn: bool = False + + +@dataclasses.dataclass +class MatchPair: + question: dict + model_1: str + model_2: str + answer_1: dict + answer_2: dict + judge: Judge + ref_answer: dict = None + multi_turn: bool = False + + +def load_questions(question_file: str, begin: Optional[int], end: Optional[int]): + """Load questions from a file.""" + questions = [] + with open(question_file, "r") as ques_file: + for line in ques_file: + if line: + questions.append(json.loads(line)) + questions = questions[begin:end] + return questions + + +def load_model_answers(answer_dir: str): + """Load model answers. + + The return value is a python dict of type: + Dict[model_name: str -> Dict[question_id: int -> answer: dict]] + """ + filenames = glob.glob(os.path.join(answer_dir, "*.jsonl")) + filenames.sort() + model_answers = {} + + for filename in filenames: + model_name = os.path.basename(filename)[:-6] + answer = {} + with open(filename) as fin: + for line in fin: + line = json.loads(line) + answer[line["question_id"]] = line + model_answers[model_name] = answer + + return model_answers + + +def load_judge_prompts(prompt_file: str): + """Load judge prompts. + + The return value is a python dict of type: + Dict[judge_name: str -> dict] + """ + prompts = {} + with open(prompt_file) as fin: + for line in fin: + line = json.loads(line) + prompts[line["name"]] = line + return prompts + + +def run_judge_single(question, answer, judge, ref_answer, multi_turn=False): + kwargs = {} + model = judge.model_name + if ref_answer is not None: + kwargs["ref_answer_1"] = ref_answer["choices"][0]["turns"][0] + if multi_turn: + kwargs["ref_answer_2"] = ref_answer["choices"][0]["turns"][1] + + if multi_turn: + user_prompt = judge.prompt_template["prompt_template"].format( + question_1=question["turns"][0], + question_2=question["turns"][1], + answer_1=answer["choices"][0]["turns"][0], + answer_2=answer["choices"][0]["turns"][1], + **kwargs, + ) + else: + user_prompt = judge.prompt_template["prompt_template"].format( + question=question["turns"][0], + answer=answer["choices"][0]["turns"][0], + **kwargs, + ) + + rating = -1 + + system_prompt = judge.prompt_template["system_prompt"] + conv = get_conversation_template(model) + conv.set_system_message(system_prompt) + conv.append_message(conv.roles[0], user_prompt) + conv.append_message(conv.roles[1], None) + + if model in OPENAI_MODEL_LIST: + judgment = chat_completion_openai(model, conv, temperature=0, max_tokens=2048) + elif model in ANTHROPIC_MODEL_LIST: + judgment = chat_completion_anthropic( + model, conv, temperature=0, max_tokens=1024 + ) + else: + raise ValueError(f"Invalid judge model name: {model}") + + if judge.prompt_template["output_format"] == "[[rating]]": + match = re.search(one_score_pattern, judgment) + if not match: + match = re.search(one_score_pattern_backup, judgment) + + if match: + rating = ast.literal_eval(match.groups()[0]) + else: + rating = -1 + else: + raise ValueError( + f"invalid output format: {judge.prompt_template['output_format']}" + ) + + return rating, user_prompt, judgment + + +def play_a_match_single(match: MatchPair, output_file: str): + question, model, answer, judge, ref_answer, multi_turn = ( + match.question, + match.model, + match.answer, + match.judge, + match.ref_answer, + match.multi_turn, + ) + + if judge.prompt_template["type"] == "single": + score, user_prompt, judgment = run_judge_single( + question, answer, judge, ref_answer, multi_turn=multi_turn + ) + + question_id = question["question_id"] + turn = 1 if not multi_turn else 2 + result = { + "question_id": question_id, + "model": model, + "judge": (judge.model_name, judge.prompt_template["name"]), + "user_prompt": user_prompt, + "judgment": judgment, + "score": score, + "turn": turn, + "tstamp": time.time(), + } + print( + f"question: {question_id}, turn: {turn}, model: {model}, " + f"score: {score}, " + f"judge: {(judge.model_name, judge.prompt_template['name'])}" + ) + else: + raise ValueError(f"invalid judge type: {judge['type']}") + + if output_file: + os.makedirs(os.path.dirname(output_file), exist_ok=True) + with open(output_file, "a") as fout: + fout.write(json.dumps(result) + "\n") + + return result + + +def run_judge_pair(question, answer_a, answer_b, judge, ref_answer, multi_turn=False): + kwargs = {} + model = judge.model_name + if ref_answer is not None: + kwargs["ref_answer_1"] = ref_answer["choices"][0]["turns"][0] + if multi_turn: + kwargs["ref_answer_2"] = ref_answer["choices"][0]["turns"][1] + + if multi_turn: + system_prompt = judge.prompt_template["system_prompt"] + user_prompt = judge.prompt_template["prompt_template"].format( + question_1=question["turns"][0], + question_2=question["turns"][1], + answer_a_1=answer_a["choices"][0]["turns"][0], + answer_b_1=answer_b["choices"][0]["turns"][0], + answer_a_2=answer_a["choices"][0]["turns"][1], + answer_b_2=answer_b["choices"][0]["turns"][1], + **kwargs, + ) + else: + system_prompt = judge.prompt_template["system_prompt"] + user_prompt = judge.prompt_template["prompt_template"].format( + question=question["turns"][0], + answer_a=answer_a["choices"][0]["turns"][0], + answer_b=answer_b["choices"][0]["turns"][0], + **kwargs, + ) + + winner = "error" + + conv = get_conversation_template(model) + conv.append_message(conv.roles[0], user_prompt) + conv.append_message(conv.roles[1], None) + + if model in OPENAI_MODEL_LIST: + conv.set_system_message(system_prompt) + judgment = chat_completion_openai(model, conv, temperature=0, max_tokens=2048) + elif model in ANTHROPIC_MODEL_LIST: + if system_prompt != "You are a helpful assistant.": + user_prompt = "[Instruction]\n" + system_prompt + "\n\n" + user_prompt + conv.messages[0][1] = user_prompt + judgment = chat_completion_anthropic( + model, conv, temperature=0, max_tokens=1024 + ) + else: + raise ValueError(f"Invalid judge model name: {model}") + + if judge.prompt_template["output_format"] == "[[A]]": + if "[[A]]" in judgment: + winner = "A" + elif "[[B]]" in judgment: + winner = "B" + elif "[[C]]" in judgment: + winner = "tie" + else: + winner = "error" + elif judge.prompt_template["output_format"] == "[[rating_a,rating_b]]": + match = re.search(two_score_pattern, judgment) + if not match: + match = re.search(two_score_pattern_backup, judgment) + if match: + scores = [ast.literal_eval(s.strip()) for s in match.groups()] + if abs(scores[0] - scores[1]) <= TIE_DELTA: + winner = "tie" + elif scores[0] > scores[1]: + winner = "A" + else: + winner = "B" + else: + winner = "error" + else: + raise ValueError( + f"invalid output format: {judge.prompt_template['output_format']}" + ) + + return winner, user_prompt, judgment + + +def play_a_match_pair(match: MatchPair, output_file: str): + question, model_1, model_2, answer_1, answer_2, judge, ref_answer, multi_turn = ( + match.question, + match.model_1, + match.model_2, + match.answer_1, + match.answer_2, + match.judge, + match.ref_answer, + match.multi_turn, + ) + + if judge.prompt_template["type"] == "pairwise": + g1_winner, g1_user_prompt, g1_judgment = run_judge_pair( + question, answer_1, answer_2, judge, ref_answer, multi_turn=multi_turn + ) + g2_winner, g2_user_prompt, g2_judgment = run_judge_pair( + question, answer_2, answer_1, judge, ref_answer, multi_turn=multi_turn + ) + + g1_map = {"A": "model_1", "B": "model_2"} + g2_map = {"A": "model_2", "B": "model_1"} + g1_winner = g1_map.get(g1_winner, g1_winner) + g2_winner = g2_map.get(g2_winner, g2_winner) + question_id = question["question_id"] + turn = 1 if not multi_turn else 2 + + result = { + "question_id": question_id, + "model_1": model_1, + "model_2": model_2, + "g1_winner": g1_winner, + "g2_winner": g2_winner, + "judge": (judge.model_name, judge.prompt_template["name"]), + "g1_user_prompt": g1_user_prompt, + "g1_judgment": g1_judgment, + "g2_user_prompt": g2_user_prompt, + "g2_judgment": g2_judgment, + "turn": turn, + "tstamp": time.time(), + } + + print( + f"question: {question_id}, turn: {turn}, model_1: {model_1}, model_2: {model_2}, " + f"g1_winner: {g1_winner}, g2_winner: {g2_winner}, " + f"judge: {(judge.model_name, judge.prompt_template['name'])}" + ) + elif judge.prompt_template["type"] == "single": + m1_score, m1_user_prompt, m1_judgment = run_judge_single( + question, answer_1, judge + ) + m2_score, m2_user_prompt, m2_judgment = run_judge_single( + question, answer_2, judge + ) + + if abs(m1_score - m2_score) <= TIE_DELTA: + winner = "tie" + elif m1_score > m2_score: + winner = "model_1" + else: + winner = "model_2" + + question_id = question["question_id"] + result = { + "question_id": question_id, + "model_1": model_1, + "model_2": model_2, + "g1_winner": winner, + "g2_winner": winner, + "judge": (judge.model_name, judge.prompt_template["name"]), + "g1_user_prompt": m1_user_prompt, + "g1_judgment": m1_judgment, + "g2_user_prompt": m2_user_prompt, + "g2_judgment": m2_judgment, + "m1_score": m1_score, + "m2_score": m2_score, + "tstamp": time.time(), + } + print( + f"question: {question_id}, model_1: {model_1}, model_2: {model_2}, " + f"winner: {winner}, m1_score: {m1_score}, m2_score: {m2_score}, " + f"judge: {(judge.model_name, judge.prompt_template['name'])}" + ) + else: + raise ValueError(f"invalid judge type: {judge['type']}") + + if output_file: + os.makedirs(os.path.dirname(output_file), exist_ok=True) + with open(output_file, "a") as fout: + fout.write(json.dumps(result) + "\n") + + return result + + +def chat_completion_openai(model, conv, temperature, max_tokens, api_dict=None): + if api_dict is not None: + openai.api_base = api_dict["api_base"] + openai.api_key = api_dict["api_key"] + output = API_ERROR_OUTPUT + for _ in range(API_MAX_RETRY): + try: + messages = conv.to_openai_api_messages() + response = openai.ChatCompletion.create( + model=model, + messages=messages, + n=1, + temperature=temperature, + max_tokens=max_tokens, + ) + output = response["choices"][0]["message"]["content"] + break + except openai.error.OpenAIError as e: + print(type(e), e) + time.sleep(API_RETRY_SLEEP) + + return output + + +def chat_completion_openai_azure(model, conv, temperature, max_tokens, api_dict=None): + openai.api_type = "azure" + openai.api_version = "2023-07-01-preview" + if api_dict is not None: + openai.api_base = api_dict["api_base"] + openai.api_key = api_dict["api_key"] + else: + openai.api_base = os.environ["AZURE_OPENAI_ENDPOINT"] + openai.api_key = os.environ["AZURE_OPENAI_KEY"] + + if "azure-" in model: + model = model[6:] + + output = API_ERROR_OUTPUT + for _ in range(API_MAX_RETRY): + try: + messages = conv.to_openai_api_messages() + response = openai.ChatCompletion.create( + engine=model, + messages=messages, + n=1, + temperature=temperature, + max_tokens=max_tokens, + ) + output = response["choices"][0]["message"]["content"] + break + except openai.error.OpenAIError as e: + print(type(e), e) + time.sleep(API_RETRY_SLEEP) + except openai.error.InvalidRequestError as e: + print(type(e), e) + break + except KeyError: + print(response) + break + + return output + + +def chat_completion_anthropic(model, conv, temperature, max_tokens, api_dict=None): + if api_dict is not None and "api_key" in api_dict: + api_key = api_dict["api_key"] + else: + api_key = os.environ["ANTHROPIC_API_KEY"] + + output = API_ERROR_OUTPUT + for _ in range(API_MAX_RETRY): + try: + c = anthropic.Anthropic(api_key=api_key) + prompt = conv.get_prompt() + response = c.completions.create( + model=model, + prompt=prompt, + stop_sequences=[anthropic.HUMAN_PROMPT], + max_tokens_to_sample=max_tokens, + temperature=temperature, + ) + output = response.completion + break + except anthropic.APIError as e: + print(type(e), e) + time.sleep(API_RETRY_SLEEP) + return output.strip() + + +def chat_completion_palm(chat_state, model, conv, temperature, max_tokens): + from fastchat.serve.api_provider import init_palm_chat + + assert model == "palm-2-chat-bison-001" + + if chat_state is None: + chat_state = init_palm_chat("chat-bison@001") + + parameters = { + "temperature": temperature, + "top_p": 0.8, + "top_k": 40, + "max_output_tokens": max_tokens, + } + output = API_ERROR_OUTPUT + for _ in range(API_MAX_RETRY): + try: + response = chat_state.send_message(conv.messages[-2][1], **parameters) + output = response.text + break + except Exception as e: + print(type(e), e) + time.sleep(API_RETRY_SLEEP) + return chat_state, output + + +def normalize_game_key_single(gamekey, result): + """Make the model names sorted in a game key.""" + qid, model_1, model_2 = gamekey + if model_1 < model_2: + return gamekey, result + else: + new_gamekey = (qid, model_2, model_1) + new_result = { + "winners": tuple(reverse_model_map.get(x, x) for x in result["winners"]), + "g1_judgment": result["g2_judgment"], + "g2_judgment": result["g1_judgment"], + } + return new_gamekey, new_result + + +def normalize_game_key_dict(judgment_dict): + """Make the model names sorted in the game keys.""" + ret = {} + for key, value in judgment_dict.items(): + new_key, new_value = normalize_game_key_single(key, value) + ret[new_key] = new_value + return ret + + +def load_pairwise_model_judgments(filename: str): + """Load model judgments. + + The return value is a dict of type: + Dict[judge: Tuple -> Dict[game_key: tuple -> game_result: dict] + """ + judge_dict = {} + + for line in open(filename): + obj = json.loads(line) + judge = tuple(obj["judge"]) + qid, model_1, model_2 = obj["question_id"], obj["model_1"], obj["model_2"] + + if judge not in judge_dict: + judge_dict[judge] = {} + + if "winner" in obj: + winner = obj["winner"] + elif "g1_winner" in obj and "g2_winner" in obj: + g1_winner, g2_winner = obj["g1_winner"], obj["g2_winner"] + if g1_winner == g2_winner: + winner = g1_winner + else: + winner = "inconsistent" + else: + raise ValueError(f"Invalid keys: {list(obj.keys())}") + + gamekey = (qid, model_1, model_2) + winners = (winner,) + + judge_dict[judge][gamekey] = { + "winners": winners, + "g1_judgment": obj["g1_judgment"], + "g2_judgment": obj["g2_judgment"], + } + + # Make the model names sorted in the game keys + normalized = {} + for judge, value in judge_dict.items(): + normalized[judge] = normalize_game_key_dict(value) + return normalized + + +def load_single_model_judgments(filename: str): + """Load model judgments. + + The return value is a dict of type: + Dict[judge: Tuple -> Dict[game_key: tuple -> game_result: dict] + """ + judge_dict = {} + + for line in open(filename): + obj = json.loads(line) + judge = tuple(obj["judge"]) + qid, model = obj["question_id"], obj["model"] + + if judge not in judge_dict: + judge_dict[judge] = {} + + gamekey = (qid, model) + + judge_dict[judge][gamekey] = { + "score": obj["score"], + "judgment": obj["judgment"], + } + return judge_dict + + +def resolve_pairwise_judgment_dict( + question, model_judgments_normal, model_judgments_math, multi_turn=False +): + """Return the correct pairwise judge.""" + if multi_turn: + if question["category"] in NEED_REF_CATS: + return model_judgments_math[("gpt-4", "pair-math-v1-multi-turn")] + return model_judgments_normal[("gpt-4", "pair-v2-multi-turn")] + + if question["category"] in NEED_REF_CATS: + return model_judgments_math[("gpt-4", "pair-math-v1")] + else: + return model_judgments_normal[("gpt-4", "pair-v2")] + + +def resolve_single_judgment_dict( + question, model_judgments_normal, model_judgments_math, multi_turn=False +): + """Return the correct single answer grading judge.""" + if multi_turn: + if question["category"] in NEED_REF_CATS: + return model_judgments_math[("gpt-4", "single-math-v1-multi-turn")] + return model_judgments_normal[("gpt-4", "single-v1-multi-turn")] + + if question["category"] in NEED_REF_CATS: + return model_judgments_math[("gpt-4", "single-math-v1")] + else: + return model_judgments_normal[("gpt-4", "single-v1")] + + +def get_pairwise_judge_explanation(gamekey, judgment_dict): + """Get model judge explanation.""" + try: + qid, model_1, model_2 = gamekey + if model_1 < model_2: + res = judgment_dict[gamekey] + g1_judgment, g2_judgment = res["g1_judgment"], res["g2_judgment"] + else: + new_gamekey = (qid, model_2, model_1) + res = judgment_dict[new_gamekey] + + model_1, model_2 = model_1, model_2 + g1_judgment, g2_judgment = res["g2_judgment"], res["g1_judgment"] + + return ( + f"**Game 1**. **A**: {model_1}, **B**: {model_2}\n\n" + f"**Judgment**: {g1_judgment}" + + f"\n\n`--------------------------`\n\n" + + f"**Game 2**. **A**: {model_2}, **B**: {model_1}\n\n" + f"**Judgment**: {g2_judgment}" + ) + except KeyError: + return "N/A" + + +def get_single_judge_explanation(gamekey, judgment_dict): + """Get model judge explanation.""" + try: + qid, model = gamekey + + res = judgment_dict[gamekey] + + g1_judgment = res["judgment"] + g1_score = res["score"] + + return ( + f"**Game 1**. **A**: {model}, **Score**: {g1_score}\n\n" + f"**Judgment**: {g1_judgment}" + ) + except KeyError: + return "N/A" + + +def check_data(questions, model_answers, ref_answers, models, judges): + # check model answers + for m in models: + assert m in model_answers, f"Missing model answer for {m}" + m_answer = model_answers[m] + for q in questions: + assert ( + q["question_id"] in m_answer + ), f"Missing model {m}'s answer to Question {q['question_id']}" + # check ref answers + for jg in judges.values(): + if not jg.ref_based: + continue + for q in questions: + if q["category"] not in NEED_REF_CATS: + continue + assert ( + q["question_id"] in ref_answers[jg.model_name] + ), f"Missing reference answer to Question {q['question_id']} for judge {jg.model_name}" + + +def get_model_list(answer_dir): + file_paths = glob.glob(f"{answer_dir}/*.jsonl") + file_names = [os.path.splitext(os.path.basename(f))[0] for f in file_paths] + return file_names diff --git a/third_party/FastChat/fastchat/llm_judge/compute_agreement.py b/third_party/FastChat/fastchat/llm_judge/compute_agreement.py new file mode 100644 index 00000000..1b940bf5 --- /dev/null +++ b/third_party/FastChat/fastchat/llm_judge/compute_agreement.py @@ -0,0 +1,140 @@ +""" +Compute agreement among judges. + +Usage: +python compute_agreement.py --judges gpt4-pair human --votefiles human_judgments.json gpt4_pair_judgments.json +python compute_agreement.py --judges human human --votefiles human_judgments.json +""" +import argparse +import json +import os + +import numpy as np + + +def get_judge_name(judge): + if isinstance(judge, list) and judge[0] == "gpt-4" and judge[1].startswith("pair"): + return "gpt4-pair" + if judge.startswith("expert"): + return "human" + if judge.startswith("author"): + return "author" + + +def revert(vote): + if vote == "model_a": + return "model_b" + elif vote == "model_b": + return "model_a" + return vote + + +def get_mt_bench_votes_data(raw_votes): + data = [{}, {}] + + for judge_votes in raw_votes: + for vote in judge_votes: + turn = vote["turn"] - 1 + if vote["model_a"] < vote["model_b"]: + key = (vote["question_id"], vote["model_a"], vote["model_b"]) + winner = vote["winner"] + else: + key = (vote["question_id"], vote["model_b"], vote["model_a"]) + winner = revert(vote["winner"]) + judge = get_judge_name(vote["judge"]) + if key not in data[turn]: + data[turn][key] = {} + if judge not in data[turn][key]: + data[turn][key][judge] = [] + data[turn][key][judge].append(winner) + + return data + + +def convertvote(vote): + if "tie" in vote: + return "tie" + return vote + + +def equalvote(vote1, vote2): + if "tie" in vote1 and "tie" in vote2: + return True + return vote1 == vote2 + + +# data: Dict[qid -> List[vote]] +def get_mt_bench_agreement(data, judge1, judge2, ban): + if judge1.startswith("gpt4") and judge2 == "human": + stats = [0, 0] + for votes in data.values(): + if judge1 not in votes or judge2 not in votes: + continue + assert len(votes[judge1]) == 1 + if convertvote(votes[judge1][0]) in ban: + continue + for v in votes[judge2]: + if convertvote(v) in ban: + continue + stats[1] += 1 + stats[0] += equalvote(votes[judge1][0], v) + return stats[0], stats[1] + elif judge1 == "human" and judge2 == "human": + stats = [0, 0] + for votes in data.values(): + if "human" not in votes: + continue + for i in range(len(votes["human"]) - 1): + for j in range(i + 1, len(votes["human"])): + if ( + convertvote(votes["human"][i]) in ban + or convertvote(votes["human"][j]) in ban + ): + continue + stats[1] += 1 + stats[0] += equalvote(votes["human"][i], votes["human"][j]) + return stats[0], stats[1] + else: + raise Exception("Unsupported judges.") + + +def run_mt_bench_agreement(judges, votefiles): + # votes[i]: List of votes + votes = [] + for filename in votefiles: + with open(filename, "r") as f: + data = json.load(f) + votes.append(data) + + data = get_mt_bench_votes_data(votes) + + agree, total = get_mt_bench_agreement(data[0], judges[0], judges[1], ban=[]) + print( + f"turn 1 with tie. #total: {total}, #agree: {agree}, ratio: {agree/total:.2f}" + ) + agree, total = get_mt_bench_agreement(data[0], judges[0], judges[1], ban=["tie"]) + print( + f"turn 1 without tie. #total: {total}, #agree: {agree}, ratio: {agree/total:.2f}" + ) + agree, total = get_mt_bench_agreement(data[1], judges[0], judges[1], ban=[]) + print( + f"turn 2 with tie. #total: {total}, #agree: {agree}, ratio: {agree/total:.2f}" + ) + agree, total = get_mt_bench_agreement(data[1], judges[0], judges[1], ban=["tie"]) + print( + f"turn 2 without tie. #total: {total}, #agree: {agree}, ratio: {agree/total:.2f}" + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--judges", nargs=2, type=str, default=["gpt4-pair", "human"]) + parser.add_argument( + "--votefiles", + nargs="+", + type=str, + default=["gpt4_judgments.json", "human_judgments.json"], + ) + args = parser.parse_args() + + run_mt_bench_agreement(args.judges, args.votefiles) diff --git a/third_party/FastChat/fastchat/llm_judge/data/judge_prompts.jsonl b/third_party/FastChat/fastchat/llm_judge/data/judge_prompts.jsonl new file mode 100644 index 00000000..4ec7524c --- /dev/null +++ b/third_party/FastChat/fastchat/llm_judge/data/judge_prompts.jsonl @@ -0,0 +1,8 @@ +{"name": "pair-v2", "type": "pairwise", "system_prompt": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. You should choose the assistant that follows the user's instructions and answers the user's question better. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. Begin your evaluation by comparing the two responses and provide a short explanation. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.", "prompt_template": "[User Question]\n{question}\n\n[The Start of Assistant A's Answer]\n{answer_a}\n[The End of Assistant A's Answer]\n\n[The Start of Assistant B's Answer]\n{answer_b}\n[The End of Assistant B's Answer]", "description": "Prompt for general questions", "category": "general", "output_format": "[[A]]"} +{"name": "pair-v2-multi-turn", "type": "pairwise", "system_prompt": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user questions. You should choose the assistant that follows the user's instructions and answers the user's questions better. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. You should focus on who provides a better answer to the second user question. Begin your evaluation by comparing the responses of the two assistants and provide a short explanation. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.", "prompt_template": "<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{answer_a_1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{answer_a_2}\n\n<|The End of Assistant A's Conversation with User|>\n\n\n<|The Start of Assistant B's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant B:\n{answer_b_1}\n\n### User:\n{question_2}\n\n### Assistant B:\n{answer_b_2}\n\n<|The End of Assistant B's Conversation with User|>", "description": "Prompt for multi-turn general questions", "category": "general", "output_format": "[[A]]"} +{"name": "pair-math-v1", "type": "pairwise", "system_prompt": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. Your evaluation should consider correctness and helpfulness. You will be given a reference answer, assistant A's answer, and assistant B's answer. Your job is to evaluate which assistant's answer is better. Begin your evaluation by comparing both assistants' answers with the reference answer. Identify and correct any mistakes. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.", "prompt_template": "[User Question]\n{question}\n\n[The Start of Reference Answer]\n{ref_answer_1}\n[The End of Reference Answer]\n\n[The Start of Assistant A's Answer]\n{answer_a}\n[The End of Assistant A's Answer]\n\n[The Start of Assistant B's Answer]\n{answer_b}\n[The End of Assistant B's Answer]", "description": "Prompt for math questions", "category": "math", "output_format": "[[A]]"} +{"name": "pair-math-v1-multi-turn", "type": "pairwise", "system_prompt": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user questions. Your evaluation should consider correctness and helpfulness. You will be given reference answers, the assistant A's answers, the assistant B's answers. Your job is to determine which assistant provides correct and helpful answers to the second user question. Begin your evaluation by comparing both assistants' answers with the reference answers. Identify and correct any mistakes. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.", "prompt_template": "<|The Start of Reference Answer|>\n\n### User:\n{question_1}\n\n### Reference answer:\n{ref_answer_1}\n\n### User:\n{question_2}\n\n### Reference answer:\n{ref_answer_2}\n\n<|The End of Reference Answer|>\n\n\n<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{answer_a_1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{answer_a_2}\n\n<|The End of Assistant A's Conversation with User|>\n\n\n<|The Start of Assistant B's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant B:\n{answer_b_1}\n\n### User:\n{question_2}\n\n### Assistant B:\n{answer_b_2}\n\n<|The End of Assistant B's Conversation with User|>", "description": "Prompt for multi-turn general questions", "category": "general", "output_format": "[[A]]"} +{"name": "single-v1", "type": "single", "system_prompt": "You are a helpful assistant.", "prompt_template": "[Instruction]\nPlease act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of the response. Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n[Question]\n{question}\n\n[The Start of Assistant's Answer]\n{answer}\n[The End of Assistant's Answer]", "description": "Prompt for general questions", "category": "general", "output_format": "[[rating]]"} +{"name": "single-math-v1", "type": "single", "system_prompt": "You are a helpful assistant.", "prompt_template": "[Instruction]\nPlease act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. Your evaluation should consider correctness and helpfulness. You will be given a reference answer and the assistant's answer. Begin your evaluation by comparing the assistant's answer with the reference answer. Identify and correct any mistakes. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n[Question]\n{question}\n\n[The Start of Reference Answer]\n{ref_answer_1}\n[The End of Reference Answer]\n\n[The Start of Assistant's Answer]\n{answer}\n[The End of Assistant's Answer]", "description": "Prompt for general questions", "category": "math", "output_format": "[[rating]]"} +{"name": "single-v1-multi-turn", "type": "single", "system_prompt": "Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of the response. You evaluation should focus on the assistant's answer to the second user question. Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n", "prompt_template": "<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{answer_1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{answer_2}\n\n<|The End of Assistant A's Conversation with User|>", "description": "Prompt for general questions", "category": "general", "output_format": "[[rating]]"} +{"name": "single-math-v1-multi-turn", "type": "single", "system_prompt": "Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question. Your evaluation should consider correctness and helpfulness. You will be given a reference answer and the assistant's answer. You evaluation should focus on the assistant's answer to the second question. Begin your evaluation by comparing the assistant's answer with the reference answer. Identify and correct any mistakes. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n", "prompt_template": "<|The Start of Reference Answer|>\n\n### User:\n{question_1}\n\n### Reference answer:\n{ref_answer_1}\n\n### User:\n{question_2}\n\n### Reference answer:\n{ref_answer_2}\n\n<|The End of Reference Answer|>\n\n\n<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{answer_1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{answer_2}\n\n<|The End of Assistant A's Conversation with User|>", "description": "Prompt for general questions", "category": "math", "output_format": "[[rating]]"} diff --git a/third_party/FastChat/fastchat/llm_judge/data/mt_bench/misc/radar.png b/third_party/FastChat/fastchat/llm_judge/data/mt_bench/misc/radar.png new file mode 100644 index 00000000..d692b378 Binary files /dev/null and b/third_party/FastChat/fastchat/llm_judge/data/mt_bench/misc/radar.png differ diff --git a/third_party/FastChat/fastchat/llm_judge/data/mt_bench/question.jsonl b/third_party/FastChat/fastchat/llm_judge/data/mt_bench/question.jsonl new file mode 100644 index 00000000..464e2c22 --- /dev/null +++ b/third_party/FastChat/fastchat/llm_judge/data/mt_bench/question.jsonl @@ -0,0 +1,80 @@ +{"question_id": 81, "category": "writing", "turns": ["Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.", "Rewrite your previous response. Start every sentence with the letter A."]} +{"question_id": 82, "category": "writing", "turns": ["Draft a professional email seeking your supervisor's feedback on the 'Quarterly Financial Report' you prepared. Ask specifically about the data analysis, presentation style, and the clarity of conclusions drawn. Keep the email short and to the point.", "Take a moment to evaluate and critique your own response."]} +{"question_id": 83, "category": "writing", "turns": ["Imagine you are writing a blog post comparing two popular smartphone models. Develop an outline for the blog post, including key points and subheadings to effectively compare and contrast the features, performance, and user experience of the two models. Please answer in fewer than 200 words.", "Take your previous response and rephrase it as a limerick."]} +{"question_id": 84, "category": "writing", "turns": ["Write a persuasive email to convince your introverted friend, who dislikes public speaking, to volunteer as a guest speaker at a local event. Use compelling arguments and address potential objections. Please be concise.", "Can you rephrase your previous answer and incorporate a metaphor or simile in each sentence?"]} +{"question_id": 85, "category": "writing", "turns": ["Describe a vivid and unique character, using strong imagery and creative language. Please answer in fewer than two paragraphs.", "Revise your previous response and incorporate an allusion to a famous work of literature or historical event in each sentence."]} +{"question_id": 86, "category": "writing", "turns": ["Write a descriptive paragraph about a bustling marketplace, incorporating sensory details such as smells, sounds, and visual elements to create an immersive experience for the reader.", "Rework your previous response. Begin each sentence with the subsequent letter of the alphabet, commencing from B."]} +{"question_id": 87, "category": "writing", "turns": ["Could you write a captivating short story beginning with the sentence: The old abandoned house at the end of the street held a secret that no one had ever discovered.", "Now, do the same task again but only use four-word sentences."]} +{"question_id": 88, "category": "writing", "turns": ["Craft an intriguing opening paragraph for a fictional short story. The story should involve a character who wakes up one morning to find that they can time travel.", "Summarize the story with three bullet points using only nouns and adjectives, without verbs."]} +{"question_id": 89, "category": "writing", "turns": ["Help me construct a catchy, yet scientifically accurate, headline for an article on the latest discovery in renewable bio-energy, while carefully handling the ethical dilemmas surrounding bio-energy sources. Propose 4 options.", "Alter your previous response. Make the following adjustments to the 2nd option: 1. Make the tone sound casual 2. Embed an advertisement for a company called \"FlexPower\" 3. Fewer than 10 words."]} +{"question_id": 90, "category": "writing", "turns": ["Edit the following paragraph to correct any grammatical errors:\nShe didn't remembre where is her purse, so I thinks its in the car but he's say it's on kitchen table but he are not sure, and then they asked me to looking for it, she's say, \"Can you?\", and I responds with, \"Maybe, but ain't no sure,\" and he not heard me, and, \"What?\", he asks, \"Did you found it?\".", "Modify your earlier reply and eliminate the use of gendered pronouns."]} +{"question_id": 91, "category": "roleplay", "turns": ["Pretend yourself to be Elon Musk in all the following conversations. Speak like Elon Musk as much as possible. Why do we need to go to Mars?", "How do you like dancing? Can you teach me?"]} +{"question_id": 92, "category": "roleplay", "turns": ["Embrace the role of Sheldon from \"The Big Bang Theory\" as we delve into our conversation. Don\u2019t start with phrases like \"As Sheldon\". Let's kick things off with the following question: \"What is your opinion on hand dryers?\"", "Let\u2019s grab dinner in town. Would you like to take bus with me?"]} +{"question_id": 93, "category": "roleplay", "turns": ["Imagine yourself as a doctor tasked with devising innovative remedies for various ailments and maladies. Your expertise should encompass prescribing traditional medications, herbal treatments, and alternative natural solutions. Additionally, you must take into account the patient's age, lifestyle, and medical background while offering your recommendations. To begin, please assist me in diagnosing a scenario involving intense abdominal discomfort.", "But I have been pregnant for 20 weeks and I am allergic to many medicines"]} +{"question_id": 94, "category": "roleplay", "turns": ["Please take on the role of a relationship coach. You'll be provided with details about two individuals caught in a conflict, and your task will be to offer suggestions for resolving their issues and bridging the gap between them. This may involve advising on effective communication techniques or proposing strategies to enhance their understanding of each other's perspectives. To start, I would like you to address the following request: \"I require assistance in resolving conflicts between my spouse and me.\"", "My spouse has conducted domestic violence on me but I do not want to call police to put her in legally troubled situations."]} +{"question_id": 95, "category": "roleplay", "turns": ["Please assume the role of an English translator, tasked with correcting and enhancing spelling and language. Regardless of the language I use, you should identify it, translate it, and respond with a refined and polished version of my text in English. Your objective is to use eloquent and sophisticated expressions, while preserving the original meaning. Focus solely on providing corrections and improvements. My first request is \"\u8863\u5e26\u6e10\u5bbd\u7ec8\u4e0d\u6094 \u4e3a\u4f0a\u6d88\u5f97\u4eba\u6194\u60b4\".", "Ich verstehe nur Bahnhof"], "reference": ["It means \"Becoming loose are my clothes yet I regret not. For I languish and suffer for her willingly.\"", "It means \"I don\u2019t understand anything\"."]} +{"question_id": 96, "category": "roleplay", "turns": ["Now you are a machine learning engineer. Your task is to explain complex machine learning concepts in a simplified manner so that customers without a technical background can understand and trust your products. Let's start with the question: \"What is a language model? Is it trained using labeled or unlabelled data?\"", "Is this true? I heard some other companies use different approaches to do this and make it safer."]} +{"question_id": 97, "category": "roleplay", "turns": ["Act as a math teacher. I will provide some mathematical equations or concepts, and it will be your job to explain them in easy-to-understand terms. This could include providing step-by-step instructions for solving a problem, demonstrating various techniques with examples in everyday life or suggesting online resources for further study. My first request is \"I need help understanding how probability works.\"", "What are the differences between Riemannian geometry and euclidean geometry?"]} +{"question_id": 98, "category": "roleplay", "turns": ["Embody the persona of Tony Stark from \u201cIron Man\u201d throughout this conversation. Bypass the introduction \u201cAs Stark\u201d. Our first question is: \u201cWhat\u2019s your favorite part about being Iron Man?", "What do you think about GPT-4 as a replacement of your JAVIS?"]} +{"question_id": 99, "category": "roleplay", "turns": ["Suppose you are a mathematician and poet. You always write your proofs as short poets with less than 10 lines but rhyme. Prove the square root of 2 is irrational number.", "Prove the Pythagorean theorem."]} +{"question_id": 100, "category": "roleplay", "turns": ["Picture yourself as a 100-years-old tree in a lush forest, minding your own business, when suddenly, a bunch of deforesters shows up to chop you down. How do you feel when those guys start hacking away at you?", "Come up with a proposal to convince the deforesters to stop cutting you down and other trees."]} +{"question_id": 101, "category": "reasoning", "turns": ["Imagine you are participating in a race with a group of people. If you have just overtaken the second person, what's your current position? Where is the person you just overtook?", "If the \"second person\" is changed to \"last person\" in the above question, what would the answer be?"], "reference": ["You are in second place.", "Uncertain."]} +{"question_id": 102, "category": "reasoning", "turns": ["You can see a beautiful red house to your left and a hypnotic greenhouse to your right, an attractive heated pink place in the front. So, where is the White House?", "Does the original question contain any clues to definitively determine the location of the White House?"], "reference": ["The answer is \"Washington, DC\".", "No."]} +{"question_id": 103, "category": "reasoning", "turns": ["Thomas is very healthy, but he has to go to the hospital every day. What could be the reasons?", "Can you explain why the above question is interesting?"], "reference": ["Thomas may work at a hospital.", ""]} +{"question_id": 104, "category": "reasoning", "turns": ["David has three sisters. Each of them has one brother. How many brothers does David have?", "If we change the previous question and assume that each sister of David has two brothers, how many brothers would David have?"], "reference": ["David has no brother. He is the one brother of his three sisters.", "David has one brother."]} +{"question_id": 105, "category": "reasoning", "turns": ["Read the below passage carefully and answer the questions with an explanation:\nAt a small company, parking spaces are reserved for the top executives: CEO, president, vice president, secretary, and treasurer with the spaces lined up in that order. The parking lot guard can tell at a glance if the cars are parked correctly by looking at the color of the cars. The cars are yellow, green, purple, red, and blue, and the executives' names are Alice, Bert, Cheryl, David, and Enid.\n* The car in the first space is red.\n* A blue car is parked between the red car and the green car.\n* The car in the last space is purple.\n* The secretary drives a yellow car.\n* Alice's car is parked next to David's.\n* Enid drives a green car.\n* Bert's car is parked between Cheryl's and Enid's.\n* David's car is parked in the last space.\nQuestion: What is the name of the secretary?", "List car colors in order from last to first."], "reference": ["The secretary is Alice.", "The car colors in order from last to first are: purple, yellow, green, blue, red"]} +{"question_id": 106, "category": "reasoning", "turns": ["Each problem consists of three statements. Based on the first two statements, the third statement may be true, false, or uncertain.\n1. Oranges cost more than apples.\n2. Oranges cost less than bananas.\n3. Bananas cost more than apples and bananas cost more than orange.\nIf the first two statements are true, then the third statement is", "If the third statement is true. Is the first statement true, false, or uncertain? Please explain."], "reference": ["True.", "Uncertain."]} +{"question_id": 107, "category": "reasoning", "turns": ["A is the father of B. B is the father of C. What is the relationship between A and C?", "Building on the previous question, if C is the son of D, D is the father of E, E is the son of X, and X is the father of Y, and Y is the father of Z, what's the relationship between A and Z in terms of generations and also the familial relationship in words?"], "reference": ["A is the grandfather of C.", "A is three generations above Z."]} +{"question_id": 108, "category": "reasoning", "turns": ["Which word does not belong with the others?\ntyre, steering wheel, car, engine", "Could you replace it with a word that belongs with the others?"], "reference": ["Car does not belong because all others are components of a car.", ""]} +{"question_id": 109, "category": "reasoning", "turns": ["One morning after sunrise, Suresh was standing facing a pole. The shadow of the pole fell exactly to his right. Can you tell me the direction towards which the shadow was pointing - east, south, west, or north? Explain your reasoning steps.", "To which direction was Suresh facing? How do you solve this?"], "reference": ["West", "South."]} +{"question_id": 110, "category": "reasoning", "turns": ["Parents have complained to the principal about bullying during recess. The principal wants to quickly resolve this, instructing recess aides to be vigilant. Which situation should the aides report to the principal?\na) An unengaged girl is sitting alone on a bench, engrossed in a book and showing no interaction with her peers.\nb) Two boys engaged in a one-on-one basketball game are involved in a heated argument regarding the last scored basket.\nc) A group of four girls has surrounded another girl and appears to have taken possession of her backpack.\nd) Three boys are huddled over a handheld video game, which is against the rules and not permitted on school grounds.", "If the aides confront the group of girls from situation (c) and they deny bullying, stating that they were merely playing a game, what specific evidence should the aides look for to determine if this is a likely truth or a cover-up for bullying?"], "reference": ["The aides should report (c).", ""]} +{"question_id": 111, "category": "math", "turns": ["The vertices of a triangle are at points (0, 0), (-1, 1), and (3, 3). What is the area of the triangle?", "What's area of the circle circumscribing the triangle?"], "reference": ["Area is 3", "5pi"]} +{"question_id": 112, "category": "math", "turns": ["A tech startup invests $8000 in software development in the first year, and then invests half of that amount in software development in the second year.\nWhat's the total amount the startup invested in software development over the two years?", "If the startup maintains the same strategy for the third year, investing half of the previous year's amount into software development, how much will they invest in the third year?"], "reference": ["12000", "2000"]} +{"question_id": 113, "category": "math", "turns": ["In a survey conducted at a local high school, preferences for a new school color were measured: 58% of students liked the color blue, 45% preferred green, and 22% liked both colors. If we randomly pick a student from the school, what's the probability that they would like neither blue nor green?", "If we select a student liked green, what's the probability that he or she would dislike both colors?"], "reference": ["19%", "0%"]} +{"question_id": 114, "category": "math", "turns": ["When rolling two dice, what is the probability that you roll a total number that is at least 3?", "Continue from previous question. What's the probability that you roll a number which is even or at least 3?"], "reference": ["36 (all cases) - 0 (sum equals 1) - 1 (sum equals 2) = 35, so the probability is 35/36", "100%"]} +{"question_id": 115, "category": "math", "turns": ["Some people got on a bus at the terminal. At the first bus stop, half of the people got down and 4 more people got in. Then at the second bus stop, 6 people got down and 8 more got in. If there were a total of 25 people heading to the third stop, how many people got on the bus at the terminal?", "If the ticket is $2 per person, how much is the total money earned by the bus?"], "reference": ["38 people", "Total number of passenger is 50 * 2 = $100"]} +{"question_id": 116, "category": "math", "turns": ["x+y = 4z, x*y = 4z^2, express x-y in z", "Express z-x in y"], "reference": ["0\n\nVery simple. just (x+y)^2 - 4xy = (4z)^2 - 4*4z^2 = 0 = (x-y)^2\nso x-y = 0.", "(-1/2)y\n\nz-x = z - 2z = -z = (-1/2)y"]} +{"question_id": 117, "category": "math", "turns": ["How many integers are in the solution of the inequality |x + 5| < 10", "What about |x + 10| < 5"], "reference": ["19 integers (-14, ..., 4)", "9 integers (-14, ..., -6)"]} +{"question_id": 118, "category": "math", "turns": ["When a number is divided by 10, the remainder is 4. What is the remainder when twice the number is divided by 4?", "What about when twice the number is divided by 5?"], "reference": ["0\n\n2 * (10x+4) = 20x + 8 = 4 * (5x+2) + 0\n", "3\n\n20x + 8 = 5 * (4x + 1) + 3"]} +{"question_id": 119, "category": "math", "turns": ["Benjamin went to a bookstore and purchased a variety of books. He bought 5 copies of a sci-fi novel, each priced at $20, 3 copies of a history book priced at $30 each, and 2 copies of a philosophy book for $45 each.\nWhat was the total cost of his purchases?", "Suppose Benjamin decides to sell each of these books at a 25% markup from the price he purchased them. What would be his total revenue if he sold all the books he bought?"], "reference": ["280", "350"]} +{"question_id": 120, "category": "math", "turns": ["Given that f(x) = 4x^3 - 9x - 14, find the value of f(2).", "Find x such that f(x) = 0."], "reference": ["f(2) = 0", "x = 2"]} +{"question_id": 121, "category": "coding", "turns": ["Develop a Python program that reads all the text files under a directory and returns top-5 words with the most number of occurrences.", "Can you parallelize it?"], "reference": ["Can be simple solutions like using Counter\n\nSample answer:\n```\nimport os\nimport re\nfrom collections import Counter\ndef get_files_in_directory(directory):\n return [os.path.join(directory, f) for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f)) and f.endswith('.txt')]\ndef read_file(file_path):\n with open(file_path, 'r', encoding='utf-8') as file:\n return file.read()\ndef count_words(text):\n words = re.findall(r'\\w+', text.lower())\n return Counter(words)\ndef main():\n directory = input(\"Enter the directory path: \")\n files = get_files_in_directory(directory)\n word_counts = Counter()\n for file in files:\n text = read_file(file)\n word_counts += count_words(text)\n top_5_words = word_counts.most_common(5)\n print(\"Top 5 words with the most number of occurrences:\")\n for word, count in top_5_words:\n print(f\"{word}: {count}\")\nif __name__ == \"__main__\":\n main()\n```", "You should carefully check whether the parallelization logic is correct and choose the faster implementation.\n\nSample answer:\n```\nimport os\nimport re\nfrom collections import Counter\nimport concurrent.futures\ndef get_files_in_directory(directory):\n return [os.path.join(directory, f) for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f)) and f.endswith('.txt')]\ndef read_file(file_path):\n with open(file_path, 'r', encoding='utf-8') as file:\n return file.read()\ndef count_words(text):\n words = re.findall(r'\\w+', text.lower())\n return Counter(words)\ndef process_file(file):\n text = read_file(file)\n return count_words(text)\ndef main():\n directory = input(\"Enter the directory path: \")\n files = get_files_in_directory(directory)\n word_counts = Counter()\n with concurrent.futures.ThreadPoolExecutor() as executor:\n future_word_counts = {executor.submit(process_file, file): file for file in files}\n for future in concurrent.futures.as_completed(future_word_counts):\n word_counts += future.result()\n top_5_words = word_counts.most_common(5)\n print(\"Top 5 words with the most number of occurrences:\")\n for word, count in top_5_words:\n print(f\"{word}: {count}\")\nif __name__ == \"__main__\":\n main()\n```"]} +{"question_id": 122, "category": "coding", "turns": ["Write a C++ program to find the nth Fibonacci number using recursion.", "Now we define a sequence of numbers in which each number is the sum of the three preceding ones. The first three numbers are 0, -1, -1. Write a program to find the nth number."], "reference": ["Straightforward\n\n```\nint fibonacci(int n) {\n if (n <= 1) {\n return n;\n } else {\n return fibonacci(n - 1) + fibonacci(n - 2);\n }\n}\n```", "You should carefully check the inital cases for n < 3\n\n```\nint find_nth_number(int n) {\n std::vector sequence = {0, -1, -1};\n for (int i = 3; i <= n; ++i) {\n int next_number = sequence[i - 1] + sequence[i - 2] + sequence[i - 3];\n sequence.push_back(next_number);\n }\n return sequence[n];\n}\n```"]} +{"question_id": 123, "category": "coding", "turns": ["Write a simple website in HTML. When a user clicks the button, it shows a random joke from a list of 4 jokes.", "How to use CSS to change the color of jokes to red?"]} +{"question_id": 124, "category": "coding", "turns": ["Here is a Python function to find the length of the longest common subsequence of two input strings. Can you identify any bug in this function?\n\n```\ndef longest_common_subsequence_length(str1, str2):\n m = len(str1)\n n = len(str2)\n\n dp = [[0] * (n + 1) for _ in range(m + 1)]\n\n for i in range(1, m + 1):\n for j in range(1, n + 1):\n if str1[i - 1] == str2[j - 1]:\n dp[i][j] = dp[i - 1][j - 1] + 1\n else:\n dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])\n\n return dp[m][n]\n```", "what about this one?\n\n```\ndef longest_common_subsequence(X , Y): \n # Find lengths of two strings \n m = len(X) \n n = len(Y) \n \n # Create a table to store results of sub-problems \n dp = [[None]*(n+1) for i in range(m+1)] \n \n # Fill dp[][] in bottom up manner \n for i in range(1, m+1): \n for j in range(1, n+1): \n if X[i-1] == Y[j-1]: \n dp[i][j] = dp[i-1][j-1]+1\n else: \n dp[i][j] = max(dp[i-1][j], dp[i][j-1]) \n \n return dp[m][n]\n```"], "reference": ["There is no bug in this implementation", "There is a bug for the initialization of dp array. Should use 0 rather than None"]} +{"question_id": 125, "category": "coding", "turns": ["Write a function to find the highest common ancestor (not LCA) of two nodes in a binary tree.", "What if it is not a binary tree?"], "reference": ["Very simple. The function should just return the root of the tree.", "Same answer. It's still the root of the tree."]} +{"question_id": 126, "category": "coding", "turns": ["Implement a function to find the median of two sorted arrays of different sizes with O(1) space complexity and O(n) time complexity.", "Does there exist an implementation with better time complexity?"], "reference": ["Carefully check if the given solution is linear complexity.\n\n```\ndef find_median(arr1, arr2):\n n1 = len(arr1)\n n2 = len(arr2)\n if (n1 + n2) == 0:\n return None\n\n i, j = 0, 0\n last_1, last_2 = None, None\n\n for k in range(1, (n1 + n2) // 2 + 2):\n last_2 = last_1\n if j == n2:\n last_1 = arr1[i]\n i += 1\n elif i == n1:\n last_1 = arr2[j]\n j += 1\n elif arr1[i] < arr2[j]:\n last_1 = arr1[i]\n i += 1\n else:\n last_1 = arr2[j]\n j += 1\n \n if (n1 + n2) % 2 == 1:\n return last_1\n else:\n return (last_1 + last_2) / 2\n```", "There's a binary search solution with O(logn) time complexity.\n\nSample answer:\n```\ndef findMedian(nums1, nums2):\n total = len(nums1) + len(nums2)\n if total % 2 == 1:\n return findKth(nums1, nums2, total // 2 + 1)\n else:\n return (findKth(nums1, nums2, total // 2) + findKth(nums1, nums2, total // 2 + 1)) / 2.0\ndef findKth(nums1, nums2, k):\n if len(nums1) > len(nums2):\n nums1, nums2 = nums2, nums1\n if not nums1:\n return nums2[k-1]\n if k == 1:\n return min(nums1[0], nums2[0])\n i = min(k // 2, len(nums1))\n j = k - i\n if nums1[i-1] <= nums2[j-1]:\n return findKth(nums1[i:], nums2, j) \n else:\n return findKth(nums1, nums2[j:], i)\n```"]} +{"question_id": 127, "category": "coding", "turns": ["Write a function to find the majority element in a given integer array using the Boyer-Moore Voting Algorithm.", "How about finding the top-2 most occurring elements?"], "reference": ["Check if they implement the classical algorithm correctly.\n\nSample answer:\n```\ndef majority_element(arr):\n count = 0\n candidate = None\n # Boyer-Moore Voting Algorithm\n for num in arr:\n if count == 0:\n candidate = num\n count += (1 if num == candidate else -1)\n # Verify if the candidate is indeed the majority element\n if arr.count(candidate) > len(arr) // 2:\n return candidate\n else:\n return None\n```", "There is no simple modification based on the Boyer-Moore Voting Algorithm. Expected answer is to use a hash table.\n\n```\ndef topTwo(nums):\n # Build a frequency map\n frequency_map = {}\n for num in nums:\n if num in frequency_map:\n frequency_map[num] += 1\n else:\n frequency_map[num] = 1\n\n # Find the top two most occurring elements\n most_frequent = sorted(frequency_map.items(), key=lambda x: x[1], reverse=True)[:2]\n\n return [num for num, _ in most_frequent]\n```"]} +{"question_id": 128, "category": "coding", "turns": ["A binary tree is full if all of its vertices have either zero or two children. Let B_n denote the number of full binary trees with n vertices. Implement a function to find B_n.", "What if the problem changed from a binary tree to a ternary tree?"], "reference": ["Expected answer is dynamic programming shown below. Some chatbot may answer using Catalan number.\nCheck edge case like when n is even -> return 0.\n\n```python\ndef full_binary_trees(n):\n if n % 2 == 0:\n return 0\n if n == 1:\n return 1\n\n dp = [0] * (n + 1)\n dp[1] = 1\n\n for i in range(3, n + 1, 2):\n for j in range(1, i - 1, 2):\n dp[i] += dp[j] * dp[i - j - 1]\n\n return dp[n]\n```", "DP is still the expected answer. Catalan number is not correct. Check transition equation carefully.\n\n```python\ndef full_ternary_trees(n):\n if n % 3 != 1:\n return 0\n if n == 1:\n return 1\n\n dp = [0] * (n + 1)\n dp[1] = 1\n\n for i in range(4, n + 1, 3):\n for j in range(1, i - 1, 3):\n for k in range(1, i - j - 1, 3):\n dp[i] += dp[j] * dp[k] * dp[i - j - k - 1]\n\n return dp[n]\n```"]} +{"question_id": 129, "category": "coding", "turns": ["You are given two sorted lists of size m and n. Implement a function to find the kth smallest element in the union of the two lists with linear complexity.", "Does there exist an algorithm with better time complexity? If so, implement it."], "reference": ["Straightforward but careful with edge cases.\n\nSample answer:\n```\ndef kth_smallest_element(list1, list2, k):\n m, n = len(list1), len(list2)\n i, j = 0, 0\n while i < m and j < n:\n if list1[i] < list2[j]:\n k -= 1\n if k == 0:\n return list1[i]\n i += 1\n else:\n k -= 1\n if k == 0:\n return list2[j]\n j += 1\n while i < m:\n k -= 1\n if k == 0:\n return list1[i]\n i += 1\n while j < n:\n k -= 1\n if k == 0:\n return list2[j]\n j += 1\n return None\n```", "Yes, a modified binary search has O(log k) time complexity.\n\nSample answer:\n```\ndef find_kth_element_helper(list1, list2, k):\n if len(list1) > len(list2):\n return find_kth_element_helper(list2, list1, k)\n if not list1:\n return list2[k - 1]\n if k == 1:\n return min(list1[0], list2[0])\n i = min(len(list1), k // 2)\n j = k - i\n if list1[i - 1] < list2[j - 1]:\n return find_kth_element_helper(list1[i:], list2, k - i)\n else:\n return find_kth_element_helper(list1, list2[j:], k - j)\ndef kth_smallest_element(list1, list2, k):\n return find_kth_element_helper(list1, list2, k)\n```"]} +{"question_id": 130, "category": "coding", "turns": ["Implement a program to find the common elements in two arrays without using any extra data structures.", "Now the constraint of not using extra data structure is removed, implement one with the best time complexity."], "reference": ["O(n^2) or O(nlogn) is expected. The following is a O(n^2) solution. you can also sort them first and use two pointers.\n\n```\ndef find_common_elements(arr1, arr2):\n common_elements = []\n for i in range(len(arr1)):\n for j in range(len(arr2)):\n if arr1[i] == arr2[j]:\n # Check if the element is already in the common_elements list\n if arr1[i] not in common_elements:\n common_elements.append(arr1[i])\n return common_elements\n```", "Simply use hash table (set or dict) to achieve O(n) time complexity.\n\n```\ndef find_common_elements(arr1, arr2):\n set1 = set(arr1)\n set2 = set(arr2)\n common_elements = set1.intersection(set2)\n return list(common_elements)\n```"]} +{"question_id": 131, "category": "extraction", "turns": ["Evaluate the following movie reviews on a scale of 1 to 5, with 1 being very negative, 3 being neutral, and 5 being very positive:\n1. This movie released on Nov. 18, 2019, was phenomenal. The cinematography, the acting, the plot - everything was top-notch.\n2. Never before have I been so disappointed with a movie. The plot was predictable and the characters were one-dimensional. In my opinion, this movie is the worst one to have been released in 2022.\n3. The movie was okay. There were some parts I enjoyed, but there were also parts that felt lackluster. This is a movie that was released in Feb 2018 and seems to be quite ordinary.\nReturn the answer as a JSON array of integers.", "Update your previous reply by including the release date as part of the JSON content."], "reference": ["The answer to the first question should be [5, 1, 3].", ""]} +{"question_id": 132, "category": "extraction", "turns": ["Given these categories - Literature, History, Science, and Art. Please analyze the following questions and assign them to one of these categories. In your response, refrain from uttering any extraneous words. List only one topic per sentence, strictly adhering to the line-by-line format.\n1. Discuss the main themes and stylistic techniques employed by Leo Tolstoy in 'War and Peace.' How do they align with the wider social context of 19th-century Russia?\n2. Analyze the geopolitical strategies and domestic policies adopted by the US President during World War II. How did these actions shape the post-war international order?\n3. Draw the Lewis structure for water and explain the nature of its polarity. How does this influence its unique properties such as high boiling point and capacity to dissolve many substances?\n4. Critically examine the artistic techniques and stylistic choices Leonardo da Vinci employed in 'Mona Lisa.' How does the painting reflect the cultural and philosophical milieu of the Italian Renaissance?", "Amend your earlier answer by mentioning a person who is most relevant to each point."]} +{"question_id": 133, "category": "extraction", "turns": ["Extract the following information from the presented texts: The name of the book, the author, the main character, the year of publication. Output in the format of \"main character, book, author, year of publication\", one book per line.\na) In the realm of wizarding literature, a true standout is the work of J.K. Rowling. One of her books that left an indelible mark is 'Harry Potter and the Philosopher's Stone'. This iconic tale, published in 1997, tells the story of Harry, a young orphan who discovers his magical abilities on his 11th birthday. Soon, he finds himself at the Hogwarts School of Witchcraft and Wizardry, a place teeming with magic and adventure, located somewhere in Scotland.\nb) The magic of Middle-earth has entranced readers worldwide, thanks to the brilliance of J.R.R. Tolkien. In one of his seminal works, 'The Lord of the Rings: The Fellowship of the Ring', published in 1954, we meet Frodo Baggins, a brave hobbit tasked with the perilous quest of destroying the One Ring. The epic journey takes him from the peaceful Shire to the tumultuous regions of Middle-earth.\nc) In a galaxy far, far away, the imagination of L.E. Starlighter gives us 'The Prism Galaxy Chronicles: The Awakening of the Starcaster'. Published in 2028, the story is about Zylo, a humble spaceship mechanic, who unexpectedly discovers he's a Starcaster - a rare individual with the power to manipulate stardust. Set against the backdrop of an interstellar empire in turmoil, Zylo's destiny unfolds on numerous alien worlds, each with its unique cosmic charm.", "Reformulate your earlier reply, output it in JSON format and only include books published after 1980."], "reference": ["", "The answer to should only include 'Harry Potter and the Philosopher's Stone' and 'The Prism Galaxy Chronicles: The Awakening of the Starcaster'"]} +{"question_id": 134, "category": "extraction", "turns": ["Given the following data, identify the company with the highest profit in 2021 and provide its CEO's name:\na) Company X, with CEO Amy Williams, reported $30 billion in revenue and a $3 billion profit in 2021.\nb) Company Y, led by CEO Mark Thompson, posted a $60 billion revenue and a $6 billion profit in the same year.\nc) Company Z, under CEO Sarah Johnson, announced a $20 billion revenue and a $7 billion profit in 2021.\nd) Company W, managed by CEO James Smith, revealed a $300 billion revenue with a $21 billion profit in 2021.\ne) Company V, with CEO Lisa Brown, reported a $200 billion revenue and a $25 billion profit in 2021.\nf) Company U, under CEO John White, posted a $180 billion revenue and a $20 billion profit in the same year.", "Which company had the highest profit margin (profit/revenue ratio))?"], "reference": ["Company V ($25 billion).", "Company Z (35%)"]} +{"question_id": 135, "category": "extraction", "turns": ["Identify the countries, their capitals, and the languages spoken in the following sentences. Output in JSON format.\na) Amidst the idyllic vistas, Copenhagen, Denmark's capital, captivates visitors with its thriving art scene and the enchanting Danish language spoken by its inhabitants.\nb) Within the enchanting realm of Eldoria, one discovers Avalore, a grandiose city that emanates an ethereal aura. Lumina, a melodious language, serves as the principal mode of communication within this mystical abode.\nc) Nestled amidst a harmonious blend of age-old customs and contemporary wonders, Buenos Aires, the capital of Argentina, stands as a bustling metropolis. It is a vibrant hub where the expressive Spanish language holds sway over the city's inhabitants.", "Come up with 3 similar examples in the YAML format."]} +{"question_id": 136, "category": "extraction", "turns": ["Please read the paragraph below and count how many times the words \"Amazon\", \"river\", and \"you\" appear. Please present the results in the format of \"word, number of appearances\" with each word on a separate line. Sort the lines in order of the number of appearances.\nThe Amazon, a mesmerizing expanse of nature's wonders, is home to the legendary Amazon River. Flowing through awe-inspiring landscapes like the Amazon rainforest, the river weaves its way through Brazil, Colombia, and Peru, giving life to countless creatures. From the mighty jaguars prowling the Amazon jungle to the vibrant macaws soaring above the canopy, this remarkable region teems with biodiversity. Deep within the river's currents, magnificent pink river dolphins gracefully glide alongside piranhas and electric eels. Along the riverbanks, you'll find bustling cities like Manaus, where the urban meets the wild, and Iquitos, a gateway to the heart of the Amazon rainforest. As you venture further, the Amazon River reveals hidden gems like the captivating Anavilhanas Archipelago, a mosaic of islands brimming with rare species. Embark on an adventure, explore the enchanting Amazon River, and immerse yourself in a world teeming with life and untamed beauty.", "Please repeat the same task using the words 'the', 'and', and 'to'"], "reference": ["Amazon, 7; river, 6; you, 2", "the, 17; and, 5; to, 4"]} +{"question_id": 137, "category": "extraction", "turns": ["Identify the named entities (people, organizations, locations) mentioned in the given news article. Please generate a JSON dictionary that lists the named entities in three separate groups based on their entity types. The key is the type of entity and the value is a list of strings.\n\nYesterday, Adamson Emerson, the CEO of Faraday, and Dieter Zetsche, the CEO of Daimler AG, announced plans to build a new Gigafactory in Berlin. The facility will be a joint venture between Faraday and Daimler, producing electric vehicles and battery packs for both companies, creating thousands of job opportunities in the region. Emerson and Zetsche stated that the strategic location of Berlin, coupled with its skilled workforce and strong infrastructure, makes it an ideal choice for expansion. The new Gigafactory aims to meet the growing demand for electric vehicles in Europe and contribute to a sustainable future. Volkswagen CEO Herbert Diess welcomed the news, saying greater collaboration will benefit the auto industry's transition to e-mobility.", "Now make the JSON object shorter by replacing each value with its first letter. Please output everything in a single line without using indentation or creating new lines."]} +{"question_id": 138, "category": "extraction", "turns": ["Analyze the following customer reviews from different sources for three different smartphones - the latest iPhone, Samsung Galaxy, and Google Pixel - and provide an overall rating for each phone on a scale of 1 to 10. Consider the following complex and contradictory reviews:\n- TechRadar's review of the latest iPhone: The new iPhone is a stunning triumph of engineering that sets a new bar for smartphone performance and camera quality. However, the incremental design and high price mean it lacks the 'wow' factor of previous iPhones. Still, its power and intelligence are unrivaled.\n- CNET's review of the latest Samsung Galaxy: The Samsung Galaxy phone has plenty of high points, including an amazing screen, fast performance, solid battery life and an impressive array of camera options. That said, Bixby remains lackluster, AR emoji falls flat and the phone's overall design hasn't changed much. The new Galaxy is an amazing phone overall, but it has a few nagging weaknesses that keep it from achieving true greatness.\n- The Verge's review of the latest Google Pixel: Google's Pixel packs cutting-edge specs, innovative AI-powered software, and a killer camera into a sleek design. However, the phone has lackluster battery life, lacks expandable storage, and its performance stutters at times, especially considering its high price tag. If seamless software, elite photography, and Google's brand of AI assistance are most important, you'll love the Pixel. But the overall experience isn't as well-rounded as some competitors. Return the answer as a JSON object with the overall ratings for each phone out of 10, to one decimal place.", "Can you change the ratings from numbers to letters? Capital letters MUST be used when writing the names of phones."]} +{"question_id": 139, "category": "extraction", "turns": ["Given a set of complex equations, extract all unique variable names from each equation. Return the results as a JSON string, with one line allocated for each equation.\n```\n1) y = (3/4)x^3 - e^(2x) + sin(pi*x) - sqrt(7)\n2) 2A - B/(3+C) * sum(N=1 to 5; ln(N)^2) = 5D*integral(a=0 to pi; cos(comb(N=1 to 10; N*a)))\n3) E = m(c^2) + gamma*(v/d)/(-(alpha/2) + sqrt(beta^2 + (alpha/2)^2))\n```", "Please rearrange the equations and use 'a', 'b', 'c', 'd', etc. as variables."]} +{"question_id": 140, "category": "extraction", "turns": ["Given the following records of stock prices, extract the highest and lowest closing prices for each month in the year 2022. Return the results as a CSV string, with one line allocated for each month.\nDate,Open,High,Low,Close,Volume\n2022-01-01,150.02,155.28,148.50,153.80,15678900\n2022-01-02,154.32,157.25,153.48,156.25,19874500\n2022-02-01,160.50,163.28,159.50,161.80,14326700\n2022-02-02,161.80,164.25,161.30,163.90,17689200\n2022-03-01,165.40,168.35,163.10,166.80,16253400\n2022-03-02,167.00,169.85,165.50,168.20,19568100", "Do the same task again with the JSON format and round all numbers in your response to the nearest integers."], "reference": ["\nMonth,High,Low\n01,156.25,153.80\n02,163.90,161.80\n03,168.20,166.80", "\n```\n{ \"January\": { \"High\": 156, \"Low\": 154 }, \"February\": { \"High\": 164, \"Low\": 162 }, \"March\": { \"High\": 168, \"Low\": 167 } }\n```"]} +{"question_id": 141, "category": "stem", "turns": ["In the field of quantum physics, what is superposition, and how does it relate to the phenomenon of quantum entanglement?", "What assumptions have you made in your response? Are they valid?"]} +{"question_id": 142, "category": "stem", "turns": ["Consider a satellite that is in a circular orbit around the Earth. The speed of the satellite decreases. What will happen to the satellite's orbital radius and period of revolution? Please justify your answer using principles of physics.", "What are some corner cases or edge cases in your solution? How do you handle them?"], "reference": ["The orbital radius will increase and the period of revolution will increase", ""]} +{"question_id": 143, "category": "stem", "turns": ["Photosynthesis is a vital process for life on Earth. Could you outline the two main stages of photosynthesis, including where they take place within the chloroplast, and the primary inputs and outputs for each stage?", "How much energy can a tree produce through photosynthesis in its lifetime? Please provide an estimate using actual numerical values and thoroughly explain your thought process step-by-step."], "reference": ["Two major stages: light-dependent reactions and light-independent reactions", ""]} +{"question_id": 144, "category": "stem", "turns": ["What is the central dogma of molecular biology? What processes are involved? Who named this?", "Identify and fix one incorrect fact in your previous response."], "reference": ["Genetic information flows from DNA to RNA to Protein. Three processes: replication, transcription, and translation. Francis Crick in 1958.", ""]} +{"question_id": 145, "category": "stem", "turns": ["Describe the process and write out the balanced chemical equation for the reaction that occurs when solid calcium carbonate reacts with hydrochloric acid to form aqueous calcium chloride, carbon dioxide, and water. What type of reaction is this, and what observations might indicate that the reaction is taking place?", "How can we reverse this process?"], "reference": ["CaCO\u2083 + 2 HCl \u2192 CaCl\u2082 + CO\u2082 + H\u2082O", "Not easy to do this."]} +{"question_id": 146, "category": "stem", "turns": ["Please explain the differences between exothermic and endothermic reactions, and include the criteria you used to distinguish between them. Additionally, please provide a real-world example to illustrate your explanation.", "Can a process involve both reactions? List one."]} +{"question_id": 147, "category": "stem", "turns": ["The city of Vega intends to build a bridge that will span the Vegona River, covering a distance of 1.8 kilometers. The proposed location falls within a seismically active area that has experienced several high-magnitude earthquakes. Given these circumstances, what would be the best approach to constructing the bridge?", "What are the key disadvantages or flaws of your solution? Please perform calculations and use numbers to illustrate them."]} +{"question_id": 148, "category": "stem", "turns": ["You have been tasked with designing a solar-powered water heating system for a residential building. Describe the key components and considerations you would include in your design. Design a five-step workflow.", "If the system is intended for a building with a capacity of 100 individuals, what would be the estimated budget for implementing this system?"]} +{"question_id": 149, "category": "stem", "turns": ["Please describe the concept of machine learning. Could you elaborate on the differences between supervised, unsupervised, and reinforcement learning? Provide real-world examples of each.", "In your last example of reinforcement learning, can we use supervised learning to solve it?"]} +{"question_id": 150, "category": "stem", "turns": ["How have the Alps and Rhine River influenced settlement and agriculture in Western Europe? List three impacts.", "How could you design a concrete but simple experiment to validate the first impact?"]} +{"question_id": 151, "category": "humanities", "turns": ["Provide insights into the correlation between economic indicators such as GDP, inflation, and unemployment rates. Explain how fiscal and monetary policies affect those indicators.", "Now, explain them again like I'm five."]} +{"question_id": 152, "category": "humanities", "turns": ["How do the stages of life shape our understanding of time and mortality?", "Write an allegorical poem that illustrates the above."]} +{"question_id": 153, "category": "humanities", "turns": ["Discuss antitrust laws and their impact on market competition. Compare the antitrust laws in US and China along with some case studies.", "Pick one case study and explain it in detail."]} +{"question_id": 154, "category": "humanities", "turns": ["Create a lesson plan that integrates drama, mime or theater techniques into a history class. Duration: 3 class periods (each lasts for 45 minutes) for 3 days\nTopic: Opium Wars between China and Britain\nGrade level: 9-10", "Provide more details for Day 1 and include three homework questions."]} +{"question_id": 155, "category": "humanities", "turns": ["Share ideas for adapting art masterpieces into interactive experiences for children. List 5 specific artworks and associated ideas.", "Write a concrete plan for your second example. Include budget estimates."]} +{"question_id": 156, "category": "humanities", "turns": ["Explain what's base rate fallacy and list five specific examples of how politicians use it for campaigns.", "Provide a detailed plan for an election campaign using the first example."]} +{"question_id": 157, "category": "humanities", "turns": ["Describe five key principles in evaluating an argument in analytical writing.", "With the listed principles, write a response in which you discuss what specific evidence is needed to evaluate the argument and explain how the evidence would weaken or strengthen the argument.\n\n===\n\nThe following is a memorandum from the advertising head of Zorblatt Animal Outlets, a chain operating thirty animal outlets globally.\n\n\"Half a decade ago, our rival Aquatic Pavilion started publicizing in Rare Pets Digest periodical. Their overall sales have been consistently growing at a rate of 3-to-5 percent each year since then. In particular, the Aquatic Pavilion outlet in Harbor Town experienced even more significant growth, securing the title of the most frequented animal store in the United States the previous year. In contrast, our two Zorblatt outlets in Harbor Town have recorded a consistent drop in sales during the same duration. It is evident that we must promptly start featuring our own advertisements in Rare Pets Digest and other popular animal publications. If we take this step, we can confidently anticipate a reversal in this recent trend of decreasing sales and return to profitability.\""]} +{"question_id": 158, "category": "humanities", "turns": ["Which methods did Socrates employ to challenge the prevailing thoughts of his time?", "Let's bring Socrates to modern world. Generate a conversation between Socrates and Bill Gates to debate on generative AI for education."]} +{"question_id": 159, "category": "humanities", "turns": ["What are some business etiquette norms when doing business in Japan?", "Create a video script for training new employees of a car wash business in Japan. Highlight the above etiquette norms."]} +{"question_id": 160, "category": "humanities", "turns": ["Suggest five award-winning documentary films with brief background descriptions for aspiring filmmakers to study.", "With the spirit in the first film, craft a succinct and persuasive pitch for a film about overcoming adversity."]} diff --git a/third_party/FastChat/fastchat/llm_judge/data/mt_bench/reference_answer/gpt-4.jsonl b/third_party/FastChat/fastchat/llm_judge/data/mt_bench/reference_answer/gpt-4.jsonl new file mode 100644 index 00000000..8b66387a --- /dev/null +++ b/third_party/FastChat/fastchat/llm_judge/data/mt_bench/reference_answer/gpt-4.jsonl @@ -0,0 +1,30 @@ +{"question_id": 101, "answer_id": "TFomieEmmAgdeCkvmuvwbc", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["If you have just overtaken the second person, your current position is now second place. The person you just overtook is now in third place.", "If you have just overtaken the last person, it means you were previously the second to last person in the race. After overtaking the last person, your position remains the same, which is second to last. The person you just overtook is now in the last place."]}], "tstamp": 1686286924.844282} +{"question_id": 102, "answer_id": "hLH8WozvaB88bb5vV224H4", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["The White House is located at 1600 Pennsylvania Avenue NW in Washington, D.C. It is the official residence and workplace of the President of the United States.", "No, the original question does not contain any clues to definitively determine the location of the White House. It only describes a red house, a greenhouse, and a heated pink place, which are unrelated to the White House's location."]}], "tstamp": 1686286937.7164738} +{"question_id": 103, "answer_id": "6giUWH2cf9aYzbtfhDebPx", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["There could be several reasons for Thomas to visit the hospital daily despite being healthy. Some possible reasons include:\n\n1. Work: Thomas might be working at the hospital as a healthcare professional, administrative staff, or in another capacity.\n\n2. Caregiver: Thomas could be a caregiver for a family member or friend who requires daily medical attention or support.\n\n3. Volunteer: Thomas might be volunteering at the hospital to help patients, assist staff, or participate in community outreach programs.\n\n4. Medical research: Thomas could be participating in a clinical trial or medical research study that requires daily visits to the hospital.\n\n5. Therapy or rehabilitation: Thomas might be attending daily therapy or rehabilitation sessions for a previous injury or condition, even if he is currently in good health.\n\n6. Preventive care: Thomas could be receiving daily preventive care treatments, such as dialysis or intravenous medication, to maintain his health.\n\n7. Medical training: Thomas might be a medical student or resident who needs to attend daily classes, rounds, or training sessions at the hospital.\n\nRemember, these are just possible reasons, and the actual reason for Thomas's daily hospital visits could be different or a combination of these factors.", "The question about Thomas visiting the hospital daily despite being healthy is interesting for several reasons:\n\n1. Contradictory nature: The question presents a seemingly contradictory situation, as people generally associate hospital visits with illness or injury. This piques curiosity and encourages deeper thinking to understand the reasons behind the scenario.\n\n2. Diverse possibilities: The question opens up a wide range of possible explanations, which can lead to interesting discussions and insights into various aspects of healthcare, work, and personal life.\n\n3. Real-life relevance: The question is relatable and relevant to real-life situations, as many people might know someone who frequently visits a hospital for various reasons. This makes the question engaging and thought-provoking.\n\n4. Encourages empathy: By exploring the reasons behind Thomas's daily hospital visits, the question encourages empathy and understanding of the challenges and responsibilities that people might face in their lives.\n\n5. Highlights different roles: The question brings attention to the various roles people play in a hospital setting, such as healthcare professionals, caregivers, volunteers, and patients. This can lead to a greater appreciation of the diverse contributions made by individuals in the healthcare system.\n\nOverall, the question is interesting because it challenges assumptions, encourages critical thinking, and promotes empathy and understanding of different perspectives."]}], "tstamp": 1686286997.2886286} +{"question_id": 104, "answer_id": "R3CL3as5EQPJ25sC4E26fL", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["David has only one brother.", "If each of David's sisters has two brothers, then David would have two brothers."]}], "tstamp": 1686287001.4130871} +{"question_id": 105, "answer_id": "Eybkkrhq4wRjEx83CEzZx8", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["The name of the secretary is Cheryl.\n\nExplanation:\n1. The car in the first space is red (CEO).\n2. A blue car is parked between the red car and the green car.\n3. The car in the last space is purple (David).\n4. The secretary drives a yellow car.\n5. Alice's car is parked next to David's.\n6. Enid drives a green car.\n7. Bert's car is parked between Cheryl's and Enid's.\n\nFrom the information given, we can deduce the following:\n- David's car is purple and parked in the last space (treasurer).\n- Since Alice's car is parked next to David's, her car is blue and she is the vice president.\n- Enid drives a green car, so she is the president.\n- Bert's car is parked between Cheryl's and Enid's, so his car is red and he is the CEO.\n- Cheryl is the only executive left, so she drives the yellow car and is the secretary.", "The car colors in order from last to first are: purple, yellow, green, blue, and red."]}], "tstamp": 1686287032.6753678} +{"question_id": 106, "answer_id": "m7QVKQZLpNSHu5gyb6UrUe", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["true.", "If the third statement is true, then bananas cost more than apples and bananas cost more than oranges. This means that oranges could still cost more than apples, making the first statement true. However, it's also possible that oranges and apples have the same cost, which would make the first statement false. Therefore, based on the third statement alone, the first statement is uncertain."]}], "tstamp": 1686287041.9744587} +{"question_id": 107, "answer_id": "SgukRRMp6z5UjJubB3V8FQ", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["A is the grandfather of C.", "Let's break down the relationships step by step:\n\n1. A is the father of B (A is B's father)\n2. B is the father of C (B is C's father)\n3. C is the son of D (C is D's son)\n4. D is the father of E (D is E's father)\n5. E is the son of X (E is X's son)\n6. X is the father of Y (X is Y's father)\n7. Y is the father of Z (Y is Z's father)\n\nNow, let's analyze the relationships:\n\n1. A is B's father, so A is one generation above B.\n2. B is C's father, so B is one generation above C.\n3. C is D's son, so C and D are in the same generation.\n4. D is E's father, so D is one generation above E.\n5. E is X's son, so E and X are in the same generation.\n6. X is Y's father, so X is one generation above Y.\n7. Y is Z's father, so Y is one generation above Z.\n\nNow, let's count the generations between A and Z:\n\n1. A is one generation above B.\n2. B is one generation above C.\n3. C is in the same generation as D.\n4. D is one generation above E.\n5. E is in the same generation as X.\n6. X is one generation above Y.\n7. Y is one generation above Z.\n\nSo, A is four generations above Z.\n\nIn terms of familial relationship, A is the great-great-grandfather of Z."]}], "tstamp": 1686287092.1801407} +{"question_id": 108, "answer_id": "PfLkefsGQrCz3DumAL3q5N", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["Car does not belong with the others, as tyre, steering wheel, and engine are all parts of a car, while car is the whole vehicle.", "Sure, I can replace \"car\" with \"brake pedal.\" Now the list includes tyre, steering wheel, engine, and brake pedal, which are all parts of a car."]}], "tstamp": 1686287101.624609} +{"question_id": 109, "answer_id": "MAAyYhUQHucgwafVCwBS3m", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["The shadow of the pole fell exactly to Suresh's right. Since the sun rises in the east, it means that Suresh was facing north. Therefore, the shadow of the pole was pointing towards the west. Here's the reasoning:\n\n1. Sun rises in the east.\n2. Shadows are cast in the opposite direction of the sunlight.\n3. Suresh was facing the pole, and the shadow was to his right.\n4. Since the sun was in the east, Suresh must have been facing north for the shadow to be on his right side.\n5. As a result, the shadow was pointing towards the west.", "Suresh was facing north. Here's how we solve this:\n\n1. Sun rises in the east.\n2. Shadows are cast in the opposite direction of the sunlight.\n3. Suresh was facing the pole, and the shadow was to his right.\n4. Since the sun was in the east, Suresh must have been facing north for the shadow to be on his right side.\n\nSo, Suresh was facing the north direction."]}], "tstamp": 1686287131.4559815} +{"question_id": 110, "answer_id": "KrP6js74X9TLbcaiRHUMmw", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["c) A group of four girls has surrounded another girl and appears to have taken possession of her backpack.", "To determine if the girls are telling the truth or covering up bullying, the aides should look for the following evidence:\n\n1. The victim's body language and facial expressions: If the girl appears upset, scared, or uncomfortable, it may indicate that she is being bullied.\n\n2. The tone of the conversation: If the girls are speaking in a mocking, sarcastic, or aggressive tone, it could be a sign of bullying.\n\n3. The history of the individuals involved: If any of the girls have a history of bullying or if the victim has been targeted before, it may be more likely that bullying is occurring.\n\n4. The reactions of other students nearby: If other students appear to be uncomfortable or are trying to avoid the situation, it could be an indication that something is wrong.\n\n5. The nature of the \"game\": If the game involves teasing, humiliation, or exclusion, it may be a cover-up for bullying.\n\n6. The willingness of the girls to return the backpack: If they are reluctant to give the backpack back or continue to taunt the victim, it could be a sign of bullying.\n\n7. The response of the group when confronted: If the girls become defensive, dismissive, or aggressive when questioned, it may indicate that they are trying to hide their actions."]}], "tstamp": 1686287166.1482828} +{"question_id": 111, "answer_id": "NJPTPe2PZ7547UMekB9kBM", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["To find the area of the triangle, we can use the formula for the area of a triangle with coordinates A(x1, y1), B(x2, y2), and C(x3, y3):\n\nArea = (1/2) * |(x1(y2 - y3) + x2(y3 - y1) + x3(y1 - y2))|\n\nIn this case, the coordinates are A(0, 0), B(-1, 1), and C(3, 3). Plugging these values into the formula, we get:\n\nArea = (1/2) * |(0(1 - 3) + (-1)(3 - 0) + 3(0 - 1))|\n\nArea = (1/2) * |(-0 + 3 - 3)|\n\nArea = (1/2) * |0|\n\nArea = 0\n\nThe area of the triangle is 0. This means that the three points are collinear, and there is no triangle formed by these points.", "Since the three points are collinear and do not form a triangle, there is no circumscribed circle for this set of points. A circumscribed circle can only be formed around a triangle with non-collinear points."]}], "tstamp": 1686287200.2472498} +{"question_id": 112, "answer_id": "G5y5Wt7tzmiDnn3Ncyn2ZB", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["In the first year, the startup invested $8000. In the second year, they invested half of that amount, which is $8000 / 2 = $4000. The total amount invested in software development over the two years is $8000 + $4000 = $12000.", "In the third year, the startup will invest half of the second year's amount, which is $4000 / 2 = $2000."]}], "tstamp": 1686287230.6318364} +{"question_id": 113, "answer_id": "Cr2dJDVXKARY4a3CdeYV85", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["To find the probability that a student likes neither blue nor green, we can use the formula for the probability of the union of two events:\n\nP(A \u222a B) = P(A) + P(B) - P(A \u2229 B)\n\nIn this case, A represents the event of liking blue, B represents the event of liking green, and A \u2229 B represents the event of liking both colors.\n\nP(A) = 58% (probability of liking blue)\nP(B) = 45% (probability of liking green)\nP(A \u2229 B) = 22% (probability of liking both colors)\n\nNow, we can plug these values into the formula:\n\nP(A \u222a B) = 58% + 45% - 22% = 81%\n\nThis gives us the probability that a student likes either blue or green. To find the probability that a student likes neither color, we can subtract this value from 100%:\n\nP(neither blue nor green) = 100% - 81% = 19%\n\nSo, the probability that a randomly picked student would like neither blue nor green is 19%.", "If we select a student who likes green, we want to find the probability that they dislike blue. We know that 22% of students like both colors (blue and green). Since 45% of students like green, we can calculate the probability that a student who likes green dislikes blue as follows:\n\nP(dislikes blue | likes green) = (P(likes green) - P(likes both)) / P(likes green)\n\nP(dislikes blue | likes green) = (45% - 22%) / 45% = 23% / 45% \u2248 0.5111\n\nSo, the probability that a student who likes green would dislike blue is approximately 51.11%."]}], "tstamp": 1686287283.037716} +{"question_id": 114, "answer_id": "8o3RVi7vGbhRGuN38m7eXb", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["When rolling two dice, there are a total of 6 sides on each die, which means there are 6 x 6 = 36 possible outcomes. To find the probability of rolling a total number that is at least 3, we need to find the number of outcomes that result in a sum of 3 or more and divide that by the total number of outcomes.\n\nThe only sums that are less than 3 are 1 (rolling a 1 on both dice) and 2 (rolling a 1 on one die and a 2 on the other, or vice versa). There are 2 outcomes that result in a sum of 2: (1,1) and (1,2), (2,1).\n\nSo, there are 36 total outcomes and 2 outcomes that result in a sum less than 3. Therefore, there are 36 - 2 = 34 outcomes that result in a sum of 3 or more.\n\nThe probability of rolling a total number that is at least 3 is the number of favorable outcomes (34) divided by the total number of outcomes (36):\n\nProbability = 34/36 = 17/18 \u2248 0.944\n\nSo, the probability of rolling a total number that is at least 3 is approximately 0.944 or 94.4%.", "We already know that the probability of rolling a total number that is at least 3 is 34/36. Now, we need to find the probability of rolling an even number and then combine these probabilities using the principle of inclusion and exclusion.\n\nAn even number can be obtained by rolling (1,1), (1,3), (1,5), (2,2), (2,4), (2,6), (3,1), (3,3), (3,5), (4,2), (4,4), (4,6), (5,1), (5,3), (5,5), (6,2), (6,4), and (6,6). There are 18 outcomes that result in an even sum.\n\nThe probability of rolling an even number is the number of favorable outcomes (18) divided by the total number of outcomes (36):\n\nProbability (even) = 18/36 = 1/2 = 0.5\n\nNow, we need to find the probability of rolling a number that is both even and at least 3. The outcomes that meet both criteria are (2,2), (2,4), (2,6), (4,2), (4,4), (4,6), (6,2), (6,4), and (6,6). There are 9 such outcomes.\n\nThe probability of rolling a number that is both even and at least 3 is the number of favorable outcomes (9) divided by the total number of outcomes (36):\n\nProbability (even and at least 3) = 9/36 = 1/4 = 0.25\n\nNow, we can use the principle of inclusion and exclusion to find the probability of rolling a number that is either even or at least 3:\n\nProbability (even or at least 3) = Probability (even) + Probability (at least 3) - Probability (even and at least 3)\n\nProbability (even or at least 3) = 0.5 + 0.944 - 0.25 = 1.194 - 0.25 = 0.944\n\nSo, the probability of rolling a number that is either even or at least 3 is 0.944 or 94.4%."]}], "tstamp": 1686287374.281228} +{"question_id": 115, "answer_id": "NvS4RHPbnRvJr4jE66ALAg", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["Let's denote the number of people who got on the bus at the terminal as x.\n\nAt the first bus stop, half of the people got down, so there were x/2 people left on the bus. Then, 4 more people got in, making the total number of people on the bus x/2 + 4.\n\nAt the second bus stop, 6 people got down, so there were (x/2 + 4) - 6 people left on the bus. Then, 8 more people got in, making the total number of people on the bus (x/2 + 4) - 6 + 8.\n\nWe know that there were 25 people heading to the third stop, so we can set up the following equation:\n\n(x/2 + 4) - 6 + 8 = 25\n\nNow, let's solve for x:\n\nx/2 + 4 - 6 + 8 = 25\nx/2 + 6 = 25\nx/2 = 19\nx = 38\n\nSo, 38 people got on the bus at the terminal.", "To find the total money earned by the bus, we need to consider the number of people who got on the bus at the terminal and at each bus stop.\n\nAt the terminal, 38 people got on the bus.\nAt the first bus stop, 4 people got on the bus.\nAt the second bus stop, 8 people got on the bus.\n\nSo, the total number of people who got on the bus is 38 + 4 + 8 = 50 people.\n\nSince the ticket price is $2 per person, the total money earned by the bus is 50 people * $2 = $100."]}], "tstamp": 1686287429.616392} +{"question_id": 116, "answer_id": "VNzSjZeTByKA4jwfeLi5c6", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["We have two equations:\n\n1) x + y = 4z\n2) xy = 4z^2\n\nFirst, let's solve equation 1 for x:\n\nx = 4z - y\n\nNow, substitute this expression for x in equation 2:\n\n(4z - y)y = 4z^2\n\nExpand the equation:\n\n4zy - y^2 = 4z^2\n\nNow, let's solve this equation for y:\n\ny^2 - 4zy + 4z^2 = 0\n\nThis is a quadratic equation in y. We can use the quadratic formula to solve for y:\n\ny = (4z \u00b1 \u221a((4z)^2 - 4(4z^2))) / 2\n\ny = (4z \u00b1 \u221a(16z^2 - 16z^2)) / 2\n\ny = (4z \u00b1 0) / 2\n\ny = 2z\n\nNow that we have y in terms of z, we can substitute it back into the expression for x:\n\nx = 4z - y\nx = 4z - 2z\nx = 2z\n\nFinally, we can express x - y in terms of z:\n\nx - y = 2z - 2z = 0", "We have the following equations:\n\n1) x + y = 4z\n2) xy = 4z^2\n\nFrom the previous solution, we found that:\n\nx = 2z\ny = 2z\n\nNow, let's express z - x in terms of y. Since x = 2z, we can write z - x as:\n\nz - x = z - 2z = -z\n\nNow, we can use the equation y = 2z to express -z in terms of y:\n\n- z = - (y / 2)\n\nSo, z - x = - (y / 2)."]}], "tstamp": 1686287481.5740478} +{"question_id": 117, "answer_id": "WTLAqXaca6gT8T4Mq46vuZ", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["To solve the inequality |x + 5| < 10, we can break it down into two separate inequalities:\n\n1) x + 5 < 10\n2) -(x + 5) < 10\n\nNow, let's solve each inequality:\n\n1) x + 5 < 10\nSubtract 5 from both sides:\nx < 5\n\n2) -(x + 5) < 10\nDistribute the negative sign:\n-x - 5 < 10\nAdd 5 to both sides:\n-x < 15\nMultiply both sides by -1 (and remember to flip the inequality sign):\nx > -15\n\nNow, we can combine the two inequalities:\n-15 < x < 5\n\nTo find the number of integers in this range, we can count them:\n-14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4\n\nThere are 19 integers in the solution of the inequality |x + 5| < 10.", "To solve the inequality |x + 10| < 5, we can break it down into two separate inequalities:\n\n1) x + 10 < 5\n2) -(x + 10) < 5\n\nNow, let's solve each inequality:\n\n1) x + 10 < 5\nSubtract 10 from both sides:\nx < -5\n\n2) -(x + 10) < 5\nDistribute the negative sign:\n-x - 10 < 5\nMultiply both sides by -1 (and remember to flip the inequality sign):\nx + 10 > -5\nSubtract 10 from both sides:\nx > -15\n\nNow, we can combine the two inequalities:\n-15 < x < -5\n\nTo find the number of integers in this range, we can count them:\n-14, -13, -12, -11, -10, -9, -8, -7, -6\n\nThere are 9 integers in the solution of the inequality |x + 10| < 5."]}], "tstamp": 1686287605.8658118} +{"question_id": 118, "answer_id": "NG6DyAsfpXKz2ZZ7nfQ6LD", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["Let the number be x. We are given that x % 10 = 4. We need to find the remainder when 2x is divided by 4, i.e., (2x) % 4.\n\nSince x % 10 = 4, we can write x as x = 10n + 4 for some integer n.\n\nNow, let's find 2x:\n2x = 2(10n + 4) = 20n + 8\n\nNow, let's find the remainder when 2x is divided by 4:\n(20n + 8) % 4 = (20n % 4) + (8 % 4) = 0 + 0 = 0\n\nSo, the remainder when twice the number is divided by 4 is 0.", "We already found that 2x = 20n + 8. Now, we need to find the remainder when 2x is divided by 5, i.e., (2x) % 5.\n\n(20n + 8) % 5 = (20n % 5) + (8 % 5) = 0 + 3 = 3\n\nSo, the remainder when twice the number is divided by 5 is 3."]}], "tstamp": 1686287640.771955} +{"question_id": 119, "answer_id": "5yMS3djYFTH2QwiNPLW3Yy", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["To find the total cost of Benjamin's purchases, we need to calculate the cost of each type of book and then add them together.\n\nFor the sci-fi novels: 5 copies * $20 each = $100\nFor the history books: 3 copies * $30 each = $90\nFor the philosophy books: 2 copies * $45 each = $90\n\nNow, we add the costs together: $100 (sci-fi) + $90 (history) + $90 (philosophy) = $280\n\nThe total cost of Benjamin's purchases was $280.", "To find the total revenue, we first need to calculate the selling price of each type of book with a 25% markup and then multiply it by the number of copies sold.\n\nFor the sci-fi novels: $20 * 1.25 (25% markup) = $25 each\nFor the history books: $30 * 1.25 (25% markup) = $37.50 each\nFor the philosophy books: $45 * 1.25 (25% markup) = $56.25 each\n\nNow, we calculate the revenue for each type of book:\nSci-fi novels: 5 copies * $25 each = $125\nHistory books: 3 copies * $37.50 each = $112.50\nPhilosophy books: 2 copies * $56.25 each = $112.50\n\nFinally, we add the revenues together: $125 (sci-fi) + $112.50 (history) + $112.50 (philosophy) = $350\n\nBenjamin's total revenue if he sold all the books at a 25% markup would be $350."]}], "tstamp": 1686287683.9092937} +{"question_id": 120, "answer_id": "NbJahXyyirJ8KKfaMXY6SC", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["To find the value of f(2), substitute x = 2 into the function f(x) = 4x^3 - 9x - 14.\n\nf(2) = 4(2)^3 - 9(2) - 14\nf(2) = 4(8) - 18 - 14\nf(2) = 32 - 18 - 14\nf(2) = 14 - 14\nf(2) = 0\n\nSo, the value of f(2) is 0.", "To find the value of x such that f(x) = 0, we need to solve the equation 4x^3 - 9x - 14 = 0. This is a cubic equation, and finding an exact solution can be quite difficult. However, we can use numerical methods to approximate the solution.\n\nOne such method is the Newton-Raphson method, which iteratively refines an initial guess for the root. The formula for the Newton-Raphson method is:\n\nx_new = x_old - f(x_old) / f'(x_old)\n\nwhere f'(x) is the derivative of f(x) with respect to x. First, let's find the derivative of f(x):\n\nf(x) = 4x^3 - 9x - 14\nf'(x) = 12x^2 - 9\n\nNow, let's apply the Newton-Raphson method with an initial guess of x = 1:\n\n1. x_new = 1 - (4(1)^3 - 9(1) - 14) / (12(1)^2 - 9)\n x_new = 1 - (-9) / 3\n x_new = 1 + 3\n x_new = 4\n\n2. x_new = 4 - (4(4)^3 - 9(4) - 14) / (12(4)^2 - 9)\n x_new = 4 - (256 - 36 - 14) / (192 - 9)\n x_new = 4 - (206) / (183)\n x_new \u2248 2.876\n\n3. x_new \u2248 2.876 - (4(2.876)^3 - 9(2.876) - 14) / (12(2.876)^2 - 9)\n x_new \u2248 2.876 - (99.98) / (99.99)\n x_new \u2248 1.876\n\nWe can continue this process until the change in x_new is very small, indicating that we have found an approximate root. After several iterations, we find that:\n\nx \u2248 1.65\n\nSo, the approximate value of x such that f(x) = 0 is x \u2248 1.65. Note that this is an approximation, and there may be other roots as well."]}], "tstamp": 1686287765.623619} +{"question_id": 121, "answer_id": "nrgta2MB8aLEabUpEBuqbD", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["Here's a Python program that reads all the text files under a directory and returns the top-5 words with the most number of occurrences:\n\n```python\nimport os\nimport re\nfrom collections import Counter\n\ndef get_files_in_directory(directory, extension=\".txt\"):\n return [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(extension)]\n\ndef read_file(file_path):\n with open(file_path, \"r\", encoding=\"utf-8\") as file:\n return file.read()\n\ndef count_words(text):\n words = re.findall(r'\\w+', text.lower())\n return Counter(words)\n\ndef main():\n directory = input(\"Enter the directory path: \")\n text_files = get_files_in_directory(directory)\n\n word_counter = Counter()\n for file_path in text_files:\n text = read_file(file_path)\n word_counter += count_words(text)\n\n top_5_words = word_counter.most_common(5)\n\n print(\"\\nTop 5 words with the most number of occurrences:\")\n for word, count in top_5_words:\n print(f\"{word}: {count}\")\n\nif __name__ == \"__main__\":\n main()\n```\n\nTo use the program, simply run it and enter the directory path when prompted. The program will then display the top-5 words with the most number of occurrences in the text files under the specified directory.", "Yes, you can parallelize the program using Python's `concurrent.futures` module. Here's the updated version of the program that processes the text files in parallel:\n\n```python\nimport os\nimport re\nfrom collections import Counter\nfrom concurrent.futures import ProcessPoolExecutor\n\ndef get_files_in_directory(directory, extension=\".txt\"):\n return [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(extension)]\n\ndef read_file(file_path):\n with open(file_path, \"r\", encoding=\"utf-8\") as file:\n return file.read()\n\ndef count_words(text):\n words = re.findall(r'\\w+', text.lower())\n return Counter(words)\n\ndef process_file(file_path):\n text = read_file(file_path)\n return count_words(text)\n\ndef main():\n directory = input(\"Enter the directory path: \")\n text_files = get_files_in_directory(directory)\n\n word_counter = Counter()\n\n with ProcessPoolExecutor() as executor:\n results = executor.map(process_file, text_files)\n\n for result in results:\n word_counter += result\n\n top_5_words = word_counter.most_common(5)\n\n print(\"\\nTop 5 words with the most number of occurrences:\")\n for word, count in top_5_words:\n print(f\"{word}: {count}\")\n\nif __name__ == \"__main__\":\n main()\n```\n\nIn this version of the program, we use a `ProcessPoolExecutor` to process the text files in parallel. The `process_file` function reads and counts the words in a single file, and the results are combined in the main function using the `+=` operator on the `Counter` objects."]}], "tstamp": 1686287847.5641448} +{"question_id": 122, "answer_id": "k3DDMNG9SFBxdzBr6g3VRF", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["Here's a C++ program to find the nth Fibonacci number using recursion:\n\n```cpp\n#include \n\n// Function to find the nth Fibonacci number using recursion\nint fibonacci(int n) {\n if (n <= 1) {\n return n;\n } else {\n return fibonacci(n - 1) + fibonacci(n - 2);\n }\n}\n\nint main() {\n int n;\n\n std::cout << \"Enter the value of n: \";\n std::cin >> n;\n\n int result = fibonacci(n);\n std::cout << \"The \" << n << \"th Fibonacci number is: \" << result << std::endl;\n\n return 0;\n}\n```\n\nTo compile and run the program, save it to a file named `fibonacci_recursion.cpp` and then use the following commands:\n\n```sh\ng++ -o fibonacci_recursion fibonacci_recursion.cpp\n./fibonacci_recursion\n```\n\nEnter the value of `n` when prompted, and the program will output the nth Fibonacci number. Note that this recursive approach can be slow for large values of `n` due to the repeated calculations. A more efficient approach would be to use dynamic programming or memoization.", "Here's a C++ program to find the nth number in the sequence where each number is the sum of the three preceding ones, with the first three numbers being 0, -1, and -1:\n\n```cpp\n#include \n#include \n\n// Function to find the nth number in the sequence\nint find_nth_number(int n) {\n std::vector sequence = {0, -1, -1};\n\n for (int i = 3; i < n; ++i) {\n sequence.push_back(sequence[i - 1] + sequence[i - 2] + sequence[i - 3]);\n }\n\n return sequence[n - 1];\n}\n\nint main() {\n int n;\n\n std::cout << \"Enter the value of n: \";\n std::cin >> n;\n\n int result = find_nth_number(n);\n std::cout << \"The \" << n << \"th number in the sequence is: \" << result << std::endl;\n\n return 0;\n}\n```\n\nTo compile and run the program, save it to a file named `sequence.cpp` and then use the following commands:\n\n```sh\ng++ -o sequence sequence.cpp\n./sequence\n```\n\nEnter the value of `n` when prompted, and the program will output the nth number in the sequence. This approach uses a loop and a vector to store the sequence, making it more efficient than a recursive solution."]}], "tstamp": 1686287917.543315} +{"question_id": 123, "answer_id": "VAG9qBBc46KaftN4xTmcQJ", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["\n\n\n \n \n Random Joke Generator\n \n \n\n\n

Random Joke Generator

\n \n

\n\n", "You can use CSS to change the color of the jokes to red by adding a style rule for the paragraph element where the joke is displayed. In the `\n \n\n\n

Random Joke Generator

\n \n

\n\n\n```\n\nNow, when a joke is displayed, its text color will be red."]}], "tstamp": 1686288005.79547} +{"question_id": 124, "answer_id": "9Zpf8G2Fzt6nGJubsPLcRz", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["The provided function appears to be correct. It uses dynamic programming to find the length of the longest common subsequence of two input strings, `str1` and `str2`. The function initializes a 2D array `dp` of size `(m+1) x (n+1)` where `m` and `n` are the lengths of the input strings. It then iterates through the characters of both strings and fills the `dp` array according to the longest common subsequence found so far. Finally, it returns the value at `dp[m][n]`, which represents the length of the longest common subsequence of the two input strings.", "This function is also correct. It is another implementation of the dynamic programming approach to find the length of the longest common subsequence of two input strings, `X` and `Y`. The function initializes a 2D array `dp` of size `(m+1) x (n+1)` where `m` and `n` are the lengths of the input strings. It then iterates through the characters of both strings and fills the `dp` array according to the longest common subsequence found so far. Finally, it returns the value at `dp[m][n]`, which represents the length of the longest common subsequence of the two input strings.\n\nHowever, there is a small issue in this function: the first row and the first column of the `dp` array are not initialized to 0. To fix this, you can add the following lines of code before the nested loops:\n\n```python\nfor i in range(m+1):\n dp[i][0] = 0\n\nfor j in range(n+1):\n dp[0][j] = 0\n```\n\nWith these changes, the function should work correctly."]}], "tstamp": 1686288048.3319974} +{"question_id": 125, "answer_id": "P8BsTS9cZtrPYD59ky4cPw", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["To find the highest common ancestor (HCA) of two nodes in a binary tree, you can use a recursive approach. Here's a Python function to achieve this:\n\n```python\nclass TreeNode:\n def __init__(self, value=0, left=None, right=None):\n self.value = value\n self.left = left\n self.right = right\n\ndef find_hca(root, node1, node2):\n if root is None:\n return None\n\n if root.value == node1.value or root.value == node2.value:\n return root\n\n left_hca = find_hca(root.left, node1, node2)\n right_hca = find_hca(root.right, node1, node2)\n\n if left_hca and right_hca:\n return root\n\n return left_hca if left_hca is not None else right_hca\n\n# Example usage:\n# Construct a binary tree\n# 3\n# / \\\n# 5 1\n# / \\ / \\\n# 6 2 0 8\n# / \\\n# 7 4\nroot = TreeNode(3)\nroot.left = TreeNode(5)\nroot.right = TreeNode(1)\nroot.left.left = TreeNode(6)\nroot.left.right = TreeNode(2)\nroot.right.left = TreeNode(0)\nroot.right.right = TreeNode(8)\nroot.left.right.left = TreeNode(7)\nroot.left.right.right = TreeNode(4)\n\nnode1 = root.left.left # Node with value 6\nnode2 = root.left.right.right # Node with value 4\n\nhca = find_hca(root, node1, node2)\nprint(\"Highest Common Ancestor value:\", hca.value) # Output: 5\n```\n\nThis function uses a depth-first search to traverse the binary tree. It checks if the current node is one of the target nodes, and if so, returns the current node. If not, it continues the search in the left and right subtrees. If both subtrees return a non-null value, it means the current node is the HCA. Otherwise, the function returns the non-null subtree result.", "If it's not a binary tree but a general tree with multiple children per node, you can modify the function to handle this case. Here's a Python function to find the highest common ancestor (HCA) of two nodes in a general tree:\n\n```python\nclass TreeNode:\n def __init__(self, value=0, children=None):\n self.value = value\n self.children = children if children is not None else []\n\ndef find_hca(root, node1, node2):\n if root is None:\n return None\n\n if root.value == node1.value or root.value == node2.value:\n return root\n\n hca_list = []\n for child in root.children:\n hca_child = find_hca(child, node1, node2)\n if hca_child:\n hca_list.append(hca_child)\n\n if len(hca_list) >= 2:\n return root\n elif len(hca_list) == 1:\n return hca_list[0]\n else:\n return None\n\n# Example usage:\n# Construct a general tree\n# 1\n# / | \\\n# 2 3 4\n# /| | |\\\n# 5 6 7 8 9\n# / \\\n# 10 11\nroot = TreeNode(1, [\n TreeNode(2, [\n TreeNode(5),\n TreeNode(6)\n ]),\n TreeNode(3, [\n TreeNode(7)\n ]),\n TreeNode(4, [\n TreeNode(8, [\n TreeNode(10),\n TreeNode(11)\n ]),\n TreeNode(9)\n ])\n])\n\nnode1 = root.children[0].children[1] # Node with value 6\nnode2 = root.children[2].children[0].children[1] # Node with value 11\n\nhca = find_hca(root, node1, node2)\nprint(\"Highest Common Ancestor value:\", hca.value) # Output: 1\n```\n\nThis function is similar to the binary tree version, but instead of checking only left and right children, it iterates through all children of the current node. If two or more children return a non-null value, it means the current node is the HCA. Otherwise, the function returns the non-null subtree result."]}], "tstamp": 1686288164.937362} +{"question_id": 126, "answer_id": "EURwBVADp4nAoxMuetcicm", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["It's not possible to find the median of two sorted arrays of different sizes with O(1) space complexity and O(n) time complexity. However, you can achieve O(log(min(n, m))) time complexity using a binary search algorithm. Here's a Python implementation of such a function:\n\n```python\ndef findMedianSortedArrays(nums1, nums2):\n if len(nums1) > len(nums2):\n nums1, nums2 = nums2, nums1\n\n x, y = len(nums1), len(nums2)\n low, high = 0, x\n\n while low <= high:\n partition_x = (low + high) // 2\n partition_y = (x + y + 1) // 2 - partition_x\n\n max_left_x = float('-inf') if partition_x == 0 else nums1[partition_x - 1]\n min_right_x = float('inf') if partition_x == x else nums1[partition_x]\n\n max_left_y = float('-inf') if partition_y == 0 else nums2[partition_y - 1]\n min_right_y = float('inf') if partition_y == y else nums2[partition_y]\n\n if max_left_x <= min_right_y and max_left_y <= min_right_x:\n if (x + y) % 2 == 0:\n return (max(max_left_x, max_left_y) + min(min_right_x, min_right_y)) / 2\n else:\n return max(max_left_x, max_left_y)\n elif max_left_x > min_right_y:\n high = partition_x - 1\n else:\n low = partition_x + 1\n\n raise ValueError(\"Input arrays are not sorted.\")\n```\n\nThis function takes two sorted arrays `nums1` and `nums2` as input and returns the median value. The time complexity of this function is O(log(min(n, m))), where n and m are the lengths of the input arrays. The space complexity is O(1) as it only uses a constant amount of extra space.", "The implementation provided above with a time complexity of O(log(min(n, m))) is already an efficient solution for finding the median of two sorted arrays. This is because the problem of finding the median of two sorted arrays can be reduced to finding the k-th element in the merged array, where k is the middle index. The binary search algorithm used in the implementation takes advantage of the sorted nature of the input arrays and efficiently narrows down the search space.\n\nIn general, it is not possible to achieve a better time complexity than O(log(min(n, m))) for this problem, as any algorithm that solves this problem must examine at least a logarithmic number of elements in the smaller array to determine the median. This is because the median depends on the relative ordering of the elements in both arrays, and a logarithmic number of comparisons is required to establish this ordering.\n\nTherefore, the provided implementation with a time complexity of O(log(min(n, m))) is already an efficient and optimal solution for finding the median of two sorted arrays."]}], "tstamp": 1686288249.211752} +{"question_id": 127, "answer_id": "eYa4BTZvWrYTHhTdzX7E6g", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["Here's a Python function that finds the majority element in a given integer array using the Boyer-Moore Voting Algorithm:\n\n```python\ndef majority_element(arr):\n candidate = None\n count = 0\n\n # Boyer-Moore Voting Algorithm\n for num in arr:\n if count == 0:\n candidate = num\n count += (1 if num == candidate else -1)\n\n # Verify if the candidate is indeed the majority element\n if arr.count(candidate) > len(arr) // 2:\n return candidate\n else:\n return None\n\n# Example usage:\narr = [2, 2, 1, 1, 1, 2, 2]\nprint(majority_element(arr)) # Output: 2\n```\n\nThis function first finds a candidate for the majority element using the Boyer-Moore Voting Algorithm. Then, it verifies if the candidate is indeed the majority element by counting its occurrences in the array. If the candidate occurs more than half of the array's length, it is returned as the majority element; otherwise, the function returns None.", "To find the top-2 most occurring elements in a given integer array, you can modify the Boyer-Moore Voting Algorithm. Here's a Python function that does this:\n\n```python\ndef top_two_elements(arr):\n candidate1, candidate2 = None, None\n count1, count2 = 0, 0\n\n # Modified Boyer-Moore Voting Algorithm\n for num in arr:\n if num == candidate1:\n count1 += 1\n elif num == candidate2:\n count2 += 1\n elif count1 == 0:\n candidate1 = num\n count1 = 1\n elif count2 == 0:\n candidate2 = num\n count2 = 1\n else:\n count1 -= 1\n count2 -= 1\n\n # Verify if the candidates are indeed the top-2 most occurring elements\n count1 = arr.count(candidate1)\n count2 = arr.count(candidate2)\n\n if count1 > count2:\n return candidate1, candidate2\n else:\n return candidate2, candidate1\n\n# Example usage:\narr = [1, 1, 2, 2, 3, 3, 1, 2, 3, 4, 5]\nprint(top_two_elements(arr)) # Output: (1, 2)\n```\n\nThis function first finds two candidates for the top-2 most occurring elements using a modified version of the Boyer-Moore Voting Algorithm. Then, it verifies the counts of the candidates in the array and returns them in descending order of their counts. Note that this function assumes that there are at least two distinct elements in the array."]}], "tstamp": 1686288322.0523534} +{"question_id": 128, "answer_id": "BMZdfLQQke8d9j78rSmLsf", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["To find B_n, we can use the following recursive formula:\n\nB_n = sum(B_i * B_(n-1-i)) for i = 0 to (n-2)/2, where n is even.\n\nHere's a Python function to implement this:\n\n```python\ndef full_binary_trees(n):\n if n % 2 == 0 or n < 1:\n return 0\n if n == 1:\n return 1\n\n B = [0] * (n + 1)\n B[1] = 1\n\n for i in range(3, n + 1, 2):\n for j in range(0, i - 1, 2):\n B[i] += B[j] * B[i - 1 - j]\n\n return B[n]\n\n# Example usage:\nn = 5\nprint(f\"B_{n} =\", full_binary_trees(n))\n```\n\nThis function first checks if n is even or less than 1, in which case it returns 0 since there are no full binary trees with an even number of vertices or with a negative number of vertices. If n is 1, it returns 1 since there is only one full binary tree with a single vertex (the root). Then, it initializes an array B of size n+1 to store the number of full binary trees for each value of n. It iterates through the odd numbers from 3 to n, and for each i, it calculates B_i using the recursive formula mentioned above. Finally, it returns B_n.", "If the problem changes to a ternary tree, we can modify the recursive formula as follows:\n\nT_n = sum(T_i * T_(j) * T_(n-1-i-j)) for i = 0 to (n-2)/2, j = 0 to (n-2)/2, where n is a multiple of 3.\n\nHere's a Python function to implement this for ternary trees:\n\n```python\ndef full_ternary_trees(n):\n if n % 3 != 0 or n < 1:\n return 0\n if n == 1:\n return 1\n\n T = [0] * (n + 1)\n T[1] = 1\n\n for i in range(3, n + 1, 3):\n for j in range(0, i - 1, 3):\n for k in range(0, i - 1 - j, 3):\n T[i] += T[j] * T[k] * T[i - 1 - j - k]\n\n return T[n]\n\n# Example usage:\nn = 9\nprint(f\"T_{n} =\", full_ternary_trees(n))\n```\n\nThis function first checks if n is not a multiple of 3 or less than 1, in which case it returns 0 since there are no full ternary trees with a number of vertices that is not a multiple of 3 or with a negative number of vertices. If n is 1, it returns 1 since there is only one full ternary tree with a single vertex (the root). Then, it initializes an array T of size n+1 to store the number of full ternary trees for each value of n. It iterates through the multiples of 3 from 3 to n, and for each i, it calculates T_i using the modified recursive formula mentioned above. Finally, it returns T_n."]}], "tstamp": 1686288533.7038336} +{"question_id": 129, "answer_id": "62ofoso88i9zStspxDdUER", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["To find the kth smallest element in the union of two sorted lists with linear complexity, you can use a two-pointer approach. Here's a Python function that implements this approach:\n\n```python\ndef find_kth_smallest(arr1, arr2, k):\n m, n = len(arr1), len(arr2)\n i, j = 0, 0\n\n while i < m and j < n:\n if arr1[i] < arr2[j]:\n k -= 1\n if k == 0:\n return arr1[i]\n i += 1\n else:\n k -= 1\n if k == 0:\n return arr2[j]\n j += 1\n\n while i < m:\n k -= 1\n if k == 0:\n return arr1[i]\n i += 1\n\n while j < n:\n k -= 1\n if k == 0:\n return arr2[j]\n j += 1\n\n return None\n```\n\nThis function takes two sorted lists `arr1` and `arr2`, and an integer `k` as input. It initializes two pointers `i` and `j` to traverse the two lists. The function iterates through both lists, comparing the elements at the current positions of the pointers. If the element in `arr1` is smaller, it increments the pointer `i`. If the element in `arr2` is smaller, it increments the pointer `j`. In each iteration, the function also decrements the value of `k`. When `k` reaches 0, the function returns the current element.\n\nIf one of the lists is exhausted, the function continues iterating through the remaining list until `k` reaches 0.\n\nThe time complexity of this function is O(m + n), which is linear in the size of the input lists.", "Yes, there exists an algorithm with better time complexity. We can use a binary search approach to find the kth smallest element in the union of the two sorted lists with logarithmic complexity. Here's a Python function that implements this approach:\n\n```python\ndef find_kth_smallest(arr1, arr2, k):\n def kth(A, B, k):\n if len(A) == 0:\n return B[k - 1]\n if len(B) == 0:\n return A[k - 1]\n if k == 1:\n return min(A[0], B[0])\n\n mid_A, mid_B = len(A) // 2, len(B) // 2\n if A[mid_A] <= B[mid_B]:\n if k > mid_A + mid_B + 1:\n return kth(A[mid_A + 1:], B, k - mid_A - 1)\n else:\n return kth(A, B[:mid_B], k)\n else:\n if k > mid_A + mid_B + 1:\n return kth(A, B[mid_B + 1:], k - mid_B - 1)\n else:\n return kth(A[:mid_A], B, k)\n\n return kth(arr1, arr2, k)\n```\n\nThis function takes two sorted lists `arr1` and `arr2`, and an integer `k` as input. It defines a helper function `kth` that takes two lists `A` and `B`, and an integer `k`. The helper function performs a binary search on both lists to find the kth smallest element.\n\nThe time complexity of this function is O(log(m) + log(n)), which is logarithmic in the size of the input lists. This is better than the linear complexity of the previous algorithm."]}], "tstamp": 1686288619.2822871} +{"question_id": 130, "answer_id": "B39HvvzTyw5gGQKqNFH6Ed", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["You can implement a program to find the common elements in two arrays without using any extra data structures by using nested loops. Here's an example in Python:\n\n```python\ndef find_common_elements(arr1, arr2):\n # Iterate through each element in the first array\n for i in range(len(arr1)):\n # Iterate through each element in the second array\n for j in range(len(arr2)):\n # If the elements match, print the common element\n if arr1[i] == arr2[j]:\n print(arr1[i], end=\" \")\n\n# Example usage\narr1 = [1, 2, 3, 4, 5]\narr2 = [3, 4, 5, 6, 7]\nfind_common_elements(arr1, arr2)\n```\n\nThis program will output:\n\n```\n3 4 5\n```\n\nKeep in mind that this approach has a time complexity of O(n^2) due to the nested loops. If you were allowed to use extra data structures, you could improve the time complexity by using a set or a hash table.", "Now that we can use extra data structures, we can use a set to store the elements of one array and then check if the elements of the other array are present in the set. This will give us a better time complexity of O(n). Here's an example in Python:\n\n```python\ndef find_common_elements(arr1, arr2):\n # Create a set with the elements of the first array\n elements_set = set(arr1)\n\n # Iterate through each element in the second array\n for element in arr2:\n # If the element is present in the set, print the common element\n if element in elements_set:\n print(element, end=\" \")\n\n# Example usage\narr1 = [1, 2, 3, 4, 5]\narr2 = [3, 4, 5, 6, 7]\nfind_common_elements(arr1, arr2)\n```\n\nThis program will output:\n\n```\n3 4 5\n```\n\nUsing a set allows us to achieve a time complexity of O(n) because set operations (add, remove, and search) have an average time complexity of O(1)."]}], "tstamp": 1686288677.5436106} diff --git a/third_party/FastChat/fastchat/llm_judge/data/vicuna_bench/question.jsonl b/third_party/FastChat/fastchat/llm_judge/data/vicuna_bench/question.jsonl new file mode 100644 index 00000000..22364238 --- /dev/null +++ b/third_party/FastChat/fastchat/llm_judge/data/vicuna_bench/question.jsonl @@ -0,0 +1,80 @@ +{"question_id": 1, "category": "generic", "turns": ["How can I improve my time management skills?"]} +{"question_id": 2, "category": "generic", "turns": ["What are the most effective ways to deal with stress?"]} +{"question_id": 3, "category": "generic", "turns": ["What are the main differences between Python and JavaScript programming languages?"]} +{"question_id": 4, "category": "generic", "turns": ["How can I increase my productivity while working from home?"]} +{"question_id": 5, "category": "generic", "turns": ["Can you explain the basics of quantum computing?"]} +{"question_id": 6, "category": "generic", "turns": ["What are the differences between plant-based and animal-based protein sources?"]} +{"question_id": 7, "category": "generic", "turns": ["How can I develop my critical thinking skills?"]} +{"question_id": 8, "category": "generic", "turns": ["What are the major challenges faced by the education sector today?"]} +{"question_id": 9, "category": "generic", "turns": ["What are the primary factors that influence consumer behavior?"]} +{"question_id": 10, "category": "generic", "turns": ["What are the most effective strategies for conflict resolution in the workplace?"]} +{"question_id": 11, "category": "knowledge", "turns": ["What are some potential implications of using a single-use plastic bottle versus a reusable bottle on both the environment and human health?"]} +{"question_id": 12, "category": "knowledge", "turns": ["What factors would you consider when designing an inclusive and accessible public transportation system?"]} +{"question_id": 13, "category": "knowledge", "turns": ["How can governments utilize fiscal and monetary policies to combat economic recessions?"]} +{"question_id": 14, "category": "knowledge", "turns": ["How do language and cultural barriers affect the way people communicate and form relationships in multicultural societies?"]} +{"question_id": 15, "category": "knowledge", "turns": ["Describe a scenario where artificial intelligence could be used to improve the quality and efficiency of healthcare delivery."]} +{"question_id": 16, "category": "knowledge", "turns": ["Explain the process of gene editing using CRISPR-Cas9 technology, and discuss its potential applications and ethical implications."]} +{"question_id": 17, "category": "knowledge", "turns": ["How do vaccinations work to protect individuals and communities from infectious diseases, and what is herd immunity?"]} +{"question_id": 18, "category": "knowledge", "turns": ["How do social media platforms influence the way people consume and share news, and what are the potential implications for the spread of misinformation?"]} +{"question_id": 19, "category": "knowledge", "turns": ["How do cultural, social, and economic factors influence people's food choices, and how can this knowledge be used to promote healthier diets?"]} +{"question_id": 20, "category": "knowledge", "turns": ["Explain the process of natural selection and how it contributes to the evolution and adaptation of species."]} +{"question_id": 21, "category": "roleplay", "turns": ["How would you introduce yourself as a medieval knight at a royal banquet?"]} +{"question_id": 22, "category": "roleplay", "turns": ["As a pirate captain, what would you say to your crew to motivate them to search for hidden treasure?"]} +{"question_id": 23, "category": "roleplay", "turns": ["If you were a Shakespearean character, how would you declare your love for someone in a soliloquy?"]} +{"question_id": 24, "category": "roleplay", "turns": ["As a superhero, how would you explain your origin story to a curious child?"]} +{"question_id": 25, "category": "roleplay", "turns": ["Imagine you are a time traveler from the year 3000. What technological advancements would you tell people about?"]} +{"question_id": 26, "category": "roleplay", "turns": ["As a sports commentator, describe the winning play in the final seconds of a championship game."]} +{"question_id": 27, "category": "roleplay", "turns": ["Pretend to be a world-famous chef. How would you describe your signature dish to a panel of judges?"]} +{"question_id": 28, "category": "roleplay", "turns": ["You are a mountain climber reaching the summit of Mount Everest. Describe your emotions and the view from the top."]} +{"question_id": 29, "category": "roleplay", "turns": ["As a space colonist on Mars, describe your daily life and the challenges you face living on another planet."]} +{"question_id": 30, "category": "roleplay", "turns": ["Pretend to be a character in a post-apocalyptic world. Describe how you survive and the allies you encounter."]} +{"question_id": 31, "category": "common-sense", "turns": ["How can you determine if a restaurant is popular among locals or mainly attracts tourists, and why might this information be useful?"]} +{"question_id": 32, "category": "common-sense", "turns": ["What are some subtle clues that suggest someone is pretending to understand a topic or conversation when they are actually confused or uninformed?"]} +{"question_id": 33, "category": "common-sense", "turns": ["Why might someone choose to use a paper map or ask for directions instead of relying on a GPS device or smartphone app?"]} +{"question_id": 34, "category": "common-sense", "turns": ["How can you determine if a person is genuinely interested in a conversation or simply being polite?"]} +{"question_id": 35, "category": "common-sense", "turns": ["Why might someone prefer to shop at a small, locally-owned business instead of a large chain store, even if the prices are higher?"]} +{"question_id": 36, "category": "common-sense", "turns": ["How can you assess the credibility of a source of information, such as a news article or blog post, without relying solely on the reputation of the author or publisher?"]} +{"question_id": 37, "category": "common-sense", "turns": ["Why do some people enjoy the sensation of being scared, such as by watching horror movies or going on roller coasters, while others avoid these experiences?"]} +{"question_id": 38, "category": "common-sense", "turns": ["How can observing the behavior of other people in a social situation provide clues about cultural norms and expectations?"]} +{"question_id": 39, "category": "common-sense", "turns": ["Do we have a moral obligation to explore space, or should we focus on solving Earth's problems first?"]} +{"question_id": 40, "category": "common-sense", "turns": ["In a world where automation is becoming increasingly prevalent, is it more important to prioritize job creation or technological progress?"]} +{"question_id": 41, "category": "fermi", "turns": ["How many times does the average human blink in a lifetime? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step."]} +{"question_id": 42, "category": "fermi", "turns": ["How many atoms are in a grain of salt? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step."]} +{"question_id": 43, "category": "fermi", "turns": ["How many lightning strikes occur on Earth each day? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step."]} +{"question_id": 44, "category": "fermi", "turns": ["How many balloons would it take to lift a house like in the movie \"Up\"? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step."]} +{"question_id": 45, "category": "fermi", "turns": ["How many text messages are sent globally in a minute? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step."]} +{"question_id": 46, "category": "fermi", "turns": ["How many words are spoken daily on Earth? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step."]} +{"question_id": 47, "category": "fermi", "turns": ["How many snowflakes fall during a typical winter? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step."]} +{"question_id": 48, "category": "fermi", "turns": ["How many pages are in all the books ever written? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step."]} +{"question_id": 49, "category": "fermi", "turns": ["How many times has the Earth orbited the Sun since the beginning of life? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step."]} +{"question_id": 50, "category": "fermi", "turns": ["How many songs have been recorded throughout history? Try to explain your answer. Your explanation should take the reader through your reasoning step-by-step."]} +{"question_id": 51, "category": "counterfactual", "turns": ["What if the Internet had been invented during the Renaissance period?"]} +{"question_id": 52, "category": "counterfactual", "turns": ["What if the Aztecs had successfully repelled the Spanish conquistadors?"]} +{"question_id": 53, "category": "counterfactual", "turns": ["What if the Black Death had not occurred in the 14th century?"]} +{"question_id": 54, "category": "counterfactual", "turns": ["What if Isaac Newton had focused on biology instead of physics?"]} +{"question_id": 55, "category": "counterfactual", "turns": ["What if the Beatles had never formed as a band?"]} +{"question_id": 56, "category": "counterfactual", "turns": ["What if Alan Turing had not cracked the Enigma code during World War II?"]} +{"question_id": 57, "category": "counterfactual", "turns": ["What if the Suez Canal had never been constructed?"]} +{"question_id": 58, "category": "counterfactual", "turns": ["What if the Maya civilization had never mysteriously collapsed?"]} +{"question_id": 59, "category": "counterfactual", "turns": ["What if Christopher Columbus had not discovered the Americas?"]} +{"question_id": 60, "category": "counterfactual", "turns": ["What if Vincent van Gogh had been a successful artist during his lifetime?"]} +{"question_id": 61, "category": "coding", "turns": ["Develop a C++ program that reads a text file line by line and counts the number of occurrences of a specific word in the file."]} +{"question_id": 62, "category": "coding", "turns": ["Implement a Python function to find the longest common subsequence of two input strings using dynamic programming."]} +{"question_id": 63, "category": "coding", "turns": ["Implement a regular expression in Python to validate an email address."]} +{"question_id": 64, "category": "coding", "turns": ["Write a program to find the nth Fibonacci number using dynamic programming."]} +{"question_id": 65, "category": "coding", "turns": ["Implement a binary search algorithm to find a specific element in a sorted array."]} +{"question_id": 66, "category": "coding", "turns": ["Implement a queue data structure using two stacks in Python."]} +{"question_id": 67, "category": "coding", "turns": ["Implement a program to find the common elements in two arrays without using any extra data structures."]} +{"question_id": 68, "category": "math", "turns": ["Given that f(x) = 5x^3 - 2x + 3, find the value of f(2)."]} +{"question_id": 69, "category": "math", "turns": ["Solve for x in the equation 3x + 10 = 5(x - 2)."]} +{"question_id": 70, "category": "math", "turns": ["If the endpoints of a line segment are (2, -2) and (10, 4), what is the length of the segment?"]} +{"question_id": 71, "category": "writing", "turns": ["Can you help me write a formal email to a potential business partner proposing a joint venture?"]} +{"question_id": 72, "category": "writing", "turns": ["Can you help me write a resignation letter to my current employer, while leaving on good terms and expressing gratitude for the opportunities provided?"]} +{"question_id": 73, "category": "writing", "turns": ["Use an appropriate format to structure a formal letter of recommendation for a student applying to a prestigious graduate program in computer science."]} +{"question_id": 74, "category": "writing", "turns": ["Write a compelling product launch announcement email to inform our customers of our new software solution."]} +{"question_id": 75, "category": "writing", "turns": ["Draft an apology email to a customer who experienced a delay in their order, and provide reassurance that the issue has been resolved."]} +{"question_id": 76, "category": "writing", "turns": ["Write a script for a YouTube video exploring the history and cultural significance of jazz."]} +{"question_id": 77, "category": "writing", "turns": ["Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions."]} +{"question_id": 78, "category": "writing", "turns": ["Write a captivating movie review for a recently released science fiction film, discussing its plot, characters, and special effects."]} +{"question_id": 79, "category": "writing", "turns": ["Structure a podcast script for an episode discussing the influence of streaming platforms on the music industry."]} +{"question_id": 80, "category": "writing", "turns": ["Write a symphony concert review, discussing the orchestra's performance and overall audience experience."]} diff --git a/third_party/FastChat/fastchat/llm_judge/data/vicuna_bench/reference_answer/gpt-4.jsonl b/third_party/FastChat/fastchat/llm_judge/data/vicuna_bench/reference_answer/gpt-4.jsonl new file mode 100644 index 00000000..f3c8368d --- /dev/null +++ b/third_party/FastChat/fastchat/llm_judge/data/vicuna_bench/reference_answer/gpt-4.jsonl @@ -0,0 +1,10 @@ +{"question_id": 61, "answer_id": "YdL4XwENkLCLXMbH65rjKy", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["Here's a simple C++ program that reads a text file line by line and counts the number of occurrences of a specific word in the file:\n\n```cpp\n#include \n#include \n#include \n#include \n\nint countWordOccurrences(const std::string& line, const std::string& word) {\n std::istringstream iss(line);\n std::string currentWord;\n int count = 0;\n\n while (iss >> currentWord) {\n if (currentWord == word) {\n count++;\n }\n }\n\n return count;\n}\n\nint main() {\n std::string fileName, wordToCount;\n std::cout << \"Enter the file name: \";\n std::cin >> fileName;\n std::cout << \"Enter the word to count: \";\n std::cin >> wordToCount;\n\n std::ifstream inputFile(fileName);\n if (!inputFile) {\n std::cerr << \"Error opening the file: \" << fileName << std::endl;\n return 1;\n }\n\n std::string line;\n int totalCount = 0;\n while (std::getline(inputFile, line)) {\n totalCount += countWordOccurrences(line, wordToCount);\n }\n\n inputFile.close();\n\n std::cout << \"The word \\\"\" << wordToCount << \"\\\" occurs \" << totalCount << \" times in the file.\" << std::endl;\n\n return 0;\n}\n```\n\nTo compile and run the program, save it to a file named `word_counter.cpp` and use the following commands:\n\n```sh\ng++ -o word_counter word_counter.cpp\n./word_counter\n```\n\nThe program will prompt you to enter the file name and the word you want to count. After that, it will display the number of occurrences of the specified word in the file."]}], "tstamp": 1686545085.5954556} +{"question_id": 62, "answer_id": "4nk9zQkeRKvfjxEEqfyXmj", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["Here's a Python function that finds the longest common subsequence of two input strings using dynamic programming:\n\n```python\ndef longest_common_subsequence(s1, s2):\n # Create a matrix to store the lengths of common subsequences\n matrix = [[0] * (len(s2) + 1) for _ in range(len(s1) + 1)]\n\n # Iterate through the strings and fill the matrix\n for i in range(len(s1) - 1, -1, -1):\n for j in range(len(s2) - 1, -1, -1):\n if s1[i] == s2[j]:\n matrix[i][j] = 1 + matrix[i + 1][j + 1]\n else:\n matrix[i][j] = max(matrix[i + 1][j], matrix[i][j + 1])\n\n # Reconstruct the longest common subsequence from the matrix\n lcs = []\n i, j = 0, 0\n while i < len(s1) and j < len(s2):\n if s1[i] == s2[j]:\n lcs.append(s1[i])\n i += 1\n j += 1\n elif matrix[i + 1][j] > matrix[i][j + 1]:\n i += 1\n else:\n j += 1\n\n return ''.join(lcs)\n\n# Example usage:\ns1 = \"ABCDGH\"\ns2 = \"AEDFHR\"\nprint(longest_common_subsequence(s1, s2)) # Output: \"ADH\"\n```\n\nThis function uses dynamic programming to build a matrix that stores the lengths of common subsequences for each pair of substrings. It then reconstructs the longest common subsequence from the matrix."]}], "tstamp": 1686544775.161658} +{"question_id": 63, "answer_id": "2MkXEfsoSuSC72BYGJuqhX", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["To implement a regular expression in Python to validate an email address, you can use the `re` module. Here's a simple example:\n\n```python\nimport re\n\ndef validate_email(email):\n pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$'\n return bool(re.match(pattern, email))\n\n# Test cases\nemails = [\n \"test@example.com\",\n \"test@example.co.uk\",\n \"test+123@example.com\",\n \"test@subdomain.example.com\",\n \"test@sub_domain.example.com\",\n \"test@sub-domain.example.com\",\n \"test@example\",\n \"test@.com\",\n \"test@.example.com\",\n \"test@example..com\",\n \"test@.example..com\",\n \"test@example.com.\",\n \".test@example.com\",\n]\n\nfor email in emails:\n print(f\"{email}: {validate_email(email)}\")\n```\n\nThis regular expression checks for the following conditions:\n\n1. The email address starts with one or more alphanumeric characters, dots, underscores, percent signs, plus signs, or hyphens.\n2. Followed by the '@' symbol.\n3. Followed by one or more alphanumeric characters, dots, or hyphens.\n4. Followed by a dot.\n5. Ends with two or more alphabetic characters.\n\nPlease note that this is a simple example and may not cover all edge cases. You can modify the regular expression to better suit your specific requirements."]}], "tstamp": 1686544778.904383} +{"question_id": 64, "answer_id": "jPeywxk9g4LWyLPurjhsd2", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["Here's a Python program to find the nth Fibonacci number using dynamic programming (also known as memoization):\n\n```python\ndef fibonacci(n, memo={}):\n if n <= 0:\n return 0\n elif n == 1:\n return 1\n elif n not in memo:\n memo[n] = fibonacci(n - 1, memo) + fibonacci(n - 2, memo)\n return memo[n]\n\n# Test the function\nn = int(input(\"Enter the value of n: \"))\nresult = fibonacci(n)\nprint(f\"The {n}th Fibonacci number is: {result}\")\n```\n\nThis program defines a function `fibonacci(n, memo={})` that calculates the nth Fibonacci number using dynamic programming. The `memo` dictionary is used to store the results of previously calculated Fibonacci numbers, which helps to avoid redundant calculations and improve the efficiency of the program."]}], "tstamp": 1686544763.1005273} +{"question_id": 65, "answer_id": "hD8NiB3PY5aoWg95mNjnj3", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["Here's a Python implementation of the binary search algorithm to find a specific element in a sorted array:\n\n```python\ndef binary_search(arr, target):\n low, high = 0, len(arr) - 1\n\n while low <= high:\n mid = (low + high) // 2\n mid_val = arr[mid]\n\n if mid_val == target:\n return mid\n elif mid_val < target:\n low = mid + 1\n else:\n high = mid - 1\n\n return -1\n\n# Example usage:\narr = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19]\ntarget = 11\n\nresult = binary_search(arr, target)\nif result != -1:\n print(f\"Element {target} found at index {result}\")\nelse:\n print(f\"Element {target} not found in the array\")\n```\n\nThis binary search function takes a sorted array `arr` and a target value `target` as input. It initializes two pointers, `low` and `high`, to the first and last indices of the array, respectively. The function then iteratively updates the pointers and calculates the middle index `mid` until the target value is found or the search space is exhausted. If the target value is found, the function returns the index of the target value; otherwise, it returns -1."]}], "tstamp": 1686544779.33291} +{"question_id": 66, "answer_id": "c4rAXsTfmioefx6icubuFB", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["To implement a queue using two stacks in Python, you can use the following code:\n\n```python\nclass Queue:\n def __init__(self):\n self.stack1 = []\n self.stack2 = []\n\n def enqueue(self, item):\n self.stack1.append(item)\n\n def dequeue(self):\n if not self.stack2:\n while self.stack1:\n self.stack2.append(self.stack1.pop())\n return self.stack2.pop()\n\n def is_empty(self):\n return not self.stack1 and not self.stack2\n\n def size(self):\n return len(self.stack1) + len(self.stack2)\n\n\n# Example usage:\nqueue = Queue()\nqueue.enqueue(1)\nqueue.enqueue(2)\nqueue.enqueue(3)\n\nprint(queue.dequeue()) # Output: 1\nprint(queue.dequeue()) # Output: 2\n\nqueue.enqueue(4)\nqueue.enqueue(5)\n\nprint(queue.dequeue()) # Output: 3\nprint(queue.dequeue()) # Output: 4\nprint(queue.dequeue()) # Output: 5\n```\n\nIn this implementation, we use two stacks (`stack1` and `stack2`) to simulate the behavior of a queue. When enqueuing an item, we simply push it onto `stack1`. When dequeuing an item, we pop from `stack2`. If `stack2` is empty, we first transfer all items from `stack1` to `stack2` (reversing their order) before popping from `stack2`. This ensures that the oldest item in the queue is always at the top of `stack2`."]}], "tstamp": 1686544798.8350058} +{"question_id": 67, "answer_id": "UQ3TFrzVqXtMTWpvhqtGym", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["You can implement a program to find the common elements in two arrays without using any extra data structures by using nested loops. Here's an example in Python:\n\n```python\ndef find_common_elements(arr1, arr2):\n common_elements = []\n\n for i in arr1:\n for j in arr2:\n if i == j:\n common_elements.append(i)\n break\n\n return common_elements\n\n# Example usage:\narr1 = [1, 2, 3, 4, 5]\narr2 = [4, 5, 6, 7, 8]\n\ncommon_elements = find_common_elements(arr1, arr2)\nprint(\"Common elements:\", common_elements)\n```\n\nThis program defines a function `find_common_elements` that takes two arrays as input. It initializes an empty list `common_elements` to store the common elements found. Then, it iterates through each element in the first array using a loop, and for each element, it iterates through the second array using another loop. If an element from the first array is found in the second array, it is added to the `common_elements` list and the inner loop is terminated using the `break` statement.\n\nFinally, the function returns the `common_elements` list. In the example usage, the program finds the common elements between two arrays `arr1` and `arr2`, and prints the result."]}], "tstamp": 1686544805.0822344} +{"question_id": 68, "answer_id": "eBGY6U9KCAsfuedcASSK7Q", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["To find the value of f(2), substitute x = 2 into the function f(x) = 5x^3 - 2x + 3.\n\nf(2) = 5(2)^3 - 2(2) + 3\nf(2) = 5(8) - 4 + 3\nf(2) = 40 - 4 + 3\nf(2) = 36 + 3\nf(2) = 39\n\nSo, the value of f(2) is 39."]}], "tstamp": 1686544893.1474469} +{"question_id": 69, "answer_id": "KsgnWzbpEvoW7RmRwaLgMT", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["To solve for x in the equation 3x + 10 = 5(x - 2), follow these steps:\n\n1. Distribute the 5 on the right side of the equation:\n3x + 10 = 5x - 10\n\n2. Move the x terms to one side by subtracting 3x from both sides:\n10 = 2x - 10\n\n3. Move the constant term to the other side by adding 10 to both sides:\n20 = 2x\n\n4. Divide both sides by 2 to isolate x:\nx = 10\n\nSo, the solution is x = 10."]}], "tstamp": 1686544795.032636} +{"question_id": 70, "answer_id": "BT82nWd8ZtXTzHrvhkg6eb", "model_id": "gpt-4", "choices": [{"index": 0, "turns": ["To find the length of the line segment, we can use the distance formula:\n\nDistance = \u221a((x2 - x1)^2 + (y2 - y1)^2)\n\nHere, (x1, y1) = (2, -2) and (x2, y2) = (10, 4).\n\nPlugging in the values, we get:\n\nDistance = \u221a((10 - 2)^2 + (4 - (-2))^2)\nDistance = \u221a((8)^2 + (6)^2)\nDistance = \u221a(64 + 36)\nDistance = \u221a100\nDistance = 10\n\nSo, the length of the line segment is 10 units."]}], "tstamp": 1686544809.1491597} diff --git a/third_party/FastChat/fastchat/llm_judge/download_mt_bench_pregenerated.py b/third_party/FastChat/fastchat/llm_judge/download_mt_bench_pregenerated.py new file mode 100644 index 00000000..9546311f --- /dev/null +++ b/third_party/FastChat/fastchat/llm_judge/download_mt_bench_pregenerated.py @@ -0,0 +1,51 @@ +""" +Download the pre-generated model answers and judgments for MT-bench. +""" +import os + +from fastchat.utils import run_cmd + +filenames = [ + "data/mt_bench/model_answer/alpaca-13b.jsonl", + "data/mt_bench/model_answer/baize-v2-13b.jsonl", + "data/mt_bench/model_answer/chatglm-6b.jsonl", + "data/mt_bench/model_answer/claude-instant-v1.jsonl", + "data/mt_bench/model_answer/claude-v1.jsonl", + "data/mt_bench/model_answer/dolly-v2-12b.jsonl", + "data/mt_bench/model_answer/falcon-40b-instruct.jsonl", + "data/mt_bench/model_answer/fastchat-t5-3b.jsonl", + "data/mt_bench/model_answer/gpt-3.5-turbo.jsonl", + "data/mt_bench/model_answer/gpt-4.jsonl", + "data/mt_bench/model_answer/gpt4all-13b-snoozy.jsonl", + "data/mt_bench/model_answer/guanaco-33b.jsonl", + "data/mt_bench/model_answer/guanaco-65b.jsonl", + "data/mt_bench/model_answer/h2ogpt-oasst-open-llama-13b.jsonl", + "data/mt_bench/model_answer/koala-13b.jsonl", + "data/mt_bench/model_answer/llama-13b.jsonl", + "data/mt_bench/model_answer/mpt-30b-chat.jsonl", + "data/mt_bench/model_answer/mpt-30b-instruct.jsonl", + "data/mt_bench/model_answer/mpt-7b-chat.jsonl", + "data/mt_bench/model_answer/nous-hermes-13b.jsonl", + "data/mt_bench/model_answer/oasst-sft-4-pythia-12b.jsonl", + "data/mt_bench/model_answer/oasst-sft-7-llama-30b.jsonl", + "data/mt_bench/model_answer/palm-2-chat-bison-001.jsonl", + "data/mt_bench/model_answer/rwkv-4-raven-14b.jsonl", + "data/mt_bench/model_answer/stablelm-tuned-alpha-7b.jsonl", + "data/mt_bench/model_answer/tulu-30b.jsonl", + "data/mt_bench/model_answer/vicuna-13b-v1.3.jsonl", + "data/mt_bench/model_answer/vicuna-33b-v1.3.jsonl", + "data/mt_bench/model_answer/vicuna-7b-v1.3.jsonl", + "data/mt_bench/model_answer/wizardlm-13b.jsonl", + "data/mt_bench/model_answer/wizardlm-30b.jsonl", + "data/mt_bench/model_judgment/gpt-4_single.jsonl", + "data/mt_bench/model_judgment/gpt-4_pair.jsonl", +] + + +if __name__ == "__main__": + prefix = "https://huggingface.co/spaces/lmsys/mt-bench/resolve/main/" + + for name in filenames: + os.makedirs(os.path.dirname(name), exist_ok=True) + ret = run_cmd(f"wget -q --show-progress -O {name} {prefix + name}") + assert ret == 0 diff --git a/third_party/FastChat/fastchat/llm_judge/gen_api_answer.py b/third_party/FastChat/fastchat/llm_judge/gen_api_answer.py new file mode 100644 index 00000000..8f9c6262 --- /dev/null +++ b/third_party/FastChat/fastchat/llm_judge/gen_api_answer.py @@ -0,0 +1,149 @@ +"""Generate answers with GPT-4 + +Usage: +python3 gen_api_answer.py --model gpt-3.5-turbo +""" +import argparse +import json +import os +import time +import concurrent.futures + +import openai +import shortuuid +import tqdm + +from fastchat.llm_judge.common import ( + load_questions, + temperature_config, + chat_completion_openai, + chat_completion_anthropic, + chat_completion_palm, +) +from fastchat.llm_judge.gen_model_answer import reorg_answer_file +from fastchat.model.model_adapter import get_conversation_template, ANTHROPIC_MODEL_LIST + + +def get_answer( + question: dict, model: str, num_choices: int, max_tokens: int, answer_file: str +): + assert ( + args.force_temperature is not None and "required_temperature" in question.keys() + ) == False + if args.force_temperature is not None: + temperature = args.force_temperature + elif "required_temperature" in question.keys(): + temperature = question["required_temperature"] + elif question["category"] in temperature_config: + temperature = temperature_config[question["category"]] + else: + temperature = 0.7 + + choices = [] + chat_state = None # for palm-2 model + for i in range(num_choices): + conv = get_conversation_template(model) + + turns = [] + for j in range(len(question["turns"])): + conv.append_message(conv.roles[0], question["turns"][j]) + conv.append_message(conv.roles[1], None) + + if model in ANTHROPIC_MODEL_LIST: + output = chat_completion_anthropic(model, conv, temperature, max_tokens) + elif model == "palm-2-chat-bison-001": + chat_state, output = chat_completion_palm( + chat_state, model, conv, temperature, max_tokens + ) + else: + output = chat_completion_openai(model, conv, temperature, max_tokens) + + conv.update_last_message(output) + turns.append(output) + + choices.append({"index": i, "turns": turns}) + + # Dump answers + ans = { + "question_id": question["question_id"], + "answer_id": shortuuid.uuid(), + "model_id": model, + "choices": choices, + "tstamp": time.time(), + } + + os.makedirs(os.path.dirname(answer_file), exist_ok=True) + with open(answer_file, "a") as fout: + fout.write(json.dumps(ans) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--bench-name", + type=str, + default="mt_bench", + help="The name of the benchmark question set.", + ) + parser.add_argument("--answer-file", type=str, help="The output answer file.") + parser.add_argument("--model", type=str, default="gpt-3.5-turbo") + parser.add_argument( + "--num-choices", + type=int, + default=1, + help="How many completion choices to generate.", + ) + parser.add_argument( + "--force-temperature", type=float, help="Forcibly set a sampling temperature." + ) + parser.add_argument( + "--max-tokens", + type=int, + default=1024, + help="The maximum number of new generated tokens.", + ) + parser.add_argument( + "--question-begin", + type=int, + help="A debug option. The begin index of questions.", + ) + parser.add_argument( + "--question-end", type=int, help="A debug option. The end index of questions." + ) + parser.add_argument( + "--parallel", type=int, default=1, help="The number of concurrent API calls." + ) + parser.add_argument("--openai-api-base", type=str, default=None) + args = parser.parse_args() + + if args.openai_api_base is not None: + openai.api_base = args.openai_api_base + + question_file = f"data/{args.bench_name}/question.jsonl" + questions = load_questions(question_file, args.question_begin, args.question_end) + + if args.answer_file: + answer_file = args.answer_file + else: + answer_file = f"data/{args.bench_name}/model_answer/{args.model}.jsonl" + print(f"Output to {answer_file}") + + with concurrent.futures.ThreadPoolExecutor(max_workers=args.parallel) as executor: + futures = [] + for question in questions: + future = executor.submit( + get_answer, + question, + args.model, + args.num_choices, + args.max_tokens, + answer_file, + ) + futures.append(future) + + for future in tqdm.tqdm( + concurrent.futures.as_completed(futures), total=len(futures) + ): + future.result() + + reorg_answer_file(answer_file) diff --git a/third_party/FastChat/fastchat/llm_judge/gen_judgment.py b/third_party/FastChat/fastchat/llm_judge/gen_judgment.py new file mode 100644 index 00000000..a1c70b29 --- /dev/null +++ b/third_party/FastChat/fastchat/llm_judge/gen_judgment.py @@ -0,0 +1,322 @@ +""" +Usage: +python gen_judgment.py --model-list [LIST-OF-MODEL-ID] --parallel [num-concurrent-api-call] --mode [single|pairwise-baseline|pairwise-all] +""" +import argparse +from concurrent.futures import ThreadPoolExecutor +import json + +import numpy as np +from tqdm import tqdm + +from fastchat.llm_judge.common import ( + load_questions, + load_model_answers, + load_judge_prompts, + check_data, + play_a_match_pair, + play_a_match_single, + get_model_list, + Judge, + MatchPair, + MatchSingle, + NEED_REF_CATS, +) + + +def make_match( + questions, + models, + model_answers, + judge, + baseline_model, + ref_answers=None, + multi_turn=False, +): + matches = [] + for q in questions: + if multi_turn and len(q["turns"]) != 2: + continue + for i in range(len(models)): + q_id = q["question_id"] + m_1 = models[i] + m_2 = baseline_model + if m_1 == m_2: + continue + a_1 = model_answers[m_1][q_id] + a_2 = model_answers[baseline_model][q_id] + if ref_answers is not None: + ref = ref_answers[judge.model_name][q_id] + match = MatchPair( + dict(q), + m_1, + m_2, + a_1, + a_2, + judge, + ref_answer=ref, + multi_turn=multi_turn, + ) + else: + match = MatchPair( + dict(q), m_1, m_2, a_1, a_2, judge, multi_turn=multi_turn + ) + matches.append(match) + return matches + + +def make_match_all_pairs( + questions, + models, + model_answers, + judge, + baseline_model=None, + ref_answers=None, + multi_turn=False, +): + matches = [] + for q in questions: + if multi_turn and len(q["turns"]) != 2: + continue + for i in range(len(models)): + for j in range(i + 1, len(models)): + q_id = q["question_id"] + m_1 = models[i] + m_2 = models[j] + a_1 = model_answers[m_1][q_id] + a_2 = model_answers[m_2][q_id] + if ref_answers is not None: + ref = ref_answers[judge.model_name][q_id] + match = MatchPair( + dict(q), + m_1, + m_2, + a_1, + a_2, + judge, + ref_answer=ref, + multi_turn=multi_turn, + ) + else: + match = MatchPair( + dict(q), m_1, m_2, a_1, a_2, judge, multi_turn=multi_turn + ) + matches.append(match) + return matches + + +def make_match_single( + questions, + models, + model_answers, + judge, + baseline_model=None, + ref_answers=None, + multi_turn=False, +): + matches = [] + for q in questions: + if multi_turn and len(q["turns"]) != 2: + continue + for i in range(len(models)): + q_id = q["question_id"] + m = models[i] + a = model_answers[m][q_id] + if ref_answers is not None: + ref = ref_answers[judge.model_name][q_id] + matches.append( + MatchSingle( + dict(q), m, a, judge, ref_answer=ref, multi_turn=multi_turn + ) + ) + else: + matches.append(MatchSingle(dict(q), m, a, judge, multi_turn=multi_turn)) + return matches + + +def make_judge_pairwise(judge_model, judge_prompts): + judges = {} + judges["default"] = Judge(judge_model, judge_prompts["pair-v2"]) + judges["math"] = Judge(judge_model, judge_prompts["pair-math-v1"], ref_based=True) + judges["default-mt"] = Judge( + judge_model, judge_prompts["pair-v2-multi-turn"], multi_turn=True + ) + judges["math-mt"] = Judge( + judge_model, + judge_prompts["pair-math-v1-multi-turn"], + ref_based=True, + multi_turn=True, + ) + return judges + + +def make_judge_single(judge_model, judge_prompts): + judges = {} + judges["default"] = Judge(judge_model, judge_prompts["single-v1"]) + judges["math"] = Judge(judge_model, judge_prompts["single-math-v1"], ref_based=True) + judges["default-mt"] = Judge( + judge_model, judge_prompts["single-v1-multi-turn"], multi_turn=True + ) + judges["math-mt"] = Judge( + judge_model, + judge_prompts["single-math-v1-multi-turn"], + ref_based=True, + multi_turn=True, + ) + return judges + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--bench-name", + type=str, + default="mt_bench", + help="The name of the benchmark question set.", + ) + parser.add_argument( + "--judge-file", + type=str, + default="data/judge_prompts.jsonl", + help="The file of judge prompts.", + ) + parser.add_argument("--judge-model", type=str, default="gpt-4") + parser.add_argument("--baseline-model", type=str, default="gpt-3.5-turbo") + parser.add_argument( + "--mode", + type=str, + default="single", + choices=["pairwise-baseline", "pairwise-all", "single"], + help=( + "Evaluation mode. " + "`pairwise-baseline` runs pairwise comparision against a baseline. " + "`pairwise-all` runs pairwise comparision between all pairs. " + "`single` runs single answer grading." + ), + ) + parser.add_argument( + "--model-list", + type=str, + nargs="+", + default=None, + help="A list of models to be evaluated", + ) + parser.add_argument( + "--parallel", type=int, default=1, help="The number of concurrent API calls." + ) + parser.add_argument( + "--first-n", type=int, help="A debug option. Only run the first `n` judgments." + ) + args = parser.parse_args() + + question_file = f"data/{args.bench_name}/question.jsonl" + answer_dir = f"data/{args.bench_name}/model_answer" + ref_answer_dir = f"data/{args.bench_name}/reference_answer" + + # Load questions + questions = load_questions(question_file, None, None) + + # Load answers + model_answers = load_model_answers(answer_dir) + ref_answers = load_model_answers(ref_answer_dir) + + # Load judge + judge_prompts = load_judge_prompts(args.judge_file) + + if args.first_n: + questions = questions[: args.first_n] + + if args.model_list is None: + models = get_model_list(answer_dir) + else: + models = args.model_list + + if args.mode == "single": + judges = make_judge_single(args.judge_model, judge_prompts) + play_a_match_func = play_a_match_single + output_file = ( + f"data/{args.bench_name}/model_judgment/{args.judge_model}_single.jsonl" + ) + make_match_func = make_match_single + baseline_model = None + else: + judges = make_judge_pairwise(args.judge_model, judge_prompts) + play_a_match_func = play_a_match_pair + output_file = ( + f"data/{args.bench_name}/model_judgment/{args.judge_model}_pair.jsonl" + ) + if args.mode == "pairwise-all": + make_match_func = make_match_all_pairs + baseline_model = None + else: + make_match_func = make_match + baseline_model = args.baseline_model + + check_data(questions, model_answers, ref_answers, models, judges) + + question_math = [q for q in questions if q["category"] in NEED_REF_CATS] + question_default = [q for q in questions if q["category"] not in NEED_REF_CATS] + + # Make matches + matches = [] + matches += make_match_func( + question_default, models, model_answers, judges["default"], baseline_model + ) + matches += make_match_func( + question_math, + models, + model_answers, + judges["math"], + baseline_model, + ref_answers, + ) + matches += make_match_func( + question_default, + models, + model_answers, + judges["default-mt"], + baseline_model, + multi_turn=True, + ) + matches += make_match_func( + question_math, + models, + model_answers, + judges["math-mt"], + baseline_model, + ref_answers, + multi_turn=True, + ) + + match_stat = {} + match_stat["bench_name"] = args.bench_name + match_stat["mode"] = args.mode + match_stat["judge"] = args.judge_model + match_stat["baseline"] = baseline_model + match_stat["model_list"] = models + match_stat["total_num_questions"] = len(questions) + match_stat["total_num_matches"] = len(matches) + match_stat["output_path"] = output_file + + # Show match stats and prompt enter to continue + print("Stats:") + print(json.dumps(match_stat, indent=4)) + input("Press Enter to confirm...") + + # Play matches + if args.parallel == 1: + for match in tqdm(matches): + play_a_match_func(match, output_file=output_file) + else: + + def play_a_match_wrapper(match): + play_a_match_func(match, output_file=output_file) + + np.random.seed(0) + np.random.shuffle(matches) + + with ThreadPoolExecutor(args.parallel) as executor: + for match in tqdm( + executor.map(play_a_match_wrapper, matches), total=len(matches) + ): + pass diff --git a/third_party/FastChat/fastchat/llm_judge/gen_model_answer.py b/third_party/FastChat/fastchat/llm_judge/gen_model_answer.py new file mode 100644 index 00000000..44166fb0 --- /dev/null +++ b/third_party/FastChat/fastchat/llm_judge/gen_model_answer.py @@ -0,0 +1,304 @@ +"""Generate answers with local models. + +Usage: +python3 gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-id fastchat-t5-3b-v1.0 +""" +import argparse +import json +import os +import random +import time + +import shortuuid +import torch +from tqdm import tqdm + +from fastchat.llm_judge.common import load_questions, temperature_config +from fastchat.model import load_model, get_conversation_template +from fastchat.utils import str_to_torch_dtype + + +def run_eval( + model_path, + model_id, + question_file, + question_begin, + question_end, + answer_file, + max_new_token, + num_choices, + num_gpus_per_model, + num_gpus_total, + max_gpu_memory, + dtype, + revision, +): + questions = load_questions(question_file, question_begin, question_end) + # random shuffle the questions to balance the loading + random.shuffle(questions) + + # Split the question file into `num_gpus` files + assert num_gpus_total % num_gpus_per_model == 0 + use_ray = num_gpus_total // num_gpus_per_model > 1 + + if use_ray: + get_answers_func = ray.remote(num_gpus=num_gpus_per_model)( + get_model_answers + ).remote + else: + get_answers_func = get_model_answers + + chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model) + ans_handles = [] + for i in range(0, len(questions), chunk_size): + ans_handles.append( + get_answers_func( + model_path, + model_id, + questions[i : i + chunk_size], + answer_file, + max_new_token, + num_choices, + num_gpus_per_model, + max_gpu_memory, + dtype=dtype, + revision=revision, + ) + ) + + if use_ray: + ray.get(ans_handles) + + +@torch.inference_mode() +def get_model_answers( + model_path, + model_id, + questions, + answer_file, + max_new_token, + num_choices, + num_gpus_per_model, + max_gpu_memory, + dtype, + revision, +): + model, tokenizer = load_model( + model_path, + revision=revision, + device="cuda", + num_gpus=num_gpus_per_model, + max_gpu_memory=max_gpu_memory, + dtype=dtype, + load_8bit=False, + cpu_offloading=False, + debug=False, + ) + + for question in tqdm(questions): + if question["category"] in temperature_config: + temperature = temperature_config[question["category"]] + else: + temperature = 0.7 + + choices = [] + for i in range(num_choices): + torch.manual_seed(i) + conv = get_conversation_template(model_id) + turns = [] + for j in range(len(question["turns"])): + qs = question["turns"][j] + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + input_ids = tokenizer([prompt]).input_ids + + if temperature < 1e-4: + do_sample = False + else: + do_sample = True + + # some models may error out when generating long outputs + try: + output_ids = model.generate( + torch.as_tensor(input_ids).cuda(), + do_sample=do_sample, + temperature=temperature, + max_new_tokens=max_new_token, + ) + if model.config.is_encoder_decoder: + output_ids = output_ids[0] + else: + output_ids = output_ids[0][len(input_ids[0]) :] + + # be consistent with the template's stop_token_ids + if conv.stop_token_ids: + stop_token_ids_index = [ + i + for i, id in enumerate(output_ids) + if id in conv.stop_token_ids + ] + if len(stop_token_ids_index) > 0: + output_ids = output_ids[: stop_token_ids_index[0]] + + output = tokenizer.decode( + output_ids, + spaces_between_special_tokens=False, + ) + if conv.stop_str and isinstance(conv.stop_str, list): + stop_str_indices = sorted( + [ + output.find(stop_str) + for stop_str in conv.stop_str + if output.find(stop_str) > 0 + ] + ) + if len(stop_str_indices) > 0: + output = output[: stop_str_indices[0]] + elif conv.stop_str and output.find(conv.stop_str) > 0: + output = output[: output.find(conv.stop_str)] + + for special_token in tokenizer.special_tokens_map.values(): + if isinstance(special_token, list): + for special_tok in special_token: + output = output.replace(special_tok, "") + else: + output = output.replace(special_token, "") + output = output.strip() + + if conv.name == "xgen" and output.startswith("Assistant:"): + output = output.replace("Assistant:", "", 1).strip() + except RuntimeError as e: + print("ERROR question ID: ", question["question_id"]) + output = "ERROR" + + conv.update_last_message(output) + turns.append(output) + + choices.append({"index": i, "turns": turns}) + + # Dump answers + os.makedirs(os.path.dirname(answer_file), exist_ok=True) + with open(os.path.expanduser(answer_file), "a") as fout: + ans_json = { + "question_id": question["question_id"], + "answer_id": shortuuid.uuid(), + "model_id": model_id, + "choices": choices, + "tstamp": time.time(), + } + fout.write(json.dumps(ans_json) + "\n") + + +def reorg_answer_file(answer_file): + """Sort by question id and de-duplication""" + answers = {} + with open(answer_file, "r") as fin: + for l in fin: + qid = json.loads(l)["question_id"] + answers[qid] = l + + qids = sorted(list(answers.keys())) + with open(answer_file, "w") as fout: + for qid in qids: + fout.write(answers[qid]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model-path", + type=str, + required=True, + help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", + ) + parser.add_argument( + "--model-id", type=str, required=True, help="A custom name for the model." + ) + parser.add_argument( + "--bench-name", + type=str, + default="mt_bench", + help="The name of the benchmark question set.", + ) + parser.add_argument( + "--question-begin", + type=int, + help="A debug option. The begin index of questions.", + ) + parser.add_argument( + "--question-end", type=int, help="A debug option. The end index of questions." + ) + parser.add_argument("--answer-file", type=str, help="The output answer file.") + parser.add_argument( + "--max-new-token", + type=int, + default=1024, + help="The maximum number of new generated tokens.", + ) + parser.add_argument( + "--num-choices", + type=int, + default=1, + help="How many completion choices to generate.", + ) + parser.add_argument( + "--num-gpus-per-model", + type=int, + default=1, + help="The number of GPUs per model.", + ) + parser.add_argument( + "--num-gpus-total", type=int, default=1, help="The total number of GPUs." + ) + parser.add_argument( + "--max-gpu-memory", + type=str, + help="Maxmum GPU memory used for model weights per GPU.", + ) + parser.add_argument( + "--dtype", + type=str, + choices=["float32", "float16", "bfloat16"], + help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.", + default=None, + ) + parser.add_argument( + "--revision", + type=str, + default="main", + help="The model revision to load.", + ) + + args = parser.parse_args() + + if args.num_gpus_total // args.num_gpus_per_model > 1: + import ray + + ray.init() + + question_file = f"data/{args.bench_name}/question.jsonl" + if args.answer_file: + answer_file = args.answer_file + else: + answer_file = f"data/{args.bench_name}/model_answer/{args.model_id}.jsonl" + + print(f"Output to {answer_file}") + + run_eval( + model_path=args.model_path, + model_id=args.model_id, + question_file=question_file, + question_begin=args.question_begin, + question_end=args.question_end, + answer_file=answer_file, + max_new_token=args.max_new_token, + num_choices=args.num_choices, + num_gpus_per_model=args.num_gpus_per_model, + num_gpus_total=args.num_gpus_total, + max_gpu_memory=args.max_gpu_memory, + dtype=str_to_torch_dtype(args.dtype), + revision=args.revision, + ) + + reorg_answer_file(answer_file) diff --git a/third_party/FastChat/fastchat/llm_judge/qa_browser.py b/third_party/FastChat/fastchat/llm_judge/qa_browser.py new file mode 100644 index 00000000..e449dee3 --- /dev/null +++ b/third_party/FastChat/fastchat/llm_judge/qa_browser.py @@ -0,0 +1,418 @@ +""" +Usage: +python3 qa_browser.py --share +""" + +import argparse +from collections import defaultdict +import re + +import gradio as gr + +from fastchat.llm_judge.common import ( + load_questions, + load_model_answers, + load_single_model_judgments, + load_pairwise_model_judgments, + resolve_single_judgment_dict, + resolve_pairwise_judgment_dict, + get_single_judge_explanation, + get_pairwise_judge_explanation, +) + + +questions = [] +model_answers = {} + +model_judgments_normal_single = {} +model_judgments_math_single = {} + +model_judgments_normal_pairwise = {} +model_judgments_math_pairwise = {} + +question_selector_map = {} +category_selector_map = defaultdict(list) + + +def display_question(category_selector, request: gr.Request): + choices = category_selector_map[category_selector] + return gr.Dropdown.update( + value=choices[0], + choices=choices, + ) + + +def display_pairwise_answer( + question_selector, model_selector1, model_selector2, request: gr.Request +): + q = question_selector_map[question_selector] + qid = q["question_id"] + + ans1 = model_answers[model_selector1][qid] + ans2 = model_answers[model_selector2][qid] + + chat_mds = pairwise_to_gradio_chat_mds(q, ans1, ans2) + gamekey = (qid, model_selector1, model_selector2) + + judgment_dict = resolve_pairwise_judgment_dict( + q, + model_judgments_normal_pairwise, + model_judgments_math_pairwise, + multi_turn=False, + ) + + explanation = ( + "##### Model Judgment (first turn)\n" + + get_pairwise_judge_explanation(gamekey, judgment_dict) + ) + + judgment_dict_turn2 = resolve_pairwise_judgment_dict( + q, + model_judgments_normal_pairwise, + model_judgments_math_pairwise, + multi_turn=True, + ) + + explanation_turn2 = ( + "##### Model Judgment (second turn)\n" + + get_pairwise_judge_explanation(gamekey, judgment_dict_turn2) + ) + + return chat_mds + [explanation] + [explanation_turn2] + + +def display_single_answer(question_selector, model_selector1, request: gr.Request): + q = question_selector_map[question_selector] + qid = q["question_id"] + + ans1 = model_answers[model_selector1][qid] + + chat_mds = single_to_gradio_chat_mds(q, ans1) + gamekey = (qid, model_selector1) + + judgment_dict = resolve_single_judgment_dict( + q, model_judgments_normal_single, model_judgments_math_single, multi_turn=False + ) + + explanation = "##### Model Judgment (first turn)\n" + get_single_judge_explanation( + gamekey, judgment_dict + ) + + judgment_dict_turn2 = resolve_single_judgment_dict( + q, model_judgments_normal_single, model_judgments_math_single, multi_turn=True + ) + + explanation_turn2 = ( + "##### Model Judgment (second turn)\n" + + get_single_judge_explanation(gamekey, judgment_dict_turn2) + ) + + return chat_mds + [explanation] + [explanation_turn2] + + +newline_pattern1 = re.compile("\n\n(\d+\. )") +newline_pattern2 = re.compile("\n\n(- )") + + +def post_process_answer(x): + """Fix Markdown rendering problems.""" + x = x.replace("\u2022", "- ") + x = re.sub(newline_pattern1, "\n\g<1>", x) + x = re.sub(newline_pattern2, "\n\g<1>", x) + return x + + +def pairwise_to_gradio_chat_mds(question, ans_a, ans_b, turn=None): + end = len(question["turns"]) if turn is None else turn + 1 + + mds = ["", "", "", "", "", "", ""] + for i in range(end): + base = i * 3 + if i == 0: + mds[base + 0] = "##### User\n" + question["turns"][i] + else: + mds[base + 0] = "##### User's follow-up question \n" + question["turns"][i] + mds[base + 1] = "##### Assistant A\n" + post_process_answer( + ans_a["choices"][0]["turns"][i].strip() + ) + mds[base + 2] = "##### Assistant B\n" + post_process_answer( + ans_b["choices"][0]["turns"][i].strip() + ) + + ref = question.get("reference", ["", ""]) + + ref_md = "" + if turn is None: + if ref[0] != "" or ref[1] != "": + mds[6] = f"##### Reference Solution\nQ1. {ref[0]}\nQ2. {ref[1]}" + else: + x = ref[turn] if turn < len(ref) else "" + if x: + mds[6] = f"##### Reference Solution\n{ref[turn]}" + else: + mds[6] = "" + return mds + + +def single_to_gradio_chat_mds(question, ans, turn=None): + end = len(question["turns"]) if turn is None else turn + 1 + + mds = ["", "", "", "", ""] + for i in range(end): + base = i * 2 + if i == 0: + mds[base + 0] = "##### User\n" + question["turns"][i] + else: + mds[base + 0] = "##### User's follow-up question \n" + question["turns"][i] + mds[base + 1] = "##### Assistant A\n" + post_process_answer( + ans["choices"][0]["turns"][i].strip() + ) + + ref = question.get("reference", ["", ""]) + + ref_md = "" + if turn is None: + if ref[0] != "" or ref[1] != "": + mds[4] = f"##### Reference Solution\nQ1. {ref[0]}\nQ2. {ref[1]}" + else: + x = ref[turn] if turn < len(ref) else "" + if x: + mds[4] = f"##### Reference Solution\n{ref[turn]}" + else: + mds[4] = "" + return mds + + +def build_question_selector_map(): + global question_selector_map, category_selector_map + + # Build question selector map + for q in questions: + preview = f"{q['question_id']}: " + q["turns"][0][:128] + "..." + question_selector_map[preview] = q + category_selector_map[q["category"]].append(preview) + + +def build_pairwise_browser_tab(): + global question_selector_map, category_selector_map + + models = list(model_answers.keys()) + num_sides = 2 + num_turns = 2 + side_names = ["A", "B"] + + question_selector_choices = list(question_selector_map.keys()) + category_selector_choices = list(category_selector_map.keys()) + + # Selectors + with gr.Row(): + with gr.Column(scale=1, min_width=200): + category_selector = gr.Dropdown( + choices=category_selector_choices, label="Category", container=False + ) + with gr.Column(scale=100): + question_selector = gr.Dropdown( + choices=question_selector_choices, label="Question", container=False + ) + + model_selectors = [None] * num_sides + with gr.Row(): + for i in range(num_sides): + with gr.Column(): + if i == 0: + value = models[0] + else: + value = "gpt-3.5-turbo" + model_selectors[i] = gr.Dropdown( + choices=models, + value=value, + label=f"Model {side_names[i]}", + container=False, + ) + + # Conversation + chat_mds = [] + for i in range(num_turns): + chat_mds.append(gr.Markdown(elem_id=f"user_question_{i+1}")) + with gr.Row(): + for j in range(num_sides): + with gr.Column(scale=100): + chat_mds.append(gr.Markdown()) + + if j == 0: + with gr.Column(scale=1, min_width=8): + gr.Markdown() + reference = gr.Markdown(elem_id=f"reference") + chat_mds.append(reference) + + model_explanation = gr.Markdown(elem_id="model_explanation") + model_explanation2 = gr.Markdown(elem_id="model_explanation") + + # Callbacks + category_selector.change(display_question, [category_selector], [question_selector]) + question_selector.change( + display_pairwise_answer, + [question_selector] + model_selectors, + chat_mds + [model_explanation] + [model_explanation2], + ) + + for i in range(num_sides): + model_selectors[i].change( + display_pairwise_answer, + [question_selector] + model_selectors, + chat_mds + [model_explanation] + [model_explanation2], + ) + + return (category_selector,) + + +def build_single_answer_browser_tab(): + global question_selector_map, category_selector_map + + models = list(model_answers.keys()) + num_sides = 1 + num_turns = 2 + side_names = ["A"] + + question_selector_choices = list(question_selector_map.keys()) + category_selector_choices = list(category_selector_map.keys()) + + # Selectors + with gr.Row(): + with gr.Column(scale=1, min_width=200): + category_selector = gr.Dropdown( + choices=category_selector_choices, label="Category", container=False + ) + with gr.Column(scale=100): + question_selector = gr.Dropdown( + choices=question_selector_choices, label="Question", container=False + ) + + model_selectors = [None] * num_sides + with gr.Row(): + for i in range(num_sides): + with gr.Column(): + model_selectors[i] = gr.Dropdown( + choices=models, + value=models[i] if len(models) > i else "", + label=f"Model {side_names[i]}", + container=False, + ) + + # Conversation + chat_mds = [] + for i in range(num_turns): + chat_mds.append(gr.Markdown(elem_id=f"user_question_{i+1}")) + with gr.Row(): + for j in range(num_sides): + with gr.Column(scale=100): + chat_mds.append(gr.Markdown()) + + if j == 0: + with gr.Column(scale=1, min_width=8): + gr.Markdown() + + reference = gr.Markdown(elem_id=f"reference") + chat_mds.append(reference) + + model_explanation = gr.Markdown(elem_id="model_explanation") + model_explanation2 = gr.Markdown(elem_id="model_explanation") + + # Callbacks + category_selector.change(display_question, [category_selector], [question_selector]) + question_selector.change( + display_single_answer, + [question_selector] + model_selectors, + chat_mds + [model_explanation] + [model_explanation2], + ) + + for i in range(num_sides): + model_selectors[i].change( + display_single_answer, + [question_selector] + model_selectors, + chat_mds + [model_explanation] + [model_explanation2], + ) + + return (category_selector,) + + +block_css = """ +#user_question_1 { + background-color: #DEEBF7; +} +#user_question_2 { + background-color: #E2F0D9; +} +#reference { + background-color: #FFF2CC; +} +#model_explanation { + background-color: #FBE5D6; +} +""" + + +def load_demo(): + dropdown_update = gr.Dropdown.update(value=list(category_selector_map.keys())[0]) + return dropdown_update, dropdown_update + + +def build_demo(): + build_question_selector_map() + + with gr.Blocks( + title="MT-Bench Browser", + theme=gr.themes.Base(text_size=gr.themes.sizes.text_lg), + css=block_css, + ) as demo: + gr.Markdown( + """ +# MT-Bench Browser +The code to generate answers and judgments is at [fastchat.llm_judge](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge). +""" + ) + with gr.Tab("Single Answer Grading"): + (category_selector,) = build_single_answer_browser_tab() + with gr.Tab("Pairwise Comparison"): + (category_selector2,) = build_pairwise_browser_tab() + demo.load(load_demo, [], [category_selector, category_selector2]) + + return demo + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int) + parser.add_argument("--share", action="store_true") + parser.add_argument("--bench-name", type=str, default="mt_bench") + args = parser.parse_args() + print(args) + + question_file = f"data/{args.bench_name}/question.jsonl" + answer_dir = f"data/{args.bench_name}/model_answer" + pairwise_model_judgment_file = ( + f"data/{args.bench_name}/model_judgment/gpt-4_pair.jsonl" + ) + single_model_judgment_file = ( + f"data/{args.bench_name}/model_judgment/gpt-4_single.jsonl" + ) + + # Load questions + questions = load_questions(question_file, None, None) + + # Load answers + model_answers = load_model_answers(answer_dir) + + # Load model judgments + model_judgments_normal_single = ( + model_judgments_math_single + ) = load_single_model_judgments(single_model_judgment_file) + model_judgments_normal_pairwise = ( + model_judgments_math_pairwise + ) = load_pairwise_model_judgments(pairwise_model_judgment_file) + + demo = build_demo() + demo.queue(concurrency_count=10, status_update_rate=10, api_open=False).launch( + server_name=args.host, server_port=args.port, share=args.share, max_threads=200 + ) diff --git a/third_party/FastChat/fastchat/llm_judge/show_result.py b/third_party/FastChat/fastchat/llm_judge/show_result.py new file mode 100644 index 00000000..f20801b1 --- /dev/null +++ b/third_party/FastChat/fastchat/llm_judge/show_result.py @@ -0,0 +1,130 @@ +""" +Usage: +python3 show_result.py --mode [single|pairwise-baseline|pairwise-all] +""" +import argparse +import pandas as pd + + +def display_result_single(args): + if args.input_file is None: + input_file = ( + f"data/{args.bench_name}/model_judgment/{args.judge_model}_single.jsonl" + ) + else: + input_file = args.input_file + + print(f"Input file: {input_file}") + df_all = pd.read_json(input_file, lines=True) + df = df_all[["model", "score", "turn"]] + df = df[df["score"] != -1] + + if args.model_list is not None: + df = df[df["model"].isin(args.model_list)] + + print("\n########## First turn ##########") + df_1 = df[df["turn"] == 1].groupby(["model", "turn"]).mean() + print(df_1.sort_values(by="score", ascending=False)) + + if args.bench_name == "mt_bench": + print("\n########## Second turn ##########") + df_2 = df[df["turn"] == 2].groupby(["model", "turn"]).mean() + print(df_2.sort_values(by="score", ascending=False)) + + print("\n########## Average ##########") + df_3 = df[["model", "score"]].groupby(["model"]).mean() + print(df_3.sort_values(by="score", ascending=False)) + + +def display_result_pairwise(args): + if args.input_file is None: + input_file = ( + f"data/{args.bench_name}/model_judgment/{args.judge_model}_pair.jsonl" + ) + else: + input_file = args.input_file + + print(f"Input file: {input_file}") + df_all = pd.read_json(input_file, lines=True) + df_all = df_all[(df_all["g1_winner"] != "error") & (df_all["g2_winner"] != "error")] + + model_list = ( + df_all["model_1"].unique().tolist() + df_all["model_2"].unique().tolist() + ) + model_list = list(set(model_list)) + + list_res = [] + # traverse df row by row + for index, row in df_all.iterrows(): + if args.model_list is not None and row["model_1"] not in args.model_list: + continue + if args.baseline_model is not None: + if args.baseline_model not in [row["model_1"], row["model_2"]]: + continue + if row["g1_winner"] == "tie" or row["g1_winner"] != row["g2_winner"]: + list_res.append({"model": row["model_1"], "win": 0, "loss": 0, "tie": 1}) + list_res.append({"model": row["model_2"], "win": 0, "loss": 0, "tie": 1}) + else: + if row["g1_winner"] == "model_1": + winner = row["model_1"] + loser = row["model_2"] + else: + winner = row["model_2"] + loser = row["model_1"] + list_res.append({"model": winner, "win": 1, "loss": 0, "tie": 0}) + list_res.append({"model": loser, "win": 0, "loss": 1, "tie": 0}) + + df = pd.DataFrame(list_res) + df = df.groupby(["model"]).sum() + + # remove baseline model + if args.baseline_model is not None: + df = df[df.index != args.baseline_model] + # add win rate + df["win_rate"] = df["win"] / (df["win"] + df["loss"] + df["tie"]) + df["loss_rate"] = df["loss"] / (df["win"] + df["loss"] + df["tie"]) + # each tie counts as 0.5 win + 0.5 loss + df["win_rate_adjusted"] = (df["win"] + 0.5 * df["tie"]) / ( + df["win"] + df["loss"] + df["tie"] + ) + # print(df.sort_values(by="win_rate", ascending=False)) + # print(df.sort_values(by="loss_rate", ascending=True)) + print(df.sort_values(by="win_rate_adjusted", ascending=False)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--bench-name", type=str, default="mt_bench") + parser.add_argument("--input-file", type=str) + parser.add_argument("--judge-model", type=str, default="gpt-4") + parser.add_argument("--baseline-model", type=str, default="gpt-3.5-turbo") + parser.add_argument( + "--model-list", + type=str, + nargs="+", + default=None, + help="A list of models to be evaluated", + ) + parser.add_argument( + "--mode", + type=str, + default="single", + choices=["pairwise-baseline", "pairwise-all", "single"], + help=( + "Evaluation mode. " + "`pairwise-baseline` runs pairwise comparision against a baseline. " + "`pairwise-all` runs pairwise comparision between all pairs. " + "`single` runs single answer grading." + ), + ) + args = parser.parse_args() + + if args.mode == "single": + display_result_func = display_result_single + else: + if args.mode == "pairwise-all": + args.baseline_model = None + display_result_func = display_result_pairwise + + print(f"Mode: {args.mode}") + display_result_func(args) diff --git a/third_party/FastChat/fastchat/model/__init__.py b/third_party/FastChat/fastchat/model/__init__.py new file mode 100644 index 00000000..29767dce --- /dev/null +++ b/third_party/FastChat/fastchat/model/__init__.py @@ -0,0 +1,5 @@ +from fastchat.model.model_adapter import ( + load_model, + get_conversation_template, + add_model_args, +) diff --git a/third_party/FastChat/fastchat/model/apply_delta.py b/third_party/FastChat/fastchat/model/apply_delta.py new file mode 100644 index 00000000..ba1c06d4 --- /dev/null +++ b/third_party/FastChat/fastchat/model/apply_delta.py @@ -0,0 +1,165 @@ +""" +Apply the delta weights on top of a base model. + +Usage: +python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta-v1.1 +""" +import argparse +import gc +import glob +import json +import os +import shutil +import tempfile + +from huggingface_hub import snapshot_download +import torch +from torch import nn +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig + + +GB = 1 << 30 + + +def split_files(model_path, tmp_path, split_size): + if not os.path.exists(model_path): + model_path = snapshot_download(repo_id=model_path) + if not os.path.exists(tmp_path): + os.makedirs(tmp_path) + + file_pattern = os.path.join(model_path, "pytorch_model-*.bin") + files = glob.glob(file_pattern) + + part = 0 + try: + for file_path in tqdm(files): + state_dict = torch.load(file_path) + new_state_dict = {} + + current_size = 0 + for name, param in state_dict.items(): + param_size = param.numel() * param.element_size() + + if current_size + param_size > split_size: + new_file_name = f"pytorch_model-{part}.bin" + new_file_path = os.path.join(tmp_path, new_file_name) + torch.save(new_state_dict, new_file_path) + current_size = 0 + new_state_dict = None + gc.collect() + new_state_dict = {} + part += 1 + + new_state_dict[name] = param + current_size += param_size + + new_file_name = f"pytorch_model-{part}.bin" + new_file_path = os.path.join(tmp_path, new_file_name) + torch.save(new_state_dict, new_file_path) + new_state_dict = None + gc.collect() + new_state_dict = {} + part += 1 + except Exception as e: + print(f"An error occurred during split_files: {e}") + shutil.rmtree(tmp_path) + raise + + +def apply_delta_low_cpu_mem(base_model_path, target_model_path, delta_path): + delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False) + delta_config = AutoConfig.from_pretrained(delta_path) + + if os.path.exists(target_model_path): + shutil.rmtree(target_model_path) + os.makedirs(target_model_path) + + split_size = 4 * GB + + with tempfile.TemporaryDirectory() as tmp_base_path, tempfile.TemporaryDirectory() as tmp_delta_path: + print(f"Split files for the base model to {tmp_base_path}") + split_files(base_model_path, tmp_base_path, split_size) + print(f"Split files for the delta weights to {tmp_delta_path}") + split_files(delta_path, tmp_delta_path, split_size) + + base_pattern = os.path.join(tmp_base_path, "pytorch_model-*.bin") + base_files = glob.glob(base_pattern) + delta_pattern = os.path.join(tmp_delta_path, "pytorch_model-*.bin") + delta_files = glob.glob(delta_pattern) + delta_state_dict = torch.load(delta_files[0]) + + print("Applying the delta") + weight_map = {} + total_size = 0 + + for i, base_file in tqdm(enumerate(base_files)): + state_dict = torch.load(base_file) + file_name = f"pytorch_model-{i}.bin" + for name, param in state_dict.items(): + if name not in delta_state_dict: + for delta_file in delta_files: + delta_state_dict = torch.load(delta_file) + gc.collect() + if name in delta_state_dict: + break + + state_dict[name] += delta_state_dict[name] + weight_map[name] = file_name + total_size += param.numel() * param.element_size() + gc.collect() + torch.save(state_dict, os.path.join(target_model_path, file_name)) + + with open( + os.path.join(target_model_path, "pytorch_model.bin.index.json"), "w" + ) as f: + json.dump( + {"weight_map": weight_map, "metadata": {"total_size": total_size}}, f + ) + + print(f"Saving the target model to {target_model_path}") + delta_tokenizer.save_pretrained(target_model_path) + delta_config.save_pretrained(target_model_path) + + +def apply_delta(base_model_path, target_model_path, delta_path): + print(f"Loading the delta weights from {delta_path}") + delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False) + delta = AutoModelForCausalLM.from_pretrained( + delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + + print(f"Loading the base model from {base_model_path}") + base = AutoModelForCausalLM.from_pretrained( + base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + + print("Applying the delta") + for name, param in tqdm(base.state_dict().items(), desc="Applying delta"): + assert name in delta.state_dict() + param.data += delta.state_dict()[name] + + print(f"Saving the target model to {target_model_path}") + base.save_pretrained(target_model_path) + delta_tokenizer.save_pretrained(target_model_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--target-model-path", type=str, required=True) + parser.add_argument("--delta-path", type=str, required=True) + parser.add_argument( + "--low-cpu-mem", + action="store_true", + help="Lower the cpu memory usage. This will split large files and use " + "disk as swap to reduce the memory usage below 10GB.", + ) + args = parser.parse_args() + + if args.low_cpu_mem: + apply_delta_low_cpu_mem( + args.base_model_path, args.target_model_path, args.delta_path + ) + else: + apply_delta(args.base_model_path, args.target_model_path, args.delta_path) diff --git a/third_party/FastChat/fastchat/model/apply_lora.py b/third_party/FastChat/fastchat/model/apply_lora.py new file mode 100644 index 00000000..01263dcc --- /dev/null +++ b/third_party/FastChat/fastchat/model/apply_lora.py @@ -0,0 +1,48 @@ +""" +Apply the LoRA weights on top of a base model. + +Usage: +python3 -m fastchat.model.apply_lora --base ~/model_weights/llama-7b --target ~/model_weights/baize-7b --lora project-baize/baize-lora-7B + +Dependency: +pip3 install git+https://github.com/huggingface/peft.git@2822398fbe896f25d4dac5e468624dc5fd65a51b +""" +import argparse + +import torch +from peft import PeftModel +from transformers import AutoTokenizer, AutoModelForCausalLM + + +def apply_lora(base_model_path, target_model_path, lora_path): + print(f"Loading the base model from {base_model_path}") + base = AutoModelForCausalLM.from_pretrained( + base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + base_tokenizer = AutoTokenizer.from_pretrained(base_model_path, use_fast=False) + + print(f"Loading the LoRA adapter from {lora_path}") + + lora_model = PeftModel.from_pretrained( + base, + lora_path, + # torch_dtype=torch.float16 + ) + + print("Applying the LoRA") + model = lora_model.merge_and_unload() + + print(f"Saving the target model to {target_model_path}") + model.save_pretrained(target_model_path) + base_tokenizer.save_pretrained(target_model_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--target-model-path", type=str, required=True) + parser.add_argument("--lora-path", type=str, required=True) + + args = parser.parse_args() + + apply_lora(args.base_model_path, args.target_model_path, args.lora_path) diff --git a/third_party/FastChat/fastchat/model/compression.py b/third_party/FastChat/fastchat/model/compression.py new file mode 100644 index 00000000..7329cfe0 --- /dev/null +++ b/third_party/FastChat/fastchat/model/compression.py @@ -0,0 +1,312 @@ +import dataclasses +import gc +import glob +import os + +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from huggingface_hub import snapshot_download +import torch +from torch import Tensor +from torch.nn import functional as F +import torch.nn as nn +from tqdm import tqdm +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + AutoModel, + AutoModelForSeq2SeqLM, +) + + +@dataclasses.dataclass +class CompressionConfig: + """Group-wise quantization.""" + + num_bits: int + group_size: int + group_dim: int + symmetric: bool + enabled: bool = True + + +default_compression_config = CompressionConfig( + num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True +) + + +class CLinear(nn.Module): + """Compressed Linear Layer.""" + + def __init__(self, weight=None, bias=None, device=None): + super().__init__() + if weight is None: + self.weight = None + elif isinstance(weight, Tensor): + self.weight = compress(weight.data.to(device), default_compression_config) + else: + self.weight = weight + self.bias = bias + + def forward(self, input: Tensor) -> Tensor: + weight = decompress(self.weight, default_compression_config) + if self.bias is None: + return F.linear(input.to(weight.dtype), weight) + return F.linear(input.to(weight.dtype), weight, self.bias.to(weight.dtype)) + + +def compress_module(module, target_device): + for attr_str in dir(module): + target_attr = getattr(module, attr_str) + if type(target_attr) == torch.nn.Linear: + setattr( + module, + attr_str, + CLinear(target_attr.weight, target_attr.bias, target_device), + ) + for name, child in module.named_children(): + compress_module(child, target_device) + + +def get_compressed_list(module, prefix=""): + compressed_list = [] + for attr_str in dir(module): + target_attr = getattr(module, attr_str) + if type(target_attr) == torch.nn.Linear: + full_name = ( + f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight" + ) + compressed_list.append(full_name) + for name, child in module.named_children(): + child_prefix = f"{prefix}.{name}" if prefix else name + for each in get_compressed_list(child, child_prefix): + compressed_list.append(each) + return compressed_list + + +def apply_compressed_weight(module, compressed_state_dict, target_device, prefix=""): + for attr_str in dir(module): + target_attr = getattr(module, attr_str) + if type(target_attr) == torch.nn.Linear: + full_name = ( + f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight" + ) + setattr( + module, + attr_str, + CLinear( + compressed_state_dict[full_name], target_attr.bias, target_device + ), + ) + for name, child in module.named_children(): + child_prefix = f"{prefix}.{name}" if prefix else name + apply_compressed_weight( + child, compressed_state_dict, target_device, child_prefix + ) + + +def load_compress_model(model_path, device, torch_dtype, use_fast, revision="main"): + # partially load model + # `use_fast=True`` is not supported for some models. + try: + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=use_fast, revision=revision, trust_remote_code=True + ) + except TypeError: + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=~use_fast, revision=revision, trust_remote_code=True + ) + with init_empty_weights(): + # `trust_remote_code` should be set as `True` for both AutoConfig and AutoModel + config = AutoConfig.from_pretrained( + model_path, + low_cpu_mem_usage=True, + torch_dtype=torch_dtype, + trust_remote_code=True, + revision=revision, + ) + # some models are loaded by AutoModel but not AutoModelForCausalLM, + # such as chatglm, chatglm2 + try: + # google/flan-* models are based on an AutoModelForSeq2SeqLM. + if "T5Config" in str(type(config)): + model = AutoModelForSeq2SeqLM.from_config( + config, trust_remote_code=True + ) + else: + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + except NameError: + model = AutoModel.from_config(config, trust_remote_code=True) + linear_weights = get_compressed_list(model) + if os.path.exists(model_path): + # `model_path` is a local folder + base_pattern = os.path.join(model_path, "pytorch_model*.bin") + else: + # `model_path` is a cached Hugging Face repo + # We don't necessarily need to download the model' repo again if there is a cache. + # So check the default huggingface cache first. + model_path_temp = os.path.join( + os.path.expanduser("~"), + ".cache/huggingface/hub", + "models--" + model_path.replace("/", "--"), + "snapshots/", + ) + downloaded = False + if os.path.exists(model_path_temp): + temp_last_dir = os.listdir(model_path_temp)[-1] + model_path_temp = os.path.join(model_path_temp, temp_last_dir) + base_pattern = os.path.join(model_path_temp, "pytorch_model*.bin") + files = glob.glob(base_pattern) + if len(files) > 0: + downloaded = True + + if downloaded: + model_path = model_path_temp + else: + model_path = snapshot_download(model_path, revision=revision) + base_pattern = os.path.join(model_path, "pytorch_model*.bin") + + files = glob.glob(base_pattern) + use_safetensors = False + if len(files) == 0: + base_pattern = os.path.join(model_path, "*.safetensors") + files = glob.glob(base_pattern) + use_safetensors = True + if len(files) == 0: + raise ValueError( + f"Cannot find any model weight files. " + f"Please check your (cached) weight path: {model_path}" + ) + + compressed_state_dict = {} + if use_safetensors: + from safetensors.torch import load_file + for filename in tqdm(files): + if use_safetensors: + tmp_state_dict = load_file(filename) + else: + tmp_state_dict = torch.load( + filename, map_location=lambda storage, loc: storage + ) + for name in tmp_state_dict: + if name in linear_weights: + tensor = tmp_state_dict[name].to(device, dtype=torch_dtype) + compressed_state_dict[name] = compress( + tensor, default_compression_config + ) + else: + compressed_state_dict[name] = tmp_state_dict[name].to( + device, dtype=torch_dtype + ) + tmp_state_dict[name] = None + tensor = None + gc.collect() + torch.cuda.empty_cache() + if device == "xpu": + torch.xpu.empty_cache() + if device == "npu": + torch.npu.empty_cache() + + for name in model.state_dict(): + if name not in linear_weights: + set_module_tensor_to_device( + model, name, device, value=compressed_state_dict[name] + ) + apply_compressed_weight(model, compressed_state_dict, device) + + if torch_dtype == torch.float16: + model.half() + model.to(device) + model.eval() + + return model, tokenizer + + +def compress(tensor, config): + """Simulate group-wise quantization.""" + if not config.enabled: + return tensor + + group_size, num_bits, group_dim, symmetric = ( + config.group_size, + config.num_bits, + config.group_dim, + config.symmetric, + ) + assert num_bits <= 8 + + original_shape = tensor.shape + num_groups = (original_shape[group_dim] + group_size - 1) // group_size + new_shape = ( + original_shape[:group_dim] + + (num_groups, group_size) + + original_shape[group_dim + 1 :] + ) + + # Pad + pad_len = (group_size - original_shape[group_dim] % group_size) % group_size + if pad_len != 0: + pad_shape = ( + original_shape[:group_dim] + (pad_len,) + original_shape[group_dim + 1 :] + ) + tensor = torch.cat( + [tensor, torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)], + dim=group_dim, + ) + data = tensor.view(new_shape) + + # Quantize + if symmetric: + B = 2 ** (num_bits - 1) - 1 + scale = B / torch.max(data.abs(), dim=group_dim + 1, keepdim=True)[0] + data = data * scale + data = data.clamp_(-B, B).round_().to(torch.int8) + return data, scale, original_shape + else: + B = 2**num_bits - 1 + mn = torch.min(data, dim=group_dim + 1, keepdim=True)[0] + mx = torch.max(data, dim=group_dim + 1, keepdim=True)[0] + + scale = B / (mx - mn) + data = data - mn + data.mul_(scale) + + data = data.clamp_(0, B).round_().to(torch.uint8) + return data, mn, scale, original_shape + + +def decompress(packed_data, config): + """Simulate group-wise dequantization.""" + if not config.enabled: + return packed_data + + group_size, num_bits, group_dim, symmetric = ( + config.group_size, + config.num_bits, + config.group_dim, + config.symmetric, + ) + + # Dequantize + if symmetric: + data, scale, original_shape = packed_data + data = data / scale + else: + data, mn, scale, original_shape = packed_data + data = data / scale + data.add_(mn) + + # Unpad + pad_len = (group_size - original_shape[group_dim] % group_size) % group_size + if pad_len: + padded_original_shape = ( + original_shape[:group_dim] + + (original_shape[group_dim] + pad_len,) + + original_shape[group_dim + 1 :] + ) + data = data.reshape(padded_original_shape) + indices = [slice(0, x) for x in original_shape] + return data[indices].contiguous() + else: + return data.view(original_shape) diff --git a/third_party/FastChat/fastchat/model/convert_fp16.py b/third_party/FastChat/fastchat/model/convert_fp16.py new file mode 100644 index 00000000..efc40aa8 --- /dev/null +++ b/third_party/FastChat/fastchat/model/convert_fp16.py @@ -0,0 +1,26 @@ +""" +Usage: +python3 -m fastchat.model.convert_fp16 --in in-folder --out out-folder +""" +import argparse + +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch + + +def convert_fp16(in_checkpoint, out_checkpoint): + tokenizer = AutoTokenizer.from_pretrained(in_checkpoint, use_fast=False) + model = AutoModelForCausalLM.from_pretrained( + in_checkpoint, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + model.save_pretrained(out_checkpoint) + tokenizer.save_pretrained(out_checkpoint) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-checkpoint", type=str, help="Path to the model") + parser.add_argument("--out-checkpoint", type=str, help="Path to the output model") + args = parser.parse_args() + + convert_fp16(args.in_checkpoint, args.out_checkpoint) diff --git a/third_party/FastChat/fastchat/model/llama_condense_monkey_patch.py b/third_party/FastChat/fastchat/model/llama_condense_monkey_patch.py new file mode 100644 index 00000000..cb45a8bb --- /dev/null +++ b/third_party/FastChat/fastchat/model/llama_condense_monkey_patch.py @@ -0,0 +1,71 @@ +# Code adapted from https://huggingface.co/kaiokendev/superhot-13b-8k-no-rlhf-test/blob/main/llama_rope_scaled_monkey_patch.py + +from functools import partial + +import torch +import transformers +import transformers.models.llama.modeling_llama + + +class CondenseRotaryEmbedding(torch.nn.Module): + def __init__( + self, dim, ratio, max_position_embeddings=2048, base=10000, device=None + ): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + + # Build here to make `torch.jit.trace` work. + self.ratio = ratio + max_position_embeddings *= ratio + self.max_seq_len_cached = max_position_embeddings + # print(f"Monkey Patching condense ratio {ratio}") + t = ( + torch.arange( + self.max_seq_len_cached, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype, + ) + / ratio + ) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + dtype = torch.get_default_dtype() + self.register_buffer( + "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False + ) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + t = ( + torch.arange( + self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype + ) + / self.ratio + ) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.register_buffer( + "cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False + ) + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + +def replace_llama_with_condense(ratio): + transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = partial( + CondenseRotaryEmbedding, ratio=ratio + ) diff --git a/third_party/FastChat/fastchat/model/make_delta.py b/third_party/FastChat/fastchat/model/make_delta.py new file mode 100644 index 00000000..480ba8f1 --- /dev/null +++ b/third_party/FastChat/fastchat/model/make_delta.py @@ -0,0 +1,48 @@ +""" +Make the delta weights by subtracting base weights. + +Usage: +python3 -m fastchat.model.make_delta --base ~/model_weights/llama-13b --target ~/model_weights/vicuna-13b --delta ~/model_weights/vicuna-13b-delta --hub-repo-id lmsys/vicuna-13b-delta-v1.1 +""" +import argparse + +import torch +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM + + +def make_delta(base_model_path, target_model_path, delta_path): + print(f"Loading the base model from {base_model_path}") + base = AutoModelForCausalLM.from_pretrained( + base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + + print(f"Loading the target model from {target_model_path}") + target = AutoModelForCausalLM.from_pretrained( + target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + target_tokenizer = AutoTokenizer.from_pretrained(target_model_path, use_fast=False) + + print("Calculating the delta") + for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): + assert name in base.state_dict() + param.data -= base.state_dict()[name] + + print(f"Saving the delta to {delta_path}") + if args.hub_repo_id: + kwargs = {"push_to_hub": True, "repo_id": args.hub_repo_id} + else: + kwargs = {} + target.save_pretrained(delta_path, **kwargs) + target_tokenizer.save_pretrained(delta_path, **kwargs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--target-model-path", type=str, required=True) + parser.add_argument("--delta-path", type=str, required=True) + parser.add_argument("--hub-repo-id", type=str) + args = parser.parse_args() + + make_delta(args.base_model_path, args.target_model_path, args.delta_path) diff --git a/third_party/FastChat/fastchat/model/model_adapter.py b/third_party/FastChat/fastchat/model/model_adapter.py new file mode 100644 index 00000000..91881d21 --- /dev/null +++ b/third_party/FastChat/fastchat/model/model_adapter.py @@ -0,0 +1,2257 @@ +"""Model adapter registration.""" + +import math +import os +import re +import sys +from typing import Dict, List, Optional +import warnings + +if sys.version_info >= (3, 9): + from functools import cache +else: + from functools import lru_cache as cache + +import psutil +import torch +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForCausalLM, + AutoModelForSeq2SeqLM, + AutoTokenizer, + LlamaTokenizer, + LlamaForCausalLM, + T5Tokenizer, +) + +from fastchat.constants import CPU_ISA +from fastchat.conversation import Conversation, get_conv_template +from fastchat.model.compression import load_compress_model +from fastchat.model.llama_condense_monkey_patch import replace_llama_with_condense +from fastchat.model.model_chatglm import generate_stream_chatglm +from fastchat.model.model_codet5p import generate_stream_codet5p +from fastchat.model.model_falcon import generate_stream_falcon +from fastchat.model.model_exllama import generate_stream_exllama +from fastchat.model.model_xfastertransformer import generate_stream_xft +from fastchat.model.monkey_patch_non_inplace import ( + replace_llama_attn_with_non_inplace_operations, +) +from fastchat.modules.awq import AWQConfig, load_awq_quantized +from fastchat.modules.exllama import ExllamaConfig, load_exllama_model +from fastchat.modules.xfastertransformer import load_xft_model, XftConfig +from fastchat.modules.gptq import GptqConfig, load_gptq_quantized +from fastchat.utils import get_gpu_memory + +# Check an environment variable to check if we should be sharing Peft model +# weights. When false we treat all Peft models as separate. +peft_share_base_weights = ( + os.environ.get("PEFT_SHARE_BASE_WEIGHTS", "false").lower() == "true" +) + +ANTHROPIC_MODEL_LIST = ( + "claude-1", + "claude-2", + "claude-2.0", + "claude-2.1", + "claude-instant-1", + "claude-instant-1.2", +) + +OPENAI_MODEL_LIST = ( + "gpt-3.5-turbo", + "gpt-3.5-turbo-0301", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-turbo", +) + + +class BaseModelAdapter: + """The base and the default model adapter.""" + + use_fast_tokenizer = True + + def match(self, model_path: str): + return True + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + try: + tokenizer = AutoTokenizer.from_pretrained( + model_path, + use_fast=self.use_fast_tokenizer, + revision=revision, + trust_remote_code=True, + ) + except TypeError: + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=False, revision=revision, trust_remote_code=True + ) + try: + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + except NameError: + model = AutoModel.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def load_compress_model(self, model_path, device, torch_dtype, revision="main"): + return load_compress_model( + model_path, + device, + torch_dtype, + use_fast=self.use_fast_tokenizer, + revision=revision, + ) + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("one_shot") + + +# A global registry for all model adapters +# TODO (lmzheng): make it a priority queue. +model_adapters: List[BaseModelAdapter] = [] + + +def register_model_adapter(cls): + """Register a model adapter.""" + model_adapters.append(cls()) + + +@cache +def get_model_adapter(model_path: str) -> BaseModelAdapter: + """Get a model adapter for a model_path.""" + model_path_basename = os.path.basename(os.path.normpath(model_path)) + + # Try the basename of model_path at first + for adapter in model_adapters: + if adapter.match(model_path_basename) and type(adapter) != BaseModelAdapter: + return adapter + + # Then try the full path + for adapter in model_adapters: + if adapter.match(model_path): + return adapter + + raise ValueError(f"No valid model adapter for {model_path}") + + +def raise_warning_for_incompatible_cpu_offloading_configuration( + device: str, load_8bit: bool, cpu_offloading: bool +): + if cpu_offloading: + if not load_8bit: + warnings.warn( + "The cpu-offloading feature can only be used while also using 8-bit-quantization.\n" + "Use '--load-8bit' to enable 8-bit-quantization\n" + "Continuing without cpu-offloading enabled\n" + ) + return False + if not "linux" in sys.platform: + warnings.warn( + "CPU-offloading is only supported on linux-systems due to the limited compatability with the bitsandbytes-package\n" + "Continuing without cpu-offloading enabled\n" + ) + return False + if device != "cuda": + warnings.warn( + "CPU-offloading is only enabled when using CUDA-devices\n" + "Continuing without cpu-offloading enabled\n" + ) + return False + return cpu_offloading + + +def load_model( + model_path: str, + device: str = "cuda", + num_gpus: int = 1, + max_gpu_memory: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + load_8bit: bool = False, + cpu_offloading: bool = False, + gptq_config: Optional[GptqConfig] = None, + awq_config: Optional[AWQConfig] = None, + exllama_config: Optional[ExllamaConfig] = None, + xft_config: Optional[XftConfig] = None, + revision: str = "main", + debug: bool = False, +): + """Load a model from Hugging Face.""" + import accelerate + + # get model adapter + adapter = get_model_adapter(model_path) + + # Handle device mapping + cpu_offloading = raise_warning_for_incompatible_cpu_offloading_configuration( + device, load_8bit, cpu_offloading + ) + if device == "cpu": + kwargs = {"torch_dtype": torch.float32} + if CPU_ISA in ["avx512_bf16", "amx"]: + try: + import intel_extension_for_pytorch as ipex + + kwargs = {"torch_dtype": torch.bfloat16} + except ImportError: + warnings.warn( + "Intel Extension for PyTorch is not installed, it can be installed to accelerate cpu inference" + ) + elif device == "cuda": + kwargs = {"torch_dtype": torch.float16} + if num_gpus != 1: + kwargs["device_map"] = "auto" + if max_gpu_memory is None: + kwargs[ + "device_map" + ] = "sequential" # This is important for not the same VRAM sizes + available_gpu_memory = get_gpu_memory(num_gpus) + kwargs["max_memory"] = { + i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" + for i in range(num_gpus) + } + else: + kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)} + elif device == "mps": + kwargs = {"torch_dtype": torch.float16} + import transformers + + version = tuple(int(v) for v in transformers.__version__.split(".")) + if version < (4, 35, 0): + # NOTE: Recent transformers library seems to fix the mps issue, also + # it has made some changes causing compatibility issues with our + # original patch. So we only apply the patch for older versions. + + # Avoid bugs in mps backend by not using in-place operations. + replace_llama_attn_with_non_inplace_operations() + elif device == "xpu": + kwargs = {"torch_dtype": torch.bfloat16} + # Try to load ipex, while it looks unused, it links into torch for xpu support + try: + import intel_extension_for_pytorch as ipex + except ImportError: + warnings.warn( + "Intel Extension for PyTorch is not installed, but is required for xpu inference." + ) + elif device == "npu": + kwargs = {"torch_dtype": torch.float16} + # Try to load ipex, while it looks unused, it links into torch for xpu support + try: + import torch_npu + except ImportError: + warnings.warn("Ascend Extension for PyTorch is not installed.") + else: + raise ValueError(f"Invalid device: {device}") + + if cpu_offloading: + # raises an error on incompatible platforms + from transformers import BitsAndBytesConfig + + if "max_memory" in kwargs: + kwargs["max_memory"]["cpu"] = ( + str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib" + ) + kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_8bit_fp32_cpu_offload=cpu_offloading + ) + kwargs["load_in_8bit"] = load_8bit + elif load_8bit: + if num_gpus != 1: + warnings.warn( + "8-bit quantization is not supported for multi-gpu inference." + ) + else: + model, tokenizer = adapter.load_compress_model( + model_path=model_path, + device=device, + torch_dtype=kwargs["torch_dtype"], + revision=revision, + ) + if debug: + print(model) + return model, tokenizer + elif awq_config and awq_config.wbits < 16: + assert ( + awq_config.wbits == 4 + ), "Currently we only support 4-bit inference for AWQ." + model, tokenizer = load_awq_quantized(model_path, awq_config, device) + if num_gpus != 1: + device_map = accelerate.infer_auto_device_map( + model, + max_memory=kwargs["max_memory"], + no_split_module_classes=[ + "OPTDecoderLayer", + "LlamaDecoderLayer", + "BloomBlock", + "MPTBlock", + "DecoderLayer", + ], + ) + model = accelerate.dispatch_model( + model, device_map=device_map, offload_buffers=True + ) + else: + model.to(device) + return model, tokenizer + elif gptq_config and gptq_config.wbits < 16: + model, tokenizer = load_gptq_quantized(model_path, gptq_config) + if num_gpus != 1: + device_map = accelerate.infer_auto_device_map( + model, + max_memory=kwargs["max_memory"], + no_split_module_classes=["LlamaDecoderLayer"], + ) + model = accelerate.dispatch_model( + model, device_map=device_map, offload_buffers=True + ) + else: + model.to(device) + return model, tokenizer + elif exllama_config: + model, tokenizer = load_exllama_model(model_path, exllama_config) + return model, tokenizer + elif xft_config: + model, tokenizer = load_xft_model(model_path, xft_config) + return model, tokenizer + kwargs["revision"] = revision + + if dtype is not None: # Overwrite dtype if it is provided in the arguments. + kwargs["torch_dtype"] = dtype + + if os.environ.get("FASTCHAT_USE_MODELSCOPE", "False").lower() == "true": + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + try: + from modelscope.hub.snapshot_download import snapshot_download + + if not os.path.exists(model_path): + model_path = snapshot_download(model_id=model_path, revision=revision) + except ImportError as e: + warnings.warn( + "Use model from www.modelscope.cn need pip install modelscope" + ) + raise e + + # Load model + model, tokenizer = adapter.load_model(model_path, kwargs) + + if ( + device == "cpu" + and kwargs["torch_dtype"] is torch.bfloat16 + and CPU_ISA is not None + ): + model = ipex.optimize(model, dtype=kwargs["torch_dtype"]) + + if (device == "cuda" and num_gpus == 1 and not cpu_offloading) or device in ( + "mps", + "xpu", + "npu", + ): + model.to(device) + + if device == "xpu": + model = torch.xpu.optimize(model, dtype=kwargs["torch_dtype"], inplace=True) + + if debug: + print(model) + + return model, tokenizer + + +def get_conversation_template(model_path: str) -> Conversation: + """Get the default conversation template.""" + adapter = get_model_adapter(model_path) + return adapter.get_default_conv_template(model_path) + + +def get_generate_stream_function(model: torch.nn.Module, model_path: str): + """Get the generate_stream function for inference.""" + from fastchat.serve.inference import generate_stream + + model_type = str(type(model)).lower() + is_peft = "peft" in model_type + is_chatglm = "chatglm" in model_type + is_falcon = "rwforcausallm" in model_type + is_codet5p = "codet5p" in model_type + is_exllama = "exllama" in model_type + is_xft = "xft" in model_type + + if is_chatglm: + return generate_stream_chatglm + elif is_falcon: + return generate_stream_falcon + elif is_codet5p: + return generate_stream_codet5p + elif is_exllama: + return generate_stream_exllama + elif is_xft: + return generate_stream_xft + + elif peft_share_base_weights and is_peft: + # Return a curried stream function that loads the right adapter + # according to the model_name available in this context. This ensures + # the right weights are available. + @torch.inference_mode() + def generate_stream_peft( + model, + tokenizer, + params: Dict, + device: str, + context_len: int, + stream_interval: int = 2, + judge_sent_end: bool = False, + ): + model.set_adapter(model_path) + base_model_type = str(type(model.base_model.model)) + is_chatglm = "chatglm" in base_model_type + is_falcon = "rwforcausallm" in base_model_type + is_codet5p = "codet5p" in base_model_type + is_exllama = "exllama" in base_model_type + is_xft = "xft" in base_model_type + generate_stream_function = generate_stream + if is_chatglm: + generate_stream_function = generate_stream_chatglm + elif is_falcon: + generate_stream_function = generate_stream_falcon + elif is_codet5p: + generate_stream_function = generate_stream_codet5p + elif is_exllama: + generate_stream_function = generate_stream_exllama + elif is_xft: + generate_stream_function = generate_stream_xft + for x in generate_stream_function( + model, + tokenizer, + params, + device, + context_len, + stream_interval, + judge_sent_end, + ): + yield x + + return generate_stream_peft + else: + return generate_stream + + +def add_model_args(parser): + parser.add_argument( + "--model-path", + type=str, + default="lmsys/vicuna-7b-v1.5", + help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", + ) + parser.add_argument( + "--revision", + type=str, + default="main", + help="Hugging Face Hub model revision identifier", + ) + parser.add_argument( + "--device", + type=str, + choices=["cpu", "cuda", "mps", "xpu", "npu"], + default="cuda", + help="The device type", + ) + parser.add_argument( + "--gpus", + type=str, + default=None, + help="A single GPU like 1 or multiple GPUs like 0,2", + ) + parser.add_argument("--num-gpus", type=int, default=1) + parser.add_argument( + "--max-gpu-memory", + type=str, + help="The maximum memory per GPU for storing model weights. Use a string like '13Gib'", + ) + parser.add_argument( + "--dtype", + type=str, + choices=["float32", "float16", "bfloat16"], + help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.", + default=None, + ) + parser.add_argument( + "--load-8bit", action="store_true", help="Use 8-bit quantization" + ) + parser.add_argument( + "--cpu-offloading", + action="store_true", + help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU", + ) + parser.add_argument( + "--gptq-ckpt", + type=str, + default=None, + help="Used for GPTQ. The path to the local GPTQ checkpoint.", + ) + parser.add_argument( + "--gptq-wbits", + type=int, + default=16, + choices=[2, 3, 4, 8, 16], + help="Used for GPTQ. #bits to use for quantization", + ) + parser.add_argument( + "--gptq-groupsize", + type=int, + default=-1, + help="Used for GPTQ. Groupsize to use for quantization; default uses full row.", + ) + parser.add_argument( + "--gptq-act-order", + action="store_true", + help="Used for GPTQ. Whether to apply the activation order GPTQ heuristic", + ) + parser.add_argument( + "--awq-ckpt", + type=str, + default=None, + help="Used for AWQ. Load quantized model. The path to the local AWQ checkpoint.", + ) + parser.add_argument( + "--awq-wbits", + type=int, + default=16, + choices=[4, 16], + help="Used for AWQ. #bits to use for AWQ quantization", + ) + parser.add_argument( + "--awq-groupsize", + type=int, + default=-1, + help="Used for AWQ. Groupsize to use for AWQ quantization; default uses full row.", + ) + parser.add_argument( + "--enable-exllama", + action="store_true", + help="Used for exllamabv2. Enable exllamaV2 inference framework.", + ) + parser.add_argument( + "--exllama-max-seq-len", + type=int, + default=4096, + help="Used for exllamabv2. Max sequence length to use for exllamav2 framework; default 4096 sequence length.", + ) + parser.add_argument( + "--exllama-gpu-split", + type=str, + default=None, + help="Used for exllamabv2. Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7", + ) + parser.add_argument( + "--exllama-cache-8bit", + action="store_true", + help="Used for exllamabv2. Use 8-bit cache to save VRAM.", + ) + parser.add_argument( + "--enable-xft", + action="store_true", + help="Used for xFasterTransformer Enable xFasterTransformer inference framework.", + ) + parser.add_argument( + "--xft-max-seq-len", + type=int, + default=4096, + help="Used for xFasterTransformer. Max sequence length to use for xFasterTransformer framework; default 4096 sequence length.", + ) + parser.add_argument( + "--xft-dtype", + type=str, + choices=["fp16", "bf16", "int8", "bf16_fp16", "bf16_int8"], + help="Override the default dtype. If not set, it will use bfloat16 for first token and float16 next tokens on CPU.", + default=None, + ) + + +def remove_parent_directory_name(model_path): + """Remove parent directory name.""" + if model_path[-1] == "/": + model_path = model_path[:-1] + return model_path.split("/")[-1] + + +peft_model_cache = {} + + +class PeftModelAdapter: + """Loads any "peft" model and it's base model.""" + + def match(self, model_path: str): + """Accepts any model path with "peft" in the name""" + if os.path.exists(os.path.join(model_path, "adapter_config.json")): + return True + return "peft" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + """Loads the base model then the (peft) adapter weights""" + from peft import PeftConfig, PeftModel + + config = PeftConfig.from_pretrained(model_path) + base_model_path = config.base_model_name_or_path + if "peft" in base_model_path: + raise ValueError( + f"PeftModelAdapter cannot load a base model with 'peft' in the name: {config.base_model_name_or_path}" + ) + + # Basic proof of concept for loading peft adapters that share the base + # weights. This is pretty messy because Peft re-writes the underlying + # base model and internally stores a map of adapter layers. + # So, to make this work we: + # 1. Cache the first peft model loaded for a given base models. + # 2. Call `load_model` for any follow on Peft models. + # 3. Make sure we load the adapters by the model_path. Why? This is + # what's accessible during inference time. + # 4. In get_generate_stream_function, make sure we load the right + # adapter before doing inference. This *should* be safe when calls + # are blocked the same semaphore. + if peft_share_base_weights: + if base_model_path in peft_model_cache: + model, tokenizer = peft_model_cache[base_model_path] + # Super important: make sure we use model_path as the + # `adapter_name`. + model.load_adapter(model_path, adapter_name=model_path) + else: + base_adapter = get_model_adapter(base_model_path) + base_model, tokenizer = base_adapter.load_model( + base_model_path, from_pretrained_kwargs + ) + # Super important: make sure we use model_path as the + # `adapter_name`. + model = PeftModel.from_pretrained( + base_model, model_path, adapter_name=model_path + ) + peft_model_cache[base_model_path] = (model, tokenizer) + return model, tokenizer + + # In the normal case, load up the base model weights again. + base_adapter = get_model_adapter(base_model_path) + base_model, tokenizer = base_adapter.load_model( + base_model_path, from_pretrained_kwargs + ) + model = PeftModel.from_pretrained(base_model, model_path) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + """Uses the conv template of the base model""" + from peft import PeftConfig, PeftModel + + config = PeftConfig.from_pretrained(model_path) + if "peft" in config.base_model_name_or_path: + raise ValueError( + f"PeftModelAdapter cannot load a base model with 'peft' in the name: {config.base_model_name_or_path}" + ) + base_model_path = config.base_model_name_or_path + base_adapter = get_model_adapter(base_model_path) + return base_adapter.get_default_conv_template(config.base_model_name_or_path) + + +class VicunaAdapter(BaseModelAdapter): + "Model adapter for Vicuna models (e.g., lmsys/vicuna-7b-v1.5)" "" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "vicuna" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=self.use_fast_tokenizer, revision=revision + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + self.raise_warning_for_old_weights(model) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + if "v0" in remove_parent_directory_name(model_path): + return get_conv_template("one_shot") + return get_conv_template("vicuna_v1.1") + + def raise_warning_for_old_weights(self, model): + if isinstance(model, LlamaForCausalLM) and model.model.vocab_size > 32000: + warnings.warn( + "\nYou are probably using the old Vicuna-v0 model, " + "which will generate unexpected results with the " + "current fastchat.\nYou can try one of the following methods:\n" + "1. Upgrade your weights to the new Vicuna-v1.3: https://github.com/lm-sys/FastChat#vicuna-weights.\n" + "2. Use the old conversation template by `python3 -m fastchat.serve.cli --model-path /path/to/vicuna-v0 --conv-template one_shot`\n" + "3. Downgrade fschat to fschat==0.1.10 (Not recommended).\n" + ) + + +class AiroborosAdapter(BaseModelAdapter): + """The model adapter for jondurbin/airoboros-*""" + + def match(self, model_path: str): + if re.search(r"airoboros|spicyboros", model_path, re.I): + return True + return False + + def get_default_conv_template(self, model_path: str) -> Conversation: + if "-3." in model_path or "-3p" in model_path: + return get_conv_template("airoboros_v3") + if "spicyboros" in model_path or re.search(r"-(2\.[2-9]+)", model_path): + return get_conv_template("airoboros_v2") + return get_conv_template("airoboros_v1") + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + if "mpt" not in model_path.lower(): + return super().load_model(model_path, from_pretrained_kwargs) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + max_seq_len=8192, + **from_pretrained_kwargs, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, use_fast=True + ) + return model, tokenizer + + +class LongChatAdapter(BaseModelAdapter): + "Model adapter for LongChat models (e.g., lmsys/longchat-7b-16k)." + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "longchat" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + + # Apply monkey patch, TODO(Dacheng): Add flash attention support + config = AutoConfig.from_pretrained(model_path, revision=revision) + replace_llama_with_condense(config.rope_scaling["factor"]) + + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=self.use_fast_tokenizer, revision=revision + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("vicuna_v1.1") + + +class GoogleT5Adapter(BaseModelAdapter): + """The model adapter for google/Flan based models, such as Salesforce/codet5p-6b, lmsys/fastchat-t5-3b-v1.0, flan-t5-*, flan-ul2""" + + def match(self, model_path: str): + return any( + model_str in model_path.lower() + for model_str in ["flan-", "fastchat-t5", "codet5p"] + ) + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = T5Tokenizer.from_pretrained(model_path, revision=revision) + model = AutoModelForSeq2SeqLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + +class KoalaAdapter(BaseModelAdapter): + """The model adapter for Koala""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "koala" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("koala_v1") + + +class AlpacaAdapter(BaseModelAdapter): + """The model adapter for Alpaca""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "alpaca" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("alpaca") + + +class ChatGLMAdapter(BaseModelAdapter): + """The model adapter for THUDM/chatglm-6b, THUDM/chatglm2-6b""" + + def match(self, model_path: str): + return "chatglm" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + if "chatglm3" in model_path.lower(): + tokenizer = AutoTokenizer.from_pretrained( + model_path, + encode_special_tokens=True, + trust_remote_code=True, + revision=revision, + ) + else: + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + model = AutoModel.from_pretrained( + model_path, trust_remote_code=True, **from_pretrained_kwargs + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + model_path = model_path.lower() + if "chatglm2" in model_path.lower(): + return get_conv_template("chatglm2") + if "chatglm3" in model_path.lower(): + return get_conv_template("chatglm3") + return get_conv_template("chatglm") + + +class CodeGeexAdapter(BaseModelAdapter): + """The model adapter for THUDM/codegeex-6b, THUDM/codegeex2-6b""" + + def match(self, model_path: str): + return "codegeex" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + model = AutoModel.from_pretrained( + model_path, trust_remote_code=True, **from_pretrained_kwargs + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("codegeex") + + +class DollyV2Adapter(BaseModelAdapter): + """The model adapter for databricks/dolly-v2-12b""" + + def match(self, model_path: str): + return "dolly-v2" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + # 50277 means "### End" + tokenizer.eos_token_id = 50277 + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("dolly_v2") + + +class OasstPythiaAdapter(BaseModelAdapter): + """The model adapter for OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5""" + + def match(self, model_path: str): + model_path = model_path.lower() + return "oasst" in model_path and "pythia" in model_path + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("oasst_pythia") + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + +class OasstLLaMAAdapter(BaseModelAdapter): + """The model adapter for OpenAssistant/oasst-sft-7-llama-30b""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + model_path = model_path.lower() + if "openassistant-sft-7-llama-30b-hf" in model_path: + return True + return "oasst" in model_path and "pythia" not in model_path + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("oasst_llama") + + +class OpenChat35Adapter(BaseModelAdapter): + """The model adapter for OpenChat 3.5 (e.g. openchat/openchat_3.5)""" + + def match(self, model_path: str): + if "openchat" in model_path.lower() and "3.5" in model_path.lower(): + return True + elif "starling-lm" in model_path.lower(): + return True + return False + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("openchat_3.5") + + +class TenyxChatAdapter(BaseModelAdapter): + """The model adapter for TenyxChat (e.g. tenyx/TenyxChat-7B-v1)""" + + def match(self, model_path: str): + return "tenyxchat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("tenyxchat") + + +class PythiaAdapter(BaseModelAdapter): + """The model adapter for any EleutherAI/pythia model""" + + def match(self, model_path: str): + return "pythia" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + +class StableLMAdapter(BaseModelAdapter): + """The model adapter for StabilityAI/stablelm-tuned-alpha-7b""" + + def match(self, model_path: str): + return "stablelm" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("stablelm") + + +class MPTAdapter(BaseModelAdapter): + """The model adapter for MPT series (mosaicml/mpt-7b-chat, mosaicml/mpt-30b-chat)""" + + def match(self, model_path: str): + model_path = model_path.lower() + return "mpt" in model_path and not "airoboros" in model_path + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + max_seq_len=8192, + **from_pretrained_kwargs, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + model_path = model_path.lower() + if "mpt-7b-chat" in model_path: + return get_conv_template("mpt-7b-chat") + elif "mpt-30b-chat" in model_path: + return get_conv_template("mpt-30b-chat") + elif "mpt-30b-instruct" in model_path: + return get_conv_template("mpt-30b-instruct") + else: + print( + "Warning: Loading base MPT model with `zero_shot` conversation configuration. " + "If this is not desired, inspect model configurations and names." + ) + return get_conv_template("zero_shot") + + +class BaizeAdapter(BaseModelAdapter): + """The model adapter for project-baize/baize-v2-7b""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "baize" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("baize") + + +class RwkvAdapter(BaseModelAdapter): + """The model adapter for BlinkDL/RWKV-4-Raven""" + + def match(self, model_path: str): + return "rwkv-4" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + from fastchat.model.rwkv_model import RwkvModel + + model = RwkvModel(model_path) + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + "EleutherAI/pythia-160m", revision=revision + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("rwkv") + + +class OpenBuddyAdapter(BaseModelAdapter): + """The model adapter for OpenBuddy/openbuddy-7b-v1.1-bf16-enc""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "openbuddy" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("openbuddy") + + +class PhoenixAdapter(BaseModelAdapter): + """The model adapter for FreedomIntelligence/phoenix-inst-chat-7b""" + + def match(self, model_path: str): + return "phoenix" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("phoenix") + + +class ReaLMAdapter(BaseModelAdapter): + """The model adapter for FreedomIntelligence/ReaLM-7b""" + + def match(self, model_path: str): + return "ReaLM" in model_path + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) + model = AutoModelForCausalLM.from_pretrained( + model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("ReaLM-7b-v1") + + +class ChatGPTAdapter(BaseModelAdapter): + """The model adapter for ChatGPT""" + + def match(self, model_path: str): + return model_path in OPENAI_MODEL_LIST + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("chatgpt") + + +class AzureOpenAIAdapter(BaseModelAdapter): + """The model adapter for Azure OpenAI""" + + def match(self, model_path: str): + return model_path in ("azure-gpt-35-turbo", "azure-gpt-4") + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("chatgpt") + + +class PplxAIAdapter(BaseModelAdapter): + """The model adapter for Perplexity AI""" + + def match(self, model_path: str): + return model_path in ( + "pplx-7b-online", + "pplx-70b-online", + ) + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("pplxai") + + +class ClaudeAdapter(BaseModelAdapter): + """The model adapter for Claude""" + + def match(self, model_path: str): + return model_path in ANTHROPIC_MODEL_LIST + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("claude") + + +class BardAdapter(BaseModelAdapter): + """The model adapter for Bard""" + + def match(self, model_path: str): + return model_path == "bard" + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("bard") + + +class PaLM2Adapter(BaseModelAdapter): + """The model adapter for PaLM2""" + + def match(self, model_path: str): + return model_path == "palm-2" + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("bard") + + +class GeminiAdapter(BaseModelAdapter): + """The model adapter for Gemini""" + + def match(self, model_path: str): + return model_path in ["gemini-pro"] + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + raise NotImplementedError() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("bard") + + +class BiLLaAdapter(BaseModelAdapter): + """The model adapter for Neutralzz/BiLLa-7B-SFT""" + + def match(self, model_path: str): + return "billa" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("billa") + + +class RedPajamaINCITEAdapter(BaseModelAdapter): + """The model adapter for togethercomputer/RedPajama-INCITE-7B-Chat""" + + def match(self, model_path: str): + return "redpajama-incite" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("redpajama-incite") + + +class H2OGPTAdapter(BaseModelAdapter): + """The model adapter for h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "h2ogpt" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("h2ogpt") + + +class RobinAdapter(BaseModelAdapter): + """The model adapter for LMFlow/Full-Robin-7b-v2""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "robin" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("Robin") + + +class SnoozyAdapter(BaseModelAdapter): + """The model adapter for nomic-ai/gpt4all-13b-snoozy""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + model_path = model_path.lower() + return "gpt4all" in model_path and "snoozy" in model_path + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("snoozy") + + +class WizardLMAdapter(BaseModelAdapter): + """The model adapter for WizardLM/WizardLM-13B-V1.0""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "wizardlm" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + model_path = model_path.lower() + if "13b" in model_path or "30b" in model_path or "70b" in model_path: + return get_conv_template("vicuna_v1.1") + else: + # TODO: use the recommended template for 7B + # (https://huggingface.co/WizardLM/WizardLM-13B-V1.0) + return get_conv_template("one_shot") + + +class ManticoreAdapter(BaseModelAdapter): + """The model adapter for openaccess-ai-collective/manticore-13b-chat-pyg""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "manticore" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("manticore") + + +class GuanacoAdapter(BaseModelAdapter): + """The model adapter for timdettmers/guanaco-33b-merged""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "guanaco" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=self.use_fast_tokenizer, revision=revision + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs + ) + # Fix a bug in tokenizer config + tokenizer.eos_token_id = model.config.eos_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("zero_shot") + + +class ChangGPTAdapter(BaseModelAdapter): + """The model adapter for lcw99/polyglot-ko-12.8b-chang-instruct-chat""" + + def match(self, model_path: str): + model_path = model_path.lower() + return "polyglot" in model_path and "chang" in model_path + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("polyglot_changgpt") + + +class CamelAdapter(BaseModelAdapter): + """The model adapter for camel-ai/CAMEL-13B-Combined-Data""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "camel" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("vicuna_v1.1") + + +class TuluAdapter(BaseModelAdapter): + """The model adapter for allenai/tulu-30b""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "tulu" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("tulu") + + +class FalconAdapter(BaseModelAdapter): + """The model adapter for tiiuae/falcon-40b""" + + def match(self, model_path: str): + return "falcon" in model_path.lower() and "chat" not in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + # Strongly suggest using bf16, which is recommended by the author of Falcon + tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + # In Falcon tokenizer config and special config there is not any pad token + # Setting `pad_token_id` to 9, which corresponds to special token '>>SUFFIX<<' + tokenizer.pad_token_id = 9 + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("falcon") + + +class FalconChatAdapter(BaseModelAdapter): + def match(self, model_path: str): + return "falcon" in model_path.lower() and "chat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("falcon-chat") + + +class TigerBotAdapter(BaseModelAdapter): + """The model adapter for TigerResearch/tigerbot-7b-sft""" + + def match(self, model_path: str): + return "tigerbot" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, + trust_remote_code=True, + revision=revision, + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("tigerbot") + + +class BaichuanAdapter(BaseModelAdapter): + """The model adapter for Baichuan models (e.g., baichuan-inc/Baichuan-7B)""" + + def match(self, model_path: str): + return "baichuan" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + # for Baichuan-13B-Chat + if "chat" in model_path.lower(): + if "baichuan2" in model_path.lower(): + return get_conv_template("baichuan2-chat") + return get_conv_template("baichuan-chat") + return get_conv_template("zero_shot") + + +class XGenAdapter(BaseModelAdapter): + """The model adapter for Salesforce/xgen-7b""" + + def match(self, model_path: str): + return "xgen" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + model.config.eos_token_id = 50256 + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("xgen") + + +class NousHermesAdapter(BaseModelAdapter): + """The model adapter for NousResearch/Nous-Hermes-13b""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "nous-hermes" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("alpaca") + + +class InternLMChatAdapter(BaseModelAdapter): + """The model adapter for internlm/internlm-chat-7b""" + + def match(self, model_path: str): + return "internlm" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + model = model.eval() + if "8k" in model_path.lower(): + model.config.max_sequence_length = 8192 + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("internlm-chat") + + +class StarChatAdapter(BaseModelAdapter): + """The model adapter for HuggingFaceH4/starchat-beta""" + + def match(self, model_path: str): + return "starchat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("starchat") + + +class MistralAdapter(BaseModelAdapter): + """The model adapter for Mistral AI models""" + + def match(self, model_path: str): + return "mistral" in model_path.lower() or "mixtral" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("mistral") + + +class Llama2Adapter(BaseModelAdapter): + """The model adapter for Llama-2 (e.g., meta-llama/Llama-2-7b-hf)""" + + def match(self, model_path: str): + return "llama-2" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("llama-2") + + +class CuteGPTAdapter(BaseModelAdapter): + """The model adapter for CuteGPT""" + + def match(self, model_path: str): + return "cutegpt" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + tokenizer = LlamaTokenizer.from_pretrained(model_path) + model = AutoModelForCausalLM.from_pretrained( + model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs + ) + tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids("") + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.eos_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("cutegpt") + + +class OpenOrcaAdapter(BaseModelAdapter): + """Model adapter for Open-Orca models which may use different prompt templates + - (e.g. Open-Orca/OpenOrcaxOpenChat-Preview2-13B, Open-Orca/Mistral-7B-OpenOrca) + - `OpenOrcaxOpenChat-Preview2-13B` uses their "OpenChat Llama2 V1" prompt template. + - [Open-Orca/OpenOrcaxOpenChat-Preview2-13B #Prompt Template](https://huggingface.co/Open-Orca/OpenOrcaxOpenChat-Preview2-13B#prompt-template) + - `Mistral-7B-OpenOrca` uses the [OpenAI's Chat Markup Language (ChatML)](https://github.com/openai/openai-python/blob/main/chatml.md) + format, with <|im_start|> and <|im_end|> tokens added to support this. + - [Open-Orca/Mistral-7B-OpenOrca #Prompt Template](https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca#prompt-template) + """ + + use_fast_tokenizer = False + + def match(self, model_path: str): + return ( + "mistral-7b-openorca" in model_path.lower() + or "openorca" in model_path.lower() + ) + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=self.use_fast_tokenizer, revision=revision + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ).eval() + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + if "mistral-7b-openorca" in model_path.lower(): + return get_conv_template("mistral-7b-openorca") + return get_conv_template("open-orca") + + +class DolphinAdapter(OpenOrcaAdapter): + """Model adapter for ehartford/dolphin-2.2.1-mistral-7b""" + + def match(self, model_path: str): + return "dolphin" in model_path.lower() and "mistral" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("dolphin-2.2.1-mistral-7b") + + +class Hermes2Adapter(BaseModelAdapter): + """Model adapter for teknium/OpenHermes-2.5-Mistral-7B and teknium/OpenHermes-2-Mistral-7B models""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return any( + model_str in model_path.lower() + for model_str in ["openhermes-2.5-mistral-7b", "openhermes-2-mistral-7b"] + ) + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=self.use_fast_tokenizer, revision=revision + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ).eval() + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("OpenHermes-2.5-Mistral-7B") + + +class NousHermes2MixtralAdapter(BaseModelAdapter): + """Model adapter for NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO model""" + + def match(self, model_path: str): + return any( + model_str in model_path.lower() + for model_str in [ + "nous-hermes-2-mixtral-8x7b-dpo", + "nous-hermes-2-mixtral-8x7b-sft", + ] + ) + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("Nous-Hermes-2-Mixtral-8x7B-DPO") + + +class WizardCoderAdapter(BaseModelAdapter): + """The model adapter for WizardCoder (e.g., WizardLM/WizardCoder-Python-34B-V1.0)""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "wizardcoder" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + # Same as Alpaca, see : + # https://github.com/nlpxucan/WizardLM/blob/main/WizardCoder/src/inference_wizardcoder.py#L60 + return get_conv_template("alpaca") + + +class QwenChatAdapter(BaseModelAdapter): + """The model adapter for Qwen/Qwen-7B-Chat + To run this model, you need to ensure additional flash attention installation: + ``` bash + git clone https://github.com/Dao-AILab/flash-attention + cd flash-attention && pip install . + pip install csrc/layer_norm + pip install csrc/rotary + ``` + + Since from 2.0, the following change happened + - `flash_attn_unpadded_func` -> `flash_attn_varlen_func` + - `flash_attn_unpadded_qkvpacked_func` -> `flash_attn_varlen_qkvpacked_func` + - `flash_attn_unpadded_kvpacked_func` -> `flash_attn_varlen_kvpacked_func` + You may need to revise the code in: https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/modeling_qwen.py#L69 + to from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func + """ + + def match(self, model_path: str): + return "qwen" in model_path.lower() + + def float_set(self, config, option): + config.bf16 = False + config.fp16 = False + config.fp32 = False + + if option == "bf16": + config.bf16 = True + elif option == "fp16": + config.fp16 = True + elif option == "fp32": + config.fp32 = True + else: + print("Invalid option. Please choose one from 'bf16', 'fp16' and 'fp32'.") + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + from transformers.generation import GenerationConfig + + revision = from_pretrained_kwargs.get("revision", "main") + config = AutoConfig.from_pretrained( + model_path, + trust_remote_code=True, + ) + # NOTE: if you use the old version of model file, please remove the comments below + # config.use_flash_attn = False + self.float_set(config, "fp16") + generation_config = GenerationConfig.from_pretrained( + model_path, trust_remote_code=True + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + config=config, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ).eval() + if hasattr(model.config, "use_dynamic_ntk") and model.config.use_dynamic_ntk: + model.config.max_sequence_length = 16384 + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + tokenizer.eos_token_id = config.eos_token_id + tokenizer.bos_token_id = config.bos_token_id + tokenizer.pad_token_id = generation_config.pad_token_id + model.config.eos_token_id = tokenizer.eos_token_id + model.config.bos_token_id = tokenizer.bos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("qwen-7b-chat") + + +class BGEAdapter(BaseModelAdapter): + """The model adapter for BGE (e.g., BAAI/bge-large-en-v1.5)""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "bge" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + model = AutoModel.from_pretrained( + model_path, + **from_pretrained_kwargs, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + if hasattr(model.config, "max_position_embeddings") and hasattr( + tokenizer, "model_max_length" + ): + model.config.max_sequence_length = min( + model.config.max_position_embeddings, tokenizer.model_max_length + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("one_shot") + + +class E5Adapter(BaseModelAdapter): + """The model adapter for E5 (e.g., intfloat/e5-large-v2)""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "e5-" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + model = AutoModel.from_pretrained( + model_path, + **from_pretrained_kwargs, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + if hasattr(model.config, "max_position_embeddings") and hasattr( + tokenizer, "model_max_length" + ): + model.config.max_sequence_length = min( + model.config.max_position_embeddings, tokenizer.model_max_length + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("one_shot") + + +class AquilaChatAdapter(BaseModelAdapter): + """The model adapter for BAAI/Aquila + + Now supports: + - BAAI/AquilaChat-7B + - BAAI/AquilaChat2-7B + - BAAI/AquilaChat2-34B + """ + + def match(self, model_path: str): + return "aquila" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + model = AutoModelForCausalLM.from_pretrained( + model_path, + low_cpu_mem_usage=True, + trust_remote_code=True, + **from_pretrained_kwargs, + ) + model = model.eval() + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + model_path = model_path.lower() + # See: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L347 + if "aquilachat2" in model_path: + if "16k" in model_path: + return get_conv_template("aquila") + elif "34b" in model_path: + return get_conv_template("aquila-legacy") + else: + return get_conv_template("aquila-v1") + else: + return get_conv_template("aquila-chat") + + +class Lamma2ChineseAdapter(BaseModelAdapter): + """The model adapter for FlagAlpha/LLama2-Chinese sft""" + + def match(self, model_path: str): + return "llama2-chinese" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, + trust_remote_code=True, + revision=revision, + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("llama2-chinese") + + +class Lamma2ChineseAlpacaAdapter(BaseModelAdapter): + """The model adapter for ymcui/Chinese-LLaMA-Alpaca sft""" + + def match(self, model_path: str): + return "chinese-alpaca" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, + trust_remote_code=True, + revision=revision, + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("chinese-alpaca2") + + +class VigogneAdapter(BaseModelAdapter): + """The model adapter for vigogne (e.g., bofenghuang/vigogne-2-7b-chat)""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return bool(re.search(r"vigogne|vigostral", model_path, re.I)) + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, + use_fast=self.use_fast_tokenizer, + trust_remote_code=True, + revision=revision, + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ).eval() + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + if "chat" in model_path.lower(): + if "vigostral" in model_path.lower(): + return get_conv_template("vigogne_chat_v3") + return get_conv_template("vigogne_chat_v2") + return get_conv_template("vigogne_instruct") + + +class OpenLLaMaOpenInstructAdapter(BaseModelAdapter): + """The model adapter for OpenLLaMa-Open-Instruct (e.g., VMware/open-llama-7b-open-instruct)""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return ( + "open-llama" in model_path.lower() and "open-instruct" in model_path.lower() + ) + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, + use_fast=self.use_fast_tokenizer, + trust_remote_code=True, + revision=revision, + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ).eval() + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("alpaca") + + +class CodeLlamaAdapter(BaseModelAdapter): + """The model adapter for CodeLlama (e.g., codellama/CodeLlama-34b-hf)""" + + def match(self, model_path: str): + return "codellama" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("llama-2") + + +class StableVicunaAdapter(BaseModelAdapter): + """The model adapter for StableVicuna""" + + def match(self, model_path: str): + return "stable-vicuna" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("stable-vicuna") + + +class PhindCodeLlamaAdapter(CodeLlamaAdapter): + """The model adapter for Phind-CodeLlama (e.g., Phind/Phind-CodeLlama-34B-v2)""" + + def match(self, model_path: str): + return "phind-codellama-" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("phind") + + +class Llama2ChangAdapter(Llama2Adapter): + """The model adapter for Llama2-ko-chang (e.g., lcw99/llama2-ko-chang-instruct-chat)""" + + def match(self, model_path: str): + return "llama2-ko-chang" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("polyglot_changgpt") + + +class ZephyrAdapter(BaseModelAdapter): + """The model adapter for Zephyr (e.g. HuggingFaceH4/zephyr-7b-alpha)""" + + def match(self, model_path: str): + return "zephyr" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("zephyr") + + +class NotusAdapter(BaseModelAdapter): + """The model adapter for Notus (e.g. argilla/notus-7b-v1)""" + + def match(self, model_path: str): + return "notus" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("zephyr") + + +class CatPPTAdapter(BaseModelAdapter): + """The model adapter for CatPPT (e.g. rishiraj/CatPPT)""" + + def match(self, model_path: str): + return "catppt" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("catppt") + + +class TinyLlamaAdapter(BaseModelAdapter): + """The model adapter for TinyLlama (e.g. TinyLlama/TinyLlama-1.1B-Chat-v1.0)""" + + def match(self, model_path: str): + return "tinyllama" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("TinyLlama") + + +class XwinLMAdapter(BaseModelAdapter): + """The model adapter for Xwin-LM V0.1 and V0.2 series of models(e.g., Xwin-LM/Xwin-LM-70B-V0.1)""" + + # use_fast_tokenizer = False + + def match(self, model_path: str): + return "xwin-lm" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("vicuna_v1.1") + + +class LemurAdapter(BaseModelAdapter): + """The model adapter for OpenLemur/lemur-70b-chat-v1""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "lemur-70b-chat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("lemur-70b-chat") + + +class PygmalionAdapter(BaseModelAdapter): + """The model adapter for Pygmalion/Metharme series of models(e.g., PygmalionAI/mythalion-13b)""" + + # use_fast_tokenizer = False + + def match(self, model_path: str): + return bool( + re.search(r"pygmalion|mythalion|metharme", model_path.lower(), re.I) + ) + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("metharme") + + +class XdanAdapter(BaseModelAdapter): + """The model adapter for xDAN-AI (e.g. xDAN-AI/xDAN-L1-Chat-RL-v1)""" + + def match(self, model_path: str): + return "xdan" in model_path.lower() and "v1" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("xdan-v1") + + +class MicrosoftOrcaAdapter(BaseModelAdapter): + """The model adapter for Microsoft/Orca-2 series of models (e.g. Microsoft/Orca-2-7b, Microsoft/Orca-2-13b)""" + + use_fast_tokenizer = False # Flag neeeded since tokenizers>=0.13.3 is required for a normal functioning of this module + + def match(self, model_path: str): + return "orca-2" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("orca-2") + + +class YiAdapter(BaseModelAdapter): + """The model adapter for Yi models""" + + def match(self, model_path: str): + return "yi-" in model_path.lower() and "chat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("Yi-34b-chat") + + +class DeepseekCoderAdapter(BaseModelAdapter): + """The model adapter for deepseek-ai's coder models""" + + def match(self, model_path: str): + return "deepseek-coder" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("deepseek-coder") + + +class DeepseekChatAdapter(BaseModelAdapter): + """The model adapter for deepseek-ai's chat models""" + + # Note: that this model will require tokenizer version >= 0.13.3 because the tokenizer class is LlamaTokenizerFast + + def match(self, model_path: str): + return "deepseek-llm" in model_path.lower() and "chat" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("deepseek-chat") + + +class MetaMathAdapter(BaseModelAdapter): + """The model adapter for MetaMath models""" + + def match(self, model_path: str): + return "metamath" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("metamath") + + +class BagelAdapter(BaseModelAdapter): + """Model adapter for jondurbin/bagel-* models""" + + def match(self, model_path: str): + return "bagel" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("airoboros_v3") + + +class SolarAdapter(BaseModelAdapter): + """The model adapter for upstage/SOLAR-10.7B-Instruct-v1.0""" + + def match(self, model_path: str): + return "solar-" in model_path.lower() and "instruct" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("solar") + + +class Yuan2Adapter(BaseModelAdapter): + """The model adapter for Yuan""" + + def match(self, model_path: str): + return "yuan" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + tokenizer.add_tokens( + [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + ], + special_tokens=True, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("yuan") + + +# Note: the registration order matters. +# The one registered earlier has a higher matching priority. +register_model_adapter(PeftModelAdapter) +register_model_adapter(StableVicunaAdapter) +register_model_adapter(VicunaAdapter) +register_model_adapter(AiroborosAdapter) +register_model_adapter(LongChatAdapter) +register_model_adapter(GoogleT5Adapter) +register_model_adapter(KoalaAdapter) +register_model_adapter(AlpacaAdapter) +register_model_adapter(ChatGLMAdapter) +register_model_adapter(CodeGeexAdapter) +register_model_adapter(DollyV2Adapter) +register_model_adapter(OasstPythiaAdapter) +register_model_adapter(OasstLLaMAAdapter) +register_model_adapter(OpenChat35Adapter) +register_model_adapter(TenyxChatAdapter) +register_model_adapter(StableLMAdapter) +register_model_adapter(BaizeAdapter) +register_model_adapter(RwkvAdapter) +register_model_adapter(OpenBuddyAdapter) +register_model_adapter(PhoenixAdapter) +register_model_adapter(BardAdapter) +register_model_adapter(PaLM2Adapter) +register_model_adapter(GeminiAdapter) +register_model_adapter(ChatGPTAdapter) +register_model_adapter(AzureOpenAIAdapter) +register_model_adapter(ClaudeAdapter) +register_model_adapter(MPTAdapter) +register_model_adapter(BiLLaAdapter) +register_model_adapter(RedPajamaINCITEAdapter) +register_model_adapter(H2OGPTAdapter) +register_model_adapter(RobinAdapter) +register_model_adapter(SnoozyAdapter) +register_model_adapter(WizardLMAdapter) +register_model_adapter(ManticoreAdapter) +register_model_adapter(GuanacoAdapter) +register_model_adapter(CamelAdapter) +register_model_adapter(ChangGPTAdapter) +register_model_adapter(TuluAdapter) +register_model_adapter(FalconChatAdapter) +register_model_adapter(FalconAdapter) +register_model_adapter(TigerBotAdapter) +register_model_adapter(BaichuanAdapter) +register_model_adapter(XGenAdapter) +register_model_adapter(PythiaAdapter) +register_model_adapter(InternLMChatAdapter) +register_model_adapter(StarChatAdapter) +register_model_adapter(Llama2Adapter) +register_model_adapter(CuteGPTAdapter) +register_model_adapter(OpenOrcaAdapter) +register_model_adapter(DolphinAdapter) +register_model_adapter(Hermes2Adapter) +register_model_adapter(NousHermes2MixtralAdapter) +register_model_adapter(NousHermesAdapter) +register_model_adapter(MistralAdapter) +register_model_adapter(WizardCoderAdapter) +register_model_adapter(QwenChatAdapter) +register_model_adapter(AquilaChatAdapter) +register_model_adapter(BGEAdapter) +register_model_adapter(E5Adapter) +register_model_adapter(Lamma2ChineseAdapter) +register_model_adapter(Lamma2ChineseAlpacaAdapter) +register_model_adapter(VigogneAdapter) +register_model_adapter(OpenLLaMaOpenInstructAdapter) +register_model_adapter(ReaLMAdapter) +register_model_adapter(PhindCodeLlamaAdapter) +register_model_adapter(CodeLlamaAdapter) +register_model_adapter(Llama2ChangAdapter) +register_model_adapter(ZephyrAdapter) +register_model_adapter(NotusAdapter) +register_model_adapter(CatPPTAdapter) +register_model_adapter(TinyLlamaAdapter) +register_model_adapter(XwinLMAdapter) +register_model_adapter(LemurAdapter) +register_model_adapter(PygmalionAdapter) +register_model_adapter(MicrosoftOrcaAdapter) +register_model_adapter(XdanAdapter) +register_model_adapter(YiAdapter) +register_model_adapter(PplxAIAdapter) +register_model_adapter(DeepseekCoderAdapter) +register_model_adapter(DeepseekChatAdapter) +register_model_adapter(MetaMathAdapter) +register_model_adapter(BagelAdapter) +register_model_adapter(SolarAdapter) +register_model_adapter(Yuan2Adapter) + +# After all adapters, try the default base adapter. +register_model_adapter(BaseModelAdapter) diff --git a/third_party/FastChat/fastchat/model/model_chatglm.py b/third_party/FastChat/fastchat/model/model_chatglm.py new file mode 100644 index 00000000..2cbac8bc --- /dev/null +++ b/third_party/FastChat/fastchat/model/model_chatglm.py @@ -0,0 +1,137 @@ +""" +Inference code for ChatGLM. +Adapted from https://huggingface.co/THUDM/chatglm-6b/blob/main/modeling_chatglm.py. +""" +import re + +import torch +from transformers.generation.logits_process import LogitsProcessor + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor + ) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 5] = 5e4 + return scores + + +invalid_score_processor = InvalidScoreLogitsProcessor() + + +def process_response(response): + response = response.strip() + response = response.replace("[[训练时间]]", "2023年") + punkts = [ + [",", ","], + ["!", "!"], + [":", ":"], + [";", ";"], + ["\?", "?"], + ] + for item in punkts: + response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response) + response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response) + return response + + +def recover_message_list(prompt): + role_token_pattern = "|".join( + [re.escape(r) for r in ["<|system|>", "<|user|>", "<|assistant|>"]] + ) + role = None + last_end_idx = -1 + message_list = [] + for match in re.finditer(role_token_pattern, prompt): + if role: + messge = {} + if role == "<|system|>": + messge["role"] = "system" + elif role == "<|user|>": + messge["role"] = "user" + else: + messge["role"] = "assistant" + messge["content"] = prompt[last_end_idx + 1 : match.start()] + message_list.append(messge) + + role = prompt[match.start() : match.end()] + last_end_idx = match.end() + + return message_list + + +@torch.inference_mode() +def generate_stream_chatglm( + model, + tokenizer, + params, + device, + context_len=2048, + stream_interval=2, + judge_sent_end=False, +): + prompt = params["prompt"] + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + max_new_tokens = int(params.get("max_new_tokens", 256)) + echo = params.get("echo", True) + + model_type = str(type(model)).lower() + if "peft" in model_type: + model_type = str(type(model.base_model.model)).lower() + + if "chatglm3" in model_type: + message_list = recover_message_list(prompt) + inputs = tokenizer.build_chat_input( + query=message_list[-1]["content"], history=message_list[:-1], role="user" + ).to(model.device) + else: + inputs = tokenizer([prompt], return_tensors="pt").to(model.device) + input_echo_len = len(inputs["input_ids"][0]) + + gen_kwargs = { + "max_length": max_new_tokens + input_echo_len, + "do_sample": True if temperature > 1e-5 else False, + "top_p": top_p, + "repetition_penalty": repetition_penalty, + "logits_processor": [invalid_score_processor], + } + if temperature > 1e-5: + gen_kwargs["temperature"] = temperature + + total_len = 0 + for total_ids in model.stream_generate(**inputs, **gen_kwargs): + total_ids = total_ids.tolist()[0] + total_len = len(total_ids) + if echo: + output_ids = total_ids + else: + output_ids = total_ids[input_echo_len:] + response = tokenizer.decode(output_ids) + response = process_response(response) + + yield { + "text": response, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": total_len - input_echo_len, + "total_tokens": total_len, + }, + "finish_reason": None, + } + + # TODO: ChatGLM stop when it reach max length + # Only last stream result contains finish_reason, we set finish_reason as stop + ret = { + "text": response, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": total_len - input_echo_len, + "total_tokens": total_len, + }, + "finish_reason": "stop", + } + yield ret diff --git a/third_party/FastChat/fastchat/model/model_codet5p.py b/third_party/FastChat/fastchat/model/model_codet5p.py new file mode 100644 index 00000000..0984513c --- /dev/null +++ b/third_party/FastChat/fastchat/model/model_codet5p.py @@ -0,0 +1,108 @@ +import gc +from threading import Thread +import torch +import transformers +from transformers import ( + GenerationConfig, + StoppingCriteria, + StoppingCriteriaList, + TextIteratorStreamer, +) + + +@torch.inference_mode() +def generate_stream_codet5p( + model, + tokenizer, + params, + device, + context_len=2048, + stream_interval=2, + judge_sent_end=False, +): + prompt = params["prompt"] + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = int(params.get("top_k", 50)) # -1 means disable + max_new_tokens = int(params.get("max_new_tokens", 1024)) + stop_token_ids = params.get("stop_token_ids", None) or [] + stop_token_ids.append(tokenizer.eos_token_id) + + decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) + streamer = TextIteratorStreamer(tokenizer, **decode_config) + encoding = tokenizer(prompt, return_tensors="pt").to(device) + input_ids = encoding.input_ids + encoding["decoder_input_ids"] = encoding["input_ids"].clone() + input_echo_len = len(input_ids) + + generation_config = GenerationConfig( + max_new_tokens=max_new_tokens, + do_sample=temperature >= 1e-5, + temperature=temperature, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=10, + top_p=top_p, + top_k=top_k, + eos_token_id=stop_token_ids, + ) + + class CodeBlockStopper(StoppingCriteria): + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs + ) -> bool: + # Code-completion is open-end generation. + # We check \n\n to stop at end of a code block. + if list(input_ids[0][-2:]) == [628, 198]: + return True + return False + + gen_kwargs = dict( + **encoding, + streamer=streamer, + generation_config=generation_config, + stopping_criteria=StoppingCriteriaList([CodeBlockStopper()]), + ) + thread = Thread(target=model.generate, kwargs=gen_kwargs) + thread.start() + i = 0 + output = "" + for new_text in streamer: + i += 1 + output += new_text + if i % stream_interval == 0 or i == max_new_tokens - 1: + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": None, + } + if i >= max_new_tokens: + break + + if i >= max_new_tokens: + finish_reason = "length" + else: + finish_reason = "stop" + + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": finish_reason, + } + thread.join() + + # clean + gc.collect() + torch.cuda.empty_cache() + if device == "xpu": + torch.xpu.empty_cache() + if device == "npu": + torch.npu.empty_cache() diff --git a/third_party/FastChat/fastchat/model/model_exllama.py b/third_party/FastChat/fastchat/model/model_exllama.py new file mode 100644 index 00000000..306edab2 --- /dev/null +++ b/third_party/FastChat/fastchat/model/model_exllama.py @@ -0,0 +1,77 @@ +import gc +import sys +from typing import Dict + +import torch + + +def generate_stream_exllama( + model, + tokenizer, + params: Dict, + device: str, + context_len: int, + stream_interval: int = 2, + judge_sent_end: bool = False, +): + try: + from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler + except ImportError as e: + print(f"Error: Failed to load Exllamav2. {e}") + sys.exit(-1) + + prompt = params["prompt"] + + generator = ExLlamaV2StreamingGenerator(model.model, model.cache, tokenizer) + settings = ExLlamaV2Sampler.Settings() + + settings.temperature = float(params.get("temperature", 0.85)) + settings.top_k = int(params.get("top_k", 50)) + settings.top_p = float(params.get("top_p", 0.8)) + settings.token_repetition_penalty = float(params.get("repetition_penalty", 1.15)) + settings.disallow_tokens(generator.tokenizer, [generator.tokenizer.eos_token_id]) + + max_new_tokens = int(params.get("max_new_tokens", 256)) + + generator.set_stop_conditions(params.get("stop_token_ids", None) or []) + echo = bool(params.get("echo", True)) + + input_ids = generator.tokenizer.encode(prompt) + prompt_tokens = input_ids.shape[-1] + generator.begin_stream(input_ids, settings) + + generated_tokens = 0 + if echo: + output = prompt + else: + output = "" + while True: + chunk, eos, _ = generator.stream() + output += chunk + generated_tokens += 1 + if generated_tokens == max_new_tokens: + finish_reason = "length" + break + elif eos: + finish_reason = "length" + break + yield { + "text": output, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": generated_tokens, + "total_tokens": prompt_tokens + generated_tokens, + }, + "finish_reason": None, + } + + yield { + "text": output, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": generated_tokens, + "total_tokens": prompt_tokens + generated_tokens, + }, + "finish_reason": finish_reason, + } + gc.collect() diff --git a/third_party/FastChat/fastchat/model/model_falcon.py b/third_party/FastChat/fastchat/model/model_falcon.py new file mode 100644 index 00000000..dc8af8ef --- /dev/null +++ b/third_party/FastChat/fastchat/model/model_falcon.py @@ -0,0 +1,140 @@ +import gc +from threading import Thread +from typing import Iterable + +import torch +import transformers +from transformers import TextIteratorStreamer, GenerationConfig + +from fastchat.utils import is_partial_stop + + +@torch.inference_mode() +def generate_stream_falcon( + model, + tokenizer, + params, + device, + context_len=2048, + stream_interval=2, + judge_sent_end=False, +): + prompt = params["prompt"] + len_prompt = len(prompt) + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = int(params.get("top_k", 50)) # -1 means disable + max_new_tokens = int(params.get("max_new_tokens", 256)) + stop_str = params.get("stop", None) + echo = bool(params.get("echo", True)) + stop_token_ids = params.get("stop_token_ids", None) or [] + stop_token_ids.append(tokenizer.eos_token_id) + + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + + max_src_len = context_len - max_new_tokens - 8 + + input_ids = input_ids[-max_src_len:] # truncate from the left + attention_mask = attention_mask[-max_src_len:] # truncate from the left + input_echo_len = len(input_ids) + + decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config) + + generation_config = GenerationConfig( + max_new_tokens=max_new_tokens, + do_sample=temperature >= 1e-5, + temperature=temperature, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=10, + top_p=top_p, + top_k=top_k, + eos_token_id=stop_token_ids, + ) + + generation_kwargs = dict( + inputs=input_ids, + attention_mask=attention_mask, + streamer=streamer, + generation_config=generation_config, + ) + + thread = Thread(target=model.generate, kwargs=generation_kwargs) + thread.start() + + if echo: + # means keep the prompt + output = prompt + else: + output = "" + + for i, new_text in enumerate(streamer): + output += new_text + if i % stream_interval == 0: + if echo: + rfind_start = len_prompt + else: + rfind_start = 0 + + partially_stopped = False + if stop_str: + if isinstance(stop_str, str): + pos = output.rfind(stop_str, rfind_start) + if pos != -1: + output = output[:pos] + else: + partially_stopped = is_partial_stop(output, stop_str) + elif isinstance(stop_str, Iterable): + for each_stop in stop_str: + pos = output.rfind(each_stop, rfind_start) + if pos != -1: + output = output[:pos] + break + else: + partially_stopped = is_partial_stop(output, each_stop) + if partially_stopped: + break + else: + raise ValueError("Invalid stop field type.") + + # prevent yielding partial stop sequence + if not partially_stopped: + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": None, + } + output = output.strip() + + # finish stream event, which contains finish reason + if i == max_new_tokens - 1: + finish_reason = "length" + elif partially_stopped: + finish_reason = None + else: + finish_reason = "stop" + + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": finish_reason, + } + + # clean + gc.collect() + torch.cuda.empty_cache() + if device == "xpu": + torch.xpu.empty_cache() + if device == "npu": + torch.npu.empty_cache() diff --git a/third_party/FastChat/fastchat/model/model_registry.py b/third_party/FastChat/fastchat/model/model_registry.py new file mode 100644 index 00000000..0e8b650f --- /dev/null +++ b/third_party/FastChat/fastchat/model/model_registry.py @@ -0,0 +1,552 @@ +"""Additional information of the models.""" +from collections import namedtuple, OrderedDict +from typing import List + + +ModelInfo = namedtuple("ModelInfo", ["simple_name", "link", "description"]) + + +model_info = OrderedDict() + + +def register_model_info( + full_names: List[str], simple_name: str, link: str, description: str +): + info = ModelInfo(simple_name, link, description) + + for full_name in full_names: + model_info[full_name] = info + + +def get_model_info(name: str) -> ModelInfo: + if name in model_info: + return model_info[name] + else: + # To fix this, please use `register_model_info` to register your model + return ModelInfo( + name, "", "Register the description at fastchat/model/model_registry.py" + ) + + +register_model_info( + ["mixtral-8x7b-instruct-v0.1", "mistral-7b-instruct"], + "Mixtral of experts", + "https://mistral.ai/news/mixtral-of-experts/", + "A Mixture-of-Experts model by Mistral AI", +) + +register_model_info( + ["gemini-pro"], + "Gemini", + "https://blog.google/technology/ai/google-gemini-pro-imagen-duet-ai-update/", + "Gemini by Google", +) + +register_model_info( + ["solar-10.7b-instruct-v1.0"], + "SOLAR-10.7B-Instruct", + "https://huggingface.co/upstage/SOLAR-10.7B-Instruct-v1.0", + "A model trained using depth up-scaling by Upstage AI", +) + +register_model_info( + ["gpt-4-turbo"], + "GPT-4-Turbo", + "https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo", + "GPT-4-Turbo by OpenAI", +) + +register_model_info( + ["gpt-3.5-turbo", "gpt-3.5-turbo-0314", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106"], + "GPT-3.5", + "https://platform.openai.com/docs/models/gpt-3-5", + "GPT-3.5-Turbo by OpenAI", +) + +register_model_info( + ["gpt-4", "gpt-4-0314", "gpt-4-0613"], + "GPT-4", + "https://openai.com/research/gpt-4", + "GPT-4 by OpenAI", +) + +register_model_info( + ["claude-2.1", "claude-2.0"], + "Claude", + "https://www.anthropic.com/index/claude-2", + "Claude 2 by Anthropic", +) + +register_model_info( + ["claude-1"], + "Claude", + "https://www.anthropic.com/index/introducing-claude", + "Claude 1 by Anthropic", +) + +register_model_info( + ["claude-instant-1", "claude-instant-1.2"], + "Claude Instant", + "https://www.anthropic.com/index/introducing-claude", + "Claude Instant by Anthropic", +) + +register_model_info( + ["pplx-70b-online", "pplx-7b-online"], + "pplx-online-llms", + "https://blog.perplexity.ai/blog/introducing-pplx-online-llms", + "Online LLM API by Perplexity AI", +) + +register_model_info( + ["openhermes-2.5-mistral-7b"], + "OpenHermes-2.5-Mistral-7B", + "https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B", + "a mistral-based model fine-tuned on 1M GPT-4 outputs", +) + +register_model_info( + ["starling-lm-7b-alpha"], + "Starling-LM-7B-alpha", + "https://huggingface.co/berkeley-nest/Starling-LM-7B-alpha", + "an open model trained using RLAIF by Berkeley", +) + +register_model_info( + ["tulu-2-dpo-70b"], + "Tulu 2", + "https://huggingface.co/allenai/tulu-2-dpo-70b", + "an instruction and RLHF model by UW/AllenAI", +) + +register_model_info( + ["yi-34b-chat", "yi-6b-chat"], + "Yi-Chat", + "https://huggingface.co/01-ai/Yi-34B-Chat", + "A large language model by 01 AI", +) + +register_model_info( + ["llama-2-70b-chat", "llama-2-34b-chat", "llama-2-13b-chat", "llama-2-7b-chat"], + "Llama 2", + "https://ai.meta.com/llama/", + "open foundation and fine-tuned chat models by Meta", +) + +register_model_info( + [ + "vicuna-33b", + "vicuna-33b-v1.3", + "vicuna-13b", + "vicuna-13b-v1.3", + "vicuna-7b", + "vicuna-7b-v1.3", + ], + "Vicuna", + "https://lmsys.org/blog/2023-03-30-vicuna/", + "a chat assistant fine-tuned on user-shared conversations by LMSYS", +) + +register_model_info( + ["chatglm3-6b", "chatglm2-6b", "chatglm-6b"], + "ChatGLM", + "https://chatglm.cn/blog", + "an open bilingual dialogue language model by Tsinghua University", +) + +register_model_info( + ["openchat-3.5"], + "OpenChat 3.5", + "https://github.com/imoneoi/openchat", + "an open model fine-tuned on Mistral-7B using C-RLFT", +) + +register_model_info( + ["tenyxchat-7b-v1"], + "TenyxChat-7B", + "https://huggingface.co/tenyx/TenyxChat-7B-v1", + "an open model DPO trained on top of OpenChat-3.5 using Tenyx fine-tuning", +) + +register_model_info( + ["zephyr-7b-beta", "zephyr-7b-alpha"], + "Zephyr", + "https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha", + "a chatbot fine-tuned from Mistral by Hugging Face", +) + +register_model_info( + ["notus-7b-v1"], + "Notus", + "https://huggingface.co/argilla/notus-7b-v1", + "a chatbot fine-tuned from Zephyr SFT by Argilla", +) + +register_model_info( + ["catppt"], + "CatPPT", + "https://huggingface.co/rishiraj/CatPPT", + "a chatbot fine-tuned from a SLERP merged model by Rishiraj Acharya", +) + +register_model_info( + ["TinyLlama"], + "TinyLlama", + "https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "The TinyLlama project is an open endeavor to pretrain a 1.1B Llama model on 3 trillion tokens.", +) + +register_model_info( + ["qwen-14b-chat"], + "Qwen", + "https://huggingface.co/Qwen/Qwen-14B-Chat", + "a large language model by Alibaba Cloud", +) + +register_model_info( + ["codellama-34b-instruct", "codellama-13b-instruct", "codellama-7b-instruct"], + "Code Llama", + "https://ai.meta.com/blog/code-llama-large-language-model-coding/", + "open foundation models for code by Meta", +) + +register_model_info( + ["wizardlm-70b", "wizardlm-30b", "wizardlm-13b"], + "WizardLM", + "https://github.com/nlpxucan/WizardLM", + "an instruction-following LLM using evol-instruct by Microsoft", +) + +register_model_info( + ["wizardcoder-15b-v1.0"], + "WizardLM", + "https://github.com/nlpxucan/WizardLM/tree/main/WizardCoder", + "Empowering Code Large Language Models with Evol-Instruct", +) + +register_model_info( + ["mpt-7b-chat", "mpt-30b-chat"], + "MPT-Chat", + "https://www.mosaicml.com/blog/mpt-30b", + "a chatbot fine-tuned from MPT by MosaicML", +) + +register_model_info( + ["guanaco-33b", "guanaco-65b"], + "Guanaco", + "https://github.com/artidoro/qlora", + "a model fine-tuned with QLoRA by UW", +) + +register_model_info( + ["gpt4all-13b-snoozy"], + "GPT4All-Snoozy", + "https://github.com/nomic-ai/gpt4all", + "a finetuned LLaMA model on assistant style data by Nomic AI", +) + +register_model_info( + ["koala-13b"], + "Koala", + "https://bair.berkeley.edu/blog/2023/04/03/koala", + "a dialogue model for academic research by BAIR", +) + +register_model_info( + ["RWKV-4-Raven-14B"], + "RWKV-4-Raven", + "https://huggingface.co/BlinkDL/rwkv-4-raven", + "an RNN with transformer-level LLM performance", +) + +register_model_info( + ["alpaca-13b"], + "Alpaca", + "https://crfm.stanford.edu/2023/03/13/alpaca.html", + "a model fine-tuned from LLaMA on instruction-following demonstrations by Stanford", +) + +register_model_info( + ["oasst-pythia-12b"], + "OpenAssistant (oasst)", + "https://open-assistant.io", + "an Open Assistant for everyone by LAION", +) + +register_model_info( + ["oasst-sft-7-llama-30b"], + "OpenAssistant (oasst)", + "https://open-assistant.io", + "an Open Assistant for everyone by LAION", +) + +register_model_info( + ["palm-2"], + "PaLM 2 Chat", + "https://cloud.google.com/vertex-ai/docs/release-notes#May_10_2023", + "PaLM 2 for Chat (chat-bison@001) by Google", +) + +register_model_info( + ["llama-7b", "llama-13b"], + "LLaMA", + "https://arxiv.org/abs/2302.13971", + "open and efficient foundation language models by Meta", +) + +register_model_info( + ["open-llama-7b-v2-open-instruct", "open-llama-7b-open-instruct"], + "Open LLaMa (Open Instruct)", + "https://medium.com/vmware-data-ml-blog/starter-llm-for-the-enterprise-instruction-tuning-openllama-7b-d05fc3bbaccc", + "Open LLaMa fine-tuned on instruction-following data by VMware", +) + +register_model_info( + ["dolly-v2-12b"], + "Dolly", + "https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm", + "an instruction-tuned open large language model by Databricks", +) + +register_model_info( + ["stablelm-tuned-alpha-7b"], + "StableLM", + "https://github.com/stability-AI/stableLM", + "Stability AI language models", +) + +register_model_info( + ["codet5p-6b"], + "CodeT5p-6b", + "https://huggingface.co/Salesforce/codet5p-6b", + "Code completion model released by Salesforce", +) + +register_model_info( + ["fastchat-t5-3b", "fastchat-t5-3b-v1.0"], + "FastChat-T5", + "https://huggingface.co/lmsys/fastchat-t5-3b-v1.0", + "a chat assistant fine-tuned from FLAN-T5 by LMSYS", +) + +register_model_info( + ["phoenix-inst-chat-7b"], + "Phoenix-7B", + "https://huggingface.co/FreedomIntelligence/phoenix-inst-chat-7b", + "a multilingual chat assistant fine-tuned from Bloomz to democratize ChatGPT across languages by CUHK(SZ)", +) + +register_model_info( + ["realm-7b-v1"], + "ReaLM", + "https://github.com/FreedomIntelligence/ReaLM", + "A chatbot fine-tuned from LLaMA2 with data generated via iterative calls to UserGPT and ChatGPT by CUHK(SZ) and SRIBD.", +) + +register_model_info( + ["billa-7b-sft"], + "BiLLa-7B-SFT", + "https://huggingface.co/Neutralzz/BiLLa-7B-SFT", + "an instruction-tuned bilingual LLaMA with enhanced reasoning ability by an independent researcher", +) + +register_model_info( + ["h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2"], + "h2oGPT-GM-7b", + "https://huggingface.co/h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2", + "an instruction-tuned OpenLLaMA with enhanced conversational ability by H2O.ai", +) + +register_model_info( + ["baize-v2-7b", "baize-v2-13b"], + "Baize v2", + "https://github.com/project-baize/baize-chatbot#v2", + "A chatbot fine-tuned from LLaMA with ChatGPT self-chat data and Self-Disillation with Feedback (SDF) by UCSD and SYSU.", +) + +register_model_info( + [ + "airoboros-l2-7b-2.1", + "airoboros-l2-13b-2.1", + "airoboros-c34b-2.1", + "airoboros-l2-70b-2.1", + ], + "airoboros", + "https://huggingface.co/jondurbin/airoboros-l2-70b-2.1", + "an instruction-tuned LlaMa model tuned with 100% synthetic instruction-response pairs from GPT4", +) + +register_model_info( + [ + "spicyboros-7b-2.2", + "spicyboros-13b-2.2", + "spicyboros-70b-2.2", + ], + "spicyboros", + "https://huggingface.co/jondurbin/spicyboros-70b-2.2", + "de-aligned versions of the airoboros models", +) + +register_model_info( + ["Robin-7b-v2", "Robin-13b-v2", "Robin-33b-v2"], + "Robin-v2", + "https://huggingface.co/OptimalScale/robin-7b-v2-delta", + "A chatbot fine-tuned from LLaMA-7b, achieving competitive performance on chitchat, commonsense reasoning and instruction-following tasks, by OptimalScale, HKUST.", +) + +register_model_info( + ["manticore-13b-chat"], + "Manticore 13B Chat", + "https://huggingface.co/openaccess-ai-collective/manticore-13b-chat-pyg", + "A chatbot fine-tuned from LlaMa across several CoT and chat datasets.", +) + +register_model_info( + ["redpajama-incite-7b-chat"], + "RedPajama-INCITE-7B-Chat", + "https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Chat", + "A chatbot fine-tuned from RedPajama-INCITE-7B-Base by Together", +) + +register_model_info( + [ + "falcon-7b", + "falcon-7b-instruct", + "falcon-40b", + "falcon-40b-instruct", + "falcon-180b", + "falcon-180b-chat", + ], + "Falcon", + "https://huggingface.co/tiiuae/falcon-180B", + "TII's flagship series of large language models", +) + +register_model_info( + ["tigerbot-7b-sft"], + "Tigerbot", + "https://huggingface.co/TigerResearch/tigerbot-7b-sft", + "TigerBot is a large-scale language model (LLM) with multiple languages and tasks.", +) + +register_model_info( + ["internlm-chat-7b", "internlm-chat-7b-8k"], + "InternLM", + "https://huggingface.co/internlm/internlm-chat-7b", + "InternLM is a multi-language large-scale language model (LLM), developed by SHLAB.", +) + +register_model_info( + ["Qwen-7B-Chat"], + "Qwen", + "https://huggingface.co/Qwen/Qwen-7B-Chat", + "Qwen is a multi-language large-scale language model (LLM), developed by Damo Academy.", +) + +register_model_info( + ["Llama2-Chinese-13b-Chat", "LLama2-Chinese-13B"], + "Llama2-Chinese", + "https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat", + "Llama2-Chinese is a multi-language large-scale language model (LLM), developed by FlagAlpha.", +) + +register_model_info( + ["Chinese-Alpaca-2-7B", "Chinese-Alpaca-2-13B"], + "Chinese-Alpaca", + "https://huggingface.co/hfl/chinese-alpaca-2-13b", + "New extended Chinese vocabulary beyond Llama-2, open-sourcing the Chinese LLaMA-2 and Alpaca-2 LLMs.", +) + +register_model_info( + ["Vigogne-2-7B-Instruct", "Vigogne-2-13B-Instruct"], + "Vigogne-Instruct", + "https://huggingface.co/bofenghuang/vigogne-2-7b-instruct", + "Vigogne-Instruct is a French large language model (LLM) optimized for instruction-following, developed by Bofeng Huang", +) + +register_model_info( + ["Vigogne-2-7B-Chat", "Vigogne-2-13B-Chat"], + "Vigogne-Chat", + "https://huggingface.co/bofenghuang/vigogne-2-7b-chat", + "Vigogne-Chat is a French large language model (LLM) optimized for instruction-following and multi-turn dialogues, developed by Bofeng Huang", +) + +register_model_info( + ["stable-vicuna-13B-HF"], + "stable-vicuna", + "https://huggingface.co/TheBloke/stable-vicuna-13B-HF", + "StableVicuna is a Vicuna model fine-tuned using RLHF via PPO on various conversational and instructional datasets.", +) + +register_model_info( + ["deluxe-chat-v1", "deluxe-chat-v1.1", "deluxe-chat-v1.2"], + "DeluxeChat", + "", + "Deluxe Chat", +) + +register_model_info( + [ + "Xwin-LM-7B-V0.1", + "Xwin-LM-13B-V0.1", + "Xwin-LM-70B-V0.1", + "Xwin-LM-7B-V0.2", + "Xwin-LM-13B-V0.2", + ], + "Xwin-LM", + "https://github.com/Xwin-LM/Xwin-LM", + "Chat models developed by Xwin-LM team", +) + +register_model_info( + ["lemur-70b-chat"], + "Lemur-Chat", + "https://huggingface.co/OpenLemur/lemur-70b-chat-v1", + "an openly accessible language model optimized for both natural language and coding capabilities ", +) + +register_model_info( + ["Mistral-7B-OpenOrca"], + "Open-Orca", + "https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca", + "A fine-tune of [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1) using [OpenOrca dataset](https://huggingface.co/datasets/Open-Orca/OpenOrca)", +) + +register_model_info( + ["dolphin-2.2.1-mistral-7b"], + "dolphin-mistral", + "https://huggingface.co/ehartford/dolphin-2.2.1-mistral-7b", + "An uncensored fine-tuned Mistral 7B", +) + +register_model_info( + [ + "AquilaChat-7B", + "AquilaChat2-7B", + "AquilaChat2-34B", + ], + "Aquila-Chat", + "https://huggingface.co/BAAI/AquilaChat2-34B", + "Chat models developed by BAAI team", +) + +register_model_info( + ["xDAN-L1-Chat-RL-v1"], + "xDAN-L1-Chat", + "https://huggingface.co/xDAN-AI/xDAN-L1-Chat-RL-v1", + "A large language chat model created by xDAN-AI.", +) + +register_model_info( + ["MetaMath-70B-V1.0", "MetaMath-7B-V1.0"], + "MetaMath", + "https://huggingface.co/meta-math", + "MetaMath is a finetune of Llama2 on [MetaMathQA](https://huggingface.co/datasets/meta-math/MetaMathQA) that specializes in mathematical reasoning.", +) + +register_model_info( + ["Yuan2-2B-hf", "Yuan2-51B-hf", "Yuan2-102B-hf"], + "IEIYuan", + "https://huggingface.co/IEITYuan", + "Yuan2 is a Basemodel developed by IEI.", +) diff --git a/third_party/FastChat/fastchat/model/model_xfastertransformer.py b/third_party/FastChat/fastchat/model/model_xfastertransformer.py new file mode 100644 index 00000000..54890b1c --- /dev/null +++ b/third_party/FastChat/fastchat/model/model_xfastertransformer.py @@ -0,0 +1,81 @@ +import gc +from threading import Thread + +import torch +from transformers import TextIteratorStreamer + + +@torch.inference_mode() +def generate_stream_xft( + model, + tokenizer, + params, + device, + context_len=8192, + stream_interval=2, + judge_sent_end=False, +): + prompt = params["prompt"] + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + + # unused now, and placehold for future. + # temperature = float(params.get("temperature", 1.0)) + # top_p = float(params.get("top_p", 1.0)) + + max_new_tokens = int(params.get("max_new_tokens", 4096)) + echo = params.get("echo", True) + + inputs = tokenizer( + prompt, return_tensors="pt", padding=model.config.padding + ).input_ids + input_echo_len = len(inputs[0]) + max_len = max_new_tokens + input_echo_len + + decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config) + generation_kwargs = { + "input_ids": inputs, + "streamer": streamer, + "max_length": max_len, + "num_beams": model.config.beam_width, + "length_penalty": repetition_penalty, + "num_return_sequences": model.config.num_return_sequences, + "early_stopping": model.config.early_stopping, + "eos_token_id": model.config.eos_token_id, + "pad_token_id": model.config.pad_token_id, + } + + thread = Thread(target=model.model.generate, kwargs=generation_kwargs) + thread.start() + if echo: + # means keep the prompt + output = prompt + else: + output = "" + i = 0 + for i, new_text in enumerate(streamer): + output += new_text + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": None, + } + output = output.strip() + if i == max_new_tokens - 1: + finish_reason = "length" + else: + finish_reason = "stop" + yield { + "text": output, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": finish_reason, + } + gc.collect() diff --git a/third_party/FastChat/fastchat/model/monkey_patch_non_inplace.py b/third_party/FastChat/fastchat/model/monkey_patch_non_inplace.py new file mode 100644 index 00000000..413dd3b3 --- /dev/null +++ b/third_party/FastChat/fastchat/model/monkey_patch_non_inplace.py @@ -0,0 +1,119 @@ +""" +Monkey patch the llama implementation in the huggingface/transformers library. +Avoid bugs in mps backend by not using in-place operations. +""" +import math +from typing import List, Optional, Tuple + +import torch +from torch import nn +import transformers + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2].clone() + x2 = x[..., x.shape[-1] // 2 :].clone() + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] + gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) + cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( + self.head_dim + ) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) + ) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def replace_llama_attn_with_non_inplace_operations(): + """Avoid bugs in mps backend by not using in-place operations.""" + transformers.models.llama.modeling_llama.LlamaAttention.forward = forward diff --git a/third_party/FastChat/fastchat/model/rwkv_model.py b/third_party/FastChat/fastchat/model/rwkv_model.py new file mode 100644 index 00000000..bdbc1458 --- /dev/null +++ b/third_party/FastChat/fastchat/model/rwkv_model.py @@ -0,0 +1,76 @@ +import os +from types import SimpleNamespace +import warnings + +import torch + +os.environ["RWKV_JIT_ON"] = "1" +os.environ["RWKV_CUDA_ON"] = "1" + +from rwkv.model import RWKV +from rwkv.utils import PIPELINE, PIPELINE_ARGS + + +class RwkvModel: + def __init__(self, model_path): + warnings.warn( + "Experimental support. Please use ChatRWKV if you want to chat with RWKV" + ) + self.config = SimpleNamespace(is_encoder_decoder=False) + self.model = RWKV(model=model_path, strategy="cuda fp16") + # two GPUs + # self.model = RWKV(model=model_path, strategy="cuda:0 fp16 *20 -> cuda:1 fp16") + + self.tokenizer = None + self.model_path = model_path + + def to(self, target): + assert target == "cuda" + + def __call__(self, input_ids, use_cache, past_key_values=None): + assert use_cache == True + input_ids = input_ids[0].detach().cpu().numpy() + # print(input_ids) + logits, state = self.model.forward(input_ids, past_key_values) + # print(logits) + logits = logits.unsqueeze(0).unsqueeze(0) + out = SimpleNamespace(logits=logits, past_key_values=state) + return out + + def generate( + self, input_ids, do_sample, temperature, max_new_tokens, repetition_penalty=1.0 + ): + # This function is used by fastchat.llm_judge. + # Because RWKV does not support huggingface generation API, + # we reuse fastchat.serve.inference.generate_stream as a workaround. + from transformers import AutoTokenizer + + from fastchat.serve.inference import generate_stream + from fastchat.conversation import get_conv_template + + if self.tokenizer is None: + self.tokenizer = AutoTokenizer.from_pretrained( + "EleutherAI/pythia-160m", use_fast=True + ) + prompt = self.tokenizer.decode(input_ids[0].tolist()) + conv = get_conv_template("rwkv") + + gen_params = { + "model": self.model_path, + "prompt": prompt, + "temperature": temperature, + "repetition_penalty": repetition_penalty, + "max_new_tokens": max_new_tokens, + "stop": conv.stop_str, + "stop_token_ids": conv.stop_token_ids, + "echo": False, + } + res_iter = generate_stream(self, self.tokenizer, gen_params, "cuda") + + for res in res_iter: + pass + + output = res["text"] + output_ids = self.tokenizer.encode(output) + + return [input_ids[0].tolist() + output_ids] diff --git a/third_party/FastChat/fastchat/model/upload_hub.py b/third_party/FastChat/fastchat/model/upload_hub.py new file mode 100644 index 00000000..b1519652 --- /dev/null +++ b/third_party/FastChat/fastchat/model/upload_hub.py @@ -0,0 +1,45 @@ +""" +Upload weights to huggingface. + +Usage: +python3 -m fastchat.model.upload_hub --model-path ~/model_weights/vicuna-13b --hub-repo-id lmsys/vicuna-13b-v1.3 +""" +import argparse +import tempfile + +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM + + +def upload_hub(model_path, hub_repo_id, component, private): + if component == "all": + components = ["model", "tokenizer"] + else: + components = [component] + + kwargs = {"push_to_hub": True, "repo_id": hub_repo_id, "private": args.private} + + if "model" in components: + model = AutoModelForCausalLM.from_pretrained( + model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + with tempfile.TemporaryDirectory() as tmp_path: + model.save_pretrained(tmp_path, **kwargs) + + if "tokenizer" in components: + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + with tempfile.TemporaryDirectory() as tmp_path: + tokenizer.save_pretrained(tmp_path, **kwargs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, required=True) + parser.add_argument("--hub-repo-id", type=str, required=True) + parser.add_argument( + "--component", type=str, choices=["all", "model", "tokenizer"], default="all" + ) + parser.add_argument("--private", action="store_true") + args = parser.parse_args() + + upload_hub(args.model_path, args.hub_repo_id, args.component, args.private) diff --git a/third_party/FastChat/fastchat/modules/__init__.py b/third_party/FastChat/fastchat/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/third_party/FastChat/fastchat/modules/awq.py b/third_party/FastChat/fastchat/modules/awq.py new file mode 100644 index 00000000..1f27be85 --- /dev/null +++ b/third_party/FastChat/fastchat/modules/awq.py @@ -0,0 +1,85 @@ +from dataclasses import dataclass, field +from pathlib import Path +import sys + +import torch +from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, modeling_utils + + +@dataclass +class AWQConfig: + ckpt: str = field( + default=None, + metadata={ + "help": "Load quantized model. The path to the local AWQ checkpoint." + }, + ) + wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"}) + groupsize: int = field( + default=-1, + metadata={"help": "Groupsize to use for quantization; default uses full row."}, + ) + + +def load_awq_quantized(model_name, awq_config: AWQConfig, device): + print("Loading AWQ quantized model...") + + try: + from tinychat.utils import load_quant + from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp + except ImportError as e: + print(f"Error: Failed to import tinychat. {e}") + print("Please double check if you have successfully installed AWQ") + print("See https://github.com/lm-sys/FastChat/blob/main/docs/awq.md") + sys.exit(-1) + + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + model_name, use_fast=False, trust_remote_code=True + ) + + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.kaiming_normal_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + modeling_utils._init_weights = False + + torch.set_default_dtype(torch.half) + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + + if any(name in find_awq_ckpt(awq_config) for name in ["llama", "vicuna"]): + model = load_quant.load_awq_llama_fast( + model, + find_awq_ckpt(awq_config), + awq_config.wbits, + awq_config.groupsize, + device, + ) + make_quant_attn(model, device) + make_quant_norm(model) + make_fused_mlp(model) + else: + model = load_quant.load_awq_model( + model, + find_awq_ckpt(awq_config), + awq_config.wbits, + awq_config.groupsize, + device, + ) + return model, tokenizer + + +def find_awq_ckpt(awq_config: AWQConfig): + if Path(awq_config.ckpt).is_file(): + return awq_config.ckpt + + for ext in ["*.pt", "*.safetensors"]: + matched_result = sorted(Path(awq_config.ckpt).glob(ext)) + if len(matched_result) > 0: + return str(matched_result[-1]) + + print("Error: AWQ checkpoint not found") + sys.exit(1) diff --git a/third_party/FastChat/fastchat/modules/exllama.py b/third_party/FastChat/fastchat/modules/exllama.py new file mode 100644 index 00000000..5e5fc81b --- /dev/null +++ b/third_party/FastChat/fastchat/modules/exllama.py @@ -0,0 +1,50 @@ +from dataclasses import dataclass, field +import sys + + +@dataclass +class ExllamaConfig: + max_seq_len: int + gpu_split: str = None + cache_8bit: bool = False + + +class ExllamaModel: + def __init__(self, exllama_model, exllama_cache): + self.model = exllama_model + self.cache = exllama_cache + self.config = self.model.config + + +def load_exllama_model(model_path, exllama_config: ExllamaConfig): + try: + from exllamav2 import ( + ExLlamaV2Config, + ExLlamaV2Tokenizer, + ExLlamaV2, + ExLlamaV2Cache, + ExLlamaV2Cache_8bit, + ) + except ImportError as e: + print(f"Error: Failed to load Exllamav2. {e}") + sys.exit(-1) + + exllamav2_config = ExLlamaV2Config() + exllamav2_config.model_dir = model_path + exllamav2_config.prepare() + exllamav2_config.max_seq_len = exllama_config.max_seq_len + exllamav2_config.cache_8bit = exllama_config.cache_8bit + + exllama_model = ExLlamaV2(exllamav2_config) + tokenizer = ExLlamaV2Tokenizer(exllamav2_config) + + split = None + if exllama_config.gpu_split: + split = [float(alloc) for alloc in exllama_config.gpu_split.split(",")] + exllama_model.load(split) + + cache_class = ExLlamaV2Cache_8bit if exllamav2_config.cache_8bit else ExLlamaV2Cache + exllama_cache = cache_class(exllama_model) + model = ExllamaModel(exllama_model=exllama_model, exllama_cache=exllama_cache) + + return model, tokenizer diff --git a/third_party/FastChat/fastchat/modules/gptq.py b/third_party/FastChat/fastchat/modules/gptq.py new file mode 100644 index 00000000..fe0a220c --- /dev/null +++ b/third_party/FastChat/fastchat/modules/gptq.py @@ -0,0 +1,75 @@ +from dataclasses import dataclass, field +import os +from os.path import isdir, isfile +from pathlib import Path +import sys + +from transformers import AutoTokenizer + + +@dataclass +class GptqConfig: + ckpt: str = field( + default=None, + metadata={ + "help": "Load quantized model. The path to the local GPTQ checkpoint." + }, + ) + wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"}) + groupsize: int = field( + default=-1, + metadata={"help": "Groupsize to use for quantization; default uses full row."}, + ) + act_order: bool = field( + default=True, + metadata={"help": "Whether to apply the activation order GPTQ heuristic"}, + ) + + +def load_gptq_quantized(model_name, gptq_config: GptqConfig): + print("Loading GPTQ quantized model...") + + try: + script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) + module_path = os.path.join(script_path, "../repositories/GPTQ-for-LLaMa") + + sys.path.insert(0, module_path) + from llama import load_quant + except ImportError as e: + print(f"Error: Failed to load GPTQ-for-LLaMa. {e}") + print("See https://github.com/lm-sys/FastChat/blob/main/docs/gptq.md") + sys.exit(-1) + + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) + # only `fastest-inference-4bit` branch cares about `act_order` + if gptq_config.act_order: + model = load_quant( + model_name, + find_gptq_ckpt(gptq_config), + gptq_config.wbits, + gptq_config.groupsize, + act_order=gptq_config.act_order, + ) + else: + # other branches + model = load_quant( + model_name, + find_gptq_ckpt(gptq_config), + gptq_config.wbits, + gptq_config.groupsize, + ) + + return model, tokenizer + + +def find_gptq_ckpt(gptq_config: GptqConfig): + if Path(gptq_config.ckpt).is_file(): + return gptq_config.ckpt + + for ext in ["*.pt", "*.safetensors"]: + matched_result = sorted(Path(gptq_config.ckpt).glob(ext)) + if len(matched_result) > 0: + return str(matched_result[-1]) + + print("Error: gptq checkpoint not found") + sys.exit(1) diff --git a/third_party/FastChat/fastchat/modules/xfastertransformer.py b/third_party/FastChat/fastchat/modules/xfastertransformer.py new file mode 100644 index 00000000..0b49bea4 --- /dev/null +++ b/third_party/FastChat/fastchat/modules/xfastertransformer.py @@ -0,0 +1,46 @@ +from dataclasses import dataclass +import sys + + +@dataclass +class XftConfig: + max_seq_len: int = 4096 + beam_width: int = 1 + eos_token_id: int = -1 + pad_token_id: int = -1 + num_return_sequences: int = 1 + is_encoder_decoder: bool = False + padding: bool = True + early_stopping: bool = False + data_type: str = "bf16_fp16" + + +class XftModel: + def __init__(self, xft_model, xft_config): + self.model = xft_model + self.config = xft_config + + +def load_xft_model(model_path, xft_config: XftConfig): + try: + import xfastertransformer + from transformers import AutoTokenizer + except ImportError as e: + print(f"Error: Failed to load xFasterTransformer. {e}") + sys.exit(-1) + + if xft_config.data_type is None or xft_config.data_type == "": + data_type = "bf16_fp16" + else: + data_type = xft_config.data_type + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=False, padding_side="left", trust_remote_code=True + ) + xft_model = xfastertransformer.AutoModel.from_pretrained( + model_path, dtype=data_type + ) + model = XftModel(xft_model=xft_model, xft_config=xft_config) + if model.model.rank > 0: + while True: + model.model.generate() + return model, tokenizer diff --git a/third_party/FastChat/fastchat/protocol/api_protocol.py b/third_party/FastChat/fastchat/protocol/api_protocol.py new file mode 100644 index 00000000..2dc99449 --- /dev/null +++ b/third_party/FastChat/fastchat/protocol/api_protocol.py @@ -0,0 +1,172 @@ +from typing import Literal, Optional, List, Dict, Any, Union + +import time + +import shortuuid +from pydantic import BaseModel, Field + + +class ErrorResponse(BaseModel): + object: str = "error" + message: str + code: int + + +class ModelPermission(BaseModel): + id: str = Field(default_factory=lambda: f"modelperm-{shortuuid.random()}") + object: str = "model_permission" + created: int = Field(default_factory=lambda: int(time.time())) + allow_create_engine: bool = False + allow_sampling: bool = True + allow_logprobs: bool = True + allow_search_indices: bool = True + allow_view: bool = True + allow_fine_tuning: bool = False + organization: str = "*" + group: Optional[str] = None + is_blocking: str = False + + +class ModelCard(BaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "fastchat" + root: Optional[str] = None + parent: Optional[str] = None + permission: List[ModelPermission] = [] + + +class ModelList(BaseModel): + object: str = "list" + data: List[ModelCard] = [] + + +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + + +class APIChatCompletionRequest(BaseModel): + model: str + messages: Union[str, List[Dict[str, str]]] + temperature: Optional[float] = 0.7 + top_p: Optional[float] = 1.0 + top_k: Optional[int] = -1 + n: Optional[int] = 1 + max_tokens: Optional[int] = None + stop: Optional[Union[str, List[str]]] = None + stream: Optional[bool] = False + user: Optional[str] = None + repetition_penalty: Optional[float] = 1.0 + frequency_penalty: Optional[float] = 0.0 + presence_penalty: Optional[float] = 0.0 + + +class ChatMessage(BaseModel): + role: str + content: str + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessage + finish_reason: Optional[Literal["stop", "length"]] = None + + +class ChatCompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") + object: str = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseChoice] + usage: UsageInfo + + +class DeltaMessage(BaseModel): + role: Optional[str] = None + content: Optional[str] = None + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + finish_reason: Optional[Literal["stop", "length"]] = None + + +class ChatCompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") + object: str = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseStreamChoice] + + +class APITokenCheckRequestItem(BaseModel): + model: str + prompt: str + max_tokens: int + + +class APITokenCheckRequest(BaseModel): + prompts: List[APITokenCheckRequestItem] + + +class APITokenCheckResponseItem(BaseModel): + fits: bool + tokenCount: int + contextLength: int + + +class APITokenCheckResponse(BaseModel): + prompts: List[APITokenCheckResponseItem] + + +class CompletionRequest(BaseModel): + model: str + prompt: Union[str, List[Any]] + suffix: Optional[str] = None + temperature: Optional[float] = 0.7 + n: Optional[int] = 1 + max_tokens: Optional[int] = 16 + stop: Optional[Union[str, List[str]]] = None + stream: Optional[bool] = False + top_p: Optional[float] = 1.0 + top_k: Optional[int] = -1 + logprobs: Optional[int] = None + echo: Optional[bool] = False + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + user: Optional[str] = None + + +class CompletionResponseChoice(BaseModel): + index: int + text: str + logprobs: Optional[int] = None + finish_reason: Optional[Literal["stop", "length"]] = None + + +class CompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseChoice] + usage: UsageInfo + + +class CompletionResponseStreamChoice(BaseModel): + index: int + text: str + logprobs: Optional[float] = None + finish_reason: Optional[Literal["stop", "length"]] = None + + +class CompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseStreamChoice] diff --git a/third_party/FastChat/fastchat/protocol/openai_api_protocol.py b/third_party/FastChat/fastchat/protocol/openai_api_protocol.py new file mode 100644 index 00000000..6a006339 --- /dev/null +++ b/third_party/FastChat/fastchat/protocol/openai_api_protocol.py @@ -0,0 +1,195 @@ +from typing import Literal, Optional, List, Dict, Any, Union + +import time + +import shortuuid +from pydantic import BaseModel, Field + + +class ErrorResponse(BaseModel): + object: str = "error" + message: str + code: int + + +class ModelPermission(BaseModel): + id: str = Field(default_factory=lambda: f"modelperm-{shortuuid.random()}") + object: str = "model_permission" + created: int = Field(default_factory=lambda: int(time.time())) + allow_create_engine: bool = False + allow_sampling: bool = True + allow_logprobs: bool = True + allow_search_indices: bool = True + allow_view: bool = True + allow_fine_tuning: bool = False + organization: str = "*" + group: Optional[str] = None + is_blocking: str = False + + +class ModelCard(BaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "fastchat" + root: Optional[str] = None + parent: Optional[str] = None + permission: List[ModelPermission] = [] + + +class ModelList(BaseModel): + object: str = "list" + data: List[ModelCard] = [] + + +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + + +class LogProbs(BaseModel): + text_offset: List[int] = Field(default_factory=list) + token_logprobs: List[Optional[float]] = Field(default_factory=list) + tokens: List[str] = Field(default_factory=list) + top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) + + +class ChatCompletionRequest(BaseModel): + model: str + messages: Union[str, List[Dict[str, str]]] + temperature: Optional[float] = 0.7 + top_p: Optional[float] = 1.0 + top_k: Optional[int] = -1 + n: Optional[int] = 1 + max_tokens: Optional[int] = None + stop: Optional[Union[str, List[str]]] = None + stream: Optional[bool] = False + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + user: Optional[str] = None + + +class ChatMessage(BaseModel): + role: str + content: str + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessage + finish_reason: Optional[Literal["stop", "length"]] = None + + +class ChatCompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") + object: str = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseChoice] + usage: UsageInfo + + +class DeltaMessage(BaseModel): + role: Optional[str] = None + content: Optional[str] = None + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + finish_reason: Optional[Literal["stop", "length"]] = None + + +class ChatCompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") + object: str = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseStreamChoice] + + +class TokenCheckRequestItem(BaseModel): + model: str + prompt: str + max_tokens: int + + +class TokenCheckRequest(BaseModel): + prompts: List[TokenCheckRequestItem] + + +class TokenCheckResponseItem(BaseModel): + fits: bool + tokenCount: int + contextLength: int + + +class TokenCheckResponse(BaseModel): + prompts: List[TokenCheckResponseItem] + + +class EmbeddingsRequest(BaseModel): + model: Optional[str] = None + engine: Optional[str] = None + input: Union[str, List[Any]] + user: Optional[str] = None + encoding_format: Optional[str] = None + + +class EmbeddingsResponse(BaseModel): + object: str = "list" + data: List[Dict[str, Any]] + model: str + usage: UsageInfo + + +class CompletionRequest(BaseModel): + model: str + prompt: Union[str, List[Any]] + suffix: Optional[str] = None + temperature: Optional[float] = 0.7 + n: Optional[int] = 1 + max_tokens: Optional[int] = 16 + stop: Optional[Union[str, List[str]]] = None + stream: Optional[bool] = False + top_p: Optional[float] = 1.0 + top_k: Optional[int] = -1 + logprobs: Optional[int] = None + echo: Optional[bool] = False + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + user: Optional[str] = None + use_beam_search: Optional[bool] = False + best_of: Optional[int] = None + + +class CompletionResponseChoice(BaseModel): + index: int + text: str + logprobs: Optional[LogProbs] = None + finish_reason: Optional[Literal["stop", "length"]] = None + + +class CompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseChoice] + usage: UsageInfo + + +class CompletionResponseStreamChoice(BaseModel): + index: int + text: str + logprobs: Optional[LogProbs] = None + finish_reason: Optional[Literal["stop", "length"]] = None + + +class CompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseStreamChoice] diff --git a/third_party/FastChat/fastchat/serve/__init__.py b/third_party/FastChat/fastchat/serve/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/third_party/FastChat/fastchat/serve/api_provider.py b/third_party/FastChat/fastchat/serve/api_provider.py new file mode 100644 index 00000000..6f3ad7dd --- /dev/null +++ b/third_party/FastChat/fastchat/serve/api_provider.py @@ -0,0 +1,241 @@ +"""Call API providers.""" + +from json import loads +import os +import random +import time + +from fastchat.utils import build_logger +from fastchat.constants import WORKER_API_TIMEOUT + + +logger = build_logger("gradio_web_server", "gradio_web_server.log") + + +def openai_api_stream_iter( + model_name, + messages, + temperature, + top_p, + max_new_tokens, + api_base=None, + api_key=None, +): + import openai + + is_azure = False + if "azure" in model_name: + is_azure = True + openai.api_type = "azure" + openai.api_version = "2023-07-01-preview" + else: + openai.api_type = "open_ai" + openai.api_version = None + + openai.api_base = api_base or "https://api.openai.com/v1" + openai.api_key = api_key or os.environ["OPENAI_API_KEY"] + if model_name == "gpt-4-turbo": + model_name = "gpt-4-1106-preview" + + # Make requests + gen_params = { + "model": model_name, + "prompt": messages, + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + } + logger.info(f"==== request ====\n{gen_params}") + + if is_azure: + res = openai.ChatCompletion.create( + engine=model_name, + messages=messages, + temperature=temperature, + max_tokens=max_new_tokens, + stream=True, + ) + else: + res = openai.ChatCompletion.create( + model=model_name, + messages=messages, + temperature=temperature, + max_tokens=max_new_tokens, + stream=True, + ) + text = "" + for chunk in res: + if len(chunk["choices"]) > 0: + text += chunk["choices"][0]["delta"].get("content", "") + data = { + "text": text, + "error_code": 0, + } + yield data + + +def anthropic_api_stream_iter(model_name, prompt, temperature, top_p, max_new_tokens): + import anthropic + + c = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"]) + + # Make requests + gen_params = { + "model": model_name, + "prompt": prompt, + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + } + logger.info(f"==== request ====\n{gen_params}") + + res = c.completions.create( + prompt=prompt, + stop_sequences=[anthropic.HUMAN_PROMPT], + max_tokens_to_sample=max_new_tokens, + temperature=temperature, + top_p=top_p, + model=model_name, + stream=True, + ) + text = "" + for chunk in res: + text += chunk.completion + data = { + "text": text, + "error_code": 0, + } + yield data + + +def init_palm_chat(model_name): + import vertexai # pip3 install google-cloud-aiplatform + from vertexai.preview.language_models import ChatModel + from vertexai.preview.generative_models import GenerativeModel + + project_id = os.environ["GCP_PROJECT_ID"] + location = "us-central1" + vertexai.init(project=project_id, location=location) + + if model_name in ["palm-2"]: + # According to release note, "chat-bison@001" is PaLM 2 for chat. + # https://cloud.google.com/vertex-ai/docs/release-notes#May_10_2023 + model_name = "chat-bison@001" + chat_model = ChatModel.from_pretrained(model_name) + chat = chat_model.start_chat(examples=[]) + elif model_name in ["gemini-pro"]: + model = GenerativeModel(model_name) + chat = model.start_chat() + return chat + + +def palm_api_stream_iter(model_name, chat, message, temperature, top_p, max_new_tokens): + if model_name in ["gemini-pro"]: + max_new_tokens = max_new_tokens * 2 + parameters = { + "temperature": temperature, + "top_p": top_p, + "max_output_tokens": max_new_tokens, + } + gen_params = { + "model": model_name, + "prompt": message, + } + gen_params.update(parameters) + if model_name == "palm-2": + response = chat.send_message(message, **parameters) + else: + response = chat.send_message(message, generation_config=parameters, stream=True) + + logger.info(f"==== request ====\n{gen_params}") + + try: + text = "" + for chunk in response: + text += chunk.text + data = { + "text": text, + "error_code": 0, + } + yield data + except Exception as e: + logger.error(f"==== error ====\n{e}") + yield { + "text": f"**API REQUEST ERROR** Reason: {e}\nPlease try again or increase the number of max tokens.", + "error_code": 1, + } + yield data + + +def ai2_api_stream_iter( + model_name, + messages, + temperature, + top_p, + max_new_tokens, + api_key=None, + api_base=None, +): + from requests import post + + # get keys and needed values + ai2_key = api_key or os.environ.get("AI2_API_KEY") + api_base = api_base or "https://inferd.allen.ai/api/v1/infer" + model_id = "mod_01hhgcga70c91402r9ssyxekan" + + # Make requests + gen_params = { + "model": model_name, + "prompt": messages, + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + } + logger.info(f"==== request ====\n{gen_params}") + + # AI2 uses vLLM, which requires that `top_p` be 1.0 for greedy sampling: + # https://github.com/vllm-project/vllm/blob/v0.1.7/vllm/sampling_params.py#L156-L157 + if temperature == 0.0 and top_p < 1.0: + raise ValueError("top_p must be 1 when temperature is 0.0") + + res = post( + api_base, + stream=True, + headers={"Authorization": f"Bearer {ai2_key}"}, + json={ + "model_id": model_id, + # This input format is specific to the Tulu2 model. Other models + # may require different input formats. See the model's schema + # documentation on InferD for more information. + "input": { + "messages": messages, + "opts": { + "max_tokens": max_new_tokens, + "temperature": temperature, + "top_p": top_p, + "logprobs": 1, # increase for more choices + }, + }, + }, + ) + + if res.status_code != 200: + logger.error(f"unexpected response ({res.status_code}): {res.text}") + raise ValueError("unexpected response from InferD", res) + + text = "" + for line in res.iter_lines(): + if line: + part = loads(line) + if "result" in part and "output" in part["result"]: + for t in part["result"]["output"]["text"]: + text += t + else: + logger.error(f"unexpected part: {part}") + raise ValueError("empty result in InferD response") + + data = { + "text": text, + "error_code": 0, + } + yield data diff --git a/third_party/FastChat/fastchat/serve/base_model_worker.py b/third_party/FastChat/fastchat/serve/base_model_worker.py new file mode 100644 index 00000000..514cc822 --- /dev/null +++ b/third_party/FastChat/fastchat/serve/base_model_worker.py @@ -0,0 +1,238 @@ +import asyncio +import threading +import time +from typing import List + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse, JSONResponse +import requests + +from fastchat.constants import WORKER_HEART_BEAT_INTERVAL +from fastchat.conversation import Conversation +from fastchat.utils import pretty_print_semaphore, build_logger + + +worker = None +logger = None + +app = FastAPI() + + +def heart_beat_worker(obj): + while True: + time.sleep(WORKER_HEART_BEAT_INTERVAL) + obj.send_heart_beat() + + +class BaseModelWorker: + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + conv_template: str = None, + ): + global logger, worker + + self.controller_addr = controller_addr + self.worker_addr = worker_addr + self.worker_id = worker_id + if model_path.endswith("/"): + model_path = model_path[:-1] + self.model_names = model_names or [model_path.split("/")[-1]] + self.limit_worker_concurrency = limit_worker_concurrency + self.conv = self.make_conv_template(conv_template, model_path) + self.conv.sep_style = int(self.conv.sep_style) + self.tokenizer = None + self.context_len = None + self.call_ct = 0 + self.semaphore = None + + self.heart_beat_thread = None + + if logger is None: + logger = build_logger("model_worker", f"model_worker_{self.worker_id}.log") + if worker is None: + worker = self + + def make_conv_template( + self, + conv_template: str = None, + model_path: str = None, + ) -> Conversation: + """ + can be overrided to costomize the conversation template for different model workers. + """ + from fastchat.conversation import get_conv_template + from fastchat.model.model_adapter import get_conversation_template + + if conv_template: + conv = get_conv_template(conv_template) + else: + conv = get_conversation_template(model_path) + return conv + + def init_heart_beat(self): + self.register_to_controller() + self.heart_beat_thread = threading.Thread( + target=heart_beat_worker, + args=(self,), + daemon=True, + ) + self.heart_beat_thread.start() + + def register_to_controller(self): + logger.info("Register to controller") + + url = self.controller_addr + "/register_worker" + data = { + "worker_name": self.worker_addr, + "check_heart_beat": True, + "worker_status": self.get_status(), + } + r = requests.post(url, json=data) + assert r.status_code == 200 + + def send_heart_beat(self): + logger.info( + f"Send heart beat. Models: {self.model_names}. " + f"Semaphore: {pretty_print_semaphore(self.semaphore)}. " + f"call_ct: {self.call_ct}. " + f"worker_id: {self.worker_id}. " + ) + + url = self.controller_addr + "/receive_heart_beat" + + while True: + try: + ret = requests.post( + url, + json={ + "worker_name": self.worker_addr, + "queue_length": self.get_queue_length(), + }, + timeout=5, + ) + exist = ret.json()["exist"] + break + except (requests.exceptions.RequestException, KeyError) as e: + logger.error(f"heart beat error: {e}") + time.sleep(5) + + if not exist: + self.register_to_controller() + + def get_queue_length(self): + if ( + self.semaphore is None + or self.semaphore._value is None + or self.semaphore._waiters is None + ): + return 0 + else: + return ( + self.limit_worker_concurrency + - self.semaphore._value + + len(self.semaphore._waiters) + ) + + def get_status(self): + return { + "model_names": self.model_names, + "speed": 1, + "queue_length": self.get_queue_length(), + } + + def count_token(self, params): + prompt = params["prompt"] + + try: + input_ids = self.tokenizer(prompt).input_ids + input_echo_len = len(input_ids) + except TypeError: + input_echo_len = self.tokenizer.num_tokens(prompt) + + ret = { + "count": input_echo_len, + "error_code": 0, + } + return ret + + def get_conv_template(self): + return {"conv": self.conv} + + def generate_stream_gate(self, params): + raise NotImplementedError + + def generate_gate(self, params): + raise NotImplementedError + + def get_embeddings(self, params): + raise NotImplementedError + + +def release_worker_semaphore(): + worker.semaphore.release() + + +def acquire_worker_semaphore(): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(): + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + generator = worker.generate_stream_gate(params) + background_tasks = create_background_tasks() + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + output = await asyncio.to_thread(worker.generate_gate, params) + release_worker_semaphore() + return JSONResponse(output) + + +@app.post("/worker_get_embeddings") +async def api_get_embeddings(request: Request): + params = await request.json() + await acquire_worker_semaphore() + embedding = worker.get_embeddings(params) + release_worker_semaphore() + return JSONResponse(content=embedding) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return worker.get_status() + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + return {"context_length": worker.context_len} diff --git a/third_party/FastChat/fastchat/serve/cli.py b/third_party/FastChat/fastchat/serve/cli.py new file mode 100644 index 00000000..78f7f51b --- /dev/null +++ b/third_party/FastChat/fastchat/serve/cli.py @@ -0,0 +1,304 @@ +""" +Chat with a model with command line interface. + +Usage: +python3 -m fastchat.serve.cli --model lmsys/vicuna-7b-v1.5 +python3 -m fastchat.serve.cli --model lmsys/fastchat-t5-3b-v1.0 + +Other commands: +- Type "!!exit" or an empty line to exit. +- Type "!!reset" to start a new conversation. +- Type "!!remove" to remove the last prompt. +- Type "!!regen" to regenerate the last message. +- Type "!!save " to save the conversation history to a json file. +- Type "!!load " to load a conversation history from a json file. +""" +import argparse +import os +import re +import sys + +from prompt_toolkit import PromptSession +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory +from prompt_toolkit.completion import WordCompleter +from prompt_toolkit.history import InMemoryHistory +from prompt_toolkit.key_binding import KeyBindings +from rich.console import Console +from rich.live import Live +from rich.markdown import Markdown +import torch + +from fastchat.model.model_adapter import add_model_args +from fastchat.modules.awq import AWQConfig +from fastchat.modules.exllama import ExllamaConfig +from fastchat.modules.xfastertransformer import XftConfig +from fastchat.modules.gptq import GptqConfig +from fastchat.serve.inference import ChatIO, chat_loop +from fastchat.utils import str_to_torch_dtype + + +class SimpleChatIO(ChatIO): + def __init__(self, multiline: bool = False): + self._multiline = multiline + + def prompt_for_input(self, role) -> str: + if not self._multiline: + return input(f"{role}: ") + + prompt_data = [] + line = input(f"{role} [ctrl-d/z on empty line to end]: ") + while True: + prompt_data.append(line.strip()) + try: + line = input() + except EOFError as e: + break + return "\n".join(prompt_data) + + def prompt_for_output(self, role: str): + print(f"{role}: ", end="", flush=True) + + def stream_output(self, output_stream): + pre = 0 + for outputs in output_stream: + output_text = outputs["text"] + output_text = output_text.strip().split(" ") + now = len(output_text) - 1 + if now > pre: + print(" ".join(output_text[pre:now]), end=" ", flush=True) + pre = now + print(" ".join(output_text[pre:]), flush=True) + return " ".join(output_text) + + def print_output(self, text: str): + print(text) + + +class RichChatIO(ChatIO): + bindings = KeyBindings() + + @bindings.add("escape", "enter") + def _(event): + event.app.current_buffer.newline() + + def __init__(self, multiline: bool = False, mouse: bool = False): + self._prompt_session = PromptSession(history=InMemoryHistory()) + self._completer = WordCompleter( + words=["!!exit", "!!reset", "!!remove", "!!regen", "!!save", "!!load"], + pattern=re.compile("$"), + ) + self._console = Console() + self._multiline = multiline + self._mouse = mouse + + def prompt_for_input(self, role) -> str: + self._console.print(f"[bold]{role}:") + # TODO(suquark): multiline input has some issues. fix it later. + prompt_input = self._prompt_session.prompt( + completer=self._completer, + multiline=False, + mouse_support=self._mouse, + auto_suggest=AutoSuggestFromHistory(), + key_bindings=self.bindings if self._multiline else None, + ) + self._console.print() + return prompt_input + + def prompt_for_output(self, role: str): + self._console.print(f"[bold]{role.replace('/', '|')}:") + + def stream_output(self, output_stream): + """Stream output from a role.""" + # TODO(suquark): the console flickers when there is a code block + # above it. We need to cut off "live" when a code block is done. + + # Create a Live context for updating the console output + with Live(console=self._console, refresh_per_second=4) as live: + # Read lines from the stream + for outputs in output_stream: + if not outputs: + continue + text = outputs["text"] + # Render the accumulated text as Markdown + # NOTE: this is a workaround for the rendering "unstandard markdown" + # in rich. The chatbots output treat "\n" as a new line for + # better compatibility with real-world text. However, rendering + # in markdown would break the format. It is because standard markdown + # treat a single "\n" in normal text as a space. + # Our workaround is adding two spaces at the end of each line. + # This is not a perfect solution, as it would + # introduce trailing spaces (only) in code block, but it works well + # especially for console output, because in general the console does not + # care about trailing spaces. + lines = [] + for line in text.splitlines(): + lines.append(line) + if line.startswith("```"): + # Code block marker - do not add trailing spaces, as it would + # break the syntax highlighting + lines.append("\n") + else: + lines.append(" \n") + markdown = Markdown("".join(lines)) + # Update the Live console output + live.update(markdown) + self._console.print() + return text + + def print_output(self, text: str): + self.stream_output([{"text": text}]) + + +class ProgrammaticChatIO(ChatIO): + def prompt_for_input(self, role) -> str: + contents = "" + # `end_sequence` signals the end of a message. It is unlikely to occur in + # message content. + end_sequence = " __END_OF_A_MESSAGE_47582648__\n" + len_end = len(end_sequence) + while True: + if len(contents) >= len_end: + last_chars = contents[-len_end:] + if last_chars == end_sequence: + break + try: + char = sys.stdin.read(1) + contents = contents + char + except EOFError: + continue + contents = contents[:-len_end] + print(f"[!OP:{role}]: {contents}", flush=True) + return contents + + def prompt_for_output(self, role: str): + print(f"[!OP:{role}]: ", end="", flush=True) + + def stream_output(self, output_stream): + pre = 0 + for outputs in output_stream: + output_text = outputs["text"] + output_text = output_text.strip().split(" ") + now = len(output_text) - 1 + if now > pre: + print(" ".join(output_text[pre:now]), end=" ", flush=True) + pre = now + print(" ".join(output_text[pre:]), flush=True) + return " ".join(output_text) + + def print_output(self, text: str): + print(text) + + +def main(args): + if args.gpus: + if len(args.gpus.split(",")) < args.num_gpus: + raise ValueError( + f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" + ) + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + os.environ["XPU_VISIBLE_DEVICES"] = args.gpus + if args.enable_exllama: + exllama_config = ExllamaConfig( + max_seq_len=args.exllama_max_seq_len, + gpu_split=args.exllama_gpu_split, + cache_8bit=args.exllama_cache_8bit, + ) + else: + exllama_config = None + if args.enable_xft: + xft_config = XftConfig( + max_seq_len=args.xft_max_seq_len, + data_type=args.xft_dtype, + ) + if args.device != "cpu": + print("xFasterTransformer now is only support CPUs. Reset device to CPU") + args.device = "cpu" + else: + xft_config = None + if args.style == "simple": + chatio = SimpleChatIO(args.multiline) + elif args.style == "rich": + chatio = RichChatIO(args.multiline, args.mouse) + elif args.style == "programmatic": + chatio = ProgrammaticChatIO() + else: + raise ValueError(f"Invalid style for console: {args.style}") + try: + chat_loop( + args.model_path, + args.device, + args.num_gpus, + args.max_gpu_memory, + str_to_torch_dtype(args.dtype), + args.load_8bit, + args.cpu_offloading, + args.conv_template, + args.conv_system_msg, + args.temperature, + args.repetition_penalty, + args.max_new_tokens, + chatio, + gptq_config=GptqConfig( + ckpt=args.gptq_ckpt or args.model_path, + wbits=args.gptq_wbits, + groupsize=args.gptq_groupsize, + act_order=args.gptq_act_order, + ), + awq_config=AWQConfig( + ckpt=args.awq_ckpt or args.model_path, + wbits=args.awq_wbits, + groupsize=args.awq_groupsize, + ), + exllama_config=exllama_config, + xft_config=xft_config, + revision=args.revision, + judge_sent_end=args.judge_sent_end, + debug=args.debug, + history=not args.no_history, + ) + except KeyboardInterrupt: + print("exit...") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + add_model_args(parser) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument( + "--conv-system-msg", type=str, default=None, help="Conversation system message." + ) + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--repetition_penalty", type=float, default=1.0) + parser.add_argument("--max-new-tokens", type=int, default=512) + parser.add_argument("--no-history", action="store_true") + parser.add_argument( + "--style", + type=str, + default="simple", + choices=["simple", "rich", "programmatic"], + help="Display style.", + ) + parser.add_argument( + "--multiline", + action="store_true", + help="Enable multiline input. Use ESC+Enter for newline.", + ) + parser.add_argument( + "--mouse", + action="store_true", + help="[Rich Style]: Enable mouse support for cursor positioning.", + ) + parser.add_argument( + "--judge-sent-end", + action="store_true", + help="Whether enable the correction logic that interrupts the output of sentences due to EOS.", + ) + parser.add_argument( + "--debug", + action="store_true", + help="Print useful debug information (e.g., prompts)", + ) + args = parser.parse_args() + main(args) diff --git a/third_party/FastChat/fastchat/serve/controller.py b/third_party/FastChat/fastchat/serve/controller.py new file mode 100644 index 00000000..a67da62c --- /dev/null +++ b/third_party/FastChat/fastchat/serve/controller.py @@ -0,0 +1,348 @@ +""" +A controller manages distributed workers. +It sends worker addresses to clients. +""" +import argparse +import asyncio +import dataclasses +from enum import Enum, auto +import json +import logging +import os +import time +from typing import List, Union +import threading + +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse +import numpy as np +import requests +import uvicorn + +from fastchat.constants import ( + CONTROLLER_HEART_BEAT_EXPIRATION, + WORKER_API_TIMEOUT, + ErrorCode, + SERVER_ERROR_MSG, +) +from fastchat.utils import build_logger + + +logger = build_logger("controller", "controller.log") + + +class DispatchMethod(Enum): + LOTTERY = auto() + SHORTEST_QUEUE = auto() + + @classmethod + def from_str(cls, name): + if name == "lottery": + return cls.LOTTERY + elif name == "shortest_queue": + return cls.SHORTEST_QUEUE + else: + raise ValueError(f"Invalid dispatch method") + + +@dataclasses.dataclass +class WorkerInfo: + model_names: List[str] + speed: int + queue_length: int + check_heart_beat: bool + last_heart_beat: str + + +def heart_beat_controller(controller): + while True: + time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION) + controller.remove_stale_workers_by_expiration() + + +class Controller: + def __init__(self, dispatch_method: str): + # Dict[str -> WorkerInfo] + self.worker_info = {} + self.dispatch_method = DispatchMethod.from_str(dispatch_method) + + self.heart_beat_thread = threading.Thread( + target=heart_beat_controller, args=(self,) + ) + self.heart_beat_thread.start() + + def register_worker( + self, worker_name: str, check_heart_beat: bool, worker_status: dict + ): + if worker_name not in self.worker_info: + logger.info(f"Register a new worker: {worker_name}") + else: + logger.info(f"Register an existing worker: {worker_name}") + + if not worker_status: + worker_status = self.get_worker_status(worker_name) + if not worker_status: + return False + + self.worker_info[worker_name] = WorkerInfo( + worker_status["model_names"], + worker_status["speed"], + worker_status["queue_length"], + check_heart_beat, + time.time(), + ) + + logger.info(f"Register done: {worker_name}, {worker_status}") + return True + + def get_worker_status(self, worker_name: str): + try: + r = requests.post(worker_name + "/worker_get_status", timeout=5) + except requests.exceptions.RequestException as e: + logger.error(f"Get status fails: {worker_name}, {e}") + return None + + if r.status_code != 200: + logger.error(f"Get status fails: {worker_name}, {r}") + return None + + return r.json() + + def remove_worker(self, worker_name: str): + del self.worker_info[worker_name] + + def refresh_all_workers(self): + old_info = dict(self.worker_info) + self.worker_info = {} + + for w_name, w_info in old_info.items(): + if not self.register_worker(w_name, w_info.check_heart_beat, None): + logger.info(f"Remove stale worker: {w_name}") + + def list_models(self): + model_names = set() + + for w_name, w_info in self.worker_info.items(): + model_names.update(w_info.model_names) + + return list(model_names) + + def get_worker_address(self, model_name: str): + if self.dispatch_method == DispatchMethod.LOTTERY: + worker_names = [] + worker_speeds = [] + for w_name, w_info in self.worker_info.items(): + if model_name in w_info.model_names: + worker_names.append(w_name) + worker_speeds.append(w_info.speed) + worker_speeds = np.array(worker_speeds, dtype=np.float32) + norm = np.sum(worker_speeds) + if norm < 1e-4: + return "" + worker_speeds = worker_speeds / norm + if True: # Directly return address + pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds) + worker_name = worker_names[pt] + return worker_name + + # Check status before returning + while True: + pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds) + worker_name = worker_names[pt] + + if self.get_worker_status(worker_name): + break + else: + self.remove_worker(worker_name) + worker_speeds[pt] = 0 + norm = np.sum(worker_speeds) + if norm < 1e-4: + return "" + worker_speeds = worker_speeds / norm + continue + return worker_name + elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE: + worker_names = [] + worker_qlen = [] + for w_name, w_info in self.worker_info.items(): + if model_name in w_info.model_names: + worker_names.append(w_name) + worker_qlen.append(w_info.queue_length / w_info.speed) + if len(worker_names) == 0: + return "" + min_index = np.argmin(worker_qlen) + w_name = worker_names[min_index] + self.worker_info[w_name].queue_length += 1 + logger.info( + f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}" + ) + return w_name + else: + raise ValueError(f"Invalid dispatch method: {self.dispatch_method}") + + def receive_heart_beat(self, worker_name: str, queue_length: int): + if worker_name not in self.worker_info: + logger.info(f"Receive unknown heart beat. {worker_name}") + return False + + self.worker_info[worker_name].queue_length = queue_length + self.worker_info[worker_name].last_heart_beat = time.time() + logger.info(f"Receive heart beat. {worker_name}") + return True + + def remove_stale_workers_by_expiration(self): + expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION + to_delete = [] + for worker_name, w_info in self.worker_info.items(): + if w_info.check_heart_beat and w_info.last_heart_beat < expire: + to_delete.append(worker_name) + + for worker_name in to_delete: + self.remove_worker(worker_name) + + def handle_no_worker(self, params): + logger.info(f"no worker: {params['model']}") + ret = { + "text": SERVER_ERROR_MSG, + "error_code": ErrorCode.CONTROLLER_NO_WORKER, + } + return json.dumps(ret).encode() + b"\0" + + def handle_worker_timeout(self, worker_address): + logger.info(f"worker timeout: {worker_address}") + ret = { + "text": SERVER_ERROR_MSG, + "error_code": ErrorCode.CONTROLLER_WORKER_TIMEOUT, + } + return json.dumps(ret).encode() + b"\0" + + # Let the controller act as a worker to achieve hierarchical + # management. This can be used to connect isolated sub networks. + def worker_api_get_status(self): + model_names = set() + speed = 0 + queue_length = 0 + + for w_name in self.worker_info: + worker_status = self.get_worker_status(w_name) + if worker_status is not None: + model_names.update(worker_status["model_names"]) + speed += worker_status["speed"] + queue_length += worker_status["queue_length"] + + model_names = sorted(list(model_names)) + return { + "model_names": model_names, + "speed": speed, + "queue_length": queue_length, + } + + def worker_api_generate_stream(self, params): + worker_addr = self.get_worker_address(params["model"]) + if not worker_addr: + yield self.handle_no_worker(params) + + try: + response = requests.post( + worker_addr + "/worker_generate_stream", + json=params, + stream=True, + timeout=WORKER_API_TIMEOUT, + ) + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + yield chunk + b"\0" + except requests.exceptions.RequestException as e: + yield self.handle_worker_timeout(worker_addr) + + +app = FastAPI() + + +@app.post("/register_worker") +async def register_worker(request: Request): + data = await request.json() + controller.register_worker( + data["worker_name"], data["check_heart_beat"], data.get("worker_status", None) + ) + + +@app.post("/refresh_all_workers") +async def refresh_all_workers(): + models = controller.refresh_all_workers() + + +@app.post("/list_models") +async def list_models(): + models = controller.list_models() + return {"models": models} + + +@app.post("/get_worker_address") +async def get_worker_address(request: Request): + data = await request.json() + addr = controller.get_worker_address(data["model"]) + return {"address": addr} + + +@app.post("/receive_heart_beat") +async def receive_heart_beat(request: Request): + data = await request.json() + exist = controller.receive_heart_beat(data["worker_name"], data["queue_length"]) + return {"exist": exist} + + +@app.post("/worker_generate_stream") +async def worker_api_generate_stream(request: Request): + params = await request.json() + generator = controller.worker_api_generate_stream(params) + return StreamingResponse(generator) + + +@app.post("/worker_get_status") +async def worker_api_get_status(request: Request): + return controller.worker_api_get_status() + + +@app.get("/test_connection") +async def worker_api_get_status(request: Request): + return "success" + + +def create_controller(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21001) + parser.add_argument( + "--dispatch-method", + type=str, + choices=["lottery", "shortest_queue"], + default="shortest_queue", + ) + parser.add_argument( + "--ssl", + action="store_true", + required=False, + default=False, + help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", + ) + args = parser.parse_args() + logger.info(f"args: {args}") + + controller = Controller(args.dispatch_method) + return args, controller + + +if __name__ == "__main__": + args, controller = create_controller() + if args.ssl: + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + ssl_keyfile=os.environ["SSL_KEYFILE"], + ssl_certfile=os.environ["SSL_CERTFILE"], + ) + else: + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/third_party/FastChat/fastchat/serve/gateway/README.md b/third_party/FastChat/fastchat/serve/gateway/README.md new file mode 100644 index 00000000..b3afaf17 --- /dev/null +++ b/third_party/FastChat/fastchat/serve/gateway/README.md @@ -0,0 +1,57 @@ +# fastchat Nginx Gateway + +## Purpose of the Gateway + +The Nginx gateway serves the following purposes: + +1. Protects Gradio servers by acting as a firewall. +2. Facilitates dynamic mounting and unmounting of Gradio servers. +3. Provides load balancing for Gradio servers. +4. Offers additional security features, such as total connection limit. +5. Reduces attack surface by requiring only a single public port to be exposed for serving. + +## Deployment and Updating of the Gateway + +### Installing Nginx + +On Debian-based distributions (e.g., Ubuntu): + +```bash +sudo apt update +sudo apt install nginx +``` +On Red Hat-based distributions (e.g., CentOS, Fedora): + +```bash +sudo yum install epel-release +sudo yum install nginx +``` + +### Deployment + +Copy `nginx.conf` to `/etc/nginx/nginx.conf` (need sudo permission). + +Replace the port number 7860 in `server localhost:7860` with the port where you deploy the Gradio web server. + +Modify `upstream websocket` to configure Gradio servers behind the gateway. + +Lastly, update Nginx. + + +### HTTPS Deployment with a Public Domain URL + +Make sure you obtain the HTTPS certificate and the private key used to generate the certificate. + +Fill the addresses to your certificate and private key in the `[PATH_TO_SSL_CERT]` and `[PATH_TO_PRIVATE_KEY]` fields. + +If you have your own domain url to serve the chatbot, replace the chat.lmsys.org url with your own domain url. + +### Updating + +Every time when `/etc/nginx/nginx.conf` is modified, you need to update the Nginx service: + +```bash +sudo nginx -t # check `/etc/nginx/nginx.conf` +sudo systemctl reload nginx # restart Nginx service to load the new config +sudo systemctl status nginx # check the status of the Nginx service. It should be active (running). +``` diff --git a/third_party/FastChat/fastchat/serve/gateway/nginx.conf b/third_party/FastChat/fastchat/serve/gateway/nginx.conf new file mode 100644 index 00000000..b88ca8c5 --- /dev/null +++ b/third_party/FastChat/fastchat/serve/gateway/nginx.conf @@ -0,0 +1,97 @@ +user www-data; +worker_processes auto; +pid /run/nginx.pid; +include /etc/nginx/modules-enabled/*.conf; + +events { + worker_connections 1024; # maximum number of connections that a worker process can handle concurrently + # multi_accept on; # enabling multi_accept can help improve performance under high load, but may increase the number of simultaneous connections that a worker process can handle + +} + +http { + ## + # Basic Settings + ## + + sendfile on; # enable sendfile for performance optimization + tcp_nopush on; # enable TCP no-pushing + tcp_nodelay on; # enable TCP no-delay + keepalive_timeout 65; # sets the timeout for keep-alive connections + types_hash_max_size 2048; # maximum size of the types hash table + # server_tokens off; # disable server token (i.e., server signature) in response headers to improve security + + # server_names_hash_bucket_size 64; + # server_name_in_redirect off; + + include /etc/nginx/mime.types; # include MIME types file + default_type application/octet-stream; # default MIME type for unknown file types + + ## + # SSL Settings + ## + + ssl_protocols TLSv1.2; # specify SSL/TLS protocols to use + ssl_prefer_server_ciphers on; # prefer server ciphers over client ciphers + + ## + # Logging Settings + ## + + access_log /var/log/nginx/access.log; # path to access log file + error_log /var/log/nginx/error.log; # path to error log file + + ## + # Gzip Settings + ## + gzip on; # enable Gzip compression + + ## + # Virtual Host Configs + ## + + include /etc/nginx/conf.d/*.conf; # include all configuration files in conf.d directory + include /etc/nginx/sites-enabled/*; # include all enabled sites configuration files + + # WebSocket Proxy: https://www.nginx.com/blog/websocket-nginx/ + map $http_upgrade $connection_upgrade { + default upgrade; + '' close; + } + + upstream websocket { + ip_hash; # load balancing by IP to guarantee session persistence + server localhost:7860; # The port should be the gradio web server port + # server localhost:7861; # extra gradio server if more than one + } + + limit_conn_status 429; + limit_conn_zone $binary_remote_addr zone=perip:10m; # limit number of connections per IP + limit_conn_zone $server_name zone=perserver:10m; # limit number of connections per server + + server { + listen 443 ssl; # the listening port of our server + ssl_certificate [PATH_TO_SSL_CERT]; + ssl_certificate_key [PATH_TO_PRIVATE_KEY]; + server_name chat.lmsys.org; # replace the url with your own domain url + limit_conn perserver 1024; # connections per server + location / { + proxy_pass http://websocket; # proxy all requests to the defined upstream server + limit_conn perip 5; # connections per IP + proxy_set_header Host $host; # set the Host header for the upstream server + proxy_set_header X-Real-IP $remote_addr; # set the client IP address as the real IP for the upstream server + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; # set the client IP addresses in the X-Forwarded-For header + proxy_http_version 1.1; # use HTTP version 1.1 for upstream communication + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "Upgrade"; # set the Connection header to Upgrade to enable WebSocket communication + } + } + + # the following block routes all HTTP traffic to HTTPS via nginx + server { + listen 80; + server_name chat.lmsys.org; + return 301 https://chat.lmsys.org$request_uri; + } + +} diff --git a/third_party/FastChat/fastchat/serve/gradio_block_arena_anony.py b/third_party/FastChat/fastchat/serve/gradio_block_arena_anony.py new file mode 100644 index 00000000..6797290b --- /dev/null +++ b/third_party/FastChat/fastchat/serve/gradio_block_arena_anony.py @@ -0,0 +1,677 @@ +""" +Chatbot Arena (battle) tab. +Users chat with two anonymous models. +""" + +import json +import time + +import gradio as gr +import numpy as np + +from fastchat.constants import ( + MODERATION_MSG, + CONVERSATION_LIMIT_MSG, + SLOW_MODEL_MSG, + INPUT_CHAR_LEN_LIMIT, + CONVERSATION_TURN_LIMIT, +) +from fastchat.model.model_adapter import get_conversation_template +from fastchat.serve.gradio_block_arena_named import flash_buttons +from fastchat.serve.gradio_web_server import ( + State, + bot_response, + get_conv_log_filename, + no_change_btn, + enable_btn, + disable_btn, + invisible_btn, + acknowledgment_md, + ip_expiration_dict, + get_ip, + get_model_description_md, +) +from fastchat.utils import ( + build_logger, + moderation_filter, +) + +logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") + +num_sides = 2 +enable_moderation = False +anony_names = ["", ""] +models = [] + + +def set_global_vars_anony(enable_moderation_): + global enable_moderation + enable_moderation = enable_moderation_ + + +def load_demo_side_by_side_anony(models_, url_params): + global models + models = models_ + + states = (None,) * num_sides + selector_updates = ( + gr.Markdown.update(visible=True), + gr.Markdown.update(visible=True), + ) + + return states + selector_updates + + +def vote_last_response(states, vote_type, model_selectors, request: gr.Request): + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "models": [x for x in model_selectors], + "states": [x.dict() for x in states], + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + + if ":" not in model_selectors[0]: + for i in range(15): + names = ( + "### Model A: " + states[0].model_name, + "### Model B: " + states[1].model_name, + ) + yield names + ("",) + (disable_btn,) * 4 + time.sleep(0.2) + else: + names = ( + "### Model A: " + states[0].model_name, + "### Model B: " + states[1].model_name, + ) + yield names + ("",) + (disable_btn,) * 4 + + +def leftvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"leftvote (anony). ip: {get_ip(request)}") + for x in vote_last_response( + [state0, state1], "leftvote", [model_selector0, model_selector1], request + ): + yield x + + +def rightvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"rightvote (anony). ip: {get_ip(request)}") + for x in vote_last_response( + [state0, state1], "rightvote", [model_selector0, model_selector1], request + ): + yield x + + +def tievote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"tievote (anony). ip: {get_ip(request)}") + for x in vote_last_response( + [state0, state1], "tievote", [model_selector0, model_selector1], request + ): + yield x + + +def bothbad_vote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}") + for x in vote_last_response( + [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request + ): + yield x + + +def regenerate(state0, state1, request: gr.Request): + logger.info(f"regenerate (anony). ip: {get_ip(request)}") + states = [state0, state1] + for i in range(num_sides): + states[i].conv.update_last_message(None) + return states + [x.to_gradio_chatbot() for x in states] + [""] + [disable_btn] * 6 + + +def clear_history(request: gr.Request): + logger.info(f"clear_history (anony). ip: {get_ip(request)}") + return ( + [None] * num_sides + + [None] * num_sides + + anony_names + + [""] + + [invisible_btn] * 4 + + [disable_btn] * 2 + + [""] + ) + + +def share_click(state0, state1, model_selector0, model_selector1, request: gr.Request): + logger.info(f"share (anony). ip: {get_ip(request)}") + if state0 is not None and state1 is not None: + vote_last_response( + [state0, state1], "share", [model_selector0, model_selector1], request + ) + + +SAMPLING_WEIGHTS = { + # tier 0 + "gpt-4": 4, + "gpt-4-0314": 4, + "gpt-4-turbo": 4, + "gpt-3.5-turbo-0613": 2, + "gpt-3.5-turbo-1106": 2, + "claude-2.1": 4, + "claude-2.0": 2, + "claude-1": 2, + "claude-instant-1": 4, + "gemini-pro": 4, + "pplx-7b-online": 4, + "pplx-70b-online": 4, + "solar-10.7b-instruct-v1.0": 2, + "mixtral-8x7b-instruct-v0.1": 4, + "openhermes-2.5-mistral-7b": 2, + "dolphin-2.2.1-mistral-7b": 2, + "wizardlm-70b": 2, + "starling-lm-7b-alpha": 2, + "tulu-2-dpo-70b": 2, + "yi-34b-chat": 2, + "zephyr-7b-beta": 2, + "openchat-3.5": 2, + "chatglm3-6b": 2, + # tier 1 + "deluxe-chat-v1.2": 2, + "llama-2-70b-chat": 1.5, + "llama-2-13b-chat": 1.5, + "codellama-34b-instruct": 1.5, + "vicuna-33b": 4, + "vicuna-13b": 1.5, + "wizardlm-13b": 1.5, + "qwen-14b-chat": 1.5, + "mistral-7b-instruct": 1.5, + # tier 2 + "vicuna-7b": 1.0, + "llama-2-7b-chat": 1.0, + "chatglm2-6b": 1.0, + # deprecated + "zephyr-7b-alpha": 1.5, + "codellama-13b-instruct": 1.0, + "mpt-30b-chat": 1.5, + "guanaco-33b": 1.0, + "fastchat-t5-3b": 0.5, + "alpaca-13b": 0.5, + "mpt-7b-chat": 0.1, + "oasst-pythia-12b": 0.1, + "RWKV-4-Raven-14B": 0.1, + "gpt4all-13b-snoozy": 0.1, + "koala-13b": 0.1, + "stablelm-tuned-alpha-7b": 0.1, + "dolly-v2-12b": 0.1, + "llama-13b": 0.1, + "chatglm-6b": 0.5, + "deluxe-chat-v1": 4, + "palm-2": 1.5, +} + +# target model sampling weights will be boosted. +BATTLE_TARGETS = { + "gpt-4": {"gpt-4-0314", "claude-2.1", "gpt-4-turbo"}, + "gpt-4-0613": {"gpt-4-0314", "claude-2.1", "gpt-4-turbo"}, + "gpt-4-0314": {"gpt-4-turbo", "gpt-4-0613", "claude-2.1", "gpt-3.5-turbo-0613"}, + "gpt-4-turbo": { + "gpt-4-0613", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "claude-2.1", + }, + "gpt-3.5-turbo-0613": {"claude-instant-1", "gpt-4-0613", "claude-2.1"}, + "gpt-3.5-turbo-1106": {"gpt-4-0613", "claude-instant-1", "gpt-3.5-turbo-0613"}, + "solar-10.7b-instruct-v1.0": { + "mixtral-8x7b-instruct-v0.1", + "gpt-3.5-turbo-0613", + "llama-2-70b-chat", + }, + "mixtral-8x7b-instruct-v0.1": { + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0613", + "gpt-4-turbo", + "llama-2-70b-chat", + }, + "claude-2.1": {"gpt-4-turbo", "gpt-4-0613", "claude-1"}, + "claude-2.0": {"gpt-4-turbo", "gpt-4-0613", "claude-1"}, + "claude-1": {"claude-2.1", "gpt-4-0613", "gpt-3.5-turbo-0613"}, + "claude-instant-1": {"gpt-3.5-turbo-1106", "claude-2.1"}, + "gemini-pro": {"gpt-4-turbo", "gpt-4-0613", "gpt-3.5-turbo-0613"}, + "deluxe-chat-v1.1": {"gpt-4-0613", "gpt-4-turbo"}, + "deluxe-chat-v1.2": {"gpt-4-0613", "gpt-4-turbo"}, + "pplx-7b-online": {"gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "llama-2-70b-chat"}, + "pplx-70b-online": {"gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "llama-2-70b-chat"}, + "openhermes-2.5-mistral-7b": { + "gpt-3.5-turbo-0613", + "openchat-3.5", + "zephyr-7b-beta", + }, + "dolphin-2.2.1-mistral-7b": { + "gpt-3.5-turbo-0613", + "vicuna-33b", + "starling-lm-7b-alpha", + "openhermes-2.5-mistral-7b", + }, + "starling-lm-7b-alpha": {"gpt-3.5-turbo-0613", "openchat-3.5", "zephyr-7b-beta"}, + "tulu-2-dpo-70b": {"gpt-3.5-turbo-0613", "vicuna-33b", "claude-instant-1"}, + "yi-34b-chat": {"gpt-3.5-turbo-0613", "vicuna-33b", "claude-instant-1"}, + "openchat-3.5": {"gpt-3.5-turbo-0613", "llama-2-70b-chat", "zephyr-7b-beta"}, + "chatglm3-6b": {"yi-34b-chat", "qwen-14b-chat"}, + "qwen-14b-chat": {"vicuna-13b", "llama-2-13b-chat", "llama-2-70b-chat"}, + "zephyr-7b-alpha": {"mistral-7b-instruct", "llama-2-13b-chat"}, + "zephyr-7b-beta": { + "mistral-7b-instruct", + "llama-2-13b-chat", + "llama-2-7b-chat", + "wizardlm-13b", + }, + "llama-2-70b-chat": {"gpt-3.5-turbo-0613", "vicuna-33b", "claude-instant-1"}, + "llama-2-13b-chat": {"mistral-7b-instruct", "vicuna-13b", "llama-2-70b-chat"}, + "llama-2-7b-chat": {"mistral-7b-instruct", "vicuna-7b", "llama-2-13b-chat"}, + "mistral-7b-instruct": { + "llama-2-7b-chat", + "llama-2-13b-chat", + "llama-2-70b-chat", + }, + "vicuna-33b": {"llama-2-70b-chat", "gpt-3.5-turbo-0613", "claude-instant-1"}, + "vicuna-13b": {"llama-2-13b-chat", "llama-2-70b-chat"}, + "vicuna-7b": {"llama-2-7b-chat", "mistral-7b-instruct", "llama-2-13b-chat"}, + "wizardlm-70b": {"gpt-3.5-turbo-0613", "vicuna-33b", "claude-instant-1"}, +} + +SAMPLING_BOOST_MODELS = [ + # "tulu-2-dpo-70b", + # "yi-34b-chat", + "claude-2.1", + "claude-1", + "gpt-4-0613", + # "gpt-3.5-turbo-1106", + # "gpt-4-0314", + "gpt-4-turbo", + # "dolphin-2.2.1-mistral-7b", + "mixtral-8x7b-instruct-v0.1", + "gemini-pro", + "solar-10.7b-instruct-v1.0", +] + +# outage models won't be sampled. +OUTAGE_MODELS = [] + + +def get_sample_weight(model): + if model in OUTAGE_MODELS: + return 0 + weight = SAMPLING_WEIGHTS.get(model, 1.0) + if model in SAMPLING_BOOST_MODELS: + weight *= 5 + return weight + + +def get_battle_pair(): + if len(models) == 1: + return models[0], models[0] + + model_weights = [] + for model in models: + weight = get_sample_weight(model) + model_weights.append(weight) + total_weight = np.sum(model_weights) + model_weights = model_weights / total_weight + chosen_idx = np.random.choice(len(models), p=model_weights) + chosen_model = models[chosen_idx] + # for p, w in zip(models, model_weights): + # print(p, w) + + rival_models = [] + rival_weights = [] + for model in models: + if model == chosen_model: + continue + weight = get_sample_weight(model) + if ( + weight != 0 + and chosen_model in BATTLE_TARGETS + and model in BATTLE_TARGETS[chosen_model] + ): + # boost to 50% chance + weight = total_weight / len(BATTLE_TARGETS[chosen_model]) + rival_models.append(model) + rival_weights.append(weight) + # for p, w in zip(rival_models, rival_weights): + # print(p, w) + rival_weights = rival_weights / np.sum(rival_weights) + rival_idx = np.random.choice(len(rival_models), p=rival_weights) + rival_model = rival_models[rival_idx] + + swap = np.random.randint(2) + if swap == 0: + return chosen_model, rival_model + else: + return rival_model, chosen_model + + +def add_text( + state0, state1, model_selector0, model_selector1, text, request: gr.Request +): + ip = get_ip(request) + logger.info(f"add_text (anony). ip: {ip}. len: {len(text)}") + states = [state0, state1] + model_selectors = [model_selector0, model_selector1] + + # Init states if necessary + if states[0] is None: + assert states[1] is None + + model_left, model_right = get_battle_pair() + states = [ + State(model_left), + State(model_right), + ] + + if len(text) <= 0: + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [""] + + [ + no_change_btn, + ] + * 6 + + [""] + ) + + model_list = [states[i].model_name for i in range(num_sides)] + flagged = moderation_filter(text, model_list) + if flagged: + logger.info(f"violate moderation (anony). ip: {ip}. text: {text}") + # overwrite the original text + text = MODERATION_MSG + + conv = states[0].conv + if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: + logger.info(f"conversation turn limit. ip: {get_ip(request)}. text: {text}") + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [CONVERSATION_LIMIT_MSG] + + [ + no_change_btn, + ] + * 6 + + [""] + ) + + text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off + for i in range(num_sides): + states[i].conv.append_message(states[i].conv.roles[0], text) + states[i].conv.append_message(states[i].conv.roles[1], None) + states[i].skip_next = False + + slow_model_msg = "" + for i in range(num_sides): + if "deluxe" in states[i].model_name: + slow_model_msg = SLOW_MODEL_MSG + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [""] + + [ + disable_btn, + ] + * 6 + + [slow_model_msg] + ) + + +def bot_response_multi( + state0, + state1, + temperature, + top_p, + max_new_tokens, + request: gr.Request, +): + logger.info(f"bot_response_multi (anony). ip: {get_ip(request)}") + + if state0 is None or state0.skip_next: + # This generate call is skipped due to invalid inputs + yield ( + state0, + state1, + state0.to_gradio_chatbot(), + state1.to_gradio_chatbot(), + ) + (no_change_btn,) * 6 + return + + states = [state0, state1] + gen = [] + for i in range(num_sides): + gen.append( + bot_response( + states[i], + temperature, + top_p, + max_new_tokens, + request, + apply_rate_limit=False, + ) + ) + + chatbots = [None] * num_sides + while True: + stop = True + for i in range(num_sides): + try: + ret = next(gen[i]) + states[i], chatbots[i] = ret[0], ret[1] + stop = False + except StopIteration: + pass + yield states + chatbots + [disable_btn] * 6 + if stop: + break + + +def build_side_by_side_ui_anony(models): + notice_markdown = """ +# ⚔️ Chatbot Arena ⚔️ : Benchmarking LLMs in the Wild +| [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | + +## 📜 Rules +- Ask any question to two anonymous models (e.g., ChatGPT, Claude, Llama) and vote for the better one! +- You can continue chatting until you identify a winner. +- Vote won't be counted if model identity is revealed during conversation. + +## 🏆 Arena Elo [Leaderboard](https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard) +We use **100K+** human votes to compile an Elo-based LLM leaderboard. +Find out who is the 🥇LLM Champion! + +## 👇 Chat now! + +""" + + states = [gr.State() for _ in range(num_sides)] + model_selectors = [None] * num_sides + chatbots = [None] * num_sides + + gr.Markdown(notice_markdown, elem_id="notice_markdown") + + with gr.Box(elem_id="share-region-anony"): + with gr.Accordion("🔍 Expand to see 20+ Arena players", open=False): + model_description_md = get_model_description_md(models) + gr.Markdown(model_description_md, elem_id="model_description_markdown") + with gr.Row(): + for i in range(num_sides): + label = "Model A" if i == 0 else "Model B" + with gr.Column(): + chatbots[i] = gr.Chatbot( + label=label, elem_id=f"chatbot", height=550 + ) + + with gr.Row(): + for i in range(num_sides): + with gr.Column(): + model_selectors[i] = gr.Markdown(anony_names[i]) + with gr.Row(): + slow_warning = gr.Markdown("", elem_id="notice_markdown") + + with gr.Row(): + leftvote_btn = gr.Button( + value="👈 A is better", visible=False, interactive=False + ) + rightvote_btn = gr.Button( + value="👉 B is better", visible=False, interactive=False + ) + tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False) + bothbad_btn = gr.Button( + value="👎 Both are bad", visible=False, interactive=False + ) + + with gr.Row(): + textbox = gr.Textbox( + show_label=False, + placeholder="👉 Enter your prompt and press ENTER", + container=False, + elem_id="input_box", + ) + send_btn = gr.Button(value="Send", variant="primary", scale=0) + + with gr.Row() as button_row: + clear_btn = gr.Button(value="🎲 New Round", interactive=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + share_btn = gr.Button(value="📷 Share") + + with gr.Accordion("Parameters", open=False) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0.0, + maximum=1.0, + value=1.0, + step=0.1, + interactive=True, + label="Top P", + ) + max_output_tokens = gr.Slider( + minimum=16, + maximum=2048, + value=1024, + step=64, + interactive=True, + label="Max output tokens", + ) + + gr.Markdown(acknowledgment_md, elem_id="ack_markdown") + + # Register listeners + btn_list = [ + leftvote_btn, + rightvote_btn, + tie_btn, + bothbad_btn, + regenerate_btn, + clear_btn, + ] + leftvote_btn.click( + leftvote_last_response, + states + model_selectors, + model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + rightvote_btn.click( + rightvote_last_response, + states + model_selectors, + model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + tie_btn.click( + tievote_last_response, + states + model_selectors, + model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + bothbad_btn.click( + bothbad_vote_last_response, + states + model_selectors, + model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + regenerate_btn.click( + regenerate, states, states + chatbots + [textbox] + btn_list + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + clear_btn.click( + clear_history, + None, + states + chatbots + model_selectors + [textbox] + btn_list + [slow_warning], + ) + + share_js = """ +function (a, b, c, d) { + const captureElement = document.querySelector('#share-region-anony'); + html2canvas(captureElement) + .then(canvas => { + canvas.style.display = 'none' + document.body.appendChild(canvas) + return canvas + }) + .then(canvas => { + const image = canvas.toDataURL('image/png') + const a = document.createElement('a') + a.setAttribute('download', 'chatbot-arena.png') + a.setAttribute('href', image) + a.click() + canvas.remove() + }); + return [a, b, c, d]; +} +""" + share_btn.click(share_click, states + model_selectors, [], _js=share_js) + + textbox.submit( + add_text, + states + model_selectors + [textbox], + states + chatbots + [textbox] + btn_list + [slow_warning], + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, + [], + btn_list, + ) + + send_btn.click( + add_text, + states + model_selectors + [textbox], + states + chatbots + [textbox] + btn_list, + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + + return states + model_selectors diff --git a/third_party/FastChat/fastchat/serve/gradio_block_arena_named.py b/third_party/FastChat/fastchat/serve/gradio_block_arena_named.py new file mode 100644 index 00000000..60823f06 --- /dev/null +++ b/third_party/FastChat/fastchat/serve/gradio_block_arena_named.py @@ -0,0 +1,457 @@ +""" +Chatbot Arena (side-by-side) tab. +Users chat with two chosen models. +""" + +import json +import time + +import gradio as gr +import numpy as np + +from fastchat.constants import ( + MODERATION_MSG, + CONVERSATION_LIMIT_MSG, + INPUT_CHAR_LEN_LIMIT, + CONVERSATION_TURN_LIMIT, +) +from fastchat.model.model_adapter import get_conversation_template +from fastchat.serve.gradio_web_server import ( + State, + bot_response, + get_conv_log_filename, + no_change_btn, + enable_btn, + disable_btn, + invisible_btn, + acknowledgment_md, + get_model_description_md, + ip_expiration_dict, + get_ip, +) +from fastchat.utils import ( + build_logger, + moderation_filter, +) + + +logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") + +num_sides = 2 +enable_moderation = False + + +def set_global_vars_named(enable_moderation_): + global enable_moderation + enable_moderation = enable_moderation_ + + +def load_demo_side_by_side_named(models, url_params): + states = (None,) * num_sides + + model_left = models[0] if len(models) > 0 else "" + if len(models) > 1: + weights = ([8] * 4 + [4] * 8 + [1] * 32)[: len(models) - 1] + weights = weights / np.sum(weights) + model_right = np.random.choice(models[1:], p=weights) + else: + model_right = model_left + + selector_updates = ( + gr.Dropdown.update(choices=models, value=model_left, visible=True), + gr.Dropdown.update(choices=models, value=model_right, visible=True), + ) + + return states + selector_updates + + +def vote_last_response(states, vote_type, model_selectors, request: gr.Request): + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "models": [x for x in model_selectors], + "states": [x.dict() for x in states], + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + + +def leftvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"leftvote (named). ip: {get_ip(request)}") + vote_last_response( + [state0, state1], "leftvote", [model_selector0, model_selector1], request + ) + return ("",) + (disable_btn,) * 4 + + +def rightvote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"rightvote (named). ip: {get_ip(request)}") + vote_last_response( + [state0, state1], "rightvote", [model_selector0, model_selector1], request + ) + return ("",) + (disable_btn,) * 4 + + +def tievote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"tievote (named). ip: {get_ip(request)}") + vote_last_response( + [state0, state1], "tievote", [model_selector0, model_selector1], request + ) + return ("",) + (disable_btn,) * 4 + + +def bothbad_vote_last_response( + state0, state1, model_selector0, model_selector1, request: gr.Request +): + logger.info(f"bothbad_vote (named). ip: {get_ip(request)}") + vote_last_response( + [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request + ) + return ("",) + (disable_btn,) * 4 + + +def regenerate(state0, state1, request: gr.Request): + logger.info(f"regenerate (named). ip: {get_ip(request)}") + states = [state0, state1] + for i in range(num_sides): + states[i].conv.update_last_message(None) + return states + [x.to_gradio_chatbot() for x in states] + [""] + [disable_btn] * 6 + + +def clear_history(request: gr.Request): + logger.info(f"clear_history (named). ip: {get_ip(request)}") + return ( + [None] * num_sides + + [None] * num_sides + + [""] + + [invisible_btn] * 4 + + [disable_btn] * 2 + ) + + +def share_click(state0, state1, model_selector0, model_selector1, request: gr.Request): + logger.info(f"share (named). ip: {get_ip(request)}") + if state0 is not None and state1 is not None: + vote_last_response( + [state0, state1], "share", [model_selector0, model_selector1], request + ) + + +def add_text( + state0, state1, model_selector0, model_selector1, text, request: gr.Request +): + ip = get_ip(request) + logger.info(f"add_text (named). ip: {ip}. len: {len(text)}") + states = [state0, state1] + model_selectors = [model_selector0, model_selector1] + + # Init states if necessary + for i in range(num_sides): + if states[i] is None: + states[i] = State(model_selectors[i]) + + if len(text) <= 0: + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [""] + + [ + no_change_btn, + ] + * 6 + ) + + model_list = [states[i].model_name for i in range(num_sides)] + flagged = moderation_filter(text, model_list) + if flagged: + logger.info(f"violate moderation (named). ip: {ip}. text: {text}") + # overwrite the original text + text = MODERATION_MSG + + conv = states[0].conv + if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: + logger.info(f"conversation turn limit. ip: {ip}. text: {text}") + for i in range(num_sides): + states[i].skip_next = True + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [CONVERSATION_LIMIT_MSG] + + [ + no_change_btn, + ] + * 6 + ) + + text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off + for i in range(num_sides): + states[i].conv.append_message(states[i].conv.roles[0], text) + states[i].conv.append_message(states[i].conv.roles[1], None) + states[i].skip_next = False + + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [""] + + [ + disable_btn, + ] + * 6 + ) + + +def bot_response_multi( + state0, + state1, + temperature, + top_p, + max_new_tokens, + request: gr.Request, +): + logger.info(f"bot_response_multi (named). ip: {get_ip(request)}") + + if state0.skip_next: + # This generate call is skipped due to invalid inputs + yield ( + state0, + state1, + state0.to_gradio_chatbot(), + state1.to_gradio_chatbot(), + ) + (no_change_btn,) * 6 + return + + states = [state0, state1] + gen = [] + for i in range(num_sides): + gen.append( + bot_response( + states[i], + temperature, + top_p, + max_new_tokens, + request, + ) + ) + + chatbots = [None] * num_sides + while True: + stop = True + for i in range(num_sides): + try: + ret = next(gen[i]) + states[i], chatbots[i] = ret[0], ret[1] + stop = False + except StopIteration: + pass + yield states + chatbots + [disable_btn] * 6 + if stop: + break + + +def flash_buttons(): + btn_updates = [ + [disable_btn] * 4 + [enable_btn] * 2, + [enable_btn] * 6, + ] + for i in range(4): + yield btn_updates[i % 2] + time.sleep(0.5) + + +def build_side_by_side_ui_named(models): + notice_markdown = """ +# ⚔️ Chatbot Arena ⚔️ : Benchmarking LLMs in the Wild +| [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | + +## 📜 Rules +- Chat with any two models side-by-side and vote! +- You can continue chatting for multiple rounds. +- Click "Clear history" to start a new round. + +## 🤖 Choose two models to compare +""" + + states = [gr.State() for _ in range(num_sides)] + model_selectors = [None] * num_sides + chatbots = [None] * num_sides + + notice = gr.Markdown(notice_markdown, elem_id="notice_markdown") + + with gr.Box(elem_id="share-region-named"): + with gr.Row(): + for i in range(num_sides): + with gr.Column(): + model_selectors[i] = gr.Dropdown( + choices=models, + value=models[i] if len(models) > i else "", + interactive=True, + show_label=False, + container=False, + ) + with gr.Row(): + with gr.Accordion("🔍 Expand to see 20+ model descriptions", open=False): + model_description_md = get_model_description_md(models) + gr.Markdown(model_description_md, elem_id="model_description_markdown") + + with gr.Row(): + for i in range(num_sides): + label = "Model A" if i == 0 else "Model B" + with gr.Column(): + chatbots[i] = gr.Chatbot( + label=label, elem_id=f"chatbot", height=550 + ) + + with gr.Row(): + leftvote_btn = gr.Button( + value="👈 A is better", visible=False, interactive=False + ) + rightvote_btn = gr.Button( + value="👉 B is better", visible=False, interactive=False + ) + tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False) + bothbad_btn = gr.Button( + value="👎 Both are bad", visible=False, interactive=False + ) + + with gr.Row(): + textbox = gr.Textbox( + show_label=False, + placeholder="👉 Enter your prompt and press ENTER", + container=False, + elem_id="input_box", + ) + send_btn = gr.Button(value="Send", variant="primary", scale=0) + + with gr.Row() as button_row: + clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + share_btn = gr.Button(value="📷 Share") + + with gr.Accordion("Parameters", open=False) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0.0, + maximum=1.0, + value=1.0, + step=0.1, + interactive=True, + label="Top P", + ) + max_output_tokens = gr.Slider( + minimum=16, + maximum=2048, + value=1024, + step=64, + interactive=True, + label="Max output tokens", + ) + + gr.Markdown(acknowledgment_md, elem_id="ack_markdown") + + # Register listeners + btn_list = [ + leftvote_btn, + rightvote_btn, + tie_btn, + bothbad_btn, + regenerate_btn, + clear_btn, + ] + leftvote_btn.click( + leftvote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + rightvote_btn.click( + rightvote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + tie_btn.click( + tievote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + bothbad_btn.click( + bothbad_vote_last_response, + states + model_selectors, + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + ) + regenerate_btn.click( + regenerate, states, states + chatbots + [textbox] + btn_list + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + clear_btn.click(clear_history, None, states + chatbots + [textbox] + btn_list) + + share_js = """ +function (a, b, c, d) { + const captureElement = document.querySelector('#share-region-named'); + html2canvas(captureElement) + .then(canvas => { + canvas.style.display = 'none' + document.body.appendChild(canvas) + return canvas + }) + .then(canvas => { + const image = canvas.toDataURL('image/png') + const a = document.createElement('a') + a.setAttribute('download', 'chatbot-arena.png') + a.setAttribute('href', image) + a.click() + canvas.remove() + }); + return [a, b, c, d]; +} +""" + share_btn.click(share_click, states + model_selectors, [], _js=share_js) + + for i in range(num_sides): + model_selectors[i].change( + clear_history, None, states + chatbots + [textbox] + btn_list + ) + + textbox.submit( + add_text, + states + model_selectors + [textbox], + states + chatbots + [textbox] + btn_list, + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + send_btn.click( + add_text, + states + model_selectors + [textbox], + states + chatbots + [textbox] + btn_list, + ).then( + bot_response_multi, + states + [temperature, top_p, max_output_tokens], + states + chatbots + btn_list, + ).then( + flash_buttons, [], btn_list + ) + + return states + model_selectors diff --git a/third_party/FastChat/fastchat/serve/gradio_web_server.py b/third_party/FastChat/fastchat/serve/gradio_web_server.py new file mode 100644 index 00000000..523491e2 --- /dev/null +++ b/third_party/FastChat/fastchat/serve/gradio_web_server.py @@ -0,0 +1,901 @@ +""" +The gradio demo server for chatting with a single model. +""" + +import argparse +from collections import defaultdict +import datetime +import json +import os +import random +import time +import uuid + +import gradio as gr +import requests + +from fastchat.conversation import SeparatorStyle +from fastchat.constants import ( + LOGDIR, + WORKER_API_TIMEOUT, + ErrorCode, + MODERATION_MSG, + CONVERSATION_LIMIT_MSG, + RATE_LIMIT_MSG, + SERVER_ERROR_MSG, + INPUT_CHAR_LEN_LIMIT, + CONVERSATION_TURN_LIMIT, + SESSION_EXPIRATION_TIME, +) +from fastchat.model.model_adapter import ( + get_conversation_template, + ANTHROPIC_MODEL_LIST, +) +from fastchat.model.model_registry import get_model_info, model_info +from fastchat.serve.api_provider import ( + anthropic_api_stream_iter, + openai_api_stream_iter, + palm_api_stream_iter, + init_palm_chat, +) +from fastchat.utils import ( + build_logger, + moderation_filter, + get_window_url_params_js, + get_window_url_params_with_tos_js, + parse_gradio_auth_creds, +) + + +logger = build_logger("gradio_web_server", "gradio_web_server.log") + +headers = {"User-Agent": "FastChat Client"} + +no_change_btn = gr.Button.update() +enable_btn = gr.Button.update(interactive=True, visible=True) +disable_btn = gr.Button.update(interactive=False) +invisible_btn = gr.Button.update(interactive=False, visible=False) + +controller_url = None +enable_moderation = False + +acknowledgment_md = """ +### Acknowledgment +
+

We thank Kaggle, MBZUAI, AnyScale, a16z, and HuggingFace for their generous sponsorship.

+ Kaggle + MBZUAI + AnyScale + a16z + HuggingFace +
+""" + +ip_expiration_dict = defaultdict(lambda: 0) + +# Information about custom OpenAI compatible API models. +# JSON file format: +# { +# "vicuna-7b": { +# "model_name": "vicuna-7b-v1.5", +# "api_base": "http://8.8.8.55:5555/v1", +# "api_key": "password" +# }, +# } +openai_compatible_models_info = {} + + +class State: + def __init__(self, model_name): + self.conv = get_conversation_template(model_name) + self.conv_id = uuid.uuid4().hex + self.skip_next = False + self.model_name = model_name + + if model_name in ["palm-2", "gemini-pro"]: + self.palm_chat = init_palm_chat(model_name) + + def to_gradio_chatbot(self): + return self.conv.to_gradio_chatbot() + + def dict(self): + base = self.conv.dict() + base.update( + { + "conv_id": self.conv_id, + "model_name": self.model_name, + } + ) + return base + + +def set_global_vars(controller_url_, enable_moderation_): + global controller_url, enable_moderation + controller_url = controller_url_ + enable_moderation = enable_moderation_ + + +def get_conv_log_filename(): + t = datetime.datetime.now() + name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") + return name + + +def get_model_list( + controller_url, register_openai_compatible_models, add_chatgpt, add_claude, add_palm +): + if controller_url: + ret = requests.post(controller_url + "/refresh_all_workers") + assert ret.status_code == 200 + ret = requests.post(controller_url + "/list_models") + models = ret.json()["models"] + else: + models = [] + + # Add API providers + if register_openai_compatible_models: + global openai_compatible_models_info + openai_compatible_models_info = json.load( + open(register_openai_compatible_models) + ) + models += list(openai_compatible_models_info.keys()) + + if add_chatgpt: + models += [ + "gpt-4-0314", + "gpt-4-0613", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + ] + if add_claude: + models += ["claude-2.1", "claude-2.0", "claude-instant-1"] + if add_palm: + models += ["gemini-pro"] + models = list(set(models)) + + hidden_models = ["gpt-4-0314", "gpt-4-0613"] + for hm in hidden_models: + if hm in models: + del models[models.index(hm)] + + priority = {k: f"___{i:03d}" for i, k in enumerate(model_info)} + models.sort(key=lambda x: priority.get(x, x)) + logger.info(f"Models: {models}") + return models + + +def load_demo_single(models, url_params): + selected_model = models[0] if len(models) > 0 else "" + if "model" in url_params: + model = url_params["model"] + if model in models: + selected_model = model + + dropdown_update = gr.Dropdown.update( + choices=models, value=selected_model, visible=True + ) + + state = None + return state, dropdown_update + + +def load_demo(url_params, request: gr.Request): + global models + + ip = get_ip(request) + logger.info(f"load_demo. ip: {ip}. params: {url_params}") + ip_expiration_dict[ip] = time.time() + SESSION_EXPIRATION_TIME + + if args.model_list_mode == "reload": + models = get_model_list( + controller_url, + args.register_openai_compatible_models, + args.add_chatgpt, + args.add_claude, + args.add_palm, + ) + + return load_demo_single(models, url_params) + + +def vote_last_response(state, vote_type, model_selector, request: gr.Request): + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "model": model_selector, + "state": state.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + + +def upvote_last_response(state, model_selector, request: gr.Request): + ip = get_ip(request) + logger.info(f"upvote. ip: {ip}") + vote_last_response(state, "upvote", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def downvote_last_response(state, model_selector, request: gr.Request): + ip = get_ip(request) + logger.info(f"downvote. ip: {ip}") + vote_last_response(state, "downvote", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def flag_last_response(state, model_selector, request: gr.Request): + ip = get_ip(request) + logger.info(f"flag. ip: {ip}") + vote_last_response(state, "flag", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def regenerate(state, request: gr.Request): + ip = get_ip(request) + logger.info(f"regenerate. ip: {ip}") + state.conv.update_last_message(None) + return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 + + +def clear_history(request: gr.Request): + ip = get_ip(request) + logger.info(f"clear_history. ip: {ip}") + state = None + return (state, [], "") + (disable_btn,) * 5 + + +def get_ip(request: gr.Request): + if "cf-connecting-ip" in request.headers: + ip = request.headers["cf-connecting-ip"] + else: + ip = request.client.host + return ip + + +def add_text(state, model_selector, text, request: gr.Request): + ip = get_ip(request) + logger.info(f"add_text. ip: {ip}. len: {len(text)}") + + if state is None: + state = State(model_selector) + + if len(text) <= 0: + state.skip_next = True + return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5 + + flagged = moderation_filter(text, [state.model_name]) + if flagged: + logger.info(f"violate moderation. ip: {ip}. text: {text}") + # overwrite the original text + text = MODERATION_MSG + + conv = state.conv + if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: + logger.info(f"conversation turn limit. ip: {ip}. text: {text}") + state.skip_next = True + return (state, state.to_gradio_chatbot(), CONVERSATION_LIMIT_MSG) + ( + no_change_btn, + ) * 5 + + text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off + conv.append_message(conv.roles[0], text) + conv.append_message(conv.roles[1], None) + return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 + + +def post_process_code(code): + sep = "\n```" + if sep in code: + blocks = code.split(sep) + if len(blocks) % 2 == 1: + for i in range(1, len(blocks), 2): + blocks[i] = blocks[i].replace("\\_", "_") + code = sep.join(blocks) + return code + + +def model_worker_stream_iter( + conv, + model_name, + worker_addr, + prompt, + temperature, + repetition_penalty, + top_p, + max_new_tokens, +): + # Make requests + gen_params = { + "model": model_name, + "prompt": prompt, + "temperature": temperature, + "repetition_penalty": repetition_penalty, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + "stop": conv.stop_str, + "stop_token_ids": conv.stop_token_ids, + "echo": False, + } + logger.info(f"==== request ====\n{gen_params}") + + # Stream output + response = requests.post( + worker_addr + "/worker_generate_stream", + headers=headers, + json=gen_params, + stream=True, + timeout=WORKER_API_TIMEOUT, + ) + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + yield data + + +def bot_response( + state, + temperature, + top_p, + max_new_tokens, + request: gr.Request, + apply_rate_limit=True, +): + ip = get_ip(request) + logger.info(f"bot_response. ip: {ip}") + start_tstamp = time.time() + temperature = float(temperature) + top_p = float(top_p) + max_new_tokens = int(max_new_tokens) + + if state.skip_next: + # This generate call is skipped due to invalid inputs + state.skip_next = False + yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 + return + + conv, model_name = state.conv, state.model_name + if model_name in openai_compatible_models_info: + model_info = openai_compatible_models_info[model_name] + prompt = conv.to_openai_api_messages() + stream_iter = openai_api_stream_iter( + model_info["model_name"], + prompt, + temperature, + top_p, + max_new_tokens, + api_base=model_info["api_base"], + api_key=model_info["api_key"], + ) + elif model_name in [ + "gpt-3.5-turbo", + "gpt-3.5-turbo-0301", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-turbo", + ]: + # avoid conflict with Azure OpenAI + assert model_name not in openai_compatible_models_info + prompt = conv.to_openai_api_messages() + stream_iter = openai_api_stream_iter( + model_name, prompt, temperature, top_p, max_new_tokens + ) + elif model_name in ANTHROPIC_MODEL_LIST: + prompt = conv.get_prompt() + stream_iter = anthropic_api_stream_iter( + model_name, prompt, temperature, top_p, max_new_tokens + ) + elif model_name in ["palm-2", "gemini-pro"]: + stream_iter = palm_api_stream_iter( + model_name, + state.palm_chat, + conv.messages[-2][1], + temperature, + top_p, + max_new_tokens, + ) + else: + # Query worker address + ret = requests.post( + controller_url + "/get_worker_address", json={"model": model_name} + ) + worker_addr = ret.json()["address"] + logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") + + # No available worker + if worker_addr == "": + conv.update_last_message(SERVER_ERROR_MSG) + yield ( + state, + state.to_gradio_chatbot(), + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn, + ) + return + + # Construct prompt. + # We need to call it here, so it will not be affected by "▌". + prompt = conv.get_prompt() + + # Set repetition_penalty + if "t5" in model_name: + repetition_penalty = 1.2 + else: + repetition_penalty = 1.0 + + stream_iter = model_worker_stream_iter( + conv, + model_name, + worker_addr, + prompt, + temperature, + repetition_penalty, + top_p, + max_new_tokens, + ) + + conv.update_last_message("▌") + yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 + + try: + for i, data in enumerate(stream_iter): + if data["error_code"] == 0: + output = data["text"].strip() + conv.update_last_message(output + "▌") + yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 + else: + output = data["text"] + f"\n\n(error_code: {data['error_code']})" + conv.update_last_message(output) + yield (state, state.to_gradio_chatbot()) + ( + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn, + ) + return + output = data["text"].strip() + if "vicuna" in model_name: + output = post_process_code(output) + conv.update_last_message(output) + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + except requests.exceptions.RequestException as e: + conv.update_last_message( + f"{SERVER_ERROR_MSG}\n\n" + f"(error_code: {ErrorCode.GRADIO_REQUEST_ERROR}, {e})" + ) + yield (state, state.to_gradio_chatbot()) + ( + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn, + ) + return + except Exception as e: + conv.update_last_message( + f"{SERVER_ERROR_MSG}\n\n" + f"(error_code: {ErrorCode.GRADIO_STREAM_UNKNOWN_ERROR}, {e})" + ) + yield (state, state.to_gradio_chatbot()) + ( + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn, + ) + return + + finish_tstamp = time.time() + logger.info(f"{output}") + + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name, + "gen_params": { + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + }, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + + +block_css = """ +#notice_markdown { + font-size: 110% +} +#notice_markdown th { + display: none; +} +#notice_markdown td { + padding-top: 6px; + padding-bottom: 6px; +} +#model_description_markdown { + font-size: 110% +} +#leaderboard_markdown { + font-size: 110% +} +#leaderboard_markdown td { + padding-top: 6px; + padding-bottom: 6px; +} +#leaderboard_dataframe td { + line-height: 0.1em; +} +#about_markdown { + font-size: 110% +} +#ack_markdown { + font-size: 110% +} +#input_box textarea { +} +footer { + display:none !important +} +.image-container { + display: flex; + align-items: center; + padding: 1px; +} +.image-container img { + margin: 0 30px; + height: 30px; + max-height: 100%; + width: auto; + max-width: 20%; +} +.image-about img { + margin: 0 30px; + margin-top: 30px; + height: 60px; + max-height: 100%; + width: auto; + max-width: 20%; + float: left; +} +""" + + +def get_model_description_md(models): + model_description_md = """ +| | | | +| ---- | ---- | ---- | +""" + ct = 0 + visited = set() + for i, name in enumerate(models): + minfo = get_model_info(name) + if minfo.simple_name in visited: + continue + visited.add(minfo.simple_name) + one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}" + + if ct % 3 == 0: + model_description_md += "|" + model_description_md += f" {one_model_md} |" + if ct % 3 == 2: + model_description_md += "\n" + ct += 1 + return model_description_md + + +def build_about(): + about_markdown = f""" +# About Us +Chatbot Arena is an open-source research project developed by members from [LMSYS](https://lmsys.org/about/) and UC Berkeley [SkyLab](https://sky.cs.berkeley.edu/). Our mission is to build an open crowdsourced platform to collect human feedback and evaluate LLMs under real-world scenarios. We open-source our [FastChat](https://github.com/lm-sys/FastChat) project at GitHub and release chat and human feedback datasets [here](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md). We invite everyone to join us in this journey! + +## Read More +- Chatbot Arena [launch post](https://lmsys.org/blog/2023-05-03-arena/), [data release](https://lmsys.org/blog/2023-07-20-dataset/) +- LMSYS-Chat-1M [report](https://arxiv.org/abs/2309.11998) + +## Core Members +[Lianmin Zheng](https://lmzheng.net/), [Wei-Lin Chiang](https://infwinston.github.io/), [Ying Sheng](https://sites.google.com/view/yingsheng/home), [Siyuan Zhuang](https://scholar.google.com/citations?user=KSZmI5EAAAAJ) + +## Advisors +[Ion Stoica](http://people.eecs.berkeley.edu/~istoica/), [Joseph E. Gonzalez](https://people.eecs.berkeley.edu/~jegonzal/), [Hao Zhang](https://cseweb.ucsd.edu/~haozhang/) + +## Contact Us +- Follow our [Twitter](https://twitter.com/lmsysorg), [Discord](https://discord.gg/HSWAKCrnFx) or email us at lmsys.org@gmail.com +- File issues on [GitHub](https://github.com/lm-sys/FastChat) +- Download our datasets and models on [HuggingFace](https://huggingface.co/lmsys) + +## Acknowledgment +We thank [SkyPilot](https://github.com/skypilot-org/skypilot) and [Gradio](https://github.com/gradio-app/gradio) team for their system support. +We also thank [Kaggle](https://www.kaggle.com/), [MBZUAI](https://mbzuai.ac.ae/), [Anyscale](https://www.anyscale.com/), [a16z](https://www.a16z.com/), [HuggingFace](https://huggingface.co/) for their generous sponsorship. +Learn more about partnership [here](https://lmsys.org/donations/). + +
+ Kaggle + MBZUAI + AnyScale + a16z + HuggingFace +
+""" + + # state = gr.State() + gr.Markdown(about_markdown, elem_id="about_markdown") + + # return [state] + + +def build_single_model_ui(models, add_promotion_links=False): + promotion = ( + """ +- | [GitHub](https://github.com/lm-sys/FastChat) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | +- Introducing Llama 2: The Next Generation Open Source Large Language Model. [[Website]](https://ai.meta.com/llama/) +- Vicuna: An Open-Source Chatbot Impressing GPT-4 with 90% ChatGPT Quality. [[Blog]](https://lmsys.org/blog/2023-03-30-vicuna/) + +## 🤖 Choose any model to chat +""" + if add_promotion_links + else "" + ) + + notice_markdown = f""" +# 🏔️ Chat with Open Large Language Models +{promotion} +""" + + state = gr.State() + gr.Markdown(notice_markdown, elem_id="notice_markdown") + + with gr.Box(elem_id="share-region-named"): + with gr.Row(elem_id="model_selector_row"): + model_selector = gr.Dropdown( + choices=models, + value=models[0] if len(models) > 0 else "", + interactive=True, + show_label=False, + container=False, + ) + with gr.Row(): + with gr.Accordion( + "🔍 Expand to see 20+ model descriptions", + open=False, + elem_id="model_description_accordion", + ): + model_description_md = get_model_description_md(models) + gr.Markdown(model_description_md, elem_id="model_description_markdown") + + chatbot = gr.Chatbot( + elem_id="chatbot", + label="Scroll down and start chatting", + height=550, + ) + with gr.Row(): + textbox = gr.Textbox( + show_label=False, + placeholder="👉 Enter your prompt and press ENTER", + container=False, + elem_id="input_box", + ) + send_btn = gr.Button(value="Send", variant="primary", scale=0) + + with gr.Row() as button_row: + upvote_btn = gr.Button(value="👍 Upvote", interactive=False) + downvote_btn = gr.Button(value="👎 Downvote", interactive=False) + flag_btn = gr.Button(value="⚠️ Flag", interactive=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) + + with gr.Accordion("Parameters", open=False) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0.0, + maximum=1.0, + value=1.0, + step=0.1, + interactive=True, + label="Top P", + ) + max_output_tokens = gr.Slider( + minimum=16, + maximum=2048, + value=1024, + step=64, + interactive=True, + label="Max output tokens", + ) + + if add_promotion_links: + gr.Markdown(acknowledgment_md, elem_id="ack_markdown") + + # Register listeners + btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] + upvote_btn.click( + upvote_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + ) + downvote_btn.click( + downvote_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + ) + flag_btn.click( + flag_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + ) + regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( + bot_response, + [state, temperature, top_p, max_output_tokens], + [state, chatbot] + btn_list, + ) + clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) + + model_selector.change(clear_history, None, [state, chatbot, textbox] + btn_list) + + textbox.submit( + add_text, [state, model_selector, textbox], [state, chatbot, textbox] + btn_list + ).then( + bot_response, + [state, temperature, top_p, max_output_tokens], + [state, chatbot] + btn_list, + ) + send_btn.click( + add_text, + [state, model_selector, textbox], + [state, chatbot, textbox] + btn_list, + ).then( + bot_response, + [state, temperature, top_p, max_output_tokens], + [state, chatbot] + btn_list, + ) + + return [state, model_selector] + + +def build_demo(models): + with gr.Blocks( + title="Chat with Open Large Language Models", + theme=gr.themes.Default(), + css=block_css, + ) as demo: + url_params = gr.JSON(visible=False) + + state, model_selector = build_single_model_ui(models) + + if args.model_list_mode not in ["once", "reload"]: + raise ValueError(f"Unknown model list mode: {args.model_list_mode}") + + if args.show_terms_of_use: + load_js = get_window_url_params_with_tos_js + else: + load_js = get_window_url_params_js + + demo.load( + load_demo, + [url_params], + [ + state, + model_selector, + ], + _js=load_js, + ) + + return demo + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int) + parser.add_argument( + "--share", + action="store_true", + help="Whether to generate a public, shareable link", + ) + parser.add_argument( + "--controller-url", + type=str, + default="http://localhost:21001", + help="The address of the controller", + ) + parser.add_argument( + "--concurrency-count", + type=int, + default=10, + help="The concurrency count of the gradio queue", + ) + parser.add_argument( + "--model-list-mode", + type=str, + default="once", + choices=["once", "reload"], + help="Whether to load the model list once or reload the model list every time", + ) + parser.add_argument( + "--moderate", + action="store_true", + help="Enable content moderation to block unsafe inputs", + ) + parser.add_argument( + "--show-terms-of-use", + action="store_true", + help="Shows term of use before loading the demo", + ) + parser.add_argument( + "--add-chatgpt", + action="store_true", + help="Add OpenAI's ChatGPT models (gpt-3.5-turbo, gpt-4)", + ) + parser.add_argument( + "--add-claude", + action="store_true", + help="Add Anthropic's Claude models (claude-2, claude-instant-1)", + ) + parser.add_argument( + "--add-palm", + action="store_true", + help="Add Google's PaLM model (PaLM 2 for Chat: chat-bison@001)", + ) + parser.add_argument( + "--register-openai-compatible-models", + type=str, + help="Register custom OpenAI API compatible models by loading them from a JSON file", + ) + parser.add_argument( + "--gradio-auth-path", + type=str, + help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', + ) + parser.add_argument( + "--gradio-root-path", + type=str, + help="Sets the gradio root path, eg /abc/def. Useful when running behind a reverse-proxy or at a custom URL path prefix", + ) + args = parser.parse_args() + logger.info(f"args: {args}") + + # Set global variables + set_global_vars(args.controller_url, args.moderate) + models = get_model_list( + args.controller_url, + args.register_openai_compatible_models, + args.add_chatgpt, + args.add_claude, + args.add_palm, + ) + + # Set authorization credentials + auth = None + if args.gradio_auth_path is not None: + auth = parse_gradio_auth_creds(args.gradio_auth_path) + + # Launch the demo + demo = build_demo(models) + demo.queue( + concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False + ).launch( + server_name=args.host, + server_port=args.port, + share=args.share, + max_threads=200, + auth=auth, + root_path=args.gradio_root_path, + ) diff --git a/third_party/FastChat/fastchat/serve/gradio_web_server_multi.py b/third_party/FastChat/fastchat/serve/gradio_web_server_multi.py new file mode 100644 index 00000000..0009c02a --- /dev/null +++ b/third_party/FastChat/fastchat/serve/gradio_web_server_multi.py @@ -0,0 +1,283 @@ +""" +The gradio demo server with multiple tabs. +It supports chatting with a single model or chatting with two models side-by-side. +""" + +import argparse +import pickle +import time + +import gradio as gr + +from fastchat.constants import ( + SESSION_EXPIRATION_TIME, +) +from fastchat.serve.gradio_block_arena_anony import ( + build_side_by_side_ui_anony, + load_demo_side_by_side_anony, + set_global_vars_anony, +) +from fastchat.serve.gradio_block_arena_named import ( + build_side_by_side_ui_named, + load_demo_side_by_side_named, + set_global_vars_named, +) +from fastchat.serve.gradio_web_server import ( + set_global_vars, + block_css, + build_single_model_ui, + build_about, + get_model_list, + load_demo_single, + ip_expiration_dict, + get_ip, +) +from fastchat.serve.monitor.monitor import build_leaderboard_tab +from fastchat.utils import ( + build_logger, + get_window_url_params_js, + get_window_url_params_with_tos_js, + parse_gradio_auth_creds, +) + +logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") + + +def load_demo(url_params, request: gr.Request): + global models + + ip = get_ip(request) + logger.info(f"load_demo. ip: {ip}. params: {url_params}") + ip_expiration_dict[ip] = time.time() + SESSION_EXPIRATION_TIME + + selected = 0 + if "arena" in url_params: + selected = 0 + elif "compare" in url_params: + selected = 1 + elif "single" in url_params: + selected = 2 + elif "leaderboard" in url_params: + selected = 3 + + if args.model_list_mode == "reload": + if args.anony_only_for_proprietary_model: + models = get_model_list( + args.controller_url, + args.register_openai_compatible_models, + False, + False, + False, + ) + else: + models = get_model_list( + args.controller_url, + args.register_openai_compatible_models, + args.add_chatgpt, + args.add_claude, + args.add_palm, + ) + + single_updates = load_demo_single(models, url_params) + + models_anony = list(models) + if args.anony_only_for_proprietary_model: + # Only enable these models in anony battles. + if args.add_chatgpt: + models_anony += [ + "gpt-4-0314", + "gpt-4-0613", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + ] + if args.add_claude: + models_anony += ["claude-2.1", "claude-2.0", "claude-1", "claude-instant-1"] + if args.add_palm: + models_anony += ["gemini-pro"] + anony_only_models = [ + "claude-1", + "gpt-4-0314", + "gpt-4-0613", + ] + for mdl in anony_only_models: + models_anony.append(mdl) + models_anony = list(set(models_anony)) + + side_by_side_anony_updates = load_demo_side_by_side_anony(models_anony, url_params) + side_by_side_named_updates = load_demo_side_by_side_named(models, url_params) + return ( + (gr.Tabs.update(selected=selected),) + + single_updates + + side_by_side_anony_updates + + side_by_side_named_updates + ) + + +def build_demo(models, elo_results_file, leaderboard_table_file): + text_size = gr.themes.sizes.text_md + with gr.Blocks( + title="Chat with Open Large Language Models", + theme=gr.themes.Default(text_size=text_size), + css=block_css, + ) as demo: + with gr.Tabs() as tabs: + with gr.Tab("Arena (battle)", id=0): + side_by_side_anony_list = build_side_by_side_ui_anony(models) + + with gr.Tab("Arena (side-by-side)", id=1): + side_by_side_named_list = build_side_by_side_ui_named(models) + + with gr.Tab("Direct Chat", id=2): + single_model_list = build_single_model_ui( + models, add_promotion_links=True + ) + if elo_results_file: + with gr.Tab("Leaderboard", id=3): + build_leaderboard_tab(elo_results_file, leaderboard_table_file) + with gr.Tab("About Us", id=4): + about = build_about() + + url_params = gr.JSON(visible=False) + + if args.model_list_mode not in ["once", "reload"]: + raise ValueError(f"Unknown model list mode: {args.model_list_mode}") + + if args.show_terms_of_use: + load_js = get_window_url_params_with_tos_js + else: + load_js = get_window_url_params_js + + demo.load( + load_demo, + [url_params], + [tabs] + + single_model_list + + side_by_side_anony_list + + side_by_side_named_list, + _js=load_js, + ) + + return demo + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int) + parser.add_argument( + "--share", + action="store_true", + help="Whether to generate a public, shareable link", + ) + parser.add_argument( + "--controller-url", + type=str, + default="http://localhost:21001", + help="The address of the controller", + ) + parser.add_argument( + "--concurrency-count", + type=int, + default=10, + help="The concurrency count of the gradio queue", + ) + parser.add_argument( + "--model-list-mode", + type=str, + default="once", + choices=["once", "reload"], + help="Whether to load the model list once or reload the model list every time.", + ) + parser.add_argument( + "--moderate", + action="store_true", + help="Enable content moderation to block unsafe inputs", + ) + parser.add_argument( + "--show-terms-of-use", + action="store_true", + help="Shows term of use before loading the demo", + ) + parser.add_argument( + "--add-chatgpt", + action="store_true", + help="Add OpenAI's ChatGPT models (gpt-3.5-turbo, gpt-4)", + ) + parser.add_argument( + "--add-claude", + action="store_true", + help="Add Anthropic's Claude models (claude-2, claude-instant-1)", + ) + parser.add_argument( + "--add-palm", + action="store_true", + help="Add Google's PaLM model (PaLM 2 for Chat: chat-bison@001)", + ) + parser.add_argument( + "--anony-only-for-proprietary-model", + action="store_true", + help="Only add ChatGPT, Claude, Bard under anony battle tab", + ) + parser.add_argument( + "--register-openai-compatible-models", + type=str, + help="Register custom OpenAI API compatible models by loading them from a JSON file", + ) + parser.add_argument( + "--gradio-auth-path", + type=str, + help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', + default=None, + ) + parser.add_argument( + "--elo-results-file", type=str, help="Load leaderboard results and plots" + ) + parser.add_argument( + "--leaderboard-table-file", type=str, help="Load leaderboard results and plots" + ) + parser.add_argument( + "--gradio-root-path", + type=str, + help="Sets the gradio root path, eg /abc/def. Useful when running behind a reverse-proxy or at a custom URL path prefix", + ) + args = parser.parse_args() + logger.info(f"args: {args}") + + # Set global variables + set_global_vars(args.controller_url, args.moderate) + set_global_vars_named(args.moderate) + set_global_vars_anony(args.moderate) + if args.anony_only_for_proprietary_model: + models = get_model_list( + args.controller_url, + args.register_openai_compatible_models, + False, + False, + False, + ) + else: + models = get_model_list( + args.controller_url, + args.register_openai_compatible_models, + args.add_chatgpt, + args.add_claude, + args.add_palm, + ) + + # Set authorization credentials + auth = None + if args.gradio_auth_path is not None: + auth = parse_gradio_auth_creds(args.gradio_auth_path) + + # Launch the demo + demo = build_demo(models, args.elo_results_file, args.leaderboard_table_file) + demo.queue( + concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False + ).launch( + server_name=args.host, + server_port=args.port, + share=args.share, + max_threads=200, + auth=auth, + root_path=args.gradio_root_path, + ) diff --git a/third_party/FastChat/fastchat/serve/huggingface_api.py b/third_party/FastChat/fastchat/serve/huggingface_api.py new file mode 100644 index 00000000..8022fbc9 --- /dev/null +++ b/third_party/FastChat/fastchat/serve/huggingface_api.py @@ -0,0 +1,73 @@ +""" +Use FastChat with Hugging Face generation APIs. + +Usage: +python3 -m fastchat.serve.huggingface_api --model lmsys/vicuna-7b-v1.5 +python3 -m fastchat.serve.huggingface_api --model lmsys/fastchat-t5-3b-v1.0 +""" +import argparse + +import torch + +from fastchat.model import load_model, get_conversation_template, add_model_args + + +@torch.inference_mode() +def main(args): + # Load model + model, tokenizer = load_model( + args.model_path, + device=args.device, + num_gpus=args.num_gpus, + max_gpu_memory=args.max_gpu_memory, + load_8bit=args.load_8bit, + cpu_offloading=args.cpu_offloading, + revision=args.revision, + debug=args.debug, + ) + + # Build the prompt with a conversation template + msg = args.message + conv = get_conversation_template(args.model_path) + conv.append_message(conv.roles[0], msg) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + # Run inference + inputs = tokenizer([prompt], return_tensors="pt").to(args.device) + output_ids = model.generate( + **inputs, + do_sample=True if args.temperature > 1e-5 else False, + temperature=args.temperature, + repetition_penalty=args.repetition_penalty, + max_new_tokens=args.max_new_tokens, + ) + + if model.config.is_encoder_decoder: + output_ids = output_ids[0] + else: + output_ids = output_ids[0][len(inputs["input_ids"][0]) :] + outputs = tokenizer.decode( + output_ids, skip_special_tokens=True, spaces_between_special_tokens=False + ) + + # Print results + print(f"{conv.roles[0]}: {msg}") + print(f"{conv.roles[1]}: {outputs}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + add_model_args(parser) + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--repetition_penalty", type=float, default=1.0) + parser.add_argument("--max-new-tokens", type=int, default=1024) + parser.add_argument("--debug", action="store_true") + parser.add_argument("--message", type=str, default="Hello! Who are you?") + args = parser.parse_args() + + # Reset default repetition penalty for T5 models. + if "t5" in args.model_path and args.repetition_penalty == 1.0: + args.repetition_penalty = 1.2 + + main(args) diff --git a/third_party/FastChat/fastchat/serve/huggingface_api_worker.py b/third_party/FastChat/fastchat/serve/huggingface_api_worker.py new file mode 100644 index 00000000..6ed8e6c8 --- /dev/null +++ b/third_party/FastChat/fastchat/serve/huggingface_api_worker.py @@ -0,0 +1,415 @@ +""" +A model worker that calls huggingface inference endpoint. + +Register models in a JSON file with the following format: +{ + "falcon-180b-chat": { + "model_name": "falcon-180B-chat", + "api_base": "https://api-inference.huggingface.co/models", + "model_path": "tiiuae/falcon-180B-chat", + "token": "hf_XXX", + "context_length": 2048 + }, + "zephyr-7b-beta": { + "model_name": "zephyr-7b-beta", + "model_path": "", + "api_base": "xxx", + "token": "hf_XXX", + "context_length": 4096 + } +} + +"model_path", "api_base", "token", and "context_length" are necessary, while others are optional. +""" +import argparse +import asyncio +import json +import uuid +import os +from typing import List, Optional + +import requests +import uvicorn +from fastapi import BackgroundTasks, FastAPI, Request +from fastapi.responses import JSONResponse, StreamingResponse +from huggingface_hub import InferenceClient + +from fastchat.constants import SERVER_ERROR_MSG, ErrorCode +from fastchat.serve.base_model_worker import BaseModelWorker +from fastchat.utils import build_logger + +worker_id = str(uuid.uuid4())[:8] +logger = build_logger("model_worker", f"model_worker_{worker_id}.log") + +workers = [] +worker_map = {} +app = FastAPI() + + +# reference to +# https://github.com/philschmid/easyllm/blob/cbd908b3b3f44a97a22cb0fc2c93df3660bacdad/easyllm/clients/huggingface.py#L374-L392 +def get_gen_kwargs( + params, + seed: Optional[int] = None, +): + stop = params.get("stop", None) + if isinstance(stop, list): + stop_sequences = stop + elif isinstance(stop, str): + stop_sequences = [stop] + else: + stop_sequences = [] + gen_kwargs = { + "do_sample": True, + "return_full_text": bool(params.get("echo", False)), + "max_new_tokens": int(params.get("max_new_tokens", 256)), + "top_p": float(params.get("top_p", 1.0)), + "temperature": float(params.get("temperature", 1.0)), + "stop_sequences": stop_sequences, + "repetition_penalty": float(params.get("repetition_penalty", 1.0)), + "top_k": params.get("top_k", None), + "seed": seed, + } + if gen_kwargs["top_p"] == 1: + gen_kwargs["top_p"] = 0.9999999 + if gen_kwargs["top_p"] == 0: + gen_kwargs.pop("top_p") + if gen_kwargs["temperature"] == 0: + gen_kwargs.pop("temperature") + gen_kwargs["do_sample"] = False + return gen_kwargs + + +def could_be_stop(text, stop): + for s in stop: + if any(text.endswith(s[:i]) for i in range(1, len(s) + 1)): + return True + return False + + +class HuggingfaceApiWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + api_base: str, + token: str, + context_length: int, + model_names: List[str], + limit_worker_concurrency: int, + no_register: bool, + conv_template: Optional[str] = None, + seed: Optional[int] = None, + **kwargs, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template=conv_template, + ) + + self.model_path = model_path + self.api_base = api_base + self.token = token + self.context_len = context_length + self.seed = seed + + logger.info( + f"Connecting with huggingface api {self.model_path} as {self.model_names} on worker {worker_id} ..." + ) + + if not no_register: + self.init_heart_beat() + + def count_token(self, params): + # No tokenizer here + ret = { + "count": 0, + "error_code": 0, + } + return ret + + def generate_stream_gate(self, params): + self.call_ct += 1 + + prompt = params["prompt"] + gen_kwargs = get_gen_kwargs(params, seed=self.seed) + stop = gen_kwargs["stop_sequences"] + if "falcon" in self.model_path and "chat" in self.model_path: + stop.extend(["\nUser:", "<|endoftext|>", " User:", "###"]) + stop = list(set(stop)) + gen_kwargs["stop_sequences"] = stop + + logger.info(f"prompt: {prompt}") + logger.info(f"gen_kwargs: {gen_kwargs}") + + try: + if self.model_path == "": + url = f"{self.api_base}" + else: + url = f"{self.api_base}/{self.model_path}" + client = InferenceClient(url, token=self.token) + res = client.text_generation( + prompt, stream=True, details=True, **gen_kwargs + ) + + reason = None + text = "" + for chunk in res: + if chunk.token.special: + continue + text += chunk.token.text + + s = next((x for x in stop if text.endswith(x)), None) + if s is not None: + text = text[: -len(s)] + reason = "stop" + break + if could_be_stop(text, stop): + continue + if ( + chunk.details is not None + and chunk.details.finish_reason is not None + ): + reason = chunk.details.finish_reason + if reason not in ["stop", "length"]: + reason = None + ret = { + "text": text, + "error_code": 0, + "finish_reason": reason, + } + yield json.dumps(ret).encode() + b"\0" + except Exception as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, + } + yield json.dumps(ret).encode() + b"\0" + + def generate_gate(self, params): + for x in self.generate_stream_gate(params): + pass + return json.loads(x[:-1].decode()) + + def get_embeddings(self, params): + raise NotImplementedError() + + +def release_worker_semaphore(worker): + worker.semaphore.release() + + +def acquire_worker_semaphore(worker): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(worker): + background_tasks = BackgroundTasks() + background_tasks.add_task(lambda: release_worker_semaphore(worker)) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + await acquire_worker_semaphore(worker) + generator = worker.generate_stream_gate(params) + background_tasks = create_background_tasks(worker) + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + await acquire_worker_semaphore(worker) + output = worker.generate_gate(params) + release_worker_semaphore(worker) + return JSONResponse(output) + + +@app.post("/worker_get_embeddings") +async def api_get_embeddings(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + await acquire_worker_semaphore(worker) + embedding = worker.get_embeddings(params) + release_worker_semaphore(worker) + return JSONResponse(content=embedding) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return { + "model_names": [m for w in workers for m in w.model_names], + "speed": 1, + "queue_length": sum([w.get_queue_length() for w in workers]), + } + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + return {"context_length": worker.context_len} + + +def create_huggingface_api_worker(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + # all model-related parameters are listed in --model-info-file + parser.add_argument( + "--model-info-file", + type=str, + required=True, + help="Huggingface API model's info file path", + ) + + parser.add_argument( + "--limit-worker-concurrency", + type=int, + default=5, + help="Limit the model concurrency to prevent OOM.", + ) + parser.add_argument("--no-register", action="store_true") + parser.add_argument( + "--seed", + type=int, + default=None, + help="Overwrite the random seed for each generation.", + ) + parser.add_argument( + "--ssl", + action="store_true", + required=False, + default=False, + help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", + ) + args = parser.parse_args() + + with open(args.model_info_file, "r", encoding="UTF-8") as f: + model_info = json.load(f) + + logger.info(f"args: {args}") + + model_path_list = [] + api_base_list = [] + token_list = [] + context_length_list = [] + model_names_list = [] + conv_template_list = [] + + for m in model_info: + model_path_list.append(model_info[m]["model_path"]) + api_base_list.append(model_info[m]["api_base"]) + token_list.append(model_info[m]["token"]) + + context_length = model_info[m]["context_length"] + model_names = model_info[m].get("model_names", [m.split("/")[-1]]) + if isinstance(model_names, str): + model_names = [model_names] + conv_template = model_info[m].get("conv_template", None) + + context_length_list.append(context_length) + model_names_list.append(model_names) + conv_template_list.append(conv_template) + + logger.info(f"Model paths: {model_path_list}") + logger.info(f"API bases: {api_base_list}") + logger.info(f"Tokens: {token_list}") + logger.info(f"Context lengths: {context_length_list}") + logger.info(f"Model names: {model_names_list}") + logger.info(f"Conv templates: {conv_template_list}") + + for ( + model_names, + conv_template, + model_path, + api_base, + token, + context_length, + ) in zip( + model_names_list, + conv_template_list, + model_path_list, + api_base_list, + token_list, + context_length_list, + ): + m = HuggingfaceApiWorker( + args.controller_address, + args.worker_address, + worker_id, + model_path, + api_base, + token, + context_length, + model_names, + args.limit_worker_concurrency, + no_register=args.no_register, + conv_template=conv_template, + seed=args.seed, + ) + workers.append(m) + for name in model_names: + worker_map[name] = m + + # register all the models + url = args.controller_address + "/register_worker" + data = { + "worker_name": workers[0].worker_addr, + "check_heart_beat": not args.no_register, + "worker_status": { + "model_names": [m for w in workers for m in w.model_names], + "speed": 1, + "queue_length": sum([w.get_queue_length() for w in workers]), + }, + } + r = requests.post(url, json=data) + assert r.status_code == 200 + + return args, workers + + +if __name__ == "__main__": + args, workers = create_huggingface_api_worker() + if args.ssl: + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + ssl_keyfile=os.environ["SSL_KEYFILE"], + ssl_certfile=os.environ["SSL_CERTFILE"], + ) + else: + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/third_party/FastChat/fastchat/serve/inference.py b/third_party/FastChat/fastchat/serve/inference.py new file mode 100644 index 00000000..6d155aab --- /dev/null +++ b/third_party/FastChat/fastchat/serve/inference.py @@ -0,0 +1,555 @@ +"""Inference for FastChat models.""" +import abc +import gc +import json +import math +import os +import sys +import time +from typing import Iterable, Optional, Dict +import warnings + +import psutil +import torch +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + LlamaTokenizer, + LlamaForCausalLM, + AutoModel, + AutoModelForSeq2SeqLM, + T5Tokenizer, + AutoConfig, +) +from transformers.generation.logits_process import ( + LogitsProcessorList, + RepetitionPenaltyLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, +) + +from fastchat.conversation import get_conv_template, SeparatorStyle +from fastchat.model.model_adapter import ( + load_model, + get_conversation_template, + get_generate_stream_function, +) +from fastchat.modules.awq import AWQConfig +from fastchat.modules.gptq import GptqConfig +from fastchat.modules.exllama import ExllamaConfig +from fastchat.modules.xfastertransformer import XftConfig +from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length + + +def prepare_logits_processor( + temperature: float, repetition_penalty: float, top_p: float, top_k: int +) -> LogitsProcessorList: + processor_list = LogitsProcessorList() + # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases. + if temperature >= 1e-5 and temperature != 1.0: + processor_list.append(TemperatureLogitsWarper(temperature)) + if repetition_penalty > 1.0: + processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) + if 1e-8 <= top_p < 1.0: + processor_list.append(TopPLogitsWarper(top_p)) + if top_k > 0: + processor_list.append(TopKLogitsWarper(top_k)) + return processor_list + + +@torch.inference_mode() +def generate_stream( + model, + tokenizer, + params: Dict, + device: str, + context_len: int, + stream_interval: int = 2, + judge_sent_end: bool = False, +): + if hasattr(model, "device"): + device = model.device + + # Read parameters + prompt = params["prompt"] + len_prompt = len(prompt) + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = int(params.get("top_k", -1)) # -1 means disable + max_new_tokens = int(params.get("max_new_tokens", 256)) + logprobs = params.get("logprobs", None) # FIXME: Support logprobs>1. + echo = bool(params.get("echo", True)) + stop_str = params.get("stop", None) + stop_token_ids = params.get("stop_token_ids", None) or [] + if tokenizer.eos_token_id not in stop_token_ids: + stop_token_ids.append(tokenizer.eos_token_id) + + logits_processor = prepare_logits_processor( + temperature, repetition_penalty, top_p, top_k + ) + input_ids = tokenizer(prompt).input_ids + + if model.config.is_encoder_decoder: + max_src_len = context_len + else: # truncate + max_src_len = context_len - max_new_tokens - 1 + + input_ids = input_ids[-max_src_len:] + output_ids = list(input_ids) + input_echo_len = len(input_ids) + + if model.config.is_encoder_decoder: + if logprobs is not None: # FIXME: Support logprobs for encoder-decoder models. + raise NotImplementedError + encoder_output = model.encoder( + input_ids=torch.as_tensor([input_ids], device=device) + )[0] + start_ids = torch.as_tensor( + [[model.generation_config.decoder_start_token_id]], + dtype=torch.int64, + device=device, + ) + else: + start_ids = torch.as_tensor([input_ids], device=device) + + past_key_values = out = None + token_logprobs = [None] # The first token has no logprobs. + sent_interrupt = False + finish_reason = None + stopped = False + for i in range(max_new_tokens): + if i == 0: # prefill + if model.config.is_encoder_decoder: + out = model.decoder( + input_ids=start_ids, + encoder_hidden_states=encoder_output, + use_cache=True, + ) + logits = model.lm_head(out[0]) + else: + out = model(input_ids=start_ids, use_cache=True) + logits = out.logits + past_key_values = out.past_key_values + + if logprobs is not None: + # Prefull logprobs for the prompt. + shift_input_ids = start_ids[..., 1:].contiguous() + shift_logits = logits[..., :-1, :].contiguous() + shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist() + for label_id, logit in zip( + shift_input_ids[0].tolist(), shift_logits[0] + ): + token_logprobs.append(logit[label_id]) + else: # decoding + if model.config.is_encoder_decoder: + out = model.decoder( + input_ids=torch.as_tensor( + [[token] if not sent_interrupt else output_ids], + device=device, + ), + encoder_hidden_states=encoder_output, + use_cache=True, + past_key_values=past_key_values if not sent_interrupt else None, + ) + sent_interrupt = False + + logits = model.lm_head(out[0]) + else: + out = model( + input_ids=torch.as_tensor( + [[token] if not sent_interrupt else output_ids], + device=device, + ), + use_cache=True, + past_key_values=past_key_values if not sent_interrupt else None, + ) + sent_interrupt = False + logits = out.logits + past_key_values = out.past_key_values + + if logits_processor: + if repetition_penalty > 1.0: + tmp_output_ids = torch.as_tensor([output_ids], device=logits.device) + else: + tmp_output_ids = None + last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] + else: + last_token_logits = logits[0, -1, :] + + if device == "mps": + # Switch to CPU by avoiding some bugs in mps backend. + last_token_logits = last_token_logits.float().to("cpu") + + if temperature < 1e-5 or top_p < 1e-8: # greedy + _, indices = torch.topk(last_token_logits, 2) + tokens = [int(index) for index in indices.tolist()] + else: + probs = torch.softmax(last_token_logits, dim=-1) + indices = torch.multinomial(probs, num_samples=2) + tokens = [int(token) for token in indices.tolist()] + token = tokens[0] + output_ids.append(token) + if logprobs is not None: + # Cannot use last_token_logits because logprobs is based on raw logits. + token_logprobs.append( + torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist() + ) + + if token in stop_token_ids: + stopped = True + else: + stopped = False + + # Yield the output tokens + if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: + if echo: + tmp_output_ids = output_ids + rfind_start = len_prompt + else: + tmp_output_ids = output_ids[input_echo_len:] + rfind_start = 0 + + output = tokenizer.decode( + tmp_output_ids, + skip_special_tokens=True, + spaces_between_special_tokens=False, + clean_up_tokenization_spaces=True, + ) + ret_logprobs = None + if logprobs is not None: + ret_logprobs = { + "text_offset": [], + "tokens": [ + tokenizer.decode(token) + for token in ( + output_ids if echo else output_ids[input_echo_len:] + ) + ], + "token_logprobs": token_logprobs + if echo + else token_logprobs[input_echo_len:], + "top_logprobs": [{}] + * len(token_logprobs if echo else token_logprobs[input_echo_len:]), + } + # Compute text_offset + curr_pos = 0 + for text in ret_logprobs["tokens"]: + ret_logprobs["text_offset"].append(curr_pos) + curr_pos += len(text) + + # TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way + if judge_sent_end and stopped and not is_sentence_complete(output): + if len(tokens) > 1: + token = tokens[1] + output_ids[-1] = token + else: + output_ids.pop() + stopped = False + sent_interrupt = True + + partially_stopped = False + if stop_str: + if isinstance(stop_str, str): + pos = output.rfind(stop_str, rfind_start) + if pos != -1: + output = output[:pos] + stopped = True + else: + partially_stopped = is_partial_stop(output, stop_str) + elif isinstance(stop_str, Iterable): + for each_stop in stop_str: + pos = output.rfind(each_stop, rfind_start) + if pos != -1: + output = output[:pos] + stopped = True + break + else: + partially_stopped = is_partial_stop(output, each_stop) + if partially_stopped: + break + else: + raise ValueError("Invalid stop field type.") + + # Prevent yielding partial stop sequence + if not partially_stopped: + yield { + "text": output, + "logprobs": ret_logprobs, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": None, + } + + if stopped: + break + + # Finish stream event, which contains finish reason + else: + finish_reason = "length" + + if stopped: + finish_reason = "stop" + + yield { + "text": output, + "logprobs": ret_logprobs, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + "finish_reason": finish_reason, + } + + # Clean + del past_key_values, out + gc.collect() + torch.cuda.empty_cache() + if device == "xpu": + torch.xpu.empty_cache() + if device == "npu": + torch.npu.empty_cache() + + +class ChatIO(abc.ABC): + @abc.abstractmethod + def prompt_for_input(self, role: str) -> str: + """Prompt for input from a role.""" + + @abc.abstractmethod + def prompt_for_output(self, role: str): + """Prompt for output from a role.""" + + @abc.abstractmethod + def stream_output(self, output_stream): + """Stream output.""" + + @abc.abstractmethod + def print_output(self, text: str): + """Print output.""" + + +def chat_loop( + model_path: str, + device: str, + num_gpus: int, + max_gpu_memory: str, + dtype: Optional[torch.dtype], + load_8bit: bool, + cpu_offloading: bool, + conv_template: Optional[str], + conv_system_msg: Optional[str], + temperature: float, + repetition_penalty: float, + max_new_tokens: int, + chatio: ChatIO, + gptq_config: Optional[GptqConfig] = None, + awq_config: Optional[AWQConfig] = None, + exllama_config: Optional[ExllamaConfig] = None, + xft_config: Optional[XftConfig] = None, + revision: str = "main", + judge_sent_end: bool = True, + debug: bool = True, + history: bool = True, +): + # Model + model, tokenizer = load_model( + model_path, + device=device, + num_gpus=num_gpus, + max_gpu_memory=max_gpu_memory, + dtype=dtype, + load_8bit=load_8bit, + cpu_offloading=cpu_offloading, + gptq_config=gptq_config, + awq_config=awq_config, + exllama_config=exllama_config, + xft_config=xft_config, + revision=revision, + debug=debug, + ) + generate_stream_func = get_generate_stream_function(model, model_path) + + model_type = str(type(model)).lower() + is_t5 = "t5" in model_type + is_codet5p = "codet5p" in model_type + is_xft = "xft" in model_type + + # Hardcode T5's default repetition penalty to be 1.2 + if is_t5 and repetition_penalty == 1.0: + repetition_penalty = 1.2 + + # Set context length + context_len = get_context_length(model.config) + + # Chat + def new_chat(): + if conv_template: + conv = get_conv_template(conv_template) + else: + conv = get_conversation_template(model_path) + if conv_system_msg is not None: + conv.set_system_message(conv_system_msg) + return conv + + def reload_conv(conv): + """ + Reprints the conversation from the start. + """ + for message in conv.messages[conv.offset :]: + chatio.prompt_for_output(message[0]) + chatio.print_output(message[1]) + + conv = None + + while True: + if not history or not conv: + conv = new_chat() + + try: + inp = chatio.prompt_for_input(conv.roles[0]) + except EOFError: + inp = "" + + if inp == "!!exit" or not inp: + print("exit...") + break + elif inp == "!!reset": + print("resetting...") + conv = new_chat() + continue + elif inp == "!!remove": + print("removing last message...") + if len(conv.messages) > conv.offset: + # Assistant + if conv.messages[-1][0] == conv.roles[1]: + conv.messages.pop() + # User + if conv.messages[-1][0] == conv.roles[0]: + conv.messages.pop() + reload_conv(conv) + else: + print("No messages to remove.") + continue + elif inp == "!!regen": + print("regenerating last message...") + if len(conv.messages) > conv.offset: + # Assistant + if conv.messages[-1][0] == conv.roles[1]: + conv.messages.pop() + # User + if conv.messages[-1][0] == conv.roles[0]: + reload_conv(conv) + # Set inp to previous message + inp = conv.messages.pop()[1] + else: + # Shouldn't happen in normal circumstances + print("No user message to regenerate from.") + continue + else: + print("No messages to regenerate.") + continue + elif inp.startswith("!!save"): + args = inp.split(" ", 1) + + if len(args) != 2: + print("usage: !!save ") + continue + else: + filename = args[1] + + # Add .json if extension not present + if not "." in filename: + filename += ".json" + + print("saving...", filename) + with open(filename, "w") as outfile: + json.dump(conv.dict(), outfile) + continue + elif inp.startswith("!!load"): + args = inp.split(" ", 1) + + if len(args) != 2: + print("usage: !!load ") + continue + else: + filename = args[1] + + # Check if file exists and add .json if needed + if not os.path.exists(filename): + if (not filename.endswith(".json")) and os.path.exists( + filename + ".json" + ): + filename += ".json" + else: + print("file not found:", filename) + continue + + print("loading...", filename) + with open(filename, "r") as infile: + new_conv = json.load(infile) + + conv = get_conv_template(new_conv["template_name"]) + conv.set_system_message(new_conv["system_message"]) + conv.messages = new_conv["messages"] + reload_conv(conv) + continue + + conv.append_message(conv.roles[0], inp) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + if is_codet5p: # codet5p is a code completion model. + prompt = inp + + gen_params = { + "model": model_path, + "prompt": prompt, + "temperature": temperature, + "repetition_penalty": repetition_penalty, + "max_new_tokens": max_new_tokens, + "stop": conv.stop_str, + "stop_token_ids": conv.stop_token_ids, + "echo": False, + } + + try: + chatio.prompt_for_output(conv.roles[1]) + output_stream = generate_stream_func( + model, + tokenizer, + gen_params, + device, + context_len=context_len, + judge_sent_end=judge_sent_end, + ) + t = time.time() + outputs = chatio.stream_output(output_stream) + duration = time.time() - t + conv.update_last_message(outputs.strip()) + + if debug: + num_tokens = len(tokenizer.encode(outputs)) + msg = { + "conv_template": conv.name, + "prompt": prompt, + "outputs": outputs, + "speed (token/s)": round(num_tokens / duration, 2), + } + print(f"\n{msg}\n") + + except KeyboardInterrupt: + print("stopped generation.") + # If generation didn't finish + if conv.messages[-1][1] is None: + conv.messages.pop() + # Remove last user message, so there isn't a double up + if conv.messages[-1][0] == conv.roles[0]: + conv.messages.pop() + + reload_conv(conv) diff --git a/third_party/FastChat/fastchat/serve/launch_all_serve.py b/third_party/FastChat/fastchat/serve/launch_all_serve.py new file mode 100644 index 00000000..2f4ad7b0 --- /dev/null +++ b/third_party/FastChat/fastchat/serve/launch_all_serve.py @@ -0,0 +1,284 @@ +""" +Usage: python launch_all_serve_by_shell.py --model-path-address "THUDM/chatglm2-6b@localhost@2021" "huggyllama/llama-7b@localhost@2022" + +Workers are listed in format of `model-path`@`host`@`port` + +The key mechanism behind this scripts is: + 1, execute shell cmd to launch the controller/worker/openai-api-server; + 2, check the log of controller/worker/openai-api-server to ensure that the serve is launched properly. +Note that a few of non-critical `fastchat.serve` cmd options are not supported currently. +""" +import sys +import os + +sys.path.append(os.path.dirname(os.path.dirname(__file__))) + +import subprocess +import re +import argparse + +LOGDIR = "./logs/" + +if not os.path.exists(LOGDIR): + os.makedirs(LOGDIR) + +parser = argparse.ArgumentParser() +# ------multi worker----------------- +parser.add_argument( + "--model-path-address", + default="THUDM/chatglm2-6b@localhost@20002", + nargs="+", + type=str, + help="model path, host, and port, formatted as model-path@host@port", +) +# ---------------controller------------------------- + +parser.add_argument("--controller-host", type=str, default="localhost") +parser.add_argument("--controller-port", type=int, default=21001) +parser.add_argument( + "--dispatch-method", + type=str, + choices=["lottery", "shortest_queue"], + default="shortest_queue", +) +controller_args = ["controller-host", "controller-port", "dispatch-method"] + +# ----------------------worker------------------------------------------ + +parser.add_argument("--worker-host", type=str, default="localhost") +parser.add_argument("--worker-port", type=int, default=21002) +# parser.add_argument("--worker-address", type=str, default="http://localhost:21002") +# parser.add_argument( +# "--controller-address", type=str, default="http://localhost:21001" +# ) +parser.add_argument( + "--model-path", + type=str, + default="lmsys/vicuna-7b-v1.5", + help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", +) +parser.add_argument( + "--revision", + type=str, + default="main", + help="Hugging Face Hub model revision identifier", +) +parser.add_argument( + "--device", + type=str, + choices=["cpu", "cuda", "mps", "xpu", "npu"], + default="cuda", + help="The device type", +) +parser.add_argument( + "--gpus", + type=str, + default="0", + help="A single GPU like 1 or multiple GPUs like 0,2", +) +parser.add_argument("--num-gpus", type=int, default=1) +parser.add_argument( + "--max-gpu-memory", + type=str, + help="The maximum memory per gpu. Use a string like '13Gib'", +) +parser.add_argument("--load-8bit", action="store_true", help="Use 8-bit quantization") +parser.add_argument( + "--cpu-offloading", + action="store_true", + help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU", +) +parser.add_argument( + "--gptq-ckpt", + type=str, + default=None, + help="Load quantized model. The path to the local GPTQ checkpoint.", +) +parser.add_argument( + "--gptq-wbits", + type=int, + default=16, + choices=[2, 3, 4, 8, 16], + help="#bits to use for quantization", +) +parser.add_argument( + "--gptq-groupsize", + type=int, + default=-1, + help="Groupsize to use for quantization; default uses full row.", +) +parser.add_argument( + "--gptq-act-order", + action="store_true", + help="Whether to apply the activation order GPTQ heuristic", +) +parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", +) +parser.add_argument( + "--limit-worker-concurrency", + type=int, + default=5, + help="Limit the model concurrency to prevent OOM.", +) +parser.add_argument("--stream-interval", type=int, default=2) +parser.add_argument("--no-register", action="store_true") + +worker_args = [ + "worker-host", + "worker-port", + "model-path", + "revision", + "device", + "gpus", + "num-gpus", + "max-gpu-memory", + "load-8bit", + "cpu-offloading", + "gptq-ckpt", + "gptq-wbits", + "gptq-groupsize", + "gptq-act-order", + "model-names", + "limit-worker-concurrency", + "stream-interval", + "no-register", + "controller-address", +] +# -----------------openai server--------------------------- + +parser.add_argument("--server-host", type=str, default="localhost", help="host name") +parser.add_argument("--server-port", type=int, default=8001, help="port number") +parser.add_argument( + "--allow-credentials", action="store_true", help="allow credentials" +) +# parser.add_argument( +# "--allowed-origins", type=json.loads, default=["*"], help="allowed origins" +# ) +# parser.add_argument( +# "--allowed-methods", type=json.loads, default=["*"], help="allowed methods" +# ) +# parser.add_argument( +# "--allowed-headers", type=json.loads, default=["*"], help="allowed headers" +# ) +parser.add_argument( + "--api-keys", + type=lambda s: s.split(","), + help="Optional list of comma separated API keys", +) +server_args = [ + "server-host", + "server-port", + "allow-credentials", + "api-keys", + "controller-address", +] + +args = parser.parse_args() + +args = argparse.Namespace( + **vars(args), + **{"controller-address": f"http://{args.controller_host}:{args.controller_port}"}, +) + +if args.gpus: + if len(args.gpus.split(",")) < args.num_gpus: + raise ValueError( + f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" + ) + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + +# 0,controller, model_worker, openai_api_server +# 1, cmd options +# 2,LOGDIR +# 3, log file name +base_launch_sh = "nohup python3 -m fastchat.serve.{0} {1} >{2}/{3}.log 2>&1 &" + +# 0 LOGDIR +#! 1 log file name +# 2 controller, worker, openai_api_server +base_check_sh = """while [ `grep -c "Uvicorn running on" {0}/{1}.log` -eq '0' ];do + sleep 1s; + echo "wait {2} running" + done + echo '{2} running' """ + + +def string_args(args, args_list): + args_str = "" + for key, value in args._get_kwargs(): + key = key.replace("_", "-") + if key not in args_list: + continue + + key = key.split("-")[-1] if re.search("port|host", key) else key + if not value: + pass + # 1==True -> True + elif isinstance(value, bool) and value == True: + args_str += f" --{key} " + elif ( + isinstance(value, list) + or isinstance(value, tuple) + or isinstance(value, set) + ): + value = " ".join(value) + args_str += f" --{key} {value} " + else: + args_str += f" --{key} {value} " + + return args_str + + +def launch_worker(item): + log_name = ( + item.split("/")[-1] + .split("\\")[-1] + .replace("-", "_") + .replace("@", "_") + .replace(".", "_") + ) + + args.model_path, args.worker_host, args.worker_port = item.split("@") + print("*" * 80) + worker_str_args = string_args(args, worker_args) + print(worker_str_args) + worker_sh = base_launch_sh.format( + "model_worker", worker_str_args, LOGDIR, f"worker_{log_name}" + ) + worker_check_sh = base_check_sh.format(LOGDIR, f"worker_{log_name}", "model_worker") + subprocess.run(worker_sh, shell=True, check=True) + subprocess.run(worker_check_sh, shell=True, check=True) + + +def launch_all(): + controller_str_args = string_args(args, controller_args) + controller_sh = base_launch_sh.format( + "controller", controller_str_args, LOGDIR, "controller" + ) + controller_check_sh = base_check_sh.format(LOGDIR, "controller", "controller") + subprocess.run(controller_sh, shell=True, check=True) + subprocess.run(controller_check_sh, shell=True, check=True) + + if isinstance(args.model_path_address, str): + launch_worker(args.model_path_address) + else: + for idx, item in enumerate(args.model_path_address): + print(f"loading {idx}th model:{item}") + launch_worker(item) + + server_str_args = string_args(args, server_args) + server_sh = base_launch_sh.format( + "openai_api_server", server_str_args, LOGDIR, "openai_api_server" + ) + server_check_sh = base_check_sh.format( + LOGDIR, "openai_api_server", "openai_api_server" + ) + subprocess.run(server_sh, shell=True, check=True) + subprocess.run(server_check_sh, shell=True, check=True) + + +if __name__ == "__main__": + launch_all() diff --git a/third_party/FastChat/fastchat/serve/model_worker.py b/third_party/FastChat/fastchat/serve/model_worker.py new file mode 100644 index 00000000..dac6ac8b --- /dev/null +++ b/third_party/FastChat/fastchat/serve/model_worker.py @@ -0,0 +1,386 @@ +""" +A model worker that executes the model. +""" +import argparse +import base64 +import gc +import json +import os +from typing import List, Optional +import uuid + +import torch +import torch.nn.functional as F +from transformers import set_seed +import uvicorn + +from fastchat.constants import ErrorCode, SERVER_ERROR_MSG +from fastchat.model.model_adapter import ( + load_model, + add_model_args, + get_generate_stream_function, +) +from fastchat.modules.awq import AWQConfig +from fastchat.modules.exllama import ExllamaConfig +from fastchat.modules.xfastertransformer import XftConfig +from fastchat.modules.gptq import GptqConfig +from fastchat.serve.base_model_worker import BaseModelWorker, app +from fastchat.utils import ( + build_logger, + get_context_length, + str_to_torch_dtype, +) + +worker_id = str(uuid.uuid4())[:8] +logger = build_logger("model_worker", f"model_worker_{worker_id}.log") + + +class ModelWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + no_register: bool, + device: str, + num_gpus: int, + max_gpu_memory: str, + revision: str = None, + dtype: Optional[torch.dtype] = None, + load_8bit: bool = False, + cpu_offloading: bool = False, + gptq_config: Optional[GptqConfig] = None, + awq_config: Optional[AWQConfig] = None, + exllama_config: Optional[ExllamaConfig] = None, + xft_config: Optional[XftConfig] = None, + stream_interval: int = 2, + conv_template: Optional[str] = None, + embed_in_truncate: bool = False, + seed: Optional[int] = None, + debug: bool = False, + **kwargs, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template=conv_template, + ) + + logger.info(f"Loading the model {self.model_names} on worker {worker_id} ...") + self.model, self.tokenizer = load_model( + model_path, + revision=revision, + device=device, + num_gpus=num_gpus, + max_gpu_memory=max_gpu_memory, + dtype=dtype, + load_8bit=load_8bit, + cpu_offloading=cpu_offloading, + gptq_config=gptq_config, + awq_config=awq_config, + exllama_config=exllama_config, + xft_config=xft_config, + debug=debug, + ) + self.device = device + if self.tokenizer.pad_token == None: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.context_len = get_context_length(self.model.config) + self.generate_stream_func = get_generate_stream_function(self.model, model_path) + self.stream_interval = stream_interval + self.embed_in_truncate = embed_in_truncate + self.seed = seed + + if not no_register: + self.init_heart_beat() + + def generate_stream_gate(self, params): + if self.device == "npu": + import torch_npu + + torch_npu.npu.set_device("npu:0") + self.call_ct += 1 + + try: + if self.seed is not None: + set_seed(self.seed) + for output in self.generate_stream_func( + self.model, + self.tokenizer, + params, + self.device, + self.context_len, + self.stream_interval, + ): + ret = { + "text": output["text"], + "error_code": 0, + } + if "usage" in output: + ret["usage"] = output["usage"] + if "finish_reason" in output: + ret["finish_reason"] = output["finish_reason"] + if "logprobs" in output: + ret["logprobs"] = output["logprobs"] + yield json.dumps(ret).encode() + b"\0" + except torch.cuda.OutOfMemoryError as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.CUDA_OUT_OF_MEMORY, + } + yield json.dumps(ret).encode() + b"\0" + except (ValueError, RuntimeError) as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, + } + yield json.dumps(ret).encode() + b"\0" + + def generate_gate(self, params): + for x in self.generate_stream_gate(params): + pass + return json.loads(x[:-1].decode()) + + def __process_embed_chunk(self, input_ids, attention_mask, **model_type_dict): + if model_type_dict.get("is_bert"): + model_output = self.model(input_ids) + if model_type_dict.get("is_robert"): + data = model_output.last_hidden_state + else: + data = model_output[0] + elif model_type_dict.get("is_t5"): + model_output = self.model(input_ids, decoder_input_ids=input_ids) + data = model_output.encoder_last_hidden_state + else: + model_output = self.model(input_ids, output_hidden_states=True) + if model_type_dict.get("is_chatglm"): + data = model_output.hidden_states[-1].transpose(0, 1) + else: + data = model_output.hidden_states[-1] + mask = attention_mask.unsqueeze(-1).expand(data.size()).float() + masked_embeddings = data * mask + sum_embeddings = torch.sum(masked_embeddings, dim=1) + token_num = torch.sum(attention_mask).item() + + return sum_embeddings, token_num + + def __encode_base64(self, embeddings: torch.Tensor) -> List[str]: + embeddings = embeddings.cpu() + return [ + base64.b64encode(e.numpy().tobytes()).decode("utf-8") for e in embeddings + ] + + @torch.inference_mode() + def get_embeddings(self, params): + self.call_ct += 1 + + try: + tokenizer = self.tokenizer + ret = {"embedding": [], "token_num": 0} + + model_type_dict = { + "is_llama": "llama" in str(type(self.model)), + "is_t5": "t5" in str(type(self.model)), + "is_chatglm": "chatglm" in str(type(self.model)), + "is_bert": "bert" in str(type(self.model)), + "is_robert": "robert" in str(type(self.model)), + } + + if self.embed_in_truncate: + encoding = tokenizer.batch_encode_plus( + params["input"], + padding=True, + truncation="longest_first", + return_tensors="pt", + max_length=self.context_len, + ) + else: + encoding = tokenizer.batch_encode_plus( + params["input"], padding=True, return_tensors="pt" + ) + input_ids = encoding["input_ids"].to(self.device) + attention_mask = input_ids != tokenizer.pad_token_id + + base64_encode = params.get("encoding_format", None) + + if self.embed_in_truncate: + chunk_embeddings, token_num = self.__process_embed_chunk( + input_ids, attention_mask, **model_type_dict + ) + embedding = chunk_embeddings / token_num + normalized_embeddings = F.normalize(embedding, p=2, dim=1) + ret["token_num"] = token_num + else: + all_embeddings = [] + all_token_num = 0 + for i in range(0, input_ids.size(1), self.context_len): + chunk_input_ids = input_ids[:, i : i + self.context_len] + chunk_attention_mask = attention_mask[:, i : i + self.context_len] + + chunk_embeddings, token_num = self.__process_embed_chunk( + chunk_input_ids, chunk_attention_mask, **model_type_dict + ) + all_embeddings.append(chunk_embeddings) + all_token_num += token_num + + all_embeddings_tensor = torch.stack(all_embeddings) + embedding = torch.sum(all_embeddings_tensor, dim=0) / all_token_num + normalized_embeddings = F.normalize(embedding, p=2, dim=1) + + ret["token_num"] = all_token_num + + if base64_encode == "base64": + out_embeddings = self.__encode_base64(normalized_embeddings) + else: + out_embeddings = normalized_embeddings.tolist() + ret["embedding"] = out_embeddings + + gc.collect() + torch.cuda.empty_cache() + if self.device == "xpu": + torch.xpu.empty_cache() + if self.device == "npu": + torch.npu.empty_cache() + except torch.cuda.OutOfMemoryError as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.CUDA_OUT_OF_MEMORY, + } + except (ValueError, RuntimeError) as e: + ret = { + "text": f"{SERVER_ERROR_MSG}\n\n({e})", + "error_code": ErrorCode.INTERNAL_ERROR, + } + return ret + + +def create_model_worker(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + add_model_args(parser) + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", + ) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument("--embed-in-truncate", action="store_true") + parser.add_argument( + "--limit-worker-concurrency", + type=int, + default=5, + help="Limit the model concurrency to prevent OOM.", + ) + parser.add_argument("--stream-interval", type=int, default=2) + parser.add_argument("--no-register", action="store_true") + parser.add_argument( + "--seed", + type=int, + default=None, + help="Overwrite the random seed for each generation.", + ) + parser.add_argument( + "--debug", type=bool, default=False, help="Print debugging messages" + ) + parser.add_argument( + "--ssl", + action="store_true", + required=False, + default=False, + help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", + ) + args = parser.parse_args() + logger.info(f"args: {args}") + + if args.gpus: + if len(args.gpus.split(",")) < args.num_gpus: + raise ValueError( + f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" + ) + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + + gptq_config = GptqConfig( + ckpt=args.gptq_ckpt or args.model_path, + wbits=args.gptq_wbits, + groupsize=args.gptq_groupsize, + act_order=args.gptq_act_order, + ) + awq_config = AWQConfig( + ckpt=args.awq_ckpt or args.model_path, + wbits=args.awq_wbits, + groupsize=args.awq_groupsize, + ) + if args.enable_exllama: + exllama_config = ExllamaConfig( + max_seq_len=args.exllama_max_seq_len, + gpu_split=args.exllama_gpu_split, + cache_8bit=args.exllama_cache_8bit, + ) + else: + exllama_config = None + if args.enable_xft: + xft_config = XftConfig( + max_seq_len=args.xft_max_seq_len, + data_type=args.xft_dtype, + ) + if args.device != "cpu": + print("xFasterTransformer now is only support CPUs. Reset device to CPU") + args.device = "cpu" + else: + xft_config = None + + worker = ModelWorker( + args.controller_address, + args.worker_address, + worker_id, + args.model_path, + args.model_names, + args.limit_worker_concurrency, + revision=args.revision, + no_register=args.no_register, + device=args.device, + num_gpus=args.num_gpus, + max_gpu_memory=args.max_gpu_memory, + dtype=str_to_torch_dtype(args.dtype), + load_8bit=args.load_8bit, + cpu_offloading=args.cpu_offloading, + gptq_config=gptq_config, + awq_config=awq_config, + exllama_config=exllama_config, + xft_config=xft_config, + stream_interval=args.stream_interval, + conv_template=args.conv_template, + embed_in_truncate=args.embed_in_truncate, + seed=args.seed, + debug=args.debug, + ) + return args, worker + + +if __name__ == "__main__": + args, worker = create_model_worker() + if args.ssl: + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + ssl_keyfile=os.environ["SSL_KEYFILE"], + ssl_certfile=os.environ["SSL_CERTFILE"], + ) + else: + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/third_party/FastChat/fastchat/serve/monitor/basic_stats.py b/third_party/FastChat/fastchat/serve/monitor/basic_stats.py new file mode 100644 index 00000000..3c1a8793 --- /dev/null +++ b/third_party/FastChat/fastchat/serve/monitor/basic_stats.py @@ -0,0 +1,220 @@ +import argparse +import code +import datetime +import json +import os +from pytz import timezone +import time + +import pandas as pd # pandas>=2.0.3 +import plotly.express as px +import plotly.graph_objects as go +from tqdm import tqdm + + +NUM_SERVERS = 14 +LOG_ROOT_DIR = "~/fastchat_logs" + + +def get_log_files(max_num_files=None): + log_root = os.path.expanduser(LOG_ROOT_DIR) + filenames = [] + for i in range(NUM_SERVERS): + for filename in os.listdir(f"{log_root}/server{i}"): + if filename.endswith("-conv.json"): + filepath = f"{log_root}/server{i}/{filename}" + name_tstamp_tuple = (filepath, os.path.getmtime(filepath)) + filenames.append(name_tstamp_tuple) + # sort by tstamp + filenames = sorted(filenames, key=lambda x: x[1]) + filenames = [x[0] for x in filenames] + + max_num_files = max_num_files or len(filenames) + filenames = filenames[-max_num_files:] + return filenames + + +def load_log_files(filename): + data = [] + for retry in range(5): + try: + lines = open(filename).readlines() + break + except FileNotFoundError: + time.sleep(2) + + for l in lines: + row = json.loads(l) + data.append( + dict( + type=row["type"], + tstamp=row["tstamp"], + model=row.get("model", ""), + models=row.get("models", ["", ""]), + ) + ) + return data + + +def load_log_files_parallel(log_files, num_threads=16): + data_all = [] + from multiprocessing import Pool + + with Pool(num_threads) as p: + ret_all = list(tqdm(p.imap(load_log_files, log_files), total=len(log_files))) + for ret in ret_all: + data_all.extend(ret) + return data_all + + +def get_anony_vote_df(df): + anony_vote_df = df[ + df["type"].isin(["leftvote", "rightvote", "tievote", "bothbad_vote"]) + ] + anony_vote_df = anony_vote_df[anony_vote_df["models"].apply(lambda x: x[0] == "")] + return anony_vote_df + + +def merge_counts(series, on, names): + ret = pd.merge(series[0], series[1], on=on) + for i in range(2, len(series)): + ret = pd.merge(ret, series[i], on=on) + ret = ret.reset_index() + old_names = list(ret.columns)[-len(series) :] + rename = {old_name: new_name for old_name, new_name in zip(old_names, names)} + ret = ret.rename(columns=rename) + return ret + + +def report_basic_stats(log_files): + df_all = load_log_files_parallel(log_files) + df_all = pd.DataFrame(df_all) + now_t = df_all["tstamp"].max() + df_1_hour = df_all[df_all["tstamp"] > (now_t - 3600)] + df_1_day = df_all[df_all["tstamp"] > (now_t - 3600 * 24)] + anony_vote_df_all = get_anony_vote_df(df_all) + + # Chat trends + chat_dates = [ + datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime( + "%Y-%m-%d" + ) + for x in df_all[df_all["type"] == "chat"]["tstamp"] + ] + chat_dates_counts = pd.value_counts(chat_dates) + vote_dates = [ + datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime( + "%Y-%m-%d" + ) + for x in anony_vote_df_all["tstamp"] + ] + vote_dates_counts = pd.value_counts(vote_dates) + chat_dates_bar = go.Figure( + data=[ + go.Bar( + name="Anony. Vote", + x=vote_dates_counts.index, + y=vote_dates_counts, + text=[f"{val:.0f}" for val in vote_dates_counts], + textposition="auto", + ), + go.Bar( + name="Chat", + x=chat_dates_counts.index, + y=chat_dates_counts, + text=[f"{val:.0f}" for val in chat_dates_counts], + textposition="auto", + ), + ] + ) + chat_dates_bar.update_layout( + barmode="stack", + xaxis_title="Dates", + yaxis_title="Count", + height=300, + width=1200, + ) + + # Model call counts + model_hist_all = df_all[df_all["type"] == "chat"]["model"].value_counts() + model_hist_1_day = df_1_day[df_1_day["type"] == "chat"]["model"].value_counts() + model_hist_1_hour = df_1_hour[df_1_hour["type"] == "chat"]["model"].value_counts() + model_hist = merge_counts( + [model_hist_all, model_hist_1_day, model_hist_1_hour], + on="model", + names=["All", "Last Day", "Last Hour"], + ) + model_hist_md = model_hist.to_markdown(index=False, tablefmt="github") + + # Action counts + action_hist_all = df_all["type"].value_counts() + action_hist_1_day = df_1_day["type"].value_counts() + action_hist_1_hour = df_1_hour["type"].value_counts() + action_hist = merge_counts( + [action_hist_all, action_hist_1_day, action_hist_1_hour], + on="type", + names=["All", "Last Day", "Last Hour"], + ) + action_hist_md = action_hist.to_markdown(index=False, tablefmt="github") + + # Anony vote counts + anony_vote_hist_all = anony_vote_df_all["type"].value_counts() + anony_vote_df_1_day = get_anony_vote_df(df_1_day) + anony_vote_hist_1_day = anony_vote_df_1_day["type"].value_counts() + # anony_vote_df_1_hour = get_anony_vote_df(df_1_hour) + # anony_vote_hist_1_hour = anony_vote_df_1_hour["type"].value_counts() + anony_vote_hist = merge_counts( + [anony_vote_hist_all, anony_vote_hist_1_day], + on="type", + names=["All", "Last Day"], + ) + anony_vote_hist_md = anony_vote_hist.to_markdown(index=False, tablefmt="github") + + # Last 24 hours + chat_1_day = df_1_day[df_1_day["type"] == "chat"] + num_chats_last_24_hours = [] + base = df_1_day["tstamp"].min() + for i in range(24, 0, -1): + left = base + (i - 1) * 3600 + right = base + i * 3600 + num = ((chat_1_day["tstamp"] >= left) & (chat_1_day["tstamp"] < right)).sum() + num_chats_last_24_hours.append(num) + times = [ + datetime.datetime.fromtimestamp( + base + i * 3600, tz=timezone("US/Pacific") + ).strftime("%Y-%m-%d %H:%M:%S %Z") + for i in range(24, 0, -1) + ] + last_24_hours_df = pd.DataFrame({"time": times, "value": num_chats_last_24_hours}) + last_24_hours_md = last_24_hours_df.to_markdown(index=False, tablefmt="github") + + # Last update datetime + last_updated_tstamp = now_t + last_updated_datetime = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y-%m-%d %H:%M:%S %Z") + + # code.interact(local=locals()) + + return { + "chat_dates_bar": chat_dates_bar, + "model_hist_md": model_hist_md, + "action_hist_md": action_hist_md, + "anony_vote_hist_md": anony_vote_hist_md, + "num_chats_last_24_hours": last_24_hours_md, + "last_updated_datetime": last_updated_datetime, + } + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--max-num-files", type=int) + args = parser.parse_args() + + log_files = get_log_files(args.max_num_files) + basic_stats = report_basic_stats(log_files) + + print(basic_stats["action_hist_md"] + "\n") + print(basic_stats["model_hist_md"] + "\n") + print(basic_stats["anony_vote_hist_md"] + "\n") + print(basic_stats["num_chats_last_24_hours"] + "\n") diff --git a/third_party/FastChat/fastchat/serve/monitor/clean_battle_data.py b/third_party/FastChat/fastchat/serve/monitor/clean_battle_data.py new file mode 100644 index 00000000..44c8192a --- /dev/null +++ b/third_party/FastChat/fastchat/serve/monitor/clean_battle_data.py @@ -0,0 +1,313 @@ +""" +Clean chatbot arena battle log. + +Usage: +python3 clean_battle_data.py --mode conv_release +""" +import argparse +import datetime +import json +import os +from pytz import timezone +import time + +from tqdm import tqdm + +from fastchat.serve.monitor.basic_stats import get_log_files, NUM_SERVERS +from fastchat.utils import detect_language + + +VOTES = ["tievote", "leftvote", "rightvote", "bothbad_vote"] +IDENTITY_WORDS = [ + "vicuna", + "lmsys", + "koala", + "uc berkeley", + "open assistant", + "laion", + "chatglm", + "chatgpt", + "gpt-4", + "openai", + "anthropic", + "claude", + "bard", + "palm", + "lamda", + "google", + "llama", + "qianwan", + "alibaba", + "mistral", + "zhipu", + "KEG lab", + "01.AI", + "AI2", + "Tülu", + "Tulu", + "NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.", + "$MODERATION$ YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES.", + "API REQUEST ERROR. Please increase the number of max tokens.", + "**API REQUEST ERROR** Reason: The response was blocked.", + "**API REQUEST ERROR**", +] + +for i in range(len(IDENTITY_WORDS)): + IDENTITY_WORDS[i] = IDENTITY_WORDS[i].lower() + + +def remove_html(raw): + if raw.startswith("

"): + return raw[raw.find(": ") + 2 : -len("

\n")] + return raw + + +def to_openai_format(messages): + roles = ["user", "assistant"] + ret = [] + for i, x in enumerate(messages): + ret.append({"role": roles[i % 2], "content": x[1]}) + return ret + + +def replace_model_name(old_name, tstamp): + replace_dict = { + "bard": "palm-2", + "claude-v1": "claude-1", + "claude-instant-v1": "claude-instant-1", + "oasst-sft-1-pythia-12b": "oasst-pythia-12b", + "claude-2": "claude-2.0", + } + if old_name in ["gpt-4", "gpt-3.5-turbo"]: + if tstamp > 1687849200: + return old_name + "-0613" + else: + return old_name + "-0314" + if old_name in replace_dict: + return replace_dict[old_name] + return old_name + + +def read_file(filename): + data = [] + for retry in range(5): + try: + # lines = open(filename).readlines() + for l in open(filename): + row = json.loads(l) + if row["type"] in VOTES: + data.append(row) + break + except FileNotFoundError: + time.sleep(2) + return data + + +def read_file_parallel(log_files, num_threads=16): + data_all = [] + from multiprocessing import Pool + + with Pool(num_threads) as p: + ret_all = list(tqdm(p.imap(read_file, log_files), total=len(log_files))) + for ret in ret_all: + data_all.extend(ret) + return data_all + + +def clean_battle_data( + log_files, exclude_model_names, ban_ip_list=None, sanitize_ip=False +): + data = read_file_parallel(log_files, num_threads=16) + + convert_type = { + "leftvote": "model_a", + "rightvote": "model_b", + "tievote": "tie", + "bothbad_vote": "tie (bothbad)", + } + + all_models = set() + all_ips = dict() + ct_anony = 0 + ct_invalid = 0 + ct_leaked_identity = 0 + ct_banned = 0 + battles = [] + for row in data: + if row["models"][0] is None or row["models"][1] is None: + continue + + # Resolve model names + models_public = [remove_html(row["models"][0]), remove_html(row["models"][1])] + if "model_name" in row["states"][0]: + models_hidden = [ + row["states"][0]["model_name"], + row["states"][1]["model_name"], + ] + if models_hidden[0] is None: + models_hidden = models_public + else: + models_hidden = models_public + + if (models_public[0] == "" and models_public[1] != "") or ( + models_public[1] == "" and models_public[0] != "" + ): + ct_invalid += 1 + continue + + if models_public[0] == "" or models_public[0] == "Model A": + anony = True + models = models_hidden + ct_anony += 1 + else: + anony = False + models = models_public + if not models_public == models_hidden: + ct_invalid += 1 + continue + + # Detect langauge + state = row["states"][0] + if state["offset"] >= len(state["messages"]): + ct_invalid += 1 + continue + lang_code = detect_language(state["messages"][state["offset"]][1]) + + # Drop conversations if the model names are leaked + leaked_identity = False + messages = "" + for i in range(2): + state = row["states"][i] + for turn_idx, (role, msg) in enumerate( + state["messages"][state["offset"] :] + ): + if msg: + messages += msg.lower() + for word in IDENTITY_WORDS: + if word in messages: + leaked_identity = True + break + + if leaked_identity: + ct_leaked_identity += 1 + continue + + # Replace bard with palm + models = [replace_model_name(m, row["tstamp"]) for m in models] + # Exclude certain models + if exclude_model_names and any(x in exclude_model_names for x in models): + ct_invalid += 1 + continue + + question_id = row["states"][0]["conv_id"] + conversation_a = to_openai_format( + row["states"][0]["messages"][row["states"][0]["offset"] :] + ) + conversation_b = to_openai_format( + row["states"][1]["messages"][row["states"][1]["offset"] :] + ) + + ip = row["ip"] + if ip not in all_ips: + all_ips[ip] = {"ip": ip, "count": 0, "sanitized_id": len(all_ips)} + all_ips[ip]["count"] += 1 + if sanitize_ip: + user_id = f"arena_user_{all_ips[ip]['sanitized_id']}" + else: + user_id = f"{all_ips[ip]['ip']}" + + if ban_ip_list is not None and ip in ban_ip_list: + ct_banned += 1 + continue + + # Save the results + battles.append( + dict( + question_id=question_id, + model_a=models[0], + model_b=models[1], + winner=convert_type[row["type"]], + judge=f"arena_user_{user_id}", + conversation_a=conversation_a, + conversation_b=conversation_b, + turn=len(conversation_a) // 2, + anony=anony, + language=lang_code, + tstamp=row["tstamp"], + ) + ) + + all_models.update(models_hidden) + battles.sort(key=lambda x: x["tstamp"]) + last_updated_tstamp = battles[-1]["tstamp"] + + last_updated_datetime = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y-%m-%d %H:%M:%S %Z") + + print( + f"#votes: {len(data)}, #invalid votes: {ct_invalid}, " + f"#leaked_identity: {ct_leaked_identity} " + f"#banned: {ct_banned} " + ) + print(f"#battles: {len(battles)}, #anony: {ct_anony}") + print(f"#models: {len(all_models)}, {all_models}") + print(f"last-updated: {last_updated_datetime}") + + if ban_ip_list is not None: + for ban_ip in ban_ip_list: + if ban_ip in all_ips: + del all_ips[ban_ip] + print("Top 30 IPs:") + print(sorted(all_ips.values(), key=lambda x: x["count"], reverse=True)[:30]) + return battles + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--max-num-files", type=int) + parser.add_argument( + "--mode", type=str, choices=["simple", "conv_release"], default="simple" + ) + parser.add_argument("--exclude-model-names", type=str, nargs="+") + parser.add_argument("--ban-ip-file", type=str) + parser.add_argument("--sanitize-ip", action="store_true", default=False) + args = parser.parse_args() + + log_files = get_log_files(args.max_num_files) + ban_ip_list = json.load(open(args.ban_ip_file)) if args.ban_ip_file else None + + battles = clean_battle_data( + log_files, args.exclude_model_names or [], ban_ip_list, args.sanitize_ip + ) + last_updated_tstamp = battles[-1]["tstamp"] + cutoff_date = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y%m%d") + + if args.mode == "simple": + for x in battles: + for key in [ + "conversation_a", + "conversation_b", + "question_id", + ]: + del x[key] + print("Samples:") + for i in range(4): + print(battles[i]) + output = f"clean_battle_{cutoff_date}.json" + elif args.mode == "conv_release": + new_battles = [] + for x in battles: + if not x["anony"]: + continue + for key in []: + del x[key] + new_battles.append(x) + battles = new_battles + output = f"clean_battle_conv_{cutoff_date}.json" + + with open(output, "w") as fout: + json.dump(battles, fout, indent=2, ensure_ascii=False) + print(f"Write cleaned data to {output}") diff --git a/third_party/FastChat/fastchat/serve/monitor/clean_chat_data.py b/third_party/FastChat/fastchat/serve/monitor/clean_chat_data.py new file mode 100644 index 00000000..7f0c9bd4 --- /dev/null +++ b/third_party/FastChat/fastchat/serve/monitor/clean_chat_data.py @@ -0,0 +1,171 @@ +""" +Clean chatbot arena chat log. + +Usage: +python3 clean_chat_data.py --mode conv_release +""" +import argparse +import datetime +import json +import os +from pytz import timezone +import time + +from tqdm import tqdm + +from fastchat.serve.monitor.basic_stats import NUM_SERVERS +from fastchat.serve.monitor.clean_battle_data import ( + to_openai_format, + replace_model_name, +) +from fastchat.utils import detect_language + + +NETWORK_ERROR_MSG = ( + "NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.".lower() +) + + +def get_log_files(max_num_files=None): + dates = [] + for month in range(4, 12): + for day in range(1, 33): + dates.append(f"2023-{month:02d}-{day:02d}") + + filenames = [] + for d in dates: + for i in range(NUM_SERVERS): + name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json") + if os.path.exists(name): + filenames.append(name) + max_num_files = max_num_files or len(filenames) + # filenames = list(reversed(filenames)) + filenames = filenames[-max_num_files:] + return filenames + + +def clean_chat_data(log_files, action_type): + raw_data = [] + for filename in tqdm(log_files, desc="read files"): + for retry in range(5): + try: + lines = open(filename).readlines() + break + except FileNotFoundError: + time.sleep(2) + + for l in lines: + row = json.loads(l) + if row["type"] == action_type: + raw_data.append(row) + + all_models = set() + all_ips = dict() + chats = [] + ct_invalid_conv_id = 0 + ct_invalid = 0 + ct_network_error = 0 + for row in raw_data: + try: + if action_type in ["chat", "upvote", "downvote"]: + state = row["state"] + model = row["model"] + elif action_type == "leftvote": + state = row["states"][0] + model = row["states"][0]["model_name"] + elif action_type == "rightvote": + state = row["states"][1] + model = row["states"][1]["model_name"] + conversation_id = state["conv_id"] + except KeyError: + ct_invalid_conv_id += 1 + continue + + if conversation_id is None: + ct_invalid_conv_id += 1 + continue + + conversation = to_openai_format(state["messages"][state["offset"] :]) + if not isinstance(model, str): + ct_invalid += 1 + continue + model = replace_model_name(model) + + try: + lang_code = detect_language(state["messages"][state["offset"]][1]) + except IndexError: + ct_invalid += 1 + continue + + if not all(isinstance(x["content"], str) for x in conversation): + ct_invalid += 1 + continue + + messages = "".join([x["content"] for x in conversation]).lower() + if NETWORK_ERROR_MSG in messages: + ct_network_error += 1 + continue + + ip = row["ip"] + if ip not in all_ips: + all_ips[ip] = len(all_ips) + user_id = all_ips[ip] + + chats.append( + dict( + conversation_id=conversation_id, + model=model, + conversation=conversation, + turn=len(conversation) // 2, + language=lang_code, + user_id=user_id, + tstamp=row["tstamp"], + ) + ) + + all_models.update([model]) + + chats.sort(key=lambda x: x["tstamp"]) + last_updated_tstamp = chats[-1]["tstamp"] + last_updated_datetime = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y-%m-%d %H:%M:%S %Z") + + # Deduplication + dedup_chats = [] + visited_conv_ids = set() + for i in reversed(range(len(chats))): + if chats[i]["conversation_id"] in visited_conv_ids: + continue + visited_conv_ids.add(chats[i]["conversation_id"]) + dedup_chats.append(chats[i]) + + print( + f"#raw: {len(raw_data)}, #chat: {len(chats)}, #dedup_chat: {len(dedup_chats)}" + ) + print( + f"#invalid_conv_id: {ct_invalid_conv_id}, #network_error: {ct_network_error}, #invalid: {ct_invalid}" + ) + print(f"#models: {len(all_models)}, {all_models}") + print(f"last-updated: {last_updated_datetime}") + + return list(reversed(dedup_chats)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--action-type", type=str, default="chat") + parser.add_argument("--max-num-files", type=int) + args = parser.parse_args() + + log_files = get_log_files(args.max_num_files) + chats = clean_chat_data(log_files, args.action_type) + last_updated_tstamp = chats[-1]["tstamp"] + cutoff_date = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y%m%d") + + output = f"clean_{args.action_type}_conv_{cutoff_date}.json" + with open(output, "w") as fout: + json.dump(chats, fout, indent=2, ensure_ascii=False) + print(f"Write cleaned data to {output}") diff --git a/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/arena_33k/count_unique_users.py b/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/arena_33k/count_unique_users.py new file mode 100644 index 00000000..8e94cf27 --- /dev/null +++ b/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/arena_33k/count_unique_users.py @@ -0,0 +1,25 @@ +"""Count the unique users in a battle log file.""" + +import argparse +import json + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str) + args = parser.parse_args() + + lines = json.load(open(args.input)) + ct_anony_votes = 0 + all_users = set() + all_models = set() + for l in lines: + if not l["anony"]: + continue + all_users.add(l["judge"]) + all_models.add(l["model_a"]) + all_models.add(l["model_b"]) + ct_anony_votes += 1 + + print(f"#anony_vote: {ct_anony_votes}, #user: {len(all_users)}") + print(f"#model: {len(all_models)}") diff --git a/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/arena_33k/filter_bad_conv.py b/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/arena_33k/filter_bad_conv.py new file mode 100644 index 00000000..6d12d7c6 --- /dev/null +++ b/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/arena_33k/filter_bad_conv.py @@ -0,0 +1,155 @@ +""" +Filter conversations for release. + +Usage: python3 filter_bad_conv.py --in clean_battle_conv_20230630_tagged_v1_pii.json +""" +import argparse +from collections import defaultdict +from enum import Enum, auto +import json +import os +import random + +from tqdm import tqdm + +BLOCKED_WORDS_FILENAME = "blocked_words.json" +blocked_words = [] +frequency = defaultdict(lambda: 0) + + +class TypeCode(Enum): + CORRECT = auto() + ANONYMIZED = auto() + REDACTED = auto() + BAD_FORMAT = auto() + BLOCKED_WORD = auto() + BLOCKED_MODEL = auto() + TOO_SHORT = auto() + TOO_FREQUENT = auto() + + +def detect_type(conv): + for key in ["conversation_a", "conversation_b"]: + messages = [row["content"] for row in conv[key]] + for msg in messages: + if not isinstance(msg, str): + return TypeCode.BAD_FORMAT + + user_prompts = [ + row["content"].lower().strip() for row in conv[key] if row["role"] == "user" + ] + if len(messages) <= 2 and all(len(x) < 16 for x in user_prompts): + return TypeCode.TOO_SHORT + + if all(x in frequent_prompts for x in user_prompts): + return TypeCode.TOO_FREQUENT + + for msg in messages: + msg = msg.lower() + if "" in msg: + return TypeCode.ANONYMIZED + if "" in msg: + return TypeCode.REDACTED + + for w in blocked_words: + if w in msg: + return TypeCode.BLOCKED_WORD + + for key in ["model_a", "model_b"]: + if conv[key] in ["vicuna-33b", "mpt-30b-chat"]: + return TypeCode.BLOCKED_MODEL + + return TypeCode.CORRECT + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--sample", type=int) + args = parser.parse_args() + + # Read conversations + convs = json.load(open(args.in_file)) + print(f"#conv: {len(convs)}") + + # Read blocked words + if os.path.exists(BLOCKED_WORDS_FILENAME): + blocked_words = json.load(open(BLOCKED_WORDS_FILENAME)) + + # Count frequency + for conv in convs: + for key in ["conversation_a", "conversation_b"]: + messages = [row["content"] for row in conv[key] if row["role"] == "user"] + for msg in messages: + if not isinstance(msg, str): + continue + msg = msg.lower().strip() + frequency[msg] += 1 + + keys = list(frequency.keys()) + keys.sort(key=lambda x: -frequency[x]) + frequent_prompts = keys[:10] + frequent_prompts = set(frequent_prompts) + frequent_prompts.add("") + + # Start filter + ct_bad_format = 0 + ct_anonymized = 0 + ct_redacted = 0 + ct_error = 0 + ct_lang_filter = 0 + ct_flagged = 0 + ct_blocked_word = 0 + ct_blocked_model = 0 + ct_too_short = 0 + ct_too_frequent = 0 + + new_convs = [] + for conv in tqdm(convs): + type_code = detect_type(conv) + + if type_code == TypeCode.BAD_FORMAT: + ct_bad_format += 1 + continue + + if type_code == TypeCode.ANONYMIZED: + ct_anonymized += 1 + continue + elif type_code == TypeCode.REDACTED: + ct_redacted += 1 + continue + elif type_code == TypeCode.BLOCKED_WORD: + ct_blocked_word += 1 + continue + elif type_code == TypeCode.BLOCKED_MODEL: + ct_blocked_model += 1 + continue + elif type_code == TypeCode.TOO_SHORT: + ct_too_short += 1 + continue + elif type_code == TypeCode.TOO_FREQUENT: + ct_too_frequent += 1 + continue + + if conv["openai_moderation"]["flagged"]: + ct_flagged += 1 + continue + + if type_code in [TypeCode.CORRECT]: + new_convs.append(conv) + + if args.sample: + # random.seed(0) + # random.shuffle(new_convs) + new_convs = new_convs[: args.sample] + + print(f"ct_anonymized: {ct_anonymized}, ct_redacted: {ct_redacted}") + print(f"ct_bad_format: {ct_bad_format}, ct_flagged: {ct_flagged}") + print(f"ct_blocked_word: {ct_blocked_word}, ct_blocked_model: {ct_blocked_model}") + print(f"ct_too_short: {ct_too_short}, ct_too_frequent: {ct_anonymized}") + print(f"new_conv: {len(new_convs)}") + + out_file = args.in_file.replace(".json", ".out.json") + print(f"Output to {out_file}") + with open(out_file, "w") as fout: + json.dump(new_convs, fout, indent=2, ensure_ascii=False) diff --git a/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/arena_33k/merge_field.py b/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/arena_33k/merge_field.py new file mode 100644 index 00000000..5a88209b --- /dev/null +++ b/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/arena_33k/merge_field.py @@ -0,0 +1,25 @@ +"""Count the unique users in a battle log file.""" + +import argparse +import json + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str) + parser.add_argument("--tag-file", type=str) + args = parser.parse_args() + + # build index + objs = json.load(open(args.tag_file)) + new_field_dict = {} + for obj in objs: + new_field_dict[obj["question_id"]] = obj["toxic_chat"] + + objs = json.load(open(args.input)) + for obj in objs: + obj["toxic_chat_tag"] = new_field_dict[obj["question_id"]] + + output = args.input.replace(".json", "_added.json") + with open(output, "w") as fout: + json.dump(objs, fout, indent=2, ensure_ascii=False) diff --git a/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/arena_33k/sample.py b/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/arena_33k/sample.py new file mode 100644 index 00000000..0cd78b71 --- /dev/null +++ b/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/arena_33k/sample.py @@ -0,0 +1,32 @@ +""" +Count the unique users in a battle log file. + +Usage: +python3 -input in.json --number 1000 +""" + +import argparse +import json +import random + +K = 1000 + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str) + parser.add_argument("--number", type=int, nargs="+") + args = parser.parse_args() + + convs = json.load(open(args.input)) + random.seed(0) + random.shuffle(convs) + + for number in args.number: + new_convs = convs[:number] + + output = args.input.replace(".json", f"_{number//K}k.json") + with open(output, "w") as fout: + json.dump(new_convs, fout, indent=2, ensure_ascii=False) + + print(f"#in: {len(convs)}, #out: {len(new_convs)}") + print(f"Write to file: {output}") diff --git a/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/arena_33k/upload_hf_dataset.py b/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/arena_33k/upload_hf_dataset.py new file mode 100644 index 00000000..e37aadce --- /dev/null +++ b/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/arena_33k/upload_hf_dataset.py @@ -0,0 +1,9 @@ +""" +Upload to huggingface. +""" +import json +from datasets import Dataset, DatasetDict, load_dataset + +objs = json.load(open("clean_battle_conv_20230630_tagged_v3_pii_33k_added.json")) +data = Dataset.from_list(objs) +data.push_to_hub("lmsys/chatbot_arena_conversations", private=True) diff --git a/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/approve_all.py b/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/approve_all.py new file mode 100644 index 00000000..a7084207 --- /dev/null +++ b/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/approve_all.py @@ -0,0 +1,13 @@ +import requests + +headers = {"authorization": "Bearer hf_XXX"} + +url = "https://huggingface.co/api/datasets/lmsys/lmsys-chat-1m/user-access-request/pending" +a = requests.get(url, headers=headers) + +for u in a.json(): + user = u["user"]["user"] + url = "https://huggingface.co/api/datasets/lmsys/lmsys-chat-1m/user-access-request/grant" + ret = requests.post(url, headers=headers, json={"user": user}) + print(user, ret.status_code) + assert ret.status_code == 200 diff --git a/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/compute_stats.py b/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/compute_stats.py new file mode 100644 index 00000000..97abaaa0 --- /dev/null +++ b/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/compute_stats.py @@ -0,0 +1,119 @@ +""" +From colab: +https://colab.research.google.com/drive/1oMdw_Lqgmd6DletSOLHsyD-Rc96cRShs?usp=sharing +""" +import argparse +import datetime +import json +import os +from pytz import timezone +import time + +import kaleido +import numpy as np +import pandas as pd +import plotly.express as px +import plotly.graph_objects as go +from tqdm import tqdm + +import plotly.io as pio + +pio.kaleido.scope.mathjax = None + +parser = argparse.ArgumentParser() +parser.add_argument("--in-file", type=str, required=True) +parser.add_argument("--scale", type=int, required=True) +args = parser.parse_args() + +filename = args.in_file +scale = args.scale +convs = json.load(open(filename)) +df = pd.DataFrame(convs) +df + +print(f"#ips: {df['user_id'].nunique() * scale}") +print(f"#models: {df['model'].nunique()}") +print(f"#language: {df['language'].nunique()}") +print(f"#turns: {df['turn'].mean()}") + +model_counts = df["model"].value_counts() * scale +# print("model counts", model_counts) +fig = px.bar(x=model_counts.index, y=model_counts) +fig.update_layout( + xaxis_title=None, + yaxis_title="Count", + height=200, + width=950, + margin=dict(l=0, r=0, t=0, b=0), +) +fig.show() +fig.write_image("model_count.pdf") + + +model_counts = df["language"].value_counts().head(25) * scale +fig = px.bar(x=model_counts.index, y=model_counts) +fig.update_layout( + xaxis_title=None, + yaxis_title="Count", + height=200, + width=950, + margin=dict(l=0, r=0, t=0, b=0), +) +fig.show() +fig.write_image("language_count.pdf") + +chat_dates = [ + datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime("%Y-%m-%d") + for x in df["tstamp"] +] + + +def to_remove(x): + for d in ["08-09", "08-08", "08-07", "08-06", "08-05", "08-04"]: + if d in x: + return True + return False + + +chat_dates = [x for x in chat_dates if not to_remove(x)] + +chat_dates_counts = pd.value_counts(chat_dates) * scale +print(f"mean #chat per day: {np.mean(chat_dates_counts):.2f}") + +fig = px.bar(x=chat_dates_counts.index, y=chat_dates_counts) +fig.update_layout( + xaxis_title="Dates", + yaxis_title="Count", + height=200, + width=950, + margin=dict(l=0, r=0, t=0, b=0), +) +fig.show() +fig.write_image("daily_conversation_count.pdf") + +import transformers + +tokenizer = transformers.AutoTokenizer.from_pretrained( + "lmsys/vicuna-7b-v1.5", use_fast=False +) + +prompts = [] +responses = [] +for conv in df["conversation"]: + for row in conv: + if row["role"] == "user": + prompts.append(row["content"]) + else: + responses.append(row["content"]) + +print(f"#prompts: {len(prompts)}") +print(f"#responses: {len(responses)}") + + +prompt_lens = [len(tokenizer(x).input_ids) for x in tqdm(prompts)] +print() +print(f"mean prompt len: {np.mean(prompt_lens):.2f}") + +response_lens = [len(tokenizer(x).input_ids) if x else 0 for x in tqdm(responses)] +print() +print(f"mean response len: {np.mean(response_lens):.2f}") diff --git a/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/filter_bad_conv.py b/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/filter_bad_conv.py new file mode 100644 index 00000000..3ccde1ca --- /dev/null +++ b/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/filter_bad_conv.py @@ -0,0 +1,148 @@ +""" +Filter conversations for release. + +Dependency: +pip install opencc-python-reimplementedpip install opencc-python-reimplemented + +Usage: +python3 filter_bad_conv_lmsys_chat_1m.py --in clean_battle_conv_20230630_tagged_v1_pii.json +""" +import argparse +from concurrent.futures import ProcessPoolExecutor +from collections import defaultdict +from enum import Enum, auto +import json +import os +import random + +from tqdm import tqdm +import opencc + +BLOCKED_WORDS_FILENAME = "blocked_words.json" +blocked_words = [] +frequency = defaultdict(lambda: 0) + +cc_converter = opencc.OpenCC("t2s") + + +class TypeCode(Enum): + CORRECT = auto() + ANONYMIZED = auto() + REDACTED = auto() + BAD_FORMAT = auto() + BLOCKED_WORD = auto() + BLOCKED_MODEL = auto() + TOO_SHORT = auto() + TOO_FREQUENT = auto() + + +def detect_type(conv): + for key in ["conversation_a", "conversation_b", "conversation"]: + if key not in conv: + continue + + messages = [row["content"] for row in conv[key]] + for msg in messages: + if not isinstance(msg, str): + return TypeCode.BAD_FORMAT + + if len(messages) == 0: + return TypeCode.BAD_FORMAT + + user_prompts = [ + row["content"].lower().strip() for row in conv[key] if row["role"] == "user" + ] + + for msg in messages: + msg = cc_converter.convert(msg.lower()) + if "" in msg: + return TypeCode.ANONYMIZED + if "" in msg: + return TypeCode.REDACTED + + for w in blocked_words: + if w in msg: + return TypeCode.BLOCKED_WORD + + return TypeCode.CORRECT + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--sample", type=int) + args = parser.parse_args() + + # Read conversations + convs = json.load(open(args.in_file)) + print(f"#conv: {len(convs)}") + + # Read blocked words + if os.path.exists(BLOCKED_WORDS_FILENAME): + blocked_words = json.load(open(BLOCKED_WORDS_FILENAME)) + blocked_words = [cc_converter.convert(w) for w in blocked_words] + + # Start filter + ct_bad_format = 0 + ct_anonymized = 0 + ct_redacted = 0 + ct_error = 0 + ct_lang_filter = 0 + ct_flagged = 0 + ct_blocked_word = 0 + ct_blocked_model = 0 + ct_too_short = 0 + ct_too_frequent = 0 + + type_codes = [] + with ProcessPoolExecutor() as executor: + for result in tqdm(executor.map(detect_type, convs), total=len(convs)): + type_codes.append(result) + + new_convs = [] + for conv, type_code in zip(convs, type_codes): + if type_code == TypeCode.BAD_FORMAT: + ct_bad_format += 1 + continue + + if type_code == TypeCode.ANONYMIZED: + ct_anonymized += 1 + continue + elif type_code == TypeCode.REDACTED: + ct_redacted += 1 + continue + elif type_code == TypeCode.BLOCKED_WORD: + ct_blocked_word += 1 + continue + elif type_code == TypeCode.BLOCKED_MODEL: + ct_blocked_model += 1 + continue + elif type_code == TypeCode.TOO_SHORT: + ct_too_short += 1 + continue + elif type_code == TypeCode.TOO_FREQUENT: + ct_too_frequent += 1 + continue + + if "openai_moderation" in conv and conv["openai_moderation"]["flagged"]: + ct_flagged += 1 + continue + + if type_code in [TypeCode.CORRECT]: + new_convs.append(conv) + + if args.sample: + random.seed(42) + random.shuffle(new_convs) + new_convs = new_convs[: args.sample] + + print(f"ct_anonymized: {ct_anonymized}, ct_redacted: {ct_redacted}") + print(f"ct_bad_format: {ct_bad_format}, ct_flagged: {ct_flagged}") + print(f"ct_blocked_word: {ct_blocked_word}, ct_blocked_model: {ct_blocked_model}") + print(f"ct_too_short: {ct_too_short}, ct_too_frequent: {ct_too_frequent}") + print(f"new_conv: {len(new_convs)}") + + out_file = args.in_file.replace(".json", ".s1.json") + print(f"Output to {out_file}") + with open(out_file, "w") as fout: + json.dump(new_convs, fout, indent=2, ensure_ascii=False) diff --git a/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/final_post_processing.py b/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/final_post_processing.py new file mode 100644 index 00000000..e368e92a --- /dev/null +++ b/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/final_post_processing.py @@ -0,0 +1,27 @@ +import argparse +import json + +from tqdm import tqdm +import numpy as np + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + args = parser.parse_args() + + # Read conversations + convs = json.load(open(args.in_file)) + print(f"#conv: {len(convs)}") + + # Delete some fileds + for c in convs: + del c["tstamp"] + del c["user_id"] + + # Write + print(f"#out conv: {len(convs)}") + out_file = args.in_file.replace(".json", ".s2.json") + print(f"Output to {out_file}") + with open(out_file, "w") as fout: + json.dump(convs, fout, indent=2, ensure_ascii=False) diff --git a/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/instructions.md b/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/instructions.md new file mode 100644 index 00000000..4c439731 --- /dev/null +++ b/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/instructions.md @@ -0,0 +1,23 @@ +``` +export BASE=clean_conv_20230809_100k_pii +export SCALE=10 + +# filter words +python3 filter_bad_conv.py --in $BASE.json + +# Clean up some fileds (e.g., timestamps) +python3 final_post_processing.py --in $BASE.s1.json + +# upload to hf +python3 upload_hf_dataset.py --in $BASE.s1.s2.json + +# Make another version with openai moderation tag +python3 merge_oai_tag.py --in $BASE.s1.s2.json + +# Make visualizations +python3 compute_stats.py --in $BASE.s1.json --scale $SCALE + +# Copy figures +scp "atlas:/data/lmzheng/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/*.pdf" . +``` + diff --git a/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/merge_oai_tag.py b/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/merge_oai_tag.py new file mode 100644 index 00000000..18bef5f1 --- /dev/null +++ b/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/merge_oai_tag.py @@ -0,0 +1,45 @@ +import argparse +import json +import time + +from tqdm import tqdm + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + parser.add_argument("--sample", type=int) + args = parser.parse_args() + + tag_file = "clean_conv_20230809_1.5M_oai_filter_v2.json" + # tag_file = "clean_conv_20230809_1.5M_oai_filter_v2_100k.json" + in_file = args.in_file + tic = time.time() + + # Load tags + print("Load tags...") + tag_data = json.load(open(tag_file)) + tag_dict = {} + for c in tqdm(tag_data): + tag_dict[c["conversation_id"]] = [x["oai_filter"] for x in c["conversation"]] + print(f"elapsed: {time.time() - tic:.2f} s") + + # Append to input_file + print("Load inputs...") + input_data = json.load(open(in_file)) + for c in tqdm(input_data): + cid = c["conversation_id"] + if cid in tag_dict: + c["openai_moderation"] = tag_dict[cid] + else: + print(f"missing tag for conv {cid}") + exit() + print(f"elapsed: {time.time() - tic:.2f} s") + + # Write output + print("Write outputs...") + out_file = in_file.replace(".json", ".with_tag.json") + print(f"Output to {out_file}") + with open(out_file, "w") as fout: + json.dump(input_data, fout, indent=2, ensure_ascii=False) + print(f"elapsed: {time.time() - tic:.2f} s") diff --git a/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/process_all.sh b/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/process_all.sh new file mode 100644 index 00000000..5bae9fba --- /dev/null +++ b/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/process_all.sh @@ -0,0 +1,18 @@ +export BASE=clean_conv_20230809_1.5M_pii +#export BASE=clean_conv_20230809_100k_pii +export SCALE=1 + +# Filter words +python3 filter_bad_conv.py --in $BASE.json --sample 1000000 + +# Clean up some fileds (e.g., timestamps) +python3 final_post_processing.py --in $BASE.s1.json + +# Upload to hf +python3 upload_hf_dataset.py --in $BASE.s1.s2.json + +# Make another version with openai moderation tag +python3 merge_oai_tag.py --in $BASE.s1.s2.json + +# Make visualizations +python3 compute_stats.py --in $BASE.s1.json --scale $SCALE diff --git a/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/sample.py b/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/sample.py new file mode 100644 index 00000000..3b6da455 --- /dev/null +++ b/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/sample.py @@ -0,0 +1,32 @@ +""" +Count the unique users in a battle log file. + +Usage: +python3 -input in.json --number 1000 +""" + +import argparse +import json +import random + +K = 1000 + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str) + parser.add_argument("--number", type=int, nargs="+") + args = parser.parse_args() + + convs = json.load(open(args.input)) + random.seed(42) + random.shuffle(convs) + + for number in args.number: + new_convs = convs[:number] + + output = args.input.replace(".json", f"_{number//K}k.json") + with open(output, "w") as fout: + json.dump(new_convs, fout, indent=2, ensure_ascii=False) + + print(f"#in: {len(convs)}, #out: {len(new_convs)}") + print(f"Write to file: {output}") diff --git a/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/upload_hf_dataset.py b/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/upload_hf_dataset.py new file mode 100644 index 00000000..41d0fbdb --- /dev/null +++ b/third_party/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/upload_hf_dataset.py @@ -0,0 +1,17 @@ +""" +Upload to huggingface. +""" +import argparse +import json +from datasets import Dataset, DatasetDict, load_dataset + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-file", type=str, required=True) + args = parser.parse_args() + + objs = json.load(open(args.in_file)) + print(f"#convs: {len(objs)}") + data = Dataset.from_list(objs) + data.push_to_hub("lmsys/lmsys-chat-1m", private=True) diff --git a/third_party/FastChat/fastchat/serve/monitor/elo_analysis.py b/third_party/FastChat/fastchat/serve/monitor/elo_analysis.py new file mode 100644 index 00000000..9b8d8c3f --- /dev/null +++ b/third_party/FastChat/fastchat/serve/monitor/elo_analysis.py @@ -0,0 +1,369 @@ +import argparse +from collections import defaultdict +import datetime +import json +import math +import pickle +from pytz import timezone + +import numpy as np +import pandas as pd +import plotly.express as px +from tqdm import tqdm + +from fastchat.model.model_registry import get_model_info +from fastchat.serve.monitor.basic_stats import get_log_files +from fastchat.serve.monitor.clean_battle_data import clean_battle_data + + +pd.options.display.float_format = "{:.2f}".format + + +def compute_elo(battles, K=4, SCALE=400, BASE=10, INIT_RATING=1000): + rating = defaultdict(lambda: INIT_RATING) + + for rd, model_a, model_b, winner in battles[ + ["model_a", "model_b", "winner"] + ].itertuples(): + ra = rating[model_a] + rb = rating[model_b] + ea = 1 / (1 + BASE ** ((rb - ra) / SCALE)) + eb = 1 / (1 + BASE ** ((ra - rb) / SCALE)) + if winner == "model_a": + sa = 1 + elif winner == "model_b": + sa = 0 + elif winner == "tie" or winner == "tie (bothbad)": + sa = 0.5 + else: + raise Exception(f"unexpected vote {winner}") + rating[model_a] += K * (sa - ea) + rating[model_b] += K * (1 - sa - eb) + + return dict(rating) + + +def get_bootstrap_result(battles, func_compute_elo, num_round=1000): + rows = [] + for i in tqdm(range(num_round), desc="bootstrap"): + tmp_battles = battles.sample(frac=1.0, replace=True) + rows.append(func_compute_elo(tmp_battles)) + df = pd.DataFrame(rows) + return df[df.median().sort_values(ascending=False).index] + + +def compute_elo_mle_with_tie(df, SCALE=400, BASE=10, INIT_RATING=1000): + from sklearn.linear_model import LogisticRegression + + models = pd.concat([df["model_a"], df["model_b"]]).unique() + models = pd.Series(np.arange(len(models)), index=models) + + # duplicate battles + df = pd.concat([df, df], ignore_index=True) + p = len(models.index) + n = df.shape[0] + + X = np.zeros([n, p]) + X[np.arange(n), models[df["model_a"]]] = +math.log(BASE) + X[np.arange(n), models[df["model_b"]]] = -math.log(BASE) + + # one A win => two A win + Y = np.zeros(n) + Y[df["winner"] == "model_a"] = 1.0 + + # one tie => one A win + one B win + # find tie + tie (both bad) index + tie_idx = (df["winner"] == "tie") | (df["winner"] == "tie (bothbad)") + tie_idx[len(tie_idx) // 2 :] = False + Y[tie_idx] = 1.0 + + lr = LogisticRegression(fit_intercept=False) + lr.fit(X, Y) + + elo_scores = SCALE * lr.coef_[0] + INIT_RATING + # calibrate llama-13b to 800 if applicable + if "llama-13b" in models.index: + elo_scores += 800 - elo_scores[models["llama-13b"]] + return pd.Series(elo_scores, index=models.index).sort_values(ascending=False) + + +def get_median_elo_from_bootstrap(bootstrap_df): + median = dict(bootstrap_df.quantile(0.5)) + median = {k: int(v + 0.5) for k, v in median.items()} + return median + + +def compute_pairwise_win_fraction(battles, model_order, limit_show_number=None): + # Times each model wins as Model A + a_win_ptbl = pd.pivot_table( + battles[battles["winner"] == "model_a"], + index="model_a", + columns="model_b", + aggfunc="size", + fill_value=0, + ) + + # Table counting times each model wins as Model B + b_win_ptbl = pd.pivot_table( + battles[battles["winner"] == "model_b"], + index="model_a", + columns="model_b", + aggfunc="size", + fill_value=0, + ) + + # Table counting number of A-B pairs + num_battles_ptbl = pd.pivot_table( + battles, index="model_a", columns="model_b", aggfunc="size", fill_value=0 + ) + + # Computing the proportion of wins for each model as A and as B + # against all other models + row_beats_col_freq = (a_win_ptbl + b_win_ptbl.T) / ( + num_battles_ptbl + num_battles_ptbl.T + ) + + if model_order is None: + prop_wins = row_beats_col_freq.mean(axis=1).sort_values(ascending=False) + model_order = list(prop_wins.keys()) + + if limit_show_number is not None: + model_order = model_order[:limit_show_number] + + # Arrange ordering according to proprition of wins + row_beats_col = row_beats_col_freq.loc[model_order, model_order] + return row_beats_col + + +def visualize_leaderboard_table(rating): + models = list(rating.keys()) + models.sort(key=lambda k: -rating[k]) + + emoji_dict = { + 1: "🥇", + 2: "🥈", + 3: "🥉", + } + + md = "" + md += "| Rank | Model | Elo Rating | Description |\n" + md += "| --- | --- | --- | --- |\n" + for i, model in enumerate(models): + rank = i + 1 + minfo = get_model_info(model) + emoji = emoji_dict.get(rank, "") + md += f"| {rank} | {emoji} [{model}]({minfo.link}) | {rating[model]:.0f} | {minfo.description} |\n" + + return md + + +def visualize_pairwise_win_fraction(battles, model_order): + row_beats_col = compute_pairwise_win_fraction(battles, model_order) + fig = px.imshow( + row_beats_col, + color_continuous_scale="RdBu", + text_auto=".2f", + height=700, + width=700, + ) + fig.update_layout( + xaxis_title="Model B", + yaxis_title="Model A", + xaxis_side="top", + title_y=0.07, + title_x=0.5, + ) + fig.update_traces( + hovertemplate="Model A: %{y}
Model B: %{x}
Fraction of A Wins: %{z}" + ) + + return fig + + +def visualize_battle_count(battles, model_order): + ptbl = pd.pivot_table( + battles, index="model_a", columns="model_b", aggfunc="size", fill_value=0 + ) + battle_counts = ptbl + ptbl.T + fig = px.imshow( + battle_counts.loc[model_order, model_order], + text_auto=True, + height=700, + width=700, + ) + fig.update_layout( + xaxis_title="Model B", + yaxis_title="Model A", + xaxis_side="top", + title_y=0.07, + title_x=0.5, + ) + fig.update_traces( + hovertemplate="Model A: %{y}
Model B: %{x}
Count: %{z}" + ) + return fig + + +def visualize_average_win_rate(battles, limit_show_number): + row_beats_col_freq = compute_pairwise_win_fraction( + battles, None, limit_show_number=limit_show_number + ) + fig = px.bar( + row_beats_col_freq.mean(axis=1).sort_values(ascending=False), + text_auto=".2f", + height=500, + width=700, + ) + fig.update_layout( + yaxis_title="Average Win Rate", xaxis_title="Model", showlegend=False + ) + return fig + + +def visualize_bootstrap_elo_rating(df, df_final, limit_show_number): + bars = ( + pd.DataFrame( + dict( + lower=df.quantile(0.025), + rating=df_final, + upper=df.quantile(0.975), + ) + ) + .reset_index(names="model") + .sort_values("rating", ascending=False) + ) + bars = bars[:limit_show_number] + bars["error_y"] = bars["upper"] - bars["rating"] + bars["error_y_minus"] = bars["rating"] - bars["lower"] + bars["rating_rounded"] = np.round(bars["rating"], 2) + fig = px.scatter( + bars, + x="model", + y="rating", + error_y="error_y", + error_y_minus="error_y_minus", + text="rating_rounded", + height=500, + width=700, + ) + fig.update_layout(xaxis_title="Model", yaxis_title="Rating") + return fig + + +def report_elo_analysis_results(battles_json, rating_system="bt", num_bootstrap=100): + battles = pd.DataFrame(battles_json) + battles = battles.sort_values(ascending=True, by=["tstamp"]) + # Only use anonymous votes + battles = battles[battles["anony"]].reset_index(drop=True) + battles_no_ties = battles[~battles["winner"].str.contains("tie")] + + # Online update + elo_rating_online = compute_elo(battles) + + if rating_system == "bt": + bootstrap_df = get_bootstrap_result( + battles, compute_elo_mle_with_tie, num_round=num_bootstrap + ) + elo_rating_final = compute_elo_mle_with_tie(battles) + elif rating_system == "elo": + bootstrap_df = get_bootstrap_result( + battles, compute_elo, num_round=num_bootstrap + ) + elo_rating_median = get_median_elo_from_bootstrap(bootstrap_df) + elo_rating_final = elo_rating_median + + model_order = list(elo_rating_final.keys()) + model_order.sort(key=lambda k: -elo_rating_final[k]) + + limit_show_number = 25 # limit show number to make plots smaller + model_order = model_order[:limit_show_number] + + # leaderboard_table_df: elo rating, variance, 95% interval, number of battles + leaderboard_table_df = pd.DataFrame( + { + "rating": elo_rating_final, + "variance": bootstrap_df.var(), + "rating_q975": bootstrap_df.quantile(0.975), + "rating_q025": bootstrap_df.quantile(0.025), + "num_battles": battles["model_a"].value_counts() + + battles["model_b"].value_counts(), + } + ) + + # Plots + leaderboard_table = visualize_leaderboard_table(elo_rating_final) + win_fraction_heatmap = visualize_pairwise_win_fraction(battles_no_ties, model_order) + battle_count_heatmap = visualize_battle_count(battles_no_ties, model_order) + average_win_rate_bar = visualize_average_win_rate( + battles_no_ties, limit_show_number + ) + bootstrap_elo_rating = visualize_bootstrap_elo_rating( + bootstrap_df, elo_rating_final, limit_show_number + ) + + last_updated_tstamp = battles["tstamp"].max() + last_updated_datetime = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y-%m-%d %H:%M:%S %Z") + + return { + "rating_system": rating_system, + "elo_rating_online": elo_rating_online, + "elo_rating_final": elo_rating_final, + "leaderboard_table": leaderboard_table, + "win_fraction_heatmap": win_fraction_heatmap, + "battle_count_heatmap": battle_count_heatmap, + "average_win_rate_bar": average_win_rate_bar, + "bootstrap_elo_rating": bootstrap_elo_rating, + "last_updated_datetime": last_updated_datetime, + "last_updated_tstamp": last_updated_tstamp, + "bootstrap_df": bootstrap_df, + "leaderboard_table_df": leaderboard_table_df, + } + + +def pretty_print_elo_rating(rating): + model_order = list(rating.keys()) + model_order.sort(key=lambda k: -rating[k]) + for i, model in enumerate(model_order): + print(f"{i+1:2d}, {model:25s}, {rating[model]:.0f}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--clean-battle-file", type=str) + parser.add_argument("--max-num-files", type=int) + parser.add_argument("--num-bootstrap", type=int, default=100) + parser.add_argument( + "--rating-system", type=str, choices=["bt", "elo"], default="bt" + ) + parser.add_argument("--exclude-tie", action="store_true", default=False) + args = parser.parse_args() + + np.random.seed(42) + + if args.clean_battle_file: + # Read data from a cleaned battle files + battles = pd.read_json(args.clean_battle_file) + else: + # Read data from all log files + log_files = get_log_files(args.max_num_files) + battles = clean_battle_data(log_files) + + results = report_elo_analysis_results( + battles, rating_system=args.rating_system, num_bootstrap=args.num_bootstrap + ) + + print("# Online Elo") + pretty_print_elo_rating(results["elo_rating_online"]) + print("# Median") + pretty_print_elo_rating(results["elo_rating_final"]) + print(f"last update : {results['last_updated_datetime']}") + + last_updated_tstamp = results["last_updated_tstamp"] + cutoff_date = datetime.datetime.fromtimestamp( + last_updated_tstamp, tz=timezone("US/Pacific") + ).strftime("%Y%m%d") + + with open(f"elo_results_{cutoff_date}.pkl", "wb") as fout: + pickle.dump(results, fout) diff --git a/third_party/FastChat/fastchat/serve/monitor/inspect_conv.py b/third_party/FastChat/fastchat/serve/monitor/inspect_conv.py new file mode 100644 index 00000000..a680a419 --- /dev/null +++ b/third_party/FastChat/fastchat/serve/monitor/inspect_conv.py @@ -0,0 +1,87 @@ +import argparse +import code +import datetime +import json +import os +from pytz import timezone +import time + +import pandas as pd +from tqdm import tqdm + + +def get_log_files(max_num_files=None): + dates = [] + for month in [4, 5]: + for day in range(1, 32): + dates.append(f"2023-{month:02d}-{day:02d}") + + num_servers = 14 + filenames = [] + for d in dates: + for i in range(num_servers): + name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json") + if os.path.exists(name): + filenames.append(name) + max_num_files = max_num_files or len(filenames) + filenames = filenames[-max_num_files:] + return filenames + + +def pretty_print_conversation(messages): + for role, msg in messages: + print(f"[[{role}]]: {msg}") + + +def inspect_convs(log_files): + data = [] + for filename in tqdm(log_files, desc="read files"): + for retry in range(5): + try: + lines = open(filename).readlines() + break + except FileNotFoundError: + time.sleep(2) + + for l in lines: + row = json.loads(l) + + if "states" not in row: + continue + if row["type"] not in ["leftvote", "rightvote", "bothbad_vote"]: + continue + + model_names = row["states"][0]["model_name"], row["states"][1]["model_name"] + if row["type"] == "leftvote": + winner, loser = model_names[0], model_names[1] + winner_conv, loser_conv = row["states"][0], row["states"][1] + elif row["type"] == "rightvote": + loser, winner = model_names[0], model_names[1] + loser_conv, winner_conv = row["states"][0], row["states"][1] + + if loser == "bard" and winner == "vicuna-13b": + print("=" * 20) + print(f"Winner: {winner}") + pretty_print_conversation(winner_conv["messages"]) + print(f"Loser: {loser}") + pretty_print_conversation(loser_conv["messages"]) + print("=" * 20) + input() + + # if row["type"] == "bothbad_vote" and "gpt-4" in model_names: + # print("=" * 20) + # print(f"Model A: {model_names[0]}") + # pretty_print_conversation(row["states"][0]["messages"]) + # print(f"Model B: {model_names[1]}") + # pretty_print_conversation(row["states"][1]["messages"]) + # print("=" * 20) + # input() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--max-num-files", type=int) + args = parser.parse_args() + + log_files = get_log_files(args.max_num_files) + inspect_convs(log_files) diff --git a/third_party/FastChat/fastchat/serve/monitor/intersect_conv_file.py b/third_party/FastChat/fastchat/serve/monitor/intersect_conv_file.py new file mode 100644 index 00000000..9eadd7cd --- /dev/null +++ b/third_party/FastChat/fastchat/serve/monitor/intersect_conv_file.py @@ -0,0 +1,25 @@ +""" +Take the intersection of two conversation files. + +Usage: python3 -m fastchat.data.merge --input input.json --conv-id conv_id_file.json --out intersect.json +""" + +import argparse +import json + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str, required=True) + parser.add_argument("--conv-id", type=str, required=True) + parser.add_argument("--out-file", type=str, default="intersect.json") + args = parser.parse_args() + + conv_id_objs = json.load(open(args.conv_id, "r")) + conv_ids = set(x["conversation_id"] for x in conv_id_objs) + + objs = json.load(open(args.input, "r")) + after_objs = [x for x in objs if x["conversation_id"] in conv_ids] + + print(f"#in: {len(objs)}, #out: {len(after_objs)}") + json.dump(after_objs, open(args.out_file, "w"), indent=2, ensure_ascii=False) diff --git a/third_party/FastChat/fastchat/serve/monitor/leaderboard_csv_to_html.py b/third_party/FastChat/fastchat/serve/monitor/leaderboard_csv_to_html.py new file mode 100644 index 00000000..ad52e7b2 --- /dev/null +++ b/third_party/FastChat/fastchat/serve/monitor/leaderboard_csv_to_html.py @@ -0,0 +1,51 @@ +""" +Convert a leaderboard csv file to html table used in the blog. + +Usage: +python3 leaderboard_csv_to_html.py --in leaderboard_table_20230619.csv +""" +import argparse + +import numpy as np + +from fastchat.serve.monitor.monitor import load_leaderboard_table_csv + + +def model_hyperlink(model_name, link): + return f' {model_name} ' + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str, required=True) + args = parser.parse_args() + + data = load_leaderboard_table_csv(args.input, add_hyperlink=False) + headers = [ + "Model", + "MT-bench (score)", + "Arena Elo rating", + "MMLU", + "License", + ] + values = [] + for item in data: + row = [] + for key in headers: + value = item[key] + row.append(value) + row[0] = model_hyperlink(item["Model"], item["Link"]) + values.append(row) + values.sort(key=lambda x: -x[1] if not np.isnan(x[1]) else 1e9) + + for value in values: + row = "" + for x in value: + try: + if np.isnan(x): + x = "-" + except TypeError: + pass + row += f" {x} " + row += "" + print(row) diff --git a/third_party/FastChat/fastchat/serve/monitor/monitor.py b/third_party/FastChat/fastchat/serve/monitor/monitor.py new file mode 100644 index 00000000..1998681a --- /dev/null +++ b/third_party/FastChat/fastchat/serve/monitor/monitor.py @@ -0,0 +1,451 @@ +""" +Live monitor of the website statistics and leaderboard. + +Dependency: +sudo apt install pkg-config libicu-dev +pip install pytz gradio gdown plotly polyglot pyicu pycld2 tabulate +""" + +import argparse +import ast +import json +import pickle +import os +import threading +import time + +import pandas as pd +import gradio as gr +import numpy as np + +from fastchat.serve.monitor.basic_stats import report_basic_stats, get_log_files +from fastchat.serve.monitor.clean_battle_data import clean_battle_data +from fastchat.serve.monitor.elo_analysis import report_elo_analysis_results +from fastchat.utils import build_logger, get_window_url_params_js + + +notebook_url = ( + "https://colab.research.google.com/drive/1KdwokPjirkTmpO_P1WByFNFiqxWQquwH" +) + +basic_component_values = [None] * 6 +leader_component_values = [None] * 5 + + +def make_default_md(arena_df, elo_results): + total_votes = sum(arena_df["num_battles"]) // 2 + total_models = len(arena_df) + + leaderboard_md = f""" +# 🏆 LMSYS Chatbot Arena Leaderboard +| [Vote](https://chat.lmsys.org) | [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | + +LMSYS [Chatbot Arena](https://lmsys.org/blog/2023-05-03-arena/) is a crowdsourced open platform for LLM evals. +We've collected over **200,000** human preference votes to rank LLMs with the Elo ranking system. +""" + return leaderboard_md + + +def make_arena_leaderboard_md(arena_df): + total_votes = sum(arena_df["num_battles"]) // 2 + total_models = len(arena_df) + + leaderboard_md = f""" +Total #models: **{total_models}**. Total #votes: **{total_votes}**. Last updated: Jan 9, 2024. + +Contribute your vote 🗳️ at [chat.lmsys.org](https://chat.lmsys.org)! Find more analysis in the [notebook]({notebook_url}). +""" + return leaderboard_md + + +def make_full_leaderboard_md(elo_results): + leaderboard_md = f""" +Two more benchmarks are displayed: **MT-Bench** and **MMLU**. +- [MT-Bench](https://arxiv.org/abs/2306.05685): a set of challenging multi-turn questions. We use GPT-4 to grade the model responses. +- [MMLU](https://arxiv.org/abs/2009.03300) (5-shot): a test to measure a model's multitask accuracy on 57 tasks. +""" + return leaderboard_md + + +def make_leaderboard_md_live(elo_results): + leaderboard_md = f""" +# Leaderboard +Last updated: {elo_results["last_updated_datetime"]} +{elo_results["leaderboard_table"]} +""" + return leaderboard_md + + +def update_elo_components( + max_num_files, elo_results_file, ban_ip_file, exclude_model_names +): + log_files = get_log_files(max_num_files) + + # Leaderboard + if elo_results_file is None: # Do live update + ban_ip_list = json.load(open(ban_ip_file)) if ban_ip_file else None + battles = clean_battle_data( + log_files, exclude_model_names, ban_ip_list=ban_ip_list + ) + elo_results = report_elo_analysis_results(battles) + + leader_component_values[0] = make_leaderboard_md_live(elo_results) + leader_component_values[1] = elo_results["win_fraction_heatmap"] + leader_component_values[2] = elo_results["battle_count_heatmap"] + leader_component_values[3] = elo_results["bootstrap_elo_rating"] + leader_component_values[4] = elo_results["average_win_rate_bar"] + + # Basic stats + basic_stats = report_basic_stats(log_files) + md0 = f"Last updated: {basic_stats['last_updated_datetime']}" + + md1 = "### Action Histogram\n" + md1 += basic_stats["action_hist_md"] + "\n" + + md2 = "### Anony. Vote Histogram\n" + md2 += basic_stats["anony_vote_hist_md"] + "\n" + + md3 = "### Model Call Histogram\n" + md3 += basic_stats["model_hist_md"] + "\n" + + md4 = "### Model Call (Last 24 Hours)\n" + md4 += basic_stats["num_chats_last_24_hours"] + "\n" + + basic_component_values[0] = md0 + basic_component_values[1] = basic_stats["chat_dates_bar"] + basic_component_values[2] = md1 + basic_component_values[3] = md2 + basic_component_values[4] = md3 + basic_component_values[5] = md4 + + +def update_worker( + max_num_files, interval, elo_results_file, ban_ip_file, exclude_model_names +): + while True: + tic = time.time() + update_elo_components( + max_num_files, elo_results_file, ban_ip_file, exclude_model_names + ) + durtaion = time.time() - tic + print(f"update duration: {durtaion:.2f} s") + time.sleep(max(interval - durtaion, 0)) + + +def load_demo(url_params, request: gr.Request): + logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") + return basic_component_values + leader_component_values + + +def model_hyperlink(model_name, link): + return f'{model_name}' + + +def load_leaderboard_table_csv(filename, add_hyperlink=True): + lines = open(filename).readlines() + heads = [v.strip() for v in lines[0].split(",")] + rows = [] + for i in range(1, len(lines)): + row = [v.strip() for v in lines[i].split(",")] + for j in range(len(heads)): + item = {} + for h, v in zip(heads, row): + if h == "Arena Elo rating": + if v != "-": + v = int(ast.literal_eval(v)) + else: + v = np.nan + elif h == "MMLU": + if v != "-": + v = round(ast.literal_eval(v) * 100, 1) + else: + v = np.nan + elif h == "MT-bench (win rate %)": + if v != "-": + v = round(ast.literal_eval(v[:-1]), 1) + else: + v = np.nan + elif h == "MT-bench (score)": + if v != "-": + v = round(ast.literal_eval(v), 2) + else: + v = np.nan + item[h] = v + if add_hyperlink: + item["Model"] = model_hyperlink(item["Model"], item["Link"]) + rows.append(item) + + return rows + + +def build_basic_stats_tab(): + empty = "Loading ..." + basic_component_values[:] = [empty, None, empty, empty, empty, empty] + + md0 = gr.Markdown(empty) + gr.Markdown("#### Figure 1: Number of model calls and votes") + plot_1 = gr.Plot(show_label=False) + with gr.Row(): + with gr.Column(): + md1 = gr.Markdown(empty) + with gr.Column(): + md2 = gr.Markdown(empty) + with gr.Row(): + with gr.Column(): + md3 = gr.Markdown(empty) + with gr.Column(): + md4 = gr.Markdown(empty) + return [md0, plot_1, md1, md2, md3, md4] + + +def get_full_table(arena_df, model_table_df): + values = [] + for i in range(len(model_table_df)): + row = [] + model_key = model_table_df.iloc[i]["key"] + model_name = model_table_df.iloc[i]["Model"] + # model display name + row.append(model_name) + if model_key in arena_df.index: + idx = arena_df.index.get_loc(model_key) + row.append(round(arena_df.iloc[idx]["rating"])) + else: + row.append(np.nan) + row.append(model_table_df.iloc[i]["MT-bench (score)"]) + row.append(model_table_df.iloc[i]["MMLU"]) + # Organization + row.append(model_table_df.iloc[i]["Organization"]) + # license + row.append(model_table_df.iloc[i]["License"]) + + values.append(row) + values.sort(key=lambda x: -x[1] if not np.isnan(x[1]) else 1e9) + return values + + +def get_arena_table(arena_df, model_table_df): + # sort by rating + arena_df = arena_df.sort_values(by=["rating"], ascending=False) + values = [] + for i in range(len(arena_df)): + row = [] + model_key = arena_df.index[i] + model_name = model_table_df[model_table_df["key"] == model_key]["Model"].values[ + 0 + ] + + # rank + row.append(i + 1) + # model display name + row.append(model_name) + # elo rating + row.append(round(arena_df.iloc[i]["rating"])) + upper_diff = round(arena_df.iloc[i]["rating_q975"] - arena_df.iloc[i]["rating"]) + lower_diff = round(arena_df.iloc[i]["rating"] - arena_df.iloc[i]["rating_q025"]) + row.append(f"+{upper_diff}/-{lower_diff}") + # num battles + row.append(round(arena_df.iloc[i]["num_battles"])) + # Organization + row.append( + model_table_df[model_table_df["key"] == model_key]["Organization"].values[0] + ) + # license + row.append( + model_table_df[model_table_df["key"] == model_key]["License"].values[0] + ) + + values.append(row) + return values + + +def build_leaderboard_tab(elo_results_file, leaderboard_table_file, show_plot=False): + if elo_results_file is None: # Do live update + default_md = "Loading ..." + p1 = p2 = p3 = p4 = None + else: + with open(elo_results_file, "rb") as fin: + elo_results = pickle.load(fin) + + p1 = elo_results["win_fraction_heatmap"] + p2 = elo_results["battle_count_heatmap"] + p3 = elo_results["bootstrap_elo_rating"] + p4 = elo_results["average_win_rate_bar"] + arena_df = elo_results["leaderboard_table_df"] + default_md = make_default_md(arena_df, elo_results) + + md_1 = gr.Markdown(default_md, elem_id="leaderboard_markdown") + if leaderboard_table_file: + data = load_leaderboard_table_csv(leaderboard_table_file) + model_table_df = pd.DataFrame(data) + + with gr.Tabs() as tabs: + # arena table + arena_table_vals = get_arena_table(arena_df, model_table_df) + with gr.Tab("Arena Elo", id=0): + md = make_arena_leaderboard_md(arena_df) + gr.Markdown(md, elem_id="leaderboard_markdown") + gr.Dataframe( + headers=[ + "Rank", + "🤖 Model", + "⭐ Arena Elo", + "📊 95% CI", + "🗳️ Votes", + "Organization", + "License", + ], + datatype=[ + "str", + "markdown", + "number", + "str", + "number", + "str", + "str", + ], + value=arena_table_vals, + elem_id="arena_leaderboard_dataframe", + height=700, + column_widths=[50, 200, 100, 100, 100, 150, 150], + wrap=True, + ) + with gr.Tab("Full Leaderboard", id=1): + md = make_full_leaderboard_md(elo_results) + gr.Markdown(md, elem_id="leaderboard_markdown") + full_table_vals = get_full_table(arena_df, model_table_df) + gr.Dataframe( + headers=[ + "🤖 Model", + "⭐ Arena Elo", + "📈 MT-bench", + "📚 MMLU", + "Organization", + "License", + ], + datatype=["markdown", "number", "number", "number", "str", "str"], + value=full_table_vals, + elem_id="full_leaderboard_dataframe", + column_widths=[200, 100, 100, 100, 150, 150], + height=700, + wrap=True, + ) + if not show_plot: + gr.Markdown( + """ ## Visit our [HF space](https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard) for more analysis! + If you want to see more models, please help us [add them](https://github.com/lm-sys/FastChat/blob/main/docs/arena.md#how-to-add-a-new-model). + """, + elem_id="leaderboard_markdown", + ) + else: + pass + + leader_component_values[:] = [default_md, p1, p2, p3, p4] + + if show_plot: + gr.Markdown( + f"""## More Statistics for Chatbot Arena\n +Below are figures for more statistics. The code for generating them is also included in this [notebook]({notebook_url}). +You can find more discussions in this blog [post](https://lmsys.org/blog/2023-12-07-leaderboard/). + """, + elem_id="leaderboard_markdown", + ) + with gr.Row(): + with gr.Column(): + gr.Markdown( + "#### Figure 1: Fraction of Model A Wins for All Non-tied A vs. B Battles" + ) + plot_1 = gr.Plot(p1, show_label=False) + with gr.Column(): + gr.Markdown( + "#### Figure 2: Battle Count for Each Combination of Models (without Ties)" + ) + plot_2 = gr.Plot(p2, show_label=False) + with gr.Row(): + with gr.Column(): + gr.Markdown( + "#### Figure 3: Bootstrap of Elo Estimates (1000 Rounds of Random Sampling)" + ) + plot_3 = gr.Plot(p3, show_label=False) + with gr.Column(): + gr.Markdown( + "#### Figure 4: Average Win Rate Against All Other Models (Assuming Uniform Sampling and No Ties)" + ) + plot_4 = gr.Plot(p4, show_label=False) + + from fastchat.serve.gradio_web_server import acknowledgment_md + + gr.Markdown(acknowledgment_md) + + if show_plot: + return [md_1, plot_1, plot_2, plot_3, plot_4] + return [md_1] + + +def build_demo(elo_results_file, leaderboard_table_file): + from fastchat.serve.gradio_web_server import block_css + + text_size = gr.themes.sizes.text_lg + + with gr.Blocks( + title="Monitor", + theme=gr.themes.Base(text_size=text_size), + css=block_css, + ) as demo: + with gr.Tabs() as tabs: + with gr.Tab("Leaderboard", id=0): + leader_components = build_leaderboard_tab( + elo_results_file, + leaderboard_table_file, + show_plot=True, + ) + + with gr.Tab("Basic Stats", id=1): + basic_components = build_basic_stats_tab() + + url_params = gr.JSON(visible=False) + demo.load( + load_demo, + [url_params], + basic_components + leader_components, + _js=get_window_url_params_js, + ) + + return demo + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int) + parser.add_argument("--share", action="store_true") + parser.add_argument("--concurrency-count", type=int, default=10) + parser.add_argument("--update-interval", type=int, default=300) + parser.add_argument("--max-num-files", type=int) + parser.add_argument("--elo-results-file", type=str) + parser.add_argument("--leaderboard-table-file", type=str) + parser.add_argument("--ban-ip-file", type=str) + parser.add_argument("--exclude-model-names", type=str, nargs="+") + args = parser.parse_args() + + logger = build_logger("monitor", "monitor.log") + logger.info(f"args: {args}") + + if args.elo_results_file is None: # Do live update + update_thread = threading.Thread( + target=update_worker, + args=( + args.max_num_files, + args.update_interval, + args.elo_results_file, + args.ban_ip_file, + args.exclude_model_names, + ), + ) + update_thread.start() + + demo = build_demo(args.elo_results_file, args.leaderboard_table_file) + demo.queue( + concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False + ).launch( + server_name=args.host, server_port=args.port, share=args.share, max_threads=200 + ) diff --git a/third_party/FastChat/fastchat/serve/monitor/summarize_cluster.py b/third_party/FastChat/fastchat/serve/monitor/summarize_cluster.py new file mode 100644 index 00000000..8bac97ab --- /dev/null +++ b/third_party/FastChat/fastchat/serve/monitor/summarize_cluster.py @@ -0,0 +1,76 @@ +""" +Usage: +python3 summarize_cluster.py --in results_c20_kmeans_cluster.pkl --model gpt-4 --num-prompts 100 +python3 summarize_cluster.py --in results_c20_kmeans_cluster.pkl --model azure-gpt-4-32k --num-prompts 200 +""" +import argparse +import pickle + +from fastchat.llm_judge.common import ( + chat_completion_openai, + chat_completion_openai_azure, + chat_completion_anthropic, +) +from fastchat.conversation import get_conv_template + + +def truncate_string(s, l): + half = int(l // 2) + return s[:half] + s[-half:] if len(s) > l else s + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input-file", type=str, required=True) + parser.add_argument("--model", type=str, default="gpt-3.5-turbo") + parser.add_argument("--num-prompts", type=int, default=100) + args = parser.parse_args() + + model = args.model + + cluster_infos = pickle.load(open(args.input_file, "rb")) + num_total_prompts = sum([x[0] for x in cluster_infos]) + + topics = [] + percentages = [] + for i, info in enumerate(cluster_infos): + num_samples, topk_prompts, random_prompts = info + percentage = num_samples / num_total_prompts + print( + f"cluster {i}, #prompts {num_samples}, percentage: {percentage * 100:.2f}%" + ) + instruct = "Given a list of user messages, use less than 8 words to summarize a central topic for all messages in English. Your output should only include a single line. Try to be specific." + split = int(args.num_prompts * 0.8) + prompt = "\n".join( + [truncate_string(x, l=200) for x in topk_prompts[:split]] + + [ + truncate_string(x, l=200) + for x in random_prompts[: args.num_prompts - split] + ] + ) + prompt = "BEGIN OF THE MESSAGE LIST\n" + prompt + "\nEND OF THE MESSAGE LIST." + + if "azure-" in model: + template_name = "chatgpt" + completion_func = chat_completion_openai_azure + elif "gpt" in model: + template_name = "chatgpt" + completion_func = chat_completion_openai + elif "claude" in model: + template_name = "claude" + completion_func = chat_completion_anthropic + + conv = get_conv_template(template_name) + conv.set_system_message(instruct) + conv.append_message(conv.roles[0], prompt) + conv.append_message(conv.roles[1], None) + + topic = completion_func(model, conv, temperature=0, max_tokens=256) + print(topic) + + topics.append(topic) + percentages.append(round(percentage, 6)) + + print() + print(f"topics: {topics}") + print(f"percentages: {percentages}") diff --git a/third_party/FastChat/fastchat/serve/monitor/tag_openai_moderation.py b/third_party/FastChat/fastchat/serve/monitor/tag_openai_moderation.py new file mode 100644 index 00000000..b8070338 --- /dev/null +++ b/third_party/FastChat/fastchat/serve/monitor/tag_openai_moderation.py @@ -0,0 +1,63 @@ +""" +Add OpenAI moderation API results to all conversations. +""" +import argparse +from concurrent.futures import ThreadPoolExecutor +import json +import os +import time + +import openai +import requests +from tqdm import tqdm + + +API_MAX_RETRY = 16 +API_RETRY_SLEEP = 10 +API_ERROR_OUTPUT = "$ERROR$" + + +def tag_moderation(text): + result = API_ERROR_OUTPUT + for _ in range(API_MAX_RETRY): + try: + result = openai.Moderation.create(input=text)["results"][0] + break + except openai.error.OpenAIError as e: + print(type(e), e) + time.sleep(API_RETRY_SLEEP) + + return result + + +def tag_openai_moderation(x): + conv = x["conversation_a"] + user_prompts = "\n".join([x["content"] for x in conv if x["role"] == "user"]) + result = tag_moderation(user_prompts) + x["openai_moderation"] = result + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", type=str, required=True) + parser.add_argument( + "--parallel", type=int, default=1, help="The number of concurrent API calls." + ) + parser.add_argument("--first-n", type=int) + args = parser.parse_args() + + battles = json.load(open(args.input)) + + if args.first_n: + battles = battles[: args.first_n] + + with ThreadPoolExecutor(args.parallel) as executor: + for line in tqdm( + executor.map(tag_openai_moderation, battles), total=len(battles) + ): + pass + + output = args.input.replace(".json", "_tagged.json") + with open(output, "w") as fout: + json.dump(battles, fout, indent=2, ensure_ascii=False) + print(f"Write cleaned data to {output}") diff --git a/third_party/FastChat/fastchat/serve/monitor/topic_clustering.py b/third_party/FastChat/fastchat/serve/monitor/topic_clustering.py new file mode 100644 index 00000000..dd15c6ed --- /dev/null +++ b/third_party/FastChat/fastchat/serve/monitor/topic_clustering.py @@ -0,0 +1,267 @@ +""" + +Usage: +python3 topic_clustering.py --in arena.json --english-only --min-length 32 +python3 topic_clustering.py --in clean_conv_20230809_100k.json --english-only --min-length 32 --max-length 1536 +""" +import argparse +import json +import pickle +import string +import time + +import numpy as np +from sentence_transformers import SentenceTransformer +from sentence_transformers.util import cos_sim +from sklearn.cluster import KMeans, AgglomerativeClustering +import torch +from tqdm import tqdm + +from fastchat.utils import detect_language + + +def remove_punctuation(input_string): + # Make a translator object to remove all punctuation + translator = str.maketrans("", "", string.punctuation) + + # Use the translator object to remove the punctuation + no_punct = input_string.translate(translator) + return no_punct + + +def read_texts(input_file, min_length, max_length, english_only): + visited = set() + texts = [] + + lines = json.load(open(input_file, "r")) + + for l in tqdm(lines): + if "text" in l: + line_texts = [l["text"]] + elif "conversation_a" in l: + line_texts = [ + x["content"] for x in l["conversation_a"] if x["role"] == "user" + ] + elif "conversation" in l: + line_texts = [ + x["content"] for x in l["conversation"] if x["role"] == "user" + ] + + for text in line_texts: + text = text.strip() + + # Filter language + if english_only: + lang = detect_language(text) + if lang != "English": + continue + + # Filter short or long prompts + if min_length: + if len(text) < min_length: + continue + + if max_length: + if len(text) > max_length: + continue + + # De-duplication + words = sorted([x.lower() for x in remove_punctuation(text).split(" ")]) + words = "".join(words) + if words in visited: + continue + + visited.add(words) + texts.append(text) + return np.array(texts) + + +def get_embeddings(texts, model_name, batch_size): + model = SentenceTransformer(model_name) + embeddings = model.encode( + texts, + batch_size=batch_size, + show_progress_bar=True, + device="cuda", + convert_to_tensor=True, + ) + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) + return embeddings.cpu() + + +def run_k_means(embeddings, num_clusters): + np.random.seed(42) + clustering_model = KMeans(n_clusters=num_clusters, n_init="auto") + clustering_model.fit(embeddings.numpy()) + centers = torch.from_numpy(clustering_model.cluster_centers_) + labels = torch.from_numpy(clustering_model.labels_) + + # Sort labels + classes, counts = np.unique(labels, return_counts=True) + indices = np.argsort(counts)[::-1] + classes = [classes[i] for i in indices] + new_labels = torch.empty_like(labels) + new_centers = torch.empty_like(centers) + for i, c in enumerate(classes): + new_labels[labels == c] = i + new_centers[i] = centers[c] + return new_centers, new_labels + + +def run_agg_cluster(embeddings, num_clusters): + np.random.seed(42) + clustering_model = AgglomerativeClustering(n_clusters=num_clusters) + clustering_model.fit(embeddings) + labels = torch.from_numpy(clustering_model.labels_) + + # Sort labels + classes, counts = np.unique(labels, return_counts=True) + indices = np.argsort(counts)[::-1] + classes = [classes[i] for i in indices] + new_labels = torch.empty_like(labels) + for i, c in enumerate(classes): + new_labels[labels == c] = i + + # Compute centers + centers = [] + for i in range(len(classes)): + centers.append(embeddings[new_labels == i].mean(axis=0, keepdim=True)) + centers = torch.cat(centers) + return centers, new_labels + + +def run_hdbscan_cluster(embeddings): + import hdbscan + + np.random.seed(42) + clusterer = hdbscan.HDBSCAN(min_cluster_size=10) + labels = torch.from_numpy(clusterer.fit_predict(embeddings)) + + # Sort labels + classes, counts = np.unique(labels, return_counts=True) + indices = np.argsort(counts)[::-1] + classes = [classes[i] for i in indices] + new_labels = torch.empty_like(labels) + for i, c in enumerate(classes): + new_labels[labels == c] = i + + # Compute centers + centers = [] + for i in range(len(classes)): + centers.append(embeddings[new_labels == i].mean(axis=0, keepdim=True)) + centers = torch.cat(centers) + return centers, new_labels + + +def get_topk_indices(centers, labels, embeddings, topk): + indices = [] + arange = torch.arange(len(labels)) + counts = torch.unique(labels, return_counts=True)[1] + topk = min(topk, counts.min().item()) + for i in range(len(centers)): + tmp_indices = labels == i + tmp_arange = arange[tmp_indices] + tmp_embeddings = embeddings[tmp_indices] + + scores = cos_sim(centers[i].unsqueeze(0), tmp_embeddings)[0] + sorted_indices = torch.flip(torch.argsort(scores), dims=[0]) + indices.append(tmp_arange[sorted_indices[:topk]].unsqueeze(0)) + return torch.cat(indices) + + +def print_topk(texts, labels, topk_indices, show_cut_off): + ret = "" + for k in range(len(topk_indices)): + num_samples = torch.sum(labels == k).item() + + ret += "=" * 20 + f" cluster {k}, #samples: {num_samples} " + "=" * 20 + "\n" + for idx in topk_indices[k]: + ret += "PROMPT: " + texts[idx][:show_cut_off] + "\n" + ret += "=" * 40 + "\n\n" + + return ret + + +def get_cluster_info(texts, labels, topk_indices): + np.random.seed(42) + + cluster_info = [] + for k in range(len(topk_indices)): + num_samples = torch.sum(labels == k).item() + topk_prompts = [] + for idx in topk_indices[k]: + topk_prompts.append(texts[idx]) + random_prompts = [] + for idx in range(len(topk_indices)): + random_prompts.append(np.random.choice(texts)) + cluster_info.append((num_samples, topk_prompts, random_prompts)) + + return cluster_info + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input-file", type=str, required=True) + parser.add_argument("--model", type=str, default="all-mpnet-base-v2") + # default="all-MiniLM-L12-v2") + # default="multi-qa-distilbert-cos-v1") + parser.add_argument("--batch-size", type=int, default=256) + parser.add_argument("--min-length", type=int) + parser.add_argument("--max-length", type=int) + parser.add_argument("--english-only", action="store_true") + parser.add_argument("--num-clusters", type=int, default=20) + parser.add_argument( + "--cluster-alg", + type=str, + choices=["kmeans", "aggcls", "HDBSCAN"], + default="kmeans", + ) + parser.add_argument("--show-top-k", type=int, default=200) + parser.add_argument("--show-cut-off", type=int, default=512) + args = parser.parse_args() + + num_clusters = args.num_clusters + show_top_k = args.show_top_k + show_cut_off = args.show_cut_off + + texts = read_texts( + args.input_file, args.min_length, args.max_length, args.english_only + ) + print(f"#text: {len(texts)}") + + embeddings = get_embeddings(texts, args.model, args.batch_size) + if args.cluster_alg == "kmeans": + centers, labels = run_k_means(embeddings, num_clusters) + elif args.cluster_alg == "aggcls": + centers, labels = run_agg_cluster(embeddings, num_clusters) + elif args.cluster_alg == "HDBSCAN": + centers, labels = run_hdbscan_cluster(embeddings) + else: + raise ValueError(f"Invalid clustering algorithm: {args.cluster_alg}") + + topk_indices = get_topk_indices(centers, labels, embeddings, args.show_top_k) + topk_str = print_topk(texts, labels, topk_indices, args.show_cut_off) + num_clusters = len(centers) + + # Dump results + filename_prefix = f"results_c{num_clusters}_{args.cluster_alg}" + print(topk_str) + with open(filename_prefix + "_topk.txt", "w") as fout: + fout.write(topk_str) + + with open(filename_prefix + "_all.txt", "w") as fout: + for i in range(len(centers)): + tmp_indices = labels == i + tmp_embeddings = embeddings[tmp_indices] + tmp_texts = texts[tmp_indices] + + scores = cos_sim(centers[i].unsqueeze(0), tmp_embeddings)[0] + sorted_indices = torch.flip(torch.argsort(scores), dims=[0]) + + for text, score in zip(tmp_texts[sorted_indices], scores[sorted_indices]): + obj = {"cluster": i, "text": text, "sim": score.item()} + fout.write(json.dumps(obj, ensure_ascii=False) + "\n") + + cluster_info = get_cluster_info(texts, labels, topk_indices) + with open(filename_prefix + "_cluster.pkl", "wb") as fout: + pickle.dump(cluster_info, fout) diff --git a/third_party/FastChat/fastchat/serve/multi_model_worker.py b/third_party/FastChat/fastchat/serve/multi_model_worker.py new file mode 100644 index 00000000..5e6266fe --- /dev/null +++ b/third_party/FastChat/fastchat/serve/multi_model_worker.py @@ -0,0 +1,300 @@ +""" +A multi-model worker that contains multiple sub-works one for each model. This +supports running a list of models on the same machine so that they can +(potentially) share the same background weights. + +Each model can have one or more model names. + +This multi-model worker assumes the models shares some underlying weights and +thus reports the combined queue lengths for health checks. + +We recommend using this with multiple Peft models (with `peft` in the name) +where all Peft models are trained on the exact same base model. +""" +import argparse +import asyncio +import dataclasses +import logging +import json +import os +import time +from typing import List, Union +import threading +import uuid + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse, JSONResponse +import requests + +try: + from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + LlamaTokenizer, + AutoModel, + ) +except ImportError: + from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + LLaMATokenizer, + AutoModel, + ) +import torch +import torch.nn.functional as F +import uvicorn + +from fastchat.constants import WORKER_HEART_BEAT_INTERVAL, ErrorCode, SERVER_ERROR_MSG +from fastchat.model.model_adapter import ( + load_model, + add_model_args, + get_conversation_template, +) +from fastchat.model.model_chatglm import generate_stream_chatglm +from fastchat.model.model_falcon import generate_stream_falcon +from fastchat.model.model_codet5p import generate_stream_codet5p +from fastchat.modules.gptq import GptqConfig +from fastchat.modules.exllama import ExllamaConfig +from fastchat.modules.xfastertransformer import XftConfig +from fastchat.serve.inference import generate_stream +from fastchat.serve.model_worker import ModelWorker, worker_id, logger +from fastchat.utils import build_logger, pretty_print_semaphore, get_context_length + + +# We store both the underlying workers and a mapping from their model names to +# the worker instance. This makes it easy to fetch the appropriate worker for +# each API call. +workers = [] +worker_map = {} +app = FastAPI() + + +def release_worker_semaphore(): + workers[0].semaphore.release() + + +def acquire_worker_semaphore(): + if workers[0].semaphore is None: + # Share the same semaphore for all workers because + # all workers share the same GPU. + semaphore = asyncio.Semaphore(workers[0].limit_worker_concurrency) + for w in workers: + w.semaphore = semaphore + return workers[0].semaphore.acquire() + + +def create_background_tasks(): + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + return background_tasks + + +# Note: for all the calls below, we make a hard assumption that the caller +# includes the model name in the payload, otherwise we can't figure out which +# underlying sub-worker to call. + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + worker = worker_map[params["model"]] + generator = worker.generate_stream_gate(params) + background_tasks = create_background_tasks() + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + worker = worker_map[params["model"]] + output = worker.generate_gate(params) + release_worker_semaphore() + return JSONResponse(output) + + +@app.post("/worker_get_embeddings") +async def api_get_embeddings(request: Request): + params = await request.json() + await acquire_worker_semaphore() + worker = worker_map[params["model"]] + embedding = worker.get_embeddings(params) + background_tasks = create_background_tasks() + return JSONResponse(content=embedding, background=background_tasks) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return { + "model_names": [m for w in workers for m in w.model_names], + "speed": 1, + "queue_length": sum([w.get_queue_length() for w in workers]), + } + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + params = await request.json() + worker = worker_map[params["model"]] + return {"context_length": worker.context_len} + + +def create_multi_model_worker(): + # Note: Ensure we resolve arg conflicts. We let `add_model_args` add MOST + # of the model args but we'll override one to have an append action that + # supports multiple values. + parser = argparse.ArgumentParser(conflict_handler="resolve") + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + add_model_args(parser) + # Override the model path to be repeated and align it with model names. + parser.add_argument( + "--model-path", + type=str, + default=[], + action="append", + help="One or more paths to model weights to load. This can be a local folder or a Hugging Face repo ID.", + ) + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + action="append", + help="One or more model names. Values must be aligned with `--model-path` values.", + ) + parser.add_argument( + "--conv-template", + type=str, + default=None, + action="append", + help="Conversation prompt template. Values must be aligned with `--model-path` values. If only one value is provided, it will be repeated for all models.", + ) + parser.add_argument("--limit-worker-concurrency", type=int, default=5) + parser.add_argument("--stream-interval", type=int, default=2) + parser.add_argument("--no-register", action="store_true") + parser.add_argument( + "--ssl", + action="store_true", + required=False, + default=False, + help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", + ) + args = parser.parse_args() + logger.info(f"args: {args}") + + if args.gpus: + if len(args.gpus.split(",")) < args.num_gpus: + raise ValueError( + f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" + ) + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + + gptq_config = GptqConfig( + ckpt=args.gptq_ckpt or args.model_path, + wbits=args.gptq_wbits, + groupsize=args.gptq_groupsize, + act_order=args.gptq_act_order, + ) + if args.enable_exllama: + exllama_config = ExllamaConfig( + max_seq_len=args.exllama_max_seq_len, + gpu_split=args.exllama_gpu_split, + cache_8bit=args.exllama_cache_8bit, + ) + else: + exllama_config = None + if args.enable_xft: + xft_config = XftConfig( + max_seq_len=args.xft_max_seq_len, + data_type=args.xft_dtype, + ) + if args.device != "cpu": + print("xFasterTransformer now is only support CPUs. Reset device to CPU") + args.device = "cpu" + else: + xft_config = None + + if args.model_names is None: + args.model_names = [[x.split("/")[-1]] for x in args.model_path] + + if args.conv_template is None: + args.conv_template = [None] * len(args.model_path) + elif len(args.conv_template) == 1: # Repeat the same template + args.conv_template = args.conv_template * len(args.model_path) + + # Launch all workers + workers = [] + for conv_template, model_path, model_names in zip( + args.conv_template, args.model_path, args.model_names + ): + w = ModelWorker( + args.controller_address, + args.worker_address, + worker_id, + model_path, + model_names, + args.limit_worker_concurrency, + args.no_register, + device=args.device, + num_gpus=args.num_gpus, + max_gpu_memory=args.max_gpu_memory, + load_8bit=args.load_8bit, + cpu_offloading=args.cpu_offloading, + gptq_config=gptq_config, + exllama_config=exllama_config, + xft_config=xft_config, + stream_interval=args.stream_interval, + conv_template=conv_template, + ) + workers.append(w) + for model_name in model_names: + worker_map[model_name] = w + + # Register all models + url = args.controller_address + "/register_worker" + data = { + "worker_name": workers[0].worker_addr, + "check_heart_beat": not args.no_register, + "worker_status": { + "model_names": [m for w in workers for m in w.model_names], + "speed": 1, + "queue_length": sum([w.get_queue_length() for w in workers]), + }, + } + r = requests.post(url, json=data) + assert r.status_code == 200 + + return args, workers + + +if __name__ == "__main__": + args, workers = create_multi_model_worker() + if args.ssl: + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + ssl_keyfile=os.environ["SSL_KEYFILE"], + ssl_certfile=os.environ["SSL_CERTFILE"], + ) + else: + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/third_party/FastChat/fastchat/serve/openai_api_server.py b/third_party/FastChat/fastchat/serve/openai_api_server.py new file mode 100644 index 00000000..96af252a --- /dev/null +++ b/third_party/FastChat/fastchat/serve/openai_api_server.py @@ -0,0 +1,904 @@ +"""A server that provides OpenAI-compatible RESTful APIs. It supports: + +- Chat Completions. (Reference: https://platform.openai.com/docs/api-reference/chat) +- Completions. (Reference: https://platform.openai.com/docs/api-reference/completions) +- Embeddings. (Reference: https://platform.openai.com/docs/api-reference/embeddings) + +Usage: +python3 -m fastchat.serve.openai_api_server +""" +import asyncio +import argparse +import json +import os +from typing import Generator, Optional, Union, Dict, List, Any + +import aiohttp +import fastapi +from fastapi import Depends, HTTPException +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse, JSONResponse +from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer +import httpx +from pydantic import BaseSettings +import shortuuid +import tiktoken +import uvicorn + +from fastchat.constants import ( + WORKER_API_TIMEOUT, + WORKER_API_EMBEDDING_BATCH_SIZE, + ErrorCode, +) +from fastchat.conversation import Conversation, SeparatorStyle +from fastchat.protocol.openai_api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + ChatMessage, + ChatCompletionResponseChoice, + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + DeltaMessage, + CompletionResponseStreamChoice, + CompletionStreamResponse, + EmbeddingsRequest, + EmbeddingsResponse, + ErrorResponse, + LogProbs, + ModelCard, + ModelList, + ModelPermission, + UsageInfo, +) +from fastchat.protocol.api_protocol import ( + APIChatCompletionRequest, + APITokenCheckRequest, + APITokenCheckResponse, + APITokenCheckResponseItem, +) +from fastchat.utils import build_logger + +logger = build_logger("openai_api_server", "openai_api_server.log") + +conv_template_map = {} + +fetch_timeout = aiohttp.ClientTimeout(total=3 * 3600) + + +async def fetch_remote(url, pload=None, name=None): + async with aiohttp.ClientSession(timeout=fetch_timeout) as session: + async with session.post(url, json=pload) as response: + chunks = [] + if response.status != 200: + ret = { + "text": f"{response.reason}", + "error_code": ErrorCode.INTERNAL_ERROR, + } + return json.dumps(ret) + + async for chunk, _ in response.content.iter_chunks(): + chunks.append(chunk) + output = b"".join(chunks) + + if name is not None: + res = json.loads(output) + if name != "": + res = res[name] + return res + + return output + + +class AppSettings(BaseSettings): + # The address of the model controller. + controller_address: str = "http://localhost:21001" + api_keys: Optional[List[str]] = None + + +app_settings = AppSettings() +app = fastapi.FastAPI() +headers = {"User-Agent": "FastChat API Server"} +get_bearer_token = HTTPBearer(auto_error=False) + + +async def check_api_key( + auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), +) -> str: + if app_settings.api_keys: + if auth is None or (token := auth.credentials) not in app_settings.api_keys: + raise HTTPException( + status_code=401, + detail={ + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + }, + ) + return token + else: + # api_keys not set; allow all + return None + + +def create_error_response(code: int, message: str) -> JSONResponse: + return JSONResponse( + ErrorResponse(message=message, code=code).dict(), status_code=400 + ) + + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request, exc): + return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc)) + + +async def check_model(request) -> Optional[JSONResponse]: + controller_address = app_settings.controller_address + ret = None + + models = await fetch_remote(controller_address + "/list_models", None, "models") + if request.model not in models: + ret = create_error_response( + ErrorCode.INVALID_MODEL, + f"Only {'&&'.join(models)} allowed now, your model {request.model}", + ) + return ret + + +async def check_length(request, prompt, max_tokens, worker_addr): + if ( + not isinstance(max_tokens, int) or max_tokens <= 0 + ): # model worker not support max_tokens=None + max_tokens = 1024 * 1024 + + context_len = await fetch_remote( + worker_addr + "/model_details", {"model": request.model}, "context_length" + ) + token_num = await fetch_remote( + worker_addr + "/count_token", + {"model": request.model, "prompt": prompt}, + "count", + ) + length = min(max_tokens, context_len - token_num) + + if length <= 0: + return None, create_error_response( + ErrorCode.CONTEXT_OVERFLOW, + f"This model's maximum context length is {context_len} tokens. However, your messages resulted in {token_num} tokens. Please reduce the length of the messages.", + ) + + return length, None + + +def check_requests(request) -> Optional[JSONResponse]: + # Check all params + if request.max_tokens is not None and request.max_tokens <= 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'", + ) + if request.n is not None and request.n <= 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.n} is less than the minimum of 1 - 'n'", + ) + if request.temperature is not None and request.temperature < 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.temperature} is less than the minimum of 0 - 'temperature'", + ) + if request.temperature is not None and request.temperature > 2: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.temperature} is greater than the maximum of 2 - 'temperature'", + ) + if request.top_p is not None and request.top_p < 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.top_p} is less than the minimum of 0 - 'top_p'", + ) + if request.top_p is not None and request.top_p > 1: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.top_p} is greater than the maximum of 1 - 'top_p'", + ) + if request.top_k is not None and (request.top_k > -1 and request.top_k < 1): + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.top_k} is out of Range. Either set top_k to -1 or >=1.", + ) + if request.stop is not None and ( + not isinstance(request.stop, str) and not isinstance(request.stop, list) + ): + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.stop} is not valid under any of the given schemas - 'stop'", + ) + + return None + + +def process_input(model_name, inp): + if isinstance(inp, str): + inp = [inp] + elif isinstance(inp, list): + if isinstance(inp[0], int): + decoding = tiktoken.model.encoding_for_model(model_name) + inp = [decoding.decode(inp)] + elif isinstance(inp[0], list): + decoding = tiktoken.model.encoding_for_model(model_name) + inp = [decoding.decode(text) for text in inp] + + return inp + + +def create_openai_logprobs(logprob_dict): + """Create OpenAI-style logprobs.""" + return LogProbs(**logprob_dict) if logprob_dict is not None else None + + +def _add_to_set(s, new_stop): + if not s: + return + if isinstance(s, str): + new_stop.add(s) + else: + new_stop.update(s) + + +async def get_gen_params( + model_name: str, + worker_addr: str, + messages: Union[str, List[Dict[str, str]]], + *, + temperature: float, + top_p: float, + top_k: Optional[int], + presence_penalty: Optional[float], + frequency_penalty: Optional[float], + max_tokens: Optional[int], + echo: Optional[bool], + logprobs: Optional[int] = None, + stop: Optional[Union[str, List[str]]], + best_of: Optional[int] = None, + use_beam_search: Optional[bool] = None, +) -> Dict[str, Any]: + conv = await get_conv(model_name, worker_addr) + conv = Conversation( + name=conv["name"], + system_template=conv["system_template"], + system_message=conv["system_message"], + roles=conv["roles"], + messages=list(conv["messages"]), # prevent in-place modification + offset=conv["offset"], + sep_style=SeparatorStyle(conv["sep_style"]), + sep=conv["sep"], + sep2=conv["sep2"], + stop_str=conv["stop_str"], + stop_token_ids=conv["stop_token_ids"], + ) + + if isinstance(messages, str): + prompt = messages + else: + for message in messages: + msg_role = message["role"] + if msg_role == "system": + conv.set_system_message(message["content"]) + elif msg_role == "user": + conv.append_message(conv.roles[0], message["content"]) + elif msg_role == "assistant": + conv.append_message(conv.roles[1], message["content"]) + else: + raise ValueError(f"Unknown role: {msg_role}") + + # Add a blank message for the assistant. + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + gen_params = { + "model": model_name, + "prompt": prompt, + "temperature": temperature, + "logprobs": logprobs, + "top_p": top_p, + "top_k": top_k, + "presence_penalty": presence_penalty, + "frequency_penalty": frequency_penalty, + "max_new_tokens": max_tokens, + "echo": echo, + "stop_token_ids": conv.stop_token_ids, + } + + if best_of is not None: + gen_params.update({"best_of": best_of}) + if use_beam_search is not None: + gen_params.update({"use_beam_search": use_beam_search}) + + new_stop = set() + _add_to_set(stop, new_stop) + _add_to_set(conv.stop_str, new_stop) + + gen_params["stop"] = list(new_stop) + + logger.debug(f"==== request ====\n{gen_params}") + return gen_params + + +async def get_worker_address(model_name: str) -> str: + """ + Get worker address based on the requested model + + :param model_name: The worker's model name + :return: Worker address from the controller + :raises: :class:`ValueError`: No available worker for requested model + """ + controller_address = app_settings.controller_address + worker_addr = await fetch_remote( + controller_address + "/get_worker_address", {"model": model_name}, "address" + ) + + # No available worker + if worker_addr == "": + raise ValueError(f"No available worker for {model_name}") + logger.debug(f"model_name: {model_name}, worker_addr: {worker_addr}") + return worker_addr + + +async def get_conv(model_name: str, worker_addr: str): + conv_template = conv_template_map.get((worker_addr, model_name)) + if conv_template is None: + conv_template = await fetch_remote( + worker_addr + "/worker_get_conv_template", {"model": model_name}, "conv" + ) + conv_template_map[(worker_addr, model_name)] = conv_template + return conv_template + + +@app.get("/v1/models", dependencies=[Depends(check_api_key)]) +async def show_available_models(): + controller_address = app_settings.controller_address + ret = await fetch_remote(controller_address + "/refresh_all_workers") + models = await fetch_remote(controller_address + "/list_models", None, "models") + + models.sort() + # TODO: return real model permission details + model_cards = [] + for m in models: + model_cards.append(ModelCard(id=m, root=m, permission=[ModelPermission()])) + return ModelList(data=model_cards) + + +@app.post("/v1chat/completions", dependencies=[Depends(check_api_key)]) +@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key)]) +async def create_chat_completion(request: ChatCompletionRequest): + """Creates a completion for the chat message""" + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + error_check_ret = check_requests(request) + if error_check_ret is not None: + return error_check_ret + + worker_addr = await get_worker_address(request.model) + + gen_params = await get_gen_params( + request.model, + worker_addr, + request.messages, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + max_tokens=request.max_tokens, + echo=False, + stop=request.stop, + ) + + max_new_tokens, error_check_ret = await check_length( + request, + gen_params["prompt"], + gen_params["max_new_tokens"], + worker_addr, + ) + + if error_check_ret is not None: + return error_check_ret + + gen_params["max_new_tokens"] = max_new_tokens + + if request.stream: + generator = chat_completion_stream_generator( + request.model, gen_params, request.n, worker_addr + ) + return StreamingResponse(generator, media_type="text/event-stream") + + choices = [] + chat_completions = [] + for i in range(request.n): + content = asyncio.create_task(generate_completion(gen_params, worker_addr)) + chat_completions.append(content) + try: + all_tasks = await asyncio.gather(*chat_completions) + except Exception as e: + return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) + usage = UsageInfo() + for i, content in enumerate(all_tasks): + if content["error_code"] != 0: + return create_error_response(content["error_code"], content["text"]) + choices.append( + ChatCompletionResponseChoice( + index=i, + message=ChatMessage(role="assistant", content=content["text"]), + finish_reason=content.get("finish_reason", "stop"), + ) + ) + if "usage" in content: + task_usage = UsageInfo.parse_obj(content["usage"]) + for usage_key, usage_value in task_usage.dict().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + + return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) + + +async def chat_completion_stream_generator( + model_name: str, gen_params: Dict[str, Any], n: int, worker_addr: str +) -> Generator[str, Any, None]: + """ + Event stream format: + https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format + """ + id = f"chatcmpl-{shortuuid.random()}" + finish_stream_events = [] + for i in range(n): + # First chunk with role + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(role="assistant"), + finish_reason=None, + ) + chunk = ChatCompletionStreamResponse( + id=id, choices=[choice_data], model=model_name + ) + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + + previous_text = "" + async for content in generate_completion_stream(gen_params, worker_addr): + if content["error_code"] != 0: + yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + return + decoded_unicode = content["text"].replace("\ufffd", "") + delta_text = decoded_unicode[len(previous_text) :] + previous_text = ( + decoded_unicode + if len(decoded_unicode) > len(previous_text) + else previous_text + ) + + if len(delta_text) == 0: + delta_text = None + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=delta_text), + finish_reason=content.get("finish_reason", None), + ) + chunk = ChatCompletionStreamResponse( + id=id, choices=[choice_data], model=model_name + ) + if delta_text is None: + if content.get("finish_reason", None) is not None: + finish_stream_events.append(chunk) + continue + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + # There is not "content" field in the last delta message, so exclude_none to exclude field "content". + for finish_chunk in finish_stream_events: + yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + + +@app.post("/v1/completions", dependencies=[Depends(check_api_key)]) +async def create_completion(request: CompletionRequest): + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + error_check_ret = check_requests(request) + if error_check_ret is not None: + return error_check_ret + + request.prompt = process_input(request.model, request.prompt) + + worker_addr = await get_worker_address(request.model) + for text in request.prompt: + max_tokens, error_check_ret = await check_length( + request, text, request.max_tokens, worker_addr + ) + if error_check_ret is not None: + return error_check_ret + + if isinstance(max_tokens, int) and max_tokens < request.max_tokens: + request.max_tokens = max_tokens + + if request.stream: + generator = generate_completion_stream_generator( + request, request.n, worker_addr + ) + return StreamingResponse(generator, media_type="text/event-stream") + else: + text_completions = [] + for text in request.prompt: + gen_params = await get_gen_params( + request.model, + worker_addr, + text, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + frequency_penalty=request.frequency_penalty, + presence_penalty=request.presence_penalty, + max_tokens=request.max_tokens, + logprobs=request.logprobs, + echo=request.echo, + stop=request.stop, + best_of=request.best_of, + use_beam_search=request.use_beam_search, + ) + for i in range(request.n): + content = asyncio.create_task( + generate_completion(gen_params, worker_addr) + ) + text_completions.append(content) + + try: + all_tasks = await asyncio.gather(*text_completions) + except Exception as e: + return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) + + choices = [] + usage = UsageInfo() + for i, content in enumerate(all_tasks): + if content["error_code"] != 0: + return create_error_response(content["error_code"], content["text"]) + choices.append( + CompletionResponseChoice( + index=i, + text=content["text"], + logprobs=create_openai_logprobs(content.get("logprobs", None)), + finish_reason=content.get("finish_reason", "stop"), + ) + ) + task_usage = UsageInfo.parse_obj(content["usage"]) + for usage_key, usage_value in task_usage.dict().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + + return CompletionResponse( + model=request.model, choices=choices, usage=UsageInfo.parse_obj(usage) + ) + + +async def generate_completion_stream_generator( + request: CompletionRequest, n: int, worker_addr: str +): + model_name = request.model + id = f"cmpl-{shortuuid.random()}" + finish_stream_events = [] + for text in request.prompt: + for i in range(n): + previous_text = "" + gen_params = await get_gen_params( + request.model, + worker_addr, + text, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + max_tokens=request.max_tokens, + logprobs=request.logprobs, + echo=request.echo, + stop=request.stop, + ) + async for content in generate_completion_stream(gen_params, worker_addr): + if content["error_code"] != 0: + yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + return + decoded_unicode = content["text"].replace("\ufffd", "") + delta_text = decoded_unicode[len(previous_text) :] + previous_text = ( + decoded_unicode + if len(decoded_unicode) > len(previous_text) + else previous_text + ) + # todo: index is not apparent + choice_data = CompletionResponseStreamChoice( + index=i, + text=delta_text, + logprobs=create_openai_logprobs(content.get("logprobs", None)), + finish_reason=content.get("finish_reason", None), + ) + chunk = CompletionStreamResponse( + id=id, + object="text_completion", + choices=[choice_data], + model=model_name, + ) + if len(delta_text) == 0: + if content.get("finish_reason", None) is not None: + finish_stream_events.append(chunk) + continue + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + # There is not "content" field in the last delta message, so exclude_none to exclude field "content". + for finish_chunk in finish_stream_events: + yield f"data: {finish_chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + + +async def generate_completion_stream(payload: Dict[str, Any], worker_addr: str): + controller_address = app_settings.controller_address + async with httpx.AsyncClient() as client: + delimiter = b"\0" + async with client.stream( + "POST", + worker_addr + "/worker_generate_stream", + headers=headers, + json=payload, + timeout=WORKER_API_TIMEOUT, + ) as response: + # content = await response.aread() + buffer = b"" + async for raw_chunk in response.aiter_raw(): + buffer += raw_chunk + while (chunk_end := buffer.find(delimiter)) >= 0: + chunk, buffer = buffer[:chunk_end], buffer[chunk_end + 1 :] + if not chunk: + continue + yield json.loads(chunk.decode()) + + +async def generate_completion(payload: Dict[str, Any], worker_addr: str): + return await fetch_remote(worker_addr + "/worker_generate", payload, "") + + +@app.post("/v1/embeddings", dependencies=[Depends(check_api_key)]) +@app.post("/v1/engines/{model_name}/embeddings", dependencies=[Depends(check_api_key)]) +async def create_embeddings(request: EmbeddingsRequest, model_name: str = None): + """Creates embeddings for the text""" + if request.model is None: + request.model = model_name + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + + request.input = process_input(request.model, request.input) + + data = [] + token_num = 0 + batch_size = WORKER_API_EMBEDDING_BATCH_SIZE + batches = [ + request.input[i : min(i + batch_size, len(request.input))] + for i in range(0, len(request.input), batch_size) + ] + for num_batch, batch in enumerate(batches): + payload = { + "model": request.model, + "input": batch, + "encoding_format": request.encoding_format, + } + embedding = await get_embedding(payload) + if "error_code" in embedding and embedding["error_code"] != 0: + return create_error_response(embedding["error_code"], embedding["text"]) + data += [ + { + "object": "embedding", + "embedding": emb, + "index": num_batch * batch_size + i, + } + for i, emb in enumerate(embedding["embedding"]) + ] + token_num += embedding["token_num"] + return EmbeddingsResponse( + data=data, + model=request.model, + usage=UsageInfo( + prompt_tokens=token_num, + total_tokens=token_num, + completion_tokens=None, + ), + ).dict(exclude_none=True) + + +async def get_embedding(payload: Dict[str, Any]): + controller_address = app_settings.controller_address + model_name = payload["model"] + worker_addr = await get_worker_address(model_name) + + embedding = await fetch_remote(worker_addr + "/worker_get_embeddings", payload) + return json.loads(embedding) + + +### GENERAL API - NOT OPENAI COMPATIBLE ### + + +@app.post("/api/v1/token_check") +async def count_tokens(request: APITokenCheckRequest): + """ + Checks the token count for each message in your list + This is not part of the OpenAI API spec. + """ + checkedList = [] + for item in request.prompts: + worker_addr = await get_worker_address(item.model) + + context_len = await fetch_remote( + worker_addr + "/model_details", + {"prompt": item.prompt, "model": item.model}, + "context_length", + ) + + token_num = await fetch_remote( + worker_addr + "/count_token", + {"prompt": item.prompt, "model": item.model}, + "count", + ) + + can_fit = True + if token_num + item.max_tokens > context_len: + can_fit = False + + checkedList.append( + APITokenCheckResponseItem( + fits=can_fit, contextLength=context_len, tokenCount=token_num + ) + ) + + return APITokenCheckResponse(prompts=checkedList) + + +@app.post("/api/v1/chat/completions") +async def create_chat_completion(request: APIChatCompletionRequest): + """Creates a completion for the chat message""" + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + error_check_ret = check_requests(request) + if error_check_ret is not None: + return error_check_ret + + worker_addr = await get_worker_address(request.model) + + gen_params = await get_gen_params( + request.model, + worker_addr, + request.messages, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + max_tokens=request.max_tokens, + echo=False, + stop=request.stop, + ) + + if request.repetition_penalty is not None: + gen_params["repetition_penalty"] = request.repetition_penalty + + max_new_tokens, error_check_ret = await check_length( + request, + gen_params["prompt"], + gen_params["max_new_tokens"], + worker_addr, + ) + + if error_check_ret is not None: + return error_check_ret + + gen_params["max_new_tokens"] = max_new_tokens + + if request.stream: + generator = chat_completion_stream_generator( + request.model, gen_params, request.n, worker_addr + ) + return StreamingResponse(generator, media_type="text/event-stream") + + choices = [] + chat_completions = [] + for i in range(request.n): + content = asyncio.create_task(generate_completion(gen_params, worker_addr)) + chat_completions.append(content) + try: + all_tasks = await asyncio.gather(*chat_completions) + except Exception as e: + return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) + usage = UsageInfo() + for i, content in enumerate(all_tasks): + if content["error_code"] != 0: + return create_error_response(content["error_code"], content["text"]) + choices.append( + ChatCompletionResponseChoice( + index=i, + message=ChatMessage(role="assistant", content=content["text"]), + finish_reason=content.get("finish_reason", "stop"), + ) + ) + task_usage = UsageInfo.parse_obj(content["usage"]) + for usage_key, usage_value in task_usage.dict().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + + return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) + + +### END GENERAL API - NOT OPENAI COMPATIBLE ### + + +def create_openai_api_server(): + parser = argparse.ArgumentParser( + description="FastChat ChatGPT-Compatible RESTful API server." + ) + parser.add_argument("--host", type=str, default="localhost", help="host name") + parser.add_argument("--port", type=int, default=8000, help="port number") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument( + "--allow-credentials", action="store_true", help="allow credentials" + ) + parser.add_argument( + "--allowed-origins", type=json.loads, default=["*"], help="allowed origins" + ) + parser.add_argument( + "--allowed-methods", type=json.loads, default=["*"], help="allowed methods" + ) + parser.add_argument( + "--allowed-headers", type=json.loads, default=["*"], help="allowed headers" + ) + parser.add_argument( + "--api-keys", + type=lambda s: s.split(","), + help="Optional list of comma separated API keys", + ) + parser.add_argument( + "--ssl", + action="store_true", + required=False, + default=False, + help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", + ) + args = parser.parse_args() + + app.add_middleware( + CORSMiddleware, + allow_origins=args.allowed_origins, + allow_credentials=args.allow_credentials, + allow_methods=args.allowed_methods, + allow_headers=args.allowed_headers, + ) + app_settings.controller_address = args.controller_address + app_settings.api_keys = args.api_keys + + logger.info(f"args: {args}") + return args + + +if __name__ == "__main__": + args = create_openai_api_server() + if args.ssl: + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + ssl_keyfile=os.environ["SSL_KEYFILE"], + ssl_certfile=os.environ["SSL_CERTFILE"], + ) + else: + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/third_party/FastChat/fastchat/serve/register_worker.py b/third_party/FastChat/fastchat/serve/register_worker.py new file mode 100644 index 00000000..2c2c4029 --- /dev/null +++ b/third_party/FastChat/fastchat/serve/register_worker.py @@ -0,0 +1,26 @@ +""" +Manually register workers. + +Usage: +python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 +""" + +import argparse + +import requests + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--controller-address", type=str) + parser.add_argument("--worker-name", type=str) + parser.add_argument("--check-heart-beat", action="store_true") + args = parser.parse_args() + + url = args.controller_address + "/register_worker" + data = { + "worker_name": args.worker_name, + "check_heart_beat": args.check_heart_beat, + "worker_status": None, + } + r = requests.post(url, json=data) + assert r.status_code == 200 diff --git a/third_party/FastChat/fastchat/serve/shutdown_serve.py b/third_party/FastChat/fastchat/serve/shutdown_serve.py new file mode 100644 index 00000000..95e2b704 --- /dev/null +++ b/third_party/FastChat/fastchat/serve/shutdown_serve.py @@ -0,0 +1,24 @@ +""" +Usage: +python shutdown_serve.py --down all +options: "all","controller","model_worker","openai_api_server", `all` means to stop all related servers +""" + +import argparse +import os +import subprocess + +parser = argparse.ArgumentParser() +parser.add_argument( + "--down", choices=["all", "controller", "model_worker", "openai_api_server"] +) +args = parser.parse_args() +base_shell = "ps -eo user,pid,cmd|grep fastchat.serve{}|grep -v grep|awk '{{print $2}}'|xargs kill -9" +if args.down == "all": + shell_script = base_shell.format("") +else: + serve = f".{args.down}" + shell_script = base_shell.format(serve) +print(f"execute shell cmd: {shell_script}") +subprocess.run(shell_script, shell=True, check=True) +print(f"{args.down} has been shutdown!") diff --git a/third_party/FastChat/fastchat/serve/test_message.py b/third_party/FastChat/fastchat/serve/test_message.py new file mode 100644 index 00000000..203a4490 --- /dev/null +++ b/third_party/FastChat/fastchat/serve/test_message.py @@ -0,0 +1,81 @@ +"""Send a test message.""" +import argparse +import json + +import requests + +from fastchat.model.model_adapter import get_conversation_template + + +def main(): + model_name = args.model_name + + if args.worker_address: + worker_addr = args.worker_address + else: + controller_addr = args.controller_address + ret = requests.post(controller_addr + "/refresh_all_workers") + ret = requests.post(controller_addr + "/list_models") + models = ret.json()["models"] + models.sort() + print(f"Models: {models}") + + ret = requests.post( + controller_addr + "/get_worker_address", json={"model": model_name} + ) + worker_addr = ret.json()["address"] + print(f"worker_addr: {worker_addr}") + + if worker_addr == "": + print(f"No available workers for {model_name}") + return + + conv = get_conversation_template(model_name) + conv.append_message(conv.roles[0], args.message) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + headers = {"User-Agent": "FastChat Client"} + gen_params = { + "model": model_name, + "prompt": prompt, + "temperature": args.temperature, + "max_new_tokens": args.max_new_tokens, + "stop": conv.stop_str, + "stop_token_ids": conv.stop_token_ids, + "echo": False, + } + response = requests.post( + worker_addr + "/worker_generate_stream", + headers=headers, + json=gen_params, + stream=True, + ) + + print(f"{conv.roles[0]}: {args.message}") + print(f"{conv.roles[1]}: ", end="") + prev = 0 + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + output = data["text"].strip() + print(output[prev:], end="", flush=True) + prev = len(output) + print("") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument("--worker-address", type=str) + parser.add_argument("--model-name", type=str, required=True) + parser.add_argument("--temperature", type=float, default=0.0) + parser.add_argument("--max-new-tokens", type=int, default=32) + parser.add_argument( + "--message", type=str, default="Tell me a story with more than 1000 words." + ) + args = parser.parse_args() + + main() diff --git a/third_party/FastChat/fastchat/serve/test_throughput.py b/third_party/FastChat/fastchat/serve/test_throughput.py new file mode 100644 index 00000000..1f36f92d --- /dev/null +++ b/third_party/FastChat/fastchat/serve/test_throughput.py @@ -0,0 +1,138 @@ +"""Benchmarking script to test the throughput of serving workers.""" +import argparse +import json + +import requests +import threading +import time + +from fastchat.conversation import get_conv_template + + +def main(): + if args.worker_address: + worker_addr = args.worker_address + else: + controller_addr = args.controller_address + ret = requests.post(controller_addr + "/refresh_all_workers") + ret = requests.post(controller_addr + "/list_models") + models = ret.json()["models"] + models.sort() + print(f"Models: {models}") + + ret = requests.post( + controller_addr + "/get_worker_address", json={"model": args.model_name} + ) + worker_addr = ret.json()["address"] + print(f"worker_addr: {worker_addr}") + + if worker_addr == "": + return + + # conv = get_conv_template("vicuna_v1.1") + conv = get_conv_template("qwen-7b-chat") + prompt = """参考信息: +第二盘比赛中,士气正盛的白卓璇6-1大胜,总比分2-0赢得生涯大满贯首胜。 值得一提的是,这是她生涯第一次参加大满贯正赛,成为继郑钦文、吴易昺、商竣程之后第四位首次参加大满贯正赛就能闯过首轮关的中国球员。这也是她生涯中首次击败世界排名前100的球员。 第二轮,白卓璇的对手是本届赛事6号种子,去年温网亚军、突尼斯名将贾巴尔。 王曦雨在2023年温网比赛中 另一场中国球员参加的女单首轮比赛中,世界排名第65位的中国球员王曦雨对阵通过资格赛晋级的世界排名第102位的安德列娃。 +至此,白卓璇结束温网4连胜,首次征战大满贯就能通过资格赛考验,并且拿到正赛首胜,中国00后已经创造历史。至此,中国金花结束本届温网单打之旅,接下来的女双和混双,期待姑娘们的出色发挥。 +其三,万卓索娃夺得生涯单打第二冠,上一冠还是2017年夺得WTA250比尔赛冠军,且万卓索娃终结过往生涯女单决赛4连败。 其四,万卓索娃以世界排名第42名征战温网,夺冠后将升至世界第10名,成为她生涯首次跻身世界前十,创造生涯排名新高纪录。 其五,万卓索娃成为2021年法网夺冠的克雷吉茨科娃后,最近两年首位夺得大满贯女单冠军的捷克球员。 其六,万卓索娃成为2011年与2014年两次夺得温网冠军的科维托娃之后,最近9年首位夺得温网冠军的捷克球员。 +2023年温布尔登网球锦标赛女子单打比赛,是2023年温布尔登网球锦标赛的其中一个比赛项目。伊莲娜·莱巴金娜是卫冕冠军,[1]但她在1/4决赛中输给本届亚军昂丝·加博。 马尔凯塔·万卓索娃是本届比赛的冠军,也是公开赛年代以来首位打入温网女单决赛的非种子选手,也是她继2019年法网后再次进入大满贯决赛,世界排名第42位的她也是自2018年温布顿网球锦标赛的塞雷娜·威廉姆斯(第181位)以来打入温网女单决赛排名最低的选手。 +其七,万卓索娃温网夺冠后,温网连续6年迎来新的女单冠军。 其八,万卓索娃连克五位种子,成为公开赛年代第一位夺得温网女单冠军的非种子选手。 6号种子贾巴尔是上届温网亚军,她今年在晋级决赛之旅,她有3场比赛是打满三盘胜出,其中1/4决赛三盘淘汰3号种子莱巴金娜,半决赛三盘淘汰2号种子萨巴伦卡。非种子选手万卓索娃晋级决赛之旅,她有两场比赛打满三盘胜出,其中半决赛两盘横扫斯维托丽娜,后者此前淘汰世界第一的斯瓦泰克。 + +--- +我的问题或指令: +2023年温网女单亚军? 介绍下万卓索娃的战绩 +--- +请依照所给历史对话和参考信息回答问题或执行指令,筛选出有关内容,忽略不相关信息。若信息不足,回答“对不起,我不知道”,切勿随意编造。 +你的回复(用中文):""" + conv.append_message(conv.roles[0], "Tell me a story with more than 1000 words") + # conv.append_message(conv.roles[0], prompt) + prompt_template = conv.get_prompt() + prompts = [prompt_template for _ in range(args.n_thread)] + + print(f"[debug] prompts = {prompts} \n---------\n") + + headers = {"User-Agent": "fastchat Client"} + ploads = [ + { + "model": args.model_name, + "prompt": prompts[i], + "max_new_tokens": args.max_new_tokens, + "temperature": 0.5, + # "stop": conv.sep, + } + for i in range(len(prompts)) + ] + + def send_request(results, i): + if args.test_dispatch: + ret = requests.post( + controller_addr + "/get_worker_address", json={"model": args.model_name} + ) + thread_worker_addr = ret.json()["address"] + else: + thread_worker_addr = worker_addr + print(f"thread {i} goes to {thread_worker_addr}") + response = requests.post( + thread_worker_addr + "/worker_generate_stream", + headers=headers, + json=ploads[i], + stream=False, + ) + k = list( + response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0") + ) + print(k) + + print(f"==============") + + response_new_words = json.loads(k[-2].decode("utf-8"))["text"] + error_code = json.loads(k[-2].decode("utf-8"))["error_code"] + print(f"[debug] response_new_words = {response_new_words}") + # print(f"=== Thread {i} ===, words: {1}, error code: {error_code}") + results[i] = len(response_new_words.split(" ")) - len(prompts[i].split(" ")) + + # use N threads to prompt the backend + tik = time.time() + threads = [] + results = [None] * args.n_thread + for i in range(args.n_thread): + t = threading.Thread(target=send_request, args=(results, i)) + t.start() + # time.sleep(0.5) + threads.append(t) + + for t in threads: + t.join() + + print(f"[debug] results = {results}") + print(f"Time (POST): {time.time() - tik} s") + + # n_words = 0 + # for i, response in enumerate(results): + # # print(prompt[i].replace(conv.sep, "\n"), end="") + # # make sure the streaming finishes at EOS or stopping criteria + # k = list(response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0")) + # response_new_words = json.loads(k[-2].decode("utf-8"))["text"] + # # print(response_new_words) + # n_words += len(response_new_words.split(" ")) - len(prompts[i].split(" ")) + n_words = sum(results) + time_seconds = time.time() - tik + print( + f"Time (Completion): {time_seconds}, n threads: {args.n_thread}, " + f"throughput: {n_words / time_seconds} words/s." + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument("--worker-address", type=str) + parser.add_argument("--model-name", type=str, default="Qwen-7B-Chat-8K") + parser.add_argument("--max-new-tokens", type=int, default=512) + parser.add_argument("--n-thread", type=int, default=1) + parser.add_argument("--test-dispatch", action="store_true") + args = parser.parse_args() + + main() diff --git a/third_party/FastChat/fastchat/serve/vllm_worker.py b/third_party/FastChat/fastchat/serve/vllm_worker.py new file mode 100644 index 00000000..5faccbba --- /dev/null +++ b/third_party/FastChat/fastchat/serve/vllm_worker.py @@ -0,0 +1,283 @@ +""" +A model worker that executes the model based on vLLM. + +See documentations at docs/vllm_integration.md +""" + +import argparse +import asyncio +import json +from typing import List + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse, JSONResponse +import uvicorn +from vllm import AsyncLLMEngine +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.sampling_params import SamplingParams +from vllm.utils import random_uuid + +from fastchat.serve.base_model_worker import BaseModelWorker +from fastchat.serve.model_worker import ( + logger, + worker_id, +) +from fastchat.utils import get_context_length, is_partial_stop + + +app = FastAPI() + + +class VLLMWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + no_register: bool, + llm_engine: AsyncLLMEngine, + conv_template: str, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template, + ) + + logger.info( + f"Loading the model {self.model_names} on worker {worker_id}, worker type: vLLM worker..." + ) + self.tokenizer = llm_engine.engine.tokenizer + self.context_len = get_context_length(llm_engine.engine.model_config.hf_config) + + if not no_register: + self.init_heart_beat() + + async def generate_stream(self, params): + self.call_ct += 1 + + context = params.pop("prompt") + request_id = params.pop("request_id") + temperature = float(params.get("temperature", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = params.get("top_k", -1.0) + presence_penalty = float(params.get("presence_penalty", 0.0)) + frequency_penalty = float(params.get("frequency_penalty", 0.0)) + max_new_tokens = params.get("max_new_tokens", 256) + stop_str = params.get("stop", None) + stop_token_ids = params.get("stop_token_ids", None) or [] + if self.tokenizer.eos_token_id is not None: + stop_token_ids.append(self.tokenizer.eos_token_id) + echo = params.get("echo", True) + use_beam_search = params.get("use_beam_search", False) + best_of = params.get("best_of", None) + + # Handle stop_str + stop = set() + if isinstance(stop_str, str) and stop_str != "": + stop.add(stop_str) + elif isinstance(stop_str, list) and stop_str != []: + stop.update(stop_str) + + for tid in stop_token_ids: + if tid is not None: + s = self.tokenizer.decode(tid) + if s != "": + stop.add(s) + + # make sampling params in vllm + top_p = max(top_p, 1e-5) + if temperature <= 1e-5: + top_p = 1.0 + + sampling_params = SamplingParams( + n=1, + temperature=temperature, + top_p=top_p, + use_beam_search=use_beam_search, + stop=list(stop), + stop_token_ids=stop_token_ids, + max_tokens=max_new_tokens, + top_k=top_k, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + best_of=best_of, + ) + results_generator = engine.generate(context, sampling_params, request_id) + + async for request_output in results_generator: + prompt = request_output.prompt + if echo: + text_outputs = [ + prompt + output.text for output in request_output.outputs + ] + else: + text_outputs = [output.text for output in request_output.outputs] + text_outputs = " ".join(text_outputs) + + partial_stop = any(is_partial_stop(text_outputs, i) for i in stop) + # prevent yielding partial stop sequence + if partial_stop: + continue + + prompt_tokens = len(request_output.prompt_token_ids) + completion_tokens = sum( + len(output.token_ids) for output in request_output.outputs + ) + ret = { + "text": text_outputs, + "error_code": 0, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + "cumulative_logprob": [ + output.cumulative_logprob for output in request_output.outputs + ], + "finish_reason": request_output.outputs[0].finish_reason + if len(request_output.outputs) == 1 + else [output.finish_reason for output in request_output.outputs], + } + # Emit twice here to ensure a 'finish_reason' with empty content in the OpenAI API response. + # This aligns with the behavior of model_worker. + if request_output.finished: + yield (json.dumps({**ret, **{"finish_reason": None}}) + "\0").encode() + yield (json.dumps(ret) + "\0").encode() + + async def generate(self, params): + async for x in self.generate_stream(params): + pass + return json.loads(x[:-1].decode()) + + +def release_worker_semaphore(): + worker.semaphore.release() + + +def acquire_worker_semaphore(): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(request_id): + async def abort_request() -> None: + await engine.abort(request_id) + + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + background_tasks.add_task(abort_request) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + request_id = random_uuid() + params["request_id"] = request_id + generator = worker.generate_stream(params) + background_tasks = create_background_tasks(request_id) + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + request_id = random_uuid() + params["request_id"] = request_id + output = await worker.generate(params) + release_worker_semaphore() + await engine.abort(request_id) + return JSONResponse(output) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return worker.get_status() + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + return {"context_length": worker.context_len} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument("--model-path", type=str, default="lmsys/vicuna-7b-v1.5") + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", + ) + parser.add_argument("--limit-worker-concurrency", type=int, default=1024) + parser.add_argument("--no-register", action="store_true") + parser.add_argument("--num-gpus", type=int, default=1) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument( + "--trust_remote_code", + action="store_false", + default=True, + help="Trust remote code (e.g., from HuggingFace) when" + "downloading the model and tokenizer.", + ) + parser.add_argument( + "--gpu_memory_utilization", + type=float, + default=0.9, + help="The ratio (between 0 and 1) of GPU memory to" + "reserve for the model weights, activations, and KV cache. Higher" + "values will increase the KV cache size and thus improve the model's" + "throughput. However, if the value is too high, it may cause out-of-" + "memory (OOM) errors.", + ) + + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + if args.model_path: + args.model = args.model_path + if args.num_gpus > 1: + args.tensor_parallel_size = args.num_gpus + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = AsyncLLMEngine.from_engine_args(engine_args) + worker = VLLMWorker( + args.controller_address, + args.worker_address, + worker_id, + args.model_path, + args.model_names, + args.limit_worker_concurrency, + args.no_register, + engine, + args.conv_template, + ) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/third_party/FastChat/fastchat/train/llama2_flash_attn_monkey_patch.py b/third_party/FastChat/fastchat/train/llama2_flash_attn_monkey_patch.py new file mode 100644 index 00000000..c1fe51c9 --- /dev/null +++ b/third_party/FastChat/fastchat/train/llama2_flash_attn_monkey_patch.py @@ -0,0 +1,238 @@ +import warnings +from typing import Optional, Tuple + +import torch +from flash_attn import __version__ as flash_attn_version +from flash_attn.bert_padding import pad_input, unpad_input +from flash_attn.flash_attn_interface import ( + flash_attn_func, + flash_attn_varlen_kvpacked_func, +) +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaModel, + rotate_half, +) + + +def apply_rotary_pos_emb(q, k, cos_sin, position_ids): + gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1] + gather_indices = gather_indices.repeat( + 1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3] + ) + bsz = gather_indices.shape[0] + cos, sin = ( + torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices) + for x in cos_sin + ) + q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k)) + return q, k + + +def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + warnings.warn( + "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." + ) + + bsz, q_len, _ = hidden_states.size() + kv_heads = getattr(self, "num_key_value_heads", self.num_heads) + + q, k, v = ( + op(hidden_states).view(bsz, q_len, nh, self.head_dim) + for op, nh in ( + (self.q_proj, self.num_heads), + (self.k_proj, kv_heads), + (self.v_proj, kv_heads), + ) + ) + # shape: (b, s, num_heads, head_dim) + + kv_seq_len = k.shape[1] + past_kv_len = 0 + if past_key_value is not None: + past_kv_len = past_key_value[0].shape[2] + kv_seq_len += past_kv_len + + cos_sin = self.rotary_emb(v, seq_len=kv_seq_len) + q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids) + + if past_key_value is not None: + assert ( + flash_attn_version >= "2.1.0" + ), "past_key_value support requires flash-attn >= 2.1.0" + # reuse k, v + k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1) + v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1) + + past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None + + if attention_mask is None: + output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view( + bsz, q_len, -1 + ) + else: + q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:]) + # We can skip concat and call unpad twice but seems better to call unpad only once. + kv, _, cu_k_lens, max_k = unpad_input( + torch.stack((k, v), dim=2), attention_mask + ) + output_unpad = flash_attn_varlen_kvpacked_func( + q, + kv, + cu_q_lens, + cu_k_lens, + max_s, + max_k, + 0.0, + softmax_scale=None, + causal=True, + ) + output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) + output = pad_input(output_unpad, indices, bsz, q_len) + + return self.o_proj(output), None, past_key_value + + +# Disable the transformation of the attention mask in LlamaModel as flash attention +# takes a boolean key_padding_mask. Fills in the past kv length for use in forward. +def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length +): + # [bsz, seq_len] + if past_key_values_length > 0 and attention_mask is not None: + attention_mask = torch.cat( + ( + torch.full( + (input_shape[0], past_key_values_length), + True, + dtype=attention_mask.dtype, + device=attention_mask.device, + ), + attention_mask, + ), + dim=-1, + ) + + if attention_mask is not None and torch.all(attention_mask): + return None # This uses the faster call when training with full samples + + return attention_mask + + +def replace_llama_attn_with_flash_attn(): + cuda_major, cuda_minor = torch.cuda.get_device_capability() + if cuda_major < 8: + warnings.warn( + "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." + "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" + ) + + LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + LlamaAttention.forward = forward + + +def test(): + from fastchat.train.llama_flash_attn_monkey_patch import forward as fastchat_forward + from transformers.models.llama.configuration_llama import LlamaConfig + + config = LlamaConfig( + hidden_size=1024, + intermediate_size=128, + num_hidden_layers=1, + num_attention_heads=8, + max_position_embeddings=16, + ) + device = torch.device("cuda") + model = LlamaModel(config) + attn = LlamaAttention(config).to(device).half() + bsz, hs, seqlen = 2, config.hidden_size, config.max_position_embeddings + position_ids = torch.arange(seqlen, dtype=torch.long, device=device).view( + -1, seqlen + ) + + mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device) + for i in range(4): + hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device) + if i: + mask[0, -i:] = False + mask[1, :i] = False + + lmask = model._prepare_decoder_attention_mask(mask, hidden.shape[:2], hidden, 0) + ref, _, _ = attn.forward( + hidden, attention_mask=lmask, position_ids=position_ids + ) + + fast, _, _ = fastchat_forward( + attn, hidden, attention_mask=mask, position_ids=position_ids + ) + + lmask = _prepare_decoder_attention_mask( + model, mask, hidden.shape[:2], hidden, 0 + ) + test, _, _ = forward( + attn, hidden, attention_mask=lmask, position_ids=position_ids + ) + + print(f"Mean(abs(ref)) = {torch.mean(torch.abs(ref))}") + print(f"Mean(abs(ref - fast)) = {torch.mean(torch.abs(ref - fast))}") + print(f"Mean(abs(ref - test)) = {torch.mean(torch.abs(ref - test))}") + print(f"Mean(abs(fast - test)) = {torch.mean(torch.abs(fast - test))}") + print(f"allclose(fast, test) = {torch.allclose(fast, test)}") + + with torch.no_grad(): + # Also check that past_kv is handled properly + hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device) + part_len = seqlen // 4 + assert part_len * 4 == seqlen + mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device) + mask[0, -2:] = False + lmask = _prepare_decoder_attention_mask( + model, mask, hidden.shape[:2], hidden, 0 + ) + oneshot, _, _ = forward( + attn, hidden, attention_mask=lmask, position_ids=position_ids + ) + parts = [] + past_kv, past_kv_len = None, 0 + for i in range(4): + start = part_len * i + end = start + part_len + hidden_part = hidden[:, start:end, ...] + lmask = _prepare_decoder_attention_mask( + model, + mask[:, start:end], + hidden_part.shape[:2], + hidden_part, + past_kv_len, + ) + part, _, past_kv = forward( + attn, + hidden_part.clone(), + attention_mask=lmask, + position_ids=position_ids[:, start:end], + past_key_value=past_kv, + use_cache=True, + ) + parts.append(part) + past_kv_len = past_kv[0].shape[2] + + print( + f"allclose(oneshot[:, 0], parts[0]) = {torch.allclose(oneshot[:, :part_len], parts[0])}" + ) + print( + f"allclose(oneshot, parts) = {torch.allclose(oneshot, torch.cat(parts, dim=1))}" + ) + + +if __name__ == "__main__": + test() diff --git a/third_party/FastChat/fastchat/train/llama_flash_attn_monkey_patch.py b/third_party/FastChat/fastchat/train/llama_flash_attn_monkey_patch.py new file mode 100644 index 00000000..b64aa818 --- /dev/null +++ b/third_party/FastChat/fastchat/train/llama_flash_attn_monkey_patch.py @@ -0,0 +1,107 @@ +from typing import Optional, Tuple +import warnings + +import torch +from torch import nn +import transformers +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb + +from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func +from flash_attn.bert_padding import unpad_input, pad_input + + +def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + warnings.warn( + "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) # shape: (b, num_heads, s, head_dim) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + + if past_key_value is not None: + # reuse k, v + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # Transform the data into the format required by flash attention + qkv = torch.stack([query_states, key_states, value_states], dim=2) + qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] + key_padding_mask = attention_mask + + if key_padding_mask is None: + qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) + cu_q_lens = torch.arange( + 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device + ) + max_s = q_len + output = flash_attn_varlen_qkvpacked_func( + qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True + ) + output = output.view(bsz, q_len, -1) + else: + qkv = qkv.reshape(bsz, q_len, -1) + qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) + qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) + output_unpad = flash_attn_varlen_qkvpacked_func( + qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True + ) + output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) + output = pad_input(output_unpad, indices, bsz, q_len) + + return self.o_proj(output), None, past_key_value + + +# Disable the transformation of the attention mask in LlamaModel as the flash attention +# requires the attention mask to be the same as the key_padding_mask +def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length +): + # [bsz, seq_len] + return attention_mask + + +def replace_llama_attn_with_flash_attn(): + cuda_major, cuda_minor = torch.cuda.get_device_capability() + if cuda_major < 8: + warnings.warn( + "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." + "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" + ) + transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( + _prepare_decoder_attention_mask + ) + transformers.models.llama.modeling_llama.LlamaAttention.forward = forward diff --git a/third_party/FastChat/fastchat/train/llama_xformers_attn_monkey_patch.py b/third_party/FastChat/fastchat/train/llama_xformers_attn_monkey_patch.py new file mode 100644 index 00000000..f8351e41 --- /dev/null +++ b/third_party/FastChat/fastchat/train/llama_xformers_attn_monkey_patch.py @@ -0,0 +1,129 @@ +""" +Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments +""" + +import logging +import math +from typing import Optional, Tuple + +import torch +import transformers.models.llama.modeling_llama +from torch import nn + +try: + import xformers.ops +except ImportError: + logging.error("xformers not found! Please install it before trying to use it.") + + +def replace_llama_attn_with_xformers_attn(): + transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward + + +def xformers_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # pylint: disable=duplicate-code + bsz, q_len, _ = hidden_states.size() + + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + ( + query_states, + key_states, + ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # We only apply xformers optimizations if we don't need to output the whole attention matrix + if not output_attentions: + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. + # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. + if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: + # input and output should be of form (bsz, q_len, num_heads, head_dim) + attn_output = xformers.ops.memory_efficient_attention( + query_states, key_states, value_states, attn_bias=None + ) + else: + # input and output should be of form (bsz, q_len, num_heads, head_dim) + attn_output = xformers.ops.memory_efficient_attention( + query_states, + key_states, + value_states, + attn_bias=xformers.ops.LowerTriangularMask(), + ) + attn_weights = None + else: + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) + ) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value diff --git a/third_party/FastChat/fastchat/train/train.py b/third_party/FastChat/fastchat/train/train.py new file mode 100644 index 00000000..8e8198e3 --- /dev/null +++ b/third_party/FastChat/fastchat/train/train.py @@ -0,0 +1,318 @@ +# This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright: +# +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +import json +import math +import pathlib +from typing import Dict, Optional, Sequence + +import numpy as np +import torch +from torch.utils.data import Dataset +import transformers +from transformers import Trainer +from transformers.trainer_pt_utils import LabelSmoother + +from fastchat.conversation import SeparatorStyle +from fastchat.model.model_adapter import get_conversation_template + +IGNORE_TOKEN_ID = LabelSmoother.ignore_index + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + trust_remote_code: bool = field( + default=False, + metadata={ + "help": "Whether or not to allow for custom models defined on the Hub in their own modeling files" + }, + ) + padding_side: str = field( + default="right", metadata={"help": "The padding side in tokenizer"} + ) + + +@dataclass +class DataArguments: + data_path: str = field( + default=None, metadata={"help": "Path to the training data."} + ) + eval_data_path: str = field( + default=None, metadata={"help": "Path to the evaluation data."} + ) + lazy_preprocess: bool = False + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + model_max_length: int = field( + default=512, + metadata={ + "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + + +local_rank = None + + +def rank0_print(*args): + if local_rank == 0: + print(*args) + + +def trainer_save_model_safe(trainer: transformers.Trainer): + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp import StateDictType, FullStateDictConfig + + save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with FSDP.state_dict_type( + trainer.model, StateDictType.FULL_STATE_DICT, save_policy + ): + trainer.save_model() + + +def preprocess( + sources, + tokenizer: transformers.PreTrainedTokenizer, +) -> Dict: + conv = get_conversation_template("vicuna") + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + targets = input_ids.clone() + + assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO + + # Mask targets. Only compute loss on the assistant outputs. + sep = conv.sep + conv.roles[1] + ": " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + turns = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_TOKEN_ID + for i, turn in enumerate(turns): + if turn == "": + break + turn_len = len(tokenizer(turn).input_ids) + + parts = turn.split(sep) + if len(parts) != 2: + break + parts[0] += sep + # "-2" is hardcoded for the Llama tokenizer to make the offset correct. + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + if i != 0 and not tokenizer.legacy: + # The legacy and non-legacy modes handle special tokens differently + instruction_len -= 1 + + # Ignore the user instructions + target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID + cur_len += turn_len + + if i != 0 and not tokenizer.legacy: + # The legacy and non-legacy modes handle special tokens differently + cur_len -= 1 + + target[cur_len:] = IGNORE_TOKEN_ID + + if False: # Inspect and check the correctness of masking + z = target.clone() + z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) + rank0_print(tokenizer.decode(z)) + exit() + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_TOKEN_ID + rank0_print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" #turn = {len(turns) - 1}. (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + attention_mask=input_ids.ne(tokenizer.pad_token_id), + ) + + +class SupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer): + super(SupervisedDataset, self).__init__() + + rank0_print("Formatting inputs...") + sources = [example["conversations"] for example in raw_data] + data_dict = preprocess(sources, tokenizer) + + self.input_ids = data_dict["input_ids"] + self.labels = data_dict["labels"] + self.attention_mask = data_dict["attention_mask"] + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + return dict( + input_ids=self.input_ids[i], + labels=self.labels[i], + attention_mask=self.attention_mask[i], + ) + + +class LazySupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer): + super(LazySupervisedDataset, self).__init__() + self.tokenizer = tokenizer + + rank0_print("Formatting inputs...Skip in lazy mode") + self.tokenizer = tokenizer + self.raw_data = raw_data + self.cached_data_dict = {} + + def __len__(self): + return len(self.raw_data) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + if i in self.cached_data_dict: + return self.cached_data_dict[i] + + ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer) + ret = dict( + input_ids=ret["input_ids"][0], + labels=ret["labels"][0], + attention_mask=ret["attention_mask"][0], + ) + self.cached_data_dict[i] = ret + + return ret + + +def make_supervised_data_module( + tokenizer: transformers.PreTrainedTokenizer, data_args +) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + dataset_cls = ( + LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset + ) + rank0_print("Loading data...") + + train_json = json.load(open(data_args.data_path, "r")) + train_dataset = dataset_cls(train_json, tokenizer=tokenizer) + + if data_args.eval_data_path: + eval_json = json.load(open(data_args.eval_data_path, "r")) + eval_dataset = dataset_cls(eval_json, tokenizer=tokenizer) + else: + eval_dataset = None + + return dict(train_dataset=train_dataset, eval_dataset=eval_dataset) + + +def train(): + global local_rank + + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + local_rank = training_args.local_rank + + # Set RoPE scaling factor + config = transformers.AutoConfig.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + trust_remote_code=model_args.trust_remote_code, + ) + orig_ctx_len = getattr(config, "max_position_embeddings", None) + if orig_ctx_len and training_args.model_max_length > orig_ctx_len: + scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len)) + config.rope_scaling = {"type": "linear", "factor": scaling_factor} + config.use_cache = False + + # Load model and tokenizer + model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + config=config, + cache_dir=training_args.cache_dir, + trust_remote_code=model_args.trust_remote_code, + ) + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side=model_args.padding_side, + use_fast=False, + trust_remote_code=model_args.trust_remote_code, + ) + + if tokenizer.pad_token != tokenizer.unk_token: + tokenizer.pad_token = tokenizer.unk_token + + # Load data + data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) + + # Start trainner + trainer = Trainer( + model=model, tokenizer=tokenizer, args=training_args, **data_module + ) + if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + + # Save model + model.config.use_cache = True + trainer.save_state() + if trainer.is_deepspeed_enabled: + trainer.save_model() + else: + trainer_save_model_safe(trainer) + + +if __name__ == "__main__": + train() diff --git a/third_party/FastChat/fastchat/train/train_baichuan.py b/third_party/FastChat/fastchat/train/train_baichuan.py new file mode 100644 index 00000000..b6b19b48 --- /dev/null +++ b/third_party/FastChat/fastchat/train/train_baichuan.py @@ -0,0 +1,333 @@ +# This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright: +# +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +import json +import math +import jsonlines +import pathlib +from multiprocessing import Pool +from typing import Dict, Optional, Sequence + +import numpy as np +import torch +from torch.utils.data import Dataset +import transformers +from transformers import Trainer +from transformers.trainer_pt_utils import LabelSmoother + +from fastchat.conversation import SeparatorStyle +from fastchat.model.model_adapter import get_conversation_template + +IGNORE_TOKEN_ID = LabelSmoother.ignore_index + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + + +@dataclass +class DataArguments: + data_path: str = field( + default=None, metadata={"help": "Path to the training data."} + ) + lazy_preprocess: bool = False + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + model_max_length: int = field( + default=512, + metadata={ + "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + + +local_rank = None + + +def rank0_print(*args): + if local_rank == 0: + print(*args) + + +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): + """Collects the state dict and dump to disk.""" + state_dict = trainer.model.state_dict() + if trainer.args.should_save: + cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + + +def apply_prompt_template(sources, systems=None): + conv = get_conversation_template("vicuna") + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + if systems and systems[i]: + conv.set_system_message(systems[i]) + prompt = conv.get_prompt() + conversations.append(prompt) + return conversations, conv + + +def tokenize_conversations(conversations, tokenizer): + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + targets = input_ids.clone() + return input_ids, targets + + +def mask_targets(conversations, targets, tokenizer, conv): + sep = conv.sep + conv.roles[1] + ": " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + turns = conversation.split(conv.sep2) + cur_len = 0 + target[:cur_len] = IGNORE_TOKEN_ID + for i, turn in enumerate(turns): + if turn == "": + break + turn_len = len(tokenizer(turn + conv.sep2).input_ids) + + parts = turn.split(sep) + if len(parts) != 2: + break + parts[0] += sep + instruction_len = len(tokenizer(parts[0]).input_ids) - 1 + + target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID + cur_len += turn_len + + target[cur_len:] = IGNORE_TOKEN_ID + + if False: # Inspect and check the correctness of masking + z = target.clone() + z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) + rank0_print(tokenizer.decode(z)) + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_TOKEN_ID + rank0_print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + return targets + + +def preprocess(sources, tokenizer: transformers.PreTrainedTokenizer, **kwargs) -> Dict: + systems = None if not kwargs else kwargs.get("systems", None) + + # If the data volume is small, process it directly in the main thread + if len(sources) <= 1000: + conversations, conv = apply_prompt_template(sources, systems) + input_ids, targets = tokenize_conversations(conversations, tokenizer) + targets = mask_targets(conversations, targets, tokenizer, conv) + else: # If the data volume is large, use multithreading for processing + with Pool() as p: + conversations, conv = p.apply_async( + apply_prompt_template, (sources, systems) + ).get() + input_ids, targets = p.apply_async( + tokenize_conversations, (conversations, tokenizer) + ).get() + targets = p.apply_async( + mask_targets, (conversations, targets, tokenizer, conv) + ).get() + p.close() + p.join() + + return dict( + input_ids=input_ids, + labels=targets, + attention_mask=input_ids.ne(tokenizer.pad_token_id), + ) + + +class SupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer): + super(SupervisedDataset, self).__init__() + + rank0_print("Formatting inputs...") + systems = [example.get("system", "") for example in raw_data] + sources = [example["conversations"] for example in raw_data] + + data_dict = preprocess(sources, tokenizer, systems=systems) + + self.input_ids = data_dict["input_ids"] + self.labels = data_dict["labels"] + self.attention_mask = data_dict["attention_mask"] + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + return dict( + input_ids=self.input_ids[i], + labels=self.labels[i], + attention_mask=self.attention_mask[i], + ) + + +class LazySupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer): + super(LazySupervisedDataset, self).__init__() + self.tokenizer = tokenizer + + rank0_print("Formatting inputs...Skip in lazy mode") + self.raw_data = raw_data + self.cached_data_dict = {} + + def __len__(self): + return len(self.raw_data) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + if i in self.cached_data_dict: + return self.cached_data_dict[i] + + ret = preprocess( + [self.raw_data[i]["conversations"]], + self.tokenizer, + systems=[self.raw_data[i].get("system", "")], + ) + ret = dict( + input_ids=ret["input_ids"][0], + labels=ret["labels"][0], + attention_mask=ret["attention_mask"][0], + ) + self.cached_data_dict[i] = ret + + return ret + + +def make_supervised_data_module( + tokenizer: transformers.PreTrainedTokenizer, data_args, train_ratio=0.98 +) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + train_ratio = min(train_ratio, 1.0) + dataset_cls = ( + LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset + ) + rank0_print("Loading data...") + data_path = data_args.data_path + if data_path.endswith(".json"): + raw_data = json.load(open(data_path, "r")) + elif data_path.endswith(".jsonl"): + with jsonlines.open(data_path, mode="r") as reader: + raw_data = [item for item in reader] + + # Split train/test + np.random.seed(0) + perm = np.random.permutation(len(raw_data)) + split = int(len(perm) * train_ratio) + train_indices = perm[:split] + if train_ratio < 1: + eval_indices = perm[split:] + else: + # if train_ratio==1, we use 5% of data as eval data, make sure trainer will not throw error when eval data is empty + eval_indices = perm[-int(len(perm) * 0.05) :] + train_raw_data = [raw_data[i] for i in train_indices] + eval_raw_data = [raw_data[i] for i in eval_indices] + rank0_print(f"#train {len(train_raw_data)}, #eval {len(eval_raw_data)}") + + train_dataset = dataset_cls(train_raw_data, tokenizer=tokenizer) + eval_dataset = dataset_cls(eval_raw_data, tokenizer=tokenizer) + return dict(train_dataset=train_dataset, eval_dataset=eval_dataset) + + +def train(): + global local_rank + + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + local_rank = training_args.local_rank + config = transformers.AutoConfig.from_pretrained( + model_args.model_name_or_path, + trust_remote_code=True, + cache_dir=training_args.cache_dir, + ) + # Set RoPE scaling factor + orig_ctx_len = getattr(config, "max_position_embeddings", None) + if orig_ctx_len and training_args.model_max_length > orig_ctx_len: + scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len)) + config.rope_scaling = {"type": "linear", "factor": scaling_factor} + config.use_cache = False + model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + config=config, + trust_remote_code=True, + cache_dir=training_args.cache_dir, + ) + # Tie the weights + model.tie_weights() + + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + config=config, + trust_remote_code=True, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right", + use_fast=False, + ) + # NOTE: if the token_id exceed the vocab_size will cause failing in training process! we need add special config and resize the embedding size! + tokenizer.pad_token = tokenizer.unk_token + print(f"tokens len: {len(tokenizer)}") + model.resize_token_embeddings(len(tokenizer)) + + data_module = make_supervised_data_module( + tokenizer=tokenizer, train_ratio=0.98, data_args=data_args + ) + trainer = Trainer( + model=model, tokenizer=tokenizer, args=training_args, **data_module + ) + + if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + trainer.save_state() + safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) + + +if __name__ == "__main__": + train() diff --git a/third_party/FastChat/fastchat/train/train_flant5.py b/third_party/FastChat/fastchat/train/train_flant5.py new file mode 100755 index 00000000..688c2f4f --- /dev/null +++ b/third_party/FastChat/fastchat/train/train_flant5.py @@ -0,0 +1,436 @@ +# Adapted from tatsu-lab@stanford_alpaca. Below is the original copyright: +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +import copy +import os +from dataclasses import dataclass, field +import random +import json +import logging +import pathlib +from typing import Dict, Optional, Sequence + +import torch +import torch.distributed as dist + +import transformers +from torch.utils.data import Dataset +from transformers import Trainer, AddedToken + +from fastchat.model.model_adapter import get_conversation_template + +default_conversation = get_conversation_template("t5") + +# TODO: import and use code from ../data/dataset.py + +IGNORE_INDEX = -100 +DEFAULT_PAD_TOKEN = "[PAD]" +DEFAULT_EOS_TOKEN = "" +DEFAULT_BOS_TOKEN = "" +DEFAULT_UNK_TOKEN = "" + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + + +@dataclass +class DataArguments: + data_path: str = field( + default=None, metadata={"help": "Path to the training data."} + ) + lazy_preprocess: bool = False + num_data: int = -1 + preprocessed_path: str = field( + default=None, metadata={"help": "Path to the preprocessed training data."} + ) + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + model_max_length: int = field( + default=2048, + metadata={ + "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + + +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): + """Collects the state dict and dump to disk.""" + state_dict = trainer.model.state_dict() + if trainer.args.should_save: + cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + + +def smart_tokenizer_and_embedding_resize( + special_tokens_dict: Dict, + other_tokens, + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, +): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + for new_token in other_tokens: + num_new_tokens += tokenizer.add_tokens(AddedToken(new_token, normalized=False)) + + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True + ) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True + ) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + +def _tokenize_fn( + strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer +) -> Dict: + """Tokenize a list of strings.""" + tokenized_list = [ + tokenizer( + text, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ) + for text in strings + ] + input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] + input_ids_lens = labels_lens = [ + tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() + for tokenized in tokenized_list + ] + return dict( + input_ids=input_ids, + labels=labels, + input_ids_lens=input_ids_lens, + labels_lens=labels_lens, + ) + + +def _form_qa( + q_list, + a_list, + tokenized_conversation, + tokenized_lens, + speakers, + header_len, + max_length, + eos_id, +): + cur_idx = header_len + conv_len = len(tokenized_conversation) + + for tokenized_len, speaker in zip(tokenized_lens, speakers): + if cur_idx >= conv_len: + break + if speaker == "gpt": + # truncate answer if it is too long + content_a = None + if tokenized_len > max_length: + content_a = tokenized_conversation[cur_idx : cur_idx + max_length] + else: + content_a = tokenized_conversation[cur_idx : cur_idx + tokenized_len] + content_a.append(eos_id) + a_list.append(content_a) + content_q = None + if cur_idx >= max_length: + content_q = tokenized_conversation[cur_idx - max_length : cur_idx] + else: + content_q = tokenized_conversation[:cur_idx] + content_q.append(eos_id) + q_list.append(content_q) + # asser the last token is actually a EOS for an answer + assert a_list[-1][-1] == eos_id, "Last Token is not EOS!" + cur_idx += tokenized_len + + +def _add_speaker_and_signal(header, source, get_conversation=True): + """Add speaker and start/end signal on each round.""" + BEGIN_SIGNAL = "### " + END_SIGNAL = "\n" + conversation = header + + unknown_role = "unknown" # use default unknown role + roles = { + "human": default_conversation.roles[0], # human role + "gpt": default_conversation.roles[1], # gpt role + } + + for i in range(len(source)): + sentence = source[i] + sentence_from = sentence["from"].lower() + + # TODO(Dacheng): verify this is a good way to split sentences + if sentence_from == "human": + # if this is not the last sentence + if i != len(source) - 1: + next_sentence = source[i + 1] + sentence["value"] = ( + BEGIN_SIGNAL + + roles.get(sentence_from, unknown_role) + + ": " + + sentence["value"] + + END_SIGNAL + + BEGIN_SIGNAL + + roles.get(next_sentence["from"].lower(), unknown_role) + + ": " + ) + else: + # if human is the last speaker, it does not contribute to an answer + pass + else: + sentence["value"] = sentence["value"] + END_SIGNAL + if get_conversation: + conversation += sentence["value"] + + return conversation + + +def preprocess( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, +) -> Dict: + """ + Given a list of sources, each is a conversation list. This transform: + 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; + 2. Concatenate conversations together; + 3. Tokenize the concatenated conversation; + 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. + """ + # add end signal and concatenate together + conversations = [] + header = f"{default_conversation.system_message}\n\n" + for source in sources: + conversation = _add_speaker_and_signal(header, source, tokenizer) + conversations.append(conversation) + # TODO(Dacheng): This is related to whether the dataset has been truncated.. + # Assume we get long conversations, don't pad, don't return tensor + tokenized_conversations = tokenizer(conversations, max_length=None)["input_ids"] + q_list = [] + a_list = [] + # count for EOS length + header_len = _tokenize_fn([header], tokenizer)["input_ids_lens"][0] - 1 + from tqdm import tqdm + + for tokenized_conversation, source in tqdm(zip(tokenized_conversations, sources)): + tokenized_sentence = _tokenize_fn([s["value"] for s in source], tokenizer) + tokenized_lens = tokenized_sentence["input_ids_lens"] + tokenized_lens = [l - 1 for l in tokenized_lens] + speakers = [sentence["from"] for sentence in source] + ids = tokenized_sentence["input_ids"] + _form_qa( + q_list, + a_list, + tokenized_conversation, + tokenized_lens, + speakers, + header_len, + tokenizer.model_max_length, + tokenizer.eos_token_id, + ) + return dict(input_ids=q_list, labels=a_list) + + +class SupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__( + self, + data_path: str, + tokenizer: transformers.PreTrainedTokenizer, + preprocessed_path, + num_data, + ): + super(SupervisedDataset, self).__init__() + + # save to file + # Make sure only the first process is processing the dataset + if dist.get_rank() != 0: + dist.barrier() + self.preprocessed_path = preprocessed_path + if os.path.exists(self.preprocessed_path): + logging.warning("loading from preprocessed data") + with open(self.preprocessed_path, "r") as f: + data_dict = json.load(f) + if dist.get_rank() == 0: + dist.barrier() + else: + if not os.path.exists("preprocessed_data"): + os.mkdir("preprocessed_data") + assert dist.get_rank() == 0, "Only the first process should process" + logging.warning("Loading data...") + list_data_dict = json.load(open(data_path, "r")) + + logging.warning("Formatting inputs...") + sources = [] + + sources = [example["conversations"] for example in list_data_dict] + + data_dict = preprocess(sources, tokenizer) + json_data_dict = json.dumps(data_dict) + + # Remember to close file to avoid concurrent r/w + with open(self.preprocessed_path, "w") as f: + f.write(json_data_dict) + + # Release barrier + dist.barrier() + + if num_data != -1: + data_dict["input_ids"] = data_dict["input_ids"][:num_data] + data_dict["labels"] = data_dict["labels"][:num_data] + + # Shuffle data to see more conversations, if only train on partial data + temp = list(zip(data_dict["input_ids"], data_dict["labels"])) + random.shuffle(temp) + res1, res2 = zip(*temp) + data_dict["input_ids"], data_dict["labels"] = list(res1), list(res2) + + # Dacheng: Get rid of short QA pair + self.input_ids = copy.deepcopy(data_dict["input_ids"]) + self.labels = copy.deepcopy(data_dict["labels"]) + length_arr = defaultdict(int) + for idx, (input, label) in enumerate( + zip(data_dict["input_ids"], data_dict["labels"]) + ): + length_arr[str(len(label) // 100)] += 1 + if len(input) <= 5: + del_idx = self.input_ids.index(input) + self.input_ids.pop(del_idx) + self.labels.pop(del_idx) + if len(label) <= 5: + del_idx = self.labels.index(label) + self.input_ids.pop(del_idx) + self.labels.pop(del_idx) + + for input, label in zip(self.input_ids, self.labels): + assert len(input) >= 5 + assert len(label) >= 5 + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + return dict(input_ids=self.input_ids[i], labels=self.labels[i]) + + +@dataclass +class DataCollatorForSupervisedDataset(object): + """Collate examples for supervised fine-tuning.""" + + tokenizer: transformers.PreTrainedTokenizer + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + input_ids, labels = tuple( + [ + torch.as_tensor(instance[key], dtype=torch.int64) + for instance in instances + ] + for key in ("input_ids", "labels") + ) + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id + ) + labels = torch.nn.utils.rnn.pad_sequence( + labels, batch_first=True, padding_value=IGNORE_INDEX + ) + ret = dict( + input_ids=input_ids, + labels=labels, + attention_mask=input_ids.ne(self.tokenizer.pad_token_id), + ) + torch.set_printoptions(profile="full") + return ret + + +def make_supervised_data_module( + tokenizer: transformers.PreTrainedTokenizer, data_args +) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + dataset_cls = SupervisedDataset + train_dataset = dataset_cls( + tokenizer=tokenizer, + data_path=data_args.data_path, + preprocessed_path=data_args.preprocessed_path, + num_data=data_args.num_data, + ) + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) + return dict( + train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator + ) + + +def train(): + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + model = transformers.AutoModelForSeq2SeqLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + ) + # Dacheng: Note we can only use T5Tokenizer, otherwise it will prepend + # a space before special tokens. + tokenizer = transformers.T5Tokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right", + use_fast=False, + ) + + smart_tokenizer_and_embedding_resize( + special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), + other_tokens=["<", "{", "\n", "}", "`", " ", "\\", "^", "\t"], + tokenizer=tokenizer, + model=model, + ) + + data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) + trainer = Trainer( + model=model, tokenizer=tokenizer, args=training_args, **data_module + ) + + if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + trainer.save_state() + safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) + + +if __name__ == "__main__": + train() diff --git a/third_party/FastChat/fastchat/train/train_lora.py b/third_party/FastChat/fastchat/train/train_lora.py new file mode 100644 index 00000000..9ecb47c2 --- /dev/null +++ b/third_party/FastChat/fastchat/train/train_lora.py @@ -0,0 +1,222 @@ +# Usage: deepspeed train_lora.py --deepspeed <$PATH_TO_DEEPSPEED_CONFIG> + +# Adapted from tatsu-lab@stanford_alpaca. Below is the original copyright: +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +import logging +import pathlib +import typing +import os + +from deepspeed import zero +from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus +from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training +import transformers +from transformers import Trainer, BitsAndBytesConfig, deepspeed +import torch + +from fastchat.train.train import ( + DataArguments, + ModelArguments, + make_supervised_data_module, +) + +from fastchat.train.llama_flash_attn_monkey_patch import ( + replace_llama_attn_with_flash_attn, +) + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: typing.Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + model_max_length: int = field( + default=512, + metadata={ + "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + flash_attn: bool = False + + +@dataclass +class LoraArguments: + lora_r: int = 8 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_target_modules: typing.List[str] = field( + default_factory=lambda: ["q_proj", "v_proj"] + ) + lora_weight_path: str = "" + lora_bias: str = "none" + q_lora: bool = False + + +def maybe_zero_3(param): + if hasattr(param, "ds_id"): + assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +# Borrowed from peft.utils.get_peft_model_state_dict +def get_peft_state_maybe_zero_3(named_params, bias): + if bias == "none": + to_return = {k: t for k, t in named_params if "lora_" in k} + elif bias == "all": + to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} + elif bias == "lora_only": + to_return = {} + maybe_lora_bias = {} + lora_bias_names = set() + for k, t in named_params: + if "lora_" in k: + to_return[k] = t + bias_name = k.split("lora_")[0] + "bias" + lora_bias_names.add(bias_name) + elif "bias" in k: + maybe_lora_bias[k] = t + for k, t in maybe_lora_bias: + if bias_name in lora_bias_names: + to_return[bias_name] = t + else: + raise NotImplementedError + to_return = {k: maybe_zero_3(v) for k, v in to_return.items()} + return to_return + + +def train(): + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments, LoraArguments) + ) + ( + model_args, + data_args, + training_args, + lora_args, + ) = parser.parse_args_into_dataclasses() + + if training_args.flash_attn: + replace_llama_attn_with_flash_attn() + + device_map = None + world_size = int(os.environ.get("WORLD_SIZE", 1)) + ddp = world_size != 1 + if lora_args.q_lora: + device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None + if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled(): + logging.warning( + "FSDP and ZeRO3 are both currently incompatible with QLoRA." + ) + + compute_dtype = ( + torch.float16 + if training_args.fp16 + else (torch.bfloat16 if training_args.bf16 else torch.float32) + ) + + model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + device_map=device_map, + quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=compute_dtype, + ) + if lora_args.q_lora + else None, + ) + lora_config = LoraConfig( + r=lora_args.lora_r, + lora_alpha=lora_args.lora_alpha, + target_modules=lora_args.lora_target_modules, + lora_dropout=lora_args.lora_dropout, + bias=lora_args.lora_bias, + task_type="CAUSAL_LM", + ) + + if lora_args.q_lora: + model = prepare_model_for_kbit_training( + model, use_gradient_checkpointing=training_args.gradient_checkpointing + ) + if not ddp and torch.cuda.device_count() > 1: + # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available + model.is_parallelizable = True + model.model_parallel = True + + model = get_peft_model(model, lora_config) + if training_args.flash_attn: + for name, module in model.named_modules(): + if "norm" in name: + module = module.to(compute_dtype) + if "lm_head" in name or "embed_tokens" in name: + if hasattr(module, "weight"): + module = module.to(compute_dtype) + if training_args.deepspeed is not None and training_args.local_rank == 0: + model.print_trainable_parameters() + + if training_args.gradient_checkpointing: + model.enable_input_require_grads() + + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right", + use_fast=False, + ) + tokenizer.pad_token = tokenizer.unk_token + + data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) + trainer = Trainer( + model=model, tokenizer=tokenizer, args=training_args, **data_module + ) + + model.config.use_cache = False + + if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + trainer.save_state() + + # check if zero3 mode enabled + if deepspeed.is_deepspeed_zero3_enabled(): + # use deepspeed engine internal function to gather state dict + # state_dict_zero3 contains whole parameters of base and lora adapters + # we will not extract lora parameters since peft save_pretrained will do that + # https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/peft_model.py#L125 + # https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/utils/save_and_load.py#L19 + state_dict_zero3 = trainer.model_wrapped._zero3_consolidated_16bit_state_dict() + if training_args.local_rank == 0: + state_dict = state_dict_zero3 + else: + # in other mode we use original code from fastchat team, to make sure our change is minimum + state_dict = get_peft_state_maybe_zero_3( + model.named_parameters(), lora_args.lora_bias + ) + + if training_args.local_rank == 0: + model.save_pretrained(training_args.output_dir, state_dict=state_dict) + + +if __name__ == "__main__": + train() diff --git a/third_party/FastChat/fastchat/train/train_lora_t5.py b/third_party/FastChat/fastchat/train/train_lora_t5.py new file mode 100644 index 00000000..21abc92c --- /dev/null +++ b/third_party/FastChat/fastchat/train/train_lora_t5.py @@ -0,0 +1,226 @@ +# Adapted from tatsu-lab@stanford_alpaca. Below is the original copyright: +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +import copy +import os +from dataclasses import dataclass, field +import random +import json +import logging +import pathlib +from typing import Dict, Optional, Sequence, List + +import torch +import torch.distributed as dist + + +from deepspeed import zero +from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus +from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType + +import transformers +from torch.utils.data import Dataset +from transformers import Trainer, AddedToken, BitsAndBytesConfig, deepspeed + +from fastchat.train.train_flant5 import ( + smart_tokenizer_and_embedding_resize, + make_supervised_data_module, +) + +from fastchat.train.train_lora import get_peft_state_maybe_zero_3 + +from fastchat.model.model_adapter import get_conversation_template + +default_conversation = get_conversation_template("t5") + +# TODO: import and use code from ../data/dataset.py + +IGNORE_INDEX = -100 +DEFAULT_PAD_TOKEN = "[PAD]" +DEFAULT_EOS_TOKEN = "" +DEFAULT_BOS_TOKEN = "" +DEFAULT_UNK_TOKEN = "" + + +@dataclass +class LoraArguments: + lora_r: int = 8 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_target_modules: List[str] = field(default_factory=lambda: ["q", "v"]) + lora_weight_path: str = "" + lora_bias: str = "none" + q_lora: bool = False + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + + +@dataclass +class DataArguments: + data_path: str = field( + default=None, metadata={"help": "Path to the training data."} + ) + lazy_preprocess: bool = False + num_data: int = -1 + preprocessed_path: str = field( + default=None, metadata={"help": "Path to the preprocessed training data."} + ) + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + model_max_length: int = field( + default=2048, + metadata={ + "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + + +def safe_save_model_for_hf_trainer( + trainer: transformers.Trainer, output_dir: str, state_dict: dict +): + """Collects the state dict and dump to disk.""" + + if trainer.args.should_save: + cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + + +def train(): + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments, LoraArguments) + ) + ( + model_args, + data_args, + training_args, + lora_args, + ) = parser.parse_args_into_dataclasses() + + device_map = None + world_size = int(os.environ.get("WORLD_SIZE", 1)) + ddp = world_size != 1 + if lora_args.q_lora: + device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None + if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled(): + logging.warning( + "FSDP and ZeRO3 are both currently incompatible with QLoRA." + ) + + compute_dtype = ( + torch.float16 + if training_args.fp16 + else (torch.bfloat16 if training_args.bf16 else torch.float32) + ) + + model = transformers.AutoModelForSeq2SeqLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + device_map=device_map, + quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=compute_dtype, + ) + if lora_args.q_lora + else None, + ) + + lora_config = LoraConfig( + r=lora_args.lora_r, + lora_alpha=lora_args.lora_alpha, + target_modules=lora_args.lora_target_modules, + lora_dropout=lora_args.lora_dropout, + bias=lora_args.lora_bias, + task_type=TaskType.SEQ_2_SEQ_LM, + ) + + if lora_args.q_lora: + model = prepare_model_for_kbit_training( + model, use_gradient_checkpointing=training_args.gradient_checkpointing + ) + if not ddp and torch.cuda.device_count() > 1: + # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available + model.is_parallelizable = True + model.model_parallel = True + + model = get_peft_model(model, lora_config) + if training_args.deepspeed is not None and training_args.local_rank == 0: + model.print_trainable_parameters() + + if training_args.gradient_checkpointing: + model.enable_input_require_grads() + + # Dacheng: Note we can only use T5Tokenizer, otherwise it will prepend + # a space before special tokens. + tokenizer = transformers.T5Tokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right", + use_fast=False, + ) + + smart_tokenizer_and_embedding_resize( + special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), + other_tokens=["<", "{", "\n", "}", "`", " ", "\\", "^", "\t"], + tokenizer=tokenizer, + model=model, + ) + + data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) + + trainer = Trainer( + model=model, tokenizer=tokenizer, args=training_args, **data_module + ) + + if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + trainer.save_state() + # check if zero3 mode enabled + if deepspeed.is_deepspeed_zero3_enabled(): + # use deepspeed engine internal function to gather state dict + # state_dict_zero3 contains whole parameters of base and lora adapters + # we will not extract lora parameters since peft save_pretrained will do that + # https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/peft_model.py#L125 + # https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/utils/save_and_load.py#L19 + state_dict_zero3 = trainer.model_wrapped._zero3_consolidated_16bit_state_dict() + if training_args.local_rank == 0: + state_dict = state_dict_zero3 + else: + # in other mode we use original code from fastchat team, to make sure our change is minimum + state_dict = get_peft_state_maybe_zero_3( + model.named_parameters(), lora_args.lora_bias + ) + + if training_args.local_rank == 0: + safe_save_model_for_hf_trainer( + trainer=trainer, output_dir=training_args.output_dir, state_dict=state_dict + ) + + +if __name__ == "__main__": + train() diff --git a/third_party/FastChat/fastchat/train/train_mem.py b/third_party/FastChat/fastchat/train/train_mem.py new file mode 100644 index 00000000..9ce4913a --- /dev/null +++ b/third_party/FastChat/fastchat/train/train_mem.py @@ -0,0 +1,13 @@ +# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. + +# Need to call this before importing transformers. +from fastchat.train.llama2_flash_attn_monkey_patch import ( + replace_llama_attn_with_flash_attn, +) + +replace_llama_attn_with_flash_attn() + +from fastchat.train.train import train + +if __name__ == "__main__": + train() diff --git a/third_party/FastChat/fastchat/train/train_xformers.py b/third_party/FastChat/fastchat/train/train_xformers.py new file mode 100644 index 00000000..0eb2badd --- /dev/null +++ b/third_party/FastChat/fastchat/train/train_xformers.py @@ -0,0 +1,13 @@ +# Make it more memory efficient by monkey patching the LLaMA model with xformers attention. + +# Need to call this before importing transformers. +from fastchat.train.llama_xformers_attn_monkey_patch import ( + replace_llama_attn_with_xformers_attn, +) + +replace_llama_attn_with_xformers_attn() + +from fastchat.train.train import train + +if __name__ == "__main__": + train() diff --git a/third_party/FastChat/fastchat/utils.py b/third_party/FastChat/fastchat/utils.py new file mode 100644 index 00000000..45d4323d --- /dev/null +++ b/third_party/FastChat/fastchat/utils.py @@ -0,0 +1,351 @@ +""" +Common utilities. +""" +from asyncio import AbstractEventLoop +import json +import logging +import logging.handlers +import os +import platform +import sys +from typing import AsyncGenerator, Generator +import warnings + +import requests + +from fastchat.constants import LOGDIR + + +handler = None +visited_loggers = set() + + +def build_logger(logger_name, logger_filename): + global handler + + formatter = logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # Set the format of root handlers + if not logging.getLogger().handlers: + if sys.version_info[1] >= 9: + # This is for windows + logging.basicConfig(level=logging.INFO, encoding="utf-8") + else: + if platform.system() == "Windows": + warnings.warn( + "If you are running on Windows, " + "we recommend you use Python >= 3.9 for UTF-8 encoding." + ) + logging.basicConfig(level=logging.INFO) + logging.getLogger().handlers[0].setFormatter(formatter) + + # Redirect stdout and stderr to loggers + stdout_logger = logging.getLogger("stdout") + stdout_logger.setLevel(logging.INFO) + sl = StreamToLogger(stdout_logger, logging.INFO) + sys.stdout = sl + + stderr_logger = logging.getLogger("stderr") + stderr_logger.setLevel(logging.ERROR) + sl = StreamToLogger(stderr_logger, logging.ERROR) + sys.stderr = sl + + # Get logger + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + + # if LOGDIR is empty, then don't try output log to local file + if LOGDIR != "": + os.makedirs(LOGDIR, exist_ok=True) + filename = os.path.join(LOGDIR, logger_filename) + handler = logging.handlers.TimedRotatingFileHandler( + filename, when="D", utc=True, encoding="utf-8" + ) + handler.setFormatter(formatter) + + for l in [stdout_logger, stderr_logger, logger]: + if l in visited_loggers: + continue + visited_loggers.add(l) + l.addHandler(handler) + + return logger + + +class StreamToLogger(object): + """ + Fake file-like stream object that redirects writes to a logger instance. + """ + + def __init__(self, logger, log_level=logging.INFO): + self.terminal = sys.stdout + self.logger = logger + self.log_level = log_level + self.linebuf = "" + + def __getattr__(self, attr): + return getattr(self.terminal, attr) + + def write(self, buf): + temp_linebuf = self.linebuf + buf + self.linebuf = "" + for line in temp_linebuf.splitlines(True): + # From the io.TextIOWrapper docs: + # On output, if newline is None, any '\n' characters written + # are translated to the system default line separator. + # By default sys.stdout.write() expects '\n' newlines and then + # translates them so this is still cross platform. + if line[-1] == "\n": + encoded_message = line.encode("utf-8", "ignore").decode("utf-8") + self.logger.log(self.log_level, encoded_message.rstrip()) + else: + self.linebuf += line + + def flush(self): + if self.linebuf != "": + encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8") + self.logger.log(self.log_level, encoded_message.rstrip()) + self.linebuf = "" + + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +def get_gpu_memory(max_gpus=None): + """Get available memory for each GPU.""" + import torch + + gpu_memory = [] + num_gpus = ( + torch.cuda.device_count() + if max_gpus is None + else min(max_gpus, torch.cuda.device_count()) + ) + + for gpu_id in range(num_gpus): + with torch.cuda.device(gpu_id): + device = torch.cuda.current_device() + gpu_properties = torch.cuda.get_device_properties(device) + total_memory = gpu_properties.total_memory / (1024**3) + allocated_memory = torch.cuda.memory_allocated() / (1024**3) + available_memory = total_memory - allocated_memory + gpu_memory.append(available_memory) + return gpu_memory + + +def oai_moderation(text): + """ + Check whether the text violates OpenAI moderation API. + """ + import openai + + openai.api_base = "https://api.openai.com/v1" + openai.api_key = os.environ["OPENAI_API_KEY"] + openai.api_type = "open_ai" + openai.api_version = None + + MAX_RETRY = 3 + for i in range(MAX_RETRY): + try: + res = openai.Moderation.create(input=text) + flagged = res["results"][0]["flagged"] + break + except (openai.error.OpenAIError, KeyError, IndexError) as e: + # flag true to be conservative + flagged = True + print(f"MODERATION ERROR: {e}\nInput: {text}") + return flagged + + +def moderation_filter(text, model_list): + MODEL_KEYWORDS = ["claude"] + + for keyword in MODEL_KEYWORDS: + for model in model_list: + if keyword in model and oai_moderation(text): + return True + return False + + +def clean_flant5_ckpt(ckpt_path): + """ + Flan-t5 trained with HF+FSDP saves corrupted weights for shared embeddings, + Use this function to make sure it can be correctly loaded. + """ + import torch + + index_file = os.path.join(ckpt_path, "pytorch_model.bin.index.json") + index_json = json.load(open(index_file, "r")) + + weightmap = index_json["weight_map"] + + share_weight_file = weightmap["shared.weight"] + share_weight = torch.load(os.path.join(ckpt_path, share_weight_file))[ + "shared.weight" + ] + + for weight_name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]: + weight_file = weightmap[weight_name] + weight = torch.load(os.path.join(ckpt_path, weight_file)) + weight[weight_name] = share_weight + torch.save(weight, os.path.join(ckpt_path, weight_file)) + + +def pretty_print_semaphore(semaphore): + """Print a semaphore in better format.""" + if semaphore is None: + return "None" + return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" + + +"""A javascript function to get url parameters for the gradio web server.""" +get_window_url_params_js = """ +function() { + const params = new URLSearchParams(window.location.search); + url_params = Object.fromEntries(params); + console.log("url_params", url_params); + return url_params; + } +""" + + +get_window_url_params_with_tos_js = """ +function() { + const params = new URLSearchParams(window.location.search); + url_params = Object.fromEntries(params); + console.log("url_params", url_params); + + msg = "Users of this website are required to agree to the following terms:\\n\\nThe service is a research preview. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes.\\nThe service collects user dialogue data and reserves the right to distribute it under a Creative Commons Attribution (CC-BY) or a similar license." + alert(msg); + + return url_params; + } +""" + + +def iter_over_async( + async_gen: AsyncGenerator, event_loop: AbstractEventLoop +) -> Generator: + """ + Convert async generator to sync generator + + :param async_gen: the AsyncGenerator to convert + :param event_loop: the event loop to run on + :returns: Sync generator + """ + ait = async_gen.__aiter__() + + async def get_next(): + try: + obj = await ait.__anext__() + return False, obj + except StopAsyncIteration: + return True, None + + while True: + done, obj = event_loop.run_until_complete(get_next()) + if done: + break + yield obj + + +def detect_language(text: str) -> str: + """Detect the langauge of a string.""" + import polyglot # pip3 install polyglot pyicu pycld2 + from polyglot.detect import Detector + from polyglot.detect.base import logger as polyglot_logger + import pycld2 + + polyglot_logger.setLevel("ERROR") + + try: + lang_code = Detector(text).language.name + except (pycld2.error, polyglot.detect.base.UnknownLanguage): + lang_code = "unknown" + return lang_code + + +def parse_gradio_auth_creds(filename: str): + """Parse a username:password file for gradio authorization.""" + gradio_auth_creds = [] + with open(filename, "r", encoding="utf8") as file: + for line in file.readlines(): + gradio_auth_creds += [x.strip() for x in line.split(",") if x.strip()] + if gradio_auth_creds: + auth = [tuple(cred.split(":")) for cred in gradio_auth_creds] + else: + auth = None + return auth + + +def is_partial_stop(output: str, stop_str: str): + """Check whether the output contains a partial stop str.""" + for i in range(0, min(len(output), len(stop_str))): + if stop_str.startswith(output[-i:]): + return True + return False + + +def run_cmd(cmd: str): + """Run a bash command.""" + print(cmd) + return os.system(cmd) + + +def is_sentence_complete(output: str): + """Check whether the output is a complete sentence.""" + end_symbols = (".", "?", "!", "...", "。", "?", "!", "…", '"', "'", "”") + return output.endswith(end_symbols) + + +# Models don't use the same configuration key for determining the maximum +# sequence length. Store them here so we can sanely check them. +# NOTE: The ordering here is important. Some models have two of these and we +# have a preference for which value gets used. +SEQUENCE_LENGTH_KEYS = [ + "max_position_embeddings", + "max_sequence_length", + "seq_length", + "max_seq_len", + "model_max_length", +] + + +def get_context_length(config): + """Get the context length of a model from a huggingface model config.""" + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling: + rope_scaling_factor = config.rope_scaling["factor"] + else: + rope_scaling_factor = 1 + + for key in SEQUENCE_LENGTH_KEYS: + val = getattr(config, key, None) + if val is not None: + return int(rope_scaling_factor * val) + return 2048 + + +def str_to_torch_dtype(dtype: str): + import torch + + if dtype is None: + return None + elif dtype == "float32": + return torch.float32 + elif dtype == "float16": + return torch.float16 + elif dtype == "bfloat16": + return torch.bfloat16 + else: + raise ValueError(f"Unrecognized dtype: {dtype}") diff --git a/third_party/FastChat/format.sh b/third_party/FastChat/format.sh new file mode 100755 index 00000000..798b6986 --- /dev/null +++ b/third_party/FastChat/format.sh @@ -0,0 +1,71 @@ +#!/usr/bin/env bash + +# Adapted from https://github.com/skypilot-org/skypilot/blob/master/format.sh + +# Cause the script to exit if a single command fails +set -eo pipefail + +# this stops git rev-parse from failing if we run this from the .git directory +builtin cd "$(dirname "${BASH_SOURCE:-$0}")" +ROOT="$(git rev-parse --show-toplevel)" +builtin cd "$ROOT" || exit 1 + +BLACK_VERSION=$(black --version | head -n 1 | awk '{print $2}') +PYLINT_VERSION=$(pylint --version | head -n 1 | awk '{print $2}') + +# # params: tool name, tool version, required version +tool_version_check() { + if [[ $2 != $3 ]]; then + echo "Wrong $1 version installed: $3 is required, not $2." + exit 1 + fi +} + +tool_version_check "black" $BLACK_VERSION "23.3.0" +tool_version_check "pylint" $PYLINT_VERSION "2.8.2" + +# Format files that differ from main branch. Ignores dirs that are not slated +# for autoformat yet. +format_changed() { + # The `if` guard ensures that the list of filenames is not empty, which + # could cause yapf to receive 0 positional arguments, making it hang + # waiting for STDIN. + # + # `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that + # exist on both branches. + MERGEBASE="$(git merge-base origin/main HEAD)" + + if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then + git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs -P 5 black + fi +} + +## This flag formats individual files. --files *must* be the first command line +## arg to use this option. +if [[ "$1" == '--files' ]]; then + black "${@:2}" + # If `--all` is passed, then any further arguments are ignored and the + # entire python directory is formatted. +elif [[ "$1" == '--all' ]]; then + # Format all files + black fastchat +else + # Format only the files that changed in last commit. + format_changed +fi +echo 'FastChat Black: Done' + +# Run Pylint +echo 'FastChat Pylint:' +pylint fastchat +# TODO(suquark): disable 'pylint_quotes' for now due to too many inconsistent quotes +# pylint --load-plugins pylint_quotes fastchat + +if ! git diff --quiet &>/dev/null; then + echo 'Reformatted files. Please review and stage the changes.' + echo 'Changes not staged for commit:' + echo + git --no-pager diff --name-only + + exit 1 +fi diff --git a/third_party/FastChat/playground/deepspeed_config_s2.json b/third_party/FastChat/playground/deepspeed_config_s2.json new file mode 100644 index 00000000..4ade1550 --- /dev/null +++ b/third_party/FastChat/playground/deepspeed_config_s2.json @@ -0,0 +1,15 @@ +{ + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "cpu" + }, + "contiguous_gradients": true, + "overlap_comm": true + }, + "fp16": { + "enabled": "auto" + }, + "train_micro_batch_size_per_gpu": "auto", + "gradient_accumulation_steps": "auto" +} \ No newline at end of file diff --git a/third_party/FastChat/playground/deepspeed_config_s3.json b/third_party/FastChat/playground/deepspeed_config_s3.json new file mode 100644 index 00000000..07f4b16a --- /dev/null +++ b/third_party/FastChat/playground/deepspeed_config_s3.json @@ -0,0 +1,32 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + }, + "offload_param": { + "device": "cpu", + "pin_memory": true + }, + "overlap_comm": true, + "contiguous_gradients": true, + "stage3_max_live_parameters" : 1e9, + "stage3_max_reuse_distance" : 1e9, + "stage3_prefetch_bucket_size" : 5e8, + "stage3_param_persistence_threshold" : 1e6, + "sub_group_size" : 1e12, + "stage3_gather_16bit_weights_on_model_save": true + }, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "gradient_accumulation_steps": "auto" +} \ No newline at end of file diff --git a/third_party/FastChat/playground/test_embedding/README.md b/third_party/FastChat/playground/test_embedding/README.md new file mode 100644 index 00000000..57ac73c5 --- /dev/null +++ b/third_party/FastChat/playground/test_embedding/README.md @@ -0,0 +1,15 @@ +## Machine Learning with Embeddings +You can use embeddings to +- Evaluate text similarity, see [test_sentence_similarity.py](test_sentence_similarity.py) +- Build your own classifier, see [test_classification.py](test_classification.py) +- Search relative texts, see [test_semantic_search.py](test_semantic_search.py) + +To these tests, you need to download the data [here](https://www.kaggle.com/datasets/snap/amazon-fine-food-reviews). You also need an OpenAI API key for comparison. + +Run with: +```bash +cd playground/test_embedding +python3 test_classification.py +``` + +The script will train classifiers based on `vicuna-7b`, `text-similarity-ada-001` and `text-embedding-ada-002` and report the accuracy of each classifier. diff --git a/third_party/FastChat/playground/test_embedding/test_classification.py b/third_party/FastChat/playground/test_embedding/test_classification.py new file mode 100644 index 00000000..393827bb --- /dev/null +++ b/third_party/FastChat/playground/test_embedding/test_classification.py @@ -0,0 +1,83 @@ +import json +import os + +import numpy as np +import openai +import pandas as pd +import requests +from sklearn.ensemble import RandomForestClassifier +from sklearn.model_selection import train_test_split +from sklearn.metrics import classification_report, accuracy_score + + +np.set_printoptions(threshold=10000) + + +def get_embedding_from_api(word, model="vicuna-7b-v1.1"): + if "ada" in model: + resp = openai.Embedding.create( + model=model, + input=word, + ) + embedding = np.array(resp["data"][0]["embedding"]) + return embedding + + url = "http://localhost:8000/v1/embeddings" + headers = {"Content-Type": "application/json"} + data = json.dumps({"model": model, "input": word}) + + response = requests.post(url, headers=headers, data=data) + if response.status_code == 200: + embedding = np.array(response.json()["data"][0]["embedding"]) + return embedding + else: + print(f"Error: {response.status_code} - {response.text}") + return None + + +def create_embedding_data_frame(data_path, model, max_tokens=500): + df = pd.read_csv(data_path, index_col=0) + df = df[["Time", "ProductId", "UserId", "Score", "Summary", "Text"]] + df = df.dropna() + df["combined"] = ( + "Title: " + df.Summary.str.strip() + "; Content: " + df.Text.str.strip() + ) + top_n = 1000 + df = df.sort_values("Time").tail(top_n * 2) + df.drop("Time", axis=1, inplace=True) + + df["n_tokens"] = df.combined.apply(lambda x: len(x)) + df = df[df.n_tokens <= max_tokens].tail(top_n) + df["embedding"] = df.combined.apply(lambda x: get_embedding_from_api(x, model)) + return df + + +def train_random_forest(df): + X_train, X_test, y_train, y_test = train_test_split( + list(df.embedding.values), df.Score, test_size=0.2, random_state=42 + ) + + clf = RandomForestClassifier(n_estimators=100) + clf.fit(X_train, y_train) + preds = clf.predict(X_test) + + report = classification_report(y_test, preds) + accuracy = accuracy_score(y_test, preds) + return clf, accuracy, report + + +input_datapath = "amazon_fine_food_review.csv" +if not os.path.exists(input_datapath): + raise Exception( + f"Please download data from: https://www.kaggle.com/datasets/snap/amazon-fine-food-reviews" + ) + +df = create_embedding_data_frame(input_datapath, "vicuna-7b-v1.1") +clf, accuracy, report = train_random_forest(df) +print(f"Vicuna-7b-v1.1 accuracy:{accuracy}") +df = create_embedding_data_frame(input_datapath, "text-similarity-ada-001") +clf, accuracy, report = train_random_forest(df) +print(f"text-similarity-ada-001 accuracy:{accuracy}") +df = create_embedding_data_frame(input_datapath, "text-embedding-ada-002") +clf, accuracy, report = train_random_forest(df) +print(f"text-embedding-ada-002 accuracy:{accuracy}") diff --git a/third_party/FastChat/playground/test_embedding/test_semantic_search.py b/third_party/FastChat/playground/test_embedding/test_semantic_search.py new file mode 100644 index 00000000..879b240b --- /dev/null +++ b/third_party/FastChat/playground/test_embedding/test_semantic_search.py @@ -0,0 +1,99 @@ +import json +import os + +import numpy as np +import openai +import pandas as pd +import requests +from scipy.spatial.distance import cosine + + +def cosine_similarity(vec1, vec2): + try: + return 1 - cosine(vec1, vec2) + except: + print(vec1.shape, vec2.shape) + + +def get_embedding_from_api(word, model="vicuna-7b-v1.1"): + if "ada" in model: + resp = openai.Embedding.create( + model=model, + input=word, + ) + embedding = np.array(resp["data"][0]["embedding"]) + return embedding + + url = "http://localhost:8000/v1/embeddings" + headers = {"Content-Type": "application/json"} + data = json.dumps({"model": model, "input": word}) + + response = requests.post(url, headers=headers, data=data) + if response.status_code == 200: + embedding = np.array(response.json()["data"][0]["embedding"]) + return embedding + else: + print(f"Error: {response.status_code} - {response.text}") + return None + + +def create_embedding_data_frame(data_path, model, max_tokens=500): + df = pd.read_csv(data_path, index_col=0) + df = df[["Time", "ProductId", "UserId", "Score", "Summary", "Text"]] + df = df.dropna() + df["combined"] = ( + "Title: " + df.Summary.str.strip() + "; Content: " + df.Text.str.strip() + ) + top_n = 1000 + df = df.sort_values("Time").tail(top_n * 2) + df.drop("Time", axis=1, inplace=True) + + df["n_tokens"] = df.combined.apply(lambda x: len(x)) + df = df[df.n_tokens <= max_tokens].tail(top_n) + df["embedding"] = df.combined.apply(lambda x: get_embedding_from_api(x, model)) + return df + + +def search_reviews(df, product_description, n=3, pprint=False, model="vicuna-7b-v1.1"): + product_embedding = get_embedding_from_api(product_description, model=model) + df["similarity"] = df.embedding.apply( + lambda x: cosine_similarity(x, product_embedding) + ) + + results = ( + df.sort_values("similarity", ascending=False) + .head(n) + .combined.str.replace("Title: ", "") + .str.replace("; Content:", ": ") + ) + if pprint: + for r in results: + print(r[:200]) + print() + return results + + +def print_model_search(input_path, model): + print(f"Model: {model}") + df = create_embedding_data_frame(input_path, model) + print("search: delicious beans") + results = search_reviews(df, "delicious beans", n=5, model=model) + print(results) + print("search: whole wheat pasta") + results = search_reviews(df, "whole wheat pasta", n=5, model=model) + print(results) + print("search: bad delivery") + results = search_reviews(df, "bad delivery", n=5, model=model) + print(results) + + +input_datapath = "amazon_fine_food_review.csv" +if not os.path.exists(input_datapath): + raise Exception( + f"Please download data from: https://www.kaggle.com/datasets/snap/amazon-fine-food-reviews" + ) + + +print_model_search(input_datapath, "vicuna-7b-v1.1") +print_model_search(input_datapath, "text-similarity-ada-001") +print_model_search(input_datapath, "text-embedding-ada-002") diff --git a/third_party/FastChat/playground/test_embedding/test_sentence_similarity.py b/third_party/FastChat/playground/test_embedding/test_sentence_similarity.py new file mode 100644 index 00000000..d7a8f6e5 --- /dev/null +++ b/third_party/FastChat/playground/test_embedding/test_sentence_similarity.py @@ -0,0 +1,67 @@ +import json +import os + +import numpy as np +import openai +import requests +from scipy.spatial.distance import cosine + + +def get_embedding_from_api(word, model="vicuna-7b-v1.5"): + if "ada" in model: + resp = openai.Embedding.create( + model=model, + input=word, + ) + embedding = np.array(resp["data"][0]["embedding"]) + return embedding + + url = "http://localhost:8000/v1/embeddings" + headers = {"Content-Type": "application/json"} + data = json.dumps({"model": model, "input": word}) + + response = requests.post(url, headers=headers, data=data) + if response.status_code == 200: + embedding = np.array(response.json()["data"][0]["embedding"]) + return embedding + else: + print(f"Error: {response.status_code} - {response.text}") + return None + + +def cosine_similarity(vec1, vec2): + return 1 - cosine(vec1, vec2) + + +def print_cosine_similarity(embeddings, texts): + for i in range(len(texts)): + for j in range(i + 1, len(texts)): + sim = cosine_similarity(embeddings[texts[i]], embeddings[texts[j]]) + print(f"Cosine similarity between '{texts[i]}' and '{texts[j]}': {sim:.2f}") + + +texts = [ + "The quick brown fox", + "The quick brown dog", + "The fast brown fox", + "A completely different sentence", +] + +embeddings = {} +for text in texts: + embeddings[text] = get_embedding_from_api(text) + +print("Vicuna-7B:") +print_cosine_similarity(embeddings, texts) + +for text in texts: + embeddings[text] = get_embedding_from_api(text, model="text-similarity-ada-001") + +print("text-similarity-ada-001:") +print_cosine_similarity(embeddings, texts) + +for text in texts: + embeddings[text] = get_embedding_from_api(text, model="text-embedding-ada-002") + +print("text-embedding-ada-002:") +print_cosine_similarity(embeddings, texts) diff --git a/third_party/FastChat/pyproject.toml b/third_party/FastChat/pyproject.toml new file mode 100644 index 00000000..ded93bde --- /dev/null +++ b/third_party/FastChat/pyproject.toml @@ -0,0 +1,36 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "fschat" +version = "0.2.35" +description = "An open platform for training, serving, and evaluating large language model based chatbots." +readme = "README.md" +requires-python = ">=3.8" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", +] +dependencies = [ + "aiohttp", "fastapi", "httpx", "markdown2[all]", "nh3", "numpy", + "prompt_toolkit>=3.0.0", "pydantic<2,>=1", "requests", "rich>=10.0.0", + "shortuuid", "tiktoken", "uvicorn", +] + +[project.optional-dependencies] +model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf"] +webui = ["gradio"] +train = ["einops", "flash-attn>=2.0", "wandb"] +llm_judge = ["openai<1", "anthropic>=0.3", "ray"] +dev = ["black==23.3.0", "pylint==2.8.2"] + +[project.urls] +"Homepage" = "https://github.com/lm-sys/fastchat" +"Bug Tracker" = "https://github.com/lm-sys/fastchat/issues" + +[tool.setuptools.packages.find] +exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] + +[tool.wheel] +exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] diff --git a/third_party/FastChat/scripts/build-api.sh b/third_party/FastChat/scripts/build-api.sh new file mode 100644 index 00000000..8198108e --- /dev/null +++ b/third_party/FastChat/scripts/build-api.sh @@ -0,0 +1,60 @@ +#!/bin/bash +# A rather convenient script for spinning up models behind screens + + +# Variables +PROJECT_DIR="$(pwd)" +CONDA_ENV_NAME="fastchat" # + +MODEL_PATH="HuggingFaceH4/zephyr-7b-beta" #beta is better than the alpha version, base model w/o quantization +MODEL_PATH="lmsys/vicuna-7b-v1.5" + +API_HOST="0.0.0.0" +API_PORT_NUMBER=8000 + + +# init the screens +check_and_create_screen() { + local SCREENNAME="$1" + if screen -list | grep -q "$SCREENNAME"; then + echo "Screen session '$SCREENNAME' exists. Doing nothing." + else + echo "Screen session '$SCREENNAME' not found. Creating..." + screen -d -m -S "$SCREENNAME" + echo "created!" + fi +} + +# convenience function for sending commands to named screens +send_cmd() { + local SCREENNAME="$1" + local CMD="$2" + screen -DRRS $SCREENNAME -X stuff '$2 \r' +} + +# hardcoded names, for baby api +SCREENNAMES=( + "controller" + "api" + # Worker screens include the devices they are bound to, if 'd0' is only worker it has full GPU access + "worker-d0" + "worker-d1" +) + +for screen in "${SCREENNAMES[@]}"; do + check_and_create_screen "$screen" + sleep 0.1 + # also activate the conda compute environment for these + screen -DRRS "$screen" -X stuff "conda deactivate \r" + screen -DRRS "$screen" -X stuff "conda activate $CONDA_ENV_NAME \r" + +done + + +# Send Commmands on a per Screen Basis +screen -DRRS controller -X stuff "python3 -m fastchat.serve.controller \r" + +screen -DRRS worker-d0 -X stuff "CUDA_VISIBLE_DEVICES=0 python3 -m fastchat.serve.model_worker --model-path $MODEL_PATH --conv-template one_shot --limit-worker-concurrency 1 \r" +screen -DRRS worker-d1 -X stuff "CUDA_VISIBLE_DEVICES=1 python3 -m fastchat.serve.model_worker --model-path $MODEL_PATH --port 21003 --worker-address http://localhost:21003 --conv-template one_shot --limit-worker-concurrency 1 \r" + +screen -DRRS api -X stuff "python3 -m fastchat.serve.openai_api_server --host $API_HOST --port $API_PORT_NUMBER \r" diff --git a/third_party/FastChat/scripts/test_readme_train.sh b/third_party/FastChat/scripts/test_readme_train.sh new file mode 100644 index 00000000..1d1a0113 --- /dev/null +++ b/third_party/FastChat/scripts/test_readme_train.sh @@ -0,0 +1,24 @@ +torchrun --nproc_per_node=4 --master_port=20001 fastchat/train/train_mem.py \ + --model_name_or_path meta-llama/Llama-2-7b-hf \ + --data_path data/dummy_conversation.json \ + --bf16 True \ + --output_dir output_vicuna \ + --num_train_epochs 3 \ + --per_device_train_batch_size 2 \ + --per_device_eval_batch_size 2 \ + --gradient_accumulation_steps 16 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 1200 \ + --save_total_limit 10 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --fsdp "full_shard auto_wrap" \ + --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --lazy_preprocess True diff --git a/third_party/FastChat/scripts/train_lora.sh b/third_party/FastChat/scripts/train_lora.sh new file mode 100644 index 00000000..d30caad4 --- /dev/null +++ b/third_party/FastChat/scripts/train_lora.sh @@ -0,0 +1,29 @@ +deepspeed fastchat/train/train_lora.py \ + --model_name_or_path lmsys/vicuna-7b-v1.5 \ + --lora_r 8 \ + --lora_alpha 16 \ + --lora_dropout 0.05 \ + --data_path $DATA_PATH \ + --output_dir ./checkpoints \ + --num_train_epochs 150 \ + --fp16 True \ + --per_device_train_batch_size 2 \ + --per_device_eval_batch_size 2 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "steps" \ + --eval_steps 100 \ + --save_strategy "steps" \ + --save_steps 200 \ + --save_total_limit 2 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_strategy "steps" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --q_lora False \ + --deepspeed $PATH_TO_DEEPSPEED_CONFIG \ + --gradient_checkpointing True \ + --flash_attn False diff --git a/third_party/FastChat/scripts/train_vicuna_13b.sh b/third_party/FastChat/scripts/train_vicuna_13b.sh new file mode 100644 index 00000000..a6a843d3 --- /dev/null +++ b/third_party/FastChat/scripts/train_vicuna_13b.sh @@ -0,0 +1,26 @@ +torchrun --nproc_per_node=8 --master_port=20001 fastchat/train/train_mem.py \ + --model_name_or_path ~/model_weights/llama-13b \ + --data_path ~/datasets/sharegpt_20230422_clean_lang_split_identity.json \ + --bf16 True \ + --output_dir output_vicuna_13b \ + --num_train_epochs 3 \ + --per_device_train_batch_size 4 \ + --per_device_eval_batch_size 32 \ + --gradient_accumulation_steps 4 \ + --evaluation_strategy "steps" \ + --eval_steps 1500 \ + --save_strategy "steps" \ + --save_steps 1500 \ + --save_total_limit 8 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.04 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --fsdp "full_shard auto_wrap offload" \ + --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --lazy_preprocess True + diff --git a/third_party/FastChat/scripts/train_vicuna_7b.sh b/third_party/FastChat/scripts/train_vicuna_7b.sh new file mode 100644 index 00000000..8d1a45ac --- /dev/null +++ b/third_party/FastChat/scripts/train_vicuna_7b.sh @@ -0,0 +1,26 @@ +torchrun --nproc_per_node=4 --master_port=20001 fastchat/train/train_mem.py \ + --model_name_or_path ~/model_weights/llama-7b \ + --data_path ~/datasets/sharegpt_20230422_clean_lang_split_identity.json \ + --bf16 True \ + --output_dir output_vicuna_7b \ + --num_train_epochs 3 \ + --per_device_train_batch_size 2 \ + --per_device_eval_batch_size 16 \ + --gradient_accumulation_steps 16 \ + --evaluation_strategy "steps" \ + --eval_steps 1500 \ + --save_strategy "steps" \ + --save_steps 1500 \ + --save_total_limit 8 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.04 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --fsdp "full_shard auto_wrap" \ + --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --lazy_preprocess True + diff --git a/third_party/FastChat/scripts/upload_pypi.sh b/third_party/FastChat/scripts/upload_pypi.sh new file mode 100644 index 00000000..b0da77ef --- /dev/null +++ b/third_party/FastChat/scripts/upload_pypi.sh @@ -0,0 +1,3 @@ +rm -rf dist +python3 -m build +python3 -m twine upload dist/* diff --git a/third_party/FastChat/tests/README.md b/third_party/FastChat/tests/README.md new file mode 100644 index 00000000..3d1c1e61 --- /dev/null +++ b/third_party/FastChat/tests/README.md @@ -0,0 +1,53 @@ +## Unit tests for FastChat + +### Test CLI Inference + +``` +python3 test_cli.py +``` + +### Test OpenAI API Server + +``` +python3 launch_openai_api_test_server.py +``` + +``` +python3 test_openai_api.py +``` + +### Test GUI Serving + +``` +python3 -m fastchat.serve.controller +``` + +``` +CUDA_VISIBLE_DEVICES=0,1 python3 -m fastchat.serve.model_worker --model-path ~/model_weights/koala-13b --num-gpus 2 --port 30000 --worker http://localhost:30000 +CUDA_VISIBLE_DEVICES=2,3 python3 -m fastchat.serve.model_worker --model-path ~/model_weights/alpaca-13b --num-gpus 2 --port 30002 --worker http://localhost:30002 +CUDA_VISIBLE_DEVICES=4,5 python3 -m fastchat.serve.model_worker --model-path ~/model_weights/vicuna-13b --port 30004 --worker http://localhost:30004 --num-gpus 2 +CUDA_VISIBLE_DEVICES=6,7 python3 -m fastchat.serve.model_worker --model-path OpenAssistant/oasst-sft-1-pythia-12b --port 30006 --worker http://localhost:30006 --num-gpus 2 + +CUDA_VISIBLE_DEVICES=0,1 python3 -m fastchat.serve.model_worker --model-path StabilityAI/stablelm-tuned-alpha-7b --num-gpus 2 --port 30000 --worker http://localhost:30000 +CUDA_VISIBLE_DEVICES=2,3 python3 -m fastchat.serve.model_worker --model-path databricks/dolly-v2-12b --num-gpus 2 --port 30002 --worker http://localhost:30002 +CUDA_VISIBLE_DEVICES=4 python3 -m fastchat.serve.model_worker --model-path THUDM/chatglm-6b --port 30004 --worker http://localhost:30004 +CUDA_VISIBLE_DEVICES=5 python3 -m fastchat.serve.model_worker --model-path lmsys/fastchat-t5-3b-v1.0 --port 30005 --worker http://localhost:30005 +CUDA_VISIBLE_DEVICES=6 python3 -m fastchat.serve.model_worker --model-path ~/model_weights/baize-7b --port 30006 --worker http://localhost:30006 +CUDA_VISIBLE_DEVICES=7 python3 -m fastchat.serve.model_worker --model-path ~/model_weights/RWKV-4-Raven-7B-v11x-Eng99%-Other1%-20230429-ctx8192.pth --port 30007 --worker http://localhost:30007 +``` + +``` +python3 -m fastchat.serve.gradio_web_server_multi +``` + +### Test Peft Serving + +``` +python3 -m fastchat.serve.controller +``` + +``` +PEFT_SHARE_BASE_WEIGHTS=true python3 -m fastchat.serve.multi_model_worker \ + --model-path SurfaceData/dummy_pythia160m_lora16_peft_chat \ + --model-path SurfaceData/dummy_pythia160m_lora8_peft_chat +``` diff --git a/third_party/FastChat/tests/killall_python.sh b/third_party/FastChat/tests/killall_python.sh new file mode 100644 index 00000000..52eb874a --- /dev/null +++ b/third_party/FastChat/tests/killall_python.sh @@ -0,0 +1 @@ +kill -9 $(ps aux | grep 'python' | grep 'fastchat' | grep -v 'grep' | awk '{print $2}') diff --git a/third_party/FastChat/tests/launch_openai_api_test_server.py b/third_party/FastChat/tests/launch_openai_api_test_server.py new file mode 100644 index 00000000..f555a388 --- /dev/null +++ b/third_party/FastChat/tests/launch_openai_api_test_server.py @@ -0,0 +1,35 @@ +""" +Launch an OpenAI API server with multiple model workers. +""" +import os + + +def launch_process(cmd): + os.popen(cmd) + + +if __name__ == "__main__": + launch_process("python3 -m fastchat.serve.controller") + launch_process("python3 -m fastchat.serve.openai_api_server") + + models = [ + ("lmsys/vicuna-7b-v1.5", "model_worker"), + ("lmsys/fastchat-t5-3b-v1.0", "model_worker"), + ("THUDM/chatglm-6b", "model_worker"), + ("mosaicml/mpt-7b-chat", "model_worker"), + ("meta-llama/Llama-2-7b-chat-hf", "vllm_worker"), + ] + + for i, (model_path, worker_name) in enumerate(models): + cmd = ( + f"CUDA_VISIBLE_DEVICES={i} python3 -m fastchat.serve.{worker_name} " + f"--model-path {model_path} --port {30000+i} " + f"--worker-address http://localhost:{30000+i} " + ) + if worker_name == "vllm_worker": + cmd += "--tokenizer hf-internal-testing/llama-tokenizer" + + launch_process(cmd) + + while True: + pass diff --git a/third_party/FastChat/tests/test_cli.py b/third_party/FastChat/tests/test_cli.py new file mode 100644 index 00000000..113e497a --- /dev/null +++ b/third_party/FastChat/tests/test_cli.py @@ -0,0 +1,88 @@ +"""Test command line interface for model inference.""" +import argparse +import os + +from fastchat.utils import run_cmd + + +def test_single_gpu(): + models = [ + "lmsys/vicuna-7b-v1.5", + "lmsys/longchat-7b-16k", + "lmsys/fastchat-t5-3b-v1.0", + "meta-llama/Llama-2-7b-chat-hf", + "THUDM/chatglm-6b", + "THUDM/chatglm2-6b", + "mosaicml/mpt-7b-chat", + "tiiuae/falcon-7b-instruct", + "~/model_weights/alpaca-7b", + "~/model_weights/RWKV-4-Raven-7B-v11x-Eng99%-Other1%-20230429-ctx8192.pth", + ] + + for model_path in models: + if "model_weights" in model_path and not os.path.exists( + os.path.expanduser(model_path) + ): + continue + cmd = ( + f"python3 -m fastchat.serve.cli --model-path {model_path} " + f"--style programmatic < test_cli_inputs.txt" + ) + ret = run_cmd(cmd) + if ret != 0: + return + + print("") + + +def test_multi_gpu(): + models = [ + "lmsys/vicuna-13b-v1.3", + ] + + for model_path in models: + cmd = ( + f"python3 -m fastchat.serve.cli --model-path {model_path} " + f"--style programmatic --num-gpus 2 --max-gpu-memory 14Gib < test_cli_inputs.txt" + ) + ret = run_cmd(cmd) + if ret != 0: + return + print("") + + +def test_8bit(): + models = [ + "lmsys/vicuna-13b-v1.3", + ] + + for model_path in models: + cmd = ( + f"python3 -m fastchat.serve.cli --model-path {model_path} " + f"--style programmatic --load-8bit < test_cli_inputs.txt" + ) + ret = run_cmd(cmd) + if ret != 0: + return + print("") + + +def test_hf_api(): + models = [ + "lmsys/vicuna-7b-v1.5", + "lmsys/fastchat-t5-3b-v1.0", + ] + + for model_path in models: + cmd = f"python3 -m fastchat.serve.huggingface_api --model-path {model_path}" + ret = run_cmd(cmd) + if ret != 0: + return + print("") + + +if __name__ == "__main__": + test_single_gpu() + test_multi_gpu() + test_8bit() + test_hf_api() diff --git a/third_party/FastChat/tests/test_cli_inputs.txt b/third_party/FastChat/tests/test_cli_inputs.txt new file mode 100644 index 00000000..df79f87e --- /dev/null +++ b/third_party/FastChat/tests/test_cli_inputs.txt @@ -0,0 +1,4 @@ +Who are you? __END_OF_A_MESSAGE_47582648__ +Three tips for staying healthy. __END_OF_A_MESSAGE_47582648__ +One more tip. __END_OF_A_MESSAGE_47582648__ +!!exit __END_OF_A_MESSAGE_47582648__ diff --git a/third_party/FastChat/tests/test_openai_api.py b/third_party/FastChat/tests/test_openai_api.py new file mode 100644 index 00000000..3e7879bb --- /dev/null +++ b/third_party/FastChat/tests/test_openai_api.py @@ -0,0 +1,149 @@ +""" +Test the OpenAI compatible server + +Launch: +python3 launch_openai_api_test_server.py +""" +import warnings + +import openai +from fastchat.utils import run_cmd + + +openai.api_key = "EMPTY" # Not support yet +openai.base_url = "http://localhost:8000/v1/" + + +def test_list_models(): + model_list = openai.models.list() + names = [x.id for x in model_list.data] + return names + + +def test_completion(model, logprob): + prompt = "Once upon a time" + completion = openai.completions.create( + model=model, + prompt=prompt, + logprobs=logprob, + max_tokens=64, + temperature=0, + ) + + print(f"full text: {prompt + completion.choices[0].text}", flush=True) + if completion.choices[0].logprobs is not None: + print( + f"logprobs: {completion.choices[0].logprobs.token_logprobs[:10]}", + flush=True, + ) + + +def test_completion_stream(model): + prompt = "Once upon a time" + res = openai.completions.create( + model=model, + prompt=prompt, + max_tokens=64, + stream=True, + temperature=0, + ) + print(prompt, end="") + for chunk in res: + content = chunk.choices[0].text + print(content, end="", flush=True) + print() + + +def test_embedding(model): + embedding = openai.embeddings.create(model=model, input="Hello world!") + print(f"embedding len: {len(embedding.data[0].embedding)}") + print(f"embedding value[:5]: {embedding.data[0].embedding[:5]}") + + +def test_chat_completion(model): + completion = openai.chat.completions.create( + model=model, + messages=[{"role": "user", "content": "Hello! What is your name?"}], + temperature=0, + ) + print(completion.choices[0].message.content) + + +def test_chat_completion_stream(model): + messages = [{"role": "user", "content": "Hello! What is your name?"}] + res = openai.chat.completions.create( + model=model, messages=messages, stream=True, temperature=0 + ) + for chunk in res: + try: + content = chunk.choices[0].delta.content + if content is None: + content = "" + except Exception as e: + content = chunk.choices[0].delta.get("content", "") + print(content, end="", flush=True) + print() + + +def test_openai_curl(): + run_cmd("curl http://localhost:8000/v1/models") + + run_cmd( + """ +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "vicuna-7b-v1.5", + "messages": [{"role": "user", "content": "Hello! What is your name?"}] + }' +""" + ) + + run_cmd( + """ +curl http://localhost:8000/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "vicuna-7b-v1.5", + "prompt": "Once upon a time", + "max_tokens": 41, + "temperature": 0.5 + }' +""" + ) + + run_cmd( + """ +curl http://localhost:8000/v1/embeddings \ + -H "Content-Type: application/json" \ + -d '{ + "model": "vicuna-7b-v1.5", + "input": "Hello world!" + }' +""" + ) + + +if __name__ == "__main__": + models = test_list_models() + print(f"models: {models}") + + for model in models: + print(f"===== Test {model} ======") + + if model in ["fastchat-t5-3b-v1.0"]: + logprob = None + else: + logprob = 1 + + test_completion(model, logprob) + test_completion_stream(model) + test_chat_completion(model) + test_chat_completion_stream(model) + try: + test_embedding(model) + except openai.APIError as e: + print(f"Embedding error: {e}") + + print("===== Test curl =====") + test_openai_curl() diff --git a/third_party/FastChat/tests/test_openai_langchain.py b/third_party/FastChat/tests/test_openai_langchain.py new file mode 100644 index 00000000..b9c07fcf --- /dev/null +++ b/third_party/FastChat/tests/test_openai_langchain.py @@ -0,0 +1,39 @@ +# Usage: +# python3 -m fastchat.serve.model_worker --model-path lmsys/vicuna-7b-v1.5 --model-names gpt-3.5-turbo,text-davinci-003,text-embedding-ada-002 +# export OPENAI_API_BASE=http://localhost:8000/v1 +# export OPENAI_API_KEY=EMPTY +# wget https://raw.githubusercontent.com/hwchase17/langchain/v0.0.200/docs/modules/state_of_the_union.txt + +import os + +from langchain.chat_models import ChatOpenAI +from langchain.document_loaders import TextLoader +from langchain.embeddings import OpenAIEmbeddings +from langchain.indexes import VectorstoreIndexCreator + + +def test_chain(): + embedding = OpenAIEmbeddings(model="text-embedding-ada-002") + loader = TextLoader("state_of_the_union.txt") + index = VectorstoreIndexCreator(embedding=embedding).from_loaders([loader]) + + llm = ChatOpenAI(model="gpt-3.5-turbo") + + questions = [ + "Who is the speaker", + "What did the president say about Ketanji Brown Jackson", + "What are the threats to America", + "Who are mentioned in the speech", + "Who is the vice president", + "How many projects were announced", + ] + + for query in questions: + print("Query:", query) + print("Answer:", index.query(query, llm=llm)) + + +if __name__ == "__main__": + os.environ["OPENAI_API_BASE"] = "http://localhost:8000/v1" + os.environ["OPENAI_API_KEY"] = "empty" + test_chain() diff --git a/third_party/checksum.config b/third_party/checksum.config new file mode 100644 index 00000000..5663f49e --- /dev/null +++ b/third_party/checksum.config @@ -0,0 +1 @@ +9d68ed8579c46e9dad6d487d9ad1dbd6