diff --git a/api/commands.py b/api/commands.py index 59dfce68e0c92f..334e7daab57997 100644 --- a/api/commands.py +++ b/api/commands.py @@ -587,7 +587,7 @@ def upgrade_db(): click.echo(click.style("Starting database migration.", fg="green")) # run db migration - import flask_migrate + import flask_migrate # type: ignore flask_migrate.upgrade() diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 3ea77fadb23344..5a3c6f843290b8 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -413,7 +413,7 @@ def get(self, dataset_id, document_id): indexing_runner = IndexingRunner() try: - response = indexing_runner.indexing_estimate( + estimate_response = indexing_runner.indexing_estimate( current_user.current_tenant_id, [extract_setting], data_process_rule_dict, @@ -421,6 +421,7 @@ def get(self, dataset_id, document_id): "English", dataset_id, ) + return estimate_response.model_dump(), 200 except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " @@ -431,7 +432,7 @@ def get(self, dataset_id, document_id): except Exception as e: raise IndexingEstimateError(str(e)) - return response.model_dump(), 200 + return response, 200 class DocumentBatchIndexingEstimateApi(DocumentResource): @@ -521,6 +522,7 @@ def get(self, dataset_id, batch): "English", dataset_id, ) + return response.model_dump(), 200 except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " @@ -530,7 +532,6 @@ def get(self, dataset_id, batch): raise ProviderNotInitializeError(ex.description) except Exception as e: raise IndexingEstimateError(str(e)) - return response.model_dump(), 200 class DocumentBatchIndexingStatusApi(DocumentResource): diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 34afe2837f4ca5..84c58c62df5b3c 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -22,6 +22,7 @@ from libs.login import current_user from models.dataset import Dataset, Document, DocumentSegment from services.dataset_service import DocumentService +from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.file_service import FileService @@ -67,13 +68,14 @@ def post(self, tenant_id, dataset_id): "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, } args["data_source"] = data_source + knowledge_config = KnowledgeConfig(**args) # validate args - DocumentService.document_create_args_validate(args) + DocumentService.document_create_args_validate(knowledge_config) try: documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, - document_data=args, + knowledge_config=knowledge_config, account=current_user, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, created_from="api", @@ -122,12 +124,13 @@ def post(self, tenant_id, dataset_id, document_id): args["data_source"] = data_source # validate args args["original_document_id"] = str(document_id) - DocumentService.document_create_args_validate(args) + knowledge_config = KnowledgeConfig(**args) + DocumentService.document_create_args_validate(knowledge_config) try: documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, - document_data=args, + knowledge_config=knowledge_config, account=current_user, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, created_from="api", @@ -186,12 +189,13 @@ def post(self, tenant_id, dataset_id): data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} args["data_source"] = data_source # validate args - DocumentService.document_create_args_validate(args) + knowledge_config = KnowledgeConfig(**args) + DocumentService.document_create_args_validate(knowledge_config) try: documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, - document_data=args, + knowledge_config=knowledge_config, account=dataset.created_by_account, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, created_from="api", @@ -245,12 +249,14 @@ def post(self, tenant_id, dataset_id, document_id): args["data_source"] = data_source # validate args args["original_document_id"] = str(document_id) - DocumentService.document_create_args_validate(args) + + knowledge_config = KnowledgeConfig(**args) + DocumentService.document_create_args_validate(knowledge_config) try: documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, - document_data=args, + knowledge_config=knowledge_config, account=dataset.created_by_account, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, created_from="api", diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index c51dca79efb513..05000c5400ff9e 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -276,7 +276,7 @@ def indexing_estimate( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) - preview_texts = [] + preview_texts = [] # type: ignore total_segments = 0 index_type = doc_form @@ -300,13 +300,13 @@ def indexing_estimate( if len(preview_texts) < 10: if doc_form and doc_form == "qa_model": preview_detail = QAPreviewDetail( - question=document.page_content, answer=document.metadata.get("answer") + question=document.page_content, answer=document.metadata.get("answer") or "" ) preview_texts.append(preview_detail) else: - preview_detail = PreviewDetail(content=document.page_content) + preview_detail = PreviewDetail(content=document.page_content) # type: ignore if document.children: - preview_detail.child_chunks = [child.page_content for child in document.children] + preview_detail.child_chunks = [child.page_content for child in document.children] # type: ignore preview_texts.append(preview_detail) # delete image files and related db records @@ -325,7 +325,7 @@ def indexing_estimate( if doc_form and doc_form == "qa_model": return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[]) - return IndexingEstimate(total_segments=total_segments, preview=preview_texts) + return IndexingEstimate(total_segments=total_segments, preview=preview_texts) # type: ignore def _extract( self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict @@ -454,7 +454,7 @@ def _get_splitter( embedding_model_instance=embedding_model_instance, ) - return character_splitter + return character_splitter # type: ignore def _split_to_documents_for_estimate( self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule @@ -535,7 +535,7 @@ def _load( # create keyword index create_keyword_thread = threading.Thread( target=self._process_keyword_index, - args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), + args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore ) create_keyword_thread.start() diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 568517c0ea6d36..3a8200bc7b5650 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -258,78 +258,79 @@ def format_retrieval_documents(documents: list[Document]) -> list[RetrievalSegme include_segment_ids = [] segment_child_map = {} for document in documents: - document_id = document.metadata["document_id"] + document_id = document.metadata.get("document_id") dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() - if dataset_document and dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: - child_index_node_id = document.metadata["doc_id"] - result = ( - db.session.query(ChildChunk, DocumentSegment) - .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) - .filter( - ChildChunk.index_node_id == child_index_node_id, - DocumentSegment.dataset_id == dataset_document.dataset_id, - DocumentSegment.enabled == True, - DocumentSegment.status == "completed", + if dataset_document: + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_index_node_id = document.metadata.get("doc_id") + result = ( + db.session.query(ChildChunk, DocumentSegment) + .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) + .filter( + ChildChunk.index_node_id == child_index_node_id, + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + ) + .first() ) - .first() - ) - if result: - child_chunk, segment = result - if not segment: - continue - if segment.id not in include_segment_ids: - include_segment_ids.append(segment.id) - child_chunk_detail = { - "id": child_chunk.id, - "content": child_chunk.content, - "position": child_chunk.position, - "score": document.metadata.get("score", 0.0), - } - map_detail = { - "max_score": document.metadata.get("score", 0.0), - "child_chunks": [child_chunk_detail], - } - segment_child_map[segment.id] = map_detail - record = { - "segment": segment, - } - records.append(record) + if result: + child_chunk, segment = result + if not segment: + continue + if segment.id not in include_segment_ids: + include_segment_ids.append(segment.id) + child_chunk_detail = { + "id": child_chunk.id, + "content": child_chunk.content, + "position": child_chunk.position, + "score": document.metadata.get("score", 0.0), + } + map_detail = { + "max_score": document.metadata.get("score", 0.0), + "child_chunks": [child_chunk_detail], + } + segment_child_map[segment.id] = map_detail + record = { + "segment": segment, + } + records.append(record) + else: + child_chunk_detail = { + "id": child_chunk.id, + "content": child_chunk.content, + "position": child_chunk.position, + "score": document.metadata.get("score", 0.0), + } + segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) + segment_child_map[segment.id]["max_score"] = max( + segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) + ) else: - child_chunk_detail = { - "id": child_chunk.id, - "content": child_chunk.content, - "position": child_chunk.position, - "score": document.metadata.get("score", 0.0), - } - segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) - segment_child_map[segment.id]["max_score"] = max( - segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) - ) + continue else: - continue - else: - index_node_id = document.metadata["doc_id"] + index_node_id = document.metadata["doc_id"] - segment = ( - db.session.query(DocumentSegment) - .filter( - DocumentSegment.dataset_id == dataset_document.dataset_id, - DocumentSegment.enabled == True, - DocumentSegment.status == "completed", - DocumentSegment.index_node_id == index_node_id, + segment = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.index_node_id == index_node_id, + ) + .first() ) - .first() - ) - if not segment: - continue - include_segment_ids.append(segment.id) - record = { - "segment": segment, - "score": document.metadata.get("score", None), - } + if not segment: + continue + include_segment_ids.append(segment.id) + record = { + "segment": segment, + "score": document.metadata.get("score", None), + } - records.append(record) + records.append(record) for record in records: if record["segment"].id in segment_child_map: record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None) diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 8dfc60184c2fb6..8b95d81cc1124b 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -122,26 +122,27 @@ def add_documents(self, docs: Sequence[Document], allow_update: bool = True, sav db.session.add(segment_document) db.session.flush() if save_child: - for postion, child in enumerate(doc.children, start=1): - child_segment = ChildChunk( - tenant_id=self._dataset.tenant_id, - dataset_id=self._dataset.id, - document_id=self._document_id, - segment_id=segment_document.id, - position=postion, - index_node_id=child.metadata["doc_id"], - index_node_hash=child.metadata["doc_hash"], - content=child.page_content, - word_count=len(child.page_content), - type="automatic", - created_by=self._user_id, - ) - db.session.add(child_segment) + if doc.children: + for postion, child in enumerate(doc.children, start=1): + child_segment = ChildChunk( + tenant_id=self._dataset.tenant_id, + dataset_id=self._dataset.id, + document_id=self._document_id, + segment_id=segment_document.id, + position=postion, + index_node_id=child.metadata.get("doc_id"), + index_node_hash=child.metadata.get("doc_hash"), + content=child.page_content, + word_count=len(child.page_content), + type="automatic", + created_by=self._user_id, + ) + db.session.add(child_segment) else: segment_document.content = doc.page_content if doc.metadata.get("answer"): segment_document.answer = doc.metadata.pop("answer", "") - segment_document.index_node_hash = doc.metadata["doc_hash"] + segment_document.index_node_hash = doc.metadata.get("doc_hash") segment_document.word_count = len(doc.page_content) segment_document.tokens = tokens if save_child and doc.children: @@ -160,8 +161,8 @@ def add_documents(self, docs: Sequence[Document], allow_update: bool = True, sav document_id=self._document_id, segment_id=segment_document.id, position=position, - index_node_id=child.metadata["doc_id"], - index_node_hash=child.metadata["doc_hash"], + index_node_id=child.metadata.get("doc_id"), + index_node_hash=child.metadata.get("doc_hash"), content=child.page_content, word_count=len(child.page_content), type="automatic", diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py index c444105bb59443..a3b35458df9ab0 100644 --- a/api/core/rag/extractor/excel_extractor.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -4,7 +4,7 @@ from typing import Optional, cast import pandas as pd -from openpyxl import load_workbook +from openpyxl import load_workbook # type: ignore from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index 6d7aa0f7df172e..2bcd1c79bb5dd4 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -81,4 +81,4 @@ def _get_splitter( embedding_model_instance=embedding_model_instance, ) - return character_splitter + return character_splitter # type: ignore diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index ec7126159021ea..dca84b90416e0d 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -30,12 +30,18 @@ def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: def transform(self, documents: list[Document], **kwargs) -> list[Document]: process_rule = kwargs.get("process_rule") + if not process_rule: + raise ValueError("No process rule found.") if process_rule.get("mode") == "automatic": automatic_rule = DatasetProcessRule.AUTOMATIC_RULES rules = Rule(**automatic_rule) else: + if not process_rule.get("rules"): + raise ValueError("No rules found in process rule.") rules = Rule(**process_rule.get("rules")) # Split the text documents into nodes. + if not rules.segmentation: + raise ValueError("No segmentation found in rules.") splitter = self._get_splitter( processing_rule_mode=process_rule.get("mode"), max_tokens=rules.segmentation.max_tokens, diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 7ff15b9f4c86d5..e8423e2b777b15 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -30,8 +30,12 @@ def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: def transform(self, documents: list[Document], **kwargs) -> list[Document]: process_rule = kwargs.get("process_rule") + if not process_rule: + raise ValueError("No process rule found.") + if not process_rule.get("rules"): + raise ValueError("No rules found in process rule.") rules = Rule(**process_rule.get("rules")) - all_documents = [] + all_documents = [] # type: ignore if rules.parent_mode == ParentMode.PARAGRAPH: # Split the text documents into nodes. splitter = self._get_splitter( @@ -161,6 +165,8 @@ def _split_child_nodes( process_rule_mode: str, embedding_model_instance: Optional[ModelInstance], ) -> list[ChildDocument]: + if not rules.subchunk_segmentation: + raise ValueError("No subchunk segmentation found in rules.") child_splitter = self._get_splitter( processing_rule_mode=process_rule_mode, max_tokens=rules.subchunk_segmentation.max_tokens, diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 6535d4626117f8..58b50a9fcbc67e 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -37,12 +37,16 @@ def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: def transform(self, documents: list[Document], **kwargs) -> list[Document]: preview = kwargs.get("preview") process_rule = kwargs.get("process_rule") + if not process_rule: + raise ValueError("No process rule found.") + if not process_rule.get("rules"): + raise ValueError("No rules found in process rule.") rules = Rule(**process_rule.get("rules")) splitter = self._get_splitter( processing_rule_mode=process_rule.get("mode"), - max_tokens=rules.segmentation.max_tokens, - chunk_overlap=rules.segmentation.chunk_overlap, - separator=rules.segmentation.separator, + max_tokens=rules.segmentation.max_tokens if rules.segmentation else 0, + chunk_overlap=rules.segmentation.chunk_overlap if rules.segmentation else 0, + separator=rules.segmentation.separator if rules.segmentation else "", embedding_model_instance=kwargs.get("embedding_model_instance"), ) @@ -71,8 +75,8 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]: all_documents.extend(split_documents) if preview: self._format_qa_document( - current_app._get_current_object(), - kwargs.get("tenant_id"), + current_app._get_current_object(), # type: ignore + kwargs.get("tenant_id"), # type: ignore all_documents[0], all_qa_documents, kwargs.get("doc_language", "English"), @@ -85,8 +89,8 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]: document_format_thread = threading.Thread( target=self._format_qa_document, kwargs={ - "flask_app": current_app._get_current_object(), - "tenant_id": kwargs.get("tenant_id"), + "flask_app": current_app._get_current_object(), # type: ignore + "tenant_id": kwargs.get("tenant_id"), # type: ignore "document_node": doc, "all_qa_documents": all_qa_documents, "document_language": kwargs.get("doc_language", "English"), diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index a34afc7bd75fda..421cdc05df7cc0 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from typing import Any, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel class ChildDocument(BaseModel): @@ -15,7 +15,7 @@ class ChildDocument(BaseModel): """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). """ - metadata: Optional[dict] = Field(default_factory=dict) + metadata: dict = {} class Document(BaseModel): @@ -28,7 +28,7 @@ class Document(BaseModel): """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). """ - metadata: Optional[dict] = Field(default_factory=dict) + metadata: dict = {} provider: Optional[str] = "dify" diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py index fcd1547a2fc492..316be12f5c14b8 100644 --- a/api/extensions/ext_blueprints.py +++ b/api/extensions/ext_blueprints.py @@ -5,7 +5,7 @@ def init_app(app: DifyApp): # register blueprint routers - from flask_cors import CORS + from flask_cors import CORS # type: ignore from controllers.console import bp as console_app_bp from controllers.files import bp as files_bp diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py index 4861e33d585781..766954a257371f 100644 --- a/api/schedule/mail_clean_document_notify_task.py +++ b/api/schedule/mail_clean_document_notify_task.py @@ -1,9 +1,9 @@ import logging import time +from collections import defaultdict import click from celery import shared_task # type: ignore -from flask import render_template from extensions.ext_mail import mail from models.account import Account, Tenant, TenantAccountJoin @@ -27,7 +27,7 @@ def send_document_clean_notify_task(): try: dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all() # group by tenant_id - dataset_auto_disable_logs_map = {} + dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) for dataset_auto_disable_log in dataset_auto_disable_logs: dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log) @@ -37,11 +37,13 @@ def send_document_clean_notify_task(): if not tenant: continue current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() + if not current_owner_join: + continue account = Account.query.filter(Account.id == current_owner_join.account_id).first() if not account: continue - dataset_auto_dataset_map = {} + dataset_auto_dataset_map = {} # type: ignore for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( dataset_auto_disable_log.document_id @@ -53,14 +55,9 @@ def send_document_clean_notify_task(): document_count = len(document_ids) knowledge_details.append(f"
  • Knowledge base {dataset.name}: {document_count} documents
  • ") - html_content = render_template( - "clean_document_job_mail_template-US.html", - ) - mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content) - end_at = time.perf_counter() logging.info( click.style("Send document clean notify mail succeeded: latency: {}".format(end_at - start_at), fg="green") ) except Exception: - logging.exception("Send invite member mail to {} failed".format(to)) + logging.exception("Send invite member mail to failed") diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index b191fa2397fa9e..528a0dbcd39d9e 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -4,7 +4,7 @@ from typing import Optional, cast from uuid import uuid4 -import yaml +import yaml # type: ignore from packaging import version from pydantic import BaseModel from sqlalchemy import select @@ -465,7 +465,7 @@ def export_dsl(cls, app_model: App, include_secret: bool = False) -> str: else: cls._append_model_config_export_data(export_data, app_model) - return yaml.dump(export_data, allow_unicode=True) + return yaml.dump(export_data, allow_unicode=True) # type: ignore @classmethod def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None: diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 8de28085d45457..1fd18568f54126 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -41,6 +41,7 @@ from services.entities.knowledge_entities.knowledge_entities import ( ChildChunkUpdateArgs, KnowledgeConfig, + RerankingModel, RetrievalModel, SegmentUpdateArgs, ) @@ -548,12 +549,14 @@ class DocumentService: } @staticmethod - def get_document(dataset_id: str, document_id: str) -> Optional[Document]: - document = ( - db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() - ) - - return document + def get_document(dataset_id: str, document_id: Optional[str] = None) -> Optional[Document]: + if document_id: + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) + return document + else: + return None @staticmethod def get_document_by_id(document_id: str) -> Optional[Document]: @@ -744,25 +747,26 @@ def save_document_with_dataset_id( if features.billing.enabled: if not knowledge_config.original_document_id: count = 0 - if knowledge_config.data_source.info_list.data_source_type == "upload_file": - upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids - count = len(upload_file_list) - elif knowledge_config.data_source.info_list.data_source_type == "notion_import": - notion_info_list = knowledge_config.data_source.info_list.notion_info_list - for notion_info in notion_info_list: - count = count + len(notion_info.pages) - elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": - website_info = knowledge_config.data_source.info_list.website_info_list - count = len(website_info.urls) - batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) - if count > batch_upload_limit: - raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") - - DocumentService.check_documents_upload_quota(count, features) + if knowledge_config.data_source: + if knowledge_config.data_source.info_list.data_source_type == "upload_file": + upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore + count = len(upload_file_list) + elif knowledge_config.data_source.info_list.data_source_type == "notion_import": + notion_info_list = knowledge_config.data_source.info_list.notion_info_list + for notion_info in notion_info_list: # type: ignore + count = count + len(notion_info.pages) + elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": + website_info = knowledge_config.data_source.info_list.website_info_list + count = len(website_info.urls) # type: ignore + batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) + if count > batch_upload_limit: + raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") + + DocumentService.check_documents_upload_quota(count, features) # if dataset is empty, update dataset data_source_type if not dataset.data_source_type: - dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type + dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore if not dataset.indexing_technique: if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: @@ -789,7 +793,7 @@ def save_document_with_dataset_id( "score_threshold_enabled": False, } - dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model + dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model # type: ignore documents = [] if knowledge_config.original_document_id: @@ -801,34 +805,35 @@ def save_document_with_dataset_id( # save process rule if not dataset_process_rule: process_rule = knowledge_config.process_rule - if process_rule.mode in ("custom", "hierarchical"): - dataset_process_rule = DatasetProcessRule( - dataset_id=dataset.id, - mode=process_rule.mode, - rules=process_rule.rules.model_dump_json(), - created_by=account.id, - ) - elif process_rule.mode == "automatic": - dataset_process_rule = DatasetProcessRule( - dataset_id=dataset.id, - mode=process_rule.mode, - rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), - created_by=account.id, - ) - else: - logging.warn( - f"Invalid process rule mode: {process_rule['mode']}, can not find dataset process rule" - ) - return - db.session.add(dataset_process_rule) - db.session.commit() + if process_rule: + if process_rule.mode in ("custom", "hierarchical"): + dataset_process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode=process_rule.mode, + rules=process_rule.rules.model_dump_json() if process_rule.rules else None, + created_by=account.id, + ) + elif process_rule.mode == "automatic": + dataset_process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode=process_rule.mode, + rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), + created_by=account.id, + ) + else: + logging.warn( + f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule" + ) + return + db.session.add(dataset_process_rule) + db.session.commit() lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) with redis_client.lock(lock_name, timeout=600): position = DocumentService.get_documents_position(dataset.id) document_ids = [] duplicate_document_ids = [] if knowledge_config.data_source.info_list.data_source_type == "upload_file": - upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids + upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore for file_id in upload_file_list: file = ( db.session.query(UploadFile) @@ -854,7 +859,7 @@ def save_document_with_dataset_id( name=file_name, ).first() if document: - document.dataset_process_rule_id = dataset_process_rule.id + document.dataset_process_rule_id = dataset_process_rule.id # type: ignore document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) document.created_from = created_from document.doc_form = knowledge_config.doc_form @@ -868,7 +873,7 @@ def save_document_with_dataset_id( continue document = DocumentService.build_document( dataset, - dataset_process_rule.id, + dataset_process_rule.id, # type: ignore knowledge_config.data_source.info_list.data_source_type, knowledge_config.doc_form, knowledge_config.doc_language, @@ -886,6 +891,8 @@ def save_document_with_dataset_id( position += 1 elif knowledge_config.data_source.info_list.data_source_type == "notion_import": notion_info_list = knowledge_config.data_source.info_list.notion_info_list + if not notion_info_list: + raise ValueError("No notion info list found.") exist_page_ids = [] exist_document = {} documents = Document.query.filter_by( @@ -921,7 +928,7 @@ def save_document_with_dataset_id( } document = DocumentService.build_document( dataset, - dataset_process_rule.id, + dataset_process_rule.id, # type: ignore knowledge_config.data_source.info_list.data_source_type, knowledge_config.doc_form, knowledge_config.doc_language, @@ -944,6 +951,8 @@ def save_document_with_dataset_id( clean_notion_document_task.delay(list(exist_document.values()), dataset.id) elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": website_info = knowledge_config.data_source.info_list.website_info_list + if not website_info: + raise ValueError("No website info list found.") urls = website_info.urls for url in urls: data_source_info = { @@ -959,7 +968,7 @@ def save_document_with_dataset_id( document_name = url document = DocumentService.build_document( dataset, - dataset_process_rule.id, + dataset_process_rule.id, # type: ignore knowledge_config.data_source.info_list.data_source_type, knowledge_config.doc_form, knowledge_config.doc_language, @@ -1054,7 +1063,7 @@ def update_document_with_dataset_id( dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, mode=process_rule.mode, - rules=process_rule.rules.model_dump_json(), + rules=process_rule.rules.model_dump_json() if process_rule.rules else None, created_by=account.id, ) elif process_rule.mode == "automatic": @@ -1073,6 +1082,8 @@ def update_document_with_dataset_id( file_name = "" data_source_info = {} if document_data.data_source.info_list.data_source_type == "upload_file": + if not document_data.data_source.info_list.file_info_list: + raise ValueError("No file info list found.") upload_file_list = document_data.data_source.info_list.file_info_list.file_ids for file_id in upload_file_list: file = ( @@ -1090,6 +1101,8 @@ def update_document_with_dataset_id( "upload_file_id": file_id, } elif document_data.data_source.info_list.data_source_type == "notion_import": + if not document_data.data_source.info_list.notion_info_list: + raise ValueError("No notion info list found.") notion_info_list = document_data.data_source.info_list.notion_info_list for notion_info in notion_info_list: workspace_id = notion_info.workspace_id @@ -1107,20 +1120,21 @@ def update_document_with_dataset_id( data_source_info = { "notion_workspace_id": workspace_id, "notion_page_id": page.page_id, - "notion_page_icon": page.page_icon, + "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore "type": page.type, } elif document_data.data_source.info_list.data_source_type == "website_crawl": website_info = document_data.data_source.info_list.website_info_list - urls = website_info.urls - for url in urls: - data_source_info = { - "url": url, - "provider": website_info.provider, - "job_id": website_info.job_id, - "only_main_content": website_info.only_main_content, - "mode": "crawl", - } + if website_info: + urls = website_info.urls + for url in urls: + data_source_info = { + "url": url, + "provider": website_info.provider, + "job_id": website_info.job_id, + "only_main_content": website_info.only_main_content, # type: ignore + "mode": "crawl", + } document.data_source_type = document_data.data_source.info_list.data_source_type document.data_source_info = json.dumps(data_source_info) document.name = file_name @@ -1155,15 +1169,21 @@ def save_document_without_dataset_id(tenant_id: str, knowledge_config: Knowledge if features.billing.enabled: count = 0 if knowledge_config.data_source.info_list.data_source_type == "upload_file": - upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids + upload_file_list = ( + knowledge_config.data_source.info_list.file_info_list.file_ids + if knowledge_config.data_source.info_list.file_info_list + else [] + ) count = len(upload_file_list) elif knowledge_config.data_source.info_list.data_source_type == "notion_import": notion_info_list = knowledge_config.data_source.info_list.notion_info_list - for notion_info in notion_info_list: - count = count + len(notion_info.pages) + if notion_info_list: + for notion_info in notion_info_list: + count = count + len(notion_info.pages) elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": website_info = knowledge_config.data_source.info_list.website_info_list - count = len(website_info.urls) + if website_info: + count = len(website_info.urls) batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") @@ -1174,20 +1194,20 @@ def save_document_without_dataset_id(tenant_id: str, knowledge_config: Knowledge retrieval_model = None if knowledge_config.indexing_technique == "high_quality": dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - knowledge_config.embedding_model_provider, knowledge_config.embedding_model + knowledge_config.embedding_model_provider, # type: ignore + knowledge_config.embedding_model, # type: ignore ) dataset_collection_binding_id = dataset_collection_binding.id if knowledge_config.retrieval_model: retrieval_model = knowledge_config.retrieval_model else: - default_retrieval_model = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, - "reranking_enable": False, - "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, - "top_k": 2, - "score_threshold_enabled": False, - } - retrieval_model = RetrievalModel(**default_retrieval_model) + retrieval_model = RetrievalModel( + search_method=RetrievalMethod.SEMANTIC_SEARCH.value, + reranking_enable=False, + reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""), + top_k=2, + score_threshold_enabled=False, + ) # save dataset dataset = Dataset( tenant_id=tenant_id, @@ -1557,12 +1577,12 @@ def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, docum raise ValueError("Can't update disabled segment") try: word_count_change = segment.word_count - content = args.content + content = args.content or segment.content if segment.content == content: segment.word_count = len(content) if document.doc_form == "qa_model": segment.answer = args.answer - segment.word_count += len(args.answer) + segment.word_count += len(args.answer) if args.answer else 0 word_count_change = segment.word_count - word_count_change if args.keywords: segment.keywords = args.keywords @@ -1577,7 +1597,12 @@ def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, docum db.session.add(document) # update segment index task if args.enabled: - VectorService.create_segments_vector([args.keywords], [segment], dataset) + VectorService.create_segments_vector( + [args.keywords] if args.keywords else None, + [segment], + dataset, + document.doc_form, + ) if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: # regenerate child chunks # get embedding model instance @@ -1605,6 +1630,8 @@ def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, docum .filter(DatasetProcessRule.id == document.dataset_process_rule_id) .first() ) + if not processing_rule: + raise ValueError("No processing rule found.") VectorService.generate_child_chunks( segment, document, dataset, embedding_model_instance, processing_rule, True ) @@ -1639,7 +1666,7 @@ def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, docum segment.disabled_by = None if document.doc_form == "qa_model": segment.answer = args.answer - segment.word_count += len(args.answer) + segment.word_count += len(args.answer) if args.answer else 0 word_count_change = segment.word_count - word_count_change # update document word count if word_count_change != 0: @@ -1673,6 +1700,8 @@ def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, docum .filter(DatasetProcessRule.id == document.dataset_process_rule_id) .first() ) + if not processing_rule: + raise ValueError("No processing rule found.") VectorService.generate_child_chunks( segment, document, dataset, embedding_model_instance, processing_rule, True ) diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index 8d6a246b6428d0..76d9c28812eaf4 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -97,7 +97,7 @@ class KnowledgeConfig(BaseModel): original_document_id: Optional[str] = None duplicate: bool = True indexing_technique: Literal["high_quality", "economy"] - data_source: Optional[DataSource] = None + data_source: DataSource process_rule: Optional[ProcessRule] = None retrieval_model: Optional[RetrievalModel] = None doc_form: str = "text_model" diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 0e61beaa90ef20..e9176fc1c6015c 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -69,7 +69,7 @@ def retrieve( db.session.add(dataset_query) db.session.commit() - return cls.compact_retrieve_response(query, all_documents) + return cls.compact_retrieve_response(query, all_documents) # type: ignore @classmethod def external_retrieve( diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 6698e6e7188223..92422bf29dc121 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -29,6 +29,8 @@ def create_segments_vector( .filter(DatasetProcessRule.id == document.dataset_process_rule_id) .first() ) + if not processing_rule: + raise ValueError("No processing rule found.") # get embedding model instance if dataset.indexing_technique == "high_quality": # check embedding model setting @@ -98,7 +100,7 @@ def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentS def generate_child_chunks( cls, segment: DocumentSegment, - dataset_document: Document, + dataset_document: DatasetDocument, dataset: Dataset, embedding_model_instance: ModelInstance, processing_rule: DatasetProcessRule, @@ -130,7 +132,7 @@ def generate_child_chunks( doc_language=dataset_document.doc_language, ) # save child chunks - if len(documents) > 0 and len(documents[0].children) > 0: + if documents and documents[0].children: index_processor.load(dataset, documents) for position, child_chunk in enumerate(documents[0].children, start=1): diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py index f52440e6438ad0..3bae82a5e3fff9 100644 --- a/api/tasks/batch_clean_document_task.py +++ b/api/tasks/batch_clean_document_task.py @@ -44,7 +44,8 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form for upload_file_id in image_upload_file_ids: image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() try: - storage.delete(image_file.key) + if image_file and image_file.key: + storage.delete(image_file.key) except Exception: logging.exception( "Delete image_files failed when storage deleted, \