Skip to content

Commit

Permalink
let file in knowledgebases visible in file manager (infiniflow#714)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Let file in knowledgebases visible in file manager.
infiniflow#162 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
  • Loading branch information
KevinHuSh committed May 11, 2024
1 parent 41debb6 commit 0714e32
Show file tree
Hide file tree
Showing 10 changed files with 187 additions and 64 deletions.
43 changes: 29 additions & 14 deletions api/apps/document_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from flask import request
from flask_login import login_required, current_user

from api.db.db_models import Task
from api.db.db_models import Task, File
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.task_service import TaskService, queue_tasks
Expand All @@ -33,7 +33,7 @@
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid
from api.db import FileType, TaskStatus, ParserType
from api.db import FileType, TaskStatus, ParserType, FileSource
from api.db.services.document_service import DocumentService
from api.settings import RetCode
from api.utils.api_utils import get_json_result
Expand All @@ -59,12 +59,19 @@ def upload():
return get_json_result(
data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR)

e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
raise LookupError("Can't find this knowledgebase!")

root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder["id"]
FileService.init_knowledgebase_docs(pf_id, current_user.id)
kb_root_folder = FileService.get_kb_folder(current_user.id)
kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])

err = []
for file in file_objs:
try:
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
raise LookupError("Can't find this knowledgebase!")
MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(kb.tenant_id) >= MAX_FILE_NUM_PER_USER:
raise RuntimeError("Exceed the maximum file number of a free user!")
Expand Down Expand Up @@ -99,6 +106,8 @@ def upload():
if re.search(r"\.(ppt|pptx|pages)$", filename):
doc["parser_id"] = ParserType.PRESENTATION.value
DocumentService.insert(doc)

FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
except Exception as e:
err.append(file.filename + ": " + str(e))
if err:
Expand Down Expand Up @@ -228,11 +237,13 @@ def rm():
req = request.json
doc_ids = req["doc_id"]
if isinstance(doc_ids, str): doc_ids = [doc_ids]
root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder["id"]
FileService.init_knowledgebase_docs(pf_id, current_user.id)
errors = ""
for doc_id in doc_ids:
try:
e, doc = DocumentService.get_by_id(doc_id)

if not e:
return get_data_error_result(retmsg="Document not found!")
tenant_id = DocumentService.get_tenant_id(doc_id)
Expand All @@ -241,21 +252,25 @@ def rm():

ELASTICSEARCH.deleteByQuery(
Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id))
DocumentService.increment_chunk_num(
doc.id, doc.kb_id, doc.token_num * -1, doc.chunk_num * -1, 0)

DocumentService.clear_chunk_num(doc_id)
b, n = File2DocumentService.get_minio_address(doc_id=doc_id)

if not DocumentService.delete(doc):
return get_data_error_result(
retmsg="Database error (Document removal)!")

informs = File2DocumentService.get_by_document_id(doc_id)
if not informs:
MINIO.rm(doc.kb_id, doc.location)
else:
File2DocumentService.delete_by_document_id(doc_id)
f2d = File2DocumentService.get_by_document_id(doc_id)
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
File2DocumentService.delete_by_document_id(doc_id)

MINIO.rm(b, n)
except Exception as e:
errors += str(e)

if errors: return server_error_response(e)
if errors:
return get_json_result(data=False, retmsg=errors, retcode=RetCode.SERVER_ERROR)

return get_json_result(data=True)


Expand Down
13 changes: 8 additions & 5 deletions api/apps/file_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from api.db.services.file2document_service import File2DocumentService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid
from api.db import FileType
from api.db import FileType, FileSource
from api.db.services import duplicate_name
from api.db.services.file_service import FileService
from api.settings import RetCode
Expand All @@ -45,7 +45,7 @@ def upload():

if not pf_id:
root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder.id
pf_id = root_folder["id"]

