Skip to content

Commit

Permalink
feat: lazy import to speed up loading (#261)
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyoucao577 authored Aug 29, 2024
1 parent 88c734e commit 09fdef1
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import asyncio
import os
import json
from .client import AliGPDBClient
from .model import Model
from ten import (
Extension,
TenEnv,
Expand All @@ -20,12 +18,6 @@
import threading
from datetime import datetime

from alibabacloud_gpdb20160503.client import Client as gpdb20160503Client
from alibabacloud_tea_openapi import models as open_api_models
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from alibabacloud_tea_util import models as util_models
from alibabacloud_tea_util.client import Client as UtilClient


class AliPGDBExtension(Extension):
def __init__(self, name):
Expand All @@ -37,7 +29,7 @@ def __init__(self, name):
self.region_id = os.environ.get("ADBPG_INSTANCE_REGION")
self.dbinstance_id = os.environ.get("ADBPG_INSTANCE_ID")
self.endpoint = f"gpdb.aliyuncs.com"
self.client = None
self.model = None
self.account = os.environ.get("ADBPG_ACCOUNT")
self.account_password = os.environ.get("ADBPG_ACCOUNT_PASSWORD")
self.namespace = os.environ.get("ADBPG_NAMESPACE")
Expand Down Expand Up @@ -92,9 +84,15 @@ def on_start(self, ten: TenEnv) -> None:
self.endpoint = "gpdb.aliyuncs.com"
else:
self.endpoint = f"gpdb.{self.region_id}.aliyuncs.com"
self.client = AliGPDBClient(

# lazy import packages which requires long time to load
from .client import AliGPDBClient
from .model import Model

client = AliGPDBClient(
self.access_key_id, self.access_key_secret, self.endpoint
)
self.model = Model(self.region_id, self.dbinstance_id, client)
self.thread = threading.Thread(
target=asyncio.run, args=(self.__thread_routine(ten),)
)
Expand Down Expand Up @@ -141,19 +139,18 @@ def on_cmd(self, ten: TenEnv, cmd: Cmd) -> None:
ten.return_result(CmdResult.create(StatusCode.ERROR), cmd)

async def async_create_collection(self, ten: TenEnv, cmd: Cmd):
m = Model(self.region_id, self.dbinstance_id, self.client)
collection = cmd.get_property_string("collection_name")
dimension = 1024
try:
dimension = cmd.get_property_int("dimension")
except Exception as e:
logger.warning(f"Error: {e}")

err = await m.create_collection_async(
err = await self.model.create_collection_async(
self.account, self.account_password, self.namespace, collection
)
if err is None:
await m.create_vector_index_async(
await self.model.create_vector_index_async(
self.account,
self.account_password,
self.namespace,
Expand All @@ -166,14 +163,13 @@ async def async_create_collection(self, ten: TenEnv, cmd: Cmd):

async def async_upsert_vector(self, ten: TenEnv, cmd: Cmd):
start_time = datetime.now()
m = Model(self.region_id, self.dbinstance_id, self.client)
collection = cmd.get_property_string("collection_name")
file = cmd.get_property_string("file_name")
content = cmd.get_property_string("content")
obj = json.loads(content)
rows = [(file, item["text"], item["embedding"]) for item in obj]

err = await m.upsert_collection_data_async(
err = await self.model.upsert_collection_data_async(
collection, self.namespace, self.namespace_password, rows
)
logger.info(
Expand All @@ -192,12 +188,11 @@ async def async_upsert_vector(self, ten: TenEnv, cmd: Cmd):

async def async_query_vector(self, ten: TenEnv, cmd: Cmd):
start_time = datetime.now()
m = Model(self.region_id, self.dbinstance_id, self.client)
collection = cmd.get_property_string("collection_name")
embedding = cmd.get_property_to_json("embedding")
top_k = cmd.get_property_int("top_k")
vector = json.loads(embedding)
response, error = await m.query_collection_data_async(
response, error = await self.model.query_collection_data_async(
collection, self.namespace, self.namespace_password, vector, top_k=top_k
)
logger.info(
Expand All @@ -212,15 +207,14 @@ async def async_query_vector(self, ten: TenEnv, cmd: Cmd):
if error:
return ten.return_result(CmdResult.create(StatusCode.ERROR), cmd)
else:
body = m.parse_collection_data(response.body)
body = self.model.parse_collection_data(response.body)
ret = CmdResult.create(StatusCode.OK)
ret.set_property_from_json("response", body)
ten.return_result(ret, cmd)

async def async_delete_collection(self, ten: TenEnv, cmd: Cmd):
m = Model(self.region_id, self.dbinstance_id, self.client)
collection = cmd.get_property_string("collection_name")
err = await m.delete_collection_async(
err = await self.model.delete_collection_async(
self.account, self.account_password, self.namespace, collection
)
if err is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
CmdResult,
)

import dashscope
import json
from typing import Generator, List
from http import HTTPStatus
Expand Down Expand Up @@ -45,6 +44,11 @@ def on_start(self, ten: TenEnv) -> None:
self.api_key = self.get_property_string(ten, "api_key", self.api_key)
self.model = self.get_property_string(ten, "model", self.api_key)


# lazy import packages which requires long time to load
global dashscope
import dashscope

dashscope.api_key = self.api_key

for i in range(self.parallel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
)
from typing import List, Any
from .log import logger
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
import json
from datetime import datetime
import uuid, math
Expand Down Expand Up @@ -63,6 +61,9 @@ def generate_collection_name(self) -> str:
return "coll_" + uuid.uuid1().hex.lower()

def split(self, path: str) -> List[Any]:
# lazy import packages which requires long time to load
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter

# load pdf file by path
documents = SimpleDirectoryReader(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
StatusCode,
CmdResult,
)
from .gemini_llm import GeminiLLM, GeminiLLMConfig
from .log import logger
from .utils import get_micro_ts, parse_sentence

Expand Down Expand Up @@ -45,6 +44,10 @@ class GeminiLLMExtension(Extension):

def on_start(self, ten: TenEnv) -> None:
logger.info("GeminiLLMExtension on_start")

# lazy import packages which requires long time to load
from .gemini_llm import GeminiLLM, GeminiLLMConfig

# Prepare configuration
gemini_llm_config = GeminiLLMConfig.default_config()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,8 @@
CmdResult,
)
from .log import logger
from .astra_llm import ASTRALLM
from .astra_retriever import ASTRARetriever
import queue, threading
from datetime import datetime
from llama_index.core.chat_engine import SimpleChatEngine, ContextChatEngine
from llama_index.core.storage.chat_store import SimpleChatStore
from llama_index.core.memory import ChatMemoryBuffer

PROPERTY_CHAT_MEMORY_TOKEN_LIMIT = "chat_memory_token_limit"
PROPERTY_GREETING = "greeting"
Expand Down Expand Up @@ -79,6 +74,8 @@ def on_start(self, ten: TenEnv) -> None:
self.thread.start()

# enable chat memory
from llama_index.core.storage.chat_store import SimpleChatStore
from llama_index.core.memory import ChatMemoryBuffer
self.chat_memory = ChatMemoryBuffer.from_defaults(
token_limit=self.chat_memory_token_limit,
chat_store=SimpleChatStore(),
Expand Down Expand Up @@ -204,9 +201,14 @@ def async_handle(self, ten: TenEnv):

logger.info("process input text [%s] ts [%s]", input_text, ts)

# lazy import packages which requires long time to load
from .astra_llm import ASTRALLM
from .astra_retriever import ASTRARetriever

# prepare chat engine
chat_engine = None
if len(self.collection_name) > 0:
from llama_index.core.chat_engine import ContextChatEngine
chat_engine = ContextChatEngine.from_defaults(
llm=ASTRALLM(ten=ten),
retriever=ASTRARetriever(ten=ten, coll=self.collection_name),
Expand All @@ -229,6 +231,7 @@ def async_handle(self, ten: TenEnv):
),
)
else:
from llama_index.core.chat_engine import SimpleChatEngine
chat_engine = SimpleChatEngine.from_defaults(
llm=ASTRALLM(ten=ten),
system_prompt=(
Expand Down

0 comments on commit 09fdef1

Please sign in to comment.