diff --git a/streamlitui.py b/streamlitui.py index e5f2024..b40fc7f 100644 --- a/streamlitui.py +++ b/streamlitui.py @@ -4,7 +4,7 @@ from streamlit_chat import message from webquery import WebQuery -st.set_page_config(page_title="ChatPDF") +st.set_page_config(page_title="Website to Chatbot") def display_messages(): @@ -22,7 +22,12 @@ def process_input(): st.session_state["messages"].append((user_text, True)) st.session_state["messages"].append((query_text, False)) - + +def ingest_input(): + if st.session_state["input_url"] and len(st.session_state["input_url"].strip()) > 0: + url = st.session_state["input_url"].strip() + with st.session_state["thinking_spinner"], st.spinner(f"Thinking"): + ingest_text = st.session_state["webquery"].ingest(url) def is_openai_api_key_set() -> bool: return len(st.session_state["OPENAI_API_KEY"]) > 0 @@ -48,15 +53,11 @@ def main(): st.session_state["OPENAI_API_KEY"] = st.session_state["input_OPENAI_API_KEY"] st.session_state["messages"] = [] st.session_state["user_input"] = "" + st.session_state["input_url"] = "" st.session_state["webquery"] = WebQuery(st.session_state["OPENAI_API_KEY"]) st.subheader("Add a url") - if st.text_input("Input url", value=st.session_state["url"], key="input_url", type="default"): - if ( - len(st.session_state["input_url"]) > 0 - and st.session_state["input_url"] != st.session_state["url"] - ): - st.session_state["url"] = st.session_state["input_url"] + st.text_input("Input url", value=st.session_state["url"], key="input_url", disabled=not is_openai_api_key_set(), on_change=ingest_input) st.session_state["ingestion_spinner"] = st.empty() diff --git a/webquery.py b/webquery.py index 2777879..b4d05b4 100644 --- a/webquery.py +++ b/webquery.py @@ -23,12 +23,13 @@ def ask(self, question: str) -> str: response = self.chain.run(input_documents=docs, question=question) return response - def ingest(self, url: str) -> None: + def ingest(self, url: str) -> str: result = trafilatura.extract(trafilatura.fetch_url(url)) documents = [Document(page_content=result, metadata={"source": url})] splitted_documents = self.text_splitter.split_documents(documents) self.db = Chroma.from_documents(splitted_documents, self.embeddings).as_retriever() self.chain = load_qa_chain(OpenAI(temperature=0), chain_type="stuff") + return "Success" def forget(self) -> None: self.db = None