Skip to content

Commit

Permalink
refine admin initialization (#75)
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinHuSh committed Feb 27, 2024
1 parent d1c600d commit 4568a4b
Show file tree
Hide file tree
Showing 13 changed files with 91 additions and 34 deletions.
4 changes: 2 additions & 2 deletions api/apps/chunk_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
from elasticsearch_dsl import Q

from rag.app.qa import rmPrefix, beAdoc
from rag.nlp import search, huqie, retrievaler
from rag.nlp import search, huqie
from rag.utils import ELASTICSEARCH, rmSpace
from api.db import LLMType, ParserType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import TenantLLMService
from api.db.services.user_service import UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.db.services.document_service import DocumentService
from api.settings import RetCode
from api.settings import RetCode, retrievaler
from api.utils.api_utils import get_json_result
import hashlib
import re
Expand Down
4 changes: 1 addition & 3 deletions api/apps/conversation_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@
from api.db import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, LLMBundle
from api.settings import access_logger, stat_logger
from api.settings import access_logger, stat_logger, retrievaler
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid
from api.utils.api_utils import get_json_result
from rag.app.resume import forbidden_select_fields4resume
from rag.llm import ChatModel
from rag.nlp import retrievaler
from rag.nlp.search import index_name
from rag.utils import num_tokens_from_string, encoder, rmSpace

Expand Down
46 changes: 42 additions & 4 deletions api/db/init_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
import time
import uuid

from api.db import LLMType
from api.db import LLMType, UserTenantRole
from api.db.db_models import init_database_tables as init_web_db
from api.db.services import UserService
from api.db.services.llm_service import LLMFactoriesService, LLMService
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
from api.db.services.user_service import TenantService, UserTenantService
from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY


def init_superuser():
Expand All @@ -32,8 +34,44 @@ def init_superuser():
"creator": "system",
"status": "1",
}
tenant = {
"id": user_info["id"],
"name": user_info["nickname"] + "‘s Kingdom",
"llm_id": CHAT_MDL,
"embd_id": EMBEDDING_MDL,
"asr_id": ASR_MDL,
"parser_ids": PARSERS,
"img2txt_id": IMAGE2TEXT_MDL
}
usr_tenant = {
"tenant_id": user_info["id"],
"user_id": user_info["id"],
"invited_by": user_info["id"],
"role": UserTenantRole.OWNER
}
tenant_llm = []
for llm in LLMService.query(fid=LLM_FACTORY):
tenant_llm.append(
{"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type": llm.model_type,
"api_key": API_KEY})

if not UserService.save(**user_info):
print("【ERROR】can't init admin.")
return
TenantService.save(**tenant)
UserTenantService.save(**usr_tenant)
TenantLLMService.insert_many(tenant_llm)
UserService.save(**user_info)

chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})
if msg.find("ERROR: ") == 0:
print("【ERROR】: '{}' dosen't work. {}".format(tenant["llm_id"]), msg)
embd_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["embd_id"])
v,c = embd_mdl.encode(["Hello!"])
if c == 0:
print("【ERROR】: '{}' dosen't work...".format(tenant["embd_id"]))


def init_llm_factory():
factory_infos = [{
Expand Down Expand Up @@ -171,10 +209,10 @@ def init_llm_factory():

def init_web_data():
start_time = time.time()
if not UserService.get_all().count():
init_superuser()

if not LLMService.get_all().count():init_llm_factory()
if not UserService.get_all().count():
init_superuser()

print("init web data success:{}".format(time.time() - start_time))

Expand Down
6 changes: 5 additions & 1 deletion api/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
from api.utils.file_utils import get_project_base_directory
from api.utils.log_utils import LoggerFactory, getLogger

from rag.nlp import search
from rag.utils import ELASTICSEARCH


# Server
API_VERSION = "v1"
RAG_FLOW_SERVICE_NAME = "ragflow"
SERVER_MODULE = "rag_flow_server.py"
Expand Down Expand Up @@ -116,6 +118,8 @@
PRIVILEGE_COMMAND_WHITELIST = []
CHECK_NODES_IDENTITY = False

retrievaler = search.Dealer(ELASTICSEARCH)

class CustomEnum(Enum):
@classmethod
def valid(cls, value):
Expand Down
2 changes: 1 addition & 1 deletion deepdoc/parser/pdf_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def gather(kwd, fzy=10, ption=0.6):
b["H_right"] = headers[ii]["x1"]
b["H"] = ii

ii = Recognizer.find_overlapped_with_threashold(b, clmns, thr=0.3)
ii = Recognizer.find_horizontally_tightest_fit(b, clmns)
if ii is not None:
b["C"] = ii
b["C_left"] = clmns[ii]["x0"]
Expand Down
2 changes: 1 addition & 1 deletion deepdoc/vision/layout_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, domain):
super().__init__(self.labels, domain,
os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))

