Skip to content

Commit

Permalink
add rerank model (#969)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

feat: add rerank models to the project #724 #162

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
  • Loading branch information
KevinHuSh committed May 29, 2024
1 parent e1f0644 commit 614defe
Show file tree
Hide file tree
Showing 17 changed files with 437 additions and 64 deletions.
11 changes: 9 additions & 2 deletions api/apps/chunk_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,15 @@ def retrieval_test():

embd_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold,
vector_similarity_weight, top, doc_ids)

rerank_mdl = None
if req.get("rerank_id"):
rerank_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])

ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl)
for c in ranks["chunks"]:
if "vector" in c:
del c["vector"]
Expand Down
5 changes: 5 additions & 0 deletions api/apps/dialog_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def set_dialog():
name = req.get("name", "New Dialog")
description = req.get("description", "A helpful Dialog")
top_n = req.get("top_n", 6)
top_k = req.get("top_k", 1024)
rerank_id = req.get("rerank_id", "")
if not rerank_id: req["rerank_id"] = ""
similarity_threshold = req.get("similarity_threshold", 0.1)
vector_similarity_weight = req.get("vector_similarity_weight", 0.3)
llm_setting = req.get("llm_setting", {})
Expand Down Expand Up @@ -83,6 +86,8 @@ def set_dialog():
"llm_setting": llm_setting,
"prompt_config": prompt_config,
"top_n": top_n,
"top_k": top_k,
"rerank_id": rerank_id,
"similarity_threshold": similarity_threshold,
"vector_similarity_weight": vector_similarity_weight
}
Expand Down
16 changes: 13 additions & 3 deletions api/apps/llm_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
from api.db import StatusEnum, LLMType
from api.db.db_models import TenantLLM
from api.utils.api_utils import get_json_result
from rag.llm import EmbeddingModel, ChatModel
from rag.llm import EmbeddingModel, ChatModel, RerankModel


@manager.route('/factories', methods=['GET'])
@login_required
def factories():
try:
fac = LLMFactoriesService.get_all()
return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed"]])
return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed", "BAAI"]])
except Exception as e:
return server_error_response(e)

Expand Down Expand Up @@ -64,6 +64,16 @@ def set_api_key():
except Exception as e:
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
e)
elif llm.model_type == LLMType.RERANK:
mdl = RerankModel[factory](
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
try:
m, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
if len(arr[0]) == 0 or tc == 0:
raise Exception("Fail")
except Exception as e:
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
e)

if msg:
return get_data_error_result(retmsg=msg)
Expand Down Expand Up @@ -199,7 +209,7 @@ def list_app():
llms = [m.to_dict()
for m in llms if m.status == StatusEnum.VALID.value]
for m in llms:
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed"]
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed", "BAAI"]

