Skip to content

Commit

Permalink
Merge pull request #297 from netease-youdao/develop_for_v1.3.1
Browse files Browse the repository at this point in the history
Develop for v1.3.1
  • Loading branch information
successren authored Apr 26, 2024
2 parents 2f7893b + ac0ea24 commit 3004bab
Show file tree
Hide file tree
Showing 56 changed files with 1,295 additions and 508 deletions.
124 changes: 121 additions & 3 deletions qanything_kernel/connector/database/mysql/mysql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from qanything_kernel.configs.model_config import SQLITE_DATABASE
from qanything_kernel.utils.custom_log import debug_logger
import uuid
import json
from datetime import datetime, timedelta


class KnowledgeBaseManager:
Expand Down Expand Up @@ -109,6 +111,34 @@ def create_tables_(self):
"""
self.execute_query_(query, (), commit=True)

# query = 'DROP TABLE IF EXISTS QaLogs'
# self.execute_query_(query, (), commit=True)

query = """
CREATE TABLE IF NOT EXISTS QaLogs (
qa_id VARCHAR(64) PRIMARY KEY,
user_id VARCHAR(255) NOT NULL,
bot_id VARCHAR(255),
kb_ids LONGTEXT NOT NULL,
query VARCHAR(512) NOT NULL,
model VARCHAR(64) NOT NULL,
product_source VARCHAR(64) NOT NULL,
time_record LONGTEXT NOT NULL,
history LONGTEXT NOT NULL,
condense_question VARCHAR(1024) NOT NULL,
prompt LONGTEXT NOT NULL,
result VARCHAR(1024) NOT NULL,
retrieval_documents LONGTEXT NOT NULL,
source_documents LONGTEXT NOT NULL,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
"""
self.execute_query_(query, (), commit=True)

# 如果不存在则创建索引
create_index_query = "CREATE INDEX IF NOT EXISTS index_bot_id ON QaLogs (bot_id);"
self.execute_query_(create_index_query, (), commit=True)

# 旧的File不存在file_path,补上默认值:'UNK'
# 如果存在File表,但是没有file_path字段,那么添加file_path字段
query = "PRAGMA table_info(File)"
Expand Down Expand Up @@ -282,6 +312,94 @@ def add_file(self, user_id, kb_id, file_name, timestamp, status="gray"):
debug_logger.info("add_file: {}".format(file_id))
return file_id, "success"

def add_qalog(self, user_id, bot_id, kb_ids, query, model, product_source, time_record, history, condense_question,
prompt, result, retrieval_documents, source_documents):
debug_logger.info("add_qalog: {}".format(query))
qa_id = uuid.uuid4().hex
kb_ids = json.dumps(kb_ids, ensure_ascii=False)
retrieval_documents = json.dumps(retrieval_documents, ensure_ascii=False)
source_documents = json.dumps(source_documents, ensure_ascii=False)
history = json.dumps(history, ensure_ascii=False)
time_record = json.dumps(time_record, ensure_ascii=False)
insert_query = ("INSERT INTO QaLogs (qa_id, user_id, bot_id, kb_ids, query, model, product_source, time_record, "
"history, condense_question, prompt, result, retrieval_documents, source_documents) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)")
self.execute_query_(insert_query, (qa_id, user_id, bot_id, kb_ids, query, model, product_source, time_record,
history, condense_question, prompt, result, retrieval_documents,
source_documents), commit=True)

def get_qalog_by_bot_id(self, bot_id, time_range=None):
need_info = ("qa_id", "user_id", "bot_id", "query", "model", "result", "timestamp")
need_info = ", ".join(need_info)
if isinstance(time_range, tuple):
time_start, time_end = time_range
if len(time_start) == 10:
time_start = time_start + " 00:00:00"
if len(time_end) == 10:
time_end = time_end + " 23:59:59"
query = f"SELECT {need_info} FROM QaLogs WHERE bot_id = ? AND timestamp BETWEEN ? AND ?"
return self.execute_query_(query, (bot_id, time_start, time_end), fetch=True)
else:
query = f"SELECT {need_info} FROM QaLogs WHERE bot_id = ?"
return self.execute_query_(query, (bot_id,), fetch=True)

def get_qalog_by_ids(self, ids):
placeholders = ','.join(['?'] * len(ids))
query = "SELECT qa_id, user_id, bot_id, kb_ids, query, model, history, result, timestamp FROM QaLogs WHERE qa_id IN ({})".format(placeholders)
return self.execute_query_(query, ids, fetch=True)

def get_qalog_by_filter(self, need_info, user_id=None, kb_ids=None, query=None, bot_id=None, time_range=None):
if kb_ids:
kb_ids = json.dumps(kb_ids, ensure_ascii=False)
if not time_range:
# time_range默认设置为最近30天,格式为2024-02-12
time_start = (datetime.now() - timedelta(days=30)).strftime("%Y-%m-%d 00:00:00")
time_end = datetime.now().strftime("%Y-%m-%d 23:59:59")
time_range = (time_start, time_end)
if isinstance(time_range, tuple):
time_start, time_end = time_range
if len(time_start) == 10:
time_start = time_start + " 00:00:00"
if len(time_end) == 10:
time_end = time_end + " 23:59:59"
time_range = (time_start, time_end)
# 判断哪些条件不是None,构建搜索query
need_info = ", ".join(need_info)
mysql_query = f"SELECT {need_info} FROM QaLogs WHERE timestamp BETWEEN ? AND ?"
params = list(time_range)
if user_id:
mysql_query += " AND user_id = ?"
params.append(user_id)
if kb_ids:
mysql_query += " AND kb_ids = ?"
params.append(kb_ids)
if bot_id:
mysql_query += " AND bot_id = ?"
params.append(bot_id)
if query:
mysql_query += " AND query = ?"
params.append(query)
debug_logger.info("get_qalog_by_filter: {}".format(params))
qa_infos = self.execute_query_(mysql_query, params, fetch=True)
# 根据need_info构建一个dict
qa_infos = [dict(zip(need_info.split(", "), qa_info)) for qa_info in qa_infos]
for qa_info in qa_infos:
if 'timestamp' in qa_info:
qa_info['timestamp'] = qa_info['timestamp'].strftime("%Y-%m-%d %H:%M:%S")
if 'kb_ids' in qa_info:
qa_info['kb_ids'] = json.loads(qa_info['kb_ids'])
if 'time_record' in qa_info:
qa_info['time_record'] = json.loads(qa_info['time_record'])
if 'retrieval_documents' in qa_info:
qa_info['retrieval_documents'] = json.loads(qa_info['retrieval_documents'])
if 'source_documents' in qa_info:
qa_info['source_documents'] = json.loads(qa_info['source_documents'])
if 'history' in qa_info:
qa_info['history'] = json.loads(qa_info['history'])
if 'timestamp' in need_info:
qa_infos = sorted(qa_infos, key=lambda x: x["timestamp"], reverse=True)
return qa_infos

def add_faq(self, faq_id, user_id, kb_id, question, answer, nos_keys):
# debug_logger.info(f"add_faq: {faq_id}, {user_id}, {kb_id}, {question}, {answer}, {nos_keys}")
query = "INSERT INTO Faqs (faq_id, user_id, kb_id, question, answer, nos_keys) VALUES (?, ?, ?, ?, ?, ?)"
Expand Down Expand Up @@ -334,13 +452,13 @@ def delete_files(self, kb_id, file_ids):

def get_bot(self, user_id, bot_id):
if not bot_id:
query = "SELECT bot_id, bot_name, description, head_image, prompt_setting, welcome_message, model, kb_ids_str, update_time FROM QanythingBot WHERE user_id = ? AND deleted = 0"
query = "SELECT bot_id, bot_name, description, head_image, prompt_setting, welcome_message, model, kb_ids_str, update_time, user_id FROM QanythingBot WHERE user_id = ? AND deleted = 0"
return self.execute_query_(query, (user_id,), fetch=True)
elif not user_id:
query = "SELECT bot_id, bot_name, description, head_image, prompt_setting, welcome_message, model, kb_ids_str, update_time FROM QanythingBot WHERE bot_id = ? AND deleted = 0"
query = "SELECT bot_id, bot_name, description, head_image, prompt_setting, welcome_message, model, kb_ids_str, update_time, user_id FROM QanythingBot WHERE bot_id = ? AND deleted = 0"
return self.execute_query_(query, (bot_id, ), fetch=True)
else:
query = "SELECT bot_id, bot_name, description, head_image, prompt_setting, welcome_message, model, kb_ids_str, update_time FROM QanythingBot WHERE user_id = ? AND bot_id = ? AND deleted = 0"
query = "SELECT bot_id, bot_name, description, head_image, prompt_setting, welcome_message, model, kb_ids_str, update_time, user_id FROM QanythingBot WHERE user_id = ? AND bot_id = ? AND deleted = 0"
return self.execute_query_(query, (user_id, bot_id), fetch=True)

def update_bot(self, user_id, bot_id, bot_name, description, head_image, prompt_setting, welcome_message, model,
Expand Down
6 changes: 0 additions & 6 deletions qanything_kernel/connector/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +0,0 @@
import platform
if platform.system() == "Linux":
from .llm_for_fastchat import OpenAICustomLLM
elif platform.system() == "Darwin":
from .llm_for_llamacpp import LlamaCPPCustomLLM
from .llm_for_openai_api import OpenAILLM
79 changes: 66 additions & 13 deletions qanything_kernel/core/local_doc_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
PROMPT_TEMPLATE, STREAMING
from typing import List
import time
from qanything_kernel.connector.llm import OpenAILLM
from qanything_kernel.connector.llm.llm_for_openai_api import OpenAILLM
from langchain.schema import Document
from qanything_kernel.connector.database.mysql.mysql_client import KnowledgeBaseManager
from qanything_kernel.connector.database.faiss.faiss_client import FaissClient
Expand All @@ -11,12 +11,14 @@
import easyocr
from easyocr import Reader
from qanything_kernel.utils.custom_log import debug_logger, qa_logger
from qanything_kernel.core.tools.web_search_tool import duckduckgo_search
from .local_file import LocalFile
import traceback
import base64
import numpy as np
import platform

from concurrent.futures import ThreadPoolExecutor, as_completed
import asyncio

class LocalDocQA:
def __init__(self):
Expand All @@ -31,6 +33,7 @@ def __init__(self):
self.ocr_reader: Reader = None
self.mode: str = None
self.use_cpu: bool = True
self.model: str = None

def get_ocr_result(self, input: dict):
img_file = input['img64']
Expand All @@ -46,11 +49,15 @@ def get_ocr_result(self, input: dict):
def init_cfg(self, args=None):
self.rerank_top_k = int(args.model_size[0])
self.use_cpu = args.use_cpu
if args.use_openai_api:
self.model = args.openai_api_model_name
else:
self.model = args.model.split('/')[-1]
if platform.system() == 'Linux':
if args.use_openai_api:
self.llm: OpenAILLM = OpenAILLM(args)
else:
from qanything_kernel.connector.llm import OpenAICustomLLM
from qanything_kernel.connector.llm.llm_for_fastchat import OpenAICustomLLM
self.llm: OpenAICustomLLM = OpenAICustomLLM(args)
from qanything_kernel.connector.rerank.rerank_onnx_backend import RerankOnnxBackend
from qanything_kernel.connector.embedding.embedding_onnx_backend import EmbeddingOnnxBackend
Expand All @@ -60,7 +67,7 @@ def init_cfg(self, args=None):
if args.use_openai_api:
self.llm: OpenAILLM = OpenAILLM(args)
else:
from qanything_kernel.connector.llm import LlamaCPPCustomLLM
from qanything_kernel.connector.llm.llm_for_llamacpp import LlamaCPPCustomLLM
self.llm: LlamaCPPCustomLLM = LlamaCPPCustomLLM(args)
from qanything_kernel.connector.rerank.rerank_torch_backend import RerankTorchBackend
from qanything_kernel.connector.embedding.embedding_torch_backend import EmbeddingTorchBackend
Expand Down Expand Up @@ -111,6 +118,46 @@ def deduplicate_documents(self, source_docs):
deduplicated_docs.append(doc)
return deduplicated_docs

async def local_doc_search(self, query, kb_ids, score_threshold=0.35):
source_documents = await self.get_source_documents(query, kb_ids)
deduplicated_docs = self.deduplicate_documents(source_documents)
retrieval_documents = sorted(deduplicated_docs, key=lambda x: x.metadata['score'], reverse=True)
if len(retrieval_documents) > 1:
debug_logger.info(f"use rerank, rerank docs num: {len(retrieval_documents)}")
# rerank需要的query必须是改写后的, 不然会丢一些信息
retrieval_documents = self.rerank_documents(query, retrieval_documents)
# 删除掉分数低于阈值的文档
if score_threshold:
retrieval_documents = [item for item in retrieval_documents if float(item.metadata['score']) > score_threshold]

retrieval_documents = retrieval_documents[: self.rerank_top_k]
debug_logger.info(f"local doc search retrieval_documents: {retrieval_documents}")
return retrieval_documents

def get_web_search(self, queries, top_k=None):
if not top_k:
top_k = self.top_k
query = queries[0]
web_content, web_documents = duckduckgo_search(query)
source_documents = []
for doc in web_documents:
doc.metadata['retrieval_query'] = query # 添加查询到文档的元数据中
source_documents.append(doc)
return web_content, source_documents



def web_page_search(self, query, top_k=None):
# 防止get_web_search调用失败,需要try catch
try:
web_content, source_documents = self.get_web_search([query], top_k)
except Exception as e:
debug_logger.error(f"web search error: {e}")
return []

return source_documents


async def get_source_documents(self, query, kb_ids, cosine_thresh=None, top_k=None):
if not top_k:
top_k = self.top_k
Expand Down Expand Up @@ -193,18 +240,24 @@ def rerank_documents(self, query, source_documents):
source_documents = sorted(source_documents, key=lambda x: x.metadata['score'], reverse=True)
return source_documents

async def get_knowledge_based_answer(self, custom_prompt, query, kb_ids, chat_history=None, streaming: bool = STREAMING,
rerank: bool = False):
async def retrieve(self, query, kb_ids, need_web_search=False):
retrieval_documents = await self.local_doc_search(query, kb_ids)
if need_web_search:
retrieval_documents.extend(self.web_page_search(query, top_k=3))
debug_logger.info(f"retrieval_documents: {retrieval_documents}")
retrieval_documents = self.rerank_documents(query, retrieval_documents)
debug_logger.info(f"reranked retrieval_documents: {retrieval_documents}")
return retrieval_documents

async def get_knowledge_based_answer(self, custom_prompt, query, kb_ids, chat_history=None,
streaming: bool = STREAMING,
rerank: bool = False,
need_web_search: bool = False):
if chat_history is None:
chat_history = []

source_documents = await self.get_source_documents(query, kb_ids)

deduplicated_docs = self.deduplicate_documents(source_documents)
retrieval_documents = sorted(deduplicated_docs, key=lambda x: x.metadata['score'], reverse=True)
if rerank and len(retrieval_documents) > 1:
debug_logger.info(f"use rerank, rerank docs num: {len(retrieval_documents)}")
retrieval_documents = self.rerank_documents(query, retrieval_documents)[: self.rerank_top_k]
#retrieval_queries = [query]
retrieval_documents = await self.retrieve(query, kb_ids, need_web_search=need_web_search)

if custom_prompt is None:
prompt_template = PROMPT_TEMPLATE
Expand Down
6 changes: 5 additions & 1 deletion qanything_kernel/core/local_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from langchain.text_splitter import RecursiveCharacterTextSplitter
from qanything_kernel.utils.custom_log import debug_logger, qa_logger
from qanything_kernel.utils.splitter import ChineseTextSplitter
from qanything_kernel.utils.loader import UnstructuredPaddleImageLoader, UnstructuredPaddlePDFLoader
from qanything_kernel.utils.loader import UnstructuredPaddleImageLoader, UnstructuredPaddlePDFLoader, UnstructuredPaddleAudioLoader
from qanything_kernel.utils.splitter import zh_title_enhance
from sanic.request import File
import pandas as pd
Expand Down Expand Up @@ -111,7 +111,11 @@ def split_file_to_docs(self, ocr_engine: Callable, sentence_size=SENTENCE_SIZE,
elif self.file_path.lower().endswith(".csv"):
loader = CSVLoader(self.file_path, csv_args={"delimiter": ",", "quotechar": '"'})
docs = loader.load()
elif self.file_path.lower().endswith(".mp3") or self.file_path.lower().endswith(".wav"):
loader = UnstructuredPaddleAudioLoader(self.file_path, self.use_cpu)
docs = loader.load()
else:
debug_logger.info("file_path: {}".format(self.file_path))
raise TypeError("文件类型不支持,目前仅支持:[md,txt,pdf,jpg,png,jpeg,docx,xlsx,pptx,eml,csv]")
if using_zh_title_enhance:
debug_logger.info("using_zh_title_enhance %s", using_zh_title_enhance)
Expand Down
Loading

0 comments on commit 3004bab

Please sign in to comment.