-
Notifications
You must be signed in to change notification settings - Fork 122
/
Copy pathchat_utils.py
134 lines (93 loc) · 4.26 KB
/
chat_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.llms import OpenAI
from fuzzywuzzy import fuzz
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from my_prompts import chat_prompt, hypothetical_prompt
from dotenv import load_dotenv
from summary_utils import doc_loader, remove_special_tokens, directory_loader
nltk.download('stopwords')
nltk.download('punkt')
def create_and_save_directory_embeddings(directory_path, name):
embeddings = OpenAIEmbeddings()
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
docs = directory_loader(directory_path)
split_docs = splitter.split_documents(docs)
processed_split_docs = remove_special_tokens(split_docs)
db = FAISS.from_documents(processed_split_docs, embeddings)
db.save_local(folder_path='directory_embeddings', index_name=name)
return db
def create_and_save_chat_embeddings(file_path):
name = os.path.split(file_path)[1].split('.')[0]
embeddings = OpenAIEmbeddings()
doc = doc_loader(file_path)
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
split_docs = splitter.split_documents(doc)
processed_split_docs = remove_special_tokens(split_docs)
db = FAISS.from_documents(processed_split_docs, embeddings)
db.save_local(folder_path='embeddings', index_name=name)
def load_chat_embeddings(file_path):
name = os.path.split(file_path)[1].split('.')[0]
embeddings = OpenAIEmbeddings()
db = FAISS.load_local(folder_path='embeddings', index_name=name, embeddings=embeddings)
return db
def results_from_db(db:FAISS, question, num_results=10):
results = db.similarity_search(question, k=num_results)
return results
def rerank_fuzzy_matching(question, results, num_results=5):
filtered_question = filter_stopwords(question)
if filtered_question == '':
return results[-5:]
scores_and_results = []
for result in results:
score = fuzz.partial_ratio(question, result.page_content)
scores_and_results.append((score, result))
scores_and_results.sort(key=lambda x: x[0], reverse=True)
reranked = [result for score, result in scores_and_results]
return reranked[:num_results]
def filter_stopwords(question):
words = word_tokenize(question)
filtered_words = [word for word in words if word not in stopwords.words('english')]
filtered_sentence = ' '.join(filtered_words)
return filtered_sentence
def qa_from_db(question, db, llm_name, hypothetical):
llm = create_llm(llm_name)
if hypothetical:
hypothetical_llm = create_llm(llm_name)
hypothetical_answer = hypothetical_document_embeddings(question, hypothetical_llm)
results = results_from_db(db, hypothetical_answer)
else:
results = results_from_db(db, question)
reranked_results = rerank_fuzzy_matching(question, results)
reranked_content = [result.page_content for result in reranked_results]
if type(llm_name) != str:
message = f'Answer the user question based on the context. Question: {question} Context: {reranked_content[:2]} Answer:'
else:
message = f'{chat_prompt} ---------- Context: {reranked_content} -------- User Question: {question} ---------- Response:'
formatted_sources = source_formatter(reranked_results)
output = llm(message)
return output, formatted_sources
def source_formatter(sources):
formatted_strings = []
for doc in sources:
source_name = doc.metadata['source'].split('\\')[-1]
source_content = doc.page_content.replace('\n', ' ') # Replacing newlines with spaces
formatted_string = f"Source name: {source_name} | Source content: '{source_content}' - end of content"
formatted_strings.append(formatted_string)
final_string = '\n\n\n'.join(formatted_strings)
return final_string
def create_llm(llm_name):
if type(llm_name) != str:
return llm_name
else:
llm = OpenAI(model_name=llm_name)
return llm
def hypothetical_document_embeddings(question, llm):
message = f'{hypothetical_prompt} {question} :'
output = llm(message)
print("output: ", output)
return output