Skip to content
This repository has been archived by the owner on Feb 3, 2025. It is now read-only.

refactor: refactored chat.py #100

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 35 additions & 62 deletions sage/chat.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
"""A gradio app that enables users to chat with their codebase.

You must run `sage-index $GITHUB_REPO` first in order to index the codebase into a vector store.
"""

import logging

import configargparse
Expand All @@ -25,52 +20,44 @@ def build_rag_chain(args):
llm = build_llm_via_langchain(args.llm_provider, args.llm_model)
retriever = build_retriever_from_args(args)

# Prompt to contextualize the latest query based on the chat history.
contextualize_q_system_prompt = (
"Given a chat history and the latest user question which might reference context in the chat history, "
"formulate a standalone question which can be understood without the chat history. Do NOT answer the question, "
"just reformulate it if needed and otherwise return it as is."
contextualize_q_prompt = ChatPromptTemplate.from_messages([
("system", (
"Given a chat history and the latest user question which might reference context in the chat history, "
"formulate a standalone question which can be understood without the chat history. "
"Do NOT answer the question, just reformulate it if needed and otherwise return it as is."
)),
MessagesPlaceholder("chat_history"),
("human", "{input}")
])

history_aware_retriever = create_history_aware_retriever(
llm.with_config(tags=["contextualize_q_llm"]),
retriever,
contextualize_q_prompt
)
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
contextualize_q_llm = llm.with_config(tags=["contextualize_q_llm"])
history_aware_retriever = create_history_aware_retriever(contextualize_q_llm, retriever, contextualize_q_prompt)

qa_system_prompt = (
f"You are my coding buddy, helping me quickly understand a GitHub repository called {args.repo_id}."
f"You are my coding buddy, helping me quickly understand a GitHub repository called {args.repo_id}. "
"Assume I am an advanced developer and answer my questions in the most succinct way possible."
"\n\n"
"Here are some snippets from the codebase."
"\n\n"
"{context}"
"\n\nHere are some snippets from the codebase.\n\n{context}"
)
qa_prompt = ChatPromptTemplate.from_messages(
[

question_answer_chain = create_stuff_documents_chain(
llm, ChatPromptTemplate.from_messages([
("system", qa_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
("human", "{input}")
])
)

question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
return rag_chain
return create_retrieval_chain(history_aware_retriever, question_answer_chain)


def main():
parser = configargparse.ArgParser(
description="Batch-embeds a GitHub repository and its issues.", ignore_unknown_config_file_keys=True
)
parser.add(
"--share",
default=False,
help="Whether to make the gradio app publicly accessible.",
)
parser.add("--share", default=False, help="Make the Gradio app publicly accessible.")

validator = sage_config.add_all_args(parser)
args = parser.parse_args()
Expand All @@ -79,48 +66,34 @@ def main():
rag_chain = build_rag_chain(args)

def source_md(file_path: str, url: str) -> str:
"""Formats a context source in Markdown."""
"""Format a context source in Markdown."""
return f"[{file_path}]({url})"

async def _predict(message, history):
"""Performs one RAG operation."""
history_langchain_format = []
for human, ai in history:
history_langchain_format.append(HumanMessage(content=human))
history_langchain_format.append(AIMessage(content=ai))
history_langchain_format.append(HumanMessage(content=message))

query_rewrite = ""
response = ""
async for event in rag_chain.astream_events(
{
"input": message,
"chat_history": history_langchain_format,
},
version="v1",
):
if event["name"] == "retrieve_documents" and "output" in event["data"]:
sources = [(doc.metadata["file_path"], doc.metadata["url"]) for doc in event["data"]["output"]]
# Deduplicate while preserving the order.
sources = list(dict.fromkeys(sources))
response += "## Sources:\n" + "\n".join([source_md(s[0], s[1]) for s in sources]) + "\n## Response:\n"
"""Perform one RAG operation."""
history_langchain_format = [
(HumanMessage(content=human), AIMessage(content=ai)) for human, ai in history
] + [HumanMessage(content=message)]

response, query_rewrite = "", ""
async for event in rag_chain.astream_events({"input": message, "chat_history": history_langchain_format}, version="v1"):
if event["name"] == "retrieve_documents" and "output" in event["data"]:
sources = {(doc.metadata["file_path"], doc.metadata["url"]) for doc in event["data"]["output"]}
response += "## Sources:\n" + "\n".join(source_md(*s) for s in sources) + "\n## Response:\n"
elif event["event"] == "on_chat_model_stream":
chunk = event["data"]["chunk"].content

if "contextualize_q_llm" in event["tags"]:
query_rewrite += chunk
else:
# This is the actual response to the user query.
if not response:
logging.info(f"Query rewrite: {query_rewrite}")
response += chunk
yield response

gr.ChatInterface(
_predict,
title=args.repo_id,
examples=["What does this repo do?", "Give me some sample code."],
title=f"{args.repo_id}" if args.repo_id else "GitHub Repo Chat",
examples=["What does this repo do?", "Give me some sample code."]
).launch(share=args.share)


Expand Down
Loading