diff --git a/api/commands.py b/api/commands.py index 59dfce68e0c92f..334e7daab57997 100644 --- a/api/commands.py +++ b/api/commands.py @@ -587,7 +587,7 @@ def upgrade_db(): click.echo(click.style("Starting database migration.", fg="green")) # run db migration - import flask_migrate + import flask_migrate # type: ignore flask_migrate.upgrade() diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 3ea77fadb23344..5a3c6f843290b8 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -413,7 +413,7 @@ def get(self, dataset_id, document_id): indexing_runner = IndexingRunner() try: - response = indexing_runner.indexing_estimate( + estimate_response = indexing_runner.indexing_estimate( current_user.current_tenant_id, [extract_setting], data_process_rule_dict, @@ -421,6 +421,7 @@ def get(self, dataset_id, document_id): "English", dataset_id, ) + return estimate_response.model_dump(), 200 except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " @@ -431,7 +432,7 @@ def get(self, dataset_id, document_id): except Exception as e: raise IndexingEstimateError(str(e)) - return response.model_dump(), 200 + return response, 200 class DocumentBatchIndexingEstimateApi(DocumentResource): @@ -521,6 +522,7 @@ def get(self, dataset_id, batch): "English", dataset_id, ) + return response.model_dump(), 200 except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " @@ -530,7 +532,6 @@ def get(self, dataset_id, batch): raise ProviderNotInitializeError(ex.description) except Exception as e: raise IndexingEstimateError(str(e)) - return response.model_dump(), 200 class DocumentBatchIndexingStatusApi(DocumentResource): diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 34afe2837f4ca5..84c58c62df5b3c 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -22,6 +22,7 @@ from libs.login import current_user from models.dataset import Dataset, Document, DocumentSegment from services.dataset_service import DocumentService +from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.file_service import FileService @@ -67,13 +68,14 @@ def post(self, tenant_id, dataset_id): "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, } args["data_source"] = data_source + knowledge_config = KnowledgeConfig(**args) # validate args - DocumentService.document_create_args_validate(args) + DocumentService.document_create_args_validate(knowledge_config) try: documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, - document_data=args, + knowledge_config=knowledge_config, account=current_user, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, created_from="api", @@ -122,12 +124,13 @@ def post(self, tenant_id, dataset_id, document_id): args["data_source"] = data_source # validate args args["original_document_id"] = str(document_id) - DocumentService.document_create_args_validate(args) + knowledge_config = KnowledgeConfig(**args) + DocumentService.document_create_args_validate(knowledge_config) try: documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, - document_data=args, + knowledge_config=knowledge_config, account=current_user, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, created_from="api", @@ -186,12 +189,13 @@ def post(self, tenant_id, dataset_id): data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} args["data_source"] = data_source # validate args - DocumentService.document_create_args_validate(args) + knowledge_config = KnowledgeConfig(**args) + DocumentService.document_create_args_validate(knowledge_config) try: documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, - document_data=args, + knowledge_config=knowledge_config, account=dataset.created_by_account, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, created_from="api", @@ -245,12 +249,14 @@ def post(self, tenant_id, dataset_id, document_id): args["data_source"] = data_source # validate args args["original_document_id"] = str(document_id) - DocumentService.document_create_args_validate(args) + + knowledge_config = KnowledgeConfig(**args) + DocumentService.document_create_args_validate(knowledge_config) try: documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, - document_data=args, + knowledge_config=knowledge_config, account=dataset.created_by_account, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, created_from="api", diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index c51dca79efb513..05000c5400ff9e 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -276,7 +276,7 @@ def indexing_estimate( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) - preview_texts = [] + preview_texts = [] # type: ignore total_segments = 0 index_type = doc_form @@ -300,13 +300,13 @@ def indexing_estimate( if len(preview_texts) < 10: if doc_form and doc_form == "qa_model": preview_detail = QAPreviewDetail( - question=document.page_content, answer=document.metadata.get("answer") + question=document.page_content, answer=document.metadata.get("answer") or "" ) preview_texts.append(preview_detail) else: - preview_detail = PreviewDetail(content=document.page_content) + preview_detail = PreviewDetail(content=document.page_content) # type: ignore if document.children: - preview_detail.child_chunks = [child.page_content for child in document.children] + preview_detail.child_chunks = [child.page_content for child in document.children] # type: ignore preview_texts.append(preview_detail) # delete image files and related db records @@ -325,7 +325,7 @@ def indexing_estimate( if doc_form and doc_form == "qa_model": return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[]) - return IndexingEstimate(total_segments=total_segments, preview=preview_texts) + return IndexingEstimate(total_segments=total_segments, preview=preview_texts) # type: ignore def _extract( self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict @@ -454,7 +454,7 @@ def _get_splitter( embedding_model_instance=embedding_model_instance, ) - return character_splitter + return character_splitter # type: ignore def _split_to_documents_for_estimate( self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule @@ -535,7 +535,7 @@ def _load( # create keyword index create_keyword_thread = threading.Thread( target=self._process_keyword_index, - args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), + args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore ) create_keyword_thread.start() diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 568517c0ea6d36..3a8200bc7b5650 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -258,78 +258,79 @@ def format_retrieval_documents(documents: list[Document]) -> list[RetrievalSegme include_segment_ids = [] segment_child_map = {} for document in documents: - document_id = document.metadata["document_id"] + document_id = document.metadata.get("document_id") dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() - if dataset_document and dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: - child_index_node_id = document.metadata["doc_id"] - result = ( - db.session.query(ChildChunk, DocumentSegment) - .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) - .filter( - ChildChunk.index_node_id == child_index_node_id, - DocumentSegment.dataset_id == dataset_document.dataset_id, - DocumentSegment.enabled == True, - DocumentSegment.status == "completed", + if dataset_document: + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_index_node_id = document.metadata.get("doc_id") + result = ( + db.session.query(ChildChunk, DocumentSegment) + .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) + .filter( + ChildChunk.index_node_id == child_index_node_id, + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + ) + .first() ) - .first() - ) - if result: - child_chunk, segment = result - if not segment: - continue - if segment.id not in include_segment_ids: - include_segment_ids.append(segment.id) - child_chunk_detail = { - "id": child_chunk.id, - "content": child_chunk.content, - "position": child_chunk.position, - "score": document.metadata.get("score", 0.0), - } - map_detail = { - "max_score": document.metadata.get("score", 0.0), - "child_chunks": [child_chunk_detail], - } - segment_child_map[segment.id] = map_detail - record = { - "segment": segment, - } - records.append(record) + if result: + child_chunk, segment = result + if not segment: + continue + if segment.id not in include_segment_ids: + include_segment_ids.append(segment.id) + child_chunk_detail = { + "id": child_chunk.id, + "content": child_chunk.content, + "position": child_chunk.position, + "score": document.metadata.get("score", 0.0), + } + map_detail = { + "max_score": document.metadata.get("score", 0.0), + "child_chunks": [child_chunk_detail], + } + segment_child_map[segment.id] = map_detail + record = { + "segment": segment, + } + records.append(record) + else: + child_chunk_detail = { + "id": child_chunk.id, + "content": child_chunk.content, + "position": child_chunk.position, + "score": document.metadata.get("score", 0.0), + } + segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) + segment_child_map[segment.id]["max_score"] = max( + segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) + ) else: - child_chunk_detail = { - "id": child_chunk.id, - "content": child_chunk.content, - "position": child_chunk.position, - "score": document.metadata.get("score", 0.0), - } - segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) - segment_child_map[segment.id]["max_score"] = max( - segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) - ) + continue else: - continue - else: - index_node_id = document.metadata["doc_id"] + index_node_id = document.metadata["doc_id"] - segment = ( - db.session.query(DocumentSegment) - .filter( - DocumentSegment.dataset_id == dataset_document.dataset_id, - DocumentSegment.enabled == True, - DocumentSegment.status == "completed", - DocumentSegment.index_node_id == index_node_id, + segment = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.index_node_id == index_node_id, + ) + .first() ) - .first() - ) - if not segment: - continue - include_segment_ids.append(segment.id) - record = { - "segment": segment, - "score": document.metadata.get("score", None), - } + if not segment: + continue + include_segment_ids.append(segment.id) + record = { + "segment": segment, + "score": document.metadata.get("score", None), + } - records.append(record) + records.append(record) for record in records: if record["segment"].id in segment_child_map: record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None) diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 8dfc60184c2fb6..8b95d81cc1124b 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -122,26 +122,27 @@ def add_documents(self, docs: Sequence[Document], allow_update: bool = True, sav db.session.add(segment_document) db.session.flush() if save_child: - for postion, child in enumerate(doc.children, start=1): - child_segment = ChildChunk( - tenant_id=self._dataset.tenant_id, - dataset_id=self._dataset.id, - document_id=self._document_id, - segment_id=segment_document.id, - position=postion, - index_node_id=child.metadata["doc_id"], - index_node_hash=child.metadata["doc_hash"], - content=child.page_content, - word_count=len(child.page_content), - type="automatic", - created_by=self._user_id, - ) - db.session.add(child_segment) + if doc.children: + for postion, child in enumerate(doc.children, start=1): + child_segment = ChildChunk( + tenant_id=self._dataset.tenant_id, + dataset_id=self._dataset.id, + document_id=self._document_id, + segment_id=segment_document.id, + position=postion, + index_node_id=child.metadata.get("doc_id"), + index_node_hash=child.metadata.get("doc_hash"), + content=child.page_content, + word_count=len(child.page_content), + type="automatic", + created_by=self._user_id, + ) + db.session.add(child_segment) else: segment_document.content = doc.page_content if doc.metadata.get("answer"): segment_document.answer = doc.metadata.pop("answer", "") - segment_document.index_node_hash = doc.metadata["doc_hash"] + segment_document.index_node_hash = doc.metadata.get("doc_hash") segment_document.word_count = len(doc.page_content) segment_document.tokens = tokens if save_child and doc.children: @@ -160,8 +161,8 @@ def add_documents(self, docs: Sequence[Document], allow_update: bool = True, sav document_id=self._document_id, segment_id=segment_document.id, position=position, - index_node_id=child.metadata["doc_id"], - index_node_hash=child.metadata["doc_hash"], + index_node_id=child.metadata.get("doc_id"), + index_node_hash=child.metadata.get("doc_hash"), content=child.page_content, word_count=len(child.page_content), type="automatic", diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py index c444105bb59443..a3b35458df9ab0 100644 --- a/api/core/rag/extractor/excel_extractor.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -4,7 +4,7 @@ from typing import Optional, cast import pandas as pd -from openpyxl import load_workbook +from openpyxl import load_workbook # type: ignore from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index 6d7aa0f7df172e..2bcd1c79bb5dd4 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -81,4 +81,4 @@ def _get_splitter( embedding_model_instance=embedding_model_instance, ) - return character_splitter + return character_splitter # type: ignore diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index ec7126159021ea..dca84b90416e0d 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -30,12 +30,18 @@ def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: def transform(self, documents: list[Document], **kwargs) -> list[Document]: process_rule = kwargs.get("process_rule") + if not process_rule: + raise ValueError("No process rule found.") if process_rule.get("mode") == "automatic": automatic_rule = DatasetProcessRule.AUTOMATIC_RULES rules = Rule(**automatic_rule) else: + if not process_rule.get("rules"): + raise ValueError("No rules found in process rule.") rules = Rule(**process_rule.get("rules")) # Split the text documents into nodes. + if not rules.segmentation: + raise ValueError("No segmentation found in rules.") splitter = self._get_splitter( processing_rule_mode=process_rule.get("mode"), max_tokens=rules.segmentation.max_tokens, diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 7ff15b9f4c86d5..e8423e2b777b15 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -30,8 +30,12 @@ def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: def transform(self, documents: list[Document], **kwargs) -> list[Document]: process_rule = kwargs.get("process_rule") + if not process_rule: + raise ValueError("No process rule found.") + if not process_rule.get("rules"): + raise ValueError("No rules found in process rule.") rules = Rule(**process_rule.get("rules")) - all_documents = [] + all_documents = [] # type: ignore if rules.parent_mode == ParentMode.PARAGRAPH: # Split the text documents into nodes. splitter = self._get_splitter( @@ -161,6 +165,8 @@ def _split_child_nodes( process_rule_mode: str, embedding_model_instance: Optional[ModelInstance], ) -> list[ChildDocument]: + if not rules.subchunk_segmentation: + raise ValueError("No subchunk segmentation found in rules.") child_splitter = self._get_splitter( processing_rule_mode=process_rule_mode, max_tokens=rules.subchunk_segmentation.max_tokens, diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 6535d4626117f8..58b50a9fcbc67e 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -37,12 +37,16 @@ def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: def transform(self, documents: list[Document], **kwargs) -> list[Document]: preview = kwargs.get("preview") process_rule = kwargs.get("process_rule") + if not process_rule: + raise ValueError("No process rule found.") + if not process_rule.get("rules"): + raise ValueError("No rules found in process rule.") rules = Rule(**process_rule.get("rules")) splitter = self._get_splitter( processing_rule_mode=process_rule.get("mode"), - max_tokens=rules.segmentation.max_tokens, - chunk_overlap=rules.segmentation.chunk_overlap, - separator=rules.segmentation.separator, + max_tokens=rules.segmentation.max_tokens if rules.segmentation else 0, + chunk_overlap=rules.segmentation.chunk_overlap if rules.segmentation else 0, + separator=rules.segmentation.separator if rules.segmentation else "", embedding_model_instance=kwargs.get("embedding_model_instance"), ) @@ -71,8 +75,8 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]: all_documents.extend(split_documents) if preview: self._format_qa_document( - current_app._get_current_object(), - kwargs.get("tenant_id"), + current_app._get_current_object(), # type: ignore + kwargs.get("tenant_id"), # type: ignore all_documents[0], all_qa_documents, kwargs.get("doc_language", "English"), @@ -85,8 +89,8 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]: document_format_thread = threading.Thread( target=self._format_qa_document, kwargs={ - "flask_app": current_app._get_current_object(), - "tenant_id": kwargs.get("tenant_id"), + "flask_app": current_app._get_current_object(), # type: ignore + "tenant_id": kwargs.get("tenant_id"), # type: ignore "document_node": doc, "all_qa_documents": all_qa_documents, "document_language": kwargs.get("doc_language", "English"), diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index a34afc7bd75fda..421cdc05df7cc0 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from typing import Any, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel class ChildDocument(BaseModel): @@ -15,7 +15,7 @@ class ChildDocument(BaseModel): """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). """ - metadata: Optional[dict] = Field(default_factory=dict) + metadata: dict = {} class Document(BaseModel): @@ -28,7 +28,7 @@ class Document(BaseModel): """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). """ - metadata: Optional[dict] = Field(default_factory=dict) + metadata: dict = {} provider: Optional[str] = "dify" diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py index fcd1547a2fc492..316be12f5c14b8 100644 --- a/api/extensions/ext_blueprints.py +++ b/api/extensions/ext_blueprints.py @@ -5,7 +5,7 @@ def init_app(app: DifyApp): # register blueprint routers - from flask_cors import CORS + from flask_cors import CORS # type: ignore from controllers.console import bp as console_app_bp from controllers.files import bp as files_bp diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py index 4861e33d585781..766954a257371f 100644 --- a/api/schedule/mail_clean_document_notify_task.py +++ b/api/schedule/mail_clean_document_notify_task.py @@ -1,9 +1,9 @@ import logging import time +from collections import defaultdict import click from celery import shared_task # type: ignore -from flask import render_template from extensions.ext_mail import mail from models.account import Account, Tenant, TenantAccountJoin @@ -27,7 +27,7 @@ def send_document_clean_notify_task(): try: dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all() # group by tenant_id - dataset_auto_disable_logs_map = {} + dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) for dataset_auto_disable_log in dataset_auto_disable_logs: dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log) @@ -37,11 +37,13 @@ def send_document_clean_notify_task(): if not tenant: continue current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() + if not current_owner_join: + continue account = Account.query.filter(Account.id == current_owner_join.account_id).first() if not account: continue - dataset_auto_dataset_map = {} + dataset_auto_dataset_map = {} # type: ignore for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( dataset_auto_disable_log.document_id @@ -53,14 +55,9 @@ def send_document_clean_notify_task(): document_count = len(document_ids) knowledge_details.append(f"