Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(): replace astra #383

Merged
merged 1 commit into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions agents/property.json
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@
"agora_asr_vendor_name": "microsoft",
"agora_asr_vendor_region": "$AZURE_STT_REGION",
"app_id": "$AGORA_APP_ID",
"channel": "astra_agents_test",
"channel": "ten_agent_test",
"enable_agora_asr": true,
"publish_audio": true,
"publish_data": true,
Expand Down Expand Up @@ -189,7 +189,7 @@
"name": "fashionai",
"property": {
"app_id": "$AGORA_APP_ID",
"channel": "astra_agents_test",
"channel": "ten_agents_test",
"stream_id": 12345,
"token": "",
"service_id": "agora"
Expand Down Expand Up @@ -693,7 +693,7 @@
"property": {
"app_id": "${env:AGORA_APP_ID}",
"token": "<agora_token>",
"channel": "astra_agents_test",
"channel": "ten_agents_test",
"stream_id": 1234,
"remote_stream_id": 123,
"subscribe_audio": true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,16 +202,16 @@ def async_handle(self, ten: TenEnv):
logger.info("process input text [%s] ts [%s]", input_text, ts)

# lazy import packages which requires long time to load
from .astra_llm import ASTRALLM
from .astra_retriever import ASTRARetriever
from .llama_llm import LlamaLLM
from .llama_retriever import LlamaRetriever

# prepare chat engine
chat_engine = None
if len(self.collection_name) > 0:
from llama_index.core.chat_engine import ContextChatEngine
from llama_index.core.chat_engine import ContextChatEngine
chat_engine = ContextChatEngine.from_defaults(
llm=ASTRALLM(ten=ten),
retriever=ASTRARetriever(ten=ten, coll=self.collection_name),
llm=LlamaLLM(ten=ten),
retriever=LlamaRetriever(ten=ten, coll=self.collection_name),
memory=self.chat_memory,
system_prompt=(
# "You are an expert Q&A system that is trusted around the world.\n"
Expand All @@ -233,7 +233,7 @@ def async_handle(self, ten: TenEnv):
else:
from llama_index.core.chat_engine import SimpleChatEngine
chat_engine = SimpleChatEngine.from_defaults(
llm=ASTRALLM(ten=ten),
llm=LlamaLLM(ten=ten),
system_prompt=(
"You are a voice assistant who talks in a conversational way and can chat with me like my friends. \n"
"I will speak to you in English or Chinese, and you will answer in the corrected and improved version of my text with the language I use. \n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ def embed_from_resp(cmd_result: CmdResult) -> List[float]:
return json.loads(embedding_output_json)


class ASTRAEmbedding(BaseEmbedding):
class LlamaEmbedding(BaseEmbedding):
ten: Any

def __init__(self, ten):
"""Creates a new ASTRA embedding interface."""
"""Creates a new Llama embedding interface."""
super().__init__()
self.ten = ten

@classmethod
def class_name(cls) -> str:
return "astra_embedding"
return "llama_embedding"

async def _aget_query_embedding(self, query: str) -> List[float]:
return self._get_query_embedding(query)
Expand All @@ -36,7 +36,7 @@ async def _aget_text_embedding(self, text: str) -> List[float]:

def _get_query_embedding(self, query: str) -> List[float]:
logger.info(
"ASTRAEmbedding generate embeddings for the query: {}".format(query)
"LlamaEmbedding generate embeddings for the query: {}".format(query)
)
wait_event = threading.Event()
resp: List[float]
Expand All @@ -45,7 +45,7 @@ def callback(_, result):
nonlocal resp
nonlocal wait_event

logger.debug("ASTRAEmbedding embedding received")
logger.debug("LlamaEmbedding embedding received")
resp = embed_from_resp(result)
wait_event.set()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ten import Cmd, StatusCode, CmdResult


def chat_from_astra_response(cmd_result: CmdResult) -> ChatResponse:
def chat_from_llama_response(cmd_result: CmdResult) -> ChatResponse:
status = cmd_result.get_status_code()
if status != StatusCode.OK:
return None
Expand All @@ -36,11 +36,11 @@ def _messages_str_from_chat_messages(messages: Sequence[ChatMessage]) -> str:
return json.dumps(messages_list, ensure_ascii=False)


class ASTRALLM(CustomLLM):
class LlamaLLM(CustomLLM):
ten: Any

def __init__(self, ten):
"""Creates a new ASTRA model interface."""
"""Creates a new Llama model interface."""
super().__init__()
self.ten = ten

Expand All @@ -50,22 +50,22 @@ def metadata(self) -> LLMMetadata:
# TODO: fix metadata
context_window=1024,
num_output=512,
model_name="astra_llm",
model_name="llama_llm",
is_chat_model=True,
)

@llm_chat_callback()
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
logger.debug("ASTRALLM chat start")
logger.debug("LlamaLLM chat start")

resp: ChatResponse
wait_event = threading.Event()

def callback(_, result):
logger.debug("ASTRALLM chat callback done")
logger.debug("LlamaLLM chat callback done")
nonlocal resp
nonlocal wait_event
resp = chat_from_astra_response(result)
resp = chat_from_llama_response(result)
wait_event.set()

messages_str = _messages_str_from_chat_messages(messages)
Expand All @@ -74,7 +74,7 @@ def callback(_, result):
cmd.set_property_string("messages", messages_str)
cmd.set_property_bool("stream", False)
logger.info(
"ASTRALLM chat send_cmd {}, messages {}".format(
"LlamaLLM chat send_cmd {}, messages {}".format(
cmd.get_name(), messages_str
)
)
Expand All @@ -87,13 +87,13 @@ def callback(_, result):
def complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponse:
logger.warning("ASTRALLM complete hasn't been implemented yet")
logger.warning("LlamaLLM complete hasn't been implemented yet")

@llm_chat_callback()
def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
logger.debug("ASTRALLM stream_chat start")
logger.debug("LlamaLLM stream_chat start")

cur_tokens = ""
resp_queue = queue.Queue()
Expand All @@ -115,12 +115,12 @@ def callback(_, result):

status = result.get_status_code()
if status != StatusCode.OK:
logger.warn("ASTRALLM stream_chat callback status {}".format(status))
logger.warn("LlamaLLM stream_chat callback status {}".format(status))
resp_queue.put(None)
return

cur_tokens = result.get_property_string("text")
logger.debug("ASTRALLM stream_chat callback text [{}]".format(cur_tokens))
logger.debug("LlamaLLM stream_chat callback text [{}]".format(cur_tokens))
resp_queue.put(cur_tokens)
if result.get_is_final():
resp_queue.put(None)
Expand All @@ -131,7 +131,7 @@ def callback(_, result):
cmd.set_property_string("messages", messages_str)
cmd.set_property_bool("stream", True)
logger.info(
"ASTRALLM stream_chat send_cmd {}, messages {}".format(
"LlamaLLM stream_chat send_cmd {}, messages {}".format(
cmd.get_name(), messages_str
)
)
Expand All @@ -141,8 +141,8 @@ def callback(_, result):
def stream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponseGen:
logger.warning("ASTRALLM stream_complete hasn't been implemented yet")
logger.warning("LlamaLLM stream_complete hasn't been implemented yet")

@classmethod
def class_name(cls) -> str:
return "astra_llm"
return "llama_llm"
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from llama_index.core.retrievers import BaseRetriever

from .log import logger
from .astra_embedding import ASTRAEmbedding
from .llama_embedding import LlamaEmbedding
from ten import (
TenEnv,
Cmd,
Expand All @@ -15,7 +15,7 @@


def format_node_result(cmd_result: CmdResult) -> List[NodeWithScore]:
logger.info("ASTRARetriever retrieve response {}".format(cmd_result.to_json()))
logger.info("LlamaRetriever retrieve response {}".format(cmd_result.to_json()))
status = cmd_result.get_status_code()
try:
contents_json = cmd_result.get_property_to_json("response")
Expand Down Expand Up @@ -45,21 +45,21 @@ def format_node_result(cmd_result: CmdResult) -> List[NodeWithScore]:
return nodes


class ASTRARetriever(BaseRetriever):
class LlamaRetriever(BaseRetriever):
ten: Any
embed_model: ASTRAEmbedding
embed_model: LlamaEmbedding

def __init__(self, ten: TenEnv, coll: str):
super().__init__()
try:
self.ten = ten
self.embed_model = ASTRAEmbedding(ten=ten)
self.embed_model = LlamaEmbedding(ten=ten)
self.collection_name = coll
except Exception as e:
logger.error(f"Failed to initialize ASTRARetriever: {e}")
logger.error(f"Failed to initialize LlamaRetriever: {e}")

def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
logger.info("ASTRARetriever retrieve: {}".format(query_bundle.to_json))
logger.info("LlamaRetriever retrieve: {}".format(query_bundle.to_json))

wait_event = threading.Event()
resp: List[NodeWithScore] = []
Expand All @@ -69,7 +69,7 @@ def cmd_callback(_, result):
nonlocal wait_event
resp = format_node_result(result)
wait_event.set()
logger.debug("ASTRARetriever callback done")
logger.debug("LlamaRetriever callback done")

embedding = self.embed_model.get_query_embedding(query=query_bundle.query_str)

Expand All @@ -78,7 +78,7 @@ def cmd_callback(_, result):
query_cmd.set_property_int("top_k", 3) # TODO: configable
query_cmd.set_property_from_json("embedding", json.dumps(embedding))
logger.info(
"ASTRARetriever send_cmd, collection_name: {}, embedding len: {}".format(
"LlamaRetriever send_cmd, collection_name: {}, embedding len: {}".format(
self.collection_name, len(embedding)
)
)
Expand Down