Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions backend/apps/ai/Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
ai-run-rag-tool:
@echo "Running RAG tool"
@CMD="python manage.py ai_run_rag_tool" $(MAKE) exec-backend-command
ai-run-agentic-rag:
@echo "Running agentic RAG"
@CMD="python manage.py ai_run_agentic_rag" $(MAKE) exec-backend-command

ai-update-chapter-chunks:
@echo "Updating chapter chunks"
Expand Down
70 changes: 70 additions & 0 deletions backend/apps/ai/agent/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""LangGraph-powered agent for iterative RAG answering."""

from __future__ import annotations

import logging
from typing import Any

from langgraph.graph import END, START, StateGraph

from apps.ai.agent.nodes import AgentNodes
from apps.ai.common.constants import (
DEFAULT_CHUNKS_RETRIEVAL_LIMIT,
DEFAULT_SIMILARITY_THRESHOLD,
)

logger = logging.getLogger(__name__)


class AgenticRAGAgent:
"""LangGraph-based controller for agentic RAG with self-correcting retrieval."""

def __init__(self) -> None:
"""Initialize the AgenticRAGAgent."""
self.nodes = AgentNodes()
self.graph = self.build_graph()

def run(
self,
query: str,
) -> dict[str, Any]:
"""Execute the full RAG loop."""
initial_state: dict[str, Any] = {
"query": query,
"iteration": 0,
"feedback": None,
"history": [],
"content_types": [],
"limit": DEFAULT_CHUNKS_RETRIEVAL_LIMIT,
"similarity_threshold": DEFAULT_SIMILARITY_THRESHOLD,
}

logger.info("Starting Agentic RAG workflow with metadata-aware retrieval")
final_state = self.graph.invoke(initial_state)

return {
"answer": final_state.get("answer", ""),
"iterations": final_state.get("iteration", 0),
"evaluation": final_state.get("evaluation", {}),
"context_chunks": final_state.get("context_chunks", []),
"history": final_state.get("history", []),
"extracted_metadata": final_state.get("extracted_metadata", {}),
}

def build_graph(self):
"""Build the LangGraph state machine for the RAG workflow."""
graph = StateGraph(dict)
graph.add_node("retrieve", self.nodes.retrieve)
graph.add_node("generate", self.nodes.generate)
graph.add_node("evaluate", self.nodes.evaluate)

graph.add_edge(START, "retrieve")
graph.add_edge("retrieve", "generate")
graph.add_edge("generate", "evaluate")
graph.add_conditional_edges(
"evaluate",
self.nodes.route_from_evaluation,
{"refine": "generate", "complete": END},
)

return graph.compile()
269 changes: 269 additions & 0 deletions backend/apps/ai/agent/nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
"""LangGraph nodes for the Agentic RAG workflow."""

from __future__ import annotations

import json
import os
from typing import Any

import openai
from django.core.exceptions import ObjectDoesNotExist

from apps.ai.agent.tools.rag.generator import Generator
from apps.ai.agent.tools.rag.retriever import Retriever
from apps.ai.common.constants import (
DEFAULT_CHUNKS_RETRIEVAL_LIMIT,
DEFAULT_MAX_ITERATIONS,
DEFAULT_REASONING_MODEL,
DEFAULT_SIMILARITY_THRESHOLD,
)
from apps.core.models.prompt import Prompt


class AgentNodes:
"""Collection of LangGraph node functions with injected dependencies."""

def __init__(self) -> None:
"""Initialize AgentNodes."""
if not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")):
error_msg = "DJANGO_OPEN_AI_SECRET_KEY environment variable not set"
raise ValueError(error_msg)

self.openai_client = openai.OpenAI(api_key=openai_api_key)

self.retriever = Retriever()
self.generator = Generator()

def retrieve(self, state: dict[str, Any]) -> dict[str, Any]:
"""Retrieve context chunks based on the query."""
if state.get("context_chunks"):
return state

limit = state.get("limit", DEFAULT_CHUNKS_RETRIEVAL_LIMIT)
threshold = state.get("similarity_threshold", DEFAULT_SIMILARITY_THRESHOLD)
query = state["query"]

if "extracted_metadata" not in state:
state["extracted_metadata"] = self.extract_query_metadata(query)

metadata = state["extracted_metadata"]

chunks = self.retriever.retrieve(
query=query,
limit=limit,
similarity_threshold=threshold,
content_types=metadata.get("entity_types"),
)

filtered_chunks = self.filter_chunks_by_metadata(chunks, metadata)

state["context_chunks"] = filtered_chunks[:limit]
return state

def generate(self, state: dict[str, Any]) -> dict[str, Any]:
"""Generate an answer using the retrieved context."""
iteration = state.get("iteration", 0) + 1
feedback = state.get("feedback")
query = state["query"]
augmented_query = (
query if not feedback else f"{query}\\n\\nRevise per feedback:\\n{feedback}"
)

answer = self.generator.generate_answer(
query=augmented_query,
context_chunks=state.get("context_chunks", []),
)

history = state.get("history", [])
history.append(
{
"iteration": iteration,
"feedback": feedback,
"query": augmented_query,
"answer": answer,
}
)

state.update(
{"answer": answer, "iteration": iteration, "history": history, "feedback": None}
)
return state

def evaluate(self, state: dict[str, Any]) -> dict[str, Any]:
"""Evaluate the generated answer and decide on the next step."""
answer = state.get("answer", "")
evaluation = self.call_evaluator(
query=state["query"],
answer=answer,
context_chunks=state.get("context_chunks", []),
)

history = state.get("history", [])
if history:
history[-1]["evaluation"] = evaluation

if "missing context" in evaluation.get("justification", "").lower():
limit = state.get("limit", DEFAULT_CHUNKS_RETRIEVAL_LIMIT) * 2
threshold = state.get("similarity_threshold", DEFAULT_SIMILARITY_THRESHOLD) * 0.95

metadata = state.get("extracted_metadata", {})

new_chunks = self.retriever.retrieve(
query=state["query"],
limit=limit,
similarity_threshold=threshold,
content_types=metadata.get("entity_types"),
)

filtered_chunks = self.filter_chunks_by_metadata(new_chunks, metadata)
state["context_chunks"] = filtered_chunks[:limit]

state["feedback"] = "Expand and refine answer using newly retrieved context."
else:
state["feedback"] = evaluation.get("feedback") or None

state.update({"evaluation": evaluation, "history": history})
return state

def route_from_evaluation(self, state: dict[str, Any]) -> str:
"""Route the workflow based on the evaluation result."""
evaluation = state.get("evaluation") or {}
iteration = state.get("iteration", 0)
if evaluation.get("complete") or iteration >= DEFAULT_MAX_ITERATIONS:
return "complete"
return "refine"

def filter_chunks_by_metadata(
self,
retrieved_chunks: list[dict[str, Any]],
query_metadata: dict[str, Any],
) -> list[dict[str, Any]]:
"""Rank and filter retrieved chunks using metadata and simple heuristics."""
if not retrieved_chunks:
return []

requested_fields = query_metadata.get("requested_fields", [])
query_filters = query_metadata.get("filters", {})

if not requested_fields and not query_filters:
return retrieved_chunks

ranked_chunks: list[tuple[dict[str, Any], float]] = []
for chunk in retrieved_chunks:
relevance_score = 0.0
chunk_metadata = chunk.get("additional_context", {})
chunk_content = chunk.get("text", "").lower()

for field_name in requested_fields:
if chunk_metadata.get(field_name):
relevance_score += 2

for filter_field, filter_value in query_filters.items():
if filter_field in chunk_metadata:
metadata_value = chunk_metadata[filter_field]

if isinstance(metadata_value, str) and isinstance(filter_value, str):
if filter_value.lower() in metadata_value.lower():
relevance_score += 5

elif isinstance(metadata_value, list):
if any(
filter_value.lower() in str(item).lower() for item in metadata_value
):
relevance_score += 5

elif metadata_value == filter_value:
relevance_score += 5

if isinstance(filter_value, str) and filter_value.lower() in chunk_content:
relevance_score += 3

if chunk_metadata:
relevance_score += len(chunk_metadata) * 0.1

ranked_chunks.append((chunk, relevance_score))

ranked_chunks.sort(
key=lambda entry: (entry[1], entry[0].get("similarity", 0)), reverse=True
)

return [chunk for chunk, _ in ranked_chunks[:DEFAULT_CHUNKS_RETRIEVAL_LIMIT]]

def extract_query_metadata(self, query: str) -> dict[str, Any]:
"""Extract metadata from the user's query using an LLM."""
metadata_extractor_prompt = Prompt.get_metadata_extractor_prompt()