def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.7, batch_size=16):
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16):
def __is_garbage(b):
patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
Expand Down
5 changes: 2 additions & 3 deletions deepdoc/vision/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np
import cv2
import paddle
from shapely.geometry import Polygon
import pyclipper

Expand Down Expand Up @@ -215,7 +214,7 @@ def box_score_slow(self, bitmap, contour):

def __call__(self, outs_dict, shape_list):
pred = outs_dict['maps']
if isinstance(pred, paddle.Tensor):
if not isinstance(pred, np.ndarray):
pred = pred.numpy()
pred = pred[:, 0, :, :]
segmentation = pred > self.thresh
Expand Down Expand Up @@ -339,7 +338,7 @@ def __init__(self, character_dict_path=None, use_space_char=False,
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, tuple) or isinstance(preds, list):
preds = preds[-1]
if isinstance(preds, paddle.Tensor):
if not isinstance(preds, np.ndarray):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
Expand Down
12 changes: 12 additions & 0 deletions deepdoc/vision/recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,18 @@ def find_overlapped(box, boxes_sorted_by_y, naive=False):

return max_overlaped_i

@staticmethod
def find_horizontally_tightest_fit(box, boxes):
if not boxes:
return
min_dis, min_i = 1000000, None
for i,b in enumerate(boxes):
dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2)
if dis < min_dis:
min_i = i
min_dis = dis
return min_i

@staticmethod
def find_overlapped_with_threashold(box, boxes, thr=0.3):
if not boxes:
Expand Down
4 changes: 3 additions & 1 deletion deepdoc/vision/t_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def gather(kwd, fzy=10, ption=0.6):
clmns = sorted([r for r in tb_cpns if re.match(
r"table column$", r["label"])], key=lambda x: x["x0"])
clmns = Recognizer.layouts_cleanup(boxes, clmns, 5, 0.5)

for b in boxes:
ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3)
if ii is not None:
Expand All @@ -89,7 +90,7 @@ def gather(kwd, fzy=10, ption=0.6):
b["H_right"] = headers[ii]["x1"]
b["H"] = ii

ii = Recognizer.find_overlapped_with_threashold(b, clmns, thr=0.3)
ii = Recognizer.find_horizontally_tightest_fit(b, clmns)
if ii is not None:
b["C"] = ii
b["C_left"] = clmns[ii]["x0"]
Expand All @@ -102,6 +103,7 @@ def gather(kwd, fzy=10, ption=0.6):
b["H_left"] = spans[ii]["x0"]
b["H_right"] = spans[ii]["x1"]
b["SP"] = ii

