From 758eb03ccb3002a0d0adcbaf1642e9999dceb732 Mon Sep 17 00:00:00 2001 From: KevinHuSh Date: Wed, 29 May 2024 19:38:57 +0800 Subject: [PATCH] fix jina adding issure and term weight refinement (#974) ### What problem does this PR solve? #724 #162 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality) --- api/apps/llm_app.py | 14 ++++++++------ api/db/services/llm_service.py | 1 - rag/llm/__init__.py | 1 + rag/llm/embedding_model.py | 2 +- rag/llm/rerank_model.py | 2 +- rag/nlp/query.py | 2 +- rag/nlp/term_weight.py | 2 +- 7 files changed, 13 insertions(+), 11 deletions(-) diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 36fa5c3ccb..c4a245ffb3 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -39,17 +39,18 @@ def factories(): def set_api_key(): req = request.json # test if api key works - chat_passed = False + chat_passed, embd_passed, rerank_passed = False, False, False factory = req["llm_factory"] msg = "" for llm in LLMService.query(fid=factory): - if llm.model_type == LLMType.EMBEDDING.value: + if not embd_passed and llm.model_type == LLMType.EMBEDDING.value: mdl = EmbeddingModel[factory]( req["api_key"], llm.llm_name, base_url=req.get("base_url")) try: arr, tc = mdl.encode(["Test if the api key is available"]) if len(arr[0]) == 0 or tc == 0: raise Exception("Fail") + embd_passed = True except Exception as e: msg += f"\nFail to access embedding model({llm.llm_name}) using this api key." + str(e) elif not chat_passed and llm.model_type == LLMType.CHAT.value: @@ -60,20 +61,21 @@ def set_api_key(): "temperature": 0.9}) if not tc: raise Exception(m) - chat_passed = True except Exception as e: msg += f"\nFail to access model({llm.llm_name}) using this api key." + str( e) - elif llm.model_type == LLMType.RERANK: + chat_passed = True + elif not rerank_passed and llm.model_type == LLMType.RERANK: mdl = RerankModel[factory]( req["api_key"], llm.llm_name, base_url=req.get("base_url")) try: - m, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"]) - if len(arr[0]) == 0 or tc == 0: + arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"]) + if len(arr) == 0 or tc == 0: raise Exception("Fail") except Exception as e: msg += f"\nFail to access model({llm.llm_name}) using this api key." + str( e) + rerank_passed = True if msg: return get_data_error_result(retmsg=msg) diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 5a71ea69ab..c484afcc3b 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -147,7 +147,6 @@ def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None): .where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\ .execute() except Exception as e: - print(e) pass return num diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 25b08921aa..c2a99b2c10 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -28,6 +28,7 @@ "FastEmbed": FastEmbed, "Youdao": YoudaoEmbed, "BaiChuan": BaiChuanEmbed, + "Jina": JinaEmbed, "BAAI": DefaultEmbedding } diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index f3d0a87298..5083a6945f 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -291,7 +291,7 @@ def encode(self, texts: list, batch_size=None): "input": texts, 'encoding_type': 'float' } - res = requests.post(self.base_url, headers=self.headers, json=data) + res = requests.post(self.base_url, headers=self.headers, json=data).json() return np.array([d["embedding"] for d in res["data"]]), res["usage"]["total_tokens"] def encode_queries(self, text): diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index 0f4440c3fd..c1c12f1130 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -91,7 +91,7 @@ def similarity(self, query: str, texts: list): "documents": texts, "top_n": len(texts) } - res = requests.post(self.base_url, headers=self.headers, json=data) + res = requests.post(self.base_url, headers=self.headers, json=data).json() return np.array([d["relevance_score"] for d in res["results"]]), res["usage"]["total_tokens"] diff --git a/rag/nlp/query.py b/rag/nlp/query.py index 07bd96f4ed..3f6dfd0399 100644 --- a/rag/nlp/query.py +++ b/rag/nlp/query.py @@ -44,7 +44,7 @@ def rmWWW(txt): def question(self, txt, tbl="qa", min_match="60%"): txt = re.sub( - r"[ \r\n\t,,。??/`!!&\^%%]+", + r"[ :\r\n\t,,。??/`!!&\^%%]+", " ", rag_tokenizer.tradi2simp( rag_tokenizer.strQ2B( diff --git a/rag/nlp/term_weight.py b/rag/nlp/term_weight.py index 178dafe102..639896c883 100644 --- a/rag/nlp/term_weight.py +++ b/rag/nlp/term_weight.py @@ -104,7 +104,7 @@ def oneTerm(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t) while i < len(tks): j = i if i == 0 and oneTerm(tks[i]) and len( - tks) > 1 and len(tks[i + 1]) > 1: # 多 工位 + tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位 res.append(" ".join(tks[0:2])) i = 2 continue