|
| 1 | +"""LangGraph nodes for the Agentic RAG workflow.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +import json |
| 6 | +import os |
| 7 | +from typing import Any |
| 8 | + |
| 9 | +import openai |
| 10 | +from django.core.exceptions import ObjectDoesNotExist |
| 11 | + |
| 12 | +from apps.ai.agent.tools.rag.generator import Generator |
| 13 | +from apps.ai.agent.tools.rag.retriever import Retriever |
| 14 | +from apps.ai.common.constants import ( |
| 15 | + DEFAULT_CHUNKS_RETRIEVAL_LIMIT, |
| 16 | + DEFAULT_MAX_ITERATIONS, |
| 17 | + DEFAULT_REASONING_MODEL, |
| 18 | + DEFAULT_SIMILARITY_THRESHOLD, |
| 19 | +) |
| 20 | +from apps.core.models.prompt import Prompt |
| 21 | + |
| 22 | + |
| 23 | +class AgentNodes: |
| 24 | + """Collection of LangGraph node functions with injected dependencies.""" |
| 25 | + |
| 26 | + def __init__(self) -> None: |
| 27 | + """Initialize AgentNodes.""" |
| 28 | + if not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): |
| 29 | + error_msg = "DJANGO_OPEN_AI_SECRET_KEY environment variable not set" |
| 30 | + raise ValueError(error_msg) |
| 31 | + |
| 32 | + self.openai_client = openai.OpenAI(api_key=openai_api_key) |
| 33 | + |
| 34 | + self.retriever = Retriever() |
| 35 | + self.generator = Generator() |
| 36 | + |
| 37 | + def retrieve(self, state: dict[str, Any]) -> dict[str, Any]: |
| 38 | + """Retrieve context chunks based on the query.""" |
| 39 | + if state.get("context_chunks"): |
| 40 | + return state |
| 41 | + |
| 42 | + limit = state.get("limit", DEFAULT_CHUNKS_RETRIEVAL_LIMIT) |
| 43 | + threshold = state.get("similarity_threshold", DEFAULT_SIMILARITY_THRESHOLD) |
| 44 | + query = state["query"] |
| 45 | + |
| 46 | + if "extracted_metadata" not in state: |
| 47 | + state["extracted_metadata"] = self.extract_query_metadata(query) |
| 48 | + |
| 49 | + metadata = state["extracted_metadata"] |
| 50 | + |
| 51 | + chunks = self.retriever.retrieve( |
| 52 | + query=query, |
| 53 | + limit=limit, |
| 54 | + similarity_threshold=threshold, |
| 55 | + content_types=metadata.get("entity_types"), |
| 56 | + ) |
| 57 | + |
| 58 | + filtered_chunks = self.filter_chunks_by_metadata(chunks, metadata) |
| 59 | + |
| 60 | + state["context_chunks"] = filtered_chunks[:limit] |
| 61 | + return state |
| 62 | + |
| 63 | + def generate(self, state: dict[str, Any]) -> dict[str, Any]: |
| 64 | + """Generate an answer using the retrieved context.""" |
| 65 | + iteration = state.get("iteration", 0) + 1 |
| 66 | + feedback = state.get("feedback") |
| 67 | + query = state["query"] |
| 68 | + augmented_query = ( |
| 69 | + query if not feedback else f"{query}\\n\\nRevise per feedback:\\n{feedback}" |
| 70 | + ) |
| 71 | + |
| 72 | + answer = self.generator.generate_answer( |
| 73 | + query=augmented_query, |
| 74 | + context_chunks=state.get("context_chunks", []), |
| 75 | + ) |
| 76 | + |
| 77 | + history = state.get("history", []) |
| 78 | + history.append( |
| 79 | + { |
| 80 | + "iteration": iteration, |
| 81 | + "feedback": feedback, |
| 82 | + "query": augmented_query, |
| 83 | + "answer": answer, |
| 84 | + } |
| 85 | + ) |
| 86 | + |
| 87 | + state.update( |
| 88 | + {"answer": answer, "iteration": iteration, "history": history, "feedback": None} |
| 89 | + ) |
| 90 | + return state |
| 91 | + |
| 92 | + def evaluate(self, state: dict[str, Any]) -> dict[str, Any]: |
| 93 | + """Evaluate the generated answer and decide on the next step.""" |
| 94 | + answer = state.get("answer", "") |
| 95 | + evaluation = self.call_evaluator( |
| 96 | + query=state["query"], |
| 97 | + answer=answer, |
| 98 | + context_chunks=state.get("context_chunks", []), |
| 99 | + ) |
| 100 | + |
| 101 | + history = state.get("history", []) |
| 102 | + if history: |
| 103 | + history[-1]["evaluation"] = evaluation |
| 104 | + |
| 105 | + if "missing context" in evaluation.get("justification", "").lower(): |
| 106 | + limit = state.get("limit", DEFAULT_CHUNKS_RETRIEVAL_LIMIT) * 2 |
| 107 | + threshold = state.get("similarity_threshold", DEFAULT_SIMILARITY_THRESHOLD) * 0.95 |
| 108 | + |
| 109 | + metadata = state.get("extracted_metadata", {}) |
| 110 | + |
| 111 | + new_chunks = self.retriever.retrieve( |
| 112 | + query=state["query"], |
| 113 | + limit=limit, |
| 114 | + similarity_threshold=threshold, |
| 115 | + content_types=metadata.get("entity_types"), |
| 116 | + ) |
| 117 | + |
| 118 | + filtered_chunks = self.filter_chunks_by_metadata(new_chunks, metadata) |
| 119 | + state["context_chunks"] = filtered_chunks[:limit] |
| 120 | + |
| 121 | + state["feedback"] = "Expand and refine answer using newly retrieved context." |
| 122 | + else: |
| 123 | + state["feedback"] = evaluation.get("feedback") or None |
| 124 | + |
| 125 | + state.update({"evaluation": evaluation, "history": history}) |
| 126 | + return state |
| 127 | + |
| 128 | + def route_from_evaluation(self, state: dict[str, Any]) -> str: |
| 129 | + """Route the workflow based on the evaluation result.""" |
| 130 | + evaluation = state.get("evaluation") or {} |
| 131 | + iteration = state.get("iteration", 0) |
| 132 | + if evaluation.get("complete") or iteration >= DEFAULT_MAX_ITERATIONS: |
| 133 | + return "complete" |
| 134 | + return "refine" |
| 135 | + |
| 136 | + def filter_chunks_by_metadata( |
| 137 | + self, |
| 138 | + retrieved_chunks: list[dict[str, Any]], |
| 139 | + query_metadata: dict[str, Any], |
| 140 | + ) -> list[dict[str, Any]]: |
| 141 | + """Rank and filter retrieved chunks using metadata and simple heuristics.""" |
| 142 | + if not retrieved_chunks: |
| 143 | + return [] |
| 144 | + |
| 145 | + requested_fields = query_metadata.get("requested_fields", []) |
| 146 | + query_filters = query_metadata.get("filters", {}) |
| 147 | + |
| 148 | + if not requested_fields and not query_filters: |
| 149 | + return retrieved_chunks |
| 150 | + |
| 151 | + ranked_chunks: list[tuple[dict[str, Any], float]] = [] |
| 152 | + for chunk in retrieved_chunks: |
| 153 | + relevance_score = 0.0 |
| 154 | + chunk_metadata = chunk.get("additional_context", {}) |
| 155 | + chunk_content = chunk.get("text", "").lower() |
| 156 | + |
| 157 | + for field_name in requested_fields: |
| 158 | + if chunk_metadata.get(field_name): |
| 159 | + relevance_score += 2 |
| 160 | + |
| 161 | + for filter_field, filter_value in query_filters.items(): |
| 162 | + if filter_field in chunk_metadata: |
| 163 | + metadata_value = chunk_metadata[filter_field] |
| 164 | + |
| 165 | + if isinstance(metadata_value, str) and isinstance(filter_value, str): |
| 166 | + if filter_value.lower() in metadata_value.lower(): |
| 167 | + relevance_score += 5 |
| 168 | + |
| 169 | + elif isinstance(metadata_value, list): |
| 170 | + if any( |
| 171 | + filter_value.lower() in str(item).lower() for item in metadata_value |
| 172 | + ): |
| 173 | + relevance_score += 5 |
| 174 | + |
| 175 | + elif metadata_value == filter_value: |
| 176 | + relevance_score += 5 |
| 177 | + |
| 178 | + if isinstance(filter_value, str) and filter_value.lower() in chunk_content: |
| 179 | + relevance_score += 3 |
| 180 | + |
| 181 | + if chunk_metadata: |
| 182 | + relevance_score += len(chunk_metadata) * 0.1 |
| 183 | + |
| 184 | + ranked_chunks.append((chunk, relevance_score)) |
| 185 | + |
| 186 | + ranked_chunks.sort( |
| 187 | + key=lambda entry: (entry[1], entry[0].get("similarity", 0)), reverse=True |
| 188 | + ) |
| 189 | + |
| 190 | + return [chunk for chunk, _ in ranked_chunks[:DEFAULT_CHUNKS_RETRIEVAL_LIMIT]] |
| 191 | + |
| 192 | + def extract_query_metadata(self, query: str) -> dict[str, Any]: |
| 193 | + """Extract metadata from the user's query using an LLM.""" |
| 194 | + metadata_extractor_prompt = Prompt.get_metadata_extractor_prompt() |
| 195 | + |
| 196 | + if not metadata_extractor_prompt: |
| 197 | + error_msg = "Prompt with key 'metadata-extractor-prompt' not found." |
| 198 | + raise ObjectDoesNotExist(error_msg) |
| 199 | + |
| 200 | + try: |
| 201 | + response = self.openai_client.chat.completions.create( |
| 202 | + model=DEFAULT_REASONING_MODEL, |
| 203 | + messages=[ |
| 204 | + {"role": "system", "content": metadata_extractor_prompt}, |
| 205 | + {"role": "user", "content": f"Query: {query}"}, |
| 206 | + ], |
| 207 | + max_tokens=500, |
| 208 | + temperature=0.7, |
| 209 | + ) |
| 210 | + content = response.choices[0].message.content.strip() |
| 211 | + |
| 212 | + if "```json" in content: |
| 213 | + content = content.split("```json")[1].split("```")[0].strip() |
| 214 | + elif "```" in content: |
| 215 | + content = content.split("```")[1].split("```")[0].strip() |
| 216 | + |
| 217 | + return json.loads(content) |
| 218 | + |
| 219 | + except (openai.OpenAIError, json.JSONDecodeError, ValueError): |
| 220 | + return { |
| 221 | + "requested_fields": [], |
| 222 | + "entity_types": [], |
| 223 | + "filters": {}, |
| 224 | + "intent": "general query", |
| 225 | + } |
| 226 | + |
| 227 | + def call_evaluator( |
| 228 | + self, *, query: str, answer: str, context_chunks: list[dict[str, Any]] |
| 229 | + ) -> dict[str, Any]: |
| 230 | + """Call the evaluator LLM to assess the quality of the generated answer.""" |
| 231 | + formatted_context = self.generator.prepare_context(context_chunks) |
| 232 | + evaluation_prompt = ( |
| 233 | + f"User Query:\\n{query}\\n\\n" |
| 234 | + f"Candidate Answer:\\n{answer}\\n\\n" |
| 235 | + f"Context Provided:\\n{formatted_context}\\n\\n" |
| 236 | + "Respond with the mandated JSON object." |
| 237 | + ) |
| 238 | + |
| 239 | + evaluator_system_prompt = Prompt.get_evaluator_system_prompt() |
| 240 | + |
| 241 | + if not evaluator_system_prompt: |
| 242 | + error_msg = "Prompt with key 'evaluator-system-prompt' not found." |
| 243 | + raise ObjectDoesNotExist(error_msg) |
| 244 | + |
| 245 | + try: |
| 246 | + response = self.openai_client.chat.completions.create( |
| 247 | + model=DEFAULT_REASONING_MODEL, |
| 248 | + messages=[ |
| 249 | + {"role": "system", "content": evaluator_system_prompt}, |
| 250 | + {"role": "user", "content": evaluation_prompt}, |
| 251 | + ], |
| 252 | + max_tokens=2000, |
| 253 | + temperature=0.7, |
| 254 | + ) |
| 255 | + content = response.choices[0].message.content.strip() |
| 256 | + |
| 257 | + if "```json" in content: |
| 258 | + content = content.split("```json")[1].split("```")[0].strip() |
| 259 | + elif "```" in content: |
| 260 | + content = content.split("```")[1].split("```")[0].strip() |
| 261 | + |
| 262 | + return json.loads(content) |
| 263 | + |
| 264 | + except (openai.OpenAIError, json.JSONDecodeError, ValueError): |
| 265 | + return { |
| 266 | + "complete": False, |
| 267 | + "feedback": "Evaluator error or invalid response.", |
| 268 | + "justification": "Evaluator error or invalid response.", |
| 269 | + } |
0 commit comments