Skip to content

Commit

Permalink
Added streamlit implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
vadootvpeer committed May 29, 2023
1 parent a22789e commit b189cb6
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 0 deletions.
70 changes: 70 additions & 0 deletions streamlitui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import os
import tempfile
import streamlit as st
from streamlit_chat import message
from webquery import WebQuery

st.set_page_config(page_title="ChatPDF")


def display_messages():
st.subheader("Chat")
for i, (msg, is_user) in enumerate(st.session_state["messages"]):
message(msg, is_user=is_user, key=str(i))
st.session_state["thinking_spinner"] = st.empty()


def process_input():
if st.session_state["user_input"] and len(st.session_state["user_input"].strip()) > 0:
user_text = st.session_state["user_input"].strip()
with st.session_state["thinking_spinner"], st.spinner(f"Thinking"):
query_text = st.session_state["webquery"].ask(user_text)

st.session_state["messages"].append((user_text, True))
st.session_state["messages"].append((query_text, False))


def is_openai_api_key_set() -> bool:
return len(st.session_state["OPENAI_API_KEY"]) > 0


def main():
if len(st.session_state) == 0:
st.session_state["messages"] = []
st.session_state["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY", "")
if is_openai_api_key_set():
st.session_state["webquery"] = WebQuery(st.session_state["OPENAI_API_KEY"])
else:
st.session_state["webquery"] = None

st.header("ChatPDF")

if st.text_input("OpenAI API Key", value=st.session_state["OPENAI_API_KEY"], key="input_OPENAI_API_KEY", type="password"):
if (
len(st.session_state["input_OPENAI_API_KEY"]) > 0
and st.session_state["input_OPENAI_API_KEY"] != st.session_state["OPENAI_API_KEY"]
):
st.session_state["OPENAI_API_KEY"] = st.session_state["input_OPENAI_API_KEY"]
st.session_state["messages"] = []
st.session_state["user_input"] = ""
st.session_state["webquery"] = WebQuery(st.session_state["OPENAI_API_KEY"])

st.subheader("Add a url")
if st.text_input("Input url", value=st.session_state["url"], key="input_url", type="text"):
if (
len(st.session_state["input_url"]) > 0
and st.session_state["input_url"] != st.session_state["text"]
):
st.session_state["url"] = st.session_state["input_url"]

st.session_state["ingestion_spinner"] = st.empty()

display_messages()
st.text_input("Message", key="user_input", disabled=not is_openai_api_key_set(), on_change=process_input)

st.divider()
st.markdown("Source code: [Github](https://github.com/Anil-matcha/Website-to-Chatbot)")


if __name__ == "__main__":
main()
35 changes: 35 additions & 0 deletions webquery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os, trafilatura
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.chains.question_answering import load_qa_chain
from langchain.llms import OpenAI
from langchain.docstore.document import Document

class WebQuery:
def __init__(self, openai_api_key = None) -> None:
self.embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
os.environ["OPENAI_API_KEY"] = openai_api_key
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
self.llm = OpenAI(temperature=0, openai_api_key=openai_api_key)
self.chain = None
self.db = None

def ask(self, question: str) -> str:
if self.chain is None:
response = "Please, add a document."
else:
docs = self.db.get_relevant_documents(question)
response = self.chain.run(input_documents=docs, question=question)
return response

def ingest(self, url: str) -> None:
result = trafilatura.extract(trafilatura.fetch_url(url))
documents = [Document(page_content=result, metadata={"source": url})]
splitted_documents = self.text_splitter.split_documents(documents)
self.db = Chroma.from_documents(splitted_documents, self.embeddings).as_retriever()
self.chain = load_qa_chain(OpenAI(temperature=0), chain_type="stuff")

def forget(self) -> None:
self.db = None
self.chain = None

0 comments on commit b189cb6

Please sign in to comment.