From d1614107e2c27d28433a11f69aacba273c2c0045 Mon Sep 17 00:00:00 2001 From: KevinHuSh Date: Fri, 17 May 2024 12:07:00 +0800 Subject: [PATCH] fix stream chat for ollama (#816) ### What problem does this PR solve? #709 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- api/apps/api_app.py | 47 ++++++++++++++++++++++++++++++++++++++++--- rag/llm/chat_model.py | 4 ++-- 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/api/apps/api_app.py b/api/apps/api_app.py index 9f80996d99..bc4fadf5e2 100644 --- a/api/apps/api_app.py +++ b/api/apps/api_app.py @@ -21,12 +21,15 @@ from flask_login import login_required, current_user from api.db import FileType, ParserType -from api.db.db_models import APIToken, API4Conversation +from api.db.db_models import APIToken, API4Conversation, Task from api.db.services import duplicate_name from api.db.services.api_service import APITokenService, API4ConversationService from api.db.services.dialog_service import DialogService, chat from api.db.services.document_service import DocumentService +from api.db.services.file2document_service import File2DocumentService +from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.task_service import queue_tasks, TaskService from api.db.services.user_service import UserTenantService from api.settings import RetCode from api.utils import get_uuid, current_timestamp, datetime_format @@ -267,6 +270,13 @@ def upload(): if file.filename == '': return get_json_result( data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR) + + root_folder = FileService.get_root_folder(tenant_id) + pf_id = root_folder["id"] + FileService.init_knowledgebase_docs(pf_id, tenant_id) + kb_root_folder = FileService.get_kb_folder(tenant_id) + kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"]) + try: if DocumentService.get_doc_count(kb.tenant_id) >= int(os.environ.get('MAX_FILE_NUM_PER_USER', 8192)): return get_data_error_result( @@ -298,11 +308,42 @@ def upload(): "size": len(blob), "thumbnail": thumbnail(filename, blob) } + + form_data=request.form + if "parser_id" in form_data.keys(): + if request.form.get("parser_id").strip() in list(vars(ParserType).values())[1:-3]: + doc["parser_id"] = request.form.get("parser_id").strip() if doc["type"] == FileType.VISUAL: doc["parser_id"] = ParserType.PICTURE.value if re.search(r"\.(ppt|pptx|pages)$", filename): doc["parser_id"] = ParserType.PRESENTATION.value - doc = DocumentService.insert(doc) - return get_json_result(data=doc.to_json()) + + doc_result = DocumentService.insert(doc) + FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id) except Exception as e: return server_error_response(e) + + if "run" in form_data.keys(): + if request.form.get("run").strip() == "1": + try: + info = {"run": 1, "progress": 0} + info["progress_msg"] = "" + info["chunk_num"] = 0 + info["token_num"] = 0 + DocumentService.update_by_id(doc["id"], info) + # if str(req["run"]) == TaskStatus.CANCEL.value: + tenant_id = DocumentService.get_tenant_id(doc["id"]) + if not tenant_id: + return get_data_error_result(retmsg="Tenant not found!") + + #e, doc = DocumentService.get_by_id(doc["id"]) + TaskService.filter_delete([Task.doc_id == doc["id"]]) + e, doc = DocumentService.get_by_id(doc["id"]) + doc = doc.to_dict() + doc["tenant_id"] = tenant_id + bucket, name = File2DocumentService.get_minio_address(doc_id=doc["id"]) + queue_tasks(doc, bucket, name) + except Exception as e: + return server_error_response(e) + + return get_json_result(data=doc_result.to_json()) diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index f8e741e764..0c53279f69 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -248,8 +248,8 @@ def chat_streamly(self, system, history, gen_conf): ) for resp in response: if resp["done"]: - return resp["prompt_eval_count"] + resp["eval_count"] - ans = resp["message"]["content"] + yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0) + ans += resp["message"]["content"] yield ans except Exception as e: yield ans + "\n**ERROR**: " + str(e)