Skip to content

Commit f669f57

Browse files
committed
agentic rag
1 parent 28932a2 commit f669f57

File tree

19 files changed

+812
-513
lines changed

19 files changed

+812
-513
lines changed

backend/apps/ai/Makefile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
ai-run-rag-tool:
2-
@echo "Running RAG tool"
3-
@CMD="python manage.py ai_run_rag_tool" $(MAKE) exec-backend-command
1+
ai-run-agentic-rag:
2+
@echo "Running agentic RAG"
3+
@CMD="python manage.py ai_run_agentic_rag" $(MAKE) exec-backend-command
44

55
ai-update-chapter-chunks:
66
@echo "Updating chapter chunks"

backend/apps/ai/agent/agent.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""LangGraph-powered agent for iterative RAG answering."""
2+
3+
from __future__ import annotations
4+
5+
import logging
6+
from typing import Any
7+
8+
from langgraph.graph import END, START, StateGraph
9+
10+
from apps.ai.agent.nodes import AgentNodes
11+
from apps.ai.common.constants import (
12+
DEFAULT_CHUNKS_RETRIEVAL_LIMIT,
13+
DEFAULT_SIMILARITY_THRESHOLD,
14+
)
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
class AgenticRAGAgent:
20+
"""LangGraph-based controller for agentic RAG with self-correcting retrieval."""
21+
22+
def __init__(self) -> None:
23+
"""Initialize the AgenticRAGAgent."""
24+
self.nodes = AgentNodes()
25+
self.graph = self.build_graph()
26+
27+
def run(
28+
self,
29+
query: str,
30+
) -> dict[str, Any]:
31+
"""Execute the full RAG loop."""
32+
initial_state: dict[str, Any] = {
33+
"query": query,
34+
"iteration": 0,
35+
"feedback": None,
36+
"history": [],
37+
"content_types": [],
38+
"limit": DEFAULT_CHUNKS_RETRIEVAL_LIMIT,
39+
"similarity_threshold": DEFAULT_SIMILARITY_THRESHOLD,
40+
}
41+
42+
logger.info("Starting Agentic RAG workflow with metadata-aware retrieval")
43+
final_state = self.graph.invoke(initial_state)
44+
45+
return {
46+
"answer": final_state.get("answer", ""),
47+
"iterations": final_state.get("iteration", 0),
48+
"evaluation": final_state.get("evaluation", {}),
49+
"context_chunks": final_state.get("context_chunks", []),
50+
"history": final_state.get("history", []),
51+
"extracted_metadata": final_state.get("extracted_metadata", {}),
52+
}
53+
54+
def build_graph(self):
55+
"""Build the LangGraph state machine for the RAG workflow."""
56+
graph = StateGraph(dict)
57+
graph.add_node("retrieve", self.nodes.retrieve)
58+
graph.add_node("generate", self.nodes.generate)
59+
graph.add_node("evaluate", self.nodes.evaluate)
60+
61+
graph.add_edge(START, "retrieve")
62+
graph.add_edge("retrieve", "generate")
63+
graph.add_edge("generate", "evaluate")
64+
graph.add_conditional_edges(
65+
"evaluate",
66+
self.nodes.route_from_evaluation,
67+
{"refine": "generate", "complete": END},
68+
)
69+
70+
return graph.compile()

backend/apps/ai/agent/nodes.py

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
"""LangGraph nodes for the Agentic RAG workflow."""
2+
3+
from __future__ import annotations
4+
5+
import json
6+
import os
7+
from typing import Any
8+
9+
import openai
10+
from django.core.exceptions import ObjectDoesNotExist
11+
12+
from apps.ai.agent.tools.rag.generator import Generator
13+
from apps.ai.agent.tools.rag.retriever import Retriever
14+
from apps.ai.common.constants import (
15+
DEFAULT_CHUNKS_RETRIEVAL_LIMIT,
16+
DEFAULT_MAX_ITERATIONS,
17+
DEFAULT_REASONING_MODEL,
18+
DEFAULT_SIMILARITY_THRESHOLD,
19+
)
20+
from apps.core.models.prompt import Prompt
21+
22+
23+
class AgentNodes:
24+
"""Collection of LangGraph node functions with injected dependencies."""
25+
26+
def __init__(self) -> None:
27+
"""Initialize AgentNodes."""
28+
if not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")):
29+
error_msg = "DJANGO_OPEN_AI_SECRET_KEY environment variable not set"
30+
raise ValueError(error_msg)
31+
32+
self.openai_client = openai.OpenAI(api_key=openai_api_key)
33+
34+
self.retriever = Retriever()
35+
self.generator = Generator()
36+
37+
def retrieve(self, state: dict[str, Any]) -> dict[str, Any]:
38+
"""Retrieve context chunks based on the query."""
39+
if state.get("context_chunks"):
40+
return state
41+
42+
limit = state.get("limit", DEFAULT_CHUNKS_RETRIEVAL_LIMIT)
43+
threshold = state.get("similarity_threshold", DEFAULT_SIMILARITY_THRESHOLD)
44+
query = state["query"]
45+
46+
if "extracted_metadata" not in state:
47+
state["extracted_metadata"] = self.extract_query_metadata(query)
48+
49+
metadata = state["extracted_metadata"]
50+
51+
chunks = self.retriever.retrieve(
52+
query=query,
53+
limit=limit,
54+
similarity_threshold=threshold,
55+
content_types=metadata.get("entity_types"),
56+
)
57+
58+
filtered_chunks = self.filter_chunks_by_metadata(chunks, metadata)
59+
60+
state["context_chunks"] = filtered_chunks[:limit]
61+
return state
62+
63+
def generate(self, state: dict[str, Any]) -> dict[str, Any]:
64+
"""Generate an answer using the retrieved context."""
65+
iteration = state.get("iteration", 0) + 1
66+
feedback = state.get("feedback")
67+
query = state["query"]
68+
augmented_query = (
69+
query if not feedback else f"{query}\\n\\nRevise per feedback:\\n{feedback}"
70+
)
71+
72+
answer = self.generator.generate_answer(
73+
query=augmented_query,
74+
context_chunks=state.get("context_chunks", []),
75+
)
76+
77+
history = state.get("history", [])
78+
history.append(
79+
{
80+
"iteration": iteration,
81+
"feedback": feedback,
82+
"query": augmented_query,
83+
"answer": answer,
84+
}
85+
)
86+
87+
state.update(
88+
{"answer": answer, "iteration": iteration, "history": history, "feedback": None}
89+
)
90+
return state
91+
92+
def evaluate(self, state: dict[str, Any]) -> dict[str, Any]:
93+
"""Evaluate the generated answer and decide on the next step."""
94+
answer = state.get("answer", "")
95+
evaluation = self.call_evaluator(
96+
query=state["query"],
97+
answer=answer,
98+
context_chunks=state.get("context_chunks", []),
99+
)
100+
101+
history = state.get("history", [])
102+
if history:
103+
history[-1]["evaluation"] = evaluation
104+
105+
if "missing context" in evaluation.get("justification", "").lower():
106+
limit = state.get("limit", DEFAULT_CHUNKS_RETRIEVAL_LIMIT) * 2
107+
threshold = state.get("similarity_threshold", DEFAULT_SIMILARITY_THRESHOLD) * 0.95
108+
109+
metadata = state.get("extracted_metadata", {})
110+
111+
new_chunks = self.retriever.retrieve(
112+
query=state["query"],
113+
limit=limit,
114+
similarity_threshold=threshold,
115+
content_types=metadata.get("entity_types"),
116+
)
117+
118+
filtered_chunks = self.filter_chunks_by_metadata(new_chunks, metadata)
119+
state["context_chunks"] = filtered_chunks[:limit]
120+
121+
state["feedback"] = "Expand and refine answer using newly retrieved context."
122+
else:
123+
state["feedback"] = evaluation.get("feedback") or None
124+
125+
state.update({"evaluation": evaluation, "history": history})
126+
return state
127+
128+
def route_from_evaluation(self, state: dict[str, Any]) -> str:
129+
"""Route the workflow based on the evaluation result."""
130+
evaluation = state.get("evaluation") or {}
131+
iteration = state.get("iteration", 0)
132+
if evaluation.get("complete") or iteration >= DEFAULT_MAX_ITERATIONS:
133+
return "complete"
134+
return "refine"
135+
136+
def filter_chunks_by_metadata(
137+
self,
138+
retrieved_chunks: list[dict[str, Any]],
139+
query_metadata: dict[str, Any],
140+
) -> list[dict[str, Any]]:
141+
"""Rank and filter retrieved chunks using metadata and simple heuristics."""
142+
if not retrieved_chunks:
143+
return []
144+
145+
requested_fields = query_metadata.get("requested_fields", [])
146+
query_filters = query_metadata.get("filters", {})
147+
148+
if not requested_fields and not query_filters:
149+
return retrieved_chunks
150+
151+
ranked_chunks: list[tuple[dict[str, Any], float]] = []
152+
for chunk in retrieved_chunks:
153+
relevance_score = 0.0
154+
chunk_metadata = chunk.get("additional_context", {})
155+
chunk_content = chunk.get("text", "").lower()
156+
157+
for field_name in requested_fields:
158+
if chunk_metadata.get(field_name):
159+
relevance_score += 2
160+
161+
for filter_field, filter_value in query_filters.items():
162+
if filter_field in chunk_metadata:
163+
metadata_value = chunk_metadata[filter_field]
164+
165+
if isinstance(metadata_value, str) and isinstance(filter_value, str):
166+
if filter_value.lower() in metadata_value.lower():
167+
relevance_score += 5
168+
169+
elif isinstance(metadata_value, list):
170+
if any(
171+
filter_value.lower() in str(item).lower() for item in metadata_value
172+
):
173+
relevance_score += 5
174+
175+
elif metadata_value == filter_value:
176+
relevance_score += 5
177+
178+
if isinstance(filter_value, str) and filter_value.lower() in chunk_content:
179+
relevance_score += 3
180+
181+
if chunk_metadata:
182+
relevance_score += len(chunk_metadata) * 0.1
183+
184+
ranked_chunks.append((chunk, relevance_score))
185+
186+
ranked_chunks.sort(
187+
key=lambda entry: (entry[1], entry[0].get("similarity", 0)), reverse=True
188+
)
189+
190+
return [chunk for chunk, _ in ranked_chunks[:DEFAULT_CHUNKS_RETRIEVAL_LIMIT]]
191+
192+
def extract_query_metadata(self, query: str) -> dict[str, Any]:
193+
"""Extract metadata from the user's query using an LLM."""
194+
metadata_extractor_prompt = Prompt.get_metadata_extractor_prompt()
195+
196+
if not metadata_extractor_prompt:
197+
error_msg = "Prompt with key 'metadata-extractor-prompt' not found."
198+
raise ObjectDoesNotExist(error_msg)
199+
200+
try:
201+
response = self.openai_client.chat.completions.create(
202+
model=DEFAULT_REASONING_MODEL,
203+
messages=[
204+
{"role": "system", "content": metadata_extractor_prompt},
205+
{"role": "user", "content": f"Query: {query}"},
206+
],
207+
max_tokens=500,
208+
temperature=0.7,
209+
)
210+
content = response.choices[0].message.content.strip()
211+
212+
if "```json" in content:
213+
content = content.split("```json")[1].split("```")[0].strip()
214+
elif "```" in content:
215+
content = content.split("```")[1].split("```")[0].strip()
216+
217+
return json.loads(content)
218+
219+
except (openai.OpenAIError, json.JSONDecodeError, ValueError):
220+
return {
221+
"requested_fields": [],
222+
"entity_types": [],
223+
"filters": {},
224+
"intent": "general query",
225+
}
226+
227+
def call_evaluator(
228+
self, *, query: str, answer: str, context_chunks: list[dict[str, Any]]
229+
) -> dict[str, Any]:
230+
"""Call the evaluator LLM to assess the quality of the generated answer."""
231+
formatted_context = self.generator.prepare_context(context_chunks)
232+
evaluation_prompt = (
233+
f"User Query:\\n{query}\\n\\n"
234+
f"Candidate Answer:\\n{answer}\\n\\n"
235+
f"Context Provided:\\n{formatted_context}\\n\\n"
236+
"Respond with the mandated JSON object."
237+
)
238+
239+
evaluator_system_prompt = Prompt.get_evaluator_system_prompt()
240+
241+
if not evaluator_system_prompt:
242+
error_msg = "Prompt with key 'evaluator-system-prompt' not found."
243+
raise ObjectDoesNotExist(error_msg)
244+
245+
try:
246+
response = self.openai_client.chat.completions.create(
247+
model=DEFAULT_REASONING_MODEL,
248+
messages=[
249+
{"role": "system", "content": evaluator_system_prompt},
250+
{"role": "user", "content": evaluation_prompt},
251+
],
252+
max_tokens=2000,
253+
temperature=0.7,
254+
)
255+
content = response.choices[0].message.content.strip()
256+
257+
if "```json" in content:
258+
content = content.split("```json")[1].split("```")[0].strip()
259+
elif "```" in content:
260+
content = content.split("```")[1].split("```")[0].strip()
261+
262+
return json.loads(content)
263+
264+
except (openai.OpenAIError, json.JSONDecodeError, ValueError):
265+
return {
266+
"complete": False,
267+
"feedback": "Evaluator error or invalid response.",
268+
"justification": "Evaluator error or invalid response.",
269+
}