if 'file' not in request.files:
return get_json_result(
Expand Down Expand Up @@ -132,7 +132,7 @@ def create():
input_file_type = request.json.get("type")
if not pf_id:
root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder.id
pf_id = root_folder["id"]

try:
if not FileService.is_parent_folder_exist(pf_id):
Expand Down Expand Up @@ -176,7 +176,8 @@ def list():
desc = request.args.get("desc", True)
if not pf_id:
root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder.id
pf_id = root_folder["id"]
FileService.init_knowledgebase_docs(pf_id, current_user.id)
try:
e, file = FileService.get_by_id(pf_id)
if not e:
Expand All @@ -199,7 +200,7 @@ def list():
def get_root_folder():
try:
root_folder = FileService.get_root_folder(current_user.id)
return get_json_result(data={"root_folder": root_folder.to_json()})
return get_json_result(data={"root_folder": root_folder})
except Exception as e:
return server_error_response(e)

Expand Down Expand Up @@ -250,6 +251,8 @@ def rm():
return get_data_error_result(retmsg="File or Folder not found!")
if not file.tenant_id:
return get_data_error_result(retmsg="Tenant not found!")
if file.source_type == FileSource.KNOWLEDGEBASE:
continue

if file.type == FileType.FOLDER.value:
file_id_list = FileService.get_all_innermost_file_ids(file_id, [])
Expand Down
8 changes: 8 additions & 0 deletions api/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,11 @@ class ParserType(StrEnum):
NAIVE = "naive"
PICTURE = "picture"
ONE = "one"


class FileSource(StrEnum):
LOCAL = ""
KNOWLEDGEBASE = "knowledgebase"
S3 = "s3"

KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase"
26 changes: 21 additions & 5 deletions api/db/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@
from functools import wraps
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from flask_login import UserMixin

from playhouse.migrate import MySQLMigrator, migrate
from peewee import (
BigAutoField, BigIntegerField, BooleanField, CharField,
CompositeKey, Insert, IntegerField, TextField, FloatField, DateTimeField,
BigIntegerField, BooleanField, CharField,
CompositeKey, IntegerField, TextField, FloatField, DateTimeField,
Field, Model, Metadata
)
from playhouse.pool import PooledMySQLDatabase

from api.db import SerializedType, ParserType
from api.settings import DATABASE, stat_logger, SECRET_KEY
from api.utils.log_utils import getLogger
Expand Down Expand Up @@ -344,7 +343,7 @@ class Meta:


@DB.connection_context()
def init_database_tables():
def init_database_tables(alter_fields=[]):
members = inspect.getmembers(sys.modules[__name__], inspect.isclass)
table_objs = []
create_failed_list = []
Expand All @@ -361,6 +360,7 @@ def init_database_tables():
if create_failed_list:
LOGGER.info(f"create tables failed: {create_failed_list}")
raise Exception(f"create tables failed: {create_failed_list}")
migrate_db()


def fill_db_model_object(model_object, human_model_dict):
Expand Down Expand Up @@ -699,6 +699,11 @@ class File(DataBaseModel):
help_text="where dose it store")
size = IntegerField(default=0)
type = CharField(max_length=32, null=False, help_text="file extension")
source_type = CharField(
max_length=128,
null=False,
default="",
help_text="where dose this document come from")

class Meta:
db_table = "file"
Expand Down Expand Up @@ -817,3 +822,14 @@ class API4Conversation(DataBaseModel):

class Meta:
db_table = "api_4_conversation"


def migrate_db():
try:
with DB.transaction():
migrator = MySQLMigrator(DB)
migrate(
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", help_text="where dose this document come from"))
)
except Exception as e:
pass
16 changes: 16 additions & 0 deletions api/db/services/document_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,22 @@ def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
Knowledgebase.id == kb_id).execute()
return num

@classmethod
@DB.connection_context()
def clear_chunk_num(cls, doc_id):
doc = cls.model.get_by_id(doc_id)
assert doc, "Can't fine document in database."

num = Knowledgebase.update(
token_num=Knowledgebase.token_num -
doc.token_num,
chunk_num=Knowledgebase.chunk_num -
doc.chunk_num,
doc_num=Knowledgebase.doc_num-1
).where(
Knowledgebase.id == doc.kb_id).execute()
return num

@classmethod
@DB.connection_context()
def get_tenant_id(cls, doc_id):
Expand Down
26 changes: 14 additions & 12 deletions api/db/services/file2document_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
#
from datetime import datetime

from api.db import FileSource
from api.db.db_models import DB
from api.db.db_models import File, Document, File2Document
from api.db.db_models import File, File2Document
from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService
from api.db.services.file_service import FileService
from api.utils import current_timestamp, datetime_format
from api.utils import current_timestamp, datetime_format, get_uuid


class File2DocumentService(CommonService):
Expand Down Expand Up @@ -71,13 +71,15 @@ def update_by_file_id(cls, file_id, obj):
@DB.connection_context()
def get_minio_address(cls, doc_id=None, file_id=None):
if doc_id:
ids = File2DocumentService.get_by_document_id(doc_id)
f2d = cls.get_by_document_id(doc_id)
else:
ids = File2DocumentService.get_by_file_id(file_id)
if ids:
e, file = FileService.get_by_id(ids[0].file_id)
return file.parent_id, file.location
else:
assert doc_id, "please specify doc_id"
e, doc = DocumentService.get_by_id(doc_id)
return doc.kb_id, doc.location
f2d = cls.get_by_file_id(file_id)
if f2d:
file = File.get_by_id(f2d[0].file_id)
if file.source_type == FileSource.LOCAL:
return file.parent_id, file.location
doc_id = f2d[0].document_id

assert doc_id, "please specify doc_id"
e, doc = DocumentService.get_by_id(doc_id)
return doc.kb_id, doc.location
Loading

0 comments on commit 0714e32

Please sign in to comment.