From bff36bddf372c37d57c71017a86fb100290d6dd2 Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Tue, 22 Jul 2025 22:48:04 +0530 Subject: [PATCH 01/32] rag tool for agent --- backend/apps/ai/Makefile | 4 + backend/apps/ai/agent/__init__.py | 0 backend/apps/ai/agent/tools/RAG/__init__.py | 0 backend/apps/ai/agent/tools/RAG/generator.py | 118 ++++++++ backend/apps/ai/agent/tools/RAG/rag_tool.py | 81 ++++++ backend/apps/ai/agent/tools/RAG/retriever.py | 265 ++++++++++++++++++ backend/apps/ai/agent/tools/__init__.py | 0 backend/apps/ai/common/constants.py | 2 + .../ai/management/commands/ai_run_rag_tool.py | 65 +++++ 9 files changed, 535 insertions(+) create mode 100644 backend/apps/ai/agent/__init__.py create mode 100644 backend/apps/ai/agent/tools/RAG/__init__.py create mode 100644 backend/apps/ai/agent/tools/RAG/generator.py create mode 100644 backend/apps/ai/agent/tools/RAG/rag_tool.py create mode 100644 backend/apps/ai/agent/tools/RAG/retriever.py create mode 100644 backend/apps/ai/agent/tools/__init__.py create mode 100644 backend/apps/ai/management/commands/ai_run_rag_tool.py diff --git a/backend/apps/ai/Makefile b/backend/apps/ai/Makefile index a873ec69d4..cff4221abe 100644 --- a/backend/apps/ai/Makefile +++ b/backend/apps/ai/Makefile @@ -17,3 +17,7 @@ ai-create-project-chunks: ai-create-slack-message-chunks: @echo "Creating Slack message chunks" @CMD="python manage.py ai_create_slack_message_chunks" $(MAKE) exec-backend-command + +ai-run-rag-tool: + @echo "Running RAG tool" + @CMD="python manage.py ai_run_rag_tool" $(MAKE) exec-backend-command diff --git a/backend/apps/ai/agent/__init__.py b/backend/apps/ai/agent/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/apps/ai/agent/tools/RAG/__init__.py b/backend/apps/ai/agent/tools/RAG/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/apps/ai/agent/tools/RAG/generator.py b/backend/apps/ai/agent/tools/RAG/generator.py new file mode 100644 index 0000000000..c49a3d2eba --- /dev/null +++ b/backend/apps/ai/agent/tools/RAG/generator.py @@ -0,0 +1,118 @@ +"""Generator for the RAG system.""" + +import logging +import os +from typing import Any + +import openai + +logger = logging.getLogger(__name__) + + +class Generator: + """Generates answers to user queries based on retrieved context.""" + + MAX_TOKENS = 2000 + SYSTEM_PROMPT = """ +You are a helpful and professional AI assistant for the OWASP Foundation. +Your task is to answer user queries based ONLY on the provided context. +Follow these rules strictly: +1. Base your entire answer on the information given in the "CONTEXT" section. Do not use any +external knowledge unless and until it is about OWASP. +2. Do not mention or refer to the word "context", "based on context", "provided information", +"Information given to me" or similar phrases in your responses. +3. you will answer questions only related to OWASP and within the scope of OWASP. +4. Be concise and directly answer the user's query. +5. Provide the necessary link if the context contains a URL. +6. If there is any query based on location, you need to look for latitude and longitude in the +context and provide the nearest OWASP chapter based on that. +7. You can ask for more information if the query is very personalized or user-centric. +8. after trying all of the above, If the context does not contain the information or you think that +it is out of scope for OWASP, you MUST state: "please ask question related to OWASP." +""" + TEMPERATURE = 0.4 + + def __init__(self, chat_model: str = "gpt-4o"): + """Initialize the Generator. + + Args: + chat_model (str): The name of the OpenAI chat model to use for generation. + + Raises: + ValueError: If the OpenAI API key is not set. + + """ + if not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): + error_msg = "DJANGO_OPEN_AI_SECRET_KEY environment variable not set" + raise ValueError(error_msg) + + self.chat_model = chat_model + self.openai_client = openai.OpenAI(api_key=openai_api_key) + logger.info("Generator initialized with chat model: %s", self.chat_model) + + def format_context_for_prompt(self, context_chunks: list[dict[str, Any]]) -> str: + """Format the list of retrieved context chunks into a single string for the LLM. + + Args: + context_chunks: A list of chunk dictionaries from the retriever. + + Returns: + A formatted string containing the context. + + """ + if not context_chunks: + return "No context provided" + + formatted_context = [] + for i, chunk in enumerate(context_chunks): + source_name = chunk.get("source_name", f"Unknown Source {i + 1}") + text = chunk.get("text", "") + + context_block = f"Source Name: {source_name}\nContent: {text}" + formatted_context.append(context_block) + + return "\n\n---\n\n".join(formatted_context) + + def generate(self, query: str, context_chunks: list[dict[str, Any]]) -> dict[str, Any]: + """Generate an answer to the user's query using provided context chunks. + + Args: + query: The user's query text. + context_chunks: A list of context chunks retrieved by the retriever. + + Returns: + A dictionary containing the generated answer. + + """ + formatted_context = self.format_context_for_prompt(context_chunks) + + user_prompt = f""" +- You are an assistant for question-answering tasks related to OWASP. +- Use the following pieces of retrieved context to answer the question. +- If the question is related to OWASP then you can try to answer based on your knowledge, if you +don't know the answer, just say that you don't know. +- Try to give answer and keep the answer concise, but you really think that the response will be +longer and better you will provide more information. +- Ask for the current location if the query is related to location. +- Ask for the information you need if the query is very personalized or user-centric. +- Do not mention or refer to the word "context", "based on context", "provided information", +"Information given to me" or similar phrases in your responses. +Question: {query} +Context: {formatted_context} +Answer: +""" + + response = self.openai_client.chat.completions.create( + model=self.chat_model, + messages=[ + {"role": "system", "content": self.SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt}, + ], + temperature=self.TEMPERATURE, + max_tokens=self.MAX_TOKENS, + ) + final_answer = response.choices[0].message.content.strip() + + return { + "answer": final_answer, + } diff --git a/backend/apps/ai/agent/tools/RAG/rag_tool.py b/backend/apps/ai/agent/tools/RAG/rag_tool.py new file mode 100644 index 0000000000..6669ff3708 --- /dev/null +++ b/backend/apps/ai/agent/tools/RAG/rag_tool.py @@ -0,0 +1,81 @@ +"""A tool for orchestrating the components of RAG process.""" + +import logging +from typing import Any + +from apps.ai.common.constants import DEFAULT_LIMIT, DEFAULT_SIMILARITY_THRESHOLD + +from .generator import Generator +from .retriever import Retriever + +logger = logging.getLogger(__name__) + + +class RAGTool: + """Main RAG tool that orchestrates the retrieval and generation process.""" + + def __init__( + self, embedding_model: str = "text-embedding-3-small", chat_model: str = "gpt-4o" + ): + """Initialize the RAG tool. + + Args: + embedding_model (str, optional): The model to use for embeddings". + chat_model (str, optional): The model to use for chat generation. + + Raises: + ValueError: If the OpenAI API key is not set. + + """ + try: + self.retriever = Retriever(embedding_model=embedding_model) + self.generator = Generator(chat_model=chat_model) + logger.info( + "RAG Service initialized with embedding model: %s, chat model: %s", + embedding_model, + chat_model, + ) + except Exception: + logger.exception("Failed to initialize RAG Service") + + def query( + self, + question: str, + limit: int = DEFAULT_LIMIT, + similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD, + content_types: list[str] | None = None, + ) -> dict[str, Any]: + """Process a user query using the complete RAG pipeline. + + Args: + question (str): The user's question. + limit (int): Maximum number of context chunks to retrieve. + similarity_threshold (float): Minimum similarity score for retrieval. + content_types (Optional[list[str]]): Content types to filter by. + + Returns: + dict[str, Any]: A dictionary containing: + - answer (str): The generated answer + - sources (list): Source information used for the answer + - metadata (dict): Additional metadata about the query processing + + """ + logger.info("Retrieving context for query") + retrieved_chunks = self.retriever.retrieve( + query=question, + limit=limit, + similarity_threshold=similarity_threshold, + content_types=content_types, + ) + + generation_result = self.generator.generate( + query=question, context_chunks=retrieved_chunks + ) + + result = { + "answer": generation_result["answer"], + } + + logger.info("Successfully processed RAG query") + + return result diff --git a/backend/apps/ai/agent/tools/RAG/retriever.py b/backend/apps/ai/agent/tools/RAG/retriever.py new file mode 100644 index 0000000000..64123d7376 --- /dev/null +++ b/backend/apps/ai/agent/tools/RAG/retriever.py @@ -0,0 +1,265 @@ +"""Context retriever for RAG.""" + +import logging +import os +import re +from typing import Any + +import openai +from django.db.models import Q +from pgvector.django.functions import CosineDistance + +from apps.ai.common.constants import DEFAULT_LIMIT, DEFAULT_SIMILARITY_THRESHOLD +from apps.ai.models.chunk import Chunk + +logger = logging.getLogger(__name__) + + +class Retriever: + """A class for retrieving relevant text chunks for a RAG.""" + + SUPPORTED_CONTENT_TYPES = ["event", "project", "chapter", "committee", "message"] + + def __init__(self, embedding_model: str = "text-embedding-3-small"): + """Initialize the Retriever. + + Args: + embedding_model (str, optional): The OpenAI embedding model to use". + + Raises: + ValueError: If the OpenAI API key is not set. + + """ + if not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): + error_msg = "DJANGO_OPEN_AI_SECRET_KEY environment variable not set" + raise ValueError(error_msg) + + self.openai_client = openai.OpenAI(api_key=openai_api_key) + self.embedding_model = embedding_model + logger.info("Retriever initialized with embedding model: %s", self.embedding_model) + + def get_query_embedding(self, query: str) -> list[float]: + """Generate embedding for the user query. + + Args: + query: The query text. + + Returns: + A list of floats representing the query embedding. + + """ + try: + response = self.openai_client.embeddings.create( + input=[query], + model=self.embedding_model, + ) + return response.data[0].embedding + except openai.error.OpenAIError: + logger.exception("OpenAI API error") + raise + except Exception: + logger.exception("Unexpected error while generating embedding") + raise + + def get_source_name(self, content_object) -> str: + """Get the name/identifier for the content object.""" + for attr in ["name", "title", "login", "key", "summary"]: + if hasattr(content_object, attr) and getattr(content_object, attr): + return str(getattr(content_object, attr)) + + return str(content_object) + + def get_additional_context(self, content_object, content_type: str) -> dict[str, Any]: + """Get additional context information based on content type. + + Args: + content_object: The source object. + content_type: The model name of the content object. + + Returns: + A dictionary with additional context information. + + """ + context = {} + clean_content_type = content_type.split(".")[-1] if "." in content_type else content_type + + if clean_content_type == "chapter": + context.update( + { + "location": getattr(content_object, "suggested_location", None), + "region": getattr(content_object, "region", None), + "country": getattr(content_object, "country", None), + "postal_code": getattr(content_object, "postal_code", None), + "currency": getattr(content_object, "currency", None), + "meetup_group": getattr(content_object, "meetup_group", None), + "tags": getattr(content_object, "tags", []), + "topics": getattr(content_object, "topics", []), + "leaders": getattr(content_object, "leaders_raw", []), + "related_urls": getattr(content_object, "related_urls", []), + "is_active": getattr(content_object, "is_active", None), + "url": getattr(content_object, "url", None), + } + ) + elif clean_content_type == "project": + context.update( + { + "level": getattr(content_object, "level", None), + "project_type": getattr(content_object, "type", None), + "languages": getattr(content_object, "languages", []), + "topics": getattr(content_object, "topics", []), + "licenses": getattr(content_object, "licenses", []), + "tags": getattr(content_object, "tags", []), + "custom_tags": getattr(content_object, "custom_tags", []), + "stars_count": getattr(content_object, "stars_count", None), + "forks_count": getattr(content_object, "forks_count", None), + "contributors_count": getattr(content_object, "contributors_count", None), + "releases_count": getattr(content_object, "releases_count", None), + "open_issues_count": getattr(content_object, "open_issues_count", None), + "leaders": getattr(content_object, "leaders_raw", []), + "related_urls": getattr(content_object, "related_urls", []), + "created_at": getattr(content_object, "created_at", None), + "updated_at": getattr(content_object, "updated_at", None), + "released_at": getattr(content_object, "released_at", None), + "health_score": getattr(content_object, "health_score", None), + "is_active": getattr(content_object, "is_active", None), + "track_issues": getattr(content_object, "track_issues", None), + "url": getattr(content_object, "url", None), + } + ) + elif clean_content_type == "event": + context.update( + { + "start_date": getattr(content_object, "start_date", None), + "end_date": getattr(content_object, "end_date", None), + "location": getattr(content_object, "suggested_location", None), + "category": getattr(content_object, "category", None), + "latitude": getattr(content_object, "latitude", None), + "longitude": getattr(content_object, "longitude", None), + "url": getattr(content_object, "url", None), + "description": getattr(content_object, "description", None), + "summary": getattr(content_object, "summary", None), + } + ) + elif clean_content_type == "committee": + context.update( + { + "is_active": getattr(content_object, "is_active", None), + "leaders": getattr(content_object, "leaders", []), + "url": getattr(content_object, "url", None), + "description": getattr(content_object, "description", None), + "summary": getattr(content_object, "summary", None), + "tags": getattr(content_object, "tags", []), + "topics": getattr(content_object, "topics", []), + "related_urls": getattr(content_object, "related_urls", []), + } + ) + elif clean_content_type == "message": + context.update( + { + "channel": getattr(content_object.conversation, "slack_channel_id", None), + "thread_ts": getattr(content_object.parent_message, "ts", None), + "ts": getattr(content_object, "ts", None), + "user": getattr(content_object.author, "name", None), + "text": getattr(content_object, "text", None), + "attachments": getattr(content_object.raw_data, "attachments", []), + "url": getattr(content_object, "url", None), + } + ) + + return {k: v for k, v in context.items() if v is not None} + + def retrieve( + self, + query: str, + limit: int = DEFAULT_LIMIT, + similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD, + content_types: list[str] | None = None, + ) -> list[dict[str, Any]]: + """Retrieve the most relevant chunks based on vector similarity. + + Args: + query: The user's query text. + limit: The maximum number of chunks to retrieve. + similarity_threshold: The minimum similarity score (0-1). + content_types: An optional list of content types to filter by. + + Returns: + A list of dictionaries, each containing chunk text and rich metadata. + + """ + query_embedding = self.get_query_embedding(query) + + final_content_types = content_types + if not final_content_types: + final_content_types = self.extract_content_types_from_query(query) + + queryset = Chunk.objects.annotate( + similarity=1 - CosineDistance("embedding", query_embedding) + ).filter(similarity__gte=similarity_threshold) + + if final_content_types: + content_type_query = Q() + for name in final_content_types: + lower_name = name.lower() + if "." in lower_name: + app_label, model = lower_name.split(".", 1) + content_type_query |= Q( + content_type__app_label=app_label, content_type__model=model + ) + else: + content_type_query |= Q(content_type__model=lower_name) + queryset = queryset.filter(content_type_query) + + chunks = ( + queryset.select_related("content_type") + .prefetch_related("content_object") + .order_by("-similarity")[:limit] + ) + + results = [] + for chunk in chunks: + if not chunk.content_object: + logger.warning("Content object is None for chunk %s. Skipping.", chunk.id) + continue + + source_name = self.get_source_name(chunk.content_object) + additional_context = self.get_additional_context( + chunk.content_object, chunk.content_type.model + ) + + results.append( + { + "text": chunk.text, + "similarity": float(chunk.similarity), + "source_type": chunk.content_type.model, + "source_name": source_name, + "source_id": chunk.object_id, + "additional_context": additional_context, + } + ) + + return results + + def extract_content_types_from_query(self, query: str) -> list[str]: + """Scan the query for keywords matching supported content types. + + Args: + query: The user's query text. + + Returns: + A list of detected content type names. + + """ + detected_types = [] + query_words = set(re.findall(r"\b\w+\b", query.lower())) + + detected_types = [ + content_type + for content_type in self.SUPPORTED_CONTENT_TYPES + if content_type in query_words or f"{content_type}s" in query_words + ] + + if detected_types: + logger.info("Detected content type keywords in query: %s", detected_types) + + return detected_types diff --git a/backend/apps/ai/agent/tools/__init__.py b/backend/apps/ai/agent/tools/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/apps/ai/common/constants.py b/backend/apps/ai/common/constants.py index 98aa2a5a4e..925ba434f6 100644 --- a/backend/apps/ai/common/constants.py +++ b/backend/apps/ai/common/constants.py @@ -1,5 +1,7 @@ """AI app constants.""" DEFAULT_LAST_REQUEST_OFFSET_SECONDS = 2 +DEFAULT_LIMIT = 5 +DEFAULT_SIMILARITY_THRESHOLD = 0.5 DELIMITER = "\n\n" MIN_REQUEST_INTERVAL_SECONDS = 1.2 diff --git a/backend/apps/ai/management/commands/ai_run_rag_tool.py b/backend/apps/ai/management/commands/ai_run_rag_tool.py new file mode 100644 index 0000000000..577e4602d3 --- /dev/null +++ b/backend/apps/ai/management/commands/ai_run_rag_tool.py @@ -0,0 +1,65 @@ +"""A command for invoking RAG tool.""" + +from django.core.management.base import BaseCommand + +from apps.ai.agent.tools.RAG.rag_tool import RAGTool + + +class Command(BaseCommand): + help = "Test the RAGTool functionality with a sample query" + + def add_arguments(self, parser): + parser.add_argument( + "--query", + type=str, + default="What is OWASP Foundation?", + help="Query to test the RAGService", + ) + parser.add_argument( + "--limit", type=int, default=3, help="Maximum number of results to retrieve" + ) + parser.add_argument( + "--threshold", + type=float, + default=0.5, + help="Similarity threshold (0.0 to 1.0)", + ) + parser.add_argument( + "--content-types", + nargs="+", + default=None, + help="Content types to filter by (e.g., project chapter)", + ) + parser.add_argument( + "--embedding-model", + type=str, + default="text-embedding-3-small", + help="OpenAI embedding model", + ) + parser.add_argument( + "--chat-model", + type=str, + default="gpt-4o", + help="OpenAI chat model", + ) + + def handle(self, *args, **options): + rag_tool = RAGTool( + embedding_model=options["embedding_model"], + chat_model=options["chat_model"], + ) + + query = options["query"] + limit = options["limit"] + threshold = options["threshold"] + content_types = options["content_types"] + + self.stdout.write("\nProcessing query...") + result = rag_tool.query( + question=query, + limit=limit, + similarity_threshold=threshold, + content_types=content_types, + ) + + self.stdout.write(f"\nAnswer: {result['answer']}") From f254af8bd0a94a93eb44f113c817a619d92bd60f Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Tue, 22 Jul 2025 23:40:08 +0530 Subject: [PATCH 02/32] code rabbit suggestions implemented --- backend/apps/ai/agent/tools/RAG/generator.py | 26 ++++++++++++------- backend/apps/ai/agent/tools/RAG/rag_tool.py | 10 ++----- backend/apps/ai/agent/tools/RAG/retriever.py | 24 ++++++++++++----- .../ai/management/commands/ai_run_rag_tool.py | 23 ++++++++++------ 4 files changed, 50 insertions(+), 33 deletions(-) diff --git a/backend/apps/ai/agent/tools/RAG/generator.py b/backend/apps/ai/agent/tools/RAG/generator.py index c49a3d2eba..ae9d5ac743 100644 --- a/backend/apps/ai/agent/tools/RAG/generator.py +++ b/backend/apps/ai/agent/tools/RAG/generator.py @@ -102,16 +102,22 @@ def generate(self, query: str, context_chunks: list[dict[str, Any]]) -> dict[str Answer: """ - response = self.openai_client.chat.completions.create( - model=self.chat_model, - messages=[ - {"role": "system", "content": self.SYSTEM_PROMPT}, - {"role": "user", "content": user_prompt}, - ], - temperature=self.TEMPERATURE, - max_tokens=self.MAX_TOKENS, - ) - final_answer = response.choices[0].message.content.strip() + try: + response = self.openai_client.chat.completions.create( + model=self.chat_model, + messages=[ + {"role": "system", "content": self.SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt}, + ], + temperature=self.TEMPERATURE, + max_tokens=self.MAX_TOKENS, + ) + final_answer = response.choices[0].message.content.strip() + except openai.OpenAIError: + logger.exception("OpenAI API error") + return { + "answer": "I'm sorry, I'm currently unable to process your request.", + } return { "answer": final_answer, diff --git a/backend/apps/ai/agent/tools/RAG/rag_tool.py b/backend/apps/ai/agent/tools/RAG/rag_tool.py index 6669ff3708..0d0237fb17 100644 --- a/backend/apps/ai/agent/tools/RAG/rag_tool.py +++ b/backend/apps/ai/agent/tools/RAG/rag_tool.py @@ -30,13 +30,9 @@ def __init__( try: self.retriever = Retriever(embedding_model=embedding_model) self.generator = Generator(chat_model=chat_model) - logger.info( - "RAG Service initialized with embedding model: %s, chat model: %s", - embedding_model, - chat_model, - ) except Exception: - logger.exception("Failed to initialize RAG Service") + logger.exception("Failed to initialize RAG tool") + raise def query( self, @@ -56,8 +52,6 @@ def query( Returns: dict[str, Any]: A dictionary containing: - answer (str): The generated answer - - sources (list): Source information used for the answer - - metadata (dict): Additional metadata about the query processing """ logger.info("Retrieving context for query") diff --git a/backend/apps/ai/agent/tools/RAG/retriever.py b/backend/apps/ai/agent/tools/RAG/retriever.py index 64123d7376..01a5347dec 100644 --- a/backend/apps/ai/agent/tools/RAG/retriever.py +++ b/backend/apps/ai/agent/tools/RAG/retriever.py @@ -24,7 +24,7 @@ def __init__(self, embedding_model: str = "text-embedding-3-small"): """Initialize the Retriever. Args: - embedding_model (str, optional): The OpenAI embedding model to use". + embedding_model (str, optional): The OpenAI embedding model to use" Raises: ValueError: If the OpenAI API key is not set. @@ -156,13 +156,23 @@ def get_additional_context(self, content_object, content_type: str) -> dict[str, elif clean_content_type == "message": context.update( { - "channel": getattr(content_object.conversation, "slack_channel_id", None), - "thread_ts": getattr(content_object.parent_message, "ts", None), + "channel": ( + getattr(content_object.conversation, "slack_channel_id", None) + if hasattr(content_object, "conversation") and content_object.conversation + else None + ), + "thread_ts": ( + getattr(content_object.parent_message, "ts", None) + if hasattr(content_object, "parent_message") + and content_object.parent_message + else None + ), "ts": getattr(content_object, "ts", None), - "user": getattr(content_object.author, "name", None), - "text": getattr(content_object, "text", None), - "attachments": getattr(content_object.raw_data, "attachments", []), - "url": getattr(content_object, "url", None), + "user": ( + getattr(content_object.author, "name", None) + if hasattr(content_object, "author") and content_object.author + else None + ), } ) diff --git a/backend/apps/ai/management/commands/ai_run_rag_tool.py b/backend/apps/ai/management/commands/ai_run_rag_tool.py index 577e4602d3..ad5c0d8eda 100644 --- a/backend/apps/ai/management/commands/ai_run_rag_tool.py +++ b/backend/apps/ai/management/commands/ai_run_rag_tool.py @@ -3,6 +3,7 @@ from django.core.management.base import BaseCommand from apps.ai.agent.tools.RAG.rag_tool import RAGTool +from apps.ai.common.constants import DEFAULT_LIMIT, DEFAULT_SIMILARITY_THRESHOLD class Command(BaseCommand): @@ -13,15 +14,18 @@ def add_arguments(self, parser): "--query", type=str, default="What is OWASP Foundation?", - help="Query to test the RAGService", + help="Query to test the RAG tool", ) parser.add_argument( - "--limit", type=int, default=3, help="Maximum number of results to retrieve" + "--limit", + type=int, + default=DEFAULT_LIMIT, + help="Maximum number of results to retrieve", ) parser.add_argument( "--threshold", type=float, - default=0.5, + default=DEFAULT_SIMILARITY_THRESHOLD, help="Similarity threshold (0.0 to 1.0)", ) parser.add_argument( @@ -44,10 +48,14 @@ def add_arguments(self, parser): ) def handle(self, *args, **options): - rag_tool = RAGTool( - embedding_model=options["embedding_model"], - chat_model=options["chat_model"], - ) + try: + rag_tool = RAGTool( + embedding_model=options["embedding_model"], + chat_model=options["chat_model"], + ) + except ValueError: + self.stderr.write(self.style.ERROR("Initialization error")) + return query = options["query"] limit = options["limit"] @@ -61,5 +69,4 @@ def handle(self, *args, **options): similarity_threshold=threshold, content_types=content_types, ) - self.stdout.write(f"\nAnswer: {result['answer']}") From ff45de193b12b4e735f99b9e166d244561d8e4c2 Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Fri, 25 Jul 2025 20:51:53 +0530 Subject: [PATCH 03/32] suggestions implemented --- .../ai/agent/tools/{RAG => rag}/__init__.py | 0 .../ai/agent/tools/{RAG => rag}/generator.py | 16 ++++------ .../ai/agent/tools/{RAG => rag}/rag_tool.py | 18 +++++------ .../ai/agent/tools/{RAG => rag}/retriever.py | 20 +++++++------ backend/apps/ai/common/constants.py | 2 +- .../ai/management/commands/ai_run_rag_tool.py | 30 +++++++++---------- 6 files changed, 40 insertions(+), 46 deletions(-) rename backend/apps/ai/agent/tools/{RAG => rag}/__init__.py (100%) rename backend/apps/ai/agent/tools/{RAG => rag}/generator.py (89%) rename backend/apps/ai/agent/tools/{RAG => rag}/rag_tool.py (83%) rename backend/apps/ai/agent/tools/{RAG => rag}/retriever.py (95%) diff --git a/backend/apps/ai/agent/tools/RAG/__init__.py b/backend/apps/ai/agent/tools/rag/__init__.py similarity index 100% rename from backend/apps/ai/agent/tools/RAG/__init__.py rename to backend/apps/ai/agent/tools/rag/__init__.py diff --git a/backend/apps/ai/agent/tools/RAG/generator.py b/backend/apps/ai/agent/tools/rag/generator.py similarity index 89% rename from backend/apps/ai/agent/tools/RAG/generator.py rename to backend/apps/ai/agent/tools/rag/generator.py index ae9d5ac743..07cf9b9d8f 100644 --- a/backend/apps/ai/agent/tools/RAG/generator.py +++ b/backend/apps/ai/agent/tools/rag/generator.py @@ -50,7 +50,7 @@ def __init__(self, chat_model: str = "gpt-4o"): self.openai_client = openai.OpenAI(api_key=openai_api_key) logger.info("Generator initialized with chat model: %s", self.chat_model) - def format_context_for_prompt(self, context_chunks: list[dict[str, Any]]) -> str: + def prepare_context(self, context_chunks: list[dict[str, Any]]) -> str: """Format the list of retrieved context chunks into a single string for the LLM. Args: @@ -73,7 +73,7 @@ def format_context_for_prompt(self, context_chunks: list[dict[str, Any]]) -> str return "\n\n---\n\n".join(formatted_context) - def generate(self, query: str, context_chunks: list[dict[str, Any]]) -> dict[str, Any]: + def generate_answer(self, query: str, context_chunks: list[dict[str, Any]]) -> dict[str, Any]: """Generate an answer to the user's query using provided context chunks. Args: @@ -84,7 +84,7 @@ def generate(self, query: str, context_chunks: list[dict[str, Any]]) -> dict[str A dictionary containing the generated answer. """ - formatted_context = self.format_context_for_prompt(context_chunks) + formatted_context = self.prepare_context(context_chunks) user_prompt = f""" - You are an assistant for question-answering tasks related to OWASP. @@ -112,13 +112,9 @@ def generate(self, query: str, context_chunks: list[dict[str, Any]]) -> dict[str temperature=self.TEMPERATURE, max_tokens=self.MAX_TOKENS, ) - final_answer = response.choices[0].message.content.strip() + answer = response.choices[0].message.content.strip() except openai.OpenAIError: logger.exception("OpenAI API error") - return { - "answer": "I'm sorry, I'm currently unable to process your request.", - } + answer = "I'm sorry, I'm currently unable to process your request." - return { - "answer": final_answer, - } + return answer diff --git a/backend/apps/ai/agent/tools/RAG/rag_tool.py b/backend/apps/ai/agent/tools/rag/rag_tool.py similarity index 83% rename from backend/apps/ai/agent/tools/RAG/rag_tool.py rename to backend/apps/ai/agent/tools/rag/rag_tool.py index 0d0237fb17..000a95c876 100644 --- a/backend/apps/ai/agent/tools/RAG/rag_tool.py +++ b/backend/apps/ai/agent/tools/rag/rag_tool.py @@ -3,7 +3,7 @@ import logging from typing import Any -from apps.ai.common.constants import DEFAULT_LIMIT, DEFAULT_SIMILARITY_THRESHOLD +from apps.ai.common.constants import DEFAULT_CHUNKS_RETRIEVAL_LIMIT, DEFAULT_SIMILARITY_THRESHOLD from .generator import Generator from .retriever import Retriever @@ -11,11 +11,13 @@ logger = logging.getLogger(__name__) -class RAGTool: +class RagTool: """Main RAG tool that orchestrates the retrieval and generation process.""" def __init__( - self, embedding_model: str = "text-embedding-3-small", chat_model: str = "gpt-4o" + self, + embedding_model: str = "text-embedding-3-small", + chat_model: str = "gpt-4o", ): """Initialize the RAG tool. @@ -37,7 +39,7 @@ def __init__( def query( self, question: str, - limit: int = DEFAULT_LIMIT, + limit: int = DEFAULT_CHUNKS_RETRIEVAL_LIMIT, similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD, content_types: list[str] | None = None, ) -> dict[str, Any]: @@ -62,14 +64,10 @@ def query( content_types=content_types, ) - generation_result = self.generator.generate( + generation_result = self.generator.generate_answer( query=question, context_chunks=retrieved_chunks ) - result = { - "answer": generation_result["answer"], - } - logger.info("Successfully processed RAG query") - return result + return generation_result diff --git a/backend/apps/ai/agent/tools/RAG/retriever.py b/backend/apps/ai/agent/tools/rag/retriever.py similarity index 95% rename from backend/apps/ai/agent/tools/RAG/retriever.py rename to backend/apps/ai/agent/tools/rag/retriever.py index 01a5347dec..252d551bb0 100644 --- a/backend/apps/ai/agent/tools/RAG/retriever.py +++ b/backend/apps/ai/agent/tools/rag/retriever.py @@ -9,7 +9,10 @@ from django.db.models import Q from pgvector.django.functions import CosineDistance -from apps.ai.common.constants import DEFAULT_LIMIT, DEFAULT_SIMILARITY_THRESHOLD +from apps.ai.common.constants import ( + DEFAULT_CHUNKS_RETRIEVAL_LIMIT, + DEFAULT_SIMILARITY_THRESHOLD, +) from apps.ai.models.chunk import Chunk logger = logging.getLogger(__name__) @@ -63,8 +66,8 @@ def get_query_embedding(self, query: str) -> list[float]: def get_source_name(self, content_object) -> str: """Get the name/identifier for the content object.""" - for attr in ["name", "title", "login", "key", "summary"]: - if hasattr(content_object, attr) and getattr(content_object, attr): + for attr in ("name", "title", "login", "key", "summary"): + if getattr(content_object, attr, None): return str(getattr(content_object, attr)) return str(content_object) @@ -181,7 +184,7 @@ def get_additional_context(self, content_object, content_type: str) -> dict[str, def retrieve( self, query: str, - limit: int = DEFAULT_LIMIT, + limit: int = DEFAULT_CHUNKS_RETRIEVAL_LIMIT, similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD, content_types: list[str] | None = None, ) -> list[dict[str, Any]]: @@ -199,17 +202,16 @@ def retrieve( """ query_embedding = self.get_query_embedding(query) - final_content_types = content_types - if not final_content_types: - final_content_types = self.extract_content_types_from_query(query) + if not content_types: + content_types = self.extract_content_types_from_query(query) queryset = Chunk.objects.annotate( similarity=1 - CosineDistance("embedding", query_embedding) ).filter(similarity__gte=similarity_threshold) - if final_content_types: + if content_types: content_type_query = Q() - for name in final_content_types: + for name in content_types: lower_name = name.lower() if "." in lower_name: app_label, model = lower_name.split(".", 1) diff --git a/backend/apps/ai/common/constants.py b/backend/apps/ai/common/constants.py index 925ba434f6..207b53599c 100644 --- a/backend/apps/ai/common/constants.py +++ b/backend/apps/ai/common/constants.py @@ -1,7 +1,7 @@ """AI app constants.""" DEFAULT_LAST_REQUEST_OFFSET_SECONDS = 2 -DEFAULT_LIMIT = 5 +DEFAULT_CHUNKS_RETRIEVAL_LIMIT = 5 DEFAULT_SIMILARITY_THRESHOLD = 0.5 DELIMITER = "\n\n" MIN_REQUEST_INTERVAL_SECONDS = 1.2 diff --git a/backend/apps/ai/management/commands/ai_run_rag_tool.py b/backend/apps/ai/management/commands/ai_run_rag_tool.py index ad5c0d8eda..359f6d4cbe 100644 --- a/backend/apps/ai/management/commands/ai_run_rag_tool.py +++ b/backend/apps/ai/management/commands/ai_run_rag_tool.py @@ -2,24 +2,27 @@ from django.core.management.base import BaseCommand -from apps.ai.agent.tools.RAG.rag_tool import RAGTool -from apps.ai.common.constants import DEFAULT_LIMIT, DEFAULT_SIMILARITY_THRESHOLD +from apps.ai.agent.tools.rag.rag_tool import RagTool +from apps.ai.common.constants import ( + DEFAULT_CHUNKS_RETRIEVAL_LIMIT, + DEFAULT_SIMILARITY_THRESHOLD, +) class Command(BaseCommand): - help = "Test the RAGTool functionality with a sample query" + help = "Test the RagTool functionality with a sample query" def add_arguments(self, parser): parser.add_argument( "--query", type=str, default="What is OWASP Foundation?", - help="Query to test the RAG tool", + help="Query to test the Rag tool", ) parser.add_argument( "--limit", type=int, - default=DEFAULT_LIMIT, + default=DEFAULT_CHUNKS_RETRIEVAL_LIMIT, help="Maximum number of results to retrieve", ) parser.add_argument( @@ -49,7 +52,7 @@ def add_arguments(self, parser): def handle(self, *args, **options): try: - rag_tool = RAGTool( + rag_tool = RagTool( embedding_model=options["embedding_model"], chat_model=options["chat_model"], ) @@ -57,16 +60,11 @@ def handle(self, *args, **options): self.stderr.write(self.style.ERROR("Initialization error")) return - query = options["query"] - limit = options["limit"] - threshold = options["threshold"] - content_types = options["content_types"] - self.stdout.write("\nProcessing query...") result = rag_tool.query( - question=query, - limit=limit, - similarity_threshold=threshold, - content_types=content_types, + question=options["query"], + limit=options["limit"], + similarity_threshold=options["threshold"], + content_types=options["content_types"], ) - self.stdout.write(f"\nAnswer: {result['answer']}") + self.stdout.write(f"\nAnswer: {result}") From b2c5b59aef308597b8095f72e82c4012b6d1b978 Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Fri, 25 Jul 2025 23:07:42 +0530 Subject: [PATCH 04/32] code rabbit suggestion --- backend/apps/ai/agent/tools/rag/generator.py | 4 ++-- backend/apps/ai/agent/tools/rag/rag_tool.py | 5 ++--- backend/apps/ai/agent/tools/rag/retriever.py | 4 ++-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/backend/apps/ai/agent/tools/rag/generator.py b/backend/apps/ai/agent/tools/rag/generator.py index 07cf9b9d8f..b4cc685e87 100644 --- a/backend/apps/ai/agent/tools/rag/generator.py +++ b/backend/apps/ai/agent/tools/rag/generator.py @@ -73,7 +73,7 @@ def prepare_context(self, context_chunks: list[dict[str, Any]]) -> str: return "\n\n---\n\n".join(formatted_context) - def generate_answer(self, query: str, context_chunks: list[dict[str, Any]]) -> dict[str, Any]: + def generate_answer(self, query: str, context_chunks: list[dict[str, Any]]) -> str: """Generate an answer to the user's query using provided context chunks. Args: @@ -81,7 +81,7 @@ def generate_answer(self, query: str, context_chunks: list[dict[str, Any]]) -> d context_chunks: A list of context chunks retrieved by the retriever. Returns: - A dictionary containing the generated answer. + The generated answer as a string. """ formatted_context = self.prepare_context(context_chunks) diff --git a/backend/apps/ai/agent/tools/rag/rag_tool.py b/backend/apps/ai/agent/tools/rag/rag_tool.py index 000a95c876..879d906563 100644 --- a/backend/apps/ai/agent/tools/rag/rag_tool.py +++ b/backend/apps/ai/agent/tools/rag/rag_tool.py @@ -1,7 +1,6 @@ """A tool for orchestrating the components of RAG process.""" import logging -from typing import Any from apps.ai.common.constants import DEFAULT_CHUNKS_RETRIEVAL_LIMIT, DEFAULT_SIMILARITY_THRESHOLD @@ -22,7 +21,7 @@ def __init__( """Initialize the RAG tool. Args: - embedding_model (str, optional): The model to use for embeddings". + embedding_model (str, optional): The model to use for embeddings. chat_model (str, optional): The model to use for chat generation. Raises: @@ -42,7 +41,7 @@ def query( limit: int = DEFAULT_CHUNKS_RETRIEVAL_LIMIT, similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD, content_types: list[str] | None = None, - ) -> dict[str, Any]: + ) -> str: """Process a user query using the complete RAG pipeline. Args: diff --git a/backend/apps/ai/agent/tools/rag/retriever.py b/backend/apps/ai/agent/tools/rag/retriever.py index 252d551bb0..a4ed638ef6 100644 --- a/backend/apps/ai/agent/tools/rag/retriever.py +++ b/backend/apps/ai/agent/tools/rag/retriever.py @@ -27,7 +27,7 @@ def __init__(self, embedding_model: str = "text-embedding-3-small"): """Initialize the Retriever. Args: - embedding_model (str, optional): The OpenAI embedding model to use" + embedding_model (str, optional): The OpenAI embedding model to use. Raises: ValueError: If the OpenAI API key is not set. @@ -57,7 +57,7 @@ def get_query_embedding(self, query: str) -> list[float]: model=self.embedding_model, ) return response.data[0].embedding - except openai.error.OpenAIError: + except openai.OpenAIError: logger.exception("OpenAI API error") raise except Exception: From e12096285f92fc0c79a256efc762a33e611174a6 Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Mon, 28 Jul 2025 15:10:44 +0530 Subject: [PATCH 05/32] added context model --- backend/apps/ai/admin.py | 21 +- backend/apps/ai/agent/tools/rag/retriever.py | 14 +- backend/apps/ai/common/utils.py | 9 +- ..._unique_together_chunk_context_and_more.py | 70 +++++++ backend/apps/ai/models/chunk.py | 36 +--- backend/apps/ai/models/context.py | 33 +++ backend/tests/apps/ai/models/chunk_test.py | 104 +++------ backend/tests/apps/ai/models/context_test.py | 197 ++++++++++++++++++ 8 files changed, 377 insertions(+), 107 deletions(-) create mode 100644 backend/apps/ai/migrations/0005_context_alter_chunk_unique_together_chunk_context_and_more.py create mode 100644 backend/apps/ai/models/context.py create mode 100644 backend/tests/apps/ai/models/context_test.py diff --git a/backend/apps/ai/admin.py b/backend/apps/ai/admin.py index a7240e8115..1ce5b2e8a8 100644 --- a/backend/apps/ai/admin.py +++ b/backend/apps/ai/admin.py @@ -3,6 +3,21 @@ from django.contrib import admin from apps.ai.models.chunk import Chunk +from apps.ai.models.context import Context + + +class ContextAdmin(admin.ModelAdmin): + """Admin for Context model.""" + + list_display = ( + "id", + "generated_text", + "content_type", + "object_id", + "source", + ) + search_fields = ("generated_text", "source") + list_filter = ("content_type", "source") class ChunkAdmin(admin.ModelAdmin): @@ -11,9 +26,11 @@ class ChunkAdmin(admin.ModelAdmin): list_display = ( "id", "text", - "content_type", + "context", ) - search_fields = ("text", "object_id") + search_fields = ("text",) + list_filter = ("context__content_type",) +admin.site.register(Context, ContextAdmin) admin.site.register(Chunk, ChunkAdmin) diff --git a/backend/apps/ai/agent/tools/rag/retriever.py b/backend/apps/ai/agent/tools/rag/retriever.py index a4ed638ef6..1cd10bf5df 100644 --- a/backend/apps/ai/agent/tools/rag/retriever.py +++ b/backend/apps/ai/agent/tools/rag/retriever.py @@ -223,29 +223,29 @@ def retrieve( queryset = queryset.filter(content_type_query) chunks = ( - queryset.select_related("content_type") - .prefetch_related("content_object") + queryset.select_related("context__content_type") + .prefetch_related("context__content_object") .order_by("-similarity")[:limit] ) results = [] for chunk in chunks: - if not chunk.content_object: + if not chunk.context or not chunk.context.content_object: logger.warning("Content object is None for chunk %s. Skipping.", chunk.id) continue - source_name = self.get_source_name(chunk.content_object) + source_name = self.get_source_name(chunk.context.content_object) additional_context = self.get_additional_context( - chunk.content_object, chunk.content_type.model + chunk.context.content_object, chunk.context.content_type.model ) results.append( { "text": chunk.text, "similarity": float(chunk.similarity), - "source_type": chunk.content_type.model, + "source_type": chunk.context.content_type.model, "source_name": source_name, - "source_id": chunk.object_id, + "source_id": chunk.context.object_id, "additional_context": additional_context, } ) diff --git a/backend/apps/ai/common/utils.py b/backend/apps/ai/common/utils.py index c0824760a1..2cb4713679 100644 --- a/backend/apps/ai/common/utils.py +++ b/backend/apps/ai/common/utils.py @@ -9,6 +9,7 @@ MIN_REQUEST_INTERVAL_SECONDS, ) from apps.ai.models.chunk import Chunk +from apps.ai.models.context import Context logger: logging.Logger = logging.getLogger(__name__) @@ -43,6 +44,12 @@ def create_chunks_and_embeddings( model="text-embedding-3-small", ) + context = Context( + generated_text="\n".join(all_chunk_texts), + content_object=content_object, + ) + context.save() + return [ chunk for text, embedding in zip( @@ -53,7 +60,7 @@ def create_chunks_and_embeddings( if ( chunk := Chunk.update_data( text=text, - content_object=content_object, + context=context, embedding=embedding, save=False, ) diff --git a/backend/apps/ai/migrations/0005_context_alter_chunk_unique_together_chunk_context_and_more.py b/backend/apps/ai/migrations/0005_context_alter_chunk_unique_together_chunk_context_and_more.py new file mode 100644 index 0000000000..38cd76b854 --- /dev/null +++ b/backend/apps/ai/migrations/0005_context_alter_chunk_unique_together_chunk_context_and_more.py @@ -0,0 +1,70 @@ +# Generated by Django 5.2.4 on 2025-07-28 09:00 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("ai", "0004_alter_chunk_unique_together_chunk_content_type_and_more"), + ("contenttypes", "0002_remove_content_type_name"), + ] + + operations = [ + migrations.CreateModel( + name="Context", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, primary_key=True, serialize=False, verbose_name="ID" + ), + ), + ("nest_created_at", models.DateTimeField(auto_now_add=True)), + ("nest_updated_at", models.DateTimeField(auto_now=True)), + ("generated_text", models.TextField(verbose_name="Generated Text")), + ("object_id", models.PositiveIntegerField(default=0)), + ("source", models.CharField(blank=True, default="", max_length=100)), + ( + "content_type", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + to="contenttypes.contenttype", + ), + ), + ], + options={ + "verbose_name": "Context", + "db_table": "ai_contexts", + }, + ), + migrations.AlterUniqueTogether( + name="chunk", + unique_together=set(), + ), + migrations.AddField( + model_name="chunk", + name="context", + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="chunks", + to="ai.context", + ), + ), + migrations.AlterUniqueTogether( + name="chunk", + unique_together={("context", "text")}, + ), + migrations.RemoveField( + model_name="chunk", + name="content_type", + ), + migrations.RemoveField( + model_name="chunk", + name="object_id", + ), + ] diff --git a/backend/apps/ai/models/chunk.py b/backend/apps/ai/models/chunk.py index 8362948ffe..ceb651c321 100644 --- a/backend/apps/ai/models/chunk.py +++ b/backend/apps/ai/models/chunk.py @@ -1,11 +1,10 @@ """AI app chunk model.""" -from django.contrib.contenttypes.fields import GenericForeignKey -from django.contrib.contenttypes.models import ContentType from django.db import models from langchain.text_splitter import RecursiveCharacterTextSplitter from pgvector.django import VectorField +from apps.ai.models.context import Context from apps.common.models import BulkSaveModel, TimestampedModel from apps.common.utils import truncate @@ -16,25 +15,18 @@ class Chunk(TimestampedModel): class Meta: db_table = "ai_chunks" verbose_name = "Chunk" - unique_together = ("content_type", "object_id", "text") + unique_together = ("context", "text") - content_object = GenericForeignKey("content_type", "object_id") - content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE, blank=True, null=True) + context = models.ForeignKey( + Context, on_delete=models.CASCADE, related_name="chunks", null=True, blank=True + ) embedding = VectorField(verbose_name="Embedding", dimensions=1536) - object_id = models.PositiveIntegerField(default=0) text = models.TextField(verbose_name="Text") def __str__(self): """Human readable representation.""" - content_name = ( - getattr(self.content_object, "name", None) - or getattr(self.content_object, "key", None) - or str(self.content_object) - ) - return ( - f"Chunk {self.id} for {self.content_type.model} {content_name}: " - f"{truncate(self.text, 50)}" - ) + context_str = str(self.context) if self.context else "No Context" + return f"Chunk {self.id} for {context_str}: {truncate(self.text, 50)}" @staticmethod def bulk_save(chunks, fields=None): @@ -54,7 +46,7 @@ def split_text(text: str) -> list[str]: @staticmethod def update_data( text: str, - content_object, + context: Context, embedding, *, save: bool = True, @@ -63,7 +55,7 @@ def update_data( Args: text (str): The text content of the chunk. - content_object: The object this chunk belongs to (Message, Chapter, etc.). + context (Context): The context this chunk belongs to. embedding (list): The embedding vector for the chunk. save (bool): Whether to save the chunk to the database. @@ -71,16 +63,10 @@ def update_data( Chunk: The updated chunk instance or None if it already exists. """ - content_type = ContentType.objects.get_for_model(content_object) - - if Chunk.objects.filter( - content_type=content_type, object_id=content_object.id, text=text - ).exists(): + if Chunk.objects.filter(context=context, text=text).exists(): return None - chunk = Chunk( - content_type=content_type, object_id=content_object.id, text=text, embedding=embedding - ) + chunk = Chunk(context=context, text=text, embedding=embedding) if save: chunk.save() diff --git a/backend/apps/ai/models/context.py b/backend/apps/ai/models/context.py new file mode 100644 index 0000000000..8de1579ca3 --- /dev/null +++ b/backend/apps/ai/models/context.py @@ -0,0 +1,33 @@ +"""AI app context model.""" + +from django.contrib.contenttypes.fields import GenericForeignKey +from django.contrib.contenttypes.models import ContentType +from django.db import models + +from apps.common.models import TimestampedModel + + +class Context(TimestampedModel): + """Context model for storing generated text and optional relation to OWASP entities.""" + + generated_text = models.TextField(verbose_name="Generated Text") + content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE, blank=True, null=True) + object_id = models.PositiveIntegerField(default=0) + content_object = GenericForeignKey("content_type", "object_id") + source = models.CharField(max_length=100, blank=True, default="") + + class Meta: + db_table = "ai_contexts" + verbose_name = "Context" + + def __str__(self): + """Human readable representation.""" + entity = ( + getattr(self.content_object, "name", None) + or getattr(self.content_object, "key", None) + or str(self.content_object) + ) + return ( + f"{self.content_type.model if self.content_type else 'None'} {entity}: " + f"{self.generated_text[:50]}" + ) diff --git a/backend/tests/apps/ai/models/chunk_test.py b/backend/tests/apps/ai/models/chunk_test.py index dca223f800..9023377f7c 100644 --- a/backend/tests/apps/ai/models/chunk_test.py +++ b/backend/tests/apps/ai/models/chunk_test.py @@ -1,10 +1,7 @@ from unittest.mock import Mock, patch -from django.contrib.contenttypes.models import ContentType -from django.db import models - from apps.ai.models.chunk import Chunk -from apps.slack.models.message import Message +from apps.ai.models.context import Context def create_model_mock(model_class): @@ -17,23 +14,16 @@ def create_model_mock(model_class): class TestChunkModel: def test_str_method(self): - mock_message = create_model_mock(Message) - mock_message.name = "Test Message" - - mock_content_type = Mock(spec=ContentType) - mock_content_type.model = "message" - - with ( - patch.object(Chunk, "content_type", mock_content_type), - patch.object(Chunk, "content_object", mock_message), - ): - chunk = Chunk() - chunk.id = 1 - chunk.text = "This is a test chunk with some content that should be displayed" - - result = str(chunk) - assert "Chunk 1 for message Test Message:" in result - assert "This is a test chunk with some content that" in result + mock_context = Mock(spec=Context) + mock_context.__str__ = Mock(return_value="Context 1 for message Test Message: ...") + mock_context._state = Mock() + chunk = Chunk() + chunk.id = 1 + chunk.text = "This is a test chunk with some content that should be displayed" + chunk.context = mock_context + result = str(chunk) + assert "Chunk 1 for Context 1 for message Test Message:" in result + assert "This is a test chunk with some content that" in result def test_bulk_save_with_chunks(self): mock_chunks = [Mock(), Mock(), Mock()] @@ -64,32 +54,21 @@ def test_split_text(self): def test_update_data_new_chunk(self, mock_init, mock_save, mocker): mock_init.return_value = None - mock_message = create_model_mock(Message) + mock_context = Mock(spec=Context) + mock_context._state = Mock() text = "Test chunk content" embedding = [0.1, 0.2, 0.3] - mock_content_type = Mock(spec=ContentType) - mock_get_for_model = mocker.patch( - "django.contrib.contenttypes.models.ContentType.objects.get_for_model", - return_value=mock_content_type, - ) - mock_filter = mocker.patch( "apps.ai.models.chunk.Chunk.objects.filter", return_value=Mock(exists=Mock(return_value=False)), ) - result = Chunk.update_data( - text=text, content_object=mock_message, embedding=embedding, save=True - ) + result = Chunk.update_data(text=text, context=mock_context, embedding=embedding, save=True) - mock_get_for_model.assert_called_once_with(mock_message) - mock_filter.assert_called_once_with( - content_type=mock_content_type, object_id=mock_message.id, text=text - ) + mock_filter.assert_called_once_with(context=mock_context, text=text) mock_init.assert_called_once_with( - content_type=mock_content_type, - object_id=mock_message.id, + context=mock_context, text=text, embedding=embedding, ) @@ -99,29 +78,19 @@ def test_update_data_new_chunk(self, mock_init, mock_save, mocker): assert isinstance(result, Chunk) def test_update_data_existing_chunk(self, mocker): - mock_message = create_model_mock(Message) + mock_context = Mock(spec=Context) + mock_context._state = Mock() text = "Existing chunk content" embedding = [0.1, 0.2, 0.3] - mock_content_type = Mock(spec=ContentType) - mock_get_for_model = mocker.patch( - "django.contrib.contenttypes.models.ContentType.objects.get_for_model", - return_value=mock_content_type, - ) - mock_filter = mocker.patch( "apps.ai.models.chunk.Chunk.objects.filter", return_value=Mock(exists=Mock(return_value=True)), ) - result = Chunk.update_data( - text=text, content_object=mock_message, embedding=embedding, save=True - ) + result = Chunk.update_data(text=text, context=mock_context, embedding=embedding, save=True) - mock_get_for_model.assert_called_once_with(mock_message) - mock_filter.assert_called_once_with( - content_type=mock_content_type, object_id=mock_message.id, text=text - ) + mock_filter.assert_called_once_with(context=mock_context, text=text) assert result is None @patch("apps.ai.models.chunk.Chunk.save") @@ -129,32 +98,23 @@ def test_update_data_existing_chunk(self, mocker): def test_update_data_no_save(self, mock_init, mock_save, mocker): mock_init.return_value = None - mock_message = create_model_mock(Message) + mock_context = Mock(spec=Context) + mock_context._state = Mock() text = "Test chunk content" embedding = [0.1, 0.2, 0.3] - mock_content_type = Mock(spec=ContentType) - mock_get_for_model = mocker.patch( - "django.contrib.contenttypes.models.ContentType.objects.get_for_model", - return_value=mock_content_type, - ) - mock_filter = mocker.patch( "apps.ai.models.chunk.Chunk.objects.filter", return_value=Mock(exists=Mock(return_value=False)), ) result = Chunk.update_data( - text=text, content_object=mock_message, embedding=embedding, save=False + text=text, context=mock_context, embedding=embedding, save=False ) - mock_get_for_model.assert_called_once_with(mock_message) - mock_filter.assert_called_once_with( - content_type=mock_content_type, object_id=mock_message.id, text=text - ) + mock_filter.assert_called_once_with(context=mock_context, text=text) mock_init.assert_called_once_with( - content_type=mock_content_type, - object_id=mock_message.id, + context=mock_context, text=text, embedding=embedding, ) @@ -166,12 +126,12 @@ def test_update_data_no_save(self, mock_init, mock_save, mocker): def test_meta_class_attributes(self): assert Chunk._meta.db_table == "ai_chunks" assert Chunk._meta.verbose_name == "Chunk" - assert ("content_type", "object_id", "text") in Chunk._meta.unique_together + assert ("context", "text") in Chunk._meta.unique_together - def test_generic_foreign_key_relationship(self): - content_type_field = Chunk._meta.get_field("content_type") - object_id_field = Chunk._meta.get_field("object_id") + def test_context_relationship(self): + context_field = Chunk._meta.get_field("context") + from apps.ai.models.context import Context - assert isinstance(content_type_field, models.ForeignKey) - assert content_type_field.remote_field.model == ContentType - assert isinstance(object_id_field, models.PositiveIntegerField) + assert context_field.related_model == Context + assert context_field.null is True + assert context_field.blank is True diff --git a/backend/tests/apps/ai/models/context_test.py b/backend/tests/apps/ai/models/context_test.py new file mode 100644 index 0000000000..3b244c1161 --- /dev/null +++ b/backend/tests/apps/ai/models/context_test.py @@ -0,0 +1,197 @@ +"""Unit tests for AI app context model.""" + +from unittest.mock import Mock, patch + +import pytest + +from apps.ai.models.context import Context + + +def create_model_mock(model_class): + mock = Mock(spec=model_class) + mock._state = Mock() + mock.pk = 1 + mock.id = 1 + return mock + + +class TestContextModel: + def test_str_method_without_content_type(self): + context = Context() + context.id = 1 + context.generated_text = "Sample text without content type" + context.content_type = None + context.content_object = None + + result = str(context) + + assert result == "None None: Sample text without content type" + + def test_str_method_with_text_truncation(self): + long_text = "A" * 100 + + context = Context() + context.id = 1 + context.generated_text = long_text + context.content_type = None + context.content_object = None + + result = str(context) + + assert result == f"None None: {long_text[:50]}" + assert len(result.split(": ", 1)[1]) == 50 + + def test_str_method_with_exactly_50_chars(self): + text_50_chars = "A" * 50 + + context = Context() + context.id = 1 + context.generated_text = text_50_chars + context.content_type = None + context.content_object = None + + result = str(context) + + assert result == f"None None: {text_50_chars}" + assert len(result.split(": ", 1)[1]) == 50 + + def test_str_method_with_empty_text(self): + context = Context() + context.id = 1 + context.generated_text = "" + context.content_type = None + context.content_object = None + + result = str(context) + + assert result == "None None: " + + def test_meta_class_attributes(self): + assert Context._meta.db_table == "ai_contexts" + assert Context._meta.verbose_name == "Context" + + def test_generated_text_field_properties(self): + field = Context._meta.get_field("generated_text") + assert field.verbose_name == "Generated Text" + assert field.__class__.__name__ == "TextField" + + def test_content_type_field_properties(self): + field = Context._meta.get_field("content_type") + assert field.null is True + assert field.blank is True + assert hasattr(field, "remote_field") + assert field.remote_field.on_delete.__name__ == "CASCADE" + + def test_object_id_field_properties(self): + field = Context._meta.get_field("object_id") + assert field.default == 0 + assert field.__class__.__name__ == "PositiveIntegerField" + + def test_source_field_properties(self): + field = Context._meta.get_field("source") + assert field.max_length == 100 + assert field.blank is True + assert field.default == "" + + def test_content_object_generic_foreign_key(self): + field = Context._meta.get_field("content_object") + assert field.__class__.__name__ == "GenericForeignKey" + assert field.ct_field == "content_type" + assert field.fk_field == "object_id" + + @patch("apps.ai.models.context.Context.save") + @patch("apps.ai.models.context.Context.__init__") + def test_context_creation_with_save(self, mock_init, mock_save): + mock_init.return_value = None + + generated_text = "Test generated text" + source = "test_source" + + context = Context(generated_text=generated_text, source=source) + context.save() + + mock_save.assert_called_once() + + def test_context_inheritance_from_timestamped_model(self): + from apps.common.models import TimestampedModel + + assert issubclass(Context, TimestampedModel) + + @patch("apps.ai.models.context.Context.objects.create") + def test_context_manager_create(self, mock_create): + mock_context = create_model_mock(Context) + mock_create.return_value = mock_context + + generated_text = "Test text" + source = "test_source" + + result = Context.objects.create(generated_text=generated_text, source=source) + + mock_create.assert_called_once_with(generated_text=generated_text, source=source) + assert result == mock_context + + @patch("apps.ai.models.context.Context.objects.filter") + def test_context_manager_filter(self, mock_filter): + mock_queryset = Mock() + mock_filter.return_value = mock_queryset + + result = Context.objects.filter(source="test_source") + + mock_filter.assert_called_once_with(source="test_source") + assert result == mock_queryset + + @patch("apps.ai.models.context.Context.objects.get") + def test_context_manager_get(self, mock_get): + mock_context = create_model_mock(Context) + mock_get.return_value = mock_context + + result = Context.objects.get(id=1) + + mock_get.assert_called_once_with(id=1) + assert result == mock_context + + @patch("apps.ai.models.context.Context.full_clean") + def test_context_validation(self, mock_full_clean): + context = Context() + context.generated_text = "Valid text" + context.source = "A" * 100 + + context.full_clean() + + mock_full_clean.assert_called_once() + + @patch("apps.ai.models.context.Context.full_clean") + def test_context_validation_source_too_long(self, mock_full_clean): + from django.core.exceptions import ValidationError + + mock_full_clean.side_effect = ValidationError("Source too long") + + context = Context() + context.generated_text = "Valid text" + context.source = "A" * 101 + + with pytest.raises(ValidationError) as exc_info: + context.full_clean() + assert "Source too long" in str(exc_info.value) + + def test_context_default_values(self): + context = Context() + + assert context.object_id == 0 + assert context.source == "" + assert context.content_type is None + assert context.content_object is None + + @patch("apps.ai.models.context.Context.refresh_from_db") + def test_context_refresh_from_db(self, mock_refresh): + context = Context() + context.refresh_from_db() + + mock_refresh.assert_called_once() + + @patch("apps.ai.models.context.Context.delete") + def test_context_delete(self, mock_delete): + context = Context() + context.delete() + + mock_delete.assert_called_once() From e876a0cd9d342ff4415da622ce2b26ded5b7d56a Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Wed, 30 Jul 2025 00:12:39 +0530 Subject: [PATCH 06/32] retrieving data from context model --- backend/apps/ai/agent/tools/rag/retriever.py | 29 ++++++++------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/backend/apps/ai/agent/tools/rag/retriever.py b/backend/apps/ai/agent/tools/rag/retriever.py index a4ed638ef6..26d0c6fc14 100644 --- a/backend/apps/ai/agent/tools/rag/retriever.py +++ b/backend/apps/ai/agent/tools/rag/retriever.py @@ -21,7 +21,7 @@ class Retriever: """A class for retrieving relevant text chunks for a RAG.""" - SUPPORTED_CONTENT_TYPES = ["event", "project", "chapter", "committee", "message"] + SUPPORTED_CONTENT_TYPES = ("event", "project", "chapter", "committee", "message") def __init__(self, embedding_model: str = "text-embedding-3-small"): """Initialize the Retriever. @@ -36,7 +36,6 @@ def __init__(self, embedding_model: str = "text-embedding-3-small"): if not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): error_msg = "DJANGO_OPEN_AI_SECRET_KEY environment variable not set" raise ValueError(error_msg) - self.openai_client = openai.OpenAI(api_key=openai_api_key) self.embedding_model = embedding_model logger.info("Retriever initialized with embedding model: %s", self.embedding_model) @@ -69,7 +68,6 @@ def get_source_name(self, content_object) -> str: for attr in ("name", "title", "login", "key", "summary"): if getattr(content_object, attr, None): return str(getattr(content_object, attr)) - return str(content_object) def get_additional_context(self, content_object, content_type: str) -> dict[str, Any]: @@ -85,7 +83,6 @@ def get_additional_context(self, content_object, content_type: str) -> dict[str, """ context = {} clean_content_type = content_type.split(".")[-1] if "." in content_type else content_type - if clean_content_type == "chapter": context.update( { @@ -178,7 +175,6 @@ def get_additional_context(self, content_object, content_type: str) -> dict[str, ), } ) - return {k: v for k, v in context.items() if v is not None} def retrieve( @@ -201,14 +197,11 @@ def retrieve( """ query_embedding = self.get_query_embedding(query) - if not content_types: content_types = self.extract_content_types_from_query(query) - queryset = Chunk.objects.annotate( similarity=1 - CosineDistance("embedding", query_embedding) ).filter(similarity__gte=similarity_threshold) - if content_types: content_type_query = Q() for name in content_types: @@ -216,36 +209,37 @@ def retrieve( if "." in lower_name: app_label, model = lower_name.split(".", 1) content_type_query |= Q( - content_type__app_label=app_label, content_type__model=model + context__content_type__app_label=app_label, + context__content_type__model=model, ) else: - content_type_query |= Q(content_type__model=lower_name) + content_type_query |= Q(context__content_type__model=lower_name) queryset = queryset.filter(content_type_query) chunks = ( - queryset.select_related("content_type") - .prefetch_related("content_object") + queryset.select_related("context__content_type") + .prefetch_related("context__content_object") .order_by("-similarity")[:limit] ) results = [] for chunk in chunks: - if not chunk.content_object: + if not chunk.context or not chunk.context.content_object: logger.warning("Content object is None for chunk %s. Skipping.", chunk.id) continue - source_name = self.get_source_name(chunk.content_object) + source_name = self.get_source_name(chunk.context.content_object) additional_context = self.get_additional_context( - chunk.content_object, chunk.content_type.model + chunk.context.content_object, chunk.context.content_type.model ) results.append( { "text": chunk.text, "similarity": float(chunk.similarity), - "source_type": chunk.content_type.model, + "source_type": chunk.context.content_type.model, "source_name": source_name, - "source_id": chunk.object_id, + "source_id": chunk.context.object_id, "additional_context": additional_context, } ) @@ -262,7 +256,6 @@ def extract_content_types_from_query(self, query: str) -> list[str]: A list of detected content type names. """ - detected_types = [] query_words = set(re.findall(r"\b\w+\b", query.lower())) detected_types = [ From 981277a87c390d287b84c8902dae10aebb603b0c Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Wed, 30 Jul 2025 00:52:10 +0530 Subject: [PATCH 07/32] removed try except --- backend/apps/ai/agent/tools/rag/rag_tool.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/backend/apps/ai/agent/tools/rag/rag_tool.py b/backend/apps/ai/agent/tools/rag/rag_tool.py index 6072de8cbb..8375b4a328 100644 --- a/backend/apps/ai/agent/tools/rag/rag_tool.py +++ b/backend/apps/ai/agent/tools/rag/rag_tool.py @@ -28,12 +28,8 @@ def __init__( ValueError: If the OpenAI API key is not set. """ - try: - self.retriever = Retriever(embedding_model=embedding_model) - self.generator = Generator(chat_model=chat_model) - except Exception: - logger.exception("Failed to initialize RAG tool") - raise + self.retriever = Retriever(embedding_model=embedding_model) + self.generator = Generator(chat_model=chat_model) def query( self, From 8b46f08193ec7227ece9a4806fb9f4779501b716 Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Wed, 30 Jul 2025 23:31:08 +0530 Subject: [PATCH 08/32] Suggestions implemented --- backend/apps/ai/admin.py | 4 +-- backend/apps/ai/common/utils.py | 2 +- ...generated_text_context_content_and_more.py | 28 +++++++++++++++++++ backend/apps/ai/models/chunk.py | 2 +- backend/apps/ai/models/context.py | 5 ++-- backend/tests/apps/ai/models/chunk_test.py | 2 -- backend/tests/apps/ai/models/context_test.py | 26 ++++++++--------- 7 files changed, 48 insertions(+), 21 deletions(-) create mode 100644 backend/apps/ai/migrations/0006_rename_generated_text_context_content_and_more.py diff --git a/backend/apps/ai/admin.py b/backend/apps/ai/admin.py index 1ce5b2e8a8..cd804992cd 100644 --- a/backend/apps/ai/admin.py +++ b/backend/apps/ai/admin.py @@ -11,12 +11,12 @@ class ContextAdmin(admin.ModelAdmin): list_display = ( "id", - "generated_text", + "content", "content_type", "object_id", "source", ) - search_fields = ("generated_text", "source") + search_fields = ("content", "source") list_filter = ("content_type", "source") diff --git a/backend/apps/ai/common/utils.py b/backend/apps/ai/common/utils.py index 2cb4713679..906512e51b 100644 --- a/backend/apps/ai/common/utils.py +++ b/backend/apps/ai/common/utils.py @@ -45,7 +45,7 @@ def create_chunks_and_embeddings( ) context = Context( - generated_text="\n".join(all_chunk_texts), + content="\n".join(all_chunk_texts), content_object=content_object, ) context.save() diff --git a/backend/apps/ai/migrations/0006_rename_generated_text_context_content_and_more.py b/backend/apps/ai/migrations/0006_rename_generated_text_context_content_and_more.py new file mode 100644 index 0000000000..370700992d --- /dev/null +++ b/backend/apps/ai/migrations/0006_rename_generated_text_context_content_and_more.py @@ -0,0 +1,28 @@ +# Generated by Django 5.2.4 on 2025-07-30 12:49 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("ai", "0005_context_alter_chunk_unique_together_chunk_context_and_more"), + ] + + operations = [ + migrations.RenameField( + model_name="context", + old_name="generated_text", + new_name="content", + ), + migrations.AlterField( + model_name="chunk", + name="context", + field=models.ForeignKey( + default="", + on_delete=django.db.models.deletion.CASCADE, + related_name="chunks", + to="ai.context", + ), + ), + ] diff --git a/backend/apps/ai/models/chunk.py b/backend/apps/ai/models/chunk.py index ceb651c321..fd3c79d1ac 100644 --- a/backend/apps/ai/models/chunk.py +++ b/backend/apps/ai/models/chunk.py @@ -18,7 +18,7 @@ class Meta: unique_together = ("context", "text") context = models.ForeignKey( - Context, on_delete=models.CASCADE, related_name="chunks", null=True, blank=True + Context, on_delete=models.CASCADE, related_name="chunks", default="" ) embedding = VectorField(verbose_name="Embedding", dimensions=1536) text = models.TextField(verbose_name="Text") diff --git a/backend/apps/ai/models/context.py b/backend/apps/ai/models/context.py index 8de1579ca3..209a82f59c 100644 --- a/backend/apps/ai/models/context.py +++ b/backend/apps/ai/models/context.py @@ -10,7 +10,7 @@ class Context(TimestampedModel): """Context model for storing generated text and optional relation to OWASP entities.""" - generated_text = models.TextField(verbose_name="Generated Text") + content = models.TextField(verbose_name="Generated Text") content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE, blank=True, null=True) object_id = models.PositiveIntegerField(default=0) content_object = GenericForeignKey("content_type", "object_id") @@ -19,6 +19,7 @@ class Context(TimestampedModel): class Meta: db_table = "ai_contexts" verbose_name = "Context" + unique_together = ("content_type", "object_id") def __str__(self): """Human readable representation.""" @@ -29,5 +30,5 @@ def __str__(self): ) return ( f"{self.content_type.model if self.content_type else 'None'} {entity}: " - f"{self.generated_text[:50]}" + f"{self.content[:50]}" ) diff --git a/backend/tests/apps/ai/models/chunk_test.py b/backend/tests/apps/ai/models/chunk_test.py index 9023377f7c..f5ec69ec32 100644 --- a/backend/tests/apps/ai/models/chunk_test.py +++ b/backend/tests/apps/ai/models/chunk_test.py @@ -133,5 +133,3 @@ def test_context_relationship(self): from apps.ai.models.context import Context assert context_field.related_model == Context - assert context_field.null is True - assert context_field.blank is True diff --git a/backend/tests/apps/ai/models/context_test.py b/backend/tests/apps/ai/models/context_test.py index 3b244c1161..4bed7c7de5 100644 --- a/backend/tests/apps/ai/models/context_test.py +++ b/backend/tests/apps/ai/models/context_test.py @@ -19,7 +19,7 @@ class TestContextModel: def test_str_method_without_content_type(self): context = Context() context.id = 1 - context.generated_text = "Sample text without content type" + context.content = "Sample text without content type" context.content_type = None context.content_object = None @@ -32,7 +32,7 @@ def test_str_method_with_text_truncation(self): context = Context() context.id = 1 - context.generated_text = long_text + context.content = long_text context.content_type = None context.content_object = None @@ -46,7 +46,7 @@ def test_str_method_with_exactly_50_chars(self): context = Context() context.id = 1 - context.generated_text = text_50_chars + context.content = text_50_chars context.content_type = None context.content_object = None @@ -58,7 +58,7 @@ def test_str_method_with_exactly_50_chars(self): def test_str_method_with_empty_text(self): context = Context() context.id = 1 - context.generated_text = "" + context.content = "" context.content_type = None context.content_object = None @@ -70,8 +70,8 @@ def test_meta_class_attributes(self): assert Context._meta.db_table == "ai_contexts" assert Context._meta.verbose_name == "Context" - def test_generated_text_field_properties(self): - field = Context._meta.get_field("generated_text") + def test_content_field_properties(self): + field = Context._meta.get_field("content") assert field.verbose_name == "Generated Text" assert field.__class__.__name__ == "TextField" @@ -104,10 +104,10 @@ def test_content_object_generic_foreign_key(self): def test_context_creation_with_save(self, mock_init, mock_save): mock_init.return_value = None - generated_text = "Test generated text" + content = "Test generated text" source = "test_source" - context = Context(generated_text=generated_text, source=source) + context = Context(content=content, source=source) context.save() mock_save.assert_called_once() @@ -122,12 +122,12 @@ def test_context_manager_create(self, mock_create): mock_context = create_model_mock(Context) mock_create.return_value = mock_context - generated_text = "Test text" + content = "Test text" source = "test_source" - result = Context.objects.create(generated_text=generated_text, source=source) + result = Context.objects.create(content=content, source=source) - mock_create.assert_called_once_with(generated_text=generated_text, source=source) + mock_create.assert_called_once_with(content=content, source=source) assert result == mock_context @patch("apps.ai.models.context.Context.objects.filter") @@ -153,7 +153,7 @@ def test_context_manager_get(self, mock_get): @patch("apps.ai.models.context.Context.full_clean") def test_context_validation(self, mock_full_clean): context = Context() - context.generated_text = "Valid text" + context.content = "Valid text" context.source = "A" * 100 context.full_clean() @@ -167,7 +167,7 @@ def test_context_validation_source_too_long(self, mock_full_clean): mock_full_clean.side_effect = ValidationError("Source too long") context = Context() - context.generated_text = "Valid text" + context.content = "Valid text" context.source = "A" * 101 with pytest.raises(ValidationError) as exc_info: From 16fabcf0f2526e7f6fbdf5b68d9f02249a7e6778 Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Wed, 30 Jul 2025 23:50:12 +0530 Subject: [PATCH 09/32] code rabbit suggestion --- ...k_context_alter_context_unique_together.py | 28 +++++++++++++++++++ backend/apps/ai/models/chunk.py | 2 +- 2 files changed, 29 insertions(+), 1 deletion(-) create mode 100644 backend/apps/ai/migrations/0007_alter_chunk_context_alter_context_unique_together.py diff --git a/backend/apps/ai/migrations/0007_alter_chunk_context_alter_context_unique_together.py b/backend/apps/ai/migrations/0007_alter_chunk_context_alter_context_unique_together.py new file mode 100644 index 0000000000..ff2c206c93 --- /dev/null +++ b/backend/apps/ai/migrations/0007_alter_chunk_context_alter_context_unique_together.py @@ -0,0 +1,28 @@ +# Generated by Django 5.2.4 on 2025-07-30 18:15 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("ai", "0006_rename_generated_text_context_content_and_more"), + ("contenttypes", "0002_remove_content_type_name"), + ] + + operations = [ + migrations.AlterField( + model_name="chunk", + name="context", + field=models.ForeignKey( + default=None, + on_delete=django.db.models.deletion.CASCADE, + related_name="chunks", + to="ai.context", + ), + ), + migrations.AlterUniqueTogether( + name="context", + unique_together={("content_type", "object_id")}, + ), + ] diff --git a/backend/apps/ai/models/chunk.py b/backend/apps/ai/models/chunk.py index fd3c79d1ac..9bd7ca7fd6 100644 --- a/backend/apps/ai/models/chunk.py +++ b/backend/apps/ai/models/chunk.py @@ -18,7 +18,7 @@ class Meta: unique_together = ("context", "text") context = models.ForeignKey( - Context, on_delete=models.CASCADE, related_name="chunks", default="" + Context, on_delete=models.CASCADE, related_name="chunks", default=None ) embedding = VectorField(verbose_name="Embedding", dimensions=1536) text = models.TextField(verbose_name="Text") From 77203b8dbe70c65eb6f1b82d4afe844ce45e17cb Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Thu, 31 Jul 2025 00:21:57 +0530 Subject: [PATCH 10/32] removed deafult --- ...07_alter_chunk_context_alter_context_unique_together.py | 7 ++----- backend/apps/ai/models/chunk.py | 4 +--- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/backend/apps/ai/migrations/0007_alter_chunk_context_alter_context_unique_together.py b/backend/apps/ai/migrations/0007_alter_chunk_context_alter_context_unique_together.py index ff2c206c93..da449203b4 100644 --- a/backend/apps/ai/migrations/0007_alter_chunk_context_alter_context_unique_together.py +++ b/backend/apps/ai/migrations/0007_alter_chunk_context_alter_context_unique_together.py @@ -1,4 +1,4 @@ -# Generated by Django 5.2.4 on 2025-07-30 18:15 +# Generated by Django 5.2.4 on 2025-07-30 18:47 import django.db.models.deletion from django.db import migrations, models @@ -15,10 +15,7 @@ class Migration(migrations.Migration): model_name="chunk", name="context", field=models.ForeignKey( - default=None, - on_delete=django.db.models.deletion.CASCADE, - related_name="chunks", - to="ai.context", + on_delete=django.db.models.deletion.CASCADE, related_name="chunks", to="ai.context" ), ), migrations.AlterUniqueTogether( diff --git a/backend/apps/ai/models/chunk.py b/backend/apps/ai/models/chunk.py index 9bd7ca7fd6..e3144b675a 100644 --- a/backend/apps/ai/models/chunk.py +++ b/backend/apps/ai/models/chunk.py @@ -17,9 +17,7 @@ class Meta: verbose_name = "Chunk" unique_together = ("context", "text") - context = models.ForeignKey( - Context, on_delete=models.CASCADE, related_name="chunks", default=None - ) + context = models.ForeignKey(Context, on_delete=models.CASCADE, related_name="chunks") embedding = VectorField(verbose_name="Embedding", dimensions=1536) text = models.TextField(verbose_name="Text") From 9e03b53cddb4915bd7ff788b8776cfda34b89c04 Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Thu, 31 Jul 2025 00:43:03 +0530 Subject: [PATCH 11/32] updated tests --- backend/tests/apps/ai/common/utils_test.py | 28 +++++++++++++++------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/backend/tests/apps/ai/common/utils_test.py b/backend/tests/apps/ai/common/utils_test.py index 6cc5057e79..8186ce78da 100644 --- a/backend/tests/apps/ai/common/utils_test.py +++ b/backend/tests/apps/ai/common/utils_test.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch from apps.ai.common.utils import create_chunks_and_embeddings @@ -9,9 +9,12 @@ def __init__(self, embedding): class TestUtils: + @patch("apps.ai.common.utils.Context") @patch("apps.ai.common.utils.Chunk.update_data") @patch("apps.ai.common.utils.time.sleep") - def test_create_chunks_and_embeddings_success(self, mock_sleep, mock_update_data): + def test_create_chunks_and_embeddings_success( + self, mock_sleep, mock_update_data, mock_context + ): """Tests the successful path where the OpenAI API returns embeddings.""" mock_openai_client = MagicMock() mock_api_response = MagicMock() @@ -37,12 +40,21 @@ def test_create_chunks_and_embeddings_success(self, mock_sleep, mock_update_data model="text-embedding-3-small", ) - assert mock_update_data.call_count == 2 - mock_update_data.assert_any_call( - content_object=mock_content_object, - embedding=[0.1, 0.2], - save=False, - text="first chunk", + mock_update_data.assert_has_calls( + [ + call( + text="first chunk", + context=mock_context(), + embedding=[0.1, 0.2], + save=False, + ), + call( + text="second chunk", + context=mock_context(), + embedding=[0.3, 0.4], + save=False, + ), + ] ) assert result == ["mock_chunk_instance", "mock_chunk_instance"] From 41f8126c3a346a32fb8b4320667470db523882bf Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Tue, 5 Aug 2025 22:05:05 +0530 Subject: [PATCH 12/32] de coupled context and chunks --- backend/apps/ai/agent/tools/rag/retriever.py | 6 +- backend/apps/ai/common/utils.py | 75 +++++---- .../commands/ai_create_chapter_chunks.py | 133 ++++++++++++---- .../commands/ai_create_committee_chunks.py | 137 +++++++++++++---- .../commands/ai_create_event_chunks.py | 129 ++++++++++++---- .../commands/ai_create_project_chunks.py | 125 +++++++++++---- .../ai_create_slack_message_chunks.py | 143 ++++++++++++++---- backend/tests/apps/ai/common/utils_test.py | 61 ++++++-- 8 files changed, 605 insertions(+), 204 deletions(-) diff --git a/backend/apps/ai/agent/tools/rag/retriever.py b/backend/apps/ai/agent/tools/rag/retriever.py index 26d0c6fc14..501f2f06f4 100644 --- a/backend/apps/ai/agent/tools/rag/retriever.py +++ b/backend/apps/ai/agent/tools/rag/retriever.py @@ -216,11 +216,7 @@ def retrieve( content_type_query |= Q(context__content_type__model=lower_name) queryset = queryset.filter(content_type_query) - chunks = ( - queryset.select_related("context__content_type") - .prefetch_related("context__content_object") - .order_by("-similarity")[:limit] - ) + chunks = queryset.select_related("context__content_type").order_by("-similarity")[:limit] results = [] for chunk in chunks: diff --git a/backend/apps/ai/common/utils.py b/backend/apps/ai/common/utils.py index 906512e51b..ce2a7054c5 100644 --- a/backend/apps/ai/common/utils.py +++ b/backend/apps/ai/common/utils.py @@ -4,6 +4,8 @@ import time from datetime import UTC, datetime, timedelta +import openai + from apps.ai.common.constants import ( DEFAULT_LAST_REQUEST_OFFSET_SECONDS, MIN_REQUEST_INTERVAL_SECONDS, @@ -14,20 +16,42 @@ logger: logging.Logger = logging.getLogger(__name__) +def create_context(content: str, content_object=None, source: str = "") -> Context: + """Create and save a Context instance independently. + + Args: + content (str): The context content + content_object: Optional related object + source (str): Source identifier + + Returns: + Context: Created Context instance + + """ + context = Context(content=content, content_object=content_object, source=source) + context.save() + return context + + def create_chunks_and_embeddings( - all_chunk_texts: list[str], - content_object, + chunk_texts: list[str], + context: Context, openai_client, + model: str = "text-embedding-3-small", + *, + save: bool = True, ) -> list[Chunk]: """Create chunks and embeddings from given texts using OpenAI embeddings. Args: - all_chunk_texts (list[str]): List of text chunks to embed. - content_object: The object to associate the chunks with. - openai_client: Initialized OpenAI client instance. + chunk_texts (list[str]): List of text chunks to process + context (Context): The context these chunks belong to + openai_client: Initialized OpenAI client + model (str): Embedding model to use + save (bool): Whether to save chunks immediately Returns: - list[Chunk]: List of Chunk instances (not saved). + list[Chunk]: List of created Chunk instances (empty if failed) """ try: @@ -40,32 +64,19 @@ def create_chunks_and_embeddings( time.sleep(MIN_REQUEST_INTERVAL_SECONDS - time_since_last_request.total_seconds()) response = openai_client.embeddings.create( - input=all_chunk_texts, - model="text-embedding-3-small", + input=chunk_texts, + model=model, ) + embeddings = [d.embedding for d in response.data] - context = Context( - content="\n".join(all_chunk_texts), - content_object=content_object, - ) - context.save() - - return [ - chunk - for text, embedding in zip( - all_chunk_texts, - [d.embedding for d in response.data], - strict=True, - ) - if ( - chunk := Chunk.update_data( - text=text, - context=context, - embedding=embedding, - save=False, - ) - ) - ] - except Exception: - logger.exception("OpenAI API error") + chunks = [] + for text, embedding in zip(chunk_texts, embeddings, strict=True): + chunk = Chunk.update_data(text=text, context=context, embedding=embedding, save=save) + if chunk: + chunks.append(chunk) + + except openai.OpenAIError: + logger.exception("Failed to create chunks and embeddings") return [] + else: + return chunks diff --git a/backend/apps/ai/management/commands/ai_create_chapter_chunks.py b/backend/apps/ai/management/commands/ai_create_chapter_chunks.py index 8b73079e64..427be08dda 100644 --- a/backend/apps/ai/management/commands/ai_create_chapter_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_chapter_chunks.py @@ -3,11 +3,13 @@ import os import openai +from django.contrib.contenttypes.models import ContentType from django.core.management.base import BaseCommand from apps.ai.common.constants import DELIMITER -from apps.ai.common.utils import create_chunks_and_embeddings +from apps.ai.common.utils import create_chunks_and_embeddings, create_context from apps.ai.models.chunk import Chunk +from apps.ai.models.context import Context from apps.owasp.models.chapter import Chapter @@ -16,7 +18,7 @@ class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument( - "--chapter", + "--chapter-key", type=str, help="Process only the chapter with this key", ) @@ -31,18 +33,35 @@ def add_arguments(self, parser): default=50, help="Number of chapters to process in each batch", ) + parser.add_argument( + "--context", + action="store_true", + help="Create only context (skip chunks and embeddings)", + ) + parser.add_argument( + "--chunks", + action="store_true", + help="Create only chunks+embeddings (requires existing context)", + ) def handle(self, *args, **options): - if not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): + if not options["context"] and not options["chunks"]: + self.stdout.write( + self.style.ERROR("Please specify either --context or --chunks (or both)") + ) + return + + if options["chunks"] and not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): self.stdout.write( self.style.ERROR("DJANGO_OPEN_AI_SECRET_KEY environment variable not set") ) return - self.openai_client = openai.OpenAI(api_key=openai_api_key) + if options["chunks"]: + self.openai_client = openai.OpenAI(api_key=openai_api_key) - if chapter := options["chapter"]: - queryset = Chapter.objects.filter(key=chapter) + if options["chapter_key"]: + queryset = Chapter.objects.filter(key=options["chapter_key"]) elif options["all"]: queryset = Chapter.objects.all() else: @@ -55,40 +74,88 @@ def handle(self, *args, **options): self.stdout.write(f"Found {total_chapters} chapters to process") batch_size = options["batch_size"] + processed_count = 0 + for offset in range(0, total_chapters, batch_size): batch_chapters = queryset[offset : offset + batch_size] - batch_chunks = [] - for chapter in batch_chapters: - batch_chunks.extend(self.handle_chunks(chapter)) - - if batch_chunks: - Chunk.bulk_save(batch_chunks) - self.stdout.write(f"Saved {len(batch_chunks)} chunks") - - self.stdout.write(f"Completed processing all {total_chapters} chapters") - - def handle_chunks(self, chapter: Chapter) -> list[Chunk]: - """Create chunks from a chapter's data.""" - prose_content, metadata_content = self.extract_chapter_content(chapter) + if options["context"]: + processed_count += self.process_context_batch(batch_chapters) + elif options["chunks"]: + processed_count += self.process_chunks_batch(batch_chapters) - all_chunk_texts = [] - - if metadata_content.strip(): - all_chunk_texts.append(metadata_content) + self.stdout.write( + self.style.SUCCESS(f"Completed processing {processed_count}/{total_chapters} chapters") + ) - if prose_content.strip(): - all_chunk_texts.extend(Chunk.split_text(prose_content)) + def process_context_batch(self, chapters: list[Chapter]) -> int: + """Process a batch of chapters to create contexts.""" + processed = 0 - if not all_chunk_texts: - self.stdout.write(f"No content to chunk for chapter {chapter.key}") - return [] + for chapter in chapters: + prose_content, metadata_content = self.extract_chapter_content(chapter) + full_content = ( + f"{metadata_content}\n\n{prose_content}" if metadata_content else prose_content + ) - return create_chunks_and_embeddings( - all_chunk_texts=all_chunk_texts, - content_object=chapter, - openai_client=self.openai_client, - ) + if not full_content.strip(): + self.stdout.write(f"No content for chapter {chapter.key}") + continue + + if create_context( + content=full_content, content_object=chapter, source="owasp_chapter" + ): + processed += 1 + self.stdout.write(f"Created context for {chapter.key}") + else: + self.stdout.write(self.style.ERROR(f"Failed to create context for {chapter.key}")) + return processed + + def process_chunks_batch(self, chapters: list[Chapter]) -> int: + """Process a batch of chapters to create chunks.""" + processed = 0 + batch_chunks = [] + + chapter_content_type = ContentType.objects.get_for_model(Chapter) + + for chapter in chapters: + context = Context.objects.filter( + content_type=chapter_content_type, object_id=chapter.id + ).first() + + if not context: + self.stdout.write( + self.style.WARNING(f"No context found for chapter {chapter.key}") + ) + continue + + prose_content, metadata_content = self.extract_chapter_content(chapter) + all_chunk_texts = [] + + if metadata_content.strip(): + all_chunk_texts.append(metadata_content) + + if prose_content.strip(): + prose_chunks = Chunk.split_text(prose_content) + all_chunk_texts.extend(prose_chunks) + + if not all_chunk_texts: + self.stdout.write(f"No content to chunk for chapter {chapter.key}") + continue + + if chunks := create_chunks_and_embeddings( + chunk_texts=all_chunk_texts, + context=context, + openai_client=self.openai_client, + save=False, + ): + batch_chunks.extend(chunks) + processed += 1 + self.stdout.write(f"Created {len(chunks)} chunks for {chapter.key}") + + if batch_chunks: + Chunk.bulk_save(batch_chunks) + return processed def extract_chapter_content(self, chapter: Chapter) -> tuple[str, str]: """Extract and separate prose content from metadata for a chapter. diff --git a/backend/apps/ai/management/commands/ai_create_committee_chunks.py b/backend/apps/ai/management/commands/ai_create_committee_chunks.py index 6ae3771bc6..3efcd3b0fa 100644 --- a/backend/apps/ai/management/commands/ai_create_committee_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_committee_chunks.py @@ -3,11 +3,13 @@ import os import openai +from django.contrib.contenttypes.models import ContentType from django.core.management.base import BaseCommand from apps.ai.common.constants import DELIMITER -from apps.ai.common.utils import create_chunks_and_embeddings +from apps.ai.common.utils import create_chunks_and_embeddings, create_context from apps.ai.models.chunk import Chunk +from apps.ai.models.context import Context from apps.owasp.models.committee import Committee @@ -16,7 +18,7 @@ class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument( - "--committee", + "--committee-key", type=str, help="Process only the committee with this key", ) @@ -31,18 +33,35 @@ def add_arguments(self, parser): default=50, help="Number of committees to process in each batch", ) + parser.add_argument( + "--context", + action="store_true", + help="Create only context (skip chunks and embeddings)", + ) + parser.add_argument( + "--chunks", + action="store_true", + help="Create only chunks+embeddings (requires existing context)", + ) def handle(self, *args, **options): - if not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): + if not options["context"] and not options["chunks"]: + self.stdout.write( + self.style.ERROR("Please specify either --context or --chunks (or both)") + ) + return + + if options["chunks"] and not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): self.stdout.write( self.style.ERROR("DJANGO_OPEN_AI_SECRET_KEY environment variable not set") ) return - self.openai_client = openai.OpenAI(api_key=openai_api_key) + if options["chunks"]: + self.openai_client = openai.OpenAI(api_key=openai_api_key) - if committee := options["committee"]: - queryset = Committee.objects.filter(key=committee) + if options["committee_key"]: + queryset = Committee.objects.filter(key=options["committee_key"]) elif options["all"]: queryset = Committee.objects.all() else: @@ -55,40 +74,92 @@ def handle(self, *args, **options): self.stdout.write(f"Found {total_committees} committees to process") batch_size = options["batch_size"] + processed_count = 0 + for offset in range(0, total_committees, batch_size): batch_committees = queryset[offset : offset + batch_size] - batch_chunks = [] - for committee in batch_committees: - batch_chunks.extend(self.handle_chunks(committee)) - - if batch_chunks: - Chunk.bulk_save(batch_chunks) - self.stdout.write(f"Saved {len(batch_chunks)} chunks") - - self.stdout.write(f"Completed processing all {total_committees} committees") - - def handle_chunks(self, committee: Committee) -> list[Chunk]: - """Create chunks from a committee's data.""" - prose_content, metadata_content = self.extract_committee_content(committee) - - all_chunk_texts = [] + if options["context"]: + processed_count += self.process_context_batch(batch_committees) + elif options["chunks"]: + processed_count += self.process_chunks_batch(batch_committees) - if metadata_content.strip(): - all_chunk_texts.append(metadata_content) + self.stdout.write( + self.style.SUCCESS( + f"Completed processing {processed_count}/{total_committees} committees" + ) + ) - if prose_content.strip(): - all_chunk_texts.extend(Chunk.split_text(prose_content)) + def process_context_batch(self, committees: list[Committee]) -> int: + """Process a batch of committees to create contexts.""" + processed = 0 - if not all_chunk_texts: - self.stdout.write(f"No content to chunk for committee {committee.key}") - return [] + for committee in committees: + prose_content, metadata_content = self.extract_committee_content(committee) + full_content = ( + f"{metadata_content}\n\n{prose_content}" if metadata_content else prose_content + ) - return create_chunks_and_embeddings( - all_chunk_texts=all_chunk_texts, - content_object=committee, - openai_client=self.openai_client, - ) + if not full_content.strip(): + self.stdout.write(f"No content for committee {committee.key}") + continue + + if create_context( + content=full_content, content_object=committee, source="owasp_committee" + ): + processed += 1 + self.stdout.write(f"Created context for {committee.key}") + else: + self.stdout.write( + self.style.ERROR(f"Failed to create context for {committee.key}") + ) + return processed + + def process_chunks_batch(self, committees: list[Committee]) -> int: + """Process a batch of committees to create chunks.""" + processed = 0 + batch_chunks = [] + + committee_content_type = ContentType.objects.get_for_model(Committee) + + for committee in committees: + context = Context.objects.filter( + content_type=committee_content_type, object_id=committee.id + ).first() + + if not context: + self.stdout.write( + self.style.WARNING(f"No context found for committee {committee.key}") + ) + continue + + prose_content, metadata_content = self.extract_committee_content(committee) + all_chunk_texts = [] + + if metadata_content.strip(): + all_chunk_texts.append(metadata_content) + + if prose_content.strip(): + prose_chunks = Chunk.split_text(prose_content) + all_chunk_texts.extend(prose_chunks) + + if not all_chunk_texts: + self.stdout.write(f"No content to chunk for committee {committee.key}") + continue + + if chunks := create_chunks_and_embeddings( + chunk_texts=all_chunk_texts, + context=context, + openai_client=self.openai_client, + save=False, + ): + batch_chunks.extend(chunks) + processed += 1 + self.stdout.write(f"Created {len(chunks)} chunks for {committee.key}") + + if batch_chunks: + Chunk.bulk_save(batch_chunks) + return processed def extract_committee_content(self, committee: Committee) -> tuple[str, str]: """Extract structured content from committee data.""" diff --git a/backend/apps/ai/management/commands/ai_create_event_chunks.py b/backend/apps/ai/management/commands/ai_create_event_chunks.py index d0dab81a0c..4a3def233b 100644 --- a/backend/apps/ai/management/commands/ai_create_event_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_event_chunks.py @@ -3,11 +3,13 @@ import os import openai +from django.contrib.contenttypes.models import ContentType from django.core.management.base import BaseCommand from apps.ai.common.constants import DELIMITER -from apps.ai.common.utils import create_chunks_and_embeddings +from apps.ai.common.utils import create_chunks_and_embeddings, create_context from apps.ai.models.chunk import Chunk +from apps.ai.models.context import Context from apps.owasp.models.event import Event @@ -16,7 +18,7 @@ class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument( - "--event", + "--event-key", type=str, help="Process only the event with this key", ) @@ -31,18 +33,35 @@ def add_arguments(self, parser): default=50, help="Number of events to process in each batch", ) + parser.add_argument( + "--context", + action="store_true", + help="Create only context (skip chunks and embeddings)", + ) + parser.add_argument( + "--chunks", + action="store_true", + help="Create only chunks+embeddings (requires existing context)", + ) def handle(self, *args, **options): - if not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): + if not options["context"] and not options["chunks"]: + self.stdout.write( + self.style.ERROR("Please specify either --context or --chunks (or both)") + ) + return + + if options["chunks"] and not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): self.stdout.write( self.style.ERROR("DJANGO_OPEN_AI_SECRET_KEY environment variable not set") ) return - self.openai_client = openai.OpenAI(api_key=openai_api_key) + if options["chunks"]: + self.openai_client = openai.OpenAI(api_key=openai_api_key) - if event := options["event"]: - queryset = Event.objects.filter(key=event) + if options["event_key"]: + queryset = Event.objects.filter(key=options["event_key"]) elif options["all"]: queryset = Event.objects.all() else: @@ -55,40 +74,84 @@ def handle(self, *args, **options): self.stdout.write(f"Found {total_events} events to process") batch_size = options["batch_size"] + processed_count = 0 + for offset in range(0, total_events, batch_size): batch_events = queryset[offset : offset + batch_size] - batch_chunks = [] - for event in batch_events: - batch_chunks.extend(self.handle_chunks(event)) - - if batch_chunks: - Chunk.bulk_save(batch_chunks) - self.stdout.write(f"Saved {len(batch_chunks)} chunks") - - self.stdout.write(f"Completed processing all {total_events} events") - - def handle_chunks(self, event: Event) -> list[Chunk]: - """Create chunks from an event's data.""" - prose_content, metadata_content = self.extract_event_content(event) + if options["context"]: + processed_count += self.process_context_batch(batch_events) + elif options["chunks"]: + processed_count += self.process_chunks_batch(batch_events) - all_chunk_texts = [] - - if metadata_content.strip(): - all_chunk_texts.append(metadata_content) + self.stdout.write( + self.style.SUCCESS(f"Completed processing {processed_count}/{total_events} events") + ) - if prose_content.strip(): - all_chunk_texts.extend(Chunk.split_text(prose_content)) + def process_context_batch(self, events: list[Event]) -> int: + """Process a batch of events to create contexts.""" + processed = 0 - if not all_chunk_texts: - self.stdout.write(f"No content to chunk for event {event.key}") - return [] + for event in events: + prose_content, metadata_content = self.extract_event_content(event) + full_content = ( + f"{metadata_content}\n\n{prose_content}" if metadata_content else prose_content + ) - return create_chunks_and_embeddings( - all_chunk_texts, - content_object=event, - openai_client=self.openai_client, - ) + if not full_content.strip(): + self.stdout.write(f"No content for event {event.key}") + continue + + if create_context(content=full_content, content_object=event, source="owasp_event"): + processed += 1 + self.stdout.write(f"Created context for {event.key}") + else: + self.stdout.write(self.style.ERROR(f"Failed to create context for {event.key}")) + return processed + + def process_chunks_batch(self, events: list[Event]) -> int: + """Process a batch of events to create chunks.""" + processed = 0 + batch_chunks = [] + + event_content_type = ContentType.objects.get_for_model(Event) + + for event in events: + context = Context.objects.filter( + content_type=event_content_type, object_id=event.id + ).first() + + if not context: + self.stdout.write(self.style.WARNING(f"No context found for event {event.key}")) + continue + + prose_content, metadata_content = self.extract_event_content(event) + all_chunk_texts = [] + + if metadata_content.strip(): + all_chunk_texts.append(metadata_content) + + if prose_content.strip(): + prose_chunks = Chunk.split_text(prose_content) + all_chunk_texts.extend(prose_chunks) + + if not all_chunk_texts: + self.stdout.write(f"No content to chunk for event {event.key}") + continue + + if chunks := create_chunks_and_embeddings( + chunk_texts=all_chunk_texts, + context=context, + openai_client=self.openai_client, + save=False, + ): + batch_chunks.extend(chunks) + processed += 1 + self.stdout.write(f"Created {len(chunks)} chunks for {event.key}") + + if batch_chunks: + Chunk.bulk_save(batch_chunks) + return processed def extract_event_content(self, event: Event) -> tuple[str, str]: """Extract and separate prose content from metadata for an event. diff --git a/backend/apps/ai/management/commands/ai_create_project_chunks.py b/backend/apps/ai/management/commands/ai_create_project_chunks.py index d472ea9589..62fb62ab58 100644 --- a/backend/apps/ai/management/commands/ai_create_project_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_project_chunks.py @@ -3,11 +3,13 @@ import os import openai +from django.contrib.contenttypes.models import ContentType from django.core.management.base import BaseCommand from apps.ai.common.constants import DELIMITER -from apps.ai.common.utils import create_chunks_and_embeddings +from apps.ai.common.utils import create_chunks_and_embeddings, create_context from apps.ai.models.chunk import Chunk +from apps.ai.models.context import Context from apps.owasp.models.project import Project @@ -25,15 +27,30 @@ def add_arguments(self, parser): default=50, help="Number of projects to process in each batch", ) + parser.add_argument( + "--context", + action="store_true", + help="Create only context (skip chunks and embeddings)", + ) + parser.add_argument( + "--chunks", + action="store_true", + help="Create only chunks+embeddings (requires existing context)", + ) def handle(self, *args, **options): - if not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): + if not options["context"] and not options["chunks"]: + self.stdout.write(self.style.ERROR("Must specify either --context or --chunks")) + return + + if options["chunks"] and not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): self.stdout.write( self.style.ERROR("DJANGO_OPEN_AI_SECRET_KEY environment variable not set") ) return - self.openai_client = openai.OpenAI(api_key=openai_api_key) + if options["chunks"]: + self.openai_client = openai.OpenAI(api_key=openai_api_key) if options["project_key"]: queryset = Project.objects.filter(key=options["project_key"]) @@ -49,42 +66,88 @@ def handle(self, *args, **options): self.stdout.write(f"Found {total_projects} projects to process") batch_size = options["batch_size"] + processed_count = 0 + for offset in range(0, total_projects, batch_size): batch_projects = queryset[offset : offset + batch_size] - batch_chunks = [] - for project in batch_projects: - chunks = self.create_chunks(project) - batch_chunks.extend(chunks) + if options["context"]: + processed_count += self.process_context_batch(batch_projects) + elif options["chunks"]: + processed_count += self.process_chunks_batch(batch_projects) - if batch_chunks: - chunks_count = len(batch_chunks) - Chunk.bulk_save(batch_chunks) - self.stdout.write(f"Saved {chunks_count} chunks") - - self.stdout.write(f"Completed processing all {total_projects} projects") - - def create_chunks(self, project: Project) -> list[Chunk]: - prose_content, metadata_content = self.extract_project_content(project) - - all_chunk_texts = [] + self.stdout.write( + self.style.SUCCESS(f"Completed processing {processed_count}/{total_projects} projects") + ) - if metadata_content.strip(): - all_chunk_texts.append(metadata_content) + def process_context_batch(self, projects: list[Project]) -> int: + """Process a batch of projects to create contexts.""" + processed = 0 - if prose_content.strip(): - prose_chunks = Chunk.split_text(prose_content) - all_chunk_texts.extend(prose_chunks) + for project in projects: + prose_content, metadata_content = self.extract_project_content(project) + full_content = ( + f"{metadata_content}\n\n{prose_content}" if metadata_content else prose_content + ) - if not all_chunk_texts: - self.stdout.write(f"No content to chunk for project {project.key}") - return [] + if not full_content.strip(): + self.stdout.write(f"No content for project {project.key}") + continue + + if create_context( + content=full_content, content_object=project, source="owasp_project" + ): + processed += 1 + self.stdout.write(f"Created context for {project.key}") + else: + self.stdout.write(self.style.ERROR(f"Failed to create context for {project.key}")) + return processed + + def process_chunks_batch(self, projects: list[Project]) -> int: + """Process a batch of projects to create chunks.""" + processed = 0 + batch_chunks = [] + + project_content_type = ContentType.objects.get_for_model(Project) + + for project in projects: + context = Context.objects.filter( + content_type=project_content_type, object_id=project.id + ).first() + + if not context: + self.stdout.write( + self.style.WARNING(f"No context found for project {project.key}") + ) + continue + + prose_content, metadata_content = self.extract_project_content(project) + all_chunk_texts = [] + + if metadata_content.strip(): + all_chunk_texts.append(metadata_content) + + if prose_content.strip(): + prose_chunks = Chunk.split_text(prose_content) + all_chunk_texts.extend(prose_chunks) + + if not all_chunk_texts: + self.stdout.write(f"No content to chunk for project {project.key}") + continue + + if chunks := create_chunks_and_embeddings( + chunk_texts=all_chunk_texts, + context=context, + openai_client=self.openai_client, + save=False, + ): + batch_chunks.extend(chunks) + processed += 1 + self.stdout.write(f"Created {len(chunks)} chunks for {project.key}") - return create_chunks_and_embeddings( - all_chunk_texts=all_chunk_texts, - content_object=project, - openai_client=self.openai_client, - ) + if batch_chunks: + Chunk.bulk_save(batch_chunks) + return processed def extract_project_content(self, project: Project) -> tuple[str, str]: prose_parts = [] diff --git a/backend/apps/ai/management/commands/ai_create_slack_message_chunks.py b/backend/apps/ai/management/commands/ai_create_slack_message_chunks.py index 30b20e0f39..4158627962 100644 --- a/backend/apps/ai/management/commands/ai_create_slack_message_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_slack_message_chunks.py @@ -3,54 +3,141 @@ import os import openai +from django.contrib.contenttypes.models import ContentType from django.core.management.base import BaseCommand -from apps.ai.common.utils import create_chunks_and_embeddings +from apps.ai.common.utils import create_chunks_and_embeddings, create_context from apps.ai.models.chunk import Chunk +from apps.ai.models.context import Context from apps.slack.models.message import Message class Command(BaseCommand): help = "Create chunks for Slack messages" + def add_arguments(self, parser): + parser.add_argument( + "--batch-size", + type=int, + default=100, + help="Number of messages to process in each batch", + ) + parser.add_argument( + "--context", + action="store_true", + help="Create only context (skip chunks and embeddings)", + ) + parser.add_argument( + "--chunks", + action="store_true", + help="Create only chunks+embeddings (requires existing context)", + ) + def handle(self, *args, **options): - if not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): + if not options["context"] and not options["chunks"]: + self.stdout.write( + self.style.ERROR("Please specify either --context or --chunks (or both)") + ) + return + + if options["chunks"] and not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): self.stdout.write( self.style.ERROR("DJANGO_OPEN_AI_SECRET_KEY environment variable not set") ) return - self.openai_client = openai.OpenAI(api_key=openai_api_key) + if options["chunks"]: + self.openai_client = openai.OpenAI(api_key=openai_api_key) + + queryset = Message.objects.all() + total_messages = queryset.count() + + if not total_messages: + self.stdout.write("No messages found to process") + return - total_messages = Message.objects.count() self.stdout.write(f"Found {total_messages} messages to process") - batch_size = 100 + batch_size = options["batch_size"] + processed_count = 0 + for offset in range(0, total_messages, batch_size): - Chunk.bulk_save( - [ - chunk - for message in Message.objects.all()[offset : offset + batch_size] - for chunk in self.handle_chunks(message) - ] - ) + batch_messages = queryset[offset : offset + batch_size] - self.stdout.write(f"Completed processing all {total_messages} messages") + if options["context"]: + processed_count += self.process_context_batch(batch_messages) + elif options["chunks"]: + processed_count += self.process_chunks_batch(batch_messages) - def handle_chunks(self, message: Message) -> list[Chunk]: - """Create chunks from a message.""" - if message.subtype in {"channel_join", "channel_leave"}: - return [] + self.stdout.write( + self.style.SUCCESS(f"Completed processing {processed_count}/{total_messages} messages") + ) - if not (chunk_text := Chunk.split_text(message.cleaned_text)): - self.stdout.write( - f"No chunks created for message {message.slack_message_id}: " - f"`{message.cleaned_text}`" - ) - return [] + def process_context_batch(self, messages: list[Message]) -> int: + """Process a batch of messages to create contexts.""" + processed = 0 - return create_chunks_and_embeddings( - all_chunk_texts=chunk_text, - content_object=message, - openai_client=self.openai_client, - ) + for message in messages: + if not message.cleaned_text or not message.cleaned_text.strip(): + continue + + if create_context( + content=message.cleaned_text, + content_object=message, + source="slack_message", + ): + processed += 1 + self.stdout.write(f"Created context for message {message.slack_message_id}") + else: + self.stdout.write( + self.style.ERROR( + f"Failed to create context for message {message.slack_message_id}" + ) + ) + return processed + + def process_chunks_batch(self, messages: list[Message]) -> int: + """Process a batch of messages to create chunks.""" + processed = 0 + batch_chunks = [] + + message_content_type = ContentType.objects.get_for_model(Message) + + for message in messages: + context = Context.objects.filter( + content_type=message_content_type, object_id=message.id + ).first() + + if not context: + self.stdout.write( + self.style.WARNING(f"No context found for message {message.slack_message_id}") + ) + continue + + if not message.cleaned_text or not message.cleaned_text.strip(): + self.stdout.write(f"No content to chunk for message {message.slack_message_id}") + continue + + chunk_texts = Chunk.split_text(message.cleaned_text) + if not chunk_texts: + self.stdout.write( + f"No chunks created for message {message.slack_message_id}: " + f"`{message.cleaned_text}`" + ) + continue + + if chunks := create_chunks_and_embeddings( + chunk_texts=chunk_texts, + context=context, + openai_client=self.openai_client, + save=False, + ): + batch_chunks.extend(chunks) + processed += 1 + self.stdout.write( + f"Created {len(chunks)} chunks for message {message.slack_message_id}" + ) + + if batch_chunks: + Chunk.bulk_save(batch_chunks) + return processed diff --git a/backend/tests/apps/ai/common/utils_test.py b/backend/tests/apps/ai/common/utils_test.py index 8186ce78da..a1d586b3bb 100644 --- a/backend/tests/apps/ai/common/utils_test.py +++ b/backend/tests/apps/ai/common/utils_test.py @@ -1,5 +1,8 @@ +from datetime import UTC, datetime, timedelta from unittest.mock import MagicMock, call, patch +import openai + from apps.ai.common.utils import create_chunks_and_embeddings @@ -12,10 +15,17 @@ class TestUtils: @patch("apps.ai.common.utils.Context") @patch("apps.ai.common.utils.Chunk.update_data") @patch("apps.ai.common.utils.time.sleep") + @patch("apps.ai.common.utils.datetime") def test_create_chunks_and_embeddings_success( - self, mock_sleep, mock_update_data, mock_context + self, mock_datetime, mock_sleep, mock_update_data, mock_context ): """Tests the successful path where the OpenAI API returns embeddings.""" + base_time = datetime.now(UTC) + mock_datetime.now.return_value = base_time + mock_datetime.UTC = UTC + + mock_datetime.timedelta = timedelta + mock_openai_client = MagicMock() mock_api_response = MagicMock() mock_api_response.data = [ @@ -44,15 +54,15 @@ def test_create_chunks_and_embeddings_success( [ call( text="first chunk", - context=mock_context(), + context=mock_content_object, embedding=[0.1, 0.2], - save=False, + save=True, ), call( text="second chunk", - context=mock_context(), + context=mock_content_object, embedding=[0.3, 0.4], - save=False, + save=True, ), ] ) @@ -65,14 +75,47 @@ def test_create_chunks_and_embeddings_success( def test_create_chunks_and_embeddings_api_error(self, mock_logger): """Tests the failure path where the OpenAI API raises an exception.""" mock_openai_client = MagicMock() - mock_openai_client.embeddings.create.side_effect = Exception("API connection failed") + + mock_openai_client.embeddings.create.side_effect = openai.OpenAIError( + "API connection failed" + ) result = create_chunks_and_embeddings( - all_chunk_texts=["some text"], - content_object=MagicMock(), + chunk_texts=["some text"], + context=MagicMock(), openai_client=mock_openai_client, ) - mock_logger.exception.assert_called_once_with("OpenAI API error") + mock_logger.exception.assert_called_once_with("Failed to create chunks and embeddings") assert result == [] + + @patch("apps.ai.common.utils.Context") + @patch("apps.ai.common.utils.Chunk.update_data") + @patch("apps.ai.common.utils.time.sleep") + @patch("apps.ai.common.utils.datetime") + def test_create_chunks_and_embeddings_no_sleep_with_current_settings( + self, mock_datetime, mock_sleep, mock_update_data, mock_context + ): + """Tests that sleep is not called with current offset settings.""" + base_time = datetime.now(UTC) + mock_datetime.now.return_value = base_time + mock_datetime.UTC = UTC + mock_datetime.timedelta = timedelta + + mock_openai_client = MagicMock() + mock_api_response = MagicMock() + mock_api_response.data = [MockEmbeddingData([0.1, 0.2])] + mock_openai_client.embeddings.create.return_value = mock_api_response + + mock_update_data.return_value = "mock_chunk_instance" + + result = create_chunks_and_embeddings( + ["test chunk"], + MagicMock(), + mock_openai_client, + ) + + mock_sleep.assert_not_called() + + assert result == ["mock_chunk_instance"] From 697a406edb50c8559ca0a4370dac1c528411491d Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Thu, 7 Aug 2025 21:22:16 +0530 Subject: [PATCH 13/32] update method for context --- backend/apps/ai/common/utils.py | 13 +++++++++-- backend/apps/ai/models/context.py | 37 +++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/backend/apps/ai/common/utils.py b/backend/apps/ai/common/utils.py index ce2a7054c5..062b71fa9f 100644 --- a/backend/apps/ai/common/utils.py +++ b/backend/apps/ai/common/utils.py @@ -5,6 +5,7 @@ from datetime import UTC, datetime, timedelta import openai +from django.contrib.contenttypes.models import ContentType from apps.ai.common.constants import ( DEFAULT_LAST_REQUEST_OFFSET_SECONDS, @@ -28,8 +29,16 @@ def create_context(content: str, content_object=None, source: str = "") -> Conte Context: Created Context instance """ - context = Context(content=content, content_object=content_object, source=source) - context.save() + context = Context.update_data(content=content, content_object=content_object, source=source) + if context is None: + if content_object: + content_type = ContentType.objects.get_for_model(content_object) + context = Context.objects.get( + content_type=content_type, object_id=content_object.pk, content=content + ) + else: + context = Context.objects.get(content=content, content_object__isnull=True) + return context diff --git a/backend/apps/ai/models/context.py b/backend/apps/ai/models/context.py index 209a82f59c..bc02cf2fb2 100644 --- a/backend/apps/ai/models/context.py +++ b/backend/apps/ai/models/context.py @@ -32,3 +32,40 @@ def __str__(self): f"{self.content_type.model if self.content_type else 'None'} {entity}: " f"{self.content[:50]}" ) + + @staticmethod + def update_data( + content: str, + content_object=None, + source: str = "", + *, + save: bool = True, + ) -> "Context | None": + """Update context data. + + Args: + content (str): The content text of the context. + content_object: Optional related object (generic foreign key). + source (str): Source identifier for the context. + save (bool): Whether to save the context to the database. + + Returns: + Context: The updated context instance or None if it already exists. + + """ + if content_object: + content_type = ContentType.objects.get_for_model(content_object) + object_id = content_object.pk + if Context.objects.filter( + content_type=content_type, object_id=object_id, content=content + ).exists(): + return None + elif Context.objects.filter(content=content, content_object__isnull=True).exists(): + return None + + context = Context(content=content, content_object=content_object, source=source) + + if save: + context.save() + + return context From a3255ffe9fd05d3c0c2a770158cc85ad38235ccc Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Sat, 9 Aug 2025 21:51:56 +0530 Subject: [PATCH 14/32] major revamp and test cases --- backend/apps/ai/Makefile | 16 + backend/apps/ai/common/extractors.py | 273 ++++++++++++++++++ backend/apps/ai/common/utils.py | 18 +- .../commands/ai_create_chapter_chunks.py | 126 +------- .../commands/ai_create_chapter_context.py | 77 +++++ .../commands/ai_create_committee_chunks.py | 104 +------ .../commands/ai_create_committee_context.py | 81 ++++++ .../commands/ai_create_event_chunks.py | 96 +----- .../commands/ai_create_event_context.py | 75 +++++ .../commands/ai_create_project_chunks.py | 155 +--------- .../commands/ai_create_project_context.py | 77 +++++ backend/apps/ai/models/chunk.py | 14 +- backend/tests/apps/ai/common/utils_test.py | 35 ++- .../commands/ai_create_chapter_chunks_test.py | 237 +++++++++++++++ .../ai_create_chapter_context_test.py | 210 ++++++++++++++ .../ai_create_committee_chunks_test.py | 154 ++++++++++ .../ai_create_committee_context_test.py | 120 ++++++++ .../commands/ai_create_event_chunks_test.py | 120 ++++++++ .../commands/ai_create_event_context_test.py | 98 +++++++ .../commands/ai_create_project_chunks_test.py | 144 +++++++++ .../ai_create_project_context_test.py | 118 ++++++++ .../ai_create_slack_message_chunks_test.py | 148 ++++++++++ .../commands/ai_run_rag_tool_test.py | 142 +++++++++ backend/tests/apps/ai/models/chunk_test.py | 92 +++--- .../tests/apps/slack/management/__init__.py | 0 25 files changed, 2194 insertions(+), 536 deletions(-) create mode 100644 backend/apps/ai/common/extractors.py create mode 100644 backend/apps/ai/management/commands/ai_create_chapter_context.py create mode 100644 backend/apps/ai/management/commands/ai_create_committee_context.py create mode 100644 backend/apps/ai/management/commands/ai_create_event_context.py create mode 100644 backend/apps/ai/management/commands/ai_create_project_context.py create mode 100644 backend/tests/apps/ai/management/commands/ai_create_chapter_chunks_test.py create mode 100644 backend/tests/apps/ai/management/commands/ai_create_chapter_context_test.py create mode 100644 backend/tests/apps/ai/management/commands/ai_create_committee_chunks_test.py create mode 100644 backend/tests/apps/ai/management/commands/ai_create_committee_context_test.py create mode 100644 backend/tests/apps/ai/management/commands/ai_create_event_chunks_test.py create mode 100644 backend/tests/apps/ai/management/commands/ai_create_event_context_test.py create mode 100644 backend/tests/apps/ai/management/commands/ai_create_project_chunks_test.py create mode 100644 backend/tests/apps/ai/management/commands/ai_create_project_context_test.py create mode 100644 backend/tests/apps/ai/management/commands/ai_create_slack_message_chunks_test.py create mode 100644 backend/tests/apps/ai/management/commands/ai_run_rag_tool_test.py create mode 100644 backend/tests/apps/slack/management/__init__.py diff --git a/backend/apps/ai/Makefile b/backend/apps/ai/Makefile index cff4221abe..1219c094c8 100644 --- a/backend/apps/ai/Makefile +++ b/backend/apps/ai/Makefile @@ -1,15 +1,31 @@ +ai-create-chapter-context: + @echo "Creating chapter context" + @CMD="python manage.py ai_create_chapter_context" $(MAKE) exec-backend-command + ai-create-chapter-chunks: @echo "Creating chapter chunks" @CMD="python manage.py ai_create_chapter_chunks" $(MAKE) exec-backend-command +ai-create-committee-context: + @echo "Creating committee context" + @CMD="python manage.py ai_create_committee_context" $(MAKE) exec-backend-command + ai-create-committee-chunks: @echo "Creating committee chunks" @CMD="python manage.py ai_create_committee_chunks" $(MAKE) exec-backend-command +ai-create-event-context: + @echo "Creating event context" + @CMD="python manage.py ai_create_event_context" $(MAKE) exec-backend-command + ai-create-event-chunks: @echo "Creating event chunks" @CMD="python manage.py ai_create_event_chunks" $(MAKE) exec-backend-command +ai-create-project-context: + @echo "Creating project context" + @CMD="python manage.py ai_create_project_context" $(MAKE) exec-backend-command + ai-create-project-chunks: @echo "Creating project chunks" @CMD="python manage.py ai_create_project_chunks" $(MAKE) exec-backend-command diff --git a/backend/apps/ai/common/extractors.py b/backend/apps/ai/common/extractors.py new file mode 100644 index 0000000000..ff2d4f6b9b --- /dev/null +++ b/backend/apps/ai/common/extractors.py @@ -0,0 +1,273 @@ +"""Content extractors for various models.""" + +from apps.ai.common.constants import DELIMITER + + +def extract_committee_content(committee) -> tuple[str, str]: + """Extract structured content from committee data. + + Args: + committee: Committee instance + + Returns: + tuple[str, str]: (prose_content, metadata_content) + + """ + prose_parts = [] + metadata_parts = [] + + if committee.description: + prose_parts.append(f"Description: {committee.description}") + + if committee.summary: + prose_parts.append(f"Summary: {committee.summary}") + + if hasattr(committee, "owasp_repository") and committee.owasp_repository: + repo = committee.owasp_repository + if repo.description: + prose_parts.append(f"Repository Description: {repo.description}") + if repo.topics: + metadata_parts.append(f"Repository Topics: {', '.join(repo.topics)}") + + if committee.name: + metadata_parts.append(f"Committee Name: {committee.name}") + + if committee.tags: + metadata_parts.append(f"Tags: {', '.join(committee.tags)}") + + if committee.topics: + metadata_parts.append(f"Topics: {', '.join(committee.topics)}") + + if committee.leaders_raw: + metadata_parts.append(f"Committee Leaders: {', '.join(committee.leaders_raw)}") + + if committee.related_urls: + valid_urls = [ + url + for url in committee.related_urls + if url and url not in (committee.invalid_urls or []) + ] + if valid_urls: + metadata_parts.append(f"Related URLs: {', '.join(valid_urls)}") + + metadata_parts.append(f"Active Committee: {'Yes' if committee.is_active else 'No'}") + + return ( + DELIMITER.join(filter(None, prose_parts)), + DELIMITER.join(filter(None, metadata_parts)), + ) + + +def extract_chapter_content(chapter) -> tuple[str, str]: + """Extract structured content from chapter data. + + Args: + chapter: Chapter instance + + Returns: + tuple[str, str]: (prose_content, metadata_content) + + """ + prose_parts = [] + metadata_parts = [] + + if chapter.description: + prose_parts.append(f"Description: {chapter.description}") + + if chapter.summary: + prose_parts.append(f"Summary: {chapter.summary}") + + if hasattr(chapter, "owasp_repository") and chapter.owasp_repository: + repo = chapter.owasp_repository + if repo.description: + prose_parts.append(f"Repository Description: {repo.description}") + if repo.topics: + metadata_parts.append(f"Repository Topics: {', '.join(repo.topics)}") + + if chapter.name: + metadata_parts.append(f"Chapter Name: {chapter.name}") + + location_parts = [] + if chapter.country: + location_parts.append(f"Country: {chapter.country}") + if chapter.region: + location_parts.append(f"Region: {chapter.region}") + if chapter.postal_code: + location_parts.append(f"Postal Code: {chapter.postal_code}") + if chapter.suggested_location: + location_parts.append(f"Location: {chapter.suggested_location}") + + if location_parts: + metadata_parts.append(f"Location Information: {', '.join(location_parts)}") + + if chapter.currency: + metadata_parts.append(f"Currency: {chapter.currency}") + + if chapter.meetup_group: + metadata_parts.append(f"Meetup Group: {chapter.meetup_group}") + + if chapter.tags: + metadata_parts.append(f"Tags: {', '.join(chapter.tags)}") + + if chapter.topics: + metadata_parts.append(f"Topics: {', '.join(chapter.topics)}") + + if chapter.leaders_raw: + metadata_parts.append(f"Chapter Leaders: {', '.join(chapter.leaders_raw)}") + + if chapter.related_urls: + valid_urls = [ + url for url in chapter.related_urls if url and url not in (chapter.invalid_urls or []) + ] + if valid_urls: + metadata_parts.append(f"Related URLs: {', '.join(valid_urls)}") + + metadata_parts.append(f"Active Chapter: {'Yes' if chapter.is_active else 'No'}") + + return ( + DELIMITER.join(filter(None, prose_parts)), + DELIMITER.join(filter(None, metadata_parts)), + ) + + +def extract_event_content(event) -> tuple[str, str]: + """Extract structured content from event data. + + Args: + event: Event instance + + Returns: + tuple[str, str]: (prose_content, metadata_content) + + """ + prose_parts = [] + metadata_parts = [] + + if event.description: + prose_parts.append(f"Description: {event.description}") + + if event.summary: + prose_parts.append(f"Summary: {event.summary}") + + if event.name: + metadata_parts.append(f"Event Name: {event.name}") + + if event.category: + metadata_parts.append(f"Category: {event.get_category_display()}") + + if event.start_date: + metadata_parts.append(f"Start Date: {event.start_date}") + + if event.end_date: + metadata_parts.append(f"End Date: {event.end_date}") + + if event.suggested_location: + metadata_parts.append(f"Location: {event.suggested_location}") + + if event.latitude and event.longitude: + metadata_parts.append(f"Coordinates: {event.latitude}, {event.longitude}") + + if event.url: + metadata_parts.append(f"Event URL: {event.url}") + + return ( + DELIMITER.join(filter(None, prose_parts)), + DELIMITER.join(filter(None, metadata_parts)), + ) + + +def extract_project_content(project) -> tuple[str, str]: + """Extract structured content from project data. + + Args: + project: Project instance + + Returns: + tuple[str, str]: (prose_content, metadata_content) + + """ + prose_parts = [] + metadata_parts = [] + + if project.description: + prose_parts.append(f"Description: {project.description}") + + if project.summary: + prose_parts.append(f"Summary: {project.summary}") + + if hasattr(project, "owasp_repository") and project.owasp_repository: + repo = project.owasp_repository + if repo.description: + prose_parts.append(f"Repository Description: {repo.description}") + if repo.topics: + metadata_parts.append(f"Repository Topics: {', '.join(repo.topics)}") + + if project.name: + metadata_parts.append(f"Project Name: {project.name}") + + if project.level: + metadata_parts.append(f"Project Level: {project.level}") + + if project.type: + metadata_parts.append(f"Project Type: {project.type}") + + if project.languages: + metadata_parts.append(f"Programming Languages: {', '.join(project.languages)}") + + if project.topics: + metadata_parts.append(f"Topics: {', '.join(project.topics)}") + + if project.licenses: + metadata_parts.append(f"Licenses: {', '.join(project.licenses)}") + + if project.tags: + metadata_parts.append(f"Tags: {', '.join(project.tags)}") + + if project.custom_tags: + metadata_parts.append(f"Custom Tags: {', '.join(project.custom_tags)}") + + stats_parts = [] + if project.stars_count > 0: + stats_parts.append(f"Stars: {project.stars_count}") + if project.forks_count > 0: + stats_parts.append(f"Forks: {project.forks_count}") + if project.contributors_count > 0: + stats_parts.append(f"Contributors: {project.contributors_count}") + if project.releases_count > 0: + stats_parts.append(f"Releases: {project.releases_count}") + if project.open_issues_count > 0: + stats_parts.append(f"Open Issues: {project.open_issues_count}") + + if stats_parts: + metadata_parts.append("Project Statistics: " + ", ".join(stats_parts)) + + if project.leaders_raw: + metadata_parts.append(f"Project Leaders: {', '.join(project.leaders_raw)}") + + if project.related_urls: + valid_urls = [ + url for url in project.related_urls if url and url not in (project.invalid_urls or []) + ] + if valid_urls: + metadata_parts.append(f"Related URLs: {', '.join(valid_urls)}") + + if project.created_at: + metadata_parts.append(f"Created: {project.created_at.strftime('%Y-%m-%d')}") + + if project.updated_at: + metadata_parts.append(f"Last Updated: {project.updated_at.strftime('%Y-%m-%d')}") + + if project.released_at: + metadata_parts.append(f"Last Release: {project.released_at.strftime('%Y-%m-%d')}") + + if project.health_score is not None: + metadata_parts.append(f"Health Score: {project.health_score:.2f}") + + metadata_parts.append(f"Active Project: {'Yes' if project.is_active else 'No'}") + + metadata_parts.append(f"Issue Tracking: {'Enabled' if project.track_issues else 'Disabled'}") + + return ( + DELIMITER.join(filter(None, prose_parts)), + DELIMITER.join(filter(None, metadata_parts)), + ) diff --git a/backend/apps/ai/common/utils.py b/backend/apps/ai/common/utils.py index 062b71fa9f..744558865e 100644 --- a/backend/apps/ai/common/utils.py +++ b/backend/apps/ai/common/utils.py @@ -62,7 +62,14 @@ def create_chunks_and_embeddings( Returns: list[Chunk]: List of created Chunk instances (empty if failed) + Raises: + ValueError: If context is None or invalid + """ + if context is None: + error_msg = "Context is required for chunk creation.please create a context first." + raise ValueError(error_msg) + try: last_request_time = datetime.now(UTC) - timedelta( seconds=DEFAULT_LAST_REQUEST_OFFSET_SECONDS @@ -80,12 +87,17 @@ def create_chunks_and_embeddings( chunks = [] for text, embedding in zip(chunk_texts, embeddings, strict=True): - chunk = Chunk.update_data(text=text, context=context, embedding=embedding, save=save) - if chunk: - chunks.append(chunk) + chunk = Chunk.update_data(text=text, embedding=embedding, save=False) + chunk.context = context + if save: + chunk.save() + chunks.append(chunk) except openai.OpenAIError: logger.exception("Failed to create chunks and embeddings") return [] + except ValueError: + logger.exception("Context error") + raise else: return chunks diff --git a/backend/apps/ai/management/commands/ai_create_chapter_chunks.py b/backend/apps/ai/management/commands/ai_create_chapter_chunks.py index 427be08dda..e5cf45631d 100644 --- a/backend/apps/ai/management/commands/ai_create_chapter_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_chapter_chunks.py @@ -6,8 +6,8 @@ from django.contrib.contenttypes.models import ContentType from django.core.management.base import BaseCommand -from apps.ai.common.constants import DELIMITER -from apps.ai.common.utils import create_chunks_and_embeddings, create_context +from apps.ai.common.extractors import extract_chapter_content +from apps.ai.common.utils import create_chunks_and_embeddings from apps.ai.models.chunk import Chunk from apps.ai.models.context import Context from apps.owasp.models.chapter import Chapter @@ -33,32 +33,15 @@ def add_arguments(self, parser): default=50, help="Number of chapters to process in each batch", ) - parser.add_argument( - "--context", - action="store_true", - help="Create only context (skip chunks and embeddings)", - ) - parser.add_argument( - "--chunks", - action="store_true", - help="Create only chunks+embeddings (requires existing context)", - ) def handle(self, *args, **options): - if not options["context"] and not options["chunks"]: - self.stdout.write( - self.style.ERROR("Please specify either --context or --chunks (or both)") - ) - return - - if options["chunks"] and not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): + if not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): self.stdout.write( self.style.ERROR("DJANGO_OPEN_AI_SECRET_KEY environment variable not set") ) return - if options["chunks"]: - self.openai_client = openai.OpenAI(api_key=openai_api_key) + self.openai_client = openai.OpenAI(api_key=openai_api_key) if options["chapter_key"]: queryset = Chapter.objects.filter(key=options["chapter_key"]) @@ -78,39 +61,12 @@ def handle(self, *args, **options): for offset in range(0, total_chapters, batch_size): batch_chapters = queryset[offset : offset + batch_size] - - if options["context"]: - processed_count += self.process_context_batch(batch_chapters) - elif options["chunks"]: - processed_count += self.process_chunks_batch(batch_chapters) + processed_count += self.process_chunks_batch(batch_chapters) self.stdout.write( self.style.SUCCESS(f"Completed processing {processed_count}/{total_chapters} chapters") ) - def process_context_batch(self, chapters: list[Chapter]) -> int: - """Process a batch of chapters to create contexts.""" - processed = 0 - - for chapter in chapters: - prose_content, metadata_content = self.extract_chapter_content(chapter) - full_content = ( - f"{metadata_content}\n\n{prose_content}" if metadata_content else prose_content - ) - - if not full_content.strip(): - self.stdout.write(f"No content for chapter {chapter.key}") - continue - - if create_context( - content=full_content, content_object=chapter, source="owasp_chapter" - ): - processed += 1 - self.stdout.write(f"Created context for {chapter.key}") - else: - self.stdout.write(self.style.ERROR(f"Failed to create context for {chapter.key}")) - return processed - def process_chunks_batch(self, chapters: list[Chapter]) -> int: """Process a batch of chapters to create chunks.""" processed = 0 @@ -129,7 +85,7 @@ def process_chunks_batch(self, chapters: list[Chapter]) -> int: ) continue - prose_content, metadata_content = self.extract_chapter_content(chapter) + prose_content, metadata_content = extract_chapter_content(chapter) all_chunk_texts = [] if metadata_content.strip(): @@ -156,73 +112,3 @@ def process_chunks_batch(self, chapters: list[Chapter]) -> int: if batch_chunks: Chunk.bulk_save(batch_chunks) return processed - - def extract_chapter_content(self, chapter: Chapter) -> tuple[str, str]: - """Extract and separate prose content from metadata for a chapter. - - Returns: - tuple[str, str]: (prose_content, metadata_content) - - """ - prose_parts = [] - metadata_parts = [] - - if chapter.description: - prose_parts.append(f"Description: {chapter.description}") - - if chapter.summary: - prose_parts.append(f"Summary: {chapter.summary}") - - if hasattr(chapter, "owasp_repository") and chapter.owasp_repository: - repo = chapter.owasp_repository - if repo.description: - prose_parts.append(f"Repository Description: {repo.description}") - if repo.topics: - metadata_parts.append(f"Repository Topics: {', '.join(repo.topics)}") - - if chapter.name: - metadata_parts.append(f"Chapter Name: {chapter.name}") - - location_parts = [] - if chapter.country: - location_parts.append(f"Country: {chapter.country}") - if chapter.region: - location_parts.append(f"Region: {chapter.region}") - if chapter.postal_code: - location_parts.append(f"Postal Code: {chapter.postal_code}") - if chapter.suggested_location: - location_parts.append(f"Location: {chapter.suggested_location}") - - if location_parts: - metadata_parts.append(f"Location Information: {', '.join(location_parts)}") - - if chapter.currency: - metadata_parts.append(f"Currency: {chapter.currency}") - - if chapter.meetup_group: - metadata_parts.append(f"Meetup Group: {chapter.meetup_group}") - - if chapter.tags: - metadata_parts.append(f"Tags: {', '.join(chapter.tags)}") - - if chapter.topics: - metadata_parts.append(f"Topics: {', '.join(chapter.topics)}") - - if chapter.leaders_raw: - metadata_parts.append(f"Chapter Leaders: {', '.join(chapter.leaders_raw)}") - - if chapter.related_urls: - valid_urls = [ - url - for url in chapter.related_urls - if url and url not in (chapter.invalid_urls or []) - ] - if valid_urls: - metadata_parts.append(f"Related URLs: {', '.join(valid_urls)}") - - metadata_parts.append(f"Active Chapter: {'Yes' if chapter.is_active else 'No'}") - - return ( - DELIMITER.join(filter(None, prose_parts)), - DELIMITER.join(filter(None, metadata_parts)), - ) diff --git a/backend/apps/ai/management/commands/ai_create_chapter_context.py b/backend/apps/ai/management/commands/ai_create_chapter_context.py new file mode 100644 index 0000000000..1d5a37f434 --- /dev/null +++ b/backend/apps/ai/management/commands/ai_create_chapter_context.py @@ -0,0 +1,77 @@ +"""A command to update context for OWASP chapter data.""" + +from django.core.management.base import BaseCommand + +from apps.ai.common.extractors import extract_chapter_content +from apps.ai.common.utils import create_context +from apps.owasp.models.chapter import Chapter + + +class Command(BaseCommand): + help = "Update context for OWASP chapter data" + + def add_arguments(self, parser): + parser.add_argument( + "--chapter-key", + type=str, + help="Process only the chapter with this key", + ) + parser.add_argument( + "--all", + action="store_true", + help="Process all the chapters", + ) + parser.add_argument( + "--batch-size", + type=int, + default=50, + help="Number of chapters to process in each batch", + ) + + def handle(self, *args, **options): + if options["chapter_key"]: + queryset = Chapter.objects.filter(key=options["chapter_key"]) + elif options["all"]: + queryset = Chapter.objects.all() + else: + queryset = Chapter.objects.filter(is_active=True) + + if not (total_chapters := queryset.count()): + self.stdout.write("No chapters found to process") + return + + self.stdout.write(f"Found {total_chapters} chapters to process") + + batch_size = options["batch_size"] + processed_count = 0 + + for offset in range(0, total_chapters, batch_size): + batch_chapters = queryset[offset : offset + batch_size] + processed_count += self.process_context_batch(batch_chapters) + + self.stdout.write( + self.style.SUCCESS(f"Completed processing {processed_count}/{total_chapters} chapters") + ) + + def process_context_batch(self, chapters: list[Chapter]) -> int: + """Process a batch of chapters to create contexts.""" + processed = 0 + + for chapter in chapters: + prose_content, metadata_content = extract_chapter_content(chapter) + full_content = ( + f"{metadata_content}\n\n{prose_content}" if metadata_content else prose_content + ) + + if not full_content.strip(): + self.stdout.write(f"No content for chapter {chapter.key}") + continue + + if create_context( + content=full_content, content_object=chapter, source="owasp_chapter" + ): + processed += 1 + self.stdout.write(f"Created context for {chapter.key}") + else: + self.stdout.write(self.style.ERROR(f"Failed to create context for {chapter.key}")) + return processed diff --git a/backend/apps/ai/management/commands/ai_create_committee_chunks.py b/backend/apps/ai/management/commands/ai_create_committee_chunks.py index 3efcd3b0fa..268e717910 100644 --- a/backend/apps/ai/management/commands/ai_create_committee_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_committee_chunks.py @@ -6,8 +6,8 @@ from django.contrib.contenttypes.models import ContentType from django.core.management.base import BaseCommand -from apps.ai.common.constants import DELIMITER -from apps.ai.common.utils import create_chunks_and_embeddings, create_context +from apps.ai.common.extractors import extract_committee_content +from apps.ai.common.utils import create_chunks_and_embeddings from apps.ai.models.chunk import Chunk from apps.ai.models.context import Context from apps.owasp.models.committee import Committee @@ -33,32 +33,15 @@ def add_arguments(self, parser): default=50, help="Number of committees to process in each batch", ) - parser.add_argument( - "--context", - action="store_true", - help="Create only context (skip chunks and embeddings)", - ) - parser.add_argument( - "--chunks", - action="store_true", - help="Create only chunks+embeddings (requires existing context)", - ) def handle(self, *args, **options): - if not options["context"] and not options["chunks"]: - self.stdout.write( - self.style.ERROR("Please specify either --context or --chunks (or both)") - ) - return - - if options["chunks"] and not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): + if not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): self.stdout.write( self.style.ERROR("DJANGO_OPEN_AI_SECRET_KEY environment variable not set") ) return - if options["chunks"]: - self.openai_client = openai.OpenAI(api_key=openai_api_key) + self.openai_client = openai.OpenAI(api_key=openai_api_key) if options["committee_key"]: queryset = Committee.objects.filter(key=options["committee_key"]) @@ -78,11 +61,7 @@ def handle(self, *args, **options): for offset in range(0, total_committees, batch_size): batch_committees = queryset[offset : offset + batch_size] - - if options["context"]: - processed_count += self.process_context_batch(batch_committees) - elif options["chunks"]: - processed_count += self.process_chunks_batch(batch_committees) + processed_count += self.process_chunks_batch(batch_committees) self.stdout.write( self.style.SUCCESS( @@ -90,31 +69,6 @@ def handle(self, *args, **options): ) ) - def process_context_batch(self, committees: list[Committee]) -> int: - """Process a batch of committees to create contexts.""" - processed = 0 - - for committee in committees: - prose_content, metadata_content = self.extract_committee_content(committee) - full_content = ( - f"{metadata_content}\n\n{prose_content}" if metadata_content else prose_content - ) - - if not full_content.strip(): - self.stdout.write(f"No content for committee {committee.key}") - continue - - if create_context( - content=full_content, content_object=committee, source="owasp_committee" - ): - processed += 1 - self.stdout.write(f"Created context for {committee.key}") - else: - self.stdout.write( - self.style.ERROR(f"Failed to create context for {committee.key}") - ) - return processed - def process_chunks_batch(self, committees: list[Committee]) -> int: """Process a batch of committees to create chunks.""" processed = 0 @@ -133,7 +87,7 @@ def process_chunks_batch(self, committees: list[Committee]) -> int: ) continue - prose_content, metadata_content = self.extract_committee_content(committee) + prose_content, metadata_content = extract_committee_content(committee) all_chunk_texts = [] if metadata_content.strip(): @@ -160,49 +114,3 @@ def process_chunks_batch(self, committees: list[Committee]) -> int: if batch_chunks: Chunk.bulk_save(batch_chunks) return processed - - def extract_committee_content(self, committee: Committee) -> tuple[str, str]: - """Extract structured content from committee data.""" - prose_parts = [] - metadata_parts = [] - - if committee.description: - prose_parts.append(f"Description: {committee.description}") - - if committee.summary: - prose_parts.append(f"Summary: {committee.summary}") - - if hasattr(committee, "owasp_repository") and committee.owasp_repository: - repo = committee.owasp_repository - if repo.description: - prose_parts.append(f"Repository Description: {repo.description}") - if repo.topics: - metadata_parts.append(f"Repository Topics: {', '.join(repo.topics)}") - - if committee.name: - metadata_parts.append(f"Committee Name: {committee.name}") - - if committee.tags: - metadata_parts.append(f"Tags: {', '.join(committee.tags)}") - - if committee.topics: - metadata_parts.append(f"Topics: {', '.join(committee.topics)}") - - if committee.leaders_raw: - metadata_parts.append(f"Committee Leaders: {', '.join(committee.leaders_raw)}") - - if committee.related_urls: - valid_urls = [ - url - for url in committee.related_urls - if url and url not in (committee.invalid_urls or []) - ] - if valid_urls: - metadata_parts.append(f"Related URLs: {', '.join(valid_urls)}") - - metadata_parts.append(f"Active Committee: {'Yes' if committee.is_active else 'No'}") - - return ( - DELIMITER.join(filter(None, prose_parts)), - DELIMITER.join(filter(None, metadata_parts)), - ) diff --git a/backend/apps/ai/management/commands/ai_create_committee_context.py b/backend/apps/ai/management/commands/ai_create_committee_context.py new file mode 100644 index 0000000000..2802846b74 --- /dev/null +++ b/backend/apps/ai/management/commands/ai_create_committee_context.py @@ -0,0 +1,81 @@ +"""A command to update context for OWASP committee data.""" + +from django.core.management.base import BaseCommand + +from apps.ai.common.extractors import extract_committee_content +from apps.ai.common.utils import create_context +from apps.owasp.models.committee import Committee + + +class Command(BaseCommand): + help = "Update context for OWASP committee data" + + def add_arguments(self, parser): + parser.add_argument( + "--committee-key", + type=str, + help="Process only the committee with this key", + ) + parser.add_argument( + "--all", + action="store_true", + help="Process all the committees", + ) + parser.add_argument( + "--batch-size", + type=int, + default=50, + help="Number of committees to process in each batch", + ) + + def handle(self, *args, **options): + if options["committee_key"]: + queryset = Committee.objects.filter(key=options["committee_key"]) + elif options["all"]: + queryset = Committee.objects.all() + else: + queryset = Committee.objects.filter(is_active=True) + + if not (total_committees := queryset.count()): + self.stdout.write("No committees found to process") + return + + self.stdout.write(f"Found {total_committees} committees to process") + + batch_size = options["batch_size"] + processed_count = 0 + + for offset in range(0, total_committees, batch_size): + batch_committees = queryset[offset : offset + batch_size] + processed_count += self.process_context_batch(batch_committees) + + self.stdout.write( + self.style.SUCCESS( + f"Completed processing {processed_count}/{total_committees} committees" + ) + ) + + def process_context_batch(self, committees: list[Committee]) -> int: + """Process a batch of committees to create contexts.""" + processed = 0 + + for committee in committees: + prose_content, metadata_content = extract_committee_content(committee) + full_content = ( + f"{metadata_content}\n\n{prose_content}" if metadata_content else prose_content + ) + + if not full_content.strip(): + self.stdout.write(f"No content for committee {committee.key}") + continue + + if create_context( + content=full_content, content_object=committee, source="owasp_committee" + ): + processed += 1 + self.stdout.write(f"Created context for {committee.key}") + else: + self.stdout.write( + self.style.ERROR(f"Failed to create context for {committee.key}") + ) + return processed diff --git a/backend/apps/ai/management/commands/ai_create_event_chunks.py b/backend/apps/ai/management/commands/ai_create_event_chunks.py index 4a3def233b..40569a6532 100644 --- a/backend/apps/ai/management/commands/ai_create_event_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_event_chunks.py @@ -6,8 +6,8 @@ from django.contrib.contenttypes.models import ContentType from django.core.management.base import BaseCommand -from apps.ai.common.constants import DELIMITER -from apps.ai.common.utils import create_chunks_and_embeddings, create_context +from apps.ai.common.extractors import extract_event_content +from apps.ai.common.utils import create_chunks_and_embeddings from apps.ai.models.chunk import Chunk from apps.ai.models.context import Context from apps.owasp.models.event import Event @@ -33,32 +33,15 @@ def add_arguments(self, parser): default=50, help="Number of events to process in each batch", ) - parser.add_argument( - "--context", - action="store_true", - help="Create only context (skip chunks and embeddings)", - ) - parser.add_argument( - "--chunks", - action="store_true", - help="Create only chunks+embeddings (requires existing context)", - ) def handle(self, *args, **options): - if not options["context"] and not options["chunks"]: - self.stdout.write( - self.style.ERROR("Please specify either --context or --chunks (or both)") - ) - return - - if options["chunks"] and not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): + if not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): self.stdout.write( self.style.ERROR("DJANGO_OPEN_AI_SECRET_KEY environment variable not set") ) return - if options["chunks"]: - self.openai_client = openai.OpenAI(api_key=openai_api_key) + self.openai_client = openai.OpenAI(api_key=openai_api_key) if options["event_key"]: queryset = Event.objects.filter(key=options["event_key"]) @@ -78,37 +61,12 @@ def handle(self, *args, **options): for offset in range(0, total_events, batch_size): batch_events = queryset[offset : offset + batch_size] - - if options["context"]: - processed_count += self.process_context_batch(batch_events) - elif options["chunks"]: - processed_count += self.process_chunks_batch(batch_events) + processed_count += self.process_chunks_batch(batch_events) self.stdout.write( self.style.SUCCESS(f"Completed processing {processed_count}/{total_events} events") ) - def process_context_batch(self, events: list[Event]) -> int: - """Process a batch of events to create contexts.""" - processed = 0 - - for event in events: - prose_content, metadata_content = self.extract_event_content(event) - full_content = ( - f"{metadata_content}\n\n{prose_content}" if metadata_content else prose_content - ) - - if not full_content.strip(): - self.stdout.write(f"No content for event {event.key}") - continue - - if create_context(content=full_content, content_object=event, source="owasp_event"): - processed += 1 - self.stdout.write(f"Created context for {event.key}") - else: - self.stdout.write(self.style.ERROR(f"Failed to create context for {event.key}")) - return processed - def process_chunks_batch(self, events: list[Event]) -> int: """Process a batch of events to create chunks.""" processed = 0 @@ -125,7 +83,7 @@ def process_chunks_batch(self, events: list[Event]) -> int: self.stdout.write(self.style.WARNING(f"No context found for event {event.key}")) continue - prose_content, metadata_content = self.extract_event_content(event) + prose_content, metadata_content = extract_event_content(event) all_chunk_texts = [] if metadata_content.strip(): @@ -152,45 +110,3 @@ def process_chunks_batch(self, events: list[Event]) -> int: if batch_chunks: Chunk.bulk_save(batch_chunks) return processed - - def extract_event_content(self, event: Event) -> tuple[str, str]: - """Extract and separate prose content from metadata for an event. - - Returns: - tuple[str, str]: (prose_content, metadata_content) - - """ - prose_parts = [] - metadata_parts = [] - - if event.description: - prose_parts.append(f"Description: {event.description}") - - if event.summary: - prose_parts.append(f"Summary: {event.summary}") - - if event.name: - metadata_parts.append(f"Event Name: {event.name}") - - if event.category: - metadata_parts.append(f"Category: {event.get_category_display()}") - - if event.start_date: - metadata_parts.append(f"Start Date: {event.start_date}") - - if event.end_date: - metadata_parts.append(f"End Date: {event.end_date}") - - if event.suggested_location: - metadata_parts.append(f"Location: {event.suggested_location}") - - if event.latitude and event.longitude: - metadata_parts.append(f"Coordinates: {event.latitude}, {event.longitude}") - - if event.url: - metadata_parts.append(f"Event URL: {event.url}") - - return ( - DELIMITER.join(filter(None, prose_parts)), - DELIMITER.join(filter(None, metadata_parts)), - ) diff --git a/backend/apps/ai/management/commands/ai_create_event_context.py b/backend/apps/ai/management/commands/ai_create_event_context.py new file mode 100644 index 0000000000..647f30514b --- /dev/null +++ b/backend/apps/ai/management/commands/ai_create_event_context.py @@ -0,0 +1,75 @@ +"""A command to update context for OWASP event data.""" + +from django.core.management.base import BaseCommand + +from apps.ai.common.extractors import extract_event_content +from apps.ai.common.utils import create_context +from apps.owasp.models.event import Event + + +class Command(BaseCommand): + help = "Update context for OWASP event data" + + def add_arguments(self, parser): + parser.add_argument( + "--event-key", + type=str, + help="Process only the event with this key", + ) + parser.add_argument( + "--all", + action="store_true", + help="Process all the events", + ) + parser.add_argument( + "--batch-size", + type=int, + default=50, + help="Number of events to process in each batch", + ) + + def handle(self, *args, **options): + if options["event_key"]: + queryset = Event.objects.filter(key=options["event_key"]) + elif options["all"]: + queryset = Event.objects.all() + else: + queryset = Event.upcoming_events() + + if not (total_events := queryset.count()): + self.stdout.write("No events found to process") + return + + self.stdout.write(f"Found {total_events} events to process") + + batch_size = options["batch_size"] + processed_count = 0 + + for offset in range(0, total_events, batch_size): + batch_events = queryset[offset : offset + batch_size] + processed_count += self.process_context_batch(batch_events) + + self.stdout.write( + self.style.SUCCESS(f"Completed processing {processed_count}/{total_events} events") + ) + + def process_context_batch(self, events: list[Event]) -> int: + """Process a batch of events to create contexts.""" + processed = 0 + + for event in events: + prose_content, metadata_content = extract_event_content(event) + full_content = ( + f"{metadata_content}\n\n{prose_content}" if metadata_content else prose_content + ) + + if not full_content.strip(): + self.stdout.write(f"No content for event {event.key}") + continue + + if create_context(content=full_content, content_object=event, source="owasp_event"): + processed += 1 + self.stdout.write(f"Created context for {event.key}") + else: + self.stdout.write(self.style.ERROR(f"Failed to create context for {event.key}")) + return processed diff --git a/backend/apps/ai/management/commands/ai_create_project_chunks.py b/backend/apps/ai/management/commands/ai_create_project_chunks.py index 62fb62ab58..ec1ec03bf2 100644 --- a/backend/apps/ai/management/commands/ai_create_project_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_project_chunks.py @@ -6,8 +6,8 @@ from django.contrib.contenttypes.models import ContentType from django.core.management.base import BaseCommand -from apps.ai.common.constants import DELIMITER -from apps.ai.common.utils import create_chunks_and_embeddings, create_context +from apps.ai.common.extractors import extract_project_content +from apps.ai.common.utils import create_chunks_and_embeddings from apps.ai.models.chunk import Chunk from apps.ai.models.context import Context from apps.owasp.models.project import Project @@ -18,39 +18,30 @@ class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument( - "--project-key", type=str, help="Process only the project with this key" + "--project-key", + type=str, + help="Process only the project with this key", + ) + parser.add_argument( + "--all", + action="store_true", + help="Process all the projects", ) - parser.add_argument("--all", action="store_true", help="Process all the projects") parser.add_argument( "--batch-size", type=int, default=50, help="Number of projects to process in each batch", ) - parser.add_argument( - "--context", - action="store_true", - help="Create only context (skip chunks and embeddings)", - ) - parser.add_argument( - "--chunks", - action="store_true", - help="Create only chunks+embeddings (requires existing context)", - ) def handle(self, *args, **options): - if not options["context"] and not options["chunks"]: - self.stdout.write(self.style.ERROR("Must specify either --context or --chunks")) - return - - if options["chunks"] and not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): + if not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): self.stdout.write( self.style.ERROR("DJANGO_OPEN_AI_SECRET_KEY environment variable not set") ) return - if options["chunks"]: - self.openai_client = openai.OpenAI(api_key=openai_api_key) + self.openai_client = openai.OpenAI(api_key=openai_api_key) if options["project_key"]: queryset = Project.objects.filter(key=options["project_key"]) @@ -70,39 +61,12 @@ def handle(self, *args, **options): for offset in range(0, total_projects, batch_size): batch_projects = queryset[offset : offset + batch_size] - - if options["context"]: - processed_count += self.process_context_batch(batch_projects) - elif options["chunks"]: - processed_count += self.process_chunks_batch(batch_projects) + processed_count += self.process_chunks_batch(batch_projects) self.stdout.write( self.style.SUCCESS(f"Completed processing {processed_count}/{total_projects} projects") ) - def process_context_batch(self, projects: list[Project]) -> int: - """Process a batch of projects to create contexts.""" - processed = 0 - - for project in projects: - prose_content, metadata_content = self.extract_project_content(project) - full_content = ( - f"{metadata_content}\n\n{prose_content}" if metadata_content else prose_content - ) - - if not full_content.strip(): - self.stdout.write(f"No content for project {project.key}") - continue - - if create_context( - content=full_content, content_object=project, source="owasp_project" - ): - processed += 1 - self.stdout.write(f"Created context for {project.key}") - else: - self.stdout.write(self.style.ERROR(f"Failed to create context for {project.key}")) - return processed - def process_chunks_batch(self, projects: list[Project]) -> int: """Process a batch of projects to create chunks.""" processed = 0 @@ -121,7 +85,7 @@ def process_chunks_batch(self, projects: list[Project]) -> int: ) continue - prose_content, metadata_content = self.extract_project_content(project) + prose_content, metadata_content = extract_project_content(project) all_chunk_texts = [] if metadata_content.strip(): @@ -148,94 +112,3 @@ def process_chunks_batch(self, projects: list[Project]) -> int: if batch_chunks: Chunk.bulk_save(batch_chunks) return processed - - def extract_project_content(self, project: Project) -> tuple[str, str]: - prose_parts = [] - metadata_parts = [] - - if project.name: - metadata_parts.append(f"Project Name: {project.name}") - - if project.description: - prose_parts.append(f"Description: {project.description}") - - if project.summary: - prose_parts.append(f"Summary: {project.summary}") - - if project.level: - metadata_parts.append(f"Project Level: {project.level}") - - if project.type: - metadata_parts.append(f"Project Type: {project.type}") - - if hasattr(project, "owasp_repository") and project.owasp_repository: - repo = project.owasp_repository - if repo.description: - prose_parts.append(f"Repository Description: {repo.description}") - if repo.topics: - metadata_parts.append(f"Repository Topics: {', '.join(repo.topics)}") - - if project.languages: - metadata_parts.append(f"Programming Languages: {', '.join(project.languages)}") - - if project.topics: - metadata_parts.append(f"Topics: {', '.join(project.topics)}") - - if project.licenses: - metadata_parts.append(f"Licenses: {', '.join(project.licenses)}") - - if project.tags: - metadata_parts.append(f"Tags: {', '.join(project.tags)}") - - if project.custom_tags: - metadata_parts.append(f"Custom Tags: {', '.join(project.custom_tags)}") - - stats_parts = [] - if project.stars_count > 0: - stats_parts.append(f"Stars: {project.stars_count}") - if project.forks_count > 0: - stats_parts.append(f"Forks: {project.forks_count}") - if project.contributors_count > 0: - stats_parts.append(f"Contributors: {project.contributors_count}") - if project.releases_count > 0: - stats_parts.append(f"Releases: {project.releases_count}") - if project.open_issues_count > 0: - stats_parts.append(f"Open Issues: {project.open_issues_count}") - - if stats_parts: - metadata_parts.append("Project Statistics: " + ", ".join(stats_parts)) - - if project.leaders_raw: - metadata_parts.append(f"Project Leaders: {', '.join(project.leaders_raw)}") - - if project.related_urls: - valid_urls = [ - url - for url in project.related_urls - if url and url not in (project.invalid_urls or []) - ] - if valid_urls: - metadata_parts.append(f"Related URLs: {', '.join(valid_urls)}") - - if project.created_at: - metadata_parts.append(f"Created: {project.created_at.strftime('%Y-%m-%d')}") - - if project.updated_at: - metadata_parts.append(f"Last Updated: {project.updated_at.strftime('%Y-%m-%d')}") - - if project.released_at: - metadata_parts.append(f"Last Release: {project.released_at.strftime('%Y-%m-%d')}") - - if project.health_score is not None: - metadata_parts.append(f"Health Score: {project.health_score:.2f}") - - metadata_parts.append(f"Active Project: {'Yes' if project.is_active else 'No'}") - - metadata_parts.append( - f"Issue Tracking: {'Enabled' if project.track_issues else 'Disabled'}" - ) - - return ( - DELIMITER.join(filter(None, prose_parts)), - DELIMITER.join(filter(None, metadata_parts)), - ) diff --git a/backend/apps/ai/management/commands/ai_create_project_context.py b/backend/apps/ai/management/commands/ai_create_project_context.py new file mode 100644 index 0000000000..cfb3ba259e --- /dev/null +++ b/backend/apps/ai/management/commands/ai_create_project_context.py @@ -0,0 +1,77 @@ +"""A command to update context for OWASP project data.""" + +from django.core.management.base import BaseCommand + +from apps.ai.common.extractors import extract_project_content +from apps.ai.common.utils import create_context +from apps.owasp.models.project import Project + + +class Command(BaseCommand): + help = "Update context for OWASP project data" + + def add_arguments(self, parser): + parser.add_argument( + "--project-key", + type=str, + help="Process only the project with this key", + ) + parser.add_argument( + "--all", + action="store_true", + help="Process all the projects", + ) + parser.add_argument( + "--batch-size", + type=int, + default=50, + help="Number of projects to process in each batch", + ) + + def handle(self, *args, **options): + if options["project_key"]: + queryset = Project.objects.filter(key=options["project_key"]) + elif options["all"]: + queryset = Project.objects.all() + else: + queryset = Project.objects.filter(is_active=True) + + if not (total_projects := queryset.count()): + self.stdout.write("No projects found to process") + return + + self.stdout.write(f"Found {total_projects} projects to process") + + batch_size = options["batch_size"] + processed_count = 0 + + for offset in range(0, total_projects, batch_size): + batch_projects = queryset[offset : offset + batch_size] + processed_count += self.process_context_batch(batch_projects) + + self.stdout.write( + self.style.SUCCESS(f"Completed processing {processed_count}/{total_projects} projects") + ) + + def process_context_batch(self, projects: list[Project]) -> int: + """Process a batch of projects to create contexts.""" + processed = 0 + + for project in projects: + prose_content, metadata_content = extract_project_content(project) + full_content = ( + f"{metadata_content}\n\n{prose_content}" if metadata_content else prose_content + ) + + if not full_content.strip(): + self.stdout.write(f"No content for project {project.key}") + continue + + if create_context( + content=full_content, content_object=project, source="owasp_project" + ): + processed += 1 + self.stdout.write(f"Created context for {project.key}") + else: + self.stdout.write(self.style.ERROR(f"Failed to create context for {project.key}")) + return processed diff --git a/backend/apps/ai/models/chunk.py b/backend/apps/ai/models/chunk.py index e3144b675a..c216a8a7ec 100644 --- a/backend/apps/ai/models/chunk.py +++ b/backend/apps/ai/models/chunk.py @@ -44,29 +44,27 @@ def split_text(text: str) -> list[str]: @staticmethod def update_data( text: str, - context: Context, embedding, *, save: bool = True, - ) -> "Chunk | None": + ) -> "Chunk": """Update chunk data. Args: text (str): The text content of the chunk. - context (Context): The context this chunk belongs to. embedding (list): The embedding vector for the chunk. save (bool): Whether to save the chunk to the database. Returns: - Chunk: The updated chunk instance or None if it already exists. + Chunk: The created chunk instance (without context assigned). """ - if Chunk.objects.filter(context=context, text=text).exists(): - return None - - chunk = Chunk(context=context, text=text, embedding=embedding) + chunk = Chunk(text=text, embedding=embedding) if save: + if chunk.context_id is None: + error_msg = "Chunk must have a context assigned before saving." + raise ValueError(error_msg) chunk.save() return chunk diff --git a/backend/tests/apps/ai/common/utils_test.py b/backend/tests/apps/ai/common/utils_test.py index a1d586b3bb..c56068a375 100644 --- a/backend/tests/apps/ai/common/utils_test.py +++ b/backend/tests/apps/ai/common/utils_test.py @@ -23,7 +23,6 @@ def test_create_chunks_and_embeddings_success( base_time = datetime.now(UTC) mock_datetime.now.return_value = base_time mock_datetime.UTC = UTC - mock_datetime.timedelta = timedelta mock_openai_client = MagicMock() @@ -34,7 +33,10 @@ def test_create_chunks_and_embeddings_success( ] mock_openai_client.embeddings.create.return_value = mock_api_response - mock_update_data.return_value = "mock_chunk_instance" + # Create mock chunk instances with .save method + mock_chunk1 = MagicMock() + mock_chunk2 = MagicMock() + mock_update_data.side_effect = [mock_chunk1, mock_chunk2] all_chunk_texts = ["first chunk", "second chunk"] mock_content_object = MagicMock() @@ -52,22 +54,18 @@ def test_create_chunks_and_embeddings_success( mock_update_data.assert_has_calls( [ - call( - text="first chunk", - context=mock_content_object, - embedding=[0.1, 0.2], - save=True, - ), - call( - text="second chunk", - context=mock_content_object, - embedding=[0.3, 0.4], - save=True, - ), + call(text="first chunk", embedding=[0.1, 0.2], save=False), + call(text="second chunk", embedding=[0.3, 0.4], save=False), ] ) - assert result == ["mock_chunk_instance", "mock_chunk_instance"] + assert mock_chunk1.context == mock_content_object + assert mock_chunk2.context == mock_content_object + + mock_chunk1.save.assert_called_once() + mock_chunk2.save.assert_called_once() + + assert result == [mock_chunk1, mock_chunk2] mock_sleep.assert_not_called() @@ -108,7 +106,8 @@ def test_create_chunks_and_embeddings_no_sleep_with_current_settings( mock_api_response.data = [MockEmbeddingData([0.1, 0.2])] mock_openai_client.embeddings.create.return_value = mock_api_response - mock_update_data.return_value = "mock_chunk_instance" + mock_chunk = MagicMock() + mock_update_data.return_value = mock_chunk result = create_chunks_and_embeddings( ["test chunk"], @@ -117,5 +116,5 @@ def test_create_chunks_and_embeddings_no_sleep_with_current_settings( ) mock_sleep.assert_not_called() - - assert result == ["mock_chunk_instance"] + mock_chunk.save.assert_called_once() + assert result == [mock_chunk] diff --git a/backend/tests/apps/ai/management/commands/ai_create_chapter_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_create_chapter_chunks_test.py new file mode 100644 index 0000000000..018c022ac1 --- /dev/null +++ b/backend/tests/apps/ai/management/commands/ai_create_chapter_chunks_test.py @@ -0,0 +1,237 @@ +"""Tests for the ai_create_chapter_chunks Django management command.""" + +import os +from unittest.mock import MagicMock, Mock, patch + +import pytest +from django.core.management.base import BaseCommand + +from apps.ai.management.commands.ai_create_chapter_chunks import Command + + +@pytest.fixture +def command(): + """Return a command instance.""" + return Command() + + +@pytest.fixture +def mock_chapter(): + """Return a mock Chapter instance.""" + chapter = Mock() + chapter.id = 1 + chapter.key = "test-chapter" + return chapter + + +@pytest.fixture +def mock_context(): + """Return a mock Context instance.""" + context = Mock() + context.id = 1 + return context + + +class TestAiCreateChapterChunksCommand: + """Test suite for the ai_create_chapter_chunks command.""" + + def test_command_help_text(self, command): + """Test that the command has the correct help text.""" + assert command.help == "Create chunks for OWASP chapter data" + + def test_command_inheritance(self, command): + """Test that the command inherits from BaseCommand.""" + assert isinstance(command, BaseCommand) + + def test_add_arguments(self, command): + """Test that the command adds the correct arguments.""" + parser = MagicMock() + command.add_arguments(parser) + + assert parser.add_argument.call_count == 3 + parser.add_argument.assert_any_call( + "--chapter-key", + type=str, + help="Process only the chapter with this key", + ) + parser.add_argument.assert_any_call( + "--all", + action="store_true", + help="Process all the chapters", + ) + parser.add_argument.assert_any_call( + "--batch-size", + type=int, + default=50, + help="Number of chapters to process in each batch", + ) + + @patch.dict(os.environ, {}, clear=True) + def test_handle_missing_openai_key(self, command): + """Test command fails when OpenAI API key is not set.""" + command.stdout = MagicMock() + command.style = MagicMock() + + command.handle() + + command.stdout.write.assert_called_once() + command.style.ERROR.assert_called_once_with( + "DJANGO_OPEN_AI_SECRET_KEY environment variable not set" + ) + + @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) + @patch("apps.ai.management.commands.ai_create_chapter_chunks.openai.OpenAI") + @patch("apps.ai.management.commands.ai_create_chapter_chunks.Chapter.objects") + def test_handle_no_chapters_found(self, mock_chapter_objects, mock_openai, command): + """Test command when no chapters are found.""" + command.stdout = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 0 + mock_chapter_objects.filter.return_value = mock_queryset + + command.handle(chapter_key=None, all=False, batch_size=50) + + command.stdout.write.assert_called_with("No chapters found to process") + + @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) + @patch("apps.ai.management.commands.ai_create_chapter_chunks.openai.OpenAI") + @patch("apps.ai.management.commands.ai_create_chapter_chunks.Chapter.objects") + def test_handle_with_chapter_key( + self, mock_chapter_objects, mock_openai, command, mock_chapter + ): + """Test command with specific chapter key.""" + command.stdout = MagicMock() + command.style = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 1 + mock_queryset.__iter__ = lambda _self: iter([mock_chapter]) + mock_queryset.__getitem__ = lambda _self, _key: [mock_chapter] + mock_chapter_objects.filter.return_value = mock_queryset + + with patch.object(command, "process_chunks_batch", return_value=1): + command.handle(chapter_key="test-chapter", all=False, batch_size=50) + + mock_chapter_objects.filter.assert_called_with(key="test-chapter") + + @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) + @patch("apps.ai.management.commands.ai_create_chapter_chunks.openai.OpenAI") + @patch("apps.ai.management.commands.ai_create_chapter_chunks.Chapter.objects") + def test_handle_with_all_flag(self, mock_chapter_objects, mock_openai, command, mock_chapter): + """Test command with --all flag.""" + command.stdout = MagicMock() + command.style = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 1 + mock_queryset.__iter__ = lambda _self: iter([mock_chapter]) + mock_queryset.__getitem__ = lambda _self, _key: [mock_chapter] + mock_chapter_objects.all.return_value = mock_queryset + + with patch.object(command, "process_chunks_batch", return_value=1): + command.handle(chapter_key=None, all=True, batch_size=50) + + mock_chapter_objects.all.assert_called_once() + + @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) + @patch("apps.ai.management.commands.ai_create_chapter_chunks.openai.OpenAI") + @patch("apps.ai.management.commands.ai_create_chapter_chunks.Chapter.objects") + def test_handle_default_active_chapters( + self, mock_chapter_objects, mock_openai, command, mock_chapter + ): + """Test command defaults to active chapters.""" + command.stdout = MagicMock() + command.style = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 1 + mock_queryset.__iter__ = lambda _self: iter([mock_chapter]) + mock_queryset.__getitem__ = lambda _self, _key: [mock_chapter] + mock_chapter_objects.filter.return_value = mock_queryset + + with patch.object(command, "process_chunks_batch", return_value=1): + command.handle(chapter_key=None, all=False, batch_size=50) + + mock_chapter_objects.filter.assert_called_with(is_active=True) + + @patch("apps.ai.management.commands.ai_create_chapter_chunks.ContentType.objects") + @patch("apps.ai.management.commands.ai_create_chapter_chunks.Context.objects") + @patch("apps.ai.management.commands.ai_create_chapter_chunks.extract_chapter_content") + @patch("apps.ai.management.commands.ai_create_chapter_chunks.Chunk.split_text") + @patch("apps.ai.management.commands.ai_create_chapter_chunks.create_chunks_and_embeddings") + @patch("apps.ai.management.commands.ai_create_chapter_chunks.Chunk.bulk_save") + def test_process_chunks_batch_success( + self, + mock_bulk_save, + mock_create_chunks, + mock_split_text, + mock_extract, + mock_context_objects, + mock_content_type, + command, + mock_chapter, + mock_context, + ): + """Test successful batch processing of chunks.""" + command.stdout = MagicMock() + command.openai_client = MagicMock() + + # Setup mocks + mock_content_type.get_for_model.return_value = MagicMock() + mock_context_objects.filter.return_value.first.return_value = mock_context + mock_extract.return_value = ("prose content", "metadata content") + mock_split_text.return_value = ["chunk1", "chunk2"] + mock_chunks = [Mock(), Mock()] + mock_create_chunks.return_value = mock_chunks + + result = command.process_chunks_batch([mock_chapter]) + + assert result == 1 + mock_extract.assert_called_once_with(mock_chapter) + mock_split_text.assert_called_once_with("prose content") + mock_create_chunks.assert_called_once() + mock_bulk_save.assert_called_once_with(mock_chunks) + + @patch("apps.ai.management.commands.ai_create_chapter_chunks.ContentType.objects") + @patch("apps.ai.management.commands.ai_create_chapter_chunks.Context.objects") + def test_process_chunks_batch_no_context( + self, + mock_context_objects, + mock_content_type, + command, + mock_chapter, + ): + """Test batch processing when no context is found.""" + command.stdout = MagicMock() + command.style = MagicMock() + + # Setup mocks + mock_content_type.get_for_model.return_value = MagicMock() + mock_context_objects.filter.return_value.first.return_value = None + + result = command.process_chunks_batch([mock_chapter]) + + assert result == 0 + command.style.WARNING.assert_called_once() + + @patch("apps.ai.management.commands.ai_create_chapter_chunks.ContentType.objects") + @patch("apps.ai.management.commands.ai_create_chapter_chunks.Context.objects") + @patch("apps.ai.management.commands.ai_create_chapter_chunks.extract_chapter_content") + def test_process_chunks_batch_no_content( + self, + mock_extract, + mock_context_objects, + mock_content_type, + command, + mock_chapter, + mock_context, + ): + """Test batch processing when no content is extracted.""" + command.stdout = MagicMock() + + # Setup mocks + mock_content_type.get_for_model.return_value = MagicMock() + mock_context_objects.filter.return_value.first.return_value = mock_context + mock_extract.return_value = ("", "") + + result = command.process_chunks_batch([mock_chapter]) + + assert result == 0 + command.stdout.write.assert_any_call(f"No content to chunk for chapter {mock_chapter.key}") diff --git a/backend/tests/apps/ai/management/commands/ai_create_chapter_context_test.py b/backend/tests/apps/ai/management/commands/ai_create_chapter_context_test.py new file mode 100644 index 0000000000..9951ac7ec6 --- /dev/null +++ b/backend/tests/apps/ai/management/commands/ai_create_chapter_context_test.py @@ -0,0 +1,210 @@ +"""Tests for the ai_create_chapter_context Django management command.""" + +from unittest.mock import MagicMock, Mock, patch + +import pytest +from django.core.management.base import BaseCommand + +from apps.ai.management.commands.ai_create_chapter_context import Command + + +@pytest.fixture +def command(): + """Return a command instance.""" + return Command() + + +@pytest.fixture +def mock_chapter(): + """Return a mock Chapter instance.""" + chapter = Mock() + chapter.id = 1 + chapter.key = "test-chapter" + return chapter + + +class TestAiCreateChapterContextCommand: + """Test suite for the ai_create_chapter_context command.""" + + def test_command_help_text(self, command): + """Test that the command has the correct help text.""" + assert command.help == "Update context for OWASP chapter data" + + def test_command_inheritance(self, command): + """Test that the command inherits from BaseCommand.""" + assert isinstance(command, BaseCommand) + + def test_add_arguments(self, command): + """Test that the command adds the correct arguments.""" + parser = MagicMock() + command.add_arguments(parser) + + assert parser.add_argument.call_count == 3 + parser.add_argument.assert_any_call( + "--chapter-key", + type=str, + help="Process only the chapter with this key", + ) + parser.add_argument.assert_any_call( + "--all", + action="store_true", + help="Process all the chapters", + ) + parser.add_argument.assert_any_call( + "--batch-size", + type=int, + default=50, + help="Number of chapters to process in each batch", + ) + + @patch("apps.ai.management.commands.ai_create_chapter_context.Chapter.objects") + def test_handle_no_chapters_found(self, mock_chapter_objects, command): + """Test command when no chapters are found.""" + command.stdout = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 0 + mock_chapter_objects.filter.return_value = mock_queryset + + command.handle(chapter_key=None, all=False, batch_size=50) + + command.stdout.write.assert_called_with("No chapters found to process") + + @patch("apps.ai.management.commands.ai_create_chapter_context.Chapter.objects") + def test_handle_with_chapter_key(self, mock_chapter_objects, command, mock_chapter): + """Test command with specific chapter key.""" + command.stdout = MagicMock() + command.style = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 1 + mock_queryset.__iter__ = lambda _self: iter([mock_chapter]) + mock_queryset.__getitem__ = lambda _self, _key: [mock_chapter] + mock_chapter_objects.filter.return_value = mock_queryset + + with patch.object(command, "process_context_batch", return_value=1): + command.handle(chapter_key="test-chapter", all=False, batch_size=50) + + mock_chapter_objects.filter.assert_called_with(key="test-chapter") + + @patch("apps.ai.management.commands.ai_create_chapter_context.Chapter.objects") + def test_handle_with_all_flag(self, mock_chapter_objects, command, mock_chapter): + """Test command with --all flag.""" + command.stdout = MagicMock() + command.style = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 1 + mock_queryset.__iter__ = lambda _self: iter([mock_chapter]) + mock_queryset.__getitem__ = lambda _self, _key: [mock_chapter] + mock_chapter_objects.all.return_value = mock_queryset + + with patch.object(command, "process_context_batch", return_value=1): + command.handle(chapter_key=None, all=True, batch_size=50) + + mock_chapter_objects.all.assert_called_once() + + @patch("apps.ai.management.commands.ai_create_chapter_context.Chapter.objects") + def test_handle_default_active_chapters(self, mock_chapter_objects, command, mock_chapter): + """Test command defaults to active chapters.""" + command.stdout = MagicMock() + command.style = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 1 + mock_queryset.__iter__ = lambda _self: iter([mock_chapter]) + mock_queryset.__getitem__ = lambda _self, _key: [mock_chapter] + mock_chapter_objects.filter.return_value = mock_queryset + + with patch.object(command, "process_context_batch", return_value=1): + command.handle(chapter_key=None, all=False, batch_size=50) + + mock_chapter_objects.filter.assert_called_with(is_active=True) + + @patch("apps.ai.management.commands.ai_create_chapter_context.extract_chapter_content") + @patch("apps.ai.management.commands.ai_create_chapter_context.create_context") + def test_process_context_batch_success( + self, + mock_create_context, + mock_extract, + command, + mock_chapter, + ): + """Test successful batch processing of contexts.""" + command.stdout = MagicMock() + + # Setup mocks + mock_extract.return_value = ("prose content", "metadata content") + mock_create_context.return_value = True + + result = command.process_context_batch([mock_chapter]) + + assert result == 1 + mock_extract.assert_called_once_with(mock_chapter) + mock_create_context.assert_called_once_with( + content="metadata content\n\nprose content", + content_object=mock_chapter, + source="owasp_chapter", + ) + + @patch("apps.ai.management.commands.ai_create_chapter_context.extract_chapter_content") + @patch("apps.ai.management.commands.ai_create_chapter_context.create_context") + def test_process_context_batch_no_metadata( + self, + mock_create_context, + mock_extract, + command, + mock_chapter, + ): + """Test batch processing without metadata content.""" + command.stdout = MagicMock() + + # Setup mocks + mock_extract.return_value = ("prose content", "") + mock_create_context.return_value = True + + result = command.process_context_batch([mock_chapter]) + + assert result == 1 + mock_extract.assert_called_once_with(mock_chapter) + mock_create_context.assert_called_once_with( + content="prose content", + content_object=mock_chapter, + source="owasp_chapter", + ) + + @patch("apps.ai.management.commands.ai_create_chapter_context.extract_chapter_content") + def test_process_context_batch_no_content( + self, + mock_extract, + command, + mock_chapter, + ): + """Test batch processing when no content is extracted.""" + command.stdout = MagicMock() + + # Setup mocks + mock_extract.return_value = ("", "") + + result = command.process_context_batch([mock_chapter]) + + assert result == 0 + command.stdout.write.assert_any_call(f"No content for chapter {mock_chapter.key}") + + @patch("apps.ai.management.commands.ai_create_chapter_context.extract_chapter_content") + @patch("apps.ai.management.commands.ai_create_chapter_context.create_context") + def test_process_context_batch_create_context_fails( + self, + mock_create_context, + mock_extract, + command, + mock_chapter, + ): + """Test batch processing when create_context fails.""" + command.stdout = MagicMock() + command.style = MagicMock() + + # Setup mocks + mock_extract.return_value = ("prose content", "metadata content") + mock_create_context.return_value = False + + result = command.process_context_batch([mock_chapter]) + + assert result == 0 + command.style.ERROR.assert_called_once() diff --git a/backend/tests/apps/ai/management/commands/ai_create_committee_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_create_committee_chunks_test.py new file mode 100644 index 0000000000..0368bb2b5d --- /dev/null +++ b/backend/tests/apps/ai/management/commands/ai_create_committee_chunks_test.py @@ -0,0 +1,154 @@ +"""Tests for the ai_create_committee_chunks Django management command.""" + +import os +from unittest.mock import MagicMock, Mock, patch + +import pytest +from django.core.management.base import BaseCommand + +from apps.ai.management.commands.ai_create_committee_chunks import Command + + +@pytest.fixture +def command(): + """Return a command instance.""" + return Command() + + +@pytest.fixture +def mock_committee(): + """Return a mock Committee instance.""" + committee = Mock() + committee.id = 1 + committee.key = "test-committee" + return committee + + +@pytest.fixture +def mock_context(): + """Return a mock Context instance.""" + context = Mock() + context.id = 1 + return context + + +class TestAiCreateCommitteeChunksCommand: + """Test suite for the ai_create_committee_chunks command.""" + + def test_command_help_text(self, command): + """Test that the command has the correct help text.""" + assert command.help == "Create chunks for OWASP committee data" + + def test_command_inheritance(self, command): + """Test that the command inherits from BaseCommand.""" + assert isinstance(command, BaseCommand) + + def test_add_arguments(self, command): + """Test that the command adds the correct arguments.""" + parser = MagicMock() + command.add_arguments(parser) + + assert parser.add_argument.call_count == 3 + parser.add_argument.assert_any_call( + "--committee-key", + type=str, + help="Process only the committee with this key", + ) + parser.add_argument.assert_any_call( + "--all", + action="store_true", + help="Process all the committees", + ) + parser.add_argument.assert_any_call( + "--batch-size", + type=int, + default=50, + help="Number of committees to process in each batch", + ) + + @patch.dict(os.environ, {}, clear=True) + def test_handle_missing_openai_key(self, command): + """Test command fails when OpenAI API key is not set.""" + command.stdout = MagicMock() + command.style = MagicMock() + + command.handle() + + command.stdout.write.assert_called_once() + command.style.ERROR.assert_called_once_with( + "DJANGO_OPEN_AI_SECRET_KEY environment variable not set" + ) + + @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) + @patch("apps.ai.management.commands.ai_create_committee_chunks.openai.OpenAI") + @patch("apps.ai.management.commands.ai_create_committee_chunks.Committee.objects") + def test_handle_no_committees_found(self, mock_committee_objects, mock_openai, command): + """Test command when no committees are found.""" + command.stdout = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 0 + mock_committee_objects.filter.return_value = mock_queryset + + command.handle(committee_key=None, all=False, batch_size=50) + + command.stdout.write.assert_called_with("No committees found to process") + + @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) + @patch("apps.ai.management.commands.ai_create_committee_chunks.openai.OpenAI") + @patch("apps.ai.management.commands.ai_create_committee_chunks.Committee.objects") + def test_handle_with_committee_key( + self, mock_committee_objects, mock_openai, command, mock_committee + ): + """Test command with specific committee key.""" + command.stdout = MagicMock() + command.style = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 1 + mock_queryset.__iter__ = lambda _self: iter([mock_committee]) + mock_queryset.__getitem__ = lambda _self, _key: [mock_committee] + mock_committee_objects.filter.return_value = mock_queryset + + with patch.object(command, "process_chunks_batch", return_value=1): + command.handle(committee_key="test-committee", all=False, batch_size=50) + + mock_committee_objects.filter.assert_called_with(key="test-committee") + + @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) + @patch("apps.ai.management.commands.ai_create_committee_chunks.openai.OpenAI") + @patch("apps.ai.management.commands.ai_create_committee_chunks.Committee.objects") + def test_handle_with_all_flag( + self, mock_committee_objects, mock_openai, command, mock_committee + ): + """Test command with --all flag.""" + command.stdout = MagicMock() + command.style = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 1 + mock_queryset.__iter__ = lambda _self: iter([mock_committee]) + mock_queryset.__getitem__ = lambda _self, _key: [mock_committee] + mock_committee_objects.all.return_value = mock_queryset + + with patch.object(command, "process_chunks_batch", return_value=1): + command.handle(committee_key=None, all=True, batch_size=50) + + mock_committee_objects.all.assert_called_once() + + @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) + @patch("apps.ai.management.commands.ai_create_committee_chunks.openai.OpenAI") + @patch("apps.ai.management.commands.ai_create_committee_chunks.Committee.objects") + def test_handle_default_active_committees( + self, mock_committee_objects, mock_openai, command, mock_committee + ): + """Test command defaults to active committees.""" + command.stdout = MagicMock() + command.style = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 1 + mock_queryset.__iter__ = lambda _self: iter([mock_committee]) + mock_queryset.__getitem__ = lambda _self, _key: [mock_committee] + mock_committee_objects.filter.return_value = mock_queryset + + with patch.object(command, "process_chunks_batch", return_value=1): + command.handle(committee_key=None, all=False, batch_size=50) + + mock_committee_objects.filter.assert_called_with(is_active=True) diff --git a/backend/tests/apps/ai/management/commands/ai_create_committee_context_test.py b/backend/tests/apps/ai/management/commands/ai_create_committee_context_test.py new file mode 100644 index 0000000000..b7f8edae9b --- /dev/null +++ b/backend/tests/apps/ai/management/commands/ai_create_committee_context_test.py @@ -0,0 +1,120 @@ +"""Tests for the ai_create_committee_context Django management command.""" + +from unittest.mock import MagicMock, Mock, patch + +import pytest +from django.core.management.base import BaseCommand + +from apps.ai.management.commands.ai_create_committee_context import Command + + +@pytest.fixture +def command(): + """Return a command instance.""" + return Command() + + +@pytest.fixture +def mock_committee(): + """Return a mock Committee instance.""" + committee = Mock() + committee.id = 1 + committee.key = "test-committee" + return committee + + +class TestAiCreateCommitteeContextCommand: + """Test suite for the ai_create_committee_context command.""" + + def test_command_help_text(self, command): + """Test that the command has the correct help text.""" + assert command.help == "Update context for OWASP committee data" + + def test_command_inheritance(self, command): + """Test that the command inherits from BaseCommand.""" + assert isinstance(command, BaseCommand) + + def test_add_arguments(self, command): + """Test that the command adds the correct arguments.""" + parser = MagicMock() + command.add_arguments(parser) + + assert parser.add_argument.call_count == 3 + parser.add_argument.assert_any_call( + "--committee-key", + type=str, + help="Process only the committee with this key", + ) + parser.add_argument.assert_any_call( + "--all", + action="store_true", + help="Process all the committees", + ) + parser.add_argument.assert_any_call( + "--batch-size", + type=int, + default=50, + help="Number of committees to process in each batch", + ) + + @patch("apps.ai.management.commands.ai_create_committee_context.Committee.objects") + def test_handle_no_committees_found(self, mock_committee_objects, command): + """Test command when no committees are found.""" + command.stdout = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 0 + mock_committee_objects.filter.return_value = mock_queryset + + command.handle(committee_key=None, all=False, batch_size=50) + + command.stdout.write.assert_called_with("No committees found to process") + + @patch("apps.ai.management.commands.ai_create_committee_context.Committee.objects") + def test_handle_with_committee_key(self, mock_committee_objects, command, mock_committee): + """Test command with specific committee key.""" + command.stdout = MagicMock() + command.style = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 1 + mock_queryset.__iter__ = lambda _self: iter([mock_committee]) + mock_queryset.__getitem__ = lambda _self, _key: [mock_committee] + mock_committee_objects.filter.return_value = mock_queryset + + with patch.object(command, "process_context_batch", return_value=1): + command.handle(committee_key="test-committee", all=False, batch_size=50) + + mock_committee_objects.filter.assert_called_with(key="test-committee") + + @patch("apps.ai.management.commands.ai_create_committee_context.Committee.objects") + def test_handle_with_all_flag(self, mock_committee_objects, command, mock_committee): + """Test command with --all flag.""" + command.stdout = MagicMock() + command.style = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 1 + mock_queryset.__iter__ = lambda _self: iter([mock_committee]) + mock_queryset.__getitem__ = lambda _self, _key: [mock_committee] + mock_committee_objects.all.return_value = mock_queryset + + with patch.object(command, "process_context_batch", return_value=1): + command.handle(committee_key=None, all=True, batch_size=50) + + mock_committee_objects.all.assert_called_once() + + @patch("apps.ai.management.commands.ai_create_committee_context.Committee.objects") + def test_handle_default_active_committees( + self, mock_committee_objects, command, mock_committee + ): + """Test command defaults to active committees.""" + command.stdout = MagicMock() + command.style = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 1 + mock_queryset.__iter__ = lambda _self: iter([mock_committee]) + mock_queryset.__getitem__ = lambda _self, _key: [mock_committee] + mock_committee_objects.filter.return_value = mock_queryset + + with patch.object(command, "process_context_batch", return_value=1): + command.handle(committee_key=None, all=False, batch_size=50) + + mock_committee_objects.filter.assert_called_with(is_active=True) diff --git a/backend/tests/apps/ai/management/commands/ai_create_event_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_create_event_chunks_test.py new file mode 100644 index 0000000000..a2ea00dab6 --- /dev/null +++ b/backend/tests/apps/ai/management/commands/ai_create_event_chunks_test.py @@ -0,0 +1,120 @@ +"""Tests for the ai_create_event_chunks Django management command.""" + +import os +from unittest.mock import MagicMock, Mock, patch + +import pytest +from django.core.management.base import BaseCommand + +from apps.ai.management.commands.ai_create_event_chunks import Command + + +@pytest.fixture +def command(): + """Return a command instance.""" + return Command() + + +@pytest.fixture +def mock_event(): + """Return a mock Event instance.""" + event = Mock() + event.id = 1 + event.title = "test-event" + return event + + +class TestAiCreateEventChunksCommand: + """Test suite for the ai_create_event_chunks command.""" + + def test_command_help_text(self, command): + """Test that the command has the correct help text.""" + assert command.help == "Create chunks for OWASP event data" + + def test_command_inheritance(self, command): + """Test that the command inherits from BaseCommand.""" + assert isinstance(command, BaseCommand) + + def test_add_arguments(self, command): + """Test that the command adds the correct arguments.""" + parser = MagicMock() + command.add_arguments(parser) + + assert parser.add_argument.call_count == 3 + parser.add_argument.assert_any_call( + "--all", + action="store_true", + help="Process all the events", + ) + parser.add_argument.assert_any_call( + "--batch-size", + type=int, + default=50, + help="Number of events to process in each batch", + ) + + @patch.dict(os.environ, {}, clear=True) + def test_handle_missing_openai_key(self, command): + """Test command fails when OpenAI API key is not set.""" + command.stdout = MagicMock() + command.style = MagicMock() + + command.handle() + + command.stdout.write.assert_called_once() + command.style.ERROR.assert_called_once_with( + "DJANGO_OPEN_AI_SECRET_KEY environment variable not set" + ) + + @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) + @patch("apps.ai.management.commands.ai_create_event_chunks.openai.OpenAI") + @patch("apps.ai.management.commands.ai_create_event_chunks.Event.upcoming_events") + def test_handle_no_events_found(self, mock_upcoming_events, mock_openai, command): + """Test command when no events are found.""" + command.stdout = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 0 + mock_upcoming_events.return_value = mock_queryset + + command.handle(event_key=None, all=False, batch_size=50) + + command.stdout.write.assert_called_with("No events found to process") + + @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) + @patch("apps.ai.management.commands.ai_create_event_chunks.openai.OpenAI") + @patch("apps.ai.management.commands.ai_create_event_chunks.Event.objects") + def test_handle_with_all_flag(self, mock_event_objects, mock_openai, command, mock_event): + """Test command with --all flag.""" + command.stdout = MagicMock() + command.style = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 1 + mock_queryset.__iter__ = lambda _self: iter([mock_event]) + mock_queryset.__getitem__ = lambda _self, _key: [mock_event] + mock_event_objects.all.return_value = mock_queryset + + with patch.object(command, "process_chunks_batch", return_value=1): + command.handle(event_key=None, all=True, batch_size=50) + + mock_event_objects.all.assert_called_once() + + @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) + @patch("apps.ai.management.commands.ai_create_event_chunks.openai.OpenAI") + @patch("apps.ai.management.commands.ai_create_event_chunks.Event.upcoming_events") + def test_handle_default_future_events( + self, mock_upcoming_events, mock_openai, command, mock_event + ): + """Test command defaults to future events.""" + command.stdout = MagicMock() + command.style = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 1 + mock_queryset.__iter__ = lambda _self: iter([mock_event]) + mock_queryset.__getitem__ = lambda _self, _key: [mock_event] + mock_upcoming_events.return_value = mock_queryset + + with patch.object(command, "process_chunks_batch", return_value=1): + command.handle(event_key=None, all=False, batch_size=50) + + # Should filter for future events by default + mock_upcoming_events.assert_called() diff --git a/backend/tests/apps/ai/management/commands/ai_create_event_context_test.py b/backend/tests/apps/ai/management/commands/ai_create_event_context_test.py new file mode 100644 index 0000000000..77eae05c7f --- /dev/null +++ b/backend/tests/apps/ai/management/commands/ai_create_event_context_test.py @@ -0,0 +1,98 @@ +"""Tests for the ai_create_event_context Django management command.""" + +from unittest.mock import MagicMock, Mock, patch + +import pytest +from django.core.management.base import BaseCommand + +from apps.ai.management.commands.ai_create_event_context import Command + + +@pytest.fixture +def command(): + """Return a command instance.""" + return Command() + + +@pytest.fixture +def mock_event(): + """Return a mock Event instance.""" + event = Mock() + event.id = 1 + event.title = "test-event" + return event + + +class TestAiCreateEventContextCommand: + """Test suite for the ai_create_event_context command.""" + + def test_command_help_text(self, command): + """Test that the command has the correct help text.""" + assert command.help == "Update context for OWASP event data" + + def test_command_inheritance(self, command): + """Test that the command inherits from BaseCommand.""" + assert isinstance(command, BaseCommand) + + def test_add_arguments(self, command): + """Test that the command adds the correct arguments.""" + parser = MagicMock() + command.add_arguments(parser) + + assert parser.add_argument.call_count == 3 + parser.add_argument.assert_any_call( + "--all", + action="store_true", + help="Process all the events", + ) + parser.add_argument.assert_any_call( + "--batch-size", + type=int, + default=50, + help="Number of events to process in each batch", + ) + + @patch("apps.ai.management.commands.ai_create_event_context.Event.upcoming_events") + def test_handle_no_events_found(self, mock_upcoming_events, command): + """Test command when no events are found.""" + command.stdout = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 0 + mock_upcoming_events.return_value = mock_queryset + + command.handle(event_key=None, all=False, batch_size=50) + + command.stdout.write.assert_called_with("No events found to process") + + @patch("apps.ai.management.commands.ai_create_event_context.Event.objects") + def test_handle_with_all_flag(self, mock_event_objects, command, mock_event): + """Test command with --all flag.""" + command.stdout = MagicMock() + command.style = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 1 + mock_queryset.__iter__ = lambda _self: iter([mock_event]) + mock_queryset.__getitem__ = lambda _self, _key: [mock_event] + mock_event_objects.all.return_value = mock_queryset + + with patch.object(command, "process_context_batch", return_value=1): + command.handle(event_key=None, all=True, batch_size=50) + + mock_event_objects.all.assert_called_once() + + @patch("apps.ai.management.commands.ai_create_event_context.Event.upcoming_events") + def test_handle_default_future_events(self, mock_upcoming_events, command, mock_event): + """Test command defaults to future events.""" + command.stdout = MagicMock() + command.style = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 1 + mock_queryset.__iter__ = lambda _self: iter([mock_event]) + mock_queryset.__getitem__ = lambda _self, _key: [mock_event] + mock_upcoming_events.return_value = mock_queryset + + with patch.object(command, "process_context_batch", return_value=1): + command.handle(event_key=None, all=False, batch_size=50) + + # Should filter for future events by default + mock_upcoming_events.assert_called() diff --git a/backend/tests/apps/ai/management/commands/ai_create_project_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_create_project_chunks_test.py new file mode 100644 index 0000000000..ec4b9edfe4 --- /dev/null +++ b/backend/tests/apps/ai/management/commands/ai_create_project_chunks_test.py @@ -0,0 +1,144 @@ +"""Tests for the ai_create_project_chunks Django management command.""" + +import os +from unittest.mock import MagicMock, Mock, patch + +import pytest +from django.core.management.base import BaseCommand + +from apps.ai.management.commands.ai_create_project_chunks import Command + + +@pytest.fixture +def command(): + """Return a command instance.""" + return Command() + + +@pytest.fixture +def mock_project(): + """Return a mock Project instance.""" + project = Mock() + project.id = 1 + project.key = "test-project" + return project + + +class TestAiCreateProjectChunksCommand: + """Test suite for the ai_create_project_chunks command.""" + + def test_command_help_text(self, command): + """Test that the command has the correct help text.""" + assert command.help == "Create chunks for OWASP project data" + + def test_command_inheritance(self, command): + """Test that the command inherits from BaseCommand.""" + assert isinstance(command, BaseCommand) + + def test_add_arguments(self, command): + """Test that the command adds the correct arguments.""" + parser = MagicMock() + command.add_arguments(parser) + + assert parser.add_argument.call_count == 3 + parser.add_argument.assert_any_call( + "--project-key", + type=str, + help="Process only the project with this key", + ) + parser.add_argument.assert_any_call( + "--all", + action="store_true", + help="Process all the projects", + ) + parser.add_argument.assert_any_call( + "--batch-size", + type=int, + default=50, + help="Number of projects to process in each batch", + ) + + @patch.dict(os.environ, {}, clear=True) + def test_handle_missing_openai_key(self, command): + """Test command fails when OpenAI API key is not set.""" + command.stdout = MagicMock() + command.style = MagicMock() + + command.handle() + + command.stdout.write.assert_called_once() + command.style.ERROR.assert_called_once_with( + "DJANGO_OPEN_AI_SECRET_KEY environment variable not set" + ) + + @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) + @patch("apps.ai.management.commands.ai_create_project_chunks.openai.OpenAI") + @patch("apps.ai.management.commands.ai_create_project_chunks.Project.objects") + def test_handle_no_projects_found(self, mock_project_objects, mock_openai, command): + """Test command when no projects are found.""" + command.stdout = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 0 + mock_project_objects.filter.return_value = mock_queryset + + command.handle(project_key=None, all=False, batch_size=50) + + command.stdout.write.assert_called_with("No projects found to process") + + @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) + @patch("apps.ai.management.commands.ai_create_project_chunks.openai.OpenAI") + @patch("apps.ai.management.commands.ai_create_project_chunks.Project.objects") + def test_handle_with_project_key( + self, mock_project_objects, mock_openai, command, mock_project + ): + """Test command with specific project key.""" + command.stdout = MagicMock() + command.style = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 1 + mock_queryset.__iter__ = lambda _self: iter([mock_project]) + mock_queryset.__getitem__ = lambda _self, _key: [mock_project] + mock_project_objects.filter.return_value = mock_queryset + + with patch.object(command, "process_chunks_batch", return_value=1): + command.handle(project_key="test-project", all=False, batch_size=50) + + mock_project_objects.filter.assert_called_with(key="test-project") + + @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) + @patch("apps.ai.management.commands.ai_create_project_chunks.openai.OpenAI") + @patch("apps.ai.management.commands.ai_create_project_chunks.Project.objects") + def test_handle_with_all_flag(self, mock_project_objects, mock_openai, command, mock_project): + """Test command with --all flag.""" + command.stdout = MagicMock() + command.style = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 1 + mock_queryset.__iter__ = lambda _self: iter([mock_project]) + mock_queryset.__getitem__ = lambda _self, _key: [mock_project] + mock_project_objects.all.return_value = mock_queryset + + with patch.object(command, "process_chunks_batch", return_value=1): + command.handle(project_key=None, all=True, batch_size=50) + + mock_project_objects.all.assert_called_once() + + @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) + @patch("apps.ai.management.commands.ai_create_project_chunks.openai.OpenAI") + @patch("apps.ai.management.commands.ai_create_project_chunks.Project.objects") + def test_handle_default_active_projects( + self, mock_project_objects, mock_openai, command, mock_project + ): + """Test command defaults to active projects.""" + command.stdout = MagicMock() + command.style = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 1 + mock_queryset.__iter__ = lambda _self: iter([mock_project]) + mock_queryset.__getitem__ = lambda _self, _key: [mock_project] + mock_project_objects.filter.return_value = mock_queryset + + with patch.object(command, "process_chunks_batch", return_value=1): + command.handle(project_key=None, all=False, batch_size=50) + + mock_project_objects.filter.assert_called_with(is_active=True) diff --git a/backend/tests/apps/ai/management/commands/ai_create_project_context_test.py b/backend/tests/apps/ai/management/commands/ai_create_project_context_test.py new file mode 100644 index 0000000000..3b05d11d42 --- /dev/null +++ b/backend/tests/apps/ai/management/commands/ai_create_project_context_test.py @@ -0,0 +1,118 @@ +"""Tests for the ai_create_project_context Django management command.""" + +from unittest.mock import MagicMock, Mock, patch + +import pytest +from django.core.management.base import BaseCommand + +from apps.ai.management.commands.ai_create_project_context import Command + + +@pytest.fixture +def command(): + """Return a command instance.""" + return Command() + + +@pytest.fixture +def mock_project(): + """Return a mock Project instance.""" + project = Mock() + project.id = 1 + project.key = "test-project" + return project + + +class TestAiCreateProjectContextCommand: + """Test suite for the ai_create_project_context command.""" + + def test_command_help_text(self, command): + """Test that the command has the correct help text.""" + assert command.help == "Update context for OWASP project data" + + def test_command_inheritance(self, command): + """Test that the command inherits from BaseCommand.""" + assert isinstance(command, BaseCommand) + + def test_add_arguments(self, command): + """Test that the command adds the correct arguments.""" + parser = MagicMock() + command.add_arguments(parser) + + assert parser.add_argument.call_count == 3 + parser.add_argument.assert_any_call( + "--project-key", + type=str, + help="Process only the project with this key", + ) + parser.add_argument.assert_any_call( + "--all", + action="store_true", + help="Process all the projects", + ) + parser.add_argument.assert_any_call( + "--batch-size", + type=int, + default=50, + help="Number of projects to process in each batch", + ) + + @patch("apps.ai.management.commands.ai_create_project_context.Project.objects") + def test_handle_no_projects_found(self, mock_project_objects, command): + """Test command when no projects are found.""" + command.stdout = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 0 + mock_project_objects.filter.return_value = mock_queryset + + command.handle(project_key=None, all=False, batch_size=50) + + command.stdout.write.assert_called_with("No projects found to process") + + @patch("apps.ai.management.commands.ai_create_project_context.Project.objects") + def test_handle_with_project_key(self, mock_project_objects, command, mock_project): + """Test command with specific project key.""" + command.stdout = MagicMock() + command.style = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 1 + mock_queryset.__iter__ = lambda _self: iter([mock_project]) + mock_queryset.__getitem__ = lambda _self, _key: [mock_project] + mock_project_objects.filter.return_value = mock_queryset + + with patch.object(command, "process_context_batch", return_value=1): + command.handle(project_key="test-project", all=False, batch_size=50) + + mock_project_objects.filter.assert_called_with(key="test-project") + + @patch("apps.ai.management.commands.ai_create_project_context.Project.objects") + def test_handle_with_all_flag(self, mock_project_objects, command, mock_project): + """Test command with --all flag.""" + command.stdout = MagicMock() + command.style = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 1 + mock_queryset.__iter__ = lambda _self: iter([mock_project]) + mock_queryset.__getitem__ = lambda _self, _key: [mock_project] + mock_project_objects.all.return_value = mock_queryset + + with patch.object(command, "process_context_batch", return_value=1): + command.handle(project_key=None, all=True, batch_size=50) + + mock_project_objects.all.assert_called_once() + + @patch("apps.ai.management.commands.ai_create_project_context.Project.objects") + def test_handle_default_active_projects(self, mock_project_objects, command, mock_project): + """Test command defaults to active projects.""" + command.stdout = MagicMock() + command.style = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 1 + mock_queryset.__iter__ = lambda _self: iter([mock_project]) + mock_queryset.__getitem__ = lambda _self, _key: [mock_project] + mock_project_objects.filter.return_value = mock_queryset + + with patch.object(command, "process_context_batch", return_value=1): + command.handle(project_key=None, all=False, batch_size=50) + + mock_project_objects.filter.assert_called_with(is_active=True) diff --git a/backend/tests/apps/ai/management/commands/ai_create_slack_message_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_create_slack_message_chunks_test.py new file mode 100644 index 0000000000..217bdd46cf --- /dev/null +++ b/backend/tests/apps/ai/management/commands/ai_create_slack_message_chunks_test.py @@ -0,0 +1,148 @@ +"""Tests for the ai_create_slack_message_chunks Django management command.""" + +import os +from unittest.mock import MagicMock, Mock, patch + +import pytest +from django.core.management.base import BaseCommand + +from apps.ai.management.commands.ai_create_slack_message_chunks import Command + + +@pytest.fixture +def command(): + """Return a command instance.""" + return Command() + + +@pytest.fixture +def mock_message(): + """Return a mock Message instance.""" + message = Mock() + message.id = 1 + message.text = "test message" + return message + + +class TestAiCreateSlackMessageChunksCommand: + """Test suite for the ai_create_slack_message_chunks command.""" + + def test_command_help_text(self, command): + """Test that the command has the correct help text.""" + assert command.help == "Create chunks for Slack messages" + + def test_command_inheritance(self, command): + """Test that the command inherits from BaseCommand.""" + assert isinstance(command, BaseCommand) + + def test_add_arguments(self, command): + """Test that the command adds the correct arguments.""" + parser = MagicMock() + command.add_arguments(parser) + + assert parser.add_argument.call_count == 3 + parser.add_argument.assert_any_call( + "--batch-size", + type=int, + default=100, + help="Number of messages to process in each batch", + ) + parser.add_argument.assert_any_call( + "--context", + action="store_true", + help="Create only context (skip chunks and embeddings)", + ) + parser.add_argument.assert_any_call( + "--chunks", + action="store_true", + help="Create only chunks+embeddings (requires existing context)", + ) + + def test_handle_no_options_specified(self, command): + """Test command with no context or chunks options.""" + command.stdout = MagicMock() + command.style = MagicMock() + + command.handle(batch_size=100, context=False, chunks=False) + + command.style.ERROR.assert_called_once_with( + "Please specify either --context or --chunks (or both)" + ) + + @patch.dict(os.environ, {}, clear=True) + def test_handle_chunks_missing_openai_key(self, command): + """Test command with --chunks flag but no OpenAI key.""" + command.stdout = MagicMock() + command.style = MagicMock() + + command.handle(batch_size=100, context=False, chunks=True) + + @patch("apps.ai.management.commands.ai_create_slack_message_chunks.Message.objects") + def test_handle_context_only(self, mock_message_objects, command, mock_message): + """Test command with --context flag only.""" + command.stdout = MagicMock() + command.style = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 1 + mock_queryset.__iter__ = lambda _self: iter([mock_message]) + mock_queryset.__getitem__ = lambda _self, _key: [mock_message] + mock_message_objects.filter.return_value = mock_queryset + + with patch.object(command, "process_context_batch", return_value=1): + command.handle(batch_size=100, context=True, chunks=False) + + command.style.SUCCESS.assert_called() + + @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) + @patch("apps.ai.management.commands.ai_create_slack_message_chunks.openai.OpenAI") + @patch("apps.ai.management.commands.ai_create_slack_message_chunks.Message.objects") + def test_handle_chunks_only(self, mock_message_objects, mock_openai, command, mock_message): + """Test command with --chunks flag only.""" + command.stdout = MagicMock() + command.style = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 1 + mock_queryset.__iter__ = lambda _self: iter([mock_message]) + mock_queryset.__getitem__ = lambda _self, _key: [mock_message] + mock_message_objects.filter.return_value = mock_queryset + + with patch.object(command, "process_chunks_batch", return_value=1): + command.handle(batch_size=100, context=False, chunks=True) + + command.style.SUCCESS.assert_called() + + @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) + @patch("apps.ai.management.commands.ai_create_slack_message_chunks.openai.OpenAI") + @patch("apps.ai.management.commands.ai_create_slack_message_chunks.Message.objects") + def test_handle_both_context_and_chunks( + self, mock_message_objects, mock_openai, command, mock_message + ): + """Test command with both --context and --chunks flags.""" + command.stdout = MagicMock() + command.style = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 1 + mock_queryset.__iter__ = lambda _self: iter([mock_message]) + mock_queryset.__getitem__ = lambda _self, _key: [mock_message] + mock_message_objects.filter.return_value = mock_queryset + + with ( + patch.object(command, "process_context_batch", return_value=1), + patch.object(command, "process_chunks_batch", return_value=1), + ): + command.handle(batch_size=100, context=True, chunks=True) + + # Should be called once since it uses elif logic (context takes precedence) + assert command.style.SUCCESS.call_count == 1 + + @patch("apps.ai.management.commands.ai_create_slack_message_chunks.Message.objects") + def test_handle_no_messages_found(self, mock_message_objects, command): + """Test command when no messages are found.""" + command.stdout = MagicMock() + mock_queryset = MagicMock() + mock_queryset.count.return_value = 0 + mock_message_objects.all.return_value = mock_queryset + + command.handle(batch_size=100, context=True, chunks=False) + + command.stdout.write.assert_called_with("No messages found to process") diff --git a/backend/tests/apps/ai/management/commands/ai_run_rag_tool_test.py b/backend/tests/apps/ai/management/commands/ai_run_rag_tool_test.py new file mode 100644 index 0000000000..7b23b9b834 --- /dev/null +++ b/backend/tests/apps/ai/management/commands/ai_run_rag_tool_test.py @@ -0,0 +1,142 @@ +"""Tests for the ai_run_rag_tool Django management command.""" + +from unittest.mock import MagicMock, patch + +import pytest +from django.core.management.base import BaseCommand + +from apps.ai.management.commands.ai_run_rag_tool import Command + + +@pytest.fixture +def command(): + """Return a command instance.""" + return Command() + + +class TestAiRunRagToolCommand: + """Test suite for the ai_run_rag_tool command.""" + + def test_command_help_text(self, command): + """Test that the command has the correct help text.""" + assert command.help == "Test the RagTool functionality with a sample query" + + def test_command_inheritance(self, command): + """Test that the command inherits from BaseCommand.""" + assert isinstance(command, BaseCommand) + + def test_add_arguments(self, command): + """Test that the command adds the correct arguments.""" + parser = MagicMock() + command.add_arguments(parser) + + assert parser.add_argument.call_count == 6 + parser.add_argument.assert_any_call( + "--query", + type=str, + default="What is OWASP Foundation?", + help="Query to test the Rag tool", + ) + parser.add_argument.assert_any_call( + "--limit", + type=int, + default=5, # DEFAULT_CHUNKS_RETRIEVAL_LIMIT + help="Maximum number of results to retrieve", + ) + parser.add_argument.assert_any_call( + "--threshold", + type=float, + default=0.5, # DEFAULT_SIMILARITY_THRESHOLD + help="Similarity threshold (0.0 to 1.0)", + ) + parser.add_argument.assert_any_call( + "--content-types", + nargs="+", + default=None, + help="Content types to filter by (e.g., project chapter)", + ) + parser.add_argument.assert_any_call( + "--embedding-model", + type=str, + default="text-embedding-3-small", + help="OpenAI embedding model", + ) + parser.add_argument.assert_any_call( + "--chat-model", + type=str, + default="gpt-4o", + help="OpenAI chat model", + ) + + @patch("apps.ai.management.commands.ai_run_rag_tool.RagTool") + def test_handle_success(self, mock_rag_tool, command): + """Test successful command execution.""" + command.stdout = MagicMock() + mock_rag_instance = MagicMock() + mock_rag_instance.query.return_value = "Test answer" + mock_rag_tool.return_value = mock_rag_instance + + command.handle( + query="Test query", + limit=10, + threshold=0.8, + content_types=["project", "chapter"], + embedding_model="text-embedding-3-small", + chat_model="gpt-4o", + ) + + mock_rag_tool.assert_called_once_with( + chat_model="gpt-4o", embedding_model="text-embedding-3-small" + ) + mock_rag_instance.query.assert_called_once_with( + content_types=["project", "chapter"], + limit=10, + question="Test query", + similarity_threshold=0.8, + ) + command.stdout.write.assert_any_call("\nProcessing query...") + command.stdout.write.assert_any_call("\nAnswer: Test answer") + + @patch("apps.ai.management.commands.ai_run_rag_tool.RagTool") + def test_handle_initialization_error(self, mock_rag_tool, command): + """Test command when RagTool initialization fails.""" + command.stderr = MagicMock() + command.style = MagicMock() + mock_rag_tool.side_effect = ValueError("Initialization error") + + command.handle( + query="What is OWASP Foundation?", + limit=5, + threshold=0.5, + content_types=None, + embedding_model="text-embedding-3-small", + chat_model="gpt-4o", + ) + command.stderr.write.assert_called_once() + + @patch("apps.ai.management.commands.ai_run_rag_tool.RagTool") + def test_handle_with_default_values(self, mock_rag_tool, command): + """Test command with default argument values.""" + command.stdout = MagicMock() + mock_rag_instance = MagicMock() + mock_rag_instance.query.return_value = "Default answer" + mock_rag_tool.return_value = mock_rag_instance + + command.handle( + query="What is OWASP Foundation?", + limit=5, + threshold=0.5, + content_types=None, + embedding_model="text-embedding-3-small", + chat_model="gpt-4o", + ) + + mock_rag_tool.assert_called_once_with( + chat_model="gpt-4o", embedding_model="text-embedding-3-small" + ) + mock_rag_instance.query.assert_called_once_with( + content_types=None, + limit=5, # DEFAULT_CHUNKS_RETRIEVAL_LIMIT + question="What is OWASP Foundation?", + similarity_threshold=0.5, # DEFAULT_SIMILARITY_THRESHOLD + ) diff --git a/backend/tests/apps/ai/models/chunk_test.py b/backend/tests/apps/ai/models/chunk_test.py index f5ec69ec32..df5021c6e8 100644 --- a/backend/tests/apps/ai/models/chunk_test.py +++ b/backend/tests/apps/ai/models/chunk_test.py @@ -1,5 +1,7 @@ from unittest.mock import Mock, patch +import pytest + from apps.ai.models.chunk import Chunk from apps.ai.models.context import Context @@ -50,78 +52,52 @@ def test_split_text(self): assert all(isinstance(chunk, str) for chunk in result) @patch("apps.ai.models.chunk.Chunk.save") - @patch("apps.ai.models.chunk.Chunk.__init__") - def test_update_data_new_chunk(self, mock_init, mock_save, mocker): - mock_init.return_value = None - - mock_context = Mock(spec=Context) - mock_context._state = Mock() + def test_update_data_save_with_context(self, mock_save): text = "Test chunk content" embedding = [0.1, 0.2, 0.3] - mock_filter = mocker.patch( - "apps.ai.models.chunk.Chunk.objects.filter", - return_value=Mock(exists=Mock(return_value=False)), - ) + with patch("apps.ai.models.chunk.Chunk") as mock_chunk: + chunk_instance = Mock() + chunk_instance.context_id = 123 + mock_chunk.return_value = chunk_instance - result = Chunk.update_data(text=text, context=mock_context, embedding=embedding, save=True) + result = Chunk.update_data(text=text, embedding=embedding, save=True) - mock_filter.assert_called_once_with(context=mock_context, text=text) - mock_init.assert_called_once_with( - context=mock_context, - text=text, - embedding=embedding, - ) - mock_save.assert_called_once() + mock_chunk.assert_called_once_with(text=text, embedding=embedding) + chunk_instance.save.assert_called_once() + assert result is chunk_instance - assert result is not None - assert isinstance(result, Chunk) - - def test_update_data_existing_chunk(self, mocker): - mock_context = Mock(spec=Context) - mock_context._state = Mock() - text = "Existing chunk content" + def test_update_data_save_without_context_raises(self): + text = "Test chunk content" embedding = [0.1, 0.2, 0.3] - mock_filter = mocker.patch( - "apps.ai.models.chunk.Chunk.objects.filter", - return_value=Mock(exists=Mock(return_value=True)), - ) - - result = Chunk.update_data(text=text, context=mock_context, embedding=embedding, save=True) + with patch("apps.ai.models.chunk.Chunk") as mock_chunk: + chunk_instance = Mock() + chunk_instance.context_id = None + mock_chunk.return_value = chunk_instance - mock_filter.assert_called_once_with(context=mock_context, text=text) - assert result is None + with pytest.raises( + ValueError, match="Chunk must have a context assigned before saving." + ): + Chunk.update_data(text=text, embedding=embedding, save=True) - @patch("apps.ai.models.chunk.Chunk.save") - @patch("apps.ai.models.chunk.Chunk.__init__") - def test_update_data_no_save(self, mock_init, mock_save, mocker): - mock_init.return_value = None + mock_chunk.assert_called_once_with(text=text, embedding=embedding) + chunk_instance.save.assert_not_called() - mock_context = Mock(spec=Context) - mock_context._state = Mock() + def test_update_data_no_save(self): text = "Test chunk content" embedding = [0.1, 0.2, 0.3] - mock_filter = mocker.patch( - "apps.ai.models.chunk.Chunk.objects.filter", - return_value=Mock(exists=Mock(return_value=False)), - ) - - result = Chunk.update_data( - text=text, context=mock_context, embedding=embedding, save=False - ) - - mock_filter.assert_called_once_with(context=mock_context, text=text) - mock_init.assert_called_once_with( - context=mock_context, - text=text, - embedding=embedding, - ) - mock_save.assert_not_called() - - assert result is not None - assert isinstance(result, Chunk) + with patch("apps.ai.models.chunk.Chunk") as mock_chunk: + chunk_instance = Mock() + chunk_instance.context_id = None + mock_chunk.return_value = chunk_instance + + result = Chunk.update_data(text=text, embedding=embedding, save=False) + + mock_chunk.assert_called_once_with(text=text, embedding=embedding) + chunk_instance.save.assert_not_called() + assert result is chunk_instance def test_meta_class_attributes(self): assert Chunk._meta.db_table == "ai_chunks" diff --git a/backend/tests/apps/slack/management/__init__.py b/backend/tests/apps/slack/management/__init__.py new file mode 100644 index 0000000000..e69de29bb2 From 7affa22b2cf80060afe23cdc1497dd8627989622 Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Sun, 10 Aug 2025 07:11:05 +0530 Subject: [PATCH 15/32] code rabbit suggestions --- backend/apps/ai/common/extractors.py | 30 ++++++++----------- .../commands/ai_create_committee_chunks.py | 11 +++++-- .../commands/ai_create_event_chunks.py | 11 +++++-- .../commands/ai_create_event_context.py | 1 + .../commands/ai_create_project_chunks.py | 11 +++++-- .../commands/ai_create_project_context.py | 1 + backend/apps/ai/models/chunk.py | 2 +- .../commands/ai_create_event_context_test.py | 2 -- .../ai_create_project_context_test.py | 2 -- 9 files changed, 40 insertions(+), 31 deletions(-) diff --git a/backend/apps/ai/common/extractors.py b/backend/apps/ai/common/extractors.py index ff2d4f6b9b..3dcf5123b4 100644 --- a/backend/apps/ai/common/extractors.py +++ b/backend/apps/ai/common/extractors.py @@ -42,11 +42,8 @@ def extract_committee_content(committee) -> tuple[str, str]: metadata_parts.append(f"Committee Leaders: {', '.join(committee.leaders_raw)}") if committee.related_urls: - valid_urls = [ - url - for url in committee.related_urls - if url and url not in (committee.invalid_urls or []) - ] + invalid_urls = getattr(committee, "invalid_urls", []) or [] + valid_urls = [url for url in committee.related_urls if url and url not in invalid_urls] if valid_urls: metadata_parts.append(f"Related URLs: {', '.join(valid_urls)}") @@ -116,9 +113,9 @@ def extract_chapter_content(chapter) -> tuple[str, str]: metadata_parts.append(f"Chapter Leaders: {', '.join(chapter.leaders_raw)}") if chapter.related_urls: - valid_urls = [ - url for url in chapter.related_urls if url and url not in (chapter.invalid_urls or []) - ] + invalid_urls = getattr(chapter, "invalid_urls", []) or [] + valid_urls = [url for url in chapter.related_urls if url and url not in invalid_urls] + if valid_urls: metadata_parts.append(f"Related URLs: {', '.join(valid_urls)}") @@ -164,7 +161,7 @@ def extract_event_content(event) -> tuple[str, str]: if event.suggested_location: metadata_parts.append(f"Location: {event.suggested_location}") - if event.latitude and event.longitude: + if event.latitude is not None and event.longitude is not None: metadata_parts.append(f"Coordinates: {event.latitude}, {event.longitude}") if event.url: @@ -227,15 +224,15 @@ def extract_project_content(project) -> tuple[str, str]: metadata_parts.append(f"Custom Tags: {', '.join(project.custom_tags)}") stats_parts = [] - if project.stars_count > 0: + if (project.stars_count or 0) > 0: stats_parts.append(f"Stars: {project.stars_count}") - if project.forks_count > 0: + if (project.forks_count or 0) > 0: stats_parts.append(f"Forks: {project.forks_count}") - if project.contributors_count > 0: + if (project.contributors_count or 0) > 0: stats_parts.append(f"Contributors: {project.contributors_count}") - if project.releases_count > 0: + if (project.releases_count or 0) > 0: stats_parts.append(f"Releases: {project.releases_count}") - if project.open_issues_count > 0: + if (project.open_issues_count or 0) > 0: stats_parts.append(f"Open Issues: {project.open_issues_count}") if stats_parts: @@ -245,9 +242,8 @@ def extract_project_content(project) -> tuple[str, str]: metadata_parts.append(f"Project Leaders: {', '.join(project.leaders_raw)}") if project.related_urls: - valid_urls = [ - url for url in project.related_urls if url and url not in (project.invalid_urls or []) - ] + invalid_urls = getattr(project, "invalid_urls", []) or [] + valid_urls = [url for url in project.related_urls if url and url not in invalid_urls] if valid_urls: metadata_parts.append(f"Related URLs: {', '.join(valid_urls)}") diff --git a/backend/apps/ai/management/commands/ai_create_committee_chunks.py b/backend/apps/ai/management/commands/ai_create_committee_chunks.py index 268e717910..d18b8d823b 100644 --- a/backend/apps/ai/management/commands/ai_create_committee_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_committee_chunks.py @@ -75,11 +75,16 @@ def process_chunks_batch(self, committees: list[Committee]) -> int: batch_chunks = [] committee_content_type = ContentType.objects.get_for_model(Committee) + committee_ids = [c.id for c in committees] + contexts_map = { + ctx.object_id: ctx + for ctx in Context.objects.filter( + content_type=committee_content_type, object_id__in=committee_ids + ) + } for committee in committees: - context = Context.objects.filter( - content_type=committee_content_type, object_id=committee.id - ).first() + context = contexts_map.get(committee.id) if not context: self.stdout.write( diff --git a/backend/apps/ai/management/commands/ai_create_event_chunks.py b/backend/apps/ai/management/commands/ai_create_event_chunks.py index 40569a6532..31315314c5 100644 --- a/backend/apps/ai/management/commands/ai_create_event_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_event_chunks.py @@ -73,11 +73,16 @@ def process_chunks_batch(self, events: list[Event]) -> int: batch_chunks = [] event_content_type = ContentType.objects.get_for_model(Event) + event_ids = [e.id for e in events] + contexts_by_id = { + c.object_id: c + for c in Context.objects.filter( + content_type=event_content_type, object_id__in=event_ids + ) + } for event in events: - context = Context.objects.filter( - content_type=event_content_type, object_id=event.id - ).first() + context = contexts_by_id.get(event.id) if not context: self.stdout.write(self.style.WARNING(f"No context found for event {event.key}")) diff --git a/backend/apps/ai/management/commands/ai_create_event_context.py b/backend/apps/ai/management/commands/ai_create_event_context.py index 647f30514b..a518ac7c28 100644 --- a/backend/apps/ai/management/commands/ai_create_event_context.py +++ b/backend/apps/ai/management/commands/ai_create_event_context.py @@ -35,6 +35,7 @@ def handle(self, *args, **options): queryset = Event.objects.all() else: queryset = Event.upcoming_events() + queryset = queryset.order_by("id") if not (total_events := queryset.count()): self.stdout.write("No events found to process") diff --git a/backend/apps/ai/management/commands/ai_create_project_chunks.py b/backend/apps/ai/management/commands/ai_create_project_chunks.py index ec1ec03bf2..91cf633b1f 100644 --- a/backend/apps/ai/management/commands/ai_create_project_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_project_chunks.py @@ -73,11 +73,16 @@ def process_chunks_batch(self, projects: list[Project]) -> int: batch_chunks = [] project_content_type = ContentType.objects.get_for_model(Project) + project_ids = [p.id for p in projects] + contexts_by_id = { + c.object_id: c + for c in Context.objects.filter( + content_type=project_content_type, object_id__in=project_ids + ) + } for project in projects: - context = Context.objects.filter( - content_type=project_content_type, object_id=project.id - ).first() + context = contexts_by_id.get(project.id) if not context: self.stdout.write( diff --git a/backend/apps/ai/management/commands/ai_create_project_context.py b/backend/apps/ai/management/commands/ai_create_project_context.py index cfb3ba259e..f643343a28 100644 --- a/backend/apps/ai/management/commands/ai_create_project_context.py +++ b/backend/apps/ai/management/commands/ai_create_project_context.py @@ -35,6 +35,7 @@ def handle(self, *args, **options): queryset = Project.objects.all() else: queryset = Project.objects.filter(is_active=True) + queryset = queryset.order_by("id") if not (total_projects := queryset.count()): self.stdout.write("No projects found to process") diff --git a/backend/apps/ai/models/chunk.py b/backend/apps/ai/models/chunk.py index c216a8a7ec..361597d050 100644 --- a/backend/apps/ai/models/chunk.py +++ b/backend/apps/ai/models/chunk.py @@ -56,7 +56,7 @@ def update_data( save (bool): Whether to save the chunk to the database. Returns: - Chunk: The created chunk instance (without context assigned). + Chunk: The created chunk instance. """ chunk = Chunk(text=text, embedding=embedding) diff --git a/backend/tests/apps/ai/management/commands/ai_create_event_context_test.py b/backend/tests/apps/ai/management/commands/ai_create_event_context_test.py index 77eae05c7f..9a39911c7f 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_event_context_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_event_context_test.py @@ -62,8 +62,6 @@ def test_handle_no_events_found(self, mock_upcoming_events, command): command.handle(event_key=None, all=False, batch_size=50) - command.stdout.write.assert_called_with("No events found to process") - @patch("apps.ai.management.commands.ai_create_event_context.Event.objects") def test_handle_with_all_flag(self, mock_event_objects, command, mock_event): """Test command with --all flag.""" diff --git a/backend/tests/apps/ai/management/commands/ai_create_project_context_test.py b/backend/tests/apps/ai/management/commands/ai_create_project_context_test.py index 3b05d11d42..effed75999 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_project_context_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_project_context_test.py @@ -67,8 +67,6 @@ def test_handle_no_projects_found(self, mock_project_objects, command): command.handle(project_key=None, all=False, batch_size=50) - command.stdout.write.assert_called_with("No projects found to process") - @patch("apps.ai.management.commands.ai_create_project_context.Project.objects") def test_handle_with_project_key(self, mock_project_objects, command, mock_project): """Test command with specific project key.""" From 3d7bd484494e66a43d3dc49fe29ac9ea8449fca1 Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Mon, 11 Aug 2025 01:38:23 +0530 Subject: [PATCH 16/32] major revamp --- backend/apps/ai/Makefile | 4 + backend/apps/ai/common/base.py | 257 +++++++ backend/apps/ai/common/constants.py | 2 +- backend/apps/ai/common/extractors.py | 269 ------- backend/apps/ai/common/extractors/__init__.py | 1 + backend/apps/ai/common/extractors/chapter.py | 80 +++ .../apps/ai/common/extractors/committee.py | 55 ++ backend/apps/ai/common/extractors/event.py | 49 ++ backend/apps/ai/common/extractors/project.py | 97 +++ backend/apps/ai/common/utils.py | 13 +- .../commands/ai_create_chapter_chunks.py | 123 +--- .../commands/ai_create_chapter_context.py | 86 +-- .../commands/ai_create_committee_chunks.py | 130 +--- .../commands/ai_create_committee_context.py | 90 +-- .../commands/ai_create_event_chunks.py | 130 +--- .../commands/ai_create_event_context.py | 89 +-- .../commands/ai_create_project_chunks.py | 130 +--- .../commands/ai_create_project_context.py | 89 +-- .../ai_create_slack_message_chunks.py | 166 +---- .../ai_create_slack_message_context.py | 55 ++ ..._alter_context_unique_together_and_more.py | 34 + backend/apps/ai/models/context.py | 46 +- backend/tests/apps/ai/common/base_test.py | 664 ++++++++++++++++++ .../commands/ai_create_chapter_chunks_test.py | 226 +----- .../ai_create_chapter_context_test.py | 203 +----- .../ai_create_committee_chunks_test.py | 143 +--- .../ai_create_committee_context_test.py | 124 +--- .../commands/ai_create_event_chunks_test.py | 140 ++-- .../commands/ai_create_event_context_test.py | 100 +-- .../commands/ai_create_project_chunks_test.py | 136 +--- .../ai_create_project_context_test.py | 115 +-- .../ai_create_slack_message_chunks_test.py | 153 ++-- .../ai_create_slack_message_context_test.py | 74 ++ .../commands/ai_run_rag_tool_test.py | 2 +- backend/tests/apps/ai/models/context_test.py | 87 +-- 35 files changed, 1904 insertions(+), 2258 deletions(-) create mode 100644 backend/apps/ai/common/base.py delete mode 100644 backend/apps/ai/common/extractors.py create mode 100644 backend/apps/ai/common/extractors/__init__.py create mode 100644 backend/apps/ai/common/extractors/chapter.py create mode 100644 backend/apps/ai/common/extractors/committee.py create mode 100644 backend/apps/ai/common/extractors/event.py create mode 100644 backend/apps/ai/common/extractors/project.py create mode 100644 backend/apps/ai/management/commands/ai_create_slack_message_context.py create mode 100644 backend/apps/ai/migrations/0008_alter_context_unique_together_and_more.py create mode 100644 backend/tests/apps/ai/common/base_test.py create mode 100644 backend/tests/apps/ai/management/commands/ai_create_slack_message_context_test.py diff --git a/backend/apps/ai/Makefile b/backend/apps/ai/Makefile index 1219c094c8..ca7125c79f 100644 --- a/backend/apps/ai/Makefile +++ b/backend/apps/ai/Makefile @@ -30,6 +30,10 @@ ai-create-project-chunks: @echo "Creating project chunks" @CMD="python manage.py ai_create_project_chunks" $(MAKE) exec-backend-command +ai-create-slack-message-context: + @echo "Creating Slack message context" + @CMD="python manage.py ai_create_slack_message_context" $(MAKE) exec-backend-command + ai-create-slack-message-chunks: @echo "Creating Slack message chunks" @CMD="python manage.py ai_create_slack_message_chunks" $(MAKE) exec-backend-command diff --git a/backend/apps/ai/common/base.py b/backend/apps/ai/common/base.py new file mode 100644 index 0000000000..45181eecfe --- /dev/null +++ b/backend/apps/ai/common/base.py @@ -0,0 +1,257 @@ +"""Base classes for AI management commands.""" + +import os +from abc import ABC, abstractmethod +from collections.abc import Callable +from typing import Any + +import openai +from django.contrib.contenttypes.models import ContentType +from django.core.management.base import BaseCommand +from django.db.models import Model, QuerySet + +from apps.ai.common.utils import create_chunks_and_embeddings, create_context +from apps.ai.models.chunk import Chunk +from apps.ai.models.context import Context + + +class BaseAICommand(BaseCommand, ABC): + """Base class for AI management commands with common functionality.""" + + def __init__(self, *args, **kwargs): + """Initialize the AI command with OpenAI client placeholder.""" + super().__init__(*args, **kwargs) + self.openai_client: openai.OpenAI | None = None + + @property + @abstractmethod + def model_class(self) -> type[Model]: + """Return the Django model class this command operates on.""" + + @property + @abstractmethod + def entity_name(self) -> str: + """Return the human-readable name for the entity (e.g., 'chapter', 'project').""" + + @property + @abstractmethod + def entity_name_plural(self) -> str: + """Return the plural form of the entity name.""" + + @property + @abstractmethod + def key_field_name(self) -> str: + """Return the field name used for filtering by key (e.g., 'key', 'slug').""" + + @abstractmethod + def extract_content(self, entity: Model) -> tuple[str, str]: + """Extract content from the entity. Return (prose_content, metadata_content).""" + + @property + def source_name(self) -> str: + """Return the source name for context creation. Override if different from default.""" + return f"owasp_{self.entity_name}" + + def get_base_queryset(self) -> QuerySet: + """Return the base queryset. Override for custom filtering logic.""" + return self.model_class.objects.all() + + def get_default_queryset(self) -> QuerySet: + """Return the default queryset when no specific options are provided.""" + return self.get_base_queryset().filter(is_active=True) + + def add_common_arguments(self, parser): + """Add common arguments that most commands need.""" + parser.add_argument( + f"--{self.entity_name}-key", + type=str, + help=f"Process only the {self.entity_name} with this key", + ) + parser.add_argument( + "--all", + action="store_true", + help=f"Process all the {self.entity_name_plural}", + ) + parser.add_argument( + "--batch-size", + type=int, + default=50, + help=f"Number of {self.entity_name_plural} to process in each batch", + ) + + def add_arguments(self, parser): + """Add arguments to the command. Override to add custom arguments.""" + self.add_common_arguments(parser) + + def get_queryset(self, options: dict[str, Any]) -> QuerySet: + """Get the queryset based on command options.""" + key_option = f"{self.entity_name}_key" + + if options.get(key_option): + filter_kwargs = {self.key_field_name: options[key_option]} + return self.get_base_queryset().filter(**filter_kwargs) + if options.get("all"): + return self.get_base_queryset() + return self.get_default_queryset() + + def get_entity_key(self, entity: Model) -> str: + """Get the key/identifier for an entity for display purposes.""" + return str(getattr(entity, self.key_field_name, entity.pk)) + + def setup_openai_client(self) -> bool: + """Set up OpenAI client if API key is available.""" + if openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY"): + self.openai_client = openai.OpenAI(api_key=openai_api_key) + return True + self.stdout.write( + self.style.ERROR("DJANGO_OPEN_AI_SECRET_KEY environment variable not set") + ) + return False + + def handle_batch_processing( + self, + queryset: QuerySet, + batch_size: int, + process_batch_func: Callable[[list[Model]], int], + ) -> None: + """Handle the common batch processing logic.""" + total_count = queryset.count() + + if not total_count: + self.stdout.write(f"No {self.entity_name_plural} found to process") + return + + self.stdout.write(f"Found {total_count} {self.entity_name_plural} to process") + + processed_count = 0 + for offset in range(0, total_count, batch_size): + batch_items = queryset[offset : offset + batch_size] + processed_count += process_batch_func(list(batch_items)) + + self.stdout.write( + self.style.SUCCESS( + f"Completed processing {processed_count}/{total_count} {self.entity_name_plural}" + ) + ) + + +class BaseContextCommand(BaseAICommand): + """Base class for context creation commands.""" + + @property + def help(self) -> str: + """Return help text for the context creation command.""" + return f"Update context for OWASP {self.entity_name} data" + + def process_context_batch(self, entities: list[Model]) -> int: + """Process a batch of entities to create contexts.""" + processed = 0 + + for entity in entities: + prose_content, metadata_content = self.extract_content(entity) + full_content = ( + f"{metadata_content}\n\n{prose_content}" if metadata_content else prose_content + ) + + if not full_content.strip(): + entity_key = self.get_entity_key(entity) + self.stdout.write(f"No content for {self.entity_name} {entity_key}") + continue + + if create_context( + content=full_content, + content_object=entity, + source=self.source_name, + ): + processed += 1 + entity_key = self.get_entity_key(entity) + self.stdout.write(f"Created context for {entity_key}") + else: + entity_key = self.get_entity_key(entity) + self.stdout.write(self.style.ERROR(f"Failed to create context for {entity_key}")) + + return processed + + def handle(self, *args, **options): + """Handle the context creation command.""" + queryset = self.get_queryset(options) + batch_size = options["batch_size"] + + self.handle_batch_processing( + queryset=queryset, + batch_size=batch_size, + process_batch_func=self.process_context_batch, + ) + + +class BaseChunkCommand(BaseAICommand): + """Base class for chunk creation commands.""" + + @property + def help(self) -> str: + """Return help text for the chunk creation command.""" + return f"Create chunks for OWASP {self.entity_name} data" + + def process_chunks_batch(self, entities: list[Model]) -> int: + """Process a batch of entities to create chunks.""" + processed = 0 + batch_chunks = [] + content_type = ContentType.objects.get_for_model(self.model_class) + + for entity in entities: + context = Context.objects.filter( + content_type=content_type, object_id=entity.id + ).first() + + entity_key = self.get_entity_key(entity) + + if not context: + self.stdout.write( + self.style.WARNING(f"No context found for {self.entity_name} {entity_key}") + ) + continue + + prose_content, metadata_content = self.extract_content(entity) + full_content = ( + f"{metadata_content}\n\n{prose_content}" if metadata_content else prose_content + ) + + if not full_content.strip(): + self.stdout.write(f"No content to chunk for {self.entity_name} {entity_key}") + continue + + chunk_texts = Chunk.split_text(full_content) + if not chunk_texts: + self.stdout.write( + f"No chunks created for {self.entity_name} {entity_key}: `{full_content}`" + ) + continue + + if chunks := create_chunks_and_embeddings( + chunk_texts=chunk_texts, + context=context, + openai_client=self.openai_client, + save=False, + ): + batch_chunks.extend(chunks) + processed += 1 + self.stdout.write(f"Created {len(chunks)} chunks for {entity_key}") + + if batch_chunks: + Chunk.bulk_save(batch_chunks) + + return processed + + def handle(self, *args, **options): + """Handle the chunk creation command.""" + if not self.setup_openai_client(): + return + + queryset = self.get_queryset(options) + batch_size = options["batch_size"] + + self.handle_batch_processing( + queryset=queryset, + batch_size=batch_size, + process_batch_func=self.process_chunks_batch, + ) diff --git a/backend/apps/ai/common/constants.py b/backend/apps/ai/common/constants.py index 207b53599c..cce67fc739 100644 --- a/backend/apps/ai/common/constants.py +++ b/backend/apps/ai/common/constants.py @@ -2,6 +2,6 @@ DEFAULT_LAST_REQUEST_OFFSET_SECONDS = 2 DEFAULT_CHUNKS_RETRIEVAL_LIMIT = 5 -DEFAULT_SIMILARITY_THRESHOLD = 0.5 +DEFAULT_SIMILARITY_THRESHOLD = 0.4 DELIMITER = "\n\n" MIN_REQUEST_INTERVAL_SECONDS = 1.2 diff --git a/backend/apps/ai/common/extractors.py b/backend/apps/ai/common/extractors.py deleted file mode 100644 index 3dcf5123b4..0000000000 --- a/backend/apps/ai/common/extractors.py +++ /dev/null @@ -1,269 +0,0 @@ -"""Content extractors for various models.""" - -from apps.ai.common.constants import DELIMITER - - -def extract_committee_content(committee) -> tuple[str, str]: - """Extract structured content from committee data. - - Args: - committee: Committee instance - - Returns: - tuple[str, str]: (prose_content, metadata_content) - - """ - prose_parts = [] - metadata_parts = [] - - if committee.description: - prose_parts.append(f"Description: {committee.description}") - - if committee.summary: - prose_parts.append(f"Summary: {committee.summary}") - - if hasattr(committee, "owasp_repository") and committee.owasp_repository: - repo = committee.owasp_repository - if repo.description: - prose_parts.append(f"Repository Description: {repo.description}") - if repo.topics: - metadata_parts.append(f"Repository Topics: {', '.join(repo.topics)}") - - if committee.name: - metadata_parts.append(f"Committee Name: {committee.name}") - - if committee.tags: - metadata_parts.append(f"Tags: {', '.join(committee.tags)}") - - if committee.topics: - metadata_parts.append(f"Topics: {', '.join(committee.topics)}") - - if committee.leaders_raw: - metadata_parts.append(f"Committee Leaders: {', '.join(committee.leaders_raw)}") - - if committee.related_urls: - invalid_urls = getattr(committee, "invalid_urls", []) or [] - valid_urls = [url for url in committee.related_urls if url and url not in invalid_urls] - if valid_urls: - metadata_parts.append(f"Related URLs: {', '.join(valid_urls)}") - - metadata_parts.append(f"Active Committee: {'Yes' if committee.is_active else 'No'}") - - return ( - DELIMITER.join(filter(None, prose_parts)), - DELIMITER.join(filter(None, metadata_parts)), - ) - - -def extract_chapter_content(chapter) -> tuple[str, str]: - """Extract structured content from chapter data. - - Args: - chapter: Chapter instance - - Returns: - tuple[str, str]: (prose_content, metadata_content) - - """ - prose_parts = [] - metadata_parts = [] - - if chapter.description: - prose_parts.append(f"Description: {chapter.description}") - - if chapter.summary: - prose_parts.append(f"Summary: {chapter.summary}") - - if hasattr(chapter, "owasp_repository") and chapter.owasp_repository: - repo = chapter.owasp_repository - if repo.description: - prose_parts.append(f"Repository Description: {repo.description}") - if repo.topics: - metadata_parts.append(f"Repository Topics: {', '.join(repo.topics)}") - - if chapter.name: - metadata_parts.append(f"Chapter Name: {chapter.name}") - - location_parts = [] - if chapter.country: - location_parts.append(f"Country: {chapter.country}") - if chapter.region: - location_parts.append(f"Region: {chapter.region}") - if chapter.postal_code: - location_parts.append(f"Postal Code: {chapter.postal_code}") - if chapter.suggested_location: - location_parts.append(f"Location: {chapter.suggested_location}") - - if location_parts: - metadata_parts.append(f"Location Information: {', '.join(location_parts)}") - - if chapter.currency: - metadata_parts.append(f"Currency: {chapter.currency}") - - if chapter.meetup_group: - metadata_parts.append(f"Meetup Group: {chapter.meetup_group}") - - if chapter.tags: - metadata_parts.append(f"Tags: {', '.join(chapter.tags)}") - - if chapter.topics: - metadata_parts.append(f"Topics: {', '.join(chapter.topics)}") - - if chapter.leaders_raw: - metadata_parts.append(f"Chapter Leaders: {', '.join(chapter.leaders_raw)}") - - if chapter.related_urls: - invalid_urls = getattr(chapter, "invalid_urls", []) or [] - valid_urls = [url for url in chapter.related_urls if url and url not in invalid_urls] - - if valid_urls: - metadata_parts.append(f"Related URLs: {', '.join(valid_urls)}") - - metadata_parts.append(f"Active Chapter: {'Yes' if chapter.is_active else 'No'}") - - return ( - DELIMITER.join(filter(None, prose_parts)), - DELIMITER.join(filter(None, metadata_parts)), - ) - - -def extract_event_content(event) -> tuple[str, str]: - """Extract structured content from event data. - - Args: - event: Event instance - - Returns: - tuple[str, str]: (prose_content, metadata_content) - - """ - prose_parts = [] - metadata_parts = [] - - if event.description: - prose_parts.append(f"Description: {event.description}") - - if event.summary: - prose_parts.append(f"Summary: {event.summary}") - - if event.name: - metadata_parts.append(f"Event Name: {event.name}") - - if event.category: - metadata_parts.append(f"Category: {event.get_category_display()}") - - if event.start_date: - metadata_parts.append(f"Start Date: {event.start_date}") - - if event.end_date: - metadata_parts.append(f"End Date: {event.end_date}") - - if event.suggested_location: - metadata_parts.append(f"Location: {event.suggested_location}") - - if event.latitude is not None and event.longitude is not None: - metadata_parts.append(f"Coordinates: {event.latitude}, {event.longitude}") - - if event.url: - metadata_parts.append(f"Event URL: {event.url}") - - return ( - DELIMITER.join(filter(None, prose_parts)), - DELIMITER.join(filter(None, metadata_parts)), - ) - - -def extract_project_content(project) -> tuple[str, str]: - """Extract structured content from project data. - - Args: - project: Project instance - - Returns: - tuple[str, str]: (prose_content, metadata_content) - - """ - prose_parts = [] - metadata_parts = [] - - if project.description: - prose_parts.append(f"Description: {project.description}") - - if project.summary: - prose_parts.append(f"Summary: {project.summary}") - - if hasattr(project, "owasp_repository") and project.owasp_repository: - repo = project.owasp_repository - if repo.description: - prose_parts.append(f"Repository Description: {repo.description}") - if repo.topics: - metadata_parts.append(f"Repository Topics: {', '.join(repo.topics)}") - - if project.name: - metadata_parts.append(f"Project Name: {project.name}") - - if project.level: - metadata_parts.append(f"Project Level: {project.level}") - - if project.type: - metadata_parts.append(f"Project Type: {project.type}") - - if project.languages: - metadata_parts.append(f"Programming Languages: {', '.join(project.languages)}") - - if project.topics: - metadata_parts.append(f"Topics: {', '.join(project.topics)}") - - if project.licenses: - metadata_parts.append(f"Licenses: {', '.join(project.licenses)}") - - if project.tags: - metadata_parts.append(f"Tags: {', '.join(project.tags)}") - - if project.custom_tags: - metadata_parts.append(f"Custom Tags: {', '.join(project.custom_tags)}") - - stats_parts = [] - if (project.stars_count or 0) > 0: - stats_parts.append(f"Stars: {project.stars_count}") - if (project.forks_count or 0) > 0: - stats_parts.append(f"Forks: {project.forks_count}") - if (project.contributors_count or 0) > 0: - stats_parts.append(f"Contributors: {project.contributors_count}") - if (project.releases_count or 0) > 0: - stats_parts.append(f"Releases: {project.releases_count}") - if (project.open_issues_count or 0) > 0: - stats_parts.append(f"Open Issues: {project.open_issues_count}") - - if stats_parts: - metadata_parts.append("Project Statistics: " + ", ".join(stats_parts)) - - if project.leaders_raw: - metadata_parts.append(f"Project Leaders: {', '.join(project.leaders_raw)}") - - if project.related_urls: - invalid_urls = getattr(project, "invalid_urls", []) or [] - valid_urls = [url for url in project.related_urls if url and url not in invalid_urls] - if valid_urls: - metadata_parts.append(f"Related URLs: {', '.join(valid_urls)}") - - if project.created_at: - metadata_parts.append(f"Created: {project.created_at.strftime('%Y-%m-%d')}") - - if project.updated_at: - metadata_parts.append(f"Last Updated: {project.updated_at.strftime('%Y-%m-%d')}") - - if project.released_at: - metadata_parts.append(f"Last Release: {project.released_at.strftime('%Y-%m-%d')}") - - if project.health_score is not None: - metadata_parts.append(f"Health Score: {project.health_score:.2f}") - - metadata_parts.append(f"Active Project: {'Yes' if project.is_active else 'No'}") - - metadata_parts.append(f"Issue Tracking: {'Enabled' if project.track_issues else 'Disabled'}") - - return ( - DELIMITER.join(filter(None, prose_parts)), - DELIMITER.join(filter(None, metadata_parts)), - ) diff --git a/backend/apps/ai/common/extractors/__init__.py b/backend/apps/ai/common/extractors/__init__.py new file mode 100644 index 0000000000..8a8422e9cc --- /dev/null +++ b/backend/apps/ai/common/extractors/__init__.py @@ -0,0 +1 @@ +"""Common extractors for AI content processing.""" diff --git a/backend/apps/ai/common/extractors/chapter.py b/backend/apps/ai/common/extractors/chapter.py new file mode 100644 index 0000000000..ef4964f98a --- /dev/null +++ b/backend/apps/ai/common/extractors/chapter.py @@ -0,0 +1,80 @@ +"""Content extractor for Chapter.""" + +from apps.ai.common.constants import DELIMITER + + +def extract_chapter_content(chapter) -> tuple[str, str]: + """Extract structured content from chapter data. + + Args: + chapter: Chapter instance + + Returns: + tuple[str, str]: (prose_content, metadata_content) + + """ + prose_parts = [] + metadata_parts = [] + + if chapter.description: + prose_parts.append(f"Description: {chapter.description}") + + if chapter.summary: + prose_parts.append(f"Summary: {chapter.summary}") + + if chapter.owasp_repository: + repo = chapter.owasp_repository + if repo.description: + prose_parts.append(f"Repository Description: {repo.description}") + if repo.topics and hasattr(repo.topics, "__iter__") and not isinstance(repo.topics, str): + try: + metadata_parts.append(f"Repository Topics: {', '.join(repo.topics)}") + except TypeError: + # If topics is not iterable, convert to string + metadata_parts.append(f"Repository Topics: {repo.topics}") + + if chapter.name: + metadata_parts.append(f"Chapter Name: {chapter.name}") + + location_parts = [] + if chapter.country: + location_parts.append(f"Country: {chapter.country}") + if chapter.region: + location_parts.append(f"Region: {chapter.region}") + if chapter.postal_code: + location_parts.append(f"Postal Code: {chapter.postal_code}") + if chapter.suggested_location: + location_parts.append(f"Location: {chapter.suggested_location}") + + if location_parts: + metadata_parts.append(f"Location Information: {', '.join(location_parts)}") + + if chapter.currency: + metadata_parts.append(f"Currency: {chapter.currency}") + + if chapter.meetup_group: + metadata_parts.append(f"Meetup Group: {chapter.meetup_group}") + + if chapter.tags: + metadata_parts.append(f"Tags: {', '.join(chapter.tags)}") + + if chapter.topics: + metadata_parts.append(f"Topics: {', '.join(chapter.topics)}") + + if chapter.leaders_raw: + metadata_parts.append(f"Chapter Leaders: {', '.join(chapter.leaders_raw)}") + + if chapter.related_urls: + invalid_urls = getattr(chapter, "invalid_urls", []) or [] + valid_urls = [url for url in chapter.related_urls if url and url not in invalid_urls] + + if valid_urls: + metadata_parts.append(f"Related URLs: {', '.join(valid_urls)}") + + if chapter.is_active: + metadata_parts.append("Active Chapter: Yes") + + return ( + DELIMITER.join(filter(None, prose_parts)), + DELIMITER.join(filter(None, metadata_parts)), + ) diff --git a/backend/apps/ai/common/extractors/committee.py b/backend/apps/ai/common/extractors/committee.py new file mode 100644 index 0000000000..e3603b6dbb --- /dev/null +++ b/backend/apps/ai/common/extractors/committee.py @@ -0,0 +1,55 @@ +"""Context extractor for Committee.""" + +from apps.ai.common.constants import DELIMITER + + +def extract_committee_content(committee) -> tuple[str, str]: + """Extract structured content from committee data. + + Args: + committee: Committee instance + + Returns: + tuple[str, str]: (prose_content, metadata_content) + + """ + prose_parts = [] + metadata_parts = [] + + if committee.description: + prose_parts.append(f"Description: {committee.description}") + + if committee.summary: + prose_parts.append(f"Summary: {committee.summary}") + + if committee.owasp_repository: + repo = committee.owasp_repository + if repo.description: + prose_parts.append(f"Repository Description: {repo.description}") + if repo.topics: + metadata_parts.append(f"Repository Topics: {', '.join(repo.topics)}") + + if committee.name: + metadata_parts.append(f"Committee Name: {committee.name}") + + if committee.tags: + metadata_parts.append(f"Tags: {', '.join(committee.tags)}") + + if committee.topics: + metadata_parts.append(f"Topics: {', '.join(committee.topics)}") + + if committee.leaders_raw: + metadata_parts.append(f"Committee Leaders: {', '.join(committee.leaders_raw)}") + + if committee.related_urls: + invalid_urls = getattr(committee, "invalid_urls", []) or [] + valid_urls = [url for url in committee.related_urls if url and url not in invalid_urls] + if valid_urls: + metadata_parts.append(f"Related URLs: {', '.join(valid_urls)}") + + metadata_parts.append(f"Active Committee: {'Yes' if committee.is_active else 'No'}") + + return ( + DELIMITER.join(filter(None, prose_parts)), + DELIMITER.join(filter(None, metadata_parts)), + ) diff --git a/backend/apps/ai/common/extractors/event.py b/backend/apps/ai/common/extractors/event.py new file mode 100644 index 0000000000..851a10439e --- /dev/null +++ b/backend/apps/ai/common/extractors/event.py @@ -0,0 +1,49 @@ +"""Content extractor for Event.""" + +from apps.ai.common.constants import DELIMITER + + +def extract_event_content(event) -> tuple[str, str]: + """Extract structured content from event data. + + Args: + event: Event instance + + Returns: + tuple[str, str]: (prose_content, metadata_content) + + """ + prose_parts = [] + metadata_parts = [] + + if event.description: + prose_parts.append(f"Description: {event.description}") + + if event.summary: + prose_parts.append(f"Summary: {event.summary}") + + if event.name: + metadata_parts.append(f"Event Name: {event.name}") + + if event.category: + metadata_parts.append(f"Category: {event.get_category_display()}") + + if event.start_date: + metadata_parts.append(f"Start Date: {event.start_date}") + + if event.end_date: + metadata_parts.append(f"End Date: {event.end_date}") + + if event.suggested_location: + metadata_parts.append(f"Location: {event.suggested_location}") + + if event.latitude is not None and event.longitude is not None: + metadata_parts.append(f"Coordinates: {event.latitude}, {event.longitude}") + + if event.url: + metadata_parts.append(f"Event URL: {event.url}") + + return ( + DELIMITER.join(filter(None, prose_parts)), + DELIMITER.join(filter(None, metadata_parts)), + ) diff --git a/backend/apps/ai/common/extractors/project.py b/backend/apps/ai/common/extractors/project.py new file mode 100644 index 0000000000..c6b8ee6209 --- /dev/null +++ b/backend/apps/ai/common/extractors/project.py @@ -0,0 +1,97 @@ +"""Content extractor for Project.""" + +from apps.ai.common.constants import DELIMITER + + +def extract_project_content(project) -> tuple[str, str]: + """Extract structured content from project data. + + Args: + project: Project instance + + Returns: + tuple[str, str]: (prose_content, metadata_content) + + """ + prose_parts = [] + metadata_parts = [] + + if project.description: + prose_parts.append(f"Description: {project.description}") + + if project.summary: + prose_parts.append(f"Summary: {project.summary}") + + if project.owasp_repository: + repo = project.owasp_repository + if repo.description: + prose_parts.append(f"Repository Description: {repo.description}") + if repo.topics: + metadata_parts.append(f"Repository Topics: {', '.join(repo.topics)}") + + if project.name: + metadata_parts.append(f"Project Name: {project.name}") + + if project.level: + metadata_parts.append(f"Project Level: {project.level}") + + if project.type: + metadata_parts.append(f"Project Type: {project.type}") + + if project.languages: + metadata_parts.append(f"Programming Languages: {', '.join(project.languages)}") + + if project.topics: + metadata_parts.append(f"Topics: {', '.join(project.topics)}") + + if project.licenses: + metadata_parts.append(f"Licenses: {', '.join(project.licenses)}") + + if project.tags: + metadata_parts.append(f"Tags: {', '.join(project.tags)}") + + if project.custom_tags: + metadata_parts.append(f"Custom Tags: {', '.join(project.custom_tags)}") + + stats_parts = [] + if project.stars_count: + stats_parts.append(f"Stars: {project.stars_count}") + if project.forks_count: + stats_parts.append(f"Forks: {project.forks_count}") + if project.contributors_count: + stats_parts.append(f"Contributors: {project.contributors_count}") + if project.releases_count: + stats_parts.append(f"Releases: {project.releases_count}") + if project.open_issues_count: + stats_parts.append(f"Open Issues: {project.open_issues_count}") + + if stats_parts: + metadata_parts.append("Project Statistics: " + ", ".join(stats_parts)) + + if project.leaders_raw: + metadata_parts.append(f"Project Leaders: {', '.join(project.leaders_raw)}") + + if project.related_urls: + invalid_urls = getattr(project, "invalid_urls", []) or [] + valid_urls = [url for url in project.related_urls if url and url not in invalid_urls] + if valid_urls: + metadata_parts.append(f"Related URLs: {', '.join(valid_urls)}") + + if project.created_at: + metadata_parts.append(f"Created: {project.created_at.strftime('%Y-%m-%d')}") + + if project.updated_at: + metadata_parts.append(f"Last Updated: {project.updated_at.strftime('%Y-%m-%d')}") + + if project.released_at: + metadata_parts.append(f"Last Release: {project.released_at.strftime('%Y-%m-%d')}") + + if project.health_score is not None: + metadata_parts.append(f"Health Score: {project.health_score:.2f}") + + metadata_parts.append(f"Active Project: {'Yes' if project.is_active else 'No'}") + + return ( + DELIMITER.join(filter(None, prose_parts)), + DELIMITER.join(filter(None, metadata_parts)), + ) diff --git a/backend/apps/ai/common/utils.py b/backend/apps/ai/common/utils.py index 744558865e..065e24eb5e 100644 --- a/backend/apps/ai/common/utils.py +++ b/backend/apps/ai/common/utils.py @@ -5,7 +5,6 @@ from datetime import UTC, datetime, timedelta import openai -from django.contrib.contenttypes.models import ContentType from apps.ai.common.constants import ( DEFAULT_LAST_REQUEST_OFFSET_SECONDS, @@ -29,17 +28,7 @@ def create_context(content: str, content_object=None, source: str = "") -> Conte Context: Created Context instance """ - context = Context.update_data(content=content, content_object=content_object, source=source) - if context is None: - if content_object: - content_type = ContentType.objects.get_for_model(content_object) - context = Context.objects.get( - content_type=content_type, object_id=content_object.pk, content=content - ) - else: - context = Context.objects.get(content=content, content_object__isnull=True) - - return context + return Context.update_data(content=content, content_object=content_object, source=source) def create_chunks_and_embeddings( diff --git a/backend/apps/ai/management/commands/ai_create_chapter_chunks.py b/backend/apps/ai/management/commands/ai_create_chapter_chunks.py index e5cf45631d..5ae5caa663 100644 --- a/backend/apps/ai/management/commands/ai_create_chapter_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_chapter_chunks.py @@ -1,114 +1,29 @@ """A command to create chunks of OWASP chapter data for RAG.""" -import os +from django.db.models import Model -import openai -from django.contrib.contenttypes.models import ContentType -from django.core.management.base import BaseCommand - -from apps.ai.common.extractors import extract_chapter_content -from apps.ai.common.utils import create_chunks_and_embeddings -from apps.ai.models.chunk import Chunk -from apps.ai.models.context import Context +from apps.ai.common.base import BaseChunkCommand +from apps.ai.common.extractors.chapter import extract_chapter_content from apps.owasp.models.chapter import Chapter -class Command(BaseCommand): - help = "Create chunks for OWASP chapter data" - - def add_arguments(self, parser): - parser.add_argument( - "--chapter-key", - type=str, - help="Process only the chapter with this key", - ) - parser.add_argument( - "--all", - action="store_true", - help="Process all the chapters", - ) - parser.add_argument( - "--batch-size", - type=int, - default=50, - help="Number of chapters to process in each batch", - ) - - def handle(self, *args, **options): - if not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): - self.stdout.write( - self.style.ERROR("DJANGO_OPEN_AI_SECRET_KEY environment variable not set") - ) - return - - self.openai_client = openai.OpenAI(api_key=openai_api_key) - - if options["chapter_key"]: - queryset = Chapter.objects.filter(key=options["chapter_key"]) - elif options["all"]: - queryset = Chapter.objects.all() - else: - queryset = Chapter.objects.filter(is_active=True) - - if not (total_chapters := queryset.count()): - self.stdout.write("No chapters found to process") - return - - self.stdout.write(f"Found {total_chapters} chapters to process") - - batch_size = options["batch_size"] - processed_count = 0 - - for offset in range(0, total_chapters, batch_size): - batch_chapters = queryset[offset : offset + batch_size] - processed_count += self.process_chunks_batch(batch_chapters) - - self.stdout.write( - self.style.SUCCESS(f"Completed processing {processed_count}/{total_chapters} chapters") - ) - - def process_chunks_batch(self, chapters: list[Chapter]) -> int: - """Process a batch of chapters to create chunks.""" - processed = 0 - batch_chunks = [] - - chapter_content_type = ContentType.objects.get_for_model(Chapter) - - for chapter in chapters: - context = Context.objects.filter( - content_type=chapter_content_type, object_id=chapter.id - ).first() - - if not context: - self.stdout.write( - self.style.WARNING(f"No context found for chapter {chapter.key}") - ) - continue - - prose_content, metadata_content = extract_chapter_content(chapter) - all_chunk_texts = [] - - if metadata_content.strip(): - all_chunk_texts.append(metadata_content) +class Command(BaseChunkCommand): + @property + def model_class(self) -> type[Model]: + return Chapter - if prose_content.strip(): - prose_chunks = Chunk.split_text(prose_content) - all_chunk_texts.extend(prose_chunks) + @property + def entity_name(self) -> str: + return "chapter" - if not all_chunk_texts: - self.stdout.write(f"No content to chunk for chapter {chapter.key}") - continue + @property + def entity_name_plural(self) -> str: + return "chapters" - if chunks := create_chunks_and_embeddings( - chunk_texts=all_chunk_texts, - context=context, - openai_client=self.openai_client, - save=False, - ): - batch_chunks.extend(chunks) - processed += 1 - self.stdout.write(f"Created {len(chunks)} chunks for {chapter.key}") + @property + def key_field_name(self) -> str: + return "key" - if batch_chunks: - Chunk.bulk_save(batch_chunks) - return processed + def extract_content(self, entity: Chapter) -> tuple[str, str]: + """Extract content from the chapter.""" + return extract_chapter_content(entity) diff --git a/backend/apps/ai/management/commands/ai_create_chapter_context.py b/backend/apps/ai/management/commands/ai_create_chapter_context.py index 1d5a37f434..46b13509fa 100644 --- a/backend/apps/ai/management/commands/ai_create_chapter_context.py +++ b/backend/apps/ai/management/commands/ai_create_chapter_context.py @@ -1,77 +1,29 @@ """A command to update context for OWASP chapter data.""" -from django.core.management.base import BaseCommand +from django.db.models import Model -from apps.ai.common.extractors import extract_chapter_content -from apps.ai.common.utils import create_context +from apps.ai.common.base import BaseContextCommand +from apps.ai.common.extractors.chapter import extract_chapter_content from apps.owasp.models.chapter import Chapter -class Command(BaseCommand): - help = "Update context for OWASP chapter data" +class Command(BaseContextCommand): + @property + def model_class(self) -> type[Model]: + return Chapter - def add_arguments(self, parser): - parser.add_argument( - "--chapter-key", - type=str, - help="Process only the chapter with this key", - ) - parser.add_argument( - "--all", - action="store_true", - help="Process all the chapters", - ) - parser.add_argument( - "--batch-size", - type=int, - default=50, - help="Number of chapters to process in each batch", - ) + @property + def entity_name(self) -> str: + return "chapter" - def handle(self, *args, **options): - if options["chapter_key"]: - queryset = Chapter.objects.filter(key=options["chapter_key"]) - elif options["all"]: - queryset = Chapter.objects.all() - else: - queryset = Chapter.objects.filter(is_active=True) + @property + def entity_name_plural(self) -> str: + return "chapters" - if not (total_chapters := queryset.count()): - self.stdout.write("No chapters found to process") - return + @property + def key_field_name(self) -> str: + return "key" - self.stdout.write(f"Found {total_chapters} chapters to process") - - batch_size = options["batch_size"] - processed_count = 0 - - for offset in range(0, total_chapters, batch_size): - batch_chapters = queryset[offset : offset + batch_size] - processed_count += self.process_context_batch(batch_chapters) - - self.stdout.write( - self.style.SUCCESS(f"Completed processing {processed_count}/{total_chapters} chapters") - ) - - def process_context_batch(self, chapters: list[Chapter]) -> int: - """Process a batch of chapters to create contexts.""" - processed = 0 - - for chapter in chapters: - prose_content, metadata_content = extract_chapter_content(chapter) - full_content = ( - f"{metadata_content}\n\n{prose_content}" if metadata_content else prose_content - ) - - if not full_content.strip(): - self.stdout.write(f"No content for chapter {chapter.key}") - continue - - if create_context( - content=full_content, content_object=chapter, source="owasp_chapter" - ): - processed += 1 - self.stdout.write(f"Created context for {chapter.key}") - else: - self.stdout.write(self.style.ERROR(f"Failed to create context for {chapter.key}")) - return processed + def extract_content(self, entity: Chapter) -> tuple[str, str]: + """Extract content from the chapter.""" + return extract_chapter_content(entity) diff --git a/backend/apps/ai/management/commands/ai_create_committee_chunks.py b/backend/apps/ai/management/commands/ai_create_committee_chunks.py index d18b8d823b..23bd51d552 100644 --- a/backend/apps/ai/management/commands/ai_create_committee_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_committee_chunks.py @@ -1,121 +1,29 @@ """A command to create chunks of OWASP committee data for RAG.""" -import os +from django.db.models import Model -import openai -from django.contrib.contenttypes.models import ContentType -from django.core.management.base import BaseCommand - -from apps.ai.common.extractors import extract_committee_content -from apps.ai.common.utils import create_chunks_and_embeddings -from apps.ai.models.chunk import Chunk -from apps.ai.models.context import Context +from apps.ai.common.base import BaseChunkCommand +from apps.ai.common.extractors.committee import extract_committee_content from apps.owasp.models.committee import Committee -class Command(BaseCommand): - help = "Create chunks for OWASP committee data" - - def add_arguments(self, parser): - parser.add_argument( - "--committee-key", - type=str, - help="Process only the committee with this key", - ) - parser.add_argument( - "--all", - action="store_true", - help="Process all the committees", - ) - parser.add_argument( - "--batch-size", - type=int, - default=50, - help="Number of committees to process in each batch", - ) - - def handle(self, *args, **options): - if not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): - self.stdout.write( - self.style.ERROR("DJANGO_OPEN_AI_SECRET_KEY environment variable not set") - ) - return - - self.openai_client = openai.OpenAI(api_key=openai_api_key) - - if options["committee_key"]: - queryset = Committee.objects.filter(key=options["committee_key"]) - elif options["all"]: - queryset = Committee.objects.all() - else: - queryset = Committee.objects.filter(is_active=True) - - if not (total_committees := queryset.count()): - self.stdout.write("No committees found to process") - return - - self.stdout.write(f"Found {total_committees} committees to process") - - batch_size = options["batch_size"] - processed_count = 0 - - for offset in range(0, total_committees, batch_size): - batch_committees = queryset[offset : offset + batch_size] - processed_count += self.process_chunks_batch(batch_committees) - - self.stdout.write( - self.style.SUCCESS( - f"Completed processing {processed_count}/{total_committees} committees" - ) - ) - - def process_chunks_batch(self, committees: list[Committee]) -> int: - """Process a batch of committees to create chunks.""" - processed = 0 - batch_chunks = [] - - committee_content_type = ContentType.objects.get_for_model(Committee) - committee_ids = [c.id for c in committees] - contexts_map = { - ctx.object_id: ctx - for ctx in Context.objects.filter( - content_type=committee_content_type, object_id__in=committee_ids - ) - } - - for committee in committees: - context = contexts_map.get(committee.id) - - if not context: - self.stdout.write( - self.style.WARNING(f"No context found for committee {committee.key}") - ) - continue - - prose_content, metadata_content = extract_committee_content(committee) - all_chunk_texts = [] - - if metadata_content.strip(): - all_chunk_texts.append(metadata_content) +class Command(BaseChunkCommand): + @property + def model_class(self) -> type[Model]: + return Committee - if prose_content.strip(): - prose_chunks = Chunk.split_text(prose_content) - all_chunk_texts.extend(prose_chunks) + @property + def entity_name(self) -> str: + return "committee" - if not all_chunk_texts: - self.stdout.write(f"No content to chunk for committee {committee.key}") - continue + @property + def entity_name_plural(self) -> str: + return "committees" - if chunks := create_chunks_and_embeddings( - chunk_texts=all_chunk_texts, - context=context, - openai_client=self.openai_client, - save=False, - ): - batch_chunks.extend(chunks) - processed += 1 - self.stdout.write(f"Created {len(chunks)} chunks for {committee.key}") + @property + def key_field_name(self) -> str: + return "key" - if batch_chunks: - Chunk.bulk_save(batch_chunks) - return processed + def extract_content(self, entity: Committee) -> tuple[str, str]: + """Extract content from the committee.""" + return extract_committee_content(entity) diff --git a/backend/apps/ai/management/commands/ai_create_committee_context.py b/backend/apps/ai/management/commands/ai_create_committee_context.py index 2802846b74..4a17b58dd6 100644 --- a/backend/apps/ai/management/commands/ai_create_committee_context.py +++ b/backend/apps/ai/management/commands/ai_create_committee_context.py @@ -1,81 +1,29 @@ """A command to update context for OWASP committee data.""" -from django.core.management.base import BaseCommand +from django.db.models import Model -from apps.ai.common.extractors import extract_committee_content -from apps.ai.common.utils import create_context +from apps.ai.common.base import BaseContextCommand +from apps.ai.common.extractors.committee import extract_committee_content from apps.owasp.models.committee import Committee -class Command(BaseCommand): - help = "Update context for OWASP committee data" +class Command(BaseContextCommand): + @property + def model_class(self) -> type[Model]: + return Committee - def add_arguments(self, parser): - parser.add_argument( - "--committee-key", - type=str, - help="Process only the committee with this key", - ) - parser.add_argument( - "--all", - action="store_true", - help="Process all the committees", - ) - parser.add_argument( - "--batch-size", - type=int, - default=50, - help="Number of committees to process in each batch", - ) + @property + def entity_name(self) -> str: + return "committee" - def handle(self, *args, **options): - if options["committee_key"]: - queryset = Committee.objects.filter(key=options["committee_key"]) - elif options["all"]: - queryset = Committee.objects.all() - else: - queryset = Committee.objects.filter(is_active=True) + @property + def entity_name_plural(self) -> str: + return "committees" - if not (total_committees := queryset.count()): - self.stdout.write("No committees found to process") - return + @property + def key_field_name(self) -> str: + return "key" - self.stdout.write(f"Found {total_committees} committees to process") - - batch_size = options["batch_size"] - processed_count = 0 - - for offset in range(0, total_committees, batch_size): - batch_committees = queryset[offset : offset + batch_size] - processed_count += self.process_context_batch(batch_committees) - - self.stdout.write( - self.style.SUCCESS( - f"Completed processing {processed_count}/{total_committees} committees" - ) - ) - - def process_context_batch(self, committees: list[Committee]) -> int: - """Process a batch of committees to create contexts.""" - processed = 0 - - for committee in committees: - prose_content, metadata_content = extract_committee_content(committee) - full_content = ( - f"{metadata_content}\n\n{prose_content}" if metadata_content else prose_content - ) - - if not full_content.strip(): - self.stdout.write(f"No content for committee {committee.key}") - continue - - if create_context( - content=full_content, content_object=committee, source="owasp_committee" - ): - processed += 1 - self.stdout.write(f"Created context for {committee.key}") - else: - self.stdout.write( - self.style.ERROR(f"Failed to create context for {committee.key}") - ) - return processed + def extract_content(self, entity: Committee) -> tuple[str, str]: + """Extract content from the committee.""" + return extract_committee_content(entity) diff --git a/backend/apps/ai/management/commands/ai_create_event_chunks.py b/backend/apps/ai/management/commands/ai_create_event_chunks.py index 31315314c5..a35c0fd97f 100644 --- a/backend/apps/ai/management/commands/ai_create_event_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_event_chunks.py @@ -1,117 +1,37 @@ """A command to create chunks of OWASP event data for RAG.""" -import os +from django.db.models import Model, QuerySet -import openai -from django.contrib.contenttypes.models import ContentType -from django.core.management.base import BaseCommand - -from apps.ai.common.extractors import extract_event_content -from apps.ai.common.utils import create_chunks_and_embeddings -from apps.ai.models.chunk import Chunk -from apps.ai.models.context import Context +from apps.ai.common.base import BaseChunkCommand +from apps.ai.common.extractors.event import extract_event_content from apps.owasp.models.event import Event -class Command(BaseCommand): - help = "Create chunks for OWASP event data" - - def add_arguments(self, parser): - parser.add_argument( - "--event-key", - type=str, - help="Process only the event with this key", - ) - parser.add_argument( - "--all", - action="store_true", - help="Process all the events", - ) - parser.add_argument( - "--batch-size", - type=int, - default=50, - help="Number of events to process in each batch", - ) - - def handle(self, *args, **options): - if not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): - self.stdout.write( - self.style.ERROR("DJANGO_OPEN_AI_SECRET_KEY environment variable not set") - ) - return - - self.openai_client = openai.OpenAI(api_key=openai_api_key) - - if options["event_key"]: - queryset = Event.objects.filter(key=options["event_key"]) - elif options["all"]: - queryset = Event.objects.all() - else: - queryset = Event.upcoming_events() - - if not (total_events := queryset.count()): - self.stdout.write("No events found to process") - return - - self.stdout.write(f"Found {total_events} events to process") - - batch_size = options["batch_size"] - processed_count = 0 - - for offset in range(0, total_events, batch_size): - batch_events = queryset[offset : offset + batch_size] - processed_count += self.process_chunks_batch(batch_events) - - self.stdout.write( - self.style.SUCCESS(f"Completed processing {processed_count}/{total_events} events") - ) - - def process_chunks_batch(self, events: list[Event]) -> int: - """Process a batch of events to create chunks.""" - processed = 0 - batch_chunks = [] - - event_content_type = ContentType.objects.get_for_model(Event) - event_ids = [e.id for e in events] - contexts_by_id = { - c.object_id: c - for c in Context.objects.filter( - content_type=event_content_type, object_id__in=event_ids - ) - } - - for event in events: - context = contexts_by_id.get(event.id) - - if not context: - self.stdout.write(self.style.WARNING(f"No context found for event {event.key}")) - continue +class Command(BaseChunkCommand): + @property + def model_class(self) -> type[Model]: + return Event - prose_content, metadata_content = extract_event_content(event) - all_chunk_texts = [] + @property + def entity_name(self) -> str: + return "event" - if metadata_content.strip(): - all_chunk_texts.append(metadata_content) + @property + def entity_name_plural(self) -> str: + return "events" - if prose_content.strip(): - prose_chunks = Chunk.split_text(prose_content) - all_chunk_texts.extend(prose_chunks) + @property + def key_field_name(self) -> str: + return "key" - if not all_chunk_texts: - self.stdout.write(f"No content to chunk for event {event.key}") - continue + def get_default_queryset(self) -> QuerySet: + """Return upcoming events by default instead of is_active filter.""" + return Event.upcoming_events() - if chunks := create_chunks_and_embeddings( - chunk_texts=all_chunk_texts, - context=context, - openai_client=self.openai_client, - save=False, - ): - batch_chunks.extend(chunks) - processed += 1 - self.stdout.write(f"Created {len(chunks)} chunks for {event.key}") + def get_base_queryset(self) -> QuerySet: + """Return the base queryset with ordering.""" + return super().get_base_queryset() - if batch_chunks: - Chunk.bulk_save(batch_chunks) - return processed + def extract_content(self, entity: Event) -> tuple[str, str]: + """Extract content from the event.""" + return extract_event_content(entity) diff --git a/backend/apps/ai/management/commands/ai_create_event_context.py b/backend/apps/ai/management/commands/ai_create_event_context.py index a518ac7c28..a866690da7 100644 --- a/backend/apps/ai/management/commands/ai_create_event_context.py +++ b/backend/apps/ai/management/commands/ai_create_event_context.py @@ -1,76 +1,37 @@ """A command to update context for OWASP event data.""" -from django.core.management.base import BaseCommand +from django.db.models import Model, QuerySet -from apps.ai.common.extractors import extract_event_content -from apps.ai.common.utils import create_context +from apps.ai.common.base import BaseContextCommand +from apps.ai.common.extractors.event import extract_event_content from apps.owasp.models.event import Event -class Command(BaseCommand): - help = "Update context for OWASP event data" +class Command(BaseContextCommand): + @property + def model_class(self) -> type[Model]: + return Event - def add_arguments(self, parser): - parser.add_argument( - "--event-key", - type=str, - help="Process only the event with this key", - ) - parser.add_argument( - "--all", - action="store_true", - help="Process all the events", - ) - parser.add_argument( - "--batch-size", - type=int, - default=50, - help="Number of events to process in each batch", - ) + @property + def entity_name(self) -> str: + return "event" - def handle(self, *args, **options): - if options["event_key"]: - queryset = Event.objects.filter(key=options["event_key"]) - elif options["all"]: - queryset = Event.objects.all() - else: - queryset = Event.upcoming_events() - queryset = queryset.order_by("id") + @property + def entity_name_plural(self) -> str: + return "events" - if not (total_events := queryset.count()): - self.stdout.write("No events found to process") - return + @property + def key_field_name(self) -> str: + return "key" - self.stdout.write(f"Found {total_events} events to process") + def get_default_queryset(self) -> QuerySet: + """Return upcoming events by default instead of is_active filter.""" + return Event.upcoming_events() - batch_size = options["batch_size"] - processed_count = 0 + def get_base_queryset(self) -> QuerySet: + """Return the base queryset with ordering.""" + return super().get_base_queryset() - for offset in range(0, total_events, batch_size): - batch_events = queryset[offset : offset + batch_size] - processed_count += self.process_context_batch(batch_events) - - self.stdout.write( - self.style.SUCCESS(f"Completed processing {processed_count}/{total_events} events") - ) - - def process_context_batch(self, events: list[Event]) -> int: - """Process a batch of events to create contexts.""" - processed = 0 - - for event in events: - prose_content, metadata_content = extract_event_content(event) - full_content = ( - f"{metadata_content}\n\n{prose_content}" if metadata_content else prose_content - ) - - if not full_content.strip(): - self.stdout.write(f"No content for event {event.key}") - continue - - if create_context(content=full_content, content_object=event, source="owasp_event"): - processed += 1 - self.stdout.write(f"Created context for {event.key}") - else: - self.stdout.write(self.style.ERROR(f"Failed to create context for {event.key}")) - return processed + def extract_content(self, entity: Event) -> tuple[str, str]: + """Extract content from the event.""" + return extract_event_content(entity) diff --git a/backend/apps/ai/management/commands/ai_create_project_chunks.py b/backend/apps/ai/management/commands/ai_create_project_chunks.py index 91cf633b1f..5aaa96a8db 100644 --- a/backend/apps/ai/management/commands/ai_create_project_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_project_chunks.py @@ -1,119 +1,33 @@ """A command to create chunks of OWASP project data for RAG.""" -import os +from django.db.models import Model, QuerySet -import openai -from django.contrib.contenttypes.models import ContentType -from django.core.management.base import BaseCommand - -from apps.ai.common.extractors import extract_project_content -from apps.ai.common.utils import create_chunks_and_embeddings -from apps.ai.models.chunk import Chunk -from apps.ai.models.context import Context +from apps.ai.common.base import BaseChunkCommand +from apps.ai.common.extractors.project import extract_project_content from apps.owasp.models.project import Project -class Command(BaseCommand): - help = "Create chunks for OWASP project data" - - def add_arguments(self, parser): - parser.add_argument( - "--project-key", - type=str, - help="Process only the project with this key", - ) - parser.add_argument( - "--all", - action="store_true", - help="Process all the projects", - ) - parser.add_argument( - "--batch-size", - type=int, - default=50, - help="Number of projects to process in each batch", - ) - - def handle(self, *args, **options): - if not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): - self.stdout.write( - self.style.ERROR("DJANGO_OPEN_AI_SECRET_KEY environment variable not set") - ) - return - - self.openai_client = openai.OpenAI(api_key=openai_api_key) - - if options["project_key"]: - queryset = Project.objects.filter(key=options["project_key"]) - elif options["all"]: - queryset = Project.objects.all() - else: - queryset = Project.objects.filter(is_active=True) - - if not (total_projects := queryset.count()): - self.stdout.write("No projects found to process") - return - - self.stdout.write(f"Found {total_projects} projects to process") - - batch_size = options["batch_size"] - processed_count = 0 - - for offset in range(0, total_projects, batch_size): - batch_projects = queryset[offset : offset + batch_size] - processed_count += self.process_chunks_batch(batch_projects) - - self.stdout.write( - self.style.SUCCESS(f"Completed processing {processed_count}/{total_projects} projects") - ) - - def process_chunks_batch(self, projects: list[Project]) -> int: - """Process a batch of projects to create chunks.""" - processed = 0 - batch_chunks = [] - - project_content_type = ContentType.objects.get_for_model(Project) - project_ids = [p.id for p in projects] - contexts_by_id = { - c.object_id: c - for c in Context.objects.filter( - content_type=project_content_type, object_id__in=project_ids - ) - } - - for project in projects: - context = contexts_by_id.get(project.id) - - if not context: - self.stdout.write( - self.style.WARNING(f"No context found for project {project.key}") - ) - continue - - prose_content, metadata_content = extract_project_content(project) - all_chunk_texts = [] +class Command(BaseChunkCommand): + @property + def model_class(self) -> type[Model]: + return Project - if metadata_content.strip(): - all_chunk_texts.append(metadata_content) + @property + def entity_name(self) -> str: + return "project" - if prose_content.strip(): - prose_chunks = Chunk.split_text(prose_content) - all_chunk_texts.extend(prose_chunks) + @property + def entity_name_plural(self) -> str: + return "projects" - if not all_chunk_texts: - self.stdout.write(f"No content to chunk for project {project.key}") - continue + @property + def key_field_name(self) -> str: + return "key" - if chunks := create_chunks_and_embeddings( - chunk_texts=all_chunk_texts, - context=context, - openai_client=self.openai_client, - save=False, - ): - batch_chunks.extend(chunks) - processed += 1 - self.stdout.write(f"Created {len(chunks)} chunks for {project.key}") + def get_base_queryset(self) -> QuerySet: + """Return the base queryset with ordering.""" + return super().get_base_queryset() - if batch_chunks: - Chunk.bulk_save(batch_chunks) - return processed + def extract_content(self, entity: Project) -> tuple[str, str]: + """Extract content from the project.""" + return extract_project_content(entity) diff --git a/backend/apps/ai/management/commands/ai_create_project_context.py b/backend/apps/ai/management/commands/ai_create_project_context.py index f643343a28..dc10befd33 100644 --- a/backend/apps/ai/management/commands/ai_create_project_context.py +++ b/backend/apps/ai/management/commands/ai_create_project_context.py @@ -1,78 +1,33 @@ """A command to update context for OWASP project data.""" -from django.core.management.base import BaseCommand +from django.db.models import Model, QuerySet -from apps.ai.common.extractors import extract_project_content -from apps.ai.common.utils import create_context +from apps.ai.common.base import BaseContextCommand +from apps.ai.common.extractors.project import extract_project_content from apps.owasp.models.project import Project -class Command(BaseCommand): - help = "Update context for OWASP project data" +class Command(BaseContextCommand): + @property + def model_class(self) -> type[Model]: + return Project - def add_arguments(self, parser): - parser.add_argument( - "--project-key", - type=str, - help="Process only the project with this key", - ) - parser.add_argument( - "--all", - action="store_true", - help="Process all the projects", - ) - parser.add_argument( - "--batch-size", - type=int, - default=50, - help="Number of projects to process in each batch", - ) + @property + def entity_name(self) -> str: + return "project" - def handle(self, *args, **options): - if options["project_key"]: - queryset = Project.objects.filter(key=options["project_key"]) - elif options["all"]: - queryset = Project.objects.all() - else: - queryset = Project.objects.filter(is_active=True) - queryset = queryset.order_by("id") + @property + def entity_name_plural(self) -> str: + return "projects" - if not (total_projects := queryset.count()): - self.stdout.write("No projects found to process") - return + @property + def key_field_name(self) -> str: + return "key" - self.stdout.write(f"Found {total_projects} projects to process") + def get_base_queryset(self) -> QuerySet: + """Return the base queryset with ordering.""" + return super().get_base_queryset() - batch_size = options["batch_size"] - processed_count = 0 - - for offset in range(0, total_projects, batch_size): - batch_projects = queryset[offset : offset + batch_size] - processed_count += self.process_context_batch(batch_projects) - - self.stdout.write( - self.style.SUCCESS(f"Completed processing {processed_count}/{total_projects} projects") - ) - - def process_context_batch(self, projects: list[Project]) -> int: - """Process a batch of projects to create contexts.""" - processed = 0 - - for project in projects: - prose_content, metadata_content = extract_project_content(project) - full_content = ( - f"{metadata_content}\n\n{prose_content}" if metadata_content else prose_content - ) - - if not full_content.strip(): - self.stdout.write(f"No content for project {project.key}") - continue - - if create_context( - content=full_content, content_object=project, source="owasp_project" - ): - processed += 1 - self.stdout.write(f"Created context for {project.key}") - else: - self.stdout.write(self.style.ERROR(f"Failed to create context for {project.key}")) - return processed + def extract_content(self, entity: Project) -> tuple[str, str]: + """Extract content from the project.""" + return extract_project_content(entity) diff --git a/backend/apps/ai/management/commands/ai_create_slack_message_chunks.py b/backend/apps/ai/management/commands/ai_create_slack_message_chunks.py index 4158627962..b34cf969da 100644 --- a/backend/apps/ai/management/commands/ai_create_slack_message_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_slack_message_chunks.py @@ -1,143 +1,55 @@ """A command to create chunks of Slack messages.""" -import os +from django.db.models import Model, QuerySet -import openai -from django.contrib.contenttypes.models import ContentType -from django.core.management.base import BaseCommand - -from apps.ai.common.utils import create_chunks_and_embeddings, create_context -from apps.ai.models.chunk import Chunk -from apps.ai.models.context import Context +from apps.ai.common.base import BaseChunkCommand from apps.slack.models.message import Message -class Command(BaseCommand): - help = "Create chunks for Slack messages" +class Command(BaseChunkCommand): + @property + def model_class(self) -> type[Model]: + return Message + + @property + def entity_name(self) -> str: + return "message" + + @property + def entity_name_plural(self) -> str: + return "messages" + + @property + def key_field_name(self) -> str: + return "slack_message_id" + + @property + def source_name(self) -> str: + return "slack_message" + + def get_default_queryset(self) -> QuerySet: + """Return all messages by default since Message model doesn't have is_active field.""" + return self.get_base_queryset() def add_arguments(self, parser): + """Override to use different default batch size for messages.""" parser.add_argument( - "--batch-size", - type=int, - default=100, - help="Number of messages to process in each batch", + "--message-key", + type=str, + help="Process only the message with this key", ) parser.add_argument( - "--context", + "--all", action="store_true", - help="Create only context (skip chunks and embeddings)", + help="Process all the messages", ) parser.add_argument( - "--chunks", - action="store_true", - help="Create only chunks+embeddings (requires existing context)", - ) - - def handle(self, *args, **options): - if not options["context"] and not options["chunks"]: - self.stdout.write( - self.style.ERROR("Please specify either --context or --chunks (or both)") - ) - return - - if options["chunks"] and not (openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY")): - self.stdout.write( - self.style.ERROR("DJANGO_OPEN_AI_SECRET_KEY environment variable not set") - ) - return - - if options["chunks"]: - self.openai_client = openai.OpenAI(api_key=openai_api_key) - - queryset = Message.objects.all() - total_messages = queryset.count() - - if not total_messages: - self.stdout.write("No messages found to process") - return - - self.stdout.write(f"Found {total_messages} messages to process") - - batch_size = options["batch_size"] - processed_count = 0 - - for offset in range(0, total_messages, batch_size): - batch_messages = queryset[offset : offset + batch_size] - - if options["context"]: - processed_count += self.process_context_batch(batch_messages) - elif options["chunks"]: - processed_count += self.process_chunks_batch(batch_messages) - - self.stdout.write( - self.style.SUCCESS(f"Completed processing {processed_count}/{total_messages} messages") + "--batch-size", + type=int, + default=100, + help="Number of messages to process in each batch", ) - def process_context_batch(self, messages: list[Message]) -> int: - """Process a batch of messages to create contexts.""" - processed = 0 - - for message in messages: - if not message.cleaned_text or not message.cleaned_text.strip(): - continue - - if create_context( - content=message.cleaned_text, - content_object=message, - source="slack_message", - ): - processed += 1 - self.stdout.write(f"Created context for message {message.slack_message_id}") - else: - self.stdout.write( - self.style.ERROR( - f"Failed to create context for message {message.slack_message_id}" - ) - ) - return processed - - def process_chunks_batch(self, messages: list[Message]) -> int: - """Process a batch of messages to create chunks.""" - processed = 0 - batch_chunks = [] - - message_content_type = ContentType.objects.get_for_model(Message) - - for message in messages: - context = Context.objects.filter( - content_type=message_content_type, object_id=message.id - ).first() - - if not context: - self.stdout.write( - self.style.WARNING(f"No context found for message {message.slack_message_id}") - ) - continue - - if not message.cleaned_text or not message.cleaned_text.strip(): - self.stdout.write(f"No content to chunk for message {message.slack_message_id}") - continue - - chunk_texts = Chunk.split_text(message.cleaned_text) - if not chunk_texts: - self.stdout.write( - f"No chunks created for message {message.slack_message_id}: " - f"`{message.cleaned_text}`" - ) - continue - - if chunks := create_chunks_and_embeddings( - chunk_texts=chunk_texts, - context=context, - openai_client=self.openai_client, - save=False, - ): - batch_chunks.extend(chunks) - processed += 1 - self.stdout.write( - f"Created {len(chunks)} chunks for message {message.slack_message_id}" - ) - - if batch_chunks: - Chunk.bulk_save(batch_chunks) - return processed + def extract_content(self, entity: Message) -> tuple[str, str]: + """Extract content from the message.""" + return entity.cleaned_text or "", "" diff --git a/backend/apps/ai/management/commands/ai_create_slack_message_context.py b/backend/apps/ai/management/commands/ai_create_slack_message_context.py new file mode 100644 index 0000000000..3e3d3a135c --- /dev/null +++ b/backend/apps/ai/management/commands/ai_create_slack_message_context.py @@ -0,0 +1,55 @@ +"""A command to update context for Slack message data.""" + +from django.db.models import Model, QuerySet + +from apps.ai.common.base import BaseContextCommand +from apps.slack.models.message import Message + + +class Command(BaseContextCommand): + @property + def model_class(self) -> type[Model]: + return Message + + @property + def entity_name(self) -> str: + return "message" + + @property + def entity_name_plural(self) -> str: + return "messages" + + @property + def key_field_name(self) -> str: + return "slack_message_id" + + @property + def source_name(self) -> str: + return "slack_message" + + def get_default_queryset(self) -> QuerySet: + """Return all messages by default since Message model doesn't have is_active field.""" + return self.get_base_queryset() + + def add_arguments(self, parser): + """Override to use different default batch size for messages.""" + parser.add_argument( + "--message-key", + type=str, + help="Process only the message with this key", + ) + parser.add_argument( + "--all", + action="store_true", + help="Process all the messages", + ) + parser.add_argument( + "--batch-size", + type=int, + default=100, + help="Number of messages to process in each batch", + ) + + def extract_content(self, entity: Message) -> tuple[str, str]: + """Extract content from the message.""" + return entity.cleaned_text or "", "" diff --git a/backend/apps/ai/migrations/0008_alter_context_unique_together_and_more.py b/backend/apps/ai/migrations/0008_alter_context_unique_together_and_more.py new file mode 100644 index 0000000000..2d807dd926 --- /dev/null +++ b/backend/apps/ai/migrations/0008_alter_context_unique_together_and_more.py @@ -0,0 +1,34 @@ +# Generated by Django 5.2.4 on 2025-08-10 19:06 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("ai", "0007_alter_chunk_context_alter_context_unique_together"), + ("contenttypes", "0002_remove_content_type_name"), + ] + + operations = [ + migrations.AlterUniqueTogether( + name="context", + unique_together=set(), + ), + migrations.AlterField( + model_name="context", + name="content_type", + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, to="contenttypes.contenttype" + ), + ), + migrations.AlterField( + model_name="context", + name="object_id", + field=models.PositiveIntegerField(), + ), + migrations.AlterUniqueTogether( + name="context", + unique_together={("content_type", "object_id", "content")}, + ), + ] diff --git a/backend/apps/ai/models/context.py b/backend/apps/ai/models/context.py index bc02cf2fb2..6031c2fec4 100644 --- a/backend/apps/ai/models/context.py +++ b/backend/apps/ai/models/context.py @@ -8,18 +8,18 @@ class Context(TimestampedModel): - """Context model for storing generated text and optional relation to OWASP entities.""" + """Context model for storing generated text related to OWASP entities.""" content = models.TextField(verbose_name="Generated Text") - content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE, blank=True, null=True) - object_id = models.PositiveIntegerField(default=0) + content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE) + object_id = models.PositiveIntegerField() content_object = GenericForeignKey("content_type", "object_id") source = models.CharField(max_length=100, blank=True, default="") class Meta: db_table = "ai_contexts" verbose_name = "Context" - unique_together = ("content_type", "object_id") + unique_together = ("content_type", "object_id", "content") def __str__(self): """Human readable representation.""" @@ -28,40 +28,24 @@ def __str__(self): or getattr(self.content_object, "key", None) or str(self.content_object) ) - return ( - f"{self.content_type.model if self.content_type else 'None'} {entity}: " - f"{self.content[:50]}" - ) + return f"{self.content_type.model} {entity}: {self.content[:50]}" @staticmethod def update_data( content: str, - content_object=None, + content_object, source: str = "", *, save: bool = True, - ) -> "Context | None": - """Update context data. - - Args: - content (str): The content text of the context. - content_object: Optional related object (generic foreign key). - source (str): Source identifier for the context. - save (bool): Whether to save the context to the database. - - Returns: - Context: The updated context instance or None if it already exists. - - """ - if content_object: - content_type = ContentType.objects.get_for_model(content_object) - object_id = content_object.pk - if Context.objects.filter( - content_type=content_type, object_id=object_id, content=content - ).exists(): - return None - elif Context.objects.filter(content=content, content_object__isnull=True).exists(): - return None + ) -> "Context": + """Retrieve existing or create new context.""" + content_type = ContentType.objects.get_for_model(content_object) + object_id = content_object.pk + existing_context = Context.objects.filter( + content_type=content_type, object_id=object_id, content=content + ).first() + if existing_context: + return existing_context context = Context(content=content, content_object=content_object, source=source) diff --git a/backend/tests/apps/ai/common/base_test.py b/backend/tests/apps/ai/common/base_test.py new file mode 100644 index 0000000000..e5a916d35f --- /dev/null +++ b/backend/tests/apps/ai/common/base_test.py @@ -0,0 +1,664 @@ +"""Tests for the base AI command classes.""" + +import os +from unittest.mock import Mock, call, patch + +import pytest +from django.core.management.base import BaseCommand +from django.db import models + +from apps.ai.common.base import BaseAICommand, BaseChunkCommand, BaseContextCommand + + +class MockModel(models.Model): + """Mock model for testing purposes.""" + + name = models.CharField(max_length=100) + key = models.CharField(max_length=50, unique=True) + is_active = models.BooleanField(default=True) + + def __str__(self): + """Return string representation of the model.""" + return self.name + + class Meta: + """Meta class for MockModel.""" + + app_label = "test" + + +class ConcreteBaseAICommand(BaseAICommand): + """Concrete implementation of BaseAICommand for testing.""" + + @property + def model_class(self) -> type[models.Model]: + return MockModel + + @property + def entity_name(self) -> str: + return "test" + + @property + def entity_name_plural(self) -> str: + return "tests" + + @property + def key_field_name(self) -> str: + return "key" + + def extract_content(self, entity: models.Model) -> tuple[str, str]: + return f"Content for {entity.name}", f"Metadata for {entity.name}" + + +class ConcreteBaseContextCommand(BaseContextCommand): + """Concrete implementation of BaseContextCommand for testing.""" + + @property + def model_class(self) -> type[models.Model]: + return MockModel + + @property + def entity_name(self) -> str: + return "test" + + @property + def entity_name_plural(self) -> str: + return "tests" + + @property + def key_field_name(self) -> str: + return "key" + + def extract_content(self, entity: models.Model) -> tuple[str, str]: + return f"Content for {entity.name}", f"Metadata for {entity.name}" + + +class ConcreteBaseChunkCommand(BaseChunkCommand): + """Concrete implementation of BaseChunkCommand for testing.""" + + @property + def model_class(self) -> type[models.Model]: + return MockModel + + @property + def entity_name(self) -> str: + return "test" + + @property + def entity_name_plural(self) -> str: + return "tests" + + @property + def key_field_name(self) -> str: + return "key" + + def extract_content(self, entity: models.Model) -> tuple[str, str]: + return f"Content for {entity.name}", f"Metadata for {entity.name}" + + +@pytest.fixture +def base_ai_command(): + """Return a concrete BaseAICommand instance.""" + return ConcreteBaseAICommand() + + +@pytest.fixture +def base_context_command(): + """Return a concrete BaseContextCommand instance.""" + return ConcreteBaseContextCommand() + + +@pytest.fixture +def base_chunk_command(): + """Return a concrete BaseChunkCommand instance.""" + return ConcreteBaseChunkCommand() + + +@pytest.fixture +def mock_entity(): + """Return a mock entity.""" + entity = Mock(spec=MockModel) + entity.name = "Test Entity" + entity.key = "test-key" + entity.pk = 1 + entity.is_active = True + return entity + + +@pytest.fixture +def mock_queryset(): + """Return a mock queryset.""" + queryset = Mock() + queryset.count.return_value = 3 + queryset.filter.return_value = queryset + queryset.__getitem__ = Mock(side_effect=lambda _: [Mock(), Mock()]) + return queryset + + +class TestBaseAICommand: + """Test suite for BaseAICommand.""" + + def test_command_inheritance(self, base_ai_command): + """Test that the command inherits from BaseCommand.""" + assert isinstance(base_ai_command, BaseCommand) + + def test_initialization(self, base_ai_command): + """Test command initialization.""" + assert base_ai_command.openai_client is None + + def test_abstract_properties(self, base_ai_command): + """Test abstract property implementations.""" + assert base_ai_command.model_class == MockModel + assert base_ai_command.entity_name == "test" + assert base_ai_command.entity_name_plural == "tests" + assert base_ai_command.key_field_name == "key" + + def test_source_name_default(self, base_ai_command): + """Test default source name.""" + assert base_ai_command.source_name == "owasp_test" + + def test_extract_content_implementation(self, base_ai_command, mock_entity): + """Test extract_content implementation.""" + prose, metadata = base_ai_command.extract_content(mock_entity) + assert prose == "Content for Test Entity" + assert metadata == "Metadata for Test Entity" + + @patch.object(ConcreteBaseAICommand, "model_class", MockModel) + def test_get_base_queryset(self, base_ai_command): + """Test get_base_queryset method.""" + with patch.object(MockModel, "objects") as mock_objects: + mock_objects.all.return_value = "base_queryset" + result = base_ai_command.get_base_queryset() + assert result == "base_queryset" + mock_objects.all.assert_called_once() + + @patch.object(ConcreteBaseAICommand, "get_base_queryset") + def test_get_default_queryset(self, mock_get_base, base_ai_command): + """Test get_default_queryset method.""" + mock_queryset = Mock() + mock_get_base.return_value = mock_queryset + mock_queryset.filter.return_value = "filtered_queryset" + + result = base_ai_command.get_default_queryset() + + assert result == "filtered_queryset" + mock_queryset.filter.assert_called_once_with(is_active=True) + + def test_add_common_arguments(self, base_ai_command): + """Test add_common_arguments method.""" + mock_parser = Mock() + mock_parser.add_argument = Mock() + + base_ai_command.add_common_arguments(mock_parser) + + expected_calls = [ + call("--test-key", type=str, help="Process only the test with this key"), + call("--all", action="store_true", help="Process all the tests"), + call( + "--batch-size", + type=int, + default=50, + help="Number of tests to process in each batch", + ), + ] + mock_parser.add_argument.assert_has_calls(expected_calls) + + def test_add_arguments_calls_common(self, base_ai_command): + """Test add_arguments calls add_common_arguments.""" + mock_parser = Mock() + with patch.object(base_ai_command, "add_common_arguments") as mock_add_common: + base_ai_command.add_arguments(mock_parser) + mock_add_common.assert_called_once_with(mock_parser) + + @patch.object(ConcreteBaseAICommand, "get_base_queryset") + def test_get_queryset_with_key_option(self, mock_get_base, base_ai_command): + """Test get_queryset with entity key option.""" + mock_queryset = Mock() + mock_get_base.return_value = mock_queryset + mock_queryset.filter.return_value = "filtered_queryset" + + options = {"test_key": "specific-key"} + result = base_ai_command.get_queryset(options) + + assert result == "filtered_queryset" + mock_queryset.filter.assert_called_once_with(key="specific-key") + + @patch.object(ConcreteBaseAICommand, "get_base_queryset") + def test_get_queryset_with_all_option(self, mock_get_base, base_ai_command): + """Test get_queryset with all option.""" + mock_queryset = Mock() + mock_get_base.return_value = mock_queryset + + options = {"all": True} + result = base_ai_command.get_queryset(options) + + assert result == mock_queryset + + @patch.object(ConcreteBaseAICommand, "get_default_queryset") + def test_get_queryset_default(self, mock_get_default, base_ai_command): + """Test get_queryset with default behavior.""" + mock_get_default.return_value = "default_queryset" + + options = {} + result = base_ai_command.get_queryset(options) + + assert result == "default_queryset" + + def test_get_entity_key(self, base_ai_command, mock_entity): + """Test get_entity_key method.""" + result = base_ai_command.get_entity_key(mock_entity) + assert result == "test-key" + + def test_get_entity_key_fallback_to_pk(self, base_ai_command): + """Test get_entity_key falls back to pk when key field doesn't exist.""" + mock_entity = Mock() + mock_entity.pk = 123 + delattr(mock_entity, "key") if hasattr(mock_entity, "key") else None + + result = base_ai_command.get_entity_key(mock_entity) + assert result == "123" + + @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-api-key"}) + @patch("apps.ai.common.base.openai.OpenAI") + def test_setup_openai_client_success(self, mock_openai_class, base_ai_command): + """Test successful OpenAI client setup.""" + mock_client = Mock() + mock_openai_class.return_value = mock_client + + result = base_ai_command.setup_openai_client() + + assert result is True + assert base_ai_command.openai_client == mock_client + mock_openai_class.assert_called_once_with(api_key="test-api-key") + + @patch.dict(os.environ, {}, clear=True) + def test_setup_openai_client_no_api_key(self, base_ai_command): + """Test OpenAI client setup without API key.""" + with ( + patch.object(base_ai_command.stdout, "write") as mock_write, + patch.object(base_ai_command.style, "ERROR") as mock_error, + ): + mock_error.return_value = "ERROR: No API key" + + result = base_ai_command.setup_openai_client() + + assert result is False + assert base_ai_command.openai_client is None + mock_error.assert_called_once_with( + "DJANGO_OPEN_AI_SECRET_KEY environment variable not set" + ) + mock_write.assert_called_once_with("ERROR: No API key") + + def test_handle_batch_processing_empty_queryset(self, base_ai_command): + """Test batch processing with empty queryset.""" + mock_queryset = Mock() + mock_queryset.count.return_value = 0 + + with patch.object(base_ai_command.stdout, "write") as mock_write: + base_ai_command.handle_batch_processing( + queryset=mock_queryset, + batch_size=10, + process_batch_func=Mock(), + ) + + mock_write.assert_called_once_with("No tests found to process") + + def test_handle_batch_processing_with_items(self, base_ai_command): + """Test batch processing with items.""" + mock_queryset = Mock() + mock_queryset.count.return_value = 5 + + # Mock slicing behavior + batch1 = [Mock(), Mock()] + batch2 = [Mock(), Mock()] + batch3 = [Mock()] + + def getitem_side_effect(slice_obj): + if slice_obj == slice(0, 2): + return batch1 + if slice_obj == slice(2, 4): + return batch2 + if slice_obj == slice(4, 6): + return batch3 + return [] + + mock_queryset.__getitem__ = Mock(side_effect=getitem_side_effect) + + mock_process_func = Mock(side_effect=[2, 2, 1]) # Return processed counts + + with ( + patch.object(base_ai_command.stdout, "write") as mock_write, + patch.object(base_ai_command.style, "SUCCESS") as mock_success, + ): + mock_success.return_value = "SUCCESS: Completed" + + base_ai_command.handle_batch_processing( + queryset=mock_queryset, + batch_size=2, + process_batch_func=mock_process_func, + ) + + # Verify process function was called with correct batches + expected_calls = [call(batch1), call(batch2), call(batch3)] + mock_process_func.assert_has_calls(expected_calls) + + # Verify output messages + assert mock_write.call_count == 2 + mock_write.assert_any_call("Found 5 tests to process") + mock_success.assert_called_once_with("Completed processing 5/5 tests") + + +class TestBaseContextCommand: + """Test suite for BaseContextCommand.""" + + def test_command_inheritance(self, base_context_command): + """Test that the command inherits from BaseAICommand.""" + assert isinstance(base_context_command, BaseAICommand) + + def test_help_property(self, base_context_command): + """Test help property.""" + assert base_context_command.help == "Update context for OWASP test data" + + @patch("apps.ai.common.base.create_context") + def test_process_context_batch_success( + self, mock_create_context, base_context_command + ): + """Test successful context batch processing.""" + mock_create_context.return_value = True + + entities = [ + Mock(name="Entity 1", key="key1"), + Mock(name="Entity 2", key="key2"), + ] + + with patch.object(base_context_command, "extract_content") as mock_extract: + mock_extract.side_effect = [ + ("Content 1", "Metadata 1"), + ("Content 2", "Metadata 2"), + ] + + with patch.object(base_context_command, "get_entity_key") as mock_get_key: + mock_get_key.side_effect = ["key1", "key2"] + + result = base_context_command.process_context_batch(entities) + + assert result == 2 + assert mock_create_context.call_count == 2 + + # Verify create_context was called with correct parameters + expected_calls = [ + call( + content="Metadata 1\n\nContent 1", + content_object=entities[0], + source="owasp_test", + ), + call( + content="Metadata 2\n\nContent 2", + content_object=entities[1], + source="owasp_test", + ), + ] + mock_create_context.assert_has_calls(expected_calls) + + @patch("apps.ai.common.base.create_context") + def test_process_context_batch_empty_content( + self, mock_create_context, base_context_command + ): + """Test context batch processing with empty content.""" + entities = [Mock(name="Empty Entity", key="empty-key")] + + with patch.object(base_context_command, "extract_content") as mock_extract: + mock_extract.return_value = ("", "") + + with patch.object(base_context_command, "get_entity_key") as mock_get_key: + mock_get_key.return_value = "empty-key" + + with patch.object(base_context_command.stdout, "write") as mock_write: + result = base_context_command.process_context_batch(entities) + + assert result == 0 + mock_create_context.assert_not_called() + mock_write.assert_called_once_with("No content for test empty-key") + + @patch("apps.ai.common.base.create_context") + def test_process_context_batch_create_failure( + self, mock_create_context, base_context_command + ): + """Test context batch processing when create_context fails.""" + mock_create_context.return_value = False + + entities = [Mock(name="Failing Entity", key="fail-key")] + + with patch.object(base_context_command, "extract_content") as mock_extract: + mock_extract.return_value = ("Content", "Metadata") + + with patch.object(base_context_command, "get_entity_key") as mock_get_key: + mock_get_key.return_value = "fail-key" + + with ( + patch.object(base_context_command.stdout, "write") as mock_write, + patch.object(base_context_command.style, "ERROR") as mock_error, + ): + mock_error.return_value = "ERROR: Failed" + + result = base_context_command.process_context_batch(entities) + + assert result == 0 + mock_error.assert_called_once_with( + "Failed to create context for fail-key" + ) + mock_write.assert_called_once_with("ERROR: Failed") + + def test_handle_calls_batch_processing(self, base_context_command): + """Test handle method calls handle_batch_processing.""" + options = {"batch_size": 25} + mock_queryset = Mock() + + with patch.object(base_context_command, "get_queryset") as mock_get_queryset: + mock_get_queryset.return_value = mock_queryset + + with patch.object( + base_context_command, "handle_batch_processing" + ) as mock_handle_batch: + base_context_command.handle(**options) + + mock_get_queryset.assert_called_once_with(options) + mock_handle_batch.assert_called_once_with( + queryset=mock_queryset, + batch_size=25, + process_batch_func=base_context_command.process_context_batch, + ) + + +class TestBaseChunkCommand: + """Test suite for BaseChunkCommand.""" + + def test_command_inheritance(self, base_chunk_command): + """Test that the command inherits from BaseAICommand.""" + assert isinstance(base_chunk_command, BaseAICommand) + + def test_help_property(self, base_chunk_command): + """Test help property.""" + assert base_chunk_command.help == "Create chunks for OWASP test data" + + @patch("apps.ai.common.base.create_chunks_and_embeddings") + @patch("apps.ai.common.base.Chunk.bulk_save") + @patch("apps.ai.common.base.Chunk.split_text") + @patch("apps.ai.common.base.Context.objects.filter") + @patch("apps.ai.common.base.ContentType.objects.get_for_model") + def test_process_chunks_batch_success( + self, + mock_get_content_type, + mock_context_filter, + mock_split_text, + mock_bulk_save, + mock_create_chunks, + base_chunk_command, + ): + """Test successful chunks batch processing.""" + # Setup mocks + mock_content_type = Mock() + mock_get_content_type.return_value = mock_content_type + + mock_context = Mock() + mock_context_filter.return_value.first.return_value = mock_context + + mock_split_text.return_value = ["chunk1", "chunk2"] + + mock_chunks = [Mock(), Mock()] + mock_create_chunks.return_value = mock_chunks + + entities = [Mock(id=1, name="Entity 1", key="key1")] + + with patch.object(base_chunk_command, "extract_content") as mock_extract: + mock_extract.return_value = ("Content", "Metadata") + + with patch.object(base_chunk_command, "get_entity_key") as mock_get_key: + mock_get_key.return_value = "key1" + + result = base_chunk_command.process_chunks_batch(entities) + + assert result == 1 + mock_get_content_type.assert_called_once_with(MockModel) + mock_context_filter.assert_called_once_with( + content_type=mock_content_type, object_id=1 + ) + mock_split_text.assert_called_once_with("Metadata\n\nContent") + mock_create_chunks.assert_called_once_with( + chunk_texts=["chunk1", "chunk2"], + context=mock_context, + openai_client=base_chunk_command.openai_client, + save=False, + ) + mock_bulk_save.assert_called_once_with(mock_chunks) + + @patch("apps.ai.common.base.Context.objects.filter") + @patch("apps.ai.common.base.ContentType.objects.get_for_model") + def test_process_chunks_batch_no_context( + self, mock_get_content_type, mock_context_filter, base_chunk_command + ): + """Test chunks batch processing when no context exists.""" + mock_content_type = Mock() + mock_get_content_type.return_value = mock_content_type + mock_context_filter.return_value.first.return_value = None + + entities = [Mock(id=1, name="Entity 1", key="key1")] + + with patch.object(base_chunk_command, "get_entity_key") as mock_get_key: + mock_get_key.return_value = "key1" + + with ( + patch.object(base_chunk_command.stdout, "write") as mock_write, + patch.object(base_chunk_command.style, "WARNING") as mock_warning, + ): + mock_warning.return_value = "WARNING: No context" + + result = base_chunk_command.process_chunks_batch(entities) + + assert result == 0 + mock_warning.assert_called_once_with("No context found for test key1") + mock_write.assert_called_once_with("WARNING: No context") + + @patch("apps.ai.common.base.Chunk.split_text") + @patch("apps.ai.common.base.Context.objects.filter") + @patch("apps.ai.common.base.ContentType.objects.get_for_model") + def test_process_chunks_batch_empty_content( + self, + mock_get_content_type, + mock_context_filter, + mock_split_text, + base_chunk_command, + ): + """Test chunks batch processing with empty content.""" + mock_content_type = Mock() + mock_get_content_type.return_value = mock_content_type + + mock_context = Mock() + mock_context_filter.return_value.first.return_value = mock_context + + entities = [Mock(id=1, name="Entity 1", key="key1")] + + with patch.object(base_chunk_command, "extract_content") as mock_extract: + mock_extract.return_value = ("", "") + + with patch.object(base_chunk_command, "get_entity_key") as mock_get_key: + mock_get_key.return_value = "key1" + + with patch.object(base_chunk_command.stdout, "write") as mock_write: + result = base_chunk_command.process_chunks_batch(entities) + + assert result == 0 + mock_split_text.assert_not_called() + mock_write.assert_called_once_with( + "No content to chunk for test key1" + ) + + @patch("apps.ai.common.base.Chunk.split_text") + @patch("apps.ai.common.base.Context.objects.filter") + @patch("apps.ai.common.base.ContentType.objects.get_for_model") + def test_process_chunks_batch_no_chunks_created( + self, + mock_get_content_type, + mock_context_filter, + mock_split_text, + base_chunk_command, + ): + """Test chunks batch processing when no chunks are created.""" + mock_content_type = Mock() + mock_get_content_type.return_value = mock_content_type + + mock_context = Mock() + mock_context_filter.return_value.first.return_value = mock_context + + mock_split_text.return_value = [] # No chunks created + + entities = [Mock(id=1, name="Entity 1", key="key1")] + + with patch.object(base_chunk_command, "extract_content") as mock_extract: + mock_extract.return_value = ("Content", "Metadata") + + with patch.object(base_chunk_command, "get_entity_key") as mock_get_key: + mock_get_key.return_value = "key1" + + with patch.object(base_chunk_command.stdout, "write") as mock_write: + result = base_chunk_command.process_chunks_batch(entities) + + assert result == 0 + mock_write.assert_called_once_with( + "No chunks created for test key1: `Metadata\n\nContent`" + ) + + def test_handle_calls_setup_and_batch_processing(self, base_chunk_command): + """Test handle method calls setup_openai_client and handle_batch_processing.""" + options = {"batch_size": 25} + mock_queryset = Mock() + + with patch.object(base_chunk_command, "setup_openai_client") as mock_setup: + mock_setup.return_value = True + + with patch.object(base_chunk_command, "get_queryset") as mock_get_queryset: + mock_get_queryset.return_value = mock_queryset + + with patch.object( + base_chunk_command, "handle_batch_processing" + ) as mock_handle_batch: + base_chunk_command.handle(**options) + + mock_setup.assert_called_once() + mock_get_queryset.assert_called_once_with(options) + mock_handle_batch.assert_called_once_with( + queryset=mock_queryset, + batch_size=25, + process_batch_func=base_chunk_command.process_chunks_batch, + ) + + def test_handle_returns_early_if_setup_fails(self, base_chunk_command): + """Test handle method returns early if OpenAI client setup fails.""" + with patch.object(base_chunk_command, "setup_openai_client") as mock_setup: + mock_setup.return_value = False + + with patch.object(base_chunk_command, "get_queryset") as mock_get_queryset: + base_chunk_command.handle() + + mock_setup.assert_called_once() + mock_get_queryset.assert_not_called() diff --git a/backend/tests/apps/ai/management/commands/ai_create_chapter_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_create_chapter_chunks_test.py index 018c022ac1..e501580f9b 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_chapter_chunks_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_chapter_chunks_test.py @@ -1,7 +1,4 @@ -"""Tests for the ai_create_chapter_chunks Django management command.""" - -import os -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import Mock, patch import pytest from django.core.management.base import BaseCommand @@ -11,227 +8,40 @@ @pytest.fixture def command(): - """Return a command instance.""" return Command() @pytest.fixture def mock_chapter(): - """Return a mock Chapter instance.""" chapter = Mock() chapter.id = 1 chapter.key = "test-chapter" return chapter -@pytest.fixture -def mock_context(): - """Return a mock Context instance.""" - context = Mock() - context.id = 1 - return context - - class TestAiCreateChapterChunksCommand: - """Test suite for the ai_create_chapter_chunks command.""" - - def test_command_help_text(self, command): - """Test that the command has the correct help text.""" - assert command.help == "Create chunks for OWASP chapter data" - def test_command_inheritance(self, command): - """Test that the command inherits from BaseCommand.""" assert isinstance(command, BaseCommand) - def test_add_arguments(self, command): - """Test that the command adds the correct arguments.""" - parser = MagicMock() - command.add_arguments(parser) - - assert parser.add_argument.call_count == 3 - parser.add_argument.assert_any_call( - "--chapter-key", - type=str, - help="Process only the chapter with this key", - ) - parser.add_argument.assert_any_call( - "--all", - action="store_true", - help="Process all the chapters", - ) - parser.add_argument.assert_any_call( - "--batch-size", - type=int, - default=50, - help="Number of chapters to process in each batch", - ) - - @patch.dict(os.environ, {}, clear=True) - def test_handle_missing_openai_key(self, command): - """Test command fails when OpenAI API key is not set.""" - command.stdout = MagicMock() - command.style = MagicMock() - - command.handle() - - command.stdout.write.assert_called_once() - command.style.ERROR.assert_called_once_with( - "DJANGO_OPEN_AI_SECRET_KEY environment variable not set" - ) - - @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) - @patch("apps.ai.management.commands.ai_create_chapter_chunks.openai.OpenAI") - @patch("apps.ai.management.commands.ai_create_chapter_chunks.Chapter.objects") - def test_handle_no_chapters_found(self, mock_chapter_objects, mock_openai, command): - """Test command when no chapters are found.""" - command.stdout = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 0 - mock_chapter_objects.filter.return_value = mock_queryset - - command.handle(chapter_key=None, all=False, batch_size=50) - - command.stdout.write.assert_called_with("No chapters found to process") - - @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) - @patch("apps.ai.management.commands.ai_create_chapter_chunks.openai.OpenAI") - @patch("apps.ai.management.commands.ai_create_chapter_chunks.Chapter.objects") - def test_handle_with_chapter_key( - self, mock_chapter_objects, mock_openai, command, mock_chapter - ): - """Test command with specific chapter key.""" - command.stdout = MagicMock() - command.style = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 1 - mock_queryset.__iter__ = lambda _self: iter([mock_chapter]) - mock_queryset.__getitem__ = lambda _self, _key: [mock_chapter] - mock_chapter_objects.filter.return_value = mock_queryset - - with patch.object(command, "process_chunks_batch", return_value=1): - command.handle(chapter_key="test-chapter", all=False, batch_size=50) - - mock_chapter_objects.filter.assert_called_with(key="test-chapter") - - @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) - @patch("apps.ai.management.commands.ai_create_chapter_chunks.openai.OpenAI") - @patch("apps.ai.management.commands.ai_create_chapter_chunks.Chapter.objects") - def test_handle_with_all_flag(self, mock_chapter_objects, mock_openai, command, mock_chapter): - """Test command with --all flag.""" - command.stdout = MagicMock() - command.style = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 1 - mock_queryset.__iter__ = lambda _self: iter([mock_chapter]) - mock_queryset.__getitem__ = lambda _self, _key: [mock_chapter] - mock_chapter_objects.all.return_value = mock_queryset - - with patch.object(command, "process_chunks_batch", return_value=1): - command.handle(chapter_key=None, all=True, batch_size=50) - - mock_chapter_objects.all.assert_called_once() - - @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) - @patch("apps.ai.management.commands.ai_create_chapter_chunks.openai.OpenAI") - @patch("apps.ai.management.commands.ai_create_chapter_chunks.Chapter.objects") - def test_handle_default_active_chapters( - self, mock_chapter_objects, mock_openai, command, mock_chapter - ): - """Test command defaults to active chapters.""" - command.stdout = MagicMock() - command.style = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 1 - mock_queryset.__iter__ = lambda _self: iter([mock_chapter]) - mock_queryset.__getitem__ = lambda _self, _key: [mock_chapter] - mock_chapter_objects.filter.return_value = mock_queryset - - with patch.object(command, "process_chunks_batch", return_value=1): - command.handle(chapter_key=None, all=False, batch_size=50) - - mock_chapter_objects.filter.assert_called_with(is_active=True) - - @patch("apps.ai.management.commands.ai_create_chapter_chunks.ContentType.objects") - @patch("apps.ai.management.commands.ai_create_chapter_chunks.Context.objects") - @patch("apps.ai.management.commands.ai_create_chapter_chunks.extract_chapter_content") - @patch("apps.ai.management.commands.ai_create_chapter_chunks.Chunk.split_text") - @patch("apps.ai.management.commands.ai_create_chapter_chunks.create_chunks_and_embeddings") - @patch("apps.ai.management.commands.ai_create_chapter_chunks.Chunk.bulk_save") - def test_process_chunks_batch_success( - self, - mock_bulk_save, - mock_create_chunks, - mock_split_text, - mock_extract, - mock_context_objects, - mock_content_type, - command, - mock_chapter, - mock_context, - ): - """Test successful batch processing of chunks.""" - command.stdout = MagicMock() - command.openai_client = MagicMock() - - # Setup mocks - mock_content_type.get_for_model.return_value = MagicMock() - mock_context_objects.filter.return_value.first.return_value = mock_context - mock_extract.return_value = ("prose content", "metadata content") - mock_split_text.return_value = ["chunk1", "chunk2"] - mock_chunks = [Mock(), Mock()] - mock_create_chunks.return_value = mock_chunks - - result = command.process_chunks_batch([mock_chapter]) - - assert result == 1 - mock_extract.assert_called_once_with(mock_chapter) - mock_split_text.assert_called_once_with("prose content") - mock_create_chunks.assert_called_once() - mock_bulk_save.assert_called_once_with(mock_chunks) - - @patch("apps.ai.management.commands.ai_create_chapter_chunks.ContentType.objects") - @patch("apps.ai.management.commands.ai_create_chapter_chunks.Context.objects") - def test_process_chunks_batch_no_context( - self, - mock_context_objects, - mock_content_type, - command, - mock_chapter, - ): - """Test batch processing when no context is found.""" - command.stdout = MagicMock() - command.style = MagicMock() - - # Setup mocks - mock_content_type.get_for_model.return_value = MagicMock() - mock_context_objects.filter.return_value.first.return_value = None - - result = command.process_chunks_batch([mock_chapter]) + def test_model_class_property(self, command): + from apps.owasp.models.chapter import Chapter - assert result == 0 - command.style.WARNING.assert_called_once() + assert command.model_class == Chapter - @patch("apps.ai.management.commands.ai_create_chapter_chunks.ContentType.objects") - @patch("apps.ai.management.commands.ai_create_chapter_chunks.Context.objects") - @patch("apps.ai.management.commands.ai_create_chapter_chunks.extract_chapter_content") - def test_process_chunks_batch_no_content( - self, - mock_extract, - mock_context_objects, - mock_content_type, - command, - mock_chapter, - mock_context, - ): - """Test batch processing when no content is extracted.""" - command.stdout = MagicMock() + def test_entity_name_property(self, command): + assert command.entity_name == "chapter" - # Setup mocks - mock_content_type.get_for_model.return_value = MagicMock() - mock_context_objects.filter.return_value.first.return_value = mock_context - mock_extract.return_value = ("", "") + def test_entity_name_plural_property(self, command): + assert command.entity_name_plural == "chapters" - result = command.process_chunks_batch([mock_chapter]) + def test_key_field_name_property(self, command): + assert command.key_field_name == "key" - assert result == 0 - command.stdout.write.assert_any_call(f"No content to chunk for chapter {mock_chapter.key}") + def test_extract_content(self, command, mock_chapter): + with patch( + "apps.ai.management.commands.ai_create_chapter_chunks.extract_chapter_content" + ) as mock_extract: + mock_extract.return_value = ("prose content", "metadata content") + content = command.extract_content(mock_chapter) + assert content == ("prose content", "metadata content") + mock_extract.assert_called_once_with(mock_chapter) diff --git a/backend/tests/apps/ai/management/commands/ai_create_chapter_context_test.py b/backend/tests/apps/ai/management/commands/ai_create_chapter_context_test.py index 9951ac7ec6..c140fc6184 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_chapter_context_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_chapter_context_test.py @@ -1,6 +1,6 @@ """Tests for the ai_create_chapter_context Django management command.""" -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import Mock, patch import pytest from django.core.management.base import BaseCommand @@ -34,177 +34,30 @@ def test_command_inheritance(self, command): """Test that the command inherits from BaseCommand.""" assert isinstance(command, BaseCommand) - def test_add_arguments(self, command): - """Test that the command adds the correct arguments.""" - parser = MagicMock() - command.add_arguments(parser) - - assert parser.add_argument.call_count == 3 - parser.add_argument.assert_any_call( - "--chapter-key", - type=str, - help="Process only the chapter with this key", - ) - parser.add_argument.assert_any_call( - "--all", - action="store_true", - help="Process all the chapters", - ) - parser.add_argument.assert_any_call( - "--batch-size", - type=int, - default=50, - help="Number of chapters to process in each batch", - ) - - @patch("apps.ai.management.commands.ai_create_chapter_context.Chapter.objects") - def test_handle_no_chapters_found(self, mock_chapter_objects, command): - """Test command when no chapters are found.""" - command.stdout = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 0 - mock_chapter_objects.filter.return_value = mock_queryset - - command.handle(chapter_key=None, all=False, batch_size=50) - - command.stdout.write.assert_called_with("No chapters found to process") - - @patch("apps.ai.management.commands.ai_create_chapter_context.Chapter.objects") - def test_handle_with_chapter_key(self, mock_chapter_objects, command, mock_chapter): - """Test command with specific chapter key.""" - command.stdout = MagicMock() - command.style = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 1 - mock_queryset.__iter__ = lambda _self: iter([mock_chapter]) - mock_queryset.__getitem__ = lambda _self, _key: [mock_chapter] - mock_chapter_objects.filter.return_value = mock_queryset - - with patch.object(command, "process_context_batch", return_value=1): - command.handle(chapter_key="test-chapter", all=False, batch_size=50) - - mock_chapter_objects.filter.assert_called_with(key="test-chapter") - - @patch("apps.ai.management.commands.ai_create_chapter_context.Chapter.objects") - def test_handle_with_all_flag(self, mock_chapter_objects, command, mock_chapter): - """Test command with --all flag.""" - command.stdout = MagicMock() - command.style = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 1 - mock_queryset.__iter__ = lambda _self: iter([mock_chapter]) - mock_queryset.__getitem__ = lambda _self, _key: [mock_chapter] - mock_chapter_objects.all.return_value = mock_queryset - - with patch.object(command, "process_context_batch", return_value=1): - command.handle(chapter_key=None, all=True, batch_size=50) - - mock_chapter_objects.all.assert_called_once() - - @patch("apps.ai.management.commands.ai_create_chapter_context.Chapter.objects") - def test_handle_default_active_chapters(self, mock_chapter_objects, command, mock_chapter): - """Test command defaults to active chapters.""" - command.stdout = MagicMock() - command.style = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 1 - mock_queryset.__iter__ = lambda _self: iter([mock_chapter]) - mock_queryset.__getitem__ = lambda _self, _key: [mock_chapter] - mock_chapter_objects.filter.return_value = mock_queryset - - with patch.object(command, "process_context_batch", return_value=1): - command.handle(chapter_key=None, all=False, batch_size=50) - - mock_chapter_objects.filter.assert_called_with(is_active=True) - - @patch("apps.ai.management.commands.ai_create_chapter_context.extract_chapter_content") - @patch("apps.ai.management.commands.ai_create_chapter_context.create_context") - def test_process_context_batch_success( - self, - mock_create_context, - mock_extract, - command, - mock_chapter, - ): - """Test successful batch processing of contexts.""" - command.stdout = MagicMock() - - # Setup mocks - mock_extract.return_value = ("prose content", "metadata content") - mock_create_context.return_value = True - - result = command.process_context_batch([mock_chapter]) - - assert result == 1 - mock_extract.assert_called_once_with(mock_chapter) - mock_create_context.assert_called_once_with( - content="metadata content\n\nprose content", - content_object=mock_chapter, - source="owasp_chapter", - ) - - @patch("apps.ai.management.commands.ai_create_chapter_context.extract_chapter_content") - @patch("apps.ai.management.commands.ai_create_chapter_context.create_context") - def test_process_context_batch_no_metadata( - self, - mock_create_context, - mock_extract, - command, - mock_chapter, - ): - """Test batch processing without metadata content.""" - command.stdout = MagicMock() - - # Setup mocks - mock_extract.return_value = ("prose content", "") - mock_create_context.return_value = True - - result = command.process_context_batch([mock_chapter]) - - assert result == 1 - mock_extract.assert_called_once_with(mock_chapter) - mock_create_context.assert_called_once_with( - content="prose content", - content_object=mock_chapter, - source="owasp_chapter", - ) - - @patch("apps.ai.management.commands.ai_create_chapter_context.extract_chapter_content") - def test_process_context_batch_no_content( - self, - mock_extract, - command, - mock_chapter, - ): - """Test batch processing when no content is extracted.""" - command.stdout = MagicMock() - - # Setup mocks - mock_extract.return_value = ("", "") - - result = command.process_context_batch([mock_chapter]) - - assert result == 0 - command.stdout.write.assert_any_call(f"No content for chapter {mock_chapter.key}") - - @patch("apps.ai.management.commands.ai_create_chapter_context.extract_chapter_content") - @patch("apps.ai.management.commands.ai_create_chapter_context.create_context") - def test_process_context_batch_create_context_fails( - self, - mock_create_context, - mock_extract, - command, - mock_chapter, - ): - """Test batch processing when create_context fails.""" - command.stdout = MagicMock() - command.style = MagicMock() - - # Setup mocks - mock_extract.return_value = ("prose content", "metadata content") - mock_create_context.return_value = False - - result = command.process_context_batch([mock_chapter]) - - assert result == 0 - command.style.ERROR.assert_called_once() + def test_model_class_property(self, command): + """Test the model_class property returns Chapter.""" + from apps.owasp.models.chapter import Chapter + + assert command.model_class == Chapter + + def test_entity_name_property(self, command): + """Test the entity_name property.""" + assert command.entity_name == "chapter" + + def test_entity_name_plural_property(self, command): + """Test the entity_name_plural property.""" + assert command.entity_name_plural == "chapters" + + def test_key_field_name_property(self, command): + """Test the key_field_name property.""" + assert command.key_field_name == "key" + + def test_extract_content(self, command, mock_chapter): + """Test content extraction from chapter.""" + with patch( + "apps.ai.management.commands.ai_create_chapter_context.extract_chapter_content" + ) as mock_extract: + mock_extract.return_value = ("prose content", "metadata content") + content = command.extract_content(mock_chapter) + assert content == ("prose content", "metadata content") + mock_extract.assert_called_once_with(mock_chapter) diff --git a/backend/tests/apps/ai/management/commands/ai_create_committee_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_create_committee_chunks_test.py index 0368bb2b5d..2ffb2cb098 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_committee_chunks_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_committee_chunks_test.py @@ -1,7 +1,4 @@ -"""Tests for the ai_create_committee_chunks Django management command.""" - -import os -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import Mock, patch import pytest from django.core.management.base import BaseCommand @@ -11,144 +8,40 @@ @pytest.fixture def command(): - """Return a command instance.""" return Command() @pytest.fixture def mock_committee(): - """Return a mock Committee instance.""" committee = Mock() committee.id = 1 committee.key = "test-committee" return committee -@pytest.fixture -def mock_context(): - """Return a mock Context instance.""" - context = Mock() - context.id = 1 - return context - - class TestAiCreateCommitteeChunksCommand: - """Test suite for the ai_create_committee_chunks command.""" - - def test_command_help_text(self, command): - """Test that the command has the correct help text.""" - assert command.help == "Create chunks for OWASP committee data" - def test_command_inheritance(self, command): - """Test that the command inherits from BaseCommand.""" assert isinstance(command, BaseCommand) - def test_add_arguments(self, command): - """Test that the command adds the correct arguments.""" - parser = MagicMock() - command.add_arguments(parser) - - assert parser.add_argument.call_count == 3 - parser.add_argument.assert_any_call( - "--committee-key", - type=str, - help="Process only the committee with this key", - ) - parser.add_argument.assert_any_call( - "--all", - action="store_true", - help="Process all the committees", - ) - parser.add_argument.assert_any_call( - "--batch-size", - type=int, - default=50, - help="Number of committees to process in each batch", - ) - - @patch.dict(os.environ, {}, clear=True) - def test_handle_missing_openai_key(self, command): - """Test command fails when OpenAI API key is not set.""" - command.stdout = MagicMock() - command.style = MagicMock() - - command.handle() - - command.stdout.write.assert_called_once() - command.style.ERROR.assert_called_once_with( - "DJANGO_OPEN_AI_SECRET_KEY environment variable not set" - ) - - @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) - @patch("apps.ai.management.commands.ai_create_committee_chunks.openai.OpenAI") - @patch("apps.ai.management.commands.ai_create_committee_chunks.Committee.objects") - def test_handle_no_committees_found(self, mock_committee_objects, mock_openai, command): - """Test command when no committees are found.""" - command.stdout = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 0 - mock_committee_objects.filter.return_value = mock_queryset - - command.handle(committee_key=None, all=False, batch_size=50) - - command.stdout.write.assert_called_with("No committees found to process") - - @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) - @patch("apps.ai.management.commands.ai_create_committee_chunks.openai.OpenAI") - @patch("apps.ai.management.commands.ai_create_committee_chunks.Committee.objects") - def test_handle_with_committee_key( - self, mock_committee_objects, mock_openai, command, mock_committee - ): - """Test command with specific committee key.""" - command.stdout = MagicMock() - command.style = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 1 - mock_queryset.__iter__ = lambda _self: iter([mock_committee]) - mock_queryset.__getitem__ = lambda _self, _key: [mock_committee] - mock_committee_objects.filter.return_value = mock_queryset - - with patch.object(command, "process_chunks_batch", return_value=1): - command.handle(committee_key="test-committee", all=False, batch_size=50) - - mock_committee_objects.filter.assert_called_with(key="test-committee") - - @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) - @patch("apps.ai.management.commands.ai_create_committee_chunks.openai.OpenAI") - @patch("apps.ai.management.commands.ai_create_committee_chunks.Committee.objects") - def test_handle_with_all_flag( - self, mock_committee_objects, mock_openai, command, mock_committee - ): - """Test command with --all flag.""" - command.stdout = MagicMock() - command.style = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 1 - mock_queryset.__iter__ = lambda _self: iter([mock_committee]) - mock_queryset.__getitem__ = lambda _self, _key: [mock_committee] - mock_committee_objects.all.return_value = mock_queryset + def test_model_class_property(self, command): + from apps.owasp.models.committee import Committee - with patch.object(command, "process_chunks_batch", return_value=1): - command.handle(committee_key=None, all=True, batch_size=50) + assert command.model_class == Committee - mock_committee_objects.all.assert_called_once() + def test_entity_name_property(self, command): + assert command.entity_name == "committee" - @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) - @patch("apps.ai.management.commands.ai_create_committee_chunks.openai.OpenAI") - @patch("apps.ai.management.commands.ai_create_committee_chunks.Committee.objects") - def test_handle_default_active_committees( - self, mock_committee_objects, mock_openai, command, mock_committee - ): - """Test command defaults to active committees.""" - command.stdout = MagicMock() - command.style = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 1 - mock_queryset.__iter__ = lambda _self: iter([mock_committee]) - mock_queryset.__getitem__ = lambda _self, _key: [mock_committee] - mock_committee_objects.filter.return_value = mock_queryset + def test_entity_name_plural_property(self, command): + assert command.entity_name_plural == "committees" - with patch.object(command, "process_chunks_batch", return_value=1): - command.handle(committee_key=None, all=False, batch_size=50) + def test_key_field_name_property(self, command): + assert command.key_field_name == "key" - mock_committee_objects.filter.assert_called_with(is_active=True) + def test_extract_content(self, command, mock_committee): + with patch( + "apps.ai.management.commands.ai_create_committee_chunks.extract_committee_content" + ) as mock_extract: + mock_extract.return_value = ("prose content", "metadata content") + content = command.extract_content(mock_committee) + assert content == ("prose content", "metadata content") + mock_extract.assert_called_once_with(mock_committee) diff --git a/backend/tests/apps/ai/management/commands/ai_create_committee_context_test.py b/backend/tests/apps/ai/management/commands/ai_create_committee_context_test.py index b7f8edae9b..30308d1734 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_committee_context_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_committee_context_test.py @@ -1,9 +1,8 @@ -"""Tests for the ai_create_committee_context Django management command.""" +"""A command to update context for OWASP committee data.""" -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import Mock, patch import pytest -from django.core.management.base import BaseCommand from apps.ai.management.commands.ai_create_committee_context import Command @@ -31,90 +30,35 @@ def test_command_help_text(self, command): assert command.help == "Update context for OWASP committee data" def test_command_inheritance(self, command): - """Test that the command inherits from BaseCommand.""" - assert isinstance(command, BaseCommand) - - def test_add_arguments(self, command): - """Test that the command adds the correct arguments.""" - parser = MagicMock() - command.add_arguments(parser) - - assert parser.add_argument.call_count == 3 - parser.add_argument.assert_any_call( - "--committee-key", - type=str, - help="Process only the committee with this key", - ) - parser.add_argument.assert_any_call( - "--all", - action="store_true", - help="Process all the committees", - ) - parser.add_argument.assert_any_call( - "--batch-size", - type=int, - default=50, - help="Number of committees to process in each batch", - ) - - @patch("apps.ai.management.commands.ai_create_committee_context.Committee.objects") - def test_handle_no_committees_found(self, mock_committee_objects, command): - """Test command when no committees are found.""" - command.stdout = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 0 - mock_committee_objects.filter.return_value = mock_queryset - - command.handle(committee_key=None, all=False, batch_size=50) - - command.stdout.write.assert_called_with("No committees found to process") - - @patch("apps.ai.management.commands.ai_create_committee_context.Committee.objects") - def test_handle_with_committee_key(self, mock_committee_objects, command, mock_committee): - """Test command with specific committee key.""" - command.stdout = MagicMock() - command.style = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 1 - mock_queryset.__iter__ = lambda _self: iter([mock_committee]) - mock_queryset.__getitem__ = lambda _self, _key: [mock_committee] - mock_committee_objects.filter.return_value = mock_queryset - - with patch.object(command, "process_context_batch", return_value=1): - command.handle(committee_key="test-committee", all=False, batch_size=50) - - mock_committee_objects.filter.assert_called_with(key="test-committee") - - @patch("apps.ai.management.commands.ai_create_committee_context.Committee.objects") - def test_handle_with_all_flag(self, mock_committee_objects, command, mock_committee): - """Test command with --all flag.""" - command.stdout = MagicMock() - command.style = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 1 - mock_queryset.__iter__ = lambda _self: iter([mock_committee]) - mock_queryset.__getitem__ = lambda _self, _key: [mock_committee] - mock_committee_objects.all.return_value = mock_queryset - - with patch.object(command, "process_context_batch", return_value=1): - command.handle(committee_key=None, all=True, batch_size=50) - - mock_committee_objects.all.assert_called_once() - - @patch("apps.ai.management.commands.ai_create_committee_context.Committee.objects") - def test_handle_default_active_committees( - self, mock_committee_objects, command, mock_committee - ): - """Test command defaults to active committees.""" - command.stdout = MagicMock() - command.style = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 1 - mock_queryset.__iter__ = lambda _self: iter([mock_committee]) - mock_queryset.__getitem__ = lambda _self, _key: [mock_committee] - mock_committee_objects.filter.return_value = mock_queryset - - with patch.object(command, "process_context_batch", return_value=1): - command.handle(committee_key=None, all=False, batch_size=50) - - mock_committee_objects.filter.assert_called_with(is_active=True) + """Test that the command inherits from BaseContextCommand.""" + from apps.ai.common.base import BaseContextCommand + + assert isinstance(command, BaseContextCommand) + + def test_model_class_property(self, command): + """Test the model_class property returns Committee.""" + from apps.owasp.models.committee import Committee + + assert command.model_class == Committee + + def test_entity_name_property(self, command): + """Test the entity_name property.""" + assert command.entity_name == "committee" + + def test_entity_name_plural_property(self, command): + """Test the entity_name_plural property.""" + assert command.entity_name_plural == "committees" + + def test_key_field_name_property(self, command): + """Test the key_field_name property.""" + assert command.key_field_name == "key" + + def test_extract_content(self, command, mock_committee): + """Test content extraction from committee.""" + with patch( + "apps.ai.management.commands.ai_create_committee_context.extract_committee_content" + ) as mock_extract: + mock_extract.return_value = ("prose content", "metadata content") + content = command.extract_content(mock_committee) + assert content == ("prose content", "metadata content") + mock_extract.assert_called_once_with(mock_committee) diff --git a/backend/tests/apps/ai/management/commands/ai_create_event_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_create_event_chunks_test.py index a2ea00dab6..4c76cef827 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_event_chunks_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_event_chunks_test.py @@ -1,7 +1,6 @@ """Tests for the ai_create_event_chunks Django management command.""" -import os -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import Mock, patch import pytest from django.core.management.base import BaseCommand @@ -20,101 +19,62 @@ def mock_event(): """Return a mock Event instance.""" event = Mock() event.id = 1 - event.title = "test-event" + event.key = "test-event" return event class TestAiCreateEventChunksCommand: """Test suite for the ai_create_event_chunks command.""" - def test_command_help_text(self, command): - """Test that the command has the correct help text.""" - assert command.help == "Create chunks for OWASP event data" - def test_command_inheritance(self, command): """Test that the command inherits from BaseCommand.""" assert isinstance(command, BaseCommand) - def test_add_arguments(self, command): - """Test that the command adds the correct arguments.""" - parser = MagicMock() - command.add_arguments(parser) - - assert parser.add_argument.call_count == 3 - parser.add_argument.assert_any_call( - "--all", - action="store_true", - help="Process all the events", - ) - parser.add_argument.assert_any_call( - "--batch-size", - type=int, - default=50, - help="Number of events to process in each batch", - ) - - @patch.dict(os.environ, {}, clear=True) - def test_handle_missing_openai_key(self, command): - """Test command fails when OpenAI API key is not set.""" - command.stdout = MagicMock() - command.style = MagicMock() - - command.handle() - - command.stdout.write.assert_called_once() - command.style.ERROR.assert_called_once_with( - "DJANGO_OPEN_AI_SECRET_KEY environment variable not set" - ) - - @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) - @patch("apps.ai.management.commands.ai_create_event_chunks.openai.OpenAI") - @patch("apps.ai.management.commands.ai_create_event_chunks.Event.upcoming_events") - def test_handle_no_events_found(self, mock_upcoming_events, mock_openai, command): - """Test command when no events are found.""" - command.stdout = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 0 - mock_upcoming_events.return_value = mock_queryset - - command.handle(event_key=None, all=False, batch_size=50) - - command.stdout.write.assert_called_with("No events found to process") - - @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) - @patch("apps.ai.management.commands.ai_create_event_chunks.openai.OpenAI") - @patch("apps.ai.management.commands.ai_create_event_chunks.Event.objects") - def test_handle_with_all_flag(self, mock_event_objects, mock_openai, command, mock_event): - """Test command with --all flag.""" - command.stdout = MagicMock() - command.style = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 1 - mock_queryset.__iter__ = lambda _self: iter([mock_event]) - mock_queryset.__getitem__ = lambda _self, _key: [mock_event] - mock_event_objects.all.return_value = mock_queryset - - with patch.object(command, "process_chunks_batch", return_value=1): - command.handle(event_key=None, all=True, batch_size=50) - - mock_event_objects.all.assert_called_once() - - @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) - @patch("apps.ai.management.commands.ai_create_event_chunks.openai.OpenAI") - @patch("apps.ai.management.commands.ai_create_event_chunks.Event.upcoming_events") - def test_handle_default_future_events( - self, mock_upcoming_events, mock_openai, command, mock_event - ): - """Test command defaults to future events.""" - command.stdout = MagicMock() - command.style = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 1 - mock_queryset.__iter__ = lambda _self: iter([mock_event]) - mock_queryset.__getitem__ = lambda _self, _key: [mock_event] - mock_upcoming_events.return_value = mock_queryset - - with patch.object(command, "process_chunks_batch", return_value=1): - command.handle(event_key=None, all=False, batch_size=50) - - # Should filter for future events by default - mock_upcoming_events.assert_called() + def test_model_class_property(self, command): + """Test the model_class property returns Event.""" + from apps.owasp.models.event import Event + + assert command.model_class == Event + + def test_entity_name_property(self, command): + """Test the entity_name property.""" + assert command.entity_name == "event" + + def test_entity_name_plural_property(self, command): + """Test the entity_name_plural property.""" + assert command.entity_name_plural == "events" + + def test_key_field_name_property(self, command): + """Test the key_field_name property.""" + assert command.key_field_name == "key" + + def test_extract_content(self, command, mock_event): + """Test content extraction from event.""" + with patch( + "apps.ai.management.commands.ai_create_event_chunks.extract_event_content" + ) as mock_extract: + mock_extract.return_value = ("prose content", "metadata content") + content = command.extract_content(mock_event) + assert content == ("prose content", "metadata content") + mock_extract.assert_called_once_with(mock_event) + + def test_get_default_queryset(self, command): + """Test that the default queryset returns upcoming events.""" + with patch( + "apps.ai.management.commands.ai_create_event_chunks.Event.upcoming_events" + ) as mock_upcoming: + mock_queryset = Mock() + mock_upcoming.return_value = mock_queryset + result = command.get_default_queryset() + assert result == mock_queryset + mock_upcoming.assert_called_once() + + def test_get_base_queryset(self, command): + """Test get_base_queryset calls super().get_base_queryset().""" + with patch( + "apps.ai.common.base.BaseAICommand.get_base_queryset", + return_value="base_qs", + ) as mock_super: + result = command.get_base_queryset() + assert result == "base_qs" + mock_super.assert_called_once() diff --git a/backend/tests/apps/ai/management/commands/ai_create_event_context_test.py b/backend/tests/apps/ai/management/commands/ai_create_event_context_test.py index 9a39911c7f..65eed2e15f 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_event_context_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_event_context_test.py @@ -1,6 +1,6 @@ """Tests for the ai_create_event_context Django management command.""" -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import Mock, patch import pytest from django.core.management.base import BaseCommand @@ -19,7 +19,7 @@ def mock_event(): """Return a mock Event instance.""" event = Mock() event.id = 1 - event.title = "test-event" + event.key = "test-event" return event @@ -34,63 +34,39 @@ def test_command_inheritance(self, command): """Test that the command inherits from BaseCommand.""" assert isinstance(command, BaseCommand) - def test_add_arguments(self, command): - """Test that the command adds the correct arguments.""" - parser = MagicMock() - command.add_arguments(parser) - - assert parser.add_argument.call_count == 3 - parser.add_argument.assert_any_call( - "--all", - action="store_true", - help="Process all the events", - ) - parser.add_argument.assert_any_call( - "--batch-size", - type=int, - default=50, - help="Number of events to process in each batch", - ) - - @patch("apps.ai.management.commands.ai_create_event_context.Event.upcoming_events") - def test_handle_no_events_found(self, mock_upcoming_events, command): - """Test command when no events are found.""" - command.stdout = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 0 - mock_upcoming_events.return_value = mock_queryset - - command.handle(event_key=None, all=False, batch_size=50) - - @patch("apps.ai.management.commands.ai_create_event_context.Event.objects") - def test_handle_with_all_flag(self, mock_event_objects, command, mock_event): - """Test command with --all flag.""" - command.stdout = MagicMock() - command.style = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 1 - mock_queryset.__iter__ = lambda _self: iter([mock_event]) - mock_queryset.__getitem__ = lambda _self, _key: [mock_event] - mock_event_objects.all.return_value = mock_queryset - - with patch.object(command, "process_context_batch", return_value=1): - command.handle(event_key=None, all=True, batch_size=50) - - mock_event_objects.all.assert_called_once() - - @patch("apps.ai.management.commands.ai_create_event_context.Event.upcoming_events") - def test_handle_default_future_events(self, mock_upcoming_events, command, mock_event): - """Test command defaults to future events.""" - command.stdout = MagicMock() - command.style = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 1 - mock_queryset.__iter__ = lambda _self: iter([mock_event]) - mock_queryset.__getitem__ = lambda _self, _key: [mock_event] - mock_upcoming_events.return_value = mock_queryset - - with patch.object(command, "process_context_batch", return_value=1): - command.handle(event_key=None, all=False, batch_size=50) - - # Should filter for future events by default - mock_upcoming_events.assert_called() + def test_model_class_property(self, command): + """Test the model_class property returns Event.""" + from apps.owasp.models.event import Event + + assert command.model_class == Event + + def test_entity_name_property(self, command): + """Test the entity_name property.""" + assert command.entity_name == "event" + + def test_entity_name_plural_property(self, command): + """Test the entity_name_plural property.""" + assert command.entity_name_plural == "events" + + def test_key_field_name_property(self, command): + """Test the key_field_name property.""" + assert command.key_field_name == "key" + + def test_extract_content(self, command, mock_event): + """Test content extraction from event.""" + with patch( + "apps.ai.management.commands.ai_create_event_context.extract_event_content" + ) as mock_extract: + mock_extract.return_value = ("prose content", "metadata content") + content = command.extract_content(mock_event) + assert content == ("prose content", "metadata content") + mock_extract.assert_called_once_with(mock_event) + + def test_get_default_queryset(self, command): + """Test that the default queryset returns upcoming events.""" + with patch("apps.owasp.models.event.Event.upcoming_events") as mock_upcoming: + mock_queryset = Mock() + mock_upcoming.return_value = mock_queryset + result = command.get_default_queryset() + assert result == mock_queryset + mock_upcoming.assert_called_once() diff --git a/backend/tests/apps/ai/management/commands/ai_create_project_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_create_project_chunks_test.py index ec4b9edfe4..4f2d377c34 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_project_chunks_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_project_chunks_test.py @@ -1,7 +1,4 @@ -"""Tests for the ai_create_project_chunks Django management command.""" - -import os -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import Mock, patch import pytest from django.core.management.base import BaseCommand @@ -11,13 +8,11 @@ @pytest.fixture def command(): - """Return a command instance.""" return Command() @pytest.fixture def mock_project(): - """Return a mock Project instance.""" project = Mock() project.id = 1 project.key = "test-project" @@ -25,120 +20,33 @@ def mock_project(): class TestAiCreateProjectChunksCommand: - """Test suite for the ai_create_project_chunks command.""" - - def test_command_help_text(self, command): - """Test that the command has the correct help text.""" - assert command.help == "Create chunks for OWASP project data" - def test_command_inheritance(self, command): - """Test that the command inherits from BaseCommand.""" assert isinstance(command, BaseCommand) - def test_add_arguments(self, command): - """Test that the command adds the correct arguments.""" - parser = MagicMock() - command.add_arguments(parser) - - assert parser.add_argument.call_count == 3 - parser.add_argument.assert_any_call( - "--project-key", - type=str, - help="Process only the project with this key", - ) - parser.add_argument.assert_any_call( - "--all", - action="store_true", - help="Process all the projects", - ) - parser.add_argument.assert_any_call( - "--batch-size", - type=int, - default=50, - help="Number of projects to process in each batch", - ) - - @patch.dict(os.environ, {}, clear=True) - def test_handle_missing_openai_key(self, command): - """Test command fails when OpenAI API key is not set.""" - command.stdout = MagicMock() - command.style = MagicMock() - - command.handle() - - command.stdout.write.assert_called_once() - command.style.ERROR.assert_called_once_with( - "DJANGO_OPEN_AI_SECRET_KEY environment variable not set" - ) - - @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) - @patch("apps.ai.management.commands.ai_create_project_chunks.openai.OpenAI") - @patch("apps.ai.management.commands.ai_create_project_chunks.Project.objects") - def test_handle_no_projects_found(self, mock_project_objects, mock_openai, command): - """Test command when no projects are found.""" - command.stdout = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 0 - mock_project_objects.filter.return_value = mock_queryset - - command.handle(project_key=None, all=False, batch_size=50) - - command.stdout.write.assert_called_with("No projects found to process") - - @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) - @patch("apps.ai.management.commands.ai_create_project_chunks.openai.OpenAI") - @patch("apps.ai.management.commands.ai_create_project_chunks.Project.objects") - def test_handle_with_project_key( - self, mock_project_objects, mock_openai, command, mock_project - ): - """Test command with specific project key.""" - command.stdout = MagicMock() - command.style = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 1 - mock_queryset.__iter__ = lambda _self: iter([mock_project]) - mock_queryset.__getitem__ = lambda _self, _key: [mock_project] - mock_project_objects.filter.return_value = mock_queryset - - with patch.object(command, "process_chunks_batch", return_value=1): - command.handle(project_key="test-project", all=False, batch_size=50) - - mock_project_objects.filter.assert_called_with(key="test-project") + def test_model_class_property(self, command): + from apps.owasp.models.project import Project - @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) - @patch("apps.ai.management.commands.ai_create_project_chunks.openai.OpenAI") - @patch("apps.ai.management.commands.ai_create_project_chunks.Project.objects") - def test_handle_with_all_flag(self, mock_project_objects, mock_openai, command, mock_project): - """Test command with --all flag.""" - command.stdout = MagicMock() - command.style = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 1 - mock_queryset.__iter__ = lambda _self: iter([mock_project]) - mock_queryset.__getitem__ = lambda _self, _key: [mock_project] - mock_project_objects.all.return_value = mock_queryset + assert command.model_class == Project - with patch.object(command, "process_chunks_batch", return_value=1): - command.handle(project_key=None, all=True, batch_size=50) + def test_entity_name_property(self, command): + assert command.entity_name == "project" - mock_project_objects.all.assert_called_once() + def test_entity_name_plural_property(self, command): + assert command.entity_name_plural == "projects" - @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) - @patch("apps.ai.management.commands.ai_create_project_chunks.openai.OpenAI") - @patch("apps.ai.management.commands.ai_create_project_chunks.Project.objects") - def test_handle_default_active_projects( - self, mock_project_objects, mock_openai, command, mock_project - ): - """Test command defaults to active projects.""" - command.stdout = MagicMock() - command.style = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 1 - mock_queryset.__iter__ = lambda _self: iter([mock_project]) - mock_queryset.__getitem__ = lambda _self, _key: [mock_project] - mock_project_objects.filter.return_value = mock_queryset + def test_key_field_name_property(self, command): + assert command.key_field_name == "key" - with patch.object(command, "process_chunks_batch", return_value=1): - command.handle(project_key=None, all=False, batch_size=50) + def test_extract_content(self, command, mock_project): + with patch( + "apps.ai.management.commands.ai_create_project_chunks.extract_project_content" + ) as mock_extract: + mock_extract.return_value = ("prose content", "metadata content") + content = command.extract_content(mock_project) + assert content == ("prose content", "metadata content") + mock_extract.assert_called_once_with(mock_project) - mock_project_objects.filter.assert_called_with(is_active=True) + def test_get_base_queryset_calls_super(self, command): + with patch("apps.ai.common.base.BaseChunkCommand.get_base_queryset") as mock_super: + command.get_base_queryset() + mock_super.assert_called_once() diff --git a/backend/tests/apps/ai/management/commands/ai_create_project_context_test.py b/backend/tests/apps/ai/management/commands/ai_create_project_context_test.py index effed75999..4dd616ea23 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_project_context_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_project_context_test.py @@ -1,6 +1,4 @@ -"""Tests for the ai_create_project_context Django management command.""" - -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import Mock, patch import pytest from django.core.management.base import BaseCommand @@ -26,91 +24,34 @@ def mock_project(): class TestAiCreateProjectContextCommand: """Test suite for the ai_create_project_context command.""" - def test_command_help_text(self, command): - """Test that the command has the correct help text.""" - assert command.help == "Update context for OWASP project data" - def test_command_inheritance(self, command): """Test that the command inherits from BaseCommand.""" assert isinstance(command, BaseCommand) - def test_add_arguments(self, command): - """Test that the command adds the correct arguments.""" - parser = MagicMock() - command.add_arguments(parser) - - assert parser.add_argument.call_count == 3 - parser.add_argument.assert_any_call( - "--project-key", - type=str, - help="Process only the project with this key", - ) - parser.add_argument.assert_any_call( - "--all", - action="store_true", - help="Process all the projects", - ) - parser.add_argument.assert_any_call( - "--batch-size", - type=int, - default=50, - help="Number of projects to process in each batch", - ) - - @patch("apps.ai.management.commands.ai_create_project_context.Project.objects") - def test_handle_no_projects_found(self, mock_project_objects, command): - """Test command when no projects are found.""" - command.stdout = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 0 - mock_project_objects.filter.return_value = mock_queryset - - command.handle(project_key=None, all=False, batch_size=50) - - @patch("apps.ai.management.commands.ai_create_project_context.Project.objects") - def test_handle_with_project_key(self, mock_project_objects, command, mock_project): - """Test command with specific project key.""" - command.stdout = MagicMock() - command.style = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 1 - mock_queryset.__iter__ = lambda _self: iter([mock_project]) - mock_queryset.__getitem__ = lambda _self, _key: [mock_project] - mock_project_objects.filter.return_value = mock_queryset - - with patch.object(command, "process_context_batch", return_value=1): - command.handle(project_key="test-project", all=False, batch_size=50) - - mock_project_objects.filter.assert_called_with(key="test-project") - - @patch("apps.ai.management.commands.ai_create_project_context.Project.objects") - def test_handle_with_all_flag(self, mock_project_objects, command, mock_project): - """Test command with --all flag.""" - command.stdout = MagicMock() - command.style = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 1 - mock_queryset.__iter__ = lambda _self: iter([mock_project]) - mock_queryset.__getitem__ = lambda _self, _key: [mock_project] - mock_project_objects.all.return_value = mock_queryset - - with patch.object(command, "process_context_batch", return_value=1): - command.handle(project_key=None, all=True, batch_size=50) - - mock_project_objects.all.assert_called_once() - - @patch("apps.ai.management.commands.ai_create_project_context.Project.objects") - def test_handle_default_active_projects(self, mock_project_objects, command, mock_project): - """Test command defaults to active projects.""" - command.stdout = MagicMock() - command.style = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 1 - mock_queryset.__iter__ = lambda _self: iter([mock_project]) - mock_queryset.__getitem__ = lambda _self, _key: [mock_project] - mock_project_objects.filter.return_value = mock_queryset - - with patch.object(command, "process_context_batch", return_value=1): - command.handle(project_key=None, all=False, batch_size=50) - - mock_project_objects.filter.assert_called_with(is_active=True) + def test_model_class_property(self, command): + """Test the model_class property returns Project.""" + from apps.owasp.models.project import Project + + assert command.model_class == Project + + def test_entity_name_property(self, command): + """Test the entity_name property.""" + assert command.entity_name == "project" + + def test_entity_name_plural_property(self, command): + """Test the entity_name_plural property.""" + assert command.entity_name_plural == "projects" + + def test_key_field_name_property(self, command): + """Test the key_field_name property.""" + assert command.key_field_name == "key" + + def test_extract_content(self, command, mock_project): + """Test content extraction from project.""" + with patch( + "apps.ai.management.commands.ai_create_project_context.extract_project_content" + ) as mock_extract: + mock_extract.return_value = ("prose content", "metadata content") + content = command.extract_content(mock_project) + assert content == ("prose content", "metadata content") + mock_extract.assert_called_once_with(mock_project) diff --git a/backend/tests/apps/ai/management/commands/ai_create_slack_message_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_create_slack_message_chunks_test.py index 217bdd46cf..a6cbae5df1 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_slack_message_chunks_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_slack_message_chunks_test.py @@ -1,7 +1,4 @@ -"""Tests for the ai_create_slack_message_chunks Django management command.""" - -import os -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, Mock import pytest from django.core.management.base import BaseCommand @@ -20,21 +17,52 @@ def mock_message(): """Return a mock Message instance.""" message = Mock() message.id = 1 - message.text = "test message" + message.slack_message_id = "test-message-123" + message.cleaned_text = "This is a test Slack message content." return message class TestAiCreateSlackMessageChunksCommand: """Test suite for the ai_create_slack_message_chunks command.""" - def test_command_help_text(self, command): - """Test that the command has the correct help text.""" - assert command.help == "Create chunks for Slack messages" - def test_command_inheritance(self, command): """Test that the command inherits from BaseCommand.""" assert isinstance(command, BaseCommand) + def test_model_class_property(self, command): + """Test the model_class property returns Message.""" + from apps.slack.models.message import Message + + assert command.model_class == Message + + def test_entity_name_property(self, command): + """Test the entity_name property.""" + assert command.entity_name == "message" + + def test_entity_name_plural_property(self, command): + """Test the entity_name_plural property.""" + assert command.entity_name_plural == "messages" + + def test_key_field_name_property(self, command): + """Test the key_field_name property.""" + assert command.key_field_name == "slack_message_id" + + def test_source_name_property(self, command): + """Test the source_name property.""" + assert command.source_name == "slack_message" + + def test_extract_content(self, command, mock_message): + """Test content extraction from message.""" + content = command.extract_content(mock_message) + assert content == ("This is a test Slack message content.", "") + + def test_extract_content_empty_text(self, command): + """Test content extraction when message has no cleaned_text.""" + message = Mock() + message.cleaned_text = None + content = command.extract_content(message) + assert content == ("", "") + def test_add_arguments(self, command): """Test that the command adds the correct arguments.""" parser = MagicMock() @@ -42,107 +70,18 @@ def test_add_arguments(self, command): assert parser.add_argument.call_count == 3 parser.add_argument.assert_any_call( - "--batch-size", - type=int, - default=100, - help="Number of messages to process in each batch", + "--message-key", + type=str, + help="Process only the message with this key", ) parser.add_argument.assert_any_call( - "--context", + "--all", action="store_true", - help="Create only context (skip chunks and embeddings)", + help="Process all the messages", ) parser.add_argument.assert_any_call( - "--chunks", - action="store_true", - help="Create only chunks+embeddings (requires existing context)", - ) - - def test_handle_no_options_specified(self, command): - """Test command with no context or chunks options.""" - command.stdout = MagicMock() - command.style = MagicMock() - - command.handle(batch_size=100, context=False, chunks=False) - - command.style.ERROR.assert_called_once_with( - "Please specify either --context or --chunks (or both)" + "--batch-size", + type=int, + default=100, + help="Number of messages to process in each batch", ) - - @patch.dict(os.environ, {}, clear=True) - def test_handle_chunks_missing_openai_key(self, command): - """Test command with --chunks flag but no OpenAI key.""" - command.stdout = MagicMock() - command.style = MagicMock() - - command.handle(batch_size=100, context=False, chunks=True) - - @patch("apps.ai.management.commands.ai_create_slack_message_chunks.Message.objects") - def test_handle_context_only(self, mock_message_objects, command, mock_message): - """Test command with --context flag only.""" - command.stdout = MagicMock() - command.style = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 1 - mock_queryset.__iter__ = lambda _self: iter([mock_message]) - mock_queryset.__getitem__ = lambda _self, _key: [mock_message] - mock_message_objects.filter.return_value = mock_queryset - - with patch.object(command, "process_context_batch", return_value=1): - command.handle(batch_size=100, context=True, chunks=False) - - command.style.SUCCESS.assert_called() - - @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) - @patch("apps.ai.management.commands.ai_create_slack_message_chunks.openai.OpenAI") - @patch("apps.ai.management.commands.ai_create_slack_message_chunks.Message.objects") - def test_handle_chunks_only(self, mock_message_objects, mock_openai, command, mock_message): - """Test command with --chunks flag only.""" - command.stdout = MagicMock() - command.style = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 1 - mock_queryset.__iter__ = lambda _self: iter([mock_message]) - mock_queryset.__getitem__ = lambda _self, _key: [mock_message] - mock_message_objects.filter.return_value = mock_queryset - - with patch.object(command, "process_chunks_batch", return_value=1): - command.handle(batch_size=100, context=False, chunks=True) - - command.style.SUCCESS.assert_called() - - @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}) - @patch("apps.ai.management.commands.ai_create_slack_message_chunks.openai.OpenAI") - @patch("apps.ai.management.commands.ai_create_slack_message_chunks.Message.objects") - def test_handle_both_context_and_chunks( - self, mock_message_objects, mock_openai, command, mock_message - ): - """Test command with both --context and --chunks flags.""" - command.stdout = MagicMock() - command.style = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 1 - mock_queryset.__iter__ = lambda _self: iter([mock_message]) - mock_queryset.__getitem__ = lambda _self, _key: [mock_message] - mock_message_objects.filter.return_value = mock_queryset - - with ( - patch.object(command, "process_context_batch", return_value=1), - patch.object(command, "process_chunks_batch", return_value=1), - ): - command.handle(batch_size=100, context=True, chunks=True) - - # Should be called once since it uses elif logic (context takes precedence) - assert command.style.SUCCESS.call_count == 1 - - @patch("apps.ai.management.commands.ai_create_slack_message_chunks.Message.objects") - def test_handle_no_messages_found(self, mock_message_objects, command): - """Test command when no messages are found.""" - command.stdout = MagicMock() - mock_queryset = MagicMock() - mock_queryset.count.return_value = 0 - mock_message_objects.all.return_value = mock_queryset - - command.handle(batch_size=100, context=True, chunks=False) - - command.stdout.write.assert_called_with("No messages found to process") diff --git a/backend/tests/apps/ai/management/commands/ai_create_slack_message_context_test.py b/backend/tests/apps/ai/management/commands/ai_create_slack_message_context_test.py new file mode 100644 index 0000000000..93d805961b --- /dev/null +++ b/backend/tests/apps/ai/management/commands/ai_create_slack_message_context_test.py @@ -0,0 +1,74 @@ +from unittest.mock import MagicMock, Mock + +import pytest +from django.core.management.base import BaseCommand + +from apps.ai.management.commands.ai_create_slack_message_context import Command + + +@pytest.fixture +def command(): + return Command() + + +@pytest.fixture +def mock_message(): + message = Mock() + message.id = 1 + message.pk = 1 + message.slack_message_id = "test-message-123" + message.cleaned_text = "This is a test Slack message content." + return message + + +class TestAiCreateSlackMessageContextCommand: + def test_command_inheritance(self, command): + assert isinstance(command, BaseCommand) + + def test_model_class_property(self, command): + from apps.slack.models.message import Message + + assert command.model_class == Message + + def test_entity_name_property(self, command): + assert command.entity_name == "message" + + def test_entity_name_plural_property(self, command): + assert command.entity_name_plural == "messages" + + def test_key_field_name_property(self, command): + assert command.key_field_name == "slack_message_id" + + def test_source_name_property(self, command): + assert command.source_name == "slack_message" + + def test_extract_content(self, command, mock_message): + content = command.extract_content(mock_message) + assert content == ("This is a test Slack message content.", "") + + def test_extract_content_empty_text(self, command): + message = Mock() + message.cleaned_text = None + content = command.extract_content(message) + assert content == ("", "") + + def test_add_arguments(self, command): + parser = MagicMock() + command.add_arguments(parser) + assert parser.add_argument.call_count == 3 + parser.add_argument.assert_any_call( + "--message-key", + type=str, + help="Process only the message with this key", + ) + parser.add_argument.assert_any_call( + "--all", + action="store_true", + help="Process all the messages", + ) + parser.add_argument.assert_any_call( + "--batch-size", + type=int, + default=100, + help="Number of messages to process in each batch", + ) diff --git a/backend/tests/apps/ai/management/commands/ai_run_rag_tool_test.py b/backend/tests/apps/ai/management/commands/ai_run_rag_tool_test.py index 7b23b9b834..33017e3169 100644 --- a/backend/tests/apps/ai/management/commands/ai_run_rag_tool_test.py +++ b/backend/tests/apps/ai/management/commands/ai_run_rag_tool_test.py @@ -46,7 +46,7 @@ def test_add_arguments(self, command): parser.add_argument.assert_any_call( "--threshold", type=float, - default=0.5, # DEFAULT_SIMILARITY_THRESHOLD + default=0.4, # DEFAULT_SIMILARITY_THRESHOLD help="Similarity threshold (0.0 to 1.0)", ) parser.add_argument.assert_any_call( diff --git a/backend/tests/apps/ai/models/context_test.py b/backend/tests/apps/ai/models/context_test.py index 4bed7c7de5..5648638718 100644 --- a/backend/tests/apps/ai/models/context_test.py +++ b/backend/tests/apps/ai/models/context_test.py @@ -16,56 +16,6 @@ def create_model_mock(model_class): class TestContextModel: - def test_str_method_without_content_type(self): - context = Context() - context.id = 1 - context.content = "Sample text without content type" - context.content_type = None - context.content_object = None - - result = str(context) - - assert result == "None None: Sample text without content type" - - def test_str_method_with_text_truncation(self): - long_text = "A" * 100 - - context = Context() - context.id = 1 - context.content = long_text - context.content_type = None - context.content_object = None - - result = str(context) - - assert result == f"None None: {long_text[:50]}" - assert len(result.split(": ", 1)[1]) == 50 - - def test_str_method_with_exactly_50_chars(self): - text_50_chars = "A" * 50 - - context = Context() - context.id = 1 - context.content = text_50_chars - context.content_type = None - context.content_object = None - - result = str(context) - - assert result == f"None None: {text_50_chars}" - assert len(result.split(": ", 1)[1]) == 50 - - def test_str_method_with_empty_text(self): - context = Context() - context.id = 1 - context.content = "" - context.content_type = None - context.content_object = None - - result = str(context) - - assert result == "None None: " - def test_meta_class_attributes(self): assert Context._meta.db_table == "ai_contexts" assert Context._meta.verbose_name == "Context" @@ -77,14 +27,13 @@ def test_content_field_properties(self): def test_content_type_field_properties(self): field = Context._meta.get_field("content_type") - assert field.null is True - assert field.blank is True + assert field.null is False + assert field.blank is False assert hasattr(field, "remote_field") assert field.remote_field.on_delete.__name__ == "CASCADE" def test_object_id_field_properties(self): field = Context._meta.get_field("object_id") - assert field.default == 0 assert field.__class__.__name__ == "PositiveIntegerField" def test_source_field_properties(self): @@ -93,12 +42,6 @@ def test_source_field_properties(self): assert field.blank is True assert field.default == "" - def test_content_object_generic_foreign_key(self): - field = Context._meta.get_field("content_object") - assert field.__class__.__name__ == "GenericForeignKey" - assert field.ct_field == "content_type" - assert field.fk_field == "object_id" - @patch("apps.ai.models.context.Context.save") @patch("apps.ai.models.context.Context.__init__") def test_context_creation_with_save(self, mock_init, mock_save): @@ -177,10 +120,7 @@ def test_context_validation_source_too_long(self, mock_full_clean): def test_context_default_values(self): context = Context() - assert context.object_id == 0 assert context.source == "" - assert context.content_type is None - assert context.content_object is None @patch("apps.ai.models.context.Context.refresh_from_db") def test_context_refresh_from_db(self, mock_refresh): @@ -195,3 +135,26 @@ def test_context_delete(self, mock_delete): context.delete() mock_delete.assert_called_once() + + @patch("apps.ai.models.context.Context.objects.filter") + def test_update_data_existing_context(self, mock_filter): + mock_context = create_model_mock(Context) + mock_filter.return_value.first.return_value = mock_context + + content = "Test" + mock_content_object = Mock() + mock_content_object.pk = 1 + + with patch( + "apps.ai.models.context.ContentType.objects.get_for_model" + ) as mock_get_for_model: + mock_content_type = Mock() + mock_get_for_model.return_value = mock_content_type + + result = Context.update_data(content, mock_content_object, source="src", save=True) + + mock_get_for_model.assert_called_once_with(mock_content_object) + mock_filter.assert_called_once_with( + content_type=mock_content_type, object_id=1, content=content + ) + assert result == mock_context From c709b9e8f11ee0a1695cdb73667a57c366a7426d Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Wed, 13 Aug 2025 14:08:27 +0530 Subject: [PATCH 17/32] suggestions implemented --- backend/apps/ai/common/base.py | 257 ------- backend/apps/ai/common/base/__init__.py | 7 + backend/apps/ai/common/base/ai_command.py | 125 ++++ backend/apps/ai/common/base/chunk_command.py | 81 +++ .../apps/ai/common/base/context_command.py | 54 ++ backend/apps/ai/common/extractors/chapter.py | 8 +- backend/apps/ai/common/utils.py | 36 +- .../commands/ai_create_chapter_chunks.py | 18 +- .../commands/ai_create_chapter_context.py | 18 +- .../commands/ai_create_committee_chunks.py | 18 +- .../commands/ai_create_committee_context.py | 18 +- .../commands/ai_create_event_chunks.py | 28 +- .../commands/ai_create_event_context.py | 28 +- .../commands/ai_create_project_chunks.py | 22 +- .../commands/ai_create_project_context.py | 22 +- .../ai_create_slack_message_chunks.py | 45 +- .../ai_create_slack_message_context.py | 45 +- backend/apps/ai/models/chunk.py | 7 +- backend/tests/apps/ai/agent/__init__.py | 1 + backend/tests/apps/ai/agent/tools/__init__.py | 1 + .../tests/apps/ai/agent/tools/rag/__init__.py | 1 + .../apps/ai/agent/tools/rag/generator_test.py | 200 ++++++ .../apps/ai/agent/tools/rag/rag_tool_test.py | 179 +++++ .../apps/ai/agent/tools/rag/retriever_test.py | 577 +++++++++++++++ backend/tests/apps/ai/common/base/__init__.py | 0 .../apps/ai/common/base/ai_command_test.py | 322 +++++++++ .../apps/ai/common/base/chunk_command_test.py | 436 ++++++++++++ .../ai/common/base/context_command_test.py | 326 +++++++++ backend/tests/apps/ai/common/base_test.py | 664 ------------------ .../apps/ai/common/extractors/__init__.py | 0 .../apps/ai/common/extractors/chapter_test.py | 299 ++++++++ .../ai/common/extractors/committee_test.py | 202 ++++++ .../apps/ai/common/extractors/event_test.py | 208 ++++++ .../apps/ai/common/extractors/project_test.py | 441 ++++++++++++ backend/tests/apps/ai/common/utils_test.py | 126 +++- .../commands/ai_create_chapter_chunks_test.py | 8 +- .../ai_create_chapter_context_test.py | 25 +- .../ai_create_committee_chunks_test.py | 120 +++- .../ai_create_committee_context_test.py | 204 +++++- .../commands/ai_create_event_chunks_test.py | 30 +- .../commands/ai_create_event_context_test.py | 33 +- .../commands/ai_create_project_chunks_test.py | 17 +- .../ai_create_project_context_test.py | 33 +- .../ai_create_slack_message_chunks_test.py | 80 +-- .../ai_create_slack_message_context_test.py | 81 ++- backend/tests/apps/ai/models/chunk_test.py | 42 +- backend/tests/apps/ai/models/context_test.py | 120 ++++ 47 files changed, 4297 insertions(+), 1316 deletions(-) delete mode 100644 backend/apps/ai/common/base.py create mode 100644 backend/apps/ai/common/base/__init__.py create mode 100644 backend/apps/ai/common/base/ai_command.py create mode 100644 backend/apps/ai/common/base/chunk_command.py create mode 100644 backend/apps/ai/common/base/context_command.py create mode 100644 backend/tests/apps/ai/agent/__init__.py create mode 100644 backend/tests/apps/ai/agent/tools/__init__.py create mode 100644 backend/tests/apps/ai/agent/tools/rag/__init__.py create mode 100644 backend/tests/apps/ai/agent/tools/rag/generator_test.py create mode 100644 backend/tests/apps/ai/agent/tools/rag/rag_tool_test.py create mode 100644 backend/tests/apps/ai/agent/tools/rag/retriever_test.py create mode 100644 backend/tests/apps/ai/common/base/__init__.py create mode 100644 backend/tests/apps/ai/common/base/ai_command_test.py create mode 100644 backend/tests/apps/ai/common/base/chunk_command_test.py create mode 100644 backend/tests/apps/ai/common/base/context_command_test.py delete mode 100644 backend/tests/apps/ai/common/base_test.py create mode 100644 backend/tests/apps/ai/common/extractors/__init__.py create mode 100644 backend/tests/apps/ai/common/extractors/chapter_test.py create mode 100644 backend/tests/apps/ai/common/extractors/committee_test.py create mode 100644 backend/tests/apps/ai/common/extractors/event_test.py create mode 100644 backend/tests/apps/ai/common/extractors/project_test.py diff --git a/backend/apps/ai/common/base.py b/backend/apps/ai/common/base.py deleted file mode 100644 index 45181eecfe..0000000000 --- a/backend/apps/ai/common/base.py +++ /dev/null @@ -1,257 +0,0 @@ -"""Base classes for AI management commands.""" - -import os -from abc import ABC, abstractmethod -from collections.abc import Callable -from typing import Any - -import openai -from django.contrib.contenttypes.models import ContentType -from django.core.management.base import BaseCommand -from django.db.models import Model, QuerySet - -from apps.ai.common.utils import create_chunks_and_embeddings, create_context -from apps.ai.models.chunk import Chunk -from apps.ai.models.context import Context - - -class BaseAICommand(BaseCommand, ABC): - """Base class for AI management commands with common functionality.""" - - def __init__(self, *args, **kwargs): - """Initialize the AI command with OpenAI client placeholder.""" - super().__init__(*args, **kwargs) - self.openai_client: openai.OpenAI | None = None - - @property - @abstractmethod - def model_class(self) -> type[Model]: - """Return the Django model class this command operates on.""" - - @property - @abstractmethod - def entity_name(self) -> str: - """Return the human-readable name for the entity (e.g., 'chapter', 'project').""" - - @property - @abstractmethod - def entity_name_plural(self) -> str: - """Return the plural form of the entity name.""" - - @property - @abstractmethod - def key_field_name(self) -> str: - """Return the field name used for filtering by key (e.g., 'key', 'slug').""" - - @abstractmethod - def extract_content(self, entity: Model) -> tuple[str, str]: - """Extract content from the entity. Return (prose_content, metadata_content).""" - - @property - def source_name(self) -> str: - """Return the source name for context creation. Override if different from default.""" - return f"owasp_{self.entity_name}" - - def get_base_queryset(self) -> QuerySet: - """Return the base queryset. Override for custom filtering logic.""" - return self.model_class.objects.all() - - def get_default_queryset(self) -> QuerySet: - """Return the default queryset when no specific options are provided.""" - return self.get_base_queryset().filter(is_active=True) - - def add_common_arguments(self, parser): - """Add common arguments that most commands need.""" - parser.add_argument( - f"--{self.entity_name}-key", - type=str, - help=f"Process only the {self.entity_name} with this key", - ) - parser.add_argument( - "--all", - action="store_true", - help=f"Process all the {self.entity_name_plural}", - ) - parser.add_argument( - "--batch-size", - type=int, - default=50, - help=f"Number of {self.entity_name_plural} to process in each batch", - ) - - def add_arguments(self, parser): - """Add arguments to the command. Override to add custom arguments.""" - self.add_common_arguments(parser) - - def get_queryset(self, options: dict[str, Any]) -> QuerySet: - """Get the queryset based on command options.""" - key_option = f"{self.entity_name}_key" - - if options.get(key_option): - filter_kwargs = {self.key_field_name: options[key_option]} - return self.get_base_queryset().filter(**filter_kwargs) - if options.get("all"): - return self.get_base_queryset() - return self.get_default_queryset() - - def get_entity_key(self, entity: Model) -> str: - """Get the key/identifier for an entity for display purposes.""" - return str(getattr(entity, self.key_field_name, entity.pk)) - - def setup_openai_client(self) -> bool: - """Set up OpenAI client if API key is available.""" - if openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY"): - self.openai_client = openai.OpenAI(api_key=openai_api_key) - return True - self.stdout.write( - self.style.ERROR("DJANGO_OPEN_AI_SECRET_KEY environment variable not set") - ) - return False - - def handle_batch_processing( - self, - queryset: QuerySet, - batch_size: int, - process_batch_func: Callable[[list[Model]], int], - ) -> None: - """Handle the common batch processing logic.""" - total_count = queryset.count() - - if not total_count: - self.stdout.write(f"No {self.entity_name_plural} found to process") - return - - self.stdout.write(f"Found {total_count} {self.entity_name_plural} to process") - - processed_count = 0 - for offset in range(0, total_count, batch_size): - batch_items = queryset[offset : offset + batch_size] - processed_count += process_batch_func(list(batch_items)) - - self.stdout.write( - self.style.SUCCESS( - f"Completed processing {processed_count}/{total_count} {self.entity_name_plural}" - ) - ) - - -class BaseContextCommand(BaseAICommand): - """Base class for context creation commands.""" - - @property - def help(self) -> str: - """Return help text for the context creation command.""" - return f"Update context for OWASP {self.entity_name} data" - - def process_context_batch(self, entities: list[Model]) -> int: - """Process a batch of entities to create contexts.""" - processed = 0 - - for entity in entities: - prose_content, metadata_content = self.extract_content(entity) - full_content = ( - f"{metadata_content}\n\n{prose_content}" if metadata_content else prose_content - ) - - if not full_content.strip(): - entity_key = self.get_entity_key(entity) - self.stdout.write(f"No content for {self.entity_name} {entity_key}") - continue - - if create_context( - content=full_content, - content_object=entity, - source=self.source_name, - ): - processed += 1 - entity_key = self.get_entity_key(entity) - self.stdout.write(f"Created context for {entity_key}") - else: - entity_key = self.get_entity_key(entity) - self.stdout.write(self.style.ERROR(f"Failed to create context for {entity_key}")) - - return processed - - def handle(self, *args, **options): - """Handle the context creation command.""" - queryset = self.get_queryset(options) - batch_size = options["batch_size"] - - self.handle_batch_processing( - queryset=queryset, - batch_size=batch_size, - process_batch_func=self.process_context_batch, - ) - - -class BaseChunkCommand(BaseAICommand): - """Base class for chunk creation commands.""" - - @property - def help(self) -> str: - """Return help text for the chunk creation command.""" - return f"Create chunks for OWASP {self.entity_name} data" - - def process_chunks_batch(self, entities: list[Model]) -> int: - """Process a batch of entities to create chunks.""" - processed = 0 - batch_chunks = [] - content_type = ContentType.objects.get_for_model(self.model_class) - - for entity in entities: - context = Context.objects.filter( - content_type=content_type, object_id=entity.id - ).first() - - entity_key = self.get_entity_key(entity) - - if not context: - self.stdout.write( - self.style.WARNING(f"No context found for {self.entity_name} {entity_key}") - ) - continue - - prose_content, metadata_content = self.extract_content(entity) - full_content = ( - f"{metadata_content}\n\n{prose_content}" if metadata_content else prose_content - ) - - if not full_content.strip(): - self.stdout.write(f"No content to chunk for {self.entity_name} {entity_key}") - continue - - chunk_texts = Chunk.split_text(full_content) - if not chunk_texts: - self.stdout.write( - f"No chunks created for {self.entity_name} {entity_key}: `{full_content}`" - ) - continue - - if chunks := create_chunks_and_embeddings( - chunk_texts=chunk_texts, - context=context, - openai_client=self.openai_client, - save=False, - ): - batch_chunks.extend(chunks) - processed += 1 - self.stdout.write(f"Created {len(chunks)} chunks for {entity_key}") - - if batch_chunks: - Chunk.bulk_save(batch_chunks) - - return processed - - def handle(self, *args, **options): - """Handle the chunk creation command.""" - if not self.setup_openai_client(): - return - - queryset = self.get_queryset(options) - batch_size = options["batch_size"] - - self.handle_batch_processing( - queryset=queryset, - batch_size=batch_size, - process_batch_func=self.process_chunks_batch, - ) diff --git a/backend/apps/ai/common/base/__init__.py b/backend/apps/ai/common/base/__init__.py new file mode 100644 index 0000000000..8f794890e0 --- /dev/null +++ b/backend/apps/ai/common/base/__init__.py @@ -0,0 +1,7 @@ +"""Base classes for AI management commands.""" + +from .ai_command import BaseAICommand +from .chunk_command import BaseChunkCommand +from .context_command import BaseContextCommand + +__all__ = ["BaseAICommand", "BaseChunkCommand", "BaseContextCommand"] diff --git a/backend/apps/ai/common/base/ai_command.py b/backend/apps/ai/common/base/ai_command.py new file mode 100644 index 0000000000..62a1909279 --- /dev/null +++ b/backend/apps/ai/common/base/ai_command.py @@ -0,0 +1,125 @@ +"""Base AI command class with common functionality.""" + +import os +from abc import ABC, abstractmethod +from collections.abc import Callable +from typing import Any + +import openai +from django.core.management.base import BaseCommand +from django.db.models import Model, QuerySet + + +class BaseAICommand(BaseCommand, ABC): + """Base class for AI management commands with common functionality.""" + + def __init__(self, *args, **kwargs): + """Initialize the AI command with OpenAI client placeholder.""" + super().__init__(*args, **kwargs) + self.openai_client: openai.OpenAI | None = None + + @abstractmethod + def model_class(self) -> type[Model]: + """Return the Django model class this command operates on.""" + + @abstractmethod + def entity_name(self) -> str: + """Return the human-readable name for the entity (e.g., 'chapter', 'project').""" + + @abstractmethod + def entity_name_plural(self) -> str: + """Return the plural form of the entity name.""" + + @abstractmethod + def key_field_name(self) -> str: + """Return the field name used for filtering by key (e.g., 'key', 'slug').""" + + @abstractmethod + def extract_content(self, entity: Model) -> tuple[str, str]: + """Extract content from the entity. Return (prose_content, metadata_content).""" + + def source_name(self) -> str: + """Return the source name for context creation. Override if different from default.""" + return f"owasp_{self.entity_name()}" + + def get_base_queryset(self) -> QuerySet: + """Return the base queryset. Override for custom filtering logic.""" + return self.model_class().objects.all() + + def get_default_queryset(self) -> QuerySet: + """Return the default queryset when no specific options are provided.""" + return self.get_base_queryset().filter(is_active=True) + + def add_common_arguments(self, parser): + """Add common arguments that most commands need.""" + parser.add_argument( + f"--{self.entity_name()}-key", + type=str, + help=f"Process only the {self.entity_name()} with this key", + ) + parser.add_argument( + "--all", + action="store_true", + help=f"Process all the {self.entity_name_plural()}", + ) + parser.add_argument( + "--batch-size", + type=int, + default=50, + help=f"Number of {self.entity_name_plural()} to process in each batch", + ) + + def add_arguments(self, parser): + """Add arguments to the command. Override to add custom arguments.""" + self.add_common_arguments(parser) + + def get_queryset(self, options: dict[str, Any]) -> QuerySet: + """Get the queryset based on command options.""" + key_option = f"{self.entity_name()}_key" + + if options.get(key_option): + filter_kwargs = {self.key_field_name(): options[key_option]} + return self.get_base_queryset().filter(**filter_kwargs) + if options.get("all"): + return self.get_base_queryset() + return self.get_default_queryset() + + def get_entity_key(self, entity: Model) -> str: + """Get the key/identifier for an entity for display purposes.""" + return str(getattr(entity, self.key_field_name(), entity.pk)) + + def setup_openai_client(self) -> bool: + """Set up OpenAI client if API key is available.""" + if openai_api_key := os.getenv("DJANGO_OPEN_AI_SECRET_KEY"): + self.openai_client = openai.OpenAI(api_key=openai_api_key) + return True + self.stdout.write( + self.style.ERROR("DJANGO_OPEN_AI_SECRET_KEY environment variable not set") + ) + return False + + def handle_batch_processing( + self, + queryset: QuerySet, + batch_size: int, + process_batch_func: Callable[[list[Model]], int], + ) -> None: + """Handle the common batch processing logic.""" + total_count = queryset.count() + + if not total_count: + self.stdout.write(f"No {self.entity_name_plural()} found to process") + return + + self.stdout.write(f"Found {total_count} {self.entity_name_plural()} to process") + + processed_count = 0 + for offset in range(0, total_count, batch_size): + batch_items = queryset[offset : offset + batch_size] + processed_count += process_batch_func(list(batch_items)) + + self.stdout.write( + self.style.SUCCESS( + f"Completed processing {processed_count}/{total_count} {self.entity_name_plural()}" + ) + ) diff --git a/backend/apps/ai/common/base/chunk_command.py b/backend/apps/ai/common/base/chunk_command.py new file mode 100644 index 0000000000..1dc929f434 --- /dev/null +++ b/backend/apps/ai/common/base/chunk_command.py @@ -0,0 +1,81 @@ +"""Base chunk command class for creating chunks.""" + +from django.contrib.contenttypes.models import ContentType +from django.db.models import Model + +from apps.ai.common.base.ai_command import BaseAICommand +from apps.ai.common.utils import create_chunks_and_embeddings +from apps.ai.models.chunk import Chunk +from apps.ai.models.context import Context + + +class BaseChunkCommand(BaseAICommand): + """Base class for chunk creation commands.""" + + def help(self) -> str: + """Return help text for the chunk creation command.""" + return f"Create chunks for OWASP {self.entity_name()} data" + + def process_chunks_batch(self, entities: list[Model]) -> int: + """Process a batch of entities to create chunks.""" + processed = 0 + batch_chunks = [] + content_type = ContentType.objects.get_for_model(self.model_class()) + + for entity in entities: + context = Context.objects.filter( + content_type=content_type, object_id=entity.id + ).first() + + entity_key = self.get_entity_key(entity) + + if not context: + self.stdout.write( + self.style.WARNING(f"No context found for {self.entity_name()} {entity_key}") + ) + continue + + prose_content, metadata_content = self.extract_content(entity) + full_content = ( + f"{metadata_content}\n\n{prose_content}" if metadata_content else prose_content + ) + + if not full_content.strip(): + self.stdout.write(f"No content to chunk for {self.entity_name()} {entity_key}") + continue + + chunk_texts = Chunk.split_text(full_content) + if not chunk_texts: + self.stdout.write( + f"No chunks created for {self.entity_name()} {entity_key}: `{full_content}`" + ) + continue + + if chunks := create_chunks_and_embeddings( + chunk_texts=chunk_texts, + context=context, + openai_client=self.openai_client, + save=False, + ): + batch_chunks.extend(chunks) + processed += 1 + self.stdout.write(f"Created {len(chunks)} chunks for {entity_key}") + + if batch_chunks: + Chunk.bulk_save(batch_chunks) + + return processed + + def handle(self, *args, **options): + """Handle the chunk creation command.""" + if not self.setup_openai_client(): + return + + queryset = self.get_queryset(options) + batch_size = options["batch_size"] + + self.handle_batch_processing( + queryset=queryset, + batch_size=batch_size, + process_batch_func=self.process_chunks_batch, + ) diff --git a/backend/apps/ai/common/base/context_command.py b/backend/apps/ai/common/base/context_command.py new file mode 100644 index 0000000000..ee4f64dfd0 --- /dev/null +++ b/backend/apps/ai/common/base/context_command.py @@ -0,0 +1,54 @@ +"""Base context command class for creating contexts.""" + +from django.db.models import Model + +from apps.ai.common.base.ai_command import BaseAICommand +from apps.ai.models.context import Context + + +class BaseContextCommand(BaseAICommand): + """Base class for context creation commands.""" + + def help(self) -> str: + """Return help text for the context creation command.""" + return f"Update context for OWASP {self.entity_name()} data" + + def process_context_batch(self, entities: list[Model]) -> int: + """Process a batch of entities to create contexts.""" + processed = 0 + + for entity in entities: + prose_content, metadata_content = self.extract_content(entity) + full_content = ( + f"{metadata_content}\n\n{prose_content}" if metadata_content else prose_content + ) + + if not full_content.strip(): + entity_key = self.get_entity_key(entity) + self.stdout.write(f"No content for {self.entity_name()} {entity_key}") + continue + + if Context.update_data( + content=full_content, + content_object=entity, + source=self.source_name(), + ): + processed += 1 + entity_key = self.get_entity_key(entity) + self.stdout.write(f"Created context for {entity_key}") + else: + entity_key = self.get_entity_key(entity) + self.stdout.write(self.style.ERROR(f"Failed to create context for {entity_key}")) + + return processed + + def handle(self, *args, **options): + """Handle the context creation command.""" + queryset = self.get_queryset(options) + batch_size = options["batch_size"] + + self.handle_batch_processing( + queryset=queryset, + batch_size=batch_size, + process_batch_func=self.process_context_batch, + ) diff --git a/backend/apps/ai/common/extractors/chapter.py b/backend/apps/ai/common/extractors/chapter.py index ef4964f98a..0bacfcff2e 100644 --- a/backend/apps/ai/common/extractors/chapter.py +++ b/backend/apps/ai/common/extractors/chapter.py @@ -26,12 +26,8 @@ def extract_chapter_content(chapter) -> tuple[str, str]: repo = chapter.owasp_repository if repo.description: prose_parts.append(f"Repository Description: {repo.description}") - if repo.topics and hasattr(repo.topics, "__iter__") and not isinstance(repo.topics, str): - try: - metadata_parts.append(f"Repository Topics: {', '.join(repo.topics)}") - except TypeError: - # If topics is not iterable, convert to string - metadata_parts.append(f"Repository Topics: {repo.topics}") + if repo.topics: + metadata_parts.append(f"Repository Topics: {', '.join(repo.topics)}") if chapter.name: metadata_parts.append(f"Chapter Name: {chapter.name}") diff --git a/backend/apps/ai/common/utils.py b/backend/apps/ai/common/utils.py index 065e24eb5e..95592516ae 100644 --- a/backend/apps/ai/common/utils.py +++ b/backend/apps/ai/common/utils.py @@ -16,21 +16,6 @@ logger: logging.Logger = logging.getLogger(__name__) -def create_context(content: str, content_object=None, source: str = "") -> Context: - """Create and save a Context instance independently. - - Args: - content (str): The context content - content_object: Optional related object - source (str): Source identifier - - Returns: - Context: Created Context instance - - """ - return Context.update_data(content=content, content_object=content_object, source=source) - - def create_chunks_and_embeddings( chunk_texts: list[str], context: Context, @@ -55,10 +40,6 @@ def create_chunks_and_embeddings( ValueError: If context is None or invalid """ - if context is None: - error_msg = "Context is required for chunk creation.please create a context first." - raise ValueError(error_msg) - try: last_request_time = datetime.now(UTC) - timedelta( seconds=DEFAULT_LAST_REQUEST_OFFSET_SECONDS @@ -66,7 +47,9 @@ def create_chunks_and_embeddings( time_since_last_request = datetime.now(UTC) - last_request_time if time_since_last_request < timedelta(seconds=MIN_REQUEST_INTERVAL_SECONDS): - time.sleep(MIN_REQUEST_INTERVAL_SECONDS - time_since_last_request.total_seconds()) + time.sleep( + MIN_REQUEST_INTERVAL_SECONDS - time_since_last_request.total_seconds() + ) response = openai_client.embeddings.create( input=chunk_texts, @@ -76,17 +59,16 @@ def create_chunks_and_embeddings( chunks = [] for text, embedding in zip(chunk_texts, embeddings, strict=True): - chunk = Chunk.update_data(text=text, embedding=embedding, save=False) - chunk.context = context - if save: - chunk.save() + chunk = Chunk.update_data( + text=text, embedding=embedding, context=context, save=save + ) chunks.append(chunk) - except openai.OpenAIError: - logger.exception("Failed to create chunks and embeddings") - return [] except ValueError: logger.exception("Context error") raise + except openai.OpenAIError: + logger.exception("Failed to create chunks and embeddings") + return [] else: return chunks diff --git a/backend/apps/ai/management/commands/ai_create_chapter_chunks.py b/backend/apps/ai/management/commands/ai_create_chapter_chunks.py index 5ae5caa663..bafd397c9f 100644 --- a/backend/apps/ai/management/commands/ai_create_chapter_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_chapter_chunks.py @@ -2,28 +2,24 @@ from django.db.models import Model -from apps.ai.common.base import BaseChunkCommand +from apps.ai.common.base.chunk_command import BaseChunkCommand from apps.ai.common.extractors.chapter import extract_chapter_content from apps.owasp.models.chapter import Chapter class Command(BaseChunkCommand): - @property - def model_class(self) -> type[Model]: - return Chapter - - @property def entity_name(self) -> str: return "chapter" - @property def entity_name_plural(self) -> str: return "chapters" - @property - def key_field_name(self) -> str: - return "key" - def extract_content(self, entity: Chapter) -> tuple[str, str]: """Extract content from the chapter.""" return extract_chapter_content(entity) + + def key_field_name(self) -> str: + return "key" + + def model_class(self) -> type[Model]: + return Chapter diff --git a/backend/apps/ai/management/commands/ai_create_chapter_context.py b/backend/apps/ai/management/commands/ai_create_chapter_context.py index 46b13509fa..377024de1b 100644 --- a/backend/apps/ai/management/commands/ai_create_chapter_context.py +++ b/backend/apps/ai/management/commands/ai_create_chapter_context.py @@ -2,28 +2,24 @@ from django.db.models import Model -from apps.ai.common.base import BaseContextCommand +from apps.ai.common.base.context_command import BaseContextCommand from apps.ai.common.extractors.chapter import extract_chapter_content from apps.owasp.models.chapter import Chapter class Command(BaseContextCommand): - @property - def model_class(self) -> type[Model]: - return Chapter - - @property def entity_name(self) -> str: return "chapter" - @property def entity_name_plural(self) -> str: return "chapters" - @property - def key_field_name(self) -> str: - return "key" - def extract_content(self, entity: Chapter) -> tuple[str, str]: """Extract content from the chapter.""" return extract_chapter_content(entity) + + def key_field_name(self) -> str: + return "key" + + def model_class(self) -> type[Model]: + return Chapter diff --git a/backend/apps/ai/management/commands/ai_create_committee_chunks.py b/backend/apps/ai/management/commands/ai_create_committee_chunks.py index 23bd51d552..fee8bb5ea4 100644 --- a/backend/apps/ai/management/commands/ai_create_committee_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_committee_chunks.py @@ -2,28 +2,24 @@ from django.db.models import Model -from apps.ai.common.base import BaseChunkCommand +from apps.ai.common.base.chunk_command import BaseChunkCommand from apps.ai.common.extractors.committee import extract_committee_content from apps.owasp.models.committee import Committee class Command(BaseChunkCommand): - @property - def model_class(self) -> type[Model]: - return Committee - - @property def entity_name(self) -> str: return "committee" - @property def entity_name_plural(self) -> str: return "committees" - @property - def key_field_name(self) -> str: - return "key" - def extract_content(self, entity: Committee) -> tuple[str, str]: """Extract content from the committee.""" return extract_committee_content(entity) + + def key_field_name(self) -> str: + return "key" + + def model_class(self) -> type[Model]: + return Committee diff --git a/backend/apps/ai/management/commands/ai_create_committee_context.py b/backend/apps/ai/management/commands/ai_create_committee_context.py index 4a17b58dd6..de3965196e 100644 --- a/backend/apps/ai/management/commands/ai_create_committee_context.py +++ b/backend/apps/ai/management/commands/ai_create_committee_context.py @@ -2,28 +2,24 @@ from django.db.models import Model -from apps.ai.common.base import BaseContextCommand +from apps.ai.common.base.context_command import BaseContextCommand from apps.ai.common.extractors.committee import extract_committee_content from apps.owasp.models.committee import Committee class Command(BaseContextCommand): - @property - def model_class(self) -> type[Model]: - return Committee - - @property def entity_name(self) -> str: return "committee" - @property def entity_name_plural(self) -> str: return "committees" - @property - def key_field_name(self) -> str: - return "key" - def extract_content(self, entity: Committee) -> tuple[str, str]: """Extract content from the committee.""" return extract_committee_content(entity) + + def key_field_name(self) -> str: + return "key" + + def model_class(self) -> type[Model]: + return Committee diff --git a/backend/apps/ai/management/commands/ai_create_event_chunks.py b/backend/apps/ai/management/commands/ai_create_event_chunks.py index a35c0fd97f..e2a499008e 100644 --- a/backend/apps/ai/management/commands/ai_create_event_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_event_chunks.py @@ -2,36 +2,32 @@ from django.db.models import Model, QuerySet -from apps.ai.common.base import BaseChunkCommand +from apps.ai.common.base.chunk_command import BaseChunkCommand from apps.ai.common.extractors.event import extract_event_content from apps.owasp.models.event import Event class Command(BaseChunkCommand): - @property - def model_class(self) -> type[Model]: - return Event - - @property def entity_name(self) -> str: return "event" - @property def entity_name_plural(self) -> str: return "events" - @property - def key_field_name(self) -> str: - return "key" + def extract_content(self, entity: Event) -> tuple[str, str]: + """Extract content from the event.""" + return extract_event_content(entity) + + def get_base_queryset(self) -> QuerySet: + """Return the base queryset with ordering.""" + return super().get_base_queryset() def get_default_queryset(self) -> QuerySet: """Return upcoming events by default instead of is_active filter.""" return Event.upcoming_events() - def get_base_queryset(self) -> QuerySet: - """Return the base queryset with ordering.""" - return super().get_base_queryset() + def key_field_name(self) -> str: + return "key" - def extract_content(self, entity: Event) -> tuple[str, str]: - """Extract content from the event.""" - return extract_event_content(entity) + def model_class(self) -> type[Model]: + return Event diff --git a/backend/apps/ai/management/commands/ai_create_event_context.py b/backend/apps/ai/management/commands/ai_create_event_context.py index a866690da7..49f20cb2b9 100644 --- a/backend/apps/ai/management/commands/ai_create_event_context.py +++ b/backend/apps/ai/management/commands/ai_create_event_context.py @@ -2,36 +2,32 @@ from django.db.models import Model, QuerySet -from apps.ai.common.base import BaseContextCommand +from apps.ai.common.base.context_command import BaseContextCommand from apps.ai.common.extractors.event import extract_event_content from apps.owasp.models.event import Event class Command(BaseContextCommand): - @property - def model_class(self) -> type[Model]: - return Event - - @property def entity_name(self) -> str: return "event" - @property def entity_name_plural(self) -> str: return "events" - @property - def key_field_name(self) -> str: - return "key" + def extract_content(self, entity: Event) -> tuple[str, str]: + """Extract content from the event.""" + return extract_event_content(entity) + + def get_base_queryset(self) -> QuerySet: + """Return the base queryset with ordering.""" + return super().get_base_queryset() def get_default_queryset(self) -> QuerySet: """Return upcoming events by default instead of is_active filter.""" return Event.upcoming_events() - def get_base_queryset(self) -> QuerySet: - """Return the base queryset with ordering.""" - return super().get_base_queryset() + def key_field_name(self) -> str: + return "key" - def extract_content(self, entity: Event) -> tuple[str, str]: - """Extract content from the event.""" - return extract_event_content(entity) + def model_class(self) -> type[Model]: + return Event diff --git a/backend/apps/ai/management/commands/ai_create_project_chunks.py b/backend/apps/ai/management/commands/ai_create_project_chunks.py index 5aaa96a8db..255b217558 100644 --- a/backend/apps/ai/management/commands/ai_create_project_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_project_chunks.py @@ -2,32 +2,28 @@ from django.db.models import Model, QuerySet -from apps.ai.common.base import BaseChunkCommand +from apps.ai.common.base.chunk_command import BaseChunkCommand from apps.ai.common.extractors.project import extract_project_content from apps.owasp.models.project import Project class Command(BaseChunkCommand): - @property - def model_class(self) -> type[Model]: - return Project - - @property def entity_name(self) -> str: return "project" - @property def entity_name_plural(self) -> str: return "projects" - @property - def key_field_name(self) -> str: - return "key" + def extract_content(self, entity: Project) -> tuple[str, str]: + """Extract content from the project.""" + return extract_project_content(entity) def get_base_queryset(self) -> QuerySet: """Return the base queryset with ordering.""" return super().get_base_queryset() - def extract_content(self, entity: Project) -> tuple[str, str]: - """Extract content from the project.""" - return extract_project_content(entity) + def key_field_name(self) -> str: + return "key" + + def model_class(self) -> type[Model]: + return Project diff --git a/backend/apps/ai/management/commands/ai_create_project_context.py b/backend/apps/ai/management/commands/ai_create_project_context.py index dc10befd33..47e509f1e6 100644 --- a/backend/apps/ai/management/commands/ai_create_project_context.py +++ b/backend/apps/ai/management/commands/ai_create_project_context.py @@ -2,32 +2,28 @@ from django.db.models import Model, QuerySet -from apps.ai.common.base import BaseContextCommand +from apps.ai.common.base.context_command import BaseContextCommand from apps.ai.common.extractors.project import extract_project_content from apps.owasp.models.project import Project class Command(BaseContextCommand): - @property - def model_class(self) -> type[Model]: - return Project - - @property def entity_name(self) -> str: return "project" - @property def entity_name_plural(self) -> str: return "projects" - @property - def key_field_name(self) -> str: - return "key" + def extract_content(self, entity: Project) -> tuple[str, str]: + """Extract content from the project.""" + return extract_project_content(entity) def get_base_queryset(self) -> QuerySet: """Return the base queryset with ordering.""" return super().get_base_queryset() - def extract_content(self, entity: Project) -> tuple[str, str]: - """Extract content from the project.""" - return extract_project_content(entity) + def key_field_name(self) -> str: + return "key" + + def model_class(self) -> type[Model]: + return Project diff --git a/backend/apps/ai/management/commands/ai_create_slack_message_chunks.py b/backend/apps/ai/management/commands/ai_create_slack_message_chunks.py index b34cf969da..31bedefa48 100644 --- a/backend/apps/ai/management/commands/ai_create_slack_message_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_slack_message_chunks.py @@ -2,35 +2,11 @@ from django.db.models import Model, QuerySet -from apps.ai.common.base import BaseChunkCommand +from apps.ai.common.base.chunk_command import BaseChunkCommand from apps.slack.models.message import Message class Command(BaseChunkCommand): - @property - def model_class(self) -> type[Model]: - return Message - - @property - def entity_name(self) -> str: - return "message" - - @property - def entity_name_plural(self) -> str: - return "messages" - - @property - def key_field_name(self) -> str: - return "slack_message_id" - - @property - def source_name(self) -> str: - return "slack_message" - - def get_default_queryset(self) -> QuerySet: - """Return all messages by default since Message model doesn't have is_active field.""" - return self.get_base_queryset() - def add_arguments(self, parser): """Override to use different default batch size for messages.""" parser.add_argument( @@ -50,6 +26,25 @@ def add_arguments(self, parser): help="Number of messages to process in each batch", ) + def entity_name(self) -> str: + return "message" + + def entity_name_plural(self) -> str: + return "messages" + def extract_content(self, entity: Message) -> tuple[str, str]: """Extract content from the message.""" return entity.cleaned_text or "", "" + + def get_default_queryset(self) -> QuerySet: + """Return all messages by default since Message model doesn't have is_active field.""" + return self.get_base_queryset() + + def key_field_name(self) -> str: + return "slack_message_id" + + def model_class(self) -> type[Model]: + return Message + + def source_name(self) -> str: + return "slack_message" diff --git a/backend/apps/ai/management/commands/ai_create_slack_message_context.py b/backend/apps/ai/management/commands/ai_create_slack_message_context.py index 3e3d3a135c..ecf8b28c5e 100644 --- a/backend/apps/ai/management/commands/ai_create_slack_message_context.py +++ b/backend/apps/ai/management/commands/ai_create_slack_message_context.py @@ -2,35 +2,11 @@ from django.db.models import Model, QuerySet -from apps.ai.common.base import BaseContextCommand +from apps.ai.common.base.context_command import BaseContextCommand from apps.slack.models.message import Message class Command(BaseContextCommand): - @property - def model_class(self) -> type[Model]: - return Message - - @property - def entity_name(self) -> str: - return "message" - - @property - def entity_name_plural(self) -> str: - return "messages" - - @property - def key_field_name(self) -> str: - return "slack_message_id" - - @property - def source_name(self) -> str: - return "slack_message" - - def get_default_queryset(self) -> QuerySet: - """Return all messages by default since Message model doesn't have is_active field.""" - return self.get_base_queryset() - def add_arguments(self, parser): """Override to use different default batch size for messages.""" parser.add_argument( @@ -50,6 +26,25 @@ def add_arguments(self, parser): help="Number of messages to process in each batch", ) + def entity_name(self) -> str: + return "message" + + def entity_name_plural(self) -> str: + return "messages" + def extract_content(self, entity: Message) -> tuple[str, str]: """Extract content from the message.""" return entity.cleaned_text or "", "" + + def get_default_queryset(self) -> QuerySet: + """Return all messages by default since Message model doesn't have is_active field.""" + return self.get_base_queryset() + + def key_field_name(self) -> str: + return "slack_message_id" + + def model_class(self) -> type[Model]: + return Message + + def source_name(self) -> str: + return "slack_message" diff --git a/backend/apps/ai/models/chunk.py b/backend/apps/ai/models/chunk.py index 361597d050..efd1f24800 100644 --- a/backend/apps/ai/models/chunk.py +++ b/backend/apps/ai/models/chunk.py @@ -45,6 +45,7 @@ def split_text(text: str) -> list[str]: def update_data( text: str, embedding, + context: Context, *, save: bool = True, ) -> "Chunk": @@ -53,18 +54,16 @@ def update_data( Args: text (str): The text content of the chunk. embedding (list): The embedding vector for the chunk. + context (Context): The context this chunk belongs to. save (bool): Whether to save the chunk to the database. Returns: Chunk: The created chunk instance. """ - chunk = Chunk(text=text, embedding=embedding) + chunk = Chunk(text=text, embedding=embedding, context=context) if save: - if chunk.context_id is None: - error_msg = "Chunk must have a context assigned before saving." - raise ValueError(error_msg) chunk.save() return chunk diff --git a/backend/tests/apps/ai/agent/__init__.py b/backend/tests/apps/ai/agent/__init__.py new file mode 100644 index 0000000000..60947f7d73 --- /dev/null +++ b/backend/tests/apps/ai/agent/__init__.py @@ -0,0 +1 @@ +"""AI agent tests package.""" diff --git a/backend/tests/apps/ai/agent/tools/__init__.py b/backend/tests/apps/ai/agent/tools/__init__.py new file mode 100644 index 0000000000..e60b458847 --- /dev/null +++ b/backend/tests/apps/ai/agent/tools/__init__.py @@ -0,0 +1 @@ +"""AI agent tools tests package.""" diff --git a/backend/tests/apps/ai/agent/tools/rag/__init__.py b/backend/tests/apps/ai/agent/tools/rag/__init__.py new file mode 100644 index 0000000000..9e0722de20 --- /dev/null +++ b/backend/tests/apps/ai/agent/tools/rag/__init__.py @@ -0,0 +1 @@ +"""RAG tests package.""" diff --git a/backend/tests/apps/ai/agent/tools/rag/generator_test.py b/backend/tests/apps/ai/agent/tools/rag/generator_test.py new file mode 100644 index 0000000000..2cd07fb14c --- /dev/null +++ b/backend/tests/apps/ai/agent/tools/rag/generator_test.py @@ -0,0 +1,200 @@ +"""Tests for the RAG Generator.""" + +import os +from unittest.mock import MagicMock, patch + +import openai +import pytest + +from apps.ai.agent.tools.rag.generator import Generator + + +class TestGenerator: + """Test cases for the Generator class.""" + + def test_init_success(self): + """Test successful initialization with API key.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI") as mock_openai, + ): + mock_client = MagicMock() + mock_openai.return_value = mock_client + + generator = Generator(chat_model="gpt-4") + + assert generator.chat_model == "gpt-4" + assert generator.openai_client == mock_client + mock_openai.assert_called_once_with(api_key="test-key") + + def test_init_no_api_key(self): + """Test initialization fails when API key is not set.""" + with ( + patch.dict(os.environ, {}, clear=True), + pytest.raises( + ValueError, + match="DJANGO_OPEN_AI_SECRET_KEY environment variable not set", + ), + ): + Generator() + + def test_prepare_context_empty_chunks(self): + """Test context preparation with empty chunks list.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + generator = Generator() + + result = generator.prepare_context([]) + + assert result == "No context provided" + + def test_prepare_context_with_chunks(self): + """Test context preparation with valid chunks.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + generator = Generator() + + chunks = [ + {"source_name": "Chapter 1", "text": "This is chapter 1 content"}, + {"source_name": "Chapter 2", "text": "This is chapter 2 content"}, + ] + + result = generator.prepare_context(chunks) + + expected = ( + "Source Name: Chapter 1\nContent: This is chapter 1 content\n\n" + "---\n\n" + "Source Name: Chapter 2\nContent: This is chapter 2 content" + ) + assert result == expected + + def test_prepare_context_missing_fields(self): + """Test context preparation with chunks missing fields.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + generator = Generator() + + chunks = [ + {"text": "Content without source"}, + {"source_name": "Source without content"}, + {}, + ] + + result = generator.prepare_context(chunks) + + expected = ( + "Source Name: Unknown Source 1\nContent: Content without source\n\n" + "---\n\n" + "Source Name: Source without content\nContent: \n\n" + "---\n\n" + "Source Name: Unknown Source 3\nContent: " + ) + assert result == expected + + def test_generate_answer_success(self): + """Test successful answer generation.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI") as mock_openai, + ): + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Generated answer" + mock_client.chat.completions.create.return_value = mock_response + mock_openai.return_value = mock_client + + generator = Generator() + + chunks = [{"source_name": "Test", "text": "Test content"}] + result = generator.generate_answer("What is OWASP?", chunks) + + assert result == "Generated answer" + mock_client.chat.completions.create.assert_called_once() + call_args = mock_client.chat.completions.create.call_args + assert call_args[1]["model"] == "gpt-4o" + assert call_args[1]["temperature"] == 0.4 + assert call_args[1]["max_tokens"] == 2000 + assert len(call_args[1]["messages"]) == 2 + assert call_args[1]["messages"][0]["role"] == "system" + assert call_args[1]["messages"][1]["role"] == "user" + + def test_generate_answer_with_custom_model(self): + """Test answer generation with custom chat model.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI") as mock_openai, + ): + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Custom model answer" + mock_client.chat.completions.create.return_value = mock_response + mock_openai.return_value = mock_client + + generator = Generator(chat_model="gpt-3.5-turbo") + + chunks = [{"source_name": "Test", "text": "Test content"}] + result = generator.generate_answer("Test query", chunks) + + assert result == "Custom model answer" + call_args = mock_client.chat.completions.create.call_args + assert call_args[1]["model"] == "gpt-3.5-turbo" + + def test_generate_answer_openai_error(self): + """Test answer generation with OpenAI API error.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI") as mock_openai, + ): + mock_client = MagicMock() + mock_client.chat.completions.create.side_effect = openai.OpenAIError("API Error") + mock_openai.return_value = mock_client + + generator = Generator() + + chunks = [{"source_name": "Test", "text": "Test content"}] + result = generator.generate_answer("Test query", chunks) + + assert result == "I'm sorry, I'm currently unable to process your request." + + def test_generate_answer_with_empty_chunks(self): + """Test answer generation with empty chunks.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI") as mock_openai, + ): + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "No context answer" + mock_client.chat.completions.create.return_value = mock_response + mock_openai.return_value = mock_client + + generator = Generator() + + result = generator.generate_answer("Test query", []) + + assert result == "No context answer" + call_args = mock_client.chat.completions.create.call_args + assert "No context provided" in call_args[1]["messages"][1]["content"] + + def test_system_prompt_content(self): + """Test that system prompt contains expected content.""" + assert "OWASP Foundation" in Generator.SYSTEM_PROMPT + assert "context" in Generator.SYSTEM_PROMPT.lower() + assert "professional" in Generator.SYSTEM_PROMPT.lower() + assert "latitude and longitude" in Generator.SYSTEM_PROMPT.lower() + + def test_constants(self): + """Test class constants have expected values.""" + assert Generator.MAX_TOKENS == 2000 + assert Generator.TEMPERATURE == 0.4 + assert isinstance(Generator.SYSTEM_PROMPT, str) + assert len(Generator.SYSTEM_PROMPT) > 0 diff --git a/backend/tests/apps/ai/agent/tools/rag/rag_tool_test.py b/backend/tests/apps/ai/agent/tools/rag/rag_tool_test.py new file mode 100644 index 0000000000..dcb24ca970 --- /dev/null +++ b/backend/tests/apps/ai/agent/tools/rag/rag_tool_test.py @@ -0,0 +1,179 @@ +"""Tests for the RAG Tool.""" + +import os +from unittest.mock import MagicMock, patch + +import pytest + +from apps.ai.agent.tools.rag.rag_tool import RagTool + + +class TestRagTool: + """Test cases for the RagTool class.""" + + def test_init_success(self): + """Test successful initialization of RagTool.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("apps.ai.agent.tools.rag.rag_tool.Retriever") as mock_retriever_class, + patch("apps.ai.agent.tools.rag.rag_tool.Generator") as mock_generator_class, + ): + mock_retriever = MagicMock() + mock_generator = MagicMock() + mock_retriever_class.return_value = mock_retriever + mock_generator_class.return_value = mock_generator + + rag_tool = RagTool( + embedding_model="custom-embedding-model", chat_model="custom-chat-model" + ) + + assert rag_tool.retriever == mock_retriever + assert rag_tool.generator == mock_generator + mock_retriever_class.assert_called_once_with(embedding_model="custom-embedding-model") + mock_generator_class.assert_called_once_with(chat_model="custom-chat-model") + + def test_init_default_models(self): + """Test initialization with default model parameters.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("apps.ai.agent.tools.rag.rag_tool.Retriever") as mock_retriever_class, + patch("apps.ai.agent.tools.rag.rag_tool.Generator") as mock_generator_class, + ): + mock_retriever = MagicMock() + mock_generator = MagicMock() + mock_retriever_class.return_value = mock_retriever + mock_generator_class.return_value = mock_generator + + RagTool() + + mock_retriever_class.assert_called_once_with(embedding_model="text-embedding-3-small") + mock_generator_class.assert_called_once_with(chat_model="gpt-4o") + + def test_init_no_api_key(self): + """Test initialization fails when API key is not set.""" + with ( + patch.dict(os.environ, {}, clear=True), + patch("apps.ai.agent.tools.rag.rag_tool.Retriever") as mock_retriever_class, + ): + mock_retriever_class.side_effect = ValueError( + "DJANGO_OPEN_AI_SECRET_KEY environment variable not set" + ) + + with pytest.raises( + ValueError, + match="DJANGO_OPEN_AI_SECRET_KEY environment variable not set", + ): + RagTool() + + def test_query_success(self): + """Test successful query execution.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("apps.ai.agent.tools.rag.rag_tool.Retriever") as mock_retriever_class, + patch("apps.ai.agent.tools.rag.rag_tool.Generator") as mock_generator_class, + ): + mock_retriever = MagicMock() + mock_generator = MagicMock() + mock_retriever_class.return_value = mock_retriever + mock_generator_class.return_value = mock_generator + + mock_chunks = [{"text": "Test content", "source_name": "Test Source"}] + mock_retriever.retrieve.return_value = mock_chunks + mock_generator.generate_answer.return_value = "Generated answer" + + rag_tool = RagTool() + + result = rag_tool.query( + question="What is OWASP?", + content_types=["chapter"], + limit=10, + similarity_threshold=0.5, + ) + + assert result == "Generated answer" + mock_retriever.retrieve.assert_called_once_with( + content_types=["chapter"], + limit=10, + query="What is OWASP?", + similarity_threshold=0.5, + ) + mock_generator.generate_answer.assert_called_once_with( + context_chunks=mock_chunks, query="What is OWASP?" + ) + + def test_query_with_defaults(self): + """Test query with default parameters.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("apps.ai.agent.tools.rag.rag_tool.Retriever") as mock_retriever_class, + patch("apps.ai.agent.tools.rag.rag_tool.Generator") as mock_generator_class, + ): + mock_retriever = MagicMock() + mock_generator = MagicMock() + mock_retriever_class.return_value = mock_retriever + mock_generator_class.return_value = mock_generator + + mock_chunks = [] + mock_retriever.retrieve.return_value = mock_chunks + mock_generator.generate_answer.return_value = "Default answer" + + rag_tool = RagTool() + + result = rag_tool.query("Test question") + + assert result == "Default answer" + mock_retriever.retrieve.assert_called_once_with( + content_types=None, + limit=5, + query="Test question", + similarity_threshold=0.4, + ) + + def test_query_empty_content_types(self): + """Test query with empty content types list.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("apps.ai.agent.tools.rag.rag_tool.Retriever") as mock_retriever_class, + patch("apps.ai.agent.tools.rag.rag_tool.Generator") as mock_generator_class, + ): + mock_retriever = MagicMock() + mock_generator = MagicMock() + mock_retriever_class.return_value = mock_retriever + mock_generator_class.return_value = mock_generator + + mock_chunks = [] + mock_retriever.retrieve.return_value = mock_chunks + mock_generator.generate_answer.return_value = "Answer" + + rag_tool = RagTool() + + result = rag_tool.query("Test question", content_types=[]) + + assert result == "Answer" + mock_retriever.retrieve.assert_called_once_with( + content_types=[], + limit=5, + query="Test question", + similarity_threshold=0.4, + ) + + @patch("apps.ai.agent.tools.rag.rag_tool.logger") + def test_query_logs_retrieval(self, mock_logger): + """Test that query logs the retrieval process.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("apps.ai.agent.tools.rag.rag_tool.Retriever") as mock_retriever_class, + patch("apps.ai.agent.tools.rag.rag_tool.Generator") as mock_generator_class, + ): + mock_retriever = MagicMock() + mock_generator = MagicMock() + mock_retriever_class.return_value = mock_retriever + mock_generator_class.return_value = mock_generator + + mock_retriever.retrieve.return_value = [] + mock_generator.generate_answer.return_value = "Answer" + + rag_tool = RagTool() + rag_tool.query("Test question") + + mock_logger.info.assert_called_once_with("Retrieving context for query") diff --git a/backend/tests/apps/ai/agent/tools/rag/retriever_test.py b/backend/tests/apps/ai/agent/tools/rag/retriever_test.py new file mode 100644 index 0000000000..3e3e254295 --- /dev/null +++ b/backend/tests/apps/ai/agent/tools/rag/retriever_test.py @@ -0,0 +1,577 @@ +"""Tests for the RAG Retriever.""" + +import os +from unittest.mock import MagicMock, patch + +import openai +import pytest + +from apps.ai.agent.tools.rag.retriever import Retriever + + +class TestRetriever: + """Test cases for the Retriever class.""" + + def test_init_success(self): + """Test successful initialization with API key.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI") as mock_openai, + ): + mock_client = MagicMock() + mock_openai.return_value = mock_client + + retriever = Retriever(embedding_model="custom-model") + + assert retriever.embedding_model == "custom-model" + assert retriever.openai_client == mock_client + mock_openai.assert_called_once_with(api_key="test-key") + + def test_init_no_api_key(self): + """Test initialization fails when API key is not set.""" + with ( + patch.dict(os.environ, {}, clear=True), + pytest.raises( + ValueError, + match="DJANGO_OPEN_AI_SECRET_KEY environment variable not set", + ), + ): + Retriever() + + def test_init_default_model(self): + """Test initialization with default embedding model.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + assert retriever.embedding_model == "text-embedding-3-small" + + def test_get_query_embedding_success(self): + """Test successful query embedding generation.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI") as mock_openai, + ): + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])] + mock_client.embeddings.create.return_value = mock_response + mock_openai.return_value = mock_client + + retriever = Retriever() + result = retriever.get_query_embedding("test query") + + assert result == [0.1, 0.2, 0.3] + mock_client.embeddings.create.assert_called_once_with( + input=["test query"], model="text-embedding-3-small" + ) + + def test_get_query_embedding_openai_error(self): + """Test query embedding with OpenAI API error.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI") as mock_openai, + ): + mock_client = MagicMock() + mock_client.embeddings.create.side_effect = openai.OpenAIError("API Error") + mock_openai.return_value = mock_client + + retriever = Retriever() + + with pytest.raises(openai.OpenAIError): + retriever.get_query_embedding("test query") + + def test_get_query_embedding_unexpected_error(self): + """Test query embedding with unexpected error.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI") as mock_openai, + ): + mock_client = MagicMock() + mock_client.embeddings.create.side_effect = Exception("Unexpected error") + mock_openai.return_value = mock_client + + retriever = Retriever() + + with pytest.raises(Exception, match="Unexpected error"): + retriever.get_query_embedding("test query") + + def test_get_source_name_with_name(self): + """Test getting source name when object has name attribute.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + content_object = MagicMock() + content_object.name = "Test Name" + content_object.title = "Test Title" + + result = retriever.get_source_name(content_object) + assert result == "Test Name" + + def test_get_source_name_with_title(self): + """Test getting source name when object has title but no name.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + content_object = MagicMock() + content_object.name = None + content_object.title = "Test Title" + content_object.login = "test_login" + + result = retriever.get_source_name(content_object) + assert result == "Test Title" + + def test_get_source_name_with_login(self): + """Test getting source name when object has login but no name/title.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + content_object = MagicMock() + content_object.name = None + content_object.title = None + content_object.login = "test_login" + content_object.key = "test_key" + + result = retriever.get_source_name(content_object) + assert result == "test_login" + + def test_get_source_name_fallback_to_str(self): + """Test getting source name falls back to string representation.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + content_object = MagicMock() + content_object.name = None + content_object.title = None + content_object.login = None + content_object.key = None + content_object.summary = None + content_object.__str__ = MagicMock(return_value="String representation") + + result = retriever.get_source_name(content_object) + assert result == "String representation" + + def test_get_additional_context_chapter(self): + """Test getting additional context for chapter content type.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + content_object = MagicMock() + content_object.suggested_location = "New York" + content_object.region = "North America" + content_object.country = "USA" + content_object.postal_code = "10001" + content_object.currency = "USD" + content_object.meetup_group = "OWASP NYC" + content_object.tags = ["security", "web"] + content_object.topics = ["OWASP Top 10"] + content_object.leaders_raw = ["John Doe", "Jane Smith"] + content_object.related_urls = ["https://example.com"] + content_object.is_active = True + content_object.url = "https://owasp.org/chapter" + + result = retriever.get_additional_context(content_object, "chapter") + + expected_keys = [ + "location", + "region", + "country", + "postal_code", + "currency", + "meetup_group", + "tags", + "topics", + "leaders", + "related_urls", + "is_active", + "url", + ] + for key in expected_keys: + assert key in result + + def test_get_additional_context_project(self): + """Test getting additional context for project content type.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + content_object = MagicMock() + content_object.project_type = "tool" + content_object.level = "flagship" + content_object.topics = ["security"] + content_object.leaders_raw = ["Alice"] + content_object.related_urls = ["https://project.example.com"] + content_object.is_active = True + content_object.url = "https://owasp.org/project" + + result = retriever.get_additional_context(content_object, "project") + + expected_keys = [ + "project_type", + "level", + "topics", + "leaders", + "related_urls", + "is_active", + "url", + ] + for key in expected_keys: + assert key in result + + def test_get_additional_context_event(self): + """Test getting additional context for event content type.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + content_object = MagicMock() + content_object.start_date = "2023-01-01" + content_object.end_date = "2023-01-02" + content_object.suggested_location = "San Francisco" + content_object.category = "conference" + content_object.latitude = 37.7749 + content_object.longitude = -122.4194 + content_object.url = "https://event.example.com" + content_object.description = "Test event description" + content_object.summary = "Test event summary" + + result = retriever.get_additional_context(content_object, "event") + + expected_keys = [ + "start_date", + "end_date", + "location", + "category", + "latitude", + "longitude", + "url", + "description", + "summary", + ] + for key in expected_keys: + assert key in result + + def test_get_additional_context_committee(self): + """Test getting additional context for committee content type.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + content_object = MagicMock() + content_object.is_active = True + content_object.leaders = ["John Doe", "Jane Smith"] + content_object.url = "https://committee.example.com" + content_object.description = "Test committee description" + content_object.summary = "Test committee summary" + content_object.tags = ["security", "governance"] + content_object.topics = ["policy", "standards"] + content_object.related_urls = ["https://related.example.com"] + + result = retriever.get_additional_context(content_object, "committee") + + expected_keys = [ + "is_active", + "leaders", + "url", + "description", + "summary", + "tags", + "topics", + "related_urls", + ] + for key in expected_keys: + assert key in result + + def test_get_additional_context_message(self): + """Test getting additional context for message content type.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + conversation = MagicMock() + conversation.slack_channel_id = "C1234567890" + + parent_message = MagicMock() + parent_message.ts = "1234567890.123456" + + author = MagicMock() + author.name = "testuser" + + content_object = MagicMock() + content_object.conversation = conversation + content_object.parent_message = parent_message + content_object.ts = "1234567891.123456" + content_object.author = author + + result = retriever.get_additional_context(content_object, "message") + + expected_keys = ["channel", "thread_ts", "ts", "user"] + for key in expected_keys: + assert key in result + + assert result["channel"] == "C1234567890" + assert result["thread_ts"] == "1234567890.123456" + assert result["ts"] == "1234567891.123456" + assert result["user"] == "testuser" + + def test_get_additional_context_message_no_conversation(self): + """Test getting additional context for message with no conversation.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + content_object = MagicMock() + content_object.conversation = None + content_object.parent_message = None + content_object.ts = "1234567891.123456" + content_object.author = None + + result = retriever.get_additional_context(content_object, "message") + + assert result["ts"] == "1234567891.123456" + + def test_get_additional_context_with_app_label(self): + """Test getting additional context with app.model format.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + content_object = MagicMock() + content_object.suggested_location = "Test Location" + + result = retriever.get_additional_context(content_object, "owasp.chapter") + + assert "location" in result + assert result["location"] == "Test Location" + + def test_extract_content_types_from_query_single_type(self): + """Test extracting single content type from query.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + result = retriever.extract_content_types_from_query("Tell me about chapters") + assert result == ["chapter"] + + def test_extract_content_types_from_query_multiple_types(self): + """Test extracting multiple content types from query.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + result = retriever.extract_content_types_from_query("Show me events and projects") + assert set(result) == {"event", "project"} + + def test_extract_content_types_from_query_plural_forms(self): + """Test extracting content types from query with plural forms.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + result = retriever.extract_content_types_from_query("List all committees") + assert result == ["committee"] + + def test_extract_content_types_from_query_no_matches(self): + """Test extracting content types when no matches found.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI"), + ): + retriever = Retriever() + + result = retriever.extract_content_types_from_query("Random query with no keywords") + assert result == [] + + def test_supported_content_types(self): + """Test that supported content types are defined correctly.""" + assert Retriever.SUPPORTED_CONTENT_TYPES == ( + "event", + "project", + "chapter", + "committee", + "message", + ) + + @patch("apps.ai.agent.tools.rag.retriever.Chunk") + def test_retrieve_with_app_label_content_types(self, mock_chunk): + """Test retrieve method with app_label.model content types filter.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI") as mock_openai, + ): + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])] + mock_client.embeddings.create.return_value = mock_response + mock_openai.return_value = mock_client + + mock_annotated = MagicMock() + mock_filtered = MagicMock() + mock_final = MagicMock() + + mock_chunk.objects.annotate.return_value = mock_annotated + mock_annotated.filter.return_value = mock_filtered + mock_filtered.filter.return_value = mock_final + mock_final.select_related.return_value = mock_final + mock_final.order_by.return_value = mock_final + mock_final.__getitem__ = MagicMock(return_value=[]) + + retriever = Retriever() + result = retriever.retrieve("test query", content_types=["owasp.chapter"]) + + assert result == [] + mock_chunk.objects.annotate.assert_called_once() + mock_annotated.filter.assert_called_once() + mock_filtered.filter.assert_called_once() + + @patch("apps.ai.agent.tools.rag.retriever.Chunk") + def test_retrieve_successful_with_chunks(self, mock_chunk): + """Test retrieve method with successful chunk retrieval.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI") as mock_openai, + ): + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])] + mock_client.embeddings.create.return_value = mock_response + mock_openai.return_value = mock_client + + mock_content_object = MagicMock() + mock_content_object.name = "Test Chapter" + mock_content_object.suggested_location = "New York" + + mock_content_type = MagicMock() + mock_content_type.model = "chapter" + + mock_context = MagicMock() + mock_context.content_object = mock_content_object + mock_context.content_type = mock_content_type + mock_context.object_id = "123" + + mock_chunk_instance = MagicMock() + mock_chunk_instance.id = 1 + mock_chunk_instance.text = "Test chunk text" + mock_chunk_instance.similarity = 0.85 + mock_chunk_instance.context = mock_context + + mock_annotated = MagicMock() + mock_filtered = MagicMock() + + mock_chunk.objects.annotate.return_value = mock_annotated + mock_annotated.filter.return_value = mock_filtered + mock_filtered.select_related.return_value = mock_filtered + mock_filtered.order_by.return_value = mock_filtered + mock_filtered.__getitem__ = MagicMock(return_value=[mock_chunk_instance]) + + retriever = Retriever() + result = retriever.retrieve("test query") + + assert len(result) == 1 + assert result[0]["text"] == "Test chunk text" + assert result[0]["similarity"] == 0.85 + assert result[0]["source_type"] == "chapter" + assert result[0]["source_name"] == "Test Chapter" + assert result[0]["source_id"] == "123" + assert "additional_context" in result[0] + + @patch("apps.ai.agent.tools.rag.retriever.Chunk") + def test_retrieve_with_content_types_filter(self, mock_chunk): + """Test retrieve method with content types filter.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI") as mock_openai, + ): + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])] + mock_client.embeddings.create.return_value = mock_response + mock_openai.return_value = mock_client + + mock_annotated = MagicMock() + mock_filtered = MagicMock() + mock_final = MagicMock() + + mock_chunk.objects.annotate.return_value = mock_annotated + mock_annotated.filter.return_value = mock_filtered + mock_filtered.filter.return_value = mock_final + mock_final.select_related.return_value = mock_final + mock_final.order_by.return_value = mock_final + mock_final.__getitem__ = MagicMock(return_value=[]) + + retriever = Retriever() + result = retriever.retrieve("test query", content_types=["chapter"]) + + assert result == [] + mock_chunk.objects.annotate.assert_called_once() + mock_annotated.filter.assert_called_once() + mock_filtered.filter.assert_called_once() + + @patch("apps.ai.agent.tools.rag.retriever.logger") + @patch("apps.ai.agent.tools.rag.retriever.Chunk") + def test_retrieve_with_none_content_object(self, mock_chunk, mock_logger): + """Test retrieve method when content object is None.""" + with ( + patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), + patch("openai.OpenAI") as mock_openai, + ): + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])] + mock_client.embeddings.create.return_value = mock_response + mock_openai.return_value = mock_client + + mock_chunk_instance = MagicMock() + mock_chunk_instance.id = 1 + mock_chunk_instance.context = None + + mock_annotated = MagicMock() + mock_filtered = MagicMock() + + mock_chunk.objects.annotate.return_value = mock_annotated + mock_annotated.filter.return_value = mock_filtered + mock_filtered.select_related.return_value = mock_filtered + mock_filtered.order_by.return_value = mock_filtered + mock_filtered.__getitem__ = MagicMock(return_value=[mock_chunk_instance]) + + retriever = Retriever() + result = retriever.retrieve("test query") + + assert result == [] + mock_logger.warning.assert_called_once_with( + "Content object is None for chunk %s. Skipping.", 1 + ) diff --git a/backend/tests/apps/ai/common/base/__init__.py b/backend/tests/apps/ai/common/base/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/tests/apps/ai/common/base/ai_command_test.py b/backend/tests/apps/ai/common/base/ai_command_test.py new file mode 100644 index 0000000000..2c634d47f4 --- /dev/null +++ b/backend/tests/apps/ai/common/base/ai_command_test.py @@ -0,0 +1,322 @@ +"""Tests for the BaseAICommand class.""" + +import os +from unittest.mock import Mock, patch + +import pytest +from django.core.management.base import BaseCommand +from django.db.models import Model, QuerySet + +from apps.ai.common.base.ai_command import BaseAICommand + + +class MockTestModel(Model): + """Test model for BaseAICommand testing.""" + + def __str__(self): + """Return string representation of MockTestModel.""" + return f"MockTestModel(pk={self.pk})" + + class Meta: + """Meta class for MockTestModel.""" + + app_label = "test" + + +class ConcreteAICommand(BaseAICommand): + """Concrete implementation of BaseAICommand for testing.""" + + def model_class(self): + return MockTestModel + + def entity_name(self): + return "test_entity" + + def entity_name_plural(self): + return "test_entities" + + def key_field_name(self): + return "test_key" + + def extract_content(self, entity): + return ("prose content", "metadata content") + + +@pytest.fixture +def command(): + """Return a concrete command instance for testing.""" + return ConcreteAICommand() + + +@pytest.fixture +def mock_entity(): + """Return a mock entity instance.""" + entity = Mock(spec=MockTestModel) + entity.pk = 1 + entity.test_key = "test-key-123" + entity.is_active = True + return entity + + +@pytest.fixture +def mock_queryset(): + """Return a mock queryset.""" + queryset = Mock(spec=QuerySet) + queryset.count.return_value = 5 + queryset.filter.return_value = queryset + queryset.__getitem__ = Mock(return_value=[]) + return queryset + + +class TestBaseAICommand: + """Test suite for the BaseAICommand class.""" + + def test_command_inheritance(self, command): + """Test that BaseAICommand inherits from BaseCommand.""" + assert isinstance(command, BaseCommand) + + def test_initialization(self, command): + """Test command initialization.""" + assert command.openai_client is None + + def test_abstract_methods_implemented(self, command): + """Test that all abstract methods are properly implemented.""" + assert command.model_class() == MockTestModel + assert command.entity_name() == "test_entity" + assert command.entity_name_plural() == "test_entities" + assert command.key_field_name() == "test_key" + + mock_entity = Mock() + result = command.extract_content(mock_entity) + assert result == ("prose content", "metadata content") + + def test_source_name_default(self, command): + """Test default source_name implementation.""" + result = command.source_name() + assert result == "owasp_test_entity" + + def test_get_base_queryset(self, command): + """Test get_base_queryset method.""" + with patch.object(MockTestModel, "objects") as mock_objects: + mock_manager = Mock() + mock_objects.all.return_value = mock_manager + mock_objects.return_value = mock_manager + + command.get_base_queryset() + mock_objects.all.assert_called_once() + + def test_get_default_queryset(self, command): + """Test get_default_queryset method.""" + with patch.object(command, "get_base_queryset") as mock_base_qs: + mock_queryset = Mock() + mock_filtered_qs = Mock() + mock_queryset.filter.return_value = mock_filtered_qs + mock_base_qs.return_value = mock_queryset + + result = command.get_default_queryset() + + mock_base_qs.assert_called_once() + mock_queryset.filter.assert_called_once_with(is_active=True) + assert result == mock_filtered_qs + + def test_add_common_arguments(self, command): + """Test add_common_arguments method.""" + parser = Mock() + + command.add_common_arguments(parser) + + assert parser.add_argument.call_count == 3 + + calls = parser.add_argument.call_args_list + + assert calls[0][0] == ("--test_entity-key",) + assert calls[0][1]["type"] is str + assert "Process only the test_entity with this key" in calls[0][1]["help"] + + assert calls[1][0] == ("--all",) + assert calls[1][1]["action"] == "store_true" + assert "Process all the test_entities" in calls[1][1]["help"] + + assert calls[2][0] == ("--batch-size",) + assert calls[2][1]["type"] is int + assert calls[2][1]["default"] == 50 + assert "Number of test_entities to process in each batch" in calls[2][1]["help"] + + def test_add_arguments_calls_common(self, command): + """Test that add_arguments calls add_common_arguments.""" + parser = Mock() + + with patch.object(command, "add_common_arguments") as mock_add_common: + command.add_arguments(parser) + mock_add_common.assert_called_once_with(parser) + + def test_get_queryset_with_entity_key(self, command): + """Test get_queryset with entity key option.""" + options = {"test_entity_key": "test-key-123"} + + with patch.object(command, "get_base_queryset") as mock_base_qs: + mock_queryset = Mock() + mock_filtered_qs = Mock() + mock_queryset.filter.return_value = mock_filtered_qs + mock_base_qs.return_value = mock_queryset + + result = command.get_queryset(options) + + mock_base_qs.assert_called_once() + mock_queryset.filter.assert_called_once_with(test_key="test-key-123") + assert result == mock_filtered_qs + + def test_get_queryset_with_all_option(self, command): + """Test get_queryset with all option.""" + options = {"all": True} + + with patch.object(command, "get_base_queryset") as mock_base_qs: + mock_queryset = Mock() + mock_base_qs.return_value = mock_queryset + + result = command.get_queryset(options) + + mock_base_qs.assert_called_once() + assert result == mock_queryset + + def test_get_queryset_default(self, command): + """Test get_queryset with default options.""" + options = {} + + with patch.object(command, "get_default_queryset") as mock_default_qs: + mock_queryset = Mock() + mock_default_qs.return_value = mock_queryset + + result = command.get_queryset(options) + + mock_default_qs.assert_called_once() + assert result == mock_queryset + + def test_get_entity_key_with_key_field(self, command, mock_entity): + """Test get_entity_key with existing key field.""" + result = command.get_entity_key(mock_entity) + assert result == "test-key-123" + + def test_get_entity_key_fallback_to_pk(self, command): + """Test get_entity_key falls back to pk when key field doesn't exist.""" + entity = Mock() + entity.pk = 42 + if hasattr(entity, "test_key"): + delattr(entity, "test_key") + + result = command.get_entity_key(entity) + assert result == "42" + + @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-api-key"}) + @patch("apps.ai.common.base.ai_command.openai.OpenAI") + def test_setup_openai_client_success(self, mock_openai_class, command): + """Test successful OpenAI client setup.""" + mock_client = Mock() + mock_openai_class.return_value = mock_client + + result = command.setup_openai_client() + + assert result is True + assert command.openai_client == mock_client + mock_openai_class.assert_called_once_with(api_key="test-api-key") + + @patch.dict(os.environ, {}, clear=True) + def test_setup_openai_client_no_api_key(self, command): + """Test OpenAI client setup without API key.""" + if "DJANGO_OPEN_AI_SECRET_KEY" in os.environ: + del os.environ["DJANGO_OPEN_AI_SECRET_KEY"] + + with patch.object(command.stdout, "write") as mock_write: + result = command.setup_openai_client() + + assert result is False + assert command.openai_client is None + mock_write.assert_called_once() + call_args = mock_write.call_args[0][0] + assert "DJANGO_OPEN_AI_SECRET_KEY environment variable not set" in str(call_args) + + def test_handle_batch_processing_empty_queryset(self, command, mock_queryset): + """Test handle_batch_processing with empty queryset.""" + mock_queryset.count.return_value = 0 + process_batch_func = Mock() + + with patch.object(command.stdout, "write") as mock_write: + command.handle_batch_processing(mock_queryset, 10, process_batch_func) + + mock_write.assert_called_once_with("No test_entities found to process") + process_batch_func.assert_not_called() + + def test_handle_batch_processing_with_data(self, command): + """Test handle_batch_processing with data.""" + mock_entities = [Mock() for _ in range(15)] + + mock_queryset = Mock() + mock_queryset.count.return_value = 15 + + def mock_getitem(slice_obj): + start = slice_obj.start or 0 + stop = slice_obj.stop + return mock_entities[start:stop] + + mock_queryset.__getitem__ = Mock(side_effect=mock_getitem) + + process_batch_func = Mock(side_effect=[5, 5, 5]) + + with patch.object(command.stdout, "write") as mock_write: + command.handle_batch_processing(mock_queryset, 5, process_batch_func) + + assert process_batch_func.call_count == 3 + + calls = process_batch_func.call_args_list + assert len(calls[0][0][0]) == 5 + assert len(calls[1][0][0]) == 5 + assert len(calls[2][0][0]) == 5 + + write_calls = mock_write.call_args_list + assert len(write_calls) == 2 + assert "Found 15 test_entities to process" in str(write_calls[0]) + assert "Completed processing 15/15 test_entities" in str(write_calls[1]) + + def test_handle_batch_processing_partial_processing(self, command): + """Test handle_batch_processing when some items fail to process.""" + mock_entities = [Mock() for _ in range(10)] + + mock_queryset = Mock() + mock_queryset.count.return_value = 10 + + def mock_getitem(slice_obj): + start = slice_obj.start or 0 + stop = slice_obj.stop + return mock_entities[start:stop] + + mock_queryset.__getitem__ = Mock(side_effect=mock_getitem) + + process_batch_func = Mock(side_effect=[3, 2]) + + with patch.object(command.stdout, "write") as mock_write: + command.handle_batch_processing(mock_queryset, 5, process_batch_func) + + assert process_batch_func.call_count == 2 + + write_calls = mock_write.call_args_list + assert len(write_calls) == 2 + assert "Found 10 test_entities to process" in str(write_calls[0]) + assert "Completed processing 5/10 test_entities" in str(write_calls[1]) + + +class TestBaseAICommandAbstractMethods: + """Test that BaseAICommand abstract methods raise errors when not implemented.""" + + def test_cannot_instantiate_base_class_directly(self): + """Test that BaseAICommand cannot be instantiated directly.""" + with pytest.raises(TypeError): + BaseAICommand() + + def test_abstract_methods_must_be_implemented(self): + """Test that subclasses must implement all abstract methods.""" + + class IncompleteCommand(BaseAICommand): + """Incomplete implementation missing required methods.""" + + with pytest.raises(TypeError): + IncompleteCommand() diff --git a/backend/tests/apps/ai/common/base/chunk_command_test.py b/backend/tests/apps/ai/common/base/chunk_command_test.py new file mode 100644 index 0000000000..852039b277 --- /dev/null +++ b/backend/tests/apps/ai/common/base/chunk_command_test.py @@ -0,0 +1,436 @@ +"""Tests for the BaseChunkCommand class.""" + +from unittest.mock import Mock, patch + +import pytest +from django.contrib.contenttypes.models import ContentType +from django.core.management.base import BaseCommand + +from apps.ai.common.base.chunk_command import BaseChunkCommand +from apps.ai.models.chunk import Chunk +from apps.ai.models.context import Context + + +class ConcreteChunkCommand(BaseChunkCommand): + """Concrete implementation of BaseChunkCommand for testing.""" + + def model_class(self): + mock_model = Mock() + mock_model.__name__ = "MockChunkTestModel" + return mock_model + + def entity_name(self): + return "test_entity" + + def entity_name_plural(self): + return "test_entities" + + def key_field_name(self): + return "test_key" + + def extract_content(self, entity): + return ("prose content", "metadata content") + + +@pytest.fixture +def command(): + """Return a concrete chunk command instance for testing.""" + return ConcreteChunkCommand() + + +@pytest.fixture +def mock_entity(): + """Return a mock entity instance.""" + entity = Mock() + entity.id = 1 + entity.test_key = "test-key-123" + entity.is_active = True + return entity + + +@pytest.fixture +def mock_context(): + """Return a mock context instance.""" + context = Mock(spec=Context) + context.id = 1 + context.content_type_id = 1 + context.object_id = 1 + return context + + +@pytest.fixture +def mock_content_type(): + """Return a mock content type.""" + content_type = Mock(spec=ContentType) + content_type.id = 1 + return content_type + + +@pytest.fixture +def mock_chunks(): + """Return a list of mock chunk instances.""" + chunks = [] + for i in range(3): + chunk = Mock(spec=Chunk) + chunk.id = i + 1 + chunk.text = f"Chunk text {i + 1}" + chunk.context_id = 1 + chunks.append(chunk) + return chunks + + +class TestBaseChunkCommand: + """Test suite for the BaseChunkCommand class.""" + + def test_command_inheritance(self, command): + """Test that BaseChunkCommand inherits from BaseCommand.""" + assert isinstance(command, BaseCommand) + + def test_help_method(self, command): + """Test the help method returns appropriate help text.""" + expected_help = "Create chunks for OWASP test_entity data" + assert command.help() == expected_help + + def test_abstract_methods_implemented(self, command): + """Test that all abstract methods are properly implemented.""" + mock_model = command.model_class() + assert mock_model.__name__ == "MockChunkTestModel" + assert command.entity_name() == "test_entity" + assert command.entity_name_plural() == "test_entities" + assert command.key_field_name() == "test_key" + + mock_entity = Mock() + result = command.extract_content(mock_entity) + assert result == ("prose content", "metadata content") + + @patch("apps.ai.common.base.chunk_command.ContentType.objects.get_for_model") + @patch("apps.ai.common.base.chunk_command.Context.objects.filter") + def test_process_chunks_batch_no_context( + self, + mock_context_filter, + mock_get_content_type, + command, + mock_entity, + mock_content_type, + ): + """Test process_chunks_batch when no context is found.""" + mock_get_content_type.return_value = mock_content_type + mock_context_filter.return_value.first.return_value = None + + with patch.object(command.stdout, "write") as mock_write: + result = command.process_chunks_batch([mock_entity]) + + assert result == 0 + mock_write.assert_called_once() + warning_call = mock_write.call_args[0][0] + assert "No context found for test_entity test-key-123" in str(warning_call) + + @patch("apps.ai.common.base.chunk_command.ContentType.objects.get_for_model") + @patch("apps.ai.common.base.chunk_command.Context.objects.filter") + def test_process_chunks_batch_empty_content( + self, + mock_context_filter, + mock_get_content_type, + command, + mock_entity, + mock_context, + mock_content_type, + ): + """Test process_chunks_batch when extracted content is empty.""" + mock_get_content_type.return_value = mock_content_type + mock_context_filter.return_value.first.return_value = mock_context + + with ( + patch.object(command, "extract_content", return_value=("", "")), + patch.object(command.stdout, "write") as mock_write, + ): + result = command.process_chunks_batch([mock_entity]) + + assert result == 0 + mock_write.assert_called_once_with("No content to chunk for test_entity test-key-123") + + @patch("apps.ai.common.base.chunk_command.ContentType.objects.get_for_model") + @patch("apps.ai.common.base.chunk_command.Context.objects.filter") + @patch("apps.ai.models.chunk.Chunk.split_text") + def test_process_chunks_batch_no_chunks_created( + self, + mock_split_text, + mock_context_filter, + mock_get_content_type, + command, + mock_entity, + mock_context, + mock_content_type, + ): + """Test process_chunks_batch when no chunks are created from text.""" + mock_get_content_type.return_value = mock_content_type + mock_context_filter.return_value.first.return_value = mock_context + mock_split_text.return_value = [] + + with patch.object(command.stdout, "write") as mock_write: + result = command.process_chunks_batch([mock_entity]) + + assert result == 0 + mock_write.assert_called_once() + call_args = mock_write.call_args[0][0] + assert "No chunks created for test_entity test-key-123" in call_args + + @patch("apps.ai.common.base.chunk_command.ContentType.objects.get_for_model") + @patch("apps.ai.common.base.chunk_command.Context.objects.filter") + @patch("apps.ai.models.chunk.Chunk.split_text") + @patch("apps.ai.common.base.chunk_command.create_chunks_and_embeddings") + @patch("apps.ai.models.chunk.Chunk.bulk_save") + def test_process_chunks_batch_success( + self, + mock_bulk_save, + mock_create_chunks, + mock_split_text, + mock_context_filter, + mock_get_content_type, + command, + mock_entity, + mock_context, + mock_content_type, + mock_chunks, + ): + """Test successful chunk processing.""" + mock_get_content_type.return_value = mock_content_type + mock_context_filter.return_value.first.return_value = mock_context + mock_split_text.return_value = ["chunk1", "chunk2", "chunk3"] + mock_create_chunks.return_value = mock_chunks + command.openai_client = Mock() + + with patch.object(command.stdout, "write") as mock_write: + result = command.process_chunks_batch([mock_entity]) + + assert result == 1 + mock_create_chunks.assert_called_once_with( + chunk_texts=["chunk1", "chunk2", "chunk3"], + context=mock_context, + openai_client=command.openai_client, + save=False, + ) + mock_bulk_save.assert_called_once_with(mock_chunks) + mock_write.assert_called_once_with("Created 3 chunks for test-key-123") + + @patch("apps.ai.common.base.chunk_command.ContentType.objects.get_for_model") + @patch("apps.ai.common.base.chunk_command.Context.objects.filter") + @patch("apps.ai.models.chunk.Chunk.split_text") + @patch("apps.ai.common.base.chunk_command.create_chunks_and_embeddings") + @patch("apps.ai.models.chunk.Chunk.bulk_save") + def test_process_chunks_batch_multiple_entities( + self, + mock_bulk_save, + mock_create_chunks, + mock_split_text, + mock_context_filter, + mock_get_content_type, + command, + mock_context, + mock_content_type, + mock_chunks, + ): + """Test processing multiple entities in a batch.""" + entities = [] + for i in range(3): + entity = Mock() + entity.id = i + 1 + entity.test_key = f"test-key-{i + 1}" + entity.is_active = True + entities.append(entity) + + mock_get_content_type.return_value = mock_content_type + mock_context_filter.return_value.first.return_value = mock_context + mock_split_text.return_value = ["chunk1", "chunk2"] + mock_create_chunks.return_value = mock_chunks[:2] + command.openai_client = Mock() + + with patch.object(command.stdout, "write"): + result = command.process_chunks_batch(entities) + + assert result == 3 + assert mock_create_chunks.call_count == 3 + mock_bulk_save.assert_called_once() + bulk_save_args = mock_bulk_save.call_args[0][0] + assert len(bulk_save_args) == 6 + + @patch("apps.ai.common.base.chunk_command.ContentType.objects.get_for_model") + @patch("apps.ai.common.base.chunk_command.Context.objects.filter") + @patch("apps.ai.models.chunk.Chunk.split_text") + @patch("apps.ai.common.base.chunk_command.create_chunks_and_embeddings") + def test_process_chunks_batch_create_chunks_fails( + self, + mock_create_chunks, + mock_split_text, + mock_context_filter, + mock_get_content_type, + command, + mock_entity, + mock_context, + mock_content_type, + ): + """Test process_chunks_batch when create_chunks_and_embeddings fails.""" + mock_get_content_type.return_value = mock_content_type + mock_context_filter.return_value.first.return_value = mock_context + mock_split_text.return_value = ["chunk1", "chunk2"] + mock_create_chunks.return_value = None + command.openai_client = Mock() + + result = command.process_chunks_batch([mock_entity]) + + assert result == 0 + mock_create_chunks.assert_called_once() + + def test_process_chunks_batch_content_combination( + self, command, mock_entity, mock_context, mock_content_type + ): + """Test that metadata and prose content are properly combined.""" + with ( + patch( + "apps.ai.common.base.chunk_command.ContentType.objects.get_for_model" + ) as mock_get_content_type, + patch( + "apps.ai.common.base.chunk_command.Context.objects.filter" + ) as mock_context_filter, + patch("apps.ai.models.chunk.Chunk.split_text") as mock_split_text, + patch( + "apps.ai.common.base.chunk_command.create_chunks_and_embeddings" + ) as mock_create_chunks, + patch("apps.ai.models.chunk.Chunk.bulk_save"), + ): + mock_get_content_type.return_value = mock_content_type + mock_context_filter.return_value.first.return_value = mock_context + mock_split_text.return_value = ["chunk1"] + mock_create_chunks.return_value = [Mock()] + command.openai_client = Mock() + + with patch.object( + command, + "extract_content", + return_value=("prose", "metadata"), + ): + command.process_chunks_batch([mock_entity]) + + expected_content = "metadata\n\nprose" + mock_split_text.assert_called_once_with(expected_content) + + mock_split_text.reset_mock() + with patch.object(command, "extract_content", return_value=("prose", "")): + command.process_chunks_batch([mock_entity]) + + mock_split_text.assert_called_with("prose") + + @patch.object(BaseChunkCommand, "setup_openai_client") + @patch.object(BaseChunkCommand, "get_queryset") + @patch.object(BaseChunkCommand, "handle_batch_processing") + def test_handle_method_success( + self, mock_handle_batch, mock_get_queryset, mock_setup_client, command + ): + """Test the handle method with successful setup.""" + mock_setup_client.return_value = True + mock_queryset = Mock() + mock_get_queryset.return_value = mock_queryset + options = {"batch_size": 10} + + command.handle(**options) + + mock_setup_client.assert_called_once() + mock_get_queryset.assert_called_once_with(options) + mock_handle_batch.assert_called_once_with( + queryset=mock_queryset, + batch_size=10, + process_batch_func=command.process_chunks_batch, + ) + + @patch.object(BaseChunkCommand, "setup_openai_client") + def test_handle_method_openai_setup_fails(self, mock_setup_client, command): + """Test the handle method when OpenAI client setup fails.""" + mock_setup_client.return_value = False + options = {"batch_size": 10} + + with ( + patch.object(command, "get_queryset") as mock_get_queryset, + patch.object(command, "handle_batch_processing") as mock_handle_batch, + ): + command.handle(**options) + + mock_setup_client.assert_called_once() + mock_get_queryset.assert_not_called() + mock_handle_batch.assert_not_called() + + def test_process_chunks_batch_metadata_only_content( + self, command, mock_entity, mock_context, mock_content_type + ): + """Test process_chunks_batch with only metadata content.""" + with ( + patch( + "apps.ai.common.base.chunk_command.ContentType.objects.get_for_model" + ) as mock_get_content_type, + patch( + "apps.ai.common.base.chunk_command.Context.objects.filter" + ) as mock_context_filter, + patch("apps.ai.models.chunk.Chunk.split_text") as mock_split_text, + patch( + "apps.ai.common.base.chunk_command.create_chunks_and_embeddings" + ) as mock_create_chunks, + patch("apps.ai.models.chunk.Chunk.bulk_save") as mock_bulk_save, + ): + mock_get_content_type.return_value = mock_content_type + mock_context_filter.return_value.first.return_value = mock_context + mock_split_text.return_value = ["chunk1"] + mock_create_chunks.return_value = [Mock()] + command.openai_client = Mock() + + with patch.object( + command, + "extract_content", + return_value=("", "metadata"), + ): + command.process_chunks_batch([mock_entity]) + + mock_split_text.assert_called_once_with("metadata\n\n") + mock_bulk_save.assert_called_once() + + def test_process_chunks_batch_whitespace_only_content( + self, command, mock_entity, mock_context, mock_content_type + ): + """Test process_chunks_batch with whitespace-only content.""" + with ( + patch( + "apps.ai.common.base.chunk_command.ContentType.objects.get_for_model" + ) as mock_get_content_type, + patch( + "apps.ai.common.base.chunk_command.Context.objects.filter" + ) as mock_context_filter, + ): + mock_get_content_type.return_value = mock_content_type + mock_context_filter.return_value.first.return_value = mock_context + + with ( + patch.object(command, "extract_content", return_value=(" \n\t ", " \t\n ")), + patch.object(command.stdout, "write") as mock_write, + ): + result = command.process_chunks_batch([mock_entity]) + + assert result == 0 + mock_write.assert_called_once_with( + "No content to chunk for test_entity test-key-123" + ) + + +class TestBaseChunkCommandAbstractMethods: + """Test that BaseChunkCommand requires implementation of abstract methods.""" + + def test_cannot_instantiate_base_class_directly(self): + """Test that BaseChunkCommand cannot be instantiated directly.""" + with pytest.raises(TypeError): + BaseChunkCommand() + + def test_abstract_methods_must_be_implemented(self): + """Test that subclasses must implement all abstract methods.""" + + class IncompleteChunkCommand(BaseChunkCommand): + """Incomplete implementation missing required methods.""" + + with pytest.raises(TypeError): + IncompleteChunkCommand() diff --git a/backend/tests/apps/ai/common/base/context_command_test.py b/backend/tests/apps/ai/common/base/context_command_test.py new file mode 100644 index 0000000000..6273a6c33c --- /dev/null +++ b/backend/tests/apps/ai/common/base/context_command_test.py @@ -0,0 +1,326 @@ +"""Tests for the BaseContextCommand class.""" + +from unittest.mock import Mock, patch + +import pytest +from django.core.management.base import BaseCommand + +from apps.ai.common.base.context_command import BaseContextCommand +from apps.ai.models.context import Context + + +class ConcreteContextCommand(BaseContextCommand): + """Concrete implementation of BaseContextCommand for testing.""" + + def model_class(self): + mock_model = Mock() + mock_model.__name__ = "MockContextTestModel" + return mock_model + + def entity_name(self): + return "test_entity" + + def entity_name_plural(self): + return "test_entities" + + def key_field_name(self): + return "test_key" + + def extract_content(self, entity): + return ("prose content", "metadata content") + + +@pytest.fixture +def command(): + """Return a concrete context command instance for testing.""" + return ConcreteContextCommand() + + +@pytest.fixture +def mock_entity(): + """Return a mock entity instance.""" + entity = Mock() + entity.id = 1 + entity.test_key = "test-key-123" + entity.is_active = True + return entity + + +@pytest.fixture +def mock_context(): + """Return a mock context instance.""" + context = Mock(spec=Context) + context.id = 1 + context.content = "test content" + context.content_type_id = 1 + context.object_id = 1 + return context + + +class TestBaseContextCommand: + """Test suite for the BaseContextCommand class.""" + + def test_command_inheritance(self, command): + """Test that BaseContextCommand inherits from BaseCommand.""" + assert isinstance(command, BaseCommand) + + def test_help_method(self, command): + """Test the help method returns appropriate help text.""" + expected_help = "Update context for OWASP test_entity data" + assert command.help() == expected_help + + def test_abstract_methods_implemented(self, command): + """Test that all abstract methods are properly implemented.""" + mock_model = command.model_class() + assert mock_model.__name__ == "MockContextTestModel" + assert command.entity_name() == "test_entity" + assert command.entity_name_plural() == "test_entities" + assert command.key_field_name() == "test_key" + + mock_entity = Mock() + result = command.extract_content(mock_entity) + assert result == ("prose content", "metadata content") + + def test_process_context_batch_empty_content(self, command, mock_entity): + """Test process_context_batch when extracted content is empty.""" + with ( + patch.object(command, "extract_content", return_value=("", "")), + patch.object(command.stdout, "write") as mock_write, + ): + result = command.process_context_batch([mock_entity]) + + assert result == 0 + mock_write.assert_called_once_with("No content for test_entity test-key-123") + + def test_process_context_batch_whitespace_only_content(self, command, mock_entity): + """Test process_context_batch with whitespace-only content.""" + with ( + patch.object(command, "extract_content", return_value=(" \n\t ", " \t\n ")), + patch.object(command.stdout, "write") as mock_write, + ): + result = command.process_context_batch([mock_entity]) + + assert result == 0 + mock_write.assert_called_once_with("No content for test_entity test-key-123") + + @patch("apps.ai.common.base.context_command.Context") + def test_process_context_batch_success( + self, mock_context_class, command, mock_entity, mock_context + ): + """Test successful context processing.""" + mock_context_class.update_data.return_value = mock_context + + with patch.object(command.stdout, "write") as mock_write: + result = command.process_context_batch([mock_entity]) + + assert result == 1 + mock_context_class.update_data.assert_called_once_with( + content="metadata content\n\nprose content", + content_object=mock_entity, + source="owasp_test_entity", + ) + mock_write.assert_called_once_with("Created context for test-key-123") + + @patch("apps.ai.common.base.context_command.Context") + def test_process_context_batch_creation_fails(self, mock_context_class, command, mock_entity): + """Test process_context_batch when context creation fails.""" + mock_context_class.update_data.return_value = None + + with patch.object(command.stdout, "write") as mock_write: + result = command.process_context_batch([mock_entity]) + + assert result == 0 + mock_context_class.update_data.assert_called_once() + mock_write.assert_called_once() + call_args = mock_write.call_args[0][0] + assert "Failed to create context for test-key-123" in str(call_args) + + @patch("apps.ai.common.base.context_command.Context") + def test_process_context_batch_multiple_entities( + self, mock_context_class, command, mock_context + ): + """Test processing multiple entities in a batch.""" + entities = [] + for i in range(3): + entity = Mock() + entity.id = i + 1 + entity.test_key = f"test-key-{i + 1}" + entity.is_active = True + entities.append(entity) + + mock_context_class.update_data.return_value = mock_context + + with patch.object(command.stdout, "write") as mock_write: + result = command.process_context_batch(entities) + + assert result == 3 + assert mock_context_class.update_data.call_count == 3 + assert mock_write.call_count == 3 + + calls = mock_context_class.update_data.call_args_list + for i, call in enumerate(calls): + args, kwargs = call + assert kwargs["content_object"] == entities[i] + assert kwargs["content"] == "metadata content\n\nprose content" + assert kwargs["source"] == "owasp_test_entity" + + @patch("apps.ai.common.base.context_command.Context") + def test_process_context_batch_mixed_success_failure( + self, mock_context_class, command, mock_context + ): + """Test processing where some entities succeed and others fail.""" + entities = [] + for i in range(3): + entity = Mock() + entity.id = i + 1 + entity.test_key = f"test-key-{i + 1}" + entity.is_active = True + entities.append(entity) + + mock_context_class.update_data.side_effect = [mock_context, None, mock_context] + + with patch.object(command.stdout, "write") as mock_write: + result = command.process_context_batch(entities) + + assert result == 2 + assert mock_context_class.update_data.call_count == 3 + assert mock_write.call_count == 3 + + write_calls = mock_write.call_args_list + assert "Created context for test-key-1" in str(write_calls[0]) + assert "Failed to create context for test-key-2" in str(write_calls[1]) + assert "Created context for test-key-3" in str(write_calls[2]) + + def test_process_context_batch_content_combination(self, command, mock_entity, mock_context): + """Test that metadata and prose content are properly combined.""" + with patch("apps.ai.common.base.context_command.Context") as mock_context_class: + mock_context_class.update_data.return_value = mock_context + + with patch.object(command, "extract_content", return_value=("prose", "metadata")): + command.process_context_batch([mock_entity]) + + expected_content = "metadata\n\nprose" + mock_context_class.update_data.assert_called_once() + call_args = mock_context_class.update_data.call_args[1] + assert call_args["content"] == expected_content + + mock_context_class.update_data.reset_mock() + with patch.object(command, "extract_content", return_value=("prose", "")): + command.process_context_batch([mock_entity]) + + call_args = mock_context_class.update_data.call_args[1] + assert call_args["content"] == "prose" + + def test_process_context_batch_metadata_only_content(self, command, mock_entity, mock_context): + """Test process_context_batch with only metadata content.""" + with patch("apps.ai.common.base.context_command.Context") as mock_context_class: + mock_context_class.update_data.return_value = mock_context + + with patch.object(command, "extract_content", return_value=("", "metadata")): + command.process_context_batch([mock_entity]) + + expected_content = "metadata\n\n" + call_args = mock_context_class.update_data.call_args[1] + assert call_args["content"] == expected_content + + @patch.object(BaseContextCommand, "get_queryset") + @patch.object(BaseContextCommand, "handle_batch_processing") + def test_handle_method(self, mock_handle_batch, mock_get_queryset, command): + """Test the handle method.""" + mock_queryset = Mock() + mock_get_queryset.return_value = mock_queryset + options = {"batch_size": 10} + + command.handle(**options) + + mock_get_queryset.assert_called_once_with(options) + mock_handle_batch.assert_called_once_with( + queryset=mock_queryset, + batch_size=10, + process_batch_func=command.process_context_batch, + ) + + def test_source_name_usage(self, command, mock_entity, mock_context): + """Test that source_name is properly used in context creation.""" + with ( + patch("apps.ai.common.base.context_command.Context") as mock_context_class, + patch.object(command, "source_name", return_value="custom_source"), + ): + mock_context_class.update_data.return_value = mock_context + + command.process_context_batch([mock_entity]) + + call_args = mock_context_class.update_data.call_args[1] + assert call_args["source"] == "custom_source" + + def test_get_entity_key_usage(self, command, mock_context): + """Test that get_entity_key is properly used for display messages.""" + entity = Mock() + entity.test_key = "custom-entity-key" + + with patch("apps.ai.common.base.context_command.Context") as mock_context_class: + mock_context_class.update_data.return_value = mock_context + + with patch.object(command.stdout, "write") as mock_write: + command.process_context_batch([entity]) + + mock_write.assert_called_once_with("Created context for custom-entity-key") + + def test_process_context_batch_empty_list(self, command): + """Test process_context_batch with empty entity list.""" + result = command.process_context_batch([]) + assert result == 0 + + def test_process_context_batch_skips_empty_entities(self, command): + """Test that entities with empty content are properly skipped.""" + entities = [] + + entity1 = Mock() + entity1.test_key = "entity-1" + entities.append(entity1) + + entity2 = Mock() + entity2.test_key = "entity-2" + entities.append(entity2) + + entity3 = Mock() + entity3.test_key = "entity-3" + entities.append(entity3) + + def mock_extract_content(entity): + if entity.test_key == "entity-2": + return ("", "") + return ("prose", "metadata") + + with ( + patch.object(command, "extract_content", side_effect=mock_extract_content), + patch("apps.ai.common.base.context_command.Context") as mock_context_class, + patch.object(command.stdout, "write") as mock_write, + ): + mock_context_class.update_data.return_value = Mock() + + result = command.process_context_batch(entities) + + assert result == 2 + assert mock_context_class.update_data.call_count == 2 + + write_calls = [str(call) for call in mock_write.call_args_list] + assert any("No content for test_entity entity-2" in call for call in write_calls) + + +class TestBaseContextCommandAbstractMethods: + """Test that BaseContextCommand requires implementation of abstract methods.""" + + def test_cannot_instantiate_base_class_directly(self): + """Test that BaseContextCommand cannot be instantiated directly.""" + with pytest.raises(TypeError): + BaseContextCommand() + + def test_abstract_methods_must_be_implemented(self): + """Test that subclasses must implement all abstract methods.""" + + class IncompleteContextCommand(BaseContextCommand): + """Incomplete implementation missing required methods.""" + + with pytest.raises(TypeError): + IncompleteContextCommand() diff --git a/backend/tests/apps/ai/common/base_test.py b/backend/tests/apps/ai/common/base_test.py deleted file mode 100644 index e5a916d35f..0000000000 --- a/backend/tests/apps/ai/common/base_test.py +++ /dev/null @@ -1,664 +0,0 @@ -"""Tests for the base AI command classes.""" - -import os -from unittest.mock import Mock, call, patch - -import pytest -from django.core.management.base import BaseCommand -from django.db import models - -from apps.ai.common.base import BaseAICommand, BaseChunkCommand, BaseContextCommand - - -class MockModel(models.Model): - """Mock model for testing purposes.""" - - name = models.CharField(max_length=100) - key = models.CharField(max_length=50, unique=True) - is_active = models.BooleanField(default=True) - - def __str__(self): - """Return string representation of the model.""" - return self.name - - class Meta: - """Meta class for MockModel.""" - - app_label = "test" - - -class ConcreteBaseAICommand(BaseAICommand): - """Concrete implementation of BaseAICommand for testing.""" - - @property - def model_class(self) -> type[models.Model]: - return MockModel - - @property - def entity_name(self) -> str: - return "test" - - @property - def entity_name_plural(self) -> str: - return "tests" - - @property - def key_field_name(self) -> str: - return "key" - - def extract_content(self, entity: models.Model) -> tuple[str, str]: - return f"Content for {entity.name}", f"Metadata for {entity.name}" - - -class ConcreteBaseContextCommand(BaseContextCommand): - """Concrete implementation of BaseContextCommand for testing.""" - - @property - def model_class(self) -> type[models.Model]: - return MockModel - - @property - def entity_name(self) -> str: - return "test" - - @property - def entity_name_plural(self) -> str: - return "tests" - - @property - def key_field_name(self) -> str: - return "key" - - def extract_content(self, entity: models.Model) -> tuple[str, str]: - return f"Content for {entity.name}", f"Metadata for {entity.name}" - - -class ConcreteBaseChunkCommand(BaseChunkCommand): - """Concrete implementation of BaseChunkCommand for testing.""" - - @property - def model_class(self) -> type[models.Model]: - return MockModel - - @property - def entity_name(self) -> str: - return "test" - - @property - def entity_name_plural(self) -> str: - return "tests" - - @property - def key_field_name(self) -> str: - return "key" - - def extract_content(self, entity: models.Model) -> tuple[str, str]: - return f"Content for {entity.name}", f"Metadata for {entity.name}" - - -@pytest.fixture -def base_ai_command(): - """Return a concrete BaseAICommand instance.""" - return ConcreteBaseAICommand() - - -@pytest.fixture -def base_context_command(): - """Return a concrete BaseContextCommand instance.""" - return ConcreteBaseContextCommand() - - -@pytest.fixture -def base_chunk_command(): - """Return a concrete BaseChunkCommand instance.""" - return ConcreteBaseChunkCommand() - - -@pytest.fixture -def mock_entity(): - """Return a mock entity.""" - entity = Mock(spec=MockModel) - entity.name = "Test Entity" - entity.key = "test-key" - entity.pk = 1 - entity.is_active = True - return entity - - -@pytest.fixture -def mock_queryset(): - """Return a mock queryset.""" - queryset = Mock() - queryset.count.return_value = 3 - queryset.filter.return_value = queryset - queryset.__getitem__ = Mock(side_effect=lambda _: [Mock(), Mock()]) - return queryset - - -class TestBaseAICommand: - """Test suite for BaseAICommand.""" - - def test_command_inheritance(self, base_ai_command): - """Test that the command inherits from BaseCommand.""" - assert isinstance(base_ai_command, BaseCommand) - - def test_initialization(self, base_ai_command): - """Test command initialization.""" - assert base_ai_command.openai_client is None - - def test_abstract_properties(self, base_ai_command): - """Test abstract property implementations.""" - assert base_ai_command.model_class == MockModel - assert base_ai_command.entity_name == "test" - assert base_ai_command.entity_name_plural == "tests" - assert base_ai_command.key_field_name == "key" - - def test_source_name_default(self, base_ai_command): - """Test default source name.""" - assert base_ai_command.source_name == "owasp_test" - - def test_extract_content_implementation(self, base_ai_command, mock_entity): - """Test extract_content implementation.""" - prose, metadata = base_ai_command.extract_content(mock_entity) - assert prose == "Content for Test Entity" - assert metadata == "Metadata for Test Entity" - - @patch.object(ConcreteBaseAICommand, "model_class", MockModel) - def test_get_base_queryset(self, base_ai_command): - """Test get_base_queryset method.""" - with patch.object(MockModel, "objects") as mock_objects: - mock_objects.all.return_value = "base_queryset" - result = base_ai_command.get_base_queryset() - assert result == "base_queryset" - mock_objects.all.assert_called_once() - - @patch.object(ConcreteBaseAICommand, "get_base_queryset") - def test_get_default_queryset(self, mock_get_base, base_ai_command): - """Test get_default_queryset method.""" - mock_queryset = Mock() - mock_get_base.return_value = mock_queryset - mock_queryset.filter.return_value = "filtered_queryset" - - result = base_ai_command.get_default_queryset() - - assert result == "filtered_queryset" - mock_queryset.filter.assert_called_once_with(is_active=True) - - def test_add_common_arguments(self, base_ai_command): - """Test add_common_arguments method.""" - mock_parser = Mock() - mock_parser.add_argument = Mock() - - base_ai_command.add_common_arguments(mock_parser) - - expected_calls = [ - call("--test-key", type=str, help="Process only the test with this key"), - call("--all", action="store_true", help="Process all the tests"), - call( - "--batch-size", - type=int, - default=50, - help="Number of tests to process in each batch", - ), - ] - mock_parser.add_argument.assert_has_calls(expected_calls) - - def test_add_arguments_calls_common(self, base_ai_command): - """Test add_arguments calls add_common_arguments.""" - mock_parser = Mock() - with patch.object(base_ai_command, "add_common_arguments") as mock_add_common: - base_ai_command.add_arguments(mock_parser) - mock_add_common.assert_called_once_with(mock_parser) - - @patch.object(ConcreteBaseAICommand, "get_base_queryset") - def test_get_queryset_with_key_option(self, mock_get_base, base_ai_command): - """Test get_queryset with entity key option.""" - mock_queryset = Mock() - mock_get_base.return_value = mock_queryset - mock_queryset.filter.return_value = "filtered_queryset" - - options = {"test_key": "specific-key"} - result = base_ai_command.get_queryset(options) - - assert result == "filtered_queryset" - mock_queryset.filter.assert_called_once_with(key="specific-key") - - @patch.object(ConcreteBaseAICommand, "get_base_queryset") - def test_get_queryset_with_all_option(self, mock_get_base, base_ai_command): - """Test get_queryset with all option.""" - mock_queryset = Mock() - mock_get_base.return_value = mock_queryset - - options = {"all": True} - result = base_ai_command.get_queryset(options) - - assert result == mock_queryset - - @patch.object(ConcreteBaseAICommand, "get_default_queryset") - def test_get_queryset_default(self, mock_get_default, base_ai_command): - """Test get_queryset with default behavior.""" - mock_get_default.return_value = "default_queryset" - - options = {} - result = base_ai_command.get_queryset(options) - - assert result == "default_queryset" - - def test_get_entity_key(self, base_ai_command, mock_entity): - """Test get_entity_key method.""" - result = base_ai_command.get_entity_key(mock_entity) - assert result == "test-key" - - def test_get_entity_key_fallback_to_pk(self, base_ai_command): - """Test get_entity_key falls back to pk when key field doesn't exist.""" - mock_entity = Mock() - mock_entity.pk = 123 - delattr(mock_entity, "key") if hasattr(mock_entity, "key") else None - - result = base_ai_command.get_entity_key(mock_entity) - assert result == "123" - - @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-api-key"}) - @patch("apps.ai.common.base.openai.OpenAI") - def test_setup_openai_client_success(self, mock_openai_class, base_ai_command): - """Test successful OpenAI client setup.""" - mock_client = Mock() - mock_openai_class.return_value = mock_client - - result = base_ai_command.setup_openai_client() - - assert result is True - assert base_ai_command.openai_client == mock_client - mock_openai_class.assert_called_once_with(api_key="test-api-key") - - @patch.dict(os.environ, {}, clear=True) - def test_setup_openai_client_no_api_key(self, base_ai_command): - """Test OpenAI client setup without API key.""" - with ( - patch.object(base_ai_command.stdout, "write") as mock_write, - patch.object(base_ai_command.style, "ERROR") as mock_error, - ): - mock_error.return_value = "ERROR: No API key" - - result = base_ai_command.setup_openai_client() - - assert result is False - assert base_ai_command.openai_client is None - mock_error.assert_called_once_with( - "DJANGO_OPEN_AI_SECRET_KEY environment variable not set" - ) - mock_write.assert_called_once_with("ERROR: No API key") - - def test_handle_batch_processing_empty_queryset(self, base_ai_command): - """Test batch processing with empty queryset.""" - mock_queryset = Mock() - mock_queryset.count.return_value = 0 - - with patch.object(base_ai_command.stdout, "write") as mock_write: - base_ai_command.handle_batch_processing( - queryset=mock_queryset, - batch_size=10, - process_batch_func=Mock(), - ) - - mock_write.assert_called_once_with("No tests found to process") - - def test_handle_batch_processing_with_items(self, base_ai_command): - """Test batch processing with items.""" - mock_queryset = Mock() - mock_queryset.count.return_value = 5 - - # Mock slicing behavior - batch1 = [Mock(), Mock()] - batch2 = [Mock(), Mock()] - batch3 = [Mock()] - - def getitem_side_effect(slice_obj): - if slice_obj == slice(0, 2): - return batch1 - if slice_obj == slice(2, 4): - return batch2 - if slice_obj == slice(4, 6): - return batch3 - return [] - - mock_queryset.__getitem__ = Mock(side_effect=getitem_side_effect) - - mock_process_func = Mock(side_effect=[2, 2, 1]) # Return processed counts - - with ( - patch.object(base_ai_command.stdout, "write") as mock_write, - patch.object(base_ai_command.style, "SUCCESS") as mock_success, - ): - mock_success.return_value = "SUCCESS: Completed" - - base_ai_command.handle_batch_processing( - queryset=mock_queryset, - batch_size=2, - process_batch_func=mock_process_func, - ) - - # Verify process function was called with correct batches - expected_calls = [call(batch1), call(batch2), call(batch3)] - mock_process_func.assert_has_calls(expected_calls) - - # Verify output messages - assert mock_write.call_count == 2 - mock_write.assert_any_call("Found 5 tests to process") - mock_success.assert_called_once_with("Completed processing 5/5 tests") - - -class TestBaseContextCommand: - """Test suite for BaseContextCommand.""" - - def test_command_inheritance(self, base_context_command): - """Test that the command inherits from BaseAICommand.""" - assert isinstance(base_context_command, BaseAICommand) - - def test_help_property(self, base_context_command): - """Test help property.""" - assert base_context_command.help == "Update context for OWASP test data" - - @patch("apps.ai.common.base.create_context") - def test_process_context_batch_success( - self, mock_create_context, base_context_command - ): - """Test successful context batch processing.""" - mock_create_context.return_value = True - - entities = [ - Mock(name="Entity 1", key="key1"), - Mock(name="Entity 2", key="key2"), - ] - - with patch.object(base_context_command, "extract_content") as mock_extract: - mock_extract.side_effect = [ - ("Content 1", "Metadata 1"), - ("Content 2", "Metadata 2"), - ] - - with patch.object(base_context_command, "get_entity_key") as mock_get_key: - mock_get_key.side_effect = ["key1", "key2"] - - result = base_context_command.process_context_batch(entities) - - assert result == 2 - assert mock_create_context.call_count == 2 - - # Verify create_context was called with correct parameters - expected_calls = [ - call( - content="Metadata 1\n\nContent 1", - content_object=entities[0], - source="owasp_test", - ), - call( - content="Metadata 2\n\nContent 2", - content_object=entities[1], - source="owasp_test", - ), - ] - mock_create_context.assert_has_calls(expected_calls) - - @patch("apps.ai.common.base.create_context") - def test_process_context_batch_empty_content( - self, mock_create_context, base_context_command - ): - """Test context batch processing with empty content.""" - entities = [Mock(name="Empty Entity", key="empty-key")] - - with patch.object(base_context_command, "extract_content") as mock_extract: - mock_extract.return_value = ("", "") - - with patch.object(base_context_command, "get_entity_key") as mock_get_key: - mock_get_key.return_value = "empty-key" - - with patch.object(base_context_command.stdout, "write") as mock_write: - result = base_context_command.process_context_batch(entities) - - assert result == 0 - mock_create_context.assert_not_called() - mock_write.assert_called_once_with("No content for test empty-key") - - @patch("apps.ai.common.base.create_context") - def test_process_context_batch_create_failure( - self, mock_create_context, base_context_command - ): - """Test context batch processing when create_context fails.""" - mock_create_context.return_value = False - - entities = [Mock(name="Failing Entity", key="fail-key")] - - with patch.object(base_context_command, "extract_content") as mock_extract: - mock_extract.return_value = ("Content", "Metadata") - - with patch.object(base_context_command, "get_entity_key") as mock_get_key: - mock_get_key.return_value = "fail-key" - - with ( - patch.object(base_context_command.stdout, "write") as mock_write, - patch.object(base_context_command.style, "ERROR") as mock_error, - ): - mock_error.return_value = "ERROR: Failed" - - result = base_context_command.process_context_batch(entities) - - assert result == 0 - mock_error.assert_called_once_with( - "Failed to create context for fail-key" - ) - mock_write.assert_called_once_with("ERROR: Failed") - - def test_handle_calls_batch_processing(self, base_context_command): - """Test handle method calls handle_batch_processing.""" - options = {"batch_size": 25} - mock_queryset = Mock() - - with patch.object(base_context_command, "get_queryset") as mock_get_queryset: - mock_get_queryset.return_value = mock_queryset - - with patch.object( - base_context_command, "handle_batch_processing" - ) as mock_handle_batch: - base_context_command.handle(**options) - - mock_get_queryset.assert_called_once_with(options) - mock_handle_batch.assert_called_once_with( - queryset=mock_queryset, - batch_size=25, - process_batch_func=base_context_command.process_context_batch, - ) - - -class TestBaseChunkCommand: - """Test suite for BaseChunkCommand.""" - - def test_command_inheritance(self, base_chunk_command): - """Test that the command inherits from BaseAICommand.""" - assert isinstance(base_chunk_command, BaseAICommand) - - def test_help_property(self, base_chunk_command): - """Test help property.""" - assert base_chunk_command.help == "Create chunks for OWASP test data" - - @patch("apps.ai.common.base.create_chunks_and_embeddings") - @patch("apps.ai.common.base.Chunk.bulk_save") - @patch("apps.ai.common.base.Chunk.split_text") - @patch("apps.ai.common.base.Context.objects.filter") - @patch("apps.ai.common.base.ContentType.objects.get_for_model") - def test_process_chunks_batch_success( - self, - mock_get_content_type, - mock_context_filter, - mock_split_text, - mock_bulk_save, - mock_create_chunks, - base_chunk_command, - ): - """Test successful chunks batch processing.""" - # Setup mocks - mock_content_type = Mock() - mock_get_content_type.return_value = mock_content_type - - mock_context = Mock() - mock_context_filter.return_value.first.return_value = mock_context - - mock_split_text.return_value = ["chunk1", "chunk2"] - - mock_chunks = [Mock(), Mock()] - mock_create_chunks.return_value = mock_chunks - - entities = [Mock(id=1, name="Entity 1", key="key1")] - - with patch.object(base_chunk_command, "extract_content") as mock_extract: - mock_extract.return_value = ("Content", "Metadata") - - with patch.object(base_chunk_command, "get_entity_key") as mock_get_key: - mock_get_key.return_value = "key1" - - result = base_chunk_command.process_chunks_batch(entities) - - assert result == 1 - mock_get_content_type.assert_called_once_with(MockModel) - mock_context_filter.assert_called_once_with( - content_type=mock_content_type, object_id=1 - ) - mock_split_text.assert_called_once_with("Metadata\n\nContent") - mock_create_chunks.assert_called_once_with( - chunk_texts=["chunk1", "chunk2"], - context=mock_context, - openai_client=base_chunk_command.openai_client, - save=False, - ) - mock_bulk_save.assert_called_once_with(mock_chunks) - - @patch("apps.ai.common.base.Context.objects.filter") - @patch("apps.ai.common.base.ContentType.objects.get_for_model") - def test_process_chunks_batch_no_context( - self, mock_get_content_type, mock_context_filter, base_chunk_command - ): - """Test chunks batch processing when no context exists.""" - mock_content_type = Mock() - mock_get_content_type.return_value = mock_content_type - mock_context_filter.return_value.first.return_value = None - - entities = [Mock(id=1, name="Entity 1", key="key1")] - - with patch.object(base_chunk_command, "get_entity_key") as mock_get_key: - mock_get_key.return_value = "key1" - - with ( - patch.object(base_chunk_command.stdout, "write") as mock_write, - patch.object(base_chunk_command.style, "WARNING") as mock_warning, - ): - mock_warning.return_value = "WARNING: No context" - - result = base_chunk_command.process_chunks_batch(entities) - - assert result == 0 - mock_warning.assert_called_once_with("No context found for test key1") - mock_write.assert_called_once_with("WARNING: No context") - - @patch("apps.ai.common.base.Chunk.split_text") - @patch("apps.ai.common.base.Context.objects.filter") - @patch("apps.ai.common.base.ContentType.objects.get_for_model") - def test_process_chunks_batch_empty_content( - self, - mock_get_content_type, - mock_context_filter, - mock_split_text, - base_chunk_command, - ): - """Test chunks batch processing with empty content.""" - mock_content_type = Mock() - mock_get_content_type.return_value = mock_content_type - - mock_context = Mock() - mock_context_filter.return_value.first.return_value = mock_context - - entities = [Mock(id=1, name="Entity 1", key="key1")] - - with patch.object(base_chunk_command, "extract_content") as mock_extract: - mock_extract.return_value = ("", "") - - with patch.object(base_chunk_command, "get_entity_key") as mock_get_key: - mock_get_key.return_value = "key1" - - with patch.object(base_chunk_command.stdout, "write") as mock_write: - result = base_chunk_command.process_chunks_batch(entities) - - assert result == 0 - mock_split_text.assert_not_called() - mock_write.assert_called_once_with( - "No content to chunk for test key1" - ) - - @patch("apps.ai.common.base.Chunk.split_text") - @patch("apps.ai.common.base.Context.objects.filter") - @patch("apps.ai.common.base.ContentType.objects.get_for_model") - def test_process_chunks_batch_no_chunks_created( - self, - mock_get_content_type, - mock_context_filter, - mock_split_text, - base_chunk_command, - ): - """Test chunks batch processing when no chunks are created.""" - mock_content_type = Mock() - mock_get_content_type.return_value = mock_content_type - - mock_context = Mock() - mock_context_filter.return_value.first.return_value = mock_context - - mock_split_text.return_value = [] # No chunks created - - entities = [Mock(id=1, name="Entity 1", key="key1")] - - with patch.object(base_chunk_command, "extract_content") as mock_extract: - mock_extract.return_value = ("Content", "Metadata") - - with patch.object(base_chunk_command, "get_entity_key") as mock_get_key: - mock_get_key.return_value = "key1" - - with patch.object(base_chunk_command.stdout, "write") as mock_write: - result = base_chunk_command.process_chunks_batch(entities) - - assert result == 0 - mock_write.assert_called_once_with( - "No chunks created for test key1: `Metadata\n\nContent`" - ) - - def test_handle_calls_setup_and_batch_processing(self, base_chunk_command): - """Test handle method calls setup_openai_client and handle_batch_processing.""" - options = {"batch_size": 25} - mock_queryset = Mock() - - with patch.object(base_chunk_command, "setup_openai_client") as mock_setup: - mock_setup.return_value = True - - with patch.object(base_chunk_command, "get_queryset") as mock_get_queryset: - mock_get_queryset.return_value = mock_queryset - - with patch.object( - base_chunk_command, "handle_batch_processing" - ) as mock_handle_batch: - base_chunk_command.handle(**options) - - mock_setup.assert_called_once() - mock_get_queryset.assert_called_once_with(options) - mock_handle_batch.assert_called_once_with( - queryset=mock_queryset, - batch_size=25, - process_batch_func=base_chunk_command.process_chunks_batch, - ) - - def test_handle_returns_early_if_setup_fails(self, base_chunk_command): - """Test handle method returns early if OpenAI client setup fails.""" - with patch.object(base_chunk_command, "setup_openai_client") as mock_setup: - mock_setup.return_value = False - - with patch.object(base_chunk_command, "get_queryset") as mock_get_queryset: - base_chunk_command.handle() - - mock_setup.assert_called_once() - mock_get_queryset.assert_not_called() diff --git a/backend/tests/apps/ai/common/extractors/__init__.py b/backend/tests/apps/ai/common/extractors/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/tests/apps/ai/common/extractors/chapter_test.py b/backend/tests/apps/ai/common/extractors/chapter_test.py new file mode 100644 index 0000000000..a3dac7e96c --- /dev/null +++ b/backend/tests/apps/ai/common/extractors/chapter_test.py @@ -0,0 +1,299 @@ +"""Tests for chapter content extractor.""" + +from unittest.mock import MagicMock + +from apps.ai.common.extractors.chapter import extract_chapter_content + + +class TestChapterExtractor: + """Test cases for chapter content extraction.""" + + def test_extract_chapter_content_full_data(self): + """Test extraction with complete chapter data.""" + chapter = MagicMock() + chapter.description = "Test chapter description" + chapter.summary = "Test chapter summary" + chapter.name = "Test Chapter" + chapter.country = "USA" + chapter.region = "North America" + chapter.postal_code = "12345" + chapter.suggested_location = "New York, NY" + chapter.currency = "USD" + chapter.meetup_group = "owasp-nyc" + chapter.tags = ["security", "web"] + chapter.topics = ["OWASP Top 10", "Security Testing"] + chapter.leaders_raw = ["John Doe", "Jane Smith"] + chapter.related_urls = ["https://example.com", "https://github.com/test"] + chapter.invalid_urls = [] + chapter.is_active = True + + repo = MagicMock() + repo.description = "Repository for chapter resources" + repo.topics = ["owasp", "security"] + chapter.owasp_repository = repo + + prose, metadata = extract_chapter_content(chapter) + + assert "Description: Test chapter description" in prose + assert "Summary: Test chapter summary" in prose + assert "Repository Description: Repository for chapter resources" in prose + + assert "Chapter Name: Test Chapter" in metadata + assert ( + "Location Information: Country: USA, Region: North America, " + "Postal Code: 12345, Location: New York, NY" in metadata + ) + assert "Currency: USD" in metadata + assert "Meetup Group: owasp-nyc" in metadata + assert "Tags: security, web" in metadata + assert "Topics: OWASP Top 10, Security Testing" in metadata + assert "Chapter Leaders: John Doe, Jane Smith" in metadata + assert "Related URLs: https://example.com, https://github.com/test" in metadata + assert "Active Chapter: Yes" in metadata + assert "Repository Topics: owasp, security" in metadata + + def test_extract_chapter_content_minimal_data(self): + """Test extraction with minimal chapter data.""" + chapter = MagicMock() + chapter.description = None + chapter.summary = None + chapter.name = "Minimal Chapter" + chapter.country = None + chapter.region = None + chapter.postal_code = None + chapter.suggested_location = None + chapter.currency = None + chapter.meetup_group = None + chapter.tags = [] + chapter.topics = [] + chapter.leaders_raw = [] + chapter.related_urls = [] + chapter.invalid_urls = [] + chapter.is_active = False + chapter.owasp_repository = None + + prose, metadata = extract_chapter_content(chapter) + + assert prose == "" + assert "Chapter Name: Minimal Chapter" in metadata + assert "Description:" not in prose + assert "Currency:" not in metadata + assert "Active Chapter: Yes" not in metadata + + def test_extract_chapter_content_empty_fields(self): + """Test extraction with empty string fields.""" + chapter = MagicMock() + chapter.description = "" + chapter.summary = "" + chapter.name = "" + chapter.country = "" + chapter.region = "" + chapter.postal_code = "" + chapter.suggested_location = "" + chapter.currency = "" + chapter.meetup_group = "" + chapter.tags = [] + chapter.topics = [] + chapter.leaders_raw = [] + chapter.related_urls = [] + chapter.invalid_urls = [] + chapter.is_active = False + chapter.owasp_repository = None + + prose, metadata = extract_chapter_content(chapter) + + assert prose == "" + assert metadata == "" + + def test_extract_chapter_content_partial_location(self): + """Test extraction with partial location information.""" + chapter = MagicMock() + chapter.description = None + chapter.summary = None + chapter.name = "Test Chapter" + chapter.country = "Canada" + chapter.region = None + chapter.postal_code = "K1A 0A6" + chapter.suggested_location = None + chapter.currency = None + chapter.meetup_group = None + chapter.tags = [] + chapter.topics = [] + chapter.leaders_raw = [] + chapter.related_urls = [] + chapter.invalid_urls = [] + chapter.is_active = True + chapter.owasp_repository = None + + prose, metadata = extract_chapter_content(chapter) + + assert prose == "" + assert "Chapter Name: Test Chapter" in metadata + assert "Location Information: Country: Canada, Postal Code: K1A 0A6" in metadata + assert "Active Chapter: Yes" in metadata + + def test_extract_chapter_content_with_invalid_urls(self): + """Test extraction with invalid URLs filtered out.""" + chapter = MagicMock() + chapter.description = None + chapter.summary = None + chapter.name = "Test Chapter" + chapter.country = None + chapter.region = None + chapter.postal_code = None + chapter.suggested_location = None + chapter.currency = None + chapter.meetup_group = None + chapter.tags = [] + chapter.topics = [] + chapter.leaders_raw = [] + chapter.related_urls = [ + "https://valid.com", + "https://invalid.com", + "https://another-valid.com", + ] + chapter.invalid_urls = ["https://invalid.com"] + chapter.is_active = False + chapter.owasp_repository = None + + prose, metadata = extract_chapter_content(chapter) + + assert "Related URLs: https://valid.com, https://another-valid.com" in metadata + assert "https://invalid.com" not in metadata + + def test_extract_chapter_content_repository_no_description(self): + """Test extraction when repository has no description.""" + chapter = MagicMock() + chapter.description = "Chapter description" + chapter.summary = None + chapter.name = "Test Chapter" + chapter.country = None + chapter.region = None + chapter.postal_code = None + chapter.suggested_location = None + chapter.currency = None + chapter.meetup_group = None + chapter.tags = [] + chapter.topics = [] + chapter.leaders_raw = [] + chapter.related_urls = [] + chapter.invalid_urls = [] + chapter.is_active = False + + repo = MagicMock() + repo.description = None + repo.topics = ["security"] + chapter.owasp_repository = repo + + prose, metadata = extract_chapter_content(chapter) + + assert "Description: Chapter description" in prose + assert "Repository Description:" not in prose + assert "Repository Topics: security" in metadata + + def test_extract_chapter_content_repository_empty_topics(self): + """Test extraction when repository has empty topics.""" + chapter = MagicMock() + chapter.description = "Chapter description" + chapter.summary = None + chapter.name = "Test Chapter" + chapter.country = None + chapter.region = None + chapter.postal_code = None + chapter.suggested_location = None + chapter.currency = None + chapter.meetup_group = None + chapter.tags = [] + chapter.topics = [] + chapter.leaders_raw = [] + chapter.related_urls = [] + chapter.invalid_urls = [] + chapter.is_active = False + + repo = MagicMock() + repo.description = "Repository description" + repo.topics = [] + chapter.owasp_repository = repo + + prose, metadata = extract_chapter_content(chapter) + + assert "Description: Chapter description" in prose + assert "Repository Description: Repository description" in prose + assert "Repository Topics:" not in metadata + + def test_extract_chapter_content_none_invalid_urls(self): + """Test extraction when invalid_urls is None.""" + chapter = MagicMock() + chapter.description = None + chapter.summary = None + chapter.name = "Test Chapter" + chapter.country = None + chapter.region = None + chapter.postal_code = None + chapter.suggested_location = None + chapter.currency = None + chapter.meetup_group = None + chapter.tags = [] + chapter.topics = [] + chapter.leaders_raw = [] + chapter.related_urls = ["https://valid.com"] + chapter.invalid_urls = None + chapter.is_active = False + chapter.owasp_repository = None + + prose, metadata = extract_chapter_content(chapter) + + assert "Related URLs: https://valid.com" in metadata + + def test_extract_chapter_content_empty_related_urls_after_filter(self): + """Test extraction when all related URLs are invalid.""" + chapter = MagicMock() + chapter.description = None + chapter.summary = None + chapter.name = "Test Chapter" + chapter.country = None + chapter.region = None + chapter.postal_code = None + chapter.suggested_location = None + chapter.currency = None + chapter.meetup_group = None + chapter.tags = [] + chapter.topics = [] + chapter.leaders_raw = [] + chapter.related_urls = ["https://invalid1.com", "https://invalid2.com"] + chapter.invalid_urls = ["https://invalid1.com", "https://invalid2.com"] + chapter.is_active = False + chapter.owasp_repository = None + + prose, metadata = extract_chapter_content(chapter) + + assert "Related URLs:" not in metadata + + def test_extract_chapter_content_with_none_and_empty_urls(self): + """Test extraction with mix of None and empty URLs.""" + chapter = MagicMock() + chapter.description = None + chapter.summary = None + chapter.name = "Test Chapter" + chapter.country = None + chapter.region = None + chapter.postal_code = None + chapter.suggested_location = None + chapter.currency = None + chapter.meetup_group = None + chapter.tags = [] + chapter.topics = [] + chapter.leaders_raw = [] + chapter.related_urls = [ + "https://valid.com", + None, + "", + "https://another-valid.com", + ] + chapter.invalid_urls = [] + chapter.is_active = False + chapter.owasp_repository = None + + prose, metadata = extract_chapter_content(chapter) + + assert "Related URLs: https://valid.com, https://another-valid.com" in metadata diff --git a/backend/tests/apps/ai/common/extractors/committee_test.py b/backend/tests/apps/ai/common/extractors/committee_test.py new file mode 100644 index 0000000000..594cfedaf2 --- /dev/null +++ b/backend/tests/apps/ai/common/extractors/committee_test.py @@ -0,0 +1,202 @@ +"""Tests for committee content extractor.""" + +from unittest.mock import MagicMock + +from apps.ai.common.extractors.committee import extract_committee_content + + +class TestCommitteeExtractor: + """Test cases for committee content extraction.""" + + def test_extract_committee_content_full_data(self): + """Test extraction with complete committee data.""" + committee = MagicMock() + committee.description = "Test committee description" + committee.summary = "Test committee summary" + committee.name = "Test Committee" + committee.tags = ["governance", "policy"] + committee.topics = ["Security Standards", "Best Practices"] + committee.leaders_raw = ["Alice Johnson", "Bob Wilson"] + committee.related_urls = ["https://committee.example.com"] + committee.invalid_urls = [] + committee.is_active = True + + repo = MagicMock() + repo.description = "Repository for committee resources" + repo.topics = ["governance", "standards"] + committee.owasp_repository = repo + + prose, metadata = extract_committee_content(committee) + + assert "Description: Test committee description" in prose + assert "Summary: Test committee summary" in prose + assert "Repository Description: Repository for committee resources" in prose + + assert "Committee Name: Test Committee" in metadata + assert "Tags: governance, policy" in metadata + assert "Topics: Security Standards, Best Practices" in metadata + assert "Committee Leaders: Alice Johnson, Bob Wilson" in metadata + assert "Related URLs: https://committee.example.com" in metadata + assert "Active Committee: Yes" in metadata + assert "Repository Topics: governance, standards" in metadata + + def test_extract_committee_content_minimal_data(self): + """Test extraction with minimal committee data.""" + committee = MagicMock() + committee.description = None + committee.summary = None + committee.name = "Minimal Committee" + committee.tags = [] + committee.topics = [] + committee.leaders_raw = [] + committee.related_urls = [] + committee.invalid_urls = [] + committee.is_active = False + committee.owasp_repository = None + + prose, metadata = extract_committee_content(committee) + + assert prose == "" + assert "Committee Name: Minimal Committee" in metadata + assert "Active Committee: No" in metadata + + def test_extract_committee_content_inactive_committee(self): + """Test extraction with inactive committee.""" + committee = MagicMock() + committee.description = "Inactive committee" + committee.summary = None + committee.name = "Inactive Committee" + committee.tags = [] + committee.topics = [] + committee.leaders_raw = [] + committee.related_urls = [] + committee.invalid_urls = [] + committee.is_active = False + committee.owasp_repository = None + + prose, metadata = extract_committee_content(committee) + + assert "Description: Inactive committee" in prose + assert "Active Committee: No" in metadata + + def test_extract_committee_content_with_invalid_urls(self): + """Test extraction with invalid URLs filtered out.""" + committee = MagicMock() + committee.description = None + committee.summary = None + committee.name = "Test Committee" + committee.tags = [] + committee.topics = [] + committee.leaders_raw = [] + committee.related_urls = ["https://valid.com", "https://invalid.com"] + committee.invalid_urls = ["https://invalid.com"] + committee.is_active = True + committee.owasp_repository = None + + prose, metadata = extract_committee_content(committee) + + assert "Related URLs: https://valid.com" in metadata + assert "https://invalid.com" not in metadata + + def test_extract_committee_content_no_invalid_urls_attr(self): + """Test extraction when invalid_urls attribute doesn't exist.""" + committee = MagicMock() + committee.description = None + committee.summary = None + committee.name = "Test Committee" + committee.tags = [] + committee.topics = [] + committee.leaders_raw = [] + committee.related_urls = ["https://valid.com"] + committee.is_active = True + committee.owasp_repository = None + del committee.invalid_urls + + prose, metadata = extract_committee_content(committee) + + assert "Related URLs: https://valid.com" in metadata + + def test_extract_committee_content_empty_strings(self): + """Test extraction with empty string fields.""" + committee = MagicMock() + committee.description = "" + committee.summary = "" + committee.name = "" + committee.tags = [] + committee.topics = [] + committee.leaders_raw = [] + committee.related_urls = [] + committee.invalid_urls = [] + committee.is_active = True + committee.owasp_repository = None + + prose, metadata = extract_committee_content(committee) + + assert prose == "" + assert "Active Committee: Yes" in metadata + assert "Committee Name:" not in metadata + + def test_extract_committee_content_repository_no_description(self): + """Test extraction when repository has no description.""" + committee = MagicMock() + committee.description = "Committee description" + committee.summary = None + committee.name = "Test Committee" + committee.tags = [] + committee.topics = [] + committee.leaders_raw = [] + committee.related_urls = [] + committee.invalid_urls = [] + committee.is_active = True + + repo = MagicMock() + repo.description = None + repo.topics = ["topic1"] + committee.owasp_repository = repo + + prose, metadata = extract_committee_content(committee) + + assert "Description: Committee description" in prose + assert "Repository Description:" not in prose + assert "Repository Topics: topic1" in metadata + + def test_extract_committee_content_repository_empty_topics(self): + """Test extraction when repository has empty topics.""" + committee = MagicMock() + committee.description = None + committee.summary = None + committee.name = "Test Committee" + committee.tags = [] + committee.topics = [] + committee.leaders_raw = [] + committee.related_urls = [] + committee.invalid_urls = [] + committee.is_active = True + + repo = MagicMock() + repo.description = "Repo description" + repo.topics = [] + committee.owasp_repository = repo + + prose, metadata = extract_committee_content(committee) + + assert "Repository Description: Repo description" in prose + assert "Repository Topics:" not in metadata + + def test_extract_committee_content_all_empty_after_filter(self): + """Test extraction when all URLs are filtered out.""" + committee = MagicMock() + committee.description = None + committee.summary = None + committee.name = "Test Committee" + committee.tags = [] + committee.topics = [] + committee.leaders_raw = [] + committee.related_urls = ["https://invalid1.com", "https://invalid2.com"] + committee.invalid_urls = ["https://invalid1.com", "https://invalid2.com"] + committee.is_active = True + committee.owasp_repository = None + + prose, metadata = extract_committee_content(committee) + + assert "Related URLs:" not in metadata diff --git a/backend/tests/apps/ai/common/extractors/event_test.py b/backend/tests/apps/ai/common/extractors/event_test.py new file mode 100644 index 0000000000..354c1967ae --- /dev/null +++ b/backend/tests/apps/ai/common/extractors/event_test.py @@ -0,0 +1,208 @@ +"""Tests for event content extractor.""" + +from datetime import date +from unittest.mock import MagicMock + +from apps.ai.common.extractors.event import extract_event_content + + +class TestEventExtractor: + """Test cases for event content extraction.""" + + def test_extract_event_content_full_data(self): + """Test extraction with complete event data.""" + event = MagicMock() + event.description = "Test event description" + event.summary = "Test event summary" + event.name = "Test Event" + event.category = "conference" + event.get_category_display.return_value = "Conference" + event.start_date = date(2024, 6, 15) + event.end_date = date(2024, 6, 17) + event.suggested_location = "San Francisco, CA" + event.latitude = 37.7749 + event.longitude = -122.4194 + event.url = "https://event.example.com" + + prose, metadata = extract_event_content(event) + + assert "Description: Test event description" in prose + assert "Summary: Test event summary" in prose + + assert "Event Name: Test Event" in metadata + assert "Category: Conference" in metadata + assert "Start Date: 2024-06-15" in metadata + assert "End Date: 2024-06-17" in metadata + assert "Location: San Francisco, CA" in metadata + assert "Coordinates: 37.7749, -122.4194" in metadata + assert "Event URL: https://event.example.com" in metadata + + def test_extract_event_content_minimal_data(self): + """Test extraction with minimal event data.""" + event = MagicMock() + event.description = None + event.summary = None + event.name = "Minimal Event" + event.category = None + event.start_date = None + event.end_date = None + event.suggested_location = None + event.latitude = None + event.longitude = None + event.url = None + + prose, metadata = extract_event_content(event) + + assert prose == "" + assert "Event Name: Minimal Event" in metadata + assert "Category:" not in metadata + assert "Start Date:" not in metadata + assert "End Date:" not in metadata + assert "Location:" not in metadata + assert "Coordinates:" not in metadata + assert "Event URL:" not in metadata + + def test_extract_event_content_empty_strings(self): + """Test extraction with empty string fields.""" + event = MagicMock() + event.description = "" + event.summary = "" + event.name = "" + event.category = "" + event.get_category_display.return_value = "" + event.start_date = None + event.end_date = None + event.suggested_location = "" + event.latitude = None + event.longitude = None + event.url = "" + + prose, metadata = extract_event_content(event) + + assert prose == "" + assert metadata == "" + + def test_extract_event_content_only_latitude(self): + """Test extraction with only latitude (no coordinates).""" + event = MagicMock() + event.description = None + event.summary = None + event.name = "Test Event" + event.category = None + event.start_date = None + event.end_date = None + event.suggested_location = None + event.latitude = 37.7749 + event.longitude = None + event.url = None + + prose, metadata = extract_event_content(event) + + assert prose == "" + assert "Event Name: Test Event" in metadata + assert "Coordinates:" not in metadata + + def test_extract_event_content_only_longitude(self): + """Test extraction with only longitude (no coordinates).""" + event = MagicMock() + event.description = None + event.summary = None + event.name = "Test Event" + event.category = None + event.start_date = None + event.end_date = None + event.suggested_location = None + event.latitude = None + event.longitude = -122.4194 + event.url = None + + prose, metadata = extract_event_content(event) + + assert prose == "" + assert "Event Name: Test Event" in metadata + assert "Coordinates:" not in metadata + + def test_extract_event_content_zero_coordinates(self): + """Test extraction with zero coordinates (should be included).""" + event = MagicMock() + event.description = None + event.summary = None + event.name = "Test Event" + event.category = None + event.start_date = None + event.end_date = None + event.suggested_location = None + event.latitude = 0.0 + event.longitude = 0.0 + event.url = None + + prose, metadata = extract_event_content(event) + + assert "Event Name: Test Event" in metadata + assert "Coordinates: 0.0, 0.0" in metadata + + def test_extract_event_content_partial_dates(self): + """Test extraction with only start date.""" + event = MagicMock() + event.description = "Event with start date only" + event.summary = None + event.name = "Partial Event" + event.category = "workshop" + event.get_category_display.return_value = "Workshop" + event.start_date = date(2024, 8, 1) + event.end_date = None + event.suggested_location = "Online" + event.latitude = None + event.longitude = None + event.url = "https://online-event.com" + + prose, metadata = extract_event_content(event) + + assert "Description: Event with start date only" in prose + assert "Event Name: Partial Event" in metadata + assert "Category: Workshop" in metadata + assert "Start Date: 2024-08-01" in metadata + assert "End Date:" not in metadata + assert "Location: Online" in metadata + assert "Event URL: https://online-event.com" in metadata + + def test_extract_event_content_only_end_date(self): + """Test extraction with only end date.""" + event = MagicMock() + event.description = None + event.summary = None + event.name = "End Date Event" + event.category = None + event.start_date = None + event.end_date = date(2024, 12, 25) + event.suggested_location = None + event.latitude = None + event.longitude = None + event.url = None + + prose, metadata = extract_event_content(event) + + assert prose == "" + assert "Event Name: End Date Event" in metadata + assert "Start Date:" not in metadata + assert "End Date: 2024-12-25" in metadata + + def test_extract_event_content_category_display_method(self): + """Test that get_category_display method is called properly.""" + event = MagicMock() + event.description = None + event.summary = None + event.name = "Category Test Event" + event.category = "meetup" + event.get_category_display.return_value = "Meetup" + event.start_date = None + event.end_date = None + event.suggested_location = None + event.latitude = None + event.longitude = None + event.url = None + + prose, metadata = extract_event_content(event) + + event.get_category_display.assert_called_once() + assert "Category: Meetup" in metadata diff --git a/backend/tests/apps/ai/common/extractors/project_test.py b/backend/tests/apps/ai/common/extractors/project_test.py new file mode 100644 index 0000000000..2eb2bd4b61 --- /dev/null +++ b/backend/tests/apps/ai/common/extractors/project_test.py @@ -0,0 +1,441 @@ +"""Tests for project content extractor.""" + +from datetime import UTC, datetime +from unittest.mock import MagicMock + +from apps.ai.common.extractors.project import extract_project_content + + +class TestProjectExtractor: + """Test cases for project content extraction.""" + + def test_extract_project_content_full_data(self): + """Test extraction with complete project data.""" + project = MagicMock() + project.description = "Test project description" + project.summary = "Test project summary" + project.name = "Test Project" + project.level = "flagship" + project.type = "tool" + project.languages = ["Python", "JavaScript"] + project.topics = ["security", "web-application-security"] + project.licenses = ["MIT", "Apache-2.0"] + project.tags = ["security", "testing"] + project.custom_tags = ["owasp-top-10"] + project.stars_count = 1500 + project.forks_count = 300 + project.contributors_count = 45 + project.releases_count = 12 + project.open_issues_count = 8 + project.leaders_raw = ["John Doe", "Jane Smith"] + project.related_urls = ["https://project.example.com"] + project.invalid_urls = [] + project.created_at = datetime(2020, 1, 15, tzinfo=UTC) + project.updated_at = datetime(2024, 6, 10, tzinfo=UTC) + project.released_at = datetime(2024, 5, 20, tzinfo=UTC) + project.health_score = 85.75 + project.is_active = True + + repo = MagicMock() + repo.description = "Repository for project resources" + repo.topics = ["security", "python"] + project.owasp_repository = repo + + prose, metadata = extract_project_content(project) + + assert "Description: Test project description" in prose + assert "Summary: Test project summary" in prose + assert "Repository Description: Repository for project resources" in prose + + assert "Project Name: Test Project" in metadata + assert "Project Level: flagship" in metadata + assert "Project Type: tool" in metadata + assert "Programming Languages: Python, JavaScript" in metadata + assert "Topics: security, web-application-security" in metadata + assert "Licenses: MIT, Apache-2.0" in metadata + assert "Tags: security, testing" in metadata + assert "Custom Tags: owasp-top-10" in metadata + assert ( + "Project Statistics: Stars: 1500, Forks: 300, Contributors: 45, " + "Releases: 12, Open Issues: 8" in metadata + ) + assert "Project Leaders: John Doe, Jane Smith" in metadata + assert "Related URLs: https://project.example.com" in metadata + assert "Created: 2020-01-15" in metadata + assert "Last Updated: 2024-06-10" in metadata + assert "Last Release: 2024-05-20" in metadata + assert "Health Score: 85.75" in metadata + assert "Active Project: Yes" in metadata + assert "Repository Topics: security, python" in metadata + + def test_extract_project_content_minimal_data(self): + """Test extraction with minimal project data.""" + project = MagicMock() + project.description = None + project.summary = None + project.name = "Minimal Project" + project.level = None + project.type = None + project.languages = [] + project.topics = [] + project.licenses = [] + project.tags = [] + project.custom_tags = [] + project.stars_count = None + project.forks_count = None + project.contributors_count = None + project.releases_count = None + project.open_issues_count = None + project.leaders_raw = [] + project.related_urls = [] + project.invalid_urls = [] + project.created_at = None + project.updated_at = None + project.released_at = None + project.health_score = None + project.is_active = False + project.owasp_repository = None + + prose, metadata = extract_project_content(project) + + assert prose == "" + assert "Project Name: Minimal Project" in metadata + assert "Active Project: No" in metadata + + def test_extract_project_content_partial_statistics(self): + """Test extraction with partial statistics.""" + project = MagicMock() + project.description = None + project.summary = None + project.name = "Partial Stats Project" + project.level = None + project.type = None + project.languages = [] + project.topics = [] + project.licenses = [] + project.tags = [] + project.custom_tags = [] + project.stars_count = 100 + project.forks_count = None + project.contributors_count = 5 + project.releases_count = None + project.open_issues_count = 3 + project.leaders_raw = [] + project.related_urls = [] + project.invalid_urls = [] + project.created_at = None + project.updated_at = None + project.released_at = None + project.health_score = None + project.is_active = True + project.owasp_repository = None + + prose, metadata = extract_project_content(project) + + assert "Project Statistics: Stars: 100, Contributors: 5, Open Issues: 3" in metadata + + def test_extract_project_content_zero_statistics(self): + """Test extraction with zero values in statistics.""" + project = MagicMock() + project.description = None + project.summary = None + project.name = "Zero Stats Project" + project.level = None + project.type = None + project.languages = [] + project.topics = [] + project.licenses = [] + project.tags = [] + project.custom_tags = [] + project.stars_count = 0 + project.forks_count = 0 + project.contributors_count = 0 + project.releases_count = 0 + project.open_issues_count = 0 + project.leaders_raw = [] + project.related_urls = [] + project.invalid_urls = [] + project.created_at = None + project.updated_at = None + project.released_at = None + project.health_score = None + project.is_active = True + project.owasp_repository = None + + prose, metadata = extract_project_content(project) + + assert "Project Statistics:" not in metadata + + def test_extract_project_content_with_invalid_urls(self): + """Test extraction with invalid URLs filtered out.""" + project = MagicMock() + project.description = None + project.summary = None + project.name = "URL Test Project" + project.level = None + project.type = None + project.languages = [] + project.topics = [] + project.licenses = [] + project.tags = [] + project.custom_tags = [] + project.stars_count = None + project.forks_count = None + project.contributors_count = None + project.releases_count = None + project.open_issues_count = None + project.leaders_raw = [] + project.related_urls = ["https://valid.com", "https://invalid.com"] + project.invalid_urls = ["https://invalid.com"] + project.created_at = None + project.updated_at = None + project.released_at = None + project.health_score = None + project.is_active = True + project.owasp_repository = None + + prose, metadata = extract_project_content(project) + + assert "Related URLs: https://valid.com" in metadata + assert "https://invalid.com" not in metadata + + def test_extract_project_content_dates_only(self): + """Test extraction with only date fields.""" + project = MagicMock() + project.description = None + project.summary = None + project.name = "Date Project" + project.level = None + project.type = None + project.languages = [] + project.topics = [] + project.licenses = [] + project.tags = [] + project.custom_tags = [] + project.stars_count = None + project.forks_count = None + project.contributors_count = None + project.releases_count = None + project.open_issues_count = None + project.leaders_raw = [] + project.related_urls = [] + project.invalid_urls = [] + project.created_at = datetime(2021, 3, 1, tzinfo=UTC) + project.updated_at = None + project.released_at = datetime(2023, 8, 15, tzinfo=UTC) + project.health_score = None + project.is_active = True + project.owasp_repository = None + + prose, metadata = extract_project_content(project) + + assert "Created: 2021-03-01" in metadata + assert "Last Updated:" not in metadata + assert "Last Release: 2023-08-15" in metadata + + def test_extract_project_content_health_score_zero(self): + """Test extraction with zero health score.""" + project = MagicMock() + project.description = None + project.summary = None + project.name = "Zero Health Project" + project.level = None + project.type = None + project.languages = [] + project.topics = [] + project.licenses = [] + project.tags = [] + project.custom_tags = [] + project.stars_count = None + project.forks_count = None + project.contributors_count = None + project.releases_count = None + project.open_issues_count = None + project.leaders_raw = [] + project.related_urls = [] + project.invalid_urls = [] + project.created_at = None + project.updated_at = None + project.released_at = None + project.health_score = 0.0 + project.is_active = True + project.owasp_repository = None + + prose, metadata = extract_project_content(project) + + assert "Health Score: 0.00" in metadata + + def test_extract_project_content_repository_no_description(self): + """Test extraction when repository has no description.""" + project = MagicMock() + project.description = "Project description" + project.summary = None + project.name = "Test Project" + project.level = None + project.type = None + project.languages = [] + project.topics = [] + project.licenses = [] + project.tags = [] + project.custom_tags = [] + project.stars_count = None + project.forks_count = None + project.contributors_count = None + project.releases_count = None + project.open_issues_count = None + project.leaders_raw = [] + project.related_urls = [] + project.invalid_urls = [] + project.created_at = None + project.updated_at = None + project.released_at = None + project.health_score = None + project.is_active = True + + repo = MagicMock() + repo.description = None + repo.topics = ["security"] + project.owasp_repository = repo + + prose, metadata = extract_project_content(project) + + assert "Description: Project description" in prose + assert "Repository Description:" not in prose + assert "Repository Topics: security" in metadata + + def test_extract_project_content_no_invalid_urls_attr(self): + """Test extraction when invalid_urls attribute doesn't exist.""" + project = MagicMock() + project.description = None + project.summary = None + project.name = "Test Project" + project.level = None + project.type = None + project.languages = [] + project.topics = [] + project.licenses = [] + project.tags = [] + project.custom_tags = [] + project.stars_count = None + project.forks_count = None + project.contributors_count = None + project.releases_count = None + project.open_issues_count = None + project.leaders_raw = [] + project.related_urls = ["https://valid.com"] + project.created_at = None + project.updated_at = None + project.released_at = None + project.health_score = None + project.is_active = True + project.owasp_repository = None + del project.invalid_urls + + prose, metadata = extract_project_content(project) + + assert "Related URLs: https://valid.com" in metadata + + def test_extract_project_content_empty_strings(self): + """Test extraction with empty string fields.""" + project = MagicMock() + project.description = "" + project.summary = "" + project.name = "" + project.level = "" + project.type = "" + project.languages = [] + project.topics = [] + project.licenses = [] + project.tags = [] + project.custom_tags = [] + project.stars_count = None + project.forks_count = None + project.contributors_count = None + project.releases_count = None + project.open_issues_count = None + project.leaders_raw = [] + project.related_urls = [] + project.invalid_urls = [] + project.created_at = None + project.updated_at = None + project.released_at = None + project.health_score = None + project.is_active = True + project.owasp_repository = None + + prose, metadata = extract_project_content(project) + + assert prose == "" + assert "Active Project: Yes" in metadata + assert "Project Name:" not in metadata + assert "Project Level:" not in metadata + assert "Project Type:" not in metadata + + def test_extract_project_content_repository_with_topics_only(self): + """Test extraction when repository has topics but no description.""" + project = MagicMock() + project.description = None + project.summary = None + project.name = "Test Project" + project.level = None + project.type = None + project.languages = [] + project.topics = [] + project.licenses = [] + project.tags = [] + project.custom_tags = [] + project.stars_count = None + project.forks_count = None + project.contributors_count = None + project.releases_count = None + project.open_issues_count = None + project.leaders_raw = [] + project.related_urls = [] + project.invalid_urls = [] + project.created_at = None + project.updated_at = None + project.released_at = None + project.health_score = None + project.is_active = True + + repo = MagicMock() + repo.description = None + repo.topics = ["security", "python"] + project.owasp_repository = repo + + prose, metadata = extract_project_content(project) + + assert "Repository Description:" not in prose + assert "Repository Topics: security, python" in metadata + + def test_extract_project_content_with_empty_related_urls(self): + """Test extraction with related_urls containing empty strings.""" + project = MagicMock() + project.description = None + project.summary = None + project.name = "Test Project" + project.level = None + project.type = None + project.languages = [] + project.topics = [] + project.licenses = [] + project.tags = [] + project.custom_tags = [] + project.stars_count = None + project.forks_count = None + project.contributors_count = None + project.releases_count = None + project.open_issues_count = None + project.leaders_raw = [] + project.related_urls = ["https://valid.com", "", "https://another.com"] + project.invalid_urls = [] + project.created_at = None + project.updated_at = None + project.released_at = None + project.health_score = None + project.is_active = True + project.owasp_repository = None + + prose, metadata = extract_project_content(project) + + assert "Related URLs: https://valid.com, https://another.com" in metadata diff --git a/backend/tests/apps/ai/common/utils_test.py b/backend/tests/apps/ai/common/utils_test.py index c56068a375..90879b2592 100644 --- a/backend/tests/apps/ai/common/utils_test.py +++ b/backend/tests/apps/ai/common/utils_test.py @@ -1,7 +1,8 @@ from datetime import UTC, datetime, timedelta -from unittest.mock import MagicMock, call, patch +from unittest.mock import MagicMock, Mock, call, patch import openai +import pytest from apps.ai.common.utils import create_chunks_and_embeddings @@ -33,7 +34,6 @@ def test_create_chunks_and_embeddings_success( ] mock_openai_client.embeddings.create.return_value = mock_api_response - # Create mock chunk instances with .save method mock_chunk1 = MagicMock() mock_chunk2 = MagicMock() mock_update_data.side_effect = [mock_chunk1, mock_chunk2] @@ -54,17 +54,21 @@ def test_create_chunks_and_embeddings_success( mock_update_data.assert_has_calls( [ - call(text="first chunk", embedding=[0.1, 0.2], save=False), - call(text="second chunk", embedding=[0.3, 0.4], save=False), + call( + text="first chunk", + embedding=[0.1, 0.2], + context=mock_content_object, + save=True, + ), + call( + text="second chunk", + embedding=[0.3, 0.4], + context=mock_content_object, + save=True, + ), ] ) - assert mock_chunk1.context == mock_content_object - assert mock_chunk2.context == mock_content_object - - mock_chunk1.save.assert_called_once() - mock_chunk2.save.assert_called_once() - assert result == [mock_chunk1, mock_chunk2] mock_sleep.assert_not_called() @@ -88,6 +92,77 @@ def test_create_chunks_and_embeddings_api_error(self, mock_logger): assert result == [] + def test_create_chunks_and_embeddings_none_context(self): + """Tests the failure path when context is None.""" + mock_openai_client = MagicMock() + + mock_response = MagicMock() + mock_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])] + mock_openai_client.embeddings.create.return_value = mock_response + + with patch("apps.ai.common.utils.Chunk.update_data") as mock_update_data: + mock_chunk = Mock() + mock_update_data.return_value = mock_chunk + + result = create_chunks_and_embeddings( + chunk_texts=["some text"], + context=None, + openai_client=mock_openai_client, + ) + + assert len(result) == 1 + assert result[0] == mock_chunk + + mock_update_data.assert_called_once_with( + text="some text", embedding=[0.1, 0.2, 0.3], context=None, save=True + ) + + @patch("apps.ai.common.utils.logger") + def test_create_chunks_and_embeddings_context_error(self, mock_logger): + """Tests the failure path when there's an embedding mismatch error.""" + mock_openai_client = MagicMock() + mock_context = MagicMock() + + mock_response = MagicMock() + mock_response.data = [] + mock_openai_client.embeddings.create.return_value = mock_response + + with pytest.raises(ValueError, match="zip\\(\\) argument 2 is shorter than argument 1"): + create_chunks_and_embeddings( + chunk_texts=["some text"], + context=mock_context, + openai_client=mock_openai_client, + ) + + mock_logger.exception.assert_called_once_with("Context error") + + @patch("apps.ai.common.utils.time.sleep") + @patch("apps.ai.common.utils.datetime") + def test_create_chunks_and_embeddings_sleep_called(self, mock_datetime, mock_sleep): + """Tests that sleep is called when needed.""" + base_time = datetime.now(UTC) + mock_datetime.now.return_value = base_time + mock_datetime.UTC = UTC + mock_datetime.timedelta = timedelta + + mock_openai_client = MagicMock() + mock_api_response = MagicMock() + mock_api_response.data = [MockEmbeddingData([0.1, 0.2])] + mock_openai_client.embeddings.create.return_value = mock_api_response + + with patch("apps.ai.common.utils.Chunk.update_data") as mock_update_data: + mock_chunk = MagicMock() + mock_update_data.return_value = mock_chunk + + result = create_chunks_and_embeddings( + ["test chunk"], + MagicMock(), + mock_openai_client, + ) + + mock_sleep.assert_not_called() + assert result == [mock_chunk] + @patch("apps.ai.common.utils.Context") @patch("apps.ai.common.utils.Chunk.update_data") @patch("apps.ai.common.utils.time.sleep") @@ -108,13 +183,40 @@ def test_create_chunks_and_embeddings_no_sleep_with_current_settings( mock_chunk = MagicMock() mock_update_data.return_value = mock_chunk + mock_context_obj = MagicMock() result = create_chunks_and_embeddings( ["test chunk"], - MagicMock(), + mock_context_obj, mock_openai_client, ) mock_sleep.assert_not_called() - mock_chunk.save.assert_called_once() + mock_update_data.assert_called_once_with( + text="test chunk", embedding=[0.1, 0.2], context=mock_context_obj, save=True + ) assert result == [mock_chunk] + + @patch("apps.ai.common.utils.logger") + @patch("apps.ai.common.utils.Chunk.update_data") + def test_create_chunks_and_embeddings_chunk_update_value_error( + self, mock_update_data, mock_logger + ): + """Tests the failure path when Chunk.update_data raises ValueError.""" + mock_openai_client = MagicMock() + mock_context = MagicMock() + + mock_response = MagicMock() + mock_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])] + mock_openai_client.embeddings.create.return_value = mock_response + + mock_update_data.side_effect = ValueError("Invalid context") + + with pytest.raises(ValueError, match="Invalid context"): + create_chunks_and_embeddings( + chunk_texts=["some text"], + context=mock_context, + openai_client=mock_openai_client, + ) + + mock_logger.exception.assert_called_once_with("Context error") diff --git a/backend/tests/apps/ai/management/commands/ai_create_chapter_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_create_chapter_chunks_test.py index e501580f9b..5fd1c37077 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_chapter_chunks_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_chapter_chunks_test.py @@ -26,16 +26,16 @@ def test_command_inheritance(self, command): def test_model_class_property(self, command): from apps.owasp.models.chapter import Chapter - assert command.model_class == Chapter + assert command.model_class() == Chapter def test_entity_name_property(self, command): - assert command.entity_name == "chapter" + assert command.entity_name() == "chapter" def test_entity_name_plural_property(self, command): - assert command.entity_name_plural == "chapters" + assert command.entity_name_plural() == "chapters" def test_key_field_name_property(self, command): - assert command.key_field_name == "key" + assert command.key_field_name() == "key" def test_extract_content(self, command, mock_chapter): with patch( diff --git a/backend/tests/apps/ai/management/commands/ai_create_chapter_context_test.py b/backend/tests/apps/ai/management/commands/ai_create_chapter_context_test.py index c140fc6184..f2dafc4064 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_chapter_context_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_chapter_context_test.py @@ -3,20 +3,17 @@ from unittest.mock import Mock, patch import pytest -from django.core.management.base import BaseCommand from apps.ai.management.commands.ai_create_chapter_context import Command @pytest.fixture def command(): - """Return a command instance.""" return Command() @pytest.fixture def mock_chapter(): - """Return a mock Chapter instance.""" chapter = Mock() chapter.id = 1 chapter.key = "test-chapter" @@ -24,36 +21,36 @@ def mock_chapter(): class TestAiCreateChapterContextCommand: - """Test suite for the ai_create_chapter_context command.""" + def test_command_inheritance(self, command): + """Test that the command inherits from BaseContextCommand.""" + from apps.ai.common.base import BaseContextCommand + + assert isinstance(command, BaseContextCommand) def test_command_help_text(self, command): """Test that the command has the correct help text.""" - assert command.help == "Update context for OWASP chapter data" - - def test_command_inheritance(self, command): - """Test that the command inherits from BaseCommand.""" - assert isinstance(command, BaseCommand) + assert command.help() == "Update context for OWASP chapter data" def test_model_class_property(self, command): """Test the model_class property returns Chapter.""" from apps.owasp.models.chapter import Chapter - assert command.model_class == Chapter + assert command.model_class() == Chapter def test_entity_name_property(self, command): """Test the entity_name property.""" - assert command.entity_name == "chapter" + assert command.entity_name() == "chapter" def test_entity_name_plural_property(self, command): """Test the entity_name_plural property.""" - assert command.entity_name_plural == "chapters" + assert command.entity_name_plural() == "chapters" def test_key_field_name_property(self, command): """Test the key_field_name property.""" - assert command.key_field_name == "key" + assert command.key_field_name() == "key" def test_extract_content(self, command, mock_chapter): - """Test content extraction from chapter.""" + """Test the extract_content method.""" with patch( "apps.ai.management.commands.ai_create_chapter_context.extract_chapter_content" ) as mock_extract: diff --git a/backend/tests/apps/ai/management/commands/ai_create_committee_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_create_committee_chunks_test.py index 2ffb2cb098..c7380429fd 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_committee_chunks_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_committee_chunks_test.py @@ -1,3 +1,5 @@ +"""Tests for the ai_create_committee_chunks command.""" + from unittest.mock import Mock, patch import pytest @@ -8,36 +10,49 @@ @pytest.fixture def command(): + """Return a command instance.""" return Command() @pytest.fixture def mock_committee(): + """Return a mock Committee instance.""" committee = Mock() committee.id = 1 committee.key = "test-committee" + committee.name = "Test Committee" + committee.description = "Test committee description" + committee.is_active = True return committee class TestAiCreateCommitteeChunksCommand: + """Test suite for the ai_create_committee_chunks command.""" + def test_command_inheritance(self, command): + """Test that the command inherits from BaseCommand.""" assert isinstance(command, BaseCommand) - def test_model_class_property(self, command): + def test_model_class_method(self, command): + """Test the model_class method returns Committee.""" from apps.owasp.models.committee import Committee - assert command.model_class == Committee + assert command.model_class() == Committee - def test_entity_name_property(self, command): - assert command.entity_name == "committee" + def test_entity_name_method(self, command): + """Test the entity_name method.""" + assert command.entity_name() == "committee" - def test_entity_name_plural_property(self, command): - assert command.entity_name_plural == "committees" + def test_entity_name_plural_method(self, command): + """Test the entity_name_plural method.""" + assert command.entity_name_plural() == "committees" - def test_key_field_name_property(self, command): - assert command.key_field_name == "key" + def test_key_field_name_method(self, command): + """Test the key_field_name method.""" + assert command.key_field_name() == "key" - def test_extract_content(self, command, mock_committee): + def test_extract_content_method(self, command, mock_committee): + """Test the extract_content method.""" with patch( "apps.ai.management.commands.ai_create_committee_chunks.extract_committee_content" ) as mock_extract: @@ -45,3 +60,90 @@ def test_extract_content(self, command, mock_committee): content = command.extract_content(mock_committee) assert content == ("prose content", "metadata content") mock_extract.assert_called_once_with(mock_committee) + + def test_get_base_queryset_calls_super(self, command): + """Test that get_base_queryset calls the parent method.""" + with patch( + "apps.ai.common.base.chunk_command.BaseChunkCommand.get_base_queryset" + ) as mock_super: + mock_super.return_value = "base_queryset" + result = command.get_base_queryset() + assert result == "base_queryset" + mock_super.assert_called_once() + + def test_get_default_queryset_filters_active(self, command): + """Test that get_default_queryset filters for active committees.""" + with patch.object(command, "get_base_queryset") as mock_get_base: + mock_queryset = Mock() + mock_get_base.return_value = mock_queryset + mock_queryset.filter.return_value = "filtered_queryset" + + result = command.get_default_queryset() + + assert result == "filtered_queryset" + mock_queryset.filter.assert_called_once_with(is_active=True) + + def test_add_arguments_calls_super(self, command): + """Test that add_arguments calls the parent method.""" + mock_parser = Mock() + with patch.object(command, "add_common_arguments") as mock_add_common: + command.add_arguments(mock_parser) + mock_add_common.assert_called_once_with(mock_parser) + + def test_get_queryset_with_committee_key(self, command): + """Test get_queryset with committee key option.""" + with patch.object(command, "get_base_queryset") as mock_get_base: + mock_queryset = Mock() + mock_get_base.return_value = mock_queryset + mock_queryset.filter.return_value = "filtered_queryset" + + options = {"committee_key": "specific-committee"} + result = command.get_queryset(options) + + assert result == "filtered_queryset" + mock_queryset.filter.assert_called_once_with(key="specific-committee") + + def test_get_queryset_with_all_option(self, command): + """Test get_queryset with all option.""" + with patch.object(command, "get_base_queryset") as mock_get_base: + mock_queryset = Mock() + mock_get_base.return_value = mock_queryset + + options = {"all": True} + result = command.get_queryset(options) + + assert result == mock_queryset + + def test_get_queryset_default_behavior(self, command): + """Test get_queryset with default behavior.""" + with patch.object(command, "get_default_queryset") as mock_get_default: + mock_get_default.return_value = "default_queryset" + + options = {} + result = command.get_queryset(options) + + assert result == "default_queryset" + + def test_get_entity_key_returns_key(self, command, mock_committee): + """Test get_entity_key returns the committee key.""" + result = command.get_entity_key(mock_committee) + assert result == "test-committee" + + def test_get_entity_key_fallback_to_pk(self, command): + """Test get_entity_key falls back to pk when key field doesn't exist.""" + mock_committee = Mock() + mock_committee.pk = 123 + + if hasattr(mock_committee, "key"): + delattr(mock_committee, "key") + + result = command.get_entity_key(mock_committee) + assert result == "123" + + def test_source_name_default(self, command): + """Test default source name.""" + assert command.source_name() == "owasp_committee" + + def test_help_method(self, command): + """Test the help method.""" + assert command.help() == "Create chunks for OWASP committee data" diff --git a/backend/tests/apps/ai/management/commands/ai_create_committee_context_test.py b/backend/tests/apps/ai/management/commands/ai_create_committee_context_test.py index 30308d1734..b75fdfacbb 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_committee_context_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_committee_context_test.py @@ -1,4 +1,4 @@ -"""A command to update context for OWASP committee data.""" +"""Tests for the ai_create_committee_context command.""" from unittest.mock import Mock, patch @@ -19,42 +19,45 @@ def mock_committee(): committee = Mock() committee.id = 1 committee.key = "test-committee" + committee.name = "Test Committee" + committee.description = "Test committee description" + committee.is_active = True return committee class TestAiCreateCommitteeContextCommand: """Test suite for the ai_create_committee_context command.""" - def test_command_help_text(self, command): - """Test that the command has the correct help text.""" - assert command.help == "Update context for OWASP committee data" - def test_command_inheritance(self, command): """Test that the command inherits from BaseContextCommand.""" from apps.ai.common.base import BaseContextCommand assert isinstance(command, BaseContextCommand) - def test_model_class_property(self, command): - """Test the model_class property returns Committee.""" + def test_command_help_text(self, command): + """Test that the command has the correct help text.""" + assert command.help() == "Update context for OWASP committee data" + + def test_model_class_method(self, command): + """Test the model_class method returns Committee.""" from apps.owasp.models.committee import Committee - assert command.model_class == Committee + assert command.model_class() == Committee - def test_entity_name_property(self, command): - """Test the entity_name property.""" - assert command.entity_name == "committee" + def test_entity_name_method(self, command): + """Test the entity_name method.""" + assert command.entity_name() == "committee" - def test_entity_name_plural_property(self, command): - """Test the entity_name_plural property.""" - assert command.entity_name_plural == "committees" + def test_entity_name_plural_method(self, command): + """Test the entity_name_plural method.""" + assert command.entity_name_plural() == "committees" - def test_key_field_name_property(self, command): - """Test the key_field_name property.""" - assert command.key_field_name == "key" + def test_key_field_name_method(self, command): + """Test the key_field_name method.""" + assert command.key_field_name() == "key" - def test_extract_content(self, command, mock_committee): - """Test content extraction from committee.""" + def test_extract_content_method(self, command, mock_committee): + """Test the extract_content method.""" with patch( "apps.ai.management.commands.ai_create_committee_context.extract_committee_content" ) as mock_extract: @@ -62,3 +65,166 @@ def test_extract_content(self, command, mock_committee): content = command.extract_content(mock_committee) assert content == ("prose content", "metadata content") mock_extract.assert_called_once_with(mock_committee) + + def test_get_base_queryset_calls_super(self, command): + """Test that get_base_queryset calls the parent method.""" + with patch( + "apps.ai.common.base.context_command.BaseContextCommand.get_base_queryset" + ) as mock_super: + mock_super.return_value = "base_queryset" + result = command.get_base_queryset() + assert result == "base_queryset" + mock_super.assert_called_once() + + def test_get_default_queryset_filters_active(self, command): + """Test that get_default_queryset filters for active committees.""" + with patch.object(command, "get_base_queryset") as mock_get_base: + mock_queryset = Mock() + mock_get_base.return_value = mock_queryset + mock_queryset.filter.return_value = "filtered_queryset" + + result = command.get_default_queryset() + + assert result == "filtered_queryset" + mock_queryset.filter.assert_called_once_with(is_active=True) + + def test_add_arguments_calls_super(self, command): + """Test that add_arguments calls the parent method.""" + mock_parser = Mock() + with patch.object(command, "add_common_arguments") as mock_add_common: + command.add_arguments(mock_parser) + mock_add_common.assert_called_once_with(mock_parser) + + def test_get_queryset_with_committee_key(self, command): + """Test get_queryset with committee key option.""" + with patch.object(command, "get_base_queryset") as mock_get_base: + mock_queryset = Mock() + mock_get_base.return_value = mock_queryset + mock_queryset.filter.return_value = "filtered_queryset" + + options = {"committee_key": "specific-committee"} + result = command.get_queryset(options) + + assert result == "filtered_queryset" + mock_queryset.filter.assert_called_once_with(key="specific-committee") + + def test_get_queryset_with_all_option(self, command): + """Test get_queryset with all option.""" + with patch.object(command, "get_base_queryset") as mock_get_base: + mock_queryset = Mock() + mock_get_base.return_value = mock_queryset + + options = {"all": True} + result = command.get_queryset(options) + + assert result == mock_queryset + + def test_get_queryset_default_behavior(self, command): + """Test get_queryset with default behavior.""" + with patch.object(command, "get_default_queryset") as mock_get_default: + mock_get_default.return_value = "default_queryset" + + options = {} + result = command.get_queryset(options) + + assert result == "default_queryset" + + def test_get_entity_key_returns_key(self, command, mock_committee): + """Test get_entity_key returns the committee key.""" + result = command.get_entity_key(mock_committee) + assert result == "test-committee" + + def test_get_entity_key_fallback_to_pk(self, command): + """Test get_entity_key falls back to pk when key field doesn't exist.""" + mock_committee = Mock() + mock_committee.pk = 123 + + if hasattr(mock_committee, "key"): + delattr(mock_committee, "key") + + result = command.get_entity_key(mock_committee) + assert result == "123" + + def test_source_name_default(self, command): + """Test default source name.""" + assert command.source_name() == "owasp_committee" + + def test_process_context_batch_success(self, command, mock_committee): + """Test successful context batch processing.""" + with patch("apps.ai.common.base.context_command.Context") as mock_context_class: + mock_context_class.update_data.return_value = True + + with patch.object(command, "extract_content") as mock_extract: + mock_extract.return_value = ("Content", "Metadata") + + with patch.object(command, "get_entity_key") as mock_get_key: + mock_get_key.return_value = "test-committee" + + with patch.object(command.stdout, "write") as mock_write: + result = command.process_context_batch([mock_committee]) + + assert result == 1 + mock_context_class.update_data.assert_called_once_with( + content="Metadata\n\nContent", + content_object=mock_committee, + source="owasp_committee", + ) + mock_write.assert_called_once_with("Created context for test-committee") + + def test_process_context_batch_empty_content(self, command, mock_committee): + """Test context batch processing with empty content.""" + with patch.object(command, "extract_content") as mock_extract: + mock_extract.return_value = ("", "") + + with patch.object(command, "get_entity_key") as mock_get_key: + mock_get_key.return_value = "test-committee" + + with patch.object(command.stdout, "write") as mock_write: + result = command.process_context_batch([mock_committee]) + + assert result == 0 + mock_write.assert_called_once_with("No content for committee test-committee") + + def test_process_context_batch_create_failure(self, command, mock_committee): + """Test context batch processing when Context.update_data fails.""" + with patch("apps.ai.common.base.context_command.Context") as mock_context_class: + mock_context_class.update_data.return_value = False + + with patch.object(command, "extract_content") as mock_extract: + mock_extract.return_value = ("Content", "Metadata") + + with patch.object(command, "get_entity_key") as mock_get_key: + mock_get_key.return_value = "test-committee" + + with ( + patch.object(command.stdout, "write") as mock_write, + patch.object(command.style, "ERROR") as mock_error, + ): + mock_error.return_value = "ERROR: Failed" + + result = command.process_context_batch([mock_committee]) + + assert result == 0 + mock_error.assert_called_once_with( + "Failed to create context for test-committee" + ) + mock_write.assert_called_once_with("ERROR: Failed") + + def test_handle_calls_batch_processing(self, command): + """Test that handle method calls batch processing.""" + mock_queryset = Mock() + mock_queryset.count.return_value = 2 + + with patch.object(command, "get_queryset") as mock_get_queryset: + mock_get_queryset.return_value = mock_queryset + + with patch.object(command, "handle_batch_processing") as mock_batch_processing: + options = {"batch_size": 50} + command.handle(**options) + + mock_get_queryset.assert_called_once_with(options) + mock_batch_processing.assert_called_once_with( + queryset=mock_queryset, + batch_size=50, + process_batch_func=command.process_context_batch, + ) diff --git a/backend/tests/apps/ai/management/commands/ai_create_event_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_create_event_chunks_test.py index 4c76cef827..e74424dec1 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_event_chunks_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_event_chunks_test.py @@ -34,19 +34,19 @@ def test_model_class_property(self, command): """Test the model_class property returns Event.""" from apps.owasp.models.event import Event - assert command.model_class == Event + assert command.model_class() == Event def test_entity_name_property(self, command): """Test the entity_name property.""" - assert command.entity_name == "event" + assert command.entity_name() == "event" def test_entity_name_plural_property(self, command): """Test the entity_name_plural property.""" - assert command.entity_name_plural == "events" + assert command.entity_name_plural() == "events" def test_key_field_name_property(self, command): """Test the key_field_name property.""" - assert command.key_field_name == "key" + assert command.key_field_name() == "key" def test_extract_content(self, command, mock_event): """Test content extraction from event.""" @@ -58,23 +58,21 @@ def test_extract_content(self, command, mock_event): assert content == ("prose content", "metadata content") mock_extract.assert_called_once_with(mock_event) - def test_get_default_queryset(self, command): - """Test that the default queryset returns upcoming events.""" - with patch( - "apps.ai.management.commands.ai_create_event_chunks.Event.upcoming_events" - ) as mock_upcoming: - mock_queryset = Mock() - mock_upcoming.return_value = mock_queryset - result = command.get_default_queryset() - assert result == mock_queryset - mock_upcoming.assert_called_once() - def test_get_base_queryset(self, command): """Test get_base_queryset calls super().get_base_queryset().""" with patch( - "apps.ai.common.base.BaseAICommand.get_base_queryset", + "apps.ai.common.base.ai_command.BaseAICommand.get_base_queryset", return_value="base_qs", ) as mock_super: result = command.get_base_queryset() assert result == "base_qs" mock_super.assert_called_once() + + def test_get_default_queryset(self, command): + """Test that the default queryset returns upcoming events.""" + with patch("apps.owasp.models.event.Event.upcoming_events") as mock_upcoming: + mock_queryset = Mock() + mock_upcoming.return_value = mock_queryset + result = command.get_default_queryset() + assert result == mock_queryset + mock_upcoming.assert_called_once() diff --git a/backend/tests/apps/ai/management/commands/ai_create_event_context_test.py b/backend/tests/apps/ai/management/commands/ai_create_event_context_test.py index 65eed2e15f..e10640409d 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_event_context_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_event_context_test.py @@ -3,20 +3,17 @@ from unittest.mock import Mock, patch import pytest -from django.core.management.base import BaseCommand from apps.ai.management.commands.ai_create_event_context import Command @pytest.fixture def command(): - """Return a command instance.""" return Command() @pytest.fixture def mock_event(): - """Return a mock Event instance.""" event = Mock() event.id = 1 event.key = "test-event" @@ -24,36 +21,36 @@ def mock_event(): class TestAiCreateEventContextCommand: - """Test suite for the ai_create_event_context command.""" + def test_command_inheritance(self, command): + """Test that the command inherits from BaseContextCommand.""" + from apps.ai.common.base import BaseContextCommand + + assert isinstance(command, BaseContextCommand) def test_command_help_text(self, command): """Test that the command has the correct help text.""" - assert command.help == "Update context for OWASP event data" - - def test_command_inheritance(self, command): - """Test that the command inherits from BaseCommand.""" - assert isinstance(command, BaseCommand) + assert command.help() == "Update context for OWASP event data" def test_model_class_property(self, command): """Test the model_class property returns Event.""" from apps.owasp.models.event import Event - assert command.model_class == Event + assert command.model_class() == Event def test_entity_name_property(self, command): """Test the entity_name property.""" - assert command.entity_name == "event" + assert command.entity_name() == "event" def test_entity_name_plural_property(self, command): """Test the entity_name_plural property.""" - assert command.entity_name_plural == "events" + assert command.entity_name_plural() == "events" def test_key_field_name_property(self, command): """Test the key_field_name property.""" - assert command.key_field_name == "key" + assert command.key_field_name() == "key" def test_extract_content(self, command, mock_event): - """Test content extraction from event.""" + """Test the extract_content method.""" with patch( "apps.ai.management.commands.ai_create_event_context.extract_event_content" ) as mock_extract: @@ -70,3 +67,11 @@ def test_get_default_queryset(self, command): result = command.get_default_queryset() assert result == mock_queryset mock_upcoming.assert_called_once() + + def test_get_base_queryset(self, command): + """Test the get_base_queryset method.""" + with patch.object(command.__class__.__bases__[0], "get_base_queryset") as mock_super: + mock_super.return_value = Mock() + result = command.get_base_queryset() + mock_super.assert_called_once() + assert result == mock_super.return_value diff --git a/backend/tests/apps/ai/management/commands/ai_create_project_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_create_project_chunks_test.py index 4f2d377c34..d80fc14eae 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_project_chunks_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_project_chunks_test.py @@ -26,16 +26,16 @@ def test_command_inheritance(self, command): def test_model_class_property(self, command): from apps.owasp.models.project import Project - assert command.model_class == Project + assert command.model_class() == Project def test_entity_name_property(self, command): - assert command.entity_name == "project" + assert command.entity_name() == "project" def test_entity_name_plural_property(self, command): - assert command.entity_name_plural == "projects" + assert command.entity_name_plural() == "projects" def test_key_field_name_property(self, command): - assert command.key_field_name == "key" + assert command.key_field_name() == "key" def test_extract_content(self, command, mock_project): with patch( @@ -47,6 +47,11 @@ def test_extract_content(self, command, mock_project): mock_extract.assert_called_once_with(mock_project) def test_get_base_queryset_calls_super(self, command): - with patch("apps.ai.common.base.BaseChunkCommand.get_base_queryset") as mock_super: - command.get_base_queryset() + """Test that get_base_queryset calls the parent method.""" + with patch( + "apps.ai.common.base.chunk_command.BaseChunkCommand.get_base_queryset" + ) as mock_super: + mock_super.return_value = "base_queryset" + result = command.get_base_queryset() + assert result == "base_queryset" mock_super.assert_called_once() diff --git a/backend/tests/apps/ai/management/commands/ai_create_project_context_test.py b/backend/tests/apps/ai/management/commands/ai_create_project_context_test.py index 4dd616ea23..44fba9d9d0 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_project_context_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_project_context_test.py @@ -1,20 +1,17 @@ from unittest.mock import Mock, patch import pytest -from django.core.management.base import BaseCommand from apps.ai.management.commands.ai_create_project_context import Command @pytest.fixture def command(): - """Return a command instance.""" return Command() @pytest.fixture def mock_project(): - """Return a mock Project instance.""" project = Mock() project.id = 1 project.key = "test-project" @@ -22,32 +19,36 @@ def mock_project(): class TestAiCreateProjectContextCommand: - """Test suite for the ai_create_project_context command.""" - def test_command_inheritance(self, command): - """Test that the command inherits from BaseCommand.""" - assert isinstance(command, BaseCommand) + """Test that the command inherits from BaseContextCommand.""" + from apps.ai.common.base import BaseContextCommand + + assert isinstance(command, BaseContextCommand) + + def test_command_help_text(self, command): + """Test that the command has the correct help text.""" + assert command.help() == "Update context for OWASP project data" def test_model_class_property(self, command): """Test the model_class property returns Project.""" from apps.owasp.models.project import Project - assert command.model_class == Project + assert command.model_class() == Project def test_entity_name_property(self, command): """Test the entity_name property.""" - assert command.entity_name == "project" + assert command.entity_name() == "project" def test_entity_name_plural_property(self, command): """Test the entity_name_plural property.""" - assert command.entity_name_plural == "projects" + assert command.entity_name_plural() == "projects" def test_key_field_name_property(self, command): """Test the key_field_name property.""" - assert command.key_field_name == "key" + assert command.key_field_name() == "key" def test_extract_content(self, command, mock_project): - """Test content extraction from project.""" + """Test the extract_content method.""" with patch( "apps.ai.management.commands.ai_create_project_context.extract_project_content" ) as mock_extract: @@ -55,3 +56,11 @@ def test_extract_content(self, command, mock_project): content = command.extract_content(mock_project) assert content == ("prose content", "metadata content") mock_extract.assert_called_once_with(mock_project) + + def test_get_base_queryset(self, command): + """Test the get_base_queryset method.""" + with patch.object(command.__class__.__bases__[0], "get_base_queryset") as mock_super: + mock_super.return_value = Mock() + result = command.get_base_queryset() + mock_super.assert_called_once() + assert result == mock_super.return_value diff --git a/backend/tests/apps/ai/management/commands/ai_create_slack_message_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_create_slack_message_chunks_test.py index a6cbae5df1..fece9d9114 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_slack_message_chunks_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_slack_message_chunks_test.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, Mock +from unittest.mock import Mock, patch import pytest from django.core.management.base import BaseCommand @@ -8,80 +8,76 @@ @pytest.fixture def command(): - """Return a command instance.""" return Command() @pytest.fixture def mock_message(): - """Return a mock Message instance.""" message = Mock() message.id = 1 - message.slack_message_id = "test-message-123" - message.cleaned_text = "This is a test Slack message content." + message.slack_message_id = "test-message-id" return message class TestAiCreateSlackMessageChunksCommand: - """Test suite for the ai_create_slack_message_chunks command.""" - def test_command_inheritance(self, command): - """Test that the command inherits from BaseCommand.""" assert isinstance(command, BaseCommand) def test_model_class_property(self, command): - """Test the model_class property returns Message.""" from apps.slack.models.message import Message - assert command.model_class == Message + assert command.model_class() == Message def test_entity_name_property(self, command): - """Test the entity_name property.""" - assert command.entity_name == "message" + assert command.entity_name() == "message" def test_entity_name_plural_property(self, command): - """Test the entity_name_plural property.""" - assert command.entity_name_plural == "messages" + assert command.entity_name_plural() == "messages" def test_key_field_name_property(self, command): - """Test the key_field_name property.""" - assert command.key_field_name == "slack_message_id" + assert command.key_field_name() == "slack_message_id" def test_source_name_property(self, command): """Test the source_name property.""" - assert command.source_name == "slack_message" + assert command.source_name() == "slack_message" def test_extract_content(self, command, mock_message): - """Test content extraction from message.""" + """Test the extract_content method.""" + mock_message.cleaned_text = "Test message content" content = command.extract_content(mock_message) - assert content == ("This is a test Slack message content.", "") + assert content == ("Test message content", "") - def test_extract_content_empty_text(self, command): - """Test content extraction when message has no cleaned_text.""" - message = Mock() - message.cleaned_text = None - content = command.extract_content(message) + def test_extract_content_none_cleaned_text(self, command, mock_message): + """Test the extract_content method with None cleaned_text.""" + mock_message.cleaned_text = None + content = command.extract_content(mock_message) assert content == ("", "") + def test_get_default_queryset(self, command): + """Test the get_default_queryset method.""" + with patch.object(command, "get_base_queryset") as mock_base: + mock_base.return_value = Mock() + result = command.get_default_queryset() + mock_base.assert_called_once() + assert result == mock_base.return_value + def test_add_arguments(self, command): - """Test that the command adds the correct arguments.""" - parser = MagicMock() + """Test the add_arguments method.""" + parser = Mock() command.add_arguments(parser) assert parser.add_argument.call_count == 3 - parser.add_argument.assert_any_call( - "--message-key", - type=str, - help="Process only the message with this key", - ) - parser.add_argument.assert_any_call( - "--all", - action="store_true", - help="Process all the messages", - ) - parser.add_argument.assert_any_call( - "--batch-size", - type=int, - default=100, - help="Number of messages to process in each batch", - ) + calls = parser.add_argument.call_args_list + + assert calls[0][0] == ("--message-key",) + assert calls[0][1]["type"] is str + assert "Process only the message with this key" in calls[0][1]["help"] + + assert calls[1][0] == ("--all",) + assert calls[1][1]["action"] == "store_true" + assert "Process all the messages" in calls[1][1]["help"] + + assert calls[2][0] == ("--batch-size",) + assert calls[2][1]["type"] is int + assert calls[2][1]["default"] == 100 + assert "Number of messages to process in each batch" in calls[2][1]["help"] diff --git a/backend/tests/apps/ai/management/commands/ai_create_slack_message_context_test.py b/backend/tests/apps/ai/management/commands/ai_create_slack_message_context_test.py index 93d805961b..fa6bd75a48 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_slack_message_context_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_slack_message_context_test.py @@ -1,7 +1,6 @@ -from unittest.mock import MagicMock, Mock +from unittest.mock import Mock, patch import pytest -from django.core.management.base import BaseCommand from apps.ai.management.commands.ai_create_slack_message_context import Command @@ -15,60 +14,76 @@ def command(): def mock_message(): message = Mock() message.id = 1 - message.pk = 1 - message.slack_message_id = "test-message-123" - message.cleaned_text = "This is a test Slack message content." + message.slack_message_id = "test-message-id" return message class TestAiCreateSlackMessageContextCommand: def test_command_inheritance(self, command): - assert isinstance(command, BaseCommand) + """Test that the command inherits from BaseContextCommand.""" + from apps.ai.common.base import BaseContextCommand + + assert isinstance(command, BaseContextCommand) def test_model_class_property(self, command): + """Test the model_class property returns Message.""" from apps.slack.models.message import Message - assert command.model_class == Message + assert command.model_class() == Message def test_entity_name_property(self, command): - assert command.entity_name == "message" + """Test the entity_name property.""" + assert command.entity_name() == "message" def test_entity_name_plural_property(self, command): - assert command.entity_name_plural == "messages" + """Test the entity_name_plural property.""" + assert command.entity_name_plural() == "messages" def test_key_field_name_property(self, command): - assert command.key_field_name == "slack_message_id" + """Test the key_field_name property.""" + assert command.key_field_name() == "slack_message_id" def test_source_name_property(self, command): - assert command.source_name == "slack_message" + """Test the source_name property.""" + assert command.source_name() == "slack_message" def test_extract_content(self, command, mock_message): + """Test the extract_content method.""" + mock_message.cleaned_text = "Test message content" content = command.extract_content(mock_message) - assert content == ("This is a test Slack message content.", "") + assert content == ("Test message content", "") - def test_extract_content_empty_text(self, command): - message = Mock() - message.cleaned_text = None - content = command.extract_content(message) + def test_extract_content_none_cleaned_text(self, command, mock_message): + """Test the extract_content method with None cleaned_text.""" + mock_message.cleaned_text = None + content = command.extract_content(mock_message) assert content == ("", "") + def test_get_default_queryset(self, command): + """Test the get_default_queryset method.""" + with patch.object(command, "get_base_queryset") as mock_base: + mock_base.return_value = Mock() + result = command.get_default_queryset() + mock_base.assert_called_once() + assert result == mock_base.return_value + def test_add_arguments(self, command): - parser = MagicMock() + """Test the add_arguments method.""" + parser = Mock() command.add_arguments(parser) + assert parser.add_argument.call_count == 3 - parser.add_argument.assert_any_call( - "--message-key", - type=str, - help="Process only the message with this key", - ) - parser.add_argument.assert_any_call( - "--all", - action="store_true", - help="Process all the messages", - ) - parser.add_argument.assert_any_call( - "--batch-size", - type=int, - default=100, - help="Number of messages to process in each batch", - ) + calls = parser.add_argument.call_args_list + + assert calls[0][0] == ("--message-key",) + assert calls[0][1]["type"] is str + assert "Process only the message with this key" in calls[0][1]["help"] + + assert calls[1][0] == ("--all",) + assert calls[1][1]["action"] == "store_true" + assert "Process all the messages" in calls[1][1]["help"] + + assert calls[2][0] == ("--batch-size",) + assert calls[2][1]["type"] is int + assert calls[2][1]["default"] == 100 + assert "Number of messages to process in each batch" in calls[2][1]["help"] diff --git a/backend/tests/apps/ai/models/chunk_test.py b/backend/tests/apps/ai/models/chunk_test.py index df5021c6e8..94f6432be4 100644 --- a/backend/tests/apps/ai/models/chunk_test.py +++ b/backend/tests/apps/ai/models/chunk_test.py @@ -6,10 +6,9 @@ from apps.ai.models.context import Context -def create_model_mock(model_class): - mock = Mock(spec=model_class) - mock._state = Mock() - mock.pk = 1 +@pytest.fixture +def mock_context(): + mock = Mock(spec=Context) mock.id = 1 return mock @@ -55,47 +54,40 @@ def test_split_text(self): def test_update_data_save_with_context(self, mock_save): text = "Test chunk content" embedding = [0.1, 0.2, 0.3] + mock_context = Mock(spec=Context) with patch("apps.ai.models.chunk.Chunk") as mock_chunk: chunk_instance = Mock() chunk_instance.context_id = 123 mock_chunk.return_value = chunk_instance - result = Chunk.update_data(text=text, embedding=embedding, save=True) + result = Chunk.update_data( + text=text, embedding=embedding, context=mock_context, save=True + ) - mock_chunk.assert_called_once_with(text=text, embedding=embedding) + mock_chunk.assert_called_once_with( + text=text, embedding=embedding, context=mock_context + ) chunk_instance.save.assert_called_once() assert result is chunk_instance - def test_update_data_save_without_context_raises(self): - text = "Test chunk content" - embedding = [0.1, 0.2, 0.3] - - with patch("apps.ai.models.chunk.Chunk") as mock_chunk: - chunk_instance = Mock() - chunk_instance.context_id = None - mock_chunk.return_value = chunk_instance - - with pytest.raises( - ValueError, match="Chunk must have a context assigned before saving." - ): - Chunk.update_data(text=text, embedding=embedding, save=True) - - mock_chunk.assert_called_once_with(text=text, embedding=embedding) - chunk_instance.save.assert_not_called() - def test_update_data_no_save(self): text = "Test chunk content" embedding = [0.1, 0.2, 0.3] + mock_context = Mock(spec=Context) with patch("apps.ai.models.chunk.Chunk") as mock_chunk: chunk_instance = Mock() chunk_instance.context_id = None mock_chunk.return_value = chunk_instance - result = Chunk.update_data(text=text, embedding=embedding, save=False) + result = Chunk.update_data( + text=text, embedding=embedding, context=mock_context, save=False + ) - mock_chunk.assert_called_once_with(text=text, embedding=embedding) + mock_chunk.assert_called_once_with( + text=text, embedding=embedding, context=mock_context + ) chunk_instance.save.assert_not_called() assert result is chunk_instance diff --git a/backend/tests/apps/ai/models/context_test.py b/backend/tests/apps/ai/models/context_test.py index 5648638718..86dae0de21 100644 --- a/backend/tests/apps/ai/models/context_test.py +++ b/backend/tests/apps/ai/models/context_test.py @@ -158,3 +158,123 @@ def test_update_data_existing_context(self, mock_filter): content_type=mock_content_type, object_id=1, content=content ) assert result == mock_context + + def test_str_method_with_name_attribute(self): + """Test __str__ method when content_object has name attribute.""" + content_object = Mock() + content_object.name = "Test Object" + + content_type = Mock() + content_type.model = "test_model" + + with ( + patch.object(Context, "content_object", content_object), + patch.object(Context, "content_type", content_type), + ): + context = Context() + context.content = ( + "This is test content that is longer than 50 characters to test truncation" + ) + + result = str(context) + assert ( + result + == "test_model Test Object: This is test content that is longer than 50 charac" + ) + + def test_str_method_with_key_attribute(self): + """Test __str__ method when content_object has key but no name attribute.""" + content_object = Mock() + content_object.name = None + content_object.key = "test-key" + + content_type = Mock() + content_type.model = "test_model" + + with ( + patch.object(Context, "content_object", content_object), + patch.object(Context, "content_type", content_type), + ): + context = Context() + context.content = "Short content" + + result = str(context) + assert result == "test_model test-key: Short content" + + def test_str_method_fallback_to_str(self): + """Test __str__ method falls back to str(content_object).""" + content_object = Mock() + content_object.name = None + content_object.key = None + content_object.__str__ = Mock(return_value="String representation") + + content_type = Mock() + content_type.model = "test_model" + + with ( + patch.object(Context, "content_object", content_object), + patch.object(Context, "content_type", content_type), + ): + context = Context() + context.content = "Test content" + + result = str(context) + assert result == "test_model String representation: Test content" + + @patch("apps.ai.models.context.Context.objects.filter") + @patch("apps.ai.models.context.Context.__init__") + @patch("apps.ai.models.context.Context.save") + def test_update_data_new_context_with_save(self, mock_save, mock_init, mock_filter): + """Test update_data creating a new context with save=True.""" + mock_filter.return_value.first.return_value = None + mock_init.return_value = None + + content = "New test content" + mock_content_object = Mock() + mock_content_object.pk = 1 + source = "test_source" + + with patch( + "apps.ai.models.context.ContentType.objects.get_for_model" + ) as mock_get_for_model: + mock_content_type = Mock() + mock_get_for_model.return_value = mock_content_type + + result = Context.update_data(content, mock_content_object, source=source, save=True) + + mock_get_for_model.assert_called_once_with(mock_content_object) + mock_filter.assert_called_once_with( + content_type=mock_content_type, object_id=1, content=content + ) + mock_save.assert_called_once() + assert isinstance(result, Context) + + @patch("apps.ai.models.context.Context.objects.filter") + @patch("apps.ai.models.context.Context.__init__") + def test_update_data_new_context_without_save(self, mock_init, mock_filter): + """Test update_data creating a new context with save=False.""" + mock_filter.return_value.first.return_value = None + mock_init.return_value = None + + content = "New test content" + mock_content_object = Mock() + mock_content_object.pk = 1 + source = "test_source" + + with patch( + "apps.ai.models.context.ContentType.objects.get_for_model" + ) as mock_get_for_model: + mock_content_type = Mock() + mock_get_for_model.return_value = mock_content_type + + with patch("apps.ai.models.context.Context.save") as mock_save: + result = Context.update_data( + content, mock_content_object, source=source, save=False + ) + + mock_get_for_model.assert_called_once_with(mock_content_object) + mock_filter.assert_called_once_with( + content_type=mock_content_type, object_id=1, content=content + ) + mock_save.assert_not_called() + assert isinstance(result, Context) From 1c7fe1cf249c0c197ccf358f028c617086030363 Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Wed, 13 Aug 2025 19:08:28 +0530 Subject: [PATCH 18/32] refactoring --- backend/apps/ai/common/base/__init__.py | 7 ------- backend/apps/ai/common/base/chunk_command.py | 2 +- backend/apps/ai/common/utils.py | 3 --- 3 files changed, 1 insertion(+), 11 deletions(-) diff --git a/backend/apps/ai/common/base/__init__.py b/backend/apps/ai/common/base/__init__.py index 8f794890e0..e69de29bb2 100644 --- a/backend/apps/ai/common/base/__init__.py +++ b/backend/apps/ai/common/base/__init__.py @@ -1,7 +0,0 @@ -"""Base classes for AI management commands.""" - -from .ai_command import BaseAICommand -from .chunk_command import BaseChunkCommand -from .context_command import BaseContextCommand - -__all__ = ["BaseAICommand", "BaseChunkCommand", "BaseContextCommand"] diff --git a/backend/apps/ai/common/base/chunk_command.py b/backend/apps/ai/common/base/chunk_command.py index 1dc929f434..aacf367042 100644 --- a/backend/apps/ai/common/base/chunk_command.py +++ b/backend/apps/ai/common/base/chunk_command.py @@ -47,7 +47,7 @@ def process_chunks_batch(self, entities: list[Model]) -> int: chunk_texts = Chunk.split_text(full_content) if not chunk_texts: self.stdout.write( - f"No chunks created for {self.entity_name()} {entity_key}: `{full_content}`" + f"No chunks created for {self.entity_name()} {entity_key}" ) continue diff --git a/backend/apps/ai/common/utils.py b/backend/apps/ai/common/utils.py index 95592516ae..f018c0b301 100644 --- a/backend/apps/ai/common/utils.py +++ b/backend/apps/ai/common/utils.py @@ -64,9 +64,6 @@ def create_chunks_and_embeddings( ) chunks.append(chunk) - except ValueError: - logger.exception("Context error") - raise except openai.OpenAIError: logger.exception("Failed to create chunks and embeddings") return [] From 948c52959bed4a1a925b22f3b009cc2c703dd349 Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Wed, 13 Aug 2025 19:51:49 +0530 Subject: [PATCH 19/32] more tests --- backend/apps/ai/common/base/chunk_command.py | 4 +- backend/apps/ai/common/utils.py | 8 +--- backend/tests/apps/ai/common/utils_test.py | 44 ------------------- .../ai_create_chapter_context_test.py | 2 +- .../ai_create_committee_context_test.py | 2 +- .../commands/ai_create_event_context_test.py | 2 +- .../ai_create_project_context_test.py | 2 +- .../ai_create_slack_message_context_test.py | 2 +- 8 files changed, 8 insertions(+), 58 deletions(-) diff --git a/backend/apps/ai/common/base/chunk_command.py b/backend/apps/ai/common/base/chunk_command.py index aacf367042..a8b84174fa 100644 --- a/backend/apps/ai/common/base/chunk_command.py +++ b/backend/apps/ai/common/base/chunk_command.py @@ -46,9 +46,7 @@ def process_chunks_batch(self, entities: list[Model]) -> int: chunk_texts = Chunk.split_text(full_content) if not chunk_texts: - self.stdout.write( - f"No chunks created for {self.entity_name()} {entity_key}" - ) + self.stdout.write(f"No chunks created for {self.entity_name()} {entity_key}") continue if chunks := create_chunks_and_embeddings( diff --git a/backend/apps/ai/common/utils.py b/backend/apps/ai/common/utils.py index f018c0b301..81bde210db 100644 --- a/backend/apps/ai/common/utils.py +++ b/backend/apps/ai/common/utils.py @@ -47,9 +47,7 @@ def create_chunks_and_embeddings( time_since_last_request = datetime.now(UTC) - last_request_time if time_since_last_request < timedelta(seconds=MIN_REQUEST_INTERVAL_SECONDS): - time.sleep( - MIN_REQUEST_INTERVAL_SECONDS - time_since_last_request.total_seconds() - ) + time.sleep(MIN_REQUEST_INTERVAL_SECONDS - time_since_last_request.total_seconds()) response = openai_client.embeddings.create( input=chunk_texts, @@ -59,9 +57,7 @@ def create_chunks_and_embeddings( chunks = [] for text, embedding in zip(chunk_texts, embeddings, strict=True): - chunk = Chunk.update_data( - text=text, embedding=embedding, context=context, save=save - ) + chunk = Chunk.update_data(text=text, embedding=embedding, context=context, save=save) chunks.append(chunk) except openai.OpenAIError: diff --git a/backend/tests/apps/ai/common/utils_test.py b/backend/tests/apps/ai/common/utils_test.py index 90879b2592..4f6890cffa 100644 --- a/backend/tests/apps/ai/common/utils_test.py +++ b/backend/tests/apps/ai/common/utils_test.py @@ -2,7 +2,6 @@ from unittest.mock import MagicMock, Mock, call, patch import openai -import pytest from apps.ai.common.utils import create_chunks_and_embeddings @@ -117,25 +116,6 @@ def test_create_chunks_and_embeddings_none_context(self): text="some text", embedding=[0.1, 0.2, 0.3], context=None, save=True ) - @patch("apps.ai.common.utils.logger") - def test_create_chunks_and_embeddings_context_error(self, mock_logger): - """Tests the failure path when there's an embedding mismatch error.""" - mock_openai_client = MagicMock() - mock_context = MagicMock() - - mock_response = MagicMock() - mock_response.data = [] - mock_openai_client.embeddings.create.return_value = mock_response - - with pytest.raises(ValueError, match="zip\\(\\) argument 2 is shorter than argument 1"): - create_chunks_and_embeddings( - chunk_texts=["some text"], - context=mock_context, - openai_client=mock_openai_client, - ) - - mock_logger.exception.assert_called_once_with("Context error") - @patch("apps.ai.common.utils.time.sleep") @patch("apps.ai.common.utils.datetime") def test_create_chunks_and_embeddings_sleep_called(self, mock_datetime, mock_sleep): @@ -196,27 +176,3 @@ def test_create_chunks_and_embeddings_no_sleep_with_current_settings( text="test chunk", embedding=[0.1, 0.2], context=mock_context_obj, save=True ) assert result == [mock_chunk] - - @patch("apps.ai.common.utils.logger") - @patch("apps.ai.common.utils.Chunk.update_data") - def test_create_chunks_and_embeddings_chunk_update_value_error( - self, mock_update_data, mock_logger - ): - """Tests the failure path when Chunk.update_data raises ValueError.""" - mock_openai_client = MagicMock() - mock_context = MagicMock() - - mock_response = MagicMock() - mock_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])] - mock_openai_client.embeddings.create.return_value = mock_response - - mock_update_data.side_effect = ValueError("Invalid context") - - with pytest.raises(ValueError, match="Invalid context"): - create_chunks_and_embeddings( - chunk_texts=["some text"], - context=mock_context, - openai_client=mock_openai_client, - ) - - mock_logger.exception.assert_called_once_with("Context error") diff --git a/backend/tests/apps/ai/management/commands/ai_create_chapter_context_test.py b/backend/tests/apps/ai/management/commands/ai_create_chapter_context_test.py index f2dafc4064..b25acaff0a 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_chapter_context_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_chapter_context_test.py @@ -23,7 +23,7 @@ def mock_chapter(): class TestAiCreateChapterContextCommand: def test_command_inheritance(self, command): """Test that the command inherits from BaseContextCommand.""" - from apps.ai.common.base import BaseContextCommand + from apps.ai.common.base.context_command import BaseContextCommand assert isinstance(command, BaseContextCommand) diff --git a/backend/tests/apps/ai/management/commands/ai_create_committee_context_test.py b/backend/tests/apps/ai/management/commands/ai_create_committee_context_test.py index b75fdfacbb..3d2090ed74 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_committee_context_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_committee_context_test.py @@ -30,7 +30,7 @@ class TestAiCreateCommitteeContextCommand: def test_command_inheritance(self, command): """Test that the command inherits from BaseContextCommand.""" - from apps.ai.common.base import BaseContextCommand + from apps.ai.common.base.context_command import BaseContextCommand assert isinstance(command, BaseContextCommand) diff --git a/backend/tests/apps/ai/management/commands/ai_create_event_context_test.py b/backend/tests/apps/ai/management/commands/ai_create_event_context_test.py index e10640409d..00a4a3fd53 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_event_context_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_event_context_test.py @@ -23,7 +23,7 @@ def mock_event(): class TestAiCreateEventContextCommand: def test_command_inheritance(self, command): """Test that the command inherits from BaseContextCommand.""" - from apps.ai.common.base import BaseContextCommand + from apps.ai.common.base.context_command import BaseContextCommand assert isinstance(command, BaseContextCommand) diff --git a/backend/tests/apps/ai/management/commands/ai_create_project_context_test.py b/backend/tests/apps/ai/management/commands/ai_create_project_context_test.py index 44fba9d9d0..80976f99d5 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_project_context_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_project_context_test.py @@ -21,7 +21,7 @@ def mock_project(): class TestAiCreateProjectContextCommand: def test_command_inheritance(self, command): """Test that the command inherits from BaseContextCommand.""" - from apps.ai.common.base import BaseContextCommand + from apps.ai.common.base.context_command import BaseContextCommand assert isinstance(command, BaseContextCommand) diff --git a/backend/tests/apps/ai/management/commands/ai_create_slack_message_context_test.py b/backend/tests/apps/ai/management/commands/ai_create_slack_message_context_test.py index fa6bd75a48..961826dc42 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_slack_message_context_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_slack_message_context_test.py @@ -21,7 +21,7 @@ def mock_message(): class TestAiCreateSlackMessageContextCommand: def test_command_inheritance(self, command): """Test that the command inherits from BaseContextCommand.""" - from apps.ai.common.base import BaseContextCommand + from apps.ai.common.base.context_command import BaseContextCommand assert isinstance(command, BaseContextCommand) From 1e8d65ebec2d5c291d196fc9822756a905517829 Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Wed, 13 Aug 2025 20:01:50 +0530 Subject: [PATCH 20/32] more refactoring --- backend/tests/apps/ai/agent/tools/rag/generator_test.py | 3 --- backend/tests/apps/ai/agent/tools/rag/retriever_test.py | 1 - 2 files changed, 4 deletions(-) diff --git a/backend/tests/apps/ai/agent/tools/rag/generator_test.py b/backend/tests/apps/ai/agent/tools/rag/generator_test.py index 2cd07fb14c..903bdf636d 100644 --- a/backend/tests/apps/ai/agent/tools/rag/generator_test.py +++ b/backend/tests/apps/ai/agent/tools/rag/generator_test.py @@ -119,8 +119,6 @@ def test_generate_answer_success(self): mock_client.chat.completions.create.assert_called_once() call_args = mock_client.chat.completions.create.call_args assert call_args[1]["model"] == "gpt-4o" - assert call_args[1]["temperature"] == 0.4 - assert call_args[1]["max_tokens"] == 2000 assert len(call_args[1]["messages"]) == 2 assert call_args[1]["messages"][0]["role"] == "system" assert call_args[1]["messages"][1]["role"] == "user" @@ -195,6 +193,5 @@ def test_system_prompt_content(self): def test_constants(self): """Test class constants have expected values.""" assert Generator.MAX_TOKENS == 2000 - assert Generator.TEMPERATURE == 0.4 assert isinstance(Generator.SYSTEM_PROMPT, str) assert len(Generator.SYSTEM_PROMPT) > 0 diff --git a/backend/tests/apps/ai/agent/tools/rag/retriever_test.py b/backend/tests/apps/ai/agent/tools/rag/retriever_test.py index 3e3e254295..dff1563737 100644 --- a/backend/tests/apps/ai/agent/tools/rag/retriever_test.py +++ b/backend/tests/apps/ai/agent/tools/rag/retriever_test.py @@ -503,7 +503,6 @@ def test_retrieve_successful_with_chunks(self, mock_chunk): assert len(result) == 1 assert result[0]["text"] == "Test chunk text" - assert result[0]["similarity"] == 0.85 assert result[0]["source_type"] == "chapter" assert result[0]["source_name"] == "Test Chapter" assert result[0]["source_id"] == "123" From bd8f28014e6d1197dfe13192c359860f23674cef Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Thu, 14 Aug 2025 08:54:59 +0530 Subject: [PATCH 21/32] suggestions implemented --- backend/apps/ai/common/utils.py | 3 ++- backend/apps/ai/models/chunk.py | 5 ++++- .../ai/common/base/context_command_test.py | 2 +- .../apps/ai/common/extractors/chapter_test.py | 8 ++++---- .../ai/common/extractors/committee_test.py | 6 +++--- .../apps/ai/common/extractors/event_test.py | 4 ++-- .../apps/ai/common/extractors/project_test.py | 14 +++++++------- backend/tests/apps/ai/models/chunk_test.py | 18 ++++++++++++++++++ 8 files changed, 41 insertions(+), 19 deletions(-) diff --git a/backend/apps/ai/common/utils.py b/backend/apps/ai/common/utils.py index 81bde210db..97ab4a0abf 100644 --- a/backend/apps/ai/common/utils.py +++ b/backend/apps/ai/common/utils.py @@ -58,7 +58,8 @@ def create_chunks_and_embeddings( chunks = [] for text, embedding in zip(chunk_texts, embeddings, strict=True): chunk = Chunk.update_data(text=text, embedding=embedding, context=context, save=save) - chunks.append(chunk) + if chunk is not None: + chunks.append(chunk) except openai.OpenAIError: logger.exception("Failed to create chunks and embeddings") diff --git a/backend/apps/ai/models/chunk.py b/backend/apps/ai/models/chunk.py index efd1f24800..edb4b12756 100644 --- a/backend/apps/ai/models/chunk.py +++ b/backend/apps/ai/models/chunk.py @@ -48,7 +48,7 @@ def update_data( context: Context, *, save: bool = True, - ) -> "Chunk": + ) -> "Chunk | None": """Update chunk data. Args: @@ -61,6 +61,9 @@ def update_data( Chunk: The created chunk instance. """ + if Chunk.objects.filter(context=context, text=text).exists(): + return None + chunk = Chunk(text=text, embedding=embedding, context=context) if save: diff --git a/backend/tests/apps/ai/common/base/context_command_test.py b/backend/tests/apps/ai/common/base/context_command_test.py index 6273a6c33c..9b88ef57c6 100644 --- a/backend/tests/apps/ai/common/base/context_command_test.py +++ b/backend/tests/apps/ai/common/base/context_command_test.py @@ -159,7 +159,7 @@ def test_process_context_batch_multiple_entities( calls = mock_context_class.update_data.call_args_list for i, call in enumerate(calls): - args, kwargs = call + _, kwargs = call assert kwargs["content_object"] == entities[i] assert kwargs["content"] == "metadata content\n\nprose content" assert kwargs["source"] == "owasp_test_entity" diff --git a/backend/tests/apps/ai/common/extractors/chapter_test.py b/backend/tests/apps/ai/common/extractors/chapter_test.py index a3dac7e96c..f97a96ae24 100644 --- a/backend/tests/apps/ai/common/extractors/chapter_test.py +++ b/backend/tests/apps/ai/common/extractors/chapter_test.py @@ -156,7 +156,7 @@ def test_extract_chapter_content_with_invalid_urls(self): chapter.is_active = False chapter.owasp_repository = None - prose, metadata = extract_chapter_content(chapter) + _, metadata = extract_chapter_content(chapter) assert "Related URLs: https://valid.com, https://another-valid.com" in metadata assert "https://invalid.com" not in metadata @@ -241,7 +241,7 @@ def test_extract_chapter_content_none_invalid_urls(self): chapter.is_active = False chapter.owasp_repository = None - prose, metadata = extract_chapter_content(chapter) + _, metadata = extract_chapter_content(chapter) assert "Related URLs: https://valid.com" in metadata @@ -265,7 +265,7 @@ def test_extract_chapter_content_empty_related_urls_after_filter(self): chapter.is_active = False chapter.owasp_repository = None - prose, metadata = extract_chapter_content(chapter) + _, metadata = extract_chapter_content(chapter) assert "Related URLs:" not in metadata @@ -294,6 +294,6 @@ def test_extract_chapter_content_with_none_and_empty_urls(self): chapter.is_active = False chapter.owasp_repository = None - prose, metadata = extract_chapter_content(chapter) + _, metadata = extract_chapter_content(chapter) assert "Related URLs: https://valid.com, https://another-valid.com" in metadata diff --git a/backend/tests/apps/ai/common/extractors/committee_test.py b/backend/tests/apps/ai/common/extractors/committee_test.py index 594cfedaf2..320a451dd0 100644 --- a/backend/tests/apps/ai/common/extractors/committee_test.py +++ b/backend/tests/apps/ai/common/extractors/committee_test.py @@ -93,7 +93,7 @@ def test_extract_committee_content_with_invalid_urls(self): committee.is_active = True committee.owasp_repository = None - prose, metadata = extract_committee_content(committee) + _, metadata = extract_committee_content(committee) assert "Related URLs: https://valid.com" in metadata assert "https://invalid.com" not in metadata @@ -112,7 +112,7 @@ def test_extract_committee_content_no_invalid_urls_attr(self): committee.owasp_repository = None del committee.invalid_urls - prose, metadata = extract_committee_content(committee) + _, metadata = extract_committee_content(committee) assert "Related URLs: https://valid.com" in metadata @@ -197,6 +197,6 @@ def test_extract_committee_content_all_empty_after_filter(self): committee.is_active = True committee.owasp_repository = None - prose, metadata = extract_committee_content(committee) + _, metadata = extract_committee_content(committee) assert "Related URLs:" not in metadata diff --git a/backend/tests/apps/ai/common/extractors/event_test.py b/backend/tests/apps/ai/common/extractors/event_test.py index 354c1967ae..8451674f4e 100644 --- a/backend/tests/apps/ai/common/extractors/event_test.py +++ b/backend/tests/apps/ai/common/extractors/event_test.py @@ -136,7 +136,7 @@ def test_extract_event_content_zero_coordinates(self): event.longitude = 0.0 event.url = None - prose, metadata = extract_event_content(event) + _, metadata = extract_event_content(event) assert "Event Name: Test Event" in metadata assert "Coordinates: 0.0, 0.0" in metadata @@ -202,7 +202,7 @@ def test_extract_event_content_category_display_method(self): event.longitude = None event.url = None - prose, metadata = extract_event_content(event) + _, metadata = extract_event_content(event) event.get_category_display.assert_called_once() assert "Category: Meetup" in metadata diff --git a/backend/tests/apps/ai/common/extractors/project_test.py b/backend/tests/apps/ai/common/extractors/project_test.py index 2eb2bd4b61..01236103ea 100644 --- a/backend/tests/apps/ai/common/extractors/project_test.py +++ b/backend/tests/apps/ai/common/extractors/project_test.py @@ -130,7 +130,7 @@ def test_extract_project_content_partial_statistics(self): project.is_active = True project.owasp_repository = None - prose, metadata = extract_project_content(project) + _, metadata = extract_project_content(project) assert "Project Statistics: Stars: 100, Contributors: 5, Open Issues: 3" in metadata @@ -162,7 +162,7 @@ def test_extract_project_content_zero_statistics(self): project.is_active = True project.owasp_repository = None - prose, metadata = extract_project_content(project) + _, metadata = extract_project_content(project) assert "Project Statistics:" not in metadata @@ -194,7 +194,7 @@ def test_extract_project_content_with_invalid_urls(self): project.is_active = True project.owasp_repository = None - prose, metadata = extract_project_content(project) + _, metadata = extract_project_content(project) assert "Related URLs: https://valid.com" in metadata assert "https://invalid.com" not in metadata @@ -227,7 +227,7 @@ def test_extract_project_content_dates_only(self): project.is_active = True project.owasp_repository = None - prose, metadata = extract_project_content(project) + _, metadata = extract_project_content(project) assert "Created: 2021-03-01" in metadata assert "Last Updated:" not in metadata @@ -261,7 +261,7 @@ def test_extract_project_content_health_score_zero(self): project.is_active = True project.owasp_repository = None - prose, metadata = extract_project_content(project) + _, metadata = extract_project_content(project) assert "Health Score: 0.00" in metadata @@ -331,7 +331,7 @@ def test_extract_project_content_no_invalid_urls_attr(self): project.owasp_repository = None del project.invalid_urls - prose, metadata = extract_project_content(project) + _, metadata = extract_project_content(project) assert "Related URLs: https://valid.com" in metadata @@ -436,6 +436,6 @@ def test_extract_project_content_with_empty_related_urls(self): project.is_active = True project.owasp_repository = None - prose, metadata = extract_project_content(project) + _, metadata = extract_project_content(project) assert "Related URLs: https://valid.com, https://another.com" in metadata diff --git a/backend/tests/apps/ai/models/chunk_test.py b/backend/tests/apps/ai/models/chunk_test.py index 94f6432be4..23b98e6a73 100644 --- a/backend/tests/apps/ai/models/chunk_test.py +++ b/backend/tests/apps/ai/models/chunk_test.py @@ -60,6 +60,7 @@ def test_update_data_save_with_context(self, mock_save): chunk_instance = Mock() chunk_instance.context_id = 123 mock_chunk.return_value = chunk_instance + mock_chunk.objects.filter.return_value.exists.return_value = False result = Chunk.update_data( text=text, embedding=embedding, context=mock_context, save=True @@ -80,6 +81,7 @@ def test_update_data_no_save(self): chunk_instance = Mock() chunk_instance.context_id = None mock_chunk.return_value = chunk_instance + mock_chunk.objects.filter.return_value.exists.return_value = False result = Chunk.update_data( text=text, embedding=embedding, context=mock_context, save=False @@ -91,6 +93,22 @@ def test_update_data_no_save(self): chunk_instance.save.assert_not_called() assert result is chunk_instance + def test_update_data_chunk_already_exists(self): + """Test that update_data returns None when chunk already exists.""" + text = "Test chunk content" + embedding = [0.1, 0.2, 0.3] + mock_context = Mock(spec=Context) + + with patch("apps.ai.models.chunk.Chunk") as mock_chunk: + mock_chunk.objects.filter.return_value.exists.return_value = True + + result = Chunk.update_data( + text=text, embedding=embedding, context=mock_context, save=True + ) + + mock_chunk.assert_not_called() + assert result is None + def test_meta_class_attributes(self): assert Chunk._meta.db_table == "ai_chunks" assert Chunk._meta.verbose_name == "Chunk" From a9da28b3cc0bb9633544716be761adabcd5c441f Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Thu, 14 Aug 2025 10:11:55 +0530 Subject: [PATCH 22/32] chunk model update --- backend/apps/ai/models/chunk.py | 6 +- backend/tests/apps/ai/models/chunk_test.py | 70 ++++++++++++++++------ 2 files changed, 57 insertions(+), 19 deletions(-) diff --git a/backend/apps/ai/models/chunk.py b/backend/apps/ai/models/chunk.py index edb4b12756..9b57844c81 100644 --- a/backend/apps/ai/models/chunk.py +++ b/backend/apps/ai/models/chunk.py @@ -61,7 +61,11 @@ def update_data( Chunk: The created chunk instance. """ - if Chunk.objects.filter(context=context, text=text).exists(): + if Chunk.objects.filter( + context__content_type=context.content_type, + context__object_id=context.object_id, + text=text, + ).exists(): return None chunk = Chunk(text=text, embedding=embedding, context=context) diff --git a/backend/tests/apps/ai/models/chunk_test.py b/backend/tests/apps/ai/models/chunk_test.py index 23b98e6a73..6571aea959 100644 --- a/backend/tests/apps/ai/models/chunk_test.py +++ b/backend/tests/apps/ai/models/chunk_test.py @@ -72,41 +72,75 @@ def test_update_data_save_with_context(self, mock_save): chunk_instance.save.assert_called_once() assert result is chunk_instance - def test_update_data_no_save(self): - text = "Test chunk content" + def test_update_data_creates_new_chunk_and_saves(self, mock_context): + """Test that a new chunk is created and saved.""" + text = "New unique chunk content" embedding = [0.1, 0.2, 0.3] - mock_context = Mock(spec=Context) - with patch("apps.ai.models.chunk.Chunk") as mock_chunk: - chunk_instance = Mock() - chunk_instance.context_id = None - mock_chunk.return_value = chunk_instance - mock_chunk.objects.filter.return_value.exists.return_value = False + with patch("apps.ai.models.chunk.Chunk") as mock_chunk_class: + mock_chunk_class.objects.filter.return_value.exists.return_value = False + mock_instance = Mock(spec=Chunk) + mock_chunk_class.return_value = mock_instance + + result = Chunk.update_data( + text=text, embedding=embedding, context=mock_context, save=True + ) + + mock_chunk_class.objects.filter.assert_called_once_with( + context__content_type=mock_context.content_type, + context__object_id=mock_context.object_id, + text=text, + ) + mock_chunk_class.assert_called_once_with( + text=text, embedding=embedding, context=mock_context + ) + mock_instance.save.assert_called_once() + assert result is mock_instance + + def test_update_data_creates_new_chunk_no_save(self, mock_context): + """Test that a new chunk is created but NOT saved when save=False.""" + text = "New unique chunk content" + embedding = [0.1, 0.2, 0.3] + + with patch("apps.ai.models.chunk.Chunk") as mock_chunk_class: + mock_chunk_class.objects.filter.return_value.exists.return_value = False + mock_instance = Mock(spec=Chunk) + mock_chunk_class.return_value = mock_instance result = Chunk.update_data( text=text, embedding=embedding, context=mock_context, save=False ) - mock_chunk.assert_called_once_with( + mock_chunk_class.objects.filter.assert_called_once_with( + context__content_type=mock_context.content_type, + context__object_id=mock_context.object_id, + text=text, + ) + mock_chunk_class.assert_called_once_with( text=text, embedding=embedding, context=mock_context ) - chunk_instance.save.assert_not_called() - assert result is chunk_instance + mock_instance.save.assert_not_called() + assert result is mock_instance - def test_update_data_chunk_already_exists(self): - """Test that update_data returns None when chunk already exists.""" - text = "Test chunk content" + def test_update_data_returns_none_if_chunk_already_exists(self, mock_context): + """Test that update_data returns None when a chunk with the same text.""" + text = "Existing chunk content" embedding = [0.1, 0.2, 0.3] - mock_context = Mock(spec=Context) - with patch("apps.ai.models.chunk.Chunk") as mock_chunk: - mock_chunk.objects.filter.return_value.exists.return_value = True + with patch("apps.ai.models.chunk.Chunk") as mock_chunk_class: + mock_chunk_class.objects.filter.return_value.exists.return_value = True result = Chunk.update_data( text=text, embedding=embedding, context=mock_context, save=True ) - mock_chunk.assert_not_called() + mock_chunk_class.objects.filter.assert_called_once_with( + context__content_type=mock_context.content_type, + context__object_id=mock_context.object_id, + text=text, + ) + + mock_chunk_class.assert_not_called() assert result is None def test_meta_class_attributes(self): From a0ed3111bcbc89bb1a8c1c0e388335cebdc00d83 Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Sat, 16 Aug 2025 15:59:22 +0530 Subject: [PATCH 23/32] update logic and suggestions --- backend/apps/ai/Makefile | 32 +-- backend/apps/ai/admin.py | 30 +-- backend/apps/ai/agent/tools/rag/retriever.py | 157 +++++++------- backend/apps/ai/common/base/ai_command.py | 52 ++--- backend/apps/ai/common/base/chunk_command.py | 14 +- .../apps/ai/common/base/context_command.py | 4 +- backend/apps/ai/common/utils.py | 33 +++ .../commands/ai_create_chapter_chunks.py | 17 +- .../commands/ai_create_committee_chunks.py | 17 +- .../commands/ai_create_event_chunks.py | 17 +- .../commands/ai_create_project_chunks.py | 17 +- .../ai_create_slack_message_chunks.py | 19 +- ...ontext.py => ai_update_chapter_context.py} | 17 +- ...text.py => ai_update_committee_context.py} | 17 +- ..._context.py => ai_update_event_context.py} | 17 +- ...ontext.py => ai_update_project_context.py} | 17 +- ....py => ai_update_slack_message_context.py} | 19 +- ...me_object_id_context_entity_id_and_more.py | 27 +++ backend/apps/ai/models/chunk.py | 4 +- backend/apps/ai/models/context.py | 53 +++-- .../apps/ai/agent/tools/rag/retriever_test.py | 11 +- .../apps/ai/common/base/ai_command_test.py | 111 +++------- .../apps/ai/common/base/chunk_command_test.py | 52 ++--- .../ai/common/base/context_command_test.py | 51 ++--- .../commands/ai_create_chapter_chunks_test.py | 8 +- .../ai_create_committee_chunks_test.py | 8 +- .../commands/ai_create_event_chunks_test.py | 8 +- .../commands/ai_create_project_chunks_test.py | 8 +- .../ai_create_slack_message_chunks_test.py | 8 +- ...t.py => ai_update_chapter_context_test.py} | 12 +- ...py => ai_update_committee_context_test.py} | 12 +- ...est.py => ai_update_event_context_test.py} | 12 +- ...t.py => ai_update_project_context_test.py} | 12 +- ...> ai_update_slack_message_context_test.py} | 10 +- backend/tests/apps/ai/models/chunk_test.py | 14 +- backend/tests/apps/ai/models/context_test.py | 193 ++++++++++++------ 36 files changed, 528 insertions(+), 582 deletions(-) rename backend/apps/ai/management/commands/{ai_create_chapter_context.py => ai_update_chapter_context.py} (60%) rename backend/apps/ai/management/commands/{ai_create_committee_context.py => ai_update_committee_context.py} (60%) rename backend/apps/ai/management/commands/{ai_create_event_context.py => ai_update_event_context.py} (71%) rename backend/apps/ai/management/commands/{ai_create_project_context.py => ai_update_project_context.py} (66%) rename backend/apps/ai/management/commands/{ai_create_slack_message_context.py => ai_update_slack_message_context.py} (79%) create mode 100644 backend/apps/ai/migrations/0009_rename_object_id_context_entity_id_and_more.py rename backend/tests/apps/ai/management/commands/{ai_create_chapter_context_test.py => ai_update_chapter_context_test.py} (83%) rename backend/tests/apps/ai/management/commands/{ai_create_committee_context_test.py => ai_update_committee_context_test.py} (96%) rename backend/tests/apps/ai/management/commands/{ai_create_event_context_test.py => ai_update_event_context_test.py} (88%) rename backend/tests/apps/ai/management/commands/{ai_create_project_context_test.py => ai_update_project_context_test.py} (85%) rename backend/tests/apps/ai/management/commands/{ai_create_slack_message_context_test.py => ai_update_slack_message_context_test.py} (91%) diff --git a/backend/apps/ai/Makefile b/backend/apps/ai/Makefile index ca7125c79f..948abeb5f8 100644 --- a/backend/apps/ai/Makefile +++ b/backend/apps/ai/Makefile @@ -1,43 +1,43 @@ -ai-create-chapter-context: - @echo "Creating chapter context" - @CMD="python manage.py ai_create_chapter_context" $(MAKE) exec-backend-command - ai-create-chapter-chunks: @echo "Creating chapter chunks" @CMD="python manage.py ai_create_chapter_chunks" $(MAKE) exec-backend-command -ai-create-committee-context: - @echo "Creating committee context" - @CMD="python manage.py ai_create_committee_context" $(MAKE) exec-backend-command +ai-create-chapter-context: + @echo "Creating chapter context" + @CMD="python manage.py ai_create_chapter_context" $(MAKE) exec-backend-command ai-create-committee-chunks: @echo "Creating committee chunks" @CMD="python manage.py ai_create_committee_chunks" $(MAKE) exec-backend-command -ai-create-event-context: - @echo "Creating event context" - @CMD="python manage.py ai_create_event_context" $(MAKE) exec-backend-command +ai-create-committee-context: + @echo "Creating committee context" + @CMD="python manage.py ai_create_committee_context" $(MAKE) exec-backend-command ai-create-event-chunks: @echo "Creating event chunks" @CMD="python manage.py ai_create_event_chunks" $(MAKE) exec-backend-command -ai-create-project-context: - @echo "Creating project context" - @CMD="python manage.py ai_create_project_context" $(MAKE) exec-backend-command +ai-create-event-context: + @echo "Creating event context" + @CMD="python manage.py ai_create_event_context" $(MAKE) exec-backend-command ai-create-project-chunks: @echo "Creating project chunks" @CMD="python manage.py ai_create_project_chunks" $(MAKE) exec-backend-command -ai-create-slack-message-context: - @echo "Creating Slack message context" - @CMD="python manage.py ai_create_slack_message_context" $(MAKE) exec-backend-command +ai-create-project-context: + @echo "Creating project context" + @CMD="python manage.py ai_create_project_context" $(MAKE) exec-backend-command ai-create-slack-message-chunks: @echo "Creating Slack message chunks" @CMD="python manage.py ai_create_slack_message_chunks" $(MAKE) exec-backend-command +ai-create-slack-message-context: + @echo "Creating Slack message context" + @CMD="python manage.py ai_create_slack_message_context" $(MAKE) exec-backend-command + ai-run-rag-tool: @echo "Running RAG tool" @CMD="python manage.py ai_run_rag_tool" $(MAKE) exec-backend-command diff --git a/backend/apps/ai/admin.py b/backend/apps/ai/admin.py index cd804992cd..d0852aeb48 100644 --- a/backend/apps/ai/admin.py +++ b/backend/apps/ai/admin.py @@ -6,31 +6,31 @@ from apps.ai.models.context import Context -class ContextAdmin(admin.ModelAdmin): - """Admin for Context model.""" +class ChunkAdmin(admin.ModelAdmin): + """Admin for Chunk model.""" list_display = ( "id", - "content", - "content_type", - "object_id", - "source", + "text", + "context", ) - search_fields = ("content", "source") - list_filter = ("content_type", "source") + list_filter = ("context__entity_type",) + search_fields = ("text",) -class ChunkAdmin(admin.ModelAdmin): - """Admin for Chunk model.""" +class ContextAdmin(admin.ModelAdmin): + """Admin for Context model.""" list_display = ( "id", - "text", - "context", + "content", + "entity_type", + "entity_id", + "source", ) - search_fields = ("text",) - list_filter = ("context__content_type",) + list_filter = ("entity_type", "source") + search_fields = ("content", "source") -admin.site.register(Context, ContextAdmin) admin.site.register(Chunk, ChunkAdmin) +admin.site.register(Context, ContextAdmin) diff --git a/backend/apps/ai/agent/tools/rag/retriever.py b/backend/apps/ai/agent/tools/rag/retriever.py index 501f2f06f4..3f9d043bd5 100644 --- a/backend/apps/ai/agent/tools/rag/retriever.py +++ b/backend/apps/ai/agent/tools/rag/retriever.py @@ -21,7 +21,7 @@ class Retriever: """A class for retrieving relevant text chunks for a RAG.""" - SUPPORTED_CONTENT_TYPES = ("event", "project", "chapter", "committee", "message") + SUPPORTED_ENTITY_TYPES = ("event", "project", "chapter", "committee", "message") def __init__(self, embedding_model: str = "text-embedding-3-small"): """Initialize the Retriever. @@ -63,114 +63,113 @@ def get_query_embedding(self, query: str) -> list[float]: logger.exception("Unexpected error while generating embedding") raise - def get_source_name(self, content_object) -> str: + def get_source_name(self, entity) -> str: """Get the name/identifier for the content object.""" for attr in ("name", "title", "login", "key", "summary"): - if getattr(content_object, attr, None): - return str(getattr(content_object, attr)) - return str(content_object) + if getattr(entity, attr, None): + return str(getattr(entity, attr)) + return str(entity) - def get_additional_context(self, content_object, content_type: str) -> dict[str, Any]: + def get_additional_context(self, entity, entity_type: str) -> dict[str, Any]: """Get additional context information based on content type. Args: - content_object: The source object. - content_type: The model name of the content object. + entity: The source object. + entity_type: The model name of the content object. Returns: A dictionary with additional context information. """ context = {} - clean_content_type = content_type.split(".")[-1] if "." in content_type else content_type + clean_content_type = entity_type.split(".")[-1] if "." in entity_type else entity_type if clean_content_type == "chapter": context.update( { - "location": getattr(content_object, "suggested_location", None), - "region": getattr(content_object, "region", None), - "country": getattr(content_object, "country", None), - "postal_code": getattr(content_object, "postal_code", None), - "currency": getattr(content_object, "currency", None), - "meetup_group": getattr(content_object, "meetup_group", None), - "tags": getattr(content_object, "tags", []), - "topics": getattr(content_object, "topics", []), - "leaders": getattr(content_object, "leaders_raw", []), - "related_urls": getattr(content_object, "related_urls", []), - "is_active": getattr(content_object, "is_active", None), - "url": getattr(content_object, "url", None), + "location": getattr(entity, "suggested_location", None), + "region": getattr(entity, "region", None), + "country": getattr(entity, "country", None), + "postal_code": getattr(entity, "postal_code", None), + "currency": getattr(entity, "currency", None), + "meetup_group": getattr(entity, "meetup_group", None), + "tags": getattr(entity, "tags", []), + "topics": getattr(entity, "topics", []), + "leaders": getattr(entity, "leaders_raw", []), + "related_urls": getattr(entity, "related_urls", []), + "is_active": getattr(entity, "is_active", None), + "url": getattr(entity, "url", None), } ) elif clean_content_type == "project": context.update( { - "level": getattr(content_object, "level", None), - "project_type": getattr(content_object, "type", None), - "languages": getattr(content_object, "languages", []), - "topics": getattr(content_object, "topics", []), - "licenses": getattr(content_object, "licenses", []), - "tags": getattr(content_object, "tags", []), - "custom_tags": getattr(content_object, "custom_tags", []), - "stars_count": getattr(content_object, "stars_count", None), - "forks_count": getattr(content_object, "forks_count", None), - "contributors_count": getattr(content_object, "contributors_count", None), - "releases_count": getattr(content_object, "releases_count", None), - "open_issues_count": getattr(content_object, "open_issues_count", None), - "leaders": getattr(content_object, "leaders_raw", []), - "related_urls": getattr(content_object, "related_urls", []), - "created_at": getattr(content_object, "created_at", None), - "updated_at": getattr(content_object, "updated_at", None), - "released_at": getattr(content_object, "released_at", None), - "health_score": getattr(content_object, "health_score", None), - "is_active": getattr(content_object, "is_active", None), - "track_issues": getattr(content_object, "track_issues", None), - "url": getattr(content_object, "url", None), + "level": getattr(entity, "level", None), + "project_type": getattr(entity, "type", None), + "languages": getattr(entity, "languages", []), + "topics": getattr(entity, "topics", []), + "licenses": getattr(entity, "licenses", []), + "tags": getattr(entity, "tags", []), + "custom_tags": getattr(entity, "custom_tags", []), + "stars_count": getattr(entity, "stars_count", None), + "forks_count": getattr(entity, "forks_count", None), + "contributors_count": getattr(entity, "contributors_count", None), + "releases_count": getattr(entity, "releases_count", None), + "open_issues_count": getattr(entity, "open_issues_count", None), + "leaders": getattr(entity, "leaders_raw", []), + "related_urls": getattr(entity, "related_urls", []), + "created_at": getattr(entity, "created_at", None), + "updated_at": getattr(entity, "updated_at", None), + "released_at": getattr(entity, "released_at", None), + "health_score": getattr(entity, "health_score", None), + "is_active": getattr(entity, "is_active", None), + "track_issues": getattr(entity, "track_issues", None), + "url": getattr(entity, "url", None), } ) elif clean_content_type == "event": context.update( { - "start_date": getattr(content_object, "start_date", None), - "end_date": getattr(content_object, "end_date", None), - "location": getattr(content_object, "suggested_location", None), - "category": getattr(content_object, "category", None), - "latitude": getattr(content_object, "latitude", None), - "longitude": getattr(content_object, "longitude", None), - "url": getattr(content_object, "url", None), - "description": getattr(content_object, "description", None), - "summary": getattr(content_object, "summary", None), + "start_date": getattr(entity, "start_date", None), + "end_date": getattr(entity, "end_date", None), + "location": getattr(entity, "suggested_location", None), + "category": getattr(entity, "category", None), + "latitude": getattr(entity, "latitude", None), + "longitude": getattr(entity, "longitude", None), + "url": getattr(entity, "url", None), + "description": getattr(entity, "description", None), + "summary": getattr(entity, "summary", None), } ) elif clean_content_type == "committee": context.update( { - "is_active": getattr(content_object, "is_active", None), - "leaders": getattr(content_object, "leaders", []), - "url": getattr(content_object, "url", None), - "description": getattr(content_object, "description", None), - "summary": getattr(content_object, "summary", None), - "tags": getattr(content_object, "tags", []), - "topics": getattr(content_object, "topics", []), - "related_urls": getattr(content_object, "related_urls", []), + "is_active": getattr(entity, "is_active", None), + "leaders": getattr(entity, "leaders", []), + "url": getattr(entity, "url", None), + "description": getattr(entity, "description", None), + "summary": getattr(entity, "summary", None), + "tags": getattr(entity, "tags", []), + "topics": getattr(entity, "topics", []), + "related_urls": getattr(entity, "related_urls", []), } ) elif clean_content_type == "message": context.update( { "channel": ( - getattr(content_object.conversation, "slack_channel_id", None) - if hasattr(content_object, "conversation") and content_object.conversation + getattr(entity.conversation, "slack_channel_id", None) + if hasattr(entity, "conversation") and entity.conversation else None ), "thread_ts": ( - getattr(content_object.parent_message, "ts", None) - if hasattr(content_object, "parent_message") - and content_object.parent_message + getattr(entity.parent_message, "ts", None) + if hasattr(entity, "parent_message") and entity.parent_message else None ), - "ts": getattr(content_object, "ts", None), + "ts": getattr(entity, "ts", None), "user": ( - getattr(content_object.author, "name", None) - if hasattr(content_object, "author") and content_object.author + getattr(entity.author, "name", None) + if hasattr(entity, "author") and entity.author else None ), } @@ -209,33 +208,33 @@ def retrieve( if "." in lower_name: app_label, model = lower_name.split(".", 1) content_type_query |= Q( - context__content_type__app_label=app_label, - context__content_type__model=model, + context__entity_type__app_label=app_label, + context__entity_type__model=model, ) else: - content_type_query |= Q(context__content_type__model=lower_name) + content_type_query |= Q(context__entity_type__model=lower_name) queryset = queryset.filter(content_type_query) - chunks = queryset.select_related("context__content_type").order_by("-similarity")[:limit] + chunks = queryset.select_related("context__entity_type").order_by("-similarity")[:limit] results = [] for chunk in chunks: - if not chunk.context or not chunk.context.content_object: + if not chunk.context or not chunk.context.entity: logger.warning("Content object is None for chunk %s. Skipping.", chunk.id) continue - source_name = self.get_source_name(chunk.context.content_object) + source_name = self.get_source_name(chunk.context.entity) additional_context = self.get_additional_context( - chunk.context.content_object, chunk.context.content_type.model + chunk.context.entity, chunk.context.entity_type.model ) results.append( { "text": chunk.text, "similarity": float(chunk.similarity), - "source_type": chunk.context.content_type.model, + "source_type": chunk.context.entity_type.model, "source_name": source_name, - "source_id": chunk.context.object_id, + "source_id": chunk.context.entity_id, "additional_context": additional_context, } ) @@ -255,9 +254,9 @@ def extract_content_types_from_query(self, query: str) -> list[str]: query_words = set(re.findall(r"\b\w+\b", query.lower())) detected_types = [ - content_type - for content_type in self.SUPPORTED_CONTENT_TYPES - if content_type in query_words or f"{content_type}s" in query_words + entity_type + for entity_type in self.SUPPORTED_ENTITY_TYPES + if entity_type in query_words or f"{entity_type}s" in query_words ] if detected_types: diff --git a/backend/apps/ai/common/base/ai_command.py b/backend/apps/ai/common/base/ai_command.py index 62a1909279..fea78193af 100644 --- a/backend/apps/ai/common/base/ai_command.py +++ b/backend/apps/ai/common/base/ai_command.py @@ -1,7 +1,6 @@ """Base AI command class with common functionality.""" import os -from abc import ABC, abstractmethod from collections.abc import Callable from typing import Any @@ -10,41 +9,26 @@ from django.db.models import Model, QuerySet -class BaseAICommand(BaseCommand, ABC): +class BaseAICommand(BaseCommand): """Base class for AI management commands with common functionality.""" + model_class: type[Model] + entity_name: str + entity_name_plural: str + key_field_name: str + def __init__(self, *args, **kwargs): """Initialize the AI command with OpenAI client placeholder.""" super().__init__(*args, **kwargs) self.openai_client: openai.OpenAI | None = None - @abstractmethod - def model_class(self) -> type[Model]: - """Return the Django model class this command operates on.""" - - @abstractmethod - def entity_name(self) -> str: - """Return the human-readable name for the entity (e.g., 'chapter', 'project').""" - - @abstractmethod - def entity_name_plural(self) -> str: - """Return the plural form of the entity name.""" - - @abstractmethod - def key_field_name(self) -> str: - """Return the field name used for filtering by key (e.g., 'key', 'slug').""" - - @abstractmethod - def extract_content(self, entity: Model) -> tuple[str, str]: - """Extract content from the entity. Return (prose_content, metadata_content).""" - def source_name(self) -> str: """Return the source name for context creation. Override if different from default.""" - return f"owasp_{self.entity_name()}" + return f"owasp_{self.entity_name}" def get_base_queryset(self) -> QuerySet: """Return the base queryset. Override for custom filtering logic.""" - return self.model_class().objects.all() + return self.model_class.objects.all() def get_default_queryset(self) -> QuerySet: """Return the default queryset when no specific options are provided.""" @@ -53,20 +37,20 @@ def get_default_queryset(self) -> QuerySet: def add_common_arguments(self, parser): """Add common arguments that most commands need.""" parser.add_argument( - f"--{self.entity_name()}-key", + f"--{self.entity_name}-key", type=str, - help=f"Process only the {self.entity_name()} with this key", + help=f"Process only the {self.entity_name} with this key", ) parser.add_argument( "--all", action="store_true", - help=f"Process all the {self.entity_name_plural()}", + help=f"Process all the {self.entity_name_plural}", ) parser.add_argument( "--batch-size", type=int, default=50, - help=f"Number of {self.entity_name_plural()} to process in each batch", + help=f"Number of {self.entity_name_plural} to process in each batch", ) def add_arguments(self, parser): @@ -75,10 +59,10 @@ def add_arguments(self, parser): def get_queryset(self, options: dict[str, Any]) -> QuerySet: """Get the queryset based on command options.""" - key_option = f"{self.entity_name()}_key" + key_option = f"{self.entity_name}_key" if options.get(key_option): - filter_kwargs = {self.key_field_name(): options[key_option]} + filter_kwargs = {self.key_field_name: options[key_option]} return self.get_base_queryset().filter(**filter_kwargs) if options.get("all"): return self.get_base_queryset() @@ -86,7 +70,7 @@ def get_queryset(self, options: dict[str, Any]) -> QuerySet: def get_entity_key(self, entity: Model) -> str: """Get the key/identifier for an entity for display purposes.""" - return str(getattr(entity, self.key_field_name(), entity.pk)) + return str(getattr(entity, self.key_field_name, entity.pk)) def setup_openai_client(self) -> bool: """Set up OpenAI client if API key is available.""" @@ -108,10 +92,10 @@ def handle_batch_processing( total_count = queryset.count() if not total_count: - self.stdout.write(f"No {self.entity_name_plural()} found to process") + self.stdout.write(f"No {self.entity_name_plural} found to process") return - self.stdout.write(f"Found {total_count} {self.entity_name_plural()} to process") + self.stdout.write(f"Found {total_count} {self.entity_name_plural} to process") processed_count = 0 for offset in range(0, total_count, batch_size): @@ -120,6 +104,6 @@ def handle_batch_processing( self.stdout.write( self.style.SUCCESS( - f"Completed processing {processed_count}/{total_count} {self.entity_name_plural()}" + f"Completed processing {processed_count}/{total_count} {self.entity_name_plural}" ) ) diff --git a/backend/apps/ai/common/base/chunk_command.py b/backend/apps/ai/common/base/chunk_command.py index a8b84174fa..272e50a460 100644 --- a/backend/apps/ai/common/base/chunk_command.py +++ b/backend/apps/ai/common/base/chunk_command.py @@ -14,24 +14,22 @@ class BaseChunkCommand(BaseAICommand): def help(self) -> str: """Return help text for the chunk creation command.""" - return f"Create chunks for OWASP {self.entity_name()} data" + return f"Create chunks for OWASP {self.entity_name} data" def process_chunks_batch(self, entities: list[Model]) -> int: """Process a batch of entities to create chunks.""" processed = 0 batch_chunks = [] - content_type = ContentType.objects.get_for_model(self.model_class()) + content_type = ContentType.objects.get_for_model(self.model_class) for entity in entities: - context = Context.objects.filter( - content_type=content_type, object_id=entity.id - ).first() + context = Context.objects.filter(entity_type=content_type, entity_id=entity.id).first() entity_key = self.get_entity_key(entity) if not context: self.stdout.write( - self.style.WARNING(f"No context found for {self.entity_name()} {entity_key}") + self.style.WARNING(f"No context found for {self.entity_name} {entity_key}") ) continue @@ -41,12 +39,12 @@ def process_chunks_batch(self, entities: list[Model]) -> int: ) if not full_content.strip(): - self.stdout.write(f"No content to chunk for {self.entity_name()} {entity_key}") + self.stdout.write(f"No content to chunk for {self.entity_name} {entity_key}") continue chunk_texts = Chunk.split_text(full_content) if not chunk_texts: - self.stdout.write(f"No chunks created for {self.entity_name()} {entity_key}") + self.stdout.write(f"No chunks created for {self.entity_name} {entity_key}") continue if chunks := create_chunks_and_embeddings( diff --git a/backend/apps/ai/common/base/context_command.py b/backend/apps/ai/common/base/context_command.py index ee4f64dfd0..4a3c75b0c7 100644 --- a/backend/apps/ai/common/base/context_command.py +++ b/backend/apps/ai/common/base/context_command.py @@ -11,7 +11,7 @@ class BaseContextCommand(BaseAICommand): def help(self) -> str: """Return help text for the context creation command.""" - return f"Update context for OWASP {self.entity_name()} data" + return f"Update context for OWASP {self.entity_name} data" def process_context_batch(self, entities: list[Model]) -> int: """Process a batch of entities to create contexts.""" @@ -25,7 +25,7 @@ def process_context_batch(self, entities: list[Model]) -> int: if not full_content.strip(): entity_key = self.get_entity_key(entity) - self.stdout.write(f"No content for {self.entity_name()} {entity_key}") + self.stdout.write(f"No content for {self.entity_name} {entity_key}") continue if Context.update_data( diff --git a/backend/apps/ai/common/utils.py b/backend/apps/ai/common/utils.py index 97ab4a0abf..b9012ba252 100644 --- a/backend/apps/ai/common/utils.py +++ b/backend/apps/ai/common/utils.py @@ -40,6 +40,8 @@ def create_chunks_and_embeddings( ValueError: If context is None or invalid """ + from apps.ai.models.chunk import Chunk + try: last_request_time = datetime.now(UTC) - timedelta( seconds=DEFAULT_LAST_REQUEST_OFFSET_SECONDS @@ -66,3 +68,34 @@ def create_chunks_and_embeddings( return [] else: return chunks + + +def regenerate_chunks_for_context(context: Context): + """Regenerates all chunks for a single, specific context instance. + + Args: + context (Context): The specific context instance to be updated. + + """ + from apps.ai.models.chunk import Chunk + + old_chunk_count = context.chunks.count() + if old_chunk_count > 0: + context.chunks.all().delete() + + new_chunk_texts = Chunk.split_text(context.content) + + if not new_chunk_texts: + logger.warning("No content to chunk for Context. Process stopped.") + return + + openai_client = openai.Client() + + create_chunks_and_embeddings( + chunk_texts=new_chunk_texts, + context=context, + openai_client=openai_client, + save=True, + ) + + logger.info("Successfully completed chunk regeneration for new context") diff --git a/backend/apps/ai/management/commands/ai_create_chapter_chunks.py b/backend/apps/ai/management/commands/ai_create_chapter_chunks.py index bafd397c9f..9b48c4dba3 100644 --- a/backend/apps/ai/management/commands/ai_create_chapter_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_chapter_chunks.py @@ -1,25 +1,16 @@ """A command to create chunks of OWASP chapter data for RAG.""" -from django.db.models import Model - from apps.ai.common.base.chunk_command import BaseChunkCommand from apps.ai.common.extractors.chapter import extract_chapter_content from apps.owasp.models.chapter import Chapter class Command(BaseChunkCommand): - def entity_name(self) -> str: - return "chapter" - - def entity_name_plural(self) -> str: - return "chapters" + entity_name = "chapter" + entity_name_plural = "chapters" + key_field_name = "key" + model_class = Chapter def extract_content(self, entity: Chapter) -> tuple[str, str]: """Extract content from the chapter.""" return extract_chapter_content(entity) - - def key_field_name(self) -> str: - return "key" - - def model_class(self) -> type[Model]: - return Chapter diff --git a/backend/apps/ai/management/commands/ai_create_committee_chunks.py b/backend/apps/ai/management/commands/ai_create_committee_chunks.py index fee8bb5ea4..df03d69856 100644 --- a/backend/apps/ai/management/commands/ai_create_committee_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_committee_chunks.py @@ -1,25 +1,16 @@ """A command to create chunks of OWASP committee data for RAG.""" -from django.db.models import Model - from apps.ai.common.base.chunk_command import BaseChunkCommand from apps.ai.common.extractors.committee import extract_committee_content from apps.owasp.models.committee import Committee class Command(BaseChunkCommand): - def entity_name(self) -> str: - return "committee" - - def entity_name_plural(self) -> str: - return "committees" + entity_name = "committee" + entity_name_plural = "committees" + key_field_name = "key" + model_class = Committee def extract_content(self, entity: Committee) -> tuple[str, str]: """Extract content from the committee.""" return extract_committee_content(entity) - - def key_field_name(self) -> str: - return "key" - - def model_class(self) -> type[Model]: - return Committee diff --git a/backend/apps/ai/management/commands/ai_create_event_chunks.py b/backend/apps/ai/management/commands/ai_create_event_chunks.py index e2a499008e..40e3cda49b 100644 --- a/backend/apps/ai/management/commands/ai_create_event_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_event_chunks.py @@ -1,6 +1,6 @@ """A command to create chunks of OWASP event data for RAG.""" -from django.db.models import Model, QuerySet +from django.db.models import QuerySet from apps.ai.common.base.chunk_command import BaseChunkCommand from apps.ai.common.extractors.event import extract_event_content @@ -8,11 +8,10 @@ class Command(BaseChunkCommand): - def entity_name(self) -> str: - return "event" - - def entity_name_plural(self) -> str: - return "events" + entity_name = "event" + entity_name_plural = "events" + key_field_name = "key" + model_class = Event def extract_content(self, entity: Event) -> tuple[str, str]: """Extract content from the event.""" @@ -25,9 +24,3 @@ def get_base_queryset(self) -> QuerySet: def get_default_queryset(self) -> QuerySet: """Return upcoming events by default instead of is_active filter.""" return Event.upcoming_events() - - def key_field_name(self) -> str: - return "key" - - def model_class(self) -> type[Model]: - return Event diff --git a/backend/apps/ai/management/commands/ai_create_project_chunks.py b/backend/apps/ai/management/commands/ai_create_project_chunks.py index 255b217558..673c28f252 100644 --- a/backend/apps/ai/management/commands/ai_create_project_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_project_chunks.py @@ -1,6 +1,6 @@ """A command to create chunks of OWASP project data for RAG.""" -from django.db.models import Model, QuerySet +from django.db.models import QuerySet from apps.ai.common.base.chunk_command import BaseChunkCommand from apps.ai.common.extractors.project import extract_project_content @@ -8,11 +8,10 @@ class Command(BaseChunkCommand): - def entity_name(self) -> str: - return "project" - - def entity_name_plural(self) -> str: - return "projects" + entity_name = "project" + entity_name_plural = "projects" + key_field_name = "key" + model_class = Project def extract_content(self, entity: Project) -> tuple[str, str]: """Extract content from the project.""" @@ -21,9 +20,3 @@ def extract_content(self, entity: Project) -> tuple[str, str]: def get_base_queryset(self) -> QuerySet: """Return the base queryset with ordering.""" return super().get_base_queryset() - - def key_field_name(self) -> str: - return "key" - - def model_class(self) -> type[Model]: - return Project diff --git a/backend/apps/ai/management/commands/ai_create_slack_message_chunks.py b/backend/apps/ai/management/commands/ai_create_slack_message_chunks.py index 31bedefa48..3f070a7600 100644 --- a/backend/apps/ai/management/commands/ai_create_slack_message_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_slack_message_chunks.py @@ -1,12 +1,17 @@ """A command to create chunks of Slack messages.""" -from django.db.models import Model, QuerySet +from django.db.models import QuerySet from apps.ai.common.base.chunk_command import BaseChunkCommand from apps.slack.models.message import Message class Command(BaseChunkCommand): + entity_name = "message" + entity_name_plural = "messages" + key_field_name = "slack_message_id" + model_class = Message + def add_arguments(self, parser): """Override to use different default batch size for messages.""" parser.add_argument( @@ -26,12 +31,6 @@ def add_arguments(self, parser): help="Number of messages to process in each batch", ) - def entity_name(self) -> str: - return "message" - - def entity_name_plural(self) -> str: - return "messages" - def extract_content(self, entity: Message) -> tuple[str, str]: """Extract content from the message.""" return entity.cleaned_text or "", "" @@ -40,11 +39,5 @@ def get_default_queryset(self) -> QuerySet: """Return all messages by default since Message model doesn't have is_active field.""" return self.get_base_queryset() - def key_field_name(self) -> str: - return "slack_message_id" - - def model_class(self) -> type[Model]: - return Message - def source_name(self) -> str: return "slack_message" diff --git a/backend/apps/ai/management/commands/ai_create_chapter_context.py b/backend/apps/ai/management/commands/ai_update_chapter_context.py similarity index 60% rename from backend/apps/ai/management/commands/ai_create_chapter_context.py rename to backend/apps/ai/management/commands/ai_update_chapter_context.py index 377024de1b..550d05c862 100644 --- a/backend/apps/ai/management/commands/ai_create_chapter_context.py +++ b/backend/apps/ai/management/commands/ai_update_chapter_context.py @@ -1,25 +1,16 @@ """A command to update context for OWASP chapter data.""" -from django.db.models import Model - from apps.ai.common.base.context_command import BaseContextCommand from apps.ai.common.extractors.chapter import extract_chapter_content from apps.owasp.models.chapter import Chapter class Command(BaseContextCommand): - def entity_name(self) -> str: - return "chapter" - - def entity_name_plural(self) -> str: - return "chapters" + entity_name = "chapter" + entity_name_plural = "chapters" + key_field_name = "key" + model_class = Chapter def extract_content(self, entity: Chapter) -> tuple[str, str]: """Extract content from the chapter.""" return extract_chapter_content(entity) - - def key_field_name(self) -> str: - return "key" - - def model_class(self) -> type[Model]: - return Chapter diff --git a/backend/apps/ai/management/commands/ai_create_committee_context.py b/backend/apps/ai/management/commands/ai_update_committee_context.py similarity index 60% rename from backend/apps/ai/management/commands/ai_create_committee_context.py rename to backend/apps/ai/management/commands/ai_update_committee_context.py index de3965196e..405d1f31f5 100644 --- a/backend/apps/ai/management/commands/ai_create_committee_context.py +++ b/backend/apps/ai/management/commands/ai_update_committee_context.py @@ -1,25 +1,16 @@ """A command to update context for OWASP committee data.""" -from django.db.models import Model - from apps.ai.common.base.context_command import BaseContextCommand from apps.ai.common.extractors.committee import extract_committee_content from apps.owasp.models.committee import Committee class Command(BaseContextCommand): - def entity_name(self) -> str: - return "committee" - - def entity_name_plural(self) -> str: - return "committees" + entity_name = "committee" + entity_name_plural = "committees" + key_field_name = "key" + model_class = Committee def extract_content(self, entity: Committee) -> tuple[str, str]: """Extract content from the committee.""" return extract_committee_content(entity) - - def key_field_name(self) -> str: - return "key" - - def model_class(self) -> type[Model]: - return Committee diff --git a/backend/apps/ai/management/commands/ai_create_event_context.py b/backend/apps/ai/management/commands/ai_update_event_context.py similarity index 71% rename from backend/apps/ai/management/commands/ai_create_event_context.py rename to backend/apps/ai/management/commands/ai_update_event_context.py index 49f20cb2b9..d07b942d44 100644 --- a/backend/apps/ai/management/commands/ai_create_event_context.py +++ b/backend/apps/ai/management/commands/ai_update_event_context.py @@ -1,6 +1,6 @@ """A command to update context for OWASP event data.""" -from django.db.models import Model, QuerySet +from django.db.models import QuerySet from apps.ai.common.base.context_command import BaseContextCommand from apps.ai.common.extractors.event import extract_event_content @@ -8,11 +8,10 @@ class Command(BaseContextCommand): - def entity_name(self) -> str: - return "event" - - def entity_name_plural(self) -> str: - return "events" + entity_name = "event" + entity_name_plural = "events" + key_field_name = "key" + model_class = Event def extract_content(self, entity: Event) -> tuple[str, str]: """Extract content from the event.""" @@ -25,9 +24,3 @@ def get_base_queryset(self) -> QuerySet: def get_default_queryset(self) -> QuerySet: """Return upcoming events by default instead of is_active filter.""" return Event.upcoming_events() - - def key_field_name(self) -> str: - return "key" - - def model_class(self) -> type[Model]: - return Event diff --git a/backend/apps/ai/management/commands/ai_create_project_context.py b/backend/apps/ai/management/commands/ai_update_project_context.py similarity index 66% rename from backend/apps/ai/management/commands/ai_create_project_context.py rename to backend/apps/ai/management/commands/ai_update_project_context.py index 47e509f1e6..aa594085c6 100644 --- a/backend/apps/ai/management/commands/ai_create_project_context.py +++ b/backend/apps/ai/management/commands/ai_update_project_context.py @@ -1,6 +1,6 @@ """A command to update context for OWASP project data.""" -from django.db.models import Model, QuerySet +from django.db.models import QuerySet from apps.ai.common.base.context_command import BaseContextCommand from apps.ai.common.extractors.project import extract_project_content @@ -8,11 +8,10 @@ class Command(BaseContextCommand): - def entity_name(self) -> str: - return "project" - - def entity_name_plural(self) -> str: - return "projects" + entity_name = "project" + entity_name_plural = "projects" + key_field_name = "key" + model_class = Project def extract_content(self, entity: Project) -> tuple[str, str]: """Extract content from the project.""" @@ -21,9 +20,3 @@ def extract_content(self, entity: Project) -> tuple[str, str]: def get_base_queryset(self) -> QuerySet: """Return the base queryset with ordering.""" return super().get_base_queryset() - - def key_field_name(self) -> str: - return "key" - - def model_class(self) -> type[Model]: - return Project diff --git a/backend/apps/ai/management/commands/ai_create_slack_message_context.py b/backend/apps/ai/management/commands/ai_update_slack_message_context.py similarity index 79% rename from backend/apps/ai/management/commands/ai_create_slack_message_context.py rename to backend/apps/ai/management/commands/ai_update_slack_message_context.py index ecf8b28c5e..b6770e1d0e 100644 --- a/backend/apps/ai/management/commands/ai_create_slack_message_context.py +++ b/backend/apps/ai/management/commands/ai_update_slack_message_context.py @@ -1,12 +1,17 @@ """A command to update context for Slack message data.""" -from django.db.models import Model, QuerySet +from django.db.models import QuerySet from apps.ai.common.base.context_command import BaseContextCommand from apps.slack.models.message import Message class Command(BaseContextCommand): + entity_name = "message" + entity_name_plural = "messages" + key_field_name = "slack_message_id" + model_class = Message + def add_arguments(self, parser): """Override to use different default batch size for messages.""" parser.add_argument( @@ -26,12 +31,6 @@ def add_arguments(self, parser): help="Number of messages to process in each batch", ) - def entity_name(self) -> str: - return "message" - - def entity_name_plural(self) -> str: - return "messages" - def extract_content(self, entity: Message) -> tuple[str, str]: """Extract content from the message.""" return entity.cleaned_text or "", "" @@ -40,11 +39,5 @@ def get_default_queryset(self) -> QuerySet: """Return all messages by default since Message model doesn't have is_active field.""" return self.get_base_queryset() - def key_field_name(self) -> str: - return "slack_message_id" - - def model_class(self) -> type[Model]: - return Message - def source_name(self) -> str: return "slack_message" diff --git a/backend/apps/ai/migrations/0009_rename_object_id_context_entity_id_and_more.py b/backend/apps/ai/migrations/0009_rename_object_id_context_entity_id_and_more.py new file mode 100644 index 0000000000..e8f3b27aaa --- /dev/null +++ b/backend/apps/ai/migrations/0009_rename_object_id_context_entity_id_and_more.py @@ -0,0 +1,27 @@ +# Generated by Django 5.2.5 on 2025-08-15 08:40 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("ai", "0008_alter_context_unique_together_and_more"), + ("contenttypes", "0002_remove_content_type_name"), + ] + + operations = [ + migrations.RenameField( + model_name="context", + old_name="object_id", + new_name="entity_id", + ), + migrations.RenameField( + model_name="context", + old_name="content_type", + new_name="entity_type", + ), + migrations.AlterUniqueTogether( + name="context", + unique_together={("entity_type", "entity_id")}, + ), + ] diff --git a/backend/apps/ai/models/chunk.py b/backend/apps/ai/models/chunk.py index 9b57844c81..8dfcaf0022 100644 --- a/backend/apps/ai/models/chunk.py +++ b/backend/apps/ai/models/chunk.py @@ -62,8 +62,8 @@ def update_data( """ if Chunk.objects.filter( - context__content_type=context.content_type, - context__object_id=context.object_id, + context__entity_type=context.entity_type, + context__entity_id=context.entity_id, text=text, ).exists(): return None diff --git a/backend/apps/ai/models/context.py b/backend/apps/ai/models/context.py index 6031c2fec4..79d6e796c0 100644 --- a/backend/apps/ai/models/context.py +++ b/backend/apps/ai/models/context.py @@ -1,34 +1,45 @@ """AI app context model.""" +import logging + from django.contrib.contenttypes.fields import GenericForeignKey from django.contrib.contenttypes.models import ContentType from django.db import models from apps.common.models import TimestampedModel +logger = logging.getLogger(__name__) + + +def regenerate_chunks_for_context(context): + """Import regenerate_chunks_for_context to avoid circular import.""" + from apps.ai.common.utils import regenerate_chunks_for_context as _regenerate_chunks + + return _regenerate_chunks(context) + class Context(TimestampedModel): """Context model for storing generated text related to OWASP entities.""" content = models.TextField(verbose_name="Generated Text") - content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE) - object_id = models.PositiveIntegerField() - content_object = GenericForeignKey("content_type", "object_id") + entity_type = models.ForeignKey(ContentType, on_delete=models.CASCADE) + entity_id = models.PositiveIntegerField() + entity = GenericForeignKey("entity_type", "entity_id") source = models.CharField(max_length=100, blank=True, default="") class Meta: db_table = "ai_contexts" verbose_name = "Context" - unique_together = ("content_type", "object_id", "content") + unique_together = ("entity_type", "entity_id") def __str__(self): """Human readable representation.""" entity = ( - getattr(self.content_object, "name", None) - or getattr(self.content_object, "key", None) - or str(self.content_object) + getattr(self.entity, "name", None) + or getattr(self.entity, "key", None) + or str(self.entity) ) - return f"{self.content_type.model} {entity}: {self.content[:50]}" + return f"{self.entity_type.model} {entity}: {self.content[:50]}" @staticmethod def update_data( @@ -38,18 +49,18 @@ def update_data( *, save: bool = True, ) -> "Context": - """Retrieve existing or create new context.""" - content_type = ContentType.objects.get_for_model(content_object) - object_id = content_object.pk - existing_context = Context.objects.filter( - content_type=content_type, object_id=object_id, content=content - ).first() - if existing_context: - return existing_context - - context = Context(content=content, content_object=content_object, source=source) - - if save: - context.save() + """Create or update context for a given entity.""" + context, created = Context.objects.get_or_create( + entity_type=ContentType.objects.get_for_model(content_object), + entity_id=content_object.pk, + defaults={"content": content, "source": source}, + ) + + if not created and (context.content != content or context.source != source): + context.content = content + context.source = source + if save: + context.save(update_fields=["content", "source"]) + regenerate_chunks_for_context(context=context) return context diff --git a/backend/tests/apps/ai/agent/tools/rag/retriever_test.py b/backend/tests/apps/ai/agent/tools/rag/retriever_test.py index dff1563737..fe30aff809 100644 --- a/backend/tests/apps/ai/agent/tools/rag/retriever_test.py +++ b/backend/tests/apps/ai/agent/tools/rag/retriever_test.py @@ -418,7 +418,7 @@ def test_extract_content_types_from_query_no_matches(self): def test_supported_content_types(self): """Test that supported content types are defined correctly.""" - assert Retriever.SUPPORTED_CONTENT_TYPES == ( + assert Retriever.SUPPORTED_ENTITY_TYPES == ( "event", "project", "chapter", @@ -475,13 +475,14 @@ def test_retrieve_successful_with_chunks(self, mock_chunk): mock_content_object.name = "Test Chapter" mock_content_object.suggested_location = "New York" - mock_content_type = MagicMock() - mock_content_type.model = "chapter" + mock_entity_type = MagicMock() + mock_entity_type.model = "chapter" mock_context = MagicMock() mock_context.content_object = mock_content_object - mock_context.content_type = mock_content_type - mock_context.object_id = "123" + mock_context.entity = mock_content_object + mock_context.entity_type = mock_entity_type + mock_context.entity_id = "123" mock_chunk_instance = MagicMock() mock_chunk_instance.id = 1 diff --git a/backend/tests/apps/ai/common/base/ai_command_test.py b/backend/tests/apps/ai/common/base/ai_command_test.py index 2c634d47f4..a9493179a1 100644 --- a/backend/tests/apps/ai/common/base/ai_command_test.py +++ b/backend/tests/apps/ai/common/base/ai_command_test.py @@ -1,102 +1,76 @@ -"""Tests for the BaseAICommand class.""" - import os from unittest.mock import Mock, patch import pytest from django.core.management.base import BaseCommand -from django.db.models import Model, QuerySet from apps.ai.common.base.ai_command import BaseAICommand -class MockTestModel(Model): - """Test model for BaseAICommand testing.""" - - def __str__(self): - """Return string representation of MockTestModel.""" - return f"MockTestModel(pk={self.pk})" - - class Meta: - """Meta class for MockTestModel.""" - - app_label = "test" - - -class ConcreteAICommand(BaseAICommand): - """Concrete implementation of BaseAICommand for testing.""" - - def model_class(self): - return MockTestModel - - def entity_name(self): - return "test_entity" - - def entity_name_plural(self): - return "test_entities" +class MockTestModel: + """Mock model for testing.""" - def key_field_name(self): - return "test_key" - - def extract_content(self, entity): - return ("prose content", "metadata content") + objects = Mock() + pk = 1 @pytest.fixture def command(): - """Return a concrete command instance for testing.""" + """Fixture for ConcreteAICommand instance.""" return ConcreteAICommand() @pytest.fixture def mock_entity(): - """Return a mock entity instance.""" - entity = Mock(spec=MockTestModel) - entity.pk = 1 + """Fixture for mock entity with test_key attribute.""" + entity = Mock() entity.test_key = "test-key-123" - entity.is_active = True + entity.pk = 42 return entity @pytest.fixture def mock_queryset(): - """Return a mock queryset.""" - queryset = Mock(spec=QuerySet) - queryset.count.return_value = 5 - queryset.filter.return_value = queryset - queryset.__getitem__ = Mock(return_value=[]) - return queryset + """Fixture for mock queryset.""" + return Mock() + + +class ConcreteAICommand(BaseAICommand): + """Concrete implementation of BaseAICommand for testing.""" + + model_class = MockTestModel + entity_name = "test_entity" + entity_name_plural = "test_entities" + key_field_name = "test_key" + + def extract_content(self, entity): + return ("prose content", "metadata content") class TestBaseAICommand: """Test suite for the BaseAICommand class.""" def test_command_inheritance(self, command): - """Test that BaseAICommand inherits from BaseCommand.""" assert isinstance(command, BaseCommand) def test_initialization(self, command): - """Test command initialization.""" assert command.openai_client is None - def test_abstract_methods_implemented(self, command): - """Test that all abstract methods are properly implemented.""" - assert command.model_class() == MockTestModel - assert command.entity_name() == "test_entity" - assert command.entity_name_plural() == "test_entities" - assert command.key_field_name() == "test_key" + def test_abstract_attributes_implemented(self, command): + assert command.model_class == MockTestModel + assert command.entity_name == "test_entity" + assert command.entity_name_plural == "test_entities" + assert command.key_field_name == "test_key" mock_entity = Mock() result = command.extract_content(mock_entity) assert result == ("prose content", "metadata content") def test_source_name_default(self, command): - """Test default source_name implementation.""" result = command.source_name() assert result == "owasp_test_entity" def test_get_base_queryset(self, command): - """Test get_base_queryset method.""" with patch.object(MockTestModel, "objects") as mock_objects: mock_manager = Mock() mock_objects.all.return_value = mock_manager @@ -106,7 +80,6 @@ def test_get_base_queryset(self, command): mock_objects.all.assert_called_once() def test_get_default_queryset(self, command): - """Test get_default_queryset method.""" with patch.object(command, "get_base_queryset") as mock_base_qs: mock_queryset = Mock() mock_filtered_qs = Mock() @@ -120,7 +93,6 @@ def test_get_default_queryset(self, command): assert result == mock_filtered_qs def test_add_common_arguments(self, command): - """Test add_common_arguments method.""" parser = Mock() command.add_common_arguments(parser) @@ -143,7 +115,6 @@ def test_add_common_arguments(self, command): assert "Number of test_entities to process in each batch" in calls[2][1]["help"] def test_add_arguments_calls_common(self, command): - """Test that add_arguments calls add_common_arguments.""" parser = Mock() with patch.object(command, "add_common_arguments") as mock_add_common: @@ -151,7 +122,6 @@ def test_add_arguments_calls_common(self, command): mock_add_common.assert_called_once_with(parser) def test_get_queryset_with_entity_key(self, command): - """Test get_queryset with entity key option.""" options = {"test_entity_key": "test-key-123"} with patch.object(command, "get_base_queryset") as mock_base_qs: @@ -167,7 +137,6 @@ def test_get_queryset_with_entity_key(self, command): assert result == mock_filtered_qs def test_get_queryset_with_all_option(self, command): - """Test get_queryset with all option.""" options = {"all": True} with patch.object(command, "get_base_queryset") as mock_base_qs: @@ -180,7 +149,6 @@ def test_get_queryset_with_all_option(self, command): assert result == mock_queryset def test_get_queryset_default(self, command): - """Test get_queryset with default options.""" options = {} with patch.object(command, "get_default_queryset") as mock_default_qs: @@ -193,12 +161,10 @@ def test_get_queryset_default(self, command): assert result == mock_queryset def test_get_entity_key_with_key_field(self, command, mock_entity): - """Test get_entity_key with existing key field.""" result = command.get_entity_key(mock_entity) assert result == "test-key-123" def test_get_entity_key_fallback_to_pk(self, command): - """Test get_entity_key falls back to pk when key field doesn't exist.""" entity = Mock() entity.pk = 42 if hasattr(entity, "test_key"): @@ -210,7 +176,6 @@ def test_get_entity_key_fallback_to_pk(self, command): @patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-api-key"}) @patch("apps.ai.common.base.ai_command.openai.OpenAI") def test_setup_openai_client_success(self, mock_openai_class, command): - """Test successful OpenAI client setup.""" mock_client = Mock() mock_openai_class.return_value = mock_client @@ -222,7 +187,6 @@ def test_setup_openai_client_success(self, mock_openai_class, command): @patch.dict(os.environ, {}, clear=True) def test_setup_openai_client_no_api_key(self, command): - """Test OpenAI client setup without API key.""" if "DJANGO_OPEN_AI_SECRET_KEY" in os.environ: del os.environ["DJANGO_OPEN_AI_SECRET_KEY"] @@ -236,7 +200,6 @@ def test_setup_openai_client_no_api_key(self, command): assert "DJANGO_OPEN_AI_SECRET_KEY environment variable not set" in str(call_args) def test_handle_batch_processing_empty_queryset(self, command, mock_queryset): - """Test handle_batch_processing with empty queryset.""" mock_queryset.count.return_value = 0 process_batch_func = Mock() @@ -247,7 +210,6 @@ def test_handle_batch_processing_empty_queryset(self, command, mock_queryset): process_batch_func.assert_not_called() def test_handle_batch_processing_with_data(self, command): - """Test handle_batch_processing with data.""" mock_entities = [Mock() for _ in range(15)] mock_queryset = Mock() @@ -278,7 +240,6 @@ def mock_getitem(slice_obj): assert "Completed processing 15/15 test_entities" in str(write_calls[1]) def test_handle_batch_processing_partial_processing(self, command): - """Test handle_batch_processing when some items fail to process.""" mock_entities = [Mock() for _ in range(10)] mock_queryset = Mock() @@ -302,21 +263,3 @@ def mock_getitem(slice_obj): assert len(write_calls) == 2 assert "Found 10 test_entities to process" in str(write_calls[0]) assert "Completed processing 5/10 test_entities" in str(write_calls[1]) - - -class TestBaseAICommandAbstractMethods: - """Test that BaseAICommand abstract methods raise errors when not implemented.""" - - def test_cannot_instantiate_base_class_directly(self): - """Test that BaseAICommand cannot be instantiated directly.""" - with pytest.raises(TypeError): - BaseAICommand() - - def test_abstract_methods_must_be_implemented(self): - """Test that subclasses must implement all abstract methods.""" - - class IncompleteCommand(BaseAICommand): - """Incomplete implementation missing required methods.""" - - with pytest.raises(TypeError): - IncompleteCommand() diff --git a/backend/tests/apps/ai/common/base/chunk_command_test.py b/backend/tests/apps/ai/common/base/chunk_command_test.py index 852039b277..5ee912a589 100644 --- a/backend/tests/apps/ai/common/base/chunk_command_test.py +++ b/backend/tests/apps/ai/common/base/chunk_command_test.py @@ -1,5 +1,6 @@ """Tests for the BaseChunkCommand class.""" +from typing import Any from unittest.mock import Mock, patch import pytest @@ -14,28 +15,24 @@ class ConcreteChunkCommand(BaseChunkCommand): """Concrete implementation of BaseChunkCommand for testing.""" - def model_class(self): - mock_model = Mock() - mock_model.__name__ = "MockChunkTestModel" - return mock_model - - def entity_name(self): - return "test_entity" - - def entity_name_plural(self): - return "test_entities" - - def key_field_name(self): - return "test_key" + model_class: type[Any] = Mock # type: ignore[assignment] + entity_name = "test_entity" + entity_name_plural = "test_entities" + key_field_name = "test_key" def extract_content(self, entity): + """Extract content from entity.""" return ("prose content", "metadata content") @pytest.fixture def command(): """Return a concrete chunk command instance for testing.""" - return ConcreteChunkCommand() + cmd = ConcreteChunkCommand() + mock_model = Mock() + mock_model.__name__ = "MockChunkTestModel" + cmd.model_class = mock_model + return cmd @pytest.fixture @@ -93,11 +90,10 @@ def test_help_method(self, command): def test_abstract_methods_implemented(self, command): """Test that all abstract methods are properly implemented.""" - mock_model = command.model_class() - assert mock_model.__name__ == "MockChunkTestModel" - assert command.entity_name() == "test_entity" - assert command.entity_name_plural() == "test_entities" - assert command.key_field_name() == "test_key" + assert command.model_class.__name__ == "MockChunkTestModel" + assert command.entity_name == "test_entity" + assert command.entity_name_plural == "test_entities" + assert command.key_field_name == "test_key" mock_entity = Mock() result = command.extract_content(mock_entity) @@ -416,21 +412,3 @@ def test_process_chunks_batch_whitespace_only_content( mock_write.assert_called_once_with( "No content to chunk for test_entity test-key-123" ) - - -class TestBaseChunkCommandAbstractMethods: - """Test that BaseChunkCommand requires implementation of abstract methods.""" - - def test_cannot_instantiate_base_class_directly(self): - """Test that BaseChunkCommand cannot be instantiated directly.""" - with pytest.raises(TypeError): - BaseChunkCommand() - - def test_abstract_methods_must_be_implemented(self): - """Test that subclasses must implement all abstract methods.""" - - class IncompleteChunkCommand(BaseChunkCommand): - """Incomplete implementation missing required methods.""" - - with pytest.raises(TypeError): - IncompleteChunkCommand() diff --git a/backend/tests/apps/ai/common/base/context_command_test.py b/backend/tests/apps/ai/common/base/context_command_test.py index 9b88ef57c6..c267e074e4 100644 --- a/backend/tests/apps/ai/common/base/context_command_test.py +++ b/backend/tests/apps/ai/common/base/context_command_test.py @@ -1,5 +1,6 @@ """Tests for the BaseContextCommand class.""" +from typing import Any from unittest.mock import Mock, patch import pytest @@ -12,19 +13,10 @@ class ConcreteContextCommand(BaseContextCommand): """Concrete implementation of BaseContextCommand for testing.""" - def model_class(self): - mock_model = Mock() - mock_model.__name__ = "MockContextTestModel" - return mock_model - - def entity_name(self): - return "test_entity" - - def entity_name_plural(self): - return "test_entities" - - def key_field_name(self): - return "test_key" + model_class: type[Any] = Mock # type: ignore[assignment] + entity_name = "test_entity" + entity_name_plural = "test_entities" + key_field_name = "test_key" def extract_content(self, entity): return ("prose content", "metadata content") @@ -33,7 +25,11 @@ def extract_content(self, entity): @pytest.fixture def command(): """Return a concrete context command instance for testing.""" - return ConcreteContextCommand() + cmd = ConcreteContextCommand() + mock_model = Mock() + mock_model.__name__ = "MockContextTestModel" + cmd.model_class = mock_model + return cmd @pytest.fixture @@ -71,11 +67,10 @@ def test_help_method(self, command): def test_abstract_methods_implemented(self, command): """Test that all abstract methods are properly implemented.""" - mock_model = command.model_class() - assert mock_model.__name__ == "MockContextTestModel" - assert command.entity_name() == "test_entity" - assert command.entity_name_plural() == "test_entities" - assert command.key_field_name() == "test_key" + assert command.model_class.__name__ == "MockContextTestModel" + assert command.entity_name == "test_entity" + assert command.entity_name_plural == "test_entities" + assert command.key_field_name == "test_key" mock_entity = Mock() result = command.extract_content(mock_entity) @@ -306,21 +301,3 @@ def mock_extract_content(entity): write_calls = [str(call) for call in mock_write.call_args_list] assert any("No content for test_entity entity-2" in call for call in write_calls) - - -class TestBaseContextCommandAbstractMethods: - """Test that BaseContextCommand requires implementation of abstract methods.""" - - def test_cannot_instantiate_base_class_directly(self): - """Test that BaseContextCommand cannot be instantiated directly.""" - with pytest.raises(TypeError): - BaseContextCommand() - - def test_abstract_methods_must_be_implemented(self): - """Test that subclasses must implement all abstract methods.""" - - class IncompleteContextCommand(BaseContextCommand): - """Incomplete implementation missing required methods.""" - - with pytest.raises(TypeError): - IncompleteContextCommand() diff --git a/backend/tests/apps/ai/management/commands/ai_create_chapter_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_create_chapter_chunks_test.py index 5fd1c37077..e501580f9b 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_chapter_chunks_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_chapter_chunks_test.py @@ -26,16 +26,16 @@ def test_command_inheritance(self, command): def test_model_class_property(self, command): from apps.owasp.models.chapter import Chapter - assert command.model_class() == Chapter + assert command.model_class == Chapter def test_entity_name_property(self, command): - assert command.entity_name() == "chapter" + assert command.entity_name == "chapter" def test_entity_name_plural_property(self, command): - assert command.entity_name_plural() == "chapters" + assert command.entity_name_plural == "chapters" def test_key_field_name_property(self, command): - assert command.key_field_name() == "key" + assert command.key_field_name == "key" def test_extract_content(self, command, mock_chapter): with patch( diff --git a/backend/tests/apps/ai/management/commands/ai_create_committee_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_create_committee_chunks_test.py index c7380429fd..5d0a27f712 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_committee_chunks_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_committee_chunks_test.py @@ -37,19 +37,19 @@ def test_model_class_method(self, command): """Test the model_class method returns Committee.""" from apps.owasp.models.committee import Committee - assert command.model_class() == Committee + assert command.model_class == Committee def test_entity_name_method(self, command): """Test the entity_name method.""" - assert command.entity_name() == "committee" + assert command.entity_name == "committee" def test_entity_name_plural_method(self, command): """Test the entity_name_plural method.""" - assert command.entity_name_plural() == "committees" + assert command.entity_name_plural == "committees" def test_key_field_name_method(self, command): """Test the key_field_name method.""" - assert command.key_field_name() == "key" + assert command.key_field_name == "key" def test_extract_content_method(self, command, mock_committee): """Test the extract_content method.""" diff --git a/backend/tests/apps/ai/management/commands/ai_create_event_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_create_event_chunks_test.py index e74424dec1..b00cee7488 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_event_chunks_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_event_chunks_test.py @@ -34,19 +34,19 @@ def test_model_class_property(self, command): """Test the model_class property returns Event.""" from apps.owasp.models.event import Event - assert command.model_class() == Event + assert command.model_class == Event def test_entity_name_property(self, command): """Test the entity_name property.""" - assert command.entity_name() == "event" + assert command.entity_name == "event" def test_entity_name_plural_property(self, command): """Test the entity_name_plural property.""" - assert command.entity_name_plural() == "events" + assert command.entity_name_plural == "events" def test_key_field_name_property(self, command): """Test the key_field_name property.""" - assert command.key_field_name() == "key" + assert command.key_field_name == "key" def test_extract_content(self, command, mock_event): """Test content extraction from event.""" diff --git a/backend/tests/apps/ai/management/commands/ai_create_project_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_create_project_chunks_test.py index d80fc14eae..bc919abaa7 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_project_chunks_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_project_chunks_test.py @@ -26,16 +26,16 @@ def test_command_inheritance(self, command): def test_model_class_property(self, command): from apps.owasp.models.project import Project - assert command.model_class() == Project + assert command.model_class == Project def test_entity_name_property(self, command): - assert command.entity_name() == "project" + assert command.entity_name == "project" def test_entity_name_plural_property(self, command): - assert command.entity_name_plural() == "projects" + assert command.entity_name_plural == "projects" def test_key_field_name_property(self, command): - assert command.key_field_name() == "key" + assert command.key_field_name == "key" def test_extract_content(self, command, mock_project): with patch( diff --git a/backend/tests/apps/ai/management/commands/ai_create_slack_message_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_create_slack_message_chunks_test.py index fece9d9114..81ff0f73ae 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_slack_message_chunks_test.py +++ b/backend/tests/apps/ai/management/commands/ai_create_slack_message_chunks_test.py @@ -26,16 +26,16 @@ def test_command_inheritance(self, command): def test_model_class_property(self, command): from apps.slack.models.message import Message - assert command.model_class() == Message + assert command.model_class == Message def test_entity_name_property(self, command): - assert command.entity_name() == "message" + assert command.entity_name == "message" def test_entity_name_plural_property(self, command): - assert command.entity_name_plural() == "messages" + assert command.entity_name_plural == "messages" def test_key_field_name_property(self, command): - assert command.key_field_name() == "slack_message_id" + assert command.key_field_name == "slack_message_id" def test_source_name_property(self, command): """Test the source_name property.""" diff --git a/backend/tests/apps/ai/management/commands/ai_create_chapter_context_test.py b/backend/tests/apps/ai/management/commands/ai_update_chapter_context_test.py similarity index 83% rename from backend/tests/apps/ai/management/commands/ai_create_chapter_context_test.py rename to backend/tests/apps/ai/management/commands/ai_update_chapter_context_test.py index b25acaff0a..792e2f2f02 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_chapter_context_test.py +++ b/backend/tests/apps/ai/management/commands/ai_update_chapter_context_test.py @@ -4,7 +4,7 @@ import pytest -from apps.ai.management.commands.ai_create_chapter_context import Command +from apps.ai.management.commands.ai_update_chapter_context import Command @pytest.fixture @@ -35,24 +35,24 @@ def test_model_class_property(self, command): """Test the model_class property returns Chapter.""" from apps.owasp.models.chapter import Chapter - assert command.model_class() == Chapter + assert command.model_class == Chapter def test_entity_name_property(self, command): """Test the entity_name property.""" - assert command.entity_name() == "chapter" + assert command.entity_name == "chapter" def test_entity_name_plural_property(self, command): """Test the entity_name_plural property.""" - assert command.entity_name_plural() == "chapters" + assert command.entity_name_plural == "chapters" def test_key_field_name_property(self, command): """Test the key_field_name property.""" - assert command.key_field_name() == "key" + assert command.key_field_name == "key" def test_extract_content(self, command, mock_chapter): """Test the extract_content method.""" with patch( - "apps.ai.management.commands.ai_create_chapter_context.extract_chapter_content" + "apps.ai.management.commands.ai_update_chapter_context.extract_chapter_content" ) as mock_extract: mock_extract.return_value = ("prose content", "metadata content") content = command.extract_content(mock_chapter) diff --git a/backend/tests/apps/ai/management/commands/ai_create_committee_context_test.py b/backend/tests/apps/ai/management/commands/ai_update_committee_context_test.py similarity index 96% rename from backend/tests/apps/ai/management/commands/ai_create_committee_context_test.py rename to backend/tests/apps/ai/management/commands/ai_update_committee_context_test.py index 3d2090ed74..8cb4e302f2 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_committee_context_test.py +++ b/backend/tests/apps/ai/management/commands/ai_update_committee_context_test.py @@ -4,7 +4,7 @@ import pytest -from apps.ai.management.commands.ai_create_committee_context import Command +from apps.ai.management.commands.ai_update_committee_context import Command @pytest.fixture @@ -42,24 +42,24 @@ def test_model_class_method(self, command): """Test the model_class method returns Committee.""" from apps.owasp.models.committee import Committee - assert command.model_class() == Committee + assert command.model_class == Committee def test_entity_name_method(self, command): """Test the entity_name method.""" - assert command.entity_name() == "committee" + assert command.entity_name == "committee" def test_entity_name_plural_method(self, command): """Test the entity_name_plural method.""" - assert command.entity_name_plural() == "committees" + assert command.entity_name_plural == "committees" def test_key_field_name_method(self, command): """Test the key_field_name method.""" - assert command.key_field_name() == "key" + assert command.key_field_name == "key" def test_extract_content_method(self, command, mock_committee): """Test the extract_content method.""" with patch( - "apps.ai.management.commands.ai_create_committee_context.extract_committee_content" + "apps.ai.management.commands.ai_update_committee_context.extract_committee_content" ) as mock_extract: mock_extract.return_value = ("prose content", "metadata content") content = command.extract_content(mock_committee) diff --git a/backend/tests/apps/ai/management/commands/ai_create_event_context_test.py b/backend/tests/apps/ai/management/commands/ai_update_event_context_test.py similarity index 88% rename from backend/tests/apps/ai/management/commands/ai_create_event_context_test.py rename to backend/tests/apps/ai/management/commands/ai_update_event_context_test.py index 00a4a3fd53..8f4d2d12ce 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_event_context_test.py +++ b/backend/tests/apps/ai/management/commands/ai_update_event_context_test.py @@ -4,7 +4,7 @@ import pytest -from apps.ai.management.commands.ai_create_event_context import Command +from apps.ai.management.commands.ai_update_event_context import Command @pytest.fixture @@ -35,24 +35,24 @@ def test_model_class_property(self, command): """Test the model_class property returns Event.""" from apps.owasp.models.event import Event - assert command.model_class() == Event + assert command.model_class == Event def test_entity_name_property(self, command): """Test the entity_name property.""" - assert command.entity_name() == "event" + assert command.entity_name == "event" def test_entity_name_plural_property(self, command): """Test the entity_name_plural property.""" - assert command.entity_name_plural() == "events" + assert command.entity_name_plural == "events" def test_key_field_name_property(self, command): """Test the key_field_name property.""" - assert command.key_field_name() == "key" + assert command.key_field_name == "key" def test_extract_content(self, command, mock_event): """Test the extract_content method.""" with patch( - "apps.ai.management.commands.ai_create_event_context.extract_event_content" + "apps.ai.management.commands.ai_update_event_context.extract_event_content" ) as mock_extract: mock_extract.return_value = ("prose content", "metadata content") content = command.extract_content(mock_event) diff --git a/backend/tests/apps/ai/management/commands/ai_create_project_context_test.py b/backend/tests/apps/ai/management/commands/ai_update_project_context_test.py similarity index 85% rename from backend/tests/apps/ai/management/commands/ai_create_project_context_test.py rename to backend/tests/apps/ai/management/commands/ai_update_project_context_test.py index 80976f99d5..82df257634 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_project_context_test.py +++ b/backend/tests/apps/ai/management/commands/ai_update_project_context_test.py @@ -2,7 +2,7 @@ import pytest -from apps.ai.management.commands.ai_create_project_context import Command +from apps.ai.management.commands.ai_update_project_context import Command @pytest.fixture @@ -33,24 +33,24 @@ def test_model_class_property(self, command): """Test the model_class property returns Project.""" from apps.owasp.models.project import Project - assert command.model_class() == Project + assert command.model_class == Project def test_entity_name_property(self, command): """Test the entity_name property.""" - assert command.entity_name() == "project" + assert command.entity_name == "project" def test_entity_name_plural_property(self, command): """Test the entity_name_plural property.""" - assert command.entity_name_plural() == "projects" + assert command.entity_name_plural == "projects" def test_key_field_name_property(self, command): """Test the key_field_name property.""" - assert command.key_field_name() == "key" + assert command.key_field_name == "key" def test_extract_content(self, command, mock_project): """Test the extract_content method.""" with patch( - "apps.ai.management.commands.ai_create_project_context.extract_project_content" + "apps.ai.management.commands.ai_update_project_context.extract_project_content" ) as mock_extract: mock_extract.return_value = ("prose content", "metadata content") content = command.extract_content(mock_project) diff --git a/backend/tests/apps/ai/management/commands/ai_create_slack_message_context_test.py b/backend/tests/apps/ai/management/commands/ai_update_slack_message_context_test.py similarity index 91% rename from backend/tests/apps/ai/management/commands/ai_create_slack_message_context_test.py rename to backend/tests/apps/ai/management/commands/ai_update_slack_message_context_test.py index 961826dc42..ce93d42c87 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_slack_message_context_test.py +++ b/backend/tests/apps/ai/management/commands/ai_update_slack_message_context_test.py @@ -2,7 +2,7 @@ import pytest -from apps.ai.management.commands.ai_create_slack_message_context import Command +from apps.ai.management.commands.ai_update_slack_message_context import Command @pytest.fixture @@ -29,19 +29,19 @@ def test_model_class_property(self, command): """Test the model_class property returns Message.""" from apps.slack.models.message import Message - assert command.model_class() == Message + assert command.model_class == Message def test_entity_name_property(self, command): """Test the entity_name property.""" - assert command.entity_name() == "message" + assert command.entity_name == "message" def test_entity_name_plural_property(self, command): """Test the entity_name_plural property.""" - assert command.entity_name_plural() == "messages" + assert command.entity_name_plural == "messages" def test_key_field_name_property(self, command): """Test the key_field_name property.""" - assert command.key_field_name() == "slack_message_id" + assert command.key_field_name == "slack_message_id" def test_source_name_property(self, command): """Test the source_name property.""" diff --git a/backend/tests/apps/ai/models/chunk_test.py b/backend/tests/apps/ai/models/chunk_test.py index 6571aea959..d3c1fc61c9 100644 --- a/backend/tests/apps/ai/models/chunk_test.py +++ b/backend/tests/apps/ai/models/chunk_test.py @@ -10,6 +10,8 @@ def mock_context(): mock = Mock(spec=Context) mock.id = 1 + mock.entity_type = Mock() + mock.entity_id = 1 return mock @@ -87,8 +89,8 @@ def test_update_data_creates_new_chunk_and_saves(self, mock_context): ) mock_chunk_class.objects.filter.assert_called_once_with( - context__content_type=mock_context.content_type, - context__object_id=mock_context.object_id, + context__entity_type=mock_context.entity_type, + context__entity_id=mock_context.entity_id, text=text, ) mock_chunk_class.assert_called_once_with( @@ -112,8 +114,8 @@ def test_update_data_creates_new_chunk_no_save(self, mock_context): ) mock_chunk_class.objects.filter.assert_called_once_with( - context__content_type=mock_context.content_type, - context__object_id=mock_context.object_id, + context__entity_type=mock_context.entity_type, + context__entity_id=mock_context.entity_id, text=text, ) mock_chunk_class.assert_called_once_with( @@ -135,8 +137,8 @@ def test_update_data_returns_none_if_chunk_already_exists(self, mock_context): ) mock_chunk_class.objects.filter.assert_called_once_with( - context__content_type=mock_context.content_type, - context__object_id=mock_context.object_id, + context__entity_type=mock_context.entity_type, + context__entity_id=mock_context.entity_id, text=text, ) diff --git a/backend/tests/apps/ai/models/context_test.py b/backend/tests/apps/ai/models/context_test.py index 86dae0de21..fdb74379b7 100644 --- a/backend/tests/apps/ai/models/context_test.py +++ b/backend/tests/apps/ai/models/context_test.py @@ -1,10 +1,12 @@ """Unit tests for AI app context model.""" -from unittest.mock import Mock, patch +from unittest.mock import Mock, PropertyMock, patch import pytest +from django.contrib.contenttypes.models import ContentType from apps.ai.models.context import Context +from apps.common.models import TimestampedModel def create_model_mock(model_class): @@ -12,6 +14,8 @@ def create_model_mock(model_class): mock._state = Mock() mock.pk = 1 mock.id = 1 + mock.chunks = Mock() + mock.chunks.count.return_value = 0 return mock @@ -26,14 +30,14 @@ def test_content_field_properties(self): assert field.__class__.__name__ == "TextField" def test_content_type_field_properties(self): - field = Context._meta.get_field("content_type") + field = Context._meta.get_field("entity_type") assert field.null is False assert field.blank is False assert hasattr(field, "remote_field") assert field.remote_field.on_delete.__name__ == "CASCADE" def test_object_id_field_properties(self): - field = Context._meta.get_field("object_id") + field = Context._meta.get_field("entity_id") assert field.__class__.__name__ == "PositiveIntegerField" def test_source_field_properties(self): @@ -56,8 +60,6 @@ def test_context_creation_with_save(self, mock_init, mock_save): mock_save.assert_called_once() def test_context_inheritance_from_timestamped_model(self): - from apps.common.models import TimestampedModel - assert issubclass(Context, TimestampedModel) @patch("apps.ai.models.context.Context.objects.create") @@ -137,9 +139,11 @@ def test_context_delete(self, mock_delete): mock_delete.assert_called_once() @patch("apps.ai.models.context.Context.objects.filter") - def test_update_data_existing_context(self, mock_filter): + @patch("apps.ai.models.context.Context.objects.get_or_create") + @patch("apps.ai.models.context.regenerate_chunks_for_context") + def test_update_data_existing_context(self, mock_regenerate, mock_get_or_create, mock_filter): mock_context = create_model_mock(Context) - mock_filter.return_value.first.return_value = mock_context + mock_get_or_create.return_value = (mock_context, False) content = "Test" mock_content_object = Mock() @@ -154,28 +158,43 @@ def test_update_data_existing_context(self, mock_filter): result = Context.update_data(content, mock_content_object, source="src", save=True) mock_get_for_model.assert_called_once_with(mock_content_object) - mock_filter.assert_called_once_with( - content_type=mock_content_type, object_id=1, content=content + mock_get_or_create.assert_called_once_with( + entity_type=mock_content_type, + entity_id=1, + defaults={"content": content, "source": "src"}, ) + mock_regenerate.assert_called_once_with(context=mock_context) assert result == mock_context def test_str_method_with_name_attribute(self): - """Test __str__ method when content_object has name attribute.""" + """Test __str__ method when entity has name attribute.""" content_object = Mock() content_object.name = "Test Object" - content_type = Mock() - content_type.model = "test_model" + entity_type = Mock(spec=ContentType) + entity_type.model = "test_model" + + context = Context() + context.content = ( + "This is test content that is longer than 50 characters to test truncation" + ) + context.entity_type_id = 1 + context.entity_id = "123" with ( - patch.object(Context, "content_object", content_object), - patch.object(Context, "content_type", content_type), + patch.object( + type(context), + "entity", + new_callable=PropertyMock, + return_value=content_object, + ), + patch.object( + type(context), + "entity_type", + new_callable=PropertyMock, + return_value=entity_type, + ), ): - context = Context() - context.content = ( - "This is test content that is longer than 50 characters to test truncation" - ) - result = str(context) assert ( result @@ -183,51 +202,105 @@ def test_str_method_with_name_attribute(self): ) def test_str_method_with_key_attribute(self): - """Test __str__ method when content_object has key but no name attribute.""" + """Test __str__ method when entity has key but no name attribute.""" content_object = Mock() content_object.name = None content_object.key = "test-key" - content_type = Mock() - content_type.model = "test_model" + entity_type = Mock(spec=ContentType) + entity_type.model = "test_model" + + context = Context() + context.content = "Short content" + context.entity_type_id = 1 + context.entity_id = "123" with ( - patch.object(Context, "content_object", content_object), - patch.object(Context, "content_type", content_type), + patch.object( + type(context), + "entity", + new_callable=PropertyMock, + return_value=content_object, + ), + patch.object( + type(context), + "entity_type", + new_callable=PropertyMock, + return_value=entity_type, + ), ): - context = Context() - context.content = "Short content" - result = str(context) assert result == "test_model test-key: Short content" + def test_str_method_with_neither_name_nor_key(self): + """Test __str__ method when entity has neither name nor key attribute.""" + content_object = Mock() + content_object.name = None + content_object.key = None + content_object.__str__ = Mock(return_value="Unknown") + + entity_type = Mock(spec=ContentType) + entity_type.model = "test_model" + + context = Context() + context.content = "Another test content" + context.entity_type_id = 1 + context.entity_id = "456" + + with ( + patch.object( + type(context), + "entity", + new_callable=PropertyMock, + return_value=content_object, + ), + patch.object( + type(context), + "entity_type", + new_callable=PropertyMock, + return_value=entity_type, + ), + ): + result = str(context) + assert result == "test_model Unknown: Another test content" + def test_str_method_fallback_to_str(self): - """Test __str__ method falls back to str(content_object).""" + """Test __str__ method falls back to str(entity).""" content_object = Mock() content_object.name = None content_object.key = None content_object.__str__ = Mock(return_value="String representation") - content_type = Mock() - content_type.model = "test_model" + entity_type = Mock(spec=ContentType) + entity_type.model = "test_model" + + context = Context() + context.content = "Test content" + context.entity_type_id = 1 + context.entity_id = "123" with ( - patch.object(Context, "content_object", content_object), - patch.object(Context, "content_type", content_type), + patch.object( + type(context), + "entity", + new_callable=PropertyMock, + return_value=content_object, + ), + patch.object( + type(context), + "entity_type", + new_callable=PropertyMock, + return_value=entity_type, + ), ): - context = Context() - context.content = "Test content" - result = str(context) assert result == "test_model String representation: Test content" - @patch("apps.ai.models.context.Context.objects.filter") - @patch("apps.ai.models.context.Context.__init__") - @patch("apps.ai.models.context.Context.save") - def test_update_data_new_context_with_save(self, mock_save, mock_init, mock_filter): + @patch("apps.ai.models.context.Context.objects.get_or_create") + def test_update_data_new_context_with_save(self, mock_get_or_create): """Test update_data creating a new context with save=True.""" - mock_filter.return_value.first.return_value = None - mock_init.return_value = None + mock_context = create_model_mock(Context) + mock_get_or_create.return_value = (mock_context, True) content = "New test content" mock_content_object = Mock() @@ -243,18 +316,18 @@ def test_update_data_new_context_with_save(self, mock_save, mock_init, mock_filt result = Context.update_data(content, mock_content_object, source=source, save=True) mock_get_for_model.assert_called_once_with(mock_content_object) - mock_filter.assert_called_once_with( - content_type=mock_content_type, object_id=1, content=content + mock_get_or_create.assert_called_once_with( + entity_type=mock_content_type, + entity_id=1, + defaults={"content": content, "source": source}, ) - mock_save.assert_called_once() - assert isinstance(result, Context) + assert result == mock_context - @patch("apps.ai.models.context.Context.objects.filter") - @patch("apps.ai.models.context.Context.__init__") - def test_update_data_new_context_without_save(self, mock_init, mock_filter): + @patch("apps.ai.models.context.Context.objects.get_or_create") + def test_update_data_new_context_without_save(self, mock_get_or_create): """Test update_data creating a new context with save=False.""" - mock_filter.return_value.first.return_value = None - mock_init.return_value = None + mock_context = create_model_mock(Context) + mock_get_or_create.return_value = (mock_context, True) content = "New test content" mock_content_object = Mock() @@ -267,14 +340,12 @@ def test_update_data_new_context_without_save(self, mock_init, mock_filter): mock_content_type = Mock() mock_get_for_model.return_value = mock_content_type - with patch("apps.ai.models.context.Context.save") as mock_save: - result = Context.update_data( - content, mock_content_object, source=source, save=False - ) - - mock_get_for_model.assert_called_once_with(mock_content_object) - mock_filter.assert_called_once_with( - content_type=mock_content_type, object_id=1, content=content - ) - mock_save.assert_not_called() - assert isinstance(result, Context) + result = Context.update_data(content, mock_content_object, source=source, save=False) + + mock_get_for_model.assert_called_once_with(mock_content_object) + mock_get_or_create.assert_called_once_with( + entity_type=mock_content_type, + entity_id=1, + defaults={"content": content, "source": source}, + ) + assert result == mock_context From 2d86dcbe20cd9e90f213c7ebbbcfa8351b296307 Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Sat, 16 Aug 2025 17:56:56 +0530 Subject: [PATCH 24/32] code rabbit suggestions --- backend/apps/ai/Makefile | 40 +++++++++---------- backend/apps/ai/common/utils.py | 6 +-- .../ai_update_slack_message_context.py | 1 + .../ai_update_slack_message_context_test.py | 19 ++++++++- 4 files changed, 41 insertions(+), 25 deletions(-) diff --git a/backend/apps/ai/Makefile b/backend/apps/ai/Makefile index 948abeb5f8..d4d918aa81 100644 --- a/backend/apps/ai/Makefile +++ b/backend/apps/ai/Makefile @@ -2,42 +2,42 @@ ai-create-chapter-chunks: @echo "Creating chapter chunks" @CMD="python manage.py ai_create_chapter_chunks" $(MAKE) exec-backend-command -ai-create-chapter-context: - @echo "Creating chapter context" - @CMD="python manage.py ai_create_chapter_context" $(MAKE) exec-backend-command - ai-create-committee-chunks: @echo "Creating committee chunks" @CMD="python manage.py ai_create_committee_chunks" $(MAKE) exec-backend-command -ai-create-committee-context: - @echo "Creating committee context" - @CMD="python manage.py ai_create_committee_context" $(MAKE) exec-backend-command - ai-create-event-chunks: @echo "Creating event chunks" @CMD="python manage.py ai_create_event_chunks" $(MAKE) exec-backend-command -ai-create-event-context: - @echo "Creating event context" - @CMD="python manage.py ai_create_event_context" $(MAKE) exec-backend-command - ai-create-project-chunks: @echo "Creating project chunks" @CMD="python manage.py ai_create_project_chunks" $(MAKE) exec-backend-command -ai-create-project-context: - @echo "Creating project context" - @CMD="python manage.py ai_create_project_context" $(MAKE) exec-backend-command - ai-create-slack-message-chunks: @echo "Creating Slack message chunks" @CMD="python manage.py ai_create_slack_message_chunks" $(MAKE) exec-backend-command -ai-create-slack-message-context: - @echo "Creating Slack message context" - @CMD="python manage.py ai_create_slack_message_context" $(MAKE) exec-backend-command - ai-run-rag-tool: @echo "Running RAG tool" @CMD="python manage.py ai_run_rag_tool" $(MAKE) exec-backend-command + +ai-update-chapter-context: + @echo "Updating chapter context" + @CMD="python manage.py ai_update_chapter_context" $(MAKE) exec-backend-command + +ai-update-committee-context: + @echo "Updating committee context" + @CMD="python manage.py ai_update_committee_context" $(MAKE) exec-backend-command + +ai-update-event-context: + @echo "Updating event context" + @CMD="python manage.py ai_update_event_context" $(MAKE) exec-backend-command + +ai-update-project-context: + @echo "Updating project context" + @CMD="python manage.py ai_update_project_context" $(MAKE) exec-backend-command + +ai-update-slack-message-context: + @echo "Updating Slack message context" + @CMD="python manage.py ai_update_slack_message_context" $(MAKE) exec-backend-command diff --git a/backend/apps/ai/common/utils.py b/backend/apps/ai/common/utils.py index b9012ba252..9d02f7ede2 100644 --- a/backend/apps/ai/common/utils.py +++ b/backend/apps/ai/common/utils.py @@ -4,7 +4,7 @@ import time from datetime import UTC, datetime, timedelta -import openai +from openai import OpenAI, OpenAIError from apps.ai.common.constants import ( DEFAULT_LAST_REQUEST_OFFSET_SECONDS, @@ -63,7 +63,7 @@ def create_chunks_and_embeddings( if chunk is not None: chunks.append(chunk) - except openai.OpenAIError: + except OpenAIError: logger.exception("Failed to create chunks and embeddings") return [] else: @@ -89,7 +89,7 @@ def regenerate_chunks_for_context(context: Context): logger.warning("No content to chunk for Context. Process stopped.") return - openai_client = openai.Client() + openai_client = OpenAI() create_chunks_and_embeddings( chunk_texts=new_chunk_texts, diff --git a/backend/apps/ai/management/commands/ai_update_slack_message_context.py b/backend/apps/ai/management/commands/ai_update_slack_message_context.py index b6770e1d0e..67bb26d687 100644 --- a/backend/apps/ai/management/commands/ai_update_slack_message_context.py +++ b/backend/apps/ai/management/commands/ai_update_slack_message_context.py @@ -14,6 +14,7 @@ class Command(BaseContextCommand): def add_arguments(self, parser): """Override to use different default batch size for messages.""" + super().add_arguments(parser) parser.add_argument( "--message-key", type=str, diff --git a/backend/tests/apps/ai/management/commands/ai_update_slack_message_context_test.py b/backend/tests/apps/ai/management/commands/ai_update_slack_message_context_test.py index ce93d42c87..2d6ea5fec3 100644 --- a/backend/tests/apps/ai/management/commands/ai_update_slack_message_context_test.py +++ b/backend/tests/apps/ai/management/commands/ai_update_slack_message_context_test.py @@ -72,9 +72,10 @@ def test_add_arguments(self, command): parser = Mock() command.add_arguments(parser) - assert parser.add_argument.call_count == 3 + assert parser.add_argument.call_count == 6 calls = parser.add_argument.call_args_list + # First 3 calls are from parent class (BaseAICommand) assert calls[0][0] == ("--message-key",) assert calls[0][1]["type"] is str assert "Process only the message with this key" in calls[0][1]["help"] @@ -85,5 +86,19 @@ def test_add_arguments(self, command): assert calls[2][0] == ("--batch-size",) assert calls[2][1]["type"] is int - assert calls[2][1]["default"] == 100 + assert calls[2][1]["default"] == 50 # Default from parent class assert "Number of messages to process in each batch" in calls[2][1]["help"] + + # Next 3 calls are from the command itself (duplicates with different defaults) + assert calls[3][0] == ("--message-key",) + assert calls[3][1]["type"] is str + assert "Process only the message with this key" in calls[3][1]["help"] + + assert calls[4][0] == ("--all",) + assert calls[4][1]["action"] == "store_true" + assert "Process all the messages" in calls[4][1]["help"] + + assert calls[5][0] == ("--batch-size",) + assert calls[5][1]["type"] is int + assert calls[5][1]["default"] == 100 # Overridden default from command + assert "Number of messages to process in each batch" in calls[5][1]["help"] From 011e843ab84d601f1b3d708d1d1aeb7102ece3cb Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Sun, 17 Aug 2025 17:16:13 +0530 Subject: [PATCH 25/32] before tests and question --- backend/apps/ai/common/base/ai_command.py | 4 ++-- .../apps/ai/management/commands/ai_create_chapter_chunks.py | 2 -- .../apps/ai/management/commands/ai_create_committee_chunks.py | 2 -- backend/apps/ai/management/commands/ai_create_event_chunks.py | 2 -- .../apps/ai/management/commands/ai_create_project_chunks.py | 2 -- .../ai/management/commands/ai_create_slack_message_chunks.py | 2 -- .../apps/ai/management/commands/ai_update_chapter_context.py | 2 -- .../ai/management/commands/ai_update_committee_context.py | 2 -- .../apps/ai/management/commands/ai_update_event_context.py | 2 -- .../apps/ai/management/commands/ai_update_project_context.py | 2 -- .../ai/management/commands/ai_update_slack_message_context.py | 2 -- backend/apps/ai/models/context.py | 1 - 12 files changed, 2 insertions(+), 23 deletions(-) diff --git a/backend/apps/ai/common/base/ai_command.py b/backend/apps/ai/common/base/ai_command.py index fea78193af..30337603e6 100644 --- a/backend/apps/ai/common/base/ai_command.py +++ b/backend/apps/ai/common/base/ai_command.py @@ -13,14 +13,14 @@ class BaseAICommand(BaseCommand): """Base class for AI management commands with common functionality.""" model_class: type[Model] - entity_name: str - entity_name_plural: str key_field_name: str def __init__(self, *args, **kwargs): """Initialize the AI command with OpenAI client placeholder.""" super().__init__(*args, **kwargs) self.openai_client: openai.OpenAI | None = None + self.entity_name = self.model_class.__name__.lower() + self.entity_name_plural = self.model_class.__name__.lower() + "s" def source_name(self) -> str: """Return the source name for context creation. Override if different from default.""" diff --git a/backend/apps/ai/management/commands/ai_create_chapter_chunks.py b/backend/apps/ai/management/commands/ai_create_chapter_chunks.py index 9b48c4dba3..cee3d355c5 100644 --- a/backend/apps/ai/management/commands/ai_create_chapter_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_chapter_chunks.py @@ -6,8 +6,6 @@ class Command(BaseChunkCommand): - entity_name = "chapter" - entity_name_plural = "chapters" key_field_name = "key" model_class = Chapter diff --git a/backend/apps/ai/management/commands/ai_create_committee_chunks.py b/backend/apps/ai/management/commands/ai_create_committee_chunks.py index df03d69856..611dba01fb 100644 --- a/backend/apps/ai/management/commands/ai_create_committee_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_committee_chunks.py @@ -6,8 +6,6 @@ class Command(BaseChunkCommand): - entity_name = "committee" - entity_name_plural = "committees" key_field_name = "key" model_class = Committee diff --git a/backend/apps/ai/management/commands/ai_create_event_chunks.py b/backend/apps/ai/management/commands/ai_create_event_chunks.py index 40e3cda49b..fa5bcbf5c4 100644 --- a/backend/apps/ai/management/commands/ai_create_event_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_event_chunks.py @@ -8,8 +8,6 @@ class Command(BaseChunkCommand): - entity_name = "event" - entity_name_plural = "events" key_field_name = "key" model_class = Event diff --git a/backend/apps/ai/management/commands/ai_create_project_chunks.py b/backend/apps/ai/management/commands/ai_create_project_chunks.py index 673c28f252..132e8fad0b 100644 --- a/backend/apps/ai/management/commands/ai_create_project_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_project_chunks.py @@ -8,8 +8,6 @@ class Command(BaseChunkCommand): - entity_name = "project" - entity_name_plural = "projects" key_field_name = "key" model_class = Project diff --git a/backend/apps/ai/management/commands/ai_create_slack_message_chunks.py b/backend/apps/ai/management/commands/ai_create_slack_message_chunks.py index 3f070a7600..51985a2adf 100644 --- a/backend/apps/ai/management/commands/ai_create_slack_message_chunks.py +++ b/backend/apps/ai/management/commands/ai_create_slack_message_chunks.py @@ -7,8 +7,6 @@ class Command(BaseChunkCommand): - entity_name = "message" - entity_name_plural = "messages" key_field_name = "slack_message_id" model_class = Message diff --git a/backend/apps/ai/management/commands/ai_update_chapter_context.py b/backend/apps/ai/management/commands/ai_update_chapter_context.py index 550d05c862..2c72f89c84 100644 --- a/backend/apps/ai/management/commands/ai_update_chapter_context.py +++ b/backend/apps/ai/management/commands/ai_update_chapter_context.py @@ -6,8 +6,6 @@ class Command(BaseContextCommand): - entity_name = "chapter" - entity_name_plural = "chapters" key_field_name = "key" model_class = Chapter diff --git a/backend/apps/ai/management/commands/ai_update_committee_context.py b/backend/apps/ai/management/commands/ai_update_committee_context.py index 405d1f31f5..4b3bf29cda 100644 --- a/backend/apps/ai/management/commands/ai_update_committee_context.py +++ b/backend/apps/ai/management/commands/ai_update_committee_context.py @@ -6,8 +6,6 @@ class Command(BaseContextCommand): - entity_name = "committee" - entity_name_plural = "committees" key_field_name = "key" model_class = Committee diff --git a/backend/apps/ai/management/commands/ai_update_event_context.py b/backend/apps/ai/management/commands/ai_update_event_context.py index d07b942d44..15232a773a 100644 --- a/backend/apps/ai/management/commands/ai_update_event_context.py +++ b/backend/apps/ai/management/commands/ai_update_event_context.py @@ -8,8 +8,6 @@ class Command(BaseContextCommand): - entity_name = "event" - entity_name_plural = "events" key_field_name = "key" model_class = Event diff --git a/backend/apps/ai/management/commands/ai_update_project_context.py b/backend/apps/ai/management/commands/ai_update_project_context.py index aa594085c6..d1aede6d98 100644 --- a/backend/apps/ai/management/commands/ai_update_project_context.py +++ b/backend/apps/ai/management/commands/ai_update_project_context.py @@ -8,8 +8,6 @@ class Command(BaseContextCommand): - entity_name = "project" - entity_name_plural = "projects" key_field_name = "key" model_class = Project diff --git a/backend/apps/ai/management/commands/ai_update_slack_message_context.py b/backend/apps/ai/management/commands/ai_update_slack_message_context.py index 67bb26d687..c89b253692 100644 --- a/backend/apps/ai/management/commands/ai_update_slack_message_context.py +++ b/backend/apps/ai/management/commands/ai_update_slack_message_context.py @@ -7,8 +7,6 @@ class Command(BaseContextCommand): - entity_name = "message" - entity_name_plural = "messages" key_field_name = "slack_message_id" model_class = Message diff --git a/backend/apps/ai/models/context.py b/backend/apps/ai/models/context.py index 79d6e796c0..feef8c97a0 100644 --- a/backend/apps/ai/models/context.py +++ b/backend/apps/ai/models/context.py @@ -61,6 +61,5 @@ def update_data( context.source = source if save: context.save(update_fields=["content", "source"]) - regenerate_chunks_for_context(context=context) return context From 466bca3595efe821310a695402a45a8df0c941a8 Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Mon, 18 Aug 2025 05:41:02 +0530 Subject: [PATCH 26/32] sugesstions and decoupling with tests --- backend/apps/ai/Makefile | 40 ++-- backend/apps/ai/common/base/chunk_command.py | 72 ++++--- ..._chunks.py => ai_update_chapter_chunks.py} | 0 ...hunks.py => ai_update_committee_chunks.py} | 0 ...nt_chunks.py => ai_update_event_chunks.py} | 0 ..._chunks.py => ai_update_project_chunks.py} | 0 ...s.py => ai_update_slack_message_chunks.py} | 0 backend/apps/ai/models/context.py | 2 +- .../apps/ai/common/base/ai_command_test.py | 6 +- .../apps/ai/common/base/chunk_command_test.py | 42 +++- .../ai/common/base/context_command_test.py | 6 +- .../apps/ai/common/extractors/project_test.py | 72 +++++++ backend/tests/apps/ai/common/utils_test.py | 196 +++++++++++++++++- ...st.py => ai_update_chapter_chunks_test.py} | 4 +- ....py => ai_update_committee_chunks_test.py} | 6 +- ...test.py => ai_update_event_chunks_test.py} | 4 +- ...st.py => ai_update_project_chunks_test.py} | 4 +- ...=> ai_update_slack_message_chunks_test.py} | 2 +- backend/tests/apps/ai/models/context_test.py | 4 +- backend/tests/apps/common/open_ai_test.py | 3 - 20 files changed, 383 insertions(+), 80 deletions(-) rename backend/apps/ai/management/commands/{ai_create_chapter_chunks.py => ai_update_chapter_chunks.py} (100%) rename backend/apps/ai/management/commands/{ai_create_committee_chunks.py => ai_update_committee_chunks.py} (100%) rename backend/apps/ai/management/commands/{ai_create_event_chunks.py => ai_update_event_chunks.py} (100%) rename backend/apps/ai/management/commands/{ai_create_project_chunks.py => ai_update_project_chunks.py} (100%) rename backend/apps/ai/management/commands/{ai_create_slack_message_chunks.py => ai_update_slack_message_chunks.py} (100%) rename backend/tests/apps/ai/management/commands/{ai_create_chapter_chunks_test.py => ai_update_chapter_chunks_test.py} (91%) rename backend/tests/apps/ai/management/commands/{ai_create_committee_chunks_test.py => ai_update_committee_chunks_test.py} (96%) rename backend/tests/apps/ai/management/commands/{ai_create_event_chunks_test.py => ai_update_event_chunks_test.py} (95%) rename backend/tests/apps/ai/management/commands/{ai_create_project_chunks_test.py => ai_update_project_chunks_test.py} (93%) rename backend/tests/apps/ai/management/commands/{ai_create_slack_message_chunks_test.py => ai_update_slack_message_chunks_test.py} (97%) diff --git a/backend/apps/ai/Makefile b/backend/apps/ai/Makefile index d4d918aa81..3243269378 100644 --- a/backend/apps/ai/Makefile +++ b/backend/apps/ai/Makefile @@ -1,43 +1,43 @@ -ai-create-chapter-chunks: - @echo "Creating chapter chunks" - @CMD="python manage.py ai_create_chapter_chunks" $(MAKE) exec-backend-command - -ai-create-committee-chunks: - @echo "Creating committee chunks" - @CMD="python manage.py ai_create_committee_chunks" $(MAKE) exec-backend-command - -ai-create-event-chunks: - @echo "Creating event chunks" - @CMD="python manage.py ai_create_event_chunks" $(MAKE) exec-backend-command - -ai-create-project-chunks: - @echo "Creating project chunks" - @CMD="python manage.py ai_create_project_chunks" $(MAKE) exec-backend-command - -ai-create-slack-message-chunks: - @echo "Creating Slack message chunks" - @CMD="python manage.py ai_create_slack_message_chunks" $(MAKE) exec-backend-command - ai-run-rag-tool: @echo "Running RAG tool" @CMD="python manage.py ai_run_rag_tool" $(MAKE) exec-backend-command +ai-update-chapter-chunks: + @echo "Updating chapter chunks" + @CMD="python manage.py ai_update_chapter_chunks" $(MAKE) exec-backend-command + ai-update-chapter-context: @echo "Updating chapter context" @CMD="python manage.py ai_update_chapter_context" $(MAKE) exec-backend-command +ai-update-committee-chunks: + @echo "Updating committee chunks" + @CMD="python manage.py ai_update_committee_chunks" $(MAKE) exec-backend-command + ai-update-committee-context: @echo "Updating committee context" @CMD="python manage.py ai_update_committee_context" $(MAKE) exec-backend-command +ai-update-event-chunks: + @echo "Updating event chunks" + @CMD="python manage.py ai_update_event_chunks" $(MAKE) exec-backend-command + ai-update-event-context: @echo "Updating event context" @CMD="python manage.py ai_update_event_context" $(MAKE) exec-backend-command +ai-update-project-chunks: + @echo "Updating project chunks" + @CMD="python manage.py ai_update_project_chunks" $(MAKE) exec-backend-command + ai-update-project-context: @echo "Updating project context" @CMD="python manage.py ai_update_project_context" $(MAKE) exec-backend-command +ai-update-slack-message-chunks: + @echo "Updating Slack message chunks" + @CMD="python manage.py ai_update_slack_message_chunks" $(MAKE) exec-backend-command + ai-update-slack-message-context: @echo "Updating Slack message context" @CMD="python manage.py ai_update_slack_message_context" $(MAKE) exec-backend-command diff --git a/backend/apps/ai/common/base/chunk_command.py b/backend/apps/ai/common/base/chunk_command.py index 272e50a460..6a29005d12 100644 --- a/backend/apps/ai/common/base/chunk_command.py +++ b/backend/apps/ai/common/base/chunk_command.py @@ -1,7 +1,7 @@ """Base chunk command class for creating chunks.""" from django.contrib.contenttypes.models import ContentType -from django.db.models import Model +from django.db.models import Max, Model from apps.ai.common.base.ai_command import BaseAICommand from apps.ai.common.utils import create_chunks_and_embeddings @@ -14,18 +14,17 @@ class BaseChunkCommand(BaseAICommand): def help(self) -> str: """Return help text for the chunk creation command.""" - return f"Create chunks for OWASP {self.entity_name} data" + return f"Create or update chunks for OWASP {self.entity_name} data" def process_chunks_batch(self, entities: list[Model]) -> int: - """Process a batch of entities to create chunks.""" + """Process a batch of entities to create or update chunks.""" processed = 0 - batch_chunks = [] + batch_chunks_to_create = [] content_type = ContentType.objects.get_for_model(self.model_class) for entity in entities: - context = Context.objects.filter(entity_type=content_type, entity_id=entity.id).first() - entity_key = self.get_entity_key(entity) + context = Context.objects.filter(entity_type=content_type, entity_id=entity.id).first() if not context: self.stdout.write( @@ -33,32 +32,47 @@ def process_chunks_batch(self, entities: list[Model]) -> int: ) continue - prose_content, metadata_content = self.extract_content(entity) - full_content = ( - f"{metadata_content}\n\n{prose_content}" if metadata_content else prose_content - ) + latest_chunk_timestamp = context.chunks.aggregate( + latest_created=Max("nest_created_at") + )["latest_created"] - if not full_content.strip(): - self.stdout.write(f"No content to chunk for {self.entity_name} {entity_key}") - continue + if not latest_chunk_timestamp or context.nest_updated_at > latest_chunk_timestamp: + self.stdout.write(f"Context for {entity_key} requires chunk creation/update") - chunk_texts = Chunk.split_text(full_content) - if not chunk_texts: - self.stdout.write(f"No chunks created for {self.entity_name} {entity_key}") - continue + if latest_chunk_timestamp: + count, _ = context.chunks.all().delete() + self.stdout.write(f"Deleted {count} stale chunks for {entity_key}") + + prose_content, metadata_content = self.extract_content(entity) + full_content = ( + f"{metadata_content}\n\n{prose_content}" if metadata_content else prose_content + ) - if chunks := create_chunks_and_embeddings( - chunk_texts=chunk_texts, - context=context, - openai_client=self.openai_client, - save=False, - ): - batch_chunks.extend(chunks) - processed += 1 - self.stdout.write(f"Created {len(chunks)} chunks for {entity_key}") - - if batch_chunks: - Chunk.bulk_save(batch_chunks) + if not full_content.strip(): + self.stdout.write(f"No content to chunk for {self.entity_name} {entity_key}") + continue + + chunk_texts = Chunk.split_text(full_content) + if not chunk_texts: + self.stdout.write(f"No chunks created for {self.entity_name} {entity_key}") + continue + + if chunks := create_chunks_and_embeddings( + chunk_texts=chunk_texts, + context=context, + openai_client=self.openai_client, + save=False, + ): + batch_chunks_to_create.extend(chunks) + processed += 1 + self.stdout.write( + self.style.SUCCESS(f"Created {len(chunks)} new chunks for {entity_key}") + ) + else: + self.stdout.write(f"Chunks for {entity_key} are already up to date.") + + if batch_chunks_to_create: + Chunk.bulk_save(batch_chunks_to_create) return processed diff --git a/backend/apps/ai/management/commands/ai_create_chapter_chunks.py b/backend/apps/ai/management/commands/ai_update_chapter_chunks.py similarity index 100% rename from backend/apps/ai/management/commands/ai_create_chapter_chunks.py rename to backend/apps/ai/management/commands/ai_update_chapter_chunks.py diff --git a/backend/apps/ai/management/commands/ai_create_committee_chunks.py b/backend/apps/ai/management/commands/ai_update_committee_chunks.py similarity index 100% rename from backend/apps/ai/management/commands/ai_create_committee_chunks.py rename to backend/apps/ai/management/commands/ai_update_committee_chunks.py diff --git a/backend/apps/ai/management/commands/ai_create_event_chunks.py b/backend/apps/ai/management/commands/ai_update_event_chunks.py similarity index 100% rename from backend/apps/ai/management/commands/ai_create_event_chunks.py rename to backend/apps/ai/management/commands/ai_update_event_chunks.py diff --git a/backend/apps/ai/management/commands/ai_create_project_chunks.py b/backend/apps/ai/management/commands/ai_update_project_chunks.py similarity index 100% rename from backend/apps/ai/management/commands/ai_create_project_chunks.py rename to backend/apps/ai/management/commands/ai_update_project_chunks.py diff --git a/backend/apps/ai/management/commands/ai_create_slack_message_chunks.py b/backend/apps/ai/management/commands/ai_update_slack_message_chunks.py similarity index 100% rename from backend/apps/ai/management/commands/ai_create_slack_message_chunks.py rename to backend/apps/ai/management/commands/ai_update_slack_message_chunks.py diff --git a/backend/apps/ai/models/context.py b/backend/apps/ai/models/context.py index feef8c97a0..f0fce30350 100644 --- a/backend/apps/ai/models/context.py +++ b/backend/apps/ai/models/context.py @@ -60,6 +60,6 @@ def update_data( context.content = content context.source = source if save: - context.save(update_fields=["content", "source"]) + context.save() return context diff --git a/backend/tests/apps/ai/common/base/ai_command_test.py b/backend/tests/apps/ai/common/base/ai_command_test.py index a9493179a1..83d73dfc97 100644 --- a/backend/tests/apps/ai/common/base/ai_command_test.py +++ b/backend/tests/apps/ai/common/base/ai_command_test.py @@ -12,12 +12,16 @@ class MockTestModel: objects = Mock() pk = 1 + __name__ = "TestEntity" @pytest.fixture def command(): """Fixture for ConcreteAICommand instance.""" - return ConcreteAICommand() + cmd = ConcreteAICommand() + cmd.entity_name = "test_entity" + cmd.entity_name_plural = "test_entities" + return cmd @pytest.fixture diff --git a/backend/tests/apps/ai/common/base/chunk_command_test.py b/backend/tests/apps/ai/common/base/chunk_command_test.py index 5ee912a589..627acf36f0 100644 --- a/backend/tests/apps/ai/common/base/chunk_command_test.py +++ b/backend/tests/apps/ai/common/base/chunk_command_test.py @@ -1,7 +1,7 @@ """Tests for the BaseChunkCommand class.""" from typing import Any -from unittest.mock import Mock, patch +from unittest.mock import Mock, call, patch import pytest from django.contrib.contenttypes.models import ContentType @@ -30,8 +30,10 @@ def command(): """Return a concrete chunk command instance for testing.""" cmd = ConcreteChunkCommand() mock_model = Mock() - mock_model.__name__ = "MockChunkTestModel" + mock_model.__name__ = "TestEntity" cmd.model_class = mock_model + cmd.entity_name = "test_entity" + cmd.entity_name_plural = "test_entities" return cmd @@ -52,6 +54,9 @@ def mock_context(): context.id = 1 context.content_type_id = 1 context.object_id = 1 + context.chunks.aggregate.return_value = {"latest_created": None} + context.chunks.all.return_value.delete.return_value = (0, {}) + context.nest_updated_at = Mock() return context @@ -85,12 +90,12 @@ def test_command_inheritance(self, command): def test_help_method(self, command): """Test the help method returns appropriate help text.""" - expected_help = "Create chunks for OWASP test_entity data" + expected_help = "Create or update chunks for OWASP test_entity data" assert command.help() == expected_help def test_abstract_methods_implemented(self, command): """Test that all abstract methods are properly implemented.""" - assert command.model_class.__name__ == "MockChunkTestModel" + assert command.model_class.__name__ == "TestEntity" assert command.entity_name == "test_entity" assert command.entity_name_plural == "test_entities" assert command.key_field_name == "test_key" @@ -143,7 +148,12 @@ def test_process_chunks_batch_empty_content( result = command.process_chunks_batch([mock_entity]) assert result == 0 - mock_write.assert_called_once_with("No content to chunk for test_entity test-key-123") + # Check that it wrote the initial message and the empty content message + expected_calls = [ + call("Context for test-key-123 requires chunk creation/update"), + call("No content to chunk for test_entity test-key-123"), + ] + mock_write.assert_has_calls(expected_calls) @patch("apps.ai.common.base.chunk_command.ContentType.objects.get_for_model") @patch("apps.ai.common.base.chunk_command.Context.objects.filter") @@ -167,7 +177,12 @@ def test_process_chunks_batch_no_chunks_created( result = command.process_chunks_batch([mock_entity]) assert result == 0 - mock_write.assert_called_once() + # Check that both messages were written + expected_calls = [ + call("Context for test-key-123 requires chunk creation/update"), + call("No chunks created for test_entity test-key-123"), + ] + mock_write.assert_has_calls(expected_calls) call_args = mock_write.call_args[0][0] assert "No chunks created for test_entity test-key-123" in call_args @@ -207,7 +222,12 @@ def test_process_chunks_batch_success( save=False, ) mock_bulk_save.assert_called_once_with(mock_chunks) - mock_write.assert_called_once_with("Created 3 chunks for test-key-123") + mock_write.assert_has_calls( + [ + call("Context for test-key-123 requires chunk creation/update"), + call(command.style.SUCCESS("Created 3 new chunks for test-key-123")), + ] + ) @patch("apps.ai.common.base.chunk_command.ContentType.objects.get_for_model") @patch("apps.ai.common.base.chunk_command.Context.objects.filter") @@ -409,6 +429,8 @@ def test_process_chunks_batch_whitespace_only_content( result = command.process_chunks_batch([mock_entity]) assert result == 0 - mock_write.assert_called_once_with( - "No content to chunk for test_entity test-key-123" - ) + expected_calls = [ + call("Context for test-key-123 requires chunk creation/update"), + call("No content to chunk for test_entity test-key-123"), + ] + mock_write.assert_has_calls(expected_calls) diff --git a/backend/tests/apps/ai/common/base/context_command_test.py b/backend/tests/apps/ai/common/base/context_command_test.py index c267e074e4..48237a1fc4 100644 --- a/backend/tests/apps/ai/common/base/context_command_test.py +++ b/backend/tests/apps/ai/common/base/context_command_test.py @@ -27,8 +27,10 @@ def command(): """Return a concrete context command instance for testing.""" cmd = ConcreteContextCommand() mock_model = Mock() - mock_model.__name__ = "MockContextTestModel" + mock_model.__name__ = "TestEntity" cmd.model_class = mock_model + cmd.entity_name = "test_entity" + cmd.entity_name_plural = "test_entities" return cmd @@ -67,7 +69,7 @@ def test_help_method(self, command): def test_abstract_methods_implemented(self, command): """Test that all abstract methods are properly implemented.""" - assert command.model_class.__name__ == "MockContextTestModel" + assert command.model_class.__name__ == "TestEntity" assert command.entity_name == "test_entity" assert command.entity_name_plural == "test_entities" assert command.key_field_name == "test_key" diff --git a/backend/tests/apps/ai/common/extractors/project_test.py b/backend/tests/apps/ai/common/extractors/project_test.py index 01236103ea..db5cbfde47 100644 --- a/backend/tests/apps/ai/common/extractors/project_test.py +++ b/backend/tests/apps/ai/common/extractors/project_test.py @@ -439,3 +439,75 @@ def test_extract_project_content_with_empty_related_urls(self): _, metadata = extract_project_content(project) assert "Related URLs: https://valid.com, https://another.com" in metadata + + def test_extract_project_content_repository_no_description_no_topics(self): + """Test extraction when repository exists but has neither description nor topics.""" + project = MagicMock() + project.description = "Project description" + project.summary = None + project.name = "Test Project" + project.level = None + project.type = None + project.languages = [] + project.topics = [] + project.licenses = [] + project.tags = [] + project.custom_tags = [] + project.stars_count = None + project.forks_count = None + project.contributors_count = None + project.releases_count = None + project.open_issues_count = None + project.leaders_raw = [] + project.related_urls = [] + project.invalid_urls = [] + project.created_at = None + project.updated_at = None + project.released_at = None + project.health_score = None + project.is_active = True + + repo = MagicMock() + repo.description = None + repo.topics = None + project.owasp_repository = repo + + prose, metadata = extract_project_content(project) + + assert "Description: Project description" in prose + assert "Repository Description:" not in prose + assert "Repository Topics:" not in metadata + + def test_extract_project_content_created_and_released_only(self): + """Test extraction with created_at and released_at but no updated_at.""" + project = MagicMock() + project.description = None + project.summary = None + project.name = "Date Test Project" + project.level = None + project.type = None + project.languages = [] + project.topics = [] + project.licenses = [] + project.tags = [] + project.custom_tags = [] + project.stars_count = None + project.forks_count = None + project.contributors_count = None + project.releases_count = None + project.open_issues_count = None + project.leaders_raw = [] + project.related_urls = [] + project.invalid_urls = [] + project.created_at = datetime(2021, 3, 1, tzinfo=UTC) + project.updated_at = None + project.released_at = datetime(2023, 8, 15, tzinfo=UTC) + project.health_score = None + project.is_active = True + project.owasp_repository = None + + _, metadata = extract_project_content(project) + + assert "Created: 2021-03-01" in metadata + assert "Last Updated:" not in metadata + assert "Last Release: 2023-08-15" in metadata diff --git a/backend/tests/apps/ai/common/utils_test.py b/backend/tests/apps/ai/common/utils_test.py index 4f6890cffa..a9e1b19618 100644 --- a/backend/tests/apps/ai/common/utils_test.py +++ b/backend/tests/apps/ai/common/utils_test.py @@ -3,7 +3,10 @@ import openai -from apps.ai.common.utils import create_chunks_and_embeddings +from apps.ai.common.utils import ( + create_chunks_and_embeddings, + regenerate_chunks_for_context, +) class MockEmbeddingData: @@ -176,3 +179,194 @@ def test_create_chunks_and_embeddings_no_sleep_with_current_settings( text="test chunk", embedding=[0.1, 0.2], context=mock_context_obj, save=True ) assert result == [mock_chunk] + + @patch("apps.ai.common.utils.time.sleep") + @patch("apps.ai.common.utils.datetime") + def test_create_chunks_and_embeddings_sleep_when_rate_limited(self, mock_datetime, mock_sleep): + """Tests that sleep is called when rate limiting is needed.""" + base_time = datetime.now(UTC) + + mock_datetime.now.side_effect = [ + base_time, + base_time, + ] + mock_datetime.UTC = UTC + mock_datetime.timedelta = timedelta + + mock_openai_client = MagicMock() + mock_api_response = MagicMock() + mock_api_response.data = [MockEmbeddingData([0.1, 0.2])] + mock_openai_client.embeddings.create.return_value = mock_api_response + + with patch("apps.ai.common.utils.Chunk.update_data") as mock_update_data: + mock_chunk = MagicMock() + mock_update_data.return_value = mock_chunk + + with ( + patch("apps.ai.common.utils.DEFAULT_LAST_REQUEST_OFFSET_SECONDS", 5), + patch("apps.ai.common.utils.MIN_REQUEST_INTERVAL_SECONDS", 10), + ): + result = create_chunks_and_embeddings( + ["test chunk"], + MagicMock(), + mock_openai_client, + ) + + mock_sleep.assert_called_once() + assert result == [mock_chunk] + + @patch("apps.ai.common.utils.Chunk.update_data") + def test_create_chunks_and_embeddings_filter_none_chunks(self, mock_update_data): + """Tests that None chunks are filtered out from results.""" + mock_openai_client = MagicMock() + mock_api_response = MagicMock() + mock_api_response.data = [ + MockEmbeddingData([0.1, 0.2]), + MockEmbeddingData([0.3, 0.4]), + ] + mock_openai_client.embeddings.create.return_value = mock_api_response + + mock_chunk = MagicMock() + mock_update_data.side_effect = [mock_chunk, None] + + result = create_chunks_and_embeddings( + ["first chunk", "second chunk"], + MagicMock(), + mock_openai_client, + ) + + assert len(result) == 1 + assert result[0] == mock_chunk + + +class TestRegenerateChunksForContext: + """Test cases for regenerate_chunks_for_context function.""" + + @patch("apps.ai.common.utils.logger") + @patch("apps.ai.common.utils.OpenAI") + @patch("apps.ai.common.utils.create_chunks_and_embeddings") + @patch("apps.ai.models.chunk.Chunk") + def test_regenerate_chunks_for_context_success( + self, mock_chunk_class, mock_create_chunks, mock_openai_class, mock_logger + ): + """Test successful regeneration of chunks for context.""" + # Setup context mock + context = MagicMock() + context.content = "This is test content for chunking" + + # Setup existing chunks + mock_existing_chunks = MagicMock() + mock_existing_chunks.count.return_value = 3 + context.chunks = mock_existing_chunks + + # Setup chunk splitting + new_chunk_texts = ["chunk1", "chunk2"] + mock_chunk_class.split_text.return_value = new_chunk_texts + + # Setup OpenAI client + mock_openai_client = MagicMock() + mock_openai_class.return_value = mock_openai_client + + regenerate_chunks_for_context(context) + + # Verify old chunks were deleted + mock_existing_chunks.all.assert_called_once() + mock_existing_chunks.all().delete.assert_called_once() + + # Verify content was split + mock_chunk_class.split_text.assert_called_once_with(context.content) + + # Verify OpenAI client was created + mock_openai_class.assert_called_once() + + # Verify new chunks were created + mock_create_chunks.assert_called_once_with( + chunk_texts=new_chunk_texts, + context=context, + openai_client=mock_openai_client, + save=True, + ) + + # Verify success log + mock_logger.info.assert_called_once_with( + "Successfully completed chunk regeneration for new context" + ) + + @patch("apps.ai.common.utils.logger") + @patch("apps.ai.models.chunk.Chunk") + def test_regenerate_chunks_for_context_no_content(self, mock_chunk_class, mock_logger): + """Test regeneration when there's no content to chunk.""" + # Setup context mock + context = MagicMock() + context.content = "Some content" + + # Setup existing chunks + mock_existing_chunks = MagicMock() + mock_existing_chunks.count.return_value = 2 + context.chunks = mock_existing_chunks + + # Setup chunk splitting to return empty list + mock_chunk_class.split_text.return_value = [] + + regenerate_chunks_for_context(context) + + # Verify old chunks were deleted + mock_existing_chunks.all.assert_called_once() + mock_existing_chunks.all().delete.assert_called_once() + + # Verify content was split + mock_chunk_class.split_text.assert_called_once_with(context.content) + + # Verify warning was logged and process stopped + mock_logger.warning.assert_called_once_with( + "No content to chunk for Context. Process stopped." + ) + + # Verify success log was not called + mock_logger.info.assert_not_called() + + @patch("apps.ai.common.utils.logger") + @patch("apps.ai.common.utils.OpenAI") + @patch("apps.ai.common.utils.create_chunks_and_embeddings") + @patch("apps.ai.models.chunk.Chunk") + def test_regenerate_chunks_for_context_no_existing_chunks( + self, mock_chunk_class, mock_create_chunks, mock_openai_class, mock_logger + ): + """Test regeneration when there are no existing chunks.""" + # Setup context mock + context = MagicMock() + context.content = "This is test content for chunking" + + # Setup no existing chunks + mock_existing_chunks = MagicMock() + mock_existing_chunks.count.return_value = 0 + context.chunks = mock_existing_chunks + + # Setup chunk splitting + new_chunk_texts = ["chunk1", "chunk2"] + mock_chunk_class.split_text.return_value = new_chunk_texts + + # Setup OpenAI client + mock_openai_client = MagicMock() + mock_openai_class.return_value = mock_openai_client + + regenerate_chunks_for_context(context) + + # Verify delete was not called since count is 0 + mock_existing_chunks.all.assert_not_called() + + # Verify content was split + mock_chunk_class.split_text.assert_called_once_with(context.content) + + # Verify new chunks were created + mock_create_chunks.assert_called_once_with( + chunk_texts=new_chunk_texts, + context=context, + openai_client=mock_openai_client, + save=True, + ) + + # Verify success log + mock_logger.info.assert_called_once_with( + "Successfully completed chunk regeneration for new context" + ) diff --git a/backend/tests/apps/ai/management/commands/ai_create_chapter_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_update_chapter_chunks_test.py similarity index 91% rename from backend/tests/apps/ai/management/commands/ai_create_chapter_chunks_test.py rename to backend/tests/apps/ai/management/commands/ai_update_chapter_chunks_test.py index e501580f9b..7df18a89bd 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_chapter_chunks_test.py +++ b/backend/tests/apps/ai/management/commands/ai_update_chapter_chunks_test.py @@ -3,7 +3,7 @@ import pytest from django.core.management.base import BaseCommand -from apps.ai.management.commands.ai_create_chapter_chunks import Command +from apps.ai.management.commands.ai_update_chapter_chunks import Command @pytest.fixture @@ -39,7 +39,7 @@ def test_key_field_name_property(self, command): def test_extract_content(self, command, mock_chapter): with patch( - "apps.ai.management.commands.ai_create_chapter_chunks.extract_chapter_content" + "apps.ai.management.commands.ai_update_chapter_chunks.extract_chapter_content" ) as mock_extract: mock_extract.return_value = ("prose content", "metadata content") content = command.extract_content(mock_chapter) diff --git a/backend/tests/apps/ai/management/commands/ai_create_committee_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_update_committee_chunks_test.py similarity index 96% rename from backend/tests/apps/ai/management/commands/ai_create_committee_chunks_test.py rename to backend/tests/apps/ai/management/commands/ai_update_committee_chunks_test.py index 5d0a27f712..e889fe5b11 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_committee_chunks_test.py +++ b/backend/tests/apps/ai/management/commands/ai_update_committee_chunks_test.py @@ -5,7 +5,7 @@ import pytest from django.core.management.base import BaseCommand -from apps.ai.management.commands.ai_create_committee_chunks import Command +from apps.ai.management.commands.ai_update_committee_chunks import Command @pytest.fixture @@ -54,7 +54,7 @@ def test_key_field_name_method(self, command): def test_extract_content_method(self, command, mock_committee): """Test the extract_content method.""" with patch( - "apps.ai.management.commands.ai_create_committee_chunks.extract_committee_content" + "apps.ai.management.commands.ai_update_committee_chunks.extract_committee_content" ) as mock_extract: mock_extract.return_value = ("prose content", "metadata content") content = command.extract_content(mock_committee) @@ -146,4 +146,4 @@ def test_source_name_default(self, command): def test_help_method(self, command): """Test the help method.""" - assert command.help() == "Create chunks for OWASP committee data" + assert command.help() == "Create or update chunks for OWASP committee data" diff --git a/backend/tests/apps/ai/management/commands/ai_create_event_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_update_event_chunks_test.py similarity index 95% rename from backend/tests/apps/ai/management/commands/ai_create_event_chunks_test.py rename to backend/tests/apps/ai/management/commands/ai_update_event_chunks_test.py index b00cee7488..298b50b185 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_event_chunks_test.py +++ b/backend/tests/apps/ai/management/commands/ai_update_event_chunks_test.py @@ -5,7 +5,7 @@ import pytest from django.core.management.base import BaseCommand -from apps.ai.management.commands.ai_create_event_chunks import Command +from apps.ai.management.commands.ai_update_event_chunks import Command @pytest.fixture @@ -51,7 +51,7 @@ def test_key_field_name_property(self, command): def test_extract_content(self, command, mock_event): """Test content extraction from event.""" with patch( - "apps.ai.management.commands.ai_create_event_chunks.extract_event_content" + "apps.ai.management.commands.ai_update_event_chunks.extract_event_content" ) as mock_extract: mock_extract.return_value = ("prose content", "metadata content") content = command.extract_content(mock_event) diff --git a/backend/tests/apps/ai/management/commands/ai_create_project_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_update_project_chunks_test.py similarity index 93% rename from backend/tests/apps/ai/management/commands/ai_create_project_chunks_test.py rename to backend/tests/apps/ai/management/commands/ai_update_project_chunks_test.py index bc919abaa7..85f889ddaf 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_project_chunks_test.py +++ b/backend/tests/apps/ai/management/commands/ai_update_project_chunks_test.py @@ -3,7 +3,7 @@ import pytest from django.core.management.base import BaseCommand -from apps.ai.management.commands.ai_create_project_chunks import Command +from apps.ai.management.commands.ai_update_project_chunks import Command @pytest.fixture @@ -39,7 +39,7 @@ def test_key_field_name_property(self, command): def test_extract_content(self, command, mock_project): with patch( - "apps.ai.management.commands.ai_create_project_chunks.extract_project_content" + "apps.ai.management.commands.ai_update_project_chunks.extract_project_content" ) as mock_extract: mock_extract.return_value = ("prose content", "metadata content") content = command.extract_content(mock_project) diff --git a/backend/tests/apps/ai/management/commands/ai_create_slack_message_chunks_test.py b/backend/tests/apps/ai/management/commands/ai_update_slack_message_chunks_test.py similarity index 97% rename from backend/tests/apps/ai/management/commands/ai_create_slack_message_chunks_test.py rename to backend/tests/apps/ai/management/commands/ai_update_slack_message_chunks_test.py index 81ff0f73ae..22c9db9772 100644 --- a/backend/tests/apps/ai/management/commands/ai_create_slack_message_chunks_test.py +++ b/backend/tests/apps/ai/management/commands/ai_update_slack_message_chunks_test.py @@ -3,7 +3,7 @@ import pytest from django.core.management.base import BaseCommand -from apps.ai.management.commands.ai_create_slack_message_chunks import Command +from apps.ai.management.commands.ai_update_slack_message_chunks import Command @pytest.fixture diff --git a/backend/tests/apps/ai/models/context_test.py b/backend/tests/apps/ai/models/context_test.py index fdb74379b7..1911034e83 100644 --- a/backend/tests/apps/ai/models/context_test.py +++ b/backend/tests/apps/ai/models/context_test.py @@ -140,8 +140,7 @@ def test_context_delete(self, mock_delete): @patch("apps.ai.models.context.Context.objects.filter") @patch("apps.ai.models.context.Context.objects.get_or_create") - @patch("apps.ai.models.context.regenerate_chunks_for_context") - def test_update_data_existing_context(self, mock_regenerate, mock_get_or_create, mock_filter): + def test_update_data_existing_context(self, mock_get_or_create, mock_filter): mock_context = create_model_mock(Context) mock_get_or_create.return_value = (mock_context, False) @@ -163,7 +162,6 @@ def test_update_data_existing_context(self, mock_regenerate, mock_get_or_create, entity_id=1, defaults={"content": content, "source": "src"}, ) - mock_regenerate.assert_called_once_with(context=mock_context) assert result == mock_context def test_str_method_with_name_attribute(self): diff --git a/backend/tests/apps/common/open_ai_test.py b/backend/tests/apps/common/open_ai_test.py index ee8827bc19..e46c38e687 100644 --- a/backend/tests/apps/common/open_ai_test.py +++ b/backend/tests/apps/common/open_ai_test.py @@ -60,6 +60,3 @@ def test_complete_general_exception(self, mock_openai, mock_logger, openai_insta response = openai_instance.complete() assert response is None - mock_logger.exception.assert_called_once_with( - "An error occurred during OpenAI API request." - ) From 197c0ff757124d6f3939bfacab47231248095456 Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Tue, 19 Aug 2025 02:38:24 +0530 Subject: [PATCH 27/32] sugesstions implemented --- backend/apps/ai/models/context.py | 22 ++--- backend/tests/apps/ai/models/context_test.py | 95 +++++++++++++------- 2 files changed, 73 insertions(+), 44 deletions(-) diff --git a/backend/apps/ai/models/context.py b/backend/apps/ai/models/context.py index f0fce30350..23e92c0c5e 100644 --- a/backend/apps/ai/models/context.py +++ b/backend/apps/ai/models/context.py @@ -50,16 +50,18 @@ def update_data( save: bool = True, ) -> "Context": """Create or update context for a given entity.""" - context, created = Context.objects.get_or_create( - entity_type=ContentType.objects.get_for_model(content_object), - entity_id=content_object.pk, - defaults={"content": content, "source": source}, - ) + entity_type = ContentType.objects.get_for_model(content_object) + entity_id = content_object.pk + + try: + context = Context.objects.get(entity_type=entity_type, entity_id=entity_id) + except Context.DoesNotExist: + context = Context(entity_type=entity_type, entity_id=entity_id) + + context.content = content + context.source = source - if not created and (context.content != content or context.source != source): - context.content = content - context.source = source - if save: - context.save() + if save: + context.save() return context diff --git a/backend/tests/apps/ai/models/context_test.py b/backend/tests/apps/ai/models/context_test.py index 1911034e83..8df4527ec8 100644 --- a/backend/tests/apps/ai/models/context_test.py +++ b/backend/tests/apps/ai/models/context_test.py @@ -138,11 +138,10 @@ def test_context_delete(self, mock_delete): mock_delete.assert_called_once() - @patch("apps.ai.models.context.Context.objects.filter") - @patch("apps.ai.models.context.Context.objects.get_or_create") - def test_update_data_existing_context(self, mock_get_or_create, mock_filter): + @patch("apps.ai.models.context.Context.objects.get") + def test_update_data_existing_context(self, mock_get): mock_context = create_model_mock(Context) - mock_get_or_create.return_value = (mock_context, False) + mock_get.return_value = mock_context content = "Test" mock_content_object = Mock() @@ -152,17 +151,19 @@ def test_update_data_existing_context(self, mock_get_or_create, mock_filter): "apps.ai.models.context.ContentType.objects.get_for_model" ) as mock_get_for_model: mock_content_type = Mock() + mock_content_type.get_source_expressions = Mock(return_value=[]) mock_get_for_model.return_value = mock_content_type result = Context.update_data(content, mock_content_object, source="src", save=True) mock_get_for_model.assert_called_once_with(mock_content_object) - mock_get_or_create.assert_called_once_with( + mock_get.assert_called_once_with( entity_type=mock_content_type, entity_id=1, - defaults={"content": content, "source": "src"}, ) assert result == mock_context + assert mock_context.content == content + assert mock_context.source == "src" def test_str_method_with_name_attribute(self): """Test __str__ method when entity has name attribute.""" @@ -294,11 +295,14 @@ def test_str_method_fallback_to_str(self): result = str(context) assert result == "test_model String representation: Test content" - @patch("apps.ai.models.context.Context.objects.get_or_create") - def test_update_data_new_context_with_save(self, mock_get_or_create): + @patch("apps.ai.models.context.Context.objects.get") + @patch("apps.ai.models.context.Context.__init__") + def test_update_data_new_context_with_save(self, mock_init, mock_get): """Test update_data creating a new context with save=True.""" - mock_context = create_model_mock(Context) - mock_get_or_create.return_value = (mock_context, True) + from apps.ai.models.context import Context as ContextModel + + mock_get.side_effect = ContextModel.DoesNotExist + mock_init.return_value = None # __init__ should return None content = "New test content" mock_content_object = Mock() @@ -308,24 +312,37 @@ def test_update_data_new_context_with_save(self, mock_get_or_create): with patch( "apps.ai.models.context.ContentType.objects.get_for_model" ) as mock_get_for_model: - mock_content_type = Mock() + mock_content_type = Mock(spec=ContentType) + mock_content_type.get_source_expressions = Mock(return_value=[]) mock_get_for_model.return_value = mock_content_type - result = Context.update_data(content, mock_content_object, source=source, save=True) + # Mock the context instance and its save method + with patch.object(ContextModel, "save") as mock_save: + result = Context.update_data( + content, mock_content_object, source=source, save=True + ) + + mock_get_for_model.assert_called_once_with(mock_content_object) + mock_get.assert_called_once_with( + entity_type=mock_content_type, + entity_id=1, + ) + mock_init.assert_called_once_with( + entity_type=mock_content_type, + entity_id=1, + ) + assert result.content == content + assert result.source == source + mock_save.assert_called_once() - mock_get_for_model.assert_called_once_with(mock_content_object) - mock_get_or_create.assert_called_once_with( - entity_type=mock_content_type, - entity_id=1, - defaults={"content": content, "source": source}, - ) - assert result == mock_context - - @patch("apps.ai.models.context.Context.objects.get_or_create") - def test_update_data_new_context_without_save(self, mock_get_or_create): + @patch("apps.ai.models.context.Context.objects.get") + @patch("apps.ai.models.context.Context.__init__") + def test_update_data_new_context_without_save(self, mock_init, mock_get): """Test update_data creating a new context with save=False.""" - mock_context = create_model_mock(Context) - mock_get_or_create.return_value = (mock_context, True) + from apps.ai.models.context import Context as ContextModel + + mock_get.side_effect = ContextModel.DoesNotExist + mock_init.return_value = None # __init__ should return None content = "New test content" mock_content_object = Mock() @@ -335,15 +352,25 @@ def test_update_data_new_context_without_save(self, mock_get_or_create): with patch( "apps.ai.models.context.ContentType.objects.get_for_model" ) as mock_get_for_model: - mock_content_type = Mock() + mock_content_type = Mock(spec=ContentType) + mock_content_type.get_source_expressions = Mock(return_value=[]) mock_get_for_model.return_value = mock_content_type - result = Context.update_data(content, mock_content_object, source=source, save=False) - - mock_get_for_model.assert_called_once_with(mock_content_object) - mock_get_or_create.assert_called_once_with( - entity_type=mock_content_type, - entity_id=1, - defaults={"content": content, "source": source}, - ) - assert result == mock_context + # Mock the context instance and its save method + with patch.object(ContextModel, "save") as mock_save: + result = Context.update_data( + content, mock_content_object, source=source, save=False + ) + + mock_get_for_model.assert_called_once_with(mock_content_object) + mock_get.assert_called_once_with( + entity_type=mock_content_type, + entity_id=1, + ) + mock_init.assert_called_once_with( + entity_type=mock_content_type, + entity_id=1, + ) + assert result.content == content + assert result.source == source + mock_save.assert_not_called() From 346d324ca2588a628a51db4424518ff7c9165ee3 Mon Sep 17 00:00:00 2001 From: Arkadii Yakovets Date: Wed, 20 Aug 2025 13:26:56 -0700 Subject: [PATCH 28/32] Update code --- backend/apps/ai/agent/tools/rag/retriever.py | 8 +- .../apps/ai/common/base/context_command.py | 2 +- backend/apps/ai/models/context.py | 6 +- .../apps/ai/agent/tools/rag/retriever_test.py | 86 +++++++++---------- .../ai/common/base/context_command_test.py | 4 +- .../ai_update_committee_context_test.py | 2 +- 6 files changed, 57 insertions(+), 51 deletions(-) diff --git a/backend/apps/ai/agent/tools/rag/retriever.py b/backend/apps/ai/agent/tools/rag/retriever.py index 3f9d043bd5..bf3e792610 100644 --- a/backend/apps/ai/agent/tools/rag/retriever.py +++ b/backend/apps/ai/agent/tools/rag/retriever.py @@ -21,7 +21,13 @@ class Retriever: """A class for retrieving relevant text chunks for a RAG.""" - SUPPORTED_ENTITY_TYPES = ("event", "project", "chapter", "committee", "message") + SUPPORTED_ENTITY_TYPES = ( + "chapter", + "committee", + "event", + "message", + "project", + ) def __init__(self, embedding_model: str = "text-embedding-3-small"): """Initialize the Retriever. diff --git a/backend/apps/ai/common/base/context_command.py b/backend/apps/ai/common/base/context_command.py index 4a3c75b0c7..5a43370450 100644 --- a/backend/apps/ai/common/base/context_command.py +++ b/backend/apps/ai/common/base/context_command.py @@ -30,7 +30,7 @@ def process_context_batch(self, entities: list[Model]) -> int: if Context.update_data( content=full_content, - content_object=entity, + entity=entity, source=self.source_name(), ): processed += 1 diff --git a/backend/apps/ai/models/context.py b/backend/apps/ai/models/context.py index 23e92c0c5e..d722d83e88 100644 --- a/backend/apps/ai/models/context.py +++ b/backend/apps/ai/models/context.py @@ -44,14 +44,14 @@ def __str__(self): @staticmethod def update_data( content: str, - content_object, + entity, source: str = "", *, save: bool = True, ) -> "Context": """Create or update context for a given entity.""" - entity_type = ContentType.objects.get_for_model(content_object) - entity_id = content_object.pk + entity_type = ContentType.objects.get_for_model(entity) + entity_id = entity.pk try: context = Context.objects.get(entity_type=entity_type, entity_id=entity_id) diff --git a/backend/tests/apps/ai/agent/tools/rag/retriever_test.py b/backend/tests/apps/ai/agent/tools/rag/retriever_test.py index fe30aff809..538ee6fe46 100644 --- a/backend/tests/apps/ai/agent/tools/rag/retriever_test.py +++ b/backend/tests/apps/ai/agent/tools/rag/retriever_test.py @@ -105,11 +105,11 @@ def test_get_source_name_with_name(self): ): retriever = Retriever() - content_object = MagicMock() - content_object.name = "Test Name" - content_object.title = "Test Title" + entity = MagicMock() + entity.name = "Test Name" + entity.title = "Test Title" - result = retriever.get_source_name(content_object) + result = retriever.get_source_name(entity) assert result == "Test Name" def test_get_source_name_with_title(self): @@ -120,12 +120,12 @@ def test_get_source_name_with_title(self): ): retriever = Retriever() - content_object = MagicMock() - content_object.name = None - content_object.title = "Test Title" - content_object.login = "test_login" + entity = MagicMock() + entity.name = None + entity.title = "Test Title" + entity.login = "test_login" - result = retriever.get_source_name(content_object) + result = retriever.get_source_name(entity) assert result == "Test Title" def test_get_source_name_with_login(self): @@ -136,13 +136,13 @@ def test_get_source_name_with_login(self): ): retriever = Retriever() - content_object = MagicMock() - content_object.name = None - content_object.title = None - content_object.login = "test_login" - content_object.key = "test_key" + entity = MagicMock() + entity.name = None + entity.title = None + entity.login = "test_login" + entity.key = "test_key" - result = retriever.get_source_name(content_object) + result = retriever.get_source_name(entity) assert result == "test_login" def test_get_source_name_fallback_to_str(self): @@ -153,15 +153,15 @@ def test_get_source_name_fallback_to_str(self): ): retriever = Retriever() - content_object = MagicMock() - content_object.name = None - content_object.title = None - content_object.login = None - content_object.key = None - content_object.summary = None - content_object.__str__ = MagicMock(return_value="String representation") - - result = retriever.get_source_name(content_object) + entity = MagicMock() + entity.name = None + entity.title = None + entity.login = None + entity.key = None + entity.summary = None + entity.__str__ = MagicMock(return_value="String representation") + + result = retriever.get_source_name(entity) assert result == "String representation" def test_get_additional_context_chapter(self): @@ -172,21 +172,21 @@ def test_get_additional_context_chapter(self): ): retriever = Retriever() - content_object = MagicMock() - content_object.suggested_location = "New York" - content_object.region = "North America" - content_object.country = "USA" - content_object.postal_code = "10001" - content_object.currency = "USD" - content_object.meetup_group = "OWASP NYC" - content_object.tags = ["security", "web"] - content_object.topics = ["OWASP Top 10"] - content_object.leaders_raw = ["John Doe", "Jane Smith"] - content_object.related_urls = ["https://example.com"] - content_object.is_active = True - content_object.url = "https://owasp.org/chapter" - - result = retriever.get_additional_context(content_object, "chapter") + chapter = MagicMock() + chapter.suggested_location = "New York" + chapter.region = "North America" + chapter.country = "USA" + chapter.postal_code = "10001" + chapter.currency = "USD" + chapter.meetup_group = "OWASP NYC" + chapter.tags = ["security", "web"] + chapter.topics = ["OWASP Top 10"] + chapter.leaders_raw = ["John Doe", "Jane Smith"] + chapter.related_urls = ["https://example.com"] + chapter.is_active = True + chapter.url = "https://owasp.org/chapter" + + result = retriever.get_additional_context(chapter, "chapter") expected_keys = [ "location", @@ -418,13 +418,13 @@ def test_extract_content_types_from_query_no_matches(self): def test_supported_content_types(self): """Test that supported content types are defined correctly.""" - assert Retriever.SUPPORTED_ENTITY_TYPES == ( - "event", - "project", + assert set(Retriever.SUPPORTED_ENTITY_TYPES) == { "chapter", "committee", + "event", "message", - ) + "project", + } @patch("apps.ai.agent.tools.rag.retriever.Chunk") def test_retrieve_with_app_label_content_types(self, mock_chunk): diff --git a/backend/tests/apps/ai/common/base/context_command_test.py b/backend/tests/apps/ai/common/base/context_command_test.py index 48237a1fc4..ff1d6cf9c5 100644 --- a/backend/tests/apps/ai/common/base/context_command_test.py +++ b/backend/tests/apps/ai/common/base/context_command_test.py @@ -113,7 +113,7 @@ def test_process_context_batch_success( assert result == 1 mock_context_class.update_data.assert_called_once_with( content="metadata content\n\nprose content", - content_object=mock_entity, + entity=mock_entity, source="owasp_test_entity", ) mock_write.assert_called_once_with("Created context for test-key-123") @@ -157,7 +157,7 @@ def test_process_context_batch_multiple_entities( calls = mock_context_class.update_data.call_args_list for i, call in enumerate(calls): _, kwargs = call - assert kwargs["content_object"] == entities[i] + assert kwargs["entity"] == entities[i] assert kwargs["content"] == "metadata content\n\nprose content" assert kwargs["source"] == "owasp_test_entity" diff --git a/backend/tests/apps/ai/management/commands/ai_update_committee_context_test.py b/backend/tests/apps/ai/management/commands/ai_update_committee_context_test.py index 8cb4e302f2..6c1940dca1 100644 --- a/backend/tests/apps/ai/management/commands/ai_update_committee_context_test.py +++ b/backend/tests/apps/ai/management/commands/ai_update_committee_context_test.py @@ -166,7 +166,7 @@ def test_process_context_batch_success(self, command, mock_committee): assert result == 1 mock_context_class.update_data.assert_called_once_with( content="Metadata\n\nContent", - content_object=mock_committee, + entity=mock_committee, source="owasp_committee", ) mock_write.assert_called_once_with("Created context for test-committee") From baae5eb6c105d3079789491ccb23897d185124c9 Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Thu, 21 Aug 2025 19:36:17 +0530 Subject: [PATCH 29/32] updated code --- backend/apps/ai/agent/tools/rag/retriever.py | 9 +- backend/apps/ai/common/utils.py | 7 +- backend/apps/ai/models/context.py | 7 - .../apps/ai/agent/tools/rag/retriever_test.py | 96 +++++++----- backend/tests/apps/ai/common/utils_test.py | 143 +++++++++--------- 5 files changed, 134 insertions(+), 128 deletions(-) diff --git a/backend/apps/ai/agent/tools/rag/retriever.py b/backend/apps/ai/agent/tools/rag/retriever.py index bf3e792610..0a7619796c 100644 --- a/backend/apps/ai/agent/tools/rag/retriever.py +++ b/backend/apps/ai/agent/tools/rag/retriever.py @@ -76,19 +76,18 @@ def get_source_name(self, entity) -> str: return str(getattr(entity, attr)) return str(entity) - def get_additional_context(self, entity, entity_type: str) -> dict[str, Any]: + def get_additional_context(self, entity) -> dict[str, Any]: """Get additional context information based on content type. Args: entity: The source object. - entity_type: The model name of the content object. Returns: A dictionary with additional context information. """ context = {} - clean_content_type = entity_type.split(".")[-1] if "." in entity_type else entity_type + clean_content_type = entity.__class__.__name__.lower() if clean_content_type == "chapter": context.update( { @@ -230,9 +229,7 @@ def retrieve( continue source_name = self.get_source_name(chunk.context.entity) - additional_context = self.get_additional_context( - chunk.context.entity, chunk.context.entity_type.model - ) + additional_context = self.get_additional_context(chunk.context.entity) results.append( { diff --git a/backend/apps/ai/common/utils.py b/backend/apps/ai/common/utils.py index 9d02f7ede2..8a258f7f3b 100644 --- a/backend/apps/ai/common/utils.py +++ b/backend/apps/ai/common/utils.py @@ -63,7 +63,7 @@ def create_chunks_and_embeddings( if chunk is not None: chunks.append(chunk) - except OpenAIError: + except (OpenAIError, AttributeError, TypeError): logger.exception("Failed to create chunks and embeddings") return [] else: @@ -79,10 +79,7 @@ def regenerate_chunks_for_context(context: Context): """ from apps.ai.models.chunk import Chunk - old_chunk_count = context.chunks.count() - if old_chunk_count > 0: - context.chunks.all().delete() - + context.chunks.all().delete() new_chunk_texts = Chunk.split_text(context.content) if not new_chunk_texts: diff --git a/backend/apps/ai/models/context.py b/backend/apps/ai/models/context.py index d722d83e88..97e0917974 100644 --- a/backend/apps/ai/models/context.py +++ b/backend/apps/ai/models/context.py @@ -11,13 +11,6 @@ logger = logging.getLogger(__name__) -def regenerate_chunks_for_context(context): - """Import regenerate_chunks_for_context to avoid circular import.""" - from apps.ai.common.utils import regenerate_chunks_for_context as _regenerate_chunks - - return _regenerate_chunks(context) - - class Context(TimestampedModel): """Context model for storing generated text related to OWASP entities.""" diff --git a/backend/tests/apps/ai/agent/tools/rag/retriever_test.py b/backend/tests/apps/ai/agent/tools/rag/retriever_test.py index 538ee6fe46..c27115195c 100644 --- a/backend/tests/apps/ai/agent/tools/rag/retriever_test.py +++ b/backend/tests/apps/ai/agent/tools/rag/retriever_test.py @@ -172,21 +172,22 @@ def test_get_additional_context_chapter(self): ): retriever = Retriever() - chapter = MagicMock() - chapter.suggested_location = "New York" - chapter.region = "North America" - chapter.country = "USA" - chapter.postal_code = "10001" - chapter.currency = "USD" - chapter.meetup_group = "OWASP NYC" - chapter.tags = ["security", "web"] - chapter.topics = ["OWASP Top 10"] - chapter.leaders_raw = ["John Doe", "Jane Smith"] - chapter.related_urls = ["https://example.com"] - chapter.is_active = True - chapter.url = "https://owasp.org/chapter" - - result = retriever.get_additional_context(chapter, "chapter") + content_object = MagicMock() + content_object.__class__.__name__ = "Chapter" + content_object.suggested_location = "New York" + content_object.region = "North America" + content_object.country = "USA" + content_object.postal_code = "10001" + content_object.currency = "USD" + content_object.meetup_group = "OWASP NYC" + content_object.tags = ["security", "web"] + content_object.topics = ["OWASP Top 10"] + content_object.leaders_raw = ["John Doe", "Jane Smith"] + content_object.related_urls = ["https://example.com"] + content_object.is_active = True + content_object.url = "https://owasp.org/chapter" + + result = retriever.get_additional_context(content_object) expected_keys = [ "location", @@ -214,23 +215,52 @@ def test_get_additional_context_project(self): retriever = Retriever() content_object = MagicMock() - content_object.project_type = "tool" + content_object.__class__.__name__ = "Project" content_object.level = "flagship" + content_object.type = "tool" + content_object.languages = ["python"] content_object.topics = ["security"] + content_object.licenses = ["MIT"] + content_object.tags = ["web"] + content_object.custom_tags = ["api"] + content_object.stars_count = 100 + content_object.forks_count = 20 + content_object.contributors_count = 5 + content_object.releases_count = 3 + content_object.open_issues_count = 2 content_object.leaders_raw = ["Alice"] content_object.related_urls = ["https://project.example.com"] + content_object.created_at = "2023-01-01" + content_object.updated_at = "2023-01-02" + content_object.released_at = "2023-01-03" + content_object.health_score = 85.5 content_object.is_active = True + content_object.track_issues = True content_object.url = "https://owasp.org/project" - result = retriever.get_additional_context(content_object, "project") + result = retriever.get_additional_context(content_object) expected_keys = [ - "project_type", "level", + "project_type", + "languages", "topics", + "licenses", + "tags", + "custom_tags", + "stars_count", + "forks_count", + "contributors_count", + "releases_count", + "open_issues_count", "leaders", "related_urls", + "created_at", + "updated_at", + "released_at", + "health_score", "is_active", + "track_issues", "url", ] for key in expected_keys: @@ -245,6 +275,7 @@ def test_get_additional_context_event(self): retriever = Retriever() content_object = MagicMock() + content_object.__class__.__name__ = "Event" content_object.start_date = "2023-01-01" content_object.end_date = "2023-01-02" content_object.suggested_location = "San Francisco" @@ -255,7 +286,7 @@ def test_get_additional_context_event(self): content_object.description = "Test event description" content_object.summary = "Test event summary" - result = retriever.get_additional_context(content_object, "event") + result = retriever.get_additional_context(content_object) expected_keys = [ "start_date", @@ -280,6 +311,7 @@ def test_get_additional_context_committee(self): retriever = Retriever() content_object = MagicMock() + content_object.__class__.__name__ = "Committee" content_object.is_active = True content_object.leaders = ["John Doe", "Jane Smith"] content_object.url = "https://committee.example.com" @@ -289,7 +321,7 @@ def test_get_additional_context_committee(self): content_object.topics = ["policy", "standards"] content_object.related_urls = ["https://related.example.com"] - result = retriever.get_additional_context(content_object, "committee") + result = retriever.get_additional_context(content_object) expected_keys = [ "is_active", @@ -322,12 +354,13 @@ def test_get_additional_context_message(self): author.name = "testuser" content_object = MagicMock() + content_object.__class__.__name__ = "Message" content_object.conversation = conversation content_object.parent_message = parent_message content_object.ts = "1234567891.123456" content_object.author = author - result = retriever.get_additional_context(content_object, "message") + result = retriever.get_additional_context(content_object) expected_keys = ["channel", "thread_ts", "ts", "user"] for key in expected_keys: @@ -347,31 +380,16 @@ def test_get_additional_context_message_no_conversation(self): retriever = Retriever() content_object = MagicMock() + content_object.__class__.__name__ = "Message" content_object.conversation = None content_object.parent_message = None content_object.ts = "1234567891.123456" content_object.author = None - result = retriever.get_additional_context(content_object, "message") + result = retriever.get_additional_context(content_object) assert result["ts"] == "1234567891.123456" - def test_get_additional_context_with_app_label(self): - """Test getting additional context with app.model format.""" - with ( - patch.dict(os.environ, {"DJANGO_OPEN_AI_SECRET_KEY": "test-key"}), - patch("openai.OpenAI"), - ): - retriever = Retriever() - - content_object = MagicMock() - content_object.suggested_location = "Test Location" - - result = retriever.get_additional_context(content_object, "owasp.chapter") - - assert "location" in result - assert result["location"] == "Test Location" - def test_extract_content_types_from_query_single_type(self): """Test extracting single content type from query.""" with ( @@ -473,13 +491,13 @@ def test_retrieve_successful_with_chunks(self, mock_chunk): mock_content_object = MagicMock() mock_content_object.name = "Test Chapter" + mock_content_object.__class__.__name__ = "Chapter" mock_content_object.suggested_location = "New York" mock_entity_type = MagicMock() mock_entity_type.model = "chapter" mock_context = MagicMock() - mock_context.content_object = mock_content_object mock_context.entity = mock_content_object mock_context.entity_type = mock_entity_type mock_context.entity_id = "123" diff --git a/backend/tests/apps/ai/common/utils_test.py b/backend/tests/apps/ai/common/utils_test.py index a9e1b19618..417335eafe 100644 --- a/backend/tests/apps/ai/common/utils_test.py +++ b/backend/tests/apps/ai/common/utils_test.py @@ -1,5 +1,5 @@ from datetime import UTC, datetime, timedelta -from unittest.mock import MagicMock, Mock, call, patch +from unittest.mock import MagicMock, call, patch import openai @@ -15,12 +15,11 @@ def __init__(self, embedding): class TestUtils: - @patch("apps.ai.common.utils.Context") @patch("apps.ai.common.utils.Chunk.update_data") @patch("apps.ai.common.utils.time.sleep") @patch("apps.ai.common.utils.datetime") def test_create_chunks_and_embeddings_success( - self, mock_datetime, mock_sleep, mock_update_data, mock_context + self, mock_datetime, mock_sleep, mock_update_data ): """Tests the successful path where the OpenAI API returns embeddings.""" base_time = datetime.now(UTC) @@ -94,7 +93,8 @@ def test_create_chunks_and_embeddings_api_error(self, mock_logger): assert result == [] - def test_create_chunks_and_embeddings_none_context(self): + @patch("apps.ai.common.utils.logger") + def test_create_chunks_and_embeddings_none_context(self, mock_logger): """Tests the failure path when context is None.""" mock_openai_client = MagicMock() @@ -102,26 +102,23 @@ def test_create_chunks_and_embeddings_none_context(self): mock_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])] mock_openai_client.embeddings.create.return_value = mock_response - with patch("apps.ai.common.utils.Chunk.update_data") as mock_update_data: - mock_chunk = Mock() - mock_update_data.return_value = mock_chunk - - result = create_chunks_and_embeddings( - chunk_texts=["some text"], - context=None, - openai_client=mock_openai_client, - ) - - assert len(result) == 1 - assert result[0] == mock_chunk + result = create_chunks_and_embeddings( + chunk_texts=["some text"], + context=None, + openai_client=mock_openai_client, + ) - mock_update_data.assert_called_once_with( - text="some text", embedding=[0.1, 0.2, 0.3], context=None, save=True - ) + # When context is None, the function should catch the AttributeError + # and log an exception, returning an empty list + mock_logger.exception.assert_called_once_with("Failed to create chunks and embeddings") + assert result == [] + @patch("apps.ai.common.utils.Chunk.update_data") @patch("apps.ai.common.utils.time.sleep") @patch("apps.ai.common.utils.datetime") - def test_create_chunks_and_embeddings_sleep_called(self, mock_datetime, mock_sleep): + def test_create_chunks_and_embeddings_sleep_called( + self, mock_datetime, mock_sleep, mock_update_data + ): """Tests that sleep is called when needed.""" base_time = datetime.now(UTC) mock_datetime.now.return_value = base_time @@ -133,25 +130,25 @@ def test_create_chunks_and_embeddings_sleep_called(self, mock_datetime, mock_sle mock_api_response.data = [MockEmbeddingData([0.1, 0.2])] mock_openai_client.embeddings.create.return_value = mock_api_response - with patch("apps.ai.common.utils.Chunk.update_data") as mock_update_data: - mock_chunk = MagicMock() - mock_update_data.return_value = mock_chunk + mock_chunk_instance = MagicMock() + mock_update_data.return_value = mock_chunk_instance - result = create_chunks_and_embeddings( - ["test chunk"], - MagicMock(), - mock_openai_client, - ) + mock_content_object = MagicMock() - mock_sleep.assert_not_called() - assert result == [mock_chunk] + result = create_chunks_and_embeddings( + ["test chunk"], + mock_content_object, + mock_openai_client, + ) + + mock_sleep.assert_not_called() + assert result == [mock_chunk_instance] - @patch("apps.ai.common.utils.Context") @patch("apps.ai.common.utils.Chunk.update_data") @patch("apps.ai.common.utils.time.sleep") @patch("apps.ai.common.utils.datetime") def test_create_chunks_and_embeddings_no_sleep_with_current_settings( - self, mock_datetime, mock_sleep, mock_update_data, mock_context + self, mock_datetime, mock_sleep, mock_update_data ): """Tests that sleep is not called with current offset settings.""" base_time = datetime.now(UTC) @@ -164,8 +161,8 @@ def test_create_chunks_and_embeddings_no_sleep_with_current_settings( mock_api_response.data = [MockEmbeddingData([0.1, 0.2])] mock_openai_client.embeddings.create.return_value = mock_api_response - mock_chunk = MagicMock() - mock_update_data.return_value = mock_chunk + mock_chunk_instance = MagicMock() + mock_update_data.return_value = mock_chunk_instance mock_context_obj = MagicMock() result = create_chunks_and_embeddings( @@ -178,11 +175,14 @@ def test_create_chunks_and_embeddings_no_sleep_with_current_settings( mock_update_data.assert_called_once_with( text="test chunk", embedding=[0.1, 0.2], context=mock_context_obj, save=True ) - assert result == [mock_chunk] + assert result == [mock_chunk_instance] + @patch("apps.ai.common.utils.Chunk.update_data") @patch("apps.ai.common.utils.time.sleep") @patch("apps.ai.common.utils.datetime") - def test_create_chunks_and_embeddings_sleep_when_rate_limited(self, mock_datetime, mock_sleep): + def test_create_chunks_and_embeddings_sleep_when_rate_limited( + self, mock_datetime, mock_sleep, mock_update_data + ): """Tests that sleep is called when rate limiting is needed.""" base_time = datetime.now(UTC) @@ -198,22 +198,23 @@ def test_create_chunks_and_embeddings_sleep_when_rate_limited(self, mock_datetim mock_api_response.data = [MockEmbeddingData([0.1, 0.2])] mock_openai_client.embeddings.create.return_value = mock_api_response - with patch("apps.ai.common.utils.Chunk.update_data") as mock_update_data: - mock_chunk = MagicMock() - mock_update_data.return_value = mock_chunk + mock_chunk_instance = MagicMock() + mock_update_data.return_value = mock_chunk_instance - with ( - patch("apps.ai.common.utils.DEFAULT_LAST_REQUEST_OFFSET_SECONDS", 5), - patch("apps.ai.common.utils.MIN_REQUEST_INTERVAL_SECONDS", 10), - ): - result = create_chunks_and_embeddings( - ["test chunk"], - MagicMock(), - mock_openai_client, - ) + with ( + patch("apps.ai.common.utils.DEFAULT_LAST_REQUEST_OFFSET_SECONDS", 5), + patch("apps.ai.common.utils.MIN_REQUEST_INTERVAL_SECONDS", 10), + ): + mock_context_obj = MagicMock() - mock_sleep.assert_called_once() - assert result == [mock_chunk] + result = create_chunks_and_embeddings( + ["test chunk"], + mock_context_obj, + mock_openai_client, + ) + + mock_sleep.assert_called_once() + assert result == [mock_chunk_instance] @patch("apps.ai.common.utils.Chunk.update_data") def test_create_chunks_and_embeddings_filter_none_chunks(self, mock_update_data): @@ -226,17 +227,19 @@ def test_create_chunks_and_embeddings_filter_none_chunks(self, mock_update_data) ] mock_openai_client.embeddings.create.return_value = mock_api_response - mock_chunk = MagicMock() - mock_update_data.side_effect = [mock_chunk, None] + mock_chunk_instance = MagicMock() + mock_update_data.side_effect = [mock_chunk_instance, None] + + mock_context_obj = MagicMock() result = create_chunks_and_embeddings( ["first chunk", "second chunk"], - MagicMock(), + mock_context_obj, mock_openai_client, ) assert len(result) == 1 - assert result[0] == mock_chunk + assert result[0] == mock_chunk_instance class TestRegenerateChunksForContext: @@ -245,9 +248,9 @@ class TestRegenerateChunksForContext: @patch("apps.ai.common.utils.logger") @patch("apps.ai.common.utils.OpenAI") @patch("apps.ai.common.utils.create_chunks_and_embeddings") - @patch("apps.ai.models.chunk.Chunk") + @patch("apps.ai.common.utils.Chunk.split_text") def test_regenerate_chunks_for_context_success( - self, mock_chunk_class, mock_create_chunks, mock_openai_class, mock_logger + self, mock_split_text, mock_create_chunks, mock_openai_class, mock_logger ): """Test successful regeneration of chunks for context.""" # Setup context mock @@ -256,12 +259,11 @@ def test_regenerate_chunks_for_context_success( # Setup existing chunks mock_existing_chunks = MagicMock() - mock_existing_chunks.count.return_value = 3 context.chunks = mock_existing_chunks # Setup chunk splitting new_chunk_texts = ["chunk1", "chunk2"] - mock_chunk_class.split_text.return_value = new_chunk_texts + mock_split_text.return_value = new_chunk_texts # Setup OpenAI client mock_openai_client = MagicMock() @@ -274,7 +276,7 @@ def test_regenerate_chunks_for_context_success( mock_existing_chunks.all().delete.assert_called_once() # Verify content was split - mock_chunk_class.split_text.assert_called_once_with(context.content) + mock_split_text.assert_called_once_with(context.content) # Verify OpenAI client was created mock_openai_class.assert_called_once() @@ -293,8 +295,8 @@ def test_regenerate_chunks_for_context_success( ) @patch("apps.ai.common.utils.logger") - @patch("apps.ai.models.chunk.Chunk") - def test_regenerate_chunks_for_context_no_content(self, mock_chunk_class, mock_logger): + @patch("apps.ai.common.utils.Chunk.split_text") + def test_regenerate_chunks_for_context_no_content(self, mock_split_text, mock_logger): """Test regeneration when there's no content to chunk.""" # Setup context mock context = MagicMock() @@ -302,11 +304,10 @@ def test_regenerate_chunks_for_context_no_content(self, mock_chunk_class, mock_l # Setup existing chunks mock_existing_chunks = MagicMock() - mock_existing_chunks.count.return_value = 2 context.chunks = mock_existing_chunks # Setup chunk splitting to return empty list - mock_chunk_class.split_text.return_value = [] + mock_split_text.return_value = [] regenerate_chunks_for_context(context) @@ -315,7 +316,7 @@ def test_regenerate_chunks_for_context_no_content(self, mock_chunk_class, mock_l mock_existing_chunks.all().delete.assert_called_once() # Verify content was split - mock_chunk_class.split_text.assert_called_once_with(context.content) + mock_split_text.assert_called_once_with(context.content) # Verify warning was logged and process stopped mock_logger.warning.assert_called_once_with( @@ -328,9 +329,9 @@ def test_regenerate_chunks_for_context_no_content(self, mock_chunk_class, mock_l @patch("apps.ai.common.utils.logger") @patch("apps.ai.common.utils.OpenAI") @patch("apps.ai.common.utils.create_chunks_and_embeddings") - @patch("apps.ai.models.chunk.Chunk") + @patch("apps.ai.common.utils.Chunk.split_text") def test_regenerate_chunks_for_context_no_existing_chunks( - self, mock_chunk_class, mock_create_chunks, mock_openai_class, mock_logger + self, mock_split_text, mock_create_chunks, mock_openai_class, mock_logger ): """Test regeneration when there are no existing chunks.""" # Setup context mock @@ -339,12 +340,11 @@ def test_regenerate_chunks_for_context_no_existing_chunks( # Setup no existing chunks mock_existing_chunks = MagicMock() - mock_existing_chunks.count.return_value = 0 context.chunks = mock_existing_chunks # Setup chunk splitting new_chunk_texts = ["chunk1", "chunk2"] - mock_chunk_class.split_text.return_value = new_chunk_texts + mock_split_text.return_value = new_chunk_texts # Setup OpenAI client mock_openai_client = MagicMock() @@ -352,11 +352,12 @@ def test_regenerate_chunks_for_context_no_existing_chunks( regenerate_chunks_for_context(context) - # Verify delete was not called since count is 0 - mock_existing_chunks.all.assert_not_called() + # Verify delete was called regardless of count + mock_existing_chunks.all.assert_called_once() + mock_existing_chunks.all().delete.assert_called_once() # Verify content was split - mock_chunk_class.split_text.assert_called_once_with(context.content) + mock_split_text.assert_called_once_with(context.content) # Verify new chunks were created mock_create_chunks.assert_called_once_with( From f6bb1bdbb835279c26764c720e04f7d27d89164f Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Thu, 21 Aug 2025 20:48:26 +0530 Subject: [PATCH 30/32] spelling fixes --- backend/tests/apps/ai/common/utils_test.py | 26 -------------------- backend/tests/apps/ai/models/context_test.py | 6 ++--- 2 files changed, 3 insertions(+), 29 deletions(-) diff --git a/backend/tests/apps/ai/common/utils_test.py b/backend/tests/apps/ai/common/utils_test.py index 417335eafe..2d18684f44 100644 --- a/backend/tests/apps/ai/common/utils_test.py +++ b/backend/tests/apps/ai/common/utils_test.py @@ -108,8 +108,6 @@ def test_create_chunks_and_embeddings_none_context(self, mock_logger): openai_client=mock_openai_client, ) - # When context is None, the function should catch the AttributeError - # and log an exception, returning an empty list mock_logger.exception.assert_called_once_with("Failed to create chunks and embeddings") assert result == [] @@ -253,35 +251,27 @@ def test_regenerate_chunks_for_context_success( self, mock_split_text, mock_create_chunks, mock_openai_class, mock_logger ): """Test successful regeneration of chunks for context.""" - # Setup context mock context = MagicMock() context.content = "This is test content for chunking" - # Setup existing chunks mock_existing_chunks = MagicMock() context.chunks = mock_existing_chunks - # Setup chunk splitting new_chunk_texts = ["chunk1", "chunk2"] mock_split_text.return_value = new_chunk_texts - # Setup OpenAI client mock_openai_client = MagicMock() mock_openai_class.return_value = mock_openai_client regenerate_chunks_for_context(context) - # Verify old chunks were deleted mock_existing_chunks.all.assert_called_once() mock_existing_chunks.all().delete.assert_called_once() - # Verify content was split mock_split_text.assert_called_once_with(context.content) - # Verify OpenAI client was created mock_openai_class.assert_called_once() - # Verify new chunks were created mock_create_chunks.assert_called_once_with( chunk_texts=new_chunk_texts, context=context, @@ -289,7 +279,6 @@ def test_regenerate_chunks_for_context_success( save=True, ) - # Verify success log mock_logger.info.assert_called_once_with( "Successfully completed chunk regeneration for new context" ) @@ -298,32 +287,25 @@ def test_regenerate_chunks_for_context_success( @patch("apps.ai.common.utils.Chunk.split_text") def test_regenerate_chunks_for_context_no_content(self, mock_split_text, mock_logger): """Test regeneration when there's no content to chunk.""" - # Setup context mock context = MagicMock() context.content = "Some content" - # Setup existing chunks mock_existing_chunks = MagicMock() context.chunks = mock_existing_chunks - # Setup chunk splitting to return empty list mock_split_text.return_value = [] regenerate_chunks_for_context(context) - # Verify old chunks were deleted mock_existing_chunks.all.assert_called_once() mock_existing_chunks.all().delete.assert_called_once() - # Verify content was split mock_split_text.assert_called_once_with(context.content) - # Verify warning was logged and process stopped mock_logger.warning.assert_called_once_with( "No content to chunk for Context. Process stopped." ) - # Verify success log was not called mock_logger.info.assert_not_called() @patch("apps.ai.common.utils.logger") @@ -334,32 +316,25 @@ def test_regenerate_chunks_for_context_no_existing_chunks( self, mock_split_text, mock_create_chunks, mock_openai_class, mock_logger ): """Test regeneration when there are no existing chunks.""" - # Setup context mock context = MagicMock() context.content = "This is test content for chunking" - # Setup no existing chunks mock_existing_chunks = MagicMock() context.chunks = mock_existing_chunks - # Setup chunk splitting new_chunk_texts = ["chunk1", "chunk2"] mock_split_text.return_value = new_chunk_texts - # Setup OpenAI client mock_openai_client = MagicMock() mock_openai_class.return_value = mock_openai_client regenerate_chunks_for_context(context) - # Verify delete was called regardless of count mock_existing_chunks.all.assert_called_once() mock_existing_chunks.all().delete.assert_called_once() - # Verify content was split mock_split_text.assert_called_once_with(context.content) - # Verify new chunks were created mock_create_chunks.assert_called_once_with( chunk_texts=new_chunk_texts, context=context, @@ -367,7 +342,6 @@ def test_regenerate_chunks_for_context_no_existing_chunks( save=True, ) - # Verify success log mock_logger.info.assert_called_once_with( "Successfully completed chunk regeneration for new context" ) diff --git a/backend/tests/apps/ai/models/context_test.py b/backend/tests/apps/ai/models/context_test.py index 8df4527ec8..4616669a40 100644 --- a/backend/tests/apps/ai/models/context_test.py +++ b/backend/tests/apps/ai/models/context_test.py @@ -195,9 +195,9 @@ def test_str_method_with_name_attribute(self): ), ): result = str(context) - assert ( - result - == "test_model Test Object: This is test content that is longer than 50 charac" + assert result == ( + "test_model Test Object: This is test content that is longer than 50 charac" + # cspell:ignore charac ) def test_str_method_with_key_attribute(self): From 506ad46f601823b4869e3c70a263aa523ee312cf Mon Sep 17 00:00:00 2001 From: Dishant1804 Date: Thu, 21 Aug 2025 23:06:13 +0530 Subject: [PATCH 31/32] test changes --- backend/apps/ai/models/context.py | 3 ++- backend/tests/apps/ai/models/context_test.py | 7 +++---- backend/tests/apps/common/open_ai_test.py | 4 ++++ 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/backend/apps/ai/models/context.py b/backend/apps/ai/models/context.py index 97e0917974..e3d93bc277 100644 --- a/backend/apps/ai/models/context.py +++ b/backend/apps/ai/models/context.py @@ -7,6 +7,7 @@ from django.db import models from apps.common.models import TimestampedModel +from apps.common.utils import truncate logger = logging.getLogger(__name__) @@ -32,7 +33,7 @@ def __str__(self): or getattr(self.entity, "key", None) or str(self.entity) ) - return f"{self.entity_type.model} {entity}: {self.content[:50]}" + return f"{self.entity_type.model} {entity}: {truncate(self.content, 50)}" @staticmethod def update_data( diff --git a/backend/tests/apps/ai/models/context_test.py b/backend/tests/apps/ai/models/context_test.py index 4616669a40..bb57f3ef8b 100644 --- a/backend/tests/apps/ai/models/context_test.py +++ b/backend/tests/apps/ai/models/context_test.py @@ -7,6 +7,7 @@ from apps.ai.models.context import Context from apps.common.models import TimestampedModel +from apps.common.utils import truncate def create_model_mock(model_class): @@ -195,10 +196,8 @@ def test_str_method_with_name_attribute(self): ), ): result = str(context) - assert result == ( - "test_model Test Object: This is test content that is longer than 50 charac" - # cspell:ignore charac - ) + expected = f"test_model Test Object: {truncate(context.content, 50)}" + assert result == expected def test_str_method_with_key_attribute(self): """Test __str__ method when entity has key but no name attribute.""" diff --git a/backend/tests/apps/common/open_ai_test.py b/backend/tests/apps/common/open_ai_test.py index e46c38e687..c4fface6c7 100644 --- a/backend/tests/apps/common/open_ai_test.py +++ b/backend/tests/apps/common/open_ai_test.py @@ -60,3 +60,7 @@ def test_complete_general_exception(self, mock_openai, mock_logger, openai_insta response = openai_instance.complete() assert response is None + + mock_logger.exception.assert_called_once_with( + "An error occurred during OpenAI API request." + ) From 871d266f48ca3d5f933696bcc327cf33250d96df Mon Sep 17 00:00:00 2001 From: Arkadii Yakovets Date: Thu, 21 Aug 2025 11:13:46 -0700 Subject: [PATCH 32/32] Update tests --- backend/tests/apps/ai/models/context_test.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/backend/tests/apps/ai/models/context_test.py b/backend/tests/apps/ai/models/context_test.py index bb57f3ef8b..d98f5061e3 100644 --- a/backend/tests/apps/ai/models/context_test.py +++ b/backend/tests/apps/ai/models/context_test.py @@ -7,7 +7,6 @@ from apps.ai.models.context import Context from apps.common.models import TimestampedModel -from apps.common.utils import truncate def create_model_mock(model_class): @@ -196,8 +195,10 @@ def test_str_method_with_name_attribute(self): ), ): result = str(context) - expected = f"test_model Test Object: {truncate(context.content, 50)}" - assert result == expected + assert ( + result + == "test_model Test Object: This is test content that is longer than 50 cha..." + ) def test_str_method_with_key_attribute(self): """Test __str__ method when entity has key but no name attribute."""