if not metadata_extractor_prompt:
error_msg = "Prompt with key 'metadata-extractor-prompt' not found."
raise ObjectDoesNotExist(error_msg)

try:
response = self.openai_client.chat.completions.create(
model=DEFAULT_REASONING_MODEL,
messages=[
{"role": "system", "content": metadata_extractor_prompt},
{"role": "user", "content": f"Query: {query}"},
],
max_tokens=500,
temperature=0.7,
)
content = response.choices[0].message.content.strip()

if "```json" in content:
content = content.split("```json")[1].split("```")[0].strip()
elif "```" in content:
content = content.split("```")[1].split("```")[0].strip()

return json.loads(content)

except (openai.OpenAIError, json.JSONDecodeError, ValueError):
return {
"requested_fields": [],
"entity_types": [],
"filters": {},
"intent": "general query",
}

def call_evaluator(
self, *, query: str, answer: str, context_chunks: list[dict[str, Any]]
) -> dict[str, Any]:
"""Call the evaluator LLM to assess the quality of the generated answer."""
formatted_context = self.generator.prepare_context(context_chunks)
evaluation_prompt = (
f"User Query:\\n{query}\\n\\n"
f"Candidate Answer:\\n{answer}\\n\\n"
f"Context Provided:\\n{formatted_context}\\n\\n"
"Respond with the mandated JSON object."
)

