Skip to content

Commit

Permalink
fix bug of file management (#565)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
  • Loading branch information
KevinHuSh authored Apr 26, 2024
1 parent 6329339 commit ab06f50
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 11 deletions.
2 changes: 1 addition & 1 deletion api/apps/file2document_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def convert():
for file_id in file_ids:
e, file = FileService.get_by_id(file_id)
file_ids_list = [file_id]
if file.type == FileType.FOLDER:
if file.type == FileType.FOLDER.value:
file_ids_list = FileService.get_all_innermost_file_ids(file_id, [])
for id in file_ids_list:
informs = File2DocumentService.get_by_file_id(id)
Expand Down
8 changes: 4 additions & 4 deletions api/apps/file_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def upload():
return get_data_error_result(
retmsg="Can't find this folder!")
MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(kb.tenant_id) >= MAX_FILE_NUM_PER_USER:
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(current_user.id) >= MAX_FILE_NUM_PER_USER:
return get_data_error_result(
retmsg="Exceed the maximum file number of a free user!")

Expand Down Expand Up @@ -143,9 +143,9 @@ def create():
retmsg="Duplicated folder name in the same folder.")

if input_file_type == FileType.FOLDER.value:
file_type = FileType.FOLDER
file_type = FileType.FOLDER.value
else:
file_type = FileType.VIRTUAL
file_type = FileType.VIRTUAL.value

file = FileService.insert({
"id": get_uuid(),
Expand Down Expand Up @@ -251,7 +251,7 @@ def rm():
if not file.tenant_id:
return get_data_error_result(retmsg="Tenant not found!")

if file.type == FileType.FOLDER:
if file.type == FileType.FOLDER.value:
file_id_list = FileService.get_all_innermost_file_ids(file_id, [])
for inner_file_id in file_id_list:
e, file = FileService.get_by_id(inner_file_id)
Expand Down
4 changes: 2 additions & 2 deletions api/apps/user_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from api.db.services.llm_service import TenantLLMService, LLMService
from api.utils.api_utils import server_error_response, validate_request
from api.utils import get_uuid, get_format_time, decrypt, download_img, current_timestamp, datetime_format
from api.db import UserTenantRole, LLMType
from api.db import UserTenantRole, LLMType, FileType
from api.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, API_KEY, \
LLM_FACTORY, LLM_BASE_URL
from api.db.services.user_service import UserService, TenantService, UserTenantService
Expand Down Expand Up @@ -229,7 +229,7 @@ def user_register(user_id, user):
"tenant_id": user_id,
"created_by": user_id,
"name": "/",
"type": FileType.FOLDER,
"type": FileType.FOLDER.value,
"size": 0,
"location": "",
}
Expand Down
26 changes: 22 additions & 4 deletions api/db/services/file_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def create_folder(cls, file, parent_id, name, count):
"name": name[count],
"location": "",
"size": 0,
"type": FileType.FOLDER
"type": FileType.FOLDER.value
})
return cls.create_folder(file, file.id, name, count + 1)

Expand All @@ -138,7 +138,23 @@ def is_parent_folder_exist(cls, parent_id):
def get_root_folder(cls, tenant_id):
file = cls.model.select().where(cls.model.tenant_id == tenant_id and
cls.model.parent_id == cls.model.id)
e, file = cls.get_by_id(file[0].id)
if not file:
file_id = get_uuid()
file = {
"id": file_id,
"parent_id": file_id,
"tenant_id": tenant_id,
"created_by": tenant_id,
"name": "/",
"type": FileType.FOLDER.value,
"size": 0,
"location": "",
}
cls.save(**file)
else:
file_id = file[0].id

e, file = cls.get_by_id(file_id)
if not e:
raise RuntimeError("Database error (File retrieval)!")
return file
Expand Down Expand Up @@ -214,12 +230,14 @@ def get_file_count(cls, tenant_id):
@DB.connection_context()
def get_folder_size(cls, folder_id):
size = 0

def dfs(parent_id):
nonlocal size
for f in cls.model.select(*[cls.model.id, cls.model.size, cls.model.type]).where(cls.model.parent_id == parent_id):
for f in cls.model.select(*[cls.model.id, cls.model.size, cls.model.type]).where(
cls.model.parent_id == parent_id, cls.model.id != parent_id):
size += f.size
if f.type == FileType.FOLDER.value:
dfs(f.id)

dfs(folder_id)
return size
return size

0 comments on commit ab06f50

Please sign in to comment.