Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy import slow python modules to speed up loading #261

Merged
merged 1 commit into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import asyncio
import os
import json
from .client import AliGPDBClient
from .model import Model
from ten import (
Extension,
TenEnv,
Expand All @@ -20,12 +18,6 @@
import threading
from datetime import datetime

from alibabacloud_gpdb20160503.client import Client as gpdb20160503Client
from alibabacloud_tea_openapi import models as open_api_models
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from alibabacloud_tea_util import models as util_models
from alibabacloud_tea_util.client import Client as UtilClient


class AliPGDBExtension(Extension):
def __init__(self, name):
Expand All @@ -37,7 +29,7 @@ def __init__(self, name):
self.region_id = os.environ.get("ADBPG_INSTANCE_REGION")
self.dbinstance_id = os.environ.get("ADBPG_INSTANCE_ID")
self.endpoint = f"gpdb.aliyuncs.com"
self.client = None
self.model = None
self.account = os.environ.get("ADBPG_ACCOUNT")
self.account_password = os.environ.get("ADBPG_ACCOUNT_PASSWORD")
self.namespace = os.environ.get("ADBPG_NAMESPACE")
Expand Down Expand Up @@ -92,9 +84,15 @@ def on_start(self, ten: TenEnv) -> None:
self.endpoint = "gpdb.aliyuncs.com"
else:
self.endpoint = f"gpdb.{self.region_id}.aliyuncs.com"
self.client = AliGPDBClient(

# lazy import packages which requires long time to load
from .client import AliGPDBClient
from .model import Model

client = AliGPDBClient(
self.access_key_id, self.access_key_secret, self.endpoint
)
self.model = Model(self.region_id, self.dbinstance_id, client)
self.thread = threading.Thread(
target=asyncio.run, args=(self.__thread_routine(ten),)
)
Expand Down Expand Up @@ -141,19 +139,18 @@ def on_cmd(self, ten: TenEnv, cmd: Cmd) -> None:
ten.return_result(CmdResult.create(StatusCode.ERROR), cmd)

async def async_create_collection(self, ten: TenEnv, cmd: Cmd):
m = Model(self.region_id, self.dbinstance_id, self.client)
collection = cmd.get_property_string("collection_name")
dimension = 1024
try:
dimension = cmd.get_property_int("dimension")
except Exception as e:
logger.warning(f"Error: {e}")

err = await m.create_collection_async(
err = await self.model.create_collection_async(
self.account, self.account_password, self.namespace, collection
)
if err is None:
await m.create_vector_index_async(
await self.model.create_vector_index_async(
self.account,
self.account_password,
self.namespace,
Expand All @@ -166,14 +163,13 @@ async def async_create_collection(self, ten: TenEnv, cmd: Cmd):

async def async_upsert_vector(self, ten: TenEnv, cmd: Cmd):
start_time = datetime.now()
m = Model(self.region_id, self.dbinstance_id, self.client)
collection = cmd.get_property_string("collection_name")
file = cmd.get_property_string("file_name")
content = cmd.get_property_string("content")
obj = json.loads(content)
rows = [(file, item["text"], item["embedding"]) for item in obj]

err = await m.upsert_collection_data_async(
err = await self.model.upsert_collection_data_async(
collection, self.namespace, self.namespace_password, rows
)
logger.info(
Expand All @@ -192,12 +188,11 @@ async def async_upsert_vector(self, ten: TenEnv, cmd: Cmd):

async def async_query_vector(self, ten: TenEnv, cmd: Cmd):
start_time = datetime.now()
m = Model(self.region_id, self.dbinstance_id, self.client)
collection = cmd.get_property_string("collection_name")
embedding = cmd.get_property_to_json("embedding")
top_k = cmd.get_property_int("top_k")
vector = json.loads(embedding)
response, error = await m.query_collection_data_async(
response, error = await self.model.query_collection_data_async(
collection, self.namespace, self.namespace_password, vector, top_k=top_k
)
logger.info(
Expand All @@ -212,15 +207,14 @@ async def async_query_vector(self, ten: TenEnv, cmd: Cmd):
if error:
return ten.return_result(CmdResult.create(StatusCode.ERROR), cmd)
else:
body = m.parse_collection_data(response.body)
body = self.model.parse_collection_data(response.body)
ret = CmdResult.create(StatusCode.OK)
ret.set_property_from_json("response", body)
ten.return_result(ret, cmd)

async def async_delete_collection(self, ten: TenEnv, cmd: Cmd):
m = Model(self.region_id, self.dbinstance_id, self.client)
collection = cmd.get_property_string("collection_name")
err = await m.delete_collection_async(
err = await self.model.delete_collection_async(
self.account, self.account_password, self.namespace, collection
)
if err is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
CmdResult,
)

import dashscope
import json
from typing import Generator, List
from http import HTTPStatus
Expand Down Expand Up @@ -45,6 +44,11 @@ def on_start(self, ten: TenEnv) -> None:
self.api_key = self.get_property_string(ten, "api_key", self.api_key)
self.model = self.get_property_string(ten, "model", self.api_key)


# lazy import packages which requires long time to load
global dashscope
import dashscope

dashscope.api_key = self.api_key

for i in range(self.parallel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
)
from typing import List, Any
from .log import logger
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
import json
from datetime import datetime
import uuid, math
Expand Down Expand Up @@ -63,6 +61,9 @@ def generate_collection_name(self) -> str:
return "coll_" + uuid.uuid1().hex.lower()

def split(self, path: str) -> List[Any]:
# lazy import packages which requires long time to load
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter

# load pdf file by path
documents = SimpleDirectoryReader(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
StatusCode,
CmdResult,
)
from .gemini_llm import GeminiLLM, GeminiLLMConfig
from .log import logger
from .utils import get_micro_ts, parse_sentence

Expand Down Expand Up @@ -45,6 +44,10 @@ class GeminiLLMExtension(Extension):

def on_start(self, ten: TenEnv) -> None:
logger.info("GeminiLLMExtension on_start")

# lazy import packages which requires long time to load
from .gemini_llm import GeminiLLM, GeminiLLMConfig

# Prepare configuration
gemini_llm_config = GeminiLLMConfig.default_config()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,8 @@
CmdResult,
)
from .log import logger
from .astra_llm import ASTRALLM
from .astra_retriever import ASTRARetriever
import queue, threading
from datetime import datetime
from llama_index.core.chat_engine import SimpleChatEngine, ContextChatEngine
from llama_index.core.storage.chat_store import SimpleChatStore
from llama_index.core.memory import ChatMemoryBuffer

PROPERTY_CHAT_MEMORY_TOKEN_LIMIT = "chat_memory_token_limit"
PROPERTY_GREETING = "greeting"
Expand Down Expand Up @@ -79,6 +74,8 @@ def on_start(self, ten: TenEnv) -> None:
self.thread.start()

# enable chat memory
from llama_index.core.storage.chat_store import SimpleChatStore
from llama_index.core.memory import ChatMemoryBuffer
self.chat_memory = ChatMemoryBuffer.from_defaults(
token_limit=self.chat_memory_token_limit,
chat_store=SimpleChatStore(),
Expand Down Expand Up @@ -204,9 +201,14 @@ def async_handle(self, ten: TenEnv):

logger.info("process input text [%s] ts [%s]", input_text, ts)

# lazy import packages which requires long time to load
from .astra_llm import ASTRALLM
from .astra_retriever import ASTRARetriever

# prepare chat engine
chat_engine = None
if len(self.collection_name) > 0:
from llama_index.core.chat_engine import ContextChatEngine
chat_engine = ContextChatEngine.from_defaults(
llm=ASTRALLM(ten=ten),
retriever=ASTRARetriever(ten=ten, coll=self.collection_name),
Expand All @@ -229,6 +231,7 @@ def async_handle(self, ten: TenEnv):
),
)
else:
from llama_index.core.chat_engine import SimpleChatEngine
chat_engine = SimpleChatEngine.from_defaults(
llm=ASTRALLM(ten=ten),
system_prompt=(
Expand Down