html = """
<html>
<head>
Expand Down
10 changes: 5 additions & 5 deletions deepdoc/vision/table_structure_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import os
import re
from collections import Counter
from copy import deepcopy

import numpy as np

Expand All @@ -37,7 +36,7 @@ def __init__(self):
super().__init__(self.labels, "tsr",
os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))

def __call__(self, images, thr=0.5):
def __call__(self, images, thr=0.2):
tbls = super().__call__(images, thr)
res = []
# align left&right for rows, align top&bottom for columns
Expand All @@ -56,8 +55,8 @@ def __call__(self, images, thr=0.5):
"row") > 0 or b["label"].find("header") > 0]
if not left:
continue
left = np.median(left) if len(left) > 4 else np.min(left)
right = np.median(right) if len(right) > 4 else np.max(right)
left = np.mean(left) if len(left) > 4 else np.min(left)
right = np.mean(right) if len(right) > 4 else np.max(right)
for b in lts:
if b["label"].find("row") > 0 or b["label"].find("header") > 0:
if b["x0"] > left:
Expand Down Expand Up @@ -129,6 +128,7 @@ def construct_table(boxes, is_english=False, html=False):
i = 0
while i < len(boxes):
if TableStructureRecognizer.is_caption(boxes[i]):
if is_english: cap + " "
cap += boxes[i]["text"]
boxes.pop(i)
i -= 1
Expand Down Expand Up @@ -398,7 +398,7 @@ def __desc_table(cap, hdr_rowno, tbl, is_english):
for i in range(clmno):
if not tbl[r][i]:
continue
txt = "".join([a["text"].strip() for a in tbl[r][i]])
txt = " ".join([a["text"].strip() for a in tbl[r][i]])
headers[r][i] = txt
hdrset.add(txt)
if all([not t for t in headers[r]]):
Expand Down
19 changes: 11 additions & 8 deletions rag/llm/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#
from abc import ABC
from openai import OpenAI
import os
import openai


class Base(ABC):
Expand All @@ -33,11 +33,14 @@ def __init__(self, key, model_name="gpt-3.5-turbo"):

def chat(self, system, history, gen_conf):
if system: history.insert(0, {"role": "system", "content": system})
res = self.client.chat.completions.create(
model=self.model_name,
messages=history,
**gen_conf)
return res.choices[0].message.content.strip(), res.usage.completion_tokens
try:
res = self.client.chat.completions.create(
model=self.model_name,
messages=history,
**gen_conf)
return res.choices[0].message.content.strip(), res.usage.completion_tokens
except openai.APIError as e:
return "ERROR: "+str(e), 0


from dashscope import Generation
Expand All @@ -58,7 +61,7 @@ def chat(self, system, history, gen_conf):
)
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content'], response.usage.output_tokens
return response.message, 0
return "ERROR: " + response.message, 0


from zhipuai import ZhipuAI
Expand All @@ -77,4 +80,4 @@ def chat(self, system, history, gen_conf):
)
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content'], response.usage.completion_tokens
return response.message, 0
return "ERROR: " + response.message, 0
5 changes: 2 additions & 3 deletions rag/nlp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from . import search
from rag.utils import ELASTICSEARCH

retrievaler = search.Dealer(ELASTICSEARCH)

from nltk.stem import PorterStemmer
stemmer = PorterStemmer()
Expand Down Expand Up @@ -39,10 +36,12 @@
]
]


def random_choices(arr, k):
k = min(len(arr), k)
return random.choices(arr, k=k)


def bullets_category(sections):
global BULLET_PATTERN
hits = [0] * len(BULLET_PATTERN)
Expand Down
6 changes: 4 additions & 2 deletions rag/nlp/search.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
import json
import re
from elasticsearch_dsl import Q, Search, A
from elasticsearch_dsl import Q, Search
from typing import List, Optional, Dict, Union
from dataclasses import dataclass

Expand Down Expand Up @@ -183,6 +183,7 @@ def trans2floats(txt):

def insert_citations(self, answer, chunks, chunk_v,
embd_mdl, tkweight=0.3, vtweight=0.7):
assert len(chunks) == len(chunk_v)
pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer)
for i in range(1, len(pieces)):
if re.match(r"[a-z][.?;!][ \n]", pieces[i]):
Expand Down Expand Up @@ -216,7 +217,7 @@ def insert_citations(self, answer, chunks, chunk_v,
if mx < 0.55:
continue
cites[idx[i]] = list(
set([str(i) for i in range(len(chunk_v)) if sim[i] > mx]))[:4]
set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]

res = ""
for i, p in enumerate(pieces):
Expand All @@ -225,6 +226,7 @@ def insert_citations(self, answer, chunks, chunk_v,
continue
if i not in cites:
continue
assert int(cites[i]) < len(chunk_v)
res += "##%s$$" % "$".join(cites[i])

return res
Expand Down

0 comments on commit 4568a4b

Please sign in to comment.