evaluator_system_prompt = Prompt.get_evaluator_system_prompt()

if not evaluator_system_prompt:
error_msg = "Prompt with key 'evaluator-system-prompt' not found."
raise ObjectDoesNotExist(error_msg)

try:
response = self.openai_client.chat.completions.create(
model=DEFAULT_REASONING_MODEL,
messages=[
{"role": "system", "content": evaluator_system_prompt},
{"role": "user", "content": evaluation_prompt},
],
max_tokens=2000,
temperature=0.7,
)
content = response.choices[0].message.content.strip()

if "```json" in content:
content = content.split("```json")[1].split("```")[0].strip()
elif "```" in content:
content = content.split("```")[1].split("```")[0].strip()

return json.loads(content)

except (openai.OpenAIError, json.JSONDecodeError, ValueError):
return {
"complete": False,
"feedback": "Evaluator error or invalid response.",
"justification": "Evaluator error or invalid response.",
}
12 changes: 1 addition & 11 deletions backend/apps/ai/agent/tools/rag/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Generator:
"""Generates answers to user queries based on retrieved context."""

MAX_TOKENS = 2000
TEMPERATURE = 0.4
TEMPERATURE = 0.8

def __init__(self, chat_model: str = "gpt-4o"):
"""Initialize the Generator.
Expand Down Expand Up @@ -73,16 +73,6 @@ def generate_answer(self, query: str, context_chunks: list[dict[str, Any]]) -> s
formatted_context = self.prepare_context(context_chunks)

user_prompt = f"""
- You are an assistant for question-answering tasks related to OWASP.
- Use the following pieces of retrieved context to answer the question.
- If the question is related to OWASP then you can try to answer based on your knowledge, if you
don't know the answer, just say that you don't know.
- Try to give answer and keep the answer concise, but you really think that the response will be
longer and better you will provide more information.
- Ask for the current location if the query is related to location.
- Ask for the information you need if the query is very personalized or user-centric.
- Do not mention or refer to the word "context", "based on context", "provided information",
"Information given to me" or similar phrases in your responses.
Question: {query}
Context: {formatted_context}
Answer:
Expand Down
Loading
Loading