backend/apps/ai/agent/tools/rag/generator.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class Generator:
1616
"""Generates answers to user queries based on retrieved context."""
1717

1818
MAX_TOKENS = 2000
19-
TEMPERATURE = 0.4
19+
TEMPERATURE = 0.8
2020

2121
def __init__(self, chat_model: str = "gpt-4o"):
2222
"""Initialize the Generator.
@@ -73,16 +73,6 @@ def generate_answer(self, query: str, context_chunks: list[dict[str, Any]]) -> s
7373
formatted_context = self.prepare_context(context_chunks)
7474

7575
user_prompt = f"""
76-
- You are an assistant for question-answering tasks related to OWASP.
77-
- Use the following pieces of retrieved context to answer the question.
78-
- If the question is related to OWASP then you can try to answer based on your knowledge, if you
79-
don't know the answer, just say that you don't know.
80-
- Try to give answer and keep the answer concise, but you really think that the response will be
81-
longer and better you will provide more information.
82-
- Ask for the current location if the query is related to location.
83-
- Ask for the information you need if the query is very personalized or user-centric.
84-
- Do not mention or refer to the word "context", "based on context", "provided information",
85-
"Information given to me" or similar phrases in your responses.
8676
Question: {query}
8777
Context: {formatted_context}
8878
Answer:

0 commit comments

Comments
 (0)