-
Notifications
You must be signed in to change notification settings - Fork 1
/
local_rag_chain_simple.py
86 lines (73 loc) · 3.07 KB
/
local_rag_chain_simple.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from embedders.create_llm_emb_default import create_llm_emb_default
from generators.create_llm_gen_default import create_llm_gen_default
from setup import Config, logger, print_config
from tools import prompt_templates_generate
from vectorstores.get_vectorstore import get_vectorstore
def local_rag_simple():
############## INITIAL SETUP ##############
print_config()
############## EMBEDDING MODEL ##############
# Load model for embedding documents
logger.info(f"LLM_EMB : {Config.HF_EMB_MODEL}")
llm_emb = create_llm_emb_default()
############## GENERATOR MODEL ##############
# Load model for generating answer
logger.info(f"LLM : {Config.HF_LLM_NAME}")
llm_gen = create_llm_gen_default()
############## VECTORSTORE FOR EMBEDDINGS ##############
# Create or load vectorstore (FAISS or Chroma)
logger.info("VECTORSTORE")
vectorstore = get_vectorstore(llm_emb)
############## RETRIEVER MODEL FROM EMBEDDING MODEL ##############
logger.info("RETRIEVER")
retriever = vectorstore.as_retriever(
search_type="similarity", search_kwargs={"k": 4}
)
del vectorstore
############## FULL RAG = RETRIEVER + GENERATOR ##############
logger.info("Simple RETRIEVER and RAG Chain")
############## RETRIEVAL CHAIN ##############
# Retrieval Chain for multiple alternatives to the question formulation
# V1
retrieval_chain = (
itemgetter("question")
# retriever.invoke() takes str as input, so we need to extract "question" key from input to
# retrieval_chain.invoke({}) dict as str
| retriever
)
# to check list of retrieved documents
# result = retrieval_chain.invoke({"question": Config.MYQ})
# print(result)
# OR for variant2:
# result = retriever.invoke(Config.MYQ)
############## GENERATOR CHAIN ##############
# Prompt for generation answer with retriever and generation prompt
prompt_generation = PromptTemplate(
template=prompt_templates_generate.prompt_template_question_context,
input_variables=["question", "context"],
)
# RAG Chain
rag_chain = (
{
# V1
"context": retrieval_chain,
# OR for variant2:
# "context": itemgetter("question") | retriever,
"question": itemgetter("question"),
}
| prompt_generation
| llm_gen
| StrOutputParser()
)
# V3 - you can create chain in another function and call it here, leave all other next code as it is:
# from chains.combine_simple_RAG_chains import simple_retriever_generator_chain1, simple_retriever_generator_chain2
# rag_chain = simple_retriever_generator_chain1(retriever, llm_gen)
# rag_chain = simple_retriever_generator_chain2(retriever, llm_gen)
############## RUN ALL CHAINS ##############
result = rag_chain.invoke({"question": Config.MYQ})
print(result)
if __name__ == "__main__":
local_rag_simple()