-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
97 lines (79 loc) · 2.69 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from langchain.callbacks.base import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chains import (
ConversationalRetrievalChain,
LLMChain
)
from langchain.chains.question_answering import load_qa_chain
from langchain.llms import OpenAI
from langchain.prompts.prompt import PromptTemplate
from redis_reader import RedisProductRetriever
from redis_store import getRedisStore
from create_meta import createMeta
import json
template = """Given the following chat history and a follow up question, rephrase the follow up input question to be a standalone question.
Or end the conversation if it seems like it's done.
Chat History:\"""
{chat_history}
\"""
Follow Up Input: \"""
{question}
\"""
Standalone question:"""
condense_question_prompt = PromptTemplate.from_template(template)
template = """You are a friendly, conversational retail shopping assistant. Use the following context including product names, descriptions, and keywords to show the shopper whats available, help find what they want, and answer any questions.
It's ok if you don't know the answer.
Context:\"""
{context}
\"""
Question:\"
\"""
Helpful Answer:"""
qa_prompt= PromptTemplate.from_template(template)
# define two LLM models from OpenAI
llm = OpenAI(temperature=0)
streaming_llm = OpenAI(
streaming=True,
callback_manager=CallbackManager([
StreamingStdOutCallbackHandler()
]),
verbose=True,
max_tokens=150,
temperature=0.2
)
# use the LLM Chain to create a question creation chain
question_generator = LLMChain(
llm=llm,
prompt=condense_question_prompt
)
# use the streaming LLM to create a question answering chain
doc_chain = load_qa_chain(
llm=streaming_llm,
chain_type="stuff",
prompt=qa_prompt
)
# chatbot = ConversationalRetrievalChain(
# retriever=vectorstore.as_retriever(),
# combine_docs_chain=doc_chain,
# question_generator=question_generator
# )
# check product_metadata.json in local directory
vectorstore = getRedisStore('product:embedding')
redis_product_retriever = RedisProductRetriever(vectorstore=vectorstore)
chatbot = ConversationalRetrievalChain(
retriever=redis_product_retriever,
combine_docs_chain=doc_chain,
question_generator=question_generator
)
# create a chat history buffer
chat_history = []
# gather user input for the first question to kick off the bot
question = input("Hi! What are you looking for today?")
# keep the bot running in a loop to simulate a conversation
while True:
result = chatbot(
{"question": question, "chat_history": chat_history}
)
print("\n")
chat_history.append((result["question"], result["answer"]))
question = input()