-
Notifications
You must be signed in to change notification settings - Fork 1
/
combine_simple_RAG_chains.py
75 lines (67 loc) · 2.64 KB
/
combine_simple_RAG_chains.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from setup import Config, logger
from tools import prompt_templates_generate
def simple_retriever_generator_chain1(retriever, llm_gen):
############## FULL RAG = RETRIEVER + GENERATOR ##############
logger.info("Simple RETRIEVER and RAG Chain")
############## RETRIEVAL CHAIN ##############
# Retrieval Chain for multiple alternatives to the question formulation
retrieval_chain = (
itemgetter("question")
# retriever.invoke() takes str as input, so we need to extract "question" key from input to
# retrieval_chain.invoke({}) dict as str
| retriever
)
# to check list of retrieved documents
# result = retrieval_chain.invoke({"question": Config.MYQ})
# print(result)
############## GENERATOR CHAIN ##############
# Prompt for generation answer with retriever and generation prompt
prompt_generation = PromptTemplate(
template=prompt_templates_generate.prompt_template_question_context,
input_variables=["question", "context"],
)
# RAG Chain
rag_chain = (
{
"context": retrieval_chain,
"question": itemgetter("question"),
}
| prompt_generation
| llm_gen
| StrOutputParser()
)
############## RUN ALL CHAINS ##############
result = rag_chain.invoke({"question": Config.MYQ})
print(result)
return rag_chain
def simple_retriever_generator_chain2(retriever, llm_gen):
############## FULL RAG = RETRIEVER + GENERATOR ##############
logger.info("Simple RETRIEVER and RAG Chain")
############## NO RETRIEVAL CHAIN ##############
# No additional retrieval chain, just the save retriever (FAISS().as_retriever() or Chroma().as_retriever())
# to check list of retrieved documents:
# result = retriever.invoke(Config.MYQ)
# print(result)
############## GENERATOR CHAIN ##############
# Prompt for generation answer with retriever and generation prompt
prompt_generation = PromptTemplate(
template=prompt_templates_generate.prompt_template_question_context,
input_variables=["question", "context"],
)
# RAG Chain
rag_chain = (
{
"context": itemgetter("question") | retriever, # retrieval chain is here
"question": itemgetter("question"),
}
| prompt_generation
| llm_gen
| StrOutputParser()
)
############## RUN ALL CHAINS ##############
result = rag_chain.invoke({"question": Config.MYQ})
print(result)
return rag_chain