diff --git a/api/apps/document_app.py b/api/apps/document_app.py index 2d129ed660..6947ac0348 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -23,7 +23,7 @@ from flask import request from flask_login import login_required, current_user -from api.db.db_models import Task +from api.db.db_models import Task, File from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService from api.db.services.task_service import TaskService, queue_tasks @@ -33,7 +33,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils import get_uuid -from api.db import FileType, TaskStatus, ParserType +from api.db import FileType, TaskStatus, ParserType, FileSource from api.db.services.document_service import DocumentService from api.settings import RetCode from api.utils.api_utils import get_json_result @@ -59,12 +59,19 @@ def upload(): return get_json_result( data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR) + e, kb = KnowledgebaseService.get_by_id(kb_id) + if not e: + raise LookupError("Can't find this knowledgebase!") + + root_folder = FileService.get_root_folder(current_user.id) + pf_id = root_folder["id"] + FileService.init_knowledgebase_docs(pf_id, current_user.id) + kb_root_folder = FileService.get_kb_folder(current_user.id) + kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"]) + err = [] for file in file_objs: try: - e, kb = KnowledgebaseService.get_by_id(kb_id) - if not e: - raise LookupError("Can't find this knowledgebase!") MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0)) if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(kb.tenant_id) >= MAX_FILE_NUM_PER_USER: raise RuntimeError("Exceed the maximum file number of a free user!") @@ -99,6 +106,8 @@ def upload(): if re.search(r"\.(ppt|pptx|pages)$", filename): doc["parser_id"] = ParserType.PRESENTATION.value DocumentService.insert(doc) + + FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id) except Exception as e: err.append(file.filename + ": " + str(e)) if err: @@ -228,11 +237,13 @@ def rm(): req = request.json doc_ids = req["doc_id"] if isinstance(doc_ids, str): doc_ids = [doc_ids] + root_folder = FileService.get_root_folder(current_user.id) + pf_id = root_folder["id"] + FileService.init_knowledgebase_docs(pf_id, current_user.id) errors = "" for doc_id in doc_ids: try: e, doc = DocumentService.get_by_id(doc_id) - if not e: return get_data_error_result(retmsg="Document not found!") tenant_id = DocumentService.get_tenant_id(doc_id) @@ -241,21 +252,25 @@ def rm(): ELASTICSEARCH.deleteByQuery( Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id)) - DocumentService.increment_chunk_num( - doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, 0) + + DocumentService.clear_chunk_num(doc_id) + b, n = File2DocumentService.get_minio_address(doc_id=doc_id) + if not DocumentService.delete(doc): return get_data_error_result( retmsg="Database error (Document removal)!") - informs = File2DocumentService.get_by_document_id(doc_id) - if not informs: - MINIO.rm(doc.kb_id, doc.location) - else: - File2DocumentService.delete_by_document_id(doc_id) + f2d = File2DocumentService.get_by_document_id(doc_id) + FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id]) + File2DocumentService.delete_by_document_id(doc_id) + + MINIO.rm(b, n) except Exception as e: errors += str(e) - if errors: return server_error_response(e) + if errors: + return get_json_result(data=False, retmsg=errors, retcode=RetCode.SERVER_ERROR) + return get_json_result(data=True) diff --git a/api/apps/file_app.py b/api/apps/file_app.py index b94c15506f..5dd4220c96 100644 --- a/api/apps/file_app.py +++ b/api/apps/file_app.py @@ -26,7 +26,7 @@ from api.db.services.file2document_service import File2DocumentService from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils import get_uuid -from api.db import FileType +from api.db import FileType, FileSource from api.db.services import duplicate_name from api.db.services.file_service import FileService from api.settings import RetCode @@ -45,7 +45,7 @@ def upload(): if not pf_id: root_folder = FileService.get_root_folder(current_user.id) - pf_id = root_folder.id + pf_id = root_folder["id"] if 'file' not in request.files: return get_json_result( @@ -132,7 +132,7 @@ def create(): input_file_type = request.json.get("type") if not pf_id: root_folder = FileService.get_root_folder(current_user.id) - pf_id = root_folder.id + pf_id = root_folder["id"] try: if not FileService.is_parent_folder_exist(pf_id): @@ -176,7 +176,8 @@ def list(): desc = request.args.get("desc", True) if not pf_id: root_folder = FileService.get_root_folder(current_user.id) - pf_id = root_folder.id + pf_id = root_folder["id"] + FileService.init_knowledgebase_docs(pf_id, current_user.id) try: e, file = FileService.get_by_id(pf_id) if not e: @@ -199,7 +200,7 @@ def list(): def get_root_folder(): try: root_folder = FileService.get_root_folder(current_user.id) - return get_json_result(data={"root_folder": root_folder.to_json()}) + return get_json_result(data={"root_folder": root_folder}) except Exception as e: return server_error_response(e) @@ -250,6 +251,8 @@ def rm(): return get_data_error_result(retmsg="File or Folder not found!") if not file.tenant_id: return get_data_error_result(retmsg="Tenant not found!") + if file.source_type == FileSource.KNOWLEDGEBASE: + continue if file.type == FileType.FOLDER.value: file_id_list = FileService.get_all_innermost_file_ids(file_id, []) diff --git a/api/db/__init__.py b/api/db/__init__.py index e1a9ff2d82..0612754740 100644 --- a/api/db/__init__.py +++ b/api/db/__init__.py @@ -83,3 +83,11 @@ class ParserType(StrEnum): NAIVE = "naive" PICTURE = "picture" ONE = "one" + + +class FileSource(StrEnum): + LOCAL = "" + KNOWLEDGEBASE = "knowledgebase" + S3 = "s3" + +KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase" \ No newline at end of file diff --git a/api/db/db_models.py b/api/db/db_models.py index db73e84b97..ecd97b2c64 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -21,14 +21,13 @@ from functools import wraps from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer from flask_login import UserMixin - +from playhouse.migrate import MySQLMigrator, migrate from peewee import ( - BigAutoField, BigIntegerField, BooleanField, CharField, - CompositeKey, Insert, IntegerField, TextField, FloatField, DateTimeField, + BigIntegerField, BooleanField, CharField, + CompositeKey, IntegerField, TextField, FloatField, DateTimeField, Field, Model, Metadata ) from playhouse.pool import PooledMySQLDatabase - from api.db import SerializedType, ParserType from api.settings import DATABASE, stat_logger, SECRET_KEY from api.utils.log_utils import getLogger @@ -344,7 +343,7 @@ class Meta: @DB.connection_context() -def init_database_tables(): +def init_database_tables(alter_fields=[]): members = inspect.getmembers(sys.modules[__name__], inspect.isclass) table_objs = [] create_failed_list = [] @@ -361,6 +360,7 @@ def init_database_tables(): if create_failed_list: LOGGER.info(f"create tables failed: {create_failed_list}") raise Exception(f"create tables failed: {create_failed_list}") + migrate_db() def fill_db_model_object(model_object, human_model_dict): @@ -699,6 +699,11 @@ class File(DataBaseModel): help_text="where dose it store") size = IntegerField(default=0) type = CharField(max_length=32, null=False, help_text="file extension") + source_type = CharField( + max_length=128, + null=False, + default="", + help_text="where dose this document come from") class Meta: db_table = "file" @@ -817,3 +822,14 @@ class API4Conversation(DataBaseModel): class Meta: db_table = "api_4_conversation" + + +def migrate_db(): + try: + with DB.transaction(): + migrator = MySQLMigrator(DB) + migrate( + migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", help_text="where dose this document come from")) + ) + except Exception as e: + pass diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 0e7a6e38ce..eca6877a9f 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -150,6 +150,22 @@ def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation): Knowledgebase.id == kb_id).execute() return num + @classmethod + @DB.connection_context() + def clear_chunk_num(cls, doc_id): + doc = cls.model.get_by_id(doc_id) + assert doc, "Can't fine document in database." + + num = Knowledgebase.update( + token_num=Knowledgebase.token_num - + doc.token_num, + chunk_num=Knowledgebase.chunk_num - + doc.chunk_num, + doc_num=Knowledgebase.doc_num-1 + ).where( + Knowledgebase.id == doc.kb_id).execute() + return num + @classmethod @DB.connection_context() def get_tenant_id(cls, doc_id): diff --git a/api/db/services/file2document_service.py b/api/db/services/file2document_service.py index 18ec03d316..71ae12539b 100644 --- a/api/db/services/file2document_service.py +++ b/api/db/services/file2document_service.py @@ -15,12 +15,12 @@ # from datetime import datetime +from api.db import FileSource from api.db.db_models import DB -from api.db.db_models import File, Document, File2Document +from api.db.db_models import File, File2Document from api.db.services.common_service import CommonService from api.db.services.document_service import DocumentService -from api.db.services.file_service import FileService -from api.utils import current_timestamp, datetime_format +from api.utils import current_timestamp, datetime_format, get_uuid class File2DocumentService(CommonService): @@ -71,13 +71,15 @@ def update_by_file_id(cls, file_id, obj): @DB.connection_context() def get_minio_address(cls, doc_id=None, file_id=None): if doc_id: - ids = File2DocumentService.get_by_document_id(doc_id) + f2d = cls.get_by_document_id(doc_id) else: - ids = File2DocumentService.get_by_file_id(file_id) - if ids: - e, file = FileService.get_by_id(ids[0].file_id) - return file.parent_id, file.location - else: - assert doc_id, "please specify doc_id" - e, doc = DocumentService.get_by_id(doc_id) - return doc.kb_id, doc.location + f2d = cls.get_by_file_id(file_id) + if f2d: + file = File.get_by_id(f2d[0].file_id) + if file.source_type == FileSource.LOCAL: + return file.parent_id, file.location + doc_id = f2d[0].document_id + + assert doc_id, "please specify doc_id" + e, doc = DocumentService.get_by_id(doc_id) + return doc.kb_id, doc.location diff --git a/api/db/services/file_service.py b/api/db/services/file_service.py index 57948d4211..664a117739 100644 --- a/api/db/services/file_service.py +++ b/api/db/services/file_service.py @@ -16,10 +16,12 @@ from flask_login import current_user from peewee import fn -from api.db import FileType +from api.db import FileType, KNOWLEDGEBASE_FOLDER_NAME, FileSource from api.db.db_models import DB, File2Document, Knowledgebase from api.db.db_models import File, Document from api.db.services.common_service import CommonService +from api.db.services.document_service import DocumentService +from api.db.services.file2document_service import File2DocumentService from api.utils import get_uuid @@ -33,10 +35,15 @@ def get_by_pf_id(cls, tenant_id, pf_id, page_number, items_per_page, if keywords: files = cls.model.select().where( (cls.model.tenant_id == tenant_id) - & (cls.model.parent_id == pf_id), (fn.LOWER(cls.model.name).like(f"%%{keywords.lower()}%%"))) + (cls.model.parent_id == pf_id), + (fn.LOWER(cls.model.name).like(f"%%{keywords.lower()}%%")), + ~(cls.model.id == pf_id) + ) else: - files = cls.model.select().where((cls.model.tenant_id == tenant_id) - & (cls.model.parent_id == pf_id)) + files = cls.model.select().where((cls.model.tenant_id == tenant_id), + (cls.model.parent_id == pf_id), + ~(cls.model.id == pf_id) + ) count = files.count() if desc: files = files.order_by(cls.model.getter_by(orderby).desc()) @@ -135,29 +142,69 @@ def is_parent_folder_exist(cls, parent_id): @classmethod @DB.connection_context() def get_root_folder(cls, tenant_id): - file = cls.model.select().where(cls.model.tenant_id == tenant_id and - cls.model.parent_id == cls.model.id) - if not file: - file_id = get_uuid() - file = { - "id": file_id, - "parent_id": file_id, - "tenant_id": tenant_id, - "created_by": tenant_id, - "name": "/", - "type": FileType.FOLDER.value, - "size": 0, - "location": "", - } - cls.save(**file) - else: - file_id = file[0].id + for file in cls.model.select().where((cls.model.tenant_id == tenant_id), + (cls.model.parent_id == cls.model.id) + ): + return file.to_dict() - e, file = cls.get_by_id(file_id) - if not e: - raise RuntimeError("Database error (File retrieval)!") + file_id = get_uuid() + file = { + "id": file_id, + "parent_id": file_id, + "tenant_id": tenant_id, + "created_by": tenant_id, + "name": "/", + "type": FileType.FOLDER.value, + "size": 0, + "location": "", + } + cls.save(**file) + return file + + @classmethod + @DB.connection_context() + def get_kb_folder(cls, tenant_id): + for root in cls.model.select().where(cls.model.tenant_id == tenant_id and + cls.model.parent_id == cls.model.id): + for folder in cls.model.select().where(cls.model.tenant_id == tenant_id and + cls.model.parent_id == root.id and + cls.model.name == KNOWLEDGEBASE_FOLDER_NAME + ): + return folder.to_dict() + assert False, "Can't find the KB folder. Database init error." + + @classmethod + @DB.connection_context() + def new_a_file_from_kb(cls, tenant_id, name, parent_id, ty=FileType.FOLDER.value, size=0, location=""): + for file in cls.query(tenant_id=tenant_id, parent_id=parent_id, name=name): + return file.to_dict() + file = { + "id": get_uuid(), + "parent_id": parent_id, + "tenant_id": tenant_id, + "created_by": tenant_id, + "name": name, + "type": ty, + "size": size, + "location": location, + "source_type": FileSource.KNOWLEDGEBASE + } + cls.save(**file) return file + @classmethod + @DB.connection_context() + def init_knowledgebase_docs(cls, root_id, tenant_id): + for _ in cls.model.select().where((cls.model.name == KNOWLEDGEBASE_FOLDER_NAME)\ + & (cls.model.parent_id == root_id)): + return + folder = cls.new_a_file_from_kb(tenant_id, KNOWLEDGEBASE_FOLDER_NAME, root_id) + + for kb in Knowledgebase.select(*[Knowledgebase.id, Knowledgebase.name]).where(Knowledgebase.tenant_id==tenant_id): + kb_folder = cls.new_a_file_from_kb(tenant_id, kb.name, folder["id"]) + for doc in DocumentService.query(kb_id=kb.id): + FileService.add_file_from_kb(doc.to_dict(), kb_folder["id"], tenant_id) + @classmethod @DB.connection_context() def get_parent_folder(cls, file_id): @@ -241,3 +288,20 @@ def dfs(parent_id): dfs(folder_id) return size + @classmethod + @DB.connection_context() + def add_file_from_kb(cls, doc, kb_folder_id, tenant_id): + for _ in File2DocumentService.get_by_document_id(doc["id"]): return + file = { + "id": get_uuid(), + "parent_id": kb_folder_id, + "tenant_id": tenant_id, + "created_by": tenant_id, + "name": doc["name"], + "type": doc["type"], + "size": doc["size"], + "location": doc["location"], + "source_type": FileSource.KNOWLEDGEBASE + } + cls.save(**file) + File2DocumentService.save(**{"id": get_uuid(), "file_id": file["id"], "document_id": doc["id"]}) \ No newline at end of file diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index 07548dc10e..59fc6360b5 100644 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -8,14 +8,14 @@ PY=/root/miniconda3/envs/py11/bin/python function task_exe(){ while [ 1 -eq 1 ];do - $PY rag/svr/task_executor.py $1 $2; + $PY rag/svr/task_executor.py ; done } WS=1 for ((i=0;i