Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update document sdk #2445

Merged
merged 17 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 194 additions & 15 deletions api/apps/sdk/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,28 @@ def upload(dataset_id, tenant_id):
@token_required
def docinfos(tenant_id):
req = request.args
if "id" not in req and "name" not in req:
return get_data_error_result(
retmsg="Id or name should be provided")
doc_id=None
if "id" in req:
doc_id = req["id"]
e, doc = DocumentService.get_by_id(doc_id)
return get_json_result(data=doc.to_json())
if "name" in req:
doc_name = req["name"]
doc_id = DocumentService.get_doc_id_by_doc_name(doc_name)
e, doc = DocumentService.get_by_id(doc_id)
return get_json_result(data=doc.to_json())
e, doc = DocumentService.get_by_id(doc_id)
#rename key's name
key_mapping = {
"chunk_num": "chunk_count",
"kb_id": "knowledgebase_id",
"token_num": "token_count",
}
renamed_doc = {}
for key, value in doc.to_dict().items():
new_key = key_mapping.get(key, key)
renamed_doc[new_key] = value

return get_json_result(data=renamed_doc)


@manager.route('/save', methods=['POST'])
Expand Down Expand Up @@ -246,7 +259,7 @@ def rename():
req["doc_id"], {"name": req["name"]}):
return get_data_error_result(
retmsg="Database error (Document rename)!")

informs = File2DocumentService.get_by_document_id(req["doc_id"])
if informs:
e, file = FileService.get_by_id(informs[0].file_id)
Expand All @@ -259,7 +272,7 @@ def rename():

@manager.route("/<document_id>", methods=["GET"])
@token_required
def download_document(dataset_id, document_id):
def download_document(dataset_id, document_id,tenant_id):
try:
# Check whether there is this document
exist, document = DocumentService.get_by_id(document_id)
Expand Down Expand Up @@ -313,7 +326,21 @@ def list_docs(dataset_id, tenant_id):
try:
docs, tol = DocumentService.get_by_kb_id(
kb_id, page_number, items_per_page, orderby, desc, keywords)
return get_json_result(data={"total": tol, "docs": docs})

# rename key's name
renamed_doc_list = []
for doc in docs:
key_mapping = {
"chunk_num": "chunk_count",
"kb_id": "knowledgebase_id",
"token_num": "token_count",
}
renamed_doc = {}
for key, value in doc.items():
new_key = key_mapping.get(key, key)
renamed_doc[new_key] = value
renamed_doc_list.append(renamed_doc)
return get_json_result(data={"total": tol, "docs": renamed_doc_list})
except Exception as e:
return server_error_response(e)

Expand Down Expand Up @@ -436,6 +463,8 @@ def list_chunk(tenant_id):
query["available_int"] = int(req["available_int"])
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}

origin_chunks=[]
for id in sres.ids:
d = {
"chunk_id": id,
Expand All @@ -455,7 +484,21 @@ def list_chunk(tenant_id):
poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]),
float(d["positions"][i + 3]), float(d["positions"][i + 4])])
d["positions"] = poss
res["chunks"].append(d)

origin_chunks.append(d)
##rename keys
for chunk in origin_chunks:
key_mapping = {
"chunk_id": "id",
"content_with_weight": "content",
"doc_id": "document_id",
"important_kwd": "important_keywords",
}
renamed_chunk = {}
for key, value in chunk.items():
new_key = key_mapping.get(key, key)
renamed_chunk[new_key] = value
res["chunks"].append(renamed_chunk)
return get_json_result(data=res)
except Exception as e:
if str(e).find("not_found") > 0:
Expand All @@ -471,8 +514,9 @@ def create(tenant_id):
req = request.json
md5 = hashlib.md5()
md5.update((req["content_with_weight"] + req["doc_id"]).encode("utf-8"))
chunck_id = md5.hexdigest()
d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),

chunk_id = md5.hexdigest()
d = {"id": chunk_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]),
"content_with_weight": req["content_with_weight"]}
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
d["important_kwd"] = req.get("important_kwd", [])
Expand Down Expand Up @@ -503,20 +547,33 @@ def create(tenant_id):

DocumentService.increment_chunk_num(
doc.id, doc.kb_id, c, 1, 0)
return get_json_result(data={"chunk": d})
# return get_json_result(data={"chunk_id": chunck_id})
d["chunk_id"] = chunk_id
#rename keys
key_mapping = {
"chunk_id": "id",
"content_with_weight": "content",
"doc_id": "document_id",
"important_kwd": "important_keywords",
"kb_id":"knowledge_base_id",
}
renamed_chunk = {}
for key, value in d.items():
new_key = key_mapping.get(key, key)
renamed_chunk[new_key] = value

return get_json_result(data={"chunk": renamed_chunk})
# return get_json_result(data={"chunk_id": chunk_id})
except Exception as e:
return server_error_response(e)


