From 09fdef178fe4acfa065e5cd28d009388331fd3bd Mon Sep 17 00:00:00 2001 From: Jay Zhang Date: Fri, 30 Aug 2024 00:08:02 +0800 Subject: [PATCH] feat: lazy import to speed up loading (#261) --- .../vector_storage_extension.py | 34 ++++++++----------- .../embedding_extension.py | 6 +++- .../file_chunker/file_chunker_extension.py | 5 +-- .../gemini_llm_python/gemini_llm_extension.py | 5 ++- .../llama_index_chat_engine/extension.py | 13 ++++--- 5 files changed, 34 insertions(+), 29 deletions(-) diff --git a/agents/ten_packages/extension/aliyun_analyticdb_vector_storage/vector_storage_extension.py b/agents/ten_packages/extension/aliyun_analyticdb_vector_storage/vector_storage_extension.py index 65f6f72d..82475f08 100644 --- a/agents/ten_packages/extension/aliyun_analyticdb_vector_storage/vector_storage_extension.py +++ b/agents/ten_packages/extension/aliyun_analyticdb_vector_storage/vector_storage_extension.py @@ -4,8 +4,6 @@ import asyncio import os import json -from .client import AliGPDBClient -from .model import Model from ten import ( Extension, TenEnv, @@ -20,12 +18,6 @@ import threading from datetime import datetime -from alibabacloud_gpdb20160503.client import Client as gpdb20160503Client -from alibabacloud_tea_openapi import models as open_api_models -from alibabacloud_gpdb20160503 import models as gpdb_20160503_models -from alibabacloud_tea_util import models as util_models -from alibabacloud_tea_util.client import Client as UtilClient - class AliPGDBExtension(Extension): def __init__(self, name): @@ -37,7 +29,7 @@ def __init__(self, name): self.region_id = os.environ.get("ADBPG_INSTANCE_REGION") self.dbinstance_id = os.environ.get("ADBPG_INSTANCE_ID") self.endpoint = f"gpdb.aliyuncs.com" - self.client = None + self.model = None self.account = os.environ.get("ADBPG_ACCOUNT") self.account_password = os.environ.get("ADBPG_ACCOUNT_PASSWORD") self.namespace = os.environ.get("ADBPG_NAMESPACE") @@ -92,9 +84,15 @@ def on_start(self, ten: TenEnv) -> None: self.endpoint = "gpdb.aliyuncs.com" else: self.endpoint = f"gpdb.{self.region_id}.aliyuncs.com" - self.client = AliGPDBClient( + + # lazy import packages which requires long time to load + from .client import AliGPDBClient + from .model import Model + + client = AliGPDBClient( self.access_key_id, self.access_key_secret, self.endpoint ) + self.model = Model(self.region_id, self.dbinstance_id, client) self.thread = threading.Thread( target=asyncio.run, args=(self.__thread_routine(ten),) ) @@ -141,7 +139,6 @@ def on_cmd(self, ten: TenEnv, cmd: Cmd) -> None: ten.return_result(CmdResult.create(StatusCode.ERROR), cmd) async def async_create_collection(self, ten: TenEnv, cmd: Cmd): - m = Model(self.region_id, self.dbinstance_id, self.client) collection = cmd.get_property_string("collection_name") dimension = 1024 try: @@ -149,11 +146,11 @@ async def async_create_collection(self, ten: TenEnv, cmd: Cmd): except Exception as e: logger.warning(f"Error: {e}") - err = await m.create_collection_async( + err = await self.model.create_collection_async( self.account, self.account_password, self.namespace, collection ) if err is None: - await m.create_vector_index_async( + await self.model.create_vector_index_async( self.account, self.account_password, self.namespace, @@ -166,14 +163,13 @@ async def async_create_collection(self, ten: TenEnv, cmd: Cmd): async def async_upsert_vector(self, ten: TenEnv, cmd: Cmd): start_time = datetime.now() - m = Model(self.region_id, self.dbinstance_id, self.client) collection = cmd.get_property_string("collection_name") file = cmd.get_property_string("file_name") content = cmd.get_property_string("content") obj = json.loads(content) rows = [(file, item["text"], item["embedding"]) for item in obj] - err = await m.upsert_collection_data_async( + err = await self.model.upsert_collection_data_async( collection, self.namespace, self.namespace_password, rows ) logger.info( @@ -192,12 +188,11 @@ async def async_upsert_vector(self, ten: TenEnv, cmd: Cmd): async def async_query_vector(self, ten: TenEnv, cmd: Cmd): start_time = datetime.now() - m = Model(self.region_id, self.dbinstance_id, self.client) collection = cmd.get_property_string("collection_name") embedding = cmd.get_property_to_json("embedding") top_k = cmd.get_property_int("top_k") vector = json.loads(embedding) - response, error = await m.query_collection_data_async( + response, error = await self.model.query_collection_data_async( collection, self.namespace, self.namespace_password, vector, top_k=top_k ) logger.info( @@ -212,15 +207,14 @@ async def async_query_vector(self, ten: TenEnv, cmd: Cmd): if error: return ten.return_result(CmdResult.create(StatusCode.ERROR), cmd) else: - body = m.parse_collection_data(response.body) + body = self.model.parse_collection_data(response.body) ret = CmdResult.create(StatusCode.OK) ret.set_property_from_json("response", body) ten.return_result(ret, cmd) async def async_delete_collection(self, ten: TenEnv, cmd: Cmd): - m = Model(self.region_id, self.dbinstance_id, self.client) collection = cmd.get_property_string("collection_name") - err = await m.delete_collection_async( + err = await self.model.delete_collection_async( self.account, self.account_password, self.namespace, collection ) if err is None: diff --git a/agents/ten_packages/extension/aliyun_text_embedding/embedding_extension.py b/agents/ten_packages/extension/aliyun_text_embedding/embedding_extension.py index df8e411c..27653459 100644 --- a/agents/ten_packages/extension/aliyun_text_embedding/embedding_extension.py +++ b/agents/ten_packages/extension/aliyun_text_embedding/embedding_extension.py @@ -6,7 +6,6 @@ CmdResult, ) -import dashscope import json from typing import Generator, List from http import HTTPStatus @@ -45,6 +44,11 @@ def on_start(self, ten: TenEnv) -> None: self.api_key = self.get_property_string(ten, "api_key", self.api_key) self.model = self.get_property_string(ten, "model", self.api_key) + + # lazy import packages which requires long time to load + global dashscope + import dashscope + dashscope.api_key = self.api_key for i in range(self.parallel): diff --git a/agents/ten_packages/extension/file_chunker/file_chunker_extension.py b/agents/ten_packages/extension/file_chunker/file_chunker_extension.py index 4e2b4632..8676a5b2 100644 --- a/agents/ten_packages/extension/file_chunker/file_chunker_extension.py +++ b/agents/ten_packages/extension/file_chunker/file_chunker_extension.py @@ -14,8 +14,6 @@ ) from typing import List, Any from .log import logger -from llama_index.core import SimpleDirectoryReader -from llama_index.core.node_parser import SentenceSplitter import json from datetime import datetime import uuid, math @@ -63,6 +61,9 @@ def generate_collection_name(self) -> str: return "coll_" + uuid.uuid1().hex.lower() def split(self, path: str) -> List[Any]: + # lazy import packages which requires long time to load + from llama_index.core import SimpleDirectoryReader + from llama_index.core.node_parser import SentenceSplitter # load pdf file by path documents = SimpleDirectoryReader( diff --git a/agents/ten_packages/extension/gemini_llm_python/gemini_llm_extension.py b/agents/ten_packages/extension/gemini_llm_python/gemini_llm_extension.py index 0db28c57..9cbf248e 100644 --- a/agents/ten_packages/extension/gemini_llm_python/gemini_llm_extension.py +++ b/agents/ten_packages/extension/gemini_llm_python/gemini_llm_extension.py @@ -14,7 +14,6 @@ StatusCode, CmdResult, ) -from .gemini_llm import GeminiLLM, GeminiLLMConfig from .log import logger from .utils import get_micro_ts, parse_sentence @@ -45,6 +44,10 @@ class GeminiLLMExtension(Extension): def on_start(self, ten: TenEnv) -> None: logger.info("GeminiLLMExtension on_start") + + # lazy import packages which requires long time to load + from .gemini_llm import GeminiLLM, GeminiLLMConfig + # Prepare configuration gemini_llm_config = GeminiLLMConfig.default_config() diff --git a/agents/ten_packages/extension/llama_index_chat_engine/extension.py b/agents/ten_packages/extension/llama_index_chat_engine/extension.py index c4a62357..53319197 100644 --- a/agents/ten_packages/extension/llama_index_chat_engine/extension.py +++ b/agents/ten_packages/extension/llama_index_chat_engine/extension.py @@ -14,13 +14,8 @@ CmdResult, ) from .log import logger -from .astra_llm import ASTRALLM -from .astra_retriever import ASTRARetriever import queue, threading from datetime import datetime -from llama_index.core.chat_engine import SimpleChatEngine, ContextChatEngine -from llama_index.core.storage.chat_store import SimpleChatStore -from llama_index.core.memory import ChatMemoryBuffer PROPERTY_CHAT_MEMORY_TOKEN_LIMIT = "chat_memory_token_limit" PROPERTY_GREETING = "greeting" @@ -79,6 +74,8 @@ def on_start(self, ten: TenEnv) -> None: self.thread.start() # enable chat memory + from llama_index.core.storage.chat_store import SimpleChatStore + from llama_index.core.memory import ChatMemoryBuffer self.chat_memory = ChatMemoryBuffer.from_defaults( token_limit=self.chat_memory_token_limit, chat_store=SimpleChatStore(), @@ -204,9 +201,14 @@ def async_handle(self, ten: TenEnv): logger.info("process input text [%s] ts [%s]", input_text, ts) + # lazy import packages which requires long time to load + from .astra_llm import ASTRALLM + from .astra_retriever import ASTRARetriever + # prepare chat engine chat_engine = None if len(self.collection_name) > 0: + from llama_index.core.chat_engine import ContextChatEngine chat_engine = ContextChatEngine.from_defaults( llm=ASTRALLM(ten=ten), retriever=ASTRARetriever(ten=ten, coll=self.collection_name), @@ -229,6 +231,7 @@ def async_handle(self, ten: TenEnv): ), ) else: + from llama_index.core.chat_engine import SimpleChatEngine chat_engine = SimpleChatEngine.from_defaults( llm=ASTRALLM(ten=ten), system_prompt=(