From 4f09b19af36d2060bdb09615add68bbe71aebc77 Mon Sep 17 00:00:00 2001 From: sunshinexcode <24xinhui@163.com> Date: Thu, 7 Nov 2024 03:10:41 +0000 Subject: [PATCH] fix(): replace astra --- agents/property.json | 6 ++-- .../llama_index_chat_engine/extension.py | 14 ++++----- ...{astra_embedding.py => llama_embedding.py} | 10 +++---- .../{astra_llm.py => llama_llm.py} | 30 +++++++++---------- ...{astra_retriever.py => llama_retriever.py} | 18 +++++------ 5 files changed, 39 insertions(+), 39 deletions(-) rename agents/ten_packages/extension/llama_index_chat_engine/{astra_embedding.py => llama_embedding.py} (86%) rename agents/ten_packages/extension/llama_index_chat_engine/{astra_llm.py => llama_llm.py} (82%) rename agents/ten_packages/extension/llama_index_chat_engine/{astra_retriever.py => llama_retriever.py} (82%) diff --git a/agents/property.json b/agents/property.json index 860a3407..0c921e84 100644 --- a/agents/property.json +++ b/agents/property.json @@ -132,7 +132,7 @@ "agora_asr_vendor_name": "microsoft", "agora_asr_vendor_region": "$AZURE_STT_REGION", "app_id": "$AGORA_APP_ID", - "channel": "astra_agents_test", + "channel": "ten_agent_test", "enable_agora_asr": true, "publish_audio": true, "publish_data": true, @@ -189,7 +189,7 @@ "name": "fashionai", "property": { "app_id": "$AGORA_APP_ID", - "channel": "astra_agents_test", + "channel": "ten_agents_test", "stream_id": 12345, "token": "", "service_id": "agora" @@ -693,7 +693,7 @@ "property": { "app_id": "${env:AGORA_APP_ID}", "token": "", - "channel": "astra_agents_test", + "channel": "ten_agents_test", "stream_id": 1234, "remote_stream_id": 123, "subscribe_audio": true, diff --git a/agents/ten_packages/extension/llama_index_chat_engine/extension.py b/agents/ten_packages/extension/llama_index_chat_engine/extension.py index 53319197..c28871b1 100644 --- a/agents/ten_packages/extension/llama_index_chat_engine/extension.py +++ b/agents/ten_packages/extension/llama_index_chat_engine/extension.py @@ -202,16 +202,16 @@ def async_handle(self, ten: TenEnv): logger.info("process input text [%s] ts [%s]", input_text, ts) # lazy import packages which requires long time to load - from .astra_llm import ASTRALLM - from .astra_retriever import ASTRARetriever - + from .llama_llm import LlamaLLM + from .llama_retriever import LlamaRetriever + # prepare chat engine chat_engine = None if len(self.collection_name) > 0: - from llama_index.core.chat_engine import ContextChatEngine + from llama_index.core.chat_engine import ContextChatEngine chat_engine = ContextChatEngine.from_defaults( - llm=ASTRALLM(ten=ten), - retriever=ASTRARetriever(ten=ten, coll=self.collection_name), + llm=LlamaLLM(ten=ten), + retriever=LlamaRetriever(ten=ten, coll=self.collection_name), memory=self.chat_memory, system_prompt=( # "You are an expert Q&A system that is trusted around the world.\n" @@ -233,7 +233,7 @@ def async_handle(self, ten: TenEnv): else: from llama_index.core.chat_engine import SimpleChatEngine chat_engine = SimpleChatEngine.from_defaults( - llm=ASTRALLM(ten=ten), + llm=LlamaLLM(ten=ten), system_prompt=( "You are a voice assistant who talks in a conversational way and can chat with me like my friends. \n" "I will speak to you in English or Chinese, and you will answer in the corrected and improved version of my text with the language I use. \n" diff --git a/agents/ten_packages/extension/llama_index_chat_engine/astra_embedding.py b/agents/ten_packages/extension/llama_index_chat_engine/llama_embedding.py similarity index 86% rename from agents/ten_packages/extension/llama_index_chat_engine/astra_embedding.py rename to agents/ten_packages/extension/llama_index_chat_engine/llama_embedding.py index 0a60c4fb..7ed928ff 100644 --- a/agents/ten_packages/extension/llama_index_chat_engine/astra_embedding.py +++ b/agents/ten_packages/extension/llama_index_chat_engine/llama_embedding.py @@ -16,17 +16,17 @@ def embed_from_resp(cmd_result: CmdResult) -> List[float]: return json.loads(embedding_output_json) -class ASTRAEmbedding(BaseEmbedding): +class LlamaEmbedding(BaseEmbedding): ten: Any def __init__(self, ten): - """Creates a new ASTRA embedding interface.""" + """Creates a new Llama embedding interface.""" super().__init__() self.ten = ten @classmethod def class_name(cls) -> str: - return "astra_embedding" + return "llama_embedding" async def _aget_query_embedding(self, query: str) -> List[float]: return self._get_query_embedding(query) @@ -36,7 +36,7 @@ async def _aget_text_embedding(self, text: str) -> List[float]: def _get_query_embedding(self, query: str) -> List[float]: logger.info( - "ASTRAEmbedding generate embeddings for the query: {}".format(query) + "LlamaEmbedding generate embeddings for the query: {}".format(query) ) wait_event = threading.Event() resp: List[float] @@ -45,7 +45,7 @@ def callback(_, result): nonlocal resp nonlocal wait_event - logger.debug("ASTRAEmbedding embedding received") + logger.debug("LlamaEmbedding embedding received") resp = embed_from_resp(result) wait_event.set() diff --git a/agents/ten_packages/extension/llama_index_chat_engine/astra_llm.py b/agents/ten_packages/extension/llama_index_chat_engine/llama_llm.py similarity index 82% rename from agents/ten_packages/extension/llama_index_chat_engine/astra_llm.py rename to agents/ten_packages/extension/llama_index_chat_engine/llama_llm.py index 50f223da..9a5f83e9 100644 --- a/agents/ten_packages/extension/llama_index_chat_engine/astra_llm.py +++ b/agents/ten_packages/extension/llama_index_chat_engine/llama_llm.py @@ -19,7 +19,7 @@ from ten import Cmd, StatusCode, CmdResult -def chat_from_astra_response(cmd_result: CmdResult) -> ChatResponse: +def chat_from_llama_response(cmd_result: CmdResult) -> ChatResponse: status = cmd_result.get_status_code() if status != StatusCode.OK: return None @@ -36,11 +36,11 @@ def _messages_str_from_chat_messages(messages: Sequence[ChatMessage]) -> str: return json.dumps(messages_list, ensure_ascii=False) -class ASTRALLM(CustomLLM): +class LlamaLLM(CustomLLM): ten: Any def __init__(self, ten): - """Creates a new ASTRA model interface.""" + """Creates a new Llama model interface.""" super().__init__() self.ten = ten @@ -50,22 +50,22 @@ def metadata(self) -> LLMMetadata: # TODO: fix metadata context_window=1024, num_output=512, - model_name="astra_llm", + model_name="llama_llm", is_chat_model=True, ) @llm_chat_callback() def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: - logger.debug("ASTRALLM chat start") + logger.debug("LlamaLLM chat start") resp: ChatResponse wait_event = threading.Event() def callback(_, result): - logger.debug("ASTRALLM chat callback done") + logger.debug("LlamaLLM chat callback done") nonlocal resp nonlocal wait_event - resp = chat_from_astra_response(result) + resp = chat_from_llama_response(result) wait_event.set() messages_str = _messages_str_from_chat_messages(messages) @@ -74,7 +74,7 @@ def callback(_, result): cmd.set_property_string("messages", messages_str) cmd.set_property_bool("stream", False) logger.info( - "ASTRALLM chat send_cmd {}, messages {}".format( + "LlamaLLM chat send_cmd {}, messages {}".format( cmd.get_name(), messages_str ) ) @@ -87,13 +87,13 @@ def callback(_, result): def complete( self, prompt: str, formatted: bool = False, **kwargs: Any ) -> CompletionResponse: - logger.warning("ASTRALLM complete hasn't been implemented yet") + logger.warning("LlamaLLM complete hasn't been implemented yet") @llm_chat_callback() def stream_chat( self, messages: Sequence[ChatMessage], **kwargs: Any ) -> ChatResponseGen: - logger.debug("ASTRALLM stream_chat start") + logger.debug("LlamaLLM stream_chat start") cur_tokens = "" resp_queue = queue.Queue() @@ -115,12 +115,12 @@ def callback(_, result): status = result.get_status_code() if status != StatusCode.OK: - logger.warn("ASTRALLM stream_chat callback status {}".format(status)) + logger.warn("LlamaLLM stream_chat callback status {}".format(status)) resp_queue.put(None) return cur_tokens = result.get_property_string("text") - logger.debug("ASTRALLM stream_chat callback text [{}]".format(cur_tokens)) + logger.debug("LlamaLLM stream_chat callback text [{}]".format(cur_tokens)) resp_queue.put(cur_tokens) if result.get_is_final(): resp_queue.put(None) @@ -131,7 +131,7 @@ def callback(_, result): cmd.set_property_string("messages", messages_str) cmd.set_property_bool("stream", True) logger.info( - "ASTRALLM stream_chat send_cmd {}, messages {}".format( + "LlamaLLM stream_chat send_cmd {}, messages {}".format( cmd.get_name(), messages_str ) ) @@ -141,8 +141,8 @@ def callback(_, result): def stream_complete( self, prompt: str, formatted: bool = False, **kwargs: Any ) -> CompletionResponseGen: - logger.warning("ASTRALLM stream_complete hasn't been implemented yet") + logger.warning("LlamaLLM stream_complete hasn't been implemented yet") @classmethod def class_name(cls) -> str: - return "astra_llm" + return "llama_llm" diff --git a/agents/ten_packages/extension/llama_index_chat_engine/astra_retriever.py b/agents/ten_packages/extension/llama_index_chat_engine/llama_retriever.py similarity index 82% rename from agents/ten_packages/extension/llama_index_chat_engine/astra_retriever.py rename to agents/ten_packages/extension/llama_index_chat_engine/llama_retriever.py index 95790cba..5163f533 100644 --- a/agents/ten_packages/extension/llama_index_chat_engine/astra_retriever.py +++ b/agents/ten_packages/extension/llama_index_chat_engine/llama_retriever.py @@ -5,7 +5,7 @@ from llama_index.core.retrievers import BaseRetriever from .log import logger -from .astra_embedding import ASTRAEmbedding +from .llama_embedding import LlamaEmbedding from ten import ( TenEnv, Cmd, @@ -15,7 +15,7 @@ def format_node_result(cmd_result: CmdResult) -> List[NodeWithScore]: - logger.info("ASTRARetriever retrieve response {}".format(cmd_result.to_json())) + logger.info("LlamaRetriever retrieve response {}".format(cmd_result.to_json())) status = cmd_result.get_status_code() try: contents_json = cmd_result.get_property_to_json("response") @@ -45,21 +45,21 @@ def format_node_result(cmd_result: CmdResult) -> List[NodeWithScore]: return nodes -class ASTRARetriever(BaseRetriever): +class LlamaRetriever(BaseRetriever): ten: Any - embed_model: ASTRAEmbedding + embed_model: LlamaEmbedding def __init__(self, ten: TenEnv, coll: str): super().__init__() try: self.ten = ten - self.embed_model = ASTRAEmbedding(ten=ten) + self.embed_model = LlamaEmbedding(ten=ten) self.collection_name = coll except Exception as e: - logger.error(f"Failed to initialize ASTRARetriever: {e}") + logger.error(f"Failed to initialize LlamaRetriever: {e}") def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: - logger.info("ASTRARetriever retrieve: {}".format(query_bundle.to_json)) + logger.info("LlamaRetriever retrieve: {}".format(query_bundle.to_json)) wait_event = threading.Event() resp: List[NodeWithScore] = [] @@ -69,7 +69,7 @@ def cmd_callback(_, result): nonlocal wait_event resp = format_node_result(result) wait_event.set() - logger.debug("ASTRARetriever callback done") + logger.debug("LlamaRetriever callback done") embedding = self.embed_model.get_query_embedding(query=query_bundle.query_str) @@ -78,7 +78,7 @@ def cmd_callback(_, result): query_cmd.set_property_int("top_k", 3) # TODO: configable query_cmd.set_property_from_json("embedding", json.dumps(embedding)) logger.info( - "ASTRARetriever send_cmd, collection_name: {}, embedding len: {}".format( + "LlamaRetriever send_cmd, collection_name: {}, embedding len: {}".format( self.collection_name, len(embedding) ) )