Skip to content

Commit

Permalink
Merge pull request #383 from TEN-framework/fix_astra
Browse files Browse the repository at this point in the history
fix(): replace astra
  • Loading branch information
sunshinexcode authored Nov 7, 2024
2 parents c188473 + 4f09b19 commit 9a168f5
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 39 deletions.
6 changes: 3 additions & 3 deletions agents/property.json
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@
"agora_asr_vendor_name": "microsoft",
"agora_asr_vendor_region": "$AZURE_STT_REGION",
"app_id": "$AGORA_APP_ID",
"channel": "astra_agents_test",
"channel": "ten_agent_test",
"enable_agora_asr": true,
"publish_audio": true,
"publish_data": true,
Expand Down Expand Up @@ -189,7 +189,7 @@
"name": "fashionai",
"property": {
"app_id": "$AGORA_APP_ID",
"channel": "astra_agents_test",
"channel": "ten_agents_test",
"stream_id": 12345,
"token": "",
"service_id": "agora"
Expand Down Expand Up @@ -693,7 +693,7 @@
"property": {
"app_id": "${env:AGORA_APP_ID}",
"token": "<agora_token>",
"channel": "astra_agents_test",
"channel": "ten_agents_test",
"stream_id": 1234,
"remote_stream_id": 123,
"subscribe_audio": true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,16 +202,16 @@ def async_handle(self, ten: TenEnv):
logger.info("process input text [%s] ts [%s]", input_text, ts)

# lazy import packages which requires long time to load
from .astra_llm import ASTRALLM
from .astra_retriever import ASTRARetriever
from .llama_llm import LlamaLLM
from .llama_retriever import LlamaRetriever

# prepare chat engine
chat_engine = None
if len(self.collection_name) > 0:
from llama_index.core.chat_engine import ContextChatEngine
from llama_index.core.chat_engine import ContextChatEngine
chat_engine = ContextChatEngine.from_defaults(
llm=ASTRALLM(ten=ten),
retriever=ASTRARetriever(ten=ten, coll=self.collection_name),
llm=LlamaLLM(ten=ten),
retriever=LlamaRetriever(ten=ten, coll=self.collection_name),
memory=self.chat_memory,
system_prompt=(
# "You are an expert Q&A system that is trusted around the world.\n"
Expand All @@ -233,7 +233,7 @@ def async_handle(self, ten: TenEnv):
else:
from llama_index.core.chat_engine import SimpleChatEngine
chat_engine = SimpleChatEngine.from_defaults(
llm=ASTRALLM(ten=ten),
llm=LlamaLLM(ten=ten),
system_prompt=(
"You are a voice assistant who talks in a conversational way and can chat with me like my friends. \n"
"I will speak to you in English or Chinese, and you will answer in the corrected and improved version of my text with the language I use. \n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ def embed_from_resp(cmd_result: CmdResult) -> List[float]:
return json.loads(embedding_output_json)


class ASTRAEmbedding(BaseEmbedding):
class LlamaEmbedding(BaseEmbedding):
ten: Any

def __init__(self, ten):
"""Creates a new ASTRA embedding interface."""
"""Creates a new Llama embedding interface."""
super().__init__()
self.ten = ten

@classmethod
def class_name(cls) -> str:
return "astra_embedding"
return "llama_embedding"

async def _aget_query_embedding(self, query: str) -> List[float]:
return self._get_query_embedding(query)
Expand All @@ -36,7 +36,7 @@ async def _aget_text_embedding(self, text: str) -> List[float]:

def _get_query_embedding(self, query: str) -> List[float]:
logger.info(
"ASTRAEmbedding generate embeddings for the query: {}".format(query)
"LlamaEmbedding generate embeddings for the query: {}".format(query)
)
wait_event = threading.Event()
resp: List[float]
Expand All @@ -45,7 +45,7 @@ def callback(_, result):
nonlocal resp
nonlocal wait_event

logger.debug("ASTRAEmbedding embedding received")
logger.debug("LlamaEmbedding embedding received")
resp = embed_from_resp(result)
wait_event.set()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ten import Cmd, StatusCode, CmdResult


def chat_from_astra_response(cmd_result: CmdResult) -> ChatResponse:
def chat_from_llama_response(cmd_result: CmdResult) -> ChatResponse:
status = cmd_result.get_status_code()
if status != StatusCode.OK:
return None
Expand All @@ -36,11 +36,11 @@ def _messages_str_from_chat_messages(messages: Sequence[ChatMessage]) -> str:
return json.dumps(messages_list, ensure_ascii=False)


class ASTRALLM(CustomLLM):
class LlamaLLM(CustomLLM):
ten: Any

def __init__(self, ten):
"""Creates a new ASTRA model interface."""
"""Creates a new Llama model interface."""
super().__init__()
self.ten = ten

Expand All @@ -50,22 +50,22 @@ def metadata(self) -> LLMMetadata:
# TODO: fix metadata
context_window=1024,
num_output=512,
model_name="astra_llm",
model_name="llama_llm",
is_chat_model=True,
)

@llm_chat_callback()
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
logger.debug("ASTRALLM chat start")
logger.debug("LlamaLLM chat start")

resp: ChatResponse
wait_event = threading.Event()

def callback(_, result):
logger.debug("ASTRALLM chat callback done")
logger.debug("LlamaLLM chat callback done")
nonlocal resp
nonlocal wait_event
resp = chat_from_astra_response(result)
resp = chat_from_llama_response(result)
wait_event.set()

messages_str = _messages_str_from_chat_messages(messages)
Expand All @@ -74,7 +74,7 @@ def callback(_, result):
cmd.set_property_string("messages", messages_str)
cmd.set_property_bool("stream", False)
logger.info(
"ASTRALLM chat send_cmd {}, messages {}".format(
"LlamaLLM chat send_cmd {}, messages {}".format(
cmd.get_name(), messages_str
)
)
Expand All @@ -87,13 +87,13 @@ def callback(_, result):
def complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponse:
logger.warning("ASTRALLM complete hasn't been implemented yet")
logger.warning("LlamaLLM complete hasn't been implemented yet")

@llm_chat_callback()
def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
logger.debug("ASTRALLM stream_chat start")
logger.debug("LlamaLLM stream_chat start")

cur_tokens = ""
resp_queue = queue.Queue()
Expand All @@ -115,12 +115,12 @@ def callback(_, result):

status = result.get_status_code()
if status != StatusCode.OK:
logger.warn("ASTRALLM stream_chat callback status {}".format(status))
logger.warn("LlamaLLM stream_chat callback status {}".format(status))
resp_queue.put(None)
return

cur_tokens = result.get_property_string("text")
logger.debug("ASTRALLM stream_chat callback text [{}]".format(cur_tokens))
logger.debug("LlamaLLM stream_chat callback text [{}]".format(cur_tokens))
resp_queue.put(cur_tokens)
if result.get_is_final():
resp_queue.put(None)
Expand All @@ -131,7 +131,7 @@ def callback(_, result):
cmd.set_property_string("messages", messages_str)
cmd.set_property_bool("stream", True)
logger.info(
"ASTRALLM stream_chat send_cmd {}, messages {}".format(
"LlamaLLM stream_chat send_cmd {}, messages {}".format(
cmd.get_name(), messages_str
)
)
Expand All @@ -141,8 +141,8 @@ def callback(_, result):
def stream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponseGen:
logger.warning("ASTRALLM stream_complete hasn't been implemented yet")
logger.warning("LlamaLLM stream_complete hasn't been implemented yet")

@classmethod
def class_name(cls) -> str:
return "astra_llm"
return "llama_llm"
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llama_index.core.retrievers import BaseRetriever

from .log import logger
from .astra_embedding import ASTRAEmbedding
from .llama_embedding import LlamaEmbedding
from ten import (
TenEnv,
Cmd,
Expand All @@ -15,7 +15,7 @@


def format_node_result(cmd_result: CmdResult) -> List[NodeWithScore]:
logger.info("ASTRARetriever retrieve response {}".format(cmd_result.to_json()))
logger.info("LlamaRetriever retrieve response {}".format(cmd_result.to_json()))
status = cmd_result.get_status_code()
try:
contents_json = cmd_result.get_property_to_json("response")
Expand Down Expand Up @@ -45,21 +45,21 @@ def format_node_result(cmd_result: CmdResult) -> List[NodeWithScore]:
return nodes


class ASTRARetriever(BaseRetriever):
class LlamaRetriever(BaseRetriever):
ten: Any
embed_model: ASTRAEmbedding
embed_model: LlamaEmbedding

def __init__(self, ten: TenEnv, coll: str):
super().__init__()
try:
self.ten = ten
self.embed_model = ASTRAEmbedding(ten=ten)
self.embed_model = LlamaEmbedding(ten=ten)
self.collection_name = coll
except Exception as e:
logger.error(f"Failed to initialize ASTRARetriever: {e}")
logger.error(f"Failed to initialize LlamaRetriever: {e}")

def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
logger.info("ASTRARetriever retrieve: {}".format(query_bundle.to_json))
logger.info("LlamaRetriever retrieve: {}".format(query_bundle.to_json))

wait_event = threading.Event()
resp: List[NodeWithScore] = []
Expand All @@ -69,7 +69,7 @@ def cmd_callback(_, result):
nonlocal wait_event
resp = format_node_result(result)
wait_event.set()
logger.debug("ASTRARetriever callback done")
logger.debug("LlamaRetriever callback done")

embedding = self.embed_model.get_query_embedding(query=query_bundle.query_str)

Expand All @@ -78,7 +78,7 @@ def cmd_callback(_, result):
query_cmd.set_property_int("top_k", 3) # TODO: configable
query_cmd.set_property_from_json("embedding", json.dumps(embedding))
logger.info(
"ASTRARetriever send_cmd, collection_name: {}, embedding len: {}".format(
"LlamaRetriever send_cmd, collection_name: {}, embedding len: {}".format(
self.collection_name, len(embedding)
)
)
Expand Down

0 comments on commit 9a168f5

Please sign in to comment.