Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug of file management #565

Merged
merged 1 commit into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/apps/file2document_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def convert():
for file_id in file_ids:
e, file = FileService.get_by_id(file_id)
file_ids_list = [file_id]
if file.type == FileType.FOLDER:
if file.type == FileType.FOLDER.value:
file_ids_list = FileService.get_all_innermost_file_ids(file_id, [])
for id in file_ids_list:
informs = File2DocumentService.get_by_file_id(id)
Expand Down
8 changes: 4 additions & 4 deletions api/apps/file_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def upload():
return get_data_error_result(
retmsg="Can't find this folder!")
MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(kb.tenant_id) >= MAX_FILE_NUM_PER_USER:
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(current_user.id) >= MAX_FILE_NUM_PER_USER:
return get_data_error_result(
retmsg="Exceed the maximum file number of a free user!")

Expand Down Expand Up @@ -143,9 +143,9 @@ def create():
retmsg="Duplicated folder name in the same folder.")

if input_file_type == FileType.FOLDER.value:
file_type = FileType.FOLDER
file_type = FileType.FOLDER.value
else:
file_type = FileType.VIRTUAL
file_type = FileType.VIRTUAL.value

file = FileService.insert({
"id": get_uuid(),
Expand Down Expand Up @@ -251,7 +251,7 @@ def rm():
if not file.tenant_id:
return get_data_error_result(retmsg="Tenant not found!")

if file.type == FileType.FOLDER:
if file.type == FileType.FOLDER.value:
file_id_list = FileService.get_all_innermost_file_ids(file_id, [])
for inner_file_id in file_id_list:
e, file = FileService.get_by_id(inner_file_id)
Expand Down
4 changes: 2 additions & 2 deletions api/apps/user_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from api.db.services.llm_service import TenantLLMService, LLMService
from api.utils.api_utils import server_error_response, validate_request
from api.utils import get_uuid, get_format_time, decrypt, download_img, current_timestamp, datetime_format
from api.db import UserTenantRole, LLMType
from api.db import UserTenantRole, LLMType, FileType
from api.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, API_KEY, \
LLM_FACTORY, LLM_BASE_URL
from api.db.services.user_service import UserService, TenantService, UserTenantService
Expand Down Expand Up @@ -229,7 +229,7 @@ def user_register(user_id, user):
"tenant_id": user_id,
"created_by": user_id,
"name": "/",
"type": FileType.FOLDER,
"type": FileType.FOLDER.value,
"size": 0,
"location": "",
}
Expand Down
26 changes: 22 additions & 4 deletions api/db/services/file_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def create_folder(cls, file, parent_id, name, count):
"name": name[count],
"location": "",
"size": 0,
"type": FileType.FOLDER
"type": FileType.FOLDER.value
})
return cls.create_folder(file, file.id, name, count + 1)

Expand All @@ -138,7 +138,23 @@ def is_parent_folder_exist(cls, parent_id):
def get_root_folder(cls, tenant_id):
file = cls.model.select().where(cls.model.tenant_id == tenant_id and
cls.model.parent_id == cls.model.id)
e, file = cls.get_by_id(file[0].id)
if not file:
file_id = get_uuid()
file = {
"id": file_id,
"parent_id": file_id,
"tenant_id": tenant_id,
"created_by": tenant_id,
"name": "/",
"type": FileType.FOLDER.value,
"size": 0,
"location": "",
}
cls.save(**file)
else:
file_id = file[0].id

e, file = cls.get_by_id(file_id)
if not e:
raise RuntimeError("Database error (File retrieval)!")
return file
Expand Down Expand Up @@ -214,12 +230,14 @@ def get_file_count(cls, tenant_id):
@DB.connection_context()
def get_folder_size(cls, folder_id):
size = 0

def dfs(parent_id):
nonlocal size
for f in cls.model.select(*[cls.model.id, cls.model.size, cls.model.type]).where(cls.model.parent_id == parent_id):
for f in cls.model.select(*[cls.model.id, cls.model.size, cls.model.type]).where(
cls.model.parent_id == parent_id, cls.model.id != parent_id):
size += f.size
if f.type == FileType.FOLDER.value:
dfs(f.id)

dfs(folder_id)
return size
return size