@manager.route('/chunk/rm', methods=['POST'])
@token_required
@validate_request("chunk_ids", "doc_id")
def rm_chunk():
def rm_chunk(tenant_id):
req = request.json
try:
if not ELASTICSEARCH.deleteByQuery(
Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)):
Q("ids", values=req["chunk_ids"]), search.index_name(tenant_id)):
return get_data_error_result(retmsg="Index updating failure")
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
Expand All @@ -526,4 +583,126 @@ def rm_chunk():
DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)

@manager.route('/chunk/set', methods=['POST'])
@token_required
@validate_request("doc_id", "chunk_id", "content_with_weight",
"important_kwd")
def set(tenant_id):
req = request.json
d = {
"id": req["chunk_id"],
"content_with_weight": req["content_with_weight"]}
d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
d["important_kwd"] = req["important_kwd"]
d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"]))
if "available_int" in req:
d["available_int"] = req["available_int"]

try:
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
return get_data_error_result(retmsg="Tenant not found!")

embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_mdl = TenantLLMService.model_instance(
tenant_id, LLMType.EMBEDDING.value, embd_id)

e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(retmsg="Document not found!")

if doc.parser_id == ParserType.QA:
arr = [
t for t in re.split(
r"[\n\t]",
req["content_with_weight"]) if len(t) > 1]
if len(arr) != 2:
return get_data_error_result(
retmsg="Q&A must be separated by TAB/ENTER key.")
q, a = rmPrefix(arr[0]), rmPrefix(arr[1])
d = beAdoc(d, arr[0], arr[1], not any(
[rag_tokenizer.is_chinese(t) for t in q + a]))

v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist()
ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)

@manager.route('/retrieval_test', methods=['POST'])
@token_required
@validate_request("kb_id", "question")
def retrieval_test(tenant_id):
req = request.json
page = int(req.get("page", 1))
size = int(req.get("size", 30))
question = req["question"]
kb_id = req["kb_id"]
if isinstance(kb_id, str): kb_id = [kb_id]
doc_ids = req.get("doc_ids", [])
similarity_threshold = float(req.get("similarity_threshold", 0.2))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
top = int(req.get("top_k", 1024))

try:
tenants = UserTenantService.query(user_id=tenant_id)
for kid in kb_id:
for tenant in tenants:
if KnowledgebaseService.query(
tenant_id=tenant.tenant_id, id=kid):
break
else:
return get_json_result(
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)

e, kb = KnowledgebaseService.get_by_id(kb_id[0])
if not e:
return get_data_error_result(retmsg="Knowledgebase not found!")

embd_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)

rerank_mdl = None
if req.get("rerank_id"):
rerank_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])

if req.get("keyword", False):
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)

retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_id, page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
for c in ranks["chunks"]:
if "vector" in c:
del c["vector"]

##rename keys
renamed_chunks=[]
for chunk in ranks["chunks"]:
key_mapping = {
"chunk_id": "id",
"content_with_weight": "content",
"doc_id": "document_id",
"important_kwd": "important_keywords",
}
rename_chunk={}
for key, value in chunk.items():
new_key = key_mapping.get(key, key)
rename_chunk[new_key] = value
renamed_chunks.append(rename_chunk)
ranks["chunks"] = renamed_chunks
return get_json_result(data=ranks)
except Exception as e:
if str(e).find("not_found") > 0:
return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!',
retcode=RetCode.DATA_ERROR)
return server_error_response(e)
50 changes: 33 additions & 17 deletions sdk/python/ragflow/modules/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,48 @@

class Chunk(Base):
def __init__(self, rag, res_dict):
# 初始化类的属性
self.id = ""
self.content_with_weight = ""
self.content_ltks = []
self.content_sm_ltks = []
self.important_kwd = []
self.important_tks = []
self.content = ""
self.important_keywords = []
self.create_time = ""
self.create_timestamp_flt = 0.0
self.kb_id = None
self.docnm_kwd = ""
self.doc_id = ""
self.q_vec = []
self.knowledgebase_id = None
self.document_name = ""
self.document_id = ""
self.status = "1"
for k, v in res_dict.items():
if hasattr(self, k):
setattr(self, k, v)

for k in list(res_dict.keys()):
if k not in self.__dict__:
res_dict.pop(k)
super().__init__(rag, res_dict)

def delete(self) -> bool:
"""
Delete the chunk in the document.
"""
res = self.rm('/doc/chunk/rm',
{"doc_id": [self.id],""})
res = self.post('/doc/chunk/rm',
{"doc_id": self.document_id, 'chunk_ids': [self.id]})
res = res.json()
if res.get("retmsg") == "success":
return True
raise Exception(res["retmsg"])
raise Exception(res["retmsg"])

def save(self) -> bool:
"""
Save the document details to the server.
"""
res = self.post('/doc/chunk/set',
{"chunk_id": self.id,
"kb_id": self.knowledgebase_id,
"name": self.document_name,
"content_with_weight": self.content,
"important_kwd": self.important_keywords,
"create_time": self.create_time,
"create_timestamp_flt": self.create_timestamp_flt,
"doc_id": self.document_id,
"status": self.status,
})
res = res.json()
if res.get("retmsg") == "success":
return True
raise Exception(res["retmsg"])

Loading