llm_set = set([m["llm_name"] for m in llms])
for o in objs:
Expand Down
8 changes: 5 additions & 3 deletions api/apps/user_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
from api.utils.api_utils import server_error_response, validate_request
from api.utils import get_uuid, get_format_time, decrypt, download_img, current_timestamp, datetime_format
from api.db import UserTenantRole, LLMType, FileType
from api.settings import RetCode, GITHUB_OAUTH, FEISHU_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, API_KEY, \
LLM_FACTORY, LLM_BASE_URL
from api.settings import RetCode, GITHUB_OAUTH, FEISHU_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, \
API_KEY, \
LLM_FACTORY, LLM_BASE_URL, RERANK_MDL
from api.db.services.user_service import UserService, TenantService, UserTenantService
from api.db.services.file_service import FileService
from api.settings import stat_logger
Expand Down Expand Up @@ -288,7 +289,8 @@ def user_register(user_id, user):
"embd_id": EMBEDDING_MDL,
"asr_id": ASR_MDL,
"parser_ids": PARSERS,
"img2txt_id": IMAGE2TEXT_MDL
"img2txt_id": IMAGE2TEXT_MDL,
"rerank_id": RERANK_MDL
}
usr_tenant = {
"tenant_id": user_id,
Expand Down
1 change: 1 addition & 0 deletions api/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class LLMType(StrEnum):
EMBEDDING = 'embedding'
SPEECH2TEXT = 'speech2text'
IMAGE2TEXT = 'image2text'
RERANK = 'rerank'


class ChatStyle(StrEnum):
Expand Down
39 changes: 33 additions & 6 deletions api/db/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,10 @@ class Tenant(DataBaseModel):
max_length=128,
null=False,
help_text="default image to text model ID")
rerank_id = CharField(
max_length=128,
null=False,
help_text="default rerank model ID")
parser_ids = CharField(
max_length=256,
null=False,
Expand Down Expand Up @@ -771,11 +775,16 @@ class Dialog(DataBaseModel):
similarity_threshold = FloatField(default=0.2)
vector_similarity_weight = FloatField(default=0.3)
top_n = IntegerField(default=6)
top_k = IntegerField(default=1024)
do_refer = CharField(
max_length=1,
null=False,
help_text="it needs to insert reference index into answer or not",
default="1")
rerank_id = CharField(
max_length=128,
null=False,
help_text="default rerank model ID")

kb_ids = JSONField(null=False, default=[])
status = CharField(
Expand Down Expand Up @@ -825,11 +834,29 @@ class Meta:


def migrate_db():
try:
with DB.transaction():
migrator = MySQLMigrator(DB)
migrate(
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", help_text="where dose this document come from"))
)
except Exception as e:
pass
try:
migrate(
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", help_text="where dose this document come from"))
)
except Exception as e:
pass
try:
migrate(
migrator.add_column('tenant', 'rerank_id', CharField(max_length=128, null=False, default="BAAI/bge-reranker-v2-m3", help_text="default rerank model ID"))
)
except Exception as e:
pass
try:
migrate(
migrator.add_column('dialog', 'rerank_id', CharField(max_length=128, null=False, default="", help_text="default rerank model ID"))
)
except Exception as e:
pass
try:
migrate(
migrator.add_column('dialog', 'top_k', IntegerField(default=1024))
)
except Exception as e:
pass
98 changes: 97 additions & 1 deletion api/db/init_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,17 @@ def init_superuser():
"logo": "",
"tags": "LLM,TEXT EMBEDDING",
"status": "1",
},
},{
"name": "Jina",
"logo": "",
"tags": "TEXT EMBEDDING, TEXT RE-RANK",
"status": "1",
},{
"name": "BAAI",
"logo": "",
"tags": "TEXT EMBEDDING, TEXT RE-RANK",
"status": "1",
}
# {
# "name": "文心一言",
# "logo": "",
Expand Down Expand Up @@ -367,6 +377,13 @@ def init_llm_factory():
"max_tokens": 512,
"model_type": LLMType.EMBEDDING.value
},
{
"fid": factory_infos[7]["name"],
"llm_name": "maidalun1020/bce-reranker-base_v1",
"tags": "RE-RANK, 8K",
"max_tokens": 8196,
"model_type": LLMType.RERANK.value
},
# ------------------------ DeepSeek -----------------------
{
"fid": factory_infos[8]["name"],
Expand Down Expand Up @@ -440,6 +457,85 @@ def init_llm_factory():
"max_tokens": 512,
"model_type": LLMType.EMBEDDING.value
},
# ------------------------ Jina -----------------------
{
"fid": factory_infos[11]["name"],
"llm_name": "jina-reranker-v1-base-en",
"tags": "RE-RANK,8k",
"max_tokens": 8196,
"model_type": LLMType.RERANK.value
},
{
"fid": factory_infos[11]["name"],
"llm_name": "jina-reranker-v1-turbo-en",
"tags": "RE-RANK,8k",
"max_tokens": 8196,
"model_type": LLMType.RERANK.value
},
{
"fid": factory_infos[11]["name"],
"llm_name": "jina-reranker-v1-tiny-en",
"tags": "RE-RANK,8k",
"max_tokens": 8196,
"model_type": LLMType.RERANK.value
},
{
"fid": factory_infos[11]["name"],
"llm_name": "jina-colbert-v1-en",
"tags": "RE-RANK,8k",
"max_tokens": 8196,
"model_type": LLMType.RERANK.value
},
{
"fid": factory_infos[11]["name"],
"llm_name": "jina-embeddings-v2-base-en",
"tags": "TEXT EMBEDDING",
"max_tokens": 8196,
"model_type": LLMType.EMBEDDING.value
},
{
"fid": factory_infos[11]["name"],
"llm_name": "jina-embeddings-v2-base-de",
"tags": "TEXT EMBEDDING",
"max_tokens": 8196,
"model_type": LLMType.EMBEDDING.value
},
{
"fid": factory_infos[11]["name"],
"llm_name": "jina-embeddings-v2-base-es",
"tags": "TEXT EMBEDDING",
"max_tokens": 8196,
"model_type": LLMType.EMBEDDING.value
},
{
"fid": factory_infos[11]["name"],
"llm_name": "jina-embeddings-v2-base-code",
"tags": "TEXT EMBEDDING",
"max_tokens": 8196,
"model_type": LLMType.EMBEDDING.value
},
{
"fid": factory_infos[11]["name"],
"llm_name": "jina-embeddings-v2-base-zh",
"tags": "TEXT EMBEDDING",
"max_tokens": 8196,
"model_type": LLMType.EMBEDDING.value
},
# ------------------------ BAAI -----------------------
{
"fid": factory_infos[12]["name"],
"llm_name": "BAAI/bge-large-zh-v1.5",
"tags": "TEXT EMBEDDING,",
"max_tokens": 1024,
"model_type": LLMType.EMBEDDING.value
},
{
"fid": factory_infos[12]["name"],
"llm_name": "BAAI/bge-reranker-v2-m3",
"tags": "LLM,CHAT,",
"max_tokens": 16385,
"model_type": LLMType.RERANK.value
},
]
for info in factory_infos:
try:
Expand Down
7 changes: 5 additions & 2 deletions api/db/services/dialog_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,14 @@ def chat(dialog, messages, stream=True, **kwargs):
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
else:
rerank_mdl = None
if dialog.rerank_id:
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
top=1024, aggs=False)
top=1024, aggs=False, rerank_mdl=rerank_mdl)
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
chat_logger.info(
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
Expand All @@ -130,7 +133,7 @@ def chat(dialog, messages, stream=True, **kwargs):

kwargs["knowledge"] = "\n".join(knowledges)
gen_conf = dialog.llm_setting

msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
msg.extend([{"role": m["role"], "content": m["content"]}
for m in messages if m["role"] != "system"])
Expand Down
Loading

0 comments on commit 614defe

Please sign in to comment.