-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathgraph.py
40 lines (35 loc) · 1.28 KB
/
graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from langgraph.graph import END, StateGraph
from typing_extensions import TypedDict
from typing import List
from nodes.transform_node import transform_query
from nodes.retrieve_node import retrieve
from nodes.search_node import web_search
from nodes.grade_node import grade_documents
from nodes.generate_node import generate
from nodes.decision_node import decide_to_generate
def workflow_compiler():
class GraphState(TypedDict):
question: str
generation: str
web_search: str
documents: List[str]
workflow = StateGraph(GraphState)
workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("transform_query", transform_query)
workflow.add_node("web_search_node", web_search)
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
decide_to_generate,
{
"transform_query": "transform_query",
"generate": "generate",
},
)
workflow.add_edge("transform_query", "web_search_node")
workflow.add_edge("web_search_node", "generate")
workflow.add_edge("generate", END)
return workflow.compile()