From 56e15d09a9e94de963dd5fc38c7455a7aaee99f3 Mon Sep 17 00:00:00 2001 From: yihong Date: Tue, 24 Dec 2024 18:38:51 +0800 Subject: [PATCH] feat: mypy for all type check (#10921) --- .github/workflows/api-tests.yml | 6 + api/commands.py | 13 +- api/configs/feature/__init__.py | 14 +- api/configs/middleware/__init__.py | 4 - .../remote_settings_sources/apollo/client.py | 5 +- api/constants/model_template.py | 3 +- api/controllers/common/fields.py | 2 +- api/controllers/console/__init__.py | 95 +++++++- api/controllers/console/admin.py | 2 +- api/controllers/console/apikey.py | 23 +- .../console/app/advanced_prompt_template.py | 2 +- api/controllers/console/app/agent.py | 2 +- api/controllers/console/app/annotation.py | 6 +- api/controllers/console/app/app.py | 4 +- api/controllers/console/app/app_import.py | 4 +- api/controllers/console/app/audio.py | 2 +- api/controllers/console/app/completion.py | 4 +- api/controllers/console/app/conversation.py | 15 +- .../console/app/conversation_variables.py | 2 +- api/controllers/console/app/generator.py | 4 +- api/controllers/console/app/message.py | 6 +- api/controllers/console/app/model_config.py | 13 +- api/controllers/console/app/ops_trace.py | 2 +- api/controllers/console/app/site.py | 6 +- api/controllers/console/app/statistic.py | 4 +- api/controllers/console/app/workflow.py | 2 +- .../console/app/workflow_app_log.py | 4 +- api/controllers/console/app/workflow_run.py | 4 +- .../console/app/workflow_statistic.py | 4 +- api/controllers/console/app/wraps.py | 2 +- api/controllers/console/auth/activate.py | 6 +- .../console/auth/data_source_bearer_auth.py | 4 +- .../console/auth/data_source_oauth.py | 8 +- .../console/auth/forgot_password.py | 6 +- api/controllers/console/auth/login.py | 4 +- api/controllers/console/auth/oauth.py | 7 +- api/controllers/console/billing/billing.py | 4 +- .../console/datasets/data_source.py | 4 +- api/controllers/console/datasets/datasets.py | 6 +- .../console/datasets/datasets_document.py | 10 +- .../console/datasets/datasets_segments.py | 4 +- api/controllers/console/datasets/external.py | 4 +- .../console/datasets/hit_testing.py | 2 +- .../console/datasets/hit_testing_base.py | 4 +- api/controllers/console/datasets/website.py | 2 +- api/controllers/console/explore/audio.py | 9 +- api/controllers/console/explore/completion.py | 23 +- .../console/explore/conversation.py | 32 +-- .../console/explore/installed_app.py | 11 +- api/controllers/console/explore/message.py | 25 +- api/controllers/console/explore/parameter.py | 2 +- .../console/explore/recommended_app.py | 4 +- .../console/explore/saved_message.py | 6 +- api/controllers/console/explore/workflow.py | 9 +- api/controllers/console/explore/wraps.py | 4 +- api/controllers/console/extension.py | 4 +- api/controllers/console/feature.py | 4 +- api/controllers/console/files.py | 9 +- api/controllers/console/init_validate.py | 2 +- api/controllers/console/ping.py | 2 +- api/controllers/console/remote_files.py | 4 +- api/controllers/console/setup.py | 2 +- api/controllers/console/tag/tags.py | 6 +- api/controllers/console/version.py | 2 +- api/controllers/console/workspace/account.py | 4 +- .../workspace/load_balancing_config.py | 6 +- api/controllers/console/workspace/members.py | 31 +-- .../console/workspace/model_providers.py | 9 +- api/controllers/console/workspace/models.py | 6 +- .../console/workspace/tool_providers.py | 4 +- .../console/workspace/workspace.py | 14 +- api/controllers/console/wraps.py | 6 +- api/controllers/files/image_preview.py | 2 +- api/controllers/files/tool_files.py | 2 +- .../inner_api/workspace/workspace.py | 2 +- api/controllers/inner_api/wraps.py | 6 +- api/controllers/service_api/app/app.py | 2 +- api/controllers/service_api/app/audio.py | 4 +- api/controllers/service_api/app/completion.py | 2 +- .../service_api/app/conversation.py | 4 +- api/controllers/service_api/app/file.py | 2 +- api/controllers/service_api/app/message.py | 4 +- api/controllers/service_api/app/workflow.py | 4 +- .../service_api/dataset/dataset.py | 2 +- .../service_api/dataset/document.py | 2 +- .../service_api/dataset/segment.py | 4 +- api/controllers/service_api/index.py | 2 +- api/controllers/service_api/wraps.py | 10 +- api/controllers/web/app.py | 2 +- api/controllers/web/audio.py | 4 +- api/controllers/web/completion.py | 2 +- api/controllers/web/conversation.py | 4 +- api/controllers/web/feature.py | 2 +- api/controllers/web/files.py | 4 +- api/controllers/web/message.py | 4 +- api/controllers/web/passport.py | 2 +- api/controllers/web/remote_files.py | 2 +- api/controllers/web/saved_message.py | 4 +- api/controllers/web/site.py | 2 +- api/controllers/web/workflow.py | 2 +- api/controllers/web/wraps.py | 2 +- api/core/agent/base_agent_runner.py | 47 ++-- api/core/agent/cot_agent_runner.py | 82 ++++--- api/core/agent/cot_chat_agent_runner.py | 6 + api/core/agent/cot_completion_agent_runner.py | 21 +- api/core/agent/entities.py | 2 +- api/core/agent/fc_agent_runner.py | 46 ++-- .../agent/output_parser/cot_output_parser.py | 10 +- .../easy_ui_based_app/dataset/manager.py | 2 + .../easy_ui_based_app/model_config/manager.py | 2 +- .../features/opening_statement/manager.py | 4 +- .../app/apps/advanced_chat/app_generator.py | 5 +- .../app_generator_tts_publisher.py | 23 +- api/core/app/apps/advanced_chat/app_runner.py | 10 +- .../advanced_chat/generate_task_pipeline.py | 28 ++- .../app/apps/agent_chat/app_config_manager.py | 2 +- api/core/app/apps/agent_chat/app_generator.py | 13 +- api/core/app/apps/agent_chat/app_runner.py | 19 +- .../agent_chat/generate_response_converter.py | 14 +- api/core/app/apps/base_app_queue_manager.py | 2 +- api/core/app/apps/base_app_runner.py | 45 ++-- api/core/app/apps/chat/app_generator.py | 11 +- .../apps/chat/generate_response_converter.py | 10 +- .../app/apps/completion/app_config_manager.py | 2 +- api/core/app/apps/completion/app_generator.py | 18 +- api/core/app/apps/completion/app_runner.py | 4 +- .../completion/generate_response_converter.py | 10 +- .../app/apps/message_based_app_generator.py | 12 +- api/core/app/apps/workflow/app_generator.py | 2 +- .../workflow/generate_response_converter.py | 12 +- api/core/app/apps/workflow_app_runner.py | 14 +- api/core/app/entities/app_invoke_entities.py | 4 +- api/core/app/entities/queue_entities.py | 2 +- api/core/app/entities/task_entities.py | 14 +- .../annotation_reply/annotation_reply.py | 2 +- .../app/features/rate_limiting/rate_limit.py | 2 +- .../based_generate_task_pipeline.py | 2 + .../easy_ui_based_generate_task_pipeline.py | 51 ++-- .../app/task_pipeline/message_cycle_manage.py | 2 +- .../task_pipeline/workflow_cycle_manage.py | 22 +- .../agent_tool_callback_handler.py | 2 +- .../index_tool_callback_handler.py | 17 +- api/core/entities/model_entities.py | 3 +- api/core/entities/provider_configuration.py | 110 +++++---- .../api_based_extension_requestor.py | 6 +- api/core/extension/extensible.py | 9 +- api/core/extension/extension.py | 10 +- api/core/external_data_tool/api/api.py | 3 + .../external_data_tool/external_data_fetch.py | 24 +- api/core/external_data_tool/factory.py | 10 +- api/core/file/file_manager.py | 5 +- api/core/file/tool_file_parser.py | 4 +- .../helper/code_executor/code_executor.py | 16 +- .../code_executor/jinja2/jinja2_formatter.py | 7 +- .../code_executor/template_transformer.py | 3 +- api/core/helper/lru_cache.py | 2 +- api/core/helper/model_provider_cache.py | 2 +- api/core/helper/moderation.py | 4 +- api/core/helper/module_import_helper.py | 13 +- api/core/helper/tool_parameter_cache.py | 2 +- api/core/helper/tool_provider_cache.py | 2 +- api/core/hosting_configuration.py | 14 +- api/core/indexing_runner.py | 66 ++--- api/core/llm_generator/llm_generator.py | 89 ++++--- api/core/memory/token_buffer_memory.py | 2 +- api/core/model_manager.py | 149 +++++++----- .../callbacks/logging_callback.py | 13 +- .../entities/message_entities.py | 3 +- .../model_providers/__base/ai_model.py | 13 +- .../__base/large_language_model.py | 14 +- .../model_providers/__base/model_provider.py | 3 +- .../__base/text_embedding_model.py | 6 +- .../__base/tokenizers/gpt2_tokenzier.py | 4 +- .../model_providers/__base/tts_model.py | 9 +- .../azure_openai/speech2text/speech2text.py | 4 +- .../model_providers/azure_openai/tts/tts.py | 2 + .../model_providers/bedrock/llm/llm.py | 6 +- .../model_providers/cohere/rerank/rerank.py | 4 +- .../model_providers/fireworks/_common.py | 4 +- .../text_embedding/text_embedding.py | 3 +- .../model_providers/gitee_ai/_common.py | 2 +- .../model_providers/gitee_ai/rerank/rerank.py | 4 +- .../gitee_ai/text_embedding/text_embedding.py | 2 +- .../model_providers/gitee_ai/tts/tts.py | 8 +- .../model_providers/google/llm/llm.py | 2 +- .../huggingface_hub/_common.py | 2 +- .../huggingface_hub/llm/llm.py | 6 +- .../text_embedding/text_embedding.py | 2 +- .../model_providers/hunyuan/llm/llm.py | 12 +- .../hunyuan/text_embedding/text_embedding.py | 10 +- .../jina/text_embedding/jina_tokenizer.py | 4 +- .../minimax/llm/chat_completion.py | 30 +-- .../minimax/llm/chat_completion_pro.py | 26 +- .../model_providers/minimax/llm/types.py | 4 +- .../nomic/text_embedding/text_embedding.py | 4 +- .../model_providers/oci/llm/llm.py | 4 +- .../oci/text_embedding/text_embedding.py | 2 +- .../ollama/text_embedding/text_embedding.py | 1 + .../model_providers/openai/_common.py | 4 +- .../openai/moderation/moderation.py | 6 +- .../model_providers/openai/openai.py | 3 +- .../speech2text/speech2text.py | 1 + .../text_embedding/text_embedding.py | 1 + .../openai_api_compatible/tts/tts.py | 1 + .../openllm/llm/openllm_generate.py | 16 +- .../text_embedding/text_embedding.py | 7 +- .../model_providers/replicate/_common.py | 2 +- .../model_providers/replicate/llm/llm.py | 6 +- .../text_embedding/text_embedding.py | 6 +- .../model_providers/sagemaker/llm/llm.py | 8 +- .../sagemaker/rerank/rerank.py | 3 +- .../sagemaker/speech2text/speech2text.py | 3 +- .../text_embedding/text_embedding.py | 3 +- .../model_providers/sagemaker/tts/tts.py | 2 +- .../model_providers/siliconflow/llm/llm.py | 2 +- .../model_providers/spark/llm/llm.py | 4 +- .../model_providers/togetherai/llm/llm.py | 3 +- .../model_providers/tongyi/_common.py | 2 +- .../model_providers/tongyi/llm/llm.py | 6 +- .../model_providers/tongyi/rerank/rerank.py | 8 +- .../tongyi/text_embedding/text_embedding.py | 2 +- .../model_providers/tongyi/tts/tts.py | 8 +- .../model_providers/upstage/_common.py | 4 +- .../model_providers/upstage/llm/llm.py | 2 +- .../upstage/text_embedding/text_embedding.py | 5 +- .../model_providers/vertex_ai/_common.py | 2 +- .../model_providers/vertex_ai/llm/llm.py | 2 +- .../model_providers/vessl_ai/llm/llm.py | 4 +- .../model_providers/volcengine_maas/client.py | 12 +- .../volcengine_maas/legacy/errors.py | 3 +- .../volcengine_maas/llm/llm.py | 2 +- .../volcengine_maas/llm/models.py | 6 +- .../model_providers/wenxin/llm/ernie_bot.py | 5 +- .../wenxin/text_embedding/text_embedding.py | 9 +- .../model_providers/xinference/llm/llm.py | 2 +- .../xinference/rerank/rerank.py | 2 +- .../xinference/speech2text/speech2text.py | 2 +- .../text_embedding/text_embedding.py | 4 +- .../model_providers/xinference/tts/tts.py | 7 +- .../xinference/xinference_helper.py | 12 +- .../model_providers/yi/llm/llm.py | 2 +- .../model_providers/zhipuai/llm/llm.py | 6 +- .../zhipuai/text_embedding/text_embedding.py | 2 +- .../schema_validators/common_validator.py | 7 +- api/core/model_runtime/utils/encoders.py | 3 +- api/core/model_runtime/utils/helper.py | 3 +- api/core/moderation/api/api.py | 14 +- api/core/moderation/base.py | 4 +- api/core/moderation/factory.py | 3 +- api/core/moderation/input_moderation.py | 8 +- api/core/moderation/keywords/keywords.py | 6 +- .../openai_moderation/openai_moderation.py | 4 + api/core/moderation/output_moderation.py | 2 +- api/core/ops/entities/trace_entity.py | 5 +- api/core/ops/langfuse_trace/langfuse_trace.py | 12 +- .../entities/langsmith_trace_entity.py | 1 - .../ops/langsmith_trace/langsmith_trace.py | 155 ++++++++++-- api/core/ops/ops_trace_manager.py | 52 ++-- api/core/prompt/advanced_prompt_transform.py | 37 +-- .../prompt/agent_history_prompt_transform.py | 2 +- api/core/prompt/prompt_transform.py | 22 +- api/core/prompt/simple_prompt_transform.py | 25 +- api/core/prompt/utils/prompt_message_util.py | 8 +- .../prompt/utils/prompt_template_parser.py | 3 +- api/core/provider_manager.py | 37 ++- .../rag/datasource/keyword/jieba/jieba.py | 37 ++- .../jieba/jieba_keyword_table_handler.py | 4 +- .../rag/datasource/keyword/keyword_base.py | 4 +- api/core/rag/datasource/retrieval_service.py | 23 +- .../vdb/analyticdb/analyticdb_vector.py | 35 +-- .../analyticdb/analyticdb_vector_openapi.py | 37 +-- .../vdb/analyticdb/analyticdb_vector_sql.py | 24 +- .../rag/datasource/vdb/baidu/baidu_vector.py | 28 ++- .../datasource/vdb/chroma/chroma_vector.py | 24 +- .../vdb/couchbase/couchbase_vector.py | 24 +- .../vdb/elasticsearch/elasticsearch_vector.py | 22 +- .../datasource/vdb/lindorm/lindorm_vector.py | 33 ++- .../datasource/vdb/milvus/milvus_vector.py | 22 +- .../datasource/vdb/myscale/myscale_vector.py | 19 +- .../vdb/oceanbase/oceanbase_vector.py | 4 +- .../vdb/opensearch/opensearch_vector.py | 4 +- .../rag/datasource/vdb/oracle/oraclevector.py | 42 ++-- .../datasource/vdb/pgvecto_rs/pgvecto_rs.py | 14 +- .../rag/datasource/vdb/pgvector/pgvector.py | 31 +-- .../datasource/vdb/qdrant/qdrant_vector.py | 21 +- .../rag/datasource/vdb/relyt/relyt_vector.py | 26 +- .../datasource/vdb/tencent/tencent_vector.py | 30 +-- .../tidb_on_qdrant/tidb_on_qdrant_vector.py | 31 ++- .../vdb/tidb_on_qdrant/tidb_service.py | 11 +- .../datasource/vdb/tidb_vector/tidb_vector.py | 12 +- api/core/rag/datasource/vdb/vector_base.py | 11 +- api/core/rag/datasource/vdb/vector_factory.py | 9 +- .../vdb/vikingdb/vikingdb_vector.py | 11 +- .../vdb/weaviate/weaviate_vector.py | 14 +- api/core/rag/docstore/dataset_docstore.py | 9 +- api/core/rag/embedding/cached_embedding.py | 16 +- .../rag/extractor/entity/extract_setting.py | 2 +- api/core/rag/extractor/excel_extractor.py | 8 +- api/core/rag/extractor/extract_processor.py | 20 +- .../rag/extractor/firecrawl/firecrawl_app.py | 21 +- api/core/rag/extractor/html_extractor.py | 3 +- api/core/rag/extractor/notion_extractor.py | 14 +- api/core/rag/extractor/pdf_extractor.py | 6 +- .../unstructured_eml_extractor.py | 2 +- .../unstructured_epub_extractor.py | 3 + .../unstructured_ppt_extractor.py | 4 +- .../unstructured_pptx_extractor.py | 11 +- api/core/rag/extractor/word_extractor.py | 6 + .../index_processor/index_processor_base.py | 1 + .../index_processor_factory.py | 2 +- .../processor/paragraph_index_processor.py | 10 +- .../processor/qa_index_processor.py | 27 +- api/core/rag/rerank/rerank_model.py | 11 +- api/core/rag/rerank/weight_rerank.py | 16 +- api/core/rag/retrieval/dataset_retrieval.py | 111 +++++---- .../multi_dataset_function_call_router.py | 16 +- .../router/multi_dataset_react_route.py | 22 +- api/core/rag/splitter/fixed_text_splitter.py | 4 +- api/core/rag/splitter/text_splitter.py | 4 +- api/core/tools/entities/api_entities.py | 2 +- api/core/tools/entities/tool_bundle.py | 2 +- api/core/tools/entities/tool_entities.py | 12 +- api/core/tools/provider/api_tool_provider.py | 70 +++--- api/core/tools/provider/app_tool_provider.py | 15 +- api/core/tools/provider/builtin/_positions.py | 2 +- .../provider/builtin/aippt/tools/aippt.py | 36 +-- .../builtin/arxiv/tools/arxiv_search.py | 2 +- .../tools/provider/builtin/audio/tools/tts.py | 20 +- .../builtin/aws/tools/apply_guardrail.py | 4 +- .../aws/tools/lambda_translate_utils.py | 2 +- .../builtin/aws/tools/lambda_yaml_to_json.py | 2 +- .../aws/tools/sagemaker_text_rerank.py | 6 +- .../builtin/aws/tools/sagemaker_tts.py | 4 +- .../builtin/cogview/tools/cogvideo.py | 2 +- .../builtin/cogview/tools/cogvideo_job.py | 2 +- .../builtin/cogview/tools/cogview3.py | 2 +- .../feishu_base/tools/search_records.py | 20 +- .../feishu_base/tools/update_records.py | 12 +- .../tools/add_event_attendees.py | 8 +- .../feishu_calendar/tools/delete_event.py | 6 +- .../tools/get_primary_calendar.py | 4 + .../feishu_calendar/tools/list_events.py | 12 +- .../feishu_calendar/tools/update_event.py | 14 +- .../feishu_document/tools/create_document.py | 10 +- .../tools/list_document_blocks.py | 6 +- .../builtin/json_process/tools/delete.py | 2 +- .../builtin/json_process/tools/insert.py | 2 +- .../builtin/json_process/tools/parse.py | 2 +- .../builtin/json_process/tools/replace.py | 2 +- .../builtin/maths/tools/eval_expression.py | 2 +- .../builtin/novitaai/_novita_tool_base.py | 2 +- .../novitaai/tools/novitaai_createtile.py | 2 +- .../novitaai/tools/novitaai_txt2img.py | 2 +- .../tools/podcast_audio_generator.py | 2 +- .../builtin/qrcode/tools/qrcode_generator.py | 8 +- .../builtin/transcript/tools/transcript.py | 2 +- .../builtin/twilio/tools/send_message.py | 2 +- .../tools/provider/builtin/twilio/twilio.py | 4 +- .../provider/builtin/vanna/tools/vanna.py | 5 +- .../wikipedia/tools/wikipedia_search.py | 2 +- .../provider/builtin/yahoo/tools/analytics.py | 2 +- .../provider/builtin/yahoo/tools/news.py | 2 +- .../provider/builtin/yahoo/tools/ticker.py | 2 +- .../provider/builtin/youtube/tools/videos.py | 2 +- .../tools/provider/builtin_tool_provider.py | 70 +++--- api/core/tools/provider/tool_provider.py | 63 ++--- .../tools/provider/workflow_tool_provider.py | 13 +- api/core/tools/tool/api_tool.py | 14 +- api/core/tools/tool/builtin_tool.py | 35 ++- .../dataset_multi_retriever_tool.py | 13 +- .../dataset_retriever_base_tool.py | 2 +- .../dataset_retriever_tool.py | 57 +++-- api/core/tools/tool/dataset_retriever_tool.py | 11 +- api/core/tools/tool/tool.py | 17 +- api/core/tools/tool/workflow_tool.py | 14 +- api/core/tools/tool_engine.py | 24 +- api/core/tools/tool_label_manager.py | 8 +- api/core/tools/tool_manager.py | 126 ++++++---- api/core/tools/utils/configuration.py | 18 +- api/core/tools/utils/feishu_api_utils.py | 179 ++++++++------ api/core/tools/utils/lark_api_utils.py | 193 +++++++++------ api/core/tools/utils/message_transformer.py | 12 +- .../tools/utils/model_invocation_utils.py | 23 +- api/core/tools/utils/parser.py | 17 +- api/core/tools/utils/web_reader_tool.py | 15 +- .../utils/workflow_configuration_sync.py | 4 +- api/core/tools/utils/yaml_utils.py | 2 +- api/core/variables/variables.py | 3 +- .../callbacks/workflow_logging_callback.py | 2 +- api/core/workflow/entities/node_entities.py | 4 +- .../condition_handlers/condition_handler.py | 2 +- .../workflow/graph_engine/entities/graph.py | 16 +- .../workflow/graph_engine/graph_engine.py | 55 +++-- .../nodes/answer/answer_stream_processor.py | 4 +- .../nodes/answer/base_stream_processor.py | 8 +- api/core/workflow/nodes/base/entities.py | 5 +- api/core/workflow/nodes/code/code_node.py | 2 +- api/core/workflow/nodes/code/entities.py | 2 +- .../workflow/nodes/document_extractor/node.py | 7 +- .../nodes/end/end_stream_generate_router.py | 5 +- .../nodes/end/end_stream_processor.py | 2 +- api/core/workflow/nodes/event/event.py | 2 +- .../workflow/nodes/http_request/executor.py | 44 ++-- api/core/workflow/nodes/http_request/node.py | 8 +- .../nodes/iteration/iteration_node.py | 15 +- .../knowledge_retrieval_node.py | 16 +- api/core/workflow/nodes/list_operator/node.py | 34 +-- api/core/workflow/nodes/llm/node.py | 17 +- api/core/workflow/nodes/loop/loop_node.py | 6 +- .../nodes/parameter_extractor/entities.py | 4 +- .../parameter_extractor_node.py | 16 +- .../nodes/parameter_extractor/prompts.py | 4 +- .../question_classifier_node.py | 12 +- api/core/workflow/nodes/tool/tool_node.py | 9 +- .../nodes/variable_assigner/v1/node.py | 2 + .../nodes/variable_assigner/v2/node.py | 6 +- api/core/workflow/workflow_entry.py | 15 +- .../event_handlers/create_document_index.py | 2 +- .../create_site_record_when_app_created.py | 29 +-- .../deduct_quota_when_message_created.py | 2 +- ...rameters_cache_when_sync_draft_workflow.py | 5 +- ...aset_join_when_app_model_config_updated.py | 10 +- ...oin_when_app_published_workflow_updated.py | 10 +- api/extensions/__init__.py | 0 api/extensions/ext_app_metrics.py | 14 +- api/extensions/ext_celery.py | 6 +- api/extensions/ext_compress.py | 2 +- api/extensions/ext_logging.py | 5 +- api/extensions/ext_login.py | 2 +- api/extensions/ext_mail.py | 8 +- api/extensions/ext_migrate.py | 2 +- api/extensions/ext_proxy_fix.py | 2 +- api/extensions/ext_sentry.py | 2 +- api/extensions/ext_storage.py | 8 +- api/extensions/storage/aliyun_oss_storage.py | 12 +- api/extensions/storage/aws_s3_storage.py | 8 +- api/extensions/storage/azure_blob_storage.py | 8 +- api/extensions/storage/baidu_obs_storage.py | 9 +- .../storage/google_cloud_storage.py | 4 +- api/extensions/storage/huawei_obs_storage.py | 4 +- api/extensions/storage/opendal_storage.py | 8 +- api/extensions/storage/oracle_oci_storage.py | 6 +- api/extensions/storage/supabase_storage.py | 2 +- api/extensions/storage/tencent_cos_storage.py | 4 +- .../storage/volcengine_tos_storage.py | 4 +- api/factories/__init__.py | 0 api/factories/file_factory.py | 5 +- api/factories/variable_factory.py | 21 +- api/fields/annotation_fields.py | 2 +- api/fields/api_based_extension_fields.py | 2 +- api/fields/app_fields.py | 2 +- api/fields/conversation_fields.py | 2 +- api/fields/conversation_variable_fields.py | 2 +- api/fields/data_source_fields.py | 2 +- api/fields/dataset_fields.py | 2 +- api/fields/document_fields.py | 2 +- api/fields/end_user_fields.py | 2 +- api/fields/external_dataset_fields.py | 2 +- api/fields/file_fields.py | 2 +- api/fields/hit_testing_fields.py | 2 +- api/fields/installed_app_fields.py | 2 +- api/fields/member_fields.py | 2 +- api/fields/message_fields.py | 2 +- api/fields/raws.py | 2 +- api/fields/segment_fields.py | 2 +- api/fields/tag_fields.py | 2 +- api/fields/workflow_app_log_fields.py | 2 +- api/fields/workflow_fields.py | 2 +- api/fields/workflow_run_fields.py | 2 +- api/libs/external_api.py | 7 +- api/libs/gmpy2_pkcs10aep_cipher.py | 8 +- api/libs/helper.py | 6 +- api/libs/json_in_md_parser.py | 1 + api/libs/login.py | 15 +- api/libs/oauth.py | 6 +- api/libs/oauth_data_source.py | 7 +- api/libs/threadings_utils.py | 4 +- api/models/account.py | 32 +-- api/models/api_based_extension.py | 2 +- api/models/dataset.py | 33 +-- api/models/model.py | 73 +++--- api/models/provider.py | 14 +- api/models/source.py | 4 +- api/models/task.py | 6 +- api/models/tools.py | 22 +- api/models/web.py | 4 +- api/models/workflow.py | 19 +- api/mypy.ini | 10 + api/poetry.lock | 230 ++++++++++++------ api/pyproject.toml | 3 + api/schedule/clean_messages.py | 3 +- api/schedule/clean_unused_datasets_task.py | 6 +- api/schedule/create_tidb_serverless_task.py | 15 +- .../update_tidb_serverless_status_task.py | 13 +- api/services/account_service.py | 31 ++- .../advanced_prompt_template_service.py | 4 + api/services/agent_service.py | 13 +- api/services/annotation_service.py | 14 +- api/services/app_dsl_service.py | 9 +- api/services/app_generate_service.py | 4 +- api/services/app_service.py | 24 +- api/services/audio_service.py | 6 + api/services/auth/firecrawl/firecrawl.py | 4 +- api/services/auth/jina.py | 2 +- api/services/auth/jina/jina.py | 2 +- api/services/billing_service.py | 6 +- api/services/conversation_service.py | 3 +- api/services/dataset_service.py | 42 +++- api/services/enterprise/base.py | 4 +- .../entities/model_provider_entities.py | 8 +- api/services/external_knowledge_service.py | 39 +-- api/services/file_service.py | 6 +- api/services/hit_testing_service.py | 17 +- api/services/knowledge_service.py | 2 +- api/services/message_service.py | 6 +- api/services/model_load_balancing_service.py | 49 ++-- api/services/model_provider_service.py | 25 +- api/services/moderation_service.py | 4 +- api/services/ops_service.py | 26 +- .../buildin/buildin_retrieval.py | 8 +- .../recommend_app/remote/remote_retrieval.py | 6 +- api/services/recommended_app_service.py | 2 +- api/services/saved_message_service.py | 6 + api/services/tag_service.py | 4 +- .../tools/api_tools_manage_service.py | 33 +-- .../tools/builtin_tools_manage_service.py | 8 +- api/services/tools/tools_transform_service.py | 34 ++- .../tools/workflow_tools_manage_service.py | 63 +++-- api/services/web_conversation_service.py | 6 + api/services/website_service.py | 33 ++- api/services/workflow/workflow_converter.py | 18 +- api/services/workflow_run_service.py | 4 +- api/services/workflow_service.py | 6 +- api/services/workspace_service.py | 3 +- api/tasks/__init__.py | 0 api/tasks/add_document_to_index_task.py | 2 +- .../add_annotation_to_index_task.py | 2 +- .../batch_import_annotations_task.py | 2 +- .../delete_annotation_index_task.py | 2 +- .../disable_annotation_reply_task.py | 2 +- .../enable_annotation_reply_task.py | 2 +- .../update_annotation_to_index_task.py | 2 +- .../batch_create_segment_to_index_task.py | 10 +- api/tasks/clean_dataset_task.py | 4 +- api/tasks/clean_document_task.py | 4 +- api/tasks/clean_notion_document_task.py | 2 +- api/tasks/create_segment_to_index_task.py | 2 +- api/tasks/deal_dataset_vector_index_task.py | 2 +- api/tasks/delete_segment_from_index_task.py | 2 +- api/tasks/disable_segment_from_index_task.py | 2 +- api/tasks/document_indexing_sync_task.py | 2 +- api/tasks/document_indexing_task.py | 2 +- api/tasks/document_indexing_update_task.py | 2 +- api/tasks/duplicate_document_indexing_task.py | 4 +- api/tasks/enable_segment_to_index_task.py | 2 +- api/tasks/mail_email_code_login.py | 2 +- api/tasks/mail_invite_member_task.py | 2 +- api/tasks/mail_reset_password_task.py | 2 +- api/tasks/ops_trace_task.py | 2 +- api/tasks/recover_document_indexing_task.py | 2 +- api/tasks/remove_app_and_related_data_task.py | 2 +- api/tasks/remove_document_from_index_task.py | 2 +- api/tasks/retry_document_indexing_task.py | 45 ++-- .../sync_website_document_indexing_task.py | 42 ++-- .../dependencies/test_dependencies_sorted.py | 4 +- .../controllers/test_controllers.py | 2 +- .../model_runtime/__mock/google.py | 4 +- .../model_runtime/__mock/huggingface.py | 2 +- .../model_runtime/__mock/huggingface_chat.py | 6 +- .../model_runtime/__mock/nomic_embeddings.py | 2 +- .../model_runtime/__mock/xinference.py | 4 +- .../model_runtime/tongyi/test_rerank.py | 2 +- .../tools/__mock_server/openapi_todo.py | 2 +- .../vdb/__mock/baiduvectordb.py | 10 +- .../vdb/__mock/tcvectordb.py | 12 +- .../integration_tests/vdb/__mock/vikingdb.py | 2 +- api/tests/unit_tests/oss/__mock/aliyun_oss.py | 4 +- .../unit_tests/oss/__mock/tencent_cos.py | 4 +- .../unit_tests/oss/__mock/volcengine_tos.py | 4 +- .../aliyun_oss/aliyun_oss/test_aliyun_oss.py | 2 +- .../oss/tencent_cos/test_tencent_cos.py | 2 +- .../oss/volcengine_tos/test_volcengine_tos.py | 2 +- .../unit_tests/utils/yaml/test_yaml_utils.py | 2 +- sdks/python-client/dify_client/client.py | 27 +- 584 files changed, 3980 insertions(+), 2831 deletions(-) create mode 100644 api/extensions/__init__.py create mode 100644 api/factories/__init__.py create mode 100644 api/mypy.ini create mode 100644 api/tasks/__init__.py diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 2cd0b2a7d430de..fd98db24b961b4 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -56,6 +56,12 @@ jobs: - name: Run Tool run: poetry run -C api bash dev/pytest/pytest_tools.sh + - name: Run mypy + run: | + pushd api + poetry run python -m mypy --install-types --non-interactive . + popd + - name: Set up dotenvs run: | cp docker/.env.example docker/.env diff --git a/api/commands.py b/api/commands.py index bf013cc77e0627..ad7ad972f3fd01 100644 --- a/api/commands.py +++ b/api/commands.py @@ -159,8 +159,7 @@ def migrate_annotation_vector_database(): try: # get apps info apps = ( - db.session.query(App) - .filter(App.status == "normal") + App.query.filter(App.status == "normal") .order_by(App.created_at.desc()) .paginate(page=page, per_page=50) ) @@ -285,8 +284,7 @@ def migrate_knowledge_vector_database(): while True: try: datasets = ( - db.session.query(Dataset) - .filter(Dataset.indexing_technique == "high_quality") + Dataset.query.filter(Dataset.indexing_technique == "high_quality") .order_by(Dataset.created_at.desc()) .paginate(page=page, per_page=50) ) @@ -450,7 +448,8 @@ def convert_to_agent_apps(): if app_id not in proceeded_app_ids: proceeded_app_ids.append(app_id) app = db.session.query(App).filter(App.id == app_id).first() - apps.append(app) + if app is not None: + apps.append(app) if len(apps) == 0: break @@ -621,6 +620,10 @@ def fix_app_site_missing(): try: app = db.session.query(App).filter(App.id == app_id).first() + if not app: + print(f"App {app_id} not found") + continue + tenant = app.tenant if tenant: accounts = tenant.get_accounts() diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 73f8a95989baaf..74cdf944865796 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -239,7 +239,6 @@ class HttpConfig(BaseSettings): ) @computed_field - @property def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]: return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(",") @@ -250,7 +249,6 @@ def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]: ) @computed_field - @property def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]: return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",") @@ -715,27 +713,27 @@ class PositionConfig(BaseSettings): default="", ) - @computed_field + @property def POSITION_PROVIDER_PINS_LIST(self) -> list[str]: return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(",") if item.strip() != ""] - @computed_field + @property def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]: return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(",") if item.strip() != ""} - @computed_field + @property def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]: return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(",") if item.strip() != ""} - @computed_field + @property def POSITION_TOOL_PINS_LIST(self) -> list[str]: return [item.strip() for item in self.POSITION_TOOL_PINS.split(",") if item.strip() != ""] - @computed_field + @property def POSITION_TOOL_INCLUDES_SET(self) -> set[str]: return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(",") if item.strip() != ""} - @computed_field + @property def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]: return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""} diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 9265a48d9bc53c..f6a44eaa471e62 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -130,7 +130,6 @@ class DatabaseConfig(BaseSettings): ) @computed_field - @property def SQLALCHEMY_DATABASE_URI(self) -> str: db_extras = ( f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else self.DB_EXTRAS @@ -168,7 +167,6 @@ def SQLALCHEMY_DATABASE_URI(self) -> str: ) @computed_field - @property def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]: return { "pool_size": self.SQLALCHEMY_POOL_SIZE, @@ -206,7 +204,6 @@ class CeleryConfig(DatabaseConfig): ) @computed_field - @property def CELERY_RESULT_BACKEND(self) -> str | None: return ( "db+{}".format(self.SQLALCHEMY_DATABASE_URI) @@ -214,7 +211,6 @@ def CELERY_RESULT_BACKEND(self) -> str | None: else self.CELERY_BROKER_URL ) - @computed_field @property def BROKER_USE_SSL(self) -> bool: return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False diff --git a/api/configs/remote_settings_sources/apollo/client.py b/api/configs/remote_settings_sources/apollo/client.py index d1f6781ed370dd..03c64ea00f0185 100644 --- a/api/configs/remote_settings_sources/apollo/client.py +++ b/api/configs/remote_settings_sources/apollo/client.py @@ -4,6 +4,7 @@ import os import threading import time +from collections.abc import Mapping from pathlib import Path from .python_3x import http_request, makedirs_wrapper @@ -255,8 +256,8 @@ def _listener(self): logger.info("stopped, long_poll") # add the need for endorsement to the header - def _sign_headers(self, url): - headers = {} + def _sign_headers(self, url: str) -> Mapping[str, str]: + headers: dict[str, str] = {} if self.secret == "": return headers uri = url[len(self.config_url) : len(url)] diff --git a/api/constants/model_template.py b/api/constants/model_template.py index 7e1a196356c4e2..c26d8c018610d0 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -1,8 +1,9 @@ import json +from collections.abc import Mapping from models.model import AppMode -default_app_templates = { +default_app_templates: Mapping[AppMode, Mapping] = { # workflow default mode AppMode.WORKFLOW: { "app": { diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py index 79869916eda062..b1ebc444a51868 100644 --- a/api/controllers/common/fields.py +++ b/api/controllers/common/fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore parameters__system_parameters = { "image_file_size_limit": fields.Integer, diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index f46d5b6b138d59..cb6b0d097b1fc9 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -3,6 +3,25 @@ from libs.external_api import ExternalApi from .app.app_import import AppImportApi, AppImportConfirmApi +from .explore.audio import ChatAudioApi, ChatTextApi +from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi +from .explore.conversation import ( + ConversationApi, + ConversationListApi, + ConversationPinApi, + ConversationRenameApi, + ConversationUnPinApi, +) +from .explore.message import ( + MessageFeedbackApi, + MessageListApi, + MessageMoreLikeThisApi, + MessageSuggestedQuestionApi, +) +from .explore.workflow import ( + InstalledAppWorkflowRunApi, + InstalledAppWorkflowTaskStopApi, +) from .files import FileApi, FilePreviewApi, FileSupportTypeApi from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi @@ -66,15 +85,81 @@ # Import explore controllers from .explore import ( - audio, - completion, - conversation, installed_app, - message, parameter, recommended_app, saved_message, - workflow, +) + +# Explore Audio +api.add_resource(ChatAudioApi, "/installed-apps//audio-to-text", endpoint="installed_app_audio") +api.add_resource(ChatTextApi, "/installed-apps//text-to-audio", endpoint="installed_app_text") + +# Explore Completion +api.add_resource( + CompletionApi, "/installed-apps//completion-messages", endpoint="installed_app_completion" +) +api.add_resource( + CompletionStopApi, + "/installed-apps//completion-messages//stop", + endpoint="installed_app_stop_completion", +) +api.add_resource( + ChatApi, "/installed-apps//chat-messages", endpoint="installed_app_chat_completion" +) +api.add_resource( + ChatStopApi, + "/installed-apps//chat-messages//stop", + endpoint="installed_app_stop_chat_completion", +) + +# Explore Conversation +api.add_resource( + ConversationRenameApi, + "/installed-apps//conversations//name", + endpoint="installed_app_conversation_rename", +) +api.add_resource( + ConversationListApi, "/installed-apps//conversations", endpoint="installed_app_conversations" +) +api.add_resource( + ConversationApi, + "/installed-apps//conversations/", + endpoint="installed_app_conversation", +) +api.add_resource( + ConversationPinApi, + "/installed-apps//conversations//pin", + endpoint="installed_app_conversation_pin", +) +api.add_resource( + ConversationUnPinApi, + "/installed-apps//conversations//unpin", + endpoint="installed_app_conversation_unpin", +) + + +# Explore Message +api.add_resource(MessageListApi, "/installed-apps//messages", endpoint="installed_app_messages") +api.add_resource( + MessageFeedbackApi, + "/installed-apps//messages//feedbacks", + endpoint="installed_app_message_feedback", +) +api.add_resource( + MessageMoreLikeThisApi, + "/installed-apps//messages//more-like-this", + endpoint="installed_app_more_like_this", +) +api.add_resource( + MessageSuggestedQuestionApi, + "/installed-apps//messages//suggested-questions", + endpoint="installed_app_suggested_question", +) +# Explore Workflow +api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps//workflows/run") +api.add_resource( + InstalledAppWorkflowTaskStopApi, "/installed-apps//workflows/tasks//stop" ) # Import tag controllers diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 8c0bf8710d3964..52e0bb6c56bdc2 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -1,7 +1,7 @@ from functools import wraps from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import NotFound, Unauthorized from configs import dify_config diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 953770868904d3..ca8ddc32094ac5 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -1,5 +1,7 @@ -import flask_restful -from flask_login import current_user +from typing import Any + +import flask_restful # type: ignore +from flask_login import current_user # type: ignore from flask_restful import Resource, fields, marshal_with from werkzeug.exceptions import Forbidden @@ -35,14 +37,15 @@ def _get_resource(resource_id, tenant_id, resource_model): class BaseApiKeyListResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] - resource_type = None - resource_model = None - resource_id_field = None - token_prefix = None + resource_type: str | None = None + resource_model: Any = None + resource_id_field: str | None = None + token_prefix: str | None = None max_keys = 10 @marshal_with(api_key_list) def get(self, resource_id): + assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) keys = ( @@ -54,6 +57,7 @@ def get(self, resource_id): @marshal_with(api_key_fields) def post(self, resource_id): + assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) if not current_user.is_editor: @@ -86,11 +90,12 @@ def post(self, resource_id): class BaseApiKeyResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] - resource_type = None - resource_model = None - resource_id_field = None + resource_type: str | None = None + resource_model: Any = None + resource_id_field: str | None = None def delete(self, resource_id, api_key_id): + assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) api_key_id = str(api_key_id) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index c228743fa53591..8d0c5b84af5e37 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.wraps import account_initialization_required, setup_required diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py index d4334158945e16..920cae0d859354 100644 --- a/api/controllers/console/app/agent.py +++ b/api/controllers/console/app/agent.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index fd05cbc19bf04f..24f1020c18ec37 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,6 +1,6 @@ from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api @@ -110,7 +110,7 @@ def get(self, app_id): page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) - keyword = request.args.get("keyword", default=None, type=str) + keyword = request.args.get("keyword", default="", type=str) app_id = str(app_id) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index da72b704c71bd7..9cd56cef0b7039 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,8 +1,8 @@ import uuid from typing import cast -from flask_login import current_user -from flask_restful import Resource, inputs, marshal, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, inputs, marshal, marshal_with, reqparse # type: ignore from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden, abort diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index 244dcd75de29bc..7e2888d71c79c8 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -1,7 +1,7 @@ from typing import cast -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 695b8890e30f5c..9d26af276d2fc3 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -1,7 +1,7 @@ import logging from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import InternalServerError import services diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 9896fcaab8ad36..dba41e5c47d24f 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -1,7 +1,7 @@ import logging -import flask_login -from flask_restful import Resource, reqparse +import flask_login # type: ignore +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index a25004be4d16ae..8827f129d99317 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -1,9 +1,9 @@ from datetime import UTC, datetime -import pytz -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +import pytz # pip install pytz +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from sqlalchemy import func, or_ from sqlalchemy.orm import joinedload from werkzeug.exceptions import Forbidden, NotFound @@ -77,8 +77,9 @@ def get(self, app_model): query = query.where(Conversation.created_at < end_datetime_utc) + # FIXME, the type ignore in this file if args["annotation_status"] == "annotated": - query = query.options(joinedload(Conversation.message_annotations)).join( + query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id ) elif args["annotation_status"] == "not_annotated": @@ -222,7 +223,7 @@ def get(self, app_model): query = query.where(Conversation.created_at <= end_datetime_utc) if args["annotation_status"] == "annotated": - query = query.options(joinedload(Conversation.message_annotations)).join( + query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id ) elif args["annotation_status"] == "not_annotated": @@ -234,7 +235,7 @@ def get(self, app_model): if args["message_count_gte"] and args["message_count_gte"] >= 1: query = ( - query.options(joinedload(Conversation.messages)) + query.options(joinedload(Conversation.messages)) # type: ignore .join(Message, Message.conversation_id == Conversation.id) .group_by(Conversation.id) .having(func.count(Message.id) >= args["message_count_gte"]) diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index d49f433ba1f575..c0a20b7160e719 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, marshal_with, reqparse +from flask_restful import Resource, marshal_with, reqparse # type: ignore from sqlalchemy import select from sqlalchemy.orm import Session diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 9c3cbe4e3e049e..8518d34a8e5af2 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,7 +1,7 @@ import os -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.app.error import ( diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index b7a4c31a156b80..b5828b6b4b08c4 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -1,8 +1,8 @@ import logging -from flask_login import current_user -from flask_restful import Resource, fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_login import current_user # type: ignore +from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from controllers.console import api diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index a46bc6a8a97606..8ecc8a9db5738d 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -1,8 +1,9 @@ import json +from typing import cast from flask import request -from flask_login import current_user -from flask_restful import Resource +from flask_login import current_user # type: ignore +from flask_restful import Resource # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model @@ -26,7 +27,9 @@ def post(self, app_model): """Modify app model config""" # validate config model_configuration = AppModelConfigService.validate_configuration( - tenant_id=current_user.current_tenant_id, config=request.json, app_mode=AppMode.value_of(app_model.mode) + tenant_id=current_user.current_tenant_id, + config=cast(dict, request.json), + app_mode=AppMode.value_of(app_model.mode), ) new_app_model_config = AppModelConfig( @@ -38,9 +41,11 @@ def post(self, app_model): if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: # get original app model config - original_app_model_config: AppModelConfig = ( + original_app_model_config = ( db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() ) + if original_app_model_config is None: + raise ValueError("Original app model config not found") agent_mode = original_app_model_config.agent_mode_dict # decrypt agent tool parameters if it's secret-input parameter_map = {} diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index 3f10215e702ac1..dd25af8ebf9312 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import BadRequest from controllers.console import api diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 407f6898199bae..db29b95c4140ff 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,7 +1,7 @@ from datetime import UTC, datetime -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Forbidden, NotFound from constants.languages import supported_language @@ -50,7 +50,7 @@ def post(self, app_model): if not current_user.is_editor: raise Forbidden() - site = db.session.query(Site).filter(Site.app_id == app_model.id).one_or_404() + site = Site.query.filter(Site.app_id == app_model.id).one_or_404() for attr_name in [ "title", diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index db5e2824095ca0..3b21108ceaf76b 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -3,8 +3,8 @@ import pytz from flask import jsonify -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index f228c3ec4a0e07..26a3a022d401a4 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -2,7 +2,7 @@ import logging from flask import abort, request -from flask_restful import Resource, marshal_with, reqparse +from flask_restful import Resource, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 2940556f84ef4e..882c53e4fb9972 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -1,5 +1,5 @@ -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import Resource, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 08ab61bbb9c97e..25a99c1e1594ae 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -1,5 +1,5 @@ -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import Resource, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index 6c7c73707bb204..097bf7d1888cf5 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -3,8 +3,8 @@ import pytz from flask import jsonify -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index 63edb83079041e..9ad8c158473df9 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -8,7 +8,7 @@ from models import App, AppMode -def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None): +def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None): def decorator(view_func): @wraps(view_func) def decorated_view(*args, **kwargs): diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index d2aa7c903b046c..c56f551d49be8b 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -1,14 +1,14 @@ import datetime from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from constants.languages import supported_language from controllers.console import api from controllers.console.error import AlreadyActivateError from extensions.ext_database import db from libs.helper import StrLen, email, extract_remote_ip, timezone -from models.account import AccountStatus, Tenant +from models.account import AccountStatus from services.account_service import AccountService, RegisterService @@ -27,7 +27,7 @@ def get(self): invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token) if invitation: data = invitation.get("data", {}) - tenant: Tenant = invitation.get("tenant", None) + tenant = invitation.get("tenant", None) workspace_name = tenant.name if tenant else None workspace_id = tenant.id if tenant else None invitee_email = data.get("email") if data else None diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index 465c44e9b6dc2f..ea00c2b8c2272c 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -1,5 +1,5 @@ -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index faca67bb177f10..e911c9a5e5b5ea 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -2,8 +2,8 @@ import requests from flask import current_app, redirect, request -from flask_login import current_user -from flask_restful import Resource +from flask_login import current_user # type: ignore +from flask_restful import Resource # type: ignore from werkzeug.exceptions import Forbidden from configs import dify_config @@ -17,8 +17,8 @@ def get_oauth_providers(): with current_app.app_context(): notion_oauth = NotionOAuth( - client_id=dify_config.NOTION_CLIENT_ID, - client_secret=dify_config.NOTION_CLIENT_SECRET, + client_id=dify_config.NOTION_CLIENT_ID or "", + client_secret=dify_config.NOTION_CLIENT_SECRET or "", redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/data-source/callback/notion", ) diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index fb32bb2b60286d..140b9e145fa9cd 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -2,7 +2,7 @@ import secrets from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from constants.languages import languages from controllers.console import api @@ -122,8 +122,8 @@ def post(self): else: try: account = AccountService.create_account_and_tenant( - email=reset_data.get("email"), - name=reset_data.get("email"), + email=reset_data.get("email", ""), + name=reset_data.get("email", ""), password=password_confirm, interface_language=languages[0], ) diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index f4463ce9cb3f30..78a80fc8d7e075 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -1,8 +1,8 @@ from typing import cast -import flask_login +import flask_login # type: ignore from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore import services from constants.languages import languages diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index b9188aa0798ea2..333b24142727f0 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -4,7 +4,7 @@ import requests from flask import current_app, redirect, request -from flask_restful import Resource +from flask_restful import Resource # type: ignore from werkzeug.exceptions import Unauthorized from configs import dify_config @@ -77,7 +77,8 @@ def get(self, provider: str): token = oauth_provider.get_access_token(code) user_info = oauth_provider.get_user_info(token) except requests.exceptions.RequestException as e: - logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}") + error_text = e.response.text if e.response else str(e) + logging.exception(f"An error occurred during the OAuth process with {provider}: {error_text}") return {"error": "OAuth process failed"}, 400 if invite_token and RegisterService.is_valid_invite_token(invite_token): @@ -129,7 +130,7 @@ def get(self, provider: str): def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: - account = Account.get_by_openid(provider, user_info.id) + account: Optional[Account] = Account.get_by_openid(provider, user_info.id) if not account: account = Account.query.filter_by(email=user_info.email).first() diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 4b0c82ae6c90c2..fd7b7bd8cb3ddd 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -1,5 +1,5 @@ -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 278295ca39a696..d7c431b95080da 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -2,8 +2,8 @@ import json from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore from werkzeug.exceptions import NotFound from controllers.console import api diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 95d4013e3a8f27..f3c3736b25acc5 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -1,7 +1,7 @@ -import flask_restful +import flask_restful # type: ignore from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal, marshal_with, reqparse +from flask_login import current_user # type: ignore # type: ignore +from flask_restful import Resource, marshal, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Forbidden, NotFound import services diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index ad4768f51959ac..ca41e504be7eda 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1,12 +1,13 @@ import logging from argparse import ArgumentTypeError from datetime import UTC, datetime +from typing import cast from flask import request -from flask_login import current_user -from flask_restful import Resource, fields, marshal, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, fields, marshal, marshal_with, reqparse # type: ignore from sqlalchemy import asc, desc -from transformers.hf_argparser import string_to_bool +from transformers.hf_argparser import string_to_bool # type: ignore from werkzeug.exceptions import Forbidden, NotFound import services @@ -733,8 +734,7 @@ def put(self, dataset_id, document_id): if not isinstance(doc_metadata, dict): raise ValueError("doc_metadata must be a dictionary.") - - metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type] + metadata_schema: dict = cast(dict, DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]) document.doc_metadata = {} if doc_type == "others": diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 6f7ef86d2c3fd3..2d5933ca23609a 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -3,8 +3,8 @@ import pandas as pd from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal, reqparse # type: ignore from werkzeug.exceptions import Forbidden, NotFound import services diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index bc6e3687c1c99d..48f360dcd179bc 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -1,6 +1,6 @@ from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal, reqparse # type: ignore from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 495f511275b4b9..18b746f547287c 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restful import Resource # type: ignore from controllers.console import api from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index 3b4c07686361d0..bd944602c147cb 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -1,7 +1,7 @@ import logging -from flask_login import current_user -from flask_restful import marshal, reqparse +from flask_login import current_user # type: ignore +from flask_restful import marshal, reqparse # type: ignore from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services.dataset_service diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py index 9127c8af455f6c..da995537e74753 100644 --- a/api/controllers/console/datasets/website.py +++ b/api/controllers/console/datasets/website.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.datasets.error import WebsiteCrawlError diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 9690677f61b1c2..c7f9fec326945f 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -4,7 +4,6 @@ from werkzeug.exceptions import InternalServerError import services -from controllers.console import api from controllers.console.app.error import ( AppUnavailableError, AudioTooLargeError, @@ -67,7 +66,7 @@ def post(self, installed_app): class ChatTextApi(InstalledAppResource): def post(self, installed_app): - from flask_restful import reqparse + from flask_restful import reqparse # type: ignore app_model = installed_app.app try: @@ -118,9 +117,3 @@ def post(self, installed_app): except Exception as e: logging.exception("internal server error.") raise InternalServerError() - - -api.add_resource(ChatAudioApi, "/installed-apps//audio-to-text", endpoint="installed_app_audio") -api.add_resource(ChatTextApi, "/installed-apps//text-to-audio", endpoint="installed_app_text") -# api.add_resource(ChatTextApiWithMessageId, '/installed-apps//text-to-audio/message-id', -# endpoint='installed_app_text_with_message_id') diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 85c43f8101028e..3331ded70f6620 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -1,12 +1,11 @@ import logging from datetime import UTC, datetime -from flask_login import current_user -from flask_restful import reqparse +from flask_login import current_user # type: ignore +from flask_restful import reqparse # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services -from controllers.console import api from controllers.console.app.error import ( AppUnavailableError, CompletionRequestError, @@ -147,21 +146,3 @@ def post(self, installed_app, task_id): AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) return {"result": "success"}, 200 - - -api.add_resource( - CompletionApi, "/installed-apps//completion-messages", endpoint="installed_app_completion" -) -api.add_resource( - CompletionStopApi, - "/installed-apps//completion-messages//stop", - endpoint="installed_app_stop_completion", -) -api.add_resource( - ChatApi, "/installed-apps//chat-messages", endpoint="installed_app_chat_completion" -) -api.add_resource( - ChatStopApi, - "/installed-apps//chat-messages//stop", - endpoint="installed_app_stop_chat_completion", -) diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 5e7a3da017edd7..91916cbc1ed85f 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,10 +1,9 @@ -from flask_login import current_user -from flask_restful import marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_login import current_user # type: ignore +from flask_restful import marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound -from controllers.console import api from controllers.console.explore.error import NotChatAppError from controllers.console.explore.wraps import InstalledAppResource from core.app.entities.app_invoke_entities import InvokeFrom @@ -118,28 +117,3 @@ def patch(self, installed_app, c_id): WebConversationService.unpin(app_model, conversation_id, current_user) return {"result": "success"} - - -api.add_resource( - ConversationRenameApi, - "/installed-apps//conversations//name", - endpoint="installed_app_conversation_rename", -) -api.add_resource( - ConversationListApi, "/installed-apps//conversations", endpoint="installed_app_conversations" -) -api.add_resource( - ConversationApi, - "/installed-apps//conversations/", - endpoint="installed_app_conversation", -) -api.add_resource( - ConversationPinApi, - "/installed-apps//conversations//pin", - endpoint="installed_app_conversation_pin", -) -api.add_resource( - ConversationUnPinApi, - "/installed-apps//conversations//unpin", - endpoint="installed_app_conversation_unpin", -) diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 3de179164de91d..86550b2bdf44b9 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -1,8 +1,9 @@ from datetime import UTC, datetime +from typing import Any from flask import request -from flask_login import current_user -from flask_restful import Resource, inputs, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore from sqlalchemy import and_ from werkzeug.exceptions import BadRequest, Forbidden, NotFound @@ -34,7 +35,7 @@ def get(self): installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all() current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) - installed_apps = [ + installed_app_list: list[dict[str, Any]] = [ { "id": installed_app.id, "app": installed_app.app, @@ -47,7 +48,7 @@ def get(self): for installed_app in installed_apps if installed_app.app is not None ] - installed_apps.sort( + installed_app_list.sort( key=lambda app: ( -app["is_pinned"], app["last_used_at"] is None, @@ -55,7 +56,7 @@ def get(self): ) ) - return {"installed_apps": installed_apps} + return {"installed_apps": installed_app_list} @login_required @account_initialization_required diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 4e11d8005f138b..c3488de29929c9 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -1,12 +1,11 @@ import logging -from flask_login import current_user -from flask_restful import marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_login import current_user # type: ignore +from flask_restful import marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services -from controllers.console import api from controllers.console.app.error import ( AppMoreLikeThisDisabledError, CompletionRequestError, @@ -153,21 +152,3 @@ def get(self, installed_app, message_id): raise InternalServerError() return {"data": questions} - - -api.add_resource(MessageListApi, "/installed-apps//messages", endpoint="installed_app_messages") -api.add_resource( - MessageFeedbackApi, - "/installed-apps//messages//feedbacks", - endpoint="installed_app_message_feedback", -) -api.add_resource( - MessageMoreLikeThisApi, - "/installed-apps//messages//more-like-this", - endpoint="installed_app_more_like_this", -) -api.add_resource( - MessageSuggestedQuestionApi, - "/installed-apps//messages//suggested-questions", - endpoint="installed_app_suggested_question", -) diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index fee52248a698e0..5bc74d16e784af 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,4 +1,4 @@ -from flask_restful import marshal_with +from flask_restful import marshal_with # type: ignore from controllers.common import fields from controllers.common import helpers as controller_helpers diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index ce85f495aacd50..be6b1f5d215fb4 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,5 +1,5 @@ -from flask_login import current_user -from flask_restful import Resource, fields, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore from constants.languages import languages from controllers.console import api diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 0fc963747981e1..9f0c4966457186 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -1,6 +1,6 @@ -from flask_login import current_user -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_login import current_user # type: ignore +from flask_restful import fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import NotFound from controllers.console import api diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 45f99b1db9fa9e..76d30299cd84a7 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -1,9 +1,8 @@ import logging -from flask_restful import reqparse +from flask_restful import reqparse # type: ignore from werkzeug.exceptions import InternalServerError -from controllers.console import api from controllers.console.app.error import ( CompletionRequestError, ProviderModelCurrentlyNotSupportError, @@ -73,9 +72,3 @@ def post(self, installed_app: InstalledApp, task_id: str): AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) return {"result": "success"} - - -api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps//workflows/run") -api.add_resource( - InstalledAppWorkflowTaskStopApi, "/installed-apps//workflows/tasks//stop" -) diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index 49ea81a8a0f86d..b7ba81fba20f79 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -1,7 +1,7 @@ from functools import wraps -from flask_login import current_user -from flask_restful import Resource +from flask_login import current_user # type: ignore +from flask_restful import Resource # type: ignore from werkzeug.exceptions import NotFound from controllers.console.wraps import account_initialization_required diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 4ac0aa497e0866..ed6cedb220cf4b 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -1,5 +1,5 @@ -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore from constants import HIDDEN_VALUE from controllers.console import api diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 70ab4ff865cb48..da1171412fdb2d 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,5 +1,5 @@ -from flask_login import current_user -from flask_restful import Resource +from flask_login import current_user # type: ignore +from flask_restful import Resource # type: ignore from libs.login import login_required from services.feature_service import FeatureService diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index ca32d29efaa474..8cf754bbd686fd 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -1,6 +1,8 @@ +from typing import Literal + from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal_with +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with # type: ignore from werkzeug.exceptions import Forbidden import services @@ -48,7 +50,8 @@ def get(self): @cloud_edition_billing_resource_check("documents") def post(self): file = request.files["file"] - source = request.form.get("source") + source_str = request.form.get("source") + source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None if "file" not in request.files: raise NoFileUploadedError() diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index ae759bb752a30e..d9ae5cf29fc626 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -1,7 +1,7 @@ import os from flask import session -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from configs import dify_config from libs.helper import StrLen diff --git a/api/controllers/console/ping.py b/api/controllers/console/ping.py index cd28cc946ee288..2a116112a3227c 100644 --- a/api/controllers/console/ping.py +++ b/api/controllers/console/ping.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restful import Resource # type: ignore from controllers.console import api diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index b8cf019e4f068d..30afc930a8e980 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -2,8 +2,8 @@ from typing import cast import httpx -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore import services from controllers.common import helpers diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index e0b728d97739d3..aba6f0aad9ee54 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -1,5 +1,5 @@ from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from configs import dify_config from libs.helper import StrLen, email, extract_remote_ip diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index ccd3293a6266fc..da83f64019161b 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -1,6 +1,6 @@ from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api @@ -23,7 +23,7 @@ class TagListApi(Resource): @account_initialization_required @marshal_with(tag_fields) def get(self): - tag_type = request.args.get("type", type=str) + tag_type = request.args.get("type", type=str, default="") keyword = request.args.get("keyword", default=None, type=str) tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword) diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 7dea8e554edd7a..7773c99944e42c 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -2,7 +2,7 @@ import logging import requests -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from packaging import version from configs import dify_config diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index f704783cfff56b..96ed4b7a570256 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -2,8 +2,8 @@ import pytz from flask import request -from flask_login import current_user -from flask_restful import Resource, fields, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore from configs import dify_config from constants.languages import supported_language diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index d2b2092b75a9ff..7009343d9923da 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api @@ -37,7 +37,7 @@ def post(self, provider: str): model_load_balancing_service = ModelLoadBalancingService() result = True - error = None + error = "" try: model_load_balancing_service.validate_load_balancing_credentials( @@ -86,7 +86,7 @@ def post(self, provider: str, config_id: str): model_load_balancing_service = ModelLoadBalancingService() result = True - error = None + error = "" try: model_load_balancing_service.validate_load_balancing_credentials( diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 38ed2316a58935..1afb41ea87660c 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,7 +1,7 @@ from urllib import parse -from flask_login import current_user -from flask_restful import Resource, abort, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, abort, marshal_with, reqparse # type: ignore import services from configs import dify_config @@ -89,19 +89,19 @@ class MemberCancelInviteApi(Resource): @account_initialization_required def delete(self, member_id): member = db.session.query(Account).filter(Account.id == str(member_id)).first() - if not member: + if member is None: abort(404) - - try: - TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user) - except services.errors.account.CannotOperateSelfError as e: - return {"code": "cannot-operate-self", "message": str(e)}, 400 - except services.errors.account.NoPermissionError as e: - return {"code": "forbidden", "message": str(e)}, 403 - except services.errors.account.MemberNotInTenantError as e: - return {"code": "member-not-found", "message": str(e)}, 404 - except Exception as e: - raise ValueError(str(e)) + else: + try: + TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user) + except services.errors.account.CannotOperateSelfError as e: + return {"code": "cannot-operate-self", "message": str(e)}, 400 + except services.errors.account.NoPermissionError as e: + return {"code": "forbidden", "message": str(e)}, 403 + except services.errors.account.MemberNotInTenantError as e: + return {"code": "member-not-found", "message": str(e)}, 404 + except Exception as e: + raise ValueError(str(e)) return {"result": "success"}, 204 @@ -122,10 +122,11 @@ def put(self, member_id): return {"code": "invalid-role", "message": "Invalid role"}, 400 member = db.session.get(Account, str(member_id)) - if not member: + if member: abort(404) try: + assert member is not None, "Member not found" TenantService.update_member_role(current_user.current_tenant, member, new_role, current_user) except Exception as e: raise ValueError(str(e)) diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 0e54126063be75..2d11295b0fdf61 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -1,8 +1,8 @@ import io from flask import send_file -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api @@ -66,7 +66,7 @@ def post(self, provider: str): model_provider_service = ModelProviderService() result = True - error = None + error = "" try: model_provider_service.provider_credentials_validate( @@ -132,7 +132,8 @@ def get(self, provider: str, icon_type: str, lang: str): icon_type=icon_type, lang=lang, ) - + if icon is None: + raise ValueError(f"icon not found for provider {provider}, icon_type {icon_type}, lang {lang}") return send_file(io.BytesIO(icon), mimetype=mimetype) diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index f804285f008120..618262e502ab33 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -1,7 +1,7 @@ import logging -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api @@ -308,7 +308,7 @@ def post(self, provider: str): model_provider_service = ModelProviderService() result = True - error = None + error = "" try: model_provider_service.model_credentials_validate( diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 9e62a546997b12..964f3862291a2e 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1,8 +1,8 @@ import io from flask import send_file -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 76d76f6b58fc3c..0f99bf62e3c251 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -1,8 +1,8 @@ import logging from flask import request -from flask_login import current_user -from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Unauthorized import services @@ -82,11 +82,7 @@ def get(self): parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - tenants = ( - db.session.query(Tenant) - .order_by(Tenant.created_at.desc()) - .paginate(page=args["page"], per_page=args["limit"]) - ) + tenants = Tenant.query.order_by(Tenant.created_at.desc()).paginate(page=args["page"], per_page=args["limit"]) has_more = False if len(tenants.items) == args["limit"]: @@ -151,6 +147,8 @@ def post(self): raise AccountNotLinkTenantError("Account not link tenant") new_tenant = db.session.query(Tenant).get(args["tenant_id"]) # Get new tenant + if new_tenant is None: + raise ValueError("Tenant not found") return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)} @@ -166,7 +164,7 @@ def post(self): parser.add_argument("replace_webapp_logo", type=str, location="json") args = parser.parse_args() - tenant = db.session.query(Tenant).filter(Tenant.id == current_user.current_tenant_id).one_or_404() + tenant = Tenant.query.filter(Tenant.id == current_user.current_tenant_id).one_or_404() custom_config_dict = { "remove_webapp_brand": args["remove_webapp_brand"], diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index d0df296c240686..111db7ccf2da04 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -3,7 +3,7 @@ from functools import wraps from flask import abort, request -from flask_login import current_user +from flask_login import current_user # type: ignore from configs import dify_config from controllers.console.workspace.error import AccountNotInitializedError @@ -121,8 +121,8 @@ def decorated(*args, **kwargs): utm_info = request.cookies.get("utm_info") if utm_info: - utm_info = json.loads(utm_info) - OperationService.record_utm(current_user.current_tenant_id, utm_info) + utm_info_dict: dict = json.loads(utm_info) + OperationService.record_utm(current_user.current_tenant_id, utm_info_dict) except Exception as e: pass return view(*args, **kwargs) diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py index 6b3ac93cdf3d8f..2357288a50ae36 100644 --- a/api/controllers/files/image_preview.py +++ b/api/controllers/files/image_preview.py @@ -1,5 +1,5 @@ from flask import Response, request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import NotFound import services diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index a298701a2f8b11..cfcce8124761f5 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -1,5 +1,5 @@ from flask import Response -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import Forbidden, NotFound from controllers.files import api diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py index 99d32af593991f..d7346b13b10a90 100644 --- a/api/controllers/inner_api/workspace/workspace.py +++ b/api/controllers/inner_api/workspace/workspace.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from controllers.console.wraps import setup_required from controllers.inner_api import api diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index 51ffe683ff40ad..d4587235f6aef8 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -45,14 +45,14 @@ def decorated(*args, **kwargs): if " " in user_id: user_id = user_id.split(" ")[1] - inner_api_key = request.headers.get("X-Inner-Api-Key") + inner_api_key = request.headers.get("X-Inner-Api-Key", "") data_to_sign = f"DIFY {user_id}" signature = hmac_new(inner_api_key.encode("utf-8"), data_to_sign.encode("utf-8"), sha1) - signature = b64encode(signature.digest()).decode("utf-8") + signature_base64 = b64encode(signature.digest()).decode("utf-8") - if signature != token: + if signature_base64 != token: return view(*args, **kwargs) kwargs["user"] = db.session.query(EndUser).filter(EndUser.id == user_id).first() diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index ecff7d07e974d9..8388e2045dd34f 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, marshal_with +from flask_restful import Resource, marshal_with # type: ignore from controllers.common import fields from controllers.common import helpers as controller_helpers diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 5db41636471220..e6bcc0bfd25355 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -1,7 +1,7 @@ import logging from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import InternalServerError import services @@ -83,7 +83,7 @@ def post(self, app_model: App, end_user: EndUser): and app_model.workflow and app_model.workflow.features_dict ): - text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + text_to_speech = app_model.workflow.features_dict.get("text_to_speech", {}) voice = args.get("voice") or text_to_speech.get("voice") else: try: diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 8d8e356c4cb940..1be54b386bfe8c 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -1,6 +1,6 @@ import logging -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 32940cbc29f355..334f2c56206794 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,5 +1,5 @@ -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import Resource, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index b0fd8e65ef97df..27b21b9f505633 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -1,5 +1,5 @@ from flask import request -from flask_restful import Resource, marshal_with +from flask_restful import Resource, marshal_with # type: ignore import services from controllers.common.errors import FilenameNotExistsError diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 599401bc6f1821..522c7509b9849d 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -1,7 +1,7 @@ import logging -from flask_restful import Resource, fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import BadRequest, InternalServerError, NotFound import services diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 96d1337632826a..c7dd4de3452970 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -1,7 +1,7 @@ import logging -from flask_restful import Resource, fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import InternalServerError from controllers.service_api import api diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 799fccc228e21d..d6a3beb6b80b9d 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -1,5 +1,5 @@ from flask import request -from flask_restful import marshal, reqparse +from flask_restful import marshal, reqparse # type: ignore from werkzeug.exceptions import NotFound import services.dataset_service diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 5c3fc7b241175a..34afe2837f4ca5 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -1,7 +1,7 @@ import json from flask import request -from flask_restful import marshal, reqparse +from flask_restful import marshal, reqparse # type: ignore from sqlalchemy import desc from werkzeug.exceptions import NotFound diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index e68f6b4dc40a36..34904574a8b88d 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -1,5 +1,5 @@ -from flask_login import current_user -from flask_restful import marshal, reqparse +from flask_login import current_user # type: ignore +from flask_restful import marshal, reqparse # type: ignore from werkzeug.exceptions import NotFound from controllers.service_api import api diff --git a/api/controllers/service_api/index.py b/api/controllers/service_api/index.py index d24c4597e210fb..75d9141a6d0a3a 100644 --- a/api/controllers/service_api/index.py +++ b/api/controllers/service_api/index.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restful import Resource # type: ignore from configs import dify_config from controllers.service_api import api diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 2128c4c53f9909..740b92ef8e4faf 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -5,8 +5,8 @@ from typing import Optional from flask import current_app, request -from flask_login import user_logged_in -from flask_restful import Resource +from flask_login import user_logged_in # type: ignore +from flask_restful import Resource # type: ignore from pydantic import BaseModel from werkzeug.exceptions import Forbidden, Unauthorized @@ -49,6 +49,8 @@ def decorated_view(*args, **kwargs): raise Forbidden("The app's API service has been disabled.") tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first() + if tenant is None: + raise ValueError("Tenant does not exist.") if tenant.status == TenantStatus.ARCHIVE: raise Forbidden("The workspace's status is archived.") @@ -154,8 +156,8 @@ def decorated(*args, **kwargs): # Login admin if account: account.current_tenant = tenant - current_app.login_manager._update_request_context_with_user(account) - user_logged_in.send(current_app._get_current_object(), user=_get_user()) + current_app.login_manager._update_request_context_with_user(account) # type: ignore + user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore else: raise Unauthorized("Tenant owner account does not exist.") else: diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index cc8255ccf4e748..20e071c834ad5b 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,4 +1,4 @@ -from flask_restful import marshal_with +from flask_restful import marshal_with # type: ignore from controllers.common import fields from controllers.common import helpers as controller_helpers diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index e8521307ad357a..97d980d07c13a7 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -65,7 +65,7 @@ def post(self, app_model: App, end_user): class TextApi(WebApiResource): def post(self, app_model: App, end_user): - from flask_restful import reqparse + from flask_restful import reqparse # type: ignore try: parser = reqparse.RequestParser() @@ -82,7 +82,7 @@ def post(self, app_model: App, end_user): and app_model.workflow and app_model.workflow.features_dict ): - text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + text_to_speech = app_model.workflow.features_dict.get("text_to_speech", {}) voice = args.get("voice") or text_to_speech.get("voice") else: try: diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 45b890dfc4899d..761771a81a4bb3 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,6 +1,6 @@ import logging -from flask_restful import reqparse +from flask_restful import reqparse # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index fe0d7c74f32cff..28feb1ca47effd 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -1,5 +1,5 @@ -from flask_restful import marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound diff --git a/api/controllers/web/feature.py b/api/controllers/web/feature.py index 0563ed22382e6b..ce841a8814972d 100644 --- a/api/controllers/web/feature.py +++ b/api/controllers/web/feature.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restful import Resource # type: ignore from controllers.web import api from services.feature_service import FeatureService diff --git a/api/controllers/web/files.py b/api/controllers/web/files.py index a282fc63a8b056..1d4474015ab648 100644 --- a/api/controllers/web/files.py +++ b/api/controllers/web/files.py @@ -1,5 +1,5 @@ from flask import request -from flask_restful import marshal_with +from flask_restful import marshal_with # type: ignore import services from controllers.common.errors import FilenameNotExistsError @@ -33,7 +33,7 @@ def post(self, app_model, end_user): content=file.read(), mimetype=file.mimetype, user=end_user, - source=source, + source="datasets" if source == "datasets" else None, ) except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index febaab5328e8b3..0f47e643708570 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -1,7 +1,7 @@ import logging -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index a01ffd861230a5..4625c1f43dfbd1 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -1,7 +1,7 @@ import uuid from flask import request -from flask_restful import Resource +from flask_restful import Resource # type: ignore from werkzeug.exceptions import NotFound, Unauthorized from controllers.web import api diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index ae68df6bdc4e48..d559ab8e07e736 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -1,7 +1,7 @@ import urllib.parse import httpx -from flask_restful import marshal_with, reqparse +from flask_restful import marshal_with, reqparse # type: ignore import services from controllers.common import helpers diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index b0492e6b6f0d31..6a9b8189076c3c 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -1,5 +1,5 @@ -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import NotFound from controllers.web import api diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 0564b15ea39855..e68dc7aa4afba5 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -1,4 +1,4 @@ -from flask_restful import fields, marshal_with +from flask_restful import fields, marshal_with # type: ignore from werkzeug.exceptions import Forbidden from configs import dify_config diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 55b0c3e2ab34c5..48d25e720c10c3 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -1,6 +1,6 @@ import logging -from flask_restful import reqparse +from flask_restful import reqparse # type: ignore from werkzeug.exceptions import InternalServerError from controllers.web import api diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index c327c3df18526c..1b4d263bee4401 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -1,7 +1,7 @@ from functools import wraps from flask import request -from flask_restful import Resource +from flask_restful import Resource # type: ignore from werkzeug.exceptions import BadRequest, NotFound, Unauthorized from controllers.web.error import WebSSOAuthRequiredError diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index ead293200ea3aa..8d69bdcec2c2ac 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -1,7 +1,6 @@ import json import logging import uuid -from collections.abc import Mapping, Sequence from datetime import UTC, datetime from typing import Optional, Union, cast @@ -53,6 +52,7 @@ class BaseAgentRunner(AppRunner): def __init__( self, + *, tenant_id: str, application_generate_entity: AgentChatAppGenerateEntity, conversation: Conversation, @@ -66,7 +66,7 @@ def __init__( prompt_messages: Optional[list[PromptMessage]] = None, variables_pool: Optional[ToolRuntimeVariablePool] = None, db_variables: Optional[ToolConversationVariables] = None, - model_instance: ModelInstance | None = None, + model_instance: ModelInstance, ) -> None: self.tenant_id = tenant_id self.application_generate_entity = application_generate_entity @@ -117,7 +117,7 @@ def __init__( features = model_schema.features if model_schema and model_schema.features else [] self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features self.files = application_generate_entity.files if ModelFeature.VISION in features else [] - self.query = None + self.query: Optional[str] = "" self._current_thoughts: list[PromptMessage] = [] def _repack_app_generate_entity( @@ -145,7 +145,7 @@ def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[P message_tool = PromptMessageTool( name=tool.tool_name, - description=tool_entity.description.llm, + description=tool_entity.description.llm if tool_entity.description else "", parameters={ "type": "object", "properties": {}, @@ -167,7 +167,7 @@ def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[P continue enum = [] if parameter.type == ToolParameter.ToolParameterType.SELECT: - enum = [option.value for option in parameter.options] + enum = [option.value for option in parameter.options] if parameter.options else [] message_tool.parameters["properties"][parameter.name] = { "type": parameter_type, @@ -187,8 +187,8 @@ def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRe convert dataset retriever tool to prompt message tool """ prompt_tool = PromptMessageTool( - name=tool.identity.name, - description=tool.description.llm, + name=tool.identity.name if tool.identity else "unknown", + description=tool.description.llm if tool.description else "", parameters={ "type": "object", "properties": {}, @@ -210,14 +210,14 @@ def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRe return prompt_tool - def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]: + def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]: """ Init tools """ tool_instances = {} prompt_messages_tools = [] - for tool in self.app_config.agent.tools if self.app_config.agent else []: + for tool in self.app_config.agent.tools or [] if self.app_config.agent else []: try: prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) except Exception: @@ -234,7 +234,8 @@ def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessage # save prompt tool prompt_messages_tools.append(prompt_tool) # save tool entity - tool_instances[dataset_tool.identity.name] = dataset_tool + if dataset_tool.identity is not None: + tool_instances[dataset_tool.identity.name] = dataset_tool return tool_instances, prompt_messages_tools @@ -258,7 +259,7 @@ def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) continue enum = [] if parameter.type == ToolParameter.ToolParameterType.SELECT: - enum = [option.value for option in parameter.options] + enum = [option.value for option in parameter.options] if parameter.options else [] prompt_tool.parameters["properties"][parameter.name] = { "type": parameter_type, @@ -322,16 +323,21 @@ def save_agent_thought( tool_name: str, tool_input: Union[str, dict], thought: str, - observation: Union[str, dict], - tool_invoke_meta: Union[str, dict], + observation: Union[str, dict, None], + tool_invoke_meta: Union[str, dict, None], answer: str, messages_ids: list[str], - llm_usage: LLMUsage = None, - ) -> MessageAgentThought: + llm_usage: LLMUsage | None = None, + ): """ Save agent thought """ - agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() + queried_thought = ( + db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() + ) + if not queried_thought: + raise ValueError(f"Agent thought {agent_thought.id} not found") + agent_thought = queried_thought if thought is not None: agent_thought.thought = thought @@ -404,7 +410,7 @@ def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variab """ convert tool variables to db variables """ - db_variables = ( + queried_variables = ( db.session.query(ToolConversationVariables) .filter( ToolConversationVariables.conversation_id == self.message.conversation_id, @@ -412,6 +418,11 @@ def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variab .first() ) + if not queried_variables: + return + + db_variables = queried_variables + db_variables.updated_at = datetime.now(UTC).replace(tzinfo=None) db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) db.session.commit() @@ -421,7 +432,7 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P """ Organize agent history """ - result = [] + result: list[PromptMessage] = [] # check if there is a system message in the beginning of the conversation for prompt_message in prompt_messages: if isinstance(prompt_message, SystemPromptMessage): diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index d98ba5a3fad846..e936acb6055cb8 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -1,7 +1,7 @@ import json from abc import ABC, abstractmethod -from collections.abc import Generator -from typing import Optional, Union +from collections.abc import Generator, Mapping +from typing import Any, Optional from core.agent.base_agent_runner import BaseAgentRunner from core.agent.entities import AgentScratchpadUnit @@ -12,6 +12,7 @@ from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, + PromptMessageTool, ToolPromptMessage, UserPromptMessage, ) @@ -26,18 +27,18 @@ class CotAgentRunner(BaseAgentRunner, ABC): _is_first_iteration = True _ignore_observation_providers = ["wenxin"] - _historic_prompt_messages: list[PromptMessage] = None - _agent_scratchpad: list[AgentScratchpadUnit] = None - _instruction: str = None - _query: str = None - _prompt_messages_tools: list[PromptMessage] = None + _historic_prompt_messages: list[PromptMessage] | None = None + _agent_scratchpad: list[AgentScratchpadUnit] | None = None + _instruction: str = "" # FIXME this must be str for now + _query: str | None = None + _prompt_messages_tools: list[PromptMessageTool] = [] def run( self, message: Message, query: str, - inputs: dict[str, str], - ) -> Union[Generator, LLMResult]: + inputs: Mapping[str, str], + ) -> Generator: """ Run Cot agent application """ @@ -57,19 +58,19 @@ def run( # init instruction inputs = inputs or {} instruction = app_config.prompt_template.simple_prompt_template - self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) + self._instruction = self._fill_in_inputs_from_external_data_tools(instruction=instruction or "", inputs=inputs) iteration_step = 1 - max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 + max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1 # convert tools into ModelRuntime Tool format tool_instances, self._prompt_messages_tools = self._init_prompt_tools() function_call_state = True - llm_usage = {"usage": None} + llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} final_answer = "" - def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): + def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage): if not final_llm_usage_dict["usage"]: final_llm_usage_dict["usage"] = usage else: @@ -90,7 +91,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # the last iteration, remove all tools self._prompt_messages_tools = [] - message_file_ids = [] + message_file_ids: list[str] = [] agent_thought = self.create_agent_thought( message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids @@ -105,7 +106,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): prompt_messages = self._organize_prompt_messages() self.recalc_llm_max_tokens(self.model_config, prompt_messages) # invoke model - chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm( + chunks = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=app_generate_entity.model_conf.parameters, tools=[], @@ -115,11 +116,14 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): callbacks=[], ) + if not isinstance(chunks, Generator): + raise ValueError("Expected streaming response from LLM") + # check llm result if not chunks: raise ValueError("failed to invoke llm") - usage_dict = {} + usage_dict: dict[str, Optional[LLMUsage]] = {"usage": None} react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) scratchpad = AgentScratchpadUnit( agent_response="", @@ -139,25 +143,30 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): if isinstance(chunk, AgentScratchpadUnit.Action): action = chunk # detect action - scratchpad.agent_response += json.dumps(chunk.model_dump()) + if scratchpad.agent_response is not None: + scratchpad.agent_response += json.dumps(chunk.model_dump()) scratchpad.action_str = json.dumps(chunk.model_dump()) scratchpad.action = action else: - scratchpad.agent_response += chunk - scratchpad.thought += chunk + if scratchpad.agent_response is not None: + scratchpad.agent_response += chunk + if scratchpad.thought is not None: + scratchpad.thought += chunk yield LLMResultChunk( model=self.model_config.model, prompt_messages=prompt_messages, system_fingerprint="", delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None), ) - - scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you" - self._agent_scratchpad.append(scratchpad) + if scratchpad.thought is not None: + scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you" + if self._agent_scratchpad is not None: + self._agent_scratchpad.append(scratchpad) # get llm usage if "usage" in usage_dict: - increase_usage(llm_usage, usage_dict["usage"]) + if usage_dict["usage"] is not None: + increase_usage(llm_usage, usage_dict["usage"]) else: usage_dict["usage"] = LLMUsage.empty_usage() @@ -166,9 +175,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): tool_name=scratchpad.action.action_name if scratchpad.action else "", tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {}, tool_invoke_meta={}, - thought=scratchpad.thought, + thought=scratchpad.thought or "", observation="", - answer=scratchpad.agent_response, + answer=scratchpad.agent_response or "", messages_ids=[], llm_usage=usage_dict["usage"], ) @@ -209,7 +218,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): agent_thought=agent_thought, tool_name=scratchpad.action.action_name, tool_input={scratchpad.action.action_name: scratchpad.action.action_input}, - thought=scratchpad.thought, + thought=scratchpad.thought or "", observation={scratchpad.action.action_name: tool_invoke_response}, tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()}, answer=scratchpad.agent_response, @@ -247,8 +256,8 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): answer=final_answer, messages_ids=[], ) - - self.update_db_variables(self.variables_pool, self.db_variables_pool) + if self.variables_pool is not None and self.db_variables_pool is not None: + self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event self.queue_manager.publish( QueueMessageEndEvent( @@ -307,8 +316,9 @@ def _handle_invoke_action( # publish files for message_file_id, save_as in message_files: - if save_as: - self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as) + if save_as is not None and self.variables_pool: + # FIXME the save_as type is confusing, it should be a string or not + self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=str(save_as)) # publish message file self.queue_manager.publish( @@ -325,7 +335,7 @@ def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action: """ return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"]) - def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str: + def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: Mapping[str, Any]) -> str: """ fill in inputs from external data tools """ @@ -376,11 +386,13 @@ def _organize_historic_prompt_messages( """ result: list[PromptMessage] = [] scratchpads: list[AgentScratchpadUnit] = [] - current_scratchpad: AgentScratchpadUnit = None + current_scratchpad: AgentScratchpadUnit | None = None for message in self.history_prompt_messages: if isinstance(message, AssistantPromptMessage): if not current_scratchpad: + if not isinstance(message.content, str | None): + raise NotImplementedError("expected str type") current_scratchpad = AgentScratchpadUnit( agent_response=message.content, thought=message.content or "I am thinking about how to help you", @@ -399,8 +411,12 @@ def _organize_historic_prompt_messages( except: pass elif isinstance(message, ToolPromptMessage): - if current_scratchpad: + if not current_scratchpad: + continue + if isinstance(message.content, str): current_scratchpad.observation = message.content + else: + raise NotImplementedError("expected str type") elif isinstance(message, UserPromptMessage): if scratchpads: result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index d8d047fe91cdbd..6a96c349b2611c 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -19,7 +19,12 @@ def _organize_system_prompt(self) -> SystemPromptMessage: """ Organize system prompt """ + if not self.app_config.agent: + raise ValueError("Agent configuration is not set") + prompt_entity = self.app_config.agent.prompt + if not prompt_entity: + raise ValueError("Agent prompt configuration is not set") first_prompt = prompt_entity.first_prompt system_prompt = ( @@ -75,6 +80,7 @@ def _organize_prompt_messages(self) -> list[PromptMessage]: assistant_messages = [] else: assistant_message = AssistantPromptMessage(content="") + assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str for unit in agent_scratchpad: if unit.is_final(): assistant_message.content += f"Final Answer: {unit.agent_response}" diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index 0563090537e62c..3a4d31e047f5ae 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -2,7 +2,12 @@ from typing import Optional from core.agent.cot_agent_runner import CotAgentRunner -from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) from core.model_runtime.utils.encoders import jsonable_encoder @@ -11,7 +16,11 @@ def _organize_instruction_prompt(self) -> str: """ Organize instruction prompt """ + if self.app_config.agent is None: + raise ValueError("Agent configuration is not set") prompt_entity = self.app_config.agent.prompt + if prompt_entity is None: + raise ValueError("prompt entity is not set") first_prompt = prompt_entity.first_prompt system_prompt = ( @@ -33,7 +42,13 @@ def _organize_historic_prompt(self, current_session_messages: Optional[list[Prom if isinstance(message, UserPromptMessage): historic_prompt += f"Question: {message.content}\n\n" elif isinstance(message, AssistantPromptMessage): - historic_prompt += message.content + "\n\n" + if isinstance(message.content, str): + historic_prompt += message.content + "\n\n" + elif isinstance(message.content, list): + for content in message.content: + if not isinstance(content, TextPromptMessageContent): + continue + historic_prompt += content.data return historic_prompt @@ -50,7 +65,7 @@ def _organize_prompt_messages(self) -> list[PromptMessage]: # organize current assistant messages agent_scratchpad = self._agent_scratchpad assistant_prompt = "" - for unit in agent_scratchpad: + for unit in agent_scratchpad or []: if unit.is_final(): assistant_prompt += f"Final Answer: {unit.agent_response}" else: diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 119a88fc7becbf..2ae87dca3f8cbd 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -78,5 +78,5 @@ class Strategy(Enum): model: str strategy: Strategy prompt: Optional[AgentPromptEntity] = None - tools: list[AgentToolEntity] = None + tools: list[AgentToolEntity] | None = None max_iteration: int = 5 diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index cd546dee124147..b862c96072aaa0 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -40,6 +40,8 @@ def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResul app_generate_entity = self.application_generate_entity app_config = self.app_config + assert app_config is not None, "app_config is required" + assert app_config.agent is not None, "app_config.agent is required" # convert tools into ModelRuntime Tool format tool_instances, prompt_messages_tools = self._init_prompt_tools() @@ -49,7 +51,7 @@ def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResul # continue to run until there is not any tool call function_call_state = True - llm_usage = {"usage": None} + llm_usage: dict[str, LLMUsage] = {"usage": LLMUsage.empty_usage()} final_answer = "" # get tracing instance @@ -75,7 +77,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # the last iteration, remove all tools prompt_messages_tools = [] - message_file_ids = [] + message_file_ids: list[str] = [] agent_thought = self.create_agent_thought( message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids ) @@ -105,7 +107,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): current_llm_usage = None - if self.stream_tool_call: + if self.stream_tool_call and isinstance(chunks, Generator): is_first_chunk = True for chunk in chunks: if is_first_chunk: @@ -116,7 +118,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # check if there is any tool call if self.check_tool_calls(chunk): function_call_state = True - tool_calls.extend(self.extract_tool_calls(chunk)) + tool_calls.extend(self.extract_tool_calls(chunk) or []) tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) try: tool_call_inputs = json.dumps( @@ -131,19 +133,19 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): for content in chunk.delta.message.content: response += content.data else: - response += chunk.delta.message.content + response += str(chunk.delta.message.content) if chunk.delta.usage: increase_usage(llm_usage, chunk.delta.usage) current_llm_usage = chunk.delta.usage yield chunk - else: - result: LLMResult = chunks + elif not self.stream_tool_call and isinstance(chunks, LLMResult): + result = chunks # check if there is any tool call if self.check_blocking_tool_calls(result): function_call_state = True - tool_calls.extend(self.extract_blocking_tool_calls(result)) + tool_calls.extend(self.extract_blocking_tool_calls(result) or []) tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) try: tool_call_inputs = json.dumps( @@ -162,7 +164,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): for content in result.message.content: response += content.data else: - response += result.message.content + response += str(result.message.content) if not result.message.content: result.message.content = "" @@ -181,6 +183,8 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): usage=result.usage, ), ) + else: + raise RuntimeError(f"invalid chunks type: {type(chunks)}") assistant_message = AssistantPromptMessage(content="", tool_calls=[]) if tool_calls: @@ -243,7 +247,10 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # publish files for message_file_id, save_as in message_files: if save_as: - self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as) + if self.variables_pool: + self.variables_pool.set_file( + tool_name=tool_call_name, value=message_file_id, name=save_as + ) # publish message file self.queue_manager.publish( @@ -263,7 +270,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): if tool_response["tool_response"] is not None: self._current_thoughts.append( ToolPromptMessage( - content=tool_response["tool_response"], + content=str(tool_response["tool_response"]), tool_call_id=tool_call_id, name=tool_call_name, ) @@ -273,9 +280,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # save agent thought self.save_agent_thought( agent_thought=agent_thought, - tool_name=None, - tool_input=None, - thought=None, + tool_name="", + tool_input="", + thought="", tool_invoke_meta={ tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses }, @@ -283,7 +290,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): tool_response["tool_call_name"]: tool_response["tool_response"] for tool_response in tool_responses }, - answer=None, + answer="", messages_ids=message_file_ids, ) self.queue_manager.publish( @@ -296,7 +303,8 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): iteration_step += 1 - self.update_db_variables(self.variables_pool, self.db_variables_pool) + if self.variables_pool and self.db_variables_pool: + self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event self.queue_manager.publish( QueueMessageEndEvent( @@ -389,9 +397,9 @@ def _init_system_message( if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template: prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) - return prompt_messages + return prompt_messages or [] - def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: """ Organize user query """ @@ -449,7 +457,7 @@ def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage] def _organize_prompt_messages(self): prompt_template = self.app_config.prompt_template.simple_prompt_template or "" self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages) - query_prompt_messages = self._organize_user_query(self.query, []) + query_prompt_messages = self._organize_user_query(self.query or "", []) self.history_prompt_messages = AgentHistoryPromptTransform( model_config=self.model_config, diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index 085bac8601b2da..61fa774ea5f390 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -38,7 +38,7 @@ def parse_action(json_str): except: return json_str or "" - def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]: + def extra_json_from_code_block(code_block) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]: code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL) if not code_blocks: return @@ -67,15 +67,15 @@ def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, for response in llm_response: if response.delta.usage: usage_dict["usage"] = response.delta.usage - response = response.delta.message.content - if not isinstance(response, str): + response_content = response.delta.message.content + if not isinstance(response_content, str): continue # stream index = 0 - while index < len(response): + while index < len(response_content): steps = 1 - delta = response[index : index + steps] + delta = response_content[index : index + steps] yield_delta = False if delta == "`": diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index b9aae7904f5e7c..646c4badb9f73a 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -66,6 +66,8 @@ def convert(cls, config: dict) -> Optional[DatasetEntity]: dataset_configs = config.get("dataset_configs") else: dataset_configs = {"retrieval_model": "multiple"} + if dataset_configs is None: + return None query_variable = config.get("dataset_query_variable") if dataset_configs["retrieval_model"] == "single": diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index 5adcf26f1486e8..6426865115126f 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -94,7 +94,7 @@ def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> config["model"]["completion_params"] ) - return config, ["model"] + return dict(config), ["model"] @classmethod def validate_model_completion_params(cls, cp: dict) -> dict: diff --git a/api/core/app/app_config/features/opening_statement/manager.py b/api/core/app/app_config/features/opening_statement/manager.py index b4dacbc409044a..92b4185abf0183 100644 --- a/api/core/app/app_config/features/opening_statement/manager.py +++ b/api/core/app/app_config/features/opening_statement/manager.py @@ -7,10 +7,10 @@ def convert(cls, config: dict) -> tuple[str, list]: :param config: model config args """ # opening statement - opening_statement = config.get("opening_statement") + opening_statement = config.get("opening_statement", "") # suggested questions - suggested_questions_list = config.get("suggested_questions") + suggested_questions_list = config.get("suggested_questions", []) return opening_statement, suggested_questions_list diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 6200299d21c869..a18b40712b7ce6 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -29,6 +29,7 @@ from models.account import Account from models.model import App, Conversation, EndUser, Message from models.workflow import Workflow +from services.errors.message import MessageNotExistsError logger = logging.getLogger(__name__) @@ -145,7 +146,7 @@ def generate( user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id ), query=query, - files=file_objs, + files=list(file_objs), parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, user_id=user.id, stream=streaming, @@ -313,6 +314,8 @@ def _generate_worker( # get conversation and message conversation = self._get_conversation(conversation_id) message = self._get_message(message_id) + if message is None: + raise MessageNotExistsError("Message not exists") # chatbot app runner = AdvancedChatAppRunner( diff --git a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py index 29709914b7cfb8..a506447671abfb 100644 --- a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py +++ b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py @@ -5,6 +5,7 @@ import re import threading from collections.abc import Iterable +from typing import Optional from core.app.entities.queue_entities import ( MessageQueueMessage, @@ -15,6 +16,7 @@ WorkflowQueueMessage, ) from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.message_entities import TextPromptMessageContent from core.model_runtime.entities.model_entities import ModelType @@ -71,8 +73,9 @@ def __init__(self, tenant_id: str, voice: str): if not voice or voice not in values: self.voice = self.voices[0].get("value") self.MAX_SENTENCE = 2 - self._last_audio_event = None - self._runtime_thread = threading.Thread(target=self._runtime).start() + self._last_audio_event: Optional[AudioTrunk] = None + # FIXME better way to handle this threading.start + threading.Thread(target=self._runtime).start() self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3) def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /): @@ -92,10 +95,21 @@ def _runtime(self): future_queue.put(futures_result) break elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent): - self.msg_text += message.event.chunk.delta.message.content + message_content = message.event.chunk.delta.message.content + if not message_content: + continue + if isinstance(message_content, str): + self.msg_text += message_content + elif isinstance(message_content, list): + for content in message_content: + if not isinstance(content, TextPromptMessageContent): + continue + self.msg_text += content.data elif isinstance(message.event, QueueTextChunkEvent): self.msg_text += message.event.text elif isinstance(message.event, QueueNodeSucceededEvent): + if message.event.outputs is None: + continue self.msg_text += message.event.outputs.get("output", "") self.last_message = message sentence_arr, text_tmp = self._extract_sentence(self.msg_text) @@ -121,11 +135,10 @@ def check_and_get_audio(self): if self._last_audio_event and self._last_audio_event.status == "finish": if self.executor: self.executor.shutdown(wait=False) - return self.last_message + return self._last_audio_event audio = self._audio_queue.get_nowait() if audio and audio.status == "finish": self.executor.shutdown(wait=False) - self._runtime_thread = None if audio: self._last_audio_event = audio return audio diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index cf0c9d7593429a..6339d798984800 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -109,18 +109,18 @@ def run(self) -> None: ConversationVariable.conversation_id == self.conversation.id, ) with Session(db.engine) as session: - conversation_variables = session.scalars(stmt).all() - if not conversation_variables: + db_conversation_variables = session.scalars(stmt).all() + if not db_conversation_variables: # Create conversation variables if they don't exist. - conversation_variables = [ + db_conversation_variables = [ ConversationVariable.from_variable( app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable ) for variable in workflow.conversation_variables ] - session.add_all(conversation_variables) + session.add_all(db_conversation_variables) # Convert database entities to variables. - conversation_variables = [item.to_variable() for item in conversation_variables] + conversation_variables = [item.to_variable() for item in db_conversation_variables] session.commit() diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 635e482ad980ed..1073a0f2e4f706 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -2,6 +2,7 @@ import logging import time from collections.abc import Generator, Mapping +from threading import Thread from typing import Any, Optional, Union from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME @@ -64,6 +65,7 @@ from models.workflow import ( Workflow, WorkflowNodeExecution, + WorkflowRun, WorkflowRunStatus, ) @@ -81,6 +83,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc _user: Union[Account, EndUser] _workflow_system_variables: dict[SystemVariableKey, Any] _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] + _conversation_name_generate_thread: Optional[Thread] = None def __init__( self, @@ -131,7 +134,7 @@ def __init__( self._conversation_name_generate_thread = None self._recorded_files: list[Mapping[str, Any]] = [] - def process(self): + def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: """ Process generate task pipeline. :return: @@ -262,8 +265,8 @@ def _process_stream_response( :return: """ # init fake graph runtime state - graph_runtime_state = None - workflow_run = None + graph_runtime_state: Optional[GraphRuntimeState] = None + workflow_run: Optional[WorkflowRun] = None for queue_message in self._queue_manager.listen(): event = queue_message.event @@ -315,14 +318,14 @@ def _process_stream_response( workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) - response = self._workflow_node_start_to_stream_response( + response_start = self._workflow_node_start_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, ) - if response: - yield response + if response_start: + yield response_start elif isinstance(event, QueueNodeSucceededEvent): workflow_node_execution = self._handle_workflow_node_execution_success(event) @@ -330,18 +333,18 @@ def _process_stream_response( if event.node_type in [NodeType.ANSWER, NodeType.END]: self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {})) - response = self._workflow_node_finish_to_stream_response( + response_finish = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, ) - if response: - yield response + if response_finish: + yield response_finish elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): workflow_node_execution = self._handle_workflow_node_execution_failed(event) - response = self._workflow_node_finish_to_stream_response( + response_finish = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -609,7 +612,10 @@ def _message_end_to_stream_response(self) -> MessageEndStreamResponse: del extras["metadata"]["annotation_reply"] return MessageEndStreamResponse( - task_id=self._application_generate_entity.task_id, id=self._message.id, files=self._recorded_files, **extras + task_id=self._application_generate_entity.task_id, + id=self._message.id, + files=self._recorded_files, + metadata=extras.get("metadata", {}), ) def _handle_output_moderation_chunk(self, text: str) -> bool: diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index 417d23eccfb553..55b6ee510f228c 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -61,7 +61,7 @@ def get_app_config( app_model_config_dict = app_model_config.to_dict() config_dict = app_model_config_dict.copy() else: - config_dict = override_config_dict + config_dict = override_config_dict or {} app_mode = AppMode.value_of(app_model.mode) app_config = AgentChatAppConfig( diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index b391169e3dbe5c..63e11bdaa27f74 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -23,6 +23,7 @@ from extensions.ext_database import db from factories import file_factory from models import Account, App, EndUser +from services.errors.message import MessageNotExistsError logger = logging.getLogger(__name__) @@ -97,7 +98,7 @@ def generate( # get conversation conversation = None if args.get("conversation_id"): - conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user) + conversation = self._get_conversation_by_user(app_model, args.get("conversation_id", ""), user) # get app model config app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) @@ -153,7 +154,7 @@ def generate( user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id ), query=query, - files=file_objs, + files=list(file_objs), parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, user_id=user.id, stream=streaming, @@ -180,7 +181,7 @@ def generate( worker_thread = threading.Thread( target=self._generate_worker, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "application_generate_entity": application_generate_entity, "queue_manager": queue_manager, "conversation_id": conversation.id, @@ -199,8 +200,8 @@ def generate( user=user, stream=streaming, ) - - return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + # FIXME: Type hinting issue here, ignore it for now, will fix it later + return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) # type: ignore def _generate_worker( self, @@ -224,6 +225,8 @@ def _generate_worker( # get conversation and message conversation = self._get_conversation(conversation_id) message = self._get_message(message_id) + if message is None: + raise MessageNotExistsError("Message not exists") # chatbot app runner = AgentChatAppRunner() diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 45b1bf00934d35..ac71f02b6de03d 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -173,6 +173,8 @@ def run( return agent_entity = app_config.agent + if not agent_entity: + raise ValueError("Agent entity not found") # load tool variables tool_conversation_variables = self._load_tool_variables( @@ -200,14 +202,21 @@ def run( # change function call strategy based on LLM model llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) + if not model_schema or not model_schema.features: + raise ValueError("Model schema not found") if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []): agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING - conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first() - message = db.session.query(Message).filter(Message.id == message.id).first() + conversation_result = db.session.query(Conversation).filter(Conversation.id == conversation.id).first() + if conversation_result is None: + raise ValueError("Conversation not found") + message_result = db.session.query(Message).filter(Message.id == message.id).first() + if message_result is None: + raise ValueError("Message not found") db.session.close() + runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner] # start agent runner if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: # check LLM mode @@ -225,12 +234,12 @@ def run( runner = runner_cls( tenant_id=app_config.tenant_id, application_generate_entity=application_generate_entity, - conversation=conversation, + conversation=conversation_result, app_config=app_config, model_config=application_generate_entity.model_conf, config=agent_entity, queue_manager=queue_manager, - message=message, + message=message_result, user_id=application_generate_entity.user_id, memory=memory, prompt_messages=prompt_message, @@ -257,7 +266,7 @@ def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: st """ load tool variables from database """ - tool_variables: ToolConversationVariables = ( + tool_variables: ToolConversationVariables | None = ( db.session.query(ToolConversationVariables) .filter( ToolConversationVariables.conversation_id == conversation_id, diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py index 629c309c065458..ce331d904cc826 100644 --- a/api/core/app/apps/agent_chat/generate_response_converter.py +++ b/api/core/app/apps/agent_chat/generate_response_converter.py @@ -16,7 +16,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): _blocking_response_type = ChatbotAppBlockingResponse @classmethod - def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: + def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking full response. :param blocking_response: blocking response @@ -37,7 +37,7 @@ def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingRes return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: + def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking simple response. :param blocking_response: blocking response @@ -51,8 +51,9 @@ def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingR return response @classmethod - def convert_stream_full_response( - cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + def convert_stream_full_response( # type: ignore[override] + cls, + stream_response: Generator[ChatbotAppStreamResponse, None, None], ) -> Generator[str, None, None]: """ Convert stream full response. @@ -82,8 +83,9 @@ def convert_stream_full_response( yield json.dumps(response_chunk) @classmethod - def convert_stream_simple_response( - cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + def convert_stream_simple_response( # type: ignore[override] + cls, + stream_response: Generator[ChatbotAppStreamResponse, None, None], ) -> Generator[str, None, None]: """ Convert stream simple response. diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 3725c6e6ddc4fd..1842fc43033ab8 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -50,7 +50,7 @@ def listen(self): # wait for APP_MAX_EXECUTION_TIME seconds to stop listen listen_timeout = dify_config.APP_MAX_EXECUTION_TIME start_time = time.time() - last_ping_time = 0 + last_ping_time: int | float = 0 while True: try: message = self._q.get(timeout=1) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 609fd03f229da8..07a248d77aee86 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -1,5 +1,5 @@ import time -from collections.abc import Generator, Mapping +from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Optional, Union from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity @@ -36,8 +36,8 @@ def get_pre_calculate_rest_tokens( app_record: App, model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, - inputs: dict[str, str], - files: list["File"], + inputs: Mapping[str, str], + files: Sequence["File"], query: Optional[str] = None, ) -> int: """ @@ -64,7 +64,7 @@ def get_pre_calculate_rest_tokens( ): max_tokens = ( model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template) + or model_config.parameters.get(parameter_rule.use_template or "") ) or 0 if model_context_tokens is None: @@ -85,7 +85,7 @@ def get_pre_calculate_rest_tokens( prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) - rest_tokens = model_context_tokens - max_tokens - prompt_tokens + rest_tokens: int = model_context_tokens - max_tokens - prompt_tokens if rest_tokens < 0: raise InvokeBadRequestError( "Query or prefix prompt is too long, you can reduce the prefix prompt, " @@ -111,7 +111,7 @@ def recalc_llm_max_tokens( ): max_tokens = ( model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template) + or model_config.parameters.get(parameter_rule.use_template or "") ) or 0 if model_context_tokens is None: @@ -136,8 +136,8 @@ def organize_prompt_messages( app_record: App, model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, - inputs: dict[str, str], - files: list["File"], + inputs: Mapping[str, str], + files: Sequence["File"], query: Optional[str] = None, context: Optional[str] = None, memory: Optional[TokenBufferMemory] = None, @@ -156,6 +156,7 @@ def organize_prompt_messages( """ # get prompt without memory and context if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + prompt_transform: Union[SimplePromptTransform, AdvancedPromptTransform] prompt_transform = SimplePromptTransform() prompt_messages, stop = prompt_transform.get_prompt( app_mode=AppMode.value_of(app_record.mode), @@ -171,8 +172,11 @@ def organize_prompt_messages( memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) model_mode = ModelMode.value_of(model_config.mode) + prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]] if model_mode == ModelMode.COMPLETION: advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template + if not advanced_completion_prompt_template: + raise InvokeBadRequestError("Advanced completion prompt template is required.") prompt_template = CompletionModelPromptTemplate(text=advanced_completion_prompt_template.prompt) if advanced_completion_prompt_template.role_prefix: @@ -181,6 +185,8 @@ def organize_prompt_messages( assistant=advanced_completion_prompt_template.role_prefix.assistant, ) else: + if not prompt_template_entity.advanced_chat_prompt_template: + raise InvokeBadRequestError("Advanced chat prompt template is required.") prompt_template = [] for message in prompt_template_entity.advanced_chat_prompt_template.messages: prompt_template.append(ChatModelMessage(text=message.text, role=message.role)) @@ -246,7 +252,7 @@ def direct_output( def _handle_invoke_result( self, - invoke_result: Union[LLMResult, Generator], + invoke_result: Union[LLMResult, Generator[Any, None, None]], queue_manager: AppQueueManager, stream: bool, agent: bool = False, @@ -259,10 +265,12 @@ def _handle_invoke_result( :param agent: agent :return: """ - if not stream: + if not stream and isinstance(invoke_result, LLMResult): self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) - else: + elif stream and isinstance(invoke_result, Generator): self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) + else: + raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}") def _handle_invoke_result_direct( self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool @@ -291,8 +299,8 @@ def _handle_invoke_result_stream( :param agent: agent :return: """ - model = None - prompt_messages = [] + model: str = "" + prompt_messages: list[PromptMessage] = [] text = "" usage = None for result in invoke_result: @@ -328,13 +336,14 @@ def _handle_invoke_result_stream( def moderation_for_inputs( self, + *, app_id: str, tenant_id: str, app_generate_entity: AppGenerateEntity, inputs: Mapping[str, Any], - query: str, + query: str | None = None, message_id: str, - ) -> tuple[bool, dict, str]: + ) -> tuple[bool, Mapping[str, Any], str]: """ Process sensitive_word_avoidance. :param app_id: app id @@ -350,7 +359,7 @@ def moderation_for_inputs( app_id=app_id, tenant_id=tenant_id, app_config=app_generate_entity.app_config, - inputs=inputs, + inputs=dict(inputs), query=query or "", message_id=message_id, trace_manager=app_generate_entity.trace_manager, @@ -390,9 +399,9 @@ def fill_in_inputs_from_external_data_tools( tenant_id: str, app_id: str, external_data_tools: list[ExternalDataVariableEntity], - inputs: dict, + inputs: Mapping[str, Any], query: str, - ) -> dict: + ) -> Mapping[str, Any]: """ Fill in variable inputs from external data tools if exists. diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 5b8debaaae6a56..6ed71fcd843083 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -24,6 +24,7 @@ from factories import file_factory from models.account import Account from models.model import App, EndUser +from services.errors.message import MessageNotExistsError logger = logging.getLogger(__name__) @@ -91,7 +92,7 @@ def generate( # get conversation conversation = None if args.get("conversation_id"): - conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user) + conversation = self._get_conversation_by_user(app_model, args.get("conversation_id", ""), user) # get app model config app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) @@ -104,7 +105,7 @@ def generate( # validate config override_model_config_dict = ChatAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, config=args.get("model_config") + tenant_id=app_model.tenant_id, config=args.get("model_config", {}) ) # always enable retriever resource in debugger mode @@ -146,7 +147,7 @@ def generate( user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id ), query=query, - files=file_objs, + files=list(file_objs), parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, user_id=user.id, invoke_from=invoke_from, @@ -172,7 +173,7 @@ def generate( worker_thread = threading.Thread( target=self._generate_worker, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "application_generate_entity": application_generate_entity, "queue_manager": queue_manager, "conversation_id": conversation.id, @@ -216,6 +217,8 @@ def _generate_worker( # get conversation and message conversation = self._get_conversation(conversation_id) message = self._get_message(message_id) + if message is None: + raise MessageNotExistsError("Message not exists") # chatbot app runner = ChatAppRunner() diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py index 0fa7af0a7fa36d..9024c3a98273d1 100644 --- a/api/core/app/apps/chat/generate_response_converter.py +++ b/api/core/app/apps/chat/generate_response_converter.py @@ -16,7 +16,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): _blocking_response_type = ChatbotAppBlockingResponse @classmethod - def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: + def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking full response. :param blocking_response: blocking response @@ -37,7 +37,7 @@ def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingRes return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: + def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking simple response. :param blocking_response: blocking response @@ -52,7 +52,8 @@ def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingR @classmethod def convert_stream_full_response( - cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + cls, + stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[str, None, None]: """ Convert stream full response. @@ -83,7 +84,8 @@ def convert_stream_full_response( @classmethod def convert_stream_simple_response( - cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + cls, + stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[str, None, None]: """ Convert stream simple response. diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py index 1193c4b7a43632..02e5d475684cdc 100644 --- a/api/core/app/apps/completion/app_config_manager.py +++ b/api/core/app/apps/completion/app_config_manager.py @@ -42,7 +42,7 @@ def get_app_config( app_model_config_dict = app_model_config.to_dict() config_dict = app_model_config_dict.copy() else: - config_dict = override_config_dict + config_dict = override_config_dict or {} app_mode = AppMode.value_of(app_model.mode) app_config = CompletionAppConfig( diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 14fd33dd398927..17d0d52497ceee 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -83,8 +83,6 @@ def generate( query = query.replace("\x00", "") inputs = args["inputs"] - extras = {} - # get conversation conversation = None @@ -99,7 +97,7 @@ def generate( # validate config override_model_config_dict = CompletionAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, config=args.get("model_config") + tenant_id=app_model.tenant_id, config=args.get("model_config", {}) ) # parse files @@ -132,11 +130,11 @@ def generate( user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id ), query=query, - files=file_objs, + files=list(file_objs), user_id=user.id, stream=streaming, invoke_from=invoke_from, - extras=extras, + extras={}, trace_manager=trace_manager, ) @@ -157,7 +155,7 @@ def generate( worker_thread = threading.Thread( target=self._generate_worker, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "application_generate_entity": application_generate_entity, "queue_manager": queue_manager, "message_id": message.id, @@ -197,6 +195,8 @@ def _generate_worker( try: # get message message = self._get_message(message_id) + if message is None: + raise MessageNotExistsError() # chatbot app runner = CompletionAppRunner() @@ -231,7 +231,7 @@ def generate_more_like_this( user: Union[Account, EndUser], invoke_from: InvokeFrom, stream: bool = True, - ) -> Union[dict, Generator[str, None, None]]: + ) -> Union[Mapping[str, Any], Generator[str, None, None]]: """ Generate App response. @@ -293,7 +293,7 @@ def generate_more_like_this( model_conf=ModelConfigConverter.convert(app_config), inputs=message.inputs, query=message.query, - files=file_objs, + files=list(file_objs), user_id=user.id, stream=stream, invoke_from=invoke_from, @@ -317,7 +317,7 @@ def generate_more_like_this( worker_thread = threading.Thread( target=self._generate_worker, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "application_generate_entity": application_generate_entity, "queue_manager": queue_manager, "message_id": message.id, diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 908d74ff539a5a..41278b75b42bf4 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -76,7 +76,7 @@ def run( tenant_id=app_config.tenant_id, app_generate_entity=application_generate_entity, inputs=inputs, - query=query, + query=query or "", message_id=message.id, ) except ModerationError as e: @@ -122,7 +122,7 @@ def run( tenant_id=app_record.tenant_id, model_config=application_generate_entity.model_conf, config=dataset_config, - query=query, + query=query or "", invoke_from=application_generate_entity.invoke_from, show_retrieve_source=app_config.additional_features.show_retrieve_source, hit_callback=hit_callback, diff --git a/api/core/app/apps/completion/generate_response_converter.py b/api/core/app/apps/completion/generate_response_converter.py index 697f0273a5673e..73f38c3d0bcb96 100644 --- a/api/core/app/apps/completion/generate_response_converter.py +++ b/api/core/app/apps/completion/generate_response_converter.py @@ -16,7 +16,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): _blocking_response_type = CompletionAppBlockingResponse @classmethod - def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: + def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking full response. :param blocking_response: blocking response @@ -36,7 +36,7 @@ def convert_blocking_full_response(cls, blocking_response: CompletionAppBlocking return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: + def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking simple response. :param blocking_response: blocking response @@ -51,7 +51,8 @@ def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlocki @classmethod def convert_stream_full_response( - cls, stream_response: Generator[CompletionAppStreamResponse, None, None] + cls, + stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[str, None, None]: """ Convert stream full response. @@ -81,7 +82,8 @@ def convert_stream_full_response( @classmethod def convert_stream_simple_response( - cls, stream_response: Generator[CompletionAppStreamResponse, None, None] + cls, + stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[str, None, None]: """ Convert stream simple response. diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 95ae798ec1ac74..c2e35faf89ba15 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -2,11 +2,11 @@ import logging from collections.abc import Generator from datetime import UTC, datetime -from typing import Optional, Union +from typing import Optional, Union, cast from sqlalchemy import and_ -from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom +from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError from core.app.entities.app_invoke_entities import ( @@ -42,7 +42,7 @@ def _handle_response( ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, - AdvancedChatAppGenerateEntity, + AgentChatAppGenerateEntity, ], queue_manager: AppQueueManager, conversation: Conversation, @@ -144,7 +144,7 @@ def _init_generate_records( :conversation conversation :return: """ - app_config = application_generate_entity.app_config + app_config: EasyUIBasedAppConfig = cast(EasyUIBasedAppConfig, application_generate_entity.app_config) # get from source end_user_id = None @@ -267,7 +267,7 @@ def _get_conversation_introduction(self, application_generate_entity: AppGenerat except KeyError: pass - return introduction + return introduction or "" def _get_conversation(self, conversation_id: str): """ @@ -282,7 +282,7 @@ def _get_conversation(self, conversation_id: str): return conversation - def _get_message(self, message_id: str) -> Message: + def _get_message(self, message_id: str) -> Optional[Message]: """ Get message by message id :param message_id: message id diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index dc4ee9e566a2f3..1d5f21b9e0cc07 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -116,7 +116,7 @@ def generate( inputs=self._prepare_user_inputs( user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id ), - files=system_files, + files=list(system_files), user_id=user.id, stream=streaming, invoke_from=invoke_from, diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py index 08d00ee1805aa2..5cdac6ad28fdaa 100644 --- a/api/core/app/apps/workflow/generate_response_converter.py +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -17,16 +17,16 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): _blocking_response_type = WorkflowAppBlockingResponse @classmethod - def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: + def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking full response. :param blocking_response: blocking response :return: """ - return blocking_response.to_dict() + return dict(blocking_response.to_dict()) @classmethod - def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: + def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking simple response. :param blocking_response: blocking response @@ -36,7 +36,8 @@ def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlocking @classmethod def convert_stream_full_response( - cls, stream_response: Generator[WorkflowAppStreamResponse, None, None] + cls, + stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[str, None, None]: """ Convert stream full response. @@ -65,7 +66,8 @@ def convert_stream_full_response( @classmethod def convert_stream_simple_response( - cls, stream_response: Generator[WorkflowAppStreamResponse, None, None] + cls, + stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[str, None, None]: """ Convert stream simple response. diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 885283504b4175..63f516bcc60682 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -24,6 +24,7 @@ QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) +from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import ( GraphEngineEvent, @@ -190,16 +191,15 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count)) elif isinstance(event, NodeRunRetryEvent): node_run_result = event.route_node_state.node_run_result + inputs: Mapping[str, Any] | None = {} + process_data: Mapping[str, Any] | None = {} + outputs: Mapping[str, Any] | None = {} + execution_metadata: Mapping[NodeRunMetadataKey, Any] | None = {} if node_run_result: inputs = node_run_result.inputs process_data = node_run_result.process_data outputs = node_run_result.outputs execution_metadata = node_run_result.metadata - else: - inputs = {} - process_data = {} - outputs = {} - execution_metadata = {} self._publish_event( QueueNodeRetryEvent( node_execution_id=event.id, @@ -289,7 +289,7 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) process_data=event.route_node_state.node_run_result.process_data if event.route_node_state.node_run_result else {}, - outputs=event.route_node_state.node_run_result.outputs + outputs=event.route_node_state.node_run_result.outputs or {} if event.route_node_state.node_run_result else {}, error=event.route_node_state.node_run_result.error @@ -349,7 +349,7 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) process_data=event.route_node_state.node_run_result.process_data if event.route_node_state.node_run_result else {}, - outputs=event.route_node_state.node_run_result.outputs + outputs=event.route_node_state.node_run_result.outputs or {} if event.route_node_state.node_run_result else {}, execution_metadata=event.route_node_state.node_run_result.metadata diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 31c3a996e19286..16dc91bb777a9b 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator from constants import UUID_NIL -from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig +from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle from core.file import File, FileUploadConfig from core.model_runtime.entities.model_entities import AIModelEntity @@ -79,7 +79,7 @@ class AppGenerateEntity(BaseModel): task_id: str # app config - app_config: AppConfig + app_config: Any file_upload_config: Optional[FileUploadConfig] = None inputs: Mapping[str, Any] diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index d73c2eb53bfcd7..a93e533ff45d26 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -308,7 +308,7 @@ class QueueNodeSucceededEvent(AppQueueEvent): inputs: Optional[Mapping[str, Any]] = None process_data: Optional[Mapping[str, Any]] = None outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None + execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None error: Optional[str] = None """single iteration duration map""" diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index dd088a897816ef..5e845eba2da1d3 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -70,7 +70,7 @@ class StreamResponse(BaseModel): event: StreamEvent task_id: str - def to_dict(self) -> dict: + def to_dict(self): return jsonable_encoder(self) @@ -474,8 +474,8 @@ class Data(BaseModel): title: str created_at: int extras: dict = {} - metadata: dict = {} - inputs: dict = {} + metadata: Mapping = {} + inputs: Mapping = {} parallel_id: Optional[str] = None parallel_start_node_id: Optional[str] = None @@ -526,15 +526,15 @@ class Data(BaseModel): node_id: str node_type: str title: str - outputs: Optional[dict] = None + outputs: Optional[Mapping] = None created_at: int extras: Optional[dict] = None - inputs: Optional[dict] = None + inputs: Optional[Mapping] = None status: WorkflowNodeExecutionStatus error: Optional[str] = None elapsed_time: float total_tokens: int - execution_metadata: Optional[dict] = None + execution_metadata: Optional[Mapping] = None finished_at: int steps: int parallel_id: Optional[str] = None @@ -628,7 +628,7 @@ class AppBlockingResponse(BaseModel): task_id: str - def to_dict(self) -> dict: + def to_dict(self): return jsonable_encoder(self) diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 77b6bb554c65ec..83fd3debad4cf1 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -58,7 +58,7 @@ def query( query=query, top_k=1, score_threshold=score_threshold, filter={"group_id": [dataset.id]} ) - if documents: + if documents and documents[0].metadata: annotation_id = documents[0].metadata["annotation_id"] score = documents[0].metadata["score"] annotation = AppAnnotationService.get_annotation_by_id(annotation_id) diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index 8fe1d96b37be0c..dcc2b4e55f6ae1 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -17,7 +17,7 @@ class RateLimit: _UNLIMITED_REQUEST_ID = "unlimited_request_id" _REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes - _instance_dict = {} + _instance_dict: dict[str, "RateLimit"] = {} def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int): if client_id not in cls._instance_dict: diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 51d610e2cbedc6..03a81353d02625 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -62,6 +62,7 @@ def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = Non """ logger.debug("error: %s", event.error) e = event.error + err: Exception if isinstance(e, InvokeAuthorizationError): err = InvokeAuthorizationError("Incorrect API key provided") @@ -130,6 +131,7 @@ def _init_output_moderation(self) -> Optional[OutputModeration]: rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config), queue_manager=self._queue_manager, ) + return None def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]: """ diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index e26b60c4d3043e..b9f8e7ca560ce7 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -2,6 +2,7 @@ import logging import time from collections.abc import Generator +from threading import Thread from typing import Optional, Union, cast from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME @@ -103,7 +104,7 @@ def __init__( ) ) - self._conversation_name_generate_thread = None + self._conversation_name_generate_thread: Optional[Thread] = None def process( self, @@ -123,7 +124,7 @@ def process( if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: # start generate conversation name thread self._conversation_name_generate_thread = self._generate_conversation_name( - self._conversation, self._application_generate_entity.query + self._conversation, self._application_generate_entity.query or "" ) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) @@ -146,7 +147,7 @@ def _to_blocking_response( extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)} if self._task_state.metadata: extras["metadata"] = self._task_state.metadata - + response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse] if self._conversation.mode == AppMode.COMPLETION.value: response = CompletionAppBlockingResponse( task_id=self._application_generate_entity.task_id, @@ -154,7 +155,7 @@ def _to_blocking_response( id=self._message.id, mode=self._conversation.mode, message_id=self._message.id, - answer=self._task_state.llm_result.message.content, + answer=cast(str, self._task_state.llm_result.message.content), created_at=int(self._message.created_at.timestamp()), **extras, ), @@ -167,7 +168,7 @@ def _to_blocking_response( mode=self._conversation.mode, conversation_id=self._conversation.id, message_id=self._message.id, - answer=self._task_state.llm_result.message.content, + answer=cast(str, self._task_state.llm_result.message.content), created_at=int(self._message.created_at.timestamp()), **extras, ), @@ -252,7 +253,7 @@ def _wrapper_process_stream_response( yield MessageAudioEndStreamResponse(audio="", task_id=task_id) def _process_stream_response( - self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None + self, publisher: Optional[AppGeneratorTTSPublisher], trace_manager: Optional[TraceQueueManager] = None ) -> Generator[StreamResponse, None, None]: """ Process stream response. @@ -269,13 +270,14 @@ def _process_stream_response( break elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): if isinstance(event, QueueMessageEndEvent): - self._task_state.llm_result = event.llm_result + if event.llm_result: + self._task_state.llm_result = event.llm_result else: self._handle_stop(event) # handle output moderation output_moderation_answer = self._handle_output_moderation_when_task_finished( - self._task_state.llm_result.message.content + cast(str, self._task_state.llm_result.message.content) ) if output_moderation_answer: self._task_state.llm_result.message.content = output_moderation_answer @@ -292,7 +294,9 @@ def _process_stream_response( if annotation: self._task_state.llm_result.message.content = annotation.content elif isinstance(event, QueueAgentThoughtEvent): - yield self._agent_thought_to_stream_response(event) + agent_thought_response = self._agent_thought_to_stream_response(event) + if agent_thought_response is not None: + yield agent_thought_response elif isinstance(event, QueueMessageFileEvent): response = self._message_file_to_stream_response(event) if response: @@ -307,16 +311,18 @@ def _process_stream_response( self._task_state.llm_result.prompt_messages = chunk.prompt_messages # handle output moderation chunk - should_direct_answer = self._handle_output_moderation_chunk(delta_text) + should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text)) if should_direct_answer: continue - self._task_state.llm_result.message.content += delta_text + current_content = cast(str, self._task_state.llm_result.message.content) + current_content += cast(str, delta_text) + self._task_state.llm_result.message.content = current_content if isinstance(event, QueueLLMChunkEvent): - yield self._message_to_stream_response(delta_text, self._message.id) + yield self._message_to_stream_response(cast(str, delta_text), self._message.id) else: - yield self._agent_message_to_stream_response(delta_text, self._message.id) + yield self._agent_message_to_stream_response(cast(str, delta_text), self._message.id) elif isinstance(event, QueueMessageReplaceEvent): yield self._message_replace_to_stream_response(answer=event.text) elif isinstance(event, QueuePingEvent): @@ -336,8 +342,14 @@ def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> No llm_result = self._task_state.llm_result usage = llm_result.usage - self._message = db.session.query(Message).filter(Message.id == self._message.id).first() - self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() + message = db.session.query(Message).filter(Message.id == self._message.id).first() + if not message: + raise Exception(f"Message {self._message.id} not found") + self._message = message + conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() + if not conversation: + raise Exception(f"Conversation {self._conversation.id} not found") + self._conversation = conversation self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( self._model_config.mode, self._task_state.llm_result.prompt_messages @@ -346,7 +358,7 @@ def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> No self._message.message_unit_price = usage.prompt_unit_price self._message.message_price_unit = usage.prompt_price_unit self._message.answer = ( - PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) + PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip()) if llm_result.message.content else "" ) @@ -374,6 +386,7 @@ def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> No application_generate_entity=self._application_generate_entity, conversation=self._conversation, is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT} + and hasattr(self._application_generate_entity, "conversation_id") and self._application_generate_entity.conversation_id is None, extras=self._application_generate_entity.extras, ) @@ -420,7 +433,9 @@ def _message_end_to_stream_response(self) -> MessageEndStreamResponse: extras["metadata"] = self._task_state.metadata return MessageEndStreamResponse( - task_id=self._application_generate_entity.task_id, id=self._message.id, **extras + task_id=self._application_generate_entity.task_id, + id=self._message.id, + metadata=extras.get("metadata", {}), ) def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse: @@ -440,7 +455,7 @@ def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Op :param event: agent thought event :return: """ - agent_thought: MessageAgentThought = ( + agent_thought: Optional[MessageAgentThought] = ( db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first() ) db.session.refresh(agent_thought) diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py index e818a090ed7d0f..007543f6d0d1f2 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -128,7 +128,7 @@ def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Opti """ message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first() - if message_file: + if message_file and message_file.url is not None: # get tool file id tool_file_id = message_file.url.split("/")[-1] # trim extension diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index df7dbace0ef818..f581e564f224ce 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -93,7 +93,7 @@ def _handle_workflow_run_start(self) -> WorkflowRun: ) # handle special values - inputs = WorkflowEntry.handle_special_values(inputs) + inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) # init workflow run with Session(db.engine, expire_on_commit=False) as session: @@ -192,7 +192,7 @@ def _handle_workflow_run_partial_success( """ workflow_run = self._refetch_workflow_run(workflow_run.id) - outputs = WorkflowEntry.handle_special_values(outputs) + outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value workflow_run.outputs = json.dumps(outputs or {}) @@ -500,7 +500,7 @@ def _workflow_start_to_stream_response( id=workflow_run.id, workflow_id=workflow_run.workflow_id, sequence_number=workflow_run.sequence_number, - inputs=workflow_run.inputs_dict, + inputs=dict(workflow_run.inputs_dict or {}), created_at=int(workflow_run.created_at.timestamp()), ), ) @@ -545,7 +545,7 @@ def _workflow_finish_to_stream_response( workflow_id=workflow_run.workflow_id, sequence_number=workflow_run.sequence_number, status=workflow_run.status, - outputs=workflow_run.outputs_dict, + outputs=dict(workflow_run.outputs_dict) if workflow_run.outputs_dict else None, error=workflow_run.error, elapsed_time=workflow_run.elapsed_time, total_tokens=workflow_run.total_tokens, @@ -553,7 +553,7 @@ def _workflow_finish_to_stream_response( created_by=created_by, created_at=int(workflow_run.created_at.timestamp()), finished_at=int(workflow_run.finished_at.timestamp()), - files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict), + files=self._fetch_files_from_node_outputs(dict(workflow_run.outputs_dict)), exceptions_count=workflow_run.exceptions_count, ), ) @@ -655,7 +655,7 @@ def _workflow_node_retry_to_stream_response( event: QueueNodeRetryEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution, - ) -> Optional[NodeFinishStreamResponse]: + ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: """ Workflow node finish to stream response. :param event: queue node succeeded or failed event @@ -838,7 +838,7 @@ def _workflow_iteration_completed_to_stream_response( ), ) - def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping[str, Any]]: + def _fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any]) -> Sequence[Mapping[str, Any]]: """ Fetch files from node outputs :param outputs_dict: node outputs dict @@ -851,9 +851,11 @@ def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping # Remove None files = [file for file in files if file] # Flatten list - files = [file for sublist in files for file in sublist] + # Flatten the list of sequences into a single list of mappings + flattened_files = [file for sublist in files if sublist for file in sublist] - return files + # Convert to tuple to match Sequence type + return tuple(flattened_files) def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]: """ @@ -891,6 +893,8 @@ def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any elif isinstance(value, File): return value.to_dict() + return None + def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun: """ Refetch workflow run diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index d826edf6a0fc19..effc7eff9179ae 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -57,7 +57,7 @@ def on_tool_end( self, tool_name: str, tool_inputs: Mapping[str, Any], - tool_outputs: Sequence[ToolInvokeMessage], + tool_outputs: Sequence[ToolInvokeMessage] | str, message_id: Optional[str] = None, timer: Optional[Any] = None, trace_manager: Optional[TraceQueueManager] = None, diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 1481578630f63b..8f8aaa93d6f986 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -40,17 +40,18 @@ def on_query(self, query: str, dataset_id: str) -> None: def on_tool_end(self, documents: list[Document]) -> None: """Handle tool end.""" for document in documents: - query = db.session.query(DocumentSegment).filter( - DocumentSegment.index_node_id == document.metadata["doc_id"] - ) + if document.metadata is not None: + query = db.session.query(DocumentSegment).filter( + DocumentSegment.index_node_id == document.metadata["doc_id"] + ) - if "dataset_id" in document.metadata: - query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) + if "dataset_id" in document.metadata: + query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) - # add hit count to document segment - query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) + # add hit count to document segment + query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) - db.session.commit() + db.session.commit() def return_retriever_resource_info(self, resource: list): """Handle return_retriever_resource_info.""" diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 9ed5528e43b9b8..5017835565789c 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from enum import Enum from typing import Optional @@ -72,7 +73,7 @@ class DefaultModelProviderEntity(BaseModel): label: I18nObject icon_small: Optional[I18nObject] = None icon_large: Optional[I18nObject] = None - supported_model_types: list[ModelType] + supported_model_types: Sequence[ModelType] = [] class DefaultModelEntity(BaseModel): diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index d1b34db2fe7172..2e27b362d3092c 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -40,7 +40,7 @@ logger = logging.getLogger(__name__) -original_provider_configurate_methods = {} +original_provider_configurate_methods: dict[str, list[ConfigurateMethod]] = {} class ProviderConfiguration(BaseModel): @@ -99,7 +99,8 @@ def get_current_credentials(self, model_type: ModelType, model: str) -> Optional continue restrict_models = quota_configuration.restrict_models - + if self.system_configuration.credentials is None: + return None copy_credentials = self.system_configuration.credentials.copy() if restrict_models: for restrict_model in restrict_models: @@ -124,7 +125,7 @@ def get_current_credentials(self, model_type: ModelType, model: str) -> Optional return credentials - def get_system_configuration_status(self) -> SystemConfigurationStatus: + def get_system_configuration_status(self) -> Optional[SystemConfigurationStatus]: """ Get system configuration status. :return: @@ -136,6 +137,8 @@ def get_system_configuration_status(self) -> SystemConfigurationStatus: current_quota_configuration = next( (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None ) + if current_quota_configuration is None: + return None return ( SystemConfigurationStatus.ACTIVE @@ -150,7 +153,7 @@ def is_custom_configuration_available(self) -> bool: """ return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0 - def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]: + def get_custom_credentials(self, obfuscated: bool = False): """ Get custom credentials. @@ -172,7 +175,7 @@ def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]: else [], ) - def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]: + def custom_credentials_validate(self, credentials: dict) -> tuple[Optional[Provider], dict]: """ Validate custom credentials. :param credentials: provider credentials @@ -324,7 +327,7 @@ def get_custom_model_credentials( def custom_model_credentials_validate( self, model_type: ModelType, model: str, credentials: dict - ) -> tuple[ProviderModel, dict]: + ) -> tuple[Optional[ProviderModel], dict]: """ Validate custom model credentials. @@ -740,10 +743,10 @@ def get_provider_models( if model_type: model_types.append(model_type) else: - model_types = provider_instance.get_provider_schema().supported_model_types + model_types = list(provider_instance.get_provider_schema().supported_model_types) # Group model settings by model type and model - model_setting_map = defaultdict(dict) + model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict) for model_setting in self.model_settings: model_setting_map[model_setting.model_type][model_setting.model] = model_setting @@ -822,54 +825,57 @@ def _get_system_provider_models( ]: # only customizable model for restrict_model in restrict_models: - copy_credentials = self.system_configuration.credentials.copy() - if restrict_model.base_model_name: - copy_credentials["base_model_name"] = restrict_model.base_model_name - - try: - custom_model_schema = provider_instance.get_model_instance( - restrict_model.model_type - ).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials) - except Exception as ex: - logger.warning(f"get custom model schema failed, {ex}") - continue - - if not custom_model_schema: - continue - - if custom_model_schema.model_type not in model_types: - continue - - status = ModelStatus.ACTIVE - if ( - custom_model_schema.model_type in model_setting_map - and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] - ): - model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] - if model_setting.enabled is False: - status = ModelStatus.DISABLED - - provider_models.append( - ModelWithProviderEntity( - model=custom_model_schema.model, - label=custom_model_schema.label, - model_type=custom_model_schema.model_type, - features=custom_model_schema.features, - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties=custom_model_schema.model_properties, - deprecated=custom_model_schema.deprecated, - provider=SimpleModelProviderEntity(self.provider), - status=status, + if self.system_configuration.credentials is not None: + copy_credentials = self.system_configuration.credentials.copy() + if restrict_model.base_model_name: + copy_credentials["base_model_name"] = restrict_model.base_model_name + + try: + custom_model_schema = provider_instance.get_model_instance( + restrict_model.model_type + ).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials) + except Exception as ex: + logger.warning(f"get custom model schema failed, {ex}") + continue + + if not custom_model_schema: + continue + + if custom_model_schema.model_type not in model_types: + continue + + status = ModelStatus.ACTIVE + if ( + custom_model_schema.model_type in model_setting_map + and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] + ): + model_setting = model_setting_map[custom_model_schema.model_type][ + custom_model_schema.model + ] + if model_setting.enabled is False: + status = ModelStatus.DISABLED + + provider_models.append( + ModelWithProviderEntity( + model=custom_model_schema.model, + label=custom_model_schema.label, + model_type=custom_model_schema.model_type, + features=custom_model_schema.features, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties=custom_model_schema.model_properties, + deprecated=custom_model_schema.deprecated, + provider=SimpleModelProviderEntity(self.provider), + status=status, + ) ) - ) # if llm name not in restricted llm list, remove it restrict_model_names = [rm.model for rm in restrict_models] - for m in provider_models: - if m.model_type == ModelType.LLM and m.model not in restrict_model_names: - m.status = ModelStatus.NO_PERMISSION + for model in provider_models: + if model.model_type == ModelType.LLM and m.model not in restrict_model_names: + model.status = ModelStatus.NO_PERMISSION elif not quota_configuration.is_valid: - m.status = ModelStatus.QUOTA_EXCEEDED + model.status = ModelStatus.QUOTA_EXCEEDED return provider_models @@ -1043,7 +1049,7 @@ def __iter__(self): return iter(self.configurations) def values(self) -> Iterator[ProviderConfiguration]: - return self.configurations.values() + return iter(self.configurations.values()) def get(self, key, default=None): return self.configurations.get(key, default) diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py index 38cebb6b6b1c36..3f4e20ec245302 100644 --- a/api/core/extension/api_based_extension_requestor.py +++ b/api/core/extension/api_based_extension_requestor.py @@ -1,3 +1,5 @@ +from typing import cast + import requests from configs import dify_config @@ -5,7 +7,7 @@ class APIBasedExtensionRequestor: - timeout: (int, int) = (5, 60) + timeout: tuple[int, int] = (5, 60) """timeout for request connect and read""" def __init__(self, api_endpoint: str, api_key: str) -> None: @@ -51,4 +53,4 @@ def request(self, point: APIBasedExtensionPoint, params: dict) -> dict: "request error, status_code: {}, content: {}".format(response.status_code, response.text[:100]) ) - return response.json() + return cast(dict, response.json()) diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index 97dbaf2026e790..231743bf2a948c 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -38,8 +38,8 @@ def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None: @classmethod def scan_extensions(cls): - extensions: list[ModuleExtension] = [] - position_map = {} + extensions = [] + position_map: dict[str, int] = {} # get the path of the current class current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py") @@ -58,7 +58,8 @@ def scan_extensions(cls): # is builtin extension, builtin extension # in the front-end page and business logic, there are special treatments. builtin = False - position = None + # default position is 0 can not be None for sort_to_dict_by_position_map + position = 0 if "__builtin__" in file_names: builtin = True @@ -89,7 +90,7 @@ def scan_extensions(cls): logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.") continue - json_data = {} + json_data: dict[str, Any] = {} if not builtin: if "schema.json" not in file_names: logging.warning(f"Missing schema.json file in {subdir_path}, Skip.") diff --git a/api/core/extension/extension.py b/api/core/extension/extension.py index 3da170455e3398..9eb9e0306b577f 100644 --- a/api/core/extension/extension.py +++ b/api/core/extension/extension.py @@ -1,4 +1,6 @@ -from core.extension.extensible import ExtensionModule, ModuleExtension +from typing import cast + +from core.extension.extensible import Extensible, ExtensionModule, ModuleExtension from core.external_data_tool.base import ExternalDataTool from core.moderation.base import Moderation @@ -10,7 +12,8 @@ class Extension: def init(self): for module, module_class in self.module_classes.items(): - self.__module_extensions[module.value] = module_class.scan_extensions() + m = cast(Extensible, module_class) + self.__module_extensions[module.value] = m.scan_extensions() def module_extensions(self, module: str) -> list[ModuleExtension]: module_extensions = self.__module_extensions.get(module) @@ -35,7 +38,8 @@ def module_extension(self, module: ExtensionModule, extension_name: str) -> Modu def extension_class(self, module: ExtensionModule, extension_name: str) -> type: module_extension = self.module_extension(module, extension_name) - return module_extension.extension_class + t: type = module_extension.extension_class + return t def validate_form_schema(self, module: ExtensionModule, extension_name: str, config: dict) -> None: module_extension = self.module_extension(module, extension_name) diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index 54ec97a4933a94..9989c8a09013bd 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -48,7 +48,10 @@ def query(self, inputs: dict, query: Optional[str] = None) -> str: :return: the tool query result """ # get params from config + if not self.config: + raise ValueError("config is required, config: {}".format(self.config)) api_based_extension_id = self.config.get("api_based_extension_id") + assert api_based_extension_id is not None, "api_based_extension_id is required" # get api_based_extension api_based_extension = ( diff --git a/api/core/external_data_tool/external_data_fetch.py b/api/core/external_data_tool/external_data_fetch.py index 84b94e117ff5f9..6a9703a569b308 100644 --- a/api/core/external_data_tool/external_data_fetch.py +++ b/api/core/external_data_tool/external_data_fetch.py @@ -1,7 +1,7 @@ -import concurrent import logging -from concurrent.futures import ThreadPoolExecutor -from typing import Optional +from collections.abc import Mapping +from concurrent.futures import Future, ThreadPoolExecutor, as_completed +from typing import Any, Optional from flask import Flask, current_app @@ -17,9 +17,9 @@ def fetch( tenant_id: str, app_id: str, external_data_tools: list[ExternalDataVariableEntity], - inputs: dict, + inputs: Mapping[str, Any], query: str, - ) -> dict: + ) -> Mapping[str, Any]: """ Fill in variable inputs from external data tools if exists. @@ -30,13 +30,14 @@ def fetch( :param query: the query :return: the filled inputs """ - results = {} + results: dict[str, Any] = {} + inputs = dict(inputs) with ThreadPoolExecutor() as executor: futures = {} for tool in external_data_tools: - future = executor.submit( + future: Future[tuple[str | None, str | None]] = executor.submit( self._query_external_data_tool, - current_app._get_current_object(), + current_app._get_current_object(), # type: ignore tenant_id, app_id, tool, @@ -46,9 +47,10 @@ def fetch( futures[future] = tool - for future in concurrent.futures.as_completed(futures): + for future in as_completed(futures): tool_variable, result = future.result() - results[tool_variable] = result + if tool_variable is not None: + results[tool_variable] = result inputs.update(results) return inputs @@ -59,7 +61,7 @@ def _query_external_data_tool( tenant_id: str, app_id: str, external_data_tool: ExternalDataVariableEntity, - inputs: dict, + inputs: Mapping[str, Any], query: str, ) -> tuple[Optional[str], Optional[str]]: """ diff --git a/api/core/external_data_tool/factory.py b/api/core/external_data_tool/factory.py index 28721098594962..245507e17c7032 100644 --- a/api/core/external_data_tool/factory.py +++ b/api/core/external_data_tool/factory.py @@ -1,4 +1,5 @@ -from typing import Optional +from collections.abc import Mapping +from typing import Any, Optional, cast from core.extension.extensible import ExtensionModule from extensions.ext_code_based_extension import code_based_extension @@ -23,9 +24,10 @@ def validate_config(cls, name: str, tenant_id: str, config: dict) -> None: """ code_based_extension.validate_form_schema(ExtensionModule.EXTERNAL_DATA_TOOL, name, config) extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) - extension_class.validate_config(tenant_id, config) + # FIXME mypy issue here, figure out how to fix it + extension_class.validate_config(tenant_id, config) # type: ignore - def query(self, inputs: dict, query: Optional[str] = None) -> str: + def query(self, inputs: Mapping[str, Any], query: Optional[str] = None) -> str: """ Query the external data tool. @@ -33,4 +35,4 @@ def query(self, inputs: dict, query: Optional[str] = None) -> str: :param query: the query of chat app :return: the tool query result """ - return self.__extension_instance.query(inputs, query) + return cast(str, self.__extension_instance.query(inputs, query)) diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index 15eb351a7ef309..4a50fb85c9cca3 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -1,4 +1,5 @@ import base64 +from collections.abc import Mapping from configs import dify_config from core.helper import ssrf_proxy @@ -55,7 +56,7 @@ def to_prompt_message_content( if f.type == FileType.IMAGE: params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW - prompt_class_map = { + prompt_class_map: Mapping[FileType, type[MultiModalPromptMessageContent]] = { FileType.IMAGE: ImagePromptMessageContent, FileType.AUDIO: AudioPromptMessageContent, FileType.VIDEO: VideoPromptMessageContent, @@ -63,7 +64,7 @@ def to_prompt_message_content( } try: - return prompt_class_map[f.type](**params) + return prompt_class_map[f.type].model_validate(params) except KeyError: raise ValueError(f"file type {f.type} is not supported") diff --git a/api/core/file/tool_file_parser.py b/api/core/file/tool_file_parser.py index a17b7be3675ab1..6fa101cf36192b 100644 --- a/api/core/file/tool_file_parser.py +++ b/api/core/file/tool_file_parser.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast if TYPE_CHECKING: from core.tools.tool_file_manager import ToolFileManager @@ -9,4 +9,4 @@ class ToolFileParser: @staticmethod def get_tool_file_manager() -> "ToolFileManager": - return tool_file_manager["manager"] + return cast("ToolFileManager", tool_file_manager["manager"]) diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 584e3e9698a88d..15b501780e766c 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -38,7 +38,7 @@ class CodeLanguage(StrEnum): class CodeExecutor: - dependencies_cache = {} + dependencies_cache: dict[str, str] = {} dependencies_cache_lock = Lock() code_template_transformers: dict[CodeLanguage, type[TemplateTransformer]] = { @@ -103,19 +103,19 @@ def execute_code(cls, language: CodeLanguage, preload: str, code: str) -> str: ) try: - response = response.json() + response_data = response.json() except: raise CodeExecutionError("Failed to parse response") - if (code := response.get("code")) != 0: - raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response.get('message')}") + if (code := response_data.get("code")) != 0: + raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response_data.get('message')}") - response = CodeExecutionResponse(**response) + response_code = CodeExecutionResponse(**response_data) - if response.data.error: - raise CodeExecutionError(response.data.error) + if response_code.data.error: + raise CodeExecutionError(response_code.data.error) - return response.data.stdout or "" + return response_code.data.stdout or "" @classmethod def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]): diff --git a/api/core/helper/code_executor/jinja2/jinja2_formatter.py b/api/core/helper/code_executor/jinja2/jinja2_formatter.py index db2eb5ebb6b19a..264947b5686d0e 100644 --- a/api/core/helper/code_executor/jinja2/jinja2_formatter.py +++ b/api/core/helper/code_executor/jinja2/jinja2_formatter.py @@ -1,9 +1,11 @@ +from collections.abc import Mapping + from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage class Jinja2Formatter: @classmethod - def format(cls, template: str, inputs: dict) -> str: + def format(cls, template: str, inputs: Mapping[str, str]) -> str: """ Format template :param template: template @@ -11,5 +13,4 @@ def format(cls, template: str, inputs: dict) -> str: :return: """ result = CodeExecutor.execute_workflow_code_template(language=CodeLanguage.JINJA2, code=template, inputs=inputs) - - return result["result"] + return str(result.get("result", "")) diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index 605719747a7b56..baa792b5bc6c41 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -29,8 +29,7 @@ def extract_result_str_from_response(cls, response: str): result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL) if not result: raise ValueError("Failed to parse result") - result = result.group(1) - return result + return result.group(1) @classmethod def transform_response(cls, response: str) -> Mapping[str, Any]: diff --git a/api/core/helper/lru_cache.py b/api/core/helper/lru_cache.py index 518962c1652df7..81501d2e4e23b2 100644 --- a/api/core/helper/lru_cache.py +++ b/api/core/helper/lru_cache.py @@ -4,7 +4,7 @@ class LRUCache: def __init__(self, capacity: int): - self.cache = OrderedDict() + self.cache: OrderedDict[Any, Any] = OrderedDict() self.capacity = capacity def get(self, key: Any) -> Any: diff --git a/api/core/helper/model_provider_cache.py b/api/core/helper/model_provider_cache.py index 5e274f8916869d..35349210bd53ab 100644 --- a/api/core/helper/model_provider_cache.py +++ b/api/core/helper/model_provider_cache.py @@ -30,7 +30,7 @@ def get(self) -> Optional[dict]: except JSONDecodeError: return None - return cached_provider_credentials + return dict(cached_provider_credentials) else: return None diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index da0fd0031cc6dc..543444463b9f1a 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -22,6 +22,7 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) provider_name = model_config.provider if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers: hosting_openai_config = hosting_configuration.provider_map["openai"] + assert hosting_openai_config is not None # 2000 text per chunk length = 2000 @@ -34,8 +35,9 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) try: model_type_instance = OpenAIModerationModel() + # FIXME, for type hint using assert or raise ValueError is better here? moderation_result = model_type_instance.invoke( - model="text-moderation-stable", credentials=hosting_openai_config.credentials, text=text_chunk + model="text-moderation-stable", credentials=hosting_openai_config.credentials or {}, text=text_chunk ) if moderation_result is True: diff --git a/api/core/helper/module_import_helper.py b/api/core/helper/module_import_helper.py index 1e2fefce88b632..9a041667e46df5 100644 --- a/api/core/helper/module_import_helper.py +++ b/api/core/helper/module_import_helper.py @@ -14,12 +14,13 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz if existed_spec: spec = existed_spec if not spec.loader: - raise Exception(f"Failed to load module {module_name} from {py_file_path}") + raise Exception(f"Failed to load module {module_name} from {py_file_path!r}") else: # Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly - spec = importlib.util.spec_from_file_location(module_name, py_file_path) + # FIXME: mypy does not support the type of spec.loader + spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore if not spec or not spec.loader: - raise Exception(f"Failed to load module {module_name} from {py_file_path}") + raise Exception(f"Failed to load module {module_name} from {py_file_path!r}") if use_lazy_loader: # Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports spec.loader = importlib.util.LazyLoader(spec.loader) @@ -29,7 +30,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz spec.loader.exec_module(module) return module except Exception as e: - logging.exception(f"Failed to load module {module_name} from script file '{py_file_path}'") + logging.exception(f"Failed to load module {module_name} from script file '{py_file_path!r}'") raise e @@ -57,6 +58,6 @@ def load_single_subclass_from_source( case 1: return subclasses[0] case 0: - raise Exception(f"Missing subclass of {parent_type.__name__} in {script_path}") + raise Exception(f"Missing subclass of {parent_type.__name__} in {script_path!r}") case _: - raise Exception(f"Multiple subclasses of {parent_type.__name__} in {script_path}") + raise Exception(f"Multiple subclasses of {parent_type.__name__} in {script_path!r}") diff --git a/api/core/helper/tool_parameter_cache.py b/api/core/helper/tool_parameter_cache.py index e848b46c5633ab..3b67b3f84838d3 100644 --- a/api/core/helper/tool_parameter_cache.py +++ b/api/core/helper/tool_parameter_cache.py @@ -33,7 +33,7 @@ def get(self) -> Optional[dict]: except JSONDecodeError: return None - return cached_tool_parameter + return dict(cached_tool_parameter) else: return None diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py index 94b02cf98578b1..6de5e704abf4f5 100644 --- a/api/core/helper/tool_provider_cache.py +++ b/api/core/helper/tool_provider_cache.py @@ -28,7 +28,7 @@ def get(self) -> Optional[dict]: except JSONDecodeError: return None - return cached_provider_credentials + return dict(cached_provider_credentials) else: return None diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index b47ba67f2fa64f..f9fb7275f3624f 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -42,7 +42,7 @@ class HostedModerationConfig(BaseModel): class HostingConfiguration: provider_map: dict[str, HostingProvider] = {} - moderation_config: HostedModerationConfig = None + moderation_config: Optional[HostedModerationConfig] = None def init_app(self, app: Flask) -> None: if dify_config.EDITION != "CLOUD": @@ -67,7 +67,7 @@ def init_azure_openai() -> HostingProvider: "base_model_name": "gpt-35-turbo", } - quotas = [] + quotas: list[HostingQuota] = [] hosted_quota_limit = dify_config.HOSTED_AZURE_OPENAI_QUOTA_LIMIT trial_quota = TrialHostingQuota( quota_limit=hosted_quota_limit, @@ -123,7 +123,7 @@ def init_azure_openai() -> HostingProvider: def init_openai(self) -> HostingProvider: quota_unit = QuotaUnit.CREDITS - quotas = [] + quotas: list[HostingQuota] = [] if dify_config.HOSTED_OPENAI_TRIAL_ENABLED: hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT @@ -157,7 +157,7 @@ def init_openai(self) -> HostingProvider: @staticmethod def init_anthropic() -> HostingProvider: quota_unit = QuotaUnit.TOKENS - quotas = [] + quotas: list[HostingQuota] = [] if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED: hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT @@ -187,7 +187,7 @@ def init_anthropic() -> HostingProvider: def init_minimax() -> HostingProvider: quota_unit = QuotaUnit.TOKENS if dify_config.HOSTED_MINIMAX_ENABLED: - quotas = [FreeHostingQuota()] + quotas: list[HostingQuota] = [FreeHostingQuota()] return HostingProvider( enabled=True, @@ -205,7 +205,7 @@ def init_minimax() -> HostingProvider: def init_spark() -> HostingProvider: quota_unit = QuotaUnit.TOKENS if dify_config.HOSTED_SPARK_ENABLED: - quotas = [FreeHostingQuota()] + quotas: list[HostingQuota] = [FreeHostingQuota()] return HostingProvider( enabled=True, @@ -223,7 +223,7 @@ def init_spark() -> HostingProvider: def init_zhipuai() -> HostingProvider: quota_unit = QuotaUnit.TOKENS if dify_config.HOSTED_ZHIPUAI_ENABLED: - quotas = [FreeHostingQuota()] + quotas: list[HostingQuota] = [FreeHostingQuota()] return HostingProvider( enabled=True, diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 29e161cb747284..1f0a0d0ef1dda4 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -6,10 +6,10 @@ import threading import time import uuid -from typing import Optional, cast +from typing import Any, Optional, cast from flask import Flask, current_app -from flask_login import current_user +from flask_login import current_user # type: ignore from sqlalchemy.orm.exc import ObjectDeletedError from configs import dify_config @@ -62,6 +62,8 @@ def run(self, dataset_documents: list[DatasetDocument]): .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .first() ) + if not processing_rule: + raise ValueError("no process rule found") index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() # extract @@ -120,6 +122,8 @@ def run_in_splitting_status(self, dataset_document: DatasetDocument): .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .first() ) + if not processing_rule: + raise ValueError("no process rule found") index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -254,7 +258,7 @@ def indexing_estimate( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) - preview_texts = [] + preview_texts: list[str] = [] total_segments = 0 index_type = doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -285,7 +289,8 @@ def indexing_estimate( for upload_file_id in image_upload_file_ids: image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() try: - storage.delete(image_file.key) + if image_file: + storage.delete(image_file.key) except Exception: logging.exception( "Delete image_files failed while indexing_estimate, \ @@ -379,8 +384,9 @@ def _extract( # replace doc id to document model id text_docs = cast(list[Document], text_docs) for text_doc in text_docs: - text_doc.metadata["document_id"] = dataset_document.id - text_doc.metadata["dataset_id"] = dataset_document.dataset_id + if text_doc.metadata is not None: + text_doc.metadata["document_id"] = dataset_document.id + text_doc.metadata["dataset_id"] = dataset_document.dataset_id return text_docs @@ -400,6 +406,7 @@ def _get_splitter( """ Get the NodeParser object according to the processing rule. """ + character_splitter: TextSplitter if processing_rule.mode == "custom": # The user-defined segmentation rule rules = json.loads(processing_rule.rules) @@ -426,9 +433,10 @@ def _get_splitter( ) else: # Automatic segmentation + automatic_rules: dict[str, Any] = dict(DatasetProcessRule.AUTOMATIC_RULES["segmentation"]) character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( - chunk_size=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["max_tokens"], - chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["chunk_overlap"], + chunk_size=automatic_rules["max_tokens"], + chunk_overlap=automatic_rules["chunk_overlap"], separators=["\n\n", "。", ". ", " ", ""], embedding_model_instance=embedding_model_instance, ) @@ -497,8 +505,8 @@ def _split_to_documents( """ Split the text documents into nodes. """ - all_documents = [] - all_qa_documents = [] + all_documents: list[Document] = [] + all_qa_documents: list[Document] = [] for text_doc in text_docs: # document clean document_text = self._document_clean(text_doc.page_content, processing_rule) @@ -509,10 +517,11 @@ def _split_to_documents( split_documents = [] for document_node in documents: if document_node.page_content.strip(): - doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata["doc_id"] = doc_id - document_node.metadata["doc_hash"] = hash + if document_node.metadata is not None: + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(document_node.page_content) + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash # delete Splitter character page_content = document_node.page_content document_node.page_content = remove_leading_symbols(page_content) @@ -529,7 +538,7 @@ def _split_to_documents( document_format_thread = threading.Thread( target=self.format_qa_document, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "tenant_id": tenant_id, "document_node": doc, "all_qa_documents": all_qa_documents, @@ -557,11 +566,12 @@ def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, al qa_document = Document( page_content=result["question"], metadata=document_node.metadata.model_copy() ) - doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(result["question"]) - qa_document.metadata["answer"] = result["answer"] - qa_document.metadata["doc_id"] = doc_id - qa_document.metadata["doc_hash"] = hash + if qa_document.metadata is not None: + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(result["question"]) + qa_document.metadata["answer"] = result["answer"] + qa_document.metadata["doc_id"] = doc_id + qa_document.metadata["doc_hash"] = hash qa_documents.append(qa_document) format_documents.extend(qa_documents) except Exception as e: @@ -575,7 +585,7 @@ def _split_to_documents_for_estimate( """ Split the text documents into nodes. """ - all_documents = [] + all_documents: list[Document] = [] for text_doc in text_docs: # document clean document_text = self._document_clean(text_doc.page_content, processing_rule) @@ -588,11 +598,11 @@ def _split_to_documents_for_estimate( for document in documents: if document.page_content is None or not document.page_content.strip(): continue - doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(document.page_content) - - document.metadata["doc_id"] = doc_id - document.metadata["doc_hash"] = hash + if document.metadata is not None: + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(document.page_content) + document.metadata["doc_id"] = doc_id + document.metadata["doc_hash"] = hash split_documents.append(document) @@ -648,7 +658,7 @@ def _load( # create keyword index create_keyword_thread = threading.Thread( target=self._process_keyword_index, - args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), + args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore ) create_keyword_thread.start() if dataset.indexing_technique == "high_quality": @@ -659,7 +669,7 @@ def _load( futures.append( executor.submit( self._process_chunk, - current_app._get_current_object(), + current_app._get_current_object(), # type: ignore index_processor, chunk_documents, dataset, diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 3a92c8d9d22562..9fe3f68f2a8af5 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -1,7 +1,7 @@ import json import logging import re -from typing import Optional +from typing import Optional, cast from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser @@ -13,6 +13,7 @@ WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, ) from core.model_manager import ModelManager +from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError @@ -44,10 +45,13 @@ def generate_conversation_name( prompts = [UserPromptMessage(content=prompt)] with measure_time() as timer: - response = model_instance.invoke_llm( - prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False + ), ) - answer = response.message.content + answer = cast(str, response.message.content) cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL) if cleaned_answer is None: return "" @@ -94,11 +98,16 @@ def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: st prompt_messages = [UserPromptMessage(content=prompt)] try: - response = model_instance.invoke_llm( - prompt_messages=prompt_messages, model_parameters={"max_tokens": 256, "temperature": 0}, stream=False + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters={"max_tokens": 256, "temperature": 0}, + stream=False, + ), ) - questions = output_parser.parse(response.message.content) + questions = output_parser.parse(cast(str, response.message.content)) except InvokeError: questions = [] except Exception as e: @@ -138,11 +147,14 @@ def generate_rule_config( ) try: - response = model_instance.invoke_llm( - prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + ), ) - rule_config["prompt"] = response.message.content + rule_config["prompt"] = cast(str, response.message.content) except InvokeError as e: error = str(e) @@ -178,15 +190,18 @@ def generate_rule_config( model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider") if model_config else None, - model=model_config.get("name") if model_config else None, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), ) try: try: # the first step to generate the task prompt - prompt_content = model_instance.invoke_llm( - prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + prompt_content = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + ), ) except InvokeError as e: error = str(e) @@ -195,8 +210,10 @@ def generate_rule_config( return rule_config - rule_config["prompt"] = prompt_content.message.content + rule_config["prompt"] = cast(str, prompt_content.message.content) + if not isinstance(prompt_content.message.content, str): + raise NotImplementedError("prompt content is not a string") parameter_generate_prompt = parameter_template.format( inputs={ "INPUT_TEXT": prompt_content.message.content, @@ -216,19 +233,25 @@ def generate_rule_config( statement_messages = [UserPromptMessage(content=statement_generate_prompt)] try: - parameter_content = model_instance.invoke_llm( - prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False + parameter_content = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False + ), ) - rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.content) + rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content)) except InvokeError as e: error = str(e) error_step = "generate variables" try: - statement_content = model_instance.invoke_llm( - prompt_messages=statement_messages, model_parameters=model_parameters, stream=False + statement_content = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=statement_messages, model_parameters=model_parameters, stream=False + ), ) - rule_config["opening_statement"] = statement_content.message.content + rule_config["opening_statement"] = cast(str, statement_content.message.content) except InvokeError as e: error = str(e) error_step = "generate conversation opener" @@ -267,19 +290,22 @@ def generate_code( model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider") if model_config else None, - model=model_config.get("name") if model_config else None, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), ) prompt_messages = [UserPromptMessage(content=prompt)] model_parameters = {"max_tokens": max_tokens, "temperature": 0.01} try: - response = model_instance.invoke_llm( - prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + ), ) - generated_code = response.message.content + generated_code = cast(str, response.message.content) return {"code": generated_code, "language": code_language, "error": ""} except InvokeError as e: @@ -303,9 +329,14 @@ def generate_qa_document(cls, tenant_id: str, query, document_language: str): prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)] - response = model_instance.invoke_llm( - prompt_messages=prompt_messages, model_parameters={"temperature": 0.01, "max_tokens": 2000}, stream=False + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters={"temperature": 0.01, "max_tokens": 2000}, + stream=False, + ), ) - answer = response.message.content + answer = cast(str, response.message.content) return answer.strip() diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 81d08dc8854f80..003a0c85b1f12e 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -68,7 +68,7 @@ def get_history_prompt_messages( messages = list(reversed(thread_messages)) - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] for message in messages: files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() if files: diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 1986688551b601..d1e71148cd6023 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -124,17 +124,20 @@ def invoke_llm( raise Exception("Model type instance is not LargeLanguageModel") self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - callbacks=callbacks, + return cast( + Union[LLMResult, Generator], + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + callbacks=callbacks, + ), ) def get_llm_num_tokens( @@ -151,12 +154,15 @@ def get_llm_num_tokens( raise Exception("Model type instance is not LargeLanguageModel") self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.get_num_tokens, - model=self.model, - credentials=self.credentials, - prompt_messages=prompt_messages, - tools=tools, + return cast( + int, + self._round_robin_invoke( + function=self.model_type_instance.get_num_tokens, + model=self.model, + credentials=self.credentials, + prompt_messages=prompt_messages, + tools=tools, + ), ) def invoke_text_embedding( @@ -174,13 +180,16 @@ def invoke_text_embedding( raise Exception("Model type instance is not TextEmbeddingModel") self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - texts=texts, - user=user, - input_type=input_type, + return cast( + TextEmbeddingResult, + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + texts=texts, + user=user, + input_type=input_type, + ), ) def get_text_embedding_num_tokens(self, texts: list[str]) -> int: @@ -194,11 +203,14 @@ def get_text_embedding_num_tokens(self, texts: list[str]) -> int: raise Exception("Model type instance is not TextEmbeddingModel") self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.get_num_tokens, - model=self.model, - credentials=self.credentials, - texts=texts, + return cast( + int, + self._round_robin_invoke( + function=self.model_type_instance.get_num_tokens, + model=self.model, + credentials=self.credentials, + texts=texts, + ), ) def invoke_rerank( @@ -223,15 +235,18 @@ def invoke_rerank( raise Exception("Model type instance is not RerankModel") self.model_type_instance = cast(RerankModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - query=query, - docs=docs, - score_threshold=score_threshold, - top_n=top_n, - user=user, + return cast( + RerankResult, + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, + user=user, + ), ) def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool: @@ -246,12 +261,15 @@ def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool: raise Exception("Model type instance is not ModerationModel") self.model_type_instance = cast(ModerationModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - text=text, - user=user, + return cast( + bool, + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + text=text, + user=user, + ), ) def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str: @@ -266,12 +284,15 @@ def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str raise Exception("Model type instance is not Speech2TextModel") self.model_type_instance = cast(Speech2TextModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - file=file, - user=user, + return cast( + str, + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + file=file, + user=user, + ), ) def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> Iterable[bytes]: @@ -288,17 +309,20 @@ def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Option raise Exception("Model type instance is not TTSModel") self.model_type_instance = cast(TTSModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - content_text=content_text, - user=user, - tenant_id=tenant_id, - voice=voice, + return cast( + Iterable[bytes], + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + content_text=content_text, + user=user, + tenant_id=tenant_id, + voice=voice, + ), ) - def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs): + def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs) -> Any: """ Round-robin invoke :param function: function to invoke @@ -309,7 +333,7 @@ def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs): if not self.load_balancing_manager: return function(*args, **kwargs) - last_exception = None + last_exception: Union[InvokeRateLimitError, InvokeAuthorizationError, InvokeConnectionError, None] = None while True: lb_config = self.load_balancing_manager.fetch_next() if not lb_config: @@ -463,7 +487,7 @@ def fetch_next(self) -> Optional[ModelLoadBalancingConfiguration]: if real_index > max_index: real_index = 0 - config = self._load_balancing_configs[real_index] + config: ModelLoadBalancingConfiguration = self._load_balancing_configs[real_index] if self.in_cooldown(config): cooldown_load_balancing_configs.append(config) @@ -507,8 +531,7 @@ def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool: self._tenant_id, self._provider, self._model_type.value, self._model, config.id ) - res = redis_client.exists(cooldown_cache_key) - res = cast(bool, res) + res: bool = redis_client.exists(cooldown_cache_key) return res @staticmethod diff --git a/api/core/model_runtime/callbacks/logging_callback.py b/api/core/model_runtime/callbacks/logging_callback.py index 3b6b825244dfdc..1f21a2d3763c4a 100644 --- a/api/core/model_runtime/callbacks/logging_callback.py +++ b/api/core/model_runtime/callbacks/logging_callback.py @@ -1,7 +1,8 @@ import json import logging import sys -from typing import Optional +from collections.abc import Sequence +from typing import Optional, cast from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk @@ -20,7 +21,7 @@ def on_before_invoke( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: @@ -76,7 +77,7 @@ def on_new_chunk( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ): @@ -94,7 +95,7 @@ def on_new_chunk( :param stream: is stream response :param user: unique user id """ - sys.stdout.write(chunk.delta.message.content) + sys.stdout.write(cast(str, chunk.delta.message.content)) sys.stdout.flush() def on_after_invoke( @@ -106,7 +107,7 @@ def on_after_invoke( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: @@ -147,7 +148,7 @@ def on_invoke_error( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 0efe46f87d6de9..2f682ceef578dc 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -3,7 +3,7 @@ from enum import Enum, StrEnum from typing import Optional -from pydantic import BaseModel, Field, computed_field, field_validator +from pydantic import BaseModel, Field, field_validator class PromptMessageRole(Enum): @@ -89,7 +89,6 @@ class MultiModalPromptMessageContent(PromptMessageContent): url: str = Field(default="", description="the url of multi-modal file") mime_type: str = Field(default=..., description="the mime type of multi-modal file") - @computed_field(return_type=str) @property def data(self): return self.url or f"data:{self.mime_type};base64,{self.base64_data}" diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 79a1d28ebe637e..e2b95603379348 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -1,7 +1,6 @@ import decimal import os from abc import ABC, abstractmethod -from collections.abc import Mapping from typing import Optional from pydantic import ConfigDict @@ -36,7 +35,7 @@ class AIModel(ABC): model_config = ConfigDict(protected_namespaces=()) @abstractmethod - def validate_credentials(self, model: str, credentials: Mapping) -> None: + def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate model credentials @@ -214,7 +213,7 @@ def predefined_models(self) -> list[AIModelEntity]: return model_schemas - def get_model_schema(self, model: str, credentials: Optional[Mapping] = None) -> Optional[AIModelEntity]: + def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]: """ Get model schema by model name and credentials @@ -236,9 +235,7 @@ def get_model_schema(self, model: str, credentials: Optional[Mapping] = None) -> return None - def get_customizable_model_schema_from_credentials( - self, model: str, credentials: Mapping - ) -> Optional[AIModelEntity]: + def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ Get customizable model schema from credentials @@ -248,7 +245,7 @@ def get_customizable_model_schema_from_credentials( """ return self._get_customizable_model_schema(model, credentials) - def _get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]: + def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ Get customizable model schema and fill in the template """ @@ -301,7 +298,7 @@ def _get_customizable_model_schema(self, model: str, credentials: Mapping) -> Op return schema - def get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]: + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ Get customizable model schema diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 8faeffa872b40f..402a30376b7546 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -2,7 +2,7 @@ import re import time from abc import abstractmethod -from collections.abc import Generator, Mapping, Sequence +from collections.abc import Generator, Sequence from typing import Optional, Union from pydantic import ConfigDict @@ -48,7 +48,7 @@ def invoke( prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -291,12 +291,12 @@ def _code_block_mode_stream_processor( content = piece.delta.message.content piece.delta.message.content = "" yield piece - piece = content + content_piece = content else: yield piece continue new_piece: str = "" - for char in piece: + for char in content_piece: char = str(char) if state == "normal": if char == "`": @@ -350,7 +350,7 @@ def _code_block_mode_stream_processor_with_backtick( piece.delta.message.content = "" # Yield a piece with cleared content before processing it to maintain the generator structure yield piece - piece = content + content_piece = content else: # Yield pieces without content directly yield piece @@ -360,7 +360,7 @@ def _code_block_mode_stream_processor_with_backtick( continue new_piece: str = "" - for char in piece: + for char in content_piece: if state == "search_start": if char == "`": backtick_count += 1 @@ -535,7 +535,7 @@ def get_parameter_rules(self, model: str, credentials: dict) -> list[ParameterRu return [] - def get_model_mode(self, model: str, credentials: Optional[Mapping] = None) -> LLMMode: + def get_model_mode(self, model: str, credentials: Optional[dict] = None) -> LLMMode: """ Get model mode diff --git a/api/core/model_runtime/model_providers/__base/model_provider.py b/api/core/model_runtime/model_providers/__base/model_provider.py index 4374093de4ab38..36e3e7bd557163 100644 --- a/api/core/model_runtime/model_providers/__base/model_provider.py +++ b/api/core/model_runtime/model_providers/__base/model_provider.py @@ -104,9 +104,10 @@ def get_model_instance(self, model_type: ModelType) -> AIModel: mod = import_module_from_source( module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path ) + # FIXME "type" has no attribute "__abstractmethods__" ignore it for now fix it later model_class = next( filter( - lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__, + lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__, # type: ignore get_subclasses_from_module(mod, AIModel), ), None, diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py index 2d38fba955fb86..33135129082b1d 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -89,7 +89,8 @@ def _get_context_size(self, model: str, credentials: dict) -> int: model_schema = self.get_model_schema(model, credentials) if model_schema and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE] + content_size: int = model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE] + return content_size return 1000 @@ -104,6 +105,7 @@ def _get_max_chunks(self, model: str, credentials: dict) -> int: model_schema = self.get_model_schema(model, credentials) if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + return max_chunks return 1 diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py index 5fe6dda6ad5d79..6dab0aaf2d41e7 100644 --- a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py +++ b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py @@ -2,9 +2,9 @@ from threading import Lock from typing import Any -from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer +from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore -_tokenizer = None +_tokenizer: Any = None _lock = Lock() diff --git a/api/core/model_runtime/model_providers/__base/tts_model.py b/api/core/model_runtime/model_providers/__base/tts_model.py index b394ea4e9d22fe..6ce316b137abb4 100644 --- a/api/core/model_runtime/model_providers/__base/tts_model.py +++ b/api/core/model_runtime/model_providers/__base/tts_model.py @@ -127,7 +127,8 @@ def _get_model_audio_type(self, model: str, credentials: dict) -> str: if not model_schema or ModelPropertyKey.AUDIO_TYPE not in model_schema.model_properties: raise ValueError("this model does not support audio type") - return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE] + audio_type: str = model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE] + return audio_type def _get_model_word_limit(self, model: str, credentials: dict) -> int: """ @@ -138,8 +139,9 @@ def _get_model_word_limit(self, model: str, credentials: dict) -> int: if not model_schema or ModelPropertyKey.WORD_LIMIT not in model_schema.model_properties: raise ValueError("this model does not support word limit") + world_limit: int = model_schema.model_properties[ModelPropertyKey.WORD_LIMIT] - return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT] + return world_limit def _get_model_workers_limit(self, model: str, credentials: dict) -> int: """ @@ -150,8 +152,9 @@ def _get_model_workers_limit(self, model: str, credentials: dict) -> int: if not model_schema or ModelPropertyKey.MAX_WORKERS not in model_schema.model_properties: raise ValueError("this model does not support max workers") + workers_limit: int = model_schema.model_properties[ModelPropertyKey.MAX_WORKERS] - return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS] + return workers_limit @staticmethod def _split_text_into_sentences(org_text, max_length=2000, pattern=r"[。.!?]"): diff --git a/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py index a2b14cf3dbe6d4..4aa09e61fd3599 100644 --- a/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py @@ -64,10 +64,12 @@ def _speech2text_invoke(self, model: str, credentials: dict, file: IO[bytes]) -> def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model) + if not ai_model_entity: + return None return ai_model_entity.entity @staticmethod - def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel: + def _get_ai_model_entity(base_model_name: str, model: str) -> Optional[AzureBaseModel]: for ai_model_entity in SPEECH2TEXT_BASE_MODELS: if ai_model_entity.base_model_name == base_model_name: ai_model_entity_copy = copy.deepcopy(ai_model_entity) diff --git a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py index 173b9d250c1743..6d50ba9163984f 100644 --- a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py @@ -114,6 +114,8 @@ def _process_sentence(self, sentence: str, model: str, voice, credentials: dict) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model) + if not ai_model_entity: + return None return ai_model_entity.entity @staticmethod diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index 75ed7ad62404cb..29bd673d576fc9 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -6,9 +6,9 @@ from typing import Optional, Union, cast # 3rd import -import boto3 -from botocore.config import Config -from botocore.exceptions import ( +import boto3 # type: ignore +from botocore.config import Config # type: ignore +from botocore.exceptions import ( # type: ignore ClientError, EndpointConnectionError, NoRegionError, diff --git a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py index aba8fedbc097e5..3a0a241f7ea0c0 100644 --- a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py @@ -44,7 +44,7 @@ def _invoke( :return: rerank result """ if len(docs) == 0: - return RerankResult(model=model, docs=docs) + return RerankResult(model=model, docs=[]) # initialize client client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) @@ -62,7 +62,7 @@ def _invoke( # format document rerank_document = RerankDocument( index=result.index, - text=result.document.text, + text=result.document.text if result.document else "", score=result.relevance_score, ) diff --git a/api/core/model_runtime/model_providers/fireworks/_common.py b/api/core/model_runtime/model_providers/fireworks/_common.py index 378ced3a4019ba..38d0a9dfbcadee 100644 --- a/api/core/model_runtime/model_providers/fireworks/_common.py +++ b/api/core/model_runtime/model_providers/fireworks/_common.py @@ -1,5 +1,3 @@ -from collections.abc import Mapping - import openai from core.model_runtime.errors.invoke import ( @@ -13,7 +11,7 @@ class _CommonFireworks: - def _to_credential_kwargs(self, credentials: Mapping) -> dict: + def _to_credential_kwargs(self, credentials: dict) -> dict: """ Transform credentials to kwargs for model instance diff --git a/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py index c745a7e978f4be..4c036283893fcc 100644 --- a/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py @@ -1,5 +1,4 @@ import time -from collections.abc import Mapping from typing import Optional, Union import numpy as np @@ -93,7 +92,7 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int """ return sum(self._get_num_tokens_by_gpt2(text) for text in texts) - def validate_credentials(self, model: str, credentials: Mapping) -> None: + def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate model credentials diff --git a/api/core/model_runtime/model_providers/gitee_ai/_common.py b/api/core/model_runtime/model_providers/gitee_ai/_common.py index 0750f3b75d0542..ad6600faf7bc15 100644 --- a/api/core/model_runtime/model_providers/gitee_ai/_common.py +++ b/api/core/model_runtime/model_providers/gitee_ai/_common.py @@ -1,4 +1,4 @@ -from dashscope.common.error import ( +from dashscope.common.error import ( # type: ignore AuthenticationError, InvalidParameter, RequestFailure, diff --git a/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py b/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py index 832ba927406c4c..737d3d5c931221 100644 --- a/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Optional import httpx @@ -51,7 +51,7 @@ def _invoke( base_url = base_url.removesuffix("/") try: - body = {"model": model, "query": query, "documents": docs} + body: dict[str, Any] = {"model": model, "query": query, "documents": docs} if top_n is not None: body["top_n"] = top_n response = httpx.post( diff --git a/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py index b833c5652c650a..a1fa89c5b34af6 100644 --- a/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py @@ -24,7 +24,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: super().validate_credentials(model, credentials) @staticmethod - def _add_custom_parameters(credentials: dict, model: str) -> None: + def _add_custom_parameters(credentials: dict, model: Optional[str]) -> None: if model is None: model = "bge-m3" diff --git a/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py b/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py index 36dcea405d0974..dc91257daf9d4e 100644 --- a/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py +++ b/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Optional import requests @@ -13,9 +13,10 @@ class GiteeAIText2SpeechModel(_CommonGiteeAI, TTSModel): Model class for OpenAI text2speech model. """ + # FIXME this Any return will be better type def _invoke( self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None - ) -> any: + ) -> Any: """ _invoke text2speech model @@ -47,7 +48,8 @@ def validate_credentials(self, model: str, credentials: dict) -> None: except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: + # FIXME this Any return will be better type + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> Any: """ _tts_invoke_streaming text2speech model :param model: model name diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index 7d19ccbb74a011..98273f60a41190 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -7,7 +7,7 @@ from typing import Optional, Union import google.ai.generativelanguage as glm -import google.generativeai as genai +import google.generativeai as genai # type: ignore import requests from google.api_core import exceptions from google.generativeai.types import ContentType, File, GenerateContentResponse diff --git a/api/core/model_runtime/model_providers/huggingface_hub/_common.py b/api/core/model_runtime/model_providers/huggingface_hub/_common.py index 3c4020b6eedf24..d8a09265e21059 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/_common.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/_common.py @@ -1,4 +1,4 @@ -from huggingface_hub.utils import BadRequestError, HfHubHTTPError +from huggingface_hub.utils import BadRequestError, HfHubHTTPError # type: ignore from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError diff --git a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py index 9d29237fdde573..cdb4103cd83712 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py @@ -1,9 +1,9 @@ from collections.abc import Generator from typing import Optional, Union -from huggingface_hub import InferenceClient -from huggingface_hub.hf_api import HfApi -from huggingface_hub.utils import BadRequestError +from huggingface_hub import InferenceClient # type: ignore +from huggingface_hub.hf_api import HfApi # type: ignore +from huggingface_hub.utils import BadRequestError # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE diff --git a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py index 8278d1e64def89..4ca5379405f4e6 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ import numpy as np import requests -from huggingface_hub import HfApi, InferenceClient +from huggingface_hub import HfApi, InferenceClient # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/llm.py b/api/core/model_runtime/model_providers/hunyuan/llm/llm.py index 2014de8516bc11..2dd45f065d5e26 100644 --- a/api/core/model_runtime/model_providers/hunyuan/llm/llm.py +++ b/api/core/model_runtime/model_providers/hunyuan/llm/llm.py @@ -3,11 +3,11 @@ from collections.abc import Generator from typing import cast -from tencentcloud.common import credential -from tencentcloud.common.exception import TencentCloudSDKException -from tencentcloud.common.profile.client_profile import ClientProfile -from tencentcloud.common.profile.http_profile import HttpProfile -from tencentcloud.hunyuan.v20230901 import hunyuan_client, models +from tencentcloud.common import credential # type: ignore +from tencentcloud.common.exception import TencentCloudSDKException # type: ignore +from tencentcloud.common.profile.client_profile import ClientProfile # type: ignore +from tencentcloud.common.profile.http_profile import HttpProfile # type: ignore +from tencentcloud.hunyuan.v20230901 import hunyuan_client, models # type: ignore from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -305,7 +305,7 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str: elif isinstance(message, ToolPromptMessage): message_text = f"{tool_prompt} {content}" elif isinstance(message, SystemPromptMessage): - message_text = content + message_text = content if isinstance(content, str) else "" else: raise ValueError(f"Got unknown type {message}") diff --git a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py index b6d857cb37cba0..856cda90d35a22 100644 --- a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py @@ -3,11 +3,11 @@ import time from typing import Optional -from tencentcloud.common import credential -from tencentcloud.common.exception import TencentCloudSDKException -from tencentcloud.common.profile.client_profile import ClientProfile -from tencentcloud.common.profile.http_profile import HttpProfile -from tencentcloud.hunyuan.v20230901 import hunyuan_client, models +from tencentcloud.common import credential # type: ignore +from tencentcloud.common.exception import TencentCloudSDKException # type: ignore +from tencentcloud.common.profile.client_profile import ClientProfile # type: ignore +from tencentcloud.common.profile.http_profile import HttpProfile # type: ignore +from tencentcloud.hunyuan.v20230901 import hunyuan_client, models # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py index d80cbfa83d6425..1fc0f8c028ba92 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py @@ -1,11 +1,11 @@ from os.path import abspath, dirname, join from threading import Lock -from transformers import AutoTokenizer +from transformers import AutoTokenizer # type: ignore class JinaTokenizer: - _tokenizer = None + _tokenizer: AutoTokenizer | None = None _lock = Lock() @classmethod diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py index 88cc0e8e0f32d0..357631b2dba0b9 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py @@ -40,7 +40,7 @@ def generate( url = f"https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}" - extra_kwargs = {} + extra_kwargs: dict[str, Any] = {} if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int: extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"] @@ -117,19 +117,19 @@ def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage: """ handle chat generate response """ - response = response.json() - if "base_resp" in response and response["base_resp"]["status_code"] != 0: - code = response["base_resp"]["status_code"] - msg = response["base_resp"]["status_msg"] + response_data = response.json() + if "base_resp" in response_data and response_data["base_resp"]["status_code"] != 0: + code = response_data["base_resp"]["status_code"] + msg = response_data["base_resp"]["status_msg"] self._handle_error(code, msg) - message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value) + message = MinimaxMessage(content=response_data["reply"], role=MinimaxMessage.Role.ASSISTANT.value) message.usage = { "prompt_tokens": 0, - "completion_tokens": response["usage"]["total_tokens"], - "total_tokens": response["usage"]["total_tokens"], + "completion_tokens": response_data["usage"]["total_tokens"], + "total_tokens": response_data["usage"]["total_tokens"], } - message.stop_reason = response["choices"][0]["finish_reason"] + message.stop_reason = response_data["choices"][0]["finish_reason"] return message def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]: @@ -139,10 +139,10 @@ def _handle_stream_chat_generate_response(self, response: Response) -> Generator for line in response.iter_lines(): if not line: continue - line: str = line.decode("utf-8") - if line.startswith("data: "): - line = line[6:].strip() - data = loads(line) + line_str: str = line.decode("utf-8") + if line_str.startswith("data: "): + line_str = line_str[6:].strip() + data = loads(line_str) if "base_resp" in data and data["base_resp"]["status_code"] != 0: code = data["base_resp"]["status_code"] @@ -162,5 +162,5 @@ def _handle_stream_chat_generate_response(self, response: Response) -> Generator continue for choice in choices: - message = choice["delta"] - yield MinimaxMessage(content=message, role=MinimaxMessage.Role.ASSISTANT.value) + message_choice = choice["delta"] + yield MinimaxMessage(content=message_choice, role=MinimaxMessage.Role.ASSISTANT.value) diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py index 8b8fdbb6bdf558..284b61829f9729 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py @@ -41,7 +41,7 @@ def generate( url = f"https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}" - extra_kwargs = {} + extra_kwargs: dict[str, Any] = {} if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int: extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"] @@ -122,19 +122,19 @@ def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage: """ handle chat generate response """ - response = response.json() - if "base_resp" in response and response["base_resp"]["status_code"] != 0: - code = response["base_resp"]["status_code"] - msg = response["base_resp"]["status_msg"] + response_data = response.json() + if "base_resp" in response_data and response_data["base_resp"]["status_code"] != 0: + code = response_data["base_resp"]["status_code"] + msg = response_data["base_resp"]["status_msg"] self._handle_error(code, msg) - message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value) + message = MinimaxMessage(content=response_data["reply"], role=MinimaxMessage.Role.ASSISTANT.value) message.usage = { "prompt_tokens": 0, - "completion_tokens": response["usage"]["total_tokens"], - "total_tokens": response["usage"]["total_tokens"], + "completion_tokens": response_data["usage"]["total_tokens"], + "total_tokens": response_data["usage"]["total_tokens"], } - message.stop_reason = response["choices"][0]["finish_reason"] + message.stop_reason = response_data["choices"][0]["finish_reason"] return message def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]: @@ -144,10 +144,10 @@ def _handle_stream_chat_generate_response(self, response: Response) -> Generator for line in response.iter_lines(): if not line: continue - line: str = line.decode("utf-8") - if line.startswith("data: "): - line = line[6:].strip() - data = loads(line) + line_str: str = line.decode("utf-8") + if line_str.startswith("data: "): + line_str = line_str[6:].strip() + data = loads(line_str) if "base_resp" in data and data["base_resp"]["status_code"] != 0: code = data["base_resp"]["status_code"] diff --git a/api/core/model_runtime/model_providers/minimax/llm/types.py b/api/core/model_runtime/model_providers/minimax/llm/types.py index 88ebe5e2e00e7a..c248db374a2504 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/types.py +++ b/api/core/model_runtime/model_providers/minimax/llm/types.py @@ -11,9 +11,9 @@ class Role(Enum): role: str = Role.USER.value content: str - usage: dict[str, int] = None + usage: dict[str, int] | None = None stop_reason: str = "" - function_call: dict[str, Any] = None + function_call: dict[str, Any] | None = None def to_dict(self) -> dict[str, Any]: if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value: diff --git a/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py index 56a707333c40e9..8a4c19d4d8f71b 100644 --- a/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py @@ -2,8 +2,8 @@ from functools import wraps from typing import Optional -from nomic import embed -from nomic import login as nomic_login +from nomic import embed # type: ignore +from nomic import login as nomic_login # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType diff --git a/api/core/model_runtime/model_providers/oci/llm/llm.py b/api/core/model_runtime/model_providers/oci/llm/llm.py index 1e1fc5b3ea89aa..9f676573fc2ece 100644 --- a/api/core/model_runtime/model_providers/oci/llm/llm.py +++ b/api/core/model_runtime/model_providers/oci/llm/llm.py @@ -5,8 +5,8 @@ from collections.abc import Generator from typing import Optional, Union -import oci -from oci.generative_ai_inference.models.base_chat_response import BaseChatResponse +import oci # type: ignore +from oci.generative_ai_inference.models.base_chat_response import BaseChatResponse # type: ignore from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py index 50fa63768c241b..5a428c9fed0466 100644 --- a/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional import numpy as np -import oci +import oci # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType diff --git a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py index 83c4facc8db76c..3543fe58bb68d2 100644 --- a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py @@ -61,6 +61,7 @@ def _invoke( headers = {"Content-Type": "application/json"} endpoint_url = credentials.get("base_url") + assert endpoint_url is not None, "Base URL is required for Ollama API" if not endpoint_url.endswith("/"): endpoint_url += "/" diff --git a/api/core/model_runtime/model_providers/openai/_common.py b/api/core/model_runtime/model_providers/openai/_common.py index 2181bb4f08fd8f..ac2b3e6881c740 100644 --- a/api/core/model_runtime/model_providers/openai/_common.py +++ b/api/core/model_runtime/model_providers/openai/_common.py @@ -1,5 +1,3 @@ -from collections.abc import Mapping - import openai from httpx import Timeout @@ -14,7 +12,7 @@ class _CommonOpenAI: - def _to_credential_kwargs(self, credentials: Mapping) -> dict: + def _to_credential_kwargs(self, credentials: dict) -> dict: """ Transform credentials to kwargs for model instance diff --git a/api/core/model_runtime/model_providers/openai/moderation/moderation.py b/api/core/model_runtime/model_providers/openai/moderation/moderation.py index 619044d808cdf6..227e4b0c152a05 100644 --- a/api/core/model_runtime/model_providers/openai/moderation/moderation.py +++ b/api/core/model_runtime/model_providers/openai/moderation/moderation.py @@ -93,7 +93,8 @@ def _get_max_characters_per_chunk(self, model: str, credentials: dict) -> int: model_schema = self.get_model_schema(model, credentials) if model_schema and ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK] + max_characters_per_chunk: int = model_schema.model_properties[ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK] + return max_characters_per_chunk return 2000 @@ -108,6 +109,7 @@ def _get_max_chunks(self, model: str, credentials: dict) -> int: model_schema = self.get_model_schema(model, credentials) if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + return max_chunks return 1 diff --git a/api/core/model_runtime/model_providers/openai/openai.py b/api/core/model_runtime/model_providers/openai/openai.py index aa6f38ce9fae5a..c546441af61d9b 100644 --- a/api/core/model_runtime/model_providers/openai/openai.py +++ b/api/core/model_runtime/model_providers/openai/openai.py @@ -1,5 +1,4 @@ import logging -from collections.abc import Mapping from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError @@ -9,7 +8,7 @@ class OpenAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: Mapping) -> None: + def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials if validate failed, raise exception diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py index a490537e51a6ad..74229a089aa45e 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py @@ -33,6 +33,7 @@ def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional headers["Authorization"] = f"Bearer {api_key}" endpoint_url = credentials.get("endpoint_url") + assert endpoint_url is not None, "endpoint_url is required in credentials" if not endpoint_url.endswith("/"): endpoint_url += "/" endpoint_url = urljoin(endpoint_url, "audio/transcriptions") diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py index 9da8f55d0a7ed9..b4d6c6c6ca9942 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py @@ -55,6 +55,7 @@ def _invoke( headers["Authorization"] = f"Bearer {api_key}" endpoint_url = credentials.get("endpoint_url") + assert endpoint_url is not None, "endpoint_url is required in credentials" if not endpoint_url.endswith("/"): endpoint_url += "/" diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/tts/tts.py b/api/core/model_runtime/model_providers/openai_api_compatible/tts/tts.py index 8239c625f7ada8..53e895b0ecb376 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/tts/tts.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/tts/tts.py @@ -44,6 +44,7 @@ def _invoke( # Construct endpoint URL endpoint_url = credentials.get("endpoint_url") + assert endpoint_url is not None, "endpoint_url is required in credentials" if not endpoint_url.endswith("/"): endpoint_url += "/" endpoint_url = urljoin(endpoint_url, "audio/speech") diff --git a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py index 2789a9250a1d35..e9509b544d9f4e 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py +++ b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py @@ -1,7 +1,7 @@ from collections.abc import Generator from enum import Enum from json import dumps, loads -from typing import Any, Union +from typing import Any, Optional, Union from requests import Response, post from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema @@ -20,7 +20,7 @@ class Role(Enum): role: str = Role.USER.value content: str - usage: dict[str, int] = None + usage: Optional[dict[str, int]] = None stop_reason: str = "" def to_dict(self) -> dict[str, Any]: @@ -165,17 +165,17 @@ def _handle_chat_stream_generate_response( if not line: continue - line: str = line.decode("utf-8") - if line.startswith("data: "): - line = line[6:].strip() + line_str: str = line.decode("utf-8") + if line_str.startswith("data: "): + line_str = line_str[6:].strip() - if line == "[DONE]": + if line_str == "[DONE]": return try: - data = loads(line) + data = loads(line_str) except Exception as e: - raise InternalServerError(f"Failed to convert response to json: {e} with text: {line}") + raise InternalServerError(f"Failed to convert response to json: {e} with text: {line_str}") output = data["outputs"] diff --git a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py index 7bbd31e87c595d..40ea4dc0118026 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py @@ -53,14 +53,16 @@ def _invoke( api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - + endpoint_url: Optional[str] if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": endpoint_url = "https://cloud.perfxlab.cn/v1/" else: endpoint_url = credentials.get("endpoint_url") + assert endpoint_url is not None, "endpoint_url is required in credentials" if not endpoint_url.endswith("/"): endpoint_url += "/" + assert isinstance(endpoint_url, str) endpoint_url = urljoin(endpoint_url, "embeddings") extra_model_kwargs = {} @@ -142,13 +144,16 @@ def validate_credentials(self, model: str, credentials: dict) -> None: if api_key: headers["Authorization"] = f"Bearer {api_key}" + endpoint_url: Optional[str] if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": endpoint_url = "https://cloud.perfxlab.cn/v1/" else: endpoint_url = credentials.get("endpoint_url") + assert endpoint_url is not None, "endpoint_url is required in credentials" if not endpoint_url.endswith("/"): endpoint_url += "/" + assert isinstance(endpoint_url, str) endpoint_url = urljoin(endpoint_url, "embeddings") payload = {"input": "ping", "model": model} diff --git a/api/core/model_runtime/model_providers/replicate/_common.py b/api/core/model_runtime/model_providers/replicate/_common.py index 915f6e0eefcd08..3e2cf2adb306db 100644 --- a/api/core/model_runtime/model_providers/replicate/_common.py +++ b/api/core/model_runtime/model_providers/replicate/_common.py @@ -1,4 +1,4 @@ -from replicate.exceptions import ModelError, ReplicateError +from replicate.exceptions import ModelError, ReplicateError # type: ignore from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError diff --git a/api/core/model_runtime/model_providers/replicate/llm/llm.py b/api/core/model_runtime/model_providers/replicate/llm/llm.py index 3641b35dc02a39..1e7858100b0429 100644 --- a/api/core/model_runtime/model_providers/replicate/llm/llm.py +++ b/api/core/model_runtime/model_providers/replicate/llm/llm.py @@ -1,9 +1,9 @@ from collections.abc import Generator from typing import Optional, Union -from replicate import Client as ReplicateClient -from replicate.exceptions import ReplicateError -from replicate.prediction import Prediction +from replicate import Client as ReplicateClient # type: ignore +from replicate.exceptions import ReplicateError # type: ignore +from replicate.prediction import Prediction # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta diff --git a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py index 41759fe07d0cac..aaf825388a9043 100644 --- a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py @@ -2,11 +2,11 @@ import time from typing import Optional -from replicate import Client as ReplicateClient +from replicate import Client as ReplicateClient # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel @@ -86,7 +86,7 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> Option label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, - model_properties={"context_size": 4096, "max_chunks": 1}, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 4096, ModelPropertyKey.MAX_CHUNKS: 1}, ) return entity diff --git a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py index 5ff00f008eb621..b8c979b1f53ce9 100644 --- a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py +++ b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py @@ -4,7 +4,7 @@ from collections.abc import Generator, Iterator from typing import Any, Optional, Union, cast -import boto3 +import boto3 # type: ignore from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -83,7 +83,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): sagemaker_session: Any = None predictor: Any = None - sagemaker_endpoint: str = None + sagemaker_endpoint: str | None = None def _handle_chat_generate_response( self, @@ -209,8 +209,8 @@ def _invoke( :param user: unique user id :return: full response or stream response chunk generator result """ - from sagemaker import Predictor, serializers - from sagemaker.session import Session + from sagemaker import Predictor, serializers # type: ignore + from sagemaker.session import Session # type: ignore if not self.sagemaker_session: access_key = credentials.get("aws_access_key_id") diff --git a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py index df797bae265825..7daab6d8653d33 100644 --- a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py @@ -3,7 +3,7 @@ import operator from typing import Any, Optional -import boto3 +import boto3 # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType @@ -114,6 +114,7 @@ def _invoke( except Exception as e: logger.exception(f"Failed to invoke rerank model, model: {model}") + raise InvokeError(f"Failed to invoke rerank model, model: {model}, error: {str(e)}") def validate_credentials(self, model: str, credentials: dict) -> None: """ diff --git a/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py b/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py index 2d50e9c7b4c28a..a6aca130456063 100644 --- a/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py @@ -2,7 +2,7 @@ import logging from typing import IO, Any, Optional -import boto3 +import boto3 # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType @@ -67,6 +67,7 @@ def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional s3_prefix = "dify/speech2text/" sagemaker_endpoint = credentials.get("sagemaker_endpoint") bucket = credentials.get("audio_s3_cache_bucket") + assert bucket is not None, "audio_s3_cache_bucket is required in credentials" s3_presign_url = generate_presigned_url(self.s3_client, file, bucket, s3_prefix) payload = {"audio_s3_presign_uri": s3_presign_url} diff --git a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py index ef4ddcd6a72847..e7eccd997d11c1 100644 --- a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ import time from typing import Any, Optional -import boto3 +import boto3 # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject @@ -118,6 +118,7 @@ def _invoke( except Exception as e: logger.exception(f"Failed to invoke text embedding model, model: {model}, line: {line}") + raise InvokeError(str(e)) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ diff --git a/api/core/model_runtime/model_providers/sagemaker/tts/tts.py b/api/core/model_runtime/model_providers/sagemaker/tts/tts.py index 6a5946453be07f..62231c518deef1 100644 --- a/api/core/model_runtime/model_providers/sagemaker/tts/tts.py +++ b/api/core/model_runtime/model_providers/sagemaker/tts/tts.py @@ -5,7 +5,7 @@ from enum import Enum from typing import Any, Optional -import boto3 +import boto3 # type: ignore import requests from core.model_runtime.entities.common_entities import I18nObject diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/llm.py b/api/core/model_runtime/model_providers/siliconflow/llm/llm.py index e3a323a4965bc7..f61e8b82e4db99 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/llm.py +++ b/api/core/model_runtime/model_providers/siliconflow/llm/llm.py @@ -43,7 +43,7 @@ def _add_custom_parameters(cls, credentials: dict) -> None: credentials["mode"] = "chat" credentials["endpoint_url"] = "https://api.siliconflow.cn/v1" - def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: return AIModelEntity( model=model, label=I18nObject(en_US=model, zh_Hans=model), diff --git a/api/core/model_runtime/model_providers/spark/llm/llm.py b/api/core/model_runtime/model_providers/spark/llm/llm.py index 1181ba699af886..cb6f28b6c27fa9 100644 --- a/api/core/model_runtime/model_providers/spark/llm/llm.py +++ b/api/core/model_runtime/model_providers/spark/llm/llm.py @@ -1,6 +1,6 @@ import threading from collections.abc import Generator -from typing import Optional, Union +from typing import Optional, Union, cast from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -270,7 +270,7 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str: elif isinstance(message, AssistantPromptMessage): message_text = f"{ai_prompt} {content}" elif isinstance(message, SystemPromptMessage): - message_text = content + message_text = cast(str, content) else: raise ValueError(f"Got unknown type {message}") diff --git a/api/core/model_runtime/model_providers/togetherai/llm/llm.py b/api/core/model_runtime/model_providers/togetherai/llm/llm.py index b96d43979ef54a..03eac194235e83 100644 --- a/api/core/model_runtime/model_providers/togetherai/llm/llm.py +++ b/api/core/model_runtime/model_providers/togetherai/llm/llm.py @@ -12,6 +12,7 @@ AIModelEntity, DefaultParameterName, FetchFrom, + ModelFeature, ModelPropertyKey, ModelType, ParameterRule, @@ -67,7 +68,7 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode cred_with_endpoint = self._update_endpoint_url(credentials=credentials) REPETITION_PENALTY = "repetition_penalty" TOP_K = "top_k" - features = [] + features: list[ModelFeature] = [] entity = AIModelEntity( model=model, diff --git a/api/core/model_runtime/model_providers/tongyi/_common.py b/api/core/model_runtime/model_providers/tongyi/_common.py index 8a50c7aa05f38c..bb68319555007f 100644 --- a/api/core/model_runtime/model_providers/tongyi/_common.py +++ b/api/core/model_runtime/model_providers/tongyi/_common.py @@ -1,4 +1,4 @@ -from dashscope.common.error import ( +from dashscope.common.error import ( # type: ignore AuthenticationError, InvalidParameter, RequestFailure, diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py index 0c1f6518811aa8..61ebd45ed64a6d 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -7,9 +7,9 @@ from pathlib import Path from typing import Optional, Union, cast -from dashscope import Generation, MultiModalConversation, get_tokenizer -from dashscope.api_entities.dashscope_response import GenerationResponse -from dashscope.common.error import ( +from dashscope import Generation, MultiModalConversation, get_tokenizer # type: ignore +from dashscope.api_entities.dashscope_response import GenerationResponse # type: ignore +from dashscope.common.error import ( # type: ignore AuthenticationError, InvalidParameter, RequestFailure, diff --git a/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py b/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py index a5ce9ead6ee3be..ed682cb0f3c1e4 100644 --- a/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py @@ -1,7 +1,7 @@ from typing import Optional -import dashscope -from dashscope.common.error import ( +import dashscope # type: ignore +from dashscope.common.error import ( # type: ignore AuthenticationError, InvalidParameter, RequestFailure, @@ -51,7 +51,7 @@ def _invoke( :return: rerank result """ if len(docs) == 0: - return RerankResult(model=model, docs=docs) + return RerankResult(model=model, docs=[]) # initialize client dashscope.api_key = credentials["dashscope_api_key"] @@ -64,7 +64,7 @@ def _invoke( return_documents=True, ) - rerank_documents = [] + rerank_documents: list[RerankDocument] = [] if not response.output: return RerankResult(model=model, docs=rerank_documents) for _, result in enumerate(response.output.results): diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py index 2ef7f3f5774481..8c53be413002a9 100644 --- a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py @@ -1,7 +1,7 @@ import time from typing import Optional -import dashscope +import dashscope # type: ignore import numpy as np from core.entities.embedding_type import EmbeddingInputType diff --git a/api/core/model_runtime/model_providers/tongyi/tts/tts.py b/api/core/model_runtime/model_providers/tongyi/tts/tts.py index ca3b9fbc1c3c00..a654e2d760d7c4 100644 --- a/api/core/model_runtime/model_providers/tongyi/tts/tts.py +++ b/api/core/model_runtime/model_providers/tongyi/tts/tts.py @@ -2,10 +2,10 @@ from queue import Queue from typing import Any, Optional -import dashscope -from dashscope import SpeechSynthesizer -from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse -from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult +import dashscope # type: ignore +from dashscope import SpeechSynthesizer # type: ignore +from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse # type: ignore +from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult # type: ignore from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/core/model_runtime/model_providers/upstage/_common.py b/api/core/model_runtime/model_providers/upstage/_common.py index 47ebaccd84ab8a..f6609bba77129b 100644 --- a/api/core/model_runtime/model_providers/upstage/_common.py +++ b/api/core/model_runtime/model_providers/upstage/_common.py @@ -1,5 +1,3 @@ -from collections.abc import Mapping - import openai from httpx import Timeout @@ -14,7 +12,7 @@ class _CommonUpstage: - def _to_credential_kwargs(self, credentials: Mapping) -> dict: + def _to_credential_kwargs(self, credentials: dict) -> dict: """ Transform credentials to kwargs for model instance diff --git a/api/core/model_runtime/model_providers/upstage/llm/llm.py b/api/core/model_runtime/model_providers/upstage/llm/llm.py index a18ee906248a49..2bf6796ca5cf45 100644 --- a/api/core/model_runtime/model_providers/upstage/llm/llm.py +++ b/api/core/model_runtime/model_providers/upstage/llm/llm.py @@ -6,7 +6,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall from openai.types.chat.chat_completion_message import FunctionCall -from tokenizers import Tokenizer +from tokenizers import Tokenizer # type: ignore from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta diff --git a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py index 5b340e53bbc543..87693eca768dfd 100644 --- a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py @@ -1,11 +1,10 @@ import base64 import time -from collections.abc import Mapping from typing import Union import numpy as np from openai import OpenAI -from tokenizers import Tokenizer +from tokenizers import Tokenizer # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType @@ -132,7 +131,7 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int return total_num_tokens - def validate_credentials(self, model: str, credentials: Mapping) -> None: + def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate model credentials diff --git a/api/core/model_runtime/model_providers/vertex_ai/_common.py b/api/core/model_runtime/model_providers/vertex_ai/_common.py index 8f7c859e3803c0..4e3df7574e9ce8 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/_common.py +++ b/api/core/model_runtime/model_providers/vertex_ai/_common.py @@ -12,4 +12,4 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ - pass + raise NotImplementedError diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py index c50e0f794616b3..85be34f3f0fe7f 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py @@ -6,7 +6,7 @@ from collections.abc import Generator from typing import TYPE_CHECKING, Optional, Union, cast -import google.auth.transport.requests +import google.auth.transport.requests # type: ignore import requests from anthropic import AnthropicVertex, Stream from anthropic.types import ( diff --git a/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py b/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py index 034c066ab5f071..782e4fd6232a3b 100644 --- a/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py @@ -17,14 +17,12 @@ class VesslAILargeLanguageModel(OAIAPICompatLargeLanguageModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: - features = [] - entity = AIModelEntity( model=model, label=I18nObject(en_US=model), model_type=ModelType.LLM, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - features=features, + features=[], model_properties={ ModelPropertyKey.MODE: credentials.get("mode"), }, diff --git a/api/core/model_runtime/model_providers/volcengine_maas/client.py b/api/core/model_runtime/model_providers/volcengine_maas/client.py index 1cffd902c7a25d..a8a015167e3227 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/client.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/client.py @@ -1,8 +1,8 @@ from collections.abc import Generator from typing import Optional, cast -from volcenginesdkarkruntime import Ark -from volcenginesdkarkruntime.types.chat import ( +from volcenginesdkarkruntime import Ark # type: ignore +from volcenginesdkarkruntime.types.chat import ( # type: ignore ChatCompletion, ChatCompletionAssistantMessageParam, ChatCompletionChunk, @@ -15,10 +15,10 @@ ChatCompletionToolParam, ChatCompletionUserMessageParam, ) -from volcenginesdkarkruntime.types.chat.chat_completion_content_part_image_param import ImageURL -from volcenginesdkarkruntime.types.chat.chat_completion_message_tool_call_param import Function -from volcenginesdkarkruntime.types.create_embedding_response import CreateEmbeddingResponse -from volcenginesdkarkruntime.types.shared_params import FunctionDefinition +from volcenginesdkarkruntime.types.chat.chat_completion_content_part_image_param import ImageURL # type: ignore +from volcenginesdkarkruntime.types.chat.chat_completion_message_tool_call_param import Function # type: ignore +from volcenginesdkarkruntime.types.create_embedding_response import CreateEmbeddingResponse # type: ignore +from volcenginesdkarkruntime.types.shared_params import FunctionDefinition # type: ignore from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py index 91dbe21a616195..aa837b8318873d 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py @@ -152,5 +152,6 @@ class ServiceNotOpenError(MaasError): def wrap_error(e: MaasError) -> Exception: if ErrorCodeMap.get(e.code): - return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id) + # FIXME: mypy type error, try to fix it instead of using type: ignore + return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id) # type: ignore return e diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py index 9e19b7dedaa5a7..f0b2b101b7be9d 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py @@ -2,7 +2,7 @@ from collections.abc import Generator from typing import Optional -from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk +from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py index cf3cf23cfb9cef..7c37368086e0e6 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py @@ -1,3 +1,5 @@ +from typing import Any + from pydantic import BaseModel from core.model_runtime.entities.llm_entities import LLMMode @@ -102,7 +104,7 @@ def get_model_config(credentials: dict) -> ModelConfig: def get_v2_req_params(credentials: dict, model_parameters: dict, stop: list[str] | None = None): - req_params = {} + req_params: dict[str, Any] = {} # predefined properties model_configs = get_model_config(credentials) if model_configs: @@ -130,7 +132,7 @@ def get_v2_req_params(credentials: dict, model_parameters: dict, stop: list[str] def get_v3_req_params(credentials: dict, model_parameters: dict, stop: list[str] | None = None): - req_params = {} + req_params: dict[str, Any] = {} # predefined properties model_configs = get_model_config(credentials) if model_configs: diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py index 07b970f8104c8f..d2899795696aa4 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py @@ -1,7 +1,7 @@ from collections.abc import Generator from enum import Enum from json import dumps, loads -from typing import Any, Union +from typing import Any, Optional, Union from requests import Response, post @@ -22,7 +22,7 @@ class Role(Enum): role: str = Role.USER.value content: str - usage: dict[str, int] = None + usage: Optional[dict[str, int]] = None stop_reason: str = "" def to_dict(self) -> dict[str, Any]: @@ -135,6 +135,7 @@ def _build_function_calling_request_body( """ TODO: implement function calling """ + raise NotImplementedError("Function calling is not supported yet.") def _build_chat_request_body( self, diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py index 19135deb27380d..816b3b98c4b8c5 100644 --- a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py @@ -1,6 +1,5 @@ import time from abc import abstractmethod -from collections.abc import Mapping from json import dumps from typing import Any, Optional @@ -23,12 +22,12 @@ class TextEmbedding: @abstractmethod - def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): + def embed_documents(self, model: str, texts: list[str], user: str) -> tuple[list[list[float]], int, int]: raise NotImplementedError class WenxinTextEmbedding(_CommonWenxin, TextEmbedding): - def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): + def embed_documents(self, model: str, texts: list[str], user: str) -> tuple[list[list[float]], int, int]: access_token = self._get_access_token() url = f"{self.api_bases[model]}?access_token={access_token}" body = self._build_embed_request_body(model, texts, user) @@ -50,7 +49,7 @@ def _build_embed_request_body(self, model: str, texts: list[str], user: str) -> } return body - def _handle_embed_response(self, model: str, response: Response) -> (list[list[float]], int, int): + def _handle_embed_response(self, model: str, response: Response) -> tuple[list[list[float]], int, int]: data = response.json() if "error_code" in data: code = data["error_code"] @@ -147,7 +146,7 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int return total_num_tokens - def validate_credentials(self, model: str, credentials: Mapping) -> None: + def validate_credentials(self, model: str, credentials: dict) -> None: api_key = credentials["api_key"] secret_key = credentials["secret_key"] try: diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index 8d86d6937d8ac9..7db1203641cad2 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -17,7 +17,7 @@ from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall from openai.types.chat.chat_completion_message import FunctionCall from openai.types.completion import Completion -from xinference_client.client.restful.restful_client import ( +from xinference_client.client.restful.restful_client import ( # type: ignore Client, RESTfulChatModelHandle, RESTfulGenerateModelHandle, diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index efaf114854b5c1..078ec0537a37f4 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -1,6 +1,6 @@ from typing import Optional -from xinference_client.client.restful.restful_client import Client, RESTfulRerankModelHandle +from xinference_client.client.restful.restful_client import Client, RESTfulRerankModelHandle # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType diff --git a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py index 3d7aefeb6dd89a..5f330ece1a5750 100644 --- a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py @@ -1,6 +1,6 @@ from typing import IO, Optional -from xinference_client.client.restful.restful_client import Client, RESTfulAudioModelHandle +from xinference_client.client.restful.restful_client import Client, RESTfulAudioModelHandle # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index e51e6a941c5413..9054aabab2dd05 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -1,7 +1,7 @@ import time from typing import Optional -from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle +from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject @@ -134,7 +134,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: try: handle = client.get_model(model_uid=model_uid) except RuntimeError as e: - raise InvokeAuthorizationError(e) + raise InvokeAuthorizationError(str(e)) if not isinstance(handle, RESTfulEmbeddingModelHandle): raise InvokeBadRequestError( diff --git a/api/core/model_runtime/model_providers/xinference/tts/tts.py b/api/core/model_runtime/model_providers/xinference/tts/tts.py index ad7b64efb5d2e7..8aa39d4de0d2cb 100644 --- a/api/core/model_runtime/model_providers/xinference/tts/tts.py +++ b/api/core/model_runtime/model_providers/xinference/tts/tts.py @@ -1,7 +1,7 @@ import concurrent.futures from typing import Any, Optional -from xinference_client.client.restful.restful_client import RESTfulAudioModelHandle +from xinference_client.client.restful.restful_client import RESTfulAudioModelHandle # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType @@ -74,11 +74,14 @@ def validate_credentials(self, model: str, credentials: dict) -> None: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") credentials["server_url"] = credentials["server_url"].removesuffix("/") + api_key = credentials.get("api_key") + if api_key is None: + raise CredentialsValidateFailedError("api_key is required") extra_param = XinferenceHelper.get_xinference_extra_parameter( server_url=credentials["server_url"], model_uid=credentials["model_uid"], - api_key=credentials.get("api_key"), + api_key=api_key, ) if "text-to-audio" not in extra_param.model_ability: diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index baa3ccbe8adbc0..b51423f4eda2e6 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -1,6 +1,6 @@ from threading import Lock from time import time -from typing import Optional +from typing import Any, Optional from requests.adapters import HTTPAdapter from requests.exceptions import ConnectionError, MissingSchema, Timeout @@ -39,13 +39,15 @@ def __init__( self.model_family = model_family -cache = {} +cache: dict[str, dict[str, Any]] = {} cache_lock = Lock() class XinferenceHelper: @staticmethod - def get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter: + def get_xinference_extra_parameter( + server_url: str, model_uid: str, api_key: str | None + ) -> XinferenceModelExtraParameter: XinferenceHelper._clean_cache() with cache_lock: if model_uid not in cache: @@ -66,7 +68,9 @@ def _clean_cache() -> None: pass @staticmethod - def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter: + def _get_xinference_extra_parameter( + server_url: str, model_uid: str, api_key: str | None + ) -> XinferenceModelExtraParameter: """ get xinference model extra parameter like model_format and model_handle_type """ diff --git a/api/core/model_runtime/model_providers/yi/llm/llm.py b/api/core/model_runtime/model_providers/yi/llm/llm.py index 0642e72ed500e1..f5b61e207635bc 100644 --- a/api/core/model_runtime/model_providers/yi/llm/llm.py +++ b/api/core/model_runtime/model_providers/yi/llm/llm.py @@ -136,7 +136,7 @@ def _add_custom_parameters(credentials: dict) -> None: parsed_url = urlparse(credentials["endpoint_url"]) credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}" - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: return AIModelEntity( model=model, label=I18nObject(en_US=model, zh_Hans=model), diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index 59861507e45cd6..eef86cc52c36e8 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -1,9 +1,9 @@ from collections.abc import Generator from typing import Optional, Union -from zhipuai import ZhipuAI -from zhipuai.types.chat.chat_completion import Completion -from zhipuai.types.chat.chat_completion_chunk import ChatCompletionChunk +from zhipuai import ZhipuAI # type: ignore +from zhipuai.types.chat.chat_completion import Completion # type: ignore +from zhipuai.types.chat.chat_completion_chunk import ChatCompletionChunk # type: ignore from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( diff --git a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py index 2428284ba9a8ff..a700304db7b6f3 100644 --- a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py @@ -1,7 +1,7 @@ import time from typing import Optional -from zhipuai import ZhipuAI +from zhipuai import ZhipuAI # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType diff --git a/api/core/model_runtime/schema_validators/common_validator.py b/api/core/model_runtime/schema_validators/common_validator.py index 029ec1a581b2e9..8cc8adfc3656ea 100644 --- a/api/core/model_runtime/schema_validators/common_validator.py +++ b/api/core/model_runtime/schema_validators/common_validator.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Union, cast from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType @@ -38,7 +38,7 @@ def _validate_and_filter_credential_form_schemas( def _validate_credential_form_schema( self, credential_form_schema: CredentialFormSchema, credentials: dict - ) -> Optional[str]: + ) -> Union[str, bool, None]: """ Validate credential form schema @@ -47,6 +47,7 @@ def _validate_credential_form_schema( :return: validated credential form schema value """ # If the variable does not exist in credentials + value: Union[str, bool, None] = None if credential_form_schema.variable not in credentials or not credentials[credential_form_schema.variable]: # If required is True, an exception is thrown if credential_form_schema.required: @@ -61,7 +62,7 @@ def _validate_credential_form_schema( return None # Get the value corresponding to the variable from credentials - value = credentials[credential_form_schema.variable] + value = cast(str, credentials[credential_form_schema.variable]) # If max_length=0, no validation is performed if credential_form_schema.max_length: diff --git a/api/core/model_runtime/utils/encoders.py b/api/core/model_runtime/utils/encoders.py index ec1bad5698f2eb..03e350627140cf 100644 --- a/api/core/model_runtime/utils/encoders.py +++ b/api/core/model_runtime/utils/encoders.py @@ -129,7 +129,8 @@ def jsonable_encoder( sqlalchemy_safe=sqlalchemy_safe, ) if dataclasses.is_dataclass(obj): - obj_dict = dataclasses.asdict(obj) + # FIXME: mypy error, try to fix it instead of using type: ignore + obj_dict = dataclasses.asdict(obj) # type: ignore return jsonable_encoder( obj_dict, by_alias=by_alias, diff --git a/api/core/model_runtime/utils/helper.py b/api/core/model_runtime/utils/helper.py index 2067092d80f582..5e8a723ec7c510 100644 --- a/api/core/model_runtime/utils/helper.py +++ b/api/core/model_runtime/utils/helper.py @@ -4,6 +4,7 @@ def dump_model(model: BaseModel) -> dict: if hasattr(pydantic, "model_dump"): - return pydantic.model_dump(model) + # FIXME mypy error, try to fix it instead of using type: ignore + return pydantic.model_dump(model) # type: ignore else: return model.model_dump() diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index 094ad7863603dc..c65a3885fd1eb9 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -1,3 +1,5 @@ +from typing import Optional + from pydantic import BaseModel from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor @@ -43,6 +45,8 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["inputs_config"]["enabled"]: params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query) @@ -57,6 +61,8 @@ def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInpu def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["outputs_config"]["enabled"]: params = ModerationOutputParams(app_id=self.app_id, text=text) @@ -69,14 +75,18 @@ def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: ) def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict: - extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id")) + if self.config is None: + raise ValueError("The config is not set.") + extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", "")) + if not extension: + raise ValueError("API-based Extension not found. Please check it again.") requestor = APIBasedExtensionRequestor(extension.api_endpoint, decrypt_token(self.tenant_id, extension.api_key)) result = requestor.request(extension_point, params) return result @staticmethod - def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: + def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]: extension = ( db.session.query(APIBasedExtension) .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 60898d5547ae3b..d8c392d0970e19 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -100,14 +100,14 @@ def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_re if not inputs_config.get("preset_response"): raise ValueError("inputs_config.preset_response is required") - if len(inputs_config.get("preset_response")) > 100: + if len(inputs_config.get("preset_response", 0)) > 100: raise ValueError("inputs_config.preset_response must be less than 100 characters") if outputs_config_enabled: if not outputs_config.get("preset_response"): raise ValueError("outputs_config.preset_response is required") - if len(outputs_config.get("preset_response")) > 100: + if len(outputs_config.get("preset_response", 0)) > 100: raise ValueError("outputs_config.preset_response must be less than 100 characters") diff --git a/api/core/moderation/factory.py b/api/core/moderation/factory.py index 96bf2ab54b41eb..0ad4438c143870 100644 --- a/api/core/moderation/factory.py +++ b/api/core/moderation/factory.py @@ -22,7 +22,8 @@ def validate_config(cls, name: str, tenant_id: str, config: dict) -> None: """ code_based_extension.validate_form_schema(ExtensionModule.MODERATION, name, config) extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name) - extension_class.validate_config(tenant_id, config) + # FIXME: mypy error, try to fix it instead of using type: ignore + extension_class.validate_config(tenant_id, config) # type: ignore def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: """ diff --git a/api/core/moderation/input_moderation.py b/api/core/moderation/input_moderation.py index 46d3963bd07f5a..3ac33966cb14bf 100644 --- a/api/core/moderation/input_moderation.py +++ b/api/core/moderation/input_moderation.py @@ -1,5 +1,6 @@ import logging -from typing import Optional +from collections.abc import Mapping +from typing import Any, Optional from core.app.app_config.entities import AppConfig from core.moderation.base import ModerationAction, ModerationError @@ -17,11 +18,11 @@ def check( app_id: str, tenant_id: str, app_config: AppConfig, - inputs: dict, + inputs: Mapping[str, Any], query: str, message_id: str, trace_manager: Optional[TraceQueueManager] = None, - ) -> tuple[bool, dict, str]: + ) -> tuple[bool, Mapping[str, Any], str]: """ Process sensitive_word_avoidance. :param app_id: app id @@ -33,6 +34,7 @@ def check( :param trace_manager: trace manager :return: """ + inputs = dict(inputs) if not app_config.sensitive_word_avoidance: return False, inputs, query diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index 00b3c56c03602d..9dd2665c3bf3d3 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -21,7 +21,7 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: if not config.get("keywords"): raise ValueError("keywords is required") - if len(config.get("keywords")) > 10000: + if len(config.get("keywords", [])) > 10000: raise ValueError("keywords length must be less than 10000") keywords_row_len = config["keywords"].split("\n") @@ -31,6 +31,8 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["inputs_config"]["enabled"]: preset_response = self.config["inputs_config"]["preset_response"] @@ -50,6 +52,8 @@ def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInpu def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["outputs_config"]["enabled"]: # Filter out empty values diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index 6465de23b9a2de..d64f17b383e0b5 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -20,6 +20,8 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["inputs_config"]["enabled"]: preset_response = self.config["inputs_config"]["preset_response"] @@ -35,6 +37,8 @@ def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInpu def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["outputs_config"]["enabled"]: flagged = self._is_violated({"text": text}) diff --git a/api/core/moderation/output_moderation.py b/api/core/moderation/output_moderation.py index 4635bd9c251851..e595be126c7824 100644 --- a/api/core/moderation/output_moderation.py +++ b/api/core/moderation/output_moderation.py @@ -70,7 +70,7 @@ def start_thread(self) -> threading.Thread: thread = threading.Thread( target=self.worker, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "buffer_size": buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE, }, ) diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 71ff03b6ef5160..f0e34c0cd71241 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from datetime import datetime from enum import StrEnum from typing import Any, Optional, Union @@ -38,8 +39,8 @@ class WorkflowTraceInfo(BaseTraceInfo): workflow_run_id: str workflow_run_elapsed_time: Union[int, float] workflow_run_status: str - workflow_run_inputs: dict[str, Any] - workflow_run_outputs: dict[str, Any] + workflow_run_inputs: Mapping[str, Any] + workflow_run_outputs: Mapping[str, Any] workflow_run_version: str error: Optional[str] = None total_tokens: int diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 29fdebd8feaeb8..b9ba068b19936d 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -77,8 +77,8 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): id=trace_id, user_id=user_id, name=name, - input=trace_info.workflow_run_inputs, - output=trace_info.workflow_run_outputs, + input=dict(trace_info.workflow_run_inputs), + output=dict(trace_info.workflow_run_outputs), metadata=metadata, session_id=trace_info.conversation_id, tags=["message", "workflow"], @@ -87,8 +87,8 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): workflow_span_data = LangfuseSpan( id=trace_info.workflow_run_id, name=TraceTaskName.WORKFLOW_TRACE.value, - input=trace_info.workflow_run_inputs, - output=trace_info.workflow_run_outputs, + input=dict(trace_info.workflow_run_inputs), + output=dict(trace_info.workflow_run_outputs), trace_id=trace_id, start_time=trace_info.start_time, end_time=trace_info.end_time, @@ -102,8 +102,8 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): id=trace_id, user_id=user_id, name=TraceTaskName.WORKFLOW_TRACE.value, - input=trace_info.workflow_run_inputs, - output=trace_info.workflow_run_outputs, + input=dict(trace_info.workflow_run_inputs), + output=dict(trace_info.workflow_run_outputs), metadata=metadata, session_id=trace_info.conversation_id, tags=["workflow"], diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py index 99221d669b3193..348b7ba5012b6b 100644 --- a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py +++ b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py @@ -49,7 +49,6 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run") input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run") output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run") - dotted_order: Optional[str] = Field(None, description="Dotted order of the run") @field_validator("inputs", "outputs") @classmethod diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 672843e5a8f986..4ffd888bddf8a3 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -3,6 +3,7 @@ import os import uuid from datetime import datetime, timedelta +from typing import Optional, cast from langsmith import Client from langsmith.schemas import RunBase @@ -63,6 +64,8 @@ def trace(self, trace_info: BaseTraceInfo): def workflow_trace(self, trace_info: WorkflowTraceInfo): trace_id = trace_info.message_id or trace_info.workflow_run_id + if trace_info.start_time is None: + trace_info.start_time = datetime.now() message_dotted_order = ( generate_dotted_order(trace_info.message_id, trace_info.start_time) if trace_info.message_id else None ) @@ -78,8 +81,8 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): message_run = LangSmithRunModel( id=trace_info.message_id, name=TraceTaskName.MESSAGE_TRACE.value, - inputs=trace_info.workflow_run_inputs, - outputs=trace_info.workflow_run_outputs, + inputs=dict(trace_info.workflow_run_inputs), + outputs=dict(trace_info.workflow_run_outputs), run_type=LangSmithRunType.chain, start_time=trace_info.start_time, end_time=trace_info.end_time, @@ -90,6 +93,15 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): error=trace_info.error, trace_id=trace_id, dotted_order=message_dotted_order, + file_list=[], + serialized=None, + parent_run_id=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, ) self.add_run(message_run) @@ -98,11 +110,11 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): total_tokens=trace_info.total_tokens, id=trace_info.workflow_run_id, name=TraceTaskName.WORKFLOW_TRACE.value, - inputs=trace_info.workflow_run_inputs, + inputs=dict(trace_info.workflow_run_inputs), run_type=LangSmithRunType.tool, start_time=trace_info.workflow_data.created_at, end_time=trace_info.workflow_data.finished_at, - outputs=trace_info.workflow_run_outputs, + outputs=dict(trace_info.workflow_run_outputs), extra={ "metadata": metadata, }, @@ -111,6 +123,13 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): parent_run_id=trace_info.message_id or None, trace_id=trace_id, dotted_order=workflow_dotted_order, + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, ) self.add_run(langsmith_run) @@ -211,25 +230,35 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): id=node_execution_id, trace_id=trace_id, dotted_order=node_dotted_order, + error="", + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, ) self.add_run(langsmith_run) def message_trace(self, trace_info: MessageTraceInfo): # get message file data - file_list = trace_info.file_list - message_file_data: MessageFile = trace_info.message_file_data + file_list = cast(list[str], trace_info.file_list) or [] + message_file_data: Optional[MessageFile] = trace_info.message_file_data file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" file_list.append(file_url) metadata = trace_info.metadata message_data = trace_info.message_data + if message_data is None: + return message_id = message_data.id user_id = message_data.from_account_id metadata["user_id"] = user_id if message_data.from_end_user_id: - end_user_data: EndUser = ( + end_user_data: Optional[EndUser] = ( db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: @@ -247,12 +276,20 @@ def message_trace(self, trace_info: MessageTraceInfo): start_time=trace_info.start_time, end_time=trace_info.end_time, outputs=message_data.answer, - extra={ - "metadata": metadata, - }, + extra={"metadata": metadata}, tags=["message", str(trace_info.conversation_mode)], error=trace_info.error, file_list=file_list, + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + parent_run_id=None, ) self.add_run(message_run) @@ -267,17 +304,27 @@ def message_trace(self, trace_info: MessageTraceInfo): start_time=trace_info.start_time, end_time=trace_info.end_time, outputs=message_data.answer, - extra={ - "metadata": metadata, - }, + extra={"metadata": metadata}, parent_run_id=message_id, tags=["llm", str(trace_info.conversation_mode)], error=trace_info.error, file_list=file_list, + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + id=str(uuid.uuid4()), ) self.add_run(llm_run) def moderation_trace(self, trace_info: ModerationTraceInfo): + if trace_info.message_data is None: + return langsmith_run = LangSmithRunModel( name=TraceTaskName.MODERATION_TRACE.value, inputs=trace_info.inputs, @@ -288,48 +335,82 @@ def moderation_trace(self, trace_info: ModerationTraceInfo): "inputs": trace_info.inputs, }, run_type=LangSmithRunType.tool, - extra={ - "metadata": trace_info.metadata, - }, + extra={"metadata": trace_info.metadata}, tags=["moderation"], parent_run_id=trace_info.message_id, start_time=trace_info.start_time or trace_info.message_data.created_at, end_time=trace_info.end_time or trace_info.message_data.updated_at, + id=str(uuid.uuid4()), + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + error="", + file_list=[], ) self.add_run(langsmith_run) def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo): message_data = trace_info.message_data + if message_data is None: + return suggested_question_run = LangSmithRunModel( name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value, inputs=trace_info.inputs, outputs=trace_info.suggested_question, run_type=LangSmithRunType.tool, - extra={ - "metadata": trace_info.metadata, - }, + extra={"metadata": trace_info.metadata}, tags=["suggested_question"], parent_run_id=trace_info.message_id, start_time=trace_info.start_time or message_data.created_at, end_time=trace_info.end_time or message_data.updated_at, + id=str(uuid.uuid4()), + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + error="", + file_list=[], ) self.add_run(suggested_question_run) def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo): + if trace_info.message_data is None: + return dataset_retrieval_run = LangSmithRunModel( name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value, inputs=trace_info.inputs, outputs={"documents": trace_info.documents}, run_type=LangSmithRunType.retriever, - extra={ - "metadata": trace_info.metadata, - }, + extra={"metadata": trace_info.metadata}, tags=["dataset_retrieval"], parent_run_id=trace_info.message_id, start_time=trace_info.start_time or trace_info.message_data.created_at, end_time=trace_info.end_time or trace_info.message_data.updated_at, + id=str(uuid.uuid4()), + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + error="", + file_list=[], ) self.add_run(dataset_retrieval_run) @@ -347,7 +428,18 @@ def tool_trace(self, trace_info: ToolTraceInfo): parent_run_id=trace_info.message_id, start_time=trace_info.start_time, end_time=trace_info.end_time, - file_list=[trace_info.file_url], + file_list=[cast(str, trace_info.file_url)], + id=str(uuid.uuid4()), + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + error=trace_info.error or "", ) self.add_run(tool_run) @@ -358,12 +450,23 @@ def generate_name_trace(self, trace_info: GenerateNameTraceInfo): inputs=trace_info.inputs, outputs=trace_info.outputs, run_type=LangSmithRunType.tool, - extra={ - "metadata": trace_info.metadata, - }, + extra={"metadata": trace_info.metadata}, tags=["generate_name"], start_time=trace_info.start_time or datetime.now(), end_time=trace_info.end_time or datetime.now(), + id=str(uuid.uuid4()), + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + error="", + file_list=[], + parent_run_id=None, ) self.add_run(name_run) diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 4f41b6ed97047f..f538eaef5bd570 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -33,11 +33,11 @@ from core.ops.utils import get_message_data from extensions.ext_database import db from extensions.ext_storage import storage -from models.model import App, AppModelConfig, Conversation, Message, MessageAgentThought, MessageFile, TraceAppConfig +from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig from models.workflow import WorkflowAppLog, WorkflowRun from tasks.ops_trace_task import process_trace_tasks -provider_config_map = { +provider_config_map: dict[str, dict[str, Any]] = { TracingProviderEnum.LANGFUSE.value: { "config_class": LangfuseConfig, "secret_keys": ["public_key", "secret_key"], @@ -145,7 +145,7 @@ def get_decrypted_tracing_config(cls, app_id: str, tracing_provider: str): :param tracing_provider: tracing provider :return: """ - trace_config_data: TraceAppConfig = ( + trace_config_data: Optional[TraceAppConfig] = ( db.session.query(TraceAppConfig) .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() @@ -155,7 +155,11 @@ def get_decrypted_tracing_config(cls, app_id: str, tracing_provider: str): return None # decrypt_token - tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id + app = db.session.query(App).filter(App.id == app_id).first() + if not app: + raise ValueError("App not found") + + tenant_id = app.tenant_id decrypt_tracing_config = cls.decrypt_tracing_config( tenant_id, tracing_provider, trace_config_data.tracing_config ) @@ -178,7 +182,7 @@ def get_ops_trace_instance( if app_id is None: return None - app: App = db.session.query(App).filter(App.id == app_id).first() + app: Optional[App] = db.session.query(App).filter(App.id == app_id).first() if app is None: return None @@ -209,8 +213,12 @@ def get_ops_trace_instance( def get_app_config_through_message_id(cls, message_id: str): app_model_config = None message_data = db.session.query(Message).filter(Message.id == message_id).first() + if not message_data: + return None conversation_id = message_data.conversation_id conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() + if not conversation_data: + return None if conversation_data.app_model_config_id: app_model_config = ( @@ -236,7 +244,9 @@ def update_app_tracing_config(cls, app_id: str, enabled: bool, tracing_provider: if tracing_provider not in provider_config_map and tracing_provider is not None: raise ValueError(f"Invalid tracing provider: {tracing_provider}") - app_config: App = db.session.query(App).filter(App.id == app_id).first() + app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first() + if not app_config: + raise ValueError("App not found") app_config.tracing = json.dumps( { "enabled": enabled, @@ -252,7 +262,9 @@ def get_app_tracing_config(cls, app_id: str): :param app_id: app id :return: """ - app: App = db.session.query(App).filter(App.id == app_id).first() + app: Optional[App] = db.session.query(App).filter(App.id == app_id).first() + if not app: + raise ValueError("App not found") if not app.tracing: return {"enabled": False, "tracing_provider": None} app_trace_config = json.loads(app.tracing) @@ -483,6 +495,8 @@ def message_trace(self, message_id): def moderation_trace(self, message_id, timer, **kwargs): moderation_result = kwargs.get("moderation_result") + if not moderation_result: + return {} inputs = kwargs.get("inputs") message_data = get_message_data(message_id) if not message_data: @@ -518,7 +532,7 @@ def moderation_trace(self, message_id, timer, **kwargs): return moderation_trace_info def suggested_question_trace(self, message_id, timer, **kwargs): - suggested_question = kwargs.get("suggested_question") + suggested_question = kwargs.get("suggested_question", []) message_data = get_message_data(message_id) if not message_data: return {} @@ -586,7 +600,7 @@ def dataset_retrieval_trace(self, message_id, timer, **kwargs): dataset_retrieval_trace_info = DatasetRetrievalTraceInfo( message_id=message_id, inputs=message_data.query or message_data.inputs, - documents=[doc.model_dump() for doc in documents], + documents=[doc.model_dump() for doc in documents] if documents else [], start_time=timer.get("start"), end_time=timer.get("end"), metadata=metadata, @@ -596,9 +610,9 @@ def dataset_retrieval_trace(self, message_id, timer, **kwargs): return dataset_retrieval_trace_info def tool_trace(self, message_id, timer, **kwargs): - tool_name = kwargs.get("tool_name") - tool_inputs = kwargs.get("tool_inputs") - tool_outputs = kwargs.get("tool_outputs") + tool_name = kwargs.get("tool_name", "") + tool_inputs = kwargs.get("tool_inputs", {}) + tool_outputs = kwargs.get("tool_outputs", {}) message_data = get_message_data(message_id) if not message_data: return {} @@ -608,7 +622,7 @@ def tool_trace(self, message_id, timer, **kwargs): tool_parameters = {} created_time = message_data.created_at end_time = message_data.updated_at - agent_thoughts: list[MessageAgentThought] = message_data.agent_thoughts + agent_thoughts = message_data.agent_thoughts for agent_thought in agent_thoughts: if tool_name in agent_thought.tools: created_time = agent_thought.created_at @@ -672,6 +686,8 @@ def generate_name_trace(self, conversation_id, timer, **kwargs): generate_conversation_name = kwargs.get("generate_conversation_name") inputs = kwargs.get("inputs") tenant_id = kwargs.get("tenant_id") + if not tenant_id: + return {} start_time = timer.get("start") end_time = timer.get("end") @@ -693,8 +709,8 @@ def generate_name_trace(self, conversation_id, timer, **kwargs): return generate_name_trace_info -trace_manager_timer = None -trace_manager_queue = queue.Queue() +trace_manager_timer: Optional[threading.Timer] = None +trace_manager_queue: queue.Queue = queue.Queue() trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5)) trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100)) @@ -706,7 +722,7 @@ def __init__(self, app_id=None, user_id=None): self.app_id = app_id self.user_id = user_id self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id) - self.flask_app = current_app._get_current_object() + self.flask_app = current_app._get_current_object() # type: ignore if trace_manager_timer is None: self.start_timer() @@ -723,7 +739,7 @@ def add_trace_task(self, trace_task: TraceTask): def collect_tasks(self): global trace_manager_queue - tasks = [] + tasks: list[TraceTask] = [] while len(tasks) < trace_manager_batch_size and not trace_manager_queue.empty(): task = trace_manager_queue.get_nowait() tasks.append(task) @@ -749,6 +765,8 @@ def start_timer(self): def send_to_celery(self, tasks: list[TraceTask]): with self.flask_app.app_context(): for task in tasks: + if task.app_id is None: + continue file_id = uuid4().hex trace_info = task.execute() task_data = TaskData( diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 0f3f8249661bf0..87c7a79fb01201 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,5 +1,5 @@ -from collections.abc import Sequence -from typing import Optional +from collections.abc import Mapping, Sequence +from typing import Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file import file_manager @@ -39,7 +39,7 @@ def get_prompt( self, *, prompt_template: Sequence[ChatModelMessage] | CompletionModelPromptTemplate, - inputs: dict[str, str], + inputs: Mapping[str, str], query: str, files: Sequence[File], context: Optional[str], @@ -77,7 +77,7 @@ def get_prompt( def _get_completion_model_prompt_messages( self, prompt_template: CompletionModelPromptTemplate, - inputs: dict, + inputs: Mapping[str, str], query: Optional[str], files: Sequence[File], context: Optional[str], @@ -90,15 +90,15 @@ def _get_completion_model_prompt_messages( """ raw_prompt = prompt_template.text - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] if prompt_template.edition_type == "basic" or not prompt_template.edition_type: parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) - prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} + prompt_inputs: Mapping[str, str] = {k: inputs[k] for k in parser.variable_keys if k in inputs} prompt_inputs = self._set_context_variable(context, parser, prompt_inputs) - if memory and memory_config: + if memory and memory_config and memory_config.role_prefix: role_prefix = memory_config.role_prefix prompt_inputs = self._set_histories_variable( memory=memory, @@ -135,7 +135,7 @@ def _get_completion_model_prompt_messages( def _get_chat_model_prompt_messages( self, prompt_template: list[ChatModelMessage], - inputs: dict, + inputs: Mapping[str, str], query: Optional[str], files: Sequence[File], context: Optional[str], @@ -146,7 +146,7 @@ def _get_chat_model_prompt_messages( """ Get chat model prompt messages. """ - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] for prompt_item in prompt_template: raw_prompt = prompt_item.text @@ -160,7 +160,7 @@ def _get_chat_model_prompt_messages( prompt = vp.convert_template(raw_prompt).text else: parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) - prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} + prompt_inputs: Mapping[str, str] = {k: inputs[k] for k in parser.variable_keys if k in inputs} prompt_inputs = self._set_context_variable( context=context, parser=parser, prompt_inputs=prompt_inputs ) @@ -207,7 +207,7 @@ def _get_chat_model_prompt_messages( last_message = prompt_messages[-1] if prompt_messages else None if last_message and last_message.role == PromptMessageRole.USER: # get last user message content and add files - prompt_message_contents = [TextPromptMessageContent(data=last_message.content)] + prompt_message_contents = [TextPromptMessageContent(data=cast(str, last_message.content))] for file in files: prompt_message_contents.append(file_manager.to_prompt_message_content(file)) @@ -229,7 +229,10 @@ def _get_chat_model_prompt_messages( return prompt_messages - def _set_context_variable(self, context: str | None, parser: PromptTemplateParser, prompt_inputs: dict) -> dict: + def _set_context_variable( + self, context: str | None, parser: PromptTemplateParser, prompt_inputs: Mapping[str, str] + ) -> Mapping[str, str]: + prompt_inputs = dict(prompt_inputs) if "#context#" in parser.variable_keys: if context: prompt_inputs["#context#"] = context @@ -238,7 +241,10 @@ def _set_context_variable(self, context: str | None, parser: PromptTemplateParse return prompt_inputs - def _set_query_variable(self, query: str, parser: PromptTemplateParser, prompt_inputs: dict) -> dict: + def _set_query_variable( + self, query: str, parser: PromptTemplateParser, prompt_inputs: Mapping[str, str] + ) -> Mapping[str, str]: + prompt_inputs = dict(prompt_inputs) if "#query#" in parser.variable_keys: if query: prompt_inputs["#query#"] = query @@ -254,9 +260,10 @@ def _set_histories_variable( raw_prompt: str, role_prefix: MemoryConfig.RolePrefix, parser: PromptTemplateParser, - prompt_inputs: dict, + prompt_inputs: Mapping[str, str], model_config: ModelConfigWithCredentialsEntity, - ) -> dict: + ) -> Mapping[str, str]: + prompt_inputs = dict(prompt_inputs) if "#histories#" in parser.variable_keys: if memory: inputs = {"#histories#": "", **prompt_inputs} diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index caa1793ea8c039..09f017a7db0d3a 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -31,7 +31,7 @@ def __init__( self.memory = memory def get_prompt(self) -> list[PromptMessage]: - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] num_system = 0 for prompt_message in self.history_messages: if isinstance(prompt_message, SystemPromptMessage): diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 87acdb3c49cc01..1f040599be6dac 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Optional from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -42,7 +42,7 @@ def _calculate_rest_token( ): max_tokens = ( model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template) + or model_config.parameters.get(parameter_rule.use_template or "") ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens @@ -59,7 +59,7 @@ def _get_history_messages_from_memory( ai_prefix: Optional[str] = None, ) -> str: """Get memory messages.""" - kwargs = {"max_token_limit": max_token_limit} + kwargs: dict[str, Any] = {"max_token_limit": max_token_limit} if human_prefix: kwargs["human_prefix"] = human_prefix @@ -76,11 +76,15 @@ def _get_history_messages_list_from_memory( self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int ) -> list[PromptMessage]: """Get memory messages.""" - return memory.get_history_prompt_messages( - max_token_limit=max_token_limit, - message_limit=memory_config.window.size - if ( - memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0 + return list( + memory.get_history_prompt_messages( + max_token_limit=max_token_limit, + message_limit=memory_config.window.size + if ( + memory_config.window.enabled + and memory_config.window.size is not None + and memory_config.window.size > 0 + ) + else None, ) - else None, ) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 93dd92f188a9c6..e75877de9b695c 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -1,7 +1,8 @@ import enum import json import os -from typing import TYPE_CHECKING, Optional +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Any, Optional, cast from core.app.app_config.entities import PromptTemplateEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity @@ -41,7 +42,7 @@ def value_of(cls, value: str) -> "ModelMode": raise ValueError(f"invalid mode value {value}") -prompt_file_contents = {} +prompt_file_contents: dict[str, Any] = {} class SimplePromptTransform(PromptTransform): @@ -53,9 +54,9 @@ def get_prompt( self, app_mode: AppMode, prompt_template_entity: PromptTemplateEntity, - inputs: dict, + inputs: Mapping[str, str], query: str, - files: list["File"], + files: Sequence["File"], context: Optional[str], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, @@ -66,7 +67,7 @@ def get_prompt( if model_mode == ModelMode.CHAT: prompt_messages, stops = self._get_chat_model_prompt_messages( app_mode=app_mode, - pre_prompt=prompt_template_entity.simple_prompt_template, + pre_prompt=prompt_template_entity.simple_prompt_template or "", inputs=inputs, query=query, files=files, @@ -77,7 +78,7 @@ def get_prompt( else: prompt_messages, stops = self._get_completion_model_prompt_messages( app_mode=app_mode, - pre_prompt=prompt_template_entity.simple_prompt_template, + pre_prompt=prompt_template_entity.simple_prompt_template or "", inputs=inputs, query=query, files=files, @@ -171,11 +172,11 @@ def _get_chat_model_prompt_messages( inputs: dict, query: str, context: Optional[str], - files: list["File"], + files: Sequence["File"], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, ) -> tuple[list[PromptMessage], Optional[list[str]]]: - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] # get prompt prompt, _ = self.get_prompt_str_and_rules( @@ -216,7 +217,7 @@ def _get_completion_model_prompt_messages( inputs: dict, query: str, context: Optional[str], - files: list["File"], + files: Sequence["File"], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, ) -> tuple[list[PromptMessage], Optional[list[str]]]: @@ -263,7 +264,7 @@ def _get_completion_model_prompt_messages( return [self.get_last_user_message(prompt, files)], stops - def get_last_user_message(self, prompt: str, files: list["File"]) -> UserPromptMessage: + def get_last_user_message(self, prompt: str, files: Sequence["File"]) -> UserPromptMessage: if files: prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents.append(TextPromptMessageContent(data=prompt)) @@ -288,7 +289,7 @@ def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict # Check if the prompt file is already loaded if prompt_file_name in prompt_file_contents: - return prompt_file_contents[prompt_file_name] + return cast(dict, prompt_file_contents[prompt_file_name]) # Get the absolute path of the subdirectory prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates") @@ -301,7 +302,7 @@ def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict # Store the content of the prompt file prompt_file_contents[prompt_file_name] = content - return content + return cast(dict, content) def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str: # baichuan diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index aa175153bc633f..2f4e65146131be 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import cast +from typing import Any, cast from core.model_runtime.entities import ( AssistantPromptMessage, @@ -72,7 +72,7 @@ def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Seque } ) else: - text = prompt_message.content + text = cast(str, prompt_message.content) prompt = {"role": role, "text": text, "files": files} @@ -99,9 +99,9 @@ def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Seque } ) else: - text = prompt_message.content + text = cast(str, prompt_message.content) - params = { + params: dict[str, Any] = { "role": "user", "text": text, } diff --git a/api/core/prompt/utils/prompt_template_parser.py b/api/core/prompt/utils/prompt_template_parser.py index 0fd08c5d3c1a3e..8e40674bc193e0 100644 --- a/api/core/prompt/utils/prompt_template_parser.py +++ b/api/core/prompt/utils/prompt_template_parser.py @@ -1,4 +1,5 @@ import re +from collections.abc import Mapping REGEX = re.compile(r"\{\{([a-zA-Z_][a-zA-Z0-9_]{0,29}|#histories#|#query#|#context#)\}\}") WITH_VARIABLE_TMPL_REGEX = re.compile( @@ -28,7 +29,7 @@ def extract(self) -> list: # Regular expression to match the template rules return re.findall(self.regex, self.template) - def format(self, inputs: dict, remove_template_variables: bool = True) -> str: + def format(self, inputs: Mapping[str, str], remove_template_variables: bool = True) -> str: def replacer(match): key = match.group(1) value = inputs.get(key, match.group(0)) # return original matched string if key not found diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 3a1fe300dfd311..010abd12d275cd 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -1,7 +1,7 @@ import json from collections import defaultdict from json import JSONDecodeError -from typing import Optional +from typing import Optional, cast from sqlalchemy.exc import IntegrityError @@ -15,6 +15,7 @@ ModelLoadBalancingConfiguration, ModelSettings, QuotaConfiguration, + QuotaUnit, SystemConfiguration, ) from core.helper import encrypter @@ -116,8 +117,8 @@ def get_configurations(self, tenant_id: str) -> ProviderConfigurations: for provider_entity in provider_entities: # handle include, exclude if is_filtered( - include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET, - exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET, + include_set=cast(set[str], dify_config.POSITION_PROVIDER_INCLUDES_SET), + exclude_set=cast(set[str], dify_config.POSITION_PROVIDER_EXCLUDES_SET), data=provider_entity, name_func=lambda x: x.provider, ): @@ -490,12 +491,13 @@ def _init_trial_provider_records( # Init trial provider records if not exists if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict: try: + # FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic provider_record = Provider( tenant_id=tenant_id, provider_name=provider_name, provider_type=ProviderType.SYSTEM.value, quota_type=ProviderQuotaType.TRIAL.value, - quota_limit=quota.quota_limit, + quota_limit=quota.quota_limit, # type: ignore quota_used=0, is_valid=True, ) @@ -589,7 +591,9 @@ def _to_custom_configuration( if variable in provider_credentials: try: provider_credentials[variable] = encrypter.decrypt_token_with_decoding( - provider_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa + provider_credentials.get(variable) or "", # type: ignore + self.decoding_rsa_key, + self.decoding_cipher_rsa, ) except ValueError: pass @@ -671,13 +675,9 @@ def _to_system_configuration( # Get hosting configuration hosting_configuration = ext_hosting_provider.hosting_configuration - if ( - provider_entity.provider not in hosting_configuration.provider_map - or not hosting_configuration.provider_map.get(provider_entity.provider).enabled - ): - return SystemConfiguration(enabled=False) - provider_hosting_configuration = hosting_configuration.provider_map.get(provider_entity.provider) + if provider_hosting_configuration is None or not provider_hosting_configuration.enabled: + return SystemConfiguration(enabled=False) # Convert provider_records to dict quota_type_to_provider_records_dict = {} @@ -688,14 +688,13 @@ def _to_system_configuration( quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = ( provider_record ) - quota_configurations = [] for provider_quota in provider_hosting_configuration.quotas: if provider_quota.quota_type not in quota_type_to_provider_records_dict: if provider_quota.quota_type == ProviderQuotaType.FREE: quota_configuration = QuotaConfiguration( quota_type=provider_quota.quota_type, - quota_unit=provider_hosting_configuration.quota_unit, + quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, quota_used=0, quota_limit=0, is_valid=False, @@ -708,7 +707,7 @@ def _to_system_configuration( quota_configuration = QuotaConfiguration( quota_type=provider_quota.quota_type, - quota_unit=provider_hosting_configuration.quota_unit, + quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, quota_used=provider_record.quota_used, quota_limit=provider_record.quota_limit, is_valid=provider_record.quota_limit > provider_record.quota_used @@ -725,12 +724,12 @@ def _to_system_configuration( current_using_credentials = provider_hosting_configuration.credentials if current_quota_type == ProviderQuotaType.FREE: - provider_record = quota_type_to_provider_records_dict.get(current_quota_type) + provider_record_quota_free = quota_type_to_provider_records_dict.get(current_quota_type) - if provider_record: + if provider_record_quota_free: provider_credentials_cache = ProviderCredentialsCache( tenant_id=tenant_id, - identity_id=provider_record.id, + identity_id=provider_record_quota_free.id, cache_type=ProviderCredentialsCacheType.PROVIDER, ) @@ -763,7 +762,7 @@ def _to_system_configuration( except ValueError: pass - current_using_credentials = provider_credentials + current_using_credentials = provider_credentials or {} # cache provider credentials provider_credentials_cache.set(credentials=current_using_credentials) @@ -842,7 +841,7 @@ def _to_model_settings( else [] ) - model_settings = [] + model_settings: list[ModelSettings] = [] if not provider_model_settings: return model_settings diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index a0153c1e58a1a8..95a2316f1da4dd 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -32,8 +32,11 @@ def create(self, texts: list[Document], **kwargs) -> BaseKeyword: keywords = keyword_table_handler.extract_keywords( text.page_content, self._config.max_keywords_per_chunk ) - self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) - keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords)) + if text.metadata is not None: + self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) + keyword_table = self._add_text_to_keyword_table( + keyword_table or {}, text.metadata["doc_id"], list(keywords) + ) self._save_dataset_keyword_table(keyword_table) @@ -58,20 +61,26 @@ def add_texts(self, texts: list[Document], **kwargs): keywords = keyword_table_handler.extract_keywords( text.page_content, self._config.max_keywords_per_chunk ) - self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) - keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords)) + if text.metadata is not None: + self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) + keyword_table = self._add_text_to_keyword_table( + keyword_table or {}, text.metadata["doc_id"], list(keywords) + ) self._save_dataset_keyword_table(keyword_table) def text_exists(self, id: str) -> bool: keyword_table = self._get_dataset_keyword_table() + if keyword_table is None: + return False return id in set.union(*keyword_table.values()) def delete_by_ids(self, ids: list[str]) -> None: lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) with redis_client.lock(lock_name, timeout=600): keyword_table = self._get_dataset_keyword_table() - keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) + if keyword_table is not None: + keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) self._save_dataset_keyword_table(keyword_table) @@ -80,7 +89,7 @@ def search(self, query: str, **kwargs: Any) -> list[Document]: k = kwargs.get("top_k", 4) - sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k) + sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k) documents = [] for chunk_index in sorted_chunk_indices: @@ -137,7 +146,7 @@ def _get_dataset_keyword_table(self) -> Optional[dict]: if dataset_keyword_table: keyword_table_dict = dataset_keyword_table.keyword_table_dict if keyword_table_dict: - return keyword_table_dict["__data__"]["table"] + return dict(keyword_table_dict["__data__"]["table"]) else: keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE dataset_keyword_table = DatasetKeywordTable( @@ -188,8 +197,8 @@ def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4): # go through text chunks in order of most matching keywords chunk_indices_count: dict[str, int] = defaultdict(int) - keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())] - for keyword in keywords: + keywords_list = [keyword for keyword in keywords if keyword in set(keyword_table.keys())] + for keyword in keywords_list: for node_id in keyword_table[keyword]: chunk_indices_count[node_id] += 1 @@ -215,7 +224,7 @@ def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list def create_segment_keywords(self, node_id: str, keywords: list[str]): keyword_table = self._get_dataset_keyword_table() self._update_segment_keywords(self.dataset.id, node_id, keywords) - keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords) + keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords) self._save_dataset_keyword_table(keyword_table) def multi_create_segment_keywords(self, pre_segment_data_list: list): @@ -226,17 +235,19 @@ def multi_create_segment_keywords(self, pre_segment_data_list: list): if pre_segment_data["keywords"]: segment.keywords = pre_segment_data["keywords"] keyword_table = self._add_text_to_keyword_table( - keyword_table, segment.index_node_id, pre_segment_data["keywords"] + keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"] ) else: keywords = keyword_table_handler.extract_keywords(segment.content, self._config.max_keywords_per_chunk) segment.keywords = list(keywords) - keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords)) + keyword_table = self._add_text_to_keyword_table( + keyword_table or {}, segment.index_node_id, list(keywords) + ) self._save_dataset_keyword_table(keyword_table) def update_segment_keywords_index(self, node_id: str, keywords: list[str]): keyword_table = self._get_dataset_keyword_table() - keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords) + keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords) self._save_dataset_keyword_table(keyword_table) diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index ec809cf325306e..8b17e8dc0a3762 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -4,7 +4,7 @@ class JiebaKeywordTableHandler: def __init__(self): - import jieba.analyse + import jieba.analyse # type: ignore from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS @@ -12,7 +12,7 @@ def __init__(self): def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]: """Extract keywords with JIEBA tfidf.""" - import jieba + import jieba # type: ignore keywords = jieba.analyse.extract_tags( sentence=text, diff --git a/api/core/rag/datasource/keyword/keyword_base.py b/api/core/rag/datasource/keyword/keyword_base.py index be00687abd5025..b261b40b728692 100644 --- a/api/core/rag/datasource/keyword/keyword_base.py +++ b/api/core/rag/datasource/keyword/keyword_base.py @@ -37,6 +37,8 @@ def search(self, query: str, **kwargs: Any) -> list[Document]: def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: for text in texts.copy(): + if text.metadata is None: + continue doc_id = text.metadata["doc_id"] exists_duplicate_node = self.text_exists(doc_id) if exists_duplicate_node: @@ -45,4 +47,4 @@ def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: return texts def _get_uuids(self, texts: list[Document]) -> list[str]: - return [text.metadata["doc_id"] for text in texts] + return [text.metadata["doc_id"] for text in texts if text.metadata] diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 18f8d4e8392302..34343ad60ea4c1 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -6,6 +6,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db @@ -31,7 +32,7 @@ def retrieve( top_k: int, score_threshold: Optional[float] = 0.0, reranking_model: Optional[dict] = None, - reranking_mode: Optional[str] = "reranking_model", + reranking_mode: str = "reranking_model", weights: Optional[dict] = None, ): if not query: @@ -42,15 +43,15 @@ def retrieve( if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: return [] - all_documents = [] - threads = [] - exceptions = [] + all_documents: list[Document] = [] + threads: list[threading.Thread] = [] + exceptions: list[str] = [] # retrieval_model source with keyword if retrieval_method == "keyword_search": keyword_thread = threading.Thread( target=RetrievalService.keyword_search, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "dataset_id": dataset_id, "query": query, "top_k": top_k, @@ -65,7 +66,7 @@ def retrieve( embedding_thread = threading.Thread( target=RetrievalService.embedding_search, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "dataset_id": dataset_id, "query": query, "top_k": top_k, @@ -84,7 +85,7 @@ def retrieve( full_text_index_thread = threading.Thread( target=RetrievalService.full_text_index_search, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "dataset_id": dataset_id, "query": query, "retrieval_method": retrieval_method, @@ -124,7 +125,7 @@ def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model if not dataset: return [] all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( - dataset.tenant_id, dataset_id, query, external_retrieval_model + dataset.tenant_id, dataset_id, query, external_retrieval_model or {} ) return all_documents @@ -135,6 +136,8 @@ def keyword_search( with flask_app.app_context(): try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise ValueError("dataset not found") keyword = Keyword(dataset=dataset) @@ -159,6 +162,8 @@ def embedding_search( with flask_app.app_context(): try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise ValueError("dataset not found") vector = Vector(dataset=dataset) @@ -209,6 +214,8 @@ def full_text_index_search( with flask_app.app_context(): try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise ValueError("dataset not found") vector_processor = Vector( dataset=dataset, diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py index 09104ae4223443..603d3fdbcdf1ab 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py @@ -17,12 +17,19 @@ class AnalyticdbVector(BaseVector): def __init__( - self, collection_name: str, api_config: AnalyticdbVectorOpenAPIConfig, sql_config: AnalyticdbVectorBySqlConfig + self, + collection_name: str, + api_config: AnalyticdbVectorOpenAPIConfig | None, + sql_config: AnalyticdbVectorBySqlConfig | None, ): super().__init__(collection_name) if api_config is not None: - self.analyticdb_vector = AnalyticdbVectorOpenAPI(collection_name, api_config) + self.analyticdb_vector: AnalyticdbVectorOpenAPI | AnalyticdbVectorBySql = AnalyticdbVectorOpenAPI( + collection_name, api_config + ) else: + if sql_config is None: + raise ValueError("Either api_config or sql_config must be provided") self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config) def get_type(self) -> str: @@ -33,8 +40,8 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) self.analyticdb_vector._create_collection_if_not_exists(dimension) self.analyticdb_vector.add_texts(texts, embeddings) - def add_texts(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - self.analyticdb_vector.add_texts(texts, embeddings) + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + self.analyticdb_vector.add_texts(documents, embeddings) def text_exists(self, id: str) -> bool: return self.analyticdb_vector.text_exists(id) @@ -68,13 +75,13 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings if dify_config.ANALYTICDB_HOST is None: # implemented through OpenAPI apiConfig = AnalyticdbVectorOpenAPIConfig( - access_key_id=dify_config.ANALYTICDB_KEY_ID, - access_key_secret=dify_config.ANALYTICDB_KEY_SECRET, - region_id=dify_config.ANALYTICDB_REGION_ID, - instance_id=dify_config.ANALYTICDB_INSTANCE_ID, - account=dify_config.ANALYTICDB_ACCOUNT, - account_password=dify_config.ANALYTICDB_PASSWORD, - namespace=dify_config.ANALYTICDB_NAMESPACE, + access_key_id=dify_config.ANALYTICDB_KEY_ID or "", + access_key_secret=dify_config.ANALYTICDB_KEY_SECRET or "", + region_id=dify_config.ANALYTICDB_REGION_ID or "", + instance_id=dify_config.ANALYTICDB_INSTANCE_ID or "", + account=dify_config.ANALYTICDB_ACCOUNT or "", + account_password=dify_config.ANALYTICDB_PASSWORD or "", + namespace=dify_config.ANALYTICDB_NAMESPACE or "", namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD, ) sqlConfig = None @@ -83,11 +90,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings sqlConfig = AnalyticdbVectorBySqlConfig( host=dify_config.ANALYTICDB_HOST, port=dify_config.ANALYTICDB_PORT, - account=dify_config.ANALYTICDB_ACCOUNT, - account_password=dify_config.ANALYTICDB_PASSWORD, + account=dify_config.ANALYTICDB_ACCOUNT or "", + account_password=dify_config.ANALYTICDB_PASSWORD or "", min_connection=dify_config.ANALYTICDB_MIN_CONNECTION, max_connection=dify_config.ANALYTICDB_MAX_CONNECTION, - namespace=dify_config.ANALYTICDB_NAMESPACE, + namespace=dify_config.ANALYTICDB_NAMESPACE or "", ) apiConfig = None return AnalyticdbVector( diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py index 05e0ebc54f7c4c..095752ea8eaa42 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py @@ -1,5 +1,5 @@ import json -from typing import Any +from typing import Any, Optional from pydantic import BaseModel, model_validator @@ -20,7 +20,7 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel): account: str account_password: str namespace: str = "dify" - namespace_password: str = (None,) + namespace_password: Optional[str] = None metrics: str = "cosine" read_timeout: int = 60000 @@ -55,8 +55,8 @@ def to_analyticdb_client_params(self): class AnalyticdbVectorOpenAPI: def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig): try: - from alibabacloud_gpdb20160503.client import Client - from alibabacloud_tea_openapi import models as open_api_models + from alibabacloud_gpdb20160503.client import Client # type: ignore + from alibabacloud_tea_openapi import models as open_api_models # type: ignore except: raise ImportError(_import_err_msg) self._collection_name = collection_name.lower() @@ -77,7 +77,7 @@ def _initialize(self) -> None: redis_client.set(database_exist_cache_key, 1, ex=3600) def _initialize_vector_database(self) -> None: - from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + from alibabacloud_gpdb20160503 import models as gpdb_20160503_models # type: ignore request = gpdb_20160503_models.InitVectorDatabaseRequest( dbinstance_id=self.config.instance_id, @@ -89,7 +89,7 @@ def _initialize_vector_database(self) -> None: def _create_namespace_if_not_exists(self) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - from Tea.exceptions import TeaException + from Tea.exceptions import TeaException # type: ignore try: request = gpdb_20160503_models.DescribeNamespaceRequest( @@ -159,17 +159,18 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = [] for doc, embedding in zip(documents, embeddings, strict=True): - metadata = { - "ref_doc_id": doc.metadata["doc_id"], - "page_content": doc.page_content, - "metadata_": json.dumps(doc.metadata), - } - rows.append( - gpdb_20160503_models.UpsertCollectionDataRequestRows( - vector=embedding, - metadata=metadata, + if doc.metadata is not None: + metadata = { + "ref_doc_id": doc.metadata["doc_id"], + "page_content": doc.page_content, + "metadata_": json.dumps(doc.metadata), + } + rows.append( + gpdb_20160503_models.UpsertCollectionDataRequestRows( + vector=embedding, + metadata=metadata, + ) ) - ) request = gpdb_20160503_models.UpsertCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -258,7 +259,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc metadata=metadata, ) documents.append(doc) - documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) + documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return documents def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -290,7 +291,7 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: metadata=metadata, ) documents.append(doc) - documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) + documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return documents def delete(self) -> None: diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py index e474db5cb21971..4d8f7929413cf2 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py @@ -3,8 +3,8 @@ from contextlib import contextmanager from typing import Any -import psycopg2.extras -import psycopg2.pool +import psycopg2.extras # type: ignore +import psycopg2.pool # type: ignore from pydantic import BaseModel, model_validator from core.rag.models.document import Document @@ -75,6 +75,7 @@ def _create_connection_pool(self): @contextmanager def _get_cursor(self): + assert self.pool is not None, "Connection pool is not initialized" conn = self.pool.getconn() cur = conn.cursor() try: @@ -156,16 +157,17 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s)); """ for i, doc in enumerate(documents): - values.append( - ( - id_prefix + str(i), - doc.metadata.get("doc_id", str(uuid.uuid4())), - embeddings[i], - doc.page_content, - json.dumps(doc.metadata), - doc.page_content, + if doc.metadata is not None: + values.append( + ( + id_prefix + str(i), + doc.metadata.get("doc_id", str(uuid.uuid4())), + embeddings[i], + doc.page_content, + json.dumps(doc.metadata), + doc.page_content, + ) ) - ) with self._get_cursor() as cur: psycopg2.extras.execute_batch(cur, sql, values) diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index eb78e8aa698b9b..85596ad20e099a 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -5,13 +5,13 @@ import numpy as np from pydantic import BaseModel, model_validator -from pymochow import MochowClient -from pymochow.auth.bce_credentials import BceCredentials -from pymochow.configuration import Configuration -from pymochow.exception import ServerError -from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState -from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex -from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row +from pymochow import MochowClient # type: ignore +from pymochow.auth.bce_credentials import BceCredentials # type: ignore +from pymochow.configuration import Configuration # type: ignore +from pymochow.exception import ServerError # type: ignore +from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore +from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex # type: ignore +from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row # type: ignore from configs import dify_config from core.rag.datasource.vdb.vector_base import BaseVector @@ -75,7 +75,7 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): texts = [doc.page_content for doc in documents] - metadatas = [doc.metadata for doc in documents] + metadatas = [doc.metadata for doc in documents if doc.metadata is not None] total_count = len(documents) batch_size = 1000 @@ -84,6 +84,8 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** for start in range(0, total_count, batch_size): end = min(start + batch_size, total_count) rows = [] + assert len(metadatas) == total_count, "metadatas length should be equal to total_count" + # FIXME do you need this assert? for i in range(start, end, 1): row = Row( id=metadatas[i].get("doc_id", str(uuid.uuid4())), @@ -136,7 +138,7 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: # baidu vector database doesn't support bm25 search on current version return [] - def _get_search_res(self, res, score_threshold): + def _get_search_res(self, res, score_threshold) -> list[Document]: docs = [] for row in res.rows: row_data = row.get("row", {}) @@ -276,11 +278,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return BaiduVector( collection_name=collection_name, config=BaiduConfig( - endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT, + endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT or "", connection_timeout_in_mills=dify_config.BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS, - account=dify_config.BAIDU_VECTOR_DB_ACCOUNT, - api_key=dify_config.BAIDU_VECTOR_DB_API_KEY, - database=dify_config.BAIDU_VECTOR_DB_DATABASE, + account=dify_config.BAIDU_VECTOR_DB_ACCOUNT or "", + api_key=dify_config.BAIDU_VECTOR_DB_API_KEY or "", + database=dify_config.BAIDU_VECTOR_DB_DATABASE or "", shard=dify_config.BAIDU_VECTOR_DB_SHARD, replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS, ), diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index a9e1486edd25f1..0eab01b507dc94 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -71,11 +71,13 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** metadatas = [d.metadata for d in documents] collection = self._client.get_or_create_collection(self._collection_name) - collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) + # FIXME: chromadb using numpy array, fix the type error later + collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) # type: ignore def delete_by_metadata_field(self, key: str, value: str): collection = self._client.get_or_create_collection(self._collection_name) - collection.delete(where={key: {"$eq": value}}) + # FIXME: fix the type error later + collection.delete(where={key: {"$eq": value}}) # type: ignore def delete(self): self._client.delete_collection(self._collection_name) @@ -94,15 +96,19 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) score_threshold = float(kwargs.get("score_threshold") or 0.0) - ids: list[str] = results["ids"][0] - documents: list[str] = results["documents"][0] - metadatas: dict[str, Any] = results["metadatas"][0] - distances: list[float] = results["distances"][0] + # Check if results contain data + if not results["ids"] or not results["documents"] or not results["metadatas"] or not results["distances"]: + return [] + + ids = results["ids"][0] + documents = results["documents"][0] + metadatas = results["metadatas"][0] + distances = results["distances"][0] docs = [] for index in range(len(ids)): distance = distances[index] - metadata = metadatas[index] + metadata = dict(metadatas[index]) if distance >= score_threshold: metadata["score"] = distance doc = Document( @@ -111,7 +117,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc ) docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -133,7 +139,7 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return ChromaVector( collection_name=collection_name, config=ChromaConfig( - host=dify_config.CHROMA_HOST, + host=dify_config.CHROMA_HOST or "", port=dify_config.CHROMA_PORT, tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT, database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE, diff --git a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py b/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py index d26726e86438bd..68a9952789e5b6 100644 --- a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py +++ b/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py @@ -5,14 +5,14 @@ from datetime import timedelta from typing import Any -from couchbase import search -from couchbase.auth import PasswordAuthenticator -from couchbase.cluster import Cluster -from couchbase.management.search import SearchIndex +from couchbase import search # type: ignore +from couchbase.auth import PasswordAuthenticator # type: ignore +from couchbase.cluster import Cluster # type: ignore +from couchbase.management.search import SearchIndex # type: ignore # needed for options -- cluster, timeout, SQL++ (N1QL) query, etc. -from couchbase.options import ClusterOptions, SearchOptions -from couchbase.vector_search import VectorQuery, VectorSearch +from couchbase.options import ClusterOptions, SearchOptions # type: ignore +from couchbase.vector_search import VectorQuery, VectorSearch # type: ignore from flask import current_app from pydantic import BaseModel, model_validator @@ -231,7 +231,7 @@ def text_exists(self, id: str) -> bool: # Pass the id as a parameter to the query result = self._cluster.query(query, named_parameters={"doc_id": id}).execute() for row in result: - return row["count"] > 0 + return bool(row["count"] > 0) return False # Return False if no rows are returned def delete_by_ids(self, ids: list[str]) -> None: @@ -369,10 +369,10 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return CouchbaseVector( collection_name=collection_name, config=CouchbaseConfig( - connection_string=config.get("COUCHBASE_CONNECTION_STRING"), - user=config.get("COUCHBASE_USER"), - password=config.get("COUCHBASE_PASSWORD"), - bucket_name=config.get("COUCHBASE_BUCKET_NAME"), - scope_name=config.get("COUCHBASE_SCOPE_NAME"), + connection_string=config.get("COUCHBASE_CONNECTION_STRING", ""), + user=config.get("COUCHBASE_USER", ""), + password=config.get("COUCHBASE_PASSWORD", ""), + bucket_name=config.get("COUCHBASE_BUCKET_NAME", ""), + scope_name=config.get("COUCHBASE_SCOPE_NAME", ""), ), ) diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index b08811a02181d2..8661828dc2aa52 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -1,7 +1,7 @@ import json import logging import math -from typing import Any, Optional +from typing import Any, Optional, cast from urllib.parse import urlparse import requests @@ -70,7 +70,7 @@ def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: def _get_version(self) -> str: info = self._client.info() - return info["version"]["number"] + return cast(str, info["version"]["number"]) def _check_version(self): if self._version < "8.0.0": @@ -135,7 +135,8 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc for doc, score in docs_and_scores: score_threshold = float(kwargs.get("score_threshold") or 0.0) if score > score_threshold: - doc.metadata["score"] = score + if doc.metadata is not None: + doc.metadata["score"] = score docs.append(doc) return docs @@ -156,12 +157,15 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return docs def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - metadatas = [d.metadata for d in texts] + metadatas = [d.metadata if d.metadata is not None else {} for d in texts] self.create_collection(embeddings, metadatas) self.add_texts(texts, embeddings, **kwargs) def create_collection( - self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + self, + embeddings: list[list[float]], + metadatas: Optional[list[dict[Any, Any]]] = None, + index_params: Optional[dict] = None, ): lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): @@ -208,10 +212,10 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return ElasticSearchVector( index_name=collection_name, config=ElasticSearchConfig( - host=config.get("ELASTICSEARCH_HOST"), - port=config.get("ELASTICSEARCH_PORT"), - username=config.get("ELASTICSEARCH_USERNAME"), - password=config.get("ELASTICSEARCH_PASSWORD"), + host=config.get("ELASTICSEARCH_HOST", "localhost"), + port=config.get("ELASTICSEARCH_PORT", 9200), + username=config.get("ELASTICSEARCH_USERNAME", ""), + password=config.get("ELASTICSEARCH_PASSWORD", ""), ), attributes=[], ) diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py index 8646e52cf493ca..d7a14207e9375a 100644 --- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py +++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py @@ -42,7 +42,7 @@ def validate_config(cls, values: dict) -> dict: return values def to_opensearch_params(self) -> dict[str, Any]: - params = {"hosts": self.hosts} + params: dict[str, Any] = {"hosts": self.hosts} if self.username and self.password: params["http_auth"] = (self.username, self.password) return params @@ -53,7 +53,7 @@ def __init__(self, collection_name: str, config: LindormVectorStoreConfig, using self._routing = None self._routing_field = None if using_ugc: - routing_value: str = kwargs.get("routing_value") + routing_value: str | None = kwargs.get("routing_value") if routing_value is None: raise ValueError("UGC index should init vector with valid 'routing_value' parameter value") self._routing = routing_value.lower() @@ -87,14 +87,15 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** "_id": uuids[i], } } - action_values = { + action_values: dict[str, Any] = { Field.CONTENT_KEY.value: documents[i].page_content, Field.VECTOR.value: embeddings[i], # Make sure you pass an array here Field.METADATA_KEY.value: documents[i].metadata, } if self._using_ugc: action_header["index"]["routing"] = self._routing - action_values[self._routing_field] = self._routing + if self._routing_field is not None: + action_values[self._routing_field] = self._routing actions.append(action_header) actions.append(action_values) response = self._client.bulk(actions) @@ -105,7 +106,9 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** self.refresh() def get_ids_by_metadata_field(self, key: str, value: str): - query = {"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}} + query: dict[str, Any] = { + "query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}} + } if self._using_ugc: query["query"]["bool"]["must"].append({"term": {f"{self._routing_field}.keyword": self._routing}}) response = self._client.search(index=self._collection_name, body=query) @@ -191,7 +194,8 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc for doc, score in docs_and_scores: score_threshold = kwargs.get("score_threshold", 0.0) or 0.0 if score > score_threshold: - doc.metadata["score"] = score + if doc.metadata is not None: + doc.metadata["score"] = score docs.append(doc) return docs @@ -366,6 +370,7 @@ def default_text_search_query( routing_field: Optional[str] = None, **kwargs, ) -> dict: + query_clause: dict[str, Any] = {} if routing is not None: query_clause = { "bool": {"must": [{"match": {text_field: query_text}}, {"term": {f"{routing_field}.keyword": routing}}]} @@ -386,7 +391,7 @@ def default_text_search_query( else: must = [query_clause] - boolean_query = {"must": must} + boolean_query: dict[str, Any] = {"must": must} if must_not: if not isinstance(must_not, list): @@ -426,7 +431,7 @@ def default_vector_search_query( filter_type = "post_filter" if filter_type is None else filter_type if not isinstance(filters, list): raise RuntimeError(f"unexpected filter with {type(filters)}") - final_ext = {"lvector": {}} + final_ext: dict[str, Any] = {"lvector": {}} if min_score != "0.0": final_ext["lvector"]["min_score"] = min_score if ef_search: @@ -438,7 +443,7 @@ def default_vector_search_query( if client_refactor: final_ext["lvector"]["client_refactor"] = client_refactor - search_query = { + search_query: dict[str, Any] = { "size": k, "_source": True, # force return '_source' "query": {"knn": {vector_field: {"vector": query_vector, "k": k}}}, @@ -446,8 +451,8 @@ def default_vector_search_query( if filters is not None: # when using filter, transform filter from List[Dict] to Dict as valid format - filters = {"bool": {"must": filters}} if len(filters) > 1 else filters[0] - search_query["query"]["knn"][vector_field]["filter"] = filters # filter should be Dict + filter_dict = {"bool": {"must": filters}} if len(filters) > 1 else filters[0] + search_query["query"]["knn"][vector_field]["filter"] = filter_dict # filter should be Dict if filter_type: final_ext["lvector"]["filter_type"] = filter_type @@ -459,17 +464,19 @@ def default_vector_search_query( class LindormVectorStoreFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore: lindorm_config = LindormVectorStoreConfig( - hosts=dify_config.LINDORM_URL, + hosts=dify_config.LINDORM_URL or "", username=dify_config.LINDORM_USERNAME, password=dify_config.LINDORM_PASSWORD, using_ugc=dify_config.USING_UGC_INDEX, ) using_ugc = dify_config.USING_UGC_INDEX + if using_ugc is None: + raise ValueError("USING_UGC_INDEX is not set") routing_value = None if dataset.index_struct: # if an existed record's index_struct_dict doesn't contain using_ugc field, # it actually stores in the normal index format - stored_in_ugc = dataset.index_struct_dict.get("using_ugc", False) + stored_in_ugc: bool = dataset.index_struct_dict.get("using_ugc", False) using_ugc = stored_in_ugc if stored_in_ugc: dimension = dataset.index_struct_dict["dimension"] diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index 5a263d6e78c3bd..9b029ffc193cc0 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -3,8 +3,8 @@ from typing import Any, Optional from pydantic import BaseModel, model_validator -from pymilvus import MilvusClient, MilvusException -from pymilvus.milvus_client import IndexParams +from pymilvus import MilvusClient, MilvusException # type: ignore +from pymilvus.milvus_client import IndexParams # type: ignore from configs import dify_config from core.rag.datasource.vdb.field import Field @@ -54,14 +54,14 @@ def __init__(self, collection_name: str, config: MilvusConfig): self._client_config = config self._client = self._init_client(config) self._consistency_level = "Session" - self._fields = [] + self._fields: list[str] = [] def get_type(self) -> str: return VectorType.MILVUS def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}} - metadatas = [d.metadata for d in texts] + metadatas = [d.metadata if d.metadata is not None else {} for d in texts] self.create_collection(embeddings, metadatas, index_params) self.add_texts(texts, embeddings) @@ -161,8 +161,8 @@ def create_collection( return # Grab the existing collection if it exists if not self._client.has_collection(self._collection_name): - from pymilvus import CollectionSchema, DataType, FieldSchema - from pymilvus.orm.types import infer_dtype_bydata + from pymilvus import CollectionSchema, DataType, FieldSchema # type: ignore + from pymilvus.orm.types import infer_dtype_bydata # type: ignore # Determine embedding dim dim = len(embeddings[0]) @@ -217,10 +217,10 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return MilvusVector( collection_name=collection_name, config=MilvusConfig( - uri=dify_config.MILVUS_URI, - token=dify_config.MILVUS_TOKEN, - user=dify_config.MILVUS_USER, - password=dify_config.MILVUS_PASSWORD, - database=dify_config.MILVUS_DATABASE, + uri=dify_config.MILVUS_URI or "", + token=dify_config.MILVUS_TOKEN or "", + user=dify_config.MILVUS_USER or "", + password=dify_config.MILVUS_PASSWORD or "", + database=dify_config.MILVUS_DATABASE or "", ), ) diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index b7b6b803ad20af..e63e1f522b3812 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -74,15 +74,16 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** columns = ["id", "text", "vector", "metadata"] values = [] for i, doc in enumerate(documents): - doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) - row = ( - doc_id, - self.escape_str(doc.page_content), - embeddings[i], - json.dumps(doc.metadata) if doc.metadata else {}, - ) - values.append(str(row)) - ids.append(doc_id) + if doc.metadata is not None: + doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) + row = ( + doc_id, + self.escape_str(doc.page_content), + embeddings[i], + json.dumps(doc.metadata) if doc.metadata else {}, + ) + values.append(str(row)) + ids.append(doc_id) sql = f""" INSERT INTO {self._config.database}.{self._collection_name} ({",".join(columns)}) VALUES {",".join(values)} diff --git a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py index c44338d42a591a..957c799a60cbfe 100644 --- a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py +++ b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py @@ -4,7 +4,7 @@ from typing import Any from pydantic import BaseModel, model_validator -from pyobvector import VECTOR, ObVecClient +from pyobvector import VECTOR, ObVecClient # type: ignore from sqlalchemy import JSON, Column, String, func from sqlalchemy.dialects.mysql import LONGTEXT @@ -131,7 +131,7 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** def text_exists(self, id: str) -> bool: cur = self._client.get(table_name=self._collection_name, id=id) - return cur.rowcount != 0 + return bool(cur.rowcount != 0) def delete_by_ids(self, ids: list[str]) -> None: self._client.delete(table_name=self._collection_name, ids=ids) diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 7a976d7c3c8955..72a15022052f0a 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -66,7 +66,7 @@ def get_type(self) -> str: return VectorType.OPENSEARCH def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - metadatas = [d.metadata for d in texts] + metadatas = [d.metadata if d.metadata is not None else {} for d in texts] self.create_collection(embeddings, metadatas) self.add_texts(texts, embeddings) @@ -244,7 +244,7 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name)) open_search_config = OpenSearchConfig( - host=dify_config.OPENSEARCH_HOST, + host=dify_config.OPENSEARCH_HOST or "localhost", port=dify_config.OPENSEARCH_PORT, user=dify_config.OPENSEARCH_USER, password=dify_config.OPENSEARCH_PASSWORD, diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index 74608f1e1a3b05..dfff3563c3bb28 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -5,7 +5,7 @@ from contextlib import contextmanager from typing import Any -import jieba.posseg as pseg +import jieba.posseg as pseg # type: ignore import numpy import oracledb from pydantic import BaseModel, model_validator @@ -88,12 +88,11 @@ def input_type_handler(self, cursor, value, arraysize): def numpy_converter_out(self, value): if value.typecode == "b": - dtype = numpy.int8 + return numpy.array(value, copy=False, dtype=numpy.int8) elif value.typecode == "f": - dtype = numpy.float32 + return numpy.array(value, copy=False, dtype=numpy.float32) else: - dtype = numpy.float64 - return numpy.array(value, copy=False, dtype=dtype) + return numpy.array(value, copy=False, dtype=numpy.float64) def output_type_handler(self, cursor, metadata): if metadata.type_code is oracledb.DB_TYPE_VECTOR: @@ -135,17 +134,18 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** values = [] pks = [] for i, doc in enumerate(documents): - doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) - pks.append(doc_id) - values.append( - ( - doc_id, - doc.page_content, - json.dumps(doc.metadata), - # array.array("f", embeddings[i]), - numpy.array(embeddings[i]), + if doc.metadata is not None: + doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) + pks.append(doc_id) + values.append( + ( + doc_id, + doc.page_content, + json.dumps(doc.metadata), + # array.array("f", embeddings[i]), + numpy.array(embeddings[i]), + ) ) - ) # print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)") with self._get_cursor() as cur: cur.executemany( @@ -201,8 +201,8 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: # lazy import - import nltk - from nltk.corpus import stopwords + import nltk # type: ignore + from nltk.corpus import stopwords # type: ignore top_k = kwargs.get("top_k", 5) # just not implement fetch by score_threshold now, may be later @@ -285,10 +285,10 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return OracleVector( collection_name=collection_name, config=OracleVectorConfig( - host=dify_config.ORACLE_HOST, + host=dify_config.ORACLE_HOST or "localhost", port=dify_config.ORACLE_PORT, - user=dify_config.ORACLE_USER, - password=dify_config.ORACLE_PASSWORD, - database=dify_config.ORACLE_DATABASE, + user=dify_config.ORACLE_USER or "system", + password=dify_config.ORACLE_PASSWORD or "oracle", + database=dify_config.ORACLE_DATABASE or "orcl", ), ) diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index 7cbbdcc81f6039..221bc68d68a6f7 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -4,7 +4,7 @@ from uuid import UUID, uuid4 from numpy import ndarray -from pgvecto_rs.sqlalchemy import VECTOR +from pgvecto_rs.sqlalchemy import VECTOR # type: ignore from pydantic import BaseModel, model_validator from sqlalchemy import Float, String, create_engine, insert, select, text from sqlalchemy import text as sql_text @@ -58,7 +58,7 @@ def __init__(self, collection_name: str, config: PgvectoRSConfig, dim: int): with Session(self._client) as session: session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors")) session.commit() - self._fields = [] + self._fields: list[str] = [] class _Table(CollectionORM): __tablename__ = collection_name @@ -222,11 +222,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return PGVectoRS( collection_name=collection_name, config=PgvectoRSConfig( - host=dify_config.PGVECTO_RS_HOST, - port=dify_config.PGVECTO_RS_PORT, - user=dify_config.PGVECTO_RS_USER, - password=dify_config.PGVECTO_RS_PASSWORD, - database=dify_config.PGVECTO_RS_DATABASE, + host=dify_config.PGVECTO_RS_HOST or "localhost", + port=dify_config.PGVECTO_RS_PORT or 5432, + user=dify_config.PGVECTO_RS_USER or "postgres", + password=dify_config.PGVECTO_RS_PASSWORD or "", + database=dify_config.PGVECTO_RS_DATABASE or "postgres", ), dim=dim, ) diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index 40a9cdd136b404..271281ca7e939f 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -3,8 +3,8 @@ from contextlib import contextmanager from typing import Any -import psycopg2.extras -import psycopg2.pool +import psycopg2.extras # type: ignore +import psycopg2.pool # type: ignore from pydantic import BaseModel, model_validator from configs import dify_config @@ -98,16 +98,17 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** values = [] pks = [] for i, doc in enumerate(documents): - doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) - pks.append(doc_id) - values.append( - ( - doc_id, - doc.page_content, - json.dumps(doc.metadata), - embeddings[i], + if doc.metadata is not None: + doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) + pks.append(doc_id) + values.append( + ( + doc_id, + doc.page_content, + json.dumps(doc.metadata), + embeddings[i], + ) ) - ) with self._get_cursor() as cur: psycopg2.extras.execute_values( cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values @@ -216,11 +217,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return PGVector( collection_name=collection_name, config=PGVectorConfig( - host=dify_config.PGVECTOR_HOST, + host=dify_config.PGVECTOR_HOST or "localhost", port=dify_config.PGVECTOR_PORT, - user=dify_config.PGVECTOR_USER, - password=dify_config.PGVECTOR_PASSWORD, - database=dify_config.PGVECTOR_DATABASE, + user=dify_config.PGVECTOR_USER or "postgres", + password=dify_config.PGVECTOR_PASSWORD or "", + database=dify_config.PGVECTOR_DATABASE or "postgres", min_connection=dify_config.PGVECTOR_MIN_CONNECTION, max_connection=dify_config.PGVECTOR_MAX_CONNECTION, ), diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 3811458e02957c..6e94cb69db309d 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -51,6 +51,8 @@ def to_qdrant_params(self): if self.endpoint and self.endpoint.startswith("path:"): path = self.endpoint.replace("path:", "") if not os.path.isabs(path): + if not self.root_path: + raise ValueError("Root path is not set") path = os.path.join(self.root_path, path) return {"path": path} @@ -149,9 +151,12 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** uuids = self._get_uuids(documents) texts = [d.page_content for d in documents] metadatas = [d.metadata for d in documents] - added_ids = [] - for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id): + # Filter out None values from metadatas list to match expected type + filtered_metadatas = [m for m in metadatas if m is not None] + for batch_ids, points in self._generate_rest_batches( + texts, embeddings, filtered_metadatas, uuids, 64, self._group_id + ): self._client.upsert(collection_name=self._collection_name, points=points) added_ids.extend(batch_ids) @@ -194,7 +199,7 @@ def _generate_rest_batches( batch_metadatas, Field.CONTENT_KEY.value, Field.METADATA_KEY.value, - group_id, + group_id or "", # Ensure group_id is never None Field.GROUP_KEY.value, ), ) @@ -337,18 +342,20 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc ) docs = [] for result in results: + if result.payload is None: + continue metadata = result.payload.get(Field.METADATA_KEY.value) or {} # duplicate check score threshold score_threshold = float(kwargs.get("score_threshold") or 0.0) if result.score > score_threshold: metadata["score"] = result.score doc = Document( - page_content=result.payload.get(Field.CONTENT_KEY.value), + page_content=result.payload.get(Field.CONTENT_KEY.value, ""), metadata=metadata, ) docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -432,9 +439,9 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings collection_name=collection_name, group_id=dataset.id, config=QdrantConfig( - endpoint=dify_config.QDRANT_URL, + endpoint=dify_config.QDRANT_URL or "", api_key=dify_config.QDRANT_API_KEY, - root_path=current_app.config.root_path, + root_path=str(current_app.config.root_path), timeout=dify_config.QDRANT_CLIENT_TIMEOUT, grpc_port=dify_config.QDRANT_GRPC_PORT, prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index f373dcfeabef92..a3a20448ff7a0a 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -3,7 +3,7 @@ from typing import Any, Optional from pydantic import BaseModel, model_validator -from sqlalchemy import Column, Sequence, String, Table, create_engine, insert +from sqlalchemy import Column, String, Table, create_engine, insert from sqlalchemy import text as sql_text from sqlalchemy.dialects.postgresql import JSON, TEXT from sqlalchemy.orm import Session @@ -58,14 +58,14 @@ def __init__(self, collection_name: str, config: RelytConfig, group_id: str): f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" ) self.client = create_engine(self._url) - self._fields = [] + self._fields: list[str] = [] self._group_id = group_id def get_type(self) -> str: return VectorType.RELYT - def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - index_params = {} + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> None: + index_params: dict[str, Any] = {} metadatas = [d.metadata for d in texts] self.create_collection(len(embeddings[0])) self.embedding_dimension = len(embeddings[0]) @@ -107,10 +107,10 @@ def create_collection(self, dimension: int): redis_client.set(collection_exist_cache_key, 1, ex=3600) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): - from pgvecto_rs.sqlalchemy import VECTOR + from pgvecto_rs.sqlalchemy import VECTOR # type: ignore ids = [str(uuid.uuid1()) for _ in documents] - metadatas = [d.metadata for d in documents] + metadatas = [d.metadata for d in documents if d.metadata is not None] for metadata in metadatas: metadata["group_id"] = self._group_id texts = [d.page_content for d in documents] @@ -242,10 +242,6 @@ def similarity_search_with_score_by_vector( filter: Optional[dict] = None, ) -> list[tuple[Document, float]]: # Add the filter if provided - try: - from sqlalchemy.engine import Row - except ImportError: - raise ImportError("Could not import Row from sqlalchemy.engine. Please 'pip install sqlalchemy>=1.4'.") filter_condition = "" if filter is not None: @@ -275,7 +271,7 @@ def similarity_search_with_score_by_vector( # Execute the query and fetch the results with self.client.connect() as conn: - results: Sequence[Row] = conn.execute(sql_text(sql_query), params).fetchall() + results = conn.execute(sql_text(sql_query), params).fetchall() documents_with_scores = [ ( @@ -307,11 +303,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return RelytVector( collection_name=collection_name, config=RelytConfig( - host=dify_config.RELYT_HOST, + host=dify_config.RELYT_HOST or "localhost", port=dify_config.RELYT_PORT, - user=dify_config.RELYT_USER, - password=dify_config.RELYT_PASSWORD, - database=dify_config.RELYT_DATABASE, + user=dify_config.RELYT_USER or "", + password=dify_config.RELYT_PASSWORD or "", + database=dify_config.RELYT_DATABASE or "default", ), group_id=dataset.id, ) diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index f971a9c5eb1696..c15f4b229f81c3 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -2,10 +2,10 @@ from typing import Any, Optional from pydantic import BaseModel -from tcvectordb import VectorDBClient -from tcvectordb.model import document, enum -from tcvectordb.model import index as vdb_index -from tcvectordb.model.document import Filter +from tcvectordb import VectorDBClient # type: ignore +from tcvectordb.model import document, enum # type: ignore +from tcvectordb.model import index as vdb_index # type: ignore +from tcvectordb.model.document import Filter # type: ignore from configs import dify_config from core.rag.datasource.vdb.vector_base import BaseVector @@ -25,8 +25,8 @@ class TencentConfig(BaseModel): database: Optional[str] index_type: str = "HNSW" metric_type: str = "L2" - shard: int = (1,) - replicas: int = (2,) + shard: int = 1 + replicas: int = 2 def to_tencent_params(self): return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout} @@ -120,15 +120,15 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** metadatas = [doc.metadata for doc in documents] total_count = len(embeddings) docs = [] - for id in range(0, total_count): + for i in range(0, total_count): if metadatas is None: continue - metadata = json.dumps(metadatas[id]) + metadata = metadatas[i] or {} doc = document.Document( - id=metadatas[id]["doc_id"], - vector=embeddings[id], - text=texts[id], - metadata=metadata, + id=metadata.get("doc_id"), + vector=embeddings[i], + text=texts[i], + metadata=json.dumps(metadata), ) docs.append(doc) self._db.collection(self._collection_name).upsert(docs, self._client_config.timeout) @@ -159,8 +159,8 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return [] - def _get_search_res(self, res, score_threshold): - docs = [] + def _get_search_res(self, res: list | None, score_threshold: float) -> list[Document]: + docs: list[Document] = [] if res is None or len(res) == 0: return docs @@ -193,7 +193,7 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return TencentVector( collection_name=collection_name, config=TencentConfig( - url=dify_config.TENCENT_VECTOR_DB_URL, + url=dify_config.TENCENT_VECTOR_DB_URL or "", api_key=dify_config.TENCENT_VECTOR_DB_API_KEY, timeout=dify_config.TENCENT_VECTOR_DB_TIMEOUT, username=dify_config.TENCENT_VECTOR_DB_USERNAME, diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index cfd47aac5ba05b..19c5579a688f5a 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -54,7 +54,10 @@ def to_qdrant_params(self): if self.endpoint and self.endpoint.startswith("path:"): path = self.endpoint.replace("path:", "") if not os.path.isabs(path): - path = os.path.join(self.root_path, path) + if self.root_path: + path = os.path.join(self.root_path, path) + else: + raise ValueError("root_path is required") return {"path": path} else: @@ -157,7 +160,7 @@ def create_collection(self, collection_name: str, vector_size: int): def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): uuids = self._get_uuids(documents) texts = [d.page_content for d in documents] - metadatas = [d.metadata for d in documents] + metadatas = [d.metadata for d in documents if d.metadata is not None] added_ids = [] for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id): @@ -203,7 +206,7 @@ def _generate_rest_batches( batch_metadatas, Field.CONTENT_KEY.value, Field.METADATA_KEY.value, - group_id, + group_id or "", Field.GROUP_KEY.value, ), ) @@ -334,18 +337,20 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc ) docs = [] for result in results: + if result.payload is None: + continue metadata = result.payload.get(Field.METADATA_KEY.value) or {} # duplicate check score threshold score_threshold = kwargs.get("score_threshold") or 0.0 if result.score > score_threshold: metadata["score"] = result.score doc = Document( - page_content=result.payload.get(Field.CONTENT_KEY.value), + page_content=result.payload.get(Field.CONTENT_KEY.value, ""), metadata=metadata, ) docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -427,12 +432,12 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings else: new_cluster = TidbService.create_tidb_serverless_cluster( - dify_config.TIDB_PROJECT_ID, - dify_config.TIDB_API_URL, - dify_config.TIDB_IAM_API_URL, - dify_config.TIDB_PUBLIC_KEY, - dify_config.TIDB_PRIVATE_KEY, - dify_config.TIDB_REGION, + dify_config.TIDB_PROJECT_ID or "", + dify_config.TIDB_API_URL or "", + dify_config.TIDB_IAM_API_URL or "", + dify_config.TIDB_PUBLIC_KEY or "", + dify_config.TIDB_PRIVATE_KEY or "", + dify_config.TIDB_REGION or "", ) new_tidb_auth_binding = TidbAuthBinding( cluster_id=new_cluster["cluster_id"], @@ -464,9 +469,9 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings collection_name=collection_name, group_id=dataset.id, config=TidbOnQdrantConfig( - endpoint=dify_config.TIDB_ON_QDRANT_URL, + endpoint=dify_config.TIDB_ON_QDRANT_URL or "", api_key=TIDB_ON_QDRANT_API_KEY, - root_path=config.root_path, + root_path=str(config.root_path), timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT, grpc_port=dify_config.TIDB_ON_QDRANT_GRPC_PORT, prefer_grpc=dify_config.TIDB_ON_QDRANT_GRPC_ENABLED, diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py index 8dd5922ad0171d..0a48c79511bf26 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py @@ -146,7 +146,7 @@ def batch_update_tidb_serverless_cluster_status( iam_url: str, public_key: str, private_key: str, - ) -> list[dict]: + ): """ Update the status of a new TiDB Serverless cluster. :param project_id: The project ID of the TiDB Cloud project (required). @@ -159,7 +159,6 @@ def batch_update_tidb_serverless_cluster_status( :return: The response from the API. """ - clusters = [] tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list} cluster_ids = [item.cluster_id for item in tidb_serverless_list] params = {"clusterIds": cluster_ids, "view": "BASIC"} @@ -169,7 +168,6 @@ def batch_update_tidb_serverless_cluster_status( if response.status_code == 200: response_data = response.json() - cluster_infos = [] for item in response_data["clusters"]: state = item["state"] userPrefix = item["userPrefix"] @@ -236,16 +234,17 @@ def batch_create_tidb_serverless_cluster( cluster_infos = [] for item in response_data["clusters"]: cache_key = f"tidb_serverless_cluster_password:{item['displayName']}" - password = redis_client.get(cache_key) - if not password: + cached_password = redis_client.get(cache_key) + if not cached_password: continue cluster_info = { "cluster_id": item["clusterId"], "cluster_name": item["displayName"], "account": "root", - "password": password.decode("utf-8"), + "password": cached_password.decode("utf-8"), } cluster_infos.append(cluster_info) return cluster_infos else: response.raise_for_status() + return [] # FIXME for mypy, This line will not be reached as raise_for_status() will raise an exception diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index 39ab6ea71e9485..be3a417390e802 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -49,7 +49,7 @@ def get_type(self) -> str: return VectorType.TIDB_VECTOR def _table(self, dim: int) -> Table: - from tidb_vector.sqlalchemy import VectorType + from tidb_vector.sqlalchemy import VectorType # type: ignore return Table( self._collection_name, @@ -241,11 +241,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return TiDBVector( collection_name=collection_name, config=TiDBVectorConfig( - host=dify_config.TIDB_VECTOR_HOST, - port=dify_config.TIDB_VECTOR_PORT, - user=dify_config.TIDB_VECTOR_USER, - password=dify_config.TIDB_VECTOR_PASSWORD, - database=dify_config.TIDB_VECTOR_DATABASE, + host=dify_config.TIDB_VECTOR_HOST or "", + port=dify_config.TIDB_VECTOR_PORT or 0, + user=dify_config.TIDB_VECTOR_USER or "", + password=dify_config.TIDB_VECTOR_PASSWORD or "", + database=dify_config.TIDB_VECTOR_DATABASE or "", program_name=dify_config.APPLICATION_NAME, ), ) diff --git a/api/core/rag/datasource/vdb/vector_base.py b/api/core/rag/datasource/vdb/vector_base.py index 22e191340d3a47..edfce2edd896ee 100644 --- a/api/core/rag/datasource/vdb/vector_base.py +++ b/api/core/rag/datasource/vdb/vector_base.py @@ -51,15 +51,16 @@ def delete(self) -> None: def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: for text in texts.copy(): - doc_id = text.metadata["doc_id"] - exists_duplicate_node = self.text_exists(doc_id) - if exists_duplicate_node: - texts.remove(text) + if text.metadata and "doc_id" in text.metadata: + doc_id = text.metadata["doc_id"] + exists_duplicate_node = self.text_exists(doc_id) + if exists_duplicate_node: + texts.remove(text) return texts def _get_uuids(self, texts: list[Document]) -> list[str]: - return [text.metadata["doc_id"] for text in texts] + return [text.metadata["doc_id"] for text in texts if text.metadata and "doc_id" in text.metadata] @property def collection_name(self): diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 6d2e04fc020ab5..523fa80f124b0c 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -193,10 +193,13 @@ def _get_embeddings(self) -> Embeddings: def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: for text in texts.copy(): + if text.metadata is None: + continue doc_id = text.metadata["doc_id"] - exists_duplicate_node = self.text_exists(doc_id) - if exists_duplicate_node: - texts.remove(text) + if doc_id: + exists_duplicate_node = self.text_exists(doc_id) + if exists_duplicate_node: + texts.remove(text) return texts diff --git a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py index 4f927f28995613..9de8761a91ca68 100644 --- a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py +++ b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py @@ -2,7 +2,7 @@ from typing import Any from pydantic import BaseModel -from volcengine.viking_db import ( +from volcengine.viking_db import ( # type: ignore Data, DistanceType, Field, @@ -121,11 +121,12 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** for i, page_content in enumerate(page_contents): metadata = {} if metadatas is not None: - for key, val in metadatas[i].items(): + for key, val in (metadatas[i] or {}).items(): metadata[key] = val + # FIXME: fix the type of metadata later doc = Data( { - vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"], + vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"], # type: ignore vdb_Field.VECTOR.value: embeddings[i] if embeddings else None, vdb_Field.CONTENT_KEY.value: page_content, vdb_Field.METADATA_KEY.value: json.dumps(metadata), @@ -178,7 +179,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc score_threshold = float(kwargs.get("score_threshold") or 0.0) return self._get_search_res(results, score_threshold) - def _get_search_res(self, results, score_threshold): + def _get_search_res(self, results, score_threshold) -> list[Document]: if len(results) == 0: return [] @@ -191,7 +192,7 @@ def _get_search_res(self, results, score_threshold): metadata["score"] = result.score doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata) docs.append(doc) - docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 649cfbfea8253c..68d043a19f171f 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -3,7 +3,7 @@ from typing import Any, Optional import requests -import weaviate +import weaviate # type: ignore from pydantic import BaseModel, model_validator from configs import dify_config @@ -107,7 +107,8 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** for i, text in enumerate(texts): data_properties = {Field.TEXT_KEY.value: text} if metadatas is not None: - for key, val in metadatas[i].items(): + # metadata maybe None + for key, val in (metadatas[i] or {}).items(): data_properties[key] = self._json_serializable(val) batch.add_data_object( @@ -208,10 +209,11 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc score_threshold = float(kwargs.get("score_threshold") or 0.0) # check score threshold if score > score_threshold: - doc.metadata["score"] = score - docs.append(doc) + if doc.metadata is not None: + doc.metadata["score"] = score + docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -275,7 +277,7 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return WeaviateVector( collection_name=collection_name, config=WeaviateConfig( - endpoint=dify_config.WEAVIATE_ENDPOINT, + endpoint=dify_config.WEAVIATE_ENDPOINT or "", api_key=dify_config.WEAVIATE_API_KEY, batch_size=dify_config.WEAVIATE_BATCH_SIZE, ), diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 319a2612c7ecb8..35becaa0c7bea7 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -83,6 +83,9 @@ def add_documents(self, docs: Sequence[Document], allow_update: bool = True) -> if not isinstance(doc, Document): raise ValueError("doc must be a Document") + if doc.metadata is None: + raise ValueError("doc.metadata must be a dict") + segment_document = self.get_document_segment(doc_id=doc.metadata["doc_id"]) # NOTE: doc could already exist in the store, but we overwrite it @@ -179,10 +182,10 @@ def get_document_hash(self, doc_id: str) -> Optional[str]: if document_segment is None: return None + data: Optional[str] = document_segment.index_node_hash + return data - return document_segment.index_node_hash - - def get_document_segment(self, doc_id: str) -> DocumentSegment: + def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]: document_segment = ( db.session.query(DocumentSegment) .filter(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id) diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 8ddda7e9832d97..a2c8737da79198 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -1,6 +1,6 @@ import base64 import logging -from typing import Optional, cast +from typing import Any, Optional, cast import numpy as np from sqlalchemy.exc import IntegrityError @@ -27,7 +27,7 @@ def __init__(self, model_instance: ModelInstance, user: Optional[str] = None) -> def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs in batches of 10.""" # use doc embedding cache or store if not exists - text_embeddings = [None for _ in range(len(texts))] + text_embeddings: list[Any] = [None for _ in range(len(texts))] embedding_queue_indices = [] for i, text in enumerate(texts): hash = helper.generate_text_hash(text) @@ -64,7 +64,8 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: for vector in embedding_result.embeddings: try: - normalized_embedding = (vector / np.linalg.norm(vector)).tolist() + # FIXME: type ignore for numpy here + normalized_embedding = (vector / np.linalg.norm(vector)).tolist() # type: ignore # stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan if np.isnan(normalized_embedding).any(): # for issue #11827 float values are not json compliant @@ -77,8 +78,8 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: logging.exception("Failed transform embedding") cache_embeddings = [] try: - for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings): - text_embeddings[i] = embedding + for i, n_embedding in zip(embedding_queue_indices, embedding_queue_embeddings): + text_embeddings[i] = n_embedding hash = helper.generate_text_hash(texts[i]) if hash not in cache_embeddings: embedding_cache = Embedding( @@ -86,7 +87,7 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: hash=hash, provider_name=self._model_instance.provider, ) - embedding_cache.set_embedding(embedding) + embedding_cache.set_embedding(n_embedding) db.session.add(embedding_cache) cache_embeddings.append(hash) db.session.commit() @@ -115,7 +116,8 @@ def embed_query(self, text: str) -> list[float]: ) embedding_results = embedding_result.embeddings[0] - embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() + # FIXME: type ignore for numpy here + embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() # type: ignore if np.isnan(embedding_results).any(): raise ValueError("Normalized embedding is nan please try again") except Exception as ex: diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index 3692b5d19dfb65..7c00c668dd49a3 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -14,7 +14,7 @@ class NotionInfo(BaseModel): notion_workspace_id: str notion_obj_id: str notion_page_type: str - document: Document = None + document: Optional[Document] = None tenant_id: str model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py index fc331657195454..c444105bb59443 100644 --- a/api/core/rag/extractor/excel_extractor.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -1,7 +1,7 @@ """Abstract interface for document loader implementations.""" import os -from typing import Optional +from typing import Optional, cast import pandas as pd from openpyxl import load_workbook @@ -47,7 +47,7 @@ def extract(self) -> list[Document]: for col_index, (k, v) in enumerate(row.items()): if pd.notna(v): cell = sheet.cell( - row=index + 2, column=col_index + 1 + row=cast(int, index) + 2, column=col_index + 1 ) # +2 to account for header and 1-based index if cell.hyperlink: value = f"[{v}]({cell.hyperlink.target})" @@ -60,8 +60,8 @@ def extract(self) -> list[Document]: elif file_extension == ".xls": excel_file = pd.ExcelFile(self._file_path, engine="xlrd") - for sheet_name in excel_file.sheet_names: - df = excel_file.parse(sheet_name=sheet_name) + for excel_sheet_name in excel_file.sheet_names: + df = excel_file.parse(sheet_name=excel_sheet_name) df.dropna(how="all", inplace=True) for _, row in df.iterrows(): diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 69659e31080da6..a473b3dfa78a90 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -10,6 +10,7 @@ from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.excel_extractor import ExcelExtractor +from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor from core.rag.extractor.html_extractor import HtmlExtractor from core.rag.extractor.jina_reader_extractor import JinaReaderWebExtractor @@ -66,9 +67,13 @@ def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Docume filename_match = re.search(r'filename="([^"]+)"', content_disposition) if filename_match: filename = unquote(filename_match.group(1)) - suffix = "." + re.search(r"\.(\w+)$", filename).group(1) - - file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" + match = re.search(r"\.(\w+)$", filename) + if match: + suffix = "." + match.group(1) + else: + suffix = "" + # FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here + file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore Path(file_path).write_bytes(response.content) extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model") if return_text: @@ -89,15 +94,20 @@ def extract( if extract_setting.datasource_type == DatasourceType.FILE.value: with tempfile.TemporaryDirectory() as temp_dir: if not file_path: + assert extract_setting.upload_file is not None, "upload_file is required" upload_file: UploadFile = extract_setting.upload_file suffix = Path(upload_file.key).suffix - file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" + # FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here + file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore storage.download(upload_file.key, file_path) input_file = Path(file_path) file_extension = input_file.suffix.lower() etl_type = dify_config.ETL_TYPE unstructured_api_url = dify_config.UNSTRUCTURED_API_URL unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY + assert unstructured_api_url is not None, "unstructured_api_url is required" + assert unstructured_api_key is not None, "unstructured_api_key is required" + extractor: Optional[BaseExtractor] = None if etl_type == "Unstructured": if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) @@ -156,6 +166,7 @@ def extract( extractor = TextExtractor(file_path, autodetect_encoding=True) return extractor.extract() elif extract_setting.datasource_type == DatasourceType.NOTION.value: + assert extract_setting.notion_info is not None, "notion_info is required" extractor = NotionExtractor( notion_workspace_id=extract_setting.notion_info.notion_workspace_id, notion_obj_id=extract_setting.notion_info.notion_obj_id, @@ -165,6 +176,7 @@ def extract( ) return extractor.extract() elif extract_setting.datasource_type == DatasourceType.WEBSITE.value: + assert extract_setting.website_info is not None, "website_info is required" if extract_setting.website_info.provider == "firecrawl": extractor = FirecrawlWebExtractor( url=extract_setting.website_info.url, diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index 17c2087a0ab575..8ae4579c7cf93f 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -1,5 +1,6 @@ import json import time +from typing import cast import requests @@ -20,9 +21,9 @@ def scrape_url(self, url, params=None) -> dict: json_data.update(params) response = requests.post(f"{self.base_url}/v0/scrape", headers=headers, json=json_data) if response.status_code == 200: - response = response.json() - if response["success"] == True: - data = response["data"] + response_data = response.json() + if response_data["success"] == True: + data = response_data["data"] return { "title": data.get("metadata").get("title"), "description": data.get("metadata").get("description"), @@ -30,7 +31,7 @@ def scrape_url(self, url, params=None) -> dict: "markdown": data.get("markdown"), } else: - raise Exception(f'Failed to scrape URL. Error: {response["error"]}') + raise Exception(f'Failed to scrape URL. Error: {response_data["error"]}') elif response.status_code in {402, 409, 500}: error_message = response.json().get("error", "Unknown error occurred") @@ -46,9 +47,11 @@ def crawl_url(self, url, params=None) -> str: response = self._post_request(f"{self.base_url}/v0/crawl", json_data, headers) if response.status_code == 200: job_id = response.json().get("jobId") - return job_id + return cast(str, job_id) else: self._handle_error(response, "start crawl job") + # FIXME: unreachable code for mypy + return "" # unreachable def check_crawl_status(self, job_id) -> dict: headers = self._prepare_headers() @@ -64,9 +67,9 @@ def check_crawl_status(self, job_id) -> dict: for item in data: if isinstance(item, dict) and "metadata" in item and "markdown" in item: url_data = { - "title": item.get("metadata").get("title"), - "description": item.get("metadata").get("description"), - "source_url": item.get("metadata").get("sourceURL"), + "title": item.get("metadata", {}).get("title"), + "description": item.get("metadata", {}).get("description"), + "source_url": item.get("metadata", {}).get("sourceURL"), "markdown": item.get("markdown"), } url_data_list.append(url_data) @@ -92,6 +95,8 @@ def check_crawl_status(self, job_id) -> dict: else: self._handle_error(response, "check crawl status") + # FIXME: unreachable code for mypy + return {} # unreachable def _prepare_headers(self): return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} diff --git a/api/core/rag/extractor/html_extractor.py b/api/core/rag/extractor/html_extractor.py index 560c2d1d84b04e..350b522347b09d 100644 --- a/api/core/rag/extractor/html_extractor.py +++ b/api/core/rag/extractor/html_extractor.py @@ -1,6 +1,6 @@ """Abstract interface for document loader implementations.""" -from bs4 import BeautifulSoup +from bs4 import BeautifulSoup # type: ignore from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -23,6 +23,7 @@ def extract(self) -> list[Document]: return [Document(page_content=self._load_as_text())] def _load_as_text(self) -> str: + text: str = "" with open(self._file_path, "rb") as fp: soup = BeautifulSoup(fp, "html.parser") text = soup.get_text() diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 87a4ce08bf3f89..fdc2e46d141d07 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, Optional +from typing import Any, Optional, cast import requests @@ -78,6 +78,7 @@ def _load_data_as_documents(self, notion_obj_id: str, notion_page_type: str) -> def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] = {}) -> list[Document]: """Get all the pages from a Notion database.""" + assert self._notion_access_token is not None, "Notion access token is required" res = requests.post( DATABASE_URL_TMPL.format(database_id=database_id), headers={ @@ -96,6 +97,7 @@ def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] for result in data["results"]: properties = result["properties"] data = {} + value: Any for property_name, property_value in properties.items(): type = property_value["type"] if type == "multi_select": @@ -130,6 +132,7 @@ def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] return [Document(page_content="\n".join(database_content))] def _get_notion_block_data(self, page_id: str) -> list[str]: + assert self._notion_access_token is not None, "Notion access token is required" result_lines_arr = [] start_cursor = None block_url = BLOCK_CHILD_URL_TMPL.format(block_id=page_id) @@ -184,6 +187,7 @@ def _get_notion_block_data(self, page_id: str) -> list[str]: def _read_block(self, block_id: str, num_tabs: int = 0) -> str: """Read a block.""" + assert self._notion_access_token is not None, "Notion access token is required" result_lines_arr = [] start_cursor = None block_url = BLOCK_CHILD_URL_TMPL.format(block_id=block_id) @@ -242,6 +246,7 @@ def _read_block(self, block_id: str, num_tabs: int = 0) -> str: def _read_table_rows(self, block_id: str) -> str: """Read table rows.""" + assert self._notion_access_token is not None, "Notion access token is required" done = False result_lines_arr = [] start_cursor = None @@ -296,7 +301,7 @@ def _read_table_rows(self, block_id: str) -> str: result_lines = "\n".join(result_lines_arr) return result_lines - def update_last_edited_time(self, document_model: DocumentModel): + def update_last_edited_time(self, document_model: Optional[DocumentModel]): if not document_model: return @@ -309,6 +314,7 @@ def update_last_edited_time(self, document_model: DocumentModel): db.session.commit() def get_notion_last_edited_time(self) -> str: + assert self._notion_access_token is not None, "Notion access token is required" obj_id = self._notion_obj_id page_type = self._notion_page_type if page_type == "database": @@ -330,7 +336,7 @@ def get_notion_last_edited_time(self) -> str: ) data = res.json() - return data["last_edited_time"] + return cast(str, data["last_edited_time"]) @classmethod def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: @@ -349,4 +355,4 @@ def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: f"and notion workspace {notion_workspace_id}" ) - return data_source_binding.access_token + return cast(str, data_source_binding.access_token) diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py index 57cb9610ba267e..89a7061c26accc 100644 --- a/api/core/rag/extractor/pdf_extractor.py +++ b/api/core/rag/extractor/pdf_extractor.py @@ -1,7 +1,7 @@ """Abstract interface for document loader implementations.""" from collections.abc import Iterator -from typing import Optional +from typing import Optional, cast from core.rag.extractor.blob.blob import Blob from core.rag.extractor.extractor_base import BaseExtractor @@ -27,7 +27,7 @@ def extract(self) -> list[Document]: plaintext_file_exists = False if self._file_cache_key: try: - text = storage.load(self._file_cache_key).decode("utf-8") + text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8") plaintext_file_exists = True return [Document(page_content=text)] except FileNotFoundError: @@ -53,7 +53,7 @@ def load( def parse(self, blob: Blob) -> Iterator[Document]: """Lazily parse the blob.""" - import pypdfium2 + import pypdfium2 # type: ignore with blob.as_bytes_io() as file_path: pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True) diff --git a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py index bd669bbad36873..9647dedfff8516 100644 --- a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py @@ -1,7 +1,7 @@ import base64 import logging -from bs4 import BeautifulSoup +from bs4 import BeautifulSoup # type: ignore from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document diff --git a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py index 35220b558afab9..80c29157aaf529 100644 --- a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py @@ -30,6 +30,9 @@ def extract(self) -> list[Document]: if self._api_url: from unstructured.partition.api import partition_via_api + if self._api_key is None: + raise ValueError("api_key is required") + elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key) else: from unstructured.partition.epub import partition_epub diff --git a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py index 0fdcd58b2e569b..e504d4bc23014c 100644 --- a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py @@ -27,9 +27,11 @@ def extract(self) -> list[Document]: elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key) else: raise NotImplementedError("Unstructured API Url is not configured") - text_by_page = {} + text_by_page: dict[int, str] = {} for element in elements: page = element.metadata.page_number + if page is None: + continue text = element.text if page in text_by_page: text_by_page[page] += "\n" + text diff --git a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py index ab41290fbc4537..cefe72b29052a1 100644 --- a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py @@ -29,14 +29,15 @@ def extract(self) -> list[Document]: from unstructured.partition.pptx import partition_pptx elements = partition_pptx(filename=self._file_path) - text_by_page = {} + text_by_page: dict[int, str] = {} for element in elements: page = element.metadata.page_number text = element.text - if page in text_by_page: - text_by_page[page] += "\n" + text - else: - text_by_page[page] = text + if page is not None: + if page in text_by_page: + text_by_page[page] += "\n" + text + else: + text_by_page[page] = text combined_texts = list(text_by_page.values()) documents = [] diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 0c38a9c0762130..c3161bc812cb73 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -89,6 +89,8 @@ def _extract_images_from_docx(self, doc, image_folder): response = ssrf_proxy.get(url) if response.status_code == 200: image_ext = mimetypes.guess_extension(response.headers["Content-Type"]) + if image_ext is None: + continue file_uuid = str(uuid.uuid4()) file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext mime_type, _ = mimetypes.guess_type(file_key) @@ -97,6 +99,8 @@ def _extract_images_from_docx(self, doc, image_folder): continue else: image_ext = rel.target_ref.split(".")[-1] + if image_ext is None: + continue # user uuid as file name file_uuid = str(uuid.uuid4()) file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext @@ -226,6 +230,8 @@ def parse_docx(self, docx_path, image_folder): if x_child is None: continue if x.tag.endswith("instrText"): + if x.text is None: + continue for i in url_pattern.findall(x.text): hyperlinks_url = str(i) except Exception as e: diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index be857bd12215fd..7e5efdc66ed533 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -49,6 +49,7 @@ def _get_splitter(self, processing_rule: dict, embedding_model_instance: Optiona """ Get the NodeParser object according to the processing rule. """ + character_splitter: TextSplitter if processing_rule["mode"] == "custom": # The user-defined segmentation rule rules = processing_rule["rules"] diff --git a/api/core/rag/index_processor/index_processor_factory.py b/api/core/rag/index_processor/index_processor_factory.py index 9b855ece2c3512..c5ba6295f32f84 100644 --- a/api/core/rag/index_processor/index_processor_factory.py +++ b/api/core/rag/index_processor/index_processor_factory.py @@ -9,7 +9,7 @@ class IndexProcessorFactory: """IndexProcessorInit.""" - def __init__(self, index_type: str): + def __init__(self, index_type: str | None): self._index_type = index_type def init_index_processor(self) -> BaseIndexProcessor: diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index a631f953ce2191..c66fa54d503e9f 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -27,12 +27,13 @@ def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: def transform(self, documents: list[Document], **kwargs) -> list[Document]: # Split the text documents into nodes. splitter = self._get_splitter( - processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance") + processing_rule=kwargs.get("process_rule", {}), + embedding_model_instance=kwargs.get("embedding_model_instance"), ) all_documents = [] for document in documents: # document clean - document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule")) + document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule", {})) document.page_content = document_text # parse document to nodes document_nodes = splitter.split_documents([document]) @@ -41,8 +42,9 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]: if document_node.page_content.strip(): doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata["doc_id"] = doc_id - document_node.metadata["doc_hash"] = hash + if document_node.metadata is not None: + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash # delete Splitter character page_content = remove_leading_symbols(document_node.page_content).strip() if len(page_content) > 0: diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 320f0157a10049..20fd16e8f39b65 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -32,15 +32,16 @@ def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: def transform(self, documents: list[Document], **kwargs) -> list[Document]: splitter = self._get_splitter( - processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance") + processing_rule=kwargs.get("process_rule") or {}, + embedding_model_instance=kwargs.get("embedding_model_instance"), ) # Split the text documents into nodes. - all_documents = [] - all_qa_documents = [] + all_documents: list[Document] = [] + all_qa_documents: list[Document] = [] for document in documents: # document clean - document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule")) + document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule") or {}) document.page_content = document_text # parse document to nodes @@ -50,8 +51,9 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]: if document_node.page_content.strip(): doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata["doc_id"] = doc_id - document_node.metadata["doc_hash"] = hash + if document_node.metadata is not None: + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash # delete Splitter character page_content = document_node.page_content document_node.page_content = remove_leading_symbols(page_content) @@ -64,7 +66,7 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]: document_format_thread = threading.Thread( target=self._format_qa_document, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "tenant_id": kwargs.get("tenant_id"), "document_node": doc, "all_qa_documents": all_qa_documents, @@ -148,11 +150,12 @@ def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, a qa_documents = [] for result in document_qa_list: qa_document = Document(page_content=result["question"], metadata=document_node.metadata.copy()) - doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(result["question"]) - qa_document.metadata["answer"] = result["answer"] - qa_document.metadata["doc_id"] = doc_id - qa_document.metadata["doc_hash"] = hash + if qa_document.metadata is not None: + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(result["question"]) + qa_document.metadata["answer"] = result["answer"] + qa_document.metadata["doc_id"] = doc_id + qa_document.metadata["doc_hash"] = hash qa_documents.append(qa_document) format_documents.extend(qa_documents) except Exception as e: diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index 6ae432a526b169..ac7a3f8bb857e4 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -30,7 +30,11 @@ def run( doc_ids = set() unique_documents = [] for document in documents: - if document.provider == "dify" and document.metadata["doc_id"] not in doc_ids: + if ( + document.provider == "dify" + and document.metadata is not None + and document.metadata["doc_id"] not in doc_ids + ): doc_ids.add(document.metadata["doc_id"]) docs.append(document.page_content) unique_documents.append(document) @@ -54,7 +58,8 @@ def run( metadata=documents[result.index].metadata, provider=documents[result.index].provider, ) - rerank_document.metadata["score"] = result.score - rerank_documents.append(rerank_document) + if rerank_document.metadata is not None: + rerank_document.metadata["score"] = result.score + rerank_documents.append(rerank_document) return rerank_documents diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index 4719be012f99cc..cbc96037bf2cc0 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -39,7 +39,7 @@ def run( unique_documents = [] doc_ids = set() for document in documents: - if document.metadata["doc_id"] not in doc_ids: + if document.metadata is not None and document.metadata["doc_id"] not in doc_ids: doc_ids.add(document.metadata["doc_id"]) unique_documents.append(document) @@ -56,10 +56,11 @@ def run( ) if score_threshold and score < score_threshold: continue - document.metadata["score"] = score - rerank_documents.append(document) + if document.metadata is not None: + document.metadata["score"] = score + rerank_documents.append(document) - rerank_documents.sort(key=lambda x: x.metadata["score"], reverse=True) + rerank_documents.sort(key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return rerank_documents[:top_n] if top_n else rerank_documents def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]: @@ -76,8 +77,9 @@ def _calculate_keyword_score(self, query: str, documents: list[Document]) -> lis for document in documents: # get the document keywords document_keywords = keyword_table_handler.extract_keywords(document.page_content, None) - document.metadata["keywords"] = document_keywords - documents_keywords.append(document_keywords) + if document.metadata is not None: + document.metadata["keywords"] = document_keywords + documents_keywords.append(document_keywords) # Counter query keywords(TF) query_keyword_counts = Counter(query_keywords) @@ -162,7 +164,7 @@ def _calculate_cosine( query_vector = cache_embedding.embed_query(query) for document in documents: # calculate cosine similarity - if "score" in document.metadata: + if document.metadata and "score" in document.metadata: query_vector_scores.append(document.metadata["score"]) else: # transform to NumPy diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 7a5bf39fa63f48..a265f36671b04b 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -1,7 +1,7 @@ import math import threading from collections import Counter -from typing import Optional, cast +from typing import Any, Optional, cast from flask import Flask, current_app @@ -34,7 +34,7 @@ from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService -default_retrieval_model = { +default_retrieval_model: dict[str, Any] = { "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, @@ -140,12 +140,12 @@ def retrieve( user_from, available_datasets, query, - retrieve_config.top_k, - retrieve_config.score_threshold, - retrieve_config.rerank_mode, + retrieve_config.top_k or 0, + retrieve_config.score_threshold or 0, + retrieve_config.rerank_mode or "reranking_model", retrieve_config.reranking_model, retrieve_config.weights, - retrieve_config.reranking_enabled, + retrieve_config.reranking_enabled or True, message_id, ) @@ -300,10 +300,11 @@ def single_retrieve( metadata=external_document.get("metadata"), provider="external", ) - document.metadata["score"] = external_document.get("score") - document.metadata["title"] = external_document.get("title") - document.metadata["dataset_id"] = dataset_id - document.metadata["dataset_name"] = dataset.name + if document.metadata is not None: + document.metadata["score"] = external_document.get("score") + document.metadata["title"] = external_document.get("title") + document.metadata["dataset_id"] = dataset_id + document.metadata["dataset_name"] = dataset.name results.append(document) else: retrieval_model_config = dataset.retrieval_model or default_retrieval_model @@ -325,7 +326,7 @@ def single_retrieve( score_threshold = 0.0 score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") if score_threshold_enabled: - score_threshold = retrieval_model_config.get("score_threshold") + score_threshold = retrieval_model_config.get("score_threshold", 0.0) with measure_time() as timer: results = RetrievalService.retrieve( @@ -358,14 +359,14 @@ def multiple_retrieve( score_threshold: float, reranking_mode: str, reranking_model: Optional[dict] = None, - weights: Optional[dict] = None, + weights: Optional[dict[str, Any]] = None, reranking_enable: bool = True, message_id: Optional[str] = None, ): if not available_datasets: return [] threads = [] - all_documents = [] + all_documents: list[Document] = [] dataset_ids = [dataset.id for dataset in available_datasets] index_type_check = all( item.indexing_technique == available_datasets[0].indexing_technique for item in available_datasets @@ -392,15 +393,18 @@ def multiple_retrieve( "The configured knowledge base list have different embedding model, please set reranking model." ) if reranking_enable and reranking_mode == RerankMode.WEIGHTED_SCORE: - weights["vector_setting"]["embedding_provider_name"] = available_datasets[0].embedding_model_provider - weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model + if weights is not None: + weights["vector_setting"]["embedding_provider_name"] = available_datasets[ + 0 + ].embedding_model_provider + weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model for dataset in available_datasets: index_type = dataset.indexing_technique retrieval_thread = threading.Thread( target=self._retriever, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "dataset_id": dataset.id, "query": query, "top_k": top_k, @@ -439,21 +443,22 @@ def _on_retrieval_end( """Handle retrieval end.""" dify_documents = [document for document in documents if document.provider == "dify"] for document in dify_documents: - query = db.session.query(DocumentSegment).filter( - DocumentSegment.index_node_id == document.metadata["doc_id"] - ) + if document.metadata is not None: + query = db.session.query(DocumentSegment).filter( + DocumentSegment.index_node_id == document.metadata["doc_id"] + ) - # if 'dataset_id' in document.metadata: - if "dataset_id" in document.metadata: - query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) + # if 'dataset_id' in document.metadata: + if "dataset_id" in document.metadata: + query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) - # add hit count to document segment - query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) + # add hit count to document segment + query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) - db.session.commit() + db.session.commit() # get tracing instance - trace_manager: TraceQueueManager = ( + trace_manager: Optional[TraceQueueManager] = ( self.application_generate_entity.trace_manager if self.application_generate_entity else None ) if trace_manager: @@ -504,10 +509,11 @@ def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, metadata=external_document.get("metadata"), provider="external", ) - document.metadata["score"] = external_document.get("score") - document.metadata["title"] = external_document.get("title") - document.metadata["dataset_id"] = dataset_id - document.metadata["dataset_name"] = dataset.name + if document.metadata is not None: + document.metadata["score"] = external_document.get("score") + document.metadata["title"] = external_document.get("title") + document.metadata["dataset_id"] = dataset_id + document.metadata["dataset_name"] = dataset.name all_documents.append(document) else: # get retrieval model , if the model is not setting , using default @@ -607,19 +613,20 @@ def to_dataset_retriever_tool( tools.append(tool) elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: - tool = DatasetMultiRetrieverTool.from_dataset( - dataset_ids=[dataset.id for dataset in available_datasets], - tenant_id=tenant_id, - top_k=retrieve_config.top_k or 2, - score_threshold=retrieve_config.score_threshold, - hit_callbacks=[hit_callback], - return_resource=return_resource, - retriever_from=invoke_from.to_source(), - reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"), - reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"), - ) + if retrieve_config.reranking_model is not None: + tool = DatasetMultiRetrieverTool.from_dataset( + dataset_ids=[dataset.id for dataset in available_datasets], + tenant_id=tenant_id, + top_k=retrieve_config.top_k or 2, + score_threshold=retrieve_config.score_threshold, + hit_callbacks=[hit_callback], + return_resource=return_resource, + retriever_from=invoke_from.to_source(), + reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"), + reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"), + ) - tools.append(tool) + tools.append(tool) return tools @@ -635,10 +642,11 @@ def calculate_keyword_score(self, query: str, documents: list[Document], top_k: query_keywords = keyword_table_handler.extract_keywords(query, None) documents_keywords = [] for document in documents: - # get the document keywords - document_keywords = keyword_table_handler.extract_keywords(document.page_content, None) - document.metadata["keywords"] = document_keywords - documents_keywords.append(document_keywords) + if document.metadata is not None: + # get the document keywords + document_keywords = keyword_table_handler.extract_keywords(document.page_content, None) + document.metadata["keywords"] = document_keywords + documents_keywords.append(document_keywords) # Counter query keywords(TF) query_keyword_counts = Counter(query_keywords) @@ -696,8 +704,9 @@ def cosine_similarity(vec1, vec2): for document, score in zip(documents, similarities): # format document - document.metadata["score"] = score - documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) + if document.metadata is not None: + document.metadata["score"] = score + documents = sorted(documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True) return documents[:top_k] if top_k else documents def calculate_vector_score( @@ -705,10 +714,12 @@ def calculate_vector_score( ) -> list[Document]: filter_documents = [] for document in all_documents: - if score_threshold is None or document.metadata["score"] >= score_threshold: + if score_threshold is None or (document.metadata and document.metadata.get("score", 0) >= score_threshold): filter_documents.append(document) if not filter_documents: return [] - filter_documents = sorted(filter_documents, key=lambda x: x.metadata["score"], reverse=True) + filter_documents = sorted( + filter_documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True + ) return filter_documents[:top_k] if top_k else filter_documents diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index 06147fe7b56544..b008d0df9c2f0e 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -1,7 +1,8 @@ -from typing import Union +from typing import Union, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance +from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage @@ -27,11 +28,14 @@ def invoke( SystemPromptMessage(content="You are a helpful AI assistant."), UserPromptMessage(content=query), ] - result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - tools=dataset_tools, - stream=False, - model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500}, + result = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, + tools=dataset_tools, + stream=False, + model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500}, + ), ) if result.message.tool_calls: # get retrieval model config diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index 68fab0c127a253..05e8d043dfe741 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -1,9 +1,9 @@ from collections.abc import Generator, Sequence -from typing import Union +from typing import Union, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate @@ -92,6 +92,7 @@ def _react_invoke( suffix: str = SUFFIX, format_instructions: str = FORMAT_INSTRUCTIONS, ) -> Union[str, None]: + prompt: Union[list[ChatModelMessage], CompletionModelPromptTemplate] if model_config.mode == "chat": prompt = self.create_chat_prompt( query=query, @@ -149,12 +150,15 @@ def _invoke_llm( :param stop: stop :return: """ - invoke_result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=completion_param, - stop=stop, - stream=True, - user=user_id, + invoke_result = cast( + Generator[LLMResult, None, None], + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=completion_param, + stop=stop, + stream=True, + user=user_id, + ), ) # handle invoke result @@ -172,7 +176,7 @@ def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage :return: """ model = None - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] full_text = "" usage = None for result in invoke_result: diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 53032b34d570c7..3376bd7f75dd96 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -26,8 +26,8 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): def from_encoder( cls: type[TS], embedding_model_instance: Optional[ModelInstance], - allowed_special: Union[Literal[all], Set[str]] = set(), - disallowed_special: Union[Literal[all], Collection[str]] = "all", + allowed_special: Union[Literal["all"], Set[str]] = set(), # noqa: UP037 + disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037 **kwargs: Any, ): def _token_encoder(text: str) -> int: diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index 7dd62f8de18a15..4bfa541fd454ad 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -92,7 +92,7 @@ def split_documents(self, documents: Iterable[Document]) -> list[Document]: texts, metadatas = [], [] for doc in documents: texts.append(doc.page_content) - metadatas.append(doc.metadata) + metadatas.append(doc.metadata or {}) return self.create_documents(texts, metadatas=metadatas) def _join_docs(self, docs: list[str], separator: str) -> Optional[str]: @@ -143,7 +143,7 @@ def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter: """Text splitter that uses HuggingFace tokenizer to count length.""" try: - from transformers import PreTrainedTokenizerBase + from transformers import PreTrainedTokenizerBase # type: ignore if not isinstance(tokenizer, PreTrainedTokenizerBase): raise ValueError("Tokenizer received was not an instance of PreTrainedTokenizerBase") diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index ddb1481276df67..975c374cae8356 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -14,7 +14,7 @@ class UserTool(BaseModel): label: I18nObject # label description: I18nObject parameters: Optional[list[ToolParameter]] = None - labels: list[str] = None + labels: list[str] | None = None UserToolProviderTypeLiteral = Optional[Literal["builtin", "api", "workflow"]] diff --git a/api/core/tools/entities/tool_bundle.py b/api/core/tools/entities/tool_bundle.py index 0c15b2a3711f11..7c365dc69d3b39 100644 --- a/api/core/tools/entities/tool_bundle.py +++ b/api/core/tools/entities/tool_bundle.py @@ -18,7 +18,7 @@ class ApiToolBundle(BaseModel): # summary summary: Optional[str] = None # operation_id - operation_id: str = None + operation_id: str | None = None # parameters parameters: Optional[list[ToolParameter]] = None # author diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 4fc383f91baeba..260e4e457f083e 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -244,18 +244,19 @@ def get_simple_instance( """ # convert options to ToolParameterOption if options: - options = [ + options_tool_parametor = [ ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options ] return cls( name=name, label=I18nObject(en_US="", zh_Hans=""), human_description=I18nObject(en_US="", zh_Hans=""), + placeholder=None, type=type, form=cls.ToolParameterForm.LLM, llm_description=llm_description, required=required, - options=options, + options=options_tool_parametor, ) @@ -331,7 +332,7 @@ def to_dict(self) -> dict: "default": self.default, "options": self.options, "help": self.help.to_dict() if self.help else None, - "label": self.label.to_dict(), + "label": self.label.to_dict() if self.label else None, "url": self.url, "placeholder": self.placeholder.to_dict() if self.placeholder else None, } @@ -374,7 +375,10 @@ def __init__(self, **data: Any): pool[index] = ToolRuntimeImageVariable(**variable) super().__init__(**data) - def dict(self) -> dict: + def dict(self) -> dict: # type: ignore + """ + FIXME: just ignore the type check for now + """ return { "conversation_id": self.conversation_id, "user_id": self.user_id, diff --git a/api/core/tools/provider/api_tool_provider.py b/api/core/tools/provider/api_tool_provider.py index d99314e33a3204..f451edbf2ee969 100644 --- a/api/core/tools/provider/api_tool_provider.py +++ b/api/core/tools/provider/api_tool_provider.py @@ -1,9 +1,14 @@ +from typing import Optional + from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( ApiProviderAuthType, ToolCredentialsOption, + ToolDescription, + ToolIdentity, ToolProviderCredentials, + ToolProviderIdentity, ToolProviderType, ) from core.tools.provider.tool_provider import ToolProviderController @@ -64,21 +69,18 @@ def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "Ap pass else: raise ValueError(f"invalid auth type {auth_type}") - - user_name = db_provider.user.name if db_provider.user_id else "" - + user_name = db_provider.user.name if db_provider.user_id and db_provider.user is not None else "" return ApiToolProviderController( - **{ - "identity": { - "author": user_name, - "name": db_provider.name, - "label": {"en_US": db_provider.name, "zh_Hans": db_provider.name}, - "description": {"en_US": db_provider.description, "zh_Hans": db_provider.description}, - "icon": db_provider.icon, - }, - "credentials_schema": credentials_schema, - "provider_id": db_provider.id or "", - } + identity=ToolProviderIdentity( + author=user_name, + name=db_provider.name, + label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name), + description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description), + icon=db_provider.icon, + ), + credentials_schema=credentials_schema, + provider_id=db_provider.id or "", + tools=None, ) @property @@ -93,24 +95,22 @@ def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool: :return: the tool """ return ApiTool( - **{ - "api_bundle": tool_bundle, - "identity": { - "author": tool_bundle.author, - "name": tool_bundle.operation_id, - "label": {"en_US": tool_bundle.operation_id, "zh_Hans": tool_bundle.operation_id}, - "icon": self.identity.icon, - "provider": self.provider_id, - }, - "description": { - "human": {"en_US": tool_bundle.summary or "", "zh_Hans": tool_bundle.summary or ""}, - "llm": tool_bundle.summary or "", - }, - "parameters": tool_bundle.parameters or [], - } + api_bundle=tool_bundle, + identity=ToolIdentity( + author=tool_bundle.author, + name=tool_bundle.operation_id or "", + label=I18nObject(en_US=tool_bundle.operation_id, zh_Hans=tool_bundle.operation_id), + icon=self.identity.icon if self.identity else None, + provider=self.provider_id, + ), + description=ToolDescription( + human=I18nObject(en_US=tool_bundle.summary or "", zh_Hans=tool_bundle.summary or ""), + llm=tool_bundle.summary or "", + ), + parameters=tool_bundle.parameters or [], ) - def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]: + def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[Tool]: """ load bundled tools @@ -121,7 +121,7 @@ def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]: return self.tools - def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]: + def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]: """ fetch tools from database @@ -131,6 +131,8 @@ def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]: """ if self.tools is not None: return self.tools + if self.identity is None: + return None tools: list[Tool] = [] @@ -151,7 +153,7 @@ def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]: self.tools = tools return tools - def get_tool(self, tool_name: str) -> ApiTool: + def get_tool(self, tool_name: str) -> Tool: """ get tool by name @@ -161,7 +163,9 @@ def get_tool(self, tool_name: str) -> ApiTool: if self.tools is None: self.get_tools() - for tool in self.tools: + for tool in self.tools or []: + if tool.identity is None: + continue if tool.identity.name == tool_name: return tool diff --git a/api/core/tools/provider/app_tool_provider.py b/api/core/tools/provider/app_tool_provider.py index 582ad636b1953a..fc29920acd40dc 100644 --- a/api/core/tools/provider/app_tool_provider.py +++ b/api/core/tools/provider/app_tool_provider.py @@ -1,9 +1,10 @@ import logging -from typing import Any +from typing import Any, Optional from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolParameter, ToolParameterOption, ToolProviderType from core.tools.provider.tool_provider import ToolProviderController +from core.tools.tool.api_tool import ApiTool from core.tools.tool.tool import Tool from extensions.ext_database import db from models.model import App, AppModelConfig @@ -20,10 +21,10 @@ def provider_type(self) -> ToolProviderType: def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None: pass - def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) -> None: + def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None: pass - def get_tools(self, user_id: str) -> list[Tool]: + def get_tools(self, user_id: str = "", tenant_id: str = "") -> list[Tool]: db_tools: list[PublishedAppTool] = ( db.session.query(PublishedAppTool) .filter( @@ -38,7 +39,7 @@ def get_tools(self, user_id: str) -> list[Tool]: tools: list[Tool] = [] for db_tool in db_tools: - tool = { + tool: dict[str, Any] = { "identity": { "author": db_tool.author, "name": db_tool.tool_name, @@ -52,7 +53,7 @@ def get_tools(self, user_id: str) -> list[Tool]: "parameters": [], } # get app from db - app: App = db_tool.app + app: Optional[App] = db_tool.app if not app: logger.error(f"app {db_tool.app_id} not found") @@ -79,6 +80,7 @@ def get_tools(self, user_id: str) -> list[Tool]: type=ToolParameter.ToolParameterType.STRING, required=required, default=default, + placeholder=I18nObject(en_US="", zh_Hans=""), ) ) elif form_type == "select": @@ -92,6 +94,7 @@ def get_tools(self, user_id: str) -> list[Tool]: type=ToolParameter.ToolParameterType.SELECT, required=required, default=default, + placeholder=I18nObject(en_US="", zh_Hans=""), options=[ ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options @@ -99,5 +102,5 @@ def get_tools(self, user_id: str) -> list[Tool]: ) ) - tools.append(Tool(**tool)) + tools.append(ApiTool(**tool)) return tools diff --git a/api/core/tools/provider/builtin/_positions.py b/api/core/tools/provider/builtin/_positions.py index 5c10f72fdaed01..99a062f8c366aa 100644 --- a/api/core/tools/provider/builtin/_positions.py +++ b/api/core/tools/provider/builtin/_positions.py @@ -5,7 +5,7 @@ class BuiltinToolProviderSort: - _position = {} + _position: dict[str, int] = {} @classmethod def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]: diff --git a/api/core/tools/provider/builtin/aippt/tools/aippt.py b/api/core/tools/provider/builtin/aippt/tools/aippt.py index 38123f125ae974..cf10f5d2556edd 100644 --- a/api/core/tools/provider/builtin/aippt/tools/aippt.py +++ b/api/core/tools/provider/builtin/aippt/tools/aippt.py @@ -4,7 +4,7 @@ from json import loads as json_loads from threading import Lock from time import sleep, time -from typing import Any +from typing import Any, Union from httpx import get, post from requests import get as requests_get @@ -21,23 +21,25 @@ class AIPPTGenerateToolAdapter: """ _api_base_url = URL("https://co.aippt.cn/api") - _api_token_cache = {} - _style_cache = {} + _api_token_cache: dict[str, dict[str, Union[str, float]]] = {} + _style_cache: dict[str, dict[str, Union[list[dict[str, Any]], float]]] = {} - _api_token_cache_lock = Lock() - _style_cache_lock = Lock() + _api_token_cache_lock: Lock = Lock() + _style_cache_lock: Lock = Lock() - _task = {} + _task: dict[str, Any] = {} _task_type_map = { "auto": 1, "markdown": 7, } - _tool: BuiltinTool + _tool: BuiltinTool | None - def __init__(self, tool: BuiltinTool = None): + def __init__(self, tool: BuiltinTool | None = None): self._tool = tool - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invokes the AIPPT generate tool with the given user ID and tool parameters. @@ -68,8 +70,8 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe ) # get suit - color: str = tool_parameters.get("color") - style: str = tool_parameters.get("style") + color: str = tool_parameters.get("color", "") + style: str = tool_parameters.get("style", "") if color == "__default__": color_id = "" @@ -226,7 +228,7 @@ def _generate_content(self, task_id: str, model: str, user_id: str) -> str: return "" - def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]: + def _generate_ppt(self, task_id: str, suit_id: int, user_id: str) -> tuple[str, str]: """ Generate a ppt @@ -362,7 +364,9 @@ def _calculate_sign(access_key: str, secret_key: str, timestamp: int) -> str: ).decode("utf-8") @classmethod - def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[dict], list[dict]]: + def _get_styles( + cls, credentials: dict[str, str], user_id: str + ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: """ Get styles """ @@ -415,7 +419,7 @@ def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[di return colors, styles - def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]: + def get_styles(self, user_id: str) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: """ Get styles @@ -507,7 +511,9 @@ class AIPPTGenerateTool(BuiltinTool): def __init__(self, **kwargs: Any): super().__init__(**kwargs) - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: return AIPPTGenerateToolAdapter(self)._invoke(user_id, tool_parameters) def get_runtime_parameters(self) -> list[ToolParameter]: diff --git a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py index 2d65ba2d6f4389..8bd16050ecf0a6 100644 --- a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py +++ b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py @@ -1,7 +1,7 @@ import logging from typing import Any, Optional -import arxiv +import arxiv # type: ignore from pydantic import BaseModel, Field from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/provider/builtin/audio/tools/tts.py b/api/core/tools/provider/builtin/audio/tools/tts.py index f83a64d041faab..8a33ac405bd4c3 100644 --- a/api/core/tools/provider/builtin/audio/tools/tts.py +++ b/api/core/tools/provider/builtin/audio/tools/tts.py @@ -11,19 +11,21 @@ class TTSTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]: - provider, model = tool_parameters.get("model").split("#") - voice = tool_parameters.get(f"voice#{provider}#{model}") + provider, model = tool_parameters.get("model", "").split("#") + voice = tool_parameters.get(f"voice#{provider}#{model}", "") model_manager = ModelManager() + if not self.runtime: + raise ValueError("Runtime is required") model_instance = model_manager.get_model_instance( - tenant_id=self.runtime.tenant_id, + tenant_id=self.runtime.tenant_id or "", provider=provider, model_type=ModelType.TTS, model=model, ) tts = model_instance.invoke_tts( - content_text=tool_parameters.get("text"), + content_text=tool_parameters.get("text", ""), user=user_id, - tenant_id=self.runtime.tenant_id, + tenant_id=self.runtime.tenant_id or "", voice=voice, ) buffer = io.BytesIO() @@ -41,8 +43,11 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInv ] def get_available_models(self) -> list[tuple[str, str, list[Any]]]: + if not self.runtime: + raise ValueError("Runtime is required") model_provider_service = ModelProviderService() - models = model_provider_service.get_models_by_model_type(tenant_id=self.runtime.tenant_id, model_type="tts") + tid: str = self.runtime.tenant_id or "" + models = model_provider_service.get_models_by_model_type(tenant_id=tid, model_type="tts") items = [] for provider_model in models: provider = provider_model.provider @@ -62,6 +67,8 @@ def get_runtime_parameters(self) -> list[ToolParameter]: ToolParameter( name=f"voice#{provider}#{model}", label=I18nObject(en_US=f"Voice of {model}({provider})"), + human_description=I18nObject(en_US=f"Select a voice for {model} model"), + placeholder=I18nObject(en_US="Select a voice"), type=ToolParameter.ToolParameterType.SELECT, form=ToolParameter.ToolParameterForm.FORM, options=[ @@ -83,6 +90,7 @@ def get_runtime_parameters(self) -> list[ToolParameter]: type=ToolParameter.ToolParameterType.SELECT, form=ToolParameter.ToolParameterForm.FORM, required=True, + placeholder=I18nObject(en_US="Select a model", zh_Hans="选择模型"), options=options, ), ) diff --git a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py index a04f5c0fe9f1af..b224ff5258c879 100644 --- a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py +++ b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py @@ -2,8 +2,8 @@ import logging from typing import Any, Union -import boto3 -from botocore.exceptions import BotoCoreError +import boto3 # type: ignore +from botocore.exceptions import BotoCoreError # type: ignore from pydantic import BaseModel, Field from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py index 989608122185c8..b6d16d2759c30e 100644 --- a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py +++ b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py @@ -1,7 +1,7 @@ import json from typing import Any, Union -import boto3 +import boto3 # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py b/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py index f43f3b6fe05694..01bc596346c231 100644 --- a/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py +++ b/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py @@ -2,7 +2,7 @@ import logging from typing import Any, Union -import boto3 +import boto3 # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py index bffcd058b509bf..715b1ddeddcae5 100644 --- a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py @@ -2,7 +2,7 @@ import operator from typing import Any, Union -import boto3 +import boto3 # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool @@ -10,8 +10,8 @@ class SageMakerReRankTool(BuiltinTool): sagemaker_client: Any = None - sagemaker_endpoint: str = None - topk: int = None + sagemaker_endpoint: str | None = None + topk: int | None = None def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str): inputs = [query_input] * len(docs) diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py index 1fafe09b4d96bf..55cff89798a4eb 100644 --- a/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py @@ -2,7 +2,7 @@ from enum import Enum from typing import Any, Optional, Union -import boto3 +import boto3 # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool @@ -17,7 +17,7 @@ class TTSModelType(Enum): class SageMakerTTSTool(BuiltinTool): sagemaker_client: Any = None - sagemaker_endpoint: str = None + sagemaker_endpoint: str | None = None s3_client: Any = None comprehend_client: Any = None diff --git a/api/core/tools/provider/builtin/cogview/tools/cogvideo.py b/api/core/tools/provider/builtin/cogview/tools/cogvideo.py index 7f69e833cb9046..a60062ca66abbf 100644 --- a/api/core/tools/provider/builtin/cogview/tools/cogvideo.py +++ b/api/core/tools/provider/builtin/cogview/tools/cogvideo.py @@ -1,6 +1,6 @@ from typing import Any, Union -from zhipuai import ZhipuAI +from zhipuai import ZhipuAI # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/cogview/tools/cogvideo_job.py b/api/core/tools/provider/builtin/cogview/tools/cogvideo_job.py index a521f1c28a41b6..3e24b74d2598a7 100644 --- a/api/core/tools/provider/builtin/cogview/tools/cogvideo_job.py +++ b/api/core/tools/provider/builtin/cogview/tools/cogvideo_job.py @@ -1,7 +1,7 @@ from typing import Any, Union import httpx -from zhipuai import ZhipuAI +from zhipuai import ZhipuAI # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/cogview/tools/cogview3.py b/api/core/tools/provider/builtin/cogview/tools/cogview3.py index 12b4173fa40270..9aa781709a726c 100644 --- a/api/core/tools/provider/builtin/cogview/tools/cogview3.py +++ b/api/core/tools/provider/builtin/cogview/tools/cogview3.py @@ -1,7 +1,7 @@ import random from typing import Any, Union -from zhipuai import ZhipuAI +from zhipuai import ZhipuAI # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/feishu_base/tools/search_records.py b/api/core/tools/provider/builtin/feishu_base/tools/search_records.py index c959496735e747..d58b42b82029ce 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/search_records.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/search_records.py @@ -7,18 +7,22 @@ class SearchRecordsTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - app_token = tool_parameters.get("app_token") - table_id = tool_parameters.get("table_id") - table_name = tool_parameters.get("table_name") - view_id = tool_parameters.get("view_id") - field_names = tool_parameters.get("field_names") - sort = tool_parameters.get("sort") - filters = tool_parameters.get("filter") - page_token = tool_parameters.get("page_token") + app_token = tool_parameters.get("app_token", "") + table_id = tool_parameters.get("table_id", "") + table_name = tool_parameters.get("table_name", "") + view_id = tool_parameters.get("view_id", "") + field_names = tool_parameters.get("field_names", "") + sort = tool_parameters.get("sort", "") + filters = tool_parameters.get("filter", "") + page_token = tool_parameters.get("page_token", "") automatic_fields = tool_parameters.get("automatic_fields", False) user_id_type = tool_parameters.get("user_id_type", "open_id") page_size = tool_parameters.get("page_size", 20) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/update_records.py b/api/core/tools/provider/builtin/feishu_base/tools/update_records.py index a7b036387500b0..31cf8e18d85b8d 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/update_records.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/update_records.py @@ -7,14 +7,18 @@ class UpdateRecordsTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - app_token = tool_parameters.get("app_token") - table_id = tool_parameters.get("table_id") - table_name = tool_parameters.get("table_name") - records = tool_parameters.get("records") + app_token = tool_parameters.get("app_token", "") + table_id = tool_parameters.get("table_id", "") + table_name = tool_parameters.get("table_name", "") + records = tool_parameters.get("records", "") user_id_type = tool_parameters.get("user_id_type", "open_id") res = client.update_records(app_token, table_id, table_name, records, user_id_type) diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/add_event_attendees.py b/api/core/tools/provider/builtin/feishu_calendar/tools/add_event_attendees.py index 8f83aea5abbe3d..80287feca176e1 100644 --- a/api/core/tools/provider/builtin/feishu_calendar/tools/add_event_attendees.py +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/add_event_attendees.py @@ -7,12 +7,16 @@ class AddEventAttendeesTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - event_id = tool_parameters.get("event_id") - attendee_phone_or_email = tool_parameters.get("attendee_phone_or_email") + event_id = tool_parameters.get("event_id", "") + attendee_phone_or_email = tool_parameters.get("attendee_phone_or_email", "") need_notification = tool_parameters.get("need_notification", True) res = client.add_event_attendees(event_id, attendee_phone_or_email, need_notification) diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/delete_event.py b/api/core/tools/provider/builtin/feishu_calendar/tools/delete_event.py index 144889692f9055..02e9b445219ac8 100644 --- a/api/core/tools/provider/builtin/feishu_calendar/tools/delete_event.py +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/delete_event.py @@ -7,11 +7,15 @@ class DeleteEventTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - event_id = tool_parameters.get("event_id") + event_id = tool_parameters.get("event_id", "") need_notification = tool_parameters.get("need_notification", True) res = client.delete_event(event_id, need_notification) diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/get_primary_calendar.py b/api/core/tools/provider/builtin/feishu_calendar/tools/get_primary_calendar.py index a2cd5a8b17d0af..4dafe4b3baf0cd 100644 --- a/api/core/tools/provider/builtin/feishu_calendar/tools/get_primary_calendar.py +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/get_primary_calendar.py @@ -7,8 +7,12 @@ class GetPrimaryCalendarTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) user_id_type = tool_parameters.get("user_id_type", "open_id") diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/list_events.py b/api/core/tools/provider/builtin/feishu_calendar/tools/list_events.py index 8815b4c9c871cd..2e8ca968b3cc42 100644 --- a/api/core/tools/provider/builtin/feishu_calendar/tools/list_events.py +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/list_events.py @@ -7,14 +7,18 @@ class ListEventsTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - start_time = tool_parameters.get("start_time") - end_time = tool_parameters.get("end_time") - page_token = tool_parameters.get("page_token") - page_size = tool_parameters.get("page_size") + start_time = tool_parameters.get("start_time", "") + end_time = tool_parameters.get("end_time", "") + page_token = tool_parameters.get("page_token", "") + page_size = tool_parameters.get("page_size", 50) res = client.list_events(start_time, end_time, page_token, page_size) diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/update_event.py b/api/core/tools/provider/builtin/feishu_calendar/tools/update_event.py index 85bcb1d3f63847..b20eb6c31828e4 100644 --- a/api/core/tools/provider/builtin/feishu_calendar/tools/update_event.py +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/update_event.py @@ -7,16 +7,20 @@ class UpdateEventTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - event_id = tool_parameters.get("event_id") - summary = tool_parameters.get("summary") - description = tool_parameters.get("description") + event_id = tool_parameters.get("event_id", "") + summary = tool_parameters.get("summary", "") + description = tool_parameters.get("description", "") need_notification = tool_parameters.get("need_notification", True) - start_time = tool_parameters.get("start_time") - end_time = tool_parameters.get("end_time") + start_time = tool_parameters.get("start_time", "") + end_time = tool_parameters.get("end_time", "") auto_record = tool_parameters.get("auto_record", False) res = client.update_event(event_id, summary, description, need_notification, start_time, end_time, auto_record) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/create_document.py b/api/core/tools/provider/builtin/feishu_document/tools/create_document.py index 090a0828e89bbf..1533f594172878 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/create_document.py +++ b/api/core/tools/provider/builtin/feishu_document/tools/create_document.py @@ -7,13 +7,17 @@ class CreateDocumentTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - title = tool_parameters.get("title") - content = tool_parameters.get("content") - folder_token = tool_parameters.get("folder_token") + title = tool_parameters.get("title", "") + content = tool_parameters.get("content", "") + folder_token = tool_parameters.get("folder_token", "") res = client.create_document(title, content, folder_token) return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py b/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py index dd57c6870d0ba9..8ea68a2ed87855 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py +++ b/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py @@ -7,11 +7,15 @@ class ListDocumentBlockTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - document_id = tool_parameters.get("document_id") + document_id = tool_parameters.get("document_id", "") page_token = tool_parameters.get("page_token", "") user_id_type = tool_parameters.get("user_id_type", "open_id") page_size = tool_parameters.get("page_size", 500) diff --git a/api/core/tools/provider/builtin/json_process/tools/delete.py b/api/core/tools/provider/builtin/json_process/tools/delete.py index fcab3d71a93cf9..06f6cacd5d6126 100644 --- a/api/core/tools/provider/builtin/json_process/tools/delete.py +++ b/api/core/tools/provider/builtin/json_process/tools/delete.py @@ -1,7 +1,7 @@ import json from typing import Any, Union -from jsonpath_ng import parse +from jsonpath_ng import parse # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/json_process/tools/insert.py b/api/core/tools/provider/builtin/json_process/tools/insert.py index 793c74e5f9df51..e825329a6d8f61 100644 --- a/api/core/tools/provider/builtin/json_process/tools/insert.py +++ b/api/core/tools/provider/builtin/json_process/tools/insert.py @@ -1,7 +1,7 @@ import json from typing import Any, Union -from jsonpath_ng import parse +from jsonpath_ng import parse # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/json_process/tools/parse.py b/api/core/tools/provider/builtin/json_process/tools/parse.py index f91432ee77f488..193017ba9a7c53 100644 --- a/api/core/tools/provider/builtin/json_process/tools/parse.py +++ b/api/core/tools/provider/builtin/json_process/tools/parse.py @@ -1,7 +1,7 @@ import json from typing import Any, Union -from jsonpath_ng import parse +from jsonpath_ng import parse # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/json_process/tools/replace.py b/api/core/tools/provider/builtin/json_process/tools/replace.py index 383825c2d0b259..feca0d8a7c2783 100644 --- a/api/core/tools/provider/builtin/json_process/tools/replace.py +++ b/api/core/tools/provider/builtin/json_process/tools/replace.py @@ -1,7 +1,7 @@ import json from typing import Any, Union -from jsonpath_ng import parse +from jsonpath_ng import parse # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/maths/tools/eval_expression.py b/api/core/tools/provider/builtin/maths/tools/eval_expression.py index 0c5b5e41cbe1e1..d3a497d1cd5c54 100644 --- a/api/core/tools/provider/builtin/maths/tools/eval_expression.py +++ b/api/core/tools/provider/builtin/maths/tools/eval_expression.py @@ -1,7 +1,7 @@ import logging from typing import Any, Union -import numexpr as ne +import numexpr as ne # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py b/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py index db4adfd4ad4629..6473c509e1f4c2 100644 --- a/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py +++ b/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py @@ -1,4 +1,4 @@ -from novita_client import ( +from novita_client import ( # type: ignore Txt2ImgV3Embedding, Txt2ImgV3HiresFix, Txt2ImgV3LoRA, diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py index 0b4f2edff3607f..097b234bd50640 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py @@ -2,7 +2,7 @@ from copy import deepcopy from typing import Any, Union -from novita_client import ( +from novita_client import ( # type: ignore NovitaClient, ) diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py index 9c61eab9f95784..297a27abba667a 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py @@ -2,7 +2,7 @@ from copy import deepcopy from typing import Any, Union -from novita_client import ( +from novita_client import ( # type: ignore NovitaClient, ) diff --git a/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py b/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py index 165e93956eff38..704e0015d961a3 100644 --- a/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py +++ b/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py @@ -13,7 +13,7 @@ with warnings.catch_warnings(): warnings.simplefilter("ignore") - from pydub import AudioSegment + from pydub import AudioSegment # type: ignore class PodcastAudioGeneratorTool(BuiltinTool): diff --git a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py index d8ca20bde6ffc9..4a47c4211f4fd4 100644 --- a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py +++ b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py @@ -2,10 +2,10 @@ import logging from typing import Any, Union -from qrcode.constants import ERROR_CORRECT_H, ERROR_CORRECT_L, ERROR_CORRECT_M, ERROR_CORRECT_Q -from qrcode.image.base import BaseImage -from qrcode.image.pure import PyPNGImage -from qrcode.main import QRCode +from qrcode.constants import ERROR_CORRECT_H, ERROR_CORRECT_L, ERROR_CORRECT_M, ERROR_CORRECT_Q # type: ignore +from qrcode.image.base import BaseImage # type: ignore +from qrcode.image.pure import PyPNGImage # type: ignore +from qrcode.main import QRCode # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/transcript/tools/transcript.py b/api/core/tools/provider/builtin/transcript/tools/transcript.py index 27f700efbd6936..ac7565d9eef5b8 100644 --- a/api/core/tools/provider/builtin/transcript/tools/transcript.py +++ b/api/core/tools/provider/builtin/transcript/tools/transcript.py @@ -1,7 +1,7 @@ from typing import Any, Union from urllib.parse import parse_qs, urlparse -from youtube_transcript_api import YouTubeTranscriptApi +from youtube_transcript_api import YouTubeTranscriptApi # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/twilio/tools/send_message.py b/api/core/tools/provider/builtin/twilio/tools/send_message.py index 5ee839baa56f02..98a108f4ec7e93 100644 --- a/api/core/tools/provider/builtin/twilio/tools/send_message.py +++ b/api/core/tools/provider/builtin/twilio/tools/send_message.py @@ -37,7 +37,7 @@ class TwilioAPIWrapper(BaseModel): def set_validator(cls, values: dict) -> dict: """Validate that api key and python package exists in environment.""" try: - from twilio.rest import Client + from twilio.rest import Client # type: ignore except ImportError: raise ImportError("Could not import twilio python package. Please install it with `pip install twilio`.") account_sid = values.get("account_sid") diff --git a/api/core/tools/provider/builtin/twilio/twilio.py b/api/core/tools/provider/builtin/twilio/twilio.py index b1d100aad93dba..649e03d185121c 100644 --- a/api/core/tools/provider/builtin/twilio/twilio.py +++ b/api/core/tools/provider/builtin/twilio/twilio.py @@ -1,7 +1,7 @@ from typing import Any -from twilio.base.exceptions import TwilioRestException -from twilio.rest import Client +from twilio.base.exceptions import TwilioRestException # type: ignore +from twilio.rest import Client # type: ignore from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController diff --git a/api/core/tools/provider/builtin/vanna/tools/vanna.py b/api/core/tools/provider/builtin/vanna/tools/vanna.py index 1c7cb39c92b40b..a6afd2dddfc63a 100644 --- a/api/core/tools/provider/builtin/vanna/tools/vanna.py +++ b/api/core/tools/provider/builtin/vanna/tools/vanna.py @@ -1,6 +1,6 @@ from typing import Any, Union -from vanna.remote import VannaDefault +from vanna.remote import VannaDefault # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.errors import ToolProviderCredentialValidationError @@ -14,6 +14,9 @@ def _invoke( """ invoke tools """ + # Ensure runtime and credentials + if not self.runtime or not self.runtime.credentials: + raise ToolProviderCredentialValidationError("Tool runtime or credentials are missing") api_key = self.runtime.credentials.get("api_key", None) if not api_key: raise ToolProviderCredentialValidationError("Please input api key") diff --git a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py index cb88e9519a4346..edb96e722f7f33 100644 --- a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py +++ b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py @@ -1,6 +1,6 @@ from typing import Any, Optional, Union -import wikipedia +import wikipedia # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/yahoo/tools/analytics.py b/api/core/tools/provider/builtin/yahoo/tools/analytics.py index f044fbe5404b0a..95a65ba22fc8af 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/analytics.py +++ b/api/core/tools/provider/builtin/yahoo/tools/analytics.py @@ -3,7 +3,7 @@ import pandas as pd from requests.exceptions import HTTPError, ReadTimeout -from yfinance import download +from yfinance import download # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/yahoo/tools/news.py b/api/core/tools/provider/builtin/yahoo/tools/news.py index ff820430f9f366..c9ae0c4ca7fcc6 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/news.py +++ b/api/core/tools/provider/builtin/yahoo/tools/news.py @@ -1,6 +1,6 @@ from typing import Any, Union -import yfinance +import yfinance # type: ignore from requests.exceptions import HTTPError, ReadTimeout from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/provider/builtin/yahoo/tools/ticker.py b/api/core/tools/provider/builtin/yahoo/tools/ticker.py index dfc7e460473c33..74d0d25addf04b 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/ticker.py +++ b/api/core/tools/provider/builtin/yahoo/tools/ticker.py @@ -1,7 +1,7 @@ from typing import Any, Union from requests.exceptions import HTTPError, ReadTimeout -from yfinance import Ticker +from yfinance import Ticker # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/youtube/tools/videos.py b/api/core/tools/provider/builtin/youtube/tools/videos.py index 95dec2eac9a752..a24fe89679b29b 100644 --- a/api/core/tools/provider/builtin/youtube/tools/videos.py +++ b/api/core/tools/provider/builtin/youtube/tools/videos.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import Any, Union -from googleapiclient.discovery import build +from googleapiclient.discovery import build # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index 955a0add3b4513..61de75ac5e2ccd 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -1,6 +1,6 @@ from abc import abstractmethod from os import listdir, path -from typing import Any +from typing import Any, Optional from core.helper.module_import_helper import load_single_subclass_from_source from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType @@ -50,6 +50,8 @@ def _get_builtin_tools(self) -> list[Tool]: """ if self.tools: return self.tools + if not self.identity: + return [] provider = self.identity.name tool_path = path.join(path.dirname(path.realpath(__file__)), "builtin", provider, "tools") @@ -86,7 +88,7 @@ def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]: return self.credentials_schema.copy() - def get_tools(self) -> list[Tool]: + def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]: """ returns a list of tools that the provider can provide @@ -94,11 +96,14 @@ def get_tools(self) -> list[Tool]: """ return self._get_builtin_tools() - def get_tool(self, tool_name: str) -> Tool: + def get_tool(self, tool_name: str) -> Optional[Tool]: """ returns the tool that the provider can provide """ - return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) + tools = self.get_tools() + if tools is None: + raise ValueError("tools not found") + return next((t for t in tools if t.identity and t.identity.name == tool_name), None) def get_parameters(self, tool_name: str) -> list[ToolParameter]: """ @@ -107,10 +112,13 @@ def get_parameters(self, tool_name: str) -> list[ToolParameter]: :param tool_name: the name of the tool, defined in `get_tools` :return: list of parameters """ - tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) + tools = self.get_tools() + if tools is None: + raise ToolNotFoundError(f"tool {tool_name} not found") + tool = next((t for t in tools if t.identity and t.identity.name == tool_name), None) if tool is None: raise ToolNotFoundError(f"tool {tool_name} not found") - return tool.parameters + return tool.parameters or [] @property def need_credentials(self) -> bool: @@ -144,6 +152,8 @@ def _get_tool_labels(self) -> list[ToolLabelEnum]: """ returns the labels of the provider """ + if self.identity is None: + return [] return self.identity.tags or [] def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None: @@ -159,56 +169,56 @@ def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dic for parameter in tool_parameters_schema: tool_parameters_need_to_validate[parameter.name] = parameter - for parameter in tool_parameters: - if parameter not in tool_parameters_need_to_validate: - raise ToolParameterValidationError(f"parameter {parameter} not found in tool {tool_name}") + for parameter_name in tool_parameters: + if parameter_name not in tool_parameters_need_to_validate: + raise ToolParameterValidationError(f"parameter {parameter_name} not found in tool {tool_name}") # check type - parameter_schema = tool_parameters_need_to_validate[parameter] + parameter_schema = tool_parameters_need_to_validate[parameter_name] if parameter_schema.type == ToolParameter.ToolParameterType.STRING: - if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f"parameter {parameter} should be string") + if not isinstance(tool_parameters[parameter_name], str): + raise ToolParameterValidationError(f"parameter {parameter_name} should be string") elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER: - if not isinstance(tool_parameters[parameter], int | float): - raise ToolParameterValidationError(f"parameter {parameter} should be number") + if not isinstance(tool_parameters[parameter_name], int | float): + raise ToolParameterValidationError(f"parameter {parameter_name} should be number") - if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min: + if parameter_schema.min is not None and tool_parameters[parameter_name] < parameter_schema.min: raise ToolParameterValidationError( - f"parameter {parameter} should be greater than {parameter_schema.min}" + f"parameter {parameter_name} should be greater than {parameter_schema.min}" ) - if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max: + if parameter_schema.max is not None and tool_parameters[parameter_name] > parameter_schema.max: raise ToolParameterValidationError( - f"parameter {parameter} should be less than {parameter_schema.max}" + f"parameter {parameter_name} should be less than {parameter_schema.max}" ) elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN: - if not isinstance(tool_parameters[parameter], bool): - raise ToolParameterValidationError(f"parameter {parameter} should be boolean") + if not isinstance(tool_parameters[parameter_name], bool): + raise ToolParameterValidationError(f"parameter {parameter_name} should be boolean") elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT: - if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f"parameter {parameter} should be string") + if not isinstance(tool_parameters[parameter_name], str): + raise ToolParameterValidationError(f"parameter {parameter_name} should be string") options = parameter_schema.options if not isinstance(options, list): - raise ToolParameterValidationError(f"parameter {parameter} options should be list") + raise ToolParameterValidationError(f"parameter {parameter_name} options should be list") - if tool_parameters[parameter] not in [x.value for x in options]: - raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}") + if tool_parameters[parameter_name] not in [x.value for x in options]: + raise ToolParameterValidationError(f"parameter {parameter_name} should be one of {options}") - tool_parameters_need_to_validate.pop(parameter) + tool_parameters_need_to_validate.pop(parameter_name) - for parameter in tool_parameters_need_to_validate: - parameter_schema = tool_parameters_need_to_validate[parameter] + for parameter_name in tool_parameters_need_to_validate: + parameter_schema = tool_parameters_need_to_validate[parameter_name] if parameter_schema.required: - raise ToolParameterValidationError(f"parameter {parameter} is required") + raise ToolParameterValidationError(f"parameter {parameter_name} is required") # the parameter is not set currently, set the default value if needed if parameter_schema.default is not None: default_value = parameter_schema.type.cast_value(parameter_schema.default) - tool_parameters[parameter] = default_value + tool_parameters[parameter_name] = default_value def validate_credentials(self, credentials: dict[str, Any]) -> None: """ diff --git a/api/core/tools/provider/tool_provider.py b/api/core/tools/provider/tool_provider.py index bc05a11562b717..e35207e4f06404 100644 --- a/api/core/tools/provider/tool_provider.py +++ b/api/core/tools/provider/tool_provider.py @@ -24,10 +24,12 @@ def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]: :return: the credentials schema """ + if self.credentials_schema is None: + return {} return self.credentials_schema.copy() @abstractmethod - def get_tools(self) -> list[Tool]: + def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]: """ returns a list of tools that the provider can provide @@ -36,7 +38,7 @@ def get_tools(self) -> list[Tool]: pass @abstractmethod - def get_tool(self, tool_name: str) -> Tool: + def get_tool(self, tool_name: str) -> Optional[Tool]: """ returns a tool that the provider can provide @@ -51,10 +53,13 @@ def get_parameters(self, tool_name: str) -> list[ToolParameter]: :param tool_name: the name of the tool, defined in `get_tools` :return: list of parameters """ - tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) + tools = self.get_tools() + if tools is None: + raise ToolNotFoundError(f"tool {tool_name} not found") + tool = next((t for t in tools if t.identity and t.identity.name == tool_name), None) if tool is None: raise ToolNotFoundError(f"tool {tool_name} not found") - return tool.parameters + return tool.parameters or [] @property def provider_type(self) -> ToolProviderType: @@ -78,55 +83,55 @@ def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dic for parameter in tool_parameters_schema: tool_parameters_need_to_validate[parameter.name] = parameter - for parameter in tool_parameters: - if parameter not in tool_parameters_need_to_validate: - raise ToolParameterValidationError(f"parameter {parameter} not found in tool {tool_name}") + for tool_parameter in tool_parameters: + if tool_parameter not in tool_parameters_need_to_validate: + raise ToolParameterValidationError(f"parameter {tool_parameter} not found in tool {tool_name}") # check type - parameter_schema = tool_parameters_need_to_validate[parameter] + parameter_schema = tool_parameters_need_to_validate[tool_parameter] if parameter_schema.type == ToolParameter.ToolParameterType.STRING: - if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f"parameter {parameter} should be string") + if not isinstance(tool_parameters[tool_parameter], str): + raise ToolParameterValidationError(f"parameter {tool_parameter} should be string") elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER: - if not isinstance(tool_parameters[parameter], int | float): - raise ToolParameterValidationError(f"parameter {parameter} should be number") + if not isinstance(tool_parameters[tool_parameter], int | float): + raise ToolParameterValidationError(f"parameter {tool_parameter} should be number") - if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min: + if parameter_schema.min is not None and tool_parameters[tool_parameter] < parameter_schema.min: raise ToolParameterValidationError( - f"parameter {parameter} should be greater than {parameter_schema.min}" + f"parameter {tool_parameter} should be greater than {parameter_schema.min}" ) - if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max: + if parameter_schema.max is not None and tool_parameters[tool_parameter] > parameter_schema.max: raise ToolParameterValidationError( - f"parameter {parameter} should be less than {parameter_schema.max}" + f"parameter {tool_parameter} should be less than {parameter_schema.max}" ) elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN: - if not isinstance(tool_parameters[parameter], bool): - raise ToolParameterValidationError(f"parameter {parameter} should be boolean") + if not isinstance(tool_parameters[tool_parameter], bool): + raise ToolParameterValidationError(f"parameter {tool_parameter} should be boolean") elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT: - if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f"parameter {parameter} should be string") + if not isinstance(tool_parameters[tool_parameter], str): + raise ToolParameterValidationError(f"parameter {tool_parameter} should be string") options = parameter_schema.options if not isinstance(options, list): - raise ToolParameterValidationError(f"parameter {parameter} options should be list") + raise ToolParameterValidationError(f"parameter {tool_parameter} options should be list") - if tool_parameters[parameter] not in [x.value for x in options]: - raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}") + if tool_parameters[tool_parameter] not in [x.value for x in options]: + raise ToolParameterValidationError(f"parameter {tool_parameter} should be one of {options}") - tool_parameters_need_to_validate.pop(parameter) + tool_parameters_need_to_validate.pop(tool_parameter) - for parameter in tool_parameters_need_to_validate: - parameter_schema = tool_parameters_need_to_validate[parameter] + for tool_parameter_validate in tool_parameters_need_to_validate: + parameter_schema = tool_parameters_need_to_validate[tool_parameter_validate] if parameter_schema.required: - raise ToolParameterValidationError(f"parameter {parameter} is required") + raise ToolParameterValidationError(f"parameter {tool_parameter_validate} is required") # the parameter is not set currently, set the default value if needed if parameter_schema.default is not None: - tool_parameters[parameter] = parameter_schema.type.cast_value(parameter_schema.default) + tool_parameters[tool_parameter_validate] = parameter_schema.type.cast_value(parameter_schema.default) def validate_credentials_format(self, credentials: dict[str, Any]) -> None: """ @@ -144,6 +149,8 @@ def validate_credentials_format(self, credentials: dict[str, Any]) -> None: for credential_name in credentials: if credential_name not in credentials_need_to_validate: + if self.identity is None: + raise ValueError("identity is not set") raise ToolProviderCredentialValidationError( f"credential {credential_name} not found in provider {self.identity.name}" ) diff --git a/api/core/tools/provider/workflow_tool_provider.py b/api/core/tools/provider/workflow_tool_provider.py index 5656dd09ab8c94..17fe2e20cf282e 100644 --- a/api/core/tools/provider/workflow_tool_provider.py +++ b/api/core/tools/provider/workflow_tool_provider.py @@ -11,6 +11,7 @@ ToolProviderType, ) from core.tools.provider.tool_provider import ToolProviderController +from core.tools.tool.tool import Tool from core.tools.tool.workflow_tool import WorkflowTool from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from extensions.ext_database import db @@ -116,6 +117,7 @@ def fetch_workflow_variable(variable_name: str): llm_description=parameter.description, required=variable.required, options=options, + placeholder=I18nObject(en_US="", zh_Hans=""), ) ) elif features.file_upload: @@ -128,6 +130,7 @@ def fetch_workflow_variable(variable_name: str): llm_description=parameter.description, required=False, form=parameter.form, + placeholder=I18nObject(en_US="", zh_Hans=""), ) ) else: @@ -157,7 +160,7 @@ def fetch_workflow_variable(variable_name: str): label=db_provider.label, ) - def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]: + def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]: """ fetch tools from database @@ -168,7 +171,7 @@ def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]: if self.tools is not None: return self.tools - db_providers: WorkflowToolProvider = ( + db_providers: Optional[WorkflowToolProvider] = ( db.session.query(WorkflowToolProvider) .filter( WorkflowToolProvider.tenant_id == tenant_id, @@ -179,12 +182,14 @@ def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]: if not db_providers: return [] + if not db_providers.app: + raise ValueError("app not found") self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)] return self.tools - def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: + def get_tool(self, tool_name: str) -> Optional[Tool]: """ get tool by name @@ -195,6 +200,8 @@ def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: return None for tool in self.tools: + if tool.identity is None: + continue if tool.identity.name == tool_name: return tool diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index 48aac75dbb4115..9a00450290a660 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -32,11 +32,13 @@ def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": :param meta: the meta data of a tool call processing, tenant_id is required :return: the new tool """ + if self.api_bundle is None: + raise ValueError("api_bundle is required") return self.__class__( identity=self.identity.model_copy() if self.identity else None, parameters=self.parameters.copy() if self.parameters else None, description=self.description.model_copy() if self.description else None, - api_bundle=self.api_bundle.model_copy() if self.api_bundle else None, + api_bundle=self.api_bundle.model_copy(), runtime=Tool.Runtime(**runtime), ) @@ -61,6 +63,8 @@ def tool_provider_type(self) -> ToolProviderType: def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]: headers = {} + if self.runtime is None: + raise ValueError("runtime is required") credentials = self.runtime.credentials or {} if "auth_type" not in credentials: @@ -88,7 +92,7 @@ def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]: headers[api_key_header] = credentials["api_key_value"] - needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required] + needed_parameters = [parameter for parameter in (self.api_bundle.parameters or []) if parameter.required] for parameter in needed_parameters: if parameter.required and parameter.name not in parameters: raise ToolParameterValidationError(f"Missing required parameter {parameter.name}") @@ -137,7 +141,8 @@ def do_http_request( params = {} path_params = {} - body = {} + # FIXME: body should be a dict[str, Any] but it changed a lot in this function + body: Any = {} cookies = {} files = [] @@ -198,7 +203,7 @@ def do_http_request( body = body if method in {"get", "head", "post", "put", "delete", "patch"}: - response = getattr(ssrf_proxy, method)( + response: httpx.Response = getattr(ssrf_proxy, method)( url, params=params, headers=headers, @@ -288,6 +293,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe """ invoke http request """ + response: httpx.Response | str = "" # assemble request headers = self.assembling_request(tool_parameters) diff --git a/api/core/tools/tool/builtin_tool.py b/api/core/tools/tool/builtin_tool.py index e2a81ed0a36edd..adda4297f38e8a 100644 --- a/api/core/tools/tool/builtin_tool.py +++ b/api/core/tools/tool/builtin_tool.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, cast from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage @@ -32,9 +32,12 @@ def invoke_model(self, user_id: str, prompt_messages: list[PromptMessage], stop: :return: the model result """ # invoke model + if self.runtime is None or self.identity is None: + raise ValueError("runtime and identity are required") + return ModelInvocationUtils.invoke( user_id=user_id, - tenant_id=self.runtime.tenant_id, + tenant_id=self.runtime.tenant_id or "", tool_type="builtin", tool_name=self.identity.name, prompt_messages=prompt_messages, @@ -50,8 +53,11 @@ def get_max_tokens(self) -> int: :param model_config: the model config :return: the max tokens """ + if self.runtime is None: + raise ValueError("runtime is required") + return ModelInvocationUtils.get_max_llm_context_tokens( - tenant_id=self.runtime.tenant_id, + tenant_id=self.runtime.tenant_id or "", ) def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int: @@ -61,7 +67,12 @@ def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int: :param prompt_messages: the prompt messages :return: the tokens """ - return ModelInvocationUtils.calculate_tokens(tenant_id=self.runtime.tenant_id, prompt_messages=prompt_messages) + if self.runtime is None: + raise ValueError("runtime is required") + + return ModelInvocationUtils.calculate_tokens( + tenant_id=self.runtime.tenant_id or "", prompt_messages=prompt_messages + ) def summary(self, user_id: str, content: str) -> str: max_tokens = self.get_max_tokens() @@ -81,7 +92,7 @@ def summarize(content: str) -> str: stop=[], ) - return summary.message.content + return cast(str, summary.message.content) lines = content.split("\n") new_lines = [] @@ -102,16 +113,16 @@ def summarize(content: str) -> str: # merge lines into messages with max tokens messages: list[str] = [] - for i in new_lines: + for j in new_lines: if len(messages) == 0: - messages.append(i) + messages.append(j) else: - if len(messages[-1]) + len(i) < max_tokens * 0.5: - messages[-1] += i - if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7: - messages.append(i) + if len(messages[-1]) + len(j) < max_tokens * 0.5: + messages[-1] += j + if get_prompt_tokens(messages[-1] + j) > max_tokens * 0.7: + messages.append(j) else: - messages[-1] += i + messages[-1] += j summaries = [] for i in range(len(messages)): diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index ab7b40a2536db8..a4afea4b9df429 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -1,4 +1,5 @@ import threading +from typing import Any from flask import Flask, current_app from pydantic import BaseModel, Field @@ -7,13 +8,14 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.models.document import Document as RagDocument from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment -default_retrieval_model = { +default_retrieval_model: dict[str, Any] = { "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, @@ -44,12 +46,12 @@ def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs): def _run(self, query: str) -> str: threads = [] - all_documents = [] + all_documents: list[RagDocument] = [] for dataset_id in self.dataset_ids: retrieval_thread = threading.Thread( target=self._retriever, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "dataset_id": dataset_id, "query": query, "all_documents": all_documents, @@ -77,11 +79,11 @@ def _run(self, query: str) -> str: document_score_list = {} for item in all_documents: - if item.metadata.get("score"): + if item.metadata and item.metadata.get("score"): document_score_list[item.metadata["doc_id"]] = item.metadata["score"] document_context_list = [] - index_node_ids = [document.metadata["doc_id"] for document in all_documents] + index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata] segments = DocumentSegment.query.filter( DocumentSegment.dataset_id.in_(self.dataset_ids), DocumentSegment.completed_at.isnot(None), @@ -139,6 +141,7 @@ def _run(self, query: str) -> str: hit_callback.return_retriever_resource_info(context_list) return str("\n".join(document_context_list)) + return "" def _retriever( self, diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py index dad8c773579099..a4d2de3b1c8ef3 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py @@ -1,7 +1,7 @@ from abc import abstractmethod from typing import Any, Optional -from msal_extensions.persistence import ABC +from msal_extensions.persistence import ABC # type: ignore from pydantic import BaseModel, ConfigDict from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index 987f94a35046e9..b382016473055d 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -1,3 +1,5 @@ +from typing import Any + from pydantic import BaseModel, Field from core.rag.datasource.retrieval_service import RetrievalService @@ -69,25 +71,27 @@ def _run(self, query: str) -> str: metadata=external_document.get("metadata"), provider="external", ) - document.metadata["score"] = external_document.get("score") - document.metadata["title"] = external_document.get("title") - document.metadata["dataset_id"] = dataset.id - document.metadata["dataset_name"] = dataset.name - results.append(document) + if document.metadata is not None: + document.metadata["score"] = external_document.get("score") + document.metadata["title"] = external_document.get("title") + document.metadata["dataset_id"] = dataset.id + document.metadata["dataset_name"] = dataset.name + results.append(document) # deal with external documents context_list = [] for position, item in enumerate(results, start=1): - source = { - "position": position, - "dataset_id": item.metadata.get("dataset_id"), - "dataset_name": item.metadata.get("dataset_name"), - "document_name": item.metadata.get("title"), - "data_source_type": "external", - "retriever_from": self.retriever_from, - "score": item.metadata.get("score"), - "title": item.metadata.get("title"), - "content": item.page_content, - } + if item.metadata is not None: + source = { + "position": position, + "dataset_id": item.metadata.get("dataset_id"), + "dataset_name": item.metadata.get("dataset_name"), + "document_name": item.metadata.get("title"), + "data_source_type": "external", + "retriever_from": self.retriever_from, + "score": item.metadata.get("score"), + "title": item.metadata.get("title"), + "content": item.page_content, + } context_list.append(source) for hit_callback in self.hit_callbacks: hit_callback.return_retriever_resource_info(context_list) @@ -95,7 +99,7 @@ def _run(self, query: str) -> str: return str("\n".join([item.page_content for item in results])) else: # get retrieval model , if the model is not setting , using default - retrieval_model = dataset.retrieval_model or default_retrieval_model + retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model if dataset.indexing_technique == "economy": # use keyword table query documents = RetrievalService.retrieve( @@ -113,11 +117,11 @@ def _run(self, query: str) -> str: score_threshold=retrieval_model.get("score_threshold", 0.0) if retrieval_model["score_threshold_enabled"] else 0.0, - reranking_model=retrieval_model.get("reranking_model", None) + reranking_model=retrieval_model.get("reranking_model") if retrieval_model["reranking_enable"] else None, reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", - weights=retrieval_model.get("weights", None), + weights=retrieval_model.get("weights"), ) else: documents = [] @@ -127,7 +131,7 @@ def _run(self, query: str) -> str: document_score_list = {} if dataset.indexing_technique != "economy": for item in documents: - if item.metadata.get("score"): + if item.metadata is not None and item.metadata.get("score"): document_score_list[item.metadata["doc_id"]] = item.metadata["score"] document_context_list = [] index_node_ids = [document.metadata["doc_id"] for document in documents] @@ -155,20 +159,21 @@ def _run(self, query: str) -> str: context_list = [] resource_number = 1 for segment in sorted_segments: - context = {} - document = Document.query.filter( + document_segment = Document.query.filter( Document.id == segment.document_id, Document.enabled == True, Document.archived == False, ).first() - if dataset and document: + if not document_segment: + continue + if dataset and document_segment: source = { "position": resource_number, "dataset_id": dataset.id, "dataset_name": dataset.name, - "document_id": document.id, - "document_name": document.name, - "data_source_type": document.data_source_type, + "document_id": document_segment.id, + "document_name": document_segment.name, + "data_source_type": document_segment.data_source_type, "segment_id": segment.id, "retriever_from": self.retriever_from, "score": document_score_list.get(segment.index_node_id, None), diff --git a/api/core/tools/tool/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever_tool.py index 3c9295c493c470..2d7e193e152645 100644 --- a/api/core/tools/tool/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever_tool.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Optional from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import InvokeFrom @@ -23,7 +23,7 @@ class DatasetRetrieverTool(Tool): def get_dataset_tools( tenant_id: str, dataset_ids: list[str], - retrieve_config: DatasetRetrieveConfigEntity, + retrieve_config: Optional[DatasetRetrieveConfigEntity], return_resource: bool, invoke_from: InvokeFrom, hit_callback: DatasetIndexToolCallbackHandler, @@ -51,6 +51,8 @@ def get_dataset_tools( invoke_from=invoke_from, hit_callback=hit_callback, ) + if retrieval_tools is None: + return [] # restore retrieve strategy retrieve_config.retrieve_strategy = original_retriever_mode @@ -83,6 +85,7 @@ def get_runtime_parameters(self) -> list[ToolParameter]: llm_description="Query for the dataset to be used to retrieve the dataset.", required=True, default="", + placeholder=I18nObject(en_US="", zh_Hans=""), ), ] @@ -102,7 +105,9 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe return self.create_text_message(text=result) - def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None: + def validate_credentials( + self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False + ) -> str | None: """ validate the credentials for dataset retriever tool """ diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index 8d4045038171a6..55f94d7619635b 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -91,7 +91,7 @@ def tool_provider_type(self) -> ToolProviderType: :return: the tool provider type """ - def load_variables(self, variables: ToolRuntimeVariablePool): + def load_variables(self, variables: ToolRuntimeVariablePool | None) -> None: """ load variables from database @@ -105,6 +105,8 @@ def set_image_variable(self, variable_name: str, image_key: str) -> None: """ if not self.variables: return + if self.identity is None: + return self.variables.set_file(self.identity.name, variable_name, image_key) @@ -114,6 +116,8 @@ def set_text_variable(self, variable_name: str, text: str) -> None: """ if not self.variables: return + if self.identity is None: + return self.variables.set_text(self.identity.name, variable_name, text) @@ -200,7 +204,11 @@ def list_default_image_variables(self) -> list[ToolRuntimeVariable]: def invoke(self, user_id: str, tool_parameters: Mapping[str, Any]) -> list[ToolInvokeMessage]: # update tool_parameters # TODO: Fix type error. + if self.runtime is None: + return [] if self.runtime.runtime_parameters: + # Convert Mapping to dict before updating + tool_parameters = dict(tool_parameters) tool_parameters.update(self.runtime.runtime_parameters) # try parse tool parameters into the correct type @@ -221,7 +229,7 @@ def _transform_tool_parameters_type(self, tool_parameters: Mapping[str, Any]) -> Transform tool parameters type """ # Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials - result = deepcopy(tool_parameters) + result: dict[str, Any] = deepcopy(dict(tool_parameters)) for parameter in self.parameters or []: if parameter.name in tool_parameters: result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name]) @@ -234,12 +242,15 @@ def _invoke( ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: pass - def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None: + def validate_credentials( + self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False + ) -> str | None: """ validate the credentials :param credentials: the credentials :param parameters: the parameters + :param format_only: only return the formatted """ pass diff --git a/api/core/tools/tool/workflow_tool.py b/api/core/tools/tool/workflow_tool.py index 33b4ad021a5e7f..edff4a2d07cca2 100644 --- a/api/core/tools/tool/workflow_tool.py +++ b/api/core/tools/tool/workflow_tool.py @@ -68,20 +68,20 @@ def _invoke( if data.get("error"): raise Exception(data.get("error")) - result = [] + r = [] outputs = data.get("outputs") if outputs == None: outputs = {} else: - outputs, files = self._extract_files(outputs) - for file in files: - result.append(self.create_file_message(file)) + outputs, extracted_files = self._extract_files(outputs) + for f in extracted_files: + r.append(self.create_file_message(f)) - result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False))) - result.append(self.create_json_message(outputs)) + r.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False))) + r.append(self.create_json_message(outputs)) - return result + return r def _get_user(self, user_id: str) -> Union[EndUser, Account]: """ diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index f92b43608ed935..425a892527daa4 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -3,7 +3,7 @@ from copy import deepcopy from datetime import UTC, datetime from mimetypes import guess_type -from typing import Any, Optional, Union +from typing import Any, Optional, Union, cast from yarl import URL @@ -46,7 +46,7 @@ def agent_invoke( invoke_from: InvokeFrom, agent_tool_callback: DifyAgentCallbackHandler, trace_manager: Optional[TraceQueueManager] = None, - ) -> tuple[str, list[tuple[MessageFile, bool]], ToolInvokeMeta]: + ) -> tuple[str, list[tuple[MessageFile, str]], ToolInvokeMeta]: """ Agent invokes the tool with the given arguments. """ @@ -69,6 +69,8 @@ def agent_invoke( raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") # invoke the tool + if tool.identity is None: + raise ValueError("tool identity is not set") try: # hit the callback handler agent_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters) @@ -163,6 +165,8 @@ def _invoke(tool: Tool, tool_parameters: dict, user_id: str) -> tuple[ToolInvoke """ Invoke the tool with the given arguments. """ + if tool.identity is None: + raise ValueError("tool identity is not set") started_at = datetime.now(UTC) meta = ToolInvokeMeta( time_cost=0.0, @@ -171,7 +175,7 @@ def _invoke(tool: Tool, tool_parameters: dict, user_id: str) -> tuple[ToolInvoke "tool_name": tool.identity.name, "tool_provider": tool.identity.provider, "tool_provider_type": tool.tool_provider_type().value, - "tool_parameters": deepcopy(tool.runtime.runtime_parameters), + "tool_parameters": deepcopy(tool.runtime.runtime_parameters) if tool.runtime else {}, "tool_icon": tool.identity.icon, }, ) @@ -194,9 +198,9 @@ def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str result = "" for response in tool_response: if response.type == ToolInvokeMessage.MessageType.TEXT: - result += response.message + result += str(response.message) if response.message is not None else "" elif response.type == ToolInvokeMessage.MessageType.LINK: - result += f"result link: {response.message}. please tell user to check it." + result += f"result link: {response.message!r}. please tell user to check it." elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: result += ( "image has been created and sent to user already, you do not need to create it," @@ -205,7 +209,7 @@ def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str elif response.type == ToolInvokeMessage.MessageType.JSON: result += f"tool response: {json.dumps(response.message, ensure_ascii=False)}." else: - result += f"tool response: {response.message}." + result += f"tool response: {response.message!r}." return result @@ -223,7 +227,7 @@ def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> lis mimetype = response.meta.get("mime_type") else: try: - url = URL(response.message) + url = URL(cast(str, response.message)) extension = url.suffix guess_type_result, _ = guess_type(f"a{extension}") if guess_type_result: @@ -237,7 +241,7 @@ def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> lis result.append( ToolInvokeMessageBinary( mimetype=response.meta.get("mime_type", "image/jpeg"), - url=response.message, + url=cast(str, response.message), save_as=response.save_as, ) ) @@ -245,7 +249,7 @@ def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> lis result.append( ToolInvokeMessageBinary( mimetype=response.meta.get("mime_type", "octet/stream"), - url=response.message, + url=cast(str, response.message), save_as=response.save_as, ) ) @@ -257,7 +261,7 @@ def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> lis mimetype=response.meta.get("mime_type", "octet/stream") if response.meta else "octet/stream", - url=response.message, + url=cast(str, response.message), save_as=response.save_as, ) ) diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 2a5a2944ef8471..e53985951b0627 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -84,13 +84,17 @@ def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[ if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): raise ValueError("Unsupported tool type") - provider_ids = [controller.provider_id for controller in tool_providers] + provider_ids = [ + controller.provider_id + for controller in tool_providers + if isinstance(controller, (ApiToolProviderController, WorkflowToolProviderController)) + ] labels: list[ToolLabelBinding] = ( db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all() ) - tool_labels = {label.tool_id: [] for label in labels} + tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels} for label in labels: tool_labels[label.tool_id].append(label.label_name) diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index ac333162b6bb1c..5b2173a4d0ad69 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -4,7 +4,7 @@ from collections.abc import Generator from os import listdir, path from threading import Lock, Thread -from typing import Any, Optional, Union +from typing import Any, Optional, Union, cast from configs import dify_config from core.agent.entities import AgentToolEntity @@ -15,15 +15,18 @@ from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter -from core.tools.errors import ToolProviderNotFoundError +from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError from core.tools.provider.api_tool_provider import ApiToolProviderController from core.tools.provider.builtin._positions import BuiltinToolProviderSort from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.provider.tool_provider import ToolProviderController +from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController from core.tools.tool.api_tool import ApiTool from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.tool import Tool from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager +from core.workflow.nodes.tool.entities import ToolEntity from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -33,9 +36,9 @@ class ToolManager: _builtin_provider_lock = Lock() - _builtin_providers = {} + _builtin_providers: dict[str, BuiltinToolProviderController] = {} _builtin_providers_loaded = False - _builtin_tools_labels = {} + _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {} @classmethod def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController: @@ -55,7 +58,7 @@ def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController: return cls._builtin_providers[provider] @classmethod - def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool: + def get_builtin_tool(cls, provider: str, tool_name: str) -> Union[BuiltinTool, Tool]: """ get the builtin tool @@ -66,13 +69,15 @@ def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool: """ provider_controller = cls.get_builtin_provider(provider) tool = provider_controller.get_tool(tool_name) + if tool is None: + raise ToolNotFoundError(f"tool {tool_name} not found") return tool @classmethod def get_tool( cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: Optional[str] = None - ) -> Union[BuiltinTool, ApiTool]: + ) -> Union[BuiltinTool, ApiTool, Tool]: """ get the tool @@ -103,7 +108,7 @@ def get_tool_runtime( tenant_id: str, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, - ) -> Union[BuiltinTool, ApiTool]: + ) -> Union[BuiltinTool, ApiTool, Tool]: """ get the tool runtime @@ -113,6 +118,7 @@ def get_tool_runtime( :return: the tool """ + controller: Union[BuiltinToolProviderController, ApiToolProviderController, WorkflowToolProviderController] if provider_type == "builtin": builtin_tool = cls.get_builtin_tool(provider_id, tool_name) @@ -129,7 +135,7 @@ def get_tool_runtime( ) # get credentials - builtin_provider: BuiltinToolProvider = ( + builtin_provider: Optional[BuiltinToolProvider] = ( db.session.query(BuiltinToolProvider) .filter( BuiltinToolProvider.tenant_id == tenant_id, @@ -177,7 +183,7 @@ def get_tool_runtime( } ) elif provider_type == "workflow": - workflow_provider = ( + workflow_provider: Optional[WorkflowToolProvider] = ( db.session.query(WorkflowToolProvider) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) .first() @@ -187,8 +193,13 @@ def get_tool_runtime( raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) + controller_tools: Optional[list[Tool]] = controller.get_tools( + user_id="", tenant_id=workflow_provider.tenant_id + ) + if controller_tools is None or len(controller_tools) == 0: + raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( + return controller_tools[0].fork_tool_runtime( runtime={ "tenant_id": tenant_id, "credentials": {}, @@ -215,7 +226,7 @@ def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict if parameter_rule.type == ToolParameter.ToolParameterType.SELECT: # check if tool_parameter_config in options - options = [x.value for x in parameter_rule.options] + options = [x.value for x in parameter_rule.options or []] if parameter_value is not None and parameter_value not in options: raise ValueError( f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}" @@ -267,6 +278,8 @@ def get_agent_tool_runtime( identity_id=f"AGENT.{app_id}", ) runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) + if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None: + raise ValueError("runtime not found or runtime parameters not found") tool_entity.runtime.runtime_parameters.update(runtime_parameters) return tool_entity @@ -312,6 +325,9 @@ def get_workflow_tool_runtime( if runtime_parameters: runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) + if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None: + raise ValueError("runtime not found or runtime parameters not found") + tool_entity.runtime.runtime_parameters.update(runtime_parameters) return tool_entity @@ -326,6 +342,8 @@ def get_builtin_provider_icon(cls, provider: str) -> tuple[str, str]: """ # get provider provider_controller = cls.get_builtin_provider(provider) + if provider_controller.identity is None: + raise ToolProviderNotFoundError(f"builtin provider {provider} not found") absolute_path = path.join( path.dirname(path.realpath(__file__)), @@ -381,11 +399,15 @@ def _list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, Non ), parent_type=BuiltinToolProviderController, ) - provider: BuiltinToolProviderController = provider_class() - cls._builtin_providers[provider.identity.name] = provider - for tool in provider.get_tools(): + provider_controller: BuiltinToolProviderController = provider_class() + if provider_controller.identity is None: + continue + cls._builtin_providers[provider_controller.identity.name] = provider_controller + for tool in provider_controller.get_tools() or []: + if tool.identity is None: + continue cls._builtin_tools_labels[tool.identity.name] = tool.identity.label - yield provider + yield provider_controller except Exception as e: logger.exception(f"load builtin provider {provider}") @@ -449,9 +471,11 @@ def user_list_providers( # append builtin providers for provider in builtin_providers: # handle include, exclude + if provider.identity is None: + continue if is_filtered( - include_set=dify_config.POSITION_TOOL_INCLUDES_SET, - exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, + include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET), + exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET), data=provider, name_func=lambda x: x.identity.name, ): @@ -472,7 +496,7 @@ def user_list_providers( db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() ) - api_provider_controllers = [ + api_provider_controllers: list[dict[str, Any]] = [ {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} for provider in db_api_providers ] @@ -495,7 +519,7 @@ def user_list_providers( db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() ) - workflow_provider_controllers = [] + workflow_provider_controllers: list[WorkflowToolProviderController] = [] for provider in workflow_providers: try: workflow_provider_controllers.append( @@ -505,7 +529,9 @@ def user_list_providers( # app has been deleted pass - labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers) + labels = ToolLabelManager.get_tools_labels( + [cast(ToolProviderController, controller) for controller in workflow_provider_controllers] + ) for provider_controller in workflow_provider_controllers: user_provider = ToolTransformService.workflow_provider_to_user_provider( @@ -527,7 +553,7 @@ def get_api_provider_controller( :return: the provider controller, the credentials """ - provider: ApiToolProvider = ( + provider: Optional[ApiToolProvider] = ( db.session.query(ApiToolProvider) .filter( ApiToolProvider.id == provider_id, @@ -556,7 +582,7 @@ def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: get tool provider """ provider_name = provider - provider: ApiToolProvider = ( + provider_tool: Optional[ApiToolProvider] = ( db.session.query(ApiToolProvider) .filter( ApiToolProvider.tenant_id == tenant_id, @@ -565,17 +591,18 @@ def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: .first() ) - if provider is None: + if provider_tool is None: raise ValueError(f"you have not added provider {provider_name}") try: - credentials = json.loads(provider.credentials_str) or {} + credentials = json.loads(provider_tool.credentials_str) or {} except: credentials = {} # package tool provider controller controller = ApiToolProviderController.from_db( - provider, ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE + provider_tool, + ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, ) # init tool configuration tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) @@ -584,25 +611,28 @@ def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) try: - icon = json.loads(provider.icon) + icon = json.loads(provider_tool.icon) except: icon = {"background": "#252525", "content": "\ud83d\ude01"} # add tool labels labels = ToolLabelManager.get_tool_labels(controller) - return jsonable_encoder( - { - "schema_type": provider.schema_type, - "schema": provider.schema, - "tools": provider.tools, - "icon": icon, - "description": provider.description, - "credentials": masked_credentials, - "privacy_policy": provider.privacy_policy, - "custom_disclaimer": provider.custom_disclaimer, - "labels": labels, - } + return cast( + dict, + jsonable_encoder( + { + "schema_type": provider_tool.schema_type, + "schema": provider_tool.schema, + "tools": provider_tool.tools, + "icon": icon, + "description": provider_tool.description, + "credentials": masked_credentials, + "privacy_policy": provider_tool.privacy_policy, + "custom_disclaimer": provider_tool.custom_disclaimer, + "labels": labels, + } + ), ) @classmethod @@ -617,6 +647,7 @@ def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> """ provider_type = provider_type provider_id = provider_id + provider: Optional[Union[BuiltinToolProvider, ApiToolProvider, WorkflowToolProvider]] = None if provider_type == "builtin": return ( dify_config.CONSOLE_API_URL @@ -626,16 +657,21 @@ def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> ) elif provider_type == "api": try: - provider: ApiToolProvider = ( + provider = ( db.session.query(ApiToolProvider) .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) .first() ) - return json.loads(provider.icon) + if provider is None: + raise ToolProviderNotFoundError(f"api provider {provider_id} not found") + icon = json.loads(provider.icon) + if isinstance(icon, (str, dict)): + return icon + return {"background": "#252525", "content": "\ud83d\ude01"} except: return {"background": "#252525", "content": "\ud83d\ude01"} elif provider_type == "workflow": - provider: WorkflowToolProvider = ( + provider = ( db.session.query(WorkflowToolProvider) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) .first() @@ -643,7 +679,13 @@ def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> if provider is None: raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - return json.loads(provider.icon) + try: + icon = json.loads(provider.icon) + if isinstance(icon, (str, dict)): + return icon + return {"background": "#252525", "content": "\ud83d\ude01"} + except: + return {"background": "#252525", "content": "\ud83d\ude01"} else: raise ValueError(f"provider type {provider_type} not found") diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 8b5e27f5382ee7..d7720928644701 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -72,9 +72,13 @@ def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str return a deep copy of credentials with decrypted values """ + identity_id = "" + if self.provider_controller.identity: + identity_id = f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}" + cache = ToolProviderCredentialsCache( tenant_id=self.tenant_id, - identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}", + identity_id=identity_id, cache_type=ToolProviderCredentialsCacheType.PROVIDER, ) cached_credentials = cache.get() @@ -95,9 +99,13 @@ def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str return credentials def delete_tool_credentials_cache(self): + identity_id = "" + if self.provider_controller.identity: + identity_id = f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}" + cache = ToolProviderCredentialsCache( tenant_id=self.tenant_id, - identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}", + identity_id=identity_id, cache_type=ToolProviderCredentialsCacheType.PROVIDER, ) cache.delete() @@ -199,6 +207,9 @@ def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: return a deep copy of parameters with decrypted values """ + if self.tool_runtime is None or self.tool_runtime.identity is None: + raise ValueError("tool_runtime is required") + cache = ToolParameterCache( tenant_id=self.tenant_id, provider=f"{self.provider_type}.{self.provider_name}", @@ -232,6 +243,9 @@ def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: return parameters def delete_tool_parameters_cache(self): + if self.tool_runtime is None or self.tool_runtime.identity is None: + raise ValueError("tool_runtime is required") + cache = ToolParameterCache( tenant_id=self.tenant_id, provider=f"{self.provider_type}.{self.provider_name}", diff --git a/api/core/tools/utils/feishu_api_utils.py b/api/core/tools/utils/feishu_api_utils.py index ea28037df03720..ecf60045aa8dc5 100644 --- a/api/core/tools/utils/feishu_api_utils.py +++ b/api/core/tools/utils/feishu_api_utils.py @@ -1,5 +1,5 @@ import json -from typing import Optional +from typing import Any, Optional, cast import httpx @@ -101,7 +101,7 @@ def get_tenant_access_token(self, app_id: str, app_secret: str) -> dict: """ url = f"{self.API_BASE_URL}/access_token/get_tenant_access_token" payload = {"app_id": app_id, "app_secret": app_secret} - res = self._send_request(url, require_token=False, payload=payload) + res: dict = self._send_request(url, require_token=False, payload=payload) return res def create_document(self, title: str, content: str, folder_token: str) -> dict: @@ -126,15 +126,16 @@ def create_document(self, title: str, content: str, folder_token: str) -> dict: "content": content, "folder_token": folder_token, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def write_document(self, document_id: str, content: str, position: str = "end") -> dict: url = f"{self.API_BASE_URL}/document/write_document" payload = {"document_id": document_id, "content": content, "position": position} - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) return res def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> str: @@ -155,9 +156,9 @@ def get_document_content(self, document_id: str, mode: str = "markdown", lang: s "lang": lang, } url = f"{self.API_BASE_URL}/document/get_document_content" - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data").get("content") + return cast(str, res.get("data", {}).get("content")) return "" def list_document_blocks( @@ -173,9 +174,10 @@ def list_document_blocks( "page_token": page_token, } url = f"{self.API_BASE_URL}/document/list_document_blocks" - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict: @@ -191,9 +193,10 @@ def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, "msg_type": msg_type, "content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"), } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict: @@ -203,7 +206,7 @@ def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dic "msg_type": msg_type, "content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"), } - res = self._send_request(url, require_token=False, payload=payload) + res: dict = self._send_request(url, require_token=False, payload=payload) return res def get_chat_messages( @@ -227,9 +230,10 @@ def get_chat_messages( "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_thread_messages( @@ -245,9 +249,10 @@ def get_thread_messages( "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_task(self, summary: str, start_time: str, end_time: str, completed_time: str, description: str) -> dict: @@ -260,9 +265,10 @@ def create_task(self, summary: str, start_time: str, end_time: str, completed_ti "completed_at": completed_time, "description": description, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def update_task( @@ -278,9 +284,10 @@ def update_task( "completed_time": completed_time, "description": description, } - res = self._send_request(url, method="PATCH", payload=payload) + res: dict = self._send_request(url, method="PATCH", payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def delete_task(self, task_guid: str) -> dict: @@ -289,7 +296,7 @@ def delete_task(self, task_guid: str) -> dict: payload = { "task_guid": task_guid, } - res = self._send_request(url, method="DELETE", payload=payload) + res: dict = self._send_request(url, method="DELETE", payload=payload) return res def add_members(self, task_guid: str, member_phone_or_email: str, member_role: str) -> dict: @@ -300,7 +307,7 @@ def add_members(self, task_guid: str, member_phone_or_email: str, member_role: s "member_phone_or_email": member_phone_or_email, "member_role": member_role, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) return res def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, page_size: int = 20) -> dict: @@ -312,9 +319,10 @@ def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_primary_calendar(self, user_id_type: str = "open_id") -> dict: @@ -322,9 +330,10 @@ def get_primary_calendar(self, user_id_type: str = "open_id") -> dict: params = { "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_event( @@ -347,9 +356,10 @@ def create_event( "auto_record": auto_record, "attendee_ability": attendee_ability, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def update_event( @@ -363,7 +373,7 @@ def update_event( auto_record: bool, ) -> dict: url = f"{self.API_BASE_URL}/calendar/update_event/{event_id}" - payload = {} + payload: dict[str, Any] = {} if summary: payload["summary"] = summary if description: @@ -376,7 +386,7 @@ def update_event( payload["need_notification"] = need_notification if auto_record: payload["auto_record"] = auto_record - res = self._send_request(url, method="PATCH", payload=payload) + res: dict = self._send_request(url, method="PATCH", payload=payload) return res def delete_event(self, event_id: str, need_notification: bool = True) -> dict: @@ -384,7 +394,7 @@ def delete_event(self, event_id: str, need_notification: bool = True) -> dict: params = { "need_notification": need_notification, } - res = self._send_request(url, method="DELETE", params=params) + res: dict = self._send_request(url, method="DELETE", params=params) return res def list_events(self, start_time: str, end_time: str, page_token: str, page_size: int = 50) -> dict: @@ -395,9 +405,10 @@ def list_events(self, start_time: str, end_time: str, page_token: str, page_size "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def search_events( @@ -418,9 +429,10 @@ def search_events( "user_id_type": user_id_type, "page_size": page_size, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_notification: bool = True) -> dict: @@ -431,9 +443,10 @@ def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_ "attendee_phone_or_email": attendee_phone_or_email, "need_notification": need_notification, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_spreadsheet( @@ -447,9 +460,10 @@ def create_spreadsheet( "title": title, "folder_token": folder_token, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_spreadsheet( @@ -463,9 +477,10 @@ def get_spreadsheet( "spreadsheet_token": spreadsheet_token, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def list_spreadsheet_sheets( @@ -477,9 +492,10 @@ def list_spreadsheet_sheets( params = { "spreadsheet_token": spreadsheet_token, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_rows( @@ -499,9 +515,10 @@ def add_rows( "length": length, "values": values, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_cols( @@ -521,9 +538,10 @@ def add_cols( "length": length, "values": values, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_rows( @@ -545,9 +563,10 @@ def read_rows( "num_rows": num_rows, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_cols( @@ -569,9 +588,10 @@ def read_cols( "num_cols": num_cols, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_table( @@ -593,9 +613,10 @@ def read_table( "query": query, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_base( @@ -609,9 +630,10 @@ def create_base( "name": name, "folder_token": folder_token, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_records( @@ -633,9 +655,10 @@ def add_records( payload = { "records": convert_add_records(records), } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def update_records( @@ -657,9 +680,10 @@ def update_records( payload = { "records": convert_update_records(records), } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def delete_records( @@ -686,9 +710,10 @@ def delete_records( payload = { "records": record_id_list, } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def search_record( @@ -740,7 +765,7 @@ def search_record( except json.JSONDecodeError: raise ValueError("The input string is not valid JSON") - payload = {} + payload: dict[str, Any] = {} if view_id: payload["view_id"] = view_id @@ -752,10 +777,11 @@ def search_record( payload["filter"] = filter_dict if automatic_fields: payload["automatic_fields"] = automatic_fields - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_base_info( @@ -767,9 +793,10 @@ def get_base_info( params = { "app_token": app_token, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_table( @@ -797,9 +824,10 @@ def create_table( } if default_view_name: payload["default_view_name"] = default_view_name - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def delete_tables( @@ -834,9 +862,10 @@ def delete_tables( "table_names": table_name_list, } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def list_tables( @@ -852,9 +881,10 @@ def list_tables( "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_records( @@ -882,7 +912,8 @@ def read_records( "record_ids": record_id_list, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params, payload=payload) + res: dict = self._send_request(url, method="GET", params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res diff --git a/api/core/tools/utils/lark_api_utils.py b/api/core/tools/utils/lark_api_utils.py index 30cb0cb141d9a6..de394a39bf5a00 100644 --- a/api/core/tools/utils/lark_api_utils.py +++ b/api/core/tools/utils/lark_api_utils.py @@ -1,5 +1,5 @@ import json -from typing import Optional +from typing import Any, Optional, cast import httpx @@ -62,12 +62,10 @@ def convert_update_records(self, json_str): def tenant_access_token(self) -> str: feishu_tenant_access_token = f"tools:{self.app_id}:feishu_tenant_access_token" if redis_client.exists(feishu_tenant_access_token): - return redis_client.get(feishu_tenant_access_token).decode() - res = self.get_tenant_access_token(self.app_id, self.app_secret) + return str(redis_client.get(feishu_tenant_access_token).decode()) + res: dict[str, str] = self.get_tenant_access_token(self.app_id, self.app_secret) redis_client.setex(feishu_tenant_access_token, res.get("expire"), res.get("tenant_access_token")) - if "tenant_access_token" in res: - return res.get("tenant_access_token") - return "" + return res.get("tenant_access_token", "") def _send_request( self, @@ -91,7 +89,7 @@ def _send_request( def get_tenant_access_token(self, app_id: str, app_secret: str) -> dict: url = f"{self.API_BASE_URL}/access_token/get_tenant_access_token" payload = {"app_id": app_id, "app_secret": app_secret} - res = self._send_request(url, require_token=False, payload=payload) + res: dict = self._send_request(url, require_token=False, payload=payload) return res def create_document(self, title: str, content: str, folder_token: str) -> dict: @@ -101,15 +99,16 @@ def create_document(self, title: str, content: str, folder_token: str) -> dict: "content": content, "folder_token": folder_token, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def write_document(self, document_id: str, content: str, position: str = "end") -> dict: url = f"{self.API_BASE_URL}/document/write_document" payload = {"document_id": document_id, "content": content, "position": position} - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) return res def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> str | dict: @@ -119,9 +118,9 @@ def get_document_content(self, document_id: str, mode: str = "markdown", lang: s "lang": lang, } url = f"{self.API_BASE_URL}/document/get_document_content" - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data").get("content") + return cast(dict, res.get("data", {}).get("content")) return "" def list_document_blocks( @@ -134,9 +133,10 @@ def list_document_blocks( "page_token": page_token, } url = f"{self.API_BASE_URL}/document/list_document_blocks" - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict: @@ -149,9 +149,10 @@ def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, "msg_type": msg_type, "content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"), } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict: @@ -161,7 +162,7 @@ def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dic "msg_type": msg_type, "content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"), } - res = self._send_request(url, require_token=False, payload=payload) + res: dict = self._send_request(url, require_token=False, payload=payload) return res def get_chat_messages( @@ -182,9 +183,10 @@ def get_chat_messages( "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_thread_messages( @@ -197,9 +199,10 @@ def get_thread_messages( "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_task(self, summary: str, start_time: str, end_time: str, completed_time: str, description: str) -> dict: @@ -211,9 +214,10 @@ def create_task(self, summary: str, start_time: str, end_time: str, completed_ti "completed_at": completed_time, "description": description, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def update_task( @@ -228,9 +232,10 @@ def update_task( "completed_time": completed_time, "description": description, } - res = self._send_request(url, method="PATCH", payload=payload) + res: dict = self._send_request(url, method="PATCH", payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def delete_task(self, task_guid: str) -> dict: @@ -238,9 +243,10 @@ def delete_task(self, task_guid: str) -> dict: payload = { "task_guid": task_guid, } - res = self._send_request(url, method="DELETE", payload=payload) + res: dict = self._send_request(url, method="DELETE", payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_members(self, task_guid: str, member_phone_or_email: str, member_role: str) -> dict: @@ -250,9 +256,10 @@ def add_members(self, task_guid: str, member_phone_or_email: str, member_role: s "member_phone_or_email": member_phone_or_email, "member_role": member_role, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, page_size: int = 20) -> dict: @@ -263,9 +270,10 @@ def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_primary_calendar(self, user_id_type: str = "open_id") -> dict: @@ -273,9 +281,10 @@ def get_primary_calendar(self, user_id_type: str = "open_id") -> dict: params = { "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_event( @@ -298,9 +307,10 @@ def create_event( "auto_record": auto_record, "attendee_ability": attendee_ability, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def update_event( @@ -314,7 +324,7 @@ def update_event( auto_record: bool, ) -> dict: url = f"{self.API_BASE_URL}/calendar/update_event/{event_id}" - payload = {} + payload: dict[str, Any] = {} if summary: payload["summary"] = summary if description: @@ -327,7 +337,7 @@ def update_event( payload["need_notification"] = need_notification if auto_record: payload["auto_record"] = auto_record - res = self._send_request(url, method="PATCH", payload=payload) + res: dict = self._send_request(url, method="PATCH", payload=payload) return res def delete_event(self, event_id: str, need_notification: bool = True) -> dict: @@ -335,7 +345,7 @@ def delete_event(self, event_id: str, need_notification: bool = True) -> dict: params = { "need_notification": need_notification, } - res = self._send_request(url, method="DELETE", params=params) + res: dict = self._send_request(url, method="DELETE", params=params) return res def list_events(self, start_time: str, end_time: str, page_token: str, page_size: int = 50) -> dict: @@ -346,9 +356,10 @@ def list_events(self, start_time: str, end_time: str, page_token: str, page_size "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def search_events( @@ -369,9 +380,10 @@ def search_events( "user_id_type": user_id_type, "page_size": page_size, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_notification: bool = True) -> dict: @@ -381,9 +393,10 @@ def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_ "attendee_phone_or_email": attendee_phone_or_email, "need_notification": need_notification, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_spreadsheet( @@ -396,9 +409,10 @@ def create_spreadsheet( "title": title, "folder_token": folder_token, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_spreadsheet( @@ -411,9 +425,10 @@ def get_spreadsheet( "spreadsheet_token": spreadsheet_token, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def list_spreadsheet_sheets( @@ -424,9 +439,10 @@ def list_spreadsheet_sheets( params = { "spreadsheet_token": spreadsheet_token, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_rows( @@ -445,9 +461,10 @@ def add_rows( "length": length, "values": values, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_cols( @@ -466,9 +483,10 @@ def add_cols( "length": length, "values": values, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_rows( @@ -489,9 +507,10 @@ def read_rows( "num_rows": num_rows, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_cols( @@ -512,9 +531,10 @@ def read_cols( "num_cols": num_cols, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_table( @@ -535,9 +555,10 @@ def read_table( "query": query, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_base( @@ -550,9 +571,10 @@ def create_base( "name": name, "folder_token": folder_token, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_records( @@ -573,9 +595,10 @@ def add_records( payload = { "records": self.convert_add_records(records), } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def update_records( @@ -596,9 +619,10 @@ def update_records( payload = { "records": self.convert_update_records(records), } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def delete_records( @@ -624,9 +648,10 @@ def delete_records( payload = { "records": record_id_list, } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def search_record( @@ -678,7 +703,7 @@ def search_record( except json.JSONDecodeError: raise ValueError("The input string is not valid JSON") - payload = {} + payload: dict[str, Any] = {} if view_id: payload["view_id"] = view_id @@ -690,9 +715,10 @@ def search_record( payload["filter"] = filter_dict if automatic_fields: payload["automatic_fields"] = automatic_fields - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_base_info( @@ -703,9 +729,10 @@ def get_base_info( params = { "app_token": app_token, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_table( @@ -732,9 +759,10 @@ def create_table( } if default_view_name: payload["default_view_name"] = default_view_name - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def delete_tables( @@ -767,9 +795,10 @@ def delete_tables( "table_ids": table_id_list, "table_names": table_name_list, } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def list_tables( @@ -784,9 +813,10 @@ def list_tables( "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_records( @@ -814,7 +844,8 @@ def read_records( "record_ids": record_id_list, "user_id_type": user_id_type, } - res = self._send_request(url, method="POST", params=params, payload=payload) + res: dict = self._send_request(url, method="POST", params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index e30c903a4b1146..3509f1e6e59f77 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -90,12 +90,12 @@ def transform_tool_invoke_messages( ) elif message.type == ToolInvokeMessage.MessageType.FILE: assert message.meta is not None - file = message.meta.get("file") - if isinstance(file, File): - if file.transfer_method == FileTransferMethod.TOOL_FILE: - assert file.related_id is not None - url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension) - if file.type == FileType.IMAGE: + file_mata = message.meta.get("file") + if isinstance(file_mata, File): + if file_mata.transfer_method == FileTransferMethod.TOOL_FILE: + assert file_mata.related_id is not None + url = cls.get_tool_file_url(tool_file_id=file_mata.related_id, extension=file_mata.extension) + if file_mata.type == FileType.IMAGE: result.append( ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 4e226810d6ac90..3689dcc9e5ebfd 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -5,7 +5,7 @@ """ import json -from typing import cast +from typing import Optional, cast from core.model_manager import ModelManager from core.model_runtime.entities.llm_entities import LLMResult @@ -51,7 +51,7 @@ def get_max_llm_context_tokens( if not schema: raise InvokeModelError("No model schema found") - max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) + max_tokens: Optional[int] = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) if max_tokens is None: return 2048 @@ -133,14 +133,17 @@ def invoke( db.session.commit() try: - response: LLMResult = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=[], - stop=[], - stream=False, - user=user_id, - callbacks=[], + response: LLMResult = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=[], + stop=[], + stream=False, + user=user_id, + callbacks=[], + ), ) except InvokeRateLimitError as e: raise InvokeModelError(f"Invoke rate limit error: {e}") diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index ae44b1b99d447a..f1dc1123b9935f 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -6,7 +6,7 @@ from typing import Optional from requests import get -from yaml import YAMLError, safe_load +from yaml import YAMLError, safe_load # type: ignore from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle @@ -64,6 +64,9 @@ def parse_openapi_to_tool_bundle( default=parameter["schema"]["default"] if "schema" in parameter and "default" in parameter["schema"] else None, + placeholder=I18nObject( + en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "") + ), ) # check if there is a type @@ -108,6 +111,9 @@ def parse_openapi_to_tool_bundle( form=ToolParameter.ToolParameterForm.LLM, llm_description=property.get("description", ""), default=property.get("default", None), + placeholder=I18nObject( + en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "") + ), ) # check if there is a type @@ -158,9 +164,9 @@ def parse_openapi_to_tool_bundle( return bundles @staticmethod - def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType: + def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]: parameter = parameter or {} - typ = None + typ: Optional[str] = None if parameter.get("format") == "binary": return ToolParameter.ToolParameterType.FILE @@ -175,6 +181,8 @@ def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType return ToolParameter.ToolParameterType.BOOLEAN elif typ == "string": return ToolParameter.ToolParameterType.STRING + else: + return None @staticmethod def parse_openapi_yaml_to_tool_bundle( @@ -236,7 +244,8 @@ def parse_swagger_to_openapi(swagger: dict, extra_info: Optional[dict], warning: if ("summary" not in operation or len(operation["summary"]) == 0) and ( "description" not in operation or len(operation["description"]) == 0 ): - warning["missing_summary"] = f"No summary or description found in operation {method} {path}." + if warning is not None: + warning["missing_summary"] = f"No summary or description found in operation {method} {path}." openapi["paths"][path][method] = { "operationId": operation["operationId"], diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index 3aae31e93a1304..d42fd99fce5e80 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -9,13 +9,13 @@ import unicodedata from contextlib import contextmanager from pathlib import Path -from typing import Optional +from typing import Any, Literal, Optional, cast from urllib.parse import unquote import chardet -import cloudscraper -from bs4 import BeautifulSoup, CData, Comment, NavigableString -from regex import regex +import cloudscraper # type: ignore +from bs4 import BeautifulSoup, CData, Comment, NavigableString # type: ignore +from regex import regex # type: ignore from core.helper import ssrf_proxy from core.rag.extractor import extract_processor @@ -68,7 +68,7 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str: return "Unsupported content-type [{}] of URL.".format(main_content_type) if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: - return ExtractProcessor.load_from_url(url, return_text=True) + return cast(str, ExtractProcessor.load_from_url(url, return_text=True)) response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) elif response.status_code == 403: @@ -125,7 +125,7 @@ def extract_using_readabilipy(html): os.unlink(article_json_path) os.unlink(html_path) - article_json = { + article_json: dict[str, Any] = { "title": None, "byline": None, "date": None, @@ -300,7 +300,7 @@ def strip_control_characters(text): def normalize_unicode(text): """Normalize unicode such that things that are visually equivalent map to the same unicode string where possible.""" - normal_form = "NFKC" + normal_form: Literal["NFC", "NFD", "NFKC", "NFKD"] = "NFKC" text = unicodedata.normalize(normal_form, text) return text @@ -332,6 +332,7 @@ def add_content_digest(element): def content_digest(element): + digest: Any if is_text(element): # Hash trimmed_string = element.string.strip() diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index d92bfb9b90a9aa..08a112cfdb2b91 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -7,7 +7,7 @@ class WorkflowToolConfigurationUtils: @classmethod - def check_parameter_configurations(cls, configurations: Mapping[str, Any]): + def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]): for configuration in configurations: WorkflowToolParameterConfiguration.model_validate(configuration) @@ -27,7 +27,7 @@ def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[Vari @classmethod def check_is_synced( cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration] - ) -> None: + ) -> bool: """ check is synced diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py index 42c7f85bc6daeb..ee7ca11e056625 100644 --- a/api/core/tools/utils/yaml_utils.py +++ b/api/core/tools/utils/yaml_utils.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Any -import yaml +import yaml # type: ignore from yaml import YAMLError logger = logging.getLogger(__name__) diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index 973e420961bb46..c32815b24d02ed 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -1,4 +1,5 @@ from collections.abc import Sequence +from typing import cast from uuid import uuid4 from pydantic import Field @@ -78,7 +79,7 @@ class SecretVariable(StringVariable): @property def log(self) -> str: - return encrypter.obfuscated_token(self.value) + return cast(str, encrypter.obfuscated_token(self.value)) class NoneVariable(NoneSegment, Variable): diff --git a/api/core/workflow/callbacks/workflow_logging_callback.py b/api/core/workflow/callbacks/workflow_logging_callback.py index ed737e7316973c..b9c6b35ad3476a 100644 --- a/api/core/workflow/callbacks/workflow_logging_callback.py +++ b/api/core/workflow/callbacks/workflow_logging_callback.py @@ -33,7 +33,7 @@ class WorkflowLoggingCallback(WorkflowCallback): def __init__(self) -> None: - self.current_node_id = None + self.current_node_id: Optional[str] = None def on_event(self, event: GraphEngineEvent) -> None: if isinstance(event, GraphRunStartedEvent): diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index ca01dcd7d8d4a8..ae5f117bf9b121 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -36,9 +36,9 @@ class NodeRunResult(BaseModel): status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING inputs: Optional[Mapping[str, Any]] = None # node inputs - process_data: Optional[dict[str, Any]] = None # process data + process_data: Optional[Mapping[str, Any]] = None # process data outputs: Optional[Mapping[str, Any]] = None # node outputs - metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata + metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # node metadata llm_usage: Optional[LLMUsage] = None # llm usage edge_source_handle: Optional[str] = None # source handle id of node with multiple branches diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py index bc3a15bd004ace..b8470aecbd83a2 100644 --- a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py @@ -5,7 +5,7 @@ class ConditionRunConditionHandlerHandler(RunConditionHandler): - def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool: + def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState): """ Check if the condition can be executed diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 800dd136afb57f..b3bcc3b2ccc309 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -1,4 +1,5 @@ import uuid +from collections import defaultdict from collections.abc import Mapping from typing import Any, Optional, cast @@ -310,26 +311,17 @@ def _recursively_add_parallels( parallel = None if len(target_node_edges) > 1: # fetch all node ids in current parallels - parallel_branch_node_ids = {} - condition_edge_mappings = {} + parallel_branch_node_ids = defaultdict(list) + condition_edge_mappings = defaultdict(list) for graph_edge in target_node_edges: if graph_edge.run_condition is None: - if "default" not in parallel_branch_node_ids: - parallel_branch_node_ids["default"] = [] - parallel_branch_node_ids["default"].append(graph_edge.target_node_id) else: condition_hash = graph_edge.run_condition.hash - if condition_hash not in condition_edge_mappings: - condition_edge_mappings[condition_hash] = [] - condition_edge_mappings[condition_hash].append(graph_edge) for condition_hash, graph_edges in condition_edge_mappings.items(): if len(graph_edges) > 1: - if condition_hash not in parallel_branch_node_ids: - parallel_branch_node_ids[condition_hash] = [] - for graph_edge in graph_edges: parallel_branch_node_ids[condition_hash].append(graph_edge.target_node_id) @@ -418,7 +410,7 @@ def _recursively_add_parallels( if condition_edge_mappings: for condition_hash, graph_edges in condition_edge_mappings.items(): for graph_edge in graph_edges: - current_parallel: GraphParallel | None = cls._get_current_parallel( + current_parallel = cls._get_current_parallel( parallel_mapping=parallel_mapping, graph_edge=graph_edge, parallel=condition_parallels.get(condition_hash), diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 854036b2c13212..db1e01f14fda59 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -40,6 +40,7 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.nodes import NodeType from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor +from core.workflow.nodes.answer.base_stream_processor import StreamProcessor from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base.entities import BaseNodeData from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor @@ -66,7 +67,7 @@ def __init__( self.max_submit_count = max_submit_count self.submit_count = 0 - def submit(self, fn, *args, **kwargs): + def submit(self, fn, /, *args, **kwargs): self.submit_count += 1 self.check_is_full() @@ -140,7 +141,8 @@ def __init__( def run(self) -> Generator[GraphEngineEvent, None, None]: # trigger graph run start event yield GraphRunStartedEvent() - handle_exceptions = [] + handle_exceptions: list[str] = [] + stream_processor: StreamProcessor try: if self.init_params.workflow_type == WorkflowType.CHAT: @@ -168,7 +170,7 @@ def run(self) -> Generator[GraphEngineEvent, None, None]: elif isinstance(item, NodeRunSucceededEvent): if item.node_type == NodeType.END: self.graph_runtime_state.outputs = ( - item.route_node_state.node_run_result.outputs + dict(item.route_node_state.node_run_result.outputs) if item.route_node_state.node_run_result and item.route_node_state.node_run_result.outputs else {} @@ -350,7 +352,7 @@ def _run( if any(edge.run_condition for edge in edge_mappings): # if nodes has run conditions, get node id which branch to take based on the run condition results - condition_edge_mappings = {} + condition_edge_mappings: dict[str, list[GraphEdge]] = {} for edge in edge_mappings: if edge.run_condition: run_condition_hash = edge.run_condition.hash @@ -364,6 +366,9 @@ def _run( continue edge = cast(GraphEdge, sub_edge_mappings[0]) + if edge.run_condition is None: + logger.warning(f"Edge {edge.target_node_id} run condition is None") + continue result = ConditionManager.get_condition_handler( init_params=self.init_params, @@ -387,11 +392,11 @@ def _run( handle_exceptions=handle_exceptions, ) - for item in parallel_generator: - if isinstance(item, str): - final_node_id = item + for parallel_result in parallel_generator: + if isinstance(parallel_result, str): + final_node_id = parallel_result else: - yield item + yield parallel_result break @@ -413,11 +418,11 @@ def _run( handle_exceptions=handle_exceptions, ) - for item in parallel_generator: - if isinstance(item, str): - final_node_id = item + for generated_item in parallel_generator: + if isinstance(generated_item, str): + final_node_id = generated_item else: - yield item + yield generated_item if not final_node_id: break @@ -653,7 +658,7 @@ def _run_node( parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - error=run_result.error, + error=run_result.error or "Unknown error", retry_index=retries, start_at=retry_start_at, ) @@ -732,20 +737,20 @@ def _run_node( variable_value=variable_value, ) - # add parallel info to run result metadata - if parallel_id and parallel_start_node_id: - if not run_result.metadata: - run_result.metadata = {} + # When setting metadata, convert to dict first + if not run_result.metadata: + run_result.metadata = {} - run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id - run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = ( - parallel_start_node_id - ) + if parallel_id and parallel_start_node_id: + metadata_dict = dict(run_result.metadata) + metadata_dict[NodeRunMetadataKey.PARALLEL_ID] = parallel_id + metadata_dict[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id if parent_parallel_id and parent_parallel_start_node_id: - run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id - run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = ( + metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id + metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = ( parent_parallel_start_node_id ) + run_result.metadata = metadata_dict yield NodeRunSucceededEvent( id=node_instance.id, @@ -869,8 +874,8 @@ def _handle_continue_on_error( variable_pool.add([node_instance.node_id, "error_message"], error_result.error) variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type) # add error message to handle_exceptions - handle_exceptions.append(error_result.error) - node_error_args = { + handle_exceptions.append(error_result.error or "") + node_error_args: dict[str, Any] = { "status": WorkflowNodeExecutionStatus.EXCEPTION, "error": error_result.error, "inputs": error_result.inputs, diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index ed033e7f283961..40213bd151f7af 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -63,7 +63,7 @@ def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generat self._remove_unreachable_nodes(event) # generate stream outputs - yield from self._generate_stream_outputs_when_node_finished(event) + yield from self._generate_stream_outputs_when_node_finished(cast(NodeRunSucceededEvent, event)) else: yield event @@ -130,7 +130,7 @@ def _generate_stream_outputs_when_node_finished( node_type=event.node_type, node_data=event.node_data, chunk_content=text, - from_variable_selector=value_selector, + from_variable_selector=list(value_selector), route_node_state=event.route_node_state, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py index d785397e130565..8ffb487ec108f8 100644 --- a/api/core/workflow/nodes/answer/base_stream_processor.py +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -3,7 +3,7 @@ from collections.abc import Generator from core.workflow.entities.variable_pool import VariablePool -from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent +from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent from core.workflow.graph_engine.entities.graph import Graph logger = logging.getLogger(__name__) @@ -19,7 +19,7 @@ def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]: raise NotImplementedError - def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None: + def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent) -> None: finished_node_id = event.route_node_state.node_id if finished_node_id not in self.rest_node_ids: return @@ -32,8 +32,8 @@ def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None: return if run_result.edge_source_handle: - reachable_node_ids = [] - unreachable_first_node_ids = [] + reachable_node_ids: list[str] = [] + unreachable_first_node_ids: list[str] = [] if finished_node_id not in self.graph.edge_mapping: logger.warning(f"node {finished_node_id} has no edge mapping") return diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index 529fd7be74e9a1..6bf8899f5d698b 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -38,7 +38,8 @@ def _parse_json(value: str) -> Any: @staticmethod def _validate_array(value: Any, element_type: DefaultValueType) -> bool: """Unified array type validation""" - return isinstance(value, list) and all(isinstance(x, element_type) for x in value) + # FIXME, type ignore here for do not find the reason mypy complain, if find the root cause, please fix it + return isinstance(value, list) and all(isinstance(x, element_type) for x in value) # type: ignore @staticmethod def _convert_number(value: str) -> float: @@ -84,7 +85,7 @@ def validate_value_type(self) -> "DefaultValue": }, } - validator = type_validators.get(self.type) + validator: dict[str, Any] = type_validators.get(self.type, {}) if not validator: if self.type == DefaultValueType.ARRAY_FILES: # Handle files type diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 4e371ca43645a5..2f82bf8c382b55 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -125,7 +125,7 @@ def _transform_result( if depth > dify_config.CODE_MAX_DEPTH: raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.") - transformed_result = {} + transformed_result: dict[str, Any] = {} if output_schema is None: # validate output thought instance type for output_name, output_value in result.items(): diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index e78183baf12389..a4540358883210 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -14,7 +14,7 @@ class CodeNodeData(BaseNodeData): class Output(BaseModel): type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"] - children: Optional[dict[str, "Output"]] = None + children: Optional[dict[str, "CodeNodeData.Output"]] = None class Dependency(BaseModel): name: str diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 6d82dbe6d70da3..0b1dc611c59da2 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -4,6 +4,7 @@ import logging import os import tempfile +from typing import cast import docx import pandas as pd @@ -159,7 +160,7 @@ def _extract_text_from_yaml(file_content: bytes) -> str: """Extract the content from yaml file""" try: yaml_data = yaml.safe_load_all(file_content.decode("utf-8", "ignore")) - return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False) + return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)) except (UnicodeDecodeError, yaml.YAMLError) as e: raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e @@ -229,9 +230,9 @@ def _download_file_content(file: File) -> bytes: raise FileDownloadError("Missing URL for remote file") response = ssrf_proxy.get(file.remote_url) response.raise_for_status() - return response.content + return cast(bytes, response.content) else: - return file_manager.download(file) + return cast(bytes, file_manager.download(file)) except Exception as e: raise FileDownloadError(f"Error downloading file: {str(e)}") from e diff --git a/api/core/workflow/nodes/end/end_stream_generate_router.py b/api/core/workflow/nodes/end/end_stream_generate_router.py index 0db1ba9f09d36e..b3678a82b73959 100644 --- a/api/core/workflow/nodes/end/end_stream_generate_router.py +++ b/api/core/workflow/nodes/end/end_stream_generate_router.py @@ -67,7 +67,7 @@ def extract_stream_variable_selector_from_node_data( and node_type == NodeType.LLM.value and variable_selector.value_selector[1] == "text" ): - value_selectors.append(variable_selector.value_selector) + value_selectors.append(list(variable_selector.value_selector)) return value_selectors @@ -119,8 +119,7 @@ def _recursive_fetch_end_dependencies( current_node_id: str, end_node_id: str, node_id_config_mapping: dict[str, dict], - reverse_edge_mapping: dict[str, list["GraphEdge"]], - # type: ignore[name-defined] + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] end_dependencies: dict[str, list[str]], ) -> None: """ diff --git a/api/core/workflow/nodes/end/end_stream_processor.py b/api/core/workflow/nodes/end/end_stream_processor.py index 1aecf863ac5fb9..a770eb951f6c8c 100644 --- a/api/core/workflow/nodes/end/end_stream_processor.py +++ b/api/core/workflow/nodes/end/end_stream_processor.py @@ -23,7 +23,7 @@ def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: self.route_position[end_node_id] = 0 self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {} self.has_output = False - self.output_node_ids = set() + self.output_node_ids: set[str] = set() def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]: for event in generator: diff --git a/api/core/workflow/nodes/event/event.py b/api/core/workflow/nodes/event/event.py index 137b47655102af..9fea3fbda3141f 100644 --- a/api/core/workflow/nodes/event/event.py +++ b/api/core/workflow/nodes/event/event.py @@ -42,6 +42,6 @@ class RunRetryEvent(BaseModel): class SingleStepRetryEvent(NodeRunResult): """Single step retry event""" - status: str = WorkflowNodeExecutionStatus.RETRY.value + status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RETRY elapsed_time: float = Field(..., description="elapsed time") diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 575db15d365efb..cdfdc6e6d51b77 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -107,9 +107,9 @@ def _init_params(self): if not (key := key.strip()): continue - value = value[0].strip() if value else "" + value_str = value[0].strip() if value else "" result.append( - (self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value).text) + (self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value_str).text) ) self.params = result @@ -182,9 +182,10 @@ def _init_body(self): self.variable_pool.convert_template(item.key).text: item.file for item in filter(lambda item: item.type == "file", data) } + files: dict[str, Any] = {} files = {k: self.variable_pool.get_file(selector) for k, selector in file_selectors.items()} files = {k: v for k, v in files.items() if v is not None} - files = {k: variable.value for k, variable in files.items()} + files = {k: variable.value for k, variable in files.items() if variable is not None} files = { k: (v.filename, file_manager.download(v), v.mime_type or "application/octet-stream") for k, v in files.items() @@ -258,7 +259,8 @@ def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: response = getattr(ssrf_proxy, self.method)(**request_args) except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e: raise HttpRequestNodeError(str(e)) - return response + # FIXME: fix type ignore, this maybe httpx type issue + return response # type: ignore def invoke(self) -> Response: # assemble headers @@ -300,37 +302,37 @@ def to_log(self): continue raw += f"{k}: {v}\r\n" - body = "" + body_string = "" if self.files: for k, v in self.files.items(): - body += f"--{boundary}\r\n" - body += f'Content-Disposition: form-data; name="{k}"\r\n\r\n' - body += f"{v[1]}\r\n" - body += f"--{boundary}--\r\n" + body_string += f"--{boundary}\r\n" + body_string += f'Content-Disposition: form-data; name="{k}"\r\n\r\n' + body_string += f"{v[1]}\r\n" + body_string += f"--{boundary}--\r\n" elif self.node_data.body: if self.content: if isinstance(self.content, str): - body = self.content + body_string = self.content elif isinstance(self.content, bytes): - body = self.content.decode("utf-8", errors="replace") + body_string = self.content.decode("utf-8", errors="replace") elif self.data and self.node_data.body.type == "x-www-form-urlencoded": - body = urlencode(self.data) + body_string = urlencode(self.data) elif self.data and self.node_data.body.type == "form-data": for key, value in self.data.items(): - body += f"--{boundary}\r\n" - body += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' - body += f"{value}\r\n" - body += f"--{boundary}--\r\n" + body_string += f"--{boundary}\r\n" + body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' + body_string += f"{value}\r\n" + body_string += f"--{boundary}--\r\n" elif self.json: - body = json.dumps(self.json) + body_string = json.dumps(self.json) elif self.node_data.body.type == "raw-text": if len(self.node_data.body.data) != 1: raise RequestBodyError("raw-text body type should have exactly one item") - body = self.node_data.body.data[0].value - if body: - raw += f"Content-Length: {len(body)}\r\n" + body_string = self.node_data.body.data[0].value + if body_string: + raw += f"Content-Length: {len(body_string)}\r\n" raw += "\r\n" # Empty line between headers and body - raw += body + raw += body_string return raw diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index ebed690f6f3ffb..861119f26cb088 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -1,7 +1,7 @@ import logging import mimetypes from collections.abc import Mapping, Sequence -from typing import Any +from typing import Any, Optional from configs import dify_config from core.file import File, FileTransferMethod @@ -36,7 +36,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): _node_type = NodeType.HTTP_REQUEST @classmethod - def get_default_config(cls, filters: dict | None = None) -> dict: + def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict: return { "type": "http-request", "config": { @@ -160,8 +160,8 @@ def _extract_variable_selector_to_variable_mapping( ) mapping = {} - for selector in selectors: - mapping[node_id + "." + selector.variable] = selector.value_selector + for selector_iter in selectors: + mapping[node_id + "." + selector_iter.variable] = selector_iter.value_selector return mapping diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 6a89cbfad61684..f1289558fffa82 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -361,13 +361,16 @@ def _handle_event_metadata( metadata = event.route_node_state.node_run_result.metadata if not metadata: metadata = {} - if NodeRunMetadataKey.ITERATION_ID not in metadata: - metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id - if self.node_data.is_parallel: - metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id - else: - metadata[NodeRunMetadataKey.ITERATION_INDEX] = iter_run_index + metadata = { + **metadata, + NodeRunMetadataKey.ITERATION_ID: self.node_id, + NodeRunMetadataKey.PARALLEL_MODE_RUN_ID + if self.node_data.is_parallel + else NodeRunMetadataKey.ITERATION_INDEX: parallel_mode_run_id + if self.node_data.is_parallel + else iter_run_index, + } event.route_node_state.node_run_result.metadata = metadata return event diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 4f9e415f4b83a3..bfd93c074dd6d5 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -147,6 +147,8 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: planning_strategy=planning_strategy, ) elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: + if node_data.multiple_retrieval_config is None: + raise ValueError("multiple_retrieval_config is required") if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": if node_data.multiple_retrieval_config.reranking_model: reranking_model = { @@ -157,6 +159,8 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: reranking_model = None weights = None elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score": + if node_data.multiple_retrieval_config.weights is None: + raise ValueError("weights is required") reranking_model = None vector_setting = node_data.multiple_retrieval_config.weights.vector_setting weights = { @@ -180,7 +184,9 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: available_datasets=available_datasets, query=query, top_k=node_data.multiple_retrieval_config.top_k, - score_threshold=node_data.multiple_retrieval_config.score_threshold, + score_threshold=node_data.multiple_retrieval_config.score_threshold + if node_data.multiple_retrieval_config.score_threshold is not None + else 0.0, reranking_mode=node_data.multiple_retrieval_config.reranking_mode, reranking_model=reranking_model, weights=weights, @@ -205,7 +211,7 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: "content": item.page_content, } retrieval_resource_list.append(source) - document_score_list = {} + document_score_list: dict[str, float] = {} # deal with dify documents if dify_documents: document_score_list = {} @@ -260,7 +266,9 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: retrieval_resource_list.append(source) if retrieval_resource_list: retrieval_resource_list = sorted( - retrieval_resource_list, key=lambda x: x.get("metadata").get("score") or 0.0, reverse=True + retrieval_resource_list, + key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0, + reverse=True, ) position = 1 for item in retrieval_resource_list: @@ -295,6 +303,8 @@ def _fetch_model_config( :param node_data: node data :return: """ + if node_data.single_retrieval_config is None: + raise ValueError("single_retrieval_config is required") model_name = node_data.single_retrieval_config.model.name provider_name = node_data.single_retrieval_config.model.provider diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index 79066cece4f93c..432c57294ecbe9 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -1,5 +1,5 @@ from collections.abc import Callable, Sequence -from typing import Literal, Union +from typing import Any, Literal, Union from core.file import File from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment @@ -17,9 +17,9 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): _node_type = NodeType.LIST_OPERATOR def _run(self): - inputs = {} - process_data = {} - outputs = {} + inputs: dict[str, list] = {} + process_data: dict[str, list] = {} + outputs: dict[str, Any] = {} variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable) if variable is None: @@ -93,6 +93,8 @@ def _run(self): def _apply_filter( self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: + filter_func: Callable[[Any], bool] + result: list[Any] = [] for condition in self.node_data.filter_by.conditions: if isinstance(variable, ArrayStringSegment): if not isinstance(condition.value, str): @@ -236,6 +238,7 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]: + extract_func: Callable[[File], Any] if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str): extract_func = _get_file_extract_string_func(key=key) return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x)) @@ -249,47 +252,47 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str raise InvalidKeyError(f"Invalid key: {key}") -def _contains(value: str): +def _contains(value: str) -> Callable[[str], bool]: return lambda x: value in x -def _startswith(value: str): +def _startswith(value: str) -> Callable[[str], bool]: return lambda x: x.startswith(value) -def _endswith(value: str): +def _endswith(value: str) -> Callable[[str], bool]: return lambda x: x.endswith(value) -def _is(value: str): +def _is(value: str) -> Callable[[str], bool]: return lambda x: x is value -def _in(value: str | Sequence[str]): +def _in(value: str | Sequence[str]) -> Callable[[str], bool]: return lambda x: x in value -def _eq(value: int | float): +def _eq(value: int | float) -> Callable[[int | float], bool]: return lambda x: x == value -def _ne(value: int | float): +def _ne(value: int | float) -> Callable[[int | float], bool]: return lambda x: x != value -def _lt(value: int | float): +def _lt(value: int | float) -> Callable[[int | float], bool]: return lambda x: x < value -def _le(value: int | float): +def _le(value: int | float) -> Callable[[int | float], bool]: return lambda x: x <= value -def _gt(value: int | float): +def _gt(value: int | float) -> Callable[[int | float], bool]: return lambda x: x > value -def _ge(value: int | float): +def _ge(value: int | float) -> Callable[[int | float], bool]: return lambda x: x >= value @@ -302,6 +305,7 @@ def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]): def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]): + extract_func: Callable[[File], Any] if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url"}: extract_func = _get_file_extract_string_func(key=order_by) return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc") diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 55fac45576c821..6909b30c9e82ca 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -88,8 +88,8 @@ class LLMNode(BaseNode[LLMNodeData]): _node_data_cls = LLMNodeData _node_type = NodeType.LLM - def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None]: - node_inputs = None + def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: + node_inputs: Optional[dict[str, Any]] = None process_data = None try: @@ -196,7 +196,6 @@ def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None] error_type=type(e).__name__, ) ) - return except Exception as e: yield RunCompletedEvent( run_result=NodeRunResult( @@ -206,7 +205,6 @@ def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None] process_data=process_data, ) ) - return outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} @@ -302,7 +300,7 @@ def _transform_chat_messages( return messages def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]: - variables = {} + variables: dict[str, Any] = {} if not node_data.prompt_config: return variables @@ -319,7 +317,7 @@ def parse_dict(input_dict: Mapping[str, Any]) -> str: """ # check if it's a context structure if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict: - return input_dict["content"] + return str(input_dict["content"]) # else, parse the dict try: @@ -557,7 +555,8 @@ def _fetch_prompt_messages( variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: - prompt_messages = [] + # FIXME: fix the type error cause prompt_messages is type quick a few times + prompt_messages: list[Any] = [] if isinstance(prompt_template, list): # For chat model @@ -783,7 +782,7 @@ def _extract_variable_selector_to_variable_mapping( else: raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}") - variable_mapping = {} + variable_mapping: dict[str, Any] = {} for variable_selector in variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector @@ -981,7 +980,7 @@ def _handle_memory_chat_mode( memory_config: MemoryConfig | None, model_config: ModelConfigWithCredentialsEntity, ) -> Sequence[PromptMessage]: - memory_messages = [] + memory_messages: Sequence[PromptMessage] = [] # Get messages from memory for chat model if memory and memory_config: rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config) diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 6fdff966026b63..a366c287c2ac56 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -14,8 +14,8 @@ class LoopNode(BaseNode[LoopNodeData]): _node_data_cls = LoopNodeData _node_type = NodeType.LOOP - def _run(self) -> LoopState: - return super()._run() + def _run(self) -> LoopState: # type: ignore + return super()._run() # type: ignore @classmethod def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]: @@ -28,7 +28,7 @@ def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]: # TODO waiting for implementation return [ - Condition( + Condition( # type: ignore variable_selector=[node_id, "index"], comparison_operator="≤", value_type="value_selector", diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py index a001b44dc7dfee..369eb13b04e8c4 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -25,7 +25,7 @@ def validate_name(cls, value) -> str: raise ValueError("Parameter name is required") if value in {"__reason", "__is_success"}: raise ValueError("Invalid parameter name, __reason and __is_success are reserved") - return value + return str(value) class ParameterExtractorNodeData(BaseNodeData): @@ -52,7 +52,7 @@ def get_parameter_json_schema(self) -> dict: :return: parameter json schema """ - parameters = {"type": "object", "properties": {}, "required": []} + parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []} for parameter in self.parameters: parameter_schema: dict[str, Any] = {"description": parameter.description} diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index c8c854a43b3269..9c88047f2c8e57 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -63,7 +63,8 @@ class ParameterExtractorNode(LLMNode): Parameter Extractor Node. """ - _node_data_cls = ParameterExtractorNodeData + # FIXME: figure out why here is different from super class + _node_data_cls = ParameterExtractorNodeData # type: ignore _node_type = NodeType.PARAMETER_EXTRACTOR _model_instance: Optional[ModelInstance] = None @@ -253,6 +254,9 @@ def _invoke( # deduct quota self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) + if text is None: + text = "" + return text, usage, tool_call def _generate_function_call_prompt( @@ -605,9 +609,10 @@ def extract_json(text): json_str = extract_json(result[idx:]) if json_str: try: - return json.loads(json_str) + return cast(dict, json.loads(json_str)) except Exception: pass + return None def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]: """ @@ -616,13 +621,13 @@ def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCal if not tool_call or not tool_call.function.arguments: return None - return json.loads(tool_call.function.arguments) + return cast(dict, json.loads(tool_call.function.arguments)) def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict: """ Generate default result. """ - result = {} + result: dict[str, Any] = {} for parameter in data.parameters: if parameter.type == "number": result[parameter.name] = 0 @@ -772,7 +777,7 @@ def _extract_variable_selector_to_variable_mapping( *, graph_config: Mapping[str, Any], node_id: str, - node_data: ParameterExtractorNodeData, + node_data: ParameterExtractorNodeData, # type: ignore ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -781,6 +786,7 @@ def _extract_variable_selector_to_variable_mapping( :param node_data: node data :return: """ + # FIXME: fix the type error later variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query} if node_data.instruction: diff --git a/api/core/workflow/nodes/parameter_extractor/prompts.py b/api/core/workflow/nodes/parameter_extractor/prompts.py index e603add1704544..6c3155ac9a54e3 100644 --- a/api/core/workflow/nodes/parameter_extractor/prompts.py +++ b/api/core/workflow/nodes/parameter_extractor/prompts.py @@ -1,3 +1,5 @@ +from typing import Any + FUNCTION_CALLING_EXTRACTOR_NAME = "extract_parameters" FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy. @@ -35,7 +37,7 @@ """ # noqa: E501 -FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [ +FUNCTION_CALLING_EXTRACTOR_EXAMPLE: list[dict[str, Any]] = [ { "user": { "query": "What is the weather today in SF?", diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 31f8368d590ea9..0ec44eefacf52f 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,6 +1,6 @@ import json from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import Any, Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -34,12 +34,9 @@ QUESTION_CLASSIFIER_USER_PROMPT_3, ) -if TYPE_CHECKING: - from core.file import File - class QuestionClassifierNode(LLMNode): - _node_data_cls = QuestionClassifierNodeData + _node_data_cls = QuestionClassifierNodeData # type: ignore _node_type = NodeType.QUESTION_CLASSIFIER def _run(self): @@ -61,7 +58,7 @@ def _run(self): node_data.instruction = node_data.instruction or "" node_data.instruction = variable_pool.convert_template(node_data.instruction).text - files: Sequence[File] = ( + files = ( self._fetch_files( selector=node_data.vision.configs.variable_selector, ) @@ -168,7 +165,7 @@ def _extract_variable_selector_to_variable_mapping( *, graph_config: Mapping[str, Any], node_id: str, - node_data: QuestionClassifierNodeData, + node_data: Any, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -177,6 +174,7 @@ def _extract_variable_selector_to_variable_mapping( :param node_data: node data :return: """ + node_data = cast(QuestionClassifierNodeData, node_data) variable_mapping = {"query": node_data.query_variable_selector} variable_selectors = [] if node_data.instruction: diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 983fa7e623177a..01d07e494944b4 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -9,7 +9,6 @@ from core.file import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.tool_engine import ToolEngine -from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.variable_pool import VariablePool @@ -46,6 +45,8 @@ def _run(self) -> NodeRunResult: # get tool runtime try: + from core.tools.tool_manager import ToolManager + tool_runtime = ToolManager.get_workflow_tool_runtime( self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from ) @@ -142,7 +143,7 @@ def _generate_parameters( """ tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters} - result = {} + result: dict[str, Any] = {} for parameter_name in node_data.tool_parameters: parameter = tool_parameters_dictionary.get(parameter_name) if not parameter: @@ -264,9 +265,9 @@ def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) -> """ return "\n".join( [ - f"{message.message}" + str(message.message) if message.type == ToolInvokeMessage.MessageType.TEXT - else f"Link: {message.message}" + else f"Link: {str(message.message)}" for message in tool_response if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK} ] diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index 8eb4bd5c2da573..9acc76f326eec9 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -36,6 +36,8 @@ def _run(self) -> NodeRunResult: case WriteMode.CLEAR: income_value = get_zero_value(original_variable.value_type) + if income_value is None: + raise VariableOperatorNodeError("income value not found") updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) case _: diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index d73c7442029225..0c4aae827c0a0f 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -1,5 +1,5 @@ import json -from typing import Any +from typing import Any, cast from core.variables import SegmentType, Variable from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID @@ -29,7 +29,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): def _run(self) -> NodeRunResult: inputs = self.node_data.model_dump() - process_data = {} + process_data: dict[str, Any] = {} # NOTE: This node has no outputs updated_variables: list[Variable] = [] @@ -119,7 +119,7 @@ def _run(self) -> NodeRunResult: else: conversation_id = conversation_id.value common_helpers.update_conversation_variable( - conversation_id=conversation_id, + conversation_id=cast(str, conversation_id), variable=variable, ) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 811e40c11e5407..b14c6fafbd9fdc 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -129,11 +129,11 @@ def single_step_run( :return: """ # fetch node info from workflow graph - graph = workflow.graph_dict - if not graph: + workflow_graph = workflow.graph_dict + if not workflow_graph: raise ValueError("workflow graph not found") - nodes = graph.get("nodes") + nodes = workflow_graph.get("nodes") if not nodes: raise ValueError("nodes not found in workflow graph") @@ -196,7 +196,8 @@ def single_step_run( @staticmethod def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: - return WorkflowEntry._handle_special_values(value) + result = WorkflowEntry._handle_special_values(value) + return result if isinstance(result, Mapping) or result is None else dict(result) @staticmethod def _handle_special_values(value: Any) -> Any: @@ -208,10 +209,10 @@ def _handle_special_values(value: Any) -> Any: res[k] = WorkflowEntry._handle_special_values(v) return res if isinstance(value, list): - res = [] + res_list = [] for item in value: - res.append(WorkflowEntry._handle_special_values(item)) - return res + res_list.append(WorkflowEntry._handle_special_values(item)) + return res_list if isinstance(value, File): return value.to_dict() return value diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 24fa013697994c..8a677f6b6fc017 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -14,7 +14,7 @@ @document_index_created.connect def handle(sender, **kwargs): dataset_id = sender - document_ids = kwargs.get("document_ids") + document_ids = kwargs.get("document_ids", []) documents = [] start_at = time.perf_counter() for document_id in document_ids: diff --git a/api/events/event_handlers/create_site_record_when_app_created.py b/api/events/event_handlers/create_site_record_when_app_created.py index 1515661b2d45b8..5e7caf8cbed71e 100644 --- a/api/events/event_handlers/create_site_record_when_app_created.py +++ b/api/events/event_handlers/create_site_record_when_app_created.py @@ -8,18 +8,19 @@ def handle(sender, **kwargs): """Create site record when an app is created.""" app = sender account = kwargs.get("account") - site = Site( - app_id=app.id, - title=app.name, - icon_type=app.icon_type, - icon=app.icon, - icon_background=app.icon_background, - default_language=account.interface_language, - customize_token_strategy="not_allow", - code=Site.generate_code(16), - created_by=app.created_by, - updated_by=app.updated_by, - ) + if account is not None: + site = Site( + app_id=app.id, + title=app.name, + icon_type=app.icon_type, + icon=app.icon, + icon_background=app.icon_background, + default_language=account.interface_language, + customize_token_strategy="not_allow", + code=Site.generate_code(16), + created_by=app.created_by, + updated_by=app.updated_by, + ) - db.session.add(site) - db.session.commit() + db.session.add(site) + db.session.commit() diff --git a/api/events/event_handlers/deduct_quota_when_message_created.py b/api/events/event_handlers/deduct_quota_when_message_created.py index 843a2320968ced..1ed37efba0b3be 100644 --- a/api/events/event_handlers/deduct_quota_when_message_created.py +++ b/api/events/event_handlers/deduct_quota_when_message_created.py @@ -44,7 +44,7 @@ def handle(sender, **kwargs): else: used_quota = 1 - if used_quota is not None: + if used_quota is not None and system_configuration.current_quota_type is not None: db.session.query(Provider).filter( Provider.tenant_id == application_generate_entity.app_config.tenant_id, Provider.provider_name == model_config.provider, diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index 9c5955c8c5a1a5..f89fae24a56378 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -8,7 +8,10 @@ @app_draft_workflow_was_synced.connect def handle(sender, **kwargs): app = sender - for node_data in kwargs.get("synced_draft_workflow").graph_dict.get("nodes", []): + synced_draft_workflow = kwargs.get("synced_draft_workflow") + if synced_draft_workflow is None: + return + for node_data in synced_draft_workflow.graph_dict.get("nodes", []): if node_data.get("data", {}).get("type") == NodeType.TOOL.value: try: tool_entity = ToolEntity(**node_data["data"]) diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py index de7c0f4dfeb74f..408ed31096d2a0 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py @@ -8,16 +8,18 @@ def handle(sender, **kwargs): app = sender app_model_config = kwargs.get("app_model_config") + if app_model_config is None: + return dataset_ids = get_dataset_ids_from_model_config(app_model_config) app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all() - removed_dataset_ids = [] + removed_dataset_ids: set[int] = set() if not app_dataset_joins: added_dataset_ids = dataset_ids else: - old_dataset_ids = set() + old_dataset_ids: set[int] = set() old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins) added_dataset_ids = dataset_ids - old_dataset_ids @@ -37,8 +39,8 @@ def handle(sender, **kwargs): db.session.commit() -def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set: - dataset_ids = set() +def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set[int]: + dataset_ids: set[int] = set() if not app_model_config: return dataset_ids diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 453395e8d7dc1c..7a31c82f6adbc2 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -17,11 +17,11 @@ def handle(sender, **kwargs): dataset_ids = get_dataset_ids_from_workflow(published_workflow) app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all() - removed_dataset_ids = [] + removed_dataset_ids: set[int] = set() if not app_dataset_joins: added_dataset_ids = dataset_ids else: - old_dataset_ids = set() + old_dataset_ids: set[int] = set() old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins) added_dataset_ids = dataset_ids - old_dataset_ids @@ -41,8 +41,8 @@ def handle(sender, **kwargs): db.session.commit() -def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set: - dataset_ids = set() +def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set[int]: + dataset_ids: set[int] = set() graph = published_workflow.graph_dict if not graph: return dataset_ids @@ -60,7 +60,7 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set: for node in knowledge_retrieval_nodes: try: node_data = KnowledgeRetrievalNodeData(**node.get("data", {})) - dataset_ids.update(node_data.dataset_ids) + dataset_ids.update(int(dataset_id) for dataset_id in node_data.dataset_ids) except Exception as e: continue diff --git a/api/extensions/__init__.py b/api/extensions/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/extensions/ext_app_metrics.py b/api/extensions/ext_app_metrics.py index de1cdfeb984e86..b7d412d68deda1 100644 --- a/api/extensions/ext_app_metrics.py +++ b/api/extensions/ext_app_metrics.py @@ -54,12 +54,14 @@ def pool_stat(): from extensions.ext_database import db engine = db.engine + # TODO: Fix the type error + # FIXME maybe its sqlalchemy issue return { "pid": os.getpid(), - "pool_size": engine.pool.size(), - "checked_in_connections": engine.pool.checkedin(), - "checked_out_connections": engine.pool.checkedout(), - "overflow_connections": engine.pool.overflow(), - "connection_timeout": engine.pool.timeout(), - "recycle_time": db.engine.pool._recycle, + "pool_size": engine.pool.size(), # type: ignore + "checked_in_connections": engine.pool.checkedin(), # type: ignore + "checked_out_connections": engine.pool.checkedout(), # type: ignore + "overflow_connections": engine.pool.overflow(), # type: ignore + "connection_timeout": engine.pool.timeout(), # type: ignore + "recycle_time": db.engine.pool._recycle, # type: ignore } diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 9dbc4b93d46266..30f216ff95612b 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -1,8 +1,8 @@ from datetime import timedelta import pytz -from celery import Celery, Task -from celery.schedules import crontab +from celery import Celery, Task # type: ignore +from celery.schedules import crontab # type: ignore from configs import dify_config from dify_app import DifyApp @@ -47,7 +47,7 @@ def __call__(self, *args: object, **kwargs: object) -> object: worker_log_format=dify_config.LOG_FORMAT, worker_task_log_format=dify_config.LOG_FORMAT, worker_hijack_root_logger=False, - timezone=pytz.timezone(dify_config.LOG_TZ), + timezone=pytz.timezone(dify_config.LOG_TZ or "UTC"), ) if dify_config.BROKER_USE_SSL: diff --git a/api/extensions/ext_compress.py b/api/extensions/ext_compress.py index 9c3a663af417ae..26ff6427bef1cc 100644 --- a/api/extensions/ext_compress.py +++ b/api/extensions/ext_compress.py @@ -7,7 +7,7 @@ def is_enabled() -> bool: def init_app(app: DifyApp): - from flask_compress import Compress + from flask_compress import Compress # type: ignore compress = Compress() compress.init_app(app) diff --git a/api/extensions/ext_logging.py b/api/extensions/ext_logging.py index 9fc29b4eb17212..e1c459e8c17fd0 100644 --- a/api/extensions/ext_logging.py +++ b/api/extensions/ext_logging.py @@ -11,7 +11,7 @@ def init_app(app: DifyApp): - log_handlers = [] + log_handlers: list[logging.Handler] = [] log_file = dify_config.LOG_FILE if log_file: log_dir = os.path.dirname(log_file) @@ -49,7 +49,8 @@ def time_converter(seconds): return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple() for handler in logging.root.handlers: - handler.formatter.converter = time_converter + if handler.formatter: + handler.formatter.converter = time_converter def get_request_id(): diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index b2955307144d67..10fb89eb7370ee 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -1,6 +1,6 @@ import json -import flask_login +import flask_login # type: ignore from flask import Response, request from flask_login import user_loaded_from_request, user_logged_in from werkzeug.exceptions import Unauthorized diff --git a/api/extensions/ext_mail.py b/api/extensions/ext_mail.py index 468aedd47ea90b..9240ebe7fcba73 100644 --- a/api/extensions/ext_mail.py +++ b/api/extensions/ext_mail.py @@ -26,7 +26,7 @@ def init_app(self, app: Flask): match mail_type: case "resend": - import resend + import resend # type: ignore api_key = dify_config.RESEND_API_KEY if not api_key: @@ -48,9 +48,9 @@ def init_app(self, app: Flask): self._client = SMTPClient( server=dify_config.SMTP_SERVER, port=dify_config.SMTP_PORT, - username=dify_config.SMTP_USERNAME, - password=dify_config.SMTP_PASSWORD, - _from=dify_config.MAIL_DEFAULT_SEND_FROM, + username=dify_config.SMTP_USERNAME or "", + password=dify_config.SMTP_PASSWORD or "", + _from=dify_config.MAIL_DEFAULT_SEND_FROM or "", use_tls=dify_config.SMTP_USE_TLS, opportunistic_tls=dify_config.SMTP_OPPORTUNISTIC_TLS, ) diff --git a/api/extensions/ext_migrate.py b/api/extensions/ext_migrate.py index 6d8f35c30d9c65..5f862181fa8540 100644 --- a/api/extensions/ext_migrate.py +++ b/api/extensions/ext_migrate.py @@ -2,7 +2,7 @@ def init_app(app: DifyApp): - import flask_migrate + import flask_migrate # type: ignore from extensions.ext_database import db diff --git a/api/extensions/ext_proxy_fix.py b/api/extensions/ext_proxy_fix.py index 3b895ac95b5029..514e0658257293 100644 --- a/api/extensions/ext_proxy_fix.py +++ b/api/extensions/ext_proxy_fix.py @@ -6,4 +6,4 @@ def init_app(app: DifyApp): if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED: from werkzeug.middleware.proxy_fix import ProxyFix - app.wsgi_app = ProxyFix(app.wsgi_app) + app.wsgi_app = ProxyFix(app.wsgi_app) # type: ignore diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index 3ec8ae6e1dc14e..3a74aace6a34cf 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -6,7 +6,7 @@ def init_app(app: DifyApp): if dify_config.SENTRY_DSN: import openai import sentry_sdk - from langfuse import parse_error + from langfuse import parse_error # type: ignore from sentry_sdk.integrations.celery import CeleryIntegration from sentry_sdk.integrations.flask import FlaskIntegration from werkzeug.exceptions import HTTPException diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index 42422263c4dd03..588bdb2d2717e0 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -1,6 +1,6 @@ import logging from collections.abc import Callable, Generator -from typing import Union +from typing import Literal, Union, overload from flask import Flask @@ -79,6 +79,12 @@ def save(self, filename, data): logger.exception(f"Failed to save file {filename}") raise e + @overload + def load(self, filename: str, /, *, stream: Literal[False] = False) -> bytes: ... + + @overload + def load(self, filename: str, /, *, stream: Literal[True]) -> Generator: ... + def load(self, filename: str, /, *, stream: bool = False) -> Union[bytes, Generator]: try: if stream: diff --git a/api/extensions/storage/aliyun_oss_storage.py b/api/extensions/storage/aliyun_oss_storage.py index 58c917dbd386bc..00bf5d4f93ae3b 100644 --- a/api/extensions/storage/aliyun_oss_storage.py +++ b/api/extensions/storage/aliyun_oss_storage.py @@ -1,7 +1,7 @@ import posixpath from collections.abc import Generator -import oss2 as aliyun_s3 +import oss2 as aliyun_s3 # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -33,7 +33,7 @@ def save(self, filename, data): def load_once(self, filename: str) -> bytes: obj = self.client.get_object(self.__wrapper_folder_filename(filename)) - data = obj.read() + data: bytes = obj.read() return data def load_stream(self, filename: str) -> Generator: @@ -41,14 +41,14 @@ def load_stream(self, filename: str) -> Generator: while chunk := obj.read(4096): yield chunk - def download(self, filename, target_filepath): + def download(self, filename: str, target_filepath): self.client.get_object_to_file(self.__wrapper_folder_filename(filename), target_filepath) - def exists(self, filename): + def exists(self, filename: str): return self.client.object_exists(self.__wrapper_folder_filename(filename)) - def delete(self, filename): + def delete(self, filename: str): self.client.delete_object(self.__wrapper_folder_filename(filename)) - def __wrapper_folder_filename(self, filename) -> str: + def __wrapper_folder_filename(self, filename: str) -> str: return posixpath.join(self.folder, filename) if self.folder else filename diff --git a/api/extensions/storage/aws_s3_storage.py b/api/extensions/storage/aws_s3_storage.py index ce36c2e7deeeda..7b6b2eedd62bf2 100644 --- a/api/extensions/storage/aws_s3_storage.py +++ b/api/extensions/storage/aws_s3_storage.py @@ -1,9 +1,9 @@ import logging from collections.abc import Generator -import boto3 -from botocore.client import Config -from botocore.exceptions import ClientError +import boto3 # type: ignore +from botocore.client import Config # type: ignore +from botocore.exceptions import ClientError # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -53,7 +53,7 @@ def save(self, filename, data): def load_once(self, filename: str) -> bytes: try: - data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() + data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() except ClientError as ex: if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") diff --git a/api/extensions/storage/azure_blob_storage.py b/api/extensions/storage/azure_blob_storage.py index b26caa8671b6df..2f8532f4f8f653 100644 --- a/api/extensions/storage/azure_blob_storage.py +++ b/api/extensions/storage/azure_blob_storage.py @@ -27,7 +27,7 @@ def load_once(self, filename: str) -> bytes: client = self._sync_client() blob = client.get_container_client(container=self.bucket_name) blob = blob.get_blob_client(blob=filename) - data = blob.download_blob().readall() + data: bytes = blob.download_blob().readall() return data def load_stream(self, filename: str) -> Generator: @@ -63,11 +63,11 @@ def _sync_client(self): sas_token = cache_result.decode("utf-8") else: sas_token = generate_account_sas( - account_name=self.account_name, - account_key=self.account_key, + account_name=self.account_name or "", + account_key=self.account_key or "", resource_types=ResourceTypes(service=True, container=True, object=True), permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True), expiry=datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1), ) redis_client.set(cache_key, sas_token, ex=3000) - return BlobServiceClient(account_url=self.account_url, credential=sas_token) + return BlobServiceClient(account_url=self.account_url or "", credential=sas_token) diff --git a/api/extensions/storage/baidu_obs_storage.py b/api/extensions/storage/baidu_obs_storage.py index e0d2140e91272c..b94efa08be7613 100644 --- a/api/extensions/storage/baidu_obs_storage.py +++ b/api/extensions/storage/baidu_obs_storage.py @@ -2,9 +2,9 @@ import hashlib from collections.abc import Generator -from baidubce.auth.bce_credentials import BceCredentials -from baidubce.bce_client_configuration import BceClientConfiguration -from baidubce.services.bos.bos_client import BosClient +from baidubce.auth.bce_credentials import BceCredentials # type: ignore +from baidubce.bce_client_configuration import BceClientConfiguration # type: ignore +from baidubce.services.bos.bos_client import BosClient # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -36,7 +36,8 @@ def save(self, filename, data): def load_once(self, filename: str) -> bytes: response = self.client.get_object(bucket_name=self.bucket_name, key=filename) - return response.data.read() + data: bytes = response.data.read() + return data def load_stream(self, filename: str) -> Generator: response = self.client.get_object(bucket_name=self.bucket_name, key=filename).data diff --git a/api/extensions/storage/google_cloud_storage.py b/api/extensions/storage/google_cloud_storage.py index 26b662d2f04daf..705639f42e716f 100644 --- a/api/extensions/storage/google_cloud_storage.py +++ b/api/extensions/storage/google_cloud_storage.py @@ -3,7 +3,7 @@ import json from collections.abc import Generator -from google.cloud import storage as google_cloud_storage +from google.cloud import storage as google_cloud_storage # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -35,7 +35,7 @@ def save(self, filename, data): def load_once(self, filename: str) -> bytes: bucket = self.client.get_bucket(self.bucket_name) blob = bucket.get_blob(filename) - data = blob.download_as_bytes() + data: bytes = blob.download_as_bytes() return data def load_stream(self, filename: str) -> Generator: diff --git a/api/extensions/storage/huawei_obs_storage.py b/api/extensions/storage/huawei_obs_storage.py index 20be70ef83dd7a..07f1d199701be4 100644 --- a/api/extensions/storage/huawei_obs_storage.py +++ b/api/extensions/storage/huawei_obs_storage.py @@ -1,6 +1,6 @@ from collections.abc import Generator -from obs import ObsClient +from obs import ObsClient # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -23,7 +23,7 @@ def save(self, filename, data): self.client.putObject(bucketName=self.bucket_name, objectKey=filename, content=data) def load_once(self, filename: str) -> bytes: - data = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read() + data: bytes = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read() return data def load_stream(self, filename: str) -> Generator: diff --git a/api/extensions/storage/opendal_storage.py b/api/extensions/storage/opendal_storage.py index e671eff059ba21..b78fc94dae7843 100644 --- a/api/extensions/storage/opendal_storage.py +++ b/api/extensions/storage/opendal_storage.py @@ -3,7 +3,7 @@ from collections.abc import Generator from pathlib import Path -import opendal +import opendal # type: ignore[import] from dotenv import dotenv_values from extensions.storage.base_storage import BaseStorage @@ -18,7 +18,7 @@ def _get_opendal_kwargs(*, scheme: str, env_file_path: str = ".env", prefix: str if key.startswith(config_prefix): kwargs[key[len(config_prefix) :].lower()] = value - file_env_vars = dotenv_values(env_file_path) + file_env_vars: dict = dotenv_values(env_file_path) or {} for key, value in file_env_vars.items(): if key.startswith(config_prefix) and key[len(config_prefix) :].lower() not in kwargs and value: kwargs[key[len(config_prefix) :].lower()] = value @@ -48,7 +48,7 @@ def load_once(self, filename: str) -> bytes: if not self.exists(filename): raise FileNotFoundError("File not found") - content = self.op.read(path=filename) + content: bytes = self.op.read(path=filename) logger.debug(f"file {filename} loaded") return content @@ -75,7 +75,7 @@ def exists(self, filename: str) -> bool: # error handler here when opendal python-binding has a exists method, we should use it # more https://github.com/apache/opendal/blob/main/bindings/python/src/operator.rs try: - res = self.op.stat(path=filename).mode.is_file() + res: bool = self.op.stat(path=filename).mode.is_file() logger.debug(f"file {filename} checked") return res except Exception: diff --git a/api/extensions/storage/oracle_oci_storage.py b/api/extensions/storage/oracle_oci_storage.py index b59f83b8de90bf..82829f7fd50d65 100644 --- a/api/extensions/storage/oracle_oci_storage.py +++ b/api/extensions/storage/oracle_oci_storage.py @@ -1,7 +1,7 @@ from collections.abc import Generator -import boto3 -from botocore.exceptions import ClientError +import boto3 # type: ignore +from botocore.exceptions import ClientError # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -27,7 +27,7 @@ def save(self, filename, data): def load_once(self, filename: str) -> bytes: try: - data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() + data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() except ClientError as ex: if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") diff --git a/api/extensions/storage/supabase_storage.py b/api/extensions/storage/supabase_storage.py index 9f7c69a9ae6312..711c3f72117c86 100644 --- a/api/extensions/storage/supabase_storage.py +++ b/api/extensions/storage/supabase_storage.py @@ -32,7 +32,7 @@ def save(self, filename, data): self.client.storage.from_(self.bucket_name).upload(filename, data) def load_once(self, filename: str) -> bytes: - content = self.client.storage.from_(self.bucket_name).download(filename) + content: bytes = self.client.storage.from_(self.bucket_name).download(filename) return content def load_stream(self, filename: str) -> Generator: diff --git a/api/extensions/storage/tencent_cos_storage.py b/api/extensions/storage/tencent_cos_storage.py index 13a6c9239c2d1e..9cdd3e67f75aab 100644 --- a/api/extensions/storage/tencent_cos_storage.py +++ b/api/extensions/storage/tencent_cos_storage.py @@ -1,6 +1,6 @@ from collections.abc import Generator -from qcloud_cos import CosConfig, CosS3Client +from qcloud_cos import CosConfig, CosS3Client # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -25,7 +25,7 @@ def save(self, filename, data): self.client.put_object(Bucket=self.bucket_name, Body=data, Key=filename) def load_once(self, filename: str) -> bytes: - data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read() + data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read() return data def load_stream(self, filename: str) -> Generator: diff --git a/api/extensions/storage/volcengine_tos_storage.py b/api/extensions/storage/volcengine_tos_storage.py index de82be04ea87b7..55fe6545ec3d2d 100644 --- a/api/extensions/storage/volcengine_tos_storage.py +++ b/api/extensions/storage/volcengine_tos_storage.py @@ -1,6 +1,6 @@ from collections.abc import Generator -import tos +import tos # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -24,6 +24,8 @@ def save(self, filename, data): def load_once(self, filename: str) -> bytes: data = self.client.get_object(bucket=self.bucket_name, key=filename).read() + if not isinstance(data, bytes): + raise TypeError("Expected bytes, got {}".format(type(data).__name__)) return data def load_stream(self, filename: str) -> Generator: diff --git a/api/factories/__init__.py b/api/factories/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 13034f5cf5688b..856cf62e3ed243 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -64,7 +64,7 @@ def build_from_mapping( if not build_func: raise ValueError(f"Invalid file transfer method: {transfer_method}") - file = build_func( + file: File = build_func( mapping=mapping, tenant_id=tenant_id, transfer_method=transfer_method, @@ -72,7 +72,7 @@ def build_from_mapping( if config and not _is_file_valid_with_config( input_file_type=mapping.get("type", FileType.CUSTOM), - file_extension=file.extension, + file_extension=file.extension or "", file_transfer_method=file.transfer_method, config=config, ): @@ -281,6 +281,7 @@ def _get_file_type_by_extension(extension: str) -> FileType | None: return FileType.AUDIO elif extension in DOCUMENT_EXTENSIONS: return FileType.DOCUMENT + return None def _get_file_type_by_mimetype(mime_type: str) -> FileType | None: diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 16a578728aa16e..bbca8448ec0662 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any +from typing import Any, cast from uuid import uuid4 from configs import dify_config @@ -84,6 +84,8 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen raise VariableError("missing value type") if (value := mapping.get("value")) is None: raise VariableError("missing value") + # FIXME: using Any here, fix it later + result: Any match value_type: case SegmentType.STRING: result = StringVariable.model_validate(mapping) @@ -109,7 +111,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}") if not result.selector: result = result.model_copy(update={"selector": selector}) - return result + return cast(Variable, result) def build_segment(value: Any, /) -> Segment: @@ -164,10 +166,13 @@ def segment_to_variable( raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}") variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] - return variable_class( - id=id, - name=name, - description=description, - value=segment.value, - selector=selector, + return cast( + Variable, + variable_class( + id=id, + name=name, + description=description, + value=segment.value, + selector=selector, + ), ) diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index 379dcc6d16fe56..1c58b3a2579087 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/api_based_extension_fields.py b/api/fields/api_based_extension_fields.py index a85d4a34dbe7b1..d40407bfcc6193 100644 --- a/api/fields/api_based_extension_fields.py +++ b/api/fields/api_based_extension_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index abb27fdad17d63..73800eab853cd3 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.workflow_fields import workflow_partial_fields from libs.helper import AppIconUrlField, TimestampField diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 6a9e347b1e04b4..c54554a6de8405 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.member_fields import simple_account_fields from libs.helper import TimestampField diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py index 983e50e73ceb9f..c6385efb5a3cf1 100644 --- a/api/fields/conversation_variable_fields.py +++ b/api/fields/conversation_variable_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/data_source_fields.py b/api/fields/data_source_fields.py index 071071376fe6c8..608672121e2b50 100644 --- a/api/fields/data_source_fields.py +++ b/api/fields/data_source_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 533e3a0837b815..a74e6f54fb3858 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py index a83ec7bc97adee..2b2ac6243f4da5 100644 --- a/api/fields/document_fields.py +++ b/api/fields/document_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.dataset_fields import dataset_fields from libs.helper import TimestampField diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py index 99e529f9d1c076..aefa0b27580ca7 100644 --- a/api/fields/end_user_fields.py +++ b/api/fields/end_user_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore simple_end_user_fields = { "id": fields.String, diff --git a/api/fields/external_dataset_fields.py b/api/fields/external_dataset_fields.py index 2281460fe22146..9cc4e14a0575d7 100644 --- a/api/fields/external_dataset_fields.py +++ b/api/fields/external_dataset_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index afaacc0568ea0c..f896c15f0fec70 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py index f36e80f8d493d5..aaafcab8ab6ba0 100644 --- a/api/fields/hit_testing_fields.py +++ b/api/fields/hit_testing_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/installed_app_fields.py b/api/fields/installed_app_fields.py index e0b3e340f67b8c..16f265b9bb6d07 100644 --- a/api/fields/installed_app_fields.py +++ b/api/fields/installed_app_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import AppIconUrlField, TimestampField diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index 1cf8e408d13d32..0c854c640c3f98 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 5f6e7884a69c5e..0571faab08c134 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.conversation_fields import message_file_fields from libs.helper import TimestampField diff --git a/api/fields/raws.py b/api/fields/raws.py index 15ec16ab13e4a8..493d4b6cce7d31 100644 --- a/api/fields/raws.py +++ b/api/fields/raws.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from core.file import File diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py index 2dd4cb45be409b..4413af31607897 100644 --- a/api/fields/segment_fields.py +++ b/api/fields/segment_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/tag_fields.py b/api/fields/tag_fields.py index 9af4fc57dd061c..986cd725f70910 100644 --- a/api/fields/tag_fields.py +++ b/api/fields/tag_fields.py @@ -1,3 +1,3 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String, "binding_count": fields.String} diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index a53b54624915c2..c45b33597b3978 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 0d860d6f406502..bd093d4063bc2e 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from core.helper import encrypter from core.variables import SecretVariable, SegmentType, Variable diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 74fdf8bd97b23a..ef59c57ec37957 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields diff --git a/api/libs/external_api.py b/api/libs/external_api.py index 179617ac0a6588..922d2d9cd33324 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -1,8 +1,9 @@ import re import sys +from typing import Any from flask import current_app, got_request_exception -from flask_restful import Api, http_status_message +from flask_restful import Api, http_status_message # type: ignore from werkzeug.datastructures import Headers from werkzeug.exceptions import HTTPException @@ -84,7 +85,7 @@ def handle_error(self, e): # record the exception in the logs when we have a server error of status code: 500 if status_code and status_code >= 500: - exc_info = sys.exc_info() + exc_info: Any = sys.exc_info() if exc_info[1] is None: exc_info = None current_app.log_exception(exc_info) @@ -100,7 +101,7 @@ def handle_error(self, e): resp = self.make_response(data, status_code, headers, fallback_mediatype=fallback_mediatype) elif status_code == 400: if isinstance(data.get("message"), dict): - param_key, param_value = list(data.get("message").items())[0] + param_key, param_value = list(data.get("message", {}).items())[0] data = {"code": "invalid_param", "message": param_value, "params": param_key} else: if "code" not in data: diff --git a/api/libs/gmpy2_pkcs10aep_cipher.py b/api/libs/gmpy2_pkcs10aep_cipher.py index 83f9c74e339e17..2dae87e1710bf6 100644 --- a/api/libs/gmpy2_pkcs10aep_cipher.py +++ b/api/libs/gmpy2_pkcs10aep_cipher.py @@ -23,7 +23,7 @@ import Crypto.Hash.SHA1 import Crypto.Util.number -import gmpy2 +import gmpy2 # type: ignore from Crypto import Random from Crypto.Signature.pss import MGF1 from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes @@ -191,12 +191,12 @@ def decrypt(self, ciphertext): # Step 3g one_pos = hLen + db[hLen:].find(b"\x01") lHash1 = db[:hLen] - invalid = bord(y) | int(one_pos < hLen) + invalid = bord(y) | int(one_pos < hLen) # type: ignore hash_compare = strxor(lHash1, lHash) for x in hash_compare: - invalid |= bord(x) + invalid |= bord(x) # type: ignore for x in db[hLen:one_pos]: - invalid |= bord(x) + invalid |= bord(x) # type: ignore if invalid != 0: raise ValueError("Incorrect decryption.") # Step 4 diff --git a/api/libs/helper.py b/api/libs/helper.py index 91b1d1fe173d6f..eaa4efdb714355 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -13,7 +13,7 @@ from zoneinfo import available_timezones from flask import Response, stream_with_context -from flask_restful import fields +from flask_restful import fields # type: ignore from configs import dify_config from core.app.features.rate_limiting.rate_limit import RateLimitGenerator @@ -248,13 +248,13 @@ def get_token_data(cls, token: str, token_type: str) -> Optional[dict[str, Any]] if token_data_json is None: logging.warning(f"{token_type} token {token} not found with key {key}") return None - token_data = json.loads(token_data_json) + token_data: Optional[dict[str, Any]] = json.loads(token_data_json) return token_data @classmethod def _get_current_token_for_account(cls, account_id: str, token_type: str) -> Optional[str]: key = cls._get_account_token_key(account_id, token_type) - current_token = redis_client.get(key) + current_token: Optional[str] = redis_client.get(key) return current_token @classmethod diff --git a/api/libs/json_in_md_parser.py b/api/libs/json_in_md_parser.py index 267af611f5e8cb..9ab53b6294db93 100644 --- a/api/libs/json_in_md_parser.py +++ b/api/libs/json_in_md_parser.py @@ -10,6 +10,7 @@ def parse_json_markdown(json_string: str) -> dict: ends = ["```", "``", "`", "}"] end_index = -1 start_index = 0 + parsed: dict = {} for s in starts: start_index = json_string.find(s) if start_index != -1: diff --git a/api/libs/login.py b/api/libs/login.py index 0ea191a185785d..5395534a6df93a 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -1,8 +1,9 @@ from functools import wraps +from typing import Any from flask import current_app, g, has_request_context, request -from flask_login import user_logged_in -from flask_login.config import EXEMPT_METHODS +from flask_login import user_logged_in # type: ignore +from flask_login.config import EXEMPT_METHODS # type: ignore from werkzeug.exceptions import Unauthorized from werkzeug.local import LocalProxy @@ -12,7 +13,7 @@ #: A proxy for the current user. If no user is logged in, this will be an #: anonymous user -current_user = LocalProxy(lambda: _get_user()) +current_user: Any = LocalProxy(lambda: _get_user()) def login_required(func): @@ -79,12 +80,12 @@ def decorated_view(*args, **kwargs): # Login admin if account: account.current_tenant = tenant - current_app.login_manager._update_request_context_with_user(account) - user_logged_in.send(current_app._get_current_object(), user=_get_user()) + current_app.login_manager._update_request_context_with_user(account) # type: ignore + user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: pass elif not current_user.is_authenticated: - return current_app.login_manager.unauthorized() + return current_app.login_manager.unauthorized() # type: ignore # flask 1.x compatibility # current_app.ensure_sync is only available in Flask >= 2.0 @@ -98,7 +99,7 @@ def decorated_view(*args, **kwargs): def _get_user(): if has_request_context(): if "_login_user" not in g: - current_app.login_manager._load_user() + current_app.login_manager._load_user() # type: ignore return g._login_user diff --git a/api/libs/oauth.py b/api/libs/oauth.py index 6b6919de24f90f..df75b550195298 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -77,9 +77,9 @@ def get_raw_user_info(self, token: str): email_response = requests.get(self._EMAIL_INFO_URL, headers=headers) email_info = email_response.json() - primary_email = next((email for email in email_info if email["primary"] == True), None) + primary_email: dict = next((email for email in email_info if email["primary"] == True), {}) - return {**user_info, "email": primary_email["email"]} + return {**user_info, "email": primary_email.get("email", "")} def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: email = raw_info.get("email") @@ -130,4 +130,4 @@ def get_raw_user_info(self, token: str): return response.json() def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: - return OAuthUserInfo(id=str(raw_info["sub"]), name=None, email=raw_info["email"]) + return OAuthUserInfo(id=str(raw_info["sub"]), name="", email=raw_info["email"]) diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 1d39abd8fa7886..0c872a0066d127 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -1,8 +1,9 @@ import datetime import urllib.parse +from typing import Any import requests -from flask_login import current_user +from flask_login import current_user # type: ignore from extensions.ext_database import db from models.source import DataSourceOauthBinding @@ -226,7 +227,7 @@ def notion_page_search(self, access_token: str): has_more = True while has_more: - data = { + data: dict[str, Any] = { "filter": {"value": "page", "property": "object"}, **({"start_cursor": next_cursor} if next_cursor else {}), } @@ -281,7 +282,7 @@ def notion_database_search(self, access_token: str): has_more = True while has_more: - data = { + data: dict[str, Any] = { "filter": {"value": "database", "property": "object"}, **({"start_cursor": next_cursor} if next_cursor else {}), } diff --git a/api/libs/threadings_utils.py b/api/libs/threadings_utils.py index d356def418ab1d..e4d63fd3142ce2 100644 --- a/api/libs/threadings_utils.py +++ b/api/libs/threadings_utils.py @@ -9,8 +9,8 @@ def apply_gevent_threading_patch(): :return: """ if not dify_config.DEBUG: - from gevent import monkey - from grpc.experimental import gevent as grpc_gevent + from gevent import monkey # type: ignore + from grpc.experimental import gevent as grpc_gevent # type: ignore # gevent monkey.patch_all() diff --git a/api/models/account.py b/api/models/account.py index a8602d10a97308..88c96da1a149d5 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -1,7 +1,7 @@ import enum import json -from flask_login import UserMixin +from flask_login import UserMixin # type: ignore from sqlalchemy import func from .engine import db @@ -16,7 +16,7 @@ class AccountStatus(enum.StrEnum): CLOSED = "closed" -class Account(UserMixin, db.Model): +class Account(UserMixin, db.Model): # type: ignore[name-defined] __tablename__ = "accounts" __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) @@ -43,7 +43,8 @@ def is_password_set(self): @property def current_tenant(self): - return self._current_tenant + # FIXME: fix the type error later, because the type is important maybe cause some bugs + return self._current_tenant # type: ignore @current_tenant.setter def current_tenant(self, value: "Tenant"): @@ -52,7 +53,8 @@ def current_tenant(self, value: "Tenant"): if ta: tenant.current_role = ta.role else: - tenant = None + # FIXME: fix the type error later, because the type is important maybe cause some bugs + tenant = None # type: ignore self._current_tenant = tenant @property @@ -89,7 +91,7 @@ def get_status(self) -> AccountStatus: return AccountStatus(status_str) @classmethod - def get_by_openid(cls, provider: str, open_id: str) -> db.Model: + def get_by_openid(cls, provider: str, open_id: str): account_integrate = ( db.session.query(AccountIntegrate) .filter(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id) @@ -134,7 +136,7 @@ class TenantAccountRole(enum.StrEnum): @staticmethod def is_valid_role(role: str) -> bool: - return role and role in { + return role in { TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, @@ -144,15 +146,15 @@ def is_valid_role(role: str) -> bool: @staticmethod def is_privileged_role(role: str) -> bool: - return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN} + return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN} @staticmethod def is_admin_role(role: str) -> bool: - return role and role == TenantAccountRole.ADMIN + return role == TenantAccountRole.ADMIN @staticmethod def is_non_owner_role(role: str) -> bool: - return role and role in { + return role in { TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, TenantAccountRole.NORMAL, @@ -161,11 +163,11 @@ def is_non_owner_role(role: str) -> bool: @staticmethod def is_editing_role(role: str) -> bool: - return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR} + return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR} @staticmethod def is_dataset_edit_role(role: str) -> bool: - return role and role in { + return role in { TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, @@ -173,7 +175,7 @@ def is_dataset_edit_role(role: str) -> bool: } -class Tenant(db.Model): +class Tenant(db.Model): # type: ignore[name-defined] __tablename__ = "tenants" __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) @@ -209,7 +211,7 @@ class TenantAccountJoinRole(enum.Enum): DATASET_OPERATOR = "dataset_operator" -class TenantAccountJoin(db.Model): +class TenantAccountJoin(db.Model): # type: ignore[name-defined] __tablename__ = "tenant_account_joins" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"), @@ -228,7 +230,7 @@ class TenantAccountJoin(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class AccountIntegrate(db.Model): +class AccountIntegrate(db.Model): # type: ignore[name-defined] __tablename__ = "account_integrates" __table_args__ = ( db.PrimaryKeyConstraint("id", name="account_integrate_pkey"), @@ -245,7 +247,7 @@ class AccountIntegrate(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class InvitationCode(db.Model): +class InvitationCode(db.Model): # type: ignore[name-defined] __tablename__ = "invitation_codes" __table_args__ = ( db.PrimaryKeyConstraint("id", name="invitation_code_pkey"), diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index fbffe7a3b2ee9d..6b6d808710afc0 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -13,7 +13,7 @@ class APIBasedExtensionPoint(enum.Enum): APP_MODERATION_OUTPUT = "app.moderation.output" -class APIBasedExtension(db.Model): +class APIBasedExtension(db.Model): # type: ignore[name-defined] __tablename__ = "api_based_extensions" __table_args__ = ( db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"), diff --git a/api/models/dataset.py b/api/models/dataset.py index 7279e8d5b3394a..b9b41dcf475bb1 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -9,6 +9,7 @@ import re import time from json import JSONDecodeError +from typing import Any, cast from sqlalchemy import func from sqlalchemy.dialects.postgresql import JSONB @@ -29,7 +30,7 @@ class DatasetPermissionEnum(enum.StrEnum): PARTIAL_TEAM = "partial_members" -class Dataset(db.Model): +class Dataset(db.Model): # type: ignore[name-defined] __tablename__ = "datasets" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_pkey"), @@ -200,7 +201,7 @@ def gen_collection_name_by_id(dataset_id: str) -> str: return f"Vector_index_{normalized_dataset_id}_Node" -class DatasetProcessRule(db.Model): +class DatasetProcessRule(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_process_rules" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"), @@ -216,7 +217,7 @@ class DatasetProcessRule(db.Model): MODES = ["automatic", "custom"] PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] - AUTOMATIC_RULES = { + AUTOMATIC_RULES: dict[str, Any] = { "pre_processing_rules": [ {"id": "remove_extra_spaces", "enabled": True}, {"id": "remove_urls_emails", "enabled": False}, @@ -242,7 +243,7 @@ def rules_dict(self): return None -class Document(db.Model): +class Document(db.Model): # type: ignore[name-defined] __tablename__ = "documents" __table_args__ = ( db.PrimaryKeyConstraint("id", name="document_pkey"), @@ -492,7 +493,7 @@ def from_dict(cls, data: dict): ) -class DocumentSegment(db.Model): +class DocumentSegment(db.Model): # type: ignore[name-defined] __tablename__ = "document_segments" __table_args__ = ( db.PrimaryKeyConstraint("id", name="document_segment_pkey"), @@ -604,7 +605,7 @@ def get_sign_content(self): return text -class AppDatasetJoin(db.Model): +class AppDatasetJoin(db.Model): # type: ignore[name-defined] __tablename__ = "app_dataset_joins" __table_args__ = ( db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"), @@ -621,7 +622,7 @@ def app(self): return db.session.get(App, self.app_id) -class DatasetQuery(db.Model): +class DatasetQuery(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_queries" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_query_pkey"), @@ -638,7 +639,7 @@ class DatasetQuery(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) -class DatasetKeywordTable(db.Model): +class DatasetKeywordTable(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_keyword_tables" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"), @@ -683,7 +684,7 @@ def object_hook(self, dct): return None -class Embedding(db.Model): +class Embedding(db.Model): # type: ignore[name-defined] __tablename__ = "embeddings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="embedding_pkey"), @@ -704,10 +705,10 @@ def set_embedding(self, embedding_data: list[float]): self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) def get_embedding(self) -> list[float]: - return pickle.loads(self.embedding) + return cast(list[float], pickle.loads(self.embedding)) -class DatasetCollectionBinding(db.Model): +class DatasetCollectionBinding(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_collection_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"), @@ -722,7 +723,7 @@ class DatasetCollectionBinding(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class TidbAuthBinding(db.Model): +class TidbAuthBinding(db.Model): # type: ignore[name-defined] __tablename__ = "tidb_auth_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"), @@ -742,7 +743,7 @@ class TidbAuthBinding(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class Whitelist(db.Model): +class Whitelist(db.Model): # type: ignore[name-defined] __tablename__ = "whitelists" __table_args__ = ( db.PrimaryKeyConstraint("id", name="whitelists_pkey"), @@ -754,7 +755,7 @@ class Whitelist(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class DatasetPermission(db.Model): +class DatasetPermission(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_permissions" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"), @@ -771,7 +772,7 @@ class DatasetPermission(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class ExternalKnowledgeApis(db.Model): +class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined] __tablename__ = "external_knowledge_apis" __table_args__ = ( db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"), @@ -824,7 +825,7 @@ def dataset_bindings(self): return dataset_bindings -class ExternalKnowledgeBindings(db.Model): +class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined] __tablename__ = "external_knowledge_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"), diff --git a/api/models/model.py b/api/models/model.py index 1417298c79c0a2..2a593f08298199 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -4,11 +4,11 @@ from collections.abc import Mapping from datetime import datetime from enum import Enum, StrEnum -from typing import TYPE_CHECKING, Any, Literal, Optional +from typing import TYPE_CHECKING, Any, Literal, Optional, cast import sqlalchemy as sa from flask import request -from flask_login import UserMixin +from flask_login import UserMixin # type: ignore from sqlalchemy import Float, func, text from sqlalchemy.orm import Mapped, mapped_column @@ -28,7 +28,7 @@ from .workflow import Workflow -class DifySetup(db.Model): +class DifySetup(db.Model): # type: ignore[name-defined] __tablename__ = "dify_setups" __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) @@ -63,7 +63,7 @@ class IconType(Enum): EMOJI = "emoji" -class App(db.Model): +class App(db.Model): # type: ignore[name-defined] __tablename__ = "apps" __table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id")) @@ -86,7 +86,7 @@ class App(db.Model): is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) tracing = db.Column(db.Text, nullable=True) - max_active_requests = db.Column(db.Integer, nullable=True) + max_active_requests: Mapped[Optional[int]] = mapped_column(nullable=True) created_by = db.Column(StringUUID, nullable=True) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) @@ -154,7 +154,7 @@ def mode_compatible_with_agent(self) -> str: if self.mode == AppMode.CHAT.value and self.is_agent: return AppMode.AGENT_CHAT.value - return self.mode + return str(self.mode) @property def deleted_tools(self) -> list: @@ -219,7 +219,7 @@ def tags(self): return tags or [] -class AppModelConfig(db.Model): +class AppModelConfig(db.Model): # type: ignore[name-defined] __tablename__ = "app_model_configs" __table_args__ = (db.PrimaryKeyConstraint("id", name="app_model_config_pkey"), db.Index("app_app_id_idx", "app_id")) @@ -322,7 +322,7 @@ def external_data_tools_list(self) -> list[dict]: return json.loads(self.external_data_tools) if self.external_data_tools else [] @property - def user_input_form_list(self) -> dict: + def user_input_form_list(self) -> list[dict]: return json.loads(self.user_input_form) if self.user_input_form else [] @property @@ -344,7 +344,7 @@ def completion_prompt_config_dict(self) -> dict: @property def dataset_configs_dict(self) -> dict: if self.dataset_configs: - dataset_configs = json.loads(self.dataset_configs) + dataset_configs: dict = json.loads(self.dataset_configs) if "retrieval_model" not in dataset_configs: return {"retrieval_model": "single"} else: @@ -466,7 +466,7 @@ def copy(self): return new_app_model_config -class RecommendedApp(db.Model): +class RecommendedApp(db.Model): # type: ignore[name-defined] __tablename__ = "recommended_apps" __table_args__ = ( db.PrimaryKeyConstraint("id", name="recommended_app_pkey"), @@ -494,7 +494,7 @@ def app(self): return app -class InstalledApp(db.Model): +class InstalledApp(db.Model): # type: ignore[name-defined] __tablename__ = "installed_apps" __table_args__ = ( db.PrimaryKeyConstraint("id", name="installed_app_pkey"), @@ -523,7 +523,7 @@ def tenant(self): return tenant -class Conversation(db.Model): +class Conversation(db.Model): # type: ignore[name-defined] __tablename__ = "conversations" __table_args__ = ( db.PrimaryKeyConstraint("id", name="conversation_pkey"), @@ -602,6 +602,8 @@ def inputs(self, value: Mapping[str, Any]): @property def model_config(self): model_config = {} + app_model_config: Optional[AppModelConfig] = None + if self.mode == AppMode.ADVANCED_CHAT.value: if self.override_model_configs: override_model_configs = json.loads(self.override_model_configs) @@ -613,6 +615,7 @@ def model_config(self): if "model" in override_model_configs: app_model_config = AppModelConfig() app_model_config = app_model_config.from_model_config_dict(override_model_configs) + assert app_model_config is not None, "app model config not found" model_config = app_model_config.to_dict() else: model_config["configs"] = override_model_configs @@ -755,7 +758,7 @@ def in_debug_mode(self): return self.override_model_configs is not None -class Message(db.Model): +class Message(db.Model): # type: ignore[name-defined] __tablename__ = "messages" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_pkey"), @@ -995,7 +998,7 @@ def message_files(self): if not current_app: raise ValueError(f"App {self.app_id} not found") - files: list[File] = [] + files = [] for message_file in message_files: if message_file.transfer_method == "local_file": if message_file.upload_file_id is None: @@ -1102,7 +1105,7 @@ def from_dict(cls, data: dict): ) -class MessageFeedback(db.Model): +class MessageFeedback(db.Model): # type: ignore[name-defined] __tablename__ = "message_feedbacks" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_feedback_pkey"), @@ -1129,7 +1132,7 @@ def from_account(self): return account -class MessageFile(db.Model): +class MessageFile(db.Model): # type: ignore[name-defined] __tablename__ = "message_files" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_file_pkey"), @@ -1170,7 +1173,7 @@ def __init__( created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class MessageAnnotation(db.Model): +class MessageAnnotation(db.Model): # type: ignore[name-defined] __tablename__ = "message_annotations" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_annotation_pkey"), @@ -1201,7 +1204,7 @@ def annotation_create_account(self): return account -class AppAnnotationHitHistory(db.Model): +class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined] __tablename__ = "app_annotation_hit_histories" __table_args__ = ( db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"), @@ -1239,7 +1242,7 @@ def annotation_create_account(self): return account -class AppAnnotationSetting(db.Model): +class AppAnnotationSetting(db.Model): # type: ignore[name-defined] __tablename__ = "app_annotation_settings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"), @@ -1287,7 +1290,7 @@ def collection_binding_detail(self): return collection_binding_detail -class OperationLog(db.Model): +class OperationLog(db.Model): # type: ignore[name-defined] __tablename__ = "operation_logs" __table_args__ = ( db.PrimaryKeyConstraint("id", name="operation_log_pkey"), @@ -1304,7 +1307,7 @@ class OperationLog(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class EndUser(UserMixin, db.Model): +class EndUser(UserMixin, db.Model): # type: ignore[name-defined] __tablename__ = "end_users" __table_args__ = ( db.PrimaryKeyConstraint("id", name="end_user_pkey"), @@ -1324,7 +1327,7 @@ class EndUser(UserMixin, db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class Site(db.Model): +class Site(db.Model): # type: ignore[name-defined] __tablename__ = "sites" __table_args__ = ( db.PrimaryKeyConstraint("id", name="site_pkey"), @@ -1381,7 +1384,7 @@ def app_base_url(self): return dify_config.APP_WEB_URL or request.url_root.rstrip("/") -class ApiToken(db.Model): +class ApiToken(db.Model): # type: ignore[name-defined] __tablename__ = "api_tokens" __table_args__ = ( db.PrimaryKeyConstraint("id", name="api_token_pkey"), @@ -1408,7 +1411,7 @@ def generate_api_key(prefix, n): return result -class UploadFile(db.Model): +class UploadFile(db.Model): # type: ignore[name-defined] __tablename__ = "upload_files" __table_args__ = ( db.PrimaryKeyConstraint("id", name="upload_file_pkey"), @@ -1470,7 +1473,7 @@ def __init__( self.source_url = source_url -class ApiRequest(db.Model): +class ApiRequest(db.Model): # type: ignore[name-defined] __tablename__ = "api_requests" __table_args__ = ( db.PrimaryKeyConstraint("id", name="api_request_pkey"), @@ -1487,7 +1490,7 @@ class ApiRequest(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class MessageChain(db.Model): +class MessageChain(db.Model): # type: ignore[name-defined] __tablename__ = "message_chains" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_chain_pkey"), @@ -1502,7 +1505,7 @@ class MessageChain(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) -class MessageAgentThought(db.Model): +class MessageAgentThought(db.Model): # type: ignore[name-defined] __tablename__ = "message_agent_thoughts" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"), @@ -1542,7 +1545,7 @@ class MessageAgentThought(db.Model): @property def files(self) -> list: if self.message_files: - return json.loads(self.message_files) + return cast(list[Any], json.loads(self.message_files)) else: return [] @@ -1554,7 +1557,7 @@ def tools(self) -> list[str]: def tool_labels(self) -> dict: try: if self.tool_labels_str: - return json.loads(self.tool_labels_str) + return cast(dict, json.loads(self.tool_labels_str)) else: return {} except Exception as e: @@ -1564,7 +1567,7 @@ def tool_labels(self) -> dict: def tool_meta(self) -> dict: try: if self.tool_meta_str: - return json.loads(self.tool_meta_str) + return cast(dict, json.loads(self.tool_meta_str)) else: return {} except Exception as e: @@ -1612,9 +1615,11 @@ def tool_outputs_dict(self) -> dict: except Exception as e: if self.observation: return dict.fromkeys(tools, self.observation) + else: + return {} -class DatasetRetrieverResource(db.Model): +class DatasetRetrieverResource(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_retriever_resources" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"), @@ -1641,7 +1646,7 @@ class DatasetRetrieverResource(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) -class Tag(db.Model): +class Tag(db.Model): # type: ignore[name-defined] __tablename__ = "tags" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tag_pkey"), @@ -1659,7 +1664,7 @@ class Tag(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class TagBinding(db.Model): +class TagBinding(db.Model): # type: ignore[name-defined] __tablename__ = "tag_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tag_binding_pkey"), @@ -1675,7 +1680,7 @@ class TagBinding(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class TraceAppConfig(db.Model): +class TraceAppConfig(db.Model): # type: ignore[name-defined] __tablename__ = "trace_app_config" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"), diff --git a/api/models/provider.py b/api/models/provider.py index fdd3e802d79211..abe673975c1ccc 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -36,7 +36,7 @@ def value_of(value): raise ValueError(f"No matching enum found for value '{value}'") -class Provider(db.Model): +class Provider(db.Model): # type: ignore[name-defined] """ Provider model representing the API providers and their configurations. """ @@ -89,7 +89,7 @@ def is_enabled(self): return self.is_valid and self.token_is_set -class ProviderModel(db.Model): +class ProviderModel(db.Model): # type: ignore[name-defined] """ Provider model representing the API provider_models and their configurations. """ @@ -114,7 +114,7 @@ class ProviderModel(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class TenantDefaultModel(db.Model): +class TenantDefaultModel(db.Model): # type: ignore[name-defined] __tablename__ = "tenant_default_models" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"), @@ -130,7 +130,7 @@ class TenantDefaultModel(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class TenantPreferredModelProvider(db.Model): +class TenantPreferredModelProvider(db.Model): # type: ignore[name-defined] __tablename__ = "tenant_preferred_model_providers" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"), @@ -145,7 +145,7 @@ class TenantPreferredModelProvider(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class ProviderOrder(db.Model): +class ProviderOrder(db.Model): # type: ignore[name-defined] __tablename__ = "provider_orders" __table_args__ = ( db.PrimaryKeyConstraint("id", name="provider_order_pkey"), @@ -170,7 +170,7 @@ class ProviderOrder(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class ProviderModelSetting(db.Model): +class ProviderModelSetting(db.Model): # type: ignore[name-defined] """ Provider model settings for record the model enabled status and load balancing status. """ @@ -192,7 +192,7 @@ class ProviderModelSetting(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class LoadBalancingModelConfig(db.Model): +class LoadBalancingModelConfig(db.Model): # type: ignore[name-defined] """ Configurations for load balancing models. """ diff --git a/api/models/source.py b/api/models/source.py index 114db8e1100e5d..881cfaac7d3998 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -7,7 +7,7 @@ from .types import StringUUID -class DataSourceOauthBinding(db.Model): +class DataSourceOauthBinding(db.Model): # type: ignore[name-defined] __tablename__ = "data_source_oauth_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="source_binding_pkey"), @@ -25,7 +25,7 @@ class DataSourceOauthBinding(db.Model): disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) -class DataSourceApiKeyAuthBinding(db.Model): +class DataSourceApiKeyAuthBinding(db.Model): # type: ignore[name-defined] __tablename__ = "data_source_api_key_auth_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"), diff --git a/api/models/task.py b/api/models/task.py index 27571e24746fe7..0db1c632299fcb 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -1,11 +1,11 @@ from datetime import UTC, datetime -from celery import states +from celery import states # type: ignore from .engine import db -class CeleryTask(db.Model): +class CeleryTask(db.Model): # type: ignore[name-defined] """Task result/status.""" __tablename__ = "celery_taskmeta" @@ -29,7 +29,7 @@ class CeleryTask(db.Model): queue = db.Column(db.String(155), nullable=True) -class CeleryTaskSet(db.Model): +class CeleryTaskSet(db.Model): # type: ignore[name-defined] """TaskSet result.""" __tablename__ = "celery_tasksetmeta" diff --git a/api/models/tools.py b/api/models/tools.py index e90ab669c66f1e..4151a2e9f636a0 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -14,7 +14,7 @@ from .types import StringUUID -class BuiltinToolProvider(db.Model): +class BuiltinToolProvider(db.Model): # type: ignore[name-defined] """ This table stores the tool provider information for built-in tools for each tenant. """ @@ -41,10 +41,10 @@ class BuiltinToolProvider(db.Model): @property def credentials(self) -> dict: - return json.loads(self.encrypted_credentials) + return dict(json.loads(self.encrypted_credentials)) -class PublishedAppTool(db.Model): +class PublishedAppTool(db.Model): # type: ignore[name-defined] """ The table stores the apps published as a tool for each person. """ @@ -86,7 +86,7 @@ def app(self): return db.session.query(App).filter(App.id == self.app_id).first() -class ApiToolProvider(db.Model): +class ApiToolProvider(db.Model): # type: ignore[name-defined] """ The table stores the api providers. """ @@ -133,7 +133,7 @@ def tools(self) -> list[ApiToolBundle]: @property def credentials(self) -> dict: - return json.loads(self.credentials_str) + return dict(json.loads(self.credentials_str)) @property def user(self) -> Account | None: @@ -144,7 +144,7 @@ def tenant(self) -> Tenant | None: return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() -class ToolLabelBinding(db.Model): +class ToolLabelBinding(db.Model): # type: ignore[name-defined] """ The table stores the labels for tools. """ @@ -164,7 +164,7 @@ class ToolLabelBinding(db.Model): label_name = db.Column(db.String(40), nullable=False) -class WorkflowToolProvider(db.Model): +class WorkflowToolProvider(db.Model): # type: ignore[name-defined] """ The table stores the workflow providers. """ @@ -218,7 +218,7 @@ def app(self) -> App | None: return db.session.query(App).filter(App.id == self.app_id).first() -class ToolModelInvoke(db.Model): +class ToolModelInvoke(db.Model): # type: ignore[name-defined] """ store the invoke logs from tool invoke """ @@ -255,7 +255,7 @@ class ToolModelInvoke(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class ToolConversationVariables(db.Model): +class ToolConversationVariables(db.Model): # type: ignore[name-defined] """ store the conversation variables from tool invoke """ @@ -283,10 +283,10 @@ class ToolConversationVariables(db.Model): @property def variables(self) -> dict: - return json.loads(self.variables_str) + return dict(json.loads(self.variables_str)) -class ToolFile(db.Model): +class ToolFile(db.Model): # type: ignore[name-defined] __tablename__ = "tool_files" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tool_file_pkey"), diff --git a/api/models/web.py b/api/models/web.py index 028a768519d99a..864428fe0931b6 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -6,7 +6,7 @@ from .types import StringUUID -class SavedMessage(db.Model): +class SavedMessage(db.Model): # type: ignore[name-defined] __tablename__ = "saved_messages" __table_args__ = ( db.PrimaryKeyConstraint("id", name="saved_message_pkey"), @@ -25,7 +25,7 @@ def message(self): return db.session.query(Message).filter(Message.id == self.message_id).first() -class PinnedConversation(db.Model): +class PinnedConversation(db.Model): # type: ignore[name-defined] __tablename__ = "pinned_conversations" __table_args__ = ( db.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"), diff --git a/api/models/workflow.py b/api/models/workflow.py index d5be949bf44f2a..880e044d073a67 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -2,7 +2,7 @@ from collections.abc import Mapping, Sequence from datetime import UTC, datetime from enum import Enum, StrEnum -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import sqlalchemy as sa from sqlalchemy import func @@ -20,6 +20,9 @@ from .engine import db from .types import StringUUID +if TYPE_CHECKING: + from models.model import AppMode, Message + class WorkflowType(Enum): """ @@ -56,7 +59,7 @@ def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType": return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT -class Workflow(db.Model): +class Workflow(db.Model): # type: ignore[name-defined] """ Workflow, for `Workflow App` and `Chat App workflow mode`. @@ -182,7 +185,7 @@ def features(self, value: str) -> None: self._features = value @property - def features_dict(self) -> Mapping[str, Any]: + def features_dict(self) -> dict[str, Any]: return json.loads(self.features) if self.features else {} def user_input_form(self, to_old_structure: bool = False) -> list: @@ -199,7 +202,7 @@ def user_input_form(self, to_old_structure: bool = False) -> list: return [] # get user_input_form from start node - variables = start_node.get("data", {}).get("variables", []) + variables: list[Any] = start_node.get("data", {}).get("variables", []) if to_old_structure: old_structure_variables = [] @@ -344,7 +347,7 @@ def value_of(cls, value: str) -> "WorkflowRunStatus": raise ValueError(f"invalid workflow run status value {value}") -class WorkflowRun(db.Model): +class WorkflowRun(db.Model): # type: ignore[name-defined] """ Workflow Run @@ -546,7 +549,7 @@ def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus": raise ValueError(f"invalid workflow node execution status value {value}") -class WorkflowNodeExecution(db.Model): +class WorkflowNodeExecution(db.Model): # type: ignore[name-defined] """ Workflow Node Execution @@ -712,7 +715,7 @@ def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom": raise ValueError(f"invalid workflow app log created from value {value}") -class WorkflowAppLog(db.Model): +class WorkflowAppLog(db.Model): # type: ignore[name-defined] """ Workflow App execution log, excluding workflow debugging records. @@ -774,7 +777,7 @@ def created_by_end_user(self): return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None -class ConversationVariable(db.Model): +class ConversationVariable(db.Model): # type: ignore[name-defined] __tablename__ = "workflow_conversation_variables" id: Mapped[str] = db.Column(StringUUID, primary_key=True) diff --git a/api/mypy.ini b/api/mypy.ini new file mode 100644 index 00000000000000..2c754f9fcd7c63 --- /dev/null +++ b/api/mypy.ini @@ -0,0 +1,10 @@ +[mypy] +warn_return_any = True +warn_unused_configs = True +check_untyped_defs = True +exclude = (?x)( + core/tools/provider/builtin/ + | core/model_runtime/model_providers/ + | tests/ + | migrations/ + ) \ No newline at end of file diff --git a/api/poetry.lock b/api/poetry.lock index 35fda9b36fa42a..b42eb22dd40b8a 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -5643,6 +5643,58 @@ files = [ {file = "multitasking-0.0.11.tar.gz", hash = "sha256:4d6bc3cc65f9b2dca72fb5a787850a88dae8f620c2b36ae9b55248e51bcd6026"}, ] +[[package]] +name = "mypy" +version = "1.13.0" +description = "Optional static typing for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "mypy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6607e0f1dd1fb7f0aca14d936d13fd19eba5e17e1cd2a14f808fa5f8f6d8f60a"}, + {file = "mypy-1.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8a21be69bd26fa81b1f80a61ee7ab05b076c674d9b18fb56239d72e21d9f4c80"}, + {file = "mypy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b2353a44d2179846a096e25691d54d59904559f4232519d420d64da6828a3a7"}, + {file = "mypy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0730d1c6a2739d4511dc4253f8274cdd140c55c32dfb0a4cf8b7a43f40abfa6f"}, + {file = "mypy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:c5fc54dbb712ff5e5a0fca797e6e0aa25726c7e72c6a5850cfd2adbc1eb0a372"}, + {file = "mypy-1.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:581665e6f3a8a9078f28d5502f4c334c0c8d802ef55ea0e7276a6e409bc0d82d"}, + {file = "mypy-1.13.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3ddb5b9bf82e05cc9a627e84707b528e5c7caaa1c55c69e175abb15a761cec2d"}, + {file = "mypy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:20c7ee0bc0d5a9595c46f38beb04201f2620065a93755704e141fcac9f59db2b"}, + {file = "mypy-1.13.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3790ded76f0b34bc9c8ba4def8f919dd6a46db0f5a6610fb994fe8efdd447f73"}, + {file = "mypy-1.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:51f869f4b6b538229c1d1bcc1dd7d119817206e2bc54e8e374b3dfa202defcca"}, + {file = "mypy-1.13.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5c7051a3461ae84dfb5dd15eff5094640c61c5f22257c8b766794e6dd85e72d5"}, + {file = "mypy-1.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:39bb21c69a5d6342f4ce526e4584bc5c197fd20a60d14a8624d8743fffb9472e"}, + {file = "mypy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:164f28cb9d6367439031f4c81e84d3ccaa1e19232d9d05d37cb0bd880d3f93c2"}, + {file = "mypy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a4c1bfcdbce96ff5d96fc9b08e3831acb30dc44ab02671eca5953eadad07d6d0"}, + {file = "mypy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:a0affb3a79a256b4183ba09811e3577c5163ed06685e4d4b46429a271ba174d2"}, + {file = "mypy-1.13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a7b44178c9760ce1a43f544e595d35ed61ac2c3de306599fa59b38a6048e1aa7"}, + {file = "mypy-1.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5d5092efb8516d08440e36626f0153b5006d4088c1d663d88bf79625af3d1d62"}, + {file = "mypy-1.13.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:de2904956dac40ced10931ac967ae63c5089bd498542194b436eb097a9f77bc8"}, + {file = "mypy-1.13.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:7bfd8836970d33c2105562650656b6846149374dc8ed77d98424b40b09340ba7"}, + {file = "mypy-1.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:9f73dba9ec77acb86457a8fc04b5239822df0c14a082564737833d2963677dbc"}, + {file = "mypy-1.13.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:100fac22ce82925f676a734af0db922ecfea991e1d7ec0ceb1e115ebe501301a"}, + {file = "mypy-1.13.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7bcb0bb7f42a978bb323a7c88f1081d1b5dee77ca86f4100735a6f541299d8fb"}, + {file = "mypy-1.13.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bde31fc887c213e223bbfc34328070996061b0833b0a4cfec53745ed61f3519b"}, + {file = "mypy-1.13.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:07de989f89786f62b937851295ed62e51774722e5444a27cecca993fc3f9cd74"}, + {file = "mypy-1.13.0-cp38-cp38-win_amd64.whl", hash = "sha256:4bde84334fbe19bad704b3f5b78c4abd35ff1026f8ba72b29de70dda0916beb6"}, + {file = "mypy-1.13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0246bcb1b5de7f08f2826451abd947bf656945209b140d16ed317f65a17dc7dc"}, + {file = "mypy-1.13.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7f5b7deae912cf8b77e990b9280f170381fdfbddf61b4ef80927edd813163732"}, + {file = "mypy-1.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7029881ec6ffb8bc233a4fa364736789582c738217b133f1b55967115288a2bc"}, + {file = "mypy-1.13.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3e38b980e5681f28f033f3be86b099a247b13c491f14bb8b1e1e134d23bb599d"}, + {file = "mypy-1.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:a6789be98a2017c912ae6ccb77ea553bbaf13d27605d2ca20a76dfbced631b24"}, + {file = "mypy-1.13.0-py3-none-any.whl", hash = "sha256:9c250883f9fd81d212e0952c92dbfcc96fc237f4b7c92f56ac81fd48460b3e5a"}, + {file = "mypy-1.13.0.tar.gz", hash = "sha256:0291a61b6fbf3e6673e3405cfcc0e7650bebc7939659fdca2702958038bd835e"}, +] + +[package.dependencies] +mypy-extensions = ">=1.0.0" +typing-extensions = ">=4.6.0" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +faster-cache = ["orjson"] +install-types = ["pip"] +mypyc = ["setuptools (>=50)"] +reports = ["lxml"] + [[package]] name = "mypy-extensions" version = "1.0.0" @@ -6537,6 +6589,21 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] xml = ["lxml (>=4.9.2)"] +[[package]] +name = "pandas-stubs" +version = "2.2.3.241126" +description = "Type annotations for pandas" +optional = false +python-versions = ">=3.10" +files = [ + {file = "pandas_stubs-2.2.3.241126-py3-none-any.whl", hash = "sha256:74aa79c167af374fe97068acc90776c0ebec5266a6e5c69fe11e9c2cf51f2267"}, + {file = "pandas_stubs-2.2.3.241126.tar.gz", hash = "sha256:cf819383c6d9ae7d4dabf34cd47e1e45525bb2f312e6ad2939c2c204cb708acd"}, +] + +[package.dependencies] +numpy = ">=1.23.5" +types-pytz = ">=2022.1.1" + [[package]] name = "pathos" version = "0.3.3" @@ -9255,13 +9322,13 @@ sqlcipher = ["sqlcipher3_binary"] [[package]] name = "sqlparse" -version = "0.5.2" +version = "0.5.3" description = "A non-validating SQL parser." optional = false python-versions = ">=3.8" files = [ - {file = "sqlparse-0.5.2-py3-none-any.whl", hash = "sha256:e99bc85c78160918c3e1d9230834ab8d80fc06c59d03f8db2618f65f65dda55e"}, - {file = "sqlparse-0.5.2.tar.gz", hash = "sha256:9e37b35e16d1cc652a2545f0997c1deb23ea28fa1f3eefe609eee3063c3b105f"}, + {file = "sqlparse-0.5.3-py3-none-any.whl", hash = "sha256:cf2196ed3418f3ba5de6af7e82c694a9fbdbfecccdfc72e281548517081f16ca"}, + {file = "sqlparse-0.5.3.tar.gz", hash = "sha256:09f67787f56a0b16ecdbde1bfc7f5d9c3371ca683cfeaa8e6ff60b4807ec9272"}, ] [package.extras] @@ -9847,6 +9914,17 @@ rich = ">=10.11.0" shellingham = ">=1.3.0" typing-extensions = ">=3.7.4.3" +[[package]] +name = "types-pytz" +version = "2024.2.0.20241003" +description = "Typing stubs for pytz" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-pytz-2024.2.0.20241003.tar.gz", hash = "sha256:575dc38f385a922a212bac00a7d6d2e16e141132a3c955078f4a4fd13ed6cb44"}, + {file = "types_pytz-2024.2.0.20241003-py3-none-any.whl", hash = "sha256:3e22df1336c0c6ad1d29163c8fda82736909eb977281cb823c57f8bae07118b7"}, +] + [[package]] name = "types-requests" version = "2.32.0.20241016" @@ -10313,82 +10391,82 @@ ark = ["anyio (>=3.5.0,<5)", "cached-property", "httpx (>=0.23.0,<1)", "pydantic [[package]] name = "watchfiles" -version = "1.0.0" +version = "1.0.3" description = "Simple, modern and high performance file watching and code reload in python." optional = false python-versions = ">=3.9" files = [ - {file = "watchfiles-1.0.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:1d19df28f99d6a81730658fbeb3ade8565ff687f95acb59665f11502b441be5f"}, - {file = "watchfiles-1.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:28babb38cf2da8e170b706c4b84aa7e4528a6fa4f3ee55d7a0866456a1662041"}, - {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:12ab123135b2f42517f04e720526d41448667ae8249e651385afb5cda31fedc0"}, - {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:13a4f9ee0cd25682679eea5c14fc629e2eaa79aab74d963bc4e21f43b8ea1877"}, - {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9e1d9284cc84de7855fcf83472e51d32daf6f6cecd094160192628bc3fee1b78"}, - {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ee5edc939f53466b329bbf2e58333a5461e6c7b50c980fa6117439e2c18b42d"}, - {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5dccfc70480087567720e4e36ec381bba1ed68d7e5f368fe40c93b3b1eba0105"}, - {file = "watchfiles-1.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c83a6d33a9eda0af6a7470240d1af487807adc269704fe76a4972dd982d16236"}, - {file = "watchfiles-1.0.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:905f69aad276639eff3893759a07d44ea99560e67a1cf46ff389cd62f88872a2"}, - {file = "watchfiles-1.0.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:09551237645d6bff3972592f2aa5424df9290e7a2e15d63c5f47c48cde585935"}, - {file = "watchfiles-1.0.0-cp310-none-win32.whl", hash = "sha256:d2b39aa8edd9e5f56f99a2a2740a251dc58515398e9ed5a4b3e5ff2827060755"}, - {file = "watchfiles-1.0.0-cp310-none-win_amd64.whl", hash = "sha256:2de52b499e1ab037f1a87cb8ebcb04a819bf087b1015a4cf6dcf8af3c2a2613e"}, - {file = "watchfiles-1.0.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:fbd0ab7a9943bbddb87cbc2bf2f09317e74c77dc55b1f5657f81d04666c25269"}, - {file = "watchfiles-1.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:774ef36b16b7198669ce655d4f75b4c3d370e7f1cbdfb997fb10ee98717e2058"}, - {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b4fb98100267e6a5ebaff6aaa5d20aea20240584647470be39fe4823012ac96"}, - {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0fc3bf0effa2d8075b70badfdd7fb839d7aa9cea650d17886982840d71fdeabf"}, - {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:648e2b6db53eca6ef31245805cd528a16f56fa4cc15aeec97795eaf713c11435"}, - {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fa13d604fcb9417ae5f2e3de676e66aa97427d888e83662ad205bed35a313176"}, - {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:936f362e7ff28311b16f0b97ec51e8f2cc451763a3264640c6ed40fb252d1ee4"}, - {file = "watchfiles-1.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:245fab124b9faf58430da547512d91734858df13f2ddd48ecfa5e493455ffccb"}, - {file = "watchfiles-1.0.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:4ff9c7e84e8b644a8f985c42bcc81457240316f900fc72769aaedec9d088055a"}, - {file = "watchfiles-1.0.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9c9a8d8fd97defe935ef8dd53d562e68942ad65067cd1c54d6ed8a088b1d931d"}, - {file = "watchfiles-1.0.0-cp311-none-win32.whl", hash = "sha256:a0abf173975eb9dd17bb14c191ee79999e650997cc644562f91df06060610e62"}, - {file = "watchfiles-1.0.0-cp311-none-win_amd64.whl", hash = "sha256:2a825ba4b32c214e3855b536eb1a1f7b006511d8e64b8215aac06eb680642d84"}, - {file = "watchfiles-1.0.0-cp311-none-win_arm64.whl", hash = "sha256:a5a7a06cfc65e34fd0a765a7623c5ba14707a0870703888e51d3d67107589817"}, - {file = "watchfiles-1.0.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:28fb64b5843d94e2c2483f7b024a1280662a44409bedee8f2f51439767e2d107"}, - {file = "watchfiles-1.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e3750434c83b61abb3163b49c64b04180b85b4dabb29a294513faec57f2ffdb7"}, - {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bedf84835069f51c7b026b3ca04e2e747ea8ed0a77c72006172c72d28c9f69fc"}, - {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:90004553be36427c3d06ec75b804233f8f816374165d5225b93abd94ba6e7234"}, - {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b46e15c34d4e401e976d6949ad3a74d244600d5c4b88c827a3fdf18691a46359"}, - {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:487d15927f1b0bd24e7df921913399bb1ab94424c386bea8b267754d698f8f0e"}, - {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1ff236d7a3f4b0a42f699a22fc374ba526bc55048a70cbb299661158e1bb5e1f"}, - {file = "watchfiles-1.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c01446626574561756067f00b37e6b09c8622b0fc1e9fdbc7cbcea328d4e514"}, - {file = "watchfiles-1.0.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b551c465a59596f3d08170bd7e1c532c7260dd90ed8135778038e13c5d48aa81"}, - {file = "watchfiles-1.0.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e1ed613ee107269f66c2df631ec0fc8efddacface85314d392a4131abe299f00"}, - {file = "watchfiles-1.0.0-cp312-none-win32.whl", hash = "sha256:5f75cd42e7e2254117cf37ff0e68c5b3f36c14543756b2da621408349bd9ca7c"}, - {file = "watchfiles-1.0.0-cp312-none-win_amd64.whl", hash = "sha256:cf517701a4a872417f4e02a136e929537743461f9ec6cdb8184d9a04f4843545"}, - {file = "watchfiles-1.0.0-cp312-none-win_arm64.whl", hash = "sha256:8a2127cd68950787ee36753e6d401c8ea368f73beaeb8e54df5516a06d1ecd82"}, - {file = "watchfiles-1.0.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:95de85c254f7fe8cbdf104731f7f87f7f73ae229493bebca3722583160e6b152"}, - {file = "watchfiles-1.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:533a7cbfe700e09780bb31c06189e39c65f06c7f447326fee707fd02f9a6e945"}, - {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a2218e78e2c6c07b1634a550095ac2a429026b2d5cbcd49a594f893f2bb8c936"}, - {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9122b8fdadc5b341315d255ab51d04893f417df4e6c1743b0aac8bf34e96e025"}, - {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9272fdbc0e9870dac3b505bce1466d386b4d8d6d2bacf405e603108d50446940"}, - {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4a3b33c3aefe9067ebd87846806cd5fc0b017ab70d628aaff077ab9abf4d06b3"}, - {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bc338ce9f8846543d428260fa0f9a716626963148edc937d71055d01d81e1525"}, - {file = "watchfiles-1.0.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ac778a460ea22d63c7e6fb0bc0f5b16780ff0b128f7f06e57aaec63bd339285"}, - {file = "watchfiles-1.0.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:53ae447f06f8f29f5ab40140f19abdab822387a7c426a369eb42184b021e97eb"}, - {file = "watchfiles-1.0.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:1f73c2147a453315d672c1ad907abe6d40324e34a185b51e15624bc793f93cc6"}, - {file = "watchfiles-1.0.0-cp313-none-win32.whl", hash = "sha256:eba98901a2eab909dbd79681190b9049acc650f6111fde1845484a4450761e98"}, - {file = "watchfiles-1.0.0-cp313-none-win_amd64.whl", hash = "sha256:d562a6114ddafb09c33246c6ace7effa71ca4b6a2324a47f4b09b6445ea78941"}, - {file = "watchfiles-1.0.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:3d94fd83ed54266d789f287472269c0def9120a2022674990bd24ad989ebd7a0"}, - {file = "watchfiles-1.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:48051d1c504448b2fcda71c5e6e3610ae45de6a0b8f5a43b961f250be4bdf5a8"}, - {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29cf884ad4285d23453c702ed03d689f9c0e865e3c85d20846d800d4787de00f"}, - {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d3572d4c34c4e9c33d25b3da47d9570d5122f8433b9ac6519dca49c2740d23cd"}, - {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c2696611182c85eb0e755b62b456f48debff484b7306b56f05478b843ca8ece"}, - {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:550109001920a993a4383b57229c717fa73627d2a4e8fcb7ed33c7f1cddb0c85"}, - {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b555a93c15bd2c71081922be746291d776d47521a00703163e5fbe6d2a402399"}, - {file = "watchfiles-1.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:947ccba18a38b85c366dafeac8df2f6176342d5992ca240a9d62588b214d731f"}, - {file = "watchfiles-1.0.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ffd98a299b0a74d1b704ef0ed959efb753e656a4e0425c14e46ae4c3cbdd2919"}, - {file = "watchfiles-1.0.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:f8c4f3a1210ed099a99e6a710df4ff2f8069411059ffe30fa5f9467ebed1256b"}, - {file = "watchfiles-1.0.0-cp39-none-win32.whl", hash = "sha256:1e176b6b4119b3f369b2b4e003d53a226295ee862c0962e3afd5a1c15680b4e3"}, - {file = "watchfiles-1.0.0-cp39-none-win_amd64.whl", hash = "sha256:2d9c0518fabf4a3f373b0a94bb9e4ea7a1df18dec45e26a4d182aa8918dee855"}, - {file = "watchfiles-1.0.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f159ac795785cde4899e0afa539f4c723fb5dd336ce5605bc909d34edd00b79b"}, - {file = "watchfiles-1.0.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:c3d258d78341d5d54c0c804a5b7faa66cd30ba50b2756a7161db07ce15363b8d"}, - {file = "watchfiles-1.0.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5bbd0311588c2de7f9ea5cf3922ccacfd0ec0c1922870a2be503cc7df1ca8be7"}, - {file = "watchfiles-1.0.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9a13ac46b545a7d0d50f7641eefe47d1597e7d1783a5d89e09d080e6dff44b0"}, - {file = "watchfiles-1.0.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b2bca898c1dc073912d3db7fa6926cc08be9575add9e84872de2c99c688bac4e"}, - {file = "watchfiles-1.0.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:06d828fe2adc4ac8a64b875ca908b892a3603d596d43e18f7948f3fef5fc671c"}, - {file = "watchfiles-1.0.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:074c7618cd6c807dc4eaa0982b4a9d3f8051cd0b72793511848fd64630174b17"}, - {file = "watchfiles-1.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95dc785bc284552d044e561b8f4fe26d01ab5ca40d35852a6572d542adfeb4bc"}, - {file = "watchfiles-1.0.0.tar.gz", hash = "sha256:37566c844c9ce3b5deb964fe1a23378e575e74b114618d211fbda8f59d7b5dab"}, + {file = "watchfiles-1.0.3-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:1da46bb1eefb5a37a8fb6fd52ad5d14822d67c498d99bda8754222396164ae42"}, + {file = "watchfiles-1.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2b961b86cd3973f5822826017cad7f5a75795168cb645c3a6b30c349094e02e3"}, + {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:34e87c7b3464d02af87f1059fedda5484e43b153ef519e4085fe1a03dd94801e"}, + {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d9dd2b89a16cf7ab9c1170b5863e68de6bf83db51544875b25a5f05a7269e678"}, + {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2b4691234d31686dca133c920f94e478b548a8e7c750f28dbbc2e4333e0d3da9"}, + {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:90b0fe1fcea9bd6e3084b44875e179b4adcc4057a3b81402658d0eb58c98edf8"}, + {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0b90651b4cf9e158d01faa0833b073e2e37719264bcee3eac49fc3c74e7d304b"}, + {file = "watchfiles-1.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2e9fe695ff151b42ab06501820f40d01310fbd58ba24da8923ace79cf6d702d"}, + {file = "watchfiles-1.0.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62691f1c0894b001c7cde1195c03b7801aaa794a837bd6eef24da87d1542838d"}, + {file = "watchfiles-1.0.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:275c1b0e942d335fccb6014d79267d1b9fa45b5ac0639c297f1e856f2f532552"}, + {file = "watchfiles-1.0.3-cp310-cp310-win32.whl", hash = "sha256:06ce08549e49ba69ccc36fc5659a3d0ff4e3a07d542b895b8a9013fcab46c2dc"}, + {file = "watchfiles-1.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:f280b02827adc9d87f764972fbeb701cf5611f80b619c20568e1982a277d6146"}, + {file = "watchfiles-1.0.3-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:ffe709b1d0bc2e9921257569675674cafb3a5f8af689ab9f3f2b3f88775b960f"}, + {file = "watchfiles-1.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:418c5ce332f74939ff60691e5293e27c206c8164ce2b8ce0d9abf013003fb7fe"}, + {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2f492d2907263d6d0d52f897a68647195bc093dafed14508a8d6817973586b6b"}, + {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:48c9f3bc90c556a854f4cab6a79c16974099ccfa3e3e150673d82d47a4bc92c9"}, + {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:75d3bcfa90454dba8df12adc86b13b6d85fda97d90e708efc036c2760cc6ba44"}, + {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5691340f259b8f76b45fb31b98e594d46c36d1dc8285efa7975f7f50230c9093"}, + {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1e263cc718545b7f897baeac1f00299ab6fabe3e18caaacacb0edf6d5f35513c"}, + {file = "watchfiles-1.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c6cf7709ed3e55704cc06f6e835bf43c03bc8e3cb8ff946bf69a2e0a78d9d77"}, + {file = "watchfiles-1.0.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:703aa5e50e465be901e0e0f9d5739add15e696d8c26c53bc6fc00eb65d7b9469"}, + {file = "watchfiles-1.0.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:bfcae6aecd9e0cb425f5145afee871465b98b75862e038d42fe91fd753ddd780"}, + {file = "watchfiles-1.0.3-cp311-cp311-win32.whl", hash = "sha256:6a76494d2c5311584f22416c5a87c1e2cb954ff9b5f0988027bc4ef2a8a67181"}, + {file = "watchfiles-1.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:cf745cbfad6389c0e331786e5fe9ae3f06e9d9c2ce2432378e1267954793975c"}, + {file = "watchfiles-1.0.3-cp311-cp311-win_arm64.whl", hash = "sha256:2dcc3f60c445f8ce14156854a072ceb36b83807ed803d37fdea2a50e898635d6"}, + {file = "watchfiles-1.0.3-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:93436ed550e429da007fbafb723e0769f25bae178fbb287a94cb4ccdf42d3af3"}, + {file = "watchfiles-1.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c18f3502ad0737813c7dad70e3e1cc966cc147fbaeef47a09463bbffe70b0a00"}, + {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a5bc3ca468bb58a2ef50441f953e1f77b9a61bd1b8c347c8223403dc9b4ac9a"}, + {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0d1ec043f02ca04bf21b1b32cab155ce90c651aaf5540db8eb8ad7f7e645cba8"}, + {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f58d3bfafecf3d81c15d99fc0ecf4319e80ac712c77cf0ce2661c8cf8bf84066"}, + {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1df924ba82ae9e77340101c28d56cbaff2c991bd6fe8444a545d24075abb0a87"}, + {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:632a52dcaee44792d0965c17bdfe5dc0edad5b86d6a29e53d6ad4bf92dc0ff49"}, + {file = "watchfiles-1.0.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bf4b459d94a0387617a1b499f314aa04d8a64b7a0747d15d425b8c8b151da0"}, + {file = "watchfiles-1.0.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ca94c85911601b097d53caeeec30201736ad69a93f30d15672b967558df02885"}, + {file = "watchfiles-1.0.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:65ab1fb635476f6170b07e8e21db0424de94877e4b76b7feabfe11f9a5fc12b5"}, + {file = "watchfiles-1.0.3-cp312-cp312-win32.whl", hash = "sha256:49bc1bc26abf4f32e132652f4b3bfeec77d8f8f62f57652703ef127e85a3e38d"}, + {file = "watchfiles-1.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:48681c86f2cb08348631fed788a116c89c787fdf1e6381c5febafd782f6c3b44"}, + {file = "watchfiles-1.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:9e080cf917b35b20c889225a13f290f2716748362f6071b859b60b8847a6aa43"}, + {file = "watchfiles-1.0.3-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:e153a690b7255c5ced17895394b4f109d5dcc2a4f35cb809374da50f0e5c456a"}, + {file = "watchfiles-1.0.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ac1be85fe43b4bf9a251978ce5c3bb30e1ada9784290441f5423a28633a958a7"}, + {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a2ec98e31e1844eac860e70d9247db9d75440fc8f5f679c37d01914568d18721"}, + {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0179252846be03fa97d4d5f8233d1c620ef004855f0717712ae1c558f1974a16"}, + {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:995c374e86fa82126c03c5b4630c4e312327ecfe27761accb25b5e1d7ab50ec8"}, + {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:29b9cb35b7f290db1c31fb2fdf8fc6d3730cfa4bca4b49761083307f441cac5a"}, + {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6f8dc09ae69af50bead60783180f656ad96bd33ffbf6e7a6fce900f6d53b08f1"}, + {file = "watchfiles-1.0.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:489b80812f52a8d8c7b0d10f0d956db0efed25df2821c7a934f6143f76938bd6"}, + {file = "watchfiles-1.0.3-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:228e2247de583475d4cebf6b9af5dc9918abb99d1ef5ee737155bb39fb33f3c0"}, + {file = "watchfiles-1.0.3-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:1550be1a5cb3be08a3fb84636eaafa9b7119b70c71b0bed48726fd1d5aa9b868"}, + {file = "watchfiles-1.0.3-cp313-cp313-win32.whl", hash = "sha256:16db2d7e12f94818cbf16d4c8938e4d8aaecee23826344addfaaa671a1527b07"}, + {file = "watchfiles-1.0.3-cp313-cp313-win_amd64.whl", hash = "sha256:160eff7d1267d7b025e983ca8460e8cc67b328284967cbe29c05f3c3163711a3"}, + {file = "watchfiles-1.0.3-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:c05b021f7b5aa333124f2a64d56e4cb9963b6efdf44e8d819152237bbd93ba15"}, + {file = "watchfiles-1.0.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:310505ad305e30cb6c5f55945858cdbe0eb297fc57378f29bacceb534ac34199"}, + {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ddff3f8b9fa24a60527c137c852d0d9a7da2a02cf2151650029fdc97c852c974"}, + {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:46e86ed457c3486080a72bc837300dd200e18d08183f12b6ca63475ab64ed651"}, + {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f79fe7993e230a12172ce7d7c7db061f046f672f2b946431c81aff8f60b2758b"}, + {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ea2b51c5f38bad812da2ec0cd7eec09d25f521a8b6b6843cbccedd9a1d8a5c15"}, + {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0fe4e740ea94978b2b2ab308cbf9270a246bcbb44401f77cc8740348cbaeac3d"}, + {file = "watchfiles-1.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9af037d3df7188ae21dc1c7624501f2f90d81be6550904e07869d8d0e6766655"}, + {file = "watchfiles-1.0.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:52bb50a4c4ca2a689fdba84ba8ecc6a4e6210f03b6af93181bb61c4ec3abaf86"}, + {file = "watchfiles-1.0.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c14a07bdb475eb696f85c715dbd0f037918ccbb5248290448488a0b4ef201aad"}, + {file = "watchfiles-1.0.3-cp39-cp39-win32.whl", hash = "sha256:be37f9b1f8934cd9e7eccfcb5612af9fb728fecbe16248b082b709a9d1b348bf"}, + {file = "watchfiles-1.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:ef9ec8068cf23458dbf36a08e0c16f0a2df04b42a8827619646637be1769300a"}, + {file = "watchfiles-1.0.3-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:84fac88278f42d61c519a6c75fb5296fd56710b05bbdcc74bdf85db409a03780"}, + {file = "watchfiles-1.0.3-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:c68be72b1666d93b266714f2d4092d78dc53bd11cf91ed5a3c16527587a52e29"}, + {file = "watchfiles-1.0.3-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:889a37e2acf43c377b5124166bece139b4c731b61492ab22e64d371cce0e6e80"}, + {file = "watchfiles-1.0.3-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ca05cacf2e5c4a97d02a2878a24020daca21dbb8823b023b978210a75c79098"}, + {file = "watchfiles-1.0.3-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:8af4b582d5fc1b8465d1d2483e5e7b880cc1a4e99f6ff65c23d64d070867ac58"}, + {file = "watchfiles-1.0.3-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:127de3883bdb29dbd3b21f63126bb8fa6e773b74eaef46521025a9ce390e1073"}, + {file = "watchfiles-1.0.3-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:713f67132346bdcb4c12df185c30cf04bdf4bf6ea3acbc3ace0912cab6b7cb8c"}, + {file = "watchfiles-1.0.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:abd85de513eb83f5ec153a802348e7a5baa4588b818043848247e3e8986094e8"}, + {file = "watchfiles-1.0.3.tar.gz", hash = "sha256:f3ff7da165c99a5412fe5dd2304dd2dbaaaa5da718aad942dcb3a178eaa70c56"}, ] [package.dependencies] @@ -11095,4 +11173,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.13" -content-hash = "14476bf95504a4df4b8d5a5c6608c6aa3dae7499d27d1e41ef39d761cc7c693d" +content-hash = "f4accd01805cbf080c4c5295f97a06c8e4faec7365d2c43d0435e56b46461732" diff --git a/api/pyproject.toml b/api/pyproject.toml index da9eabecf55ccf..28e0305406a18b 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -60,6 +60,7 @@ oci = "~2.135.1" openai = "~1.52.0" openpyxl = "~3.1.5" pandas = { version = "~2.2.2", extras = ["performance", "excel"] } +pandas-stubs = "~2.2.3.241009" psycopg2-binary = "~2.9.6" pycryptodome = "3.19.1" pydantic = "~2.9.2" @@ -84,6 +85,7 @@ tencentcloud-sdk-python-hunyuan = "~3.0.1158" tiktoken = "~0.8.0" tokenizers = "~0.15.0" transformers = "~4.35.0" +types-pytz = "~2024.2.0.20241003" unstructured = { version = "~0.16.1", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"] } validators = "0.21.0" volcengine-python-sdk = {extras = ["ark"], version = "~1.0.98"} @@ -173,6 +175,7 @@ optional = true [tool.poetry.group.dev.dependencies] coverage = "~7.2.4" faker = "~32.1.0" +mypy = "~1.13.0" pytest = "~8.3.2" pytest-benchmark = "~4.0.0" pytest-env = "~1.1.3" diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py index 97e5c77e95361a..48bdc872f41e5c 100644 --- a/api/schedule/clean_messages.py +++ b/api/schedule/clean_messages.py @@ -32,8 +32,9 @@ def clean_messages(): while True: try: # Main query with join and filter + # FIXME:for mypy no paginate method error messages = ( - db.session.query(Message) + db.session.query(Message) # type: ignore .filter(Message.created_at < plan_sandbox_clean_message_day) .order_by(Message.created_at.desc()) .limit(100) diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index e12be649e4d02d..f66b3c47979435 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -52,8 +52,7 @@ def clean_unused_datasets_task(): # Main query with join and filter datasets = ( - db.session.query(Dataset) - .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) + Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) .filter( Dataset.created_at < plan_sandbox_clean_day, @@ -120,8 +119,7 @@ def clean_unused_datasets_task(): # Main query with join and filter datasets = ( - db.session.query(Dataset) - .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) + Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) .filter( Dataset.created_at < plan_pro_clean_day, diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py index a20b500308a4d6..1c985461c6aa2e 100644 --- a/api/schedule/create_tidb_serverless_task.py +++ b/api/schedule/create_tidb_serverless_task.py @@ -36,14 +36,15 @@ def create_tidb_serverless_task(): def create_clusters(batch_size): try: + # TODO: maybe we can set the default value for the following parameters in the config file new_clusters = TidbService.batch_create_tidb_serverless_cluster( - batch_size, - dify_config.TIDB_PROJECT_ID, - dify_config.TIDB_API_URL, - dify_config.TIDB_IAM_API_URL, - dify_config.TIDB_PUBLIC_KEY, - dify_config.TIDB_PRIVATE_KEY, - dify_config.TIDB_REGION, + batch_size=batch_size, + project_id=dify_config.TIDB_PROJECT_ID or "", + api_url=dify_config.TIDB_API_URL or "", + iam_url=dify_config.TIDB_IAM_API_URL or "", + public_key=dify_config.TIDB_PUBLIC_KEY or "", + private_key=dify_config.TIDB_PRIVATE_KEY or "", + region=dify_config.TIDB_REGION or "", ) for new_cluster in new_clusters: tidb_auth_binding = TidbAuthBinding( diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py index b2d8746f9ca8f4..11a39e60ee4ce5 100644 --- a/api/schedule/update_tidb_serverless_status_task.py +++ b/api/schedule/update_tidb_serverless_status_task.py @@ -36,13 +36,14 @@ def update_clusters(tidb_serverless_list: list[TidbAuthBinding]): # batch 20 for i in range(0, len(tidb_serverless_list), 20): items = tidb_serverless_list[i : i + 20] + # TODO: maybe we can set the default value for the following parameters in the config file TidbService.batch_update_tidb_serverless_cluster_status( - items, - dify_config.TIDB_PROJECT_ID, - dify_config.TIDB_API_URL, - dify_config.TIDB_IAM_API_URL, - dify_config.TIDB_PUBLIC_KEY, - dify_config.TIDB_PRIVATE_KEY, + tidb_serverless_list=items, + project_id=dify_config.TIDB_PROJECT_ID or "", + api_url=dify_config.TIDB_API_URL or "", + iam_url=dify_config.TIDB_IAM_API_URL or "", + public_key=dify_config.TIDB_PUBLIC_KEY or "", + private_key=dify_config.TIDB_PRIVATE_KEY or "", ) except Exception as e: click.echo(click.style(f"Error: {e}", fg="red")) diff --git a/api/services/account_service.py b/api/services/account_service.py index 22b54a3ab87473..91075ec46b16bf 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -6,7 +6,7 @@ import uuid from datetime import UTC, datetime, timedelta from hashlib import sha256 -from typing import Any, Optional +from typing import Any, Optional, cast from pydantic import BaseModel from sqlalchemy import func @@ -119,7 +119,7 @@ def load_user(user_id: str) -> None | Account: account.last_active_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() - return account + return cast(Account, account) @staticmethod def get_account_jwt_token(account: Account) -> str: @@ -132,7 +132,7 @@ def get_account_jwt_token(account: Account) -> str: "sub": "Console API Passport", } - token = PassportService().issue(payload) + token: str = PassportService().issue(payload) return token @staticmethod @@ -164,7 +164,7 @@ def authenticate(email: str, password: str, invite_token: Optional[str] = None) db.session.commit() - return account + return cast(Account, account) @staticmethod def update_account_password(account, password, new_password): @@ -347,6 +347,8 @@ def send_reset_password_email( language: Optional[str] = "en-US", ): account_email = account.email if account else email + if account_email is None: + raise ValueError("Email must be provided.") if cls.reset_password_rate_limiter.is_rate_limited(account_email): from controllers.console.auth.error import PasswordResetRateLimitExceededError @@ -377,6 +379,8 @@ def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]: def send_email_code_login_email( cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US" ): + if email is None: + raise ValueError("Email must be provided.") if cls.email_code_login_rate_limiter.is_rate_limited(email): from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError @@ -669,7 +673,7 @@ def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountJoi @staticmethod def get_tenant_count() -> int: """Get tenant count""" - return db.session.query(func.count(Tenant.id)).scalar() + return cast(int, db.session.query(func.count(Tenant.id)).scalar()) @staticmethod def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str) -> None: @@ -733,10 +737,10 @@ def dissolve_tenant(tenant: Tenant, operator: Account) -> None: db.session.commit() @staticmethod - def get_custom_config(tenant_id: str) -> None: - tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).one_or_404() + def get_custom_config(tenant_id: str) -> dict: + tenant = Tenant.query.filter(Tenant.id == tenant_id).one_or_404() - return tenant.custom_config_dict + return cast(dict, tenant.custom_config_dict) class RegisterService: @@ -807,7 +811,7 @@ def register( account.status = AccountStatus.ACTIVE.value if not status else status.value account.initialized_at = datetime.now(UTC).replace(tzinfo=None) - if open_id is not None or provider is not None: + if open_id is not None and provider is not None: AccountService.link_account_integrate(provider, open_id, account) if FeatureService.get_system_features().is_allow_create_workspace: @@ -828,10 +832,11 @@ def register( @classmethod def invite_new_member( - cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account = None + cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Optional[Account] = None ) -> str: """Invite new member""" account = Account.query.filter_by(email=email).first() + assert inviter is not None, "Inviter must be provided." if not account: TenantService.check_member_permission(tenant, inviter, None, "add") @@ -894,7 +899,9 @@ def revoke_token(cls, workspace_id: str, email: str, token: str): redis_client.delete(cls._get_invitation_token_key(token)) @classmethod - def get_invitation_if_token_valid(cls, workspace_id: str, email: str, token: str) -> Optional[dict[str, Any]]: + def get_invitation_if_token_valid( + cls, workspace_id: Optional[str], email: str, token: str + ) -> Optional[dict[str, Any]]: invitation_data = cls._get_invitation_by_token(token, workspace_id, email) if not invitation_data: return None @@ -953,7 +960,7 @@ def _get_invitation_by_token( if not data: return None - invitation = json.loads(data) + invitation: dict = json.loads(data) return invitation diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index d2cd7bea67c5b6..6dc1affa11d036 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -48,6 +48,8 @@ def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> return cls.get_chat_prompt( copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt ) + # default return empty dict + return {} @classmethod def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict: @@ -91,3 +93,5 @@ def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) - return cls.get_chat_prompt( copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt ) + # default return empty dict + return {} diff --git a/api/services/agent_service.py b/api/services/agent_service.py index c8819535f11a39..b02f762ad267b8 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -1,5 +1,7 @@ +from typing import Optional + import pytz -from flask_login import current_user +from flask_login import current_user # type: ignore from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager from core.tools.tool_manager import ToolManager @@ -14,7 +16,7 @@ def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) - """ Service to get agent logs """ - conversation: Conversation = ( + conversation: Optional[Conversation] = ( db.session.query(Conversation) .filter( Conversation.id == conversation_id, @@ -26,7 +28,7 @@ def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) - if not conversation: raise ValueError(f"Conversation not found: {conversation_id}") - message: Message = ( + message: Optional[Message] = ( db.session.query(Message) .filter( Message.id == message_id, @@ -72,7 +74,10 @@ def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) - } agent_config = AgentConfigManager.convert(app_model.app_model_config.to_dict()) - agent_tools = agent_config.tools + if not agent_config: + return result + + agent_tools = agent_config.tools or [] def find_agent_tool(tool_name: str): for agent_tool in agent_tools: diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index f45c21cb18f5e3..a946405c955cec 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -1,8 +1,9 @@ import datetime import uuid +from typing import cast import pandas as pd -from flask_login import current_user +from flask_login import current_user # type: ignore from sqlalchemy import or_ from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound @@ -71,7 +72,7 @@ def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> Messa app_id, annotation_setting.collection_binding_id, ) - return annotation + return cast(MessageAnnotation, annotation) @classmethod def enable_app_annotation(cls, args: dict, app_id: str) -> dict: @@ -124,8 +125,7 @@ def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keywo raise NotFound("App not found") if keyword: annotations = ( - db.session.query(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) + MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id) .filter( or_( MessageAnnotation.question.ilike("%{}%".format(keyword)), @@ -137,8 +137,7 @@ def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keywo ) else: annotations = ( - db.session.query(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) + MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id) .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) ) @@ -327,8 +326,7 @@ def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, lim raise NotFound("Annotation not found") annotation_hit_histories = ( - db.session.query(AppAnnotationHitHistory) - .filter( + AppAnnotationHitHistory.query.filter( AppAnnotationHitHistory.app_id == app_id, AppAnnotationHitHistory.annotation_id == annotation_id, ) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 7c1a175988071e..b191fa2397fa9e 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -1,7 +1,7 @@ import logging import uuid from enum import StrEnum -from typing import Optional +from typing import Optional, cast from uuid import uuid4 import yaml @@ -103,7 +103,7 @@ def import_app( raise ValueError(f"Invalid import_mode: {import_mode}") # Get YAML content - content = "" + content: bytes | str = b"" if mode == ImportMode.YAML_URL: if not yaml_url: return Import( @@ -136,7 +136,7 @@ def import_app( ) try: - content = content.decode("utf-8") + content = cast(bytes, content).decode("utf-8") except UnicodeDecodeError as e: return Import( id=import_id, @@ -362,6 +362,9 @@ def _create_or_update_app( app.icon_background = icon_background or app_data.get("icon_background", app.icon_background) app.updated_by = account.id else: + if account.current_tenant_id is None: + raise ValueError("Current tenant is not set") + # Create new app app = App() app.id = str(uuid4()) diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 9def7d15e928d4..51aef7ccab9a0c 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -118,7 +118,7 @@ def generate( @staticmethod def _get_max_active_requests(app_model: App) -> int: max_active_requests = app_model.max_active_requests - if app_model.max_active_requests is None: + if max_active_requests is None: max_active_requests = int(dify_config.APP_MAX_ACTIVE_REQUESTS) return max_active_requests @@ -150,7 +150,7 @@ def generate_more_like_this( message_id: str, invoke_from: InvokeFrom, streaming: bool = True, - ) -> Union[dict, Generator]: + ) -> Union[Mapping, Generator]: """ Generate more like this :param app_model: app model diff --git a/api/services/app_service.py b/api/services/app_service.py index 8d8ba735ecfa71..41c15bbf0a330b 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -1,9 +1,9 @@ import json import logging from datetime import UTC, datetime -from typing import cast +from typing import Optional, cast -from flask_login import current_user +from flask_login import current_user # type: ignore from flask_sqlalchemy.pagination import Pagination from configs import dify_config @@ -83,7 +83,7 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App: # get default model instance try: model_instance = model_manager.get_default_model_instance( - tenant_id=account.current_tenant_id, model_type=ModelType.LLM + tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM ) except (ProviderTokenNotInitError, LLMBadRequestError): model_instance = None @@ -100,6 +100,8 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App: else: llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) + if model_schema is None: + raise ValueError(f"model schema not found for model {model_instance.model}") default_model_dict = { "provider": model_instance.provider, @@ -109,7 +111,7 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App: } else: provider, model = model_manager.get_default_provider_model_name( - tenant_id=account.current_tenant_id, model_type=ModelType.LLM + tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM ) default_model_config["model"]["provider"] = provider default_model_config["model"]["name"] = model @@ -314,7 +316,7 @@ def get_app_meta(self, app_model: App) -> dict: """ app_mode = AppMode.value_of(app_model.mode) - meta = {"tool_icons": {}} + meta: dict = {"tool_icons": {}} if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow @@ -336,7 +338,7 @@ def get_app_meta(self, app_model: App) -> dict: } ) else: - app_model_config: AppModelConfig = app_model.app_model_config + app_model_config: Optional[AppModelConfig] = app_model.app_model_config if not app_model_config: return meta @@ -352,16 +354,18 @@ def get_app_meta(self, app_model: App) -> dict: keys = list(tool.keys()) if len(keys) >= 4: # current tool standard - provider_type = tool.get("provider_type") - provider_id = tool.get("provider_id") - tool_name = tool.get("tool_name") + provider_type = tool.get("provider_type", "") + provider_id = tool.get("provider_id", "") + tool_name = tool.get("tool_name", "") if provider_type == "builtin": meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon" elif provider_type == "api": try: - provider: ApiToolProvider = ( + provider: Optional[ApiToolProvider] = ( db.session.query(ApiToolProvider).filter(ApiToolProvider.id == provider_id).first() ) + if provider is None: + raise ValueError(f"provider not found for tool {tool_name}") meta["tool_icons"][tool_name] = json.loads(provider.icon) except: meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"} diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 7a0cd5725b2a96..973110f5156523 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -110,6 +110,8 @@ def invoke_tts(text_content: str, app_model, voice: Optional[str] = None): voices = model_instance.get_tts_voices() if voices: voice = voices[0].get("value") + if not voice: + raise ValueError("Sorry, no voice available.") else: raise ValueError("Sorry, no voice available.") @@ -121,6 +123,8 @@ def invoke_tts(text_content: str, app_model, voice: Optional[str] = None): if message_id: message = db.session.query(Message).filter(Message.id == message_id).first() + if message is None: + return None if message.answer == "" and message.status == "normal": return None @@ -130,6 +134,8 @@ def invoke_tts(text_content: str, app_model, voice: Optional[str] = None): return Response(stream_with_context(response), content_type="audio/mpeg") return response else: + if not text: + raise ValueError("Text is required") response = invoke_tts(text, app_model, voice) if isinstance(response, Generator): return Response(stream_with_context(response), content_type="audio/mpeg") diff --git a/api/services/auth/firecrawl/firecrawl.py b/api/services/auth/firecrawl/firecrawl.py index afc491398f25f3..50e4edff140346 100644 --- a/api/services/auth/firecrawl/firecrawl.py +++ b/api/services/auth/firecrawl/firecrawl.py @@ -11,8 +11,8 @@ def __init__(self, credentials: dict): auth_type = credentials.get("auth_type") if auth_type != "bearer": raise ValueError("Invalid auth type, Firecrawl auth type must be Bearer") - self.api_key = credentials.get("config").get("api_key", None) - self.base_url = credentials.get("config").get("base_url", "https://api.firecrawl.dev") + self.api_key = credentials.get("config", {}).get("api_key", None) + self.base_url = credentials.get("config", {}).get("base_url", "https://api.firecrawl.dev") if not self.api_key: raise ValueError("No API key provided") diff --git a/api/services/auth/jina.py b/api/services/auth/jina.py index de898a1f94b763..6100e9afc8f278 100644 --- a/api/services/auth/jina.py +++ b/api/services/auth/jina.py @@ -11,7 +11,7 @@ def __init__(self, credentials: dict): auth_type = credentials.get("auth_type") if auth_type != "bearer": raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer") - self.api_key = credentials.get("config").get("api_key", None) + self.api_key = credentials.get("config", {}).get("api_key", None) if not self.api_key: raise ValueError("No API key provided") diff --git a/api/services/auth/jina/jina.py b/api/services/auth/jina/jina.py index de898a1f94b763..6100e9afc8f278 100644 --- a/api/services/auth/jina/jina.py +++ b/api/services/auth/jina/jina.py @@ -11,7 +11,7 @@ def __init__(self, credentials: dict): auth_type = credentials.get("auth_type") if auth_type != "bearer": raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer") - self.api_key = credentials.get("config").get("api_key", None) + self.api_key = credentials.get("config", {}).get("api_key", None) if not self.api_key: raise ValueError("No API key provided") diff --git a/api/services/billing_service.py b/api/services/billing_service.py index edc51682179cc5..d98018648839a9 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -1,4 +1,5 @@ import os +from typing import Optional import httpx from tenacity import retry, retry_if_not_exception_type, stop_before_delay, wait_fixed @@ -58,11 +59,14 @@ def _send_request(cls, method, endpoint, json=None, params=None): def is_tenant_owner_or_admin(current_user): tenant_id = current_user.current_tenant_id - join = ( + join: Optional[TenantAccountJoin] = ( db.session.query(TenantAccountJoin) .filter(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id) .first() ) + if not join: + raise ValueError("Tenant account join not found") + if not TenantAccountRole.is_privileged_role(join.role): raise ValueError("Only team owner or team admin can perform this action") diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 456dc3ebebaa28..6485cbf37d5b7f 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -72,8 +72,7 @@ def pagination_by_last_id( sort_direction=sort_direction, reference_conversation=current_page_last_conversation, ) - count_stmt = stmt.where(rest_filter_condition) - count_stmt = select(func.count()).select_from(count_stmt.subquery()) + count_stmt = select(func.count()).select_from(stmt.where(rest_filter_condition).subquery()) rest_count = session.scalar(count_stmt) or 0 if rest_count > 0: has_more = True diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 4e99c73ad4787a..d2d8a718d55c8a 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -6,7 +6,7 @@ import uuid from typing import Any, Optional -from flask_login import current_user +from flask_login import current_user # type: ignore from sqlalchemy import func from werkzeug.exceptions import NotFound @@ -186,8 +186,9 @@ def create_empty_dataset( return dataset @staticmethod - def get_dataset(dataset_id) -> Dataset: - return Dataset.query.filter_by(id=dataset_id).first() + def get_dataset(dataset_id) -> Optional[Dataset]: + dataset: Optional[Dataset] = Dataset.query.filter_by(id=dataset_id).first() + return dataset @staticmethod def check_dataset_model_setting(dataset): @@ -228,6 +229,8 @@ def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, @staticmethod def update_dataset(dataset_id, data, user): dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise ValueError("Dataset not found") DatasetService.check_dataset_permission(dataset, user) if dataset.provider == "external": @@ -371,7 +374,13 @@ def check_dataset_permission(dataset, user): raise NoPermissionError("You do not have permission to access this dataset.") @staticmethod - def check_dataset_operator_permission(user: Account = None, dataset: Dataset = None): + def check_dataset_operator_permission(user: Optional[Account] = None, dataset: Optional[Dataset] = None): + if not dataset: + raise ValueError("Dataset not found") + + if not user: + raise ValueError("User not found") + if dataset.permission == DatasetPermissionEnum.ONLY_ME: if dataset.created_by != user.id: raise NoPermissionError("You do not have permission to access this dataset.") @@ -765,6 +774,11 @@ def save_document_with_dataset_id( rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), created_by=account.id, ) + else: + logging.warn( + f"Invalid process rule mode: {process_rule['mode']}, can not find dataset process rule" + ) + return db.session.add(dataset_process_rule) db.session.commit() lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) @@ -1009,9 +1023,10 @@ def update_document_with_dataset_id( rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), created_by=account.id, ) - db.session.add(dataset_process_rule) - db.session.commit() - document.dataset_process_rule_id = dataset_process_rule.id + if dataset_process_rule is not None: + db.session.add(dataset_process_rule) + db.session.commit() + document.dataset_process_rule_id = dataset_process_rule.id # update document data source if document_data.get("data_source"): file_name = "" @@ -1554,7 +1569,7 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document segment.word_count = len(content) if document.doc_form == "qa_model": segment.answer = segment_update_entity.answer - segment.word_count += len(segment_update_entity.answer) + segment.word_count += len(segment_update_entity.answer or "") word_count_change = segment.word_count - word_count_change if segment_update_entity.keywords: segment.keywords = segment_update_entity.keywords @@ -1569,7 +1584,8 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document db.session.add(document) # update segment index task if segment_update_entity.enabled: - VectorService.create_segments_vector([segment_update_entity.keywords], [segment], dataset) + keywords = segment_update_entity.keywords or [] + VectorService.create_segments_vector([keywords], [segment], dataset) else: segment_hash = helper.generate_text_hash(content) tokens = 0 @@ -1601,7 +1617,7 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document segment.disabled_by = None if document.doc_form == "qa_model": segment.answer = segment_update_entity.answer - segment.word_count += len(segment_update_entity.answer) + segment.word_count += len(segment_update_entity.answer or "") word_count_change = segment.word_count - word_count_change # update document word count if word_count_change != 0: @@ -1619,8 +1635,8 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document segment.status = "error" segment.error = str(e) db.session.commit() - segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first() - return segment + new_segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first() + return new_segment @classmethod def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset): @@ -1680,6 +1696,8 @@ def get_dataset_collection_binding_by_id_and_type( .order_by(DatasetCollectionBinding.created_at) .first() ) + if not dataset_collection_binding: + raise ValueError("Dataset collection binding not found") return dataset_collection_binding diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index 92098f06cca538..3c3f9704440342 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -8,8 +8,8 @@ class EnterpriseRequest: secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY") proxies = { - "http": None, - "https": None, + "http": "", + "https": "", } @classmethod diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index c519f0b0e51b68..334d009ee5f79f 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -4,7 +4,11 @@ from pydantic import BaseModel, ConfigDict from configs import dify_config -from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity +from core.entities.model_entities import ( + ModelWithProviderEntity, + ProviderModelWithStatusEntity, + SimpleModelProviderEntity, +) from core.entities.provider_entities import QuotaConfiguration from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ModelType @@ -148,7 +152,7 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity): Model with provider entity. """ - provider: SimpleProviderEntityResponse + provider: SimpleModelProviderEntity def __init__(self, model: ModelWithProviderEntity) -> None: super().__init__(**model.model_dump()) diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 7be20301a74b78..898624066bef7e 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -1,7 +1,7 @@ import json from copy import deepcopy from datetime import UTC, datetime -from typing import Any, Optional, Union +from typing import Any, Optional, Union, cast import httpx import validators @@ -45,7 +45,10 @@ def validate_api_list(cls, api_settings: dict): @staticmethod def create_external_knowledge_api(tenant_id: str, user_id: str, args: dict) -> ExternalKnowledgeApis: - ExternalDatasetService.check_endpoint_and_api_key(args.get("settings")) + settings = args.get("settings") + if settings is None: + raise ValueError("settings is required") + ExternalDatasetService.check_endpoint_and_api_key(settings) external_knowledge_api = ExternalKnowledgeApis( tenant_id=tenant_id, created_by=user_id, @@ -86,11 +89,16 @@ def check_endpoint_and_api_key(settings: dict): @staticmethod def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis: - return ExternalKnowledgeApis.query.filter_by(id=external_knowledge_api_id).first() + external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by( + id=external_knowledge_api_id + ).first() + if external_knowledge_api is None: + raise ValueError("api template not found") + return external_knowledge_api @staticmethod def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis: - external_knowledge_api = ExternalKnowledgeApis.query.filter_by( + external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by( id=external_knowledge_api_id, tenant_id=tenant_id ).first() if external_knowledge_api is None: @@ -127,7 +135,7 @@ def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bo @staticmethod def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings: - external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by( + external_knowledge_binding: Optional[ExternalKnowledgeBindings] = ExternalKnowledgeBindings.query.filter_by( dataset_id=dataset_id, tenant_id=tenant_id ).first() if not external_knowledge_binding: @@ -163,8 +171,9 @@ def process_external_api( "follow_redirects": True, } - response = getattr(ssrf_proxy, settings.request_method)(data=json.dumps(settings.params), files=files, **kwargs) - + response: httpx.Response = getattr(ssrf_proxy, settings.request_method)( + data=json.dumps(settings.params), files=files, **kwargs + ) return response @staticmethod @@ -265,15 +274,15 @@ def fetch_external_knowledge_retrieval( "knowledge_id": external_knowledge_binding.external_knowledge_id, } - external_knowledge_api_setting = { - "url": f"{settings.get('endpoint')}/retrieval", - "request_method": "post", - "headers": headers, - "params": request_params, - } response = ExternalDatasetService.process_external_api( - ExternalKnowledgeApiSetting(**external_knowledge_api_setting), None + ExternalKnowledgeApiSetting( + url=f"{settings.get('endpoint')}/retrieval", + request_method="post", + headers=headers, + params=request_params, + ), + None, ) if response.status_code == 200: - return response.json().get("records", []) + return cast(list[Any], response.json().get("records", [])) return [] diff --git a/api/services/file_service.py b/api/services/file_service.py index b12b95ca13558c..d417e81734c8af 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -3,7 +3,7 @@ import uuid from typing import Any, Literal, Union -from flask_login import current_user +from flask_login import current_user # type: ignore from werkzeug.exceptions import NotFound from configs import dify_config @@ -61,14 +61,14 @@ def upload_file( # end_user current_tenant_id = user.tenant_id - file_key = "upload_files/" + current_tenant_id + "/" + file_uuid + "." + extension + file_key = "upload_files/" + (current_tenant_id or "") + "/" + file_uuid + "." + extension # save file to storage storage.save(file_key, content) # save file to db upload_file = UploadFile( - tenant_id=current_tenant_id, + tenant_id=current_tenant_id or "", storage_type=dify_config.STORAGE_TYPE, key=file_key, name=filename, diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 7957b4dc82dfd4..41b4e1ec46374a 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -1,5 +1,6 @@ import logging import time +from typing import Any from core.rag.datasource.retrieval_service import RetrievalService from core.rag.models.document import Document @@ -24,7 +25,7 @@ def retrieve( dataset: Dataset, query: str, account: Account, - retrieval_model: dict, + retrieval_model: Any, # FIXME drop this any external_retrieval_model: dict, limit: int = 10, ) -> dict: @@ -68,7 +69,7 @@ def retrieve( db.session.add(dataset_query) db.session.commit() - return cls.compact_retrieve_response(dataset, query, all_documents) + return dict(cls.compact_retrieve_response(dataset, query, all_documents)) @classmethod def external_retrieve( @@ -102,13 +103,16 @@ def external_retrieve( db.session.add(dataset_query) db.session.commit() - return cls.compact_external_retrieve_response(dataset, query, all_documents) + return dict(cls.compact_external_retrieve_response(dataset, query, all_documents)) @classmethod def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]): records = [] for document in documents: + if document.metadata is None: + continue + index_node_id = document.metadata["doc_id"] segment = ( @@ -140,7 +144,7 @@ def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list } @classmethod - def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list): + def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list) -> dict[Any, Any]: records = [] if dataset.provider == "external": for document in documents: @@ -152,11 +156,10 @@ def compact_external_retrieve_response(cls, dataset: Dataset, query: str, docume } records.append(record) return { - "query": { - "content": query, - }, + "query": {"content": query}, "records": records, } + return {"query": {"content": query}, "records": []} @classmethod def hit_testing_args_check(cls, args): diff --git a/api/services/knowledge_service.py b/api/services/knowledge_service.py index 02fe1d19bc42be..8df1a6ba144d4e 100644 --- a/api/services/knowledge_service.py +++ b/api/services/knowledge_service.py @@ -1,4 +1,4 @@ -import boto3 +import boto3 # type: ignore from configs import dify_config diff --git a/api/services/message_service.py b/api/services/message_service.py index be2922f4c58e76..c4447a84da5e09 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -157,7 +157,7 @@ def create_feedback( user: Optional[Union[Account, EndUser]], rating: Optional[str], content: Optional[str], - ) -> MessageFeedback: + ): if not user: raise ValueError("user cannot be None") @@ -264,6 +264,8 @@ def get_suggested_questions_after_answer( ) app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs) + if not app_model_config: + raise ValueError("did not find app model config") suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict if suggested_questions_after_answer.get("enabled", False) is False: @@ -285,7 +287,7 @@ def get_suggested_questions_after_answer( ) with measure_time() as timer: - questions = LLMGenerator.generate_suggested_questions_after_answer( + questions: list[Message] = LLMGenerator.generate_suggested_questions_after_answer( tenant_id=app_model.tenant_id, histories=histories ) diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index b20bda87551ca9..bacd3a8ec3d04f 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -2,7 +2,7 @@ import json import logging from json import JSONDecodeError -from typing import Optional +from typing import Optional, Union from constants import HIDDEN_VALUE from core.entities.provider_configuration import ProviderConfiguration @@ -88,11 +88,11 @@ def get_load_balancing_configs( raise ValueError(f"Provider {provider} does not exist.") # Convert model type to ModelType - model_type = ModelType.value_of(model_type) + model_type_enum = ModelType.value_of(model_type) # Get provider model setting provider_model_setting = provider_configuration.get_provider_model_setting( - model_type=model_type, + model_type=model_type_enum, model=model, ) @@ -106,7 +106,7 @@ def get_load_balancing_configs( .filter( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, ) .order_by(LoadBalancingModelConfig.created_at) @@ -124,7 +124,7 @@ def get_load_balancing_configs( if not inherit_config_exists: # Initialize the inherit configuration - inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type) + inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type_enum) # prepend the inherit configuration load_balancing_configs.insert(0, inherit_config) @@ -148,7 +148,7 @@ def get_load_balancing_configs( tenant_id=tenant_id, provider=provider, model=model, - model_type=model_type, + model_type=model_type_enum, config_id=load_balancing_config.id, ) @@ -214,7 +214,7 @@ def get_load_balancing_config( raise ValueError(f"Provider {provider} does not exist.") # Convert model type to ModelType - model_type = ModelType.value_of(model_type) + model_type_enum = ModelType.value_of(model_type) # Get load balancing configurations load_balancing_model_config = ( @@ -222,7 +222,7 @@ def get_load_balancing_config( .filter( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, LoadBalancingModelConfig.id == config_id, ) @@ -300,7 +300,7 @@ def update_load_balancing_configs( raise ValueError(f"Provider {provider} does not exist.") # Convert model type to ModelType - model_type = ModelType.value_of(model_type) + model_type_enum = ModelType.value_of(model_type) if not isinstance(configs, list): raise ValueError("Invalid load balancing configs") @@ -310,7 +310,7 @@ def update_load_balancing_configs( .filter( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, ) .all() @@ -359,7 +359,7 @@ def update_load_balancing_configs( credentials = self._custom_credentials_validate( tenant_id=tenant_id, provider_configuration=provider_configuration, - model_type=model_type, + model_type=model_type_enum, model=model, credentials=credentials, load_balancing_model_config=load_balancing_config, @@ -395,7 +395,7 @@ def update_load_balancing_configs( credentials = self._custom_credentials_validate( tenant_id=tenant_id, provider_configuration=provider_configuration, - model_type=model_type, + model_type=model_type_enum, model=model, credentials=credentials, validate=False, @@ -405,7 +405,7 @@ def update_load_balancing_configs( load_balancing_model_config = LoadBalancingModelConfig( tenant_id=tenant_id, provider_name=provider_configuration.provider.provider, - model_type=model_type.to_origin_model_type(), + model_type=model_type_enum.to_origin_model_type(), model_name=model, name=name, encrypted_config=json.dumps(credentials), @@ -450,7 +450,7 @@ def validate_load_balancing_credentials( raise ValueError(f"Provider {provider} does not exist.") # Convert model type to ModelType - model_type = ModelType.value_of(model_type) + model_type_enum = ModelType.value_of(model_type) load_balancing_model_config = None if config_id: @@ -460,7 +460,7 @@ def validate_load_balancing_credentials( .filter( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, LoadBalancingModelConfig.id == config_id, ) @@ -474,7 +474,7 @@ def validate_load_balancing_credentials( self._custom_credentials_validate( tenant_id=tenant_id, provider_configuration=provider_configuration, - model_type=model_type, + model_type=model_type_enum, model=model, credentials=credentials, load_balancing_model_config=load_balancing_model_config, @@ -547,19 +547,14 @@ def _custom_credentials_validate( def _get_credential_schema( self, provider_configuration: ProviderConfiguration - ) -> ModelCredentialSchema | ProviderCredentialSchema: - """ - Get form schemas. - :param provider_configuration: provider configuration - :return: - """ - # Get credential form schemas from model credential schema or provider credential schema + ) -> Union[ModelCredentialSchema, ProviderCredentialSchema]: + """Get form schemas.""" if provider_configuration.provider.model_credential_schema: - credential_schema = provider_configuration.provider.model_credential_schema + return provider_configuration.provider.model_credential_schema + elif provider_configuration.provider.provider_credential_schema: + return provider_configuration.provider.provider_credential_schema else: - credential_schema = provider_configuration.provider.provider_credential_schema - - return credential_schema + raise ValueError("No credential schema found") def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None: """ diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 384a072b371fdd..b10c5ad2d616e9 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -7,7 +7,7 @@ import requests from flask import current_app -from core.entities.model_entities import ModelStatus, ProviderModelWithStatusEntity +from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, ProviderModelWithStatusEntity from core.model_runtime.entities.model_entities import ModelType, ParameterRule from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -100,23 +100,15 @@ def get_models_by_provider(self, tenant_id: str, provider: str) -> list[ModelWit ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(provider=provider) ] - def get_provider_credentials(self, tenant_id: str, provider: str) -> dict: + def get_provider_credentials(self, tenant_id: str, provider: str): """ get provider credentials. - - :param tenant_id: - :param provider: - :return: """ - # Get all provider configurations of the current workspace provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration provider_configuration = provider_configurations.get(provider) if not provider_configuration: raise ValueError(f"Provider {provider} does not exist.") - # Get provider custom credentials from workspace return provider_configuration.get_custom_credentials(obfuscated=True) def provider_credentials_validate(self, tenant_id: str, provider: str, credentials: dict) -> None: @@ -176,7 +168,7 @@ def remove_provider_credentials(self, tenant_id: str, provider: str) -> None: # Remove custom provider credentials. provider_configuration.delete_custom_credentials() - def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> dict: + def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str): """ get model credentials. @@ -287,7 +279,7 @@ def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[Prov models = provider_configurations.get_models(model_type=ModelType.value_of(model_type)) # Group models by provider - provider_models = {} + provider_models: dict[str, list[ModelWithProviderEntity]] = {} for model in models: if model.provider.provider not in provider_models: provider_models[model.provider.provider] = [] @@ -362,7 +354,7 @@ def get_model_parameter_rules(self, tenant_id: str, provider: str, model: str) - return [] # Call get_parameter_rules method of model instance to get model parameter rules - return model_type_instance.get_parameter_rules(model=model, credentials=credentials) + return list(model_type_instance.get_parameter_rules(model=model, credentials=credentials)) def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]: """ @@ -422,6 +414,7 @@ def get_model_provider_icon( """ provider_instance = model_provider_factory.get_provider_instance(provider) provider_schema = provider_instance.get_provider_schema() + file_name: str | None = None if icon_type.lower() == "icon_small": if not provider_schema.icon_small: @@ -439,6 +432,8 @@ def get_model_provider_icon( file_name = provider_schema.icon_large.zh_Hans else: file_name = provider_schema.icon_large.en_US + if not file_name: + return None, None root_path = current_app.root_path provider_instance_path = os.path.dirname( @@ -524,7 +519,7 @@ def disable_model(self, tenant_id: str, provider: str, model: str, model_type: s def free_quota_submit(self, tenant_id: str, provider: str): api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") - api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") + api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL", "") api_url = api_base_url + "/api/v1/providers/apply" headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} @@ -545,7 +540,7 @@ def free_quota_submit(self, tenant_id: str, provider: str): def free_quota_qualification_verify(self, tenant_id: str, provider: str, token: Optional[str]): api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") - api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") + api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL", "") api_url = api_base_url + "/api/v1/providers/qualification-verify" headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} diff --git a/api/services/moderation_service.py b/api/services/moderation_service.py index dfb21e767fc9b9..082afeed89a5e4 100644 --- a/api/services/moderation_service.py +++ b/api/services/moderation_service.py @@ -1,3 +1,5 @@ +from typing import Optional + from core.moderation.factory import ModerationFactory, ModerationOutputsResult from extensions.ext_database import db from models.model import App, AppModelConfig @@ -5,7 +7,7 @@ class ModerationService: def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult: - app_model_config: AppModelConfig = None + app_model_config: Optional[AppModelConfig] = None app_model_config = ( db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() diff --git a/api/services/ops_service.py b/api/services/ops_service.py index 1160a1f2751d74..fc1e08518b1945 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -1,3 +1,5 @@ +from typing import Optional + from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map from extensions.ext_database import db from models.model import App, TraceAppConfig @@ -12,7 +14,7 @@ def get_tracing_app_config(cls, app_id: str, tracing_provider: str): :param tracing_provider: tracing provider :return: """ - trace_config_data: TraceAppConfig = ( + trace_config_data: Optional[TraceAppConfig] = ( db.session.query(TraceAppConfig) .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() @@ -22,7 +24,10 @@ def get_tracing_app_config(cls, app_id: str, tracing_provider: str): return None # decrypt_token and obfuscated_token - tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id + tenant = db.session.query(App).filter(App.id == app_id).first() + if not tenant: + return None + tenant_id = tenant.tenant_id decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config( tenant_id, tracing_provider, trace_config_data.tracing_config ) @@ -73,8 +78,9 @@ def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_c provider_config_map[tracing_provider]["config_class"], provider_config_map[tracing_provider]["other_keys"], ) - default_config_instance = config_class(**tracing_config) - for key in other_keys: + # FIXME: ignore type error + default_config_instance = config_class(**tracing_config) # type: ignore + for key in other_keys: # type: ignore if key in tracing_config and tracing_config[key] == "": tracing_config[key] = getattr(default_config_instance, key, None) @@ -92,7 +98,7 @@ def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_c project_url = None # check if trace config already exists - trace_config_data: TraceAppConfig = ( + trace_config_data: Optional[TraceAppConfig] = ( db.session.query(TraceAppConfig) .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() @@ -102,7 +108,10 @@ def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_c return None # get tenant id - tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id + tenant = db.session.query(App).filter(App.id == app_id).first() + if not tenant: + return None + tenant_id = tenant.tenant_id tracing_config = OpsTraceManager.encrypt_tracing_config(tenant_id, tracing_provider, tracing_config) if project_url: tracing_config["project_url"] = project_url @@ -139,7 +148,10 @@ def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_c return None # get tenant id - tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id + tenant = db.session.query(App).filter(App.id == app_id).first() + if not tenant: + return None + tenant_id = tenant.tenant_id tracing_config = OpsTraceManager.encrypt_tracing_config( tenant_id, tracing_provider, tracing_config, current_trace_config.tracing_config ) diff --git a/api/services/recommend_app/buildin/buildin_retrieval.py b/api/services/recommend_app/buildin/buildin_retrieval.py index 4704d533a950ed..523aebeed52a4e 100644 --- a/api/services/recommend_app/buildin/buildin_retrieval.py +++ b/api/services/recommend_app/buildin/buildin_retrieval.py @@ -41,7 +41,7 @@ def _get_builtin_data(cls) -> dict: Path(path.join(root_path, "constants", "recommended_apps.json")).read_text(encoding="utf-8") ) - return cls.builtin_data + return cls.builtin_data or {} @classmethod def fetch_recommended_apps_from_builtin(cls, language: str) -> dict: @@ -50,8 +50,8 @@ def fetch_recommended_apps_from_builtin(cls, language: str) -> dict: :param language: language :return: """ - builtin_data = cls._get_builtin_data() - return builtin_data.get("recommended_apps", {}).get(language) + builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() + return builtin_data.get("recommended_apps", {}).get(language, {}) @classmethod def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dict]: @@ -60,5 +60,5 @@ def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dict :param app_id: App ID :return: """ - builtin_data = cls._get_builtin_data() + builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() return builtin_data.get("app_details", {}).get(app_id) diff --git a/api/services/recommend_app/remote/remote_retrieval.py b/api/services/recommend_app/remote/remote_retrieval.py index b0607a21323acb..80e1aefc01da85 100644 --- a/api/services/recommend_app/remote/remote_retrieval.py +++ b/api/services/recommend_app/remote/remote_retrieval.py @@ -47,8 +47,8 @@ def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> Optiona response = requests.get(url, timeout=(3, 10)) if response.status_code != 200: return None - - return response.json() + data: dict = response.json() + return data @classmethod def fetch_recommended_apps_from_dify_official(cls, language: str) -> dict: @@ -63,7 +63,7 @@ def fetch_recommended_apps_from_dify_official(cls, language: str) -> dict: if response.status_code != 200: raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}") - result = response.json() + result: dict = response.json() if "categories" in result: result["categories"] = sorted(result["categories"]) diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index 4660316fcfcf71..54c58455155c03 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -33,5 +33,5 @@ def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]: """ mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)() - result = retrieval_instance.get_recommend_app_detail(app_id) + result: dict = retrieval_instance.get_recommend_app_detail(app_id) return result diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py index 9fe3cecce7546d..4cb8700117e6f3 100644 --- a/api/services/saved_message_service.py +++ b/api/services/saved_message_service.py @@ -13,6 +13,8 @@ class SavedMessageService: def pagination_by_last_id( cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int ) -> InfiniteScrollPagination: + if not user: + raise ValueError("User is required") saved_messages = ( db.session.query(SavedMessage) .filter( @@ -31,6 +33,8 @@ def pagination_by_last_id( @classmethod def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): + if not user: + return saved_message = ( db.session.query(SavedMessage) .filter( @@ -59,6 +63,8 @@ def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_i @classmethod def delete(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): + if not user: + return saved_message = ( db.session.query(SavedMessage) .filter( diff --git a/api/services/tag_service.py b/api/services/tag_service.py index a374bdcf002bef..9600601633cddb 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -1,7 +1,7 @@ import uuid from typing import Optional -from flask_login import current_user +from flask_login import current_user # type: ignore from sqlalchemy import func from werkzeug.exceptions import NotFound @@ -21,7 +21,7 @@ def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = Non if keyword: query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%"))) query = query.group_by(Tag.id) - results = query.order_by(Tag.created_at.desc()).all() + results: list = query.order_by(Tag.created_at.desc()).all() return results @staticmethod diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 78a80f70ab6b00..0e3bd3a7b83c68 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -1,6 +1,7 @@ import json import logging -from typing import Optional +from collections.abc import Mapping +from typing import Any, Optional, cast from httpx import get @@ -28,12 +29,12 @@ class ApiToolManageService: @staticmethod - def parser_api_schema(schema: str) -> list[ApiToolBundle]: + def parser_api_schema(schema: str) -> Mapping[str, Any]: """ parse api schema to tool bundle """ try: - warnings = {} + warnings: dict[str, str] = {} try: tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings) except Exception as e: @@ -68,13 +69,16 @@ def parser_api_schema(schema: str) -> list[ApiToolBundle]: ), ] - return jsonable_encoder( - { - "schema_type": schema_type, - "parameters_schema": tool_bundles, - "credentials_schema": credentials_schema, - "warning": warnings, - } + return cast( + Mapping, + jsonable_encoder( + { + "schema_type": schema_type, + "parameters_schema": tool_bundles, + "credentials_schema": credentials_schema, + "warning": warnings, + } + ), ) except Exception as e: raise ValueError(f"invalid schema: {str(e)}") @@ -129,7 +133,7 @@ def create_api_tool_provider( raise ValueError(f"provider {provider_name} already exists") # parse openapi to tool bundle - extra_info = {} + extra_info: dict[str, str] = {} # extra info like description will be set here tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) @@ -262,9 +266,8 @@ def update_api_tool_provider( if provider is None: raise ValueError(f"api provider {provider_name} does not exists") - # parse openapi to tool bundle - extra_info = {} + extra_info: dict[str, str] = {} # extra info like description will be set here tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) @@ -416,7 +419,7 @@ def test_api_tool_preview( provider_controller.validate_credentials_format(credentials) # get tool tool = provider_controller.get_tool(tool_name) - tool = tool.fork_tool_runtime( + runtime_tool = tool.fork_tool_runtime( runtime={ "credentials": credentials, "tenant_id": tenant_id, @@ -454,7 +457,7 @@ def list_api_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id) - for tool in tools: + for tool in tools or []: user_provider.tools.append( ToolTransformService.tool_to_user_tool( tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index fada881fdeb741..21adbb0074724e 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -50,8 +50,8 @@ def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str credentials = builtin_provider.credentials credentials = tool_provider_configurations.decrypt_tool_credentials(credentials) - result = [] - for tool in tools: + result: list[UserTool] = [] + for tool in tools or []: result.append( ToolTransformService.tool_to_user_tool( tool=tool, @@ -217,6 +217,8 @@ def list_builtin_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: name_func=lambda x: x.identity.name, ): continue + if provider_controller.identity is None: + continue # convert provider controller to user provider user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( @@ -229,7 +231,7 @@ def list_builtin_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: ToolTransformService.repack_provider(user_builtin_provider) tools = provider_controller.get_tools() - for tool in tools: + for tool in tools or []: user_builtin_provider.tools.append( ToolTransformService.tool_to_user_tool( tenant_id=tenant_id, diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index a4aa870dc80352..b501554bcd091d 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -1,6 +1,6 @@ import json import logging -from typing import Optional, Union +from typing import Optional, Union, cast from configs import dify_config from core.tools.entities.api_entities import UserTool, UserToolProvider @@ -35,7 +35,7 @@ def get_tool_provider_icon_url(provider_type: str, provider_name: str, icon: str return url_prefix + "builtin/" + provider_name + "/icon" elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}: try: - return json.loads(icon) + return cast(dict, json.loads(icon)) except: return {"background": "#252525", "content": "\ud83d\ude01"} @@ -53,8 +53,11 @@ def repack_provider(provider: Union[dict, UserToolProvider]): provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"] ) elif isinstance(provider, UserToolProvider): - provider.icon = ToolTransformService.get_tool_provider_icon_url( - provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon + provider.icon = cast( + str, + ToolTransformService.get_tool_provider_icon_url( + provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon + ), ) @staticmethod @@ -66,6 +69,9 @@ def builtin_provider_to_user_provider( """ convert provider controller to user provider """ + if provider_controller.identity is None: + raise ValueError("provider identity is None") + result = UserToolProvider( id=provider_controller.identity.name, author=provider_controller.identity.author, @@ -93,7 +99,8 @@ def builtin_provider_to_user_provider( # get credentials schema schema = provider_controller.get_credentials_schema() for name, value in schema.items(): - result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(value.type) + assert result.masked_credentials is not None, "masked credentials is None" + result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(str(value.type)) # check if the provider need credentials if not provider_controller.need_credentials: @@ -149,6 +156,9 @@ def workflow_provider_to_user_provider( """ convert provider controller to user provider """ + if provider_controller.identity is None: + raise ValueError("provider identity is None") + return UserToolProvider( id=provider_controller.provider_id, author=provider_controller.identity.author, @@ -180,6 +190,8 @@ def api_provider_to_user_provider( convert provider controller to user provider """ username = "Anonymous" + if db_provider.user is None: + raise ValueError(f"user is None for api provider {db_provider.id}") try: username = db_provider.user.name except Exception as e: @@ -256,19 +268,25 @@ def tool_to_user_tool( if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: current_parameters.append(runtime_parameter) + if tool.identity is None: + raise ValueError("tool identity is None") + return UserTool( author=tool.identity.author, name=tool.identity.name, label=tool.identity.label, - description=tool.description.human, + description=I18nObject( + en_US=tool.description.human if tool.description else "", + zh_Hans=tool.description.human if tool.description else "", + ), parameters=current_parameters, labels=labels, ) if isinstance(tool, ApiToolBundle): return UserTool( author=tool.author, - name=tool.operation_id, - label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id), + name=tool.operation_id or "", + label=I18nObject(en_US=tool.operation_id or "", zh_Hans=tool.operation_id or ""), description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""), parameters=tool.parameters, labels=labels, diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 318107bebb5eb6..69430de432b143 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -6,8 +6,10 @@ from sqlalchemy import or_ from core.model_runtime.utils.encoders import jsonable_encoder -from core.tools.entities.api_entities import UserToolProvider +from core.tools.entities.api_entities import UserTool, UserToolProvider +from core.tools.provider.tool_provider import ToolProviderController from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController +from core.tools.tool.tool import Tool from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from extensions.ext_database import db @@ -32,7 +34,7 @@ def create_workflow_tool( label: str, icon: dict, description: str, - parameters: Mapping[str, Any], + parameters: list[Mapping[str, Any]], privacy_policy: str = "", labels: Optional[list[str]] = None, ) -> dict: @@ -97,7 +99,7 @@ def update_workflow_tool( label: str, icon: dict, description: str, - parameters: list[dict], + parameters: list[Mapping[str, Any]], privacy_policy: str = "", labels: Optional[list[str]] = None, ) -> dict: @@ -131,7 +133,7 @@ def update_workflow_tool( if existing_workflow_tool_provider is not None: raise ValueError(f"Tool with name {name} already exists") - workflow_tool_provider: WorkflowToolProvider = ( + workflow_tool_provider: Optional[WorkflowToolProvider] = ( db.session.query(WorkflowToolProvider) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) .first() @@ -140,14 +142,14 @@ def update_workflow_tool( if workflow_tool_provider is None: raise ValueError(f"Tool {workflow_tool_id} not found") - app: App = ( + app: Optional[App] = ( db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first() ) if app is None: raise ValueError(f"App {workflow_tool_provider.app_id} not found") - workflow: Workflow = app.workflow + workflow: Optional[Workflow] = app.workflow if workflow is None: raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}") @@ -193,7 +195,7 @@ def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserTo # skip deleted tools pass - labels = ToolLabelManager.get_tools_labels(tools) + labels = ToolLabelManager.get_tools_labels([t for t in tools if isinstance(t, ToolProviderController)]) result = [] @@ -202,10 +204,11 @@ def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserTo provider_controller=tool, labels=labels.get(tool.provider_id, []) ) ToolTransformService.repack_provider(user_tool_provider) + to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id) + if to_user_tool is None or len(to_user_tool) == 0: + continue user_tool_provider.tools = [ - ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], labels=labels.get(tool.provider_id, []) - ) + ToolTransformService.tool_to_user_tool(to_user_tool[0], labels=labels.get(tool.provider_id, [])) ] result.append(user_tool_provider) @@ -236,7 +239,7 @@ def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_too :param workflow_app_id: the workflow app id :return: the tool """ - db_tool: WorkflowToolProvider = ( + db_tool: Optional[WorkflowToolProvider] = ( db.session.query(WorkflowToolProvider) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) .first() @@ -245,13 +248,19 @@ def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_too if db_tool is None: raise ValueError(f"Tool {workflow_tool_id} not found") - workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() + workflow_app: Optional[App] = ( + db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() + ) if workflow_app is None: raise ValueError(f"App {db_tool.app_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) + to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id) + if to_user_tool is None or len(to_user_tool) == 0: + raise ValueError(f"Tool {workflow_tool_id} not found") + return { "name": db_tool.name, "label": db_tool.label, @@ -261,9 +270,9 @@ def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_too "description": db_tool.description, "parameters": jsonable_encoder(db_tool.parameter_configurations), "tool": ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) + to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool) ), - "synced": workflow_app.workflow.version == db_tool.version, + "synced": workflow_app.workflow.version == db_tool.version if workflow_app.workflow else False, "privacy_policy": db_tool.privacy_policy, } @@ -276,7 +285,7 @@ def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_ :param workflow_app_id: the workflow app id :return: the tool """ - db_tool: WorkflowToolProvider = ( + db_tool: Optional[WorkflowToolProvider] = ( db.session.query(WorkflowToolProvider) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id) .first() @@ -285,12 +294,17 @@ def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_ if db_tool is None: raise ValueError(f"Tool {workflow_app_id} not found") - workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() + workflow_app: Optional[App] = ( + db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() + ) if workflow_app is None: raise ValueError(f"App {db_tool.app_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) + to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id) + if to_user_tool is None or len(to_user_tool) == 0: + raise ValueError(f"Tool {workflow_app_id} not found") return { "name": db_tool.name, @@ -301,14 +315,14 @@ def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_ "description": db_tool.description, "parameters": jsonable_encoder(db_tool.parameter_configurations), "tool": ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) + to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool) ), - "synced": workflow_app.workflow.version == db_tool.version, + "synced": workflow_app.workflow.version == db_tool.version if workflow_app.workflow else False, "privacy_policy": db_tool.privacy_policy, } @classmethod - def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[dict]: + def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[UserTool]: """ List workflow tool provider tools. :param user_id: the user id @@ -316,7 +330,7 @@ def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_ :param workflow_app_id: the workflow app id :return: the list of tools """ - db_tool: WorkflowToolProvider = ( + db_tool: Optional[WorkflowToolProvider] = ( db.session.query(WorkflowToolProvider) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) .first() @@ -326,9 +340,8 @@ def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_ raise ValueError(f"Tool {workflow_tool_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) + to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id) + if to_user_tool is None or len(to_user_tool) == 0: + raise ValueError(f"Tool {workflow_tool_id} not found") - return [ - ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) - ) - ] + return [ToolTransformService.tool_to_user_tool(to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool))] diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index 508fe20970a703..f698ed3084bdac 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -26,6 +26,8 @@ def pagination_by_last_id( pinned: Optional[bool] = None, sort_by="-updated_at", ) -> InfiniteScrollPagination: + if not user: + raise ValueError("User is required") include_ids = None exclude_ids = None if pinned is not None and user: @@ -59,6 +61,8 @@ def pagination_by_last_id( @classmethod def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): + if not user: + return pinned_conversation = ( db.session.query(PinnedConversation) .filter( @@ -89,6 +93,8 @@ def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, @classmethod def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): + if not user: + return pinned_conversation = ( db.session.query(PinnedConversation) .filter( diff --git a/api/services/website_service.py b/api/services/website_service.py index 230f5d78152f39..1ad7d0399d6edf 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -1,8 +1,9 @@ import datetime import json +from typing import Any import requests -from flask_login import current_user +from flask_login import current_user # type: ignore from core.helper import encrypter from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp @@ -23,9 +24,9 @@ def document_create_args_validate(cls, args: dict): @classmethod def crawl_url(cls, args: dict) -> dict: - provider = args.get("provider") + provider = args.get("provider", "") url = args.get("url") - options = args.get("options") + options = args.get("options", "") credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider) if provider == "firecrawl": # decrypt api_key @@ -164,16 +165,18 @@ def get_crawl_status(cls, job_id: str, provider: str) -> dict: return crawl_status_data @classmethod - def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None: + def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[Any, Any] | None: credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) # decrypt api_key api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) + # FIXME data is redefine too many times here, use Any to ease the type checking, fix it later + data: Any if provider == "firecrawl": file_key = "website_files/" + job_id + ".txt" if storage.exists(file_key): - data = storage.load_once(file_key) - if data: - data = json.loads(data.decode("utf-8")) + d = storage.load_once(file_key) + if d: + data = json.loads(d.decode("utf-8")) else: firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) result = firecrawl_app.check_crawl_status(job_id) @@ -183,22 +186,17 @@ def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str if data: for item in data: if item.get("source_url") == url: - return item + return dict(item) return None elif provider == "jinareader": - file_key = "website_files/" + job_id + ".txt" - if storage.exists(file_key): - data = storage.load_once(file_key) - if data: - data = json.loads(data.decode("utf-8")) - elif not job_id: + if not job_id: response = requests.get( f"https://r.jina.ai/{url}", headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, ) if response.json().get("code") != 200: raise ValueError("Failed to crawl") - return response.json().get("data") + return dict(response.json().get("data", {})) else: api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) response = requests.post( @@ -218,12 +216,13 @@ def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str data = response.json().get("data", {}) for item in data.get("processed", {}).values(): if item.get("data", {}).get("url") == url: - return item.get("data", {}) + return dict(item.get("data", {})) + return None else: raise ValueError("Invalid provider") @classmethod - def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict | None: + def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict: credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) if provider == "firecrawl": # decrypt api_key diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 90b5cc48362f3b..2b0d57bdfdeda3 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -1,5 +1,5 @@ import json -from typing import Optional +from typing import Any, Optional from core.app.app_config.entities import ( DatasetEntity, @@ -101,7 +101,7 @@ def convert_app_model_config_to_workflow(self, app_model: App, app_model_config: app_config = self._convert_to_app_config(app_model=app_model, app_model_config=app_model_config) # init workflow graph - graph = {"nodes": [], "edges": []} + graph: dict[str, Any] = {"nodes": [], "edges": []} # Convert list: # - variables -> start @@ -118,7 +118,7 @@ def convert_app_model_config_to_workflow(self, app_model: App, app_model_config: graph["nodes"].append(start_node) # convert to http request node - external_data_variable_node_mapping = {} + external_data_variable_node_mapping: dict[str, str] = {} if app_config.external_data_variables: http_request_nodes, external_data_variable_node_mapping = self._convert_to_http_request_node( app_model=app_model, @@ -199,15 +199,16 @@ def convert_app_model_config_to_workflow(self, app_model: App, app_model_config: return workflow def _convert_to_app_config(self, app_model: App, app_model_config: AppModelConfig) -> EasyUIBasedAppConfig: - app_mode = AppMode.value_of(app_model.mode) - if app_mode == AppMode.AGENT_CHAT or app_model.is_agent: + app_mode_enum = AppMode.value_of(app_model.mode) + app_config: EasyUIBasedAppConfig + if app_mode_enum == AppMode.AGENT_CHAT or app_model.is_agent: app_model.mode = AppMode.AGENT_CHAT.value app_config = AgentChatAppConfigManager.get_app_config( app_model=app_model, app_model_config=app_model_config ) - elif app_mode == AppMode.CHAT: + elif app_mode_enum == AppMode.CHAT: app_config = ChatAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config) - elif app_mode == AppMode.COMPLETION: + elif app_mode_enum == AppMode.COMPLETION: app_config = CompletionAppConfigManager.get_app_config( app_model=app_model, app_model_config=app_model_config ) @@ -302,7 +303,7 @@ def _convert_to_http_request_node( nodes.append(http_request_node) # append code node for response body parsing - code_node = { + code_node: dict[str, Any] = { "id": f"code_{index}", "position": None, "data": { @@ -401,6 +402,7 @@ def _convert_to_llm_node( ) role_prefix = None + prompts: Any = None # Chat Model if model_config.mode == LLMMode.CHAT.value: diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index d8ee323908a844..4343596a236f5f 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -1,3 +1,5 @@ +from typing import Optional + from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom @@ -92,7 +94,7 @@ def get_paginate_workflow_runs(self, app_model: App, args: dict) -> InfiniteScro return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) - def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun: + def get_workflow_run(self, app_model: App, run_id: str) -> Optional[WorkflowRun]: """ Get workflow run detail diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 84768d5af053e4..ea8192edde35cc 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -2,7 +2,7 @@ import time from collections.abc import Sequence from datetime import UTC, datetime -from typing import Optional, cast +from typing import Any, Optional, cast from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager @@ -242,7 +242,7 @@ def run_draft_workflow_node( raise ValueError("Node run failed with no run result") # single step debug mode error handling return if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error: - node_error_args = { + node_error_args: dict[str, Any] = { "status": WorkflowNodeExecutionStatus.EXCEPTION, "error": node_run_result.error, "inputs": node_run_result.inputs, @@ -338,7 +338,7 @@ def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> A raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.") # convert to workflow - new_app = workflow_converter.convert_to_workflow( + new_app: App = workflow_converter.convert_to_workflow( app_model=app_model, account=account, name=args.get("name", "Default Name"), diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index 8fcb12b1cb9664..7637b31454e556 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -1,4 +1,4 @@ -from flask_login import current_user +from flask_login import current_user # type: ignore from configs import dify_config from extensions.ext_database import db @@ -29,6 +29,7 @@ def get_tenant_info(cls, tenant: Tenant): .filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id) .first() ) + assert tenant_account_join is not None, "TenantAccountJoin not found" tenant_info["role"] = tenant_account_join.role can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo diff --git a/api/tasks/__init__.py b/api/tasks/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index 09be6612160471..50bb2b6e634fba 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.index_processor.index_processor_factory import IndexProcessorFactory diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py index 25c55bcfafe11c..aab21a44109975 100644 --- a/api/tasks/annotation/add_annotation_to_index_task.py +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index fa7e5ac9190f3c..06162b02d60f8b 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.datasource.vdb.vector_factory import Vector diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py index f0f6b32b06c78c..a6a598ce4b6bca 100644 --- a/api/tasks/annotation/delete_annotation_index_task.py +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.datasource.vdb.vector_factory import Vector from models.dataset import Dataset diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index a2f49135139b08..26bf1c7c9fa32e 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.datasource.vdb.vector_factory import Vector diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index 0bdcd0eccd7f72..b42af0c7faf67e 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.datasource.vdb.vector_factory import Vector diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py index b685d84d07ad28..8c675feaa6e06f 100644 --- a/api/tasks/annotation/update_annotation_to_index_task.py +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index dcb7009e44b938..26ae9f8736d79a 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -4,7 +4,7 @@ import uuid import click -from celery import shared_task +from celery import shared_task # type: ignore from sqlalchemy import func from core.indexing_runner import IndexingRunner @@ -58,12 +58,13 @@ def batch_create_segment_to_index_task( model=dataset.embedding_model, ) word_count_change = 0 + segments_to_insert: list[str] = [] # Explicitly type hint the list as List[str] for segment in content: - content = segment["content"] + content_str = segment["content"] doc_id = str(uuid.uuid4()) - segment_hash = helper.generate_text_hash(content) + segment_hash = helper.generate_text_hash(content_str) # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) if embedding_model else 0 + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content_str]) if embedding_model else 0 max_position = ( db.session.query(func.max(DocumentSegment.position)) .filter(DocumentSegment.document_id == dataset_document.id) @@ -90,6 +91,7 @@ def batch_create_segment_to_index_task( word_count_change += segment_document.word_count db.session.add(segment_document) document_segments.append(segment_document) + segments_to_insert.append(str(segment)) # Cast to string if needed # update document word count dataset_document.word_count += word_count_change db.session.add(dataset_document) diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index a555fb28746697..d9278c03793877 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.web_reader_tool import get_image_upload_file_ids @@ -71,6 +71,8 @@ def clean_dataset_task( image_upload_file_ids = get_image_upload_file_ids(segment.content) for upload_file_id in image_upload_file_ids: image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() + if image_file is None: + continue try: storage.delete(image_file.key) except Exception: diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 4d328643bfa165..3e80dd13771802 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -3,7 +3,7 @@ from typing import Optional import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.web_reader_tool import get_image_upload_file_ids @@ -44,6 +44,8 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i image_upload_file_ids = get_image_upload_file_ids(segment.content) for upload_file_id in image_upload_file_ids: image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() + if image_file is None: + continue try: storage.delete(image_file.key) except Exception: diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index 75d9e031306381..f5d6406d9cc04f 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index 315b01f157bf13..dfa053a43cbc61 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -4,7 +4,7 @@ from typing import Optional import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.index_processor.index_processor_factory import IndexProcessorFactory diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index cfc54920e23caa..b025509aebe674 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import Document diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index c3e0ea5d9fbb77..45a612c74550cd 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py index 15e1e50076e8c9..f30a1cc7acfd6c 100644 --- a/api/tasks/disable_segment_from_index_task.py +++ b/api/tasks/disable_segment_from_index_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.index_processor.index_processor_factory import IndexProcessorFactory diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 18316913932874..ac4e81f95d127e 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.indexing_runner import DocumentIsPausedError, IndexingRunner diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 734dd2478a9847..21b571b6cb5bd4 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from configs import dify_config from core.indexing_runner import DocumentIsPausedError, IndexingRunner diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 1a52a6636b1d17..5f1e9a892f54e3 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.indexing_runner import DocumentIsPausedError, IndexingRunner diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index f4c3dbd2e2860c..6db2620eb6eef0 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from configs import dify_config from core.indexing_runner import DocumentIsPausedError, IndexingRunner @@ -26,6 +26,8 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): start_at = time.perf_counter() dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if dataset is None: + raise ValueError("Dataset not found") # check document limit features = FeatureService.get_features(dataset.tenant_id) diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 12639db9392677..2f6eb7b82a0633 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.index_processor.index_processor_factory import IndexProcessorFactory diff --git a/api/tasks/mail_email_code_login.py b/api/tasks/mail_email_code_login.py index d78fc2b8915520..5dc935548f90b8 100644 --- a/api/tasks/mail_email_code_login.py +++ b/api/tasks/mail_email_code_login.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from flask import render_template from extensions.ext_mail import mail diff --git a/api/tasks/mail_invite_member_task.py b/api/tasks/mail_invite_member_task.py index c7dfb9bf6063ff..3094527fd40945 100644 --- a/api/tasks/mail_invite_member_task.py +++ b/api/tasks/mail_invite_member_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from flask import render_template from configs import dify_config diff --git a/api/tasks/mail_reset_password_task.py b/api/tasks/mail_reset_password_task.py index 8596ca07cfcee3..d5be94431b6221 100644 --- a/api/tasks/mail_reset_password_task.py +++ b/api/tasks/mail_reset_password_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from flask import render_template from extensions.ext_mail import mail diff --git a/api/tasks/ops_trace_task.py b/api/tasks/ops_trace_task.py index 34c62dc9237fc0..bb3b9e17ead6d2 100644 --- a/api/tasks/ops_trace_task.py +++ b/api/tasks/ops_trace_task.py @@ -1,7 +1,7 @@ import json import logging -from celery import shared_task +from celery import shared_task # type: ignore from flask import current_app from core.ops.entities.config_entity import OPS_FILE_PATH, OPS_TRACE_FAILED_KEY diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index 934eb7430c90c3..b603d689ba9d8e 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.indexing_runner import DocumentIsPausedError, IndexingRunner diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 66f78636ecca60..c3910e2be3a499 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -3,7 +3,7 @@ from collections.abc import Callable import click -from celery import shared_task +from celery import shared_task # type: ignore from sqlalchemy import delete from sqlalchemy.exc import SQLAlchemyError diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index 1909eaf3418517..4ba6d1a83e32ae 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.index_processor.index_processor_factory import IndexProcessorFactory diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index 73471fd6e77c9b..485caa5152ea78 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.indexing_runner import IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -22,10 +22,13 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): Usage: retry_document_indexing_task.delay(dataset_id, document_id) """ - documents = [] + documents: list[Document] = [] start_at = time.perf_counter() dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise ValueError("Dataset not found") + for document_id in document_ids: retry_indexing_cache_key = "document_{}_is_retried".format(document_id) # check document limit @@ -55,29 +58,31 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): document = ( db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() ) + if not document: + logging.info(click.style("Document not found: {}".format(document_id), fg="yellow")) + return try: - if document: - # clean old data - index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() + # clean old data + index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - # delete from vector index - index_processor.clean(dataset, index_node_ids) + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + # delete from vector index + index_processor.clean(dataset, index_node_ids) - for segment in segments: - db.session.delete(segment) - db.session.commit() - - document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.utcnow() - db.session.add(document) + for segment in segments: + db.session.delete(segment) db.session.commit() - indexing_runner = IndexingRunner() - indexing_runner.run([document]) - redis_client.delete(retry_indexing_cache_key) + document.indexing_status = "parsing" + document.processing_started_at = datetime.datetime.utcnow() + db.session.add(document) + db.session.commit() + + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + redis_client.delete(retry_indexing_cache_key) except Exception as ex: document.indexing_status = "error" document.error = str(ex) diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py index 1d2a338c831764..5d6b069cf44919 100644 --- a/api/tasks/sync_website_document_indexing_task.py +++ b/api/tasks/sync_website_document_indexing_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.indexing_runner import IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -25,6 +25,8 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): start_at = time.perf_counter() dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if dataset is None: + raise ValueError("Dataset not found") sync_indexing_cache_key = "document_{}_is_sync".format(document_id) # check document limit @@ -52,29 +54,31 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): logging.info(click.style("Start sync website document: {}".format(document_id), fg="green")) document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + if not document: + logging.info(click.style("Document not found: {}".format(document_id), fg="yellow")) + return try: - if document: - # clean old data - index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() + # clean old data + index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - # delete from vector index - index_processor.clean(dataset, index_node_ids) + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + # delete from vector index + index_processor.clean(dataset, index_node_ids) - for segment in segments: - db.session.delete(segment) - db.session.commit() - - document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.utcnow() - db.session.add(document) + for segment in segments: + db.session.delete(segment) db.session.commit() - indexing_runner = IndexingRunner() - indexing_runner.run([document]) - redis_client.delete(sync_indexing_cache_key) + document.indexing_status = "parsing" + document.processing_started_at = datetime.datetime.utcnow() + db.session.add(document) + db.session.commit() + + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + redis_client.delete(sync_indexing_cache_key) except Exception as ex: document.indexing_status = "error" document.error = str(ex) diff --git a/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py b/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py index 64f2884c4b828c..57fba317638de8 100644 --- a/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py +++ b/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py @@ -1,6 +1,6 @@ from typing import Any -import toml +import toml # type: ignore def load_api_poetry_configs() -> dict[str, Any]: @@ -38,7 +38,7 @@ def test_group_dependencies_version_operator(): ) -def test_duplicated_dependency_crossing_groups(): +def test_duplicated_dependency_crossing_groups() -> None: all_dependency_names: list[str] = [] for dependencies in load_all_dependency_groups().values(): dependency_names = list(dependencies.keys()) diff --git a/api/tests/integration_tests/controllers/test_controllers.py b/api/tests/integration_tests/controllers/test_controllers.py index 6371694694653e..5e3ee6bedc7ebb 100644 --- a/api/tests/integration_tests/controllers/test_controllers.py +++ b/api/tests/integration_tests/controllers/test_controllers.py @@ -1,6 +1,6 @@ from unittest.mock import patch -from app_fixture import app, mock_user +from app_fixture import mock_user # type: ignore def test_post_requires_login(app): diff --git a/api/tests/integration_tests/model_runtime/__mock/google.py b/api/tests/integration_tests/model_runtime/__mock/google.py index 5ea86baa83dd4b..b90f8b444477d5 100644 --- a/api/tests/integration_tests/model_runtime/__mock/google.py +++ b/api/tests/integration_tests/model_runtime/__mock/google.py @@ -1,7 +1,7 @@ from collections.abc import Generator from unittest.mock import MagicMock -import google.generativeai.types.generation_types as generation_config_types +import google.generativeai.types.generation_types as generation_config_types # type: ignore import pytest from _pytest.monkeypatch import MonkeyPatch from google.ai import generativelanguage as glm @@ -45,7 +45,7 @@ def generate_content_sync() -> GenerateContentResponse: return GenerateContentResponse(done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]) @staticmethod - def generate_content_stream() -> Generator[GenerateContentResponse, None, None]: + def generate_content_stream() -> MockGoogleResponseClass: return MockGoogleResponseClass() def generate_content( diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface.py b/api/tests/integration_tests/model_runtime/__mock/huggingface.py index 97038ef5963e87..4de52514408a06 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface.py @@ -2,7 +2,7 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -from huggingface_hub import InferenceClient +from huggingface_hub import InferenceClient # type: ignore from tests.integration_tests.model_runtime.__mock.huggingface_chat import MockHuggingfaceChatClass diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py b/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py index 9ee76c935c9873..77c7e7f5e4089c 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py @@ -3,15 +3,15 @@ from typing import Any, Literal, Optional, Union from _pytest.monkeypatch import MonkeyPatch -from huggingface_hub import InferenceClient -from huggingface_hub.inference._text_generation import ( +from huggingface_hub import InferenceClient # type: ignore +from huggingface_hub.inference._text_generation import ( # type: ignore Details, StreamDetails, TextGenerationResponse, TextGenerationStreamResponse, Token, ) -from huggingface_hub.utils import BadRequestError +from huggingface_hub.utils import BadRequestError # type: ignore class MockHuggingfaceChatClass: diff --git a/api/tests/integration_tests/model_runtime/__mock/nomic_embeddings.py b/api/tests/integration_tests/model_runtime/__mock/nomic_embeddings.py index 6a25398cbf069a..4e00660a29162f 100644 --- a/api/tests/integration_tests/model_runtime/__mock/nomic_embeddings.py +++ b/api/tests/integration_tests/model_runtime/__mock/nomic_embeddings.py @@ -6,7 +6,7 @@ # import monkeypatch from _pytest.monkeypatch import MonkeyPatch -from nomic import embed +from nomic import embed # type: ignore def create_embedding(texts: list[str], model: str, **kwargs: Any) -> dict: diff --git a/api/tests/integration_tests/model_runtime/__mock/xinference.py b/api/tests/integration_tests/model_runtime/__mock/xinference.py index 794f4b0585632e..e2abaa52b939a6 100644 --- a/api/tests/integration_tests/model_runtime/__mock/xinference.py +++ b/api/tests/integration_tests/model_runtime/__mock/xinference.py @@ -6,14 +6,14 @@ from _pytest.monkeypatch import MonkeyPatch from requests import Response from requests.sessions import Session -from xinference_client.client.restful.restful_client import ( +from xinference_client.client.restful.restful_client import ( # type: ignore Client, RESTfulChatModelHandle, RESTfulEmbeddingModelHandle, RESTfulGenerateModelHandle, RESTfulRerankModelHandle, ) -from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage +from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage # type: ignore class MockXinferenceClass: diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py b/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py index 2dcfb92c63fee2..d37fcf897fc3a8 100644 --- a/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py @@ -1,6 +1,6 @@ import os -import dashscope +import dashscope # type: ignore import pytest from core.model_runtime.entities.rerank_entities import RerankResult diff --git a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py index 83f4d70ce9ac2f..2860739f0e30b3 100644 --- a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py +++ b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py @@ -1,5 +1,5 @@ from flask import Flask, request -from flask_restful import Api, Resource +from flask_restful import Api, Resource # type: ignore app = Flask(__name__) api = Api(app) diff --git a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py index 0ea61369c0304e..4af35a8befcaf8 100644 --- a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py @@ -4,11 +4,11 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -from pymochow import MochowClient -from pymochow.model.database import Database -from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState -from pymochow.model.schema import HNSWParams, VectorIndex -from pymochow.model.table import Table +from pymochow import MochowClient # type: ignore +from pymochow.model.database import Database # type: ignore +from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState # type: ignore +from pymochow.model.schema import HNSWParams, VectorIndex # type: ignore +from pymochow.model.table import Table # type: ignore from requests.adapters import HTTPAdapter diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py index 61d6ed16560c09..68a1e290adc120 100644 --- a/api/tests/integration_tests/vdb/__mock/tcvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py @@ -4,12 +4,12 @@ import pytest from _pytest.monkeypatch import MonkeyPatch from requests.adapters import HTTPAdapter -from tcvectordb import VectorDBClient -from tcvectordb.model.database import Collection, Database -from tcvectordb.model.document import Document, Filter -from tcvectordb.model.enum import ReadConsistency -from tcvectordb.model.index import Index -from xinference_client.types import Embedding +from tcvectordb import VectorDBClient # type: ignore +from tcvectordb.model.database import Collection, Database # type: ignore +from tcvectordb.model.document import Document, Filter # type: ignore +from tcvectordb.model.enum import ReadConsistency # type: ignore +from tcvectordb.model.index import Index # type: ignore +from xinference_client.types import Embedding # type: ignore class MockTcvectordbClass: diff --git a/api/tests/integration_tests/vdb/__mock/vikingdb.py b/api/tests/integration_tests/vdb/__mock/vikingdb.py index 0f40337feba6ee..3ad72e55501f58 100644 --- a/api/tests/integration_tests/vdb/__mock/vikingdb.py +++ b/api/tests/integration_tests/vdb/__mock/vikingdb.py @@ -4,7 +4,7 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -from volcengine.viking_db import ( +from volcengine.viking_db import ( # type: ignore Collection, Data, DistanceType, diff --git a/api/tests/unit_tests/oss/__mock/aliyun_oss.py b/api/tests/unit_tests/oss/__mock/aliyun_oss.py index 27e1c0ad85029b..4f6d8a2f54a4fd 100644 --- a/api/tests/unit_tests/oss/__mock/aliyun_oss.py +++ b/api/tests/unit_tests/oss/__mock/aliyun_oss.py @@ -4,8 +4,8 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -from oss2 import Bucket -from oss2.models import GetObjectResult, PutObjectResult +from oss2 import Bucket # type: ignore +from oss2.models import GetObjectResult, PutObjectResult # type: ignore from tests.unit_tests.oss.__mock.base import ( get_example_bucket, diff --git a/api/tests/unit_tests/oss/__mock/tencent_cos.py b/api/tests/unit_tests/oss/__mock/tencent_cos.py index 5189b68e87132a..c77c5b08f37d15 100644 --- a/api/tests/unit_tests/oss/__mock/tencent_cos.py +++ b/api/tests/unit_tests/oss/__mock/tencent_cos.py @@ -3,8 +3,8 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -from qcloud_cos import CosS3Client -from qcloud_cos.streambody import StreamBody +from qcloud_cos import CosS3Client # type: ignore +from qcloud_cos.streambody import StreamBody # type: ignore from tests.unit_tests.oss.__mock.base import ( get_example_bucket, diff --git a/api/tests/unit_tests/oss/__mock/volcengine_tos.py b/api/tests/unit_tests/oss/__mock/volcengine_tos.py index 649d93a20261d3..88df59f91c3071 100644 --- a/api/tests/unit_tests/oss/__mock/volcengine_tos.py +++ b/api/tests/unit_tests/oss/__mock/volcengine_tos.py @@ -4,8 +4,8 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -from tos import TosClientV2 -from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput +from tos import TosClientV2 # type: ignore +from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput # type: ignore from tests.unit_tests.oss.__mock.base import ( get_example_bucket, diff --git a/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/test_aliyun_oss.py b/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/test_aliyun_oss.py index 65d31352bd3437..380134bc46d02e 100644 --- a/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/test_aliyun_oss.py +++ b/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/test_aliyun_oss.py @@ -1,7 +1,7 @@ from unittest.mock import MagicMock, patch import pytest -from oss2 import Auth +from oss2 import Auth # type: ignore from extensions.storage.aliyun_oss_storage import AliyunOssStorage from tests.unit_tests.oss.__mock.aliyun_oss import setup_aliyun_oss_mock diff --git a/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py b/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py index 303f0493bda42f..d289751800633a 100644 --- a/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py +++ b/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py @@ -1,7 +1,7 @@ from unittest.mock import patch import pytest -from qcloud_cos import CosConfig +from qcloud_cos import CosConfig # type: ignore from extensions.storage.tencent_cos_storage import TencentCosStorage from tests.unit_tests.oss.__mock.base import ( diff --git a/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py index 5afbc9e8b4cb18..04988e85d85881 100644 --- a/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py +++ b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py @@ -1,5 +1,5 @@ import pytest -from tos import TosClientV2 +from tos import TosClientV2 # type: ignore from extensions.storage.volcengine_tos_storage import VolcengineTosStorage from tests.unit_tests.oss.__mock.base import ( diff --git a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py index 95b93651d57f80..8d645487278a5f 100644 --- a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py +++ b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py @@ -1,7 +1,7 @@ from textwrap import dedent import pytest -from yaml import YAMLError +from yaml import YAMLError # type: ignore from core.tools.utils.yaml_utils import load_yaml_file diff --git a/sdks/python-client/dify_client/client.py b/sdks/python-client/dify_client/client.py index e6644883018769..ee1b5c57e1d1d0 100644 --- a/sdks/python-client/dify_client/client.py +++ b/sdks/python-client/dify_client/client.py @@ -160,7 +160,10 @@ def get_result(self, workflow_run_id): class KnowledgeBaseClient(DifyClient): def __init__( - self, api_key, base_url: str = "https://api.dify.ai/v1", dataset_id: str = None + self, + api_key, + base_url: str = "https://api.dify.ai/v1", + dataset_id: str | None = None, ): """ Construct a KnowledgeBaseClient object. @@ -187,7 +190,9 @@ def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs): "GET", f"/datasets?page={page}&limit={page_size}", **kwargs ) - def create_document_by_text(self, name, text, extra_params: dict = None, **kwargs): + def create_document_by_text( + self, name, text, extra_params: dict | None = None, **kwargs + ): """ Create a document by text. @@ -225,7 +230,7 @@ def create_document_by_text(self, name, text, extra_params: dict = None, **kwarg return self._send_request("POST", url, json=data, **kwargs) def update_document_by_text( - self, document_id, name, text, extra_params: dict = None, **kwargs + self, document_id, name, text, extra_params: dict | None = None, **kwargs ): """ Update a document by text. @@ -262,7 +267,7 @@ def update_document_by_text( return self._send_request("POST", url, json=data, **kwargs) def create_document_by_file( - self, file_path, original_document_id=None, extra_params: dict = None + self, file_path, original_document_id=None, extra_params: dict | None = None ): """ Create a document by file. @@ -304,7 +309,7 @@ def create_document_by_file( ) def update_document_by_file( - self, document_id, file_path, extra_params: dict = None + self, document_id, file_path, extra_params: dict | None = None ): """ Update a document by file. @@ -372,7 +377,11 @@ def delete_document(self, document_id): return self._send_request("DELETE", url) def list_documents( - self, page: int = None, page_size: int = None, keyword: str = None, **kwargs + self, + page: int | None = None, + page_size: int | None = None, + keyword: str | None = None, + **kwargs, ): """ Get a list of documents in this dataset. @@ -402,7 +411,11 @@ def add_segments(self, document_id, segments, **kwargs): return self._send_request("POST", url, json=data, **kwargs) def query_segments( - self, document_id, keyword: str = None, status: str = None, **kwargs + self, + document_id, + keyword: str | None = None, + status: str | None = None, + **kwargs, ): """ Query